Есть ли способ увеличить размер набора данных с помощью меток с помощью увеличения данных? - PullRequest
0 голосов
/ 12 июля 2020

Я пытаюсь реализовать логистическую c регрессию на Kaggle di git распознавание набор данных . В наборе поездов 42000 строк, и я хочу увеличить количество с помощью увеличения данных.

Я пробовал использовать объект keras ImageDataGenerator

datagen = ImageDataGenerator(  
        rotation_range=30,   
        zoom_range = 0.2,  
        width_shift_range=0.2,         
        height_shift_range=0.2)

datagen.fit(X_train)

, но размер остался прежним, я позже обнаружил, что ImageDataGenerator на самом деле не добавляет строки, а вставляет расширенные данные во время обучения. Есть ли другой инструмент для сохранения или увеличения данных с такими же ярлыками?

1 Ответ

1 голос
/ 12 июля 2020

Вот как я в итоге сохранил расширенные данные с метками. Я отобрал 5 рядов для удовольствия от просмотра. И for l oop может быть не лучшим способом записи в массив, когда рассматривается полный набор данных

#importing data
train = pd.read_csv("train.csv")
X_train = train.drop(labels=["label"], axis=1)
y_train = train.label

#sampling 5 rows and reshaping x to 4D array
x = X_train[0:5].values.reshape(-1,28,28,1)
y = y_train[0:5]

#Augmentation parameters
from keras_preprocessing.image import ImageDataGenerator
datagen = ImageDataGenerator(  
        rotation_range=30,   
        zoom_range = 0.2,  
        width_shift_range=0.2,  
        height_shift_range=0.2,  
        )  

#using .flow instead of .fit to write to an array
augmented_data = []
num_augmented = 0
batch = 5  # for 5*5 = 25 entries
for X_batch, y_batch in datagen.flow(X_2, y, batch_size=batch, shuffle=False,):
    augmented_data.append(X_batch)
    augmented_labels.append(y_batch)
    num_augmented += 1
    if num_augmented == x.shape[0]:
        break
augmented_data = np.concatenate(augmented_data) #final shape = (25,28,28,1)
augmented_labels = np.concatenate(augmented_labels)


#Lets take a look at augmented images
for index, image in enumerate(augmented_data):
    plt.subplot(5, 5, index + 1)
    plt.imshow(np.reshape(image, (28,28)), cmap=plt.cm.gray)


# reshaping and converting to df
augmented_data_reshaped = augmented_data.reshape(25, 784)
augmented_dataframe = pd.DataFrame(augmented_data_reshaped)
# inserting labels in df
augmented_dataframe.insert(0, "label", augmented_labels)
header = list(train.columns.values)
augmented_dataframe.columns = header
# write
augmented_dataframe.to_csv("augmented.csv")

Данные с расширенными цифрами

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...