How to stop training based on loss when using Pre-trained model and Configuration file?

516 views Asked by At

I am using a Faster RCNN model to train an object detector, using the Pipeline configuration file. I know that training can be stopped by simply cancelling directly (ctrl+c). I want the training to stop automatically based on Loss value. How can this be done? I am aware that keras callbacks can be used when monitoring epochs. Is there any such option when using configuration files and pre-trained models (which monitors steps). Thanks.

1

There are 1 answers

0
Ameya Manas On

It might just be a hack, but I found a solution to my question. The Object detector requires tf_slim package to be installed. And within the tf_slim package, there is a module called learning.py. The complete path to this might look something like this: /usr/local/lib/python3.6/site-packages/tf_slim/learning.py Here, in the learning.py, starting Line 764, the code looks something like this:

try:
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
# OutOfRangeError is thrown when epoch limit per
# tf.compat.v1.train.limit_epochs is reached.
logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

I wrote a small if statement to check the maximum value for the last five values of the total_loss, and if below a certain threshold (in this case 3), make should_stop True. This is shown below:

try:
  total_loss_list = []
  while not sv.should_stop():
    total_loss, should_stop = train_step_fn(sess, train_op, global_step,
                                            train_step_kwargs)
    total_loss_list.append(total_loss)
    if len(total_loss_list) > 5:
      if max(total_loss_list[-5:]) < 3:
        should_stop = True
    if should_stop:
      logging.info('Stopping Training.')
      sv.request_stop()
      break
except errors.OutOfRangeError as e:
  # OutOfRangeError is thrown when epoch limit per
  # tf.compat.v1.train.limit_epochs is reached.
  logging.info('Caught OutOfRangeError. Stopping Training. %s', e)

If the loss values are continuously below 3 for five steps, then the training stops. The downside to this is that, the package distribution of tf_slim has to be altered. And every time you work on a new object detection problem, this threshold loss value will change. A better way would be to use a configuration file where you supply the threshold loss value. But I'm stopping here for now. If anyone else has a better solution, please share. I hope this helps someone. Thank you!