Как реализовать набор данных PyTorch для использования с AWS SageMaker? - PullRequest
0 голосов
/ 02 января 2019

Я реализовал PyTorch Dataset, который работает локально (на моем собственном рабочем столе), но при запуске в AWS SageMaker он ломается.Моя реализация Dataset выглядит следующим образом.

class ImageDataset(Dataset):
    def __init__(self, path='./images', transform=None):
        self.path = path
        self.files = [join(path, f) for f in listdir(path) if isfile(join(path, f)) and f.endswith('.jpg')]
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def __len__(self):
        return len(files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # we may infer the label from the filename
        dash_idx = img_name.rfind('-')
        dot_idx = img_name.rfind('.')
        label = int(img_name[dash_idx + 1:dot_idx])

        image = Image.open(img_name)

        if self.transform:
            image = self.transform(image)

        return image, label

Я следую этому примеру и этому одному , и я запускаю estimator следующим образом.

inputs = {
 'train': 'file://images',
 'eval': 'file://images'
}
estimator = PyTorch(entry_point='pytorch-train.py',
                            role=role,
                            framework_version='1.0.0',
                            train_instance_count=1,
                            train_instance_type=instance_type)
estimator.fit(inputs)

Я получаю следующую ошибку.

FileNotFoundError: [Errno 2] Нет такого файла или каталога: './images'

В следующем примере они загружают набор данных CFAIR (который загружается локально) в S3.

inputs = sagemaker_session.upload_data(path='data', bucket=bucket, key_prefix='data/cifar10')

Если я взгляну на inputs, это просто строковый литерал s3://sagemaker-us-east-3-184838577132/data/cifar10,Код для создания Dataset и DataLoader показан здесь здесь , что не поможет, если я не отыщу источник и не прохожу логику.

Я думаю, что нужнов моем ImageDataset происходит указание пути S3 и использование AWS CLI или чего-то еще для запроса файлов и получения их содержимого.Я не думаю, что AWS CLI является правильным подходом, поскольку он опирается на консоль, и мне придется выполнить некоторые команды подпроцесса, а затем проанализировать их.

Должен быть рецепт или что-то для создания пользовательского Dataset, подкрепленного S3 файлами, верно?

1 Ответ

0 голосов
/ 08 января 2019

Мне удалось создать PyTorch Dataset, подкрепленный данными S3, используя boto3. Вот фрагмент, если кому-то интересно.

class ImageDataset(Dataset):
    def __init__(self, path='./images', transform=None):
        self.path = path
        self.s3 = boto3.resource('s3')
        self.bucket = self.s3.Bucket(path)
        self.files = [obj.key for obj in self.bucket.objects.all()]
        self.transform = transform
        if transform is None:
            self.transform = transforms.Compose([
                transforms.Resize((128, 128)),
                transforms.ToTensor(),
                transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
            ])

    def __len__(self):
        return len(files)

    def __getitem__(self, idx):
        img_name = self.files[idx]

        # we may infer the label from the filename
        dash_idx = img_name.rfind('-')
        dot_idx = img_name.rfind('.')
        label = int(img_name[dash_idx + 1:dot_idx])

        # we need to download the file from S3 to a temporary file locally
        # we need to create the local file name
        obj = self.bucket.Object(img_name)
        tmp = tempfile.NamedTemporaryFile()
        tmp_name = '{}.jpg'.format(tmp.name)

        # now we can actually download from S3 to a local place
        with open(tmp_name, 'wb') as f:
            obj.download_fileobj(f)
            f.flush()
            f.close()
            image = Image.open(tmp_name)

        if self.transform:
            image = self.transform(image)

        return image, label
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...