TensorBoard with Trax

183 views Asked by At

Anyone managed to log the loss with TensorBoard? I am using the trax ml library. I am getting this error TypeError: 'SummaryWriter' object is not callable.

I am using the SummaryWriter from jaxboard and then adding it to callbacks within training.Loop.

my_dir = "/some_dir" + datetime.datetime.now().strftime("%Y%m%d-%H%M%S")
summary_writer = jaxboard.SummaryWriter(log_dir=my_dir)

def train_model(model, batch_size=batch_size, n_steps=1000, output_dir=output_dir):
    '''
    Input: 
        model - the model
        train_task - Training task
        eval_tasks - Evaluation task
        n_steps - the evaluation steps
        output_dir - folder to save your files
    Output:
        trainer -  trax trainer
    '''
    train_task = training.TrainTask(
        labeled_data=train_generator(batch_size=batch_size, shuffle=True),
        loss_layer=TripletLoss(),
        optimizer=trax.optimizers.Adam(learning_rate=0.001),
        n_steps_per_checkpoint=1000,
    )

    eval_tasks = training.EvalTask(
        labeled_data=val_generator(batch_size=batch_size, shuffle=True),
        metrics=[TripletLoss()],
        n_eval_batches=10,
    )

    training_loop = training.Loop(
                                model, # The learning model
                                train_task, # The training task
                                eval_tasks = eval_tasks, # The evaluation task
                                #random_seed=35,
                                output_dir = output_dir, # The output directory
                                callbacks=[summary_writer], # Logging
                                ) 
    
    training_loop.run(n_steps = n_steps)
    
    # Return the training_loop, since it has the model.
    return training_loop

The error appears when I run the training loop:

training_loop = train_model(my_model())
1

There are 1 answers

0
Exa On BEST ANSWER

Worked when I removed the line with callbacks, summary_writer and instead added this on google colab:

%load_ext tensorboard
%cd '/content/drive/MyDrive/path_to_the_notebook_/'
# *train* folder is created by trax and holds the logs for the train run
# to log the eval run change to --logdir eval
%tensorboard --logdir train # train = name of folder