用外行术语在pytorch中收集功能有什么作用?
pytorch
5
0

我已经通过官方文档却是很难理解到底是怎么回事。

我试图理解DQN源代码,并且它使用了197行上的collect函数。

有人可以简单地解释一下collect函数的作用吗?该功能的目的是什么?

参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

torch.gather通过沿输入维dim取每一行的值,从输入张量中创建一个新张量。作为index传递的torch.LongTensor的值指定要从每个“行”中获取的值。输出张量的尺寸与索引张量的尺寸相同。下图来自官方文档,对其进行了更清晰的说明: 来自文档的图形表示

(注意:在图中,索引从1开始,而不是0)。

在第一个例子中,给定尺寸为沿行(从上到下),所以对(1,1)的位置result ,它需要从该行值indexsrc1 。在源的值(1,1)是1左右,输出1在(1,1)中result 。类似地,对于(2,2),来自src的索引的行值为3 。在(3,2), src值为8 ,因此输出为8 ,依此类推。

类似地,对于第二个例子,索引是沿着列,并因此在所述的(2,2)位置result ,从用于索引的列值src3 ,所以在(2,3)从src6取并输出到result在(2,2)

收藏
评论

torch.gather函数(或torch.Tensor.gather )是一种多索引选择方法。查看官方文档中的以下示例:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 1, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  1],
#        [ 4,  3]])

让我们开始研究不同参数的语义:第一个参数input是我们要从中选择元素的源张量。第二个参数dim是我们要收集的尺寸(或以tensorflow / numpy表示的轴)。最后, index是索引input索引。至于操作的语义,这是官方文档对其进行解释的方式:

out[i][j][k] = input[index[i][j][k]][j][k]  # if dim == 0
out[i][j][k] = input[i][index[i][j][k]][k]  # if dim == 1
out[i][j][k] = input[i][j][index[i][j][k]]  # if dim == 2

因此,让我们来看一个例子。

输入张量为[[1, 2], [3, 4]] ,并且dim参数为1 ,即我们要从第二维进行收集。第二维的索引为[0, 0][1, 0]

当我们“跳过”第一个维度(要收集的维度为1 )时,结果的第一个维度被隐式地指定为index的第一个维度。这意味着索引保留第二维或列索引,但不保留行索引。这些由index张量本身的index给出。对于示例,这意味着输出将在其第一行中也具有对input张量的第一行元素的选择,如index张量的第一行的第一行所给定。由于列索引由[0, 0] ,因此我们两次选择了输入的第一行的第一个元素,结果为[1, 1] 。类似地,结果的第二行的元素是通过index张量的第二行的元素索引input张量的第二行的结果,结果为[4, 3] 4,3 [4, 3]

为了进一步说明这一点,让我们在示例中交换尺寸:

t = torch.tensor([[1,2],[3,4]])
r = torch.gather(t, 0, torch.tensor([[0,0],[1,0]]))
# r now holds:
# tensor([[ 1,  2],
#        [ 3,  2]])

如您所见,索引现在沿第一维收集。

对于您引用的示例,

current_Q_values = Q(obs_batch).gather(1, act_batch.unsqueeze(1))

gather将按操作的批处理列表索引q值的行(即,一组q值中的每个样本q值)。结果将与您执行以下操作相同(尽管它比循环快得多):

q_vals = []
for qv, ac in zip(Q(obs_batch), act_batch):
    q_vals.append(qv[ac])
q_vals = torch.cat(q_vals, dim=0)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号