I am trying to understand how the progress bar using tqdm works exactly. I have some code that looks as follows:
import torch
import torchvision
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
load_data()
manual_transforms = transforms.Compose([])
train_dataloader, test_dataloader, class_names = data_setup.create_dataloaders()
# them within the main function I have placed the train function that exists in the `engine.py` file
def main():
results = engine.train(model=model,
train_dataloader=train_dataloader,
test_dataloader=test_dataloader,
optimizer=optimizer,
loss_fn=loss_fn,
epochs=5,
device=device)
and the engine.train() function includes the following code for epoch in tqdm(range(epochs)): then, the training for each batch takes place to visualize the progress of the training. Each time the tqdm runs for each step it prints also the following statements:
print(f"torch version: {torch.__version__}")
print(f"torchvision version: {torchvision.__version__}")
So finally, my question is why this is happening. How does the main function have access to these global statements and how can avoid printing everything in each loop?
What you are noticing has actually nothing to do with
tqdm, but rather with the inner workings of PyTorch (in particular, theDataLoader'snum_workersattribute) and Python's underlyingmultiprocessingframework. Here is a minimum working example that should reproduce your problem:If you run this piece of code, you should see your PyTorch version number be printed exactly 4 times, messing up your
tqdmprogress bar. It is not a coincidence that this number is the same asnum_workers(which you can easily check by changing this number).What happens is the following:
num_workersis > 0, then subprocesses are launched for the workers.set_start_method()).if __name__ == "__main__":block. This includes yourprint()calls on top of the script.The behavior is documented here, along with potential mitigations. The one that would work for you, I guess, is:
So, either
print()calls to the beginning of yourif __name__ == '__main__':block,print()calls to the beginning of yourmain()function, orprint()calls.Alternatively, but this is probably not what you want, you can set
num_workers=0, which will disable the underlying use ofmultiprocessingaltogether (but in this way you will also lose the benefits of parallelization). Note that you should probably also move other function calls (such asload_data()) into theif __name__ == '__main__':block or into themain()function to avoid multiple unintended executions.