tenorflow 2.0, model.fit (): в вашем вводе закончились данные - PullRequest
1 голос
/ 14 февраля 2020

Я абсолютно новичок в TensorFlow и Keras, и я пытаюсь найти способ пробовать некоторый код, который я нахожу в Интернете.

В частности, я использую fashion-MNIST, состоящий из 60000 примеров и набор тестов из 10000 примеров. Каждый из них представляет собой изображение в градациях серого 28x28.

Я следую этому уроку "https://towardsdatascience.com/building-your-first-neural-network-in-tensorflow-2-tensorflow-for-hackers-part-i-e1e2f1dfe7a0", и у меня нет проблем до определения

history = model.fit(
train_dataset.repeat(), 
epochs=10, 
steps_per_epoch=500,
validation_data=val_dataset.repeat(), 
validation_steps=2)

Пока я понял, мне нужно используйте train_dataset.repeat () в качестве входного набора данных, потому что в противном случае мне не хватит учебного примера, использующего эти значения для гиперпараметров (epochs, steps_per_epochs).

Мой вопрос: как мне избежать использования .repeat () ? Как мне нужно изменить гиперпараметры?

Я копирую код здесь, для простоты:

def preprocess(x,y):

    x = tf.cast(x,tf.float32) / 255.0
    y = tf.cast(y, tf.float32)

    return x,y 

def create_dataset(xs, ys, n_classes=10):

    ys = tf.one_hot(ys, depth=n_classes)

    return tf.data.Dataset.from_tensor_slices((xs, ys)).map(preprocess).shuffle(len(ys)).batch(128)


model.compile(optimizer = 'adam', loss =tf.losses.CategoricalCrossentropy(from_logits= True), metrics =['accuracy'])

history1 = model.fit(train_dataset.repeat(), 
                    epochs=10, 
                    steps_per_epoch=500,
                    validation_data=val_dataset.repeat(), 
                    validation_steps=2)

Спасибо!

1 Ответ

0 голосов
/ 14 февраля 2020

Если вы не хотите использовать .repeat (), вам нужно, чтобы ваша модель передавала все ваши данные только один раз за эпоху.

Для этого вам нужно рассчитать, сколько шагов Ваша модель пройдет через весь набор данных, вычисление будет простым:

steps_per_epoch = len(train_dataset) // batch_size

Таким образом, с набором train_datase из 60 000 сэмплов и размером batch_size из 128 вам нужно иметь 468 шагов за эпоху.

Устанавливая этот параметр таким образом, вы убедитесь, что вы не превышаете размер своего набора данных.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...