Dataset from_dict() can't load columns with list values

248 views Asked by At

I'm trying to load multi-hot encodings as labels into a Dataset (object in the datasets library) using from_dict(), but the loaded label for each sample has the length of batch_size, instead of the dimension of labels.

A short example:

texts = ['a', 'b', 'c', 'd', 'e']
multihot_labels= [[0.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]
dataset = Dataset.from_dict({'text': texts, 'label':multihot_labels})
loader = DataLoader(dataset, batch_size=5)
for batch in loader:
    print(batch['text']
    print(batch['label'])
    break

and I got the following output: ['a', 'b', 'c', 'd', 'e'] [tensor([0., 0., 0., 0., 0.], dtype=torch.float64), tensor([0., 1., 1., 1., 0.], dtype=torch.float64), tensor([1., 1., 0., 0., 1.], dtype=torch.float64)]

The output for the text is as expected, but the label is 3 by 5 rather than 5 by 3. I'm very confused by this, and I didn't get any clue after a day's search. Any idea?

Thanks!

1

There are 1 answers

5
CDubyuh On

Edit: Explicitly declaring columns

import torch
from datasets import Dataset, DataLoader

texts = ['a', 'b', 'c', 'd', 'e']
multihot_labels = [[0.0, 0.0, 1.0], [0.0, 1.0, 1.0], [0.0, 1.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]

label_column = ClassLabel(names=["class_1", "class_2", "class_3"])
dataset = Dataset.from_dict({'text': texts, 'label': multihot_labels}, features={'label': label_column})
loader = DataLoader(dataset, batch_size=5)
for batch in loader:
    print(batch['text'])
    print(batch['label'])
    break