How to do a simple large matrix multiplication on multiple GPUs in PyTorch? I have wrote some simple codes, but works not well

48 views Asked by At

I want to use multiple GPUs to do matrix multiplication, like torch.mm(a, b), to reduce memory usage on a single GPU.

Here is the code working on a single GPU:

import torch

a = torch.randn(30000, 30000).cuda(1)
b = torch.randn(30000, 30000).cuda(1)
c = torch.mm(a, b)

# during this process, the maximum memory usage is 10491 MB.

Here is the code working on two GPUs:

import torch 

# assuming `a1` and `a2` are parts of a big matrix
a1 = torch.randn(15000, 30000).cuda(0)
a2 = torch.randn(15000, 30000).cuda(1)
b1 = torch.randn(30000, 30000).cuda(0)
b2 = b1.cuda(1)

c1 = torch.mm(a1,b1)
c2 = torch.mm(a2,b2).to(0)
# for now, the result `c1` and `c2` is on GPU 0
# the maximun memory usage on GPU 1 is 7059 MB
# the maximum memory usage on GPU 0 is 8777 MB, bigger than 1 because the result is on it

c = torch.concat([c1, c2], dim=0)
# OOM because concat is not in-place

Therefore, if we can make the concat operation in-place, seems it would work as expected? Or should I move c1 and c2 to CPU memory first and then cat them, then move the cated result to GPU?

I have also tried tensor parallelism provided by PyTorch 2.2:

import torch  
import torch.distributed as distributed 
import os
from torch.distributed._tensor import init_device_mesh, Shard, distribute_tensor
from torch.distributed.tensor.parallel import parallelize_module, ColwiseParallel
from visualize_sharding import visualize_sharding

mesh = init_device_mesh("cuda", (2,))
rank = distributed.get_rank()

big_tensor_1 = torch.randn(3, 2)
big_tensor_2 = torch.randn(2, 6)

print("big_tensor_1", big_tensor_1)

my_dtensor_1 = distribute_tensor(big_tensor_1, mesh, [Shard(dim=0)]) 
my_dtensor_2 = distribute_tensor(big_tensor_2, mesh, [Shard(dim=1)]) 

# visualize_sharding(my_dtensor_1, header="my_dtensor_1")

c = torch.mm(my_dtensor_1, my_dtensor_2)
print("c: ", c)

But everything would run twice because the command was python -m torch.distributed.launch --nproc_per_node=2 --nnodes=1 tmp.py, so there would be two big_tensor_1 randomly generated, how can I modify the code to make it run once with two processes?

Everthing I tried is listed in the problem details.

2

There are 2 answers

0
zenga On

I tried the following approach which can solve the first problem in the problem detail to some extent, but does not completely resolve it.

import torch

a1 = torch.randn(15000, 30000).cuda(0)
a2 = torch.randn(15000, 30000).cuda(1)
b1 = torch.randn(30000, 30000).cuda(0)
b2 = b1.cuda(1)

# create a empty tensor first,
# then directly use it to save the computation result,
# but its maximum memory usage on a single GPU is still high
c = torch.empty(30000, 30000).cuda(0)
c[:15000] = torch.mm(a1,b1)
c[15000:] = torch.mm(a2,b2).to(0)

UPDATE: this code can reduce the maximum mem usage on a single GPU when using multiple GPUs (here 2 GPUs used):

import torch

# assuming a1 and a2 are parts of a big matrix
a1 = torch.randn(15000, 30000).cuda(0)
a2 = torch.randn(15000, 30000).cuda(1)
b1 = torch.randn(30000, 15000).cuda(0)
b2 = torch.randn(30000, 15000).cuda(1)

c = torch.empty(30000, 30000).cuda(0)
c[:15000, :15000] = torch.mm(a1,b1)
c[:15000, 15000:] = torch.mm(a1.cuda(1),b2)
c[15000:, :15000] = torch.mm(a2,b1.cuda(1))
c[15000:, 15000:] = torch.mm(a2,b2)
1
Karl On

Is it necessary to use multiple GPUs or is that just a workaround for memory constraints?

On a single GPU, you can reduce memory overhead by pre-allocating the output array and breaking the matmul operation into chunks:

def chunked_matmul(a, b, n_rows, n_cols):
    assert a.shape[1] == b.shape[0]
    assert a.dtype == b.dtype
    
    rows = a.shape[0]
    cols = b.shape[1]
    c = torch.zeros(a.shape[0], b.shape[1], dtype=a.dtype)
    
    for row in range(0, rows, n_rows):
        a_chunk = a[row:row+n_rows]
        for col in range(0, cols, n_cols):
            b_chunk = b[:, col:col+n_cols]
            
            result = torch.mm(a_chunk, b_chunk)
            c[row:row+n_rows, col:col+n_cols] += result
            
    return c

Note that pytorch uses different matmul algorithms under the hood, so there may be numeric issues depending on the matmul chunksize. For example:

dim1 = 512
dim2 = 512
dim3 = 512

a = torch.randn(dim1, dim2)
b = torch.randn(dim2, dim3)

c1 = torch.mm(a,b)
c2 = chunked_matmul(a, b, 8, 8)
c3 = chunked_matmul(a, b, 128, 128)

(c1 - c2).abs().max()
> tensor(9.1553e-05)

(c1 - c3).abs().max()
> tensor(0.)