Я пачкаю руки, используя TensorFlow 2.0 для обучения моей модели. Новая функция итерации в tf.data
API довольно крута. Однако, когда я выполнял следующие коды, я обнаружил, что, в отличие от функций итерации в torch.utils.data.DataLoader
, он не перетасовывает данные автоматически в каждую эпоху. Как мне добиться этого с помощью TF2.0?
import numpy as np
import tensorflow as tf
def sample_data():
...
data = sample_data()
NUM_EPOCHS = 10
BATCH_SIZE = 128
# Subsample the data
mask = range(int(data.shape[0]*0.8), data.shape[0])
data_val = data[mask]
mask = range(int(data.shape[0]*0.8))
data_train = data[mask]
train_dset = tf.data.Dataset.from_tensor_slices(data_train).\
shuffle(buffer_size=10000).\
repeat(1).batch(BATCH_SIZE)
val_dset = tf.data.Dataset.from_tensor_slices(data_val).\
batch(BATCH_SIZE)
loss_metric = tf.keras.metrics.Mean(name='train_loss')
optimizer = tf.keras.optimizers.Adam(0.001)
@tf.function
def train_step(inputs):
...
for epoch in range(NUM_EPOCHS):
# Reset the metrics
loss_metric.reset_states()
for inputs in train_dset:
train_step(inputs)
...