I am trying to use this specific version of a Vision Transformer Model I found on GitHub to train a model with my own dataset. My data is a (400, 3, 224, 224) tensor and my labels are a (400) tensor in pytorch. The problem is the code on gitHub seems to be using cfg: DictConfig as input to the get_train_loader function, and I frankly have no clue how this works. I have tried to use my tensors as input like this:
Forward pass
outputs = model(inputs)
But i get the following error:
KeyError: 'channels'
So i assumed I have to somehow convert my data and labels tensor to a dataset with dictConfig. I am a beginner in pytorch and I have never used the omegaconf library soany help would be really apreciated!
Here's the code for the patch encoder:
class PatchEmbedPerChannel(nn.Module):
"""Image to Patch Embedding."""
def __init__(
self,
img_size: int = 224,
patch_size: int = 16,
in_chans: int = 3,
embed_dim: int = 768,
enable_sample: bool = True,
):
super().__init__()
num_patches = (img_size // patch_size) * (img_size // patch_size) * in_chans
self.img_size = img_size
self.patch_size = patch_size
self.num_patches = num_patches
self.proj = nn.Conv3d(
1,
embed_dim,
kernel_size=(1, patch_size, patch_size),
stride=(1, patch_size, patch_size),
) # CHANGED
self.channel_embed = nn.Embedding(in_chans, embed_dim)
self.enable_sample = enable_sample
trunc_normal_(self.channel_embed.weight, std=0.02)
def forward(self, x, extra_tokens={}):
# # assume all images in the same batch has the same input channels
# cur_channels = extra_tokens["channels"][0]
# embedding lookup
cur_channel_embed = self.channel_embed(
extra_tokens["channels"]
) # B, Cin, embed_dim=Cout
cur_channel_embed = cur_channel_embed.permute(0, 2, 1) # B Cout Cin
B, Cin, H, W = x.shape
# Note: The current number of channels (Cin) can be smaller or equal to in_chans
if self.training and self.enable_sample:
# Per batch channel sampling
# Note this may be slow
# Randomly sample the number of channels for this batch
Cin_new = random.randint(1, Cin)
# Randomly sample the selected channels
channels = random.sample(range(Cin), k=Cin_new)
Cin = Cin_new
x = x[:, channels, :, :]
# Update the embedding lookup
cur_channel_embed = cur_channel_embed[:, :, channels]
######
# shared projection layer across channels
x = self.proj(x.unsqueeze(1)) # B Cout Cin H W
# channel specific offsets
x += cur_channel_embed.unsqueeze(-1).unsqueeze(-1)
# x += self.channel_embed[:, :, cur_channels, :, :] # B Cout Cin H W
# preparing the output sequence
x = x.flatten(2) # B Cout CinHW
x = x.transpose(1, 2) # B CinHW Cout
return x, Cin
And here the get_train_loader mentioned above:
def get_train_loader(cfg: DictConfig):
# Define the training data loader.
if len(cfg.train_data) == 1:
print("There is only one training data")
train_data_cfg = next(iter(cfg.train_data.values()))
with open_dict(cfg):
cfg.train_data = train_data_cfg
train_data = getattr(data, train_data_cfg.name)(
is_train=True,
transform_cfg=cfg.train_transformations,
**train_data_cfg.args,
)
train_loader = DataLoader(
train_data, **train_data_cfg.loader, collate_fn=train_data.collate_fn
)
# We also need to pre-compute the number of batches for each epoch.
# We will use this inforamtion for the learning rate schedule.
with open_dict(cfg):
# get number of batches per epoch (many optimizers use this information to schedule
# the learning rate)
cfg.train_data.loader.num_batches = (
len(train_loader) // cfg.trainer.devices + 1
)
return train_loader
else:
print("There're more than one training data")
train_loaders = {}
len_loader = None
batch_size = 0
for name, train_data_cfg in cfg.train_data.items():
print(f"Loading {train_data_cfg.name}")
train_data = getattr(data, train_data_cfg.name)(
is_train=True,
transform_cfg=cfg.train_transformations,
**train_data_cfg.args,
)
train_loader = DataLoader(
train_data, **train_data_cfg.loader, collate_fn=train_data.collate_fn
)
train_loaders[name] = train_loader
print(f"Dataset {name} has length {len(train_loader)}")
if len_loader is None:
len_loader = len(train_loader)
else:
len_loader = max(len_loader, len(train_loader))
# batch_size += train_data_cfg.loader.batch_size
batch_size = train_data_cfg.loader.batch_size
with open_dict(cfg):
cfg.train_data.loader = {}
cfg.train_data.loader.num_batches = len_loader // cfg.trainer.devices + 1
cfg.train_data.loader.batch_size = batch_size
return train_loaders