Я работаю с фрагментом кода, написанным кем-то другим для обобщения домена, и как часть этого, у меня настроен загрузчик данных для загрузки данных обучения, проверки и тестирования для одного из моих наборов данных. Код работает нормально, когда я загружаю в поезд или тестовые данные, но когда я пытаюсь загрузить данные val, я получаю Значение Ошибка: не удалось передать входной массив из формы (320,371) в форму (320) в функции load_samples в строке images = np.asarray (images). Я понимаю, что говорит эта ошибка, но я не могу понять, почему она говорит это. Код для секции val идентичен кодам для секций train и test, и файл csv, из которого я читаю, точно такой же формат, как и два других файла csv. Я также вызываю функцию get_chexpert для каждого из них точно таким же образом. Кроме того, загрузчик данных для моего другого набора данных имеет почти такой же код, что и этот, и может отлично создать набор проверки. Я попытался проверить, если это был файл CSV, заменив val CSV на тест CSV, но я все еще получаю ту же ошибку. Может кто-нибудь указать мне, что я делаю не так? Я чувствую, что это, должно быть, какая-то глупо очевидная ошибка, но я просто не вижу ее.
import os
import csv
from PIL import Image
import numpy as np
import torch
import torch.utils.data as data
from torchvision import datasets, transforms
import params
class Chexpert(data.Dataset):
def __init__(self, root, train=True, val=False, transform=None):
"""Init chexpert dataset."""
# init params
self.root = os.path.expanduser(root)
self.train = train
self.val = val
self.transform = transform
self.dataset_size = None
self.train_data, self.train_labels = self.load_samples()
if self.train:
total_num_samples = self.train_labels.shape[0]
indices = np.arange(total_num_samples)
np.random.shuffle(indices)
self.train_data = self.train_data[indices[0:self.dataset_size]]
self.train_labels = self.train_labels[indices[0:self.dataset_size]]
def __getitem__(self, index):
"""Get images and target for data loader.
Args:
index (int): Index
Returns:
tuple: (image, target) where target is index of the target class.
"""
img, label = self.train_data[index], self.train_labels[index]
if self.transform is not None:
img = self.transform(img)
label = torch.LongTensor([np.int64(label).item()])
return img, label
def __len__(self):
"""Return size of dataset."""
return self.dataset_size
def load_samples(self):
"""Load sample images from dataset."""
# some arbitrary limits so I'm not loading 100,000 images while debugging
numtr = 50
numts = 20
numvl = 10
data_root = os.path.join(self.root, 'CheXpert-v1.0-small')
images = []
labels = []
if self.val:
val_info = csv.reader(open(os.path.join(data_root, 'effusion-val-split.csv'), 'r'))
for count, row in enumerate(val_info):
if count == numvl:
break
image = np.array(Image.open(os.path.join(self.root, row[0])))
images.append(image)
labels.append(row[1])
elif self.train:
train_info = csv.reader(open(os.path.join(data_root, 'effusion-train-split.csv'), 'r'))
for count, row in enumerate(train_info):
if count == numtr:
break
image = np.array(Image.open(os.path.join(self.root, row[0])))
images.append(image)
labels.append(row[1])
elif not self.val and not self.train:
test_info = csv.reader(open(os.path.join(data_root, 'effusion-test-split.csv'), 'r'))
for count, row in enumerate(test_info):
if count == numts:
break
image = np.array(Image.open(os.path.join(self.root, row[0])))
images.append(image)
labels.append(row[1])
images = np.asarray(images)
labels = np.asarray(labels)
self.dataset_size = labels.shape[0]
return images, labels
def get_chexpert(train, val):
"""Get chexpert dataset loader."""
# image pre-processing
pre_process = transforms.Compose([transforms.ToPILImage(),
transforms.Resize((224, 224)),
transforms.ToTensor(),
#transforms.Normalize(
#mean=params.dataset_mean,
#std=params.dataset_std)])
])
# dataset and data loader
chexpert_dataset = Chexpert(root=params.data_root,
train=train,
val=val,
transform=pre_process)
chexpert_data_loader = torch.utils.data.DataLoader(
dataset=chexpert_dataset,
batch_size=params.batch_size,
shuffle=True)
return chexpert_data_loader
if __name__ == '__main__':
# load dataset
print("Loading Source Train Data")
src_data_loader = get_chexpert()
print("Loading Source Validation Data")
src_data_loader_val = get_chexpert(train=False, val=True)
print("Loading Source Test Data")
src_data_loader_eval = get_chexpert(train=False)
print("Loading Target Train Data")
tgt_data_loader = get_nih()
print("Loading Target Validation Data")
tgt_data_loader_val = get_nih(train=False, val=True)
print("Loading Target Test Data")
tgt_data_loader_eval = get_nih(train=False)