Error in Keras custom loss function in R for Mixture Density Network

46 views Asked by At

I am attempting to fit a Mixture Density Network (gaussian mixture model where the parameters come from a neural net) in Keras in R using a custom loss function.

I get the following error:

Epoch 1/25
2023-10-23 11:55:50.657605: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:114] Plugin optimizer for device_type GPU is enabled.
750/750 [==============================] - 3s 4ms/step - loss: 75.5268
 Error in py_call_impl(callable, call_args$unnamed, call_args$named) : 
TypeError: in user code:

File "/Users/ryanbmac/.virtualenvs/r-tensorflow/lib/python3.9/site-packages/keras/src/engine/training.py", line 1972, in test_function *
return step_function(self, iterator)
File "/Library/Frameworks/R.framework/Versions/4.2-arm64/Resources/library/reticulate/python/rpytools/call.py", line 16, in python_function *
raise error
File "/Users/ryanbmac/.virtualenvs/r-tensorflow/lib/python3.9/site-packages/keras/src/backend.py", line 3613, in reshape
return tf.reshape(x, shape)

TypeError: Failed to convert elements of (None, 1) to Tensor. Consider casting elements to a supported type. See https://www.tensorflow.org/api_docs/python/tf/dtypes for supported TF dtypes.

Run `reticulate::py_last_error()` for details.

My code is as follows:

K <- 5 # the number of Gaussian components
input_shape <- ncol(X_train)  # input shape

# Define the shared neural network
model <- keras_model_sequential()
model %>%
  layer_dense(units = 64, activation = "relu", input_shape = input_shape) %>%
  layer_dense(units = 64, activation = "relu") %>%
  layer_dense(units = K * 3)  # K Gaussians: 3 parameters for each (mean, variance, mixing weight)

# Define a custom loss function for the negative log-likelihood
mdn_loss <- function(y_true, y_pred) {
  k <- backend()
  
  # Get components of the MoG from the NN
  mus <- y_pred[, 1:(K * 1)]
  logsigmas <- y_pred[, (K * 1 + 1):(K * 2)]
  alphas_raw <- y_pred[, (K * 2 + 1):(K * 3)]
  alphas <- k$softmax(alphas_raw, axis = 1)
  
  # Split and reshape to match dimensions
  y_true_r <- k_reshape(k_repeat(y_true, K), c(nrow(y_pred)*K, 1))
  mus_r <- k_reshape(mus, dim(y_true_r))
  logsigmas_r <- k_reshape(logsigmas, dim(y_true_r))
  sigs_r <- k$exp(logsigmas_r)
  alphas_r <- k_reshape(alphas, dim(y_true_r))
    
  # calculate the negative log-likelihood of the Gaussian Mixutre
  terms1 = 0.5 * k$square((y_true_r - mus_r)/sigs_r) 
  terms2 = logsigmas_r
  const = 0.5*k$log(2*pi)
  terms_nll = terms1+terms2+const
  loss <- k$sum(alphas_r * terms_nll)
  return(loss)
}

# Compile the model with the custom loss function
model %>%
  compile(
    loss = mdn_loss,
    optimizer = 'adam',           
    metrics = list()
  )

# Print the model summary
summary(model)

### train the model
training_history <- 
  model %>%
  fit(
    x = as.matrix(X_train),  # Exclude the target column
    y = y_train,             # Target column
    epochs = 25,                   # Number of training epochs
    batch_size = 32,               # Batch size
    validation_split = 0.2         # Portion of data for validation
  )

Help is appreciated!

Using print or k_eval(•) or print(k_get_value(•)) inside the loss function doesn't help because for some reason having to do with "the graph" I can't look at the tensors inside the loss function, which makes this hard to debug.

1

There are 1 answers

0
fryan On

This loss function ended up loading fine. (we'll see if it trains well though):

mdn_loss <- function(y_true, y_pred) {
  k <- backend()
  # Split the inputs into parameters
  out_mu <- y_pred[, 1:K]
  out_logsigma <- y_pred[, (K+1):(2*K)]
  out_alpha_raw <- y_pred[, (2*K+1):(3*K)]
  mu <- out_mu
  sigma <- k_exp(out_logsigma)
  alpha <- k_softmax(out_alpha_raw, axis=1)

  # raw probs
  loss_raw = 0
  for (j in 1:K) {
    dist_j = tfd_normal(mu[,j], sigma[,j])
    loss_raw_j = alpha[,j] * dist_j$prob(y_true[,1])
    loss_raw = loss_raw + loss_raw_j
  }
  loss = -k_mean(k_log(loss_raw))
  return (loss)
}

which corresponds to

enter image description here