如何在PyTorch中做矩阵乘积
python
pytorch
12
0

在numpy中,我可以像这样做一个简单的矩阵乘法:

a = numpy.arange(2*3).reshape(3,2)
b = numpy.arange(2).reshape(2,1)
print(a)
print(b)
print(a.dot(b))

但是,当我使用PyTorch Tensors尝试此操作时,这不起作用:

a = torch.Tensor([[1, 2, 3], [1, 2, 3]]).view(-1, 2)
b = torch.Tensor([[2, 1]]).view(2, -1)
print(a)
print(a.size())

print(b)
print(b.size())

print(torch.dot(a, b))

此代码引发以下错误:

RuntimeError:/Users/soumith/code/builder/wheel/pytorch-src/torch/lib/TH/generic/THTensorMath.c:503上的张量大小不一致

有什么想法可以在PyTorch中进行矩阵乘法吗?

参考资料:
Stack Overflow
收藏
评论
共 3 个回答
高赞 时间 活跃

使用torch.mm(a, b)torch.matmul(a, b)
两者都一样。

>>> torch.mm
<built-in method mm of type object at 0x11712a870>
>>> torch.matmul
<built-in method matmul of type object at 0x11712a870>

还有另一种可能很了解的选项。那是@运算符。 @西蒙·H。

>>> a = torch.randn(2, 3)
>>> b = torch.randn(3, 4)
>>> a@b
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.mm(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])
>>> a.matmul(b)
tensor([[ 0.6176, -0.6743,  0.5989, -0.1390],
        [ 0.8699, -0.3445,  1.4122, -0.5826]])    

这三个给出相同的结果。

相关链接:
矩阵乘法运算符
PEP 465-用于矩阵乘法的专用中缀运算符

收藏
评论

您正在寻找

torch.mm(a,b)

需要注意的是torch.dot()的行为不同,以np.dot()有人讨论了这里需要什么。具体来说, torch.dot()ab视为一维向量(与它们的原始形状torch.dot() ),并计算其内积。引发错误,因为此行为使您a a为长度6的向量,而b为长度2的向量;因此无法计算其内积。对于PyTorch中的矩阵乘法,请使用torch.mm() 。相反,Numpy的np.dot()更灵活;它计算1D数组的内积,并对2D数组执行矩阵乘法。

应广大用户要求,功能torch.matmul执行矩阵乘法,如果两个参数都是2D和计算它们的点积,如果两个参数都是1D 。对于此类尺寸的输入,其行为与np.dot相同。它还允许您批量广播或matrix x matrixmatrix x vectorvector x vector操作。有关更多信息,请参见其文档

# 1D inputs, same as torch.dot
a = torch.rand(n)
b = torch.rand(n)
torch.matmul(a, b) # torch.Size([])

# 2D inputs, same as torch.mm
a = torch.rand(m, k)
b = torch.rand(k, j)
torch.matmul(a, b) # torch.Size([m, j])
收藏
评论

如果要进行矩阵(2级张量)相乘,可以用四种等效的方式进行:

AB = A.mm(B) # computes A.B (matrix multiplication)
# or
AB = torch.mm(A, B)
# or
AB = torch.matmul(A, B)
# or, even simpler
AB = A @ B # Python 3.5+

有一些细微之处。从PyTorch文档中

torch.mm不广播。有关广播矩阵产品,请参见torch.matmul()。

例如,您不能将两个一维矢量与torch.mm相乘,也不能与批处理矩阵相乘(等级3)。为此,您应该使用功能更强大的torch.matmul 。有关torch.matmul的广播行为的详细torch.matmul ,请参见文档

对于逐元素乘法,您可以简单地做(如果A和B具有相同的形状)

A * B # element-wise matrix multiplication (Hadamard product)
收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号