How to select certain indices for multiple dimensions of a PyTorch tensor?

How to select certain indices for multiple dimensions of a PyTorch tensor?

I have a situation where I need to add one PyTorch tensor to parts of another tensor. An example is like this:

import torch

x = torch.randn([10, 7, 128, 128])  # [batch, channel, height, width]

# In the actual program, batch_idx and channel_idx are generated dynamically
batch_idx = torch.tensor([1,3], dtype=torch.int64)
channel_idx = torch.tensor([2,3,5], dtype=torch.int64)

y = torch.randn([2, 3, 128, 128])  # [len(batch_idx), len(channel_idx), height, width]

x[batch_idx, channel_idx, :, :] += y

Running this code raises the following error:

IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [2], [3]

How can I perform the desired operation without looping through each index of each dimension?

Answer

PyTorch expects that batch_idx and channel_idx can be broadcast together, but in your case batch_idx has shape [2] and channel_idx has shape [3] which cannot be broadcast directly.

You can try using torch.meshgrid along with advanced indexing:-

import torch

x = torch.randn([10, 7, 128, 128]) # [batch, channel, height, width]

batch_idx = torch.tensor([1, 3], dtype=torch.int64)

channel_idx = torch.tensor([2, 3, 5], dtype=torch.int64)

y = torch.randn([2, 3, 128, 128]) # [len(batch_idx), len(channel_idx), height, width]

# Create a meshgrid of batch and channel indices.

b_idx, c_idx = torch.meshgrid(batch_idx, channel_idx, indexing='ij') # shapes: [2, 3]

# Use the meshgrid to index x and then add y:-

x[b_idx, c_idx, :, :] += y

Enjoyed this article?

Check out more content on our blog or follow us on social media.

Browse more articles