Pytorch - TypeError: объект «torch.Size» не может быть интерпретирован как целое число - PullRequest
0 голосов
/ 03 декабря 2018

Привет Я тренирую модель PyTorch, и произошла эта ошибка:

----> 5 for i, data in enumerate(trainloader, 0):

TypeError: 'torch.Size' object cannot be interpreted as an integer

Не уверен, что означает эта ошибка.

Вы можете найти мой код здесь:

model.train()
for epoch in range(10):
    running_loss = 0

    for i, data in enumerate(trainloader, 0):

        inputs, labels = data

        optimizer.zero_grad()

        outputs = model(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        if i % 2000 == 0:
          print (loss.item())
        running_loss += loss.item()
        if i % 1000 == 0:
            print ('[%d, %5d] loss: %.3f' % (epoch, i, running_loss/ 1000))
            running_loss = 0

torch.save(model, 'FeatureNet.pkl')

Обновление

Это кодовый блок для DataLoader.Я использую настроенный загрузчик данных и наборы данных, которые x являются изображениями с размером (1025, 16), а y являются закодированными векторами для горячего кодирования для классификации.

x_train.shape = (1100, 1025,16)

y_train.shape = (1100, 10)

clean_dir = '/home/tk/Documents/clean/' 
mix_dir = '/home/tk/Documents/mix/' 
clean_label_dir = '/home/tk/Documents/clean_labels/' 
mix_label_dir = '/home/tk/Documents/mix_labels/' 

class MSourceDataSet(Dataset):

    def __init__(self, clean_dir, mix_dir, clean_label_dir, mix_label_dir):

        with open(clean_dir + 'clean0.json') as f:
            clean0 = torch.Tensor(json.load(f))

        with open(mix_dir + 'mix0.json') as f:
            mix0 = torch.Tensor(json.load(f))

        with open(clean_label_dir + 'clean_label0.json') as f:
            clean_label0 = torch.Tensor(json.load(f))


        with open(mix_label_dir + 'mix_label0.json') as f:
            mix_label0 = torch.Tensor(json.load(f))


        self.spec = torch.cat([clean0, mix0], 0)
        self.label = torch.cat([clean_label0, mix_label0], 0)

    def __len__(self):
        return self.spec.shape


    def __getitem__(self, index): 

        spec = self.spec[index]
        label = self.label[index]
        return spec, label

getitem

a, b = trainset.__getitem__(1000)
print (a.shape)
print (b.shape)

a.shape = torch.Size([1025, 16]);b.shape = torch.Size([10])

Сообщение об ошибке

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-9-3bd71e5c00e1> in <module>()
      3     running_loss = 0
      4 
----> 5     for i, data in enumerate(trainloader, 0):
      6 
      7         inputs, labels = data

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/dataloader.py in __next__(self)
    311     def __next__(self):
    312         if self.num_workers == 0:  # same-process loading
--> 313             indices = next(self.sample_iter)  # may raise StopIteration
    314             batch = self.collate_fn([self.dataset[i] for i in indices])
    315             if self.pin_memory:

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
    136     def __iter__(self):
    137         batch = []
--> 138         for idx in self.sampler:
    139             batch.append(idx)
    140             if len(batch) == self.batch_size:

~/anaconda3/lib/python3.7/site-packages/torch/utils/data/sampler.py in __iter__(self)
     32 
     33     def __iter__(self):
---> 34         return iter(range(len(self.data_source)))
     35 
     36     def __len__(self):

TypeError: 'torch.Size' object cannot be interpreted as an integer

1 Ответ

0 голосов
/ 03 декабря 2018

Ваша проблема в функции __len__.Вы не можете использовать shape в качестве возвращаемого значения.

Вот пример для иллюстрации:

import torch
class Foo:
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return self.data.shape

myFoo = Foo(data=torch.rand(10, 20))
print(len(myFoo))

Возникнет точно такая же ошибка:

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-285-e97aace2f622> in <module>
      7 
      8 myFoo = Foo(data=torch.rand(10, 20))
----> 9 print(len(myFoo))

TypeError: 'torch.Size' object cannot be interpreted as an integer

Так какshape представляет собой torch.Size кортеж:

print(myFoo.data.shape)

Вывод:

torch.Size([10, 20])

Таким образом, вы должны решить, какое измерение вы хотите передать __len__, например, первое измерение:

import torch
class Foo:
    def __init__(self, data):
        self.data = data
    def __len__(self):
        return self.data.shape[0] # choosing first dimension for len

myFoo = Foo(data=torch.rand(10, 20))
print(len(myFoo))
# prints 10

Работает нормально и возвращает 10.Конечно, вы также можете выбрать любое другое измерение ввода, но вы должны выбрать его.

Так что в вашем коде MSourceDataSet вы должны изменить свою функцию __len__, например:

def __len__(self):
    return self.spec.shape[0] # as said of course you can also choose other dimensions

Это должно решить вашу проблему.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...