I am working with JAX through numpyro. Specially, I want to use a B-spline function (e.g. implemented in scipy.interpolate.BSpline
) to transform different points into a spline where the input depends on some of the parameters in the model. Thus, I need to be able to differentiate the B-spline in JAX (only in the input argument and not in the knots or the integer order (of course!)).
I can easily use jax.custom_vjp
but not when JIT is used as it is in numpyro. I looked at the following:
- https://github.com/google/jax/issues/1142
- https://jax.readthedocs.io/en/latest/notebooks/Custom_derivative_rules_for_Python_code.html
and it seems like the best hope is to use a callback. Though, I cannot figure out entirely how that would work. At https://jax.readthedocs.io/en/latest/jax.experimental.host_callback.html#using-call-to-call-a-jax-function-on-another-device-with-reverse-mode-autodiff-support
the TensorFlow example with reverse mode autodiff seem not to use JIT.
The example
Here is Python code that works without JIT (see the b_spline_basis()
function):
from scipy.interpolate import BSpline
import numpy as np
from numpy import typing as npt
from functools import partial
import jax
doubleArray = npt.NDArray[np.double]
# see
# https://stackoverflow.com/q/74699053/5861244
# https://en.wikipedia.org/wiki/B-spline#Derivative_expressions
def _b_spline_deriv_inner(spline: BSpline, deriv_basis: doubleArray) -> doubleArray: # type: ignore[no-any-unimported]
out = np.zeros((deriv_basis.shape[0], deriv_basis.shape[1] - 1))
for col_index in range(out.shape[1] - 1):
scale = spline.t[col_index + spline.k + 1] - spline.t[col_index + 1]
if scale != 0:
out[:, col_index] = -deriv_basis[:, col_index + 1] / scale
for col_index in range(1, out.shape[1]):
scale = spline.t[col_index + spline.k] - spline.t[col_index]
if scale != 0:
out[:, col_index] += deriv_basis[:, col_index] / scale
return float(spline.k) * out
def _b_spline_eval(spline: BSpline, x: doubleArray, deriv: int) -> doubleArray: # type: ignore[no-any-unimported]
if deriv == 0:
return spline.design_matrix(x=x, t=spline.t, k=spline.k).todense()
elif spline.k <= 0:
return np.zeros((x.shape[0], spline.t.shape[0] - spline.k - 1))
return _b_spline_deriv_inner(
spline=spline,
deriv_basis=_b_spline_eval(
BSpline(t=spline.t, k=spline.k - 1, c=np.zeros(spline.c.shape[0] + 1)), x=x, deriv=deriv - 1
),
)
@partial(jax.custom_vjp, nondiff_argnums=(0, 1, 2))
def b_spline_basis(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> doubleArray:
return _b_spline_eval(spline=BSpline(t=knots, k=order, c=np.zeros((order + knots.shape[0] - 1))), x=x, deriv=deriv)[
:, 1:
]
def b_spline_basis_fwd(knots: doubleArray, order: int, deriv: int, x: doubleArray) -> tuple[doubleArray, doubleArray]:
spline = BSpline(t=knots, k=order, c=np.zeros(order + knots.shape[0] - 1))
return (
_b_spline_eval(spline=spline, x=x, deriv=deriv)[:, 1:],
_b_spline_eval(spline=spline, x=x, deriv=deriv + 1)[:, 1:],
)
def b_spline_basis_bwd(
knots: doubleArray, order: int, deriv: int, partials: doubleArray, grad: doubleArray
) -> tuple[doubleArray]:
return (jax.numpy.sum(partials * grad, axis=1),)
b_spline_basis.defvjp(b_spline_basis_fwd, b_spline_basis_bwd)
if __name__ == "__main__":
# tests
knots = np.array([0, 0, 0, 0, 0.25, 1, 1, 1, 1])
x = np.array([0.1, 0.5, 0.9])
order = 3
def test_jax(basis: doubleArray, partials: doubleArray, deriv: int) -> None:
weights = jax.numpy.arange(1, basis.shape[1] + 1)
def test_func(x: doubleArray) -> doubleArray:
return jax.numpy.sum(jax.numpy.dot(b_spline_basis(knots=knots, order=order, deriv=deriv, x=x), weights)) # type: ignore[no-any-return]
assert np.allclose(test_func(x), np.sum(np.dot(basis, weights)))
assert np.allclose(jax.grad(test_func)(x), np.dot(partials, weights))
deriv0 = np.transpose(
np.array(
[
0.684,
0.166666666666667,
0.00133333333333333,
0.096,
0.444444444444444,
0.0355555555555555,
0.004,
0.351851851851852,
0.312148148148148,
0,
0.037037037037037,
0.650962962962963,
]
).reshape(-1, 3)
)
deriv1 = np.transpose(
np.array(
[
2.52,
-1,
-0.04,
1.68,
-0.666666666666667,
-0.666666666666667,
0.12,
1.22222222222222,
-2.29777777777778,
0,
0.444444444444444,
3.00444444444444,
]
).reshape(-1, 3)
)
test_jax(deriv0, deriv1, deriv=0)
deriv2 = np.transpose(
np.array(
[
-69.6,
4,
0.8,
9.6,
-5.33333333333333,
5.33333333333333,
2.4,
-2.22222222222222,
-15.3777777777778,
0,
3.55555555555556,
9.24444444444445,
]
).reshape(-1, 3)
)
test_jax(deriv1, deriv2, deriv=1)
deriv3 = np.transpose(
np.array(
[
504,
-8,
-8,
-144,
26.6666666666667,
26.6666666666667,
24,
-32.8888888888889,
-32.8888888888889,
0,
14.2222222222222,
14.2222222222222,
]
).reshape(-1, 3)
)
test_jax(deriv2, deriv3, deriv=2)
The best way to accomplish this is probably using a combination of
custom_jvp
andjax.pure_callback
.Unfortunately,
pure_callback
is relatively new and does not have great documentation yet, but you can find examples of its use in the JAX user forums (for example here).Copied here for posterity, this is an example of computing the sine and cosine via numpy callbacks in jit-compatible code with custom JVP rules for autodiff.
Note that since
pure_callback
operates by sending data back to the host, it will generally have a lot of overhead on accelerators like GPU and TPU, although in a single-CPU setting this kind of approach can perform well.