TypeError: cannot pickle '_thread.lock' object when wrapping Pytorch Distributed Data Parallel in a class

76 views Asked by At

Here is a minimal code:

import threading
import torch

class DistributedMission:
    def __init__(self):
        self.lock = threading.Lock()
        self.world_size = 8
    def start(self):
        torch.multiprocessing.spawn(self.worker, args=(), nprocs=self.world_size, join=True)
    def worker(self, rank:int):
        with self.lock:
            print(f'{rank} working...')
            print(f'{rank} done.')

if __name__ == '__main__':
    mission = DistributedMission()
    mission.start()
    print('All Done.')

It has an error at line torch.multiprocessing.spawn(self.worker, args=(), nprocs=self.world_size, join=True): TypeError: cannot pickle '_thread.lock' object

Removing self.lock, this error disappear.

import threading
import torch

class DistributedMission:
    def __init__(self):
        self.world_size = 8
    def start(self):
        torch.multiprocessing.spawn(self.worker, args=(), nprocs=self.world_size, join=True)
    def worker(self, rank:int):
        print(f'{rank} working...')
        print(f'{rank} done.')

if __name__ == '__main__':
    mission = DistributedMission()
    mission.start()
    print('All Done.')

Or removing the wrapper class, the code also runs well:

import threading
import torch

lock = threading.Lock()
world_size = 8

def start():
    torch.multiprocessing.spawn(worker, args=(), nprocs=world_size, join=True)

def worker(rank:int):
    with lock:
        print(f'{rank} working...')
        print(f'{rank} done.')

if __name__ == '__main__':
    start()
    print('All Done.')

Why does this happen? Why does torch.multiprocessing.spawn()'s container class member matters?

I really need a wrapper class, and need a lock too so that the processes will not run in parallel for part of my code.

0

There are 0 answers