Как проанализировать AWS Sagemaker SM_USER_ARGS с argparse в пространстве имен argparse? - PullRequest
0 голосов
/ 12 марта 2020

AWS Sagemaker использует SM_USER_ARGS (как задокументировано здесь ) в качестве переменной среды, в которой он содержит строку (список) аргументов, передаваемых пользователем. Таким образом, значение переменной среды выглядит следующим образом: '["--test_size","0.2","--random_seed","42", "--not_optmize"]'.

С json.loads() я могу преобразовать эту строку в список python. Хотя я хочу создать абстрактный модуль, который возвращает argparse Namespace таким образом, чтобы остальная часть кода оставалась неизменной, независимо от того, запускаю ли я его локально или в службе AWS Sagemaker.

Итак, по сути, мне нужен код, который получает входные данные ["--test_size","0.2","--random_seed","42", "--not_optmize"] и выходные данные Namespace(test_size=0.2, random_seed='42', not_optmize=True, <other_arguments>... ]).

Помогает ли мне пакет python argparse ? Я пытаюсь выяснить способ, которым мне не нужно повторно реализовывать анализатор argparse.

Вот пример, у меня есть этот файл config.ini:

[Docker]
home_dir = /opt
SM_MODEL_DIR = %(home_dir)s/ml/model
SM_CHANNELS = ["training"]
SM_NUM_GPUS = 1
SM_NUM_CPUS =
SM_LOG_LEVEL = 20
SM_USER_ARGS = ["--test_size","0.2","--random_seed","42"]
SM_INPUT_DIR = %(home_dir)s/ml/input
SM_INPUT_CONFIG_DIR = %(home_dir)s/ml/input/config
SM_OUTPUT_DIR = %(home_dir)s/ml/output
SM_OUTPUT_INTERMEDIATE_DIR = %(home_dir)s/ml/output/intermediate

У меня есть этот класс Argparser:

import argparse
import configparser
import datetime
import json
import multiprocessing
import os
import time
from pathlib import Path
from typing import Any, Dict

from .files import JsonFile, YAMLFile


class ArgParser(ABC):

    @abstractmethod
    def get_arguments(self) -> Dict[str, Any]:
        pass


class AWSArgParser(ArgParser):

    def __init__(self):
        configuration_file_path = 'config.ini'

        self.environment = "Sagemaker" \
            if os.environ.get("SM_MODEL_DIR", False) \
            else os.environ.get("ENVIRON", "Default")

        config = configparser.ConfigParser()
        config.read(configuration_file_path)
        if self.environment == "Local":
            config[self.environment]["home_dir"] = str(pathlib.Path(__file__).parent.absolute())
        if self.environment != 'Sagemaker':
            config[self.environment]["SM_NUM_CPUS"] = str(multiprocessing.cpu_count())

        for key, value in config[self.environment].items():
            os.environ[key.upper()] = value

        self.parser = argparse.ArgumentParser()
        # AWS Sagemaker default environmental arguments
        self.parser.add_argument(
            '--model_dir',
            type=str,
            default=os.environ['SM_MODEL_DIR'],
        )
        self.parser.add_argument(
            '--channel_names',
            default=json.loads(os.environ['SM_CHANNELS']),
        )
        self.parser.add_argument(
            '--num_gpus',
            type=int,
            default=os.environ['SM_NUM_GPUS'],
        )
        self.parser.add_argument(
            '--num_cpus',
            type=int,
            default=os.environ['SM_NUM_CPUS'],
        )
        self.parser.add_argument(
            '--user_args',
            default=json.loads(os.environ['SM_USER_ARGS']),
        )
        self.parser.add_argument(
            '--input_dir',
            type=str,
            default=os.environ['SM_INPUT_DIR'],
        )
        self.parser.add_argument(
            '--input_config_dir',
            type=Path,
            default=os.environ['SM_INPUT_CONFIG_DIR'],
        )
        self.parser.add_argument(
            '--output_dir',
            type=Path,
            default=os.environ['SM_OUTPUT_DIR'],
        )

        # Extra arguments
        self.run_tag = datetime.datetime \
            .fromtimestamp(time.time()) \
            .strftime('%Y-%m-%d-%H-%M-%S')
        self.parser.add_argument(
            '--run_tag',
            default=self.run_tag,
            type=str,
            help=f"Run tag (default: 'datetime.fromtimestamp')",
        )
        self.parser.add_argument(
            '--environment',
            type=str,
            default=self.environment,
        )

        self.args = self.parser.parse_args()

    def get_arguments(self) -> Dict[str, Any]:
        <parse self.args.user_args>

        return self.args

, тогда у меня есть сценарий train:

from utils.arg_parser import AWSArgParser

if __name__ == '__main__':
    logger.info(f"Begin train.py")

    if os.environ["ENVIRON"] == "Sagemaker":
        arg_parser = AWSArgParser()
        args = arg_parser.get_arguments()
    else:
        args = <normal local parse>

1 Ответ

0 голосов
/ 12 марта 2020

После комментария @ chepner пример решения будет выглядеть примерно так:

config.ini file:

[Docker]
home_dir = /opt
SM_MODEL_DIR = %(home_dir)s/ml/model
SM_CHANNELS = ["training"]
SM_NUM_GPUS = 1
SM_NUM_CPUS =
SM_LOG_LEVEL = 20
SM_USER_ARGS = ["--test_size","0.2","--random_seed","42", "--not_optimize"]
SM_INPUT_DIR = %(home_dir)s/ml/input
SM_INPUT_CONFIG_DIR = %(home_dir)s/ml/input/config
SM_OUTPUT_DIR = %(home_dir)s/ml/output
SM_OUTPUT_INTERMEDIATE_DIR = %(home_dir)s/ml/output/intermediate

A TrainArgParser class like this:

class ArgParser(ABC):

    @abstractmethod
    def get_arguments(self) -> Dict[str, Any]:
        pass


class TrainArgParser(ArgParser):

    def __init__(self):
        configuration_file_path = 'config.ini'

        self.environment = "Sagemaker" \
            if os.environ.get("SM_MODEL_DIR", False) \
            else os.environ.get("ENVIRON", "Default")

        config = configparser.ConfigParser()
        config.read(configuration_file_path)
        if self.environment == "Local":
            config[self.environment]["home_dir"] = str(pathlib.Path(__file__).parent.absolute())
        if self.environment != 'Sagemaker':
            config[self.environment]["SM_NUM_CPUS"] = str(multiprocessing.cpu_count())

        for key, value in config[self.environment].items():
            os.environ[key.upper()] = value

        self.parser = argparse.ArgumentParser()
        # AWS Sagemaker default environmental arguments
        self.parser.add_argument(
            '--model_dir',
            type=str,
            default=os.environ['SM_MODEL_DIR'],
        )
        self.parser.add_argument(
            '--channel_names',
            default=json.loads(os.environ['SM_CHANNELS']),
        )
        self.parser.add_argument(
            '--num_gpus',
            type=int,
            default=os.environ['SM_NUM_GPUS'],
        )
        self.parser.add_argument(
            '--num_cpus',
            type=int,
            default=os.environ['SM_NUM_CPUS'],
        )
        self.parser.add_argument(
            '--user_args',
            default=json.loads(os.environ['SM_USER_ARGS']),
        )
        self.parser.add_argument(
            '--input_dir',
            type=str,
            default=os.environ['SM_INPUT_DIR'],
        )
        self.parser.add_argument(
            '--input_config_dir',
            type=Path,
            default=os.environ['SM_INPUT_CONFIG_DIR'],
        )
        self.parser.add_argument(
            '--output_dir',
            type=Path,
            default=os.environ['SM_OUTPUT_DIR'],
        )

        # Extra arguments
        self.run_tag = datetime.datetime \
            .fromtimestamp(time.time()) \
            .strftime('%Y-%m-%d-%H-%M-%S')
        self.parser.add_argument(
            '--run_tag',
            default=self.run_tag,
            type=str,
            help=f"Run tag (default: 'datetime.fromtimestamp')",
        )
        self.parser.add_argument(
            '--environment',
            type=str,
            default=self.environment,
        )

        self.args = self.parser.parse_args()

    def get_arguments(self) -> Dict[str, Any]:
        # Not in AWS Sagemaker arguments
        self.parser.add_argument(
            '--test_size',
            default=0.2,
            type=float,
            help="Test dataset size (default: '0.2')",
        )
        self.parser.add_argument(
            '--random_seed',
            default=42,
            type=int,
            help="Random number for initialization (default: '42')",
        )
        self.parser.add_argument(
            '--secrets',
            type=YAMLFile.parse_string,
            default='',
            help="An yaml formated string (default: '')"
        )
        self.parser.add_argument(
            '--bucket_name',
            type=str,
            default='',
            help="Bucket name of a remote storage (default: '')"
        )
        self.args = self.parser.parse_args(self.args.user_args)

        return self.args

и запись_текста для train будет начинаться так:

#!/usr/bin/env python

from utils.arg_parser import TrainArgParser

if __name__ == '__main__':
    logger.info(f"Begin train.py")

    arg_parser = TrainArgParser()
    args = arg_parser.get_arguments()
    print(args)

Это должно вывести что-то вроде этого:

Namespace(bucket_name='', channel_names=['training'], environment='Docker', input_config_dir=PosixPath('/opt/ml/input/config'), input_dir='/opt/ml/input', model_dir='/opt/ml/model', num_cpus=8, num_gpus=1, output_dir=PosixPath('/opt/ml/output'), random_seed=42, run_tag='2020-03-11-22-18-21', secrets={}, test_size=0.2, user_args=['--test_size', '0.2', '--random_seed', '42'])

Но это бесполезно, если AWS Sagemaker рассматривает SM_USER_ARGS и SM_HPS как одно и то же. (

...