Проблемы с преобразованием весов PyTorch (файл .pth) в ONNX - PullRequest
0 голосов
/ 29 февраля 2020

Я пытаюсь объединить различные учебники, которые я прочитал в Интернете. Моя цель - иметь трекер людей и использовать его с разными предварительно обученными моделями. Я использовал здесь код для трекера / счетчика: https://www.pyimagesearch.com/2018/08/13/opencv-people-counter/

Они используют модель Caffe MobileNetSSD. Я хочу использовать Re sNet, Yolov3 и в основном любую модель, которую стоит попробовать, чтобы получить лучшие результаты обнаружения. Поэтому я нашел AlignedReId и репо с моделью: https://github.com/huanghoujing/AlignedReID-Re-Production-Pytorch

Я пытался читать напрямую с весов Re sNet в репозитории AlignedReId, используя cv2.dnn.readNetFromTorch(model_weight.pth), но я нашел из-за того, что OpenCV не поддерживает PyTorch напрямую, поэтому я должен преобразовать модель (ONNX был наиболее распространенным решением). Проблема в том, что я не знаю, как это сделать, я застрял на этапе, когда мне нужно вызвать ModelClass из этого урока: https://michhar.github.io/convert-pytorch-onnx/

I ' Я использую pytorch-cpu, и в моем случае это будет что-то вроде:

import torch
import torch.onnx

# A model class instance (class not shown)
model = ResNet50()

# Load the weights from a file (.pth usually)
state_dict = torch.load('model_weight.pth')

# Load the weights now into a model net architecture defined by our class
model.load_state_dict(state_dict)

# Create the right input shape (e.g. for an image)
dummy_input = torch.randn(1, 3, 256, 128)

torch.onnx.export(model, dummy_input, "onnx_resnet50.onnx")

Но я не знаю, как получить этот класс ResNet50 из репозитория AlignedReId и использовать его. Кроме того, в части torch.randn(1, 3, 256, 128) я использую эти числа, потому что в репозитории AlignedReId они заявляют, что использовали этот размер изображения, но я не уверен, что sample_batch_size равен 1 (я нашел это число во многих вопросах о преобразовании pth моделей в файлы onnx).

Как преобразовать предварительно обученную модель AlignedReId в ONNX, чтобы читать ее с OpenCV и использовать для отслеживания людей в первой ссылке?

...