Pytorch изображения декодер дает несоответствие размера - PullRequest
0 голосов
/ 13 января 2020

Ниже приведен мой код

import os
import argparse

import numpy as np
from scipy.misc import imread, imresize, imsave

import torch
from torch.autograd import Variable
# device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
parser = argparse.ArgumentParser()
parser.add_argument('--model', required=True, type=str, help='path to model')
parser.add_argument('--input', required=True, type=str, help='input codes')
parser.add_argument('--output', default='.', type=str, help='output folder')
parser.add_argument('--cuda', action='store_true', help='enables cuda')
parser.add_argument(
'--iterations', type=int, default=16, help='unroll iterations')
args = parser.parse_args()

content = np.load(args.input)
codes = np.unpackbits(content['codes'])
codes = np.reshape(codes, content['shape']).astype(np.float32) * 2 - 1

codes = torch.from_numpy(codes)
iters, batch_size, channels, height, width = codes.size()
height = height * 16
# print(height)
width = width * 16
# print(width)

with torch.no_grad():
     codes = Variable(codes, requires_grad=True)

import network

decoder = network.DecoderCell()
decoder.eval()

decoder.load_state_dict(torch.load(args.model), strict=False)
# decoder.load_state_dict(torch.load(args.model, map_location={'cuda:0': 'cpu'}))
# decoder.load_state_dict(torch.load(args.model, map_location=lambda storage, location: storage))

decoder_h_1 = (Variable(
    torch.zeros(batch_size, 512, height // 16, width // 16), requires_grad=True),
            Variable(
               torch.zeros(batch_size, 512, height // 16, width // 16),
               requires_grad=True))
decoder_h_2 = (Variable(
   torch.zeros(batch_size, 512, height // 8, width // 8), requires_grad=True),
           Variable(
               torch.zeros(batch_size, 512, height // 8, width // 8),
               requires_grad=True))
decoder_h_3 = (Variable(
   torch.zeros(batch_size, 256, height // 4, width // 4), requires_grad=True),
           Variable(
               torch.zeros(batch_size, 256, height // 4, width // 4),
               requires_grad=True))
decoder_h_4 = (Variable(
   torch.zeros(batch_size, 128, height // 2, width // 2), requires_grad=True),
           Variable(
               torch.zeros(batch_size, 128, height // 2, width // 2),
               requires_grad=True))

if args.cuda:
   decoder = decoder.cuda()

   codes = codes.cuda()

    decoder_h_1 = (decoder_h_1[0].cuda(), decoder_h_1[1].cuda())
    decoder_h_2 = (decoder_h_2[0].cuda(), decoder_h_2[1].cuda())
    decoder_h_3 = (decoder_h_3[0].cuda(), decoder_h_3[1].cuda())
    decoder_h_4 = (decoder_h_4[0].cuda(), decoder_h_4[1].cuda())

image = torch.zeros(1, 3, height, width) + 0.5
for iters in range(min(args.iterations, codes.size(0))):

    output, decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4 = decoder(
    codes[iters], decoder_h_1, decoder_h_2, decoder_h_3, decoder_h_4)
    image = image + output.data.cpu()

    imsave(os.path.join(args.output, '{:02d}.png'.format(iters)),np.squeeze(image.numpy().clip(0, 1) * 255.0).astype(np.uint8).transpose(1, 2, 0))

После я получаю сообщение об ошибке типа несоответствие размера

Traceback (most recent call last):
File "decoder.py", line 38, in <module>
decoder.load_state_dict(torch.load(args.model), strict=False)
File "C:\Users\User1\Anaconda4\lib\site-packages\torch\nn\modules\module.py", line 719, in load_state_dict
self.__class__.__name__, "\n\t".join(error_msgs)))
RuntimeError: Error(s) in loading state_dict for DecoderCell:
    size mismatch for rnn1.conv_ih.weight: copying a param of torch.Size([2048, 512, 3, 3]) from 
checkpoint, where the shape is torch.Size([1024, 64, 3, 3]) in current model.
    size mismatch for rnn1.conv_hh.weight: copying a param of torch.Size([2048, 512, 1, 1]) from 
checkpoint, where the shape is torch.Size([1024, 256, 1, 1]) in current model.
    size mismatch for rnn2.conv_ih.weight: copying a param of torch.Size([2048, 128, 3, 3]) from 
checkpoint, where the shape is torch.Size([2048, 256, 3, 3]) in current model.
    size mismatch for rnn3.conv_ih.weight: copying a param of torch.Size([1024, 128, 3, 3]) from 
checkpoint, where the shape is torch.Size([2048, 512, 3, 3]) in current model.
    size mismatch for rnn3.conv_hh.weight: copying a param of torch.Size([1024, 256, 3, 3]) from 
checkpoint, where the shape is torch.Size([2048, 512, 1, 1]) in current model.

Я получаю вышеуказанную ошибку, Как устранить ошибку, потому что я тренируюсь изображения для PNG и до сих пор получаю вышеуказанную ошибку. есть ли конкретное c решение для этой ошибки.

Я использую PyTorch 0.4. python 3,6 Использование размера изображения с 64 X 64

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...