Практическое задание AWS Компьютерное зрение: get_Cifar10_dataset - PullRequest
0 голосов
/ 21 июня 2020

У меня проблема с этим методом, который должен возвращать как обучающий, так и проверочный набор данных и проверять его, чтобы вернуть индекс, который соответствует первому вхождению каждого класса в CIFAR10.

это код: def get_cifar10_dataset(): """ Должен создать сеть cifar 10 и определить индекс набора данных при первом появлении каждого нового класса

:return: tuple of training and validation dataset as well as label indices
:rtype: (gluon.data.Dataset, 'dict_values' object is not subscriptable, gluon.data.Dataset, 
 dict[int:int])
"""

train_data = None
val_data = None
# YOUR CODE HERE
train_data = datasets.CIFAR10(train=True, root=M5_IMAGES)
val_data = datasets.CIFAR10(train=False, root=M5_IMAGES)

Ответы [ 2 ]

0 голосов
/ 18 августа 2020

Вы можете сделать это

def get_cifar10_dataset():
    """
    Should create the cifar 10 network and identify the dataset index of the first time each new class appears
    
    :return: tuple of training and validation dataset as well as label indices
    :rtype: (gluon.data.Dataset, dict[int:int], gluon.data.Dataset, dict[int:int])
    """
    train_data = None
    val_data = None
    train_indices = {}
    val_indices = {}
    
    # Use `root=M5_IMAGES` for your dataset
    train_data = gluon.data.vision.datasets.CIFAR10(train=True, root=M5_IMAGES)
    val_data   = gluon.data.vision.datasets.CIFAR10(train=False, root=M5_IMAGES)
    
    #for train
    for i in range(len(train_data)):
        if train_data[i][1] not in train_indices:
            train_indices[train_data[i][1]] = i
    #for valid
    for i in range(len(val_data)):
        if val_data[i][1] not in val_indices:
            val_indices[val_data[i][1]] = i
    
    #raise NotImplementedError()
    
    return train_data, train_indices, val_data, val_indices


0 голосов
/ 29 июня 2020

Вас просят вернуть словарь с метками и соответствующими индексами. Использование следующей функции может решить вашу проблему.

def get_idx_dict(data):

    lis = []
    idx = []
    indices = {}
    
    for i in range(len(data)):
        if data[i][1] not in lis:
            lis.append(data[i][1])
            idx.append(i)
            
    indices = {lis[i]: idx[i] for i in range(len(lis))}
    return indices

Функция возвращает словарь с желаемым результатом. Используйте эту функцию для данных из набора для обучения и проверки.

train_indices = get_idx_dict(train_data)
val_indices = get_idx_dict(val_data)
...