Как открыть предварительно обученные модели в Python - PullRequest
0 голосов
/ 19 марта 2019

Привет. Я пытаюсь загрузить некоторые предварительно обученные модели из файлов .sav, и пока ничего не работает. Модели были изначально сделаны в pytorch, и когда я открываю необработанный файл в vs-коде, я вижу, что вся соответствующая информация была сохранена правильно.

Я пробовал следующие библиотеки:

sklearn.externals.joblib

pickle

scipy.io

pyreadstat

Каждая библиотека либо выдавала мне ошибку (например, wrong timestamp или signature mismatch), либо просто возвращала int вместо объекта python.

Модели можно скачать с по этой ссылке .

1 Ответ

2 голосов
/ 19 марта 2019

Вам нужно использовать PyTorch для загрузки моделей. Кроме того, вам также необходимо оригинальное определение модели, поэтому вам нужен клон авторов репозитория. В вашем примере это репо:

git clone https://github.com/tbepler/protein-sequence-embedding-iclr2019.git

Затем вы можете открыть модель с помощью torch.load(). Обратите внимание, что вам нужно определение модели на вашем пути (вы можете просто запустить python из каталога репо).

Тогда открыть файл просто:

import torch
model = torch.load('<downloaded models>/<model name>.sav')
print(model)

Последняя строка печатает определение модели. Например, me_L1_100d_lstm3x512_lm_i512_mb64_tau0.5_p0.05_epoch100.sav выдает следующий вывод:

OrdinalRegression(
  (embedding): StackedRNN(
    (embed): LMEmbed(
      (lm): BiLM(
        (embed): Embedding(22, 21, padding_idx=21)
        (dropout): Dropout(p=0)
        (rnn): ModuleList(
          (0): LSTM(21, 1024, batch_first=True)
          (1): LSTM(1024, 1024, batch_first=True)
        )
        (linear): Linear(in_features=1024, out_features=21, bias=True)
      )
      (embed): Embedding(21, 512, padding_idx=20)
      (proj): Linear(in_features=4096, out_features=512, bias=True)
      (transform): ReLU()
    )
    (dropout): Dropout(p=0)
    (rnn): LSTM(512, 512, num_layers=3, batch_first=True, bidirectional=True)
    (proj): Linear(in_features=1024, out_features=100, bias=True)
  )
  (compare): L1()
)
...