I need to calculate the gradient of the validation error w.r.t inputs x. I'm trying to see how much the validation error changes when I perturb one of the training samples.
- The validation error (E) explicitly depends on the model weights (W).
- The model weights explicitly depend on the inputs (x and y).
- Therefore, the validation error implicitly depends on the inputs.
I'm trying to calculate the gradient of E w.r.t x directly. An alternative approach would be to calculate the gradient of E w.r.t W (can easily be calculated) and the gradient of W w.r.t x (cannot do at the moment), which would allow the gradient of E w.r.t x to be calculated.
I have attached a toy example. Thanks in advance!
import numpy as np
import mnist
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense
from tensorflow.keras.utils import to_categorical
import tensorflow as tf
from autograd import grad
train_images = mnist.train_images()
train_labels = mnist.train_labels()
test_images = mnist.test_images()
test_labels = mnist.test_labels()
# Normalize the images.
train_images = (train_images / 255) - 0.5
test_images = (test_images / 255) - 0.5
# Flatten the images.
train_images = train_images.reshape((-1, 784))
test_images = test_images.reshape((-1, 784))
# Build the model.
model = Sequential([
Dense(64, activation='relu', input_shape=(784,)),
Dense(64, activation='relu'),
Dense(10, activation='softmax'),
])
# Compile the model.
model.compile(
optimizer='adam',
loss='categorical_crossentropy',
metrics=['accuracy'],
)
# Train the model.
model.fit(
train_images,
to_categorical(train_labels),
epochs=5,
batch_size=32,
)
model.save_weights('model.h5')
# Load the model's saved weights.
# model.load_weights('model.h5')
calculate_mse = tf.keras.losses.MeanSquaredError()
test_x = test_images[:5]
test_y = to_categorical(test_labels)[:5]
train_x = train_images[:1]
train_y = to_categorical(train_labels)[:1]
train_y = tf.convert_to_tensor(train_y, np.float32)
train_x = tf.convert_to_tensor(train_x, np.float64)
with tf.GradientTape() as tape:
tape.watch(train_x)
model.fit(train_x, train_y, epochs=1, verbose=0)
valid_y_hat = model(test_x, training=False)
mse = calculate_mse(test_y, valid_y_hat)
de_dx = tape.gradient(mse, train_x)
print(de_dx)
# approach 2 - does not run
def calculate_validation_mse(x):
model.fit(x, train_y, epochs=1, verbose=0)
valid_y_hat = model(test_x, training=False)
mse = calculate_mse(test_y, valid_y_hat)
return mse
train_x = train_images[:1]
train_y = to_categorical(train_labels)[:1]
validation_gradient = grad(calculate_validation_mse)
de_dx = validation_gradient(train_x)
print(de_dx)
Here's how you can do this. Derivation is as below.
Few things to note,
Disclaimer: this derivation is correct to best of my knowledge. Please do some research and verify that it is the case. You will run into memory issues for larger inputs and layer sizes.