Я работал над этим https://www.kaggle.com/gti-upm/leapgestrecog набором данных в последнее время. Это набор данных жестов рук, и я пытался сделать классификатор. Из-за изображений, доступных в разных типах папок, я загрузил их на загрузчик данных. Вот это
class DatasetLoader(Dataset):
def __init__(self,path):
self.path_list = path
self.labels = []
self.to_tensor = transforms.ToTensor()
self.resize = transforms.Resize((120,320))
self.gray = transforms.Grayscale(num_output_channels = 1)
self._init_dataset()
def _init_dataset(self):
labels = set()
for diro in os.listdir("/kaggle/input/leapgestrecog/leapGestRecog"):
for d in os.listdir(os.path.join("/kaggle/input/leapgestrecog/leapGestRecog",diro)):
if len(d.split('_'))>2:
labels.add("_".join(d.split("_")[-2:]))
else:
labels.add(d.split("_")[-1])
labels = list(labels)
## help me on this line with some codes
def __getitem__(self,idx):
if torch.is_tensor(idx):
idx = idx.tolist()
img_name = self.path_list[idx]
img = Image.open(img_name)
img = self.resize(img)
img = self.gray(img)
img = self.to_tensor(img)
if len(img_name.split('/')[-2].split('_')) > 2:
label = "_".join(img_name.split('/')[-2].split('_')[-2:])
else:
label = img_name.split('/')[-2].split('_')[-1]
label = ## Here also
return img,label
def __len__(self):
return len(self.path_list)
У меня проблема с меткой, которую я получаю от этого загрузчика набора данных. Поскольку я создал модель, которая принимает n пакетов данных с 10 классами, поэтому при расчете потерь мне нужны метки размера (n, 10). Я не знаю что делать. Вот мой дизайн сети:
class Net(nn.Module):
def __init__(self):
super(Net,self).__init__()
self.conv1 = nn.Conv2d(1,32,5)
self.pool = nn.MaxPool2d(2,2)
self.conv2 = nn.Conv2d(32,64,3)
self.conv3 = nn.Conv2d(64,64,3)
self.fc1 = nn.Linear(64*38*13,128)
self.fc2 = nn.Linear(128,10)
def forward(self,x):
x = self.pool(F.relu(self.conv1(x)))
x = self.pool(F.relu(self.conv2(x)))
x = self.pool(F.relu(self.conv3(x)))
x = x.view(64,64*38*13)
x = F.relu(self.fc1(x))
return F.log_softmax(self.fc2(x),dim = 1)
Если у - метка изображения. Для обучения нашей сети мы кормим функцию потерь с помощью y и вывода. Но вывод, который мы получаем, имеет размер (64,10), поэтому мне нужна помощь с label
в загрузчике данных