如何将numpy数组列表加载到pytorch数据集加载器?
numpy
python
pytorch
5
0

我有一个庞大的numpy数组列表,其中每个数组代表一个图像,我想使用torch.utils.data.Dataloader对象加载它。但是torch.utils.data.Dataloader的文档提到它直接从文件夹加载数据。如何为我的原因修改它?我是pytorch的新手,任何帮助将不胜感激。我的单个图像的numpy数组看起来像这样。该图像是RBG图像。

`[[[ 70  82  94]
  [ 67  81  93]
  [ 66  82  94]
  ..., 
  [182 182 188]
  [183 183 189]
  [188 186 192]]

 [[ 66  80  92]
  [ 62  78  91]
  [ 64  79  95]
  ..., 
  [176 176 182]
  [178 178 184]
  [180 180 186]]

 [[ 62  82  93]
  [ 62  81  96]
  [ 65  80  99]
  ..., 
  [169 172 177]
  [173 173 179]
  [172 172 178]]

 ..., 
`
参考资料:
Stack Overflow
收藏
评论
共 2 个回答
高赞 时间 活跃

PyTorch DataLoader需要一个DataSet ,你可以在查看文档 。正确的方法是使用:

torch.utils.data.TensorDataset(*tensors)

这是用于包装张量的数据集,其中每个样本将通过沿第一维索引张量来检索。参数*tensors量表示具有与第一维相同大小的张量。

另一个class torch.utils.data.Dataset是一个抽象类。

这是将numpy数组转换为张量的方法:

import torch
import numpy as np
n = np.arange(10)
print(n) #[0 1 2 3 4 5 6 7 8 9]
t1 = torch.Tensor(n)  # as torch.float32
print(t1) #tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.])
t2 = torch.from_numpy(n)  # as torch.int32
print(t2) #tensor([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], dtype=torch.int32)

可接受的答案使用了torch.Tensor构造。如果您的图像像素在0-255之间,则可以使用以下方法:

timg = torch.from_numpy(img).float()

或torchvision to_tensor方法,该方法将PIL图像或numpy.ndarray转换为张量。


但是,这里有一个小技巧,您可以直接放置numpy数组。

x1 = np.array([1,2,3])
d1 = DataLoader( x1, batch_size=3)

这也可以,但是如果您打印d1.dataset类型:

print(type(d1.dataset)) # <class 'numpy.ndarray'>

虽然我们实际上需要Tensors才能使用CUDA,所以最好使用Tensors来DataLoader

收藏
评论

我认为DataLoader实际需要的是对Dataset进行子类化的输入。你可以写一个继承自己的数据集类Dataset或使用TensorDataset因为我已经做了如下:

import torch
import numpy as np
from torch.utils import data

my_x = [np.array([[1.0,2],[3,4]]),np.array([[5.,6],[7,8]])] # a list of numpy arrays
my_y = [np.array([4.]), np.array([2.])] # another list of numpy arrays (targets)

tensor_x = torch.Tensor(my_x) # transform to torch tensor
tensor_y = torch.Tensor(my_y)

my_dataset = data.TensorDataset(tensor_x,tensor_y) # create your datset
my_dataloader = data.DataLoader(my_dataset) # create your dataloader

为我工作。希望对您有帮助。

收藏
评论
新手导航
  • 社区规范
  • 提出问题
  • 进行投票
  • 个人资料
  • 优化问题
  • 回答问题

关于我们

常见问题

内容许可

联系我们

@2020 AskGo
京ICP备20001863号