обучение pytorch l oop заканчивается на '' int '' объект не имеет исключения атрибута 'size' - PullRequest
0 голосов
/ 13 июля 2020

Код, который я публикую ниже, является лишь небольшой частью приложения:

def train(self, training_reviews, training_labels):
        
        # make sure out we have a matching number of reviews and labels
        assert(len(training_reviews) == len(training_labels))
        
        # Keep track of correct predictions to display accuracy during training 
        correct_so_far = 0
        
        # Remember when we started for printing time statistics
        start = time.time()
        
        
        criterion = nn.CrossEntropyLoss()
        optimizer =  torch.optim.SGD(self.parameters(), lr=self.learning_rate)

        # loop through all the given reviews and run a forward and backward pass,
        # updating weights for every item
        for i in range(len(training_reviews)):
            
            # TODO: Get the next review and its correct label
            review = training_reviews[i]
            label = training_labels[i]
            print('processing item ',i)
            self.update_input_layer(review)
            output = self.forward(torch.from_numpy(self.layer_0).float()) 
            target = self.get_target_for_label(label)
            print('output ',output)
            print('target ',target)
            loss = criterion(output, target)

...
mlp = SentimentNetwork(reviews[:-1000],labels[:-1000], learning_rate=0.1)
mlp.train(reviews[:-1000],labels[:-1000])

и заканчивается исключением в строке заголовка при оценке:

loss = criterion(output, target)

ранее для этого переменные следующие:

output  tensor([[0.5803]], grad_fn=<SigmoidBackward>)
target  1

1 Ответ

1 голос
/ 13 июля 2020

Цель должна быть torch.Tensor переменной. Используйте torch.tensor([target]).

Кроме того, вы можете захотеть использовать партии (так что есть N образцов, а форма torch.tensor равна (N,), то же самое для target).

Также см. вводное руководство о PyTorch, поскольку вы не используете пакеты, не запускаете оптимизатор или не используете torch.utils.data.Dataset и torch.utils.data.DataLoader, как вам, вероятно, следует.

...