модель восстановления pytorch с разным размером партии - PullRequest
0 голосов
/ 20 сентября 2018

У меня проблема с тем, как перезагрузить модель pytorch с другим размером пакета.На тренировках размер моей партии равен 64, но на самом деле я хотел бы, чтобы размер партии был 1 (данные фида по одной).Это код, который я использовал для сохранения и восстановления модели:

torch.save(agent.qnetwork_local.state_dict(), './ckpt/checkpoint.pth')
saved_model = QNetwork(state_size=37, action_size=4, seed=0)
saved_model.load_state_dict(torch.load('./ckpt/checkpoint.pth'))

И я получил эту ошибку при запуске модели логического вывода:

RuntimeError: size mismatch, m1: [37 x 1], m2: [37 x 64] at /Users/soumith/code/builder/wheel/pytorch-src/aten/src/TH/generic/THTensorMath.cpp:2070

Эта ошибка означает, что ввод модели должен быть 37x64где 37 - размерность данных, а 64 - размер обучающей партии.Но вход для тестирования равен 37x1, что означает, что размер данных равен 37, а размер пакета равен 1.

Есть ли какое-либо решение для другого размера пакета в модели с повторной загрузкой?Большое спасибо.

Ответы [ 2 ]

0 голосов
/ 23 декабря 2018

Когда вы строите свою модель, вы можете использовать -1 для динамического представления размера пакета.Например, ниже приведен код прямой стадии

def forward(self, x):
     x = self.conv1(x)
     x = self.layer1(x)
     x = self.layer2(x)
     x = self.avgpool(x)
     x = x.view(-1, 37)
 #instead using x.view(64,37) 
     x = self.fc(x)

надеюсь, что он может помочь вам

0 голосов
/ 15 декабря 2018

Мне в итоге удалось сделать это, используя batch_size=1 в DataLoader

import torch
import pandas as pd
from torch.utils.data.dataloader import DataLoader

df = pd.read_csv('data.csv')
df = df.values

# Use CustomDataset class for your data
inference_dataset = CustomDataset(x=df[:1, 0:2])

inference_dataloader = DataLoader(inference_dataset, batch_size=1, shuffle=False, num_workers=4)

# 
model = TheModelClass(*args, **kwargs)
model.load_state_dict(torch.load('./model/model'))
model.eval()

for i, x in enumerate(inference_dataloader):
    x = x.float()
    y_pred = model(x)
    print(y_pred.value)
...