I created a toy model like following and I wanted to wrap it with DistributedDataParallel but I failed to train it due to tensor.device errors (especially those in Distribution):
from torch import nn
from torch.distributions import Distribution, Uniform, Normal
from torch import Tensor, diag, zeros, randn
class Opt(nn.Module):
def __init__(self):
super().__init__()
self.I = diag(Tensor([1., 1., 1.]))
def forward(self, x:Tensor, p:Tensor):
x = self._some_transforms_1(x, p)
x = self._some_transforms_2(x, p)
return x
def _some_transforms_1(self, x:Tensor, p:Tensor):
trans_matrix = zeros(3, 3)
trans_matrix[0, 0:2] = p[0:2]
trans_matrix += self.I
return x @ trans_matrix
def _some_transforms_2(self, x:Tensor, p:Tensor):
trans_matrix = zeros(3, 3)
trans_matrix[1, 0:2] = p[2:4]
trans_matrix -= self.I
return x @ trans_matrix
class Model(nn.Module):
def __init__(self, opt:Opt, dist_1:Distribution, dist_2:Distribution):
super().__init__()
self.opt = opt
self.dist_1 = dist_1
self.dist_2 = dist_2
def forward(self, x:Tensor, p:Tensor, w:float):
x = self.opt(x, p)
lp_1 = self.dist_1.log_prob(p)
lp_2 = self.dist_2.log_prob(p)
return x.sum() * lp_1 * lp_2
data = randn(100, 3, 3)
opt = Opt()
dist_1 = Normal(Tensor([0] * 4), Tensor([1] * 4))
dist_2 = Uniform(Tensor([-1] * 4), Tensor([1] * 4))
model = Model(opt, dist_1, dist_2)
for x in data:
p = dist_2.sample((1,)).squeeze()
loss = model(x, p, 1.)
I tried to do some operations like adding self.I = nn.Parameter(self.I) and trans_matrix = trans_matrix.to(x.device) to make the model parallelly trainable but I have some problem with the distributions:
(self.lower_bound <= value) & (value <= self.upper_bound)
the lower_bound and upper_bound in distributions.constraint._Interval are not automatically moved to cuda devices since Distribution is not a Module. Could anyone tell me how to make this model parallelly trainable?