Как оптимизировать сценарий вывода, чтобы получить более быстрый прогноз классификатора? - PullRequest
0 голосов
/ 28 декабря 2018

Я написал следующий код предсказания, который предсказывает на основе обученной модели классификатора.Теперь время прогнозирования составляет около 40 с, и я хочу максимально сократить его.

Могу ли я выполнить какую-либо оптимизацию для своего сценария вывода или мне нужно искать разработки в сценарии обучения?

import torch
import torch.nn as nn
from torchvision.models import resnet18
from torchvision.transforms import transforms
import matplotlib.pyplot as plt
import numpy as np
from torch.autograd import Variable
import torch.functional as F
from PIL import Image
import os
import sys
import argparse
import time 
import json

parser = argparse.ArgumentParser(description = 'To Predict from a trained model')

parser.add_argument('-i','--image', dest = 'image_name', required = True, help='Path to the image file')
args = parser.parse_args()

def predict_image(image_path):
    print("prediciton in progress")
    image = Image.open(image_path)

    transformation = transforms.Compose([
        transforms.RandomResizedCrop(224),
        transforms.ToTensor(),
        transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
        ])

    image_tensor = transformation(image).float()
    image_tensor = image_tensor.unsqueeze_(0)

    if cuda:
        image_tensor.cuda()

    input = Variable(image_tensor)
    output = model(input)

    index = output.data.numpy().argmax()
    return index

def parameters():
    hyp_param = open('param_predict.txt','r')
    param = {}
    for line in hyp_param:
        l = line.strip('\n').split(':')

def class_mapping(index):
    with open("class_mapping.json") as cm:
        data = json.load(cm)
    if index == -1:
        return len(data)
    else:
        return data[str(index)]

def segregate():
    with open("class_mapping.json") as cm:
        data = json.load(cm)
    try:
        os.mkdir(seg_dir)
        print("Directory " , seg_dir ,  " Created ") 
    except OSError:
        print("Directory " , seg_dir ,  " already created")
    for x in range (0,len(data)):
        dir_path="./"+seg_dir+"/"+data[str(x)]
        try:
            os.mkdir(dir_path)
            print("Directory " , dir_path ,  " Created ") 
        except OSError:
            print("Directory " , dir_path ,  " already created")


path_to_model = "./models/"+'trained.model'
checkpoint = torch.load(path_to_model)
seg_dir="segregation_folder"

cuda = torch.cuda.is_available()

num_class = class_mapping(index=-1)
print num_class
model = resnet18(num_classes = num_class)

if cuda:
    model.load_state_dict(checkpoint)
else:
    model.load_state_dict(checkpoint, map_location = 'cpu')

model.eval()

if __name__ == "__main__":

    imagepath = "./Predict_Image/"+args.image_name
    since = time.time()
    img = Image.open(imagepath)
    prediction = predict_image(imagepath)
    name = class_mapping(prediction)
    print("Time taken = ",time.time()-since)

    print("Predicted Class: ",name)

Весь проект можно найти на https://github.com/amrit-das/custom_image_classifier_pytorch/

1 Ответ

0 голосов
/ 28 декабря 2018

Без вывода из вашего профилировщика трудно сказать, сколько из этого из-за неэффективности в вашем коде.Тем не менее, PyTorch имеет много накладных расходов при запуске - другими словами, он медленно инициализирует библиотеку, модель, веса загрузки и переносит ее в графический процессор по сравнению со временем вывода на одном изображении.Это делает его довольно плохим в качестве утилиты CLI для прогнозирования по одному изображению.

Если ваш вариант использования действительно требует работы с отдельными изображениями вместо пакетной обработки, то для оптимизации не так много возможностей.Я вижу два варианта:

  1. Это может стоить того, чтобы вообще пропустить выполнение GPU и сэкономить на выделении и передаче GPU.
  2. Вы получите лучшую производительность записиэтот код в C ++, используя LibTorch .Хотя это большая работа по разработке.
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...