Getting a Memory Out Error while Multiplying two 4D tensors with shape (1, 4, 2097152, 32)

31 views Asked by At

I am working on a vision transformer based model architectuer where i get a Q and K shape of (1, 4, 2097152, 32) (We are working with 3D images). When i try to calculate the product of Q and K, i am getting a memory out error.

For the multiplication i used both torch.matmul and torch.bmm. Nether of them solved the error.

OutOfMemoryError: CUDA out of memory. Tried to allocate 65536.00 GiB. GPU 0 has a total capacity of 15.77 GiB of which 13.47 GiB is free. Process 56962 has 2.30 GiB memory in use. Of the allocated memory 2.00 GiB is allocated by PyTorch, and 0 bytes is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

This is the code i'm working on. Module is build to performe the cross attention.

class CPA3D(nn.Module):
    def __init__(self, dim, dim_, heads=2, dim_head=64, dropout=0., p_one = True):
        super().__init__()

        self.dim = dim
        self.heads = heads
        self.scale = (dim_head ** -0.5)

        self.attend = nn.Softmax(dim=-1)
        self.to_q = nn.Linear(dim_, heads * dim_head, bias=False)
        self.to_k = nn.Linear(dim, heads * dim_head, bias=False)
        self.to_v = nn.Linear(dim, heads * dim_head, bias=False)
        self.p_one = p_one
        self.to_out = nn.Sequential(
            nn.Linear(heads * dim_head, dim),
            nn.Dropout(dropout)
        )


    def forward(self, x1, x2):
        B, D, H, W, C = x1.shape

        x1_flat = x1.view(B, -1, C)  # B, D*H*W, C
        x2_flat = x2.view(B, -1, C)  # B, D*H*W, C

        k = self.to_k(x1_flat)

        v = self.to_v(x1_flat)

        q = self.to_q(x2_flat)



        q = rearrange(q, 'b n (h d) -> b h n d', h=self.heads)
     
        k = rearrange(k, 'b n (h d) -> b h n d', h=self.heads)
        v = rearrange(v, 'b n (h d) -> b h n d', h=self.heads)

        dots = torch.bmm(q, k.permute(0,1,3,2)) * self.scale. # line
        attn = self.attend(dots)
        out = torch.bmm(attn, v)
        out = rearrange(out, 'b h n d -> b n (h d)')


        if self.p_one:

          out = self.to_out(out)

        out = out.view(B, D, H, W, -1)

        return out

For the information i am working on colab pro with A100 GPU.

can anyone suggest me a solution to resolve this situation?

0

There are 0 answers