Below is an example where a function with a custom-defined vector-Jacobian product (custom_vjp) is vmapped. For a simple function like this, invoking vjp fails:
@partial(custom_vjp, nondiff_argnums=(0,))
def test_func(f: Callable[..., float],
R: Array
) -> float:
return f(jnp.dot(R, R))
def test_func_fwd(f, primal):
primal_out = test_func(f, primal)
residual = 2. * primal * primal_out
return primal_out, residual
def test_func_bwd(f, residual, cotangent):
cotangent_out = residual * cotangent
return (cotangent_out, )
test_func.defvjp(test_func_fwd, test_func_bwd)
test_func = vmap(test_func, in_axes=(None, 0))
if __name__ == "__main__":
def f(x):
return x
# vjp
primal, f_vjp = vjp(partial(test_func, f),
jnp.ones((10, 3))
)
cotangent = jnp.ones(10)
cotangent_out = f_vjp(cotangent)
print(cotangent_out[0].shape)
The error message says:
ValueError: Shape of cotangent input to vjp pullback function (10,) must be the same as the shape of corresponding primal input (10, 3).
Here, I think the error message is misleading, because the cotangent input should have the same shape as the primal output, which should be (10, ) in this case. Still, it's not clear to me why this error occurs.
The problem is that in
test_func_fwd, you recursively calltest_func, but you've overwrittentest_funcin the global namespace with its vmapped version. If you leave the originaltest_funcunchanged in the global namespace, your code will work as expected: