How to convert 4D tensor to Dataset with omegaconf DictConfig

44 views Asked by At

I am trying to use this specific version of a Vision Transformer Model I found on GitHub to train a model with my own dataset. My data is a (400, 3, 224, 224) tensor and my labels are a (400) tensor in pytorch. The problem is the code on gitHub seems to be using cfg: DictConfig as input to the get_train_loader function, and I frankly have no clue how this works. I have tried to use my tensors as input like this:

Forward pass

outputs = model(inputs)

But i get the following error:

KeyError: 'channels'

So i assumed I have to somehow convert my data and labels tensor to a dataset with dictConfig. I am a beginner in pytorch and I have never used the omegaconf library soany help would be really apreciated!

Here's the code for the patch encoder:

class PatchEmbedPerChannel(nn.Module):
    """Image to Patch Embedding."""

    def __init__(
        self,
        img_size: int = 224,
        patch_size: int = 16,
        in_chans: int = 3,
        embed_dim: int = 768,
        enable_sample: bool = True,
    ):
        super().__init__()
        num_patches = (img_size // patch_size) * (img_size // patch_size) * in_chans
        self.img_size = img_size
        self.patch_size = patch_size
        self.num_patches = num_patches

        self.proj = nn.Conv3d(
            1,
            embed_dim,
            kernel_size=(1, patch_size, patch_size),
            stride=(1, patch_size, patch_size),
        )  # CHANGED

        self.channel_embed = nn.Embedding(in_chans, embed_dim)
        self.enable_sample = enable_sample

        trunc_normal_(self.channel_embed.weight, std=0.02)

    def forward(self, x, extra_tokens={}):
        # # assume all images in the same batch has the same input channels
        # cur_channels = extra_tokens["channels"][0]
        # embedding lookup
        cur_channel_embed = self.channel_embed(
            extra_tokens["channels"]
        )  # B, Cin, embed_dim=Cout
        cur_channel_embed = cur_channel_embed.permute(0, 2, 1)  # B Cout Cin

        B, Cin, H, W = x.shape
        # Note: The current number of channels (Cin) can be smaller or equal to in_chans

        if self.training and self.enable_sample:
            # Per batch channel sampling
            # Note this may be slow
            # Randomly sample the number of channels for this batch
            Cin_new = random.randint(1, Cin)

            # Randomly sample the selected channels
            channels = random.sample(range(Cin), k=Cin_new)
            Cin = Cin_new
            x = x[:, channels, :, :]

            # Update the embedding lookup
            cur_channel_embed = cur_channel_embed[:, :, channels]
            ######

        # shared projection layer across channels
        x = self.proj(x.unsqueeze(1))  # B Cout Cin H W

        # channel specific offsets
        x += cur_channel_embed.unsqueeze(-1).unsqueeze(-1)
        # x += self.channel_embed[:, :, cur_channels, :, :]  # B Cout Cin H W

        # preparing the output sequence
        x = x.flatten(2)  # B Cout CinHW
        x = x.transpose(1, 2)  # B CinHW Cout

        return x, Cin

And here the get_train_loader mentioned above:

def get_train_loader(cfg: DictConfig):
    # Define the training data loader.
    if len(cfg.train_data) == 1:

        print("There is only one training data")
        train_data_cfg = next(iter(cfg.train_data.values()))
        with open_dict(cfg):
            cfg.train_data = train_data_cfg

        train_data = getattr(data, train_data_cfg.name)(
            is_train=True,
            transform_cfg=cfg.train_transformations,
            **train_data_cfg.args,
        )
        train_loader = DataLoader(
            train_data, **train_data_cfg.loader, collate_fn=train_data.collate_fn
        )

        # We also need to pre-compute the number of batches for each epoch.
        # We will use this inforamtion for the learning rate schedule.
        with open_dict(cfg):
            # get number of batches per epoch (many optimizers use this information to schedule
            # the learning rate)
            cfg.train_data.loader.num_batches = (
                len(train_loader) // cfg.trainer.devices + 1
            )

        return train_loader

    else:
        print("There're more than one training data")
        train_loaders = {}
        len_loader = None
        batch_size = 0

        for name, train_data_cfg in cfg.train_data.items():
            print(f"Loading {train_data_cfg.name}")
            train_data = getattr(data, train_data_cfg.name)(
                is_train=True,
                transform_cfg=cfg.train_transformations,
                **train_data_cfg.args,
            )
            train_loader = DataLoader(
                train_data, **train_data_cfg.loader, collate_fn=train_data.collate_fn
            )
            train_loaders[name] = train_loader

            print(f"Dataset {name} has length {len(train_loader)}")

            if len_loader is None:
                len_loader = len(train_loader)
            else:
                len_loader = max(len_loader, len(train_loader))

            # batch_size += train_data_cfg.loader.batch_size
            batch_size = train_data_cfg.loader.batch_size

        with open_dict(cfg):
            cfg.train_data.loader = {}
            cfg.train_data.loader.num_batches = len_loader // cfg.trainer.devices + 1
            cfg.train_data.loader.batch_size = batch_size

        return train_loaders
0

There are 0 answers