I want to define a loss function based on a complex series of transformations of the output of a neural network. These transformations involve somewhat complex logic like this that doesn't seem to be possible without in-place operations (see comments):
def get_X_torch(C, c_table):
"""
_zmat_transformation.py line 57
C = torch tensor of floats where rows are all bonds, then all angles, then all dihedrals
c_table = torch tensor of ints where rows are all bond_idx, then all angle_idx, then all dihedral_idx
c_table blank indices for beginning of z-matrix are labeled as -9223372036854775807
"""
X = torch.zeros_like(C, device="cuda:0") # ([b a d], n_atoms)
n_atoms = X.shape[1]
# this is some complicated logic - not vectorizable because the variables
# all influence each other throughout the loop (it's a nonlinear transformation)
j: int = 0
for j in range(n_atoms):
B, ref_pos = get_B_torch(X, c_table, j)
S = get_S_torch(C, j)
X[:, j] = torch.mv(B, S) + get_ref_pos_torch(X, c_table[0, j]) # X[:, j] depends on X's current value as a whole!!! This is the tricky step
return X.T
The training code snippet looks like this below. I need to build up clash_loss iteratively with my_function_script, which is a wrapper for the functionality above, but since I'm not doing clash_loss += that should be fine. I think the problem is in the complicated logic above. The error message is that it can't take the gradient because of in-place operations somewhere in the pathway.
reconstructed_angles = torch.atan2(internal_data_batch_reconstructed[:, 0:304], internal_data_batch_reconstructed[:, 304:])
if clash_mode is True:
clash_loss = torch.tensor(0.0, requires_grad=True, device="cuda")
bonds = torch.tensor(init_z_mat["bond"].values, device="cuda", requires_grad=True)
angles = torch.tensor(init_z_mat["angle"].values * (torch.pi / 180), device="cuda", requires_grad=True)
for i in range(batch_size):
print(i + 1, batch_size)
C = torch.stack((bonds, angles, reconstructed_angles[i]))
xyz = my_function_script(C, construction_table) # very complicated function but written in pure PyTorch
# this function cannot not involve inplace operations (see other bit of code)
temp_loss = 1.0 if get_clash_loss(xyz) > 0.0 else 0.0
clash_loss = clash_loss + temp_loss
total_loss = clash_loss
total_loss.backward() # <--- this fails
Is there anything I can do to make this train of logic differentiable so that clash_loss.backward() works? Manually finding the derivative is completely impossible for such a complex set of functions...
I tried rewriting with copies and without obvious inplace edits (see below) but this still doesn't work.
Xs = [X]
for j in range(n_atoms):
B, ref_pos = get_B_torch(Xs[-1], c_table, j)
S = get_S_torch(C, j)
first = torch.mv(B, S)
second = get_ref_pos_torch(Xs[-1], c_table[0, j])
Xcopy = torch.cat((Xs[-1][:, 0:j - 1], (first + second).reshape((-1, 1)), Xs[-1][:, j + 1:]), -1)
Xs = Xs + [Xcopy]
return Xs[-1].T