# What does the gather function do in pytorch in layman terms?

## What does the gather function do in pytorch in layman terms?

`torch.gather`

creates a new tensor from the input tensor by taking the values from each row along the input dimension `dim`

. The values in `torch.LongTensor`

, passed as `index`

, specify which value to take from each row. The dimension of the output tensor is same as the dimension of index tensor. Following illustration from the official docs explains it more clearly:

(Note: In the illustration, indexing starts from 1 and not 0).

In first example, the dimension given is along rows (top to bottom), so for (1,1) position of `result`

, it takes row value from the `index`

for the `src`

that is `1`

. At (1,1) in source value is `1`

so, outputs `1`

at (1,1) in `result`

.

Similarly for (2,2) the row value from the index for `src`

is `3`

. At (3,2) the value in `src`

is `8`

and hence outputs `8`

and so on.

Similarly for second example, indexing is along columns, and hence at (2,2) position of the `result`

, the column value from the index for `src`

is `3`

, so at (2,3) from `src`

,`6`

is taken and outputs to `result`

at (2,2)

The `torch.gather`

function (or `torch.Tensor.gather`

) is a multi-index selection method. Look at the following example from the official docs:

```
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 1],
# [ 4, 3]])
```

Lets start with going through the semantics of the different arguments: The first argument, `input`

, is the source tensor that we want to select elements from. The second, `dim`

, is the dimension (or axis in tensorflow/numpy) that we want to collect along. And finally, `index`

are the indices to index `input`

.

As for the semantics of the operation, this is how the official docs explain it:

```
out[i][j][k] = input[index[i][j][k]][j][k] # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k] # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]] # if dim == 2
```

So lets go through the example.

the input tensor is `[[1, 2], [3, 4]]`

, and the dim argument is `1`

, i.e. we want to collect from the second dimension. The indices for the second dimension are given as `[0, 0]`

and `[1, 0]`

.

As we skip the first dimension (the dimension we want to collect along is `1`

), the first dimension of the result is implicitly given as the first dimension of the `index`

. That means that the indices hold the second dimension, or the column indices, but not the row indices. Those are given by the indices of the `index`

tensor itself.

For the example, this means that the output will have in its first row a selection of the elements of the `input`

tensors first row as well, as given by the first row of the `index`

tensors first row. As the column-indices are given by `[0, 0]`

, we therefore select the first element of the first row of the input twice, resulting in `[1, 1]`

. Similarly, the elements of the second row of the result are a result of indexing the second row of the `input`

tensor by the elements of the second row of the `index`

tensor, resulting in `[4, 3]`

.

To illustrate this even further, lets swap the dimension in the example:

```
t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1, 2],
# [ 3, 2]])
```

As you can see, the indices are now collected along the first dimension.

For the example you referred,

```
current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))
```

`gather`

will index the rows of the q-values (i.e. the per-sample q-values in a batch of q-values) by the batch-list of actions. The result will be the same as if you had done the following (though it will be much faster than a loop):

```
q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
```

#### What does the gather function do in pytorch in layman terms?

@Ritesh and @cleros gave great answers (with *lots* of upvotes), but after reading them I was still a bit confused, and I know why. This post will perhaps help folks like me.

For these sorts of exercises with rows and columns I think it *really* helps to use a non-square object, so lets start with a larger 4×3 `source`

(`torch.Size([4, 3])`

) using `source = torch.tensor([[1,2,3], [4,5,6], [7,8,9], [10,11,12]])`

. This will give us

```
\ This is the source tensor
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12]])
```

Now lets start indexing along the columns (`dim=1`

) and create `index = torch.tensor([[0,0],[1,1],[2,2],[0,1]])`

, which is a list of lists. Heres the **key**: since our dimension is columns, and the source has `4`

rows, the `index`

must contain `4`

lists! We need a list for each row. Running `source.gather(dim=1, index=index)`

will give us

```
tensor([[ 1, 1],
[ 5, 5],
[ 9, 9],
[10, 11]])
```

So, each list within `index`

gives us the columns from which to pull the values. The 1st list of the `index`

(`[0,0]`

) is telling us to take to look at the 1st row of the `source`

and take the 1st column of that row (its zero-indexed) twice, which is `[1,1]`

. The 2nd list of the `index`

(`[1,1]`

) is telling us to take to look at the 2nd row of `source`

and take the 2nd column of that row twice, which is `[5,5]`

. Jumping to the 4th list of the `index`

(`[0,1]`

), which is asking us to look at the 4th and final row of the `source`

, is asking us to take the 1st column (`10`

) and then the 2nd column (`11`

) which gives us `[10,11]`

.

Heres a nifty thing: each list of your `index`

has to be the same length, but they may be as long as you like! For example, with `index = torch.tensor([[0,1,2,1,0],[2,1,0,1,2],[1,2,0,2,1],[1,0,2,0,1]])`

, `source.gather(dim=1, index=index)`

will give us

```
tensor([[ 1, 2, 3, 2, 1],
[ 6, 5, 4, 5, 6],
[ 8, 9, 7, 9, 8],
[11, 10, 12, 10, 11]])
```

The output will always have the same number of rows as the `source`

, but the number of columns will equal the length of each list in `index`

. For example, the 2nd list of the `index`

(`[2,1,0,1,2]`

) is going to the 2nd row of the `source`

and pulling, respectively, the 3rd, 2nd, 1st, 2nd and 3rd items, which is `[6,5,4,5,6]`

. Note, the value of every element in `index`

has to be less than the number of columns of `source`

(in this case `3`

), otherwise you get an `out of bounds`

error.

Switching to `dim=0`

, well now be using the rows as opposed to the columns. Using the same `source`

, we now need an `index`

where the length of each list equals the number of columns in the `source`

. Why? Because each element in the list represents the row from `source`

as we move column by column.

Therefore, `index = torch.tensor([[0,0,0],[0,1,2],[1,2,3],[3,2,0]])`

will then have `source.gather(dim=0, index=index)`

give us

```
tensor([[ 1, 2, 3],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
```

Looking at the 1st list in the `index`

(`[0,0,0]`

), we can see that were moving across the 3 columns of `source`

picking the 1st element (its zero-indexed) of each column, which is `[1,2,3]`

. The 2nd list in the `index`

(`[0,1,2]`

) tells us to move across the columns taking the 1st, 2nd and 3rd items, respectively, which is `[1,5,9]`

. And so on.

With `dim=1`

our `index`

had to have a number of lists equal to the number of rows in the `source`

, but each list could be as long, or short, as you like. With `dim=0`

, each list in our `index`

has to be the same length as the number of columns in the `source`

, but we can now have as many lists as we like. Each value in `index`

, however, needs to be less than the number of row in `source`

(in this case `4`

).

For example, `index = torch.tensor([[0,0,0],[1,1,1],[2,2,2],[3,3,3],[0,1,2],[1,2,3],[3,2,0]])`

would have `source.gather(dim=0, index=index)`

give us

```
tensor([[ 1, 2, 3],
[ 4, 5, 6],
[ 7, 8, 9],
[10, 11, 12],
[ 1, 5, 9],
[ 4, 8, 12],
[10, 8, 3]])
```

With `dim=1`

the output always has the same number of rows as the `source`

, although the number of columns will equal the length of the lists in `index`

. The number of lists in `index`

has to equal the number of rows in `source`

. Each value in `index`

, however, needs to be less than the number of columns in `source`

.

With `dim=0`

the output always has the same number of columns as the `source`

, but the number of rows will equal the number of lists in `index`

. The length of each list in `index`

has to equal the number of columns in `source`

. Each value in `index`

, however, needs to be less than the number of row in `source`

.

Thats it for two dimensions. Moving beyond that will follow the same patterns.