I am trying to solve many instances of the same ODE across different constants.
Here is my code:
import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp
class SimpleODEModule(tf.Module):
def __init__(self, name=None):
super().__init__(name=name)
def __call__(self, t_initial, x_initial, solution_times, parameters):
with tf.GradientTape() as tape:
tape.watch(parameters)
solution = tfp.math.ode.BDF().solve(
self.ode_system,
t_initial,
x_initial,
solution_times,
constants={'parameters': parameters})
tape.gradient(solution.states, parameters)
return solution.states
def ode_system(self, t, x, parameters):
a = parameters[:, 0]
b = parameters[:, 1]
dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
print(dx)
return dx
constants = tf.constant([[1.0, 2.0],[3.0, 4.0], [5.0, 6.0]], dtype=tf.float32)
t_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
x_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
solution_times = tf.cast(tf.repeat(1.0, constants.shape[0]), dtype=tf.float32)
simple_ode = SimpleODEModule()
# This causes an error deep down int tfp.ode
# The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
simple_ode(t_initial, x_initial, solution_times, constants)
# Returns the expected output x(1.0) for each set of constants
simple_ode.ode_system(t_initial, x_initial, constants)
I am new to tensorflow, so I imagine I am not creating the correctly shaped tensors somewhere. I would expect this to "just work", iterating over the dimensions of the tensors to solve the ODE multiple times for each set of constants. Any help is appreciated.
I found a solution. Although I am not sure it is the best one. Instead of subclassing
tf.Module
I subclassedtf.keras.layers.Layer
and it "just worked". Here is the change in the code:Thanks to anyone who looked at this or attempted a solution.