I'm trying to build an LSTM network to classify sentences and provide explanation for the classification using saliency. This network must learn from the true class y_true
as well as from which words he shouldn't pay attention Z
(binary mask).
This paper inspired us to come up with our loss function. Here's what I'd like my loss function to look like :
Coût de classification
translates to classification_loss
and Coût d'explication (saillance)
to saliency_loss
(which is the same as gradient of output wrt the input) in the code below. I tried to implement this with a custom Model in Keras, with Tensorflow as backend :
loss_tracker = metrics.Mean(name="loss")
classification_loss_tracker = metrics.Mean(name="classification_loss")
saliency_loss_tracker = metrics.Mean(name="saliency_loss")
accuracy_tracker = metrics.CategoricalAccuracy(name="accuracy")
class CustomSequentialModel(Sequential):
def _train_test_step(self, data, training):
# Unpack the data
X = data[0]["X"]
Z = data[0]["Z"] # binary mask (1 for important words)
y_true = data[1]
# gradient tape requires "float32" instead of "int32"
# X.shape = (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM)
X = tf.cast(X, tf.float32)
# Persitent=True because we call the `gradient` more than once
with GradientTape(persistent=True) as tape:
# The tape will record everything that happens to X
# for automatic differentiation later on (used to compute saliency)
tape.watch(X)
# Forward pass
y_pred = self(X, training=training)
# (1) Compute the classification_loss
classification_loss = K.mean(
categorical_crossentropy(y_true, y_pred)
)
# (2) Compute the saliency loss
# (2.1) Compute the gradient of output wrt the maximum probability
log_prediction_proba = K.log(K.max(y_pred))
# (2.2) Compute the gradient of the output wrt the input
# saliency.shape is (None, MAX_SEQUENCE_LENGTH, None)
# why isn't it (None, MAX_SEQUENCE_LENGTH, EMBEDDING_DIM) ?!
saliency = tape.gradient(log_prediction_proba, X)
# (2.3) Sum along the embedding dimension
saliency = K.sum(saliency, axis=2)
# (2.4) Sum with the binary mask
saliency_loss = K.sum(K.square(saliency)*(1-Z))
# => ValueError: No gradients provided for any variable
loss = classification_loss + saliency_loss
trainable_vars = self.trainable_variables
# ValueError caused by the '+ saliency_loss'
gradients = tape.gradient(loss, trainable_vars)
del tape # garbage collection
if training:
# Update weights
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics
saliency_loss_tracker.update_state(saliency_loss)
classification_loss_tracker.update_state(classification_loss)
loss_tracker.update_state(loss)
accuracy_tracker.update_state(y_true, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
def train_step(self, data):
return self._train_test_step(data, True)
def test_step(self, data):
return self._train_test_step(data, False)
@property
def metrics(self):
return [
loss_tracker,
classification_loss_tracker,
saliency_loss_tracker,
accuracy_tracker
]
I manage to compute classification_loss
as well as saliency_loss
and I get a scalar value. However, this works : tape.gradient(classification_loss, trainable_vars)
but this doesn't tape.gradient(classification_loss + saliency_loss, trainable_vars)
and throws ValueError: No gradients provided for any variable
.
You are doing computations outside the tape context (after the first
gradient
call) and are then trying to take more gradients afterwards. This doesn't work; all operations to differentiate need to happen inside the context manager. I would suggest to restructure your code as follows, using two nested tapes:Now we have one tape responsible for computing the gradients wrt the input for the saliency. We have another tape around it which tracks those operations and can later compute the gradient of the gradient (i.e. gradient of the saliency). This tape also computes gradients for the classification loss. I moved the classification loss in the outer tape context because the inner tape doesn't need it. Note also how even the addition of the two losses is inside the contex of the outer tape -- everything has to happen in there, else the computation graph is lost/incomplete and gradients cannot be computed.