what is the function of default_loader in torch?

2.6k views Asked by At
import os
from torchvision.datasets.folder import default_loader
from torch.utils.data import Dataset

class Sample_Class(Dataset):
    def __init__(self, root, train=True, transform=None, loader=default_loader):
        self.root = os.path.expanduser(root)
        self.transform = transform
        self.loader = default_loader

In the above code snippet, what is the significance of loader=default_loader, what exactly does that do?

1

There are 1 answers

0
jodag On BEST ANSWER

This Sample_Class is likely imitating the behavior of ImageFolder, DatasetFolder, and ImageNet. The function should take a filename as input and return either a PIL.Image or accimage.Image depending on the selected image backend.

The default_loader function is defined in torchvision/datasets/folder.py

def pil_loader(path):
    # open path as file to avoid ResourceWarning (https://github.com/python-pillow/Pillow/issues/835)
    with open(path, 'rb') as f:
        img = Image.open(f)
        return img.convert('RGB')


def accimage_loader(path):
    import accimage
    try:
        return accimage.Image(path)
    except IOError:
        # Potentially a decoding problem, fall back to PIL.Image
        return pil_loader(path)


def default_loader(path):
    from torchvision import get_image_backend
    if get_image_backend() == 'accimage':
        return accimage_loader(path)
    else:
        return pil_loader(path)

Note : default_loader by default will be PIL reader