Solve ODE in tensorflow with tensor inputs

426 views Asked by At

I am trying to solve many instances of the same ODE across different constants.

Here is my code:

import numpy as np
import tensorflow as tf
import tensorflow_probability as tfp

class SimpleODEModule(tf.Module):
    def __init__(self, name=None):
        super().__init__(name=name)
        
    def __call__(self, t_initial, x_initial, solution_times, parameters):
        with tf.GradientTape() as tape:
            tape.watch(parameters)
            solution = tfp.math.ode.BDF().solve(
                                    self.ode_system, 
                                    t_initial,
                                    x_initial,
                                    solution_times,
                                    constants={'parameters': parameters})
            tape.gradient(solution.states, parameters)
        return solution.states
    
    def ode_system(self, t, x, parameters):
        a = parameters[:, 0]
        b = parameters[:, 1]
        dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
        print(dx)
        return dx

constants = tf.constant([[1.0, 2.0],[3.0, 4.0], [5.0, 6.0]], dtype=tf.float32)
t_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
x_initial = tf.reshape(tf.cast(tf.repeat(0.0, constants.shape[0]), dtype=tf.float32), (1,constants.shape[0]))
solution_times = tf.cast(tf.repeat(1.0, constants.shape[0]), dtype=tf.float32)

simple_ode = SimpleODEModule()

# This causes an error deep down int tfp.ode
# The truth value of an array with more than one element is ambiguous. Use a.any() or a.all()
simple_ode(t_initial, x_initial, solution_times, constants)

# Returns the expected output x(1.0) for each set of constants
simple_ode.ode_system(t_initial, x_initial, constants)

I am new to tensorflow, so I imagine I am not creating the correctly shaped tensors somewhere. I would expect this to "just work", iterating over the dimensions of the tensors to solve the ODE multiple times for each set of constants. Any help is appreciated.

1

There are 1 answers

0
jwilley44 On BEST ANSWER

I found a solution. Although I am not sure it is the best one. Instead of subclassing tf.Module I subclassed tf.keras.layers.Layer and it "just worked". Here is the change in the code:

class ODELayer(tf.keras.layers.Layer):
    def __init__(self, num_outputs, ode_system):
        super(ODELayer, self).__init__()
        self.num_outputs = num_outputs
        self.ode_system = ode_system

    def call(self, input_tensor):
        return tf.map_fn(self.solve_ode, input_tensor)
    
    def solve_ode(self, parameters):
        with tf.GradientTape() as tape:
            tape.watch(parameters)
            solution = tfp.math.ode.BDF().solve(
                    self.ode_system,
                    0.0, 0.0, [1.0],
                    constants={'parameters': parameters}
                )
            tape.gradient(solution.states, parameters)
        return solution.states
    
def simple_ode(t, x, parameters):
    a = parameters[0]
    b = parameters[1]
    dx = tf.add(tf.multiply(b, tf.exp(tf.multiply(a, t))), tf.multiply(a, x))
    return dx

Thanks to anyone who looked at this or attempted a solution.