как правильно использовать _extra_files arg в torch.jit.save - PullRequest
0 голосов
/ 03 июня 2019

Один вариант, который я пробовал, - это выборка Vocab и сохранение с экстрафайлами. Arg

import torch
import pickle

class Vocab(object):
    pass

vocab = Vocab()
pickle.dump(open('path/to/vocab.pkl','w'))

m = torch.jit.ScriptModule()

## I am not sure about the usage of this arg, the docs didn't help me
extra_files = torch._C.ExtraFilesMap()
extra_files['vocab.pkl'] = 'path/to/vocab.pkl'
# I also tried  pickle.dumps(vocab), and directly vocab

torch.jit.save(m, 'scriptmodule.pt', _extra_files=extra_files)

## Load with extra files.
files = {'vocab.pkl': ''}
torch.jit.load('scriptmodule.pt', _extra_files = files)

. Это дает

TypeError: import_ir_module(): incompatible function arguments. The following argument types are supported:
    1. (arg0: Callable[[List[str]], torch._C.ScriptModule], arg1: str, arg2: object, arg3: torch._C.ExtraFilesMap) -> None

. Другой вариант, очевидно, заключается в том, чтобы загружать рассол отдельно, но я искалопция одного файла.

было бы неплохо, если бы можно было просто добавить vocab к torchscript ... также было бы неплохо узнать, если есть какая-то причина, по которой я этого не делаю, о которой я, очевидно, не знаю.

Ответы [ 3 ]

1 голос
/ 17 июля 2019

Я считаю, что документация для torch.jit.load неверна. Вам необходимо создать объект ExtraFilesmap () для загрузки сохраненных файлов.

Ниже приведен пример того, как я начал работать: Шаг 1: Сохранить модель

extra_files = torch._C.ExtraFilesMap()
extra_files['foo.txt'] = 'bar'
traced_script_module.save(serialized_model_path, _extra_files=extra_files)

Шаг 2: Загрузить модель

files = torch._C.ExtraFilesMap()
files['foo.txt'] = ''
loaded_model = torch.jit.load(serialized_model_path, _extra_files=files)
print(files)
0 голосов
/ 02 июля 2019

Предполагая, что vocab является поддерживаемым типом, вы можете добавить его к модели как атрибут TorchScript , чтобы сохранить его вместе с моделью в 1 файле (так что вам не придется иметь дело с * 1004). *).

Тогда ваш код загрузки становится

torch.jit.load('scriptmodule.pt')
0 голосов
/ 03 июня 2019

проблема находится в torch.jit.load.попробуйте проверить ваше map_location

...