Как импортировать модель Pytorch в MATLAB - PullRequest
0 голосов
/ 26 июня 2018

Я создал модель в Pytorch и хочу перенести ее в MATLAB, показан минимальный пример

import torch.nn as nn
import torch
class cnn(nn.Module):
    def __init__(self):
        super(cnn, self).__init__()
        self.fc1 = nn.Sequential(
            nn.Linear(10, 1),
            nn.ReLU(True)
        )

    def forward(self, x):
        out = self.fc1(x)
        return out
the_net = cnn()
torch.save(the_net,'desperation.h5')

В MATLAB я звоню

net = importKerasLayers('desperation.h5')

Это дает сообщение об ошибке

Error using importKerasLayers (line 104)
Unable to read HDF5 file 'desperation.h5'. The error message was: 'The filename specified was either
not found on the MATLAB path or it contains unsupported characters.''

Файл находится на пути, и я могу загрузить модель обратно в Python. Что я действительно хочу, так это любое решение, которое позволяет мне переносить модель из Pytorch в MATLAB без ручного копирования всех весов.

Я использую MATLAB 2018b, Python 3.6 и Pytorch 0.4.0

1 Ответ

0 голосов
/ 04 июля 2018

Я использовал этот инструмент в прошлом с некоторым успехом: https://github.com/albanie/mcnPyTorch, чтобы перейти от Pytorch к MatConvNet.

...