I am iterating through each head and applying either f1 or f2 function depending on the value of the parameter self.alpha.
I only want to evaluate either function f1 or f2 not both and then select output of one based on conditional.
def f1 (x):
print('f1')
return x/x.shape[2]
def f2 (x):
print('f2')
temp = nn.relu(x)
return temp/(jnp.sum(temp,axis=-1,keepdims=True) + 1e-5)
def choose_attention(alpha, x):
return jax.lax.cond(alpha[0, 0, 0,0],f2,f1,operand=x)
results = []
func = [f1,f2]
for i in range(self.alpha.shape[1]):
print(i)
alpha_i = self.alpha[:, i:i+1, :, :]
x_i = attn_weights[:, i:i+1, :, :]
result_i = jax.lax.switch(self.alpha[0,0,0,0].astype(int),func,x_i)
results.append(result_i)
final_result = jnp.concatenate(results, axis=1)
My print statements read like 0 f1 f2 1 2 3 4 5 6 7 8 9 10 11
jax.lax.switchdoes what you want: it chooses between two different functions based on a runtime value. Your use ofprintstatements is misleading you: Pythonprintruns at trace-time rather than runtime, and all code will be traced even if it is not eventually executed.For some background on how to think about the execution model of JAX programs, I would suggest How to think in JAX.
Side note: for better performance, I would also suggest avoiding using Python
forloops to loop through array values, and instead express your algorithm using either Numpy-style explicit vectorization, or usingjax.vmapto automatically vectorize your code.