Example of what I want to compare these two:
torch.tensor([[1,2],[1,2],[1,3]]) == torch.tensor([1,2])
I want this output:
[True, True, False]
But instead the broadcasting gets me:
tensor([[ True, True],
[ True, True],
[ True, False]])
The
==operator is always element-wise, even when broadcasting. You can get your desired result by aggregating withallalong the last axis.Also if it matters for your use case, you can use reductions like
allwithkeepdim=Trueto retain the number of dims.