Getting an error in svi step due to to a multiclass distribution in sample using pyro and pytorch

358 views Asked by At

I'm working on a causal variational autoencoder which works with class segmentation masks, class labels and causality(0 or 1) as the inputs.

I'm getting an error when working with batch sizes more than 1 due to the svi step. I'm using a bernoulling function because I want it to learn the probability distribution for multiple classes in an image. I think that the Categorical distribution also fits the bill here, but I get the same error with it too.

When I tried narrowing down the code lines which create the problem, I think it's in the model function:

one_vec2 = torch.ones([batch_size, self.lbl_shape[0]], **options)
class_labels = pyro.sample('class_labels', dist.Bernoulli(one_vec2*0.5), obs = lbls)      

The error:

ValueError                                Traceback (most recent call last)
<ipython-input-19-8cbc046dd2c1> in <module>()
      6 vae = Vae_Model1(lbl_sz, ch, img_sz).to(device)
      7 svi = SVI(vae.model, vae.guide, optimizer, loss = Trace_ELBO())
----> 8 train(svi, train_loader, USE_CUDA)

6 frames
/usr/local/lib/python3.6/dist-packages/pyro/util.py in check_site_shape(site, max_plate_nesting)
    320                 '- enclose the batched tensor in a with plate(...): context',
    321                 '- .to_event(...) the distribution being sampled',
--> 322                 '- .permute() data dimensions']))
    323 
    324     # Check parallel dimensions on the left of max_plate_nesting.

ValueError: at site "class_labels", invalid log_prob shape
  Expected [-1], actual [32, 21]
  Try one of the following fixes:
  - enclose the batched tensor in a with plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

Currently the batch size is 32 and the lbl_shape[0] is 21 (VOC Dataset (background and other labels))

Could someone help me with this? It'll be very much appreciated. Thank you

0

There are 0 answers