Создание предсказания Sagemaker Pytorch - PullRequest
0 голосов
/ 05 июня 2019

Я обучил и развернул модель в Pytorch с Sagemaker.Я могу позвонить на конечную точку и получить прогноз.Я использую функцию input_fn () по умолчанию (то есть не определенную в моем serve.py).

model = PyTorchModel(model_data=trained_model_location,
                     role=role,
                     framework_version='1.0.0',
                     entry_point='serve.py',
                     source_dir='source')
predictor = model.deploy(initial_instance_count=1, instance_type='ml.m4.xlarge')

Прогноз может быть сделан следующим образом:

input ="0.12787057,  1.0612601,  -1.1081504"
predictor.predict(np.genfromtxt(StringIO(input), delimiter=",").reshape(1,3) )

Я хочу бытьможет обслуживать модель с помощью REST API и использовать HTTP POST с использованием лямбда и шлюза API.Я смог использовать invoke_endpoint () для этого с моделью XGBOOST в Sagemaker таким образом.Я не уверен, что отправлять в тело для Pytorch.

client = boto3.client('sagemaker-runtime')
response = client.invoke_endpoint(EndpointName=ENDPOINT  ,
ContentType='text/csv',
Body=???)

Мне кажется, мне нужно понять, как написать клиенту input_fn, чтобы он принимал и обрабатывал данные, которые я могу отправить через invoke_client.Я на правильном пути, и если да, то как можно написать input_fn для приема csv из invoke_endpoint?

1 Ответ

1 голос
/ 06 июня 2019

Да, вы на правильном пути.Вы можете отправить csv-сериализованный ввод в конечную точку, не используя predictor из SageMaker SDK и не используя другие SDK, такие как boto3, который установлен в лямбда-выражении:

import boto3
runtime = boto3.client('sagemaker-runtime')

payload = '0.12787057,  1.0612601,  -1.1081504'

response = runtime.invoke_endpoint(
    EndpointName=ENDPOINT_NAME,
    ContentType='text/csv',
    Body=payload.encode('utf-8'))

result = json.loads(response['Body'].read().decode()) 

Это будет переданоконечная точка ввод в формате csv, который вам может понадобиться изменить обратно в input_fn, чтобы вставить соответствующее измерение, ожидаемое моделью.

, например:

def input_fn(request_body, request_content_type):
    if request_content_type == 'text/csv':
        return torch.from_numpy(
            np.genfromtxt(StringIO(request_body), delimiter=',').reshape(1,3))

Примечание : я не смог протестировать конкретный input_fn выше с вашим входным содержимым и формой, но пару раз использовал подход на Sklearn RandomForest и смотрел на Pytorch SageMaker, обслуживающий документ приведенное выше обоснование должно работать.

Не стесняйтесь использовать журналы конечных точек в Cloudwatch для диагностики любой ошибки вывода (доступной из пользовательского интерфейса конечной точки в консоли), эти журналы обычно гораздо более многословны что высокоуровневые журналы возвращаются SDK логического вывода

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...