I'm training a classifier on raw-bytes in binary files. Each file can be several megabytes long. I have a dataset with several millions of binaries in it. Loading all this data at once and keeping it all in memory is very expensive, so I wanted to write an IterableDataset to do it on-the-fly during the training loop.
I use asyncio to read thousands of files quickly. Through some painful trial and error, I found that my asynchronous file reading code can only process 500000 files at once, so I have a variable asynchronous_chunk_size that controls how many asyncio tasks are launched at once. I also included a flag, asynch, which when False, causes the program to read all the files synchronously.
My IterableDataset works fine when DataLoader.num_workers==0. It also works fine when DataLoader.num_workers > 0 and asynch is False, i.e., the asynchronous file-reading is disabled and the program reads files synchronously. However, the program hangs when try to read files asynchronously with DataLoader.num_workers > 0 and I have no idea why (I'm a mediocre programmer).
FYI, pytorch uses multiprocessing when DataLoader.num_workers > 0. I think the problem might be related to this issue, but I can't understand it.
When I cntr-c the code that hangs, this is the stack trace that is printed:
^CTraceback (most recent call last):
File "/home/lk3591/Documents/code/RawByteClf/src/data/minimal.py", line 343, in <module>
main()
File "/home/lk3591/Documents/code/RawByteClf/src/data/minimal.py", line 339, in main
test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=True, num_workers=2)
File "/home/lk3591/Documents/code/RawByteClf/src/data/minimal.py", line 310, in test
for i, inputs in enumerate(dataloader):
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in __next__
data = self._next_data()
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
idx, data = self._get_data()
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
success, data = self._try_get_data()
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/queues.py", line 113, in get
if not self._poll(timeout):
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/connection.py", line 257, in poll
return self._poll(timeout)
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/connection.py", line 424, in _poll
r = wait([self], timeout)
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/multiprocessing/connection.py", line 931, in wait
ready = selector.select(timeout)
File "/home/lk3591/anaconda3/envs/RawByteClf/lib/python3.10/selectors.py", line 416, in select
fd_event_list = self._selector.poll(timeout)
KeyboardInterrupt
Below is my program. Its kind of long, but is fairly well-written, so hopefully its easy to read. To run this program, replace the PATH variable at the bottom with a path to a directory that contains around 20 files. It doesn't matter what kind of files they are, as much program will only read their raw-bytes. For references, I'm running Python 3.10.13 and pytorch 2.0.1 on CentOS 9.
Any help you can provide would be amazing! Thanks!
"""
"""
print(f"Entered {__file__=}")
from abc import ABC
import asyncio
from collections import Counter
from collections.abc import Iterable, Sequence
from functools import partial
import gc
from itertools import islice
import math
import os
import random
from pathlib import Path
from typing import Callable, Literal, Optional
import numpy as np
import pandas as pd
import torch
from torch import ByteTensor, LongTensor, Tensor
from torch.utils.data import DataLoader, IterableDataset, get_worker_info
from tqdm import tqdm
DEFAULT_ASYNCH_CHUNK_SIZE = 500000
DEFAULT_IN_MEMORY_DTYPE = "pt"
DEFAULT_DISABLE_TQDM = False
def batched(iterable: Iterable, n: int):
it = iter(iterable)
while batch := tuple(islice(it, n)):
yield batch
def read_binary_file(
f: Path,
max_length: Optional[int] = None,
in_memory_dtype: Literal["bytes", "np", "pt"] = DEFAULT_IN_MEMORY_DTYPE,
) -> bytes | np.ndarray | ByteTensor:
"""
Args:
dtype: "UserWarning: The given buffer is not writable..."
"""
with open(f, "rb") as fp:
b = fp.read(max_length)
if in_memory_dtype == "bytes":
return b
elif in_memory_dtype == "np":
return np.frombuffer(b, dtype=np.uint8)
elif in_memory_dtype == "pt":
return torch.frombuffer(b, dtype=torch.uint8)
raise ValueError(f"Unknown {in_memory_dtype=}")
async def read_binary_file_asynch(
f: Path,
max_length: Optional[int] = None,
in_memory_dtype: Literal["bytes", "np", "pt"] = DEFAULT_IN_MEMORY_DTYPE,
) -> bytes | np.ndarray | ByteTensor:
loop = asyncio.get_running_loop()
return await loop.run_in_executor(None, read_binary_file, f, max_length, in_memory_dtype)
async def read_binary_files_asynch(
files: list[str],
max_length: Optional[int] = None,
in_memory_dtype: Literal["bytes", "np", "pt"] = DEFAULT_IN_MEMORY_DTYPE,
disable_tqdm: bool = DEFAULT_DISABLE_TQDM,
asynch_chunk_size: int = DEFAULT_ASYNCH_CHUNK_SIZE,
) -> None:
file_chunks = batched(files, asynch_chunk_size)
iterable = file_chunks
if not disable_tqdm:
n_chunks = math.ceil(len(files) / asynch_chunk_size)
iterable = tqdm(
file_chunks,
desc=f"Asynchronously loading {len(files)} files in {n_chunks} chunks...",
total=n_chunks,
)
x = []
for files in iterable:
tasks = [read_binary_file_asynch(f, max_length, in_memory_dtype) for f in files]
x_i = await asyncio.gather(*tasks)
x.extend(x_i)
return x
def read_binary_files(
files: list[str],
max_length: Optional[int] = None,
in_memory_dtype: Literal["bytes", "np", "pt"] = "bytes",
disable_tqdm: bool = DEFAULT_DISABLE_TQDM,
) -> list[bytes | np.ndarray | ByteTensor]:
iterable = files
if not disable_tqdm:
iterable = tqdm(
files,
desc=f"Synchronously loading {len(files)} files...",
)
return [read_binary_file(f, max_length, in_memory_dtype) for f in iterable]
def to_long_tensor(x: bytes | np.ndarray | ByteTensor) -> LongTensor:
if isinstance(x, bytes):
return torch.frombuffer(x, dtype=torch.uint8).to(torch.long)
if isinstance(x, np.ndarray):
return torch.from_numpy(x).to(torch.long)
if isinstance(x, Tensor):
return x.to(torch.long)
raise TypeError(f"Unexpected type: {type(x)=}")
class BinaryDataset(ABC):
def __init__(
self,
files: Sequence[os.PathLike],
labels: Optional[Sequence[int]] = None,
max_length: Optional[int] = None,
preprocess_fn: Callable[[LongTensor], LongTensor] = lambda x: x,
in_memory_dtype: Literal["bytes", "np", "pt"] = "pt",
asynch: bool = True,
asynch_chunk_size: int = 500000,
id2label: Optional[dict[int, str]] = None,
label2id: Optional[dict[str, int]] = None,
) -> None:
self.files = list(map(str, files))
self.labels = torch.tensor(labels, dtype=torch.long) if isinstance(labels, Sequence) else None
self.max_length = max_length
self.preprocess_fn = preprocess_fn
self.asynch = asynch
self.asynch_chunk_size = asynch_chunk_size
self.in_memory_dtype = in_memory_dtype
self._id2label = id2label
self._label2id = label2id
self._dist = self.get_dist()
def __len__(self) -> int:
return len(self.files)
def __repr__(self) -> str:
return str(self)
def __str__(self) -> str:
return (
"BinaryDataset(\n"
f"\t{len(self)=}\n"
f"\t{type(self.labels)}\n"
f"\t{self.max_length=}\n"
f"\t{self.preprocess_fn=}\n"
f"\t{self.in_memory_dtype=}\n"
")"
)
def get_dist(self):
if self.labels is None:
return None
if isinstance(self.labels, (Tensor, np.ndarray)):
labels = self.labels.tolist()
return Counter([self.id2label[i] for i in labels])
@property
def dist(self) -> Counter[str, int]:
return self._dist
@property
def id2label(self) -> dict[int, str]:
if self._id2label is not None:
return self._id2label
raise NotImplementedError()
@property
def label2id(self) -> dict[str, int]:
if self._label2id is not None:
return self._label2id
raise NotImplementedError()
@property
def num_classes(self) -> int:
return len(self.dist)
class IterableBinaryDataset(IterableDataset, BinaryDataset):
def __init__(
self,
files: Sequence[os.PathLike],
labels: Optional[Sequence[int]] = None,
max_length: Optional[int] = None,
preprocess_fn: Callable[[LongTensor], LongTensor] = lambda x: x,
in_memory_dtype: Literal["bytes", "np", "pt"] = "pt",
asynch: bool = True,
asynch_chunk_size: int = 500000,
id2label: Optional[dict[int, str]] = None,
label2id: Optional[dict[str, int]] = None,
chunk_size: Optional[int] = None
) -> None:
super().__init__(
files,
labels,
max_length,
preprocess_fn,
in_memory_dtype,
asynch,
asynch_chunk_size,
id2label,
label2id,
)
self.chunk_size = self.asynch_chunk_size if chunk_size is None else chunk_size
# Initialized after call to __iter__. These are unique to each process when num_workers > 1.
# Each sequences will have the same length. Their meaning is self-evident.
# idx is used by each process to index specific value within the sequence.
self.my_files: list[str] = None
self.my_labels: Optional[LongTensor] = None
self.my_x: list[Optional[bytes | np.ndarray | ByteTensor]] = None
self.my_idx: int = None
def __iter__(self):
self.set_my_local_attributes()
return self
def set_my_local_attributes(self) -> None:
worker_info = get_worker_info()
if worker_info is None:
start = 0
end = len(self)
else:
per_worker = int(math.ceil((len(self) - 0) / float(worker_info.num_workers)))
worker_id = worker_info.id
start = 0 + worker_id * per_worker
end = min(start + per_worker, len(self))
self.my_length = end - start
self.my_files = self.files[start:end]
self.my_labels = self.labels[start:end] if self.labels is not None else None
self.my_x = [None for _ in range(self.my_length)]
self.my_idx = 0
def __next__(self):
if self.my_idx >= len(self.my_files):
raise StopIteration()
r = {"name": str(self.my_files[self.my_idx]).split("/")[-1]}
if self.my_labels is not None:
r["labels"] = self.my_labels[self.my_idx]
if self.my_x[self.my_idx] is None:
# Clean up the data from the last chunk
self.my_x = [None for _ in range(self.my_length)]
gc.collect()
# Fetch the data for this chunk
files = self.my_files[self.my_idx : self.my_idx + self.chunk_size]
if self.asynch:
loop = asyncio.get_event_loop()
future = read_binary_files_asynch(
files,
self.max_length,
self.in_memory_dtype,
DEFAULT_DISABLE_TQDM,
self.asynch_chunk_size,
)
x = loop.run_until_complete(future)
else:
x = read_binary_files(files, self.max_length, self.in_memory_dtype)
self.my_x[self.my_idx : self.my_idx + self.chunk_size] = x
x_i = to_long_tensor(self.my_x[self.my_idx])
x_i = self.preprocess_fn(x_i)[0:self.max_length]
r["input_ids"] = x_i
self.my_idx += 1
return r
def test(
files: list[os.PathLike],
asynch_chunk_size: int,
chunk_size: int,
asynch: bool,
num_workers: int,
):
BATCH_SIZE = 4
dataset = IterableBinaryDataset(
files,
None,
4096,
lambda x: x,
"pt",
asynch,
asynch_chunk_size=asynch_chunk_size,
chunk_size=chunk_size,
)
dataloader = DataLoader(dataset, BATCH_SIZE, num_workers=num_workers)
for i, inputs in enumerate(dataloader):
for n in zip(inputs["name"]):
print(i, n)
def main():
random.seed(0)
np.random.seed(0)
torch.random.manual_seed(0)
PATH = Path("/path/to/somewhere/with/some/files/to/read/bytes/from")
NUM_SAMPLES = 20
FILES = list(map(str, islice(PATH.iterdir(), NUM_SAMPLES)))
ASYNCH_CHUNK_SIZE = 5
CHUNK_SIZE = 10
print("Running with asynch==False and num_workers==0.")
test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=False, num_workers=0)
print("Running with asynch==True and num_workers==0.")
test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=True, num_workers=0)
print("Running with asynch==False and num_workers==2.")
test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=False, num_workers=2)
# This hangs :(
print("Running with asynch==True and num_workers==2.")
test(FILES.copy(), ASYNCH_CHUNK_SIZE, CHUNK_SIZE, asynch=True, num_workers=2)
if __name__ == "__main__":
main()