Reading files asynchronously hangs within Pytorch Dataset when DataLoader.num_workers > 1

76 views Asked by At

I'm training a classifier on raw-bytes in binary files. Each file can be several megabytes long. I have a dataset with several millions of binaries in it. Loading all this data at once and keeping it all in memory is very expensive, so I wanted to write an IterableDataset to do it on-the-fly during the training loop.

I use asyncio to read thousands of files quickly. Through some painful trial and error, I found that my asynchronous file reading code can only process 500000 files at once, so I have a variable asynchronous_chunk_size that controls how many asyncio tasks are launched at once. I also included a flag, asynch, which when False, causes the program to read all the files synchronously.

My IterableDataset works fine when DataLoader.num_workers==0. It also works fine when DataLoader.num_workers > 0 and asynch is False, i.e., the asynchronous file-reading is disabled and the program reads files synchronously. However, the program hangs when try to read files asynchronously with DataLoader.num_workers > 0 and I have no idea why (I'm a mediocre programmer).

FYI, pytorch uses multiprocessing when DataLoader.num_workers > 0. I think the problem might be related to this issue, but I can't understand it.

When I cntr-c the code that hangs, this is the stack trace that is printed:

 ^CTraceback (most recent call last):
  File "/home/lk3591/Documents/code/RawByteClf/src/data/minimal.py", line 343, in <module>
    main()
  File "/home/lk3591/Documents/code/RawByteClf/src/data/minimal.py", line 339, in main
    test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=True, num_workers=2)
  File "/home/lk3591/Documents/code/RawByteClf/src/data/minimal.py", line 310, in test
    for i, inputs in enumerate(dataloader):
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
    data = self._next_data()
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
    idx, data = self._get_data()
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
    success, data = self._try_get_data()
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
    data = self._data_queue.get(timeout=timeout)
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/queues.py", line 113, in get
    if not self._poll(timeout):
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/connection.py", line 257, in poll
    return self._poll(timeout)
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/connection.py", line 424, in _poll
    r = wait([self], timeout)
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/connection.py", line 931, in wait
    ready = selector.select(timeout)
  File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/selectors.py", line 416, in select
    fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt

Below is my program. Its kind of long, but is fairly well-written, so hopefully its easy to read. To run this program, replace the PATH variable at the bottom with a path to a directory that contains around 20 files. It doesn't matter what kind of files they are, as much program will only read their raw-bytes. For references, I'm running Python 3.10.13 and pytorch 2.0.1 on CentOS 9.

Any help you can provide would be amazing! Thanks!

"""
"""

print(f"Entered {__file__=}")

from abc import ABC
import asyncio
from collections import Counter
from collections.abc import Iterable, Sequence
from functools import partial
import gc
from itertools import islice
import math
import os
import random
from pathlib import Path
from typing import Callable, Literal, Optional

import numpy as np
import pandas as pd
import torch
from torch import ByteTensor, LongTensor, Tensor
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from tqdm import tqdm


DEFAULT_ASYNCH_CHUNK_SIZE = 500000
DEFAULT_IN_MEMORY_DTYPE = "pt"
DEFAULT_DISABLE_TQDM = False


def batched(iterable: Iterable, n: int):
    it = iter(iterable)
    while batch := tuple(islice(it, n)):
        yield batch


def read_binary_file(
    f: Path,
    max_length: Optional[int] = None,
    in_memory_dtype: Literal["bytes", "np", "pt"] = DEFAULT_IN_MEMORY_DTYPE,
) -> bytes | np.ndarray | ByteTensor:
    """
    Args:
        dtype: "UserWarning: The given buffer is not writable..."
    """
    with open(f, "rb") as fp:
        b = fp.read(max_length)

    if in_memory_dtype == "bytes":
        return b
    elif in_memory_dtype == "np":
        return np.frombuffer(b, dtype=np.uint8)
    elif in_memory_dtype == "pt":
        return torch.frombuffer(b, dtype=torch.uint8)

    raise ValueError(f"Unknown {in_memory_dtype=}")


async def read_binary_file_asynch(
    f: Path,
    max_length: Optional[int] = None,
    in_memory_dtype: Literal["bytes", "np", "pt"] = DEFAULT_IN_MEMORY_DTYPE,
) -> bytes | np.ndarray | ByteTensor:
    loop = asyncio.get_running_loop()
    return await loop.run_in_executor(None, read_binary_file, f, max_length, in_memory_dtype)


async def read_binary_files_asynch(
    files: list[str],
    max_length: Optional[int] = None,
    in_memory_dtype: Literal["bytes", "np", "pt"] = DEFAULT_IN_MEMORY_DTYPE,
    disable_tqdm: bool = DEFAULT_DISABLE_TQDM,
    asynch_chunk_size: int = DEFAULT_ASYNCH_CHUNK_SIZE,
) -> None:
    file_chunks = batched(files, asynch_chunk_size)

    iterable = file_chunks
    if not disable_tqdm:
        n_chunks = math.ceil(len(files) / asynch_chunk_size)
        iterable = tqdm(
            file_chunks,
            desc=f"Asynchronously loading {len(files)} files in {n_chunks} chunks...",
            total=n_chunks,
        )

    x = []
    for files in iterable:
        tasks = [read_binary_file_asynch(f, max_length, in_memory_dtype) for f in files]
        x_i = await asyncio.gather(*tasks)
        x.extend(x_i)
    return x


def read_binary_files(
    files: list[str],
    max_length: Optional[int] = None,
    in_memory_dtype: Literal["bytes", "np", "pt"] = "bytes",
    disable_tqdm: bool = DEFAULT_DISABLE_TQDM,
) -> list[bytes | np.ndarray | ByteTensor]:

    iterable = files
    if not disable_tqdm:
        iterable = tqdm(
            files,
            desc=f"Synchronously loading {len(files)} files...",
        )

    return [read_binary_file(f, max_length, in_memory_dtype) for f in iterable]


def to_long_tensor(x: bytes | np.ndarray | ByteTensor) -> LongTensor:
    if isinstance(x, bytes):
        return torch.frombuffer(x, dtype=torch.uint8).to(torch.long)
    if isinstance(x, np.ndarray):
        return torch.from_numpy(x).to(torch.long)
    if isinstance(x, Tensor):
        return x.to(torch.long)
    raise TypeError(f"Unexpected type: {type(x)=}")


class BinaryDataset(ABC):

    def __init__(
        self,
        files: Sequence[os.PathLike],
        labels: Optional[Sequence[int]] = None,
        max_length: Optional[int] = None,
        preprocess_fn: Callable[[LongTensor], LongTensor] = lambda x: x,
        in_memory_dtype: Literal["bytes", "np", "pt"] = "pt",
        asynch: bool = True,
        asynch_chunk_size: int = 500000,
        id2label: Optional[dict[int, str]] = None,
        label2id: Optional[dict[str, int]] = None,
    ) -> None:

        self.files = list(map(str, files))
        self.labels = torch.tensor(labels, dtype=torch.long) if isinstance(labels, Sequence) else None
        self.max_length = max_length
        self.preprocess_fn = preprocess_fn
        self.asynch = asynch
        self.asynch_chunk_size = asynch_chunk_size
        self.in_memory_dtype = in_memory_dtype
        self._id2label = id2label
        self._label2id = label2id
        self._dist = self.get_dist()

    def __len__(self) -> int:
        return len(self.files)

    def __repr__(self) -> str:
        return str(self)

    def __str__(self) -> str:
        return (
            "BinaryDataset(\n"
            f"\t{len(self)=}\n"
            f"\t{type(self.labels)}\n"
            f"\t{self.max_length=}\n"
            f"\t{self.preprocess_fn=}\n"
            f"\t{self.in_memory_dtype=}\n"
            ")"
        )

    def get_dist(self):
        if self.labels is None:
            return None
        if isinstance(self.labels, (Tensor, np.ndarray)):
            labels = self.labels.tolist()
        return Counter([self.id2label[i] for i in labels])

    @property
    def dist(self) -> Counter[str, int]:
        return self._dist

    @property
    def id2label(self) -> dict[int, str]:
        if self._id2label is not None:
            return self._id2label
        raise NotImplementedError()

    @property
    def label2id(self) -> dict[str, int]:
        if self._label2id is not None:
            return self._label2id
        raise NotImplementedError()

    @property
    def num_classes(self) -> int:
        return len(self.dist)


class IterableBinaryDataset(IterableDataset, BinaryDataset):

    def __init__(
        self,
        files: Sequence[os.PathLike],
        labels: Optional[Sequence[int]] = None,
        max_length: Optional[int] = None,
        preprocess_fn: Callable[[LongTensor], LongTensor] = lambda x: x,
        in_memory_dtype: Literal["bytes", "np", "pt"] = "pt",
        asynch: bool = True,
        asynch_chunk_size: int = 500000,
        id2label: Optional[dict[int, str]] = None,
        label2id: Optional[dict[str, int]] = None,
        chunk_size: Optional[int] = None
    ) -> None:
        super().__init__(
            files,
            labels,
            max_length,
            preprocess_fn,
            in_memory_dtype,
            asynch,
            asynch_chunk_size,
            id2label,
            label2id,
        )
        self.chunk_size = self.asynch_chunk_size if chunk_size is None else chunk_size

        # Initialized after call to __iter__. These are unique to each process when num_workers > 1.
        # Each sequences will have the same length. Their meaning is self-evident.
        # idx is used by each process to index specific value within the sequence.
        self.my_files: list[str] = None
        self.my_labels: Optional[LongTensor] = None
        self.my_x: list[Optional[bytes | np.ndarray | ByteTensor]] = None
        self.my_idx: int = None

    def __iter__(self):
        self.set_my_local_attributes()
        return self

    def set_my_local_attributes(self) -> None:
        worker_info = get_worker_info()

        if worker_info is None:
            start = 0
            end = len(self)
        else:
            per_worker = int(math.ceil((len(self) - 0) / float(worker_info.num_workers)))
            worker_id = worker_info.id
            start = 0 + worker_id * per_worker
            end = min(start + per_worker, len(self))

        self.my_length = end - start
        self.my_files = self.files[start:end]
        self.my_labels = self.labels[start:end] if self.labels is not None else None
        self.my_x = [None for _ in range(self.my_length)]
        self.my_idx = 0

    def __next__(self):
        if self.my_idx >= len(self.my_files):
            raise StopIteration()

        r = {"name": str(self.my_files[self.my_idx]).split("/")[-1]}
        if self.my_labels is not None:
            r["labels"] = self.my_labels[self.my_idx]

        if self.my_x[self.my_idx] is None:
            # Clean up the data from the last chunk
            self.my_x = [None for _ in range(self.my_length)]
            gc.collect()
            # Fetch the data for this chunk
            files = self.my_files[self.my_idx : self.my_idx + self.chunk_size]
            if self.asynch:
                loop = asyncio.get_event_loop()
                future = read_binary_files_asynch(
                    files,
                    self.max_length,
                    self.in_memory_dtype,
                    DEFAULT_DISABLE_TQDM,
                    self.asynch_chunk_size,
                )
                x = loop.run_until_complete(future)
            else:
                x = read_binary_files(files, self.max_length, self.in_memory_dtype)
            self.my_x[self.my_idx : self.my_idx + self.chunk_size] = x

        x_i = to_long_tensor(self.my_x[self.my_idx])
        x_i = self.preprocess_fn(x_i)[0:self.max_length]
        r["input_ids"] = x_i

        self.my_idx += 1
        return r


def test(
    files: list[os.PathLike],
    asynch_chunk_size: int,
    chunk_size: int,
    asynch: bool,
    num_workers: int,
):

    BATCH_SIZE = 4

    dataset = IterableBinaryDataset(
        files,
        None,
        4096,
        lambda x: x,
        "pt",
        asynch,
        asynch_chunk_size=asynch_chunk_size,
        chunk_size=chunk_size,
    )

    dataloader = DataLoader(dataset, BATCH_SIZE, num_workers=num_workers)
    for i, inputs in enumerate(dataloader):
        for n in zip(inputs["name"]):
            print(i, n)


def main():

    random.seed(0)
    np.random.seed(0)
    torch.random.manual_seed(0)

    PATH = Path("/path/to/somewhere/with/some/files/to/read/bytes/from")
    NUM_SAMPLES = 20
    FILES = list(map(str, islice(PATH.iterdir(), NUM_SAMPLES)))
    ASYNCH_CHUNK_SIZE = 5
    CHUNK_SIZE = 10

    print("Running with asynch==False and num_workers==0.")
    test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=False, num_workers=0)

    print("Running with asynch==True and num_workers==0.")
    test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=True, num_workers=0)

    print("Running with asynch==False and num_workers==2.")
    test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=False, num_workers=2)

    # This hangs :(
    print("Running with asynch==True and num_workers==2.")
    test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=True, num_workers=2)


if __name__ == "__main__":
    main()
0

There are 0 answers