Pytorch - получая потери 0,000 и keyerror при выводе - PullRequest
0 голосов
/ 30 апреля 2020

мой набор поездов содержит 51000 (3 канала) изображений. Я пытаюсь получить эти изображения для обучения из файла / файла данных csv, содержащего 2 столбца изображения и метку. например: traindataset.loc [0] [0] означает «/kaggle/input/alaska2-image-steganalysis/UERD/00155.jpg», который является первым каталогом изображений, а traindataset.loc [0] [1] означает «1» метка этого изображения, существует 2 метки (1,0), так что это проблема двоичной классификации, но я не смог выяснить, есть ли какая-либо ошибка в моем коде или нет. вот мой код:

class decode_images(Dataset):

    def __init__(self, file):

        self.data = file

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        #print(idx)
        img_name =  self.data.loc[idx][0]
        image = Image.open(img_name)
        image = image.resize((512, 512), resample=Image.BILINEAR)
        label = self.data.loc[idx][1] #torch.tensor(self.data.loc[idx, 'label'])
        return {'image': transforms.ToTensor()(image),
                'label': label
                }

train_dataset = decode_images(traindataset)  #traindataset is a dataframe containing images and labels(0,1)

# simple model

model = torchvision.models.resnet101(pretrained=False)
model.load_state_dict(torch.load("../input/pytorch-pretrained-models/resnet101-5d3b4d8f.pth"))
num_features = model.fc.in_features
model.fc = nn.Linear(2048, 1)
device = torch.device("cuda:0")
#device = torch.device("cpu")
model = model.to(device)

data_loader = torch.utils.data.DataLoader(train_dataset, batch_size=16, shuffle=True, num_workers=4)

plist = [
         {'params': model.layer4.parameters(), 'lr': 1e-4, 'weight': 0.001},
         {'params': model.fc.parameters(), 'lr': 1e-3}
         ]

optimizer = optim.Adam(plist, lr=0.001)
scheduler = lr_scheduler.StepLR(optimizer, step_size=10)


since = time.time()
criterion = torch.nn.CrossEntropyLoss()
num_epochs = 1
for epoch in range(num_epochs):
    print('Epoch {}/{}'.format(epoch, num_epochs - 1))
    print('-' * 10)
    scheduler.step()
    model.train()
    running_loss = 0.0
    tk0 = tqdm(data_loader, total=int(len(data_loader)))
    counter = 0
    for bi, d in enumerate(tk0):
        inputs = d["image"]
        labels = d["label"].view(-1, 1)
        inputs = inputs.to(device, dtype=torch.float)
        labels = labels.to(device, dtype=torch.float)
        optimizer.zero_grad()
        with torch.set_grad_enabled(True):
            outputs = model(inputs)
            #loss = criterion(outputs, labels)
            loss = criterion(outputs, torch.max(labels, 1)[1])
            #print(loss)
            loss.backward()
            optimizer.step()
        running_loss += loss.item() * inputs.size(0)
        counter += 1
        tk0.set_postfix(loss=(running_loss / (counter * data_loader.batch_size)))
    epoch_loss = running_loss / len(data_loader)
    print('Training Loss: {:.4f}'.format(epoch_loss))

time_elapsed = time.time() - since
print('Training complete in {:.0f}m {:.0f}s'.format(time_elapsed // 60, time_elapsed % 60))
torch.save(model.state_dict(), "model.bin")

с кодом выше, я получаю этот вывод:

Потеря обучения: 0,0000 Обучение завершено за 34 м 53 с

поэтому мой вопрос: почему потеря: 0,0000

, тогда я попытался сделать прогноз набора тестов следующим образом:


class decode_images(Dataset):

    def __init__(self, csv_file):

        self.data = csv_file#pd.read_csv(csv_file)

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        #print(idx)
        img_name =  self.data.loc[idx][0]
        image = Image.open(img_name)
        image = image.resize((512, 512), resample=Image.BILINEAR)
        #label = self.data.loc[idx][1] #torch.tensor(self.data.loc[idx, 'label'])
        #image = self.transform(image)
        return {'image': image}

test_loader = torch.utils.data.DataLoader(test_set, batch_size=1, shuffle=False) # test_set contains only images directory

for param in model.parameters():
    param.requires_grad = False


tk0 = tqdm(test_loader)
for i, x_batch in enumerate(tk0):
    print(i)
    print(x_batch)
    x_batch = x_batch["image"]
    pred = model(x_batch.to(device))
    sub.Label[i] = pred

и теперь я получаю эту ошибку:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
/opt/conda/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2645             try:
-> 2646                 return self._engine.get_loc(key)
   2647             except KeyError:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0

During handling of the above exception, another exception occurred:

KeyError                                  Traceback (most recent call last)
<ipython-input-41-fd6bbd63a0bb> in <module>
      1 tk0 = tqdm(test_loader)
----> 2 for i, x_batch in enumerate(tk0):
      3     print(i)
      4     print(x_batch)
      5     x_batch = x_batch["image"]

/opt/conda/lib/python3.7/site-packages/tqdm/notebook.py in __iter__(self, *args, **kwargs)
    216     def __iter__(self, *args, **kwargs):
    217         try:
--> 218             for obj in super(tqdm_notebook, self).__iter__(*args, **kwargs):
    219                 # return super(tqdm...) will not catch exception
    220                 yield obj

/opt/conda/lib/python3.7/site-packages/tqdm/std.py in __iter__(self)
   1106                 fp_write=getattr(self.fp, 'write', sys.stderr.write))
   1107 
-> 1108         for obj in iterable:
   1109             yield obj
   1110             # Update and possibly print the progressbar.

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    343 
    344     def __next__(self):
--> 345         data = self._next_data()
    346         self._num_yielded += 1
    347         if self._dataset_kind == _DatasetKind.Iterable and \

/opt/conda/lib/python3.7/site-packages/torch/utils/data/dataloader.py in _next_data(self)
    383     def _next_data(self):
    384         index = self._next_index()  # may raise StopIteration
--> 385         data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
    386         if self._pin_memory:
    387             data = _utils.pin_memory.pin_memory(data)

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/torch/utils/data/_utils/fetch.py in <listcomp>(.0)
     42     def fetch(self, possibly_batched_index):
     43         if self.auto_collation:
---> 44             data = [self.dataset[idx] for idx in possibly_batched_index]
     45         else:
     46             data = self.dataset[possibly_batched_index]

/opt/conda/lib/python3.7/site-packages/pandas/core/frame.py in __getitem__(self, key)
   2798             if self.columns.nlevels > 1:
   2799                 return self._getitem_multilevel(key)
-> 2800             indexer = self.columns.get_loc(key)
   2801             if is_integer(indexer):
   2802                 indexer = [indexer]

/opt/conda/lib/python3.7/site-packages/pandas/core/indexes/base.py in get_loc(self, key, method, tolerance)
   2646                 return self._engine.get_loc(key)
   2647             except KeyError:
-> 2648                 return self._engine.get_loc(self._maybe_cast_indexer(key))
   2649         indexer = self.get_indexer([key], method=method, tolerance=tolerance)
   2650         if indexer.ndim > 1 or indexer.size > 1:

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/index.pyx in pandas._libs.index.IndexEngine.get_loc()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

pandas/_libs/hashtable_class_helper.pxi in pandas._libs.hashtable.PyObjectHashTable.get_item()

KeyError: 0
...