I have the following custom loss with which I want to train a neural network with.
function cuda_euclid(X)
# Reimplementation of the Euclidean distance
# function from Distances.jl, to be compatible with CUDA.
norm_sq = sum(X .^ 2, dims=1)
dot_prod = X' * X
dist_sq = norm_sq' .+ norm_sq .- 2 .* dot_prod
return sqrt.(max.(0.0, dist_sq))
end
function corr_loss(z, y)
_, m = size(z)
A = cuda_euclid(z)
B = cuda_euclid(y)
A_row_sum = sum(A, dims=1)
A_col_sum = sum(A, dims=2)
B_row_sum = sum(B, dims=1)
B_col_sum = sum(B, dims=2)
a = A .- A_row_sum ./ (m - 2) .- A_col_sum ./ (m - 2) .+ sum(A) / ((m - 1) * (m - 2))
b = B .- B_row_sum ./ (m - 2) .- B_col_sum ./ (m - 2) .+ sum(B) / ((m - 1) * (m - 2))
AB = sum(a .* b) / (m * (m - 3))
AA = sum(a .* a) / (m * (m - 3))
BB = sum(b .* b) / (m * (m - 3))
mi = (AB^0.5) / ((AA^0.5) * (BB^0.5))^0.5
return mi
end
My training loop looks like this:
@showprogress for epoch in 1:1000
# Training.
for (x, y) in train_loader
loss, grads = Flux.withgradient(model) do m
y_hat = dropdims(m(x), dims=2)
corr_loss(y_hat, y)
end
println(loss, grads)
Flux.update!(optim, model, grads[1])
push!(train_losses, loss) # logging, outside gradient context
end
# Some other code ...
end
If I use Flux.mse(y_hat, y) instead of corr_loss(y_hat, y) the training works well, so the problem only occurs when I use corr_loss. In the first iteration of the inner loop, corr_loss(y_hat, y) returns a reasonable value, but the gradients are all NaN.
I copied the code for the loss from a python implementation with torch, which works well, but I need this in Julia. I have also tested the outputs of the loss function with test matrices and I get the same output as the python implementation, which can be found here