Exploding gradient with custom loss function in Flux

66 views Asked by At

I have the following custom loss with which I want to train a neural network with.

function cuda_euclid(X)
    # Reimplementation of the Euclidean distance
    # function from Distances.jl, to be compatible with CUDA.
    norm_sq = sum(X .^ 2, dims=1)
    dot_prod = X' * X
    dist_sq = norm_sq' .+ norm_sq .- 2 .* dot_prod
    return sqrt.(max.(0.0, dist_sq))
end

function corr_loss(z, y)
    _, m = size(z)
    A = cuda_euclid(z)
    B = cuda_euclid(y)
    A_row_sum = sum(A, dims=1)
    A_col_sum = sum(A, dims=2)
    B_row_sum = sum(B, dims=1)
    B_col_sum = sum(B, dims=2)
    a = A .- A_row_sum ./ (m - 2) .- A_col_sum ./ (m - 2) .+ sum(A) / ((m - 1) * (m - 2))
    b = B .- B_row_sum ./ (m - 2) .- B_col_sum ./ (m - 2) .+ sum(B) / ((m - 1) * (m - 2))
    AB = sum(a .* b) / (m * (m - 3))
    AA = sum(a .* a) / (m * (m - 3))
    BB = sum(b .* b) / (m * (m - 3))
    mi = (AB^0.5) / ((AA^0.5) * (BB^0.5))^0.5
    return mi
end

My training loop looks like this:

    @showprogress for epoch in 1:1000

        # Training.
        for (x, y) in train_loader
            loss, grads = Flux.withgradient(model) do m
                y_hat = dropdims(m(x), dims=2)
                corr_loss(y_hat, y)
            end
            println(loss, grads)
            Flux.update!(optim, model, grads[1])
            push!(train_losses, loss)  # logging, outside gradient context        
        end
        # Some other code ...
    end

If I use Flux.mse(y_hat, y) instead of corr_loss(y_hat, y) the training works well, so the problem only occurs when I use corr_loss. In the first iteration of the inner loop, corr_loss(y_hat, y) returns a reasonable value, but the gradients are all NaN.

I copied the code for the loss from a python implementation with torch, which works well, but I need this in Julia. I have also tested the outputs of the loss function with test matrices and I get the same output as the python implementation, which can be found here

0

There are 0 answers