I am trying to implement early stopping in TF OD API. I used this code.
Here is my EarlyStoppingHook (is it essentially just a copy from the above code):
class EarlyStoppingHook(session_run_hook.SessionRunHook):
"""Hook that requests stop at a specified step."""
def __init__(self, monitor='val_loss', min_delta=0, patience=0,
mode='auto'):
"""
"""
self.monitor = monitor
self.patience = patience
self.min_delta = min_delta
self.wait = 0
self.max_wait = 0
self.ind = 0
if mode not in ['auto', 'min', 'max']:
logging.warning('EarlyStopping mode %s is unknown, '
'fallback to auto mode.', mode, RuntimeWarning)
mode = 'auto'
if mode == 'min':
self.monitor_op = np.less
elif mode == 'max':
self.monitor_op = np.greater
else:
if 'acc' in self.monitor:
self.monitor_op = np.greater
else:
self.monitor_op = np.less
if self.monitor_op == np.greater:
self.min_delta *= 1
else:
self.min_delta *= -1
self.best = np.Inf if self.monitor_op == np.less else -np.Inf
def begin(self):
# Convert names to tensors if given
graph = tf.get_default_graph()
self.monitor = graph.as_graph_element(self.monitor)
if isinstance(self.monitor, tf.Operation):
self.monitor = self.monitor.outputs[0]
def before_run(self, run_context): # pylint: disable=unused-argument
return session_run_hook.SessionRunArgs(self.monitor)
def after_run(self, run_context, run_values):
self.ind += 1
current = run_values.results
if self.ind % 200 == 0:
print(f"loss value (inside hook!!! ): {current}, best: {self.best}, wait: {self.wait}, max_wait: {self.max_wait}")
if self.monitor_op(current - self.min_delta, self.best):
self.best = current
if self.max_wait < self.wait:
self.max_wait = self.wait
self.wait = 0
else:
self.wait += 1
if self.wait >= self.patience:
run_context.request_stop()
And I use the class like this:
early_stopping_hook = EarlyStoppingHook(
monitor='total_loss',
patience=2000)
train_spec = tf.estimator.TrainSpec(
input_fn=train_input_fn, max_steps=train_steps, hooks=[early_stopping_hook])
What I don't understand is what is total_loss? Is this val loss or train loss? Also I don't understand where these losses ('total_loss', 'loss_1', 'loss_2') are defined.
So, here is what worked for me
read_eval_metrics function returns dictionary which keys are iteration number and values are different metrics and losses computer at that evaluation step. But you can also use this function for train event files. You just need to change the path.
Example of one key value pair from returned dictionary.
So I ended up setting monitor argument to 'Loss/total_loss' instead of 'total_loss' in EarlyStoppingHook.