How to efficiently compute the Hessian matrix of a Deep Neural Network?

203 views Asked by At

I am using TF2.11.

In order to have a deeper understanding of PINNs, I want to compute the Hessian matrix of the loss wrt to my PINN parameters. My toy case is a 2D Poisson equation $\Delta u = f$

I have three version, but only two works for small model. When I have bigger models the code crashes.

General code to setup the problem

The commented code is below:

# # [2D Poisson](2d-poisson.ipynb)
# 
# This notebook will show new users how to:
# - use Rectangle
# - use Fully connected Neural Network (FNN)
# - define a boundary condition
# - use a custom implemented PDE
# - use PINNacle's default training loop by using the Trainer object
# 
# ## Problem setup
# 
# $\forall (x,y)\in [-1,1]^2$:
# $$\Delta u = \partial_{xx}u + \partial_{yy}u = f$$
# 
# with $f(x,y) = -\frac{\piĀ²}{2}sin(\pi \frac{x+1}{2})sin(\pi \frac{y+1}{2})$. 
# 
# The exact solution is $u(x,y) = sin(\pi \frac{x+1}{2})sin(\pi \frac{y+1}{2})$.

# In[1]:


import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'

# tensorflow imports
import tensorflow as tf
from tensorflow import keras
from copy import deepcopy as dcp
import numpy as np
import math as m
import matplotlib.pyplot as plt
from tensorflow.python.ops.parallel_for.gradients import jacobian
pi = np.pi

#--------------------------- seed init ----------------------------------
# for reproductibility
seed = 1234
keras.utils.set_random_seed(seed)


# ## Geometry object definition

# The space domain is a rectangle and will be sampled using the Rectangle class. The arguments of Rectangle.sample are args_interior and args_boundary. p1 and p2 are respectively the lower/upper boundary of the rectangle.

# In[2]:


Xpde = tf.random.uniform(shape=[1000,2])
Xpde = 2*Xpde-1

Xdirichlet = tf.random.uniform(shape=[1000,1])
Xdirichlet = tf.concat([Xdirichlet, tf.ones_like(Xdirichlet)], axis=1)


# ## Model definition

# In[3]:


# works if 50 is replaced by 10
model = keras.Sequential(
    [
        keras.layers.Dense(50, activation="tanh", name="layer1"),
        keras.layers.Dense(50, activation="tanh", name="layer2"),
        keras.layers.Dense(50, activation="tanh", name="layer4"),
        keras.layers.Dense(1, name="layer3"),
    ]
)
model(keras.Input(shape=(2,)))
model.summary()


# ## Custom PDE definition

# In[4]:


@tf.function
def f(X):
    """RHS aka source term"""
    x , y = X[:,0:1] , X[:,1:2]
    return -pi**2/2*tf.sin(pi*(x+1)/2)*tf.sin(pi*(y+1)/2)


# The custom PDE will be defined as a class having a residuals function.

# In[5]:


class MyPDE:
    @tf.function
    def residuals(self, X ,model):
        """returns the residuals of the PDE at each data point"""
        
        x , y = X[:,0:1] , X[:,1:2]
        
        # making predictions 
        u = model(tf.concat([x,y],axis=1))
        
        # first derivatives
        u_x = tf.gradients(u,x)[0]
        u_y = tf.gradients(u,y)[0]
        
        # second derivatives
        u_xx = tf.gradients(u_x,x)[0]
        u_yy = tf.gradients(u_y,y)[0]
        
        # values of source term
        fvals = f(X)
    
        return u_xx + u_yy - fvals    

pde = MyPDE()


# ## Boundary condition

# Let's define the Dirichlet condition:

# In[6]:


def on_boundary(X):
    """returns True if the point is on the boundary"""
    l1 = np.isclose(np.abs(X[:,0]), 1.0)
    l2 = np.isclose(np.abs(X[:,1]), 1.0)
    return np.logical_or(l1,l2)

def func(X):
    """returns the value of the solution on the boundary"""
    return tf.zeros([X.shape[0],1])

@tf.function
def dirichlet_residuals(X, model):
    return func(X) - model(X)


# The function *on_boundary* will be used later to have access to points on the boundary so that they can be separated from the points inside the domain.

# In[8]:


mse = keras.losses.MeanSquaredError()
@tf.function
def get_losses(model, Xpde, pde, Xdirichlet, dirichlet_residuals):
    """compute loss of PDE + Dirichlet """
    loss_pde = mse(0, pde.residuals(Xpde, model))
    loss_dirichlet = mse(0, dirichlet_residuals(Xdirichlet, model))
    return loss_pde, loss_dirichlet

function to compute the Hessian

Version 1

def get_hessian(model, Xpde, pde, Xdirichlet, dirichlet_residuals):
    """First version working with tape, tape2 + gradient
     
    Note: pretty slow and doesnt work for big neural networks
    
    upgrades?: 
    - use tf.function (graph mode)
    - optimize memory
    - using tape.jacobian
    
    Error: crashes the notebook without any error 
    """
    with tf.GradientTape(persistent=True) as tape:
        with tf.GradientTape() as tape2:
            losses = get_losses(model, Xpde, pde, Xdirichlet, dirichlet_residuals)
            loss_total = tf.reduce_sum(losses)
        
            grads = tape2.gradient(loss_total, model.trainable_weights, unconnected_gradients=tf.UnconnectedGradients.ZERO)        
            flat_grads = tf.concat([tf.reshape(g,(-1,1)) for g in grads], axis=0)
        hessian = []
        
        for g in flat_grads:
            hessian.append(
                tf.concat([tf.reshape(gg, (1,-1)) for gg in tape.gradient(g, model.trainable_weights)], axis=1)
            )
        
    hessian = tf.concat(hessian, axis=0)
    return hessian
res = get_hessian(model, Xpde, pde, Xdirichlet, dirichlet_residuals)
print(res.shape)

Version 2


@tf.function
def get_hessian(model, Xpde, pde, Xdirichlet, dirichlet_residuals):
    """Second version working tf.gradients -> jacobian (from tensorflow.python.ops.parallel_for.gradients)
    
    Note: pretty slow and doesnt work for big neural networks
    
    upgrades?/Error: same as last version 
    """
    
    loss = tf.reduce_sum(get_losses(model, Xpde, pde, Xdirichlet, dirichlet_residuals))
    grads = tf.gradients(loss, model.trainable_weights, unconnected_gradients=tf.UnconnectedGradients.ZERO)   
    grads = tf.concat([tf.reshape(gg, (-1,1)) for gg in grads], axis=0)
                       
    hessian = jacobian(grads, model.trainable_weights)
        
    hessian = tf.concat([tf.reshape(h_, (tf.shape(grads)[0], -1)) for h_ in hessian], axis=1)
    return hessian
res = get_hessian(model, Xpde, pde, Xdirichlet, dirichlet_residuals)
print(res.shape)

Version 3

@tf.function
def get_hessian(model, Xpde, pde, Xdirichlet, dirichlet_residuals):
    """third try not working with tf.hessians
    Error: ValueError: None values not supported.
    
    upgrades?:
    - TF example with tf.hessians with neural network models
    """
    loss = tf.reduce_sum(get_losses(model, Xpde, pde, Xdirichlet, dirichlet_residuals))
    return tf.hessians(loss, model.trainable_weights)
res = get_hessian(model, Xpde, pde, Xdirichlet, dirichlet_residuals)
print(res.shape)

Possible problems

  • RAM memory
  • Tracing/Taping not optimized in the code (graph/eager mode)

Possible fix

  • upgrade memory
  • run on GPUs
  • optimize code (ref to Tracing/Taping)
  • optimize memory
  • try to "batch" the gradients/jacobians

APPROXIMATIONS

Maybe using approximations of the Hessian with lower derivatives is an option if the actual Hessian is too expensive.

Try on other Deep Learning Frameworks

I am sure that for small model this has been implemented in TF, PyTorch, JAX... But is there an example with bigger models?

1

There are 1 answers

2
Rathod On

You can try the inbuilt function from tensor flow tf.hessian() which is in tensor flow version 2.14

https://www.tensorflow.org/api_docs/python/tf/hessians

Another way is in the paper listed below which says it's faster than tf.hessian() using the pyhessian library for tensor flow

https://arxiv.org/pdf/1905.05559.pdf