Я хотел бы создать новый тензор в методе validation_epoch_end
из LightningModule
. В официальной документации (стр. 48) сказано, что нам следует избегать прямых вызовов .cuda()
или .to(device)
:
.cuda () или .to () звонки. . . Lightning сделает это за вас.
, и мы рекомендуем использовать метод type_as
для передачи на нужное устройство.
new_x = new_x.type_as(x.type())
Однако в шаг validation_epoch_end
У меня нет никакого тензора для копирования устройства (методом type_as
) чистым способом.
Мой вопрос: что мне делать, если я хочу создать новый тензор в этом метод и передать его на устройство, где находится модель?
Единственное, что я могу придумать, - это найти тензор в словаре outputs
, но это выглядит как-то беспорядочно:
avg_loss = torch.stack([x['val_loss'] for x in outputs]).mean()
output = self(self.__test_input.type_as(avg_loss))
Есть ли какой-нибудь чистый способ добиться этого?