Ниже приведен мой код
import os
import argparse
import numpy as np
from scipy.misc import imread, imresize, imsave
import torch
from torch.autograd import Variable
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str, help='path to model')
parser.add_argument('--input', required=True, type=str, help='input codes')
parser.add_argument('--output', default='.', type=str, help='output folder')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=16, help='unroll iterations')
args = parser.parse_args()
content = np.load(args.input)
codes = np.unpackbits(content['codes'])
codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1
codes = torch.from_numpy(codes)
iters, batch_size, channels, height, width = codes.size()
height = height * 16
# print(height)
width = width * 16
# print(width)
with torch.no_grad():
codes = Variable(codes, requires_grad=True)
import network
decoder = network.DecoderCell()
decoder.eval()
decoder.load_state_dict(torch.load(args.model), strict=False)
# decoder.load_state_dict(torch.load(args.model, map_location={'cuda:0': 'cpu'}))
# decoder.load_state_dict(torch.load(args.model, map_location=lambda storage, location: storage))
decoder_h_1 = (Variable(
torch.zeros(batch_size, 512, height // 16, width // 16), requires_grad=True),
Variable(
torch.zeros(batch_size, 512, height // 16, width // 16),
requires_grad=True))
decoder_h_2 = (Variable(
torch.zeros(batch_size, 512, height // 8, width // 8), requires_grad=True),
Variable(
torch.zeros(batch_size, 512, height // 8, width // 8),
requires_grad=True))
decoder_h_3 = (Variable(
torch.zeros(batch_size, 256, height // 4, width // 4), requires_grad=True),
Variable(
torch.zeros(batch_size, 256, height // 4, width // 4),
requires_grad=True))
decoder_h_4 = (Variable(
torch.zeros(batch_size, 128, height // 2, width // 2), requires_grad=True),
Variable(
torch.zeros(batch_size, 128, height // 2, width // 2),
requires_grad=True))
if args.cuda:
decoder = decoder.cuda()
codes = codes.cuda()
decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())
image = torch.zeros(1, 3, height, width) + 0.5
for iters in range(min(args.iterations, codes.size(0))):
output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
image = image + output.data.cpu()
imsave(os.path.join(args.output, '{:02d}.png'.format(iters)),np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8).transpose(1, 2, 0))
После я получаю сообщение об ошибке типа несоответствие размера
Traceback (most recent call last):
File "decoder.py", line 38, in <module>
decoder.load_state_dict(torch.load(args.model), strict=False)
File "C:\Users\User1\Anaconda4\lib\site-packages\torch\nn\modules\module.py", line 719, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DecoderCell:
size mismatch for rnn1.conv_ih.weight: copying a param of torch.Size([2048, 512, 3, 3]) from
checkpoint, where the shape is torch.Size([1024, 64, 3, 3]) in current model.
size mismatch for rnn1.conv_hh.weight: copying a param of torch.Size([2048, 512, 1, 1]) from
checkpoint, where the shape is torch.Size([1024, 256, 1, 1]) in current model.
size mismatch for rnn2.conv_ih.weight: copying a param of torch.Size([2048, 128, 3, 3]) from
checkpoint, where the shape is torch.Size([2048, 256, 3, 3]) in current model.
size mismatch for rnn3.conv_ih.weight: copying a param of torch.Size([1024, 128, 3, 3]) from
checkpoint, where the shape is torch.Size([2048, 512, 3, 3]) in current model.
size mismatch for rnn3.conv_hh.weight: copying a param of torch.Size([1024, 256, 3, 3]) from
checkpoint, where the shape is torch.Size([2048, 512, 1, 1]) in current model.
Я получаю вышеуказанную ошибку, Как устранить ошибку, потому что я тренируюсь изображения для PNG и до сих пор получаю вышеуказанную ошибку. есть ли конкретное c решение для этой ошибки.
Я использую PyTorch 0.4. python 3,6 Использование размера изображения с 64 X 64