I'm recently working on implementing the Annealed Flow Transport Method as described in https://arxiv.org/abs/2102.07501. At one point the task is to minimize a given loss function by learning a Normalizing Flow using SGD. I studied many papers on the several topics this problem brings with it, but can't figure out how to connect the ideas. So, here's the problem:
Suppose we are given a sample (x_1,...,x_N) of a distribution p. We now want to learn a Normalizing Flow T that transports each particle so that (T(x_1),...,T(x_N)) is an appropriate sample of the target distribution q. As described in the already mentioned source, this is done by minimizing the Kullback-Leibler-Divergence of T(p) and q. The resulting loss function (the one we want to minimize) is labeled with L or L(T).
The authors describe their algorithm quite detailed, however at this point they just say "Learn T using SGD to minimize L".
My intention was to use TensorFlow and Keras, with using L as a custom loss function and - as the authors suggest - the Adam optimizer, but, as it stands, here is my code:
def LearnFlow_Test(train_iters, x_train, W_train, x_val, W_val):
# Initialize
identity = lambda x: x # Initialize flow
flows = np.array(identity)
y_true = np.array([f_target(identity(x)) for x in x_val])
y_pred = np.array([f_initial(x)/jacobian_det(identity,x) for x in x_val])
val_losses = loss_function(y_true, y_pred)
# Learn
for j in range(train_iters):
# Compute training loss
y_true = np.array([f_target(identity(x)) for x in x_train])
y_pred = np.array([f_initial(x)/jacobian_det(identity,x) for x in x_train])
train_loss = loss_function(y_true, y_pred)
"""
Update flow using SGD to minimize train_loss
minimizing_flow =
"""
# Update list of flows & list of validation losses
flows = np.append(flows, minimizing_flow)
# Compute new validation loss and update the list
y_true = np.array([f_target(minimizing_flow(x)) for x in x_val])
y_pred = np.array([f_initial(x)/jacobian_det(minimizing_flow,x) for x in x_val])
val_losses = np.append(val_losses,[loss_function(y_true, y_pred)])a
return flows[np.argmin(val_losses)] # Return flow with the smallest validation error
I would be grateful for any advices, as my search for already existing code was not succesful.
Many thanks, Christian