Как перетасовать данные в каждую эпоху с помощью API tf.data в TensorFlow 2.0? - PullRequest
1 голос
/ 25 апреля 2019

Я пачкаю руки, используя TensorFlow 2.0 для обучения моей модели. Новая функция итерации в tf.data API довольно крута. Однако, когда я выполнял следующие коды, я обнаружил, что, в отличие от функций итерации в torch.utils.data.DataLoader, он не перетасовывает данные автоматически в каждую эпоху. Как мне добиться этого с помощью TF2.0?

import numpy as np
import tensorflow as tf
def sample_data():
    ...

data = sample_data()

NUM_EPOCHS = 10
BATCH_SIZE = 128

# Subsample the data
mask = range(int(data.shape[0]*0.8), data.shape[0])
data_val = data[mask]
mask = range(int(data.shape[0]*0.8))
data_train = data[mask]

train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                 shuffle(buffer_size=10000).\
                                repeat(1).batch(BATCH_SIZE)
val_dset = tf.data.Dataset.from_tensor_slices(data_val).\
                                 batch(BATCH_SIZE)


loss_metric = tf.keras.metrics.Mean(name='train_loss')
optimizer = tf.keras.optimizers.Adam(0.001)

@tf.function
def train_step(inputs):
    ...

for epoch in range(NUM_EPOCHS):
    # Reset the metrics
    loss_metric.reset_states()
    for inputs in train_dset:
        train_step(inputs)
    ...

1 Ответ

1 голос
/ 25 апреля 2019

партия должна быть перетасована:

train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
                                repeat(1).batch(BATCH_SIZE)

train_dset = train_dset.shuffle(buffer_size=buffer_size)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...