How to activate the adaptive solver in torchdiffeq?

27 views Asked by At

In the documentation of the well-known library torchdiffeq, the author have made an user guide with how to use the adaptive solver. However, when doing the instruction (even with using "options"), the step size in my experiment does not change during the time. I wonder if I did anything wrong and hope any could help me fix my code.

Here is my code

pip install torchdiffeq
import numpy as np
import torch
from torchdiffeq import odeint, odeint_adjoint
from matplotlib import pyplot as plt
import matplotlib.image as img
from IPython.display import Image


def odefunc(t, z):
    return -2*t**3 + 12*t**2 - 20*t + 8.5

# Exact solution for comparison
def exact_solution(t,z0):
    return -0.5*t**4 + 4*t**3 - 10*t**2 + 8.5*t + 1

# Set up initial condition
z0 = torch.tensor([1.0])  # Initial value z(0) = 1.0

# Time points to solve the ODE at
t = torch.linspace(0., 4., 9)

# Solve the ODE with fixed step-size solver
z_fixed = odeint(odefunc, z0, t,method='euler')

# Solve the ODE with the adaptive solver
z_adaptive = odeint(odefunc, z0, t, method='dopri5')


exact_z = exact_solution(t,z0)

# vector
nx, ny = .25, .4
x = np.arange(0,4.2,nx)
y = np.arange(1,7.2,ny)
X,Y = np.meshgrid(x,y)

dy = odefunc(X, z0)
dx = np.ones(dy.shape)

dyu = dy/np.sqrt(dx**2 + dy**2)
dxu = dx/np.sqrt(dx**2 + dy**2)


plt.quiver(X,Y,dxu,dyu, color='grey')  # with arrows
#plt.quiver(X,Y,dxu,dyu, color='grey', headaxislength=3, headlength=0, pivot='middle', units='xy', scale=5, linewidth=.2, width=.02, headwidth=1)  # without arrows
plt.scatter([], [], marker=r'$\longrightarrow$', c="grey", s=200, label="Vectorized gradient vector field") # just for the legend of ->

# Plot the results
plt.plot(t.numpy(), z_fixed.numpy(), marker = 'o', label='Euler')
plt.plot(t.numpy(), z_adaptive.numpy(), marker = 'o', label='Adaptive solver')
plt.plot(t.numpy(), exact_z.numpy(), label='Exact solution')
plt.xlabel('t')
plt.ylabel('z(t)')
plt.title('Comparison of Numerical and Exact Solutions')
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.grid(True)
plt.show()

Result: enter image description here

0

There are 0 answers