PyTorch DataLoader возвращает пакет в виде списка с пакетом в качестве единственной записи. Как лучше всего получить тензор от моего DataLoader - PullRequest
1 голос
/ 29 октября 2019

В настоящее время у меня есть следующая ситуация, когда я хочу использовать DataLoader для пакетной обработки массива:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])
print(x.shape)
# >> (100,10)

# Create DataLoader
input_as_tensor = torch.from_numpy(x).float()
dataset = data_utils.TensorDataset(input_as_tensor)
dataloader = data_utils.DataLoader(dataset,
                                   batch_size=100,
                                  )
batch = next(iter(dataloader))

print(type(batch))
# >> <class 'list'>

print(len(batch))
# >> 1

print(type(batch[0]))
# >> class 'torch.Tensor'>

Я ожидаю, что batch уже будет torch.Tensor. На данный момент я индексирую партию примерно так: batch[0] чтобы получить Tensor, но я чувствую, что это не очень красиво и делает код труднее для чтения.

Я обнаружил, что DataLoader требует пакетной обработкифункция называется collate_fn. Однако установка data_utils.DataLoader(..., collage_fn=lambda batch: batch[0]) только изменяет список на кортеж (tensor([ 0.8454, ..., -0.5863]),), где единственной записью является партия в качестве тензора.

Вы бы мне очень помогли, если бы я выяснил, как элегантно преобразовать партиютензор (даже если это будет включать в себя сообщение мне, что индексация одной записи в пакете в порядке).

1 Ответ

1 голос
/ 29 октября 2019

Приносим извинения за неудобства с моим ответом.

На самом деле вам не нужно создавать Dataset из своего тензора, вы можете передать torch.Tensor напрямую, так как он реализует __getitem__ и __len__,так что этого достаточно:

import numpy as np
import torch
import torch.utils.data as data_utils

# Create toy data
x = np.linspace(start=1, stop=10, num=10)
x = np.array([np.random.normal(size=len(x)) for i in range(100)])

# Create DataLoader
dataset = torch.from_numpy(x).float()
dataloader = data_utils.DataLoader(dataset, batch_size=100)
batch = next(iter(dataloader))
...