Кроме предоставления ответа на ваш вопрос, я сделаю код более TF2.0
-подобным.Если у вас есть какие-либо вопросы / требуется разъяснение, пожалуйста, оставьте комментарий ниже.
1.Загрузка данных
Я бы посоветовал использовать библиотеку Tensorflow Datasets .Нет абсолютно никакой необходимости загружать данные в numpy
и преобразовывать их в tf.data.Dataset
, если это можно сделать одной строкой:
import tensorflow_datasets as tfds
dataset = tfds.load("mnist", as_supervised=True, split=tfds.Split.TRAIN)
Строка выше вернет только TRAIN
split (подробнеео тех здесь ).
2.Определить дополнения и сводки
Для сохранения изображений необходимо сохранять объект tf.summar.SummaryWriter на каждом проходе.
Я создал удобный класс переноса с__call__
метод для легкого использования с возможностями tf.data.Dataset
map
:
import tensorflow as tf
class ExampleAugmentation:
def __init__(self, logdir: str, max_images: int, name: str):
self.file_writer = tf.summary.create_file_writer(logdir)
self.max_images: int = max_images
self.name: str = name
self._counter: int = 0
def __call__(self, image, label):
augmented_image = tf.image.random_flip_left_right(
tf.image.random_flip_up_down(image)
)
with self.file_writer.as_default():
tf.summary.image(
self.name,
augmented_image,
step=self._counter,
max_outputs=self.max_images,
)
self._counter += 1
return augmented_image, label
name
будет именем, под которым будет сохраняться каждая часть изображений.Какую часть вы можете спросить - часть, определенная max_outputs
.
Скажем, image
в __call__
будет иметь форму (32, 28, 28, 1)
, где первое измерение - пакетный, вторая ширина, третья высота и последние каналы (в случае MNIST только один, но это измерение необходимо в tf.image
дополнений).Кроме того, допустим, что max_outputs
указано как 4
.В этом случае будут сохранены только 4 первых изображения из серии.Значение по умолчанию 3
, поэтому вы можете установить его как BATCH_SIZE
, чтобы сохранить каждое изображение.
В Tensorboard
каждое изображение будет отдельным образцом, который вы можете повторить в конце.
_counter
необходимо, чтобы изображения не были перезаписан (думаю, не совсем уверен, было бы неплохо получить разъяснения от кого-то другого).
Важно: Возможно, вы захотите переименовать этот класс в что-то вроде ImageSaver
при выполнении более серьезных дели перемещаем увеличение для разделения функторов / лямбда-функций.Я думаю, этого достаточно для презентаций.
3.Установка глобальных переменных
Пожалуйста, не смешивайте объявление функций, глобальные переменные, загрузку данных и другие (например, загрузка данных и создание функции впоследствии).Я знаю, TF1.0
поощрял этот тип программирования, но они пытаются уйти от него, и вы, возможно, захотите следовать тенденции.
Ниже я определил некоторые глобальные переменные, которые будут использоваться в следующих частях, довольноСамо собой разумеется, я думаю:
BATCH_SIZE = 32
DATASET_SIZE = 60000
EPOCHS = 5
LOG_DIR = "/logs/images"
AUGMENTATION = ExampleAugmentation(LOG_DIR, max_images=4, name="Images")
4.Дополнение набора данных
Аналогично вашему, но с небольшим поворотом:
dataset = (
dataset.map(
lambda image, label: (
tf.image.convert_image_dtype(image, dtype=tf.float32),
label,
)
)
.batch(BATCH_SIZE)
.map(AUGMENTATION)
.repeat(EPOCHS)
)
repeat
необходимо, поскольку загруженный набор данных является генератором tf.image.convert_image_dtype
- лучший и более читаемый вариант, чем явный tf.cast
, смешанный с делением на 255
(и обеспечивает надлежащий формат изображения) - пакетирование выполняется перед увеличением только ради представления
5.Определите модель, скомпилируйте, обучите
Почти так же, как вы это сделали в своем примере, но я предоставил дополнительные steps_per_epoch
, поэтому fit
знает, сколько партий составляют эпоху:
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
model.fit(
dataset,
epochs=EPOCHS,
steps_per_epoch=DATASET_SIZE // BATCH_SIZE,
callbacks=[tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR)],
)
Не так много, чтобы объяснить, кроме этого, я думаю.
6.Запустите Tensorboard
Поскольку TF2.0
можно делать это внутри colab, используя %tensorboard --logdir /logs/images
, просто хотел добавить это для тех, кто может посетить эту проблему.Делайте это как хотите, в любом случае вы точно знаете, как это сделать.
Изображения должны быть внутри IMAGES
, и каждый образец с именем name
предоставляется объекту AUGMENTATION
.
7.Весь код (чтобы облегчить жизнь каждому)
import tensorflow as tf
import tensorflow_datasets as tfds
class ExampleAugmentation:
def __init__(self, logdir: str, max_images: int, name: str):
self.file_writer = tf.summary.create_file_writer(logdir)
self.max_images: int = max_images
self.name: str = name
self._counter: int = 0
def __call__(self, image, label):
augmented_image = tf.image.random_flip_left_right(
tf.image.random_flip_up_down(image)
)
with self.file_writer.as_default():
tf.summary.image(
self.name,
augmented_image,
step=self._counter,
max_outputs=self.max_images,
)
self._counter += 1
return augmented_image, label
if __name__ == "__main__":
# Global settings
BATCH_SIZE = 32
DATASET_SIZE = 60000
EPOCHS = 5
LOG_DIR = "/logs/images"
AUGMENTATION = ExampleAugmentation(LOG_DIR, max_images=4, name="Images")
# Dataset
dataset = tfds.load("mnist", as_supervised=True, split=tfds.Split.TRAIN)
dataset = (
dataset.map(
lambda image, label: (
tf.image.convert_image_dtype(image, dtype=tf.float32),
label,
)
)
.batch(BATCH_SIZE)
.map(AUGMENTATION)
.repeat(EPOCHS)
)
# Model and training
model = tf.keras.models.Sequential(
[
tf.keras.layers.Flatten(input_shape=(28, 28, 1)),
tf.keras.layers.Dense(128, activation="relu"),
tf.keras.layers.Dropout(0.2),
tf.keras.layers.Dense(10, activation="softmax"),
]
)
model.compile(
optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"]
)
model.fit(
dataset,
epochs=EPOCHS,
steps_per_epoch=DATASET_SIZE // BATCH_SIZE,
callbacks=[tf.keras.callbacks.TensorBoard(log_dir=LOG_DIR)],
)