How do I compare 2-d tensor with 1-d tensor in Pytorch?

28 views Asked by At

Example of what I want to compare these two:

torch.tensor([[1,2],[1,2],[1,3]]) == torch.tensor([1,2])

I want this output:

[True, True, False]

But instead the broadcasting gets me:

tensor([[ True,  True],
        [ True,  True],
        [ True, False]])
1

There are 1 answers

0
Karl On

The == operator is always element-wise, even when broadcasting. You can get your desired result by aggregating with all along the last axis.

a = torch.tensor([[1,2],[1,2],[1,3]]) 
b = torch.tensor([1,2])

output1 = a == b
>tensor([[ True,  True],
         [ True,  True],
         [ True, False]])

output2 = output1.all(-1)
>tensor([ True,  True, False])

Also if it matters for your use case, you can use reductions like all with keepdim=True to retain the number of dims.

output2 = output1.all(-1) # shape (3,)
>tensor([ True,  True, False])

output2 = output1.all(-1, keepdim=True) # shape (3, 1)
>tensor([[ True],
         [ True],
         [False]])