Как отладить сохранение модели TypeError: невозможно выбрать объекты SwigPyObject? - PullRequest
0 голосов
/ 10 января 2020

Я пытаюсь сохранить модель, которая является объектом класса, который наследуется от nn.Module. Он отлично работает без проблем, но когда я пытаюсь запустить код:

    torch.save(
        obj=model,
        f=os.path.join(tensorboard_writer.get_logdir(), 'model.ckpt'))

, я получаю сообщение об ошибке: TypeError: can't pickle SwigPyObject objects

Я понятия не имею, что такое объект SwigPyObject. Как отладить эту ошибку, чтобы сохранить мою модель?

Полная трассировка:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 149, in _with_file_like
    return body(f)
  File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 224, in <lambda>
    return _with_file_like(f, "wb", lambda f: _save(obj, f, pickle_module, pickle_protocol))
  File "/usr/local/lib/python3.6/dist-packages/torch/serialization.py", line 296, in _save
    pickler.dump(obj)
TypeError: can't pickle SwigPyObject objects

Если это поможет, моя модель является объектом следующего класса:

class RecurrentModel(nn.Module):

    def __init__(self,
                 core_str,
                 core_kwargs,
                 tensorboard_writer=None,
                 input_size=1,
                 hidden_size=32,
                 output_size=2):

        super(RecurrentModel, self).__init__()
        self.tensorboard_writer = tensorboard_writer
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.core = self._create_core(core_str=core_str, core_kwargs=core_kwargs)
        self.core_hidden = None
        self.linear = nn.Linear(
            in_features=hidden_size,
            out_features=output_size,
            bias=True)
        self.softmax = nn.Softmax(dim=-1)

        # converts all weights into doubles i.e. float64
        # this prevents PyTorch from breaking when multiplying float32 * flaot64
        self.double()

        # TODO figure out why writing the model to tensorboard doesn't work
        # dummy_input = torch.zeros(size=(10, 1, 1), dtype=torch.double)
        # tensorboard_writer.add_graph(
        #     model=self,
        #     input_to_model=dict(stimulus=dummy_input))

    def _create_core(self, core_str, core_kwargs):
        if core_str == 'lstm':
            core_constructor = nn.LSTM
        elif core_str == 'rnn':
            core_constructor = nn.RNN
        elif core_str == 'gru':
            core_constructor = nn.GRU
        else:
            raise ValueError('Unknown core string')

        core = core_constructor(
            input_size=self.input_size,
            hidden_size=self.hidden_size,
            batch_first=True,
            **core_kwargs)
        return core

    def forward(self, model_input):

        if self.core_hidden is None:
            core_output, self.core_hidden = self.core(
                model_input['stimulus'])
        else:
            core_output, self.core_hidden = self.core(
                model_input['stimulus'],
                self.core_hidden)

        linear_output = self.linear(core_output)

        softmax_output = self.softmax(linear_output)

        forward_output = dict(
            core_output=core_output,
            core_hidden=self.core_hidden,
            linear_output=linear_output,
            softmax_output=softmax_output)

        return forward_output

    def reset_core_hidden(self):
        self.core_hidden = None

Затем я создаю и сохраняю модель:

    tensorboard_writer = SummaryWriter()

    model = RecurrentModel(
        core_str='rnn',
        core_kwargs={},
        tensorboard_writer=tensorboard_writer)

    torch.save(
        obj=model,
        f=os.path.join(tensorboard_writer.get_logdir(), 'model.ckpt')
    )

1 Ответ

0 голосов
/ 10 января 2020

Кто-то на форуме PyTorch уточнил. SummaryWriter имеет дескриптор файла, который обычно не может сериализоваться. Линия проблемных c:

self.tensorboard_writer = tensorboard_writer

https://github.com/pytorch/pytorch/issues/32046#issuecomment -573151743

...