Implementing a Gumbel sigmoid to restructure the data tensor

133 views Asked by At

Suppose that we have a tensor(shape:B,W,1) of logits, each value representing a binary prediction that needs to be sampled and based on the output of sampling I want to add extra dimensions to data representation into the network (which would again be a discrete operation). The restructured data tensor is then passed into the next component of the network and so on. e.g if the sampling is 0,1,0, the input to the next layer(in the network) would be a d1,d2,x.d3 (where . represents concatenation and d1,d2,d3 are the initial tensors , x is introduced (to expand) based on the sampling) Is there a simple way to apply the Gumbel trick etc in this use-case ? A solution in PyTorch would be great!

1

There are 1 answers

0
KonstantinosKokos On

I fail to see what the utility of this would be, but here goes:

b, w, num_samples = 3, 5, 7
thresholds = torch.rand(b, w) 
noise = torch.randn(7,).view(1, 1, -1).expand(b, w, num_samples)  # (
samples = (thresholds[:, :, None] > noise).long()