How to select certain indices for multiple dimensions of a PyTorch tensor?

I have a situation where I need to add one PyTorch tensor to parts of another tensor. An example is like this:
import torch
x = torch.randn([10, 7, 128, 128]) # [batch, channel, height, width]
# In the actual program, batch_idx and channel_idx are generated dynamically
batch_idx = torch.tensor([1,3], dtype=torch.int64)
channel_idx = torch.tensor([2,3,5], dtype=torch.int64)
y = torch.randn([2, 3, 128, 128]) # [len(batch_idx), len(channel_idx), height, width]
x[batch_idx, channel_idx, :, :] += y
Running this code raises the following error:
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [2], [3]
How can I perform the desired operation without looping through each index of each dimension?
Answer
PyTorch expects that batch_idx and channel_idx can be broadcast together, but in your case batch_idx has shape [2] and channel_idx has shape [3] which cannot be broadcast directly.
You can try using torch.meshgrid along with advanced indexing:-
import torch
x = torch.randn([10, 7, 128, 128]) # [batch, channel, height, width]
batch_idx = torch.tensor([1, 3], dtype=torch.int64)
channel_idx = torch.tensor([2, 3, 5], dtype=torch.int64)
y = torch.randn([2, 3, 128, 128]) # [len(batch_idx), len(channel_idx), height, width]
# Create a meshgrid of batch and channel indices.
b_idx, c_idx = torch.meshgrid(batch_idx, channel_idx, indexing='ij') # shapes: [2, 3]
# Use the meshgrid to index x and then add y:-
x[b_idx, c_idx, :, :] += y
Enjoyed this article?
Check out more content on our blog or follow us on social media.
Browse more articles