мой набор поездов содержит 51000 (3 канала) изображений. Я пытаюсь получить эти изображения для обучения из файла / файла данных csv, содержащего 2 столбца изображения и метку. например: traindataset.loc [0] [0] означает «/kaggle/input/alaska2-image-steganalysis/UERD/00155.jpg», который является первым каталогом изображений, а traindataset.loc [0] [1] означает «1» метка этого изображения, существует 2 метки (1,0), так что это проблема двоичной классификации, но я не смог выяснить, есть ли какая-либо ошибка в моем коде или нет. вот мой код:
class decode_images(Dataset):
def __init__(self, file):
self.data = file
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
#print(idx)
img_name = self.data.loc[idx][0]
image = Image.open(img_name)
image = image.resize((512, 512), resample=Image.BILINEAR)
label = self.data.loc[idx][1] #torch.tensor(self.data.loc[idx, 'label'])
return {'image': transforms.ToTensor()(image),
'label': label
}
train_dataset = decode_images(traindataset) #traindataset is a dataframe containing images and labels(0,1)
# simple model
model = torchvision.models.resnet101(pretrained=False)
model.load_state_dict(torch.load("../input/pytorch-pretrained-models/resnet101-5d3b4d8f.pth"))
num_features = model.fc.in_features
model.fc = nn.Linear(2048, 1)
device = torch.device("cuda:0")
#device = torch.device("cpu")
model = model.to(device)
data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)
plist = [
{'params': model.layer4.parameters(), 'lr': 1e-4, 'weight': 0.001},
{'params': model.fc.parameters(), 'lr': 1e-3}
]
optimizer = optim.Adam(plist, lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10)
since = time.time()
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 1
for epoch in range(num_epochs):
print('Epoch {}/{}'.format(epoch, num_epochs - 1))
print('-' * 10)
scheduler.step()
model.train()
running_loss = 0.0
tk0 = tqdm(data_loader, total=int(len(data_loader)))
counter = 0
for bi, d in enumerate(tk0):
inputs = d["image"]
labels = d["label"].view(-1, 1)
inputs = inputs.to(device, dtype=torch.float)
labels = labels.to(device, dtype=torch.float)
optimizer.zero_grad()
with torch.set_grad_enabled(True):
outputs = model(inputs)
#loss = criterion(outputs, labels)
loss = criterion(outputs, torch.max(labels, 1)[1])
#print(loss)
loss.backward()
optimizer.step()
running_loss += loss.item() * inputs.size(0)
counter += 1
tk0.set_postfix(loss=(running_loss / (counter * data_loader.batch_size)))
epoch_loss = running_loss / len(data_loader)
print('Training Loss: {:.4f}'.format(epoch_loss))
time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
torch.save(model.state_dict(), "model.bin")
с кодом выше, я получаю этот вывод:
Потеря обучения: 0,0000 Обучение завершено за 34 м 53 с
поэтому мой вопрос: почему потеря: 0,0000
, тогда я попытался сделать прогноз набора тестов следующим образом:
class decode_images(Dataset):
def __init__(self, csv_file):
self.data = csv_file#pd.read_csv(csv_file)
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
#print(idx)
img_name = self.data.loc[idx][0]
image = Image.open(img_name)
image = image.resize((512, 512), resample=Image.BILINEAR)
#label = self.data.loc[idx][1] #torch.tensor(self.data.loc[idx, 'label'])
#image = self.transform(image)
return {'image': image}
test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False) # test_set contains only images directory
for param in model.parameters():
param.requires_grad = False
tk0 = tqdm(test_loader)
for i, x_batch in enumerate(tk0):
print(i)
print(x_batch)
x_batch = x_batch["image"]
pred = model(x_batch.to(device))
sub.Label[i] = pred
и теперь я получаю эту ошибку:
---------------------------------------------------------------------------
KeyError Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2645 try:
-> 2646 return self._engine.get_loc(key)
2647 except KeyError:
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 0
During handling of the above exception, another exception occurred:
KeyError Traceback (most recent call last)
<ipython-input-41-fd6bbd63a0bb> in <module>
1 tk0 = tqdm(test_loader)
----> 2 for i, x_batch in enumerate(tk0):
3 print(i)
4 print(x_batch)
5 x_batch = x_batch["image"]
/opt/conda/lib/python3.7/site-packages/tqdm/notebook.py in __iter__(self, *args, **kwargs)
216 def __iter__(self, *args, **kwargs):
217 try:
--> 218 for obj in super(tqdm_notebook, self).__iter__(*args, **kwargs):
219 # return super(tqdm...) will not catch exception
220 yield obj
/opt/conda/lib/python3.7/site-packages/tqdm/std.py in __iter__(self)
1106 fp_write=getattr(self.fp, 'write', sys.stderr.write))
1107
-> 1108 for obj in iterable:
1109 yield obj
1110 # Update and possibly print the progressbar.
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
343
344 def __next__(self):
--> 345 data = self._next_data()
346 self._num_yielded += 1
347 if self._dataset_kind == _DatasetKind.Iterable and \
/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
383 def _next_data(self):
384 index = self._next_index() # may raise StopIteration
--> 385 data = self._dataset_fetcher.fetch(index) # may raise StopIteration
386 if self._pin_memory:
387 data = _utils.pin_memory.pin_memory(data)
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
42 def fetch(self, possibly_batched_index):
43 if self.auto_collation:
---> 44 data = [self.dataset[idx] for idx in possibly_batched_index]
45 else:
46 data = self.dataset[possibly_batched_index]
/opt/conda/lib/python3.7/site-packages/pandas/core/frame.py in __getitem__(self, key)
2798 if self.columns.nlevels > 1:
2799 return self._getitem_multilevel(key)
-> 2800 indexer = self.columns.get_loc(key)
2801 if is_integer(indexer):
2802 indexer = [indexer]
/opt/conda/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
2646 return self._engine.get_loc(key)
2647 except KeyError:
-> 2648 return self._engine.get_loc(self._maybe_cast_indexer(key))
2649 indexer = self.get_indexer([key], method=method, tolerance=tolerance)
2650 if indexer.ndim > 1 or indexer.size > 1:
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()
KeyError: 0