EDIT: I just found ott-jax which looks like it might be what I need, but if possible I'd still like to know what I did wrong with jaxopt below!
Original: I'm trying to solve an optimal transport problem, and after following this great blog post I have a working version in numpy/scipy (comments removed for brevity).
In trying to get a jax version of this working I came across this issue and tried looking at the jaxopt library but have not been able to find an implementation of linprog or linear programming (LP). I believe LP is a subset of quadratic programming which jaxopt does implement, but have not been able to replicate the numpy version successfully. Any idea where I am going wrong or how else I can solve this?
import jax
import jax.numpy as jnp
import jaxopt
import numpy as np
from scipy.optimize import linprog
from scipy.spatial.distance import pdist, squareform
from scipy.special import softmax
jax.config.update('jax_platform_name', 'cpu')
def prep_arrays(x, p, q):
n, d = x.shape
C = squareform(pdist(x, metric="sqeuclidean"))
Ap, Aq = [], []
z = np.zeros((n, n))
z[:, 0] = 1
for i in range(n):
Ap.append(z.ravel())
Aq.append(z.transpose().ravel())
z = np.roll(z, 1, axis=1)
A = np.row_stack((Ap, Aq))[:-1]
b = np.concatenate((p, q))[:-1]
return n, C, A, b
def demo_wasserstein(x, p, q):
n, C, A, b = prep_arrays(x, p, q)
result = linprog(C.ravel(), A_eq=A, b_eq=b)
T = result.x.reshape((n, n))
return np.sqrt(np.sum(T * C)), T
def jax_attempt_1(x, p, q):
n, C, A, b = prep_arrays(x, p, q)
C, A, b = jnp.array(C), jnp.array(A), jnp.array(b)
def matvec_Q(params_Q, u):
del params_Q
return jnp.zeros_like(u) # no quadratic term so Q is just 0
def matvec_A(params_A, u):
return jnp.dot(params_A, u)
hyper_params = dict(params_obj=(None, C.ravel()), params_eq=A, params_ineq=(b, b))
osqp = jaxopt.BoxOSQP(matvec_Q=matvec_Q, matvec_A=matvec_A)
sol, state = osqp.run(None, **hyper_params)
T = sol.primal[0].reshape((n, n))
return np.sqrt(np.sum(T * C)), np.array(T)
def jax_attempt_2(x, p, q):
n, C, A, b = prep_arrays(x, p, q)
C, A, b = jnp.array(C), jnp.array(A), jnp.array(b)
def fun(T, params_obj):
_, c = params_obj
return jnp.sum(T * c)
def matvec_A(params_A, u):
return jnp.dot(params_A, u)
# solver = jaxopt.EqualityConstrainedQP(fun=fun, matvec_A=matvec_A)
solver = jaxopt.OSQP(fun=fun, matvec_A=matvec_A)
init_T = jnp.zeros((16, 16))
hyper_params = dict(params_obj=(None, C.ravel()), params_eq=(A, b), params_ineq=None)
init_params = solver.init_params(init_T.ravel(), **hyper_params)
sol, state = solver.run(init_params=init_params, **hyper_params)
T = sol.primal.reshape((n, n))
return np.sqrt(np.sum(T * C)), np.array(T)
if __name__ == '__main__':
np.random.seed(0)
n = 16
q_values = np.random.normal(size=n)
p = np.full(n, 1. / n)
q = softmax(q_values)
x = np.random.uniform(-1., 1., (n, 1))
dist_numpy, plan_numpy = demo_wasserstein(x, p, q)
dist_jax_1, plan_jax_1 = jax_attempt_1(x, p, q)
dist_jax_2, plan_jax_2 = jax_attempt_2(x, p, q)
print(f'numpy: dist {dist_numpy}, min {plan_numpy.min()}, max {plan_numpy.max()}')
print(f'jax_1: dist {dist_jax_1}, min {plan_jax_1.min()}, max {plan_jax_1.max()}')
print(f'jax_2: dist {dist_jax_2}, min {plan_jax_2.min()}, max {plan_jax_2.max()}')
# numpy: dist 0.18283759367232585, min 0.0, max 0.06250000000000001
# jax_1: dist nan, min -395690848.0, max 453536128.0
# jax_2: dist nan, min -461479360.0, max 528943168.0
ott-jax is just what I needed. While it uses the Sinkhorn algorithm as default, and is therefore an approximation, it is more than adequate for my needs. I'm sure with config changes I can improve on the performance as well.