148

What does torch.gather() do? This answer is hard to understand.

2
  • I have never used DQN. Can you try to specify what obs_batch and act_batch are? Commented Jun 24, 2018 at 9:30
  • @McLawrence obs_batch is the batch of observations and act_batch is the batch of actions. From what I understand, it basically means that when I pass a batch of observations to the q function it returns a set of q values corresponding to each observation. Commented Jun 26, 2018 at 6:57

6 Answers 6

358

torch.gather creates a new tensor from the input tensor by taking the values from each row along the input dimension dim. The values in torch.LongTensor, passed as index, specify which value to take from each 'row'. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly: Pictoral representation from the docs

(Note: In the illustration, indexing starts from 1 and not 0).

In first example, the dimension given is along rows (top to bottom), so for (1,1) position of result, it takes row value from the index for the src that is 1. At (1,1) in source value is 1 so, outputs 1 at (1,1) in result. Similarly for (2,2) the row value from the index for src is 3. At (3,2) the value in src is 8 and hence outputs 8 and so on.

Similarly for second example, indexing is along columns, and hence at (2,2) position of the result, the column value from the index for src is 3, so at (2,3) from src ,6 is taken and outputs to result at (2,2)

Sign up to request clarification or add additional context in comments.

6 Comments

index = torch.as_tensor([[0,1,2],[1,2,0]]) and src = torch.arange(9).reshape(3,3) then torch.gather(src,0,index) and torch.gather(src,1,index.T)
Excellent description. I didn't know how I would grasp the concept without this fantastic illustration! Thanks so much, and keep it up!
The figure only shows the case where index matches the size of the corresponding src dimension. For dimensions other than dim, it can be smaller.
Just to clarify, the answer's dim indexing also starts from 1, not from 0. Dim 2 in this answer would correspond to dim=1 in the documentation: pytorch.org/docs/stable/generated/torch.gather.html Otherwise, thank you so much for explaining this visually so nicely! I finally understood this methond thanks to you.
@RuslanMukhamadiarov great point. If you are confused, note that Ruslan is correct here.
|
100

The torch.gather function (or torch.Tensor.gather) is a multi-index selection method. Look at the following example from the official docs:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

Let's start with going through the semantics of the different arguments: The first argument, input, is the source tensor that we want to select elements from. The second, dim, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, index are the indices to index input. As for the semantics of the operation, this is how the official docs explain it:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

So let's go through the example.

the input tensor is [[1, 2], [3, 4]], and the dim argument is 1, i.e. we want to collect from the second dimension. The indices for the second dimension are given as [0, 0] and [1, 0].

As we "skip" the first dimension (the dimension we want to collect along is 1), the first dimension of the result is implicitly given as the first dimension of the index. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the index tensor itself. For the example, this means that the output will have in its first row a selection of the elements of the input tensor's first row as well, as given by the first row of the index tensor's first row. As the column-indices are given by [0, 0], we therefore select the first element of the first row of the input twice, resulting in [1, 1]. Similarly, the elements of the second row of the result are a result of indexing the second row of the input tensor by the elements of the second row of the index tensor, resulting in [4, 3].

To illustrate this even further, let's swap the dimension in the example:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

As you can see, the indices are now collected along the first dimension.

For the example you referred,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather will index the rows of the q-values (i.e. the per-sample q-values in a batch of q-values) by the batch-list of actions. The result will be the same as if you had done the following (though it will be much faster than a loop):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)

Comments

71

@Ritesh and @cleros gave great answers (with lots of upvotes), but after reading them I was still a bit confused, and I know why. This post will perhaps help folks like me.

For these sorts of exercises with rows and columns I think it really helps to use a non-square object, so let's start with a larger 4x3 source (torch.Size([4, 3])) using source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]]). This will give us

\\ This is the source tensor
tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12]])

Now let's start indexing along the columns (dim=1) and create index = torch.tensor([[0,0],[1,1],[2,2],[0,1]]), which is a list of lists. Here's the key: since our dimension is columns, and the source has 4 rows, the index must contain 4 lists! We need a list for each row. Running source.gather(dim=1, index=index) will give us

tensor([[ 1,  1],
        [ 5,  5],
        [ 9,  9],
        [10, 11]])

So, each list within index gives us the columns from which to pull the values. The 1st list of the index ([0,0]) is telling us to take to look at the 1st row of the source and take the 1st column of that row (it's zero-indexed) twice, which is [1,1]. The 2nd list of the index ([1,1]) is telling us to take to look at the 2nd row of source and take the 2nd column of that row twice, which is [5,5]. Jumping to the 4th list of the index ([0,1]), which is asking us to look at the 4th and final row of the source, is asking us to take the 1st column (10) and then the 2nd column (11) which gives us [10,11].

Here's a nifty thing: each list of your index has to be the same length, but they may be as long as you like! For example, with index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]]), source.gather(dim=1, index=index) will give us

tensor([[ 1,  2,  3,  2,  1],
        [ 6,  5,  4,  5,  6],
        [ 8,  9,  7,  9,  8],
        [11, 10, 12, 10, 11]])

The output will always have the same number of rows as the source, but the number of columns will equal the length of each list in index. For example, the 2nd list of the index ([2,1,0,1,2]) is going to the 2nd row of the source and pulling, respectively, the 3rd, 2nd, 1st, 2nd and 3rd items, which is [6,5,4,5,6]. Note, the value of every element in index has to be less than the number of columns of source (in this case 3), otherwise you get an out of bounds error.

Switching to dim=0, we'll now be using the rows as opposed to the columns. Using the same source, we now need an index where the length of each list equals the number of columns in the source. Why? Because each element in the list represents the row from source as we move column by column.

Therefore, index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]]) will then have source.gather(dim=0, index=index) give us

tensor([[ 1,  2,  3],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])

Looking at the 1st list in the index ([0,0,0]), we can see that we're moving across the 3 columns of source picking the 1st element (it's zero-indexed) of each column, which is [1,2,3]. The 2nd list in the index ([0,1,2]) tells us to move across the columns taking the 1st, 2nd and 3rd items, respectively, which is [1,5,9]. And so on.

With dim=1 our index had to have a number of lists equal to the number of rows in the source, but each list could be as long, or short, as you like. With dim=0, each list in our index has to be the same length as the number of columns in the source, but we can now have as many lists as we like. Each value in index, however, needs to be less than the number of row in source (in this case 4).

For example, index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]]) would have source.gather(dim=0, index=index) give us

tensor([[ 1,  2,  3],
        [ 4,  5,  6],
        [ 7,  8,  9],
        [10, 11, 12],
        [ 1,  5,  9],
        [ 4,  8, 12],
        [10,  8,  3]])

With dim=1 the output always has the same number of rows as the source, although the number of columns will equal the length of the lists in index. The number of lists in index has to equal the number of rows in source. Each value in index, however, needs to be less than the number of columns in source.

With dim=0 the output always has the same number of columns as the source, but the number of rows will equal the number of lists in index. The length of each list in index has to equal the number of columns in source. Each value in index, however, needs to be less than the number of row in source.

That's it for two dimensions. Moving beyond that will follow the same patterns.

5 Comments

Fantastic answer. Your description of what we need for each dimension helped me visualize the operation much easier.
Bravo. They key for me was indeed you pointed out in bold. Also using a non-square matrix was super helpful. Much appreciated!
Just a tiny add-on: for anyone reading this now, there are no constraints of this kind in newer versions of PyTorch: "Here's the key: since our dimension is columns, and the source has 4 rows, the index must contain 4 lists!"
Now, the constraint is index.size(i) <= source.size(i) for all i != dim.
Best answer IMO
30

This is based on @Ritesh answer (thanks @Ritesh!) with some real codes.

The torch.gather API is

torch.gather(input, dim, index, *, sparse_grad=False, out=None) → Tensor

Example 1

When dim = 0,

enter image description here

dim = 0
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1, 2], [1, 2, 0]])

output = torch.gather(input, dim, index)
# tensor([[10, 14, 18],
#         [13, 17, 12]])

Example 2

When dim = 1,

enter image description here

dim = 1
input = torch.tensor([[10, 11, 12], [13, 14, 15], [16, 17, 18]])
index = torch.tensor([[0, 1], [1, 2], [2, 0]])

output = torch.gather(input, dim, index)
# tensor([[10, 11],
#         [14, 15],
#         [18, 16]])

Comments

1

I find torch.gather difficult to understand until I realize how similar it is to torch.index_select, which is much easier to understand.

Let's use the source data and indicies in @Mark Cramer's answer.

# Source:
source = torch.tensor([
    [ 1, 2, 3],
    [ 4, 5, 6],
    [ 7, 8, 9],
    [10,11,12],
])

# Index:
index = torch.tensor([
    [0,1,2,1,0],
    [2,1,0,1,2],
    [1,2,0,2,1],
    [1,0,2,0,1],
])

What values will torch.gather(source, 1, index) pull out of source? It turns out the answer is

# Answer:
tensor([[ 1,  2,  3,  2,  1],
        [ 6,  5,  4,  5,  6],
        [ 8,  9,  7,  9,  8],
        [11, 10, 12, 10, 11]])

It looks completely incomprehensible to me. However, let's use torch.index_select as an intermediate step to build some intuition.

torch.index_select(source, dim, index_1d) selects elements based on the 1-D index index_1d you supplied to it along a specific dimension dim. For example, let's say we have a set of indices

index_1d = torch.tensor([0, 1, 2, 1, 0]) # Essentially index[0]

and

torch.index_select(source, 1, index_1d)[0] gives

tensor([1, 2, 3, 2, 1]) # Exactly the same as the first row of the `torch.gather` results.

If I remove the indexing [0] at the end, torch.index_select(source, 1, index_1d) gives

tensor([[ 1,  2,  3,  2,  1],
        [ 4,  5,  6,  5,  4],
        [ 7,  8,  9,  8,  7],
        [10, 11, 12, 11, 10]])

Now, try to guess what's happening here.

Yes, the selection index [0, 1, 2, 1, 0] was applied to every element in every row. Since the number of rows are the same in both the input and output, let's say this row dimension is preserved.

Let's check out what happens when we loop over every row in index and only pull out the corresponding row from torch.index_select:

for i in range(index.size(0)):  # .size(0) returns the size of dim 0
    print(torch.index_select(source, 1, index[i])[i])

which produces

tensor([1, 2, 3, 2, 1])
tensor([6, 5, 4, 5, 6])
tensor([8, 9, 7, 9, 8])
tensor([11, 10, 12, 10, 11])

This is exactly the same as the torch.gather results if we format it more carefully. By the way, to produce the exact same results, we can do

# With stack
torch.stack([torch.index_select(source, 1, index[i])[i] for i in range(index.size(0))])

# With cat
torch.cat([torch.index_select(source, 1, index[i])[i][None,] for i in range(index.size(0))], dim = 0)  # Maybe torch.stack is more intuitive

In summary, torch.gather can be seen as an extension of torch.index_select that enables independent indexing for each element along a preserved dimension, while torch.index_select applies the same set of indices uniformly across the entire preserved dimension.

Comments

-2

gather allows you to take tensor indexing

>>> torch.arange(6)[torch.tensor([1,5])]
tensor([1, 5])

and do it in batches

>>> a = torch.stack((torch.arange(6),torch.arange(6)), dim=0)
>>> torch.gather(a, dim=1, index=torch.tensor([[5,1],[5,1]]))
tensor([[5, 1],
        [5, 1]])

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.