tf.keras: оценка разрывов model.updates при использовании tf.data.Dataset в качестве входных данных - PullRequest
0 голосов
/ 10 февраля 2019

Примечание. Весь код для автономного примера, воспроизводящего мою проблему, можно найти ниже.

У меня есть экземпляр tf.keras.models.Model (), и я хотел бы обучить его с помощью пользовательскогонизкоуровневый учебный цикл API TensorFlow.В рамках этого цикла обучения мне нужно убедиться, что мой пользовательский цикл обучения обновляет все переменные с состоянием из типов слоев, таких как tf.keras.layers.BatchNormalization.Чтобы это произошло, я понимаю из этого ответа Франсуа Шоле, что мне нужно оценивать model.updates на каждом этапе обучения.

Проблема в том, что это работает, когда вы кормите своегообучение данных для модели с использованием feed_dict, но это не работает, когда вы используете tf.data.Dataset объект.

Рассмотрим следующий абстрактный пример (вы можете найти конкретный пример, чтобы воспроизвести проблему ниже):

model = tf.keras.models.Model(...) # Some tf.keras model
dataset = tf.data.Dataset.from_tensor_slices(...) # Some tf.data.Dataset
iterator = dataset.make_one_shot_iterator()
features, labels = iterator.get_next()

model_output = model(features)

with tf.Session() as sess:
    ret = sess.run(model.updates)

Этот вызов sess.run() выдает ошибку

InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]

Эта ошибка, очевидно, не должна возникать.Мне не нужно указывать значение для заполнителя input_1, потому что я вызываю свою модель для tf.data.Dataset, а не для ввода входных данных в заполнитель через feed_dict.

Что можетЯ делаю, чтобы сделать эту работу?

Вот полностью воспроизводимый пример.Это простой классификатор изображений, обучаемый на Caltech256 (загрузите файлы TFRecord, используя ссылку внизу этого поста):

import tensorflow as tf
from tqdm import trange
import sys
import glob
import os

sess = tf.Session()
tf.keras.backend.set_session(sess)

num_classes = 257
image_size = (224, 224, 3)

# Build a simple CNN with BatchNorm layers.

input_tensor = tf.keras.layers.Input(shape=image_size)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), kernel_initializer='he_normal')(input_tensor)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(64, (3,3), strides=(2,2), kernel_initializer='he_normal')(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(128, (3,3), strides=(2,2), kernel_initializer='he_normal')(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.Conv2D(256, (3,3), strides=(2,2), kernel_initializer='he_normal')(x)
x = tf.keras.layers.BatchNormalization(axis=3)(x)
x = tf.keras.layers.Activation('relu')(x)
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(num_classes, activation='softmax', kernel_initializer='he_normal')(x)
model = tf.keras.models.Model(input_tensor, x)

# We'll monitor whether the moving mean and moving variance of the first BatchNorm layer is being updated as it should.
moving_mean = tf.reduce_mean(model.layers[2].moving_mean)
moving_variance = tf.reduce_mean(model.layers[2].moving_variance)

# Build a tf.data.Dataset from TFRecords.

tfrecord_directory = '/path/to/the/tfrecord/files/'

tfrecord_filennames = glob.glob(os.path.join(tfrecord_directory, '*.tfrecord'))

feature_schema = {'image': tf.FixedLenFeature([], tf.string),
                  'filename': tf.FixedLenFeature([], tf.string),
                  'label': tf.FixedLenFeature([], tf.int64)}

dataset = tf.data.Dataset.from_tensor_slices(tfrecord_filennames)
dataset = dataset.shuffle(len(tfrecord_filennames)) # Shuffle the TFRecord file names.
dataset = dataset.flat_map(lambda filename: tf.data.TFRecordDataset(filename))
dataset = dataset.map(lambda single_example_proto: tf.parse_single_example(single_example_proto, feature_schema)) # Deserialize tf.Example objects.
dataset = dataset.map(lambda sample: (sample['image'], sample['label']))
dataset = dataset.map(lambda image, label: (tf.image.decode_jpeg(image, channels=3), label)) # Decode JPEG images.
dataset = dataset.map(lambda image, label: (tf.image.resize_image_with_pad(image, target_height=image_size[0], target_width=image_size[1]), label))
dataset = dataset.map(lambda image, label: (tf.image.per_image_standardization(image), label))
dataset = dataset.map(lambda image, label: (image, tf.one_hot(indices=label, depth=num_classes))) # Convert labels to one-hot format.
dataset = dataset.shuffle(buffer_size=10000)
dataset = dataset.repeat()
dataset = dataset.batch(32)

iterator = dataset.make_one_shot_iterator()
batch_features, batch_labels = iterator.get_next()

# Build the training-relevant part of the graph.

model_output = model(batch_features)

loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=batch_labels, output=model_output, from_logits=False))

train_step = tf.train.AdamOptimizer().minimize(loss)

# The next block is for the metrics.
with tf.variable_scope('metrics') as scope:
    predictions_argmax = tf.argmax(model_output, axis=-1, output_type=tf.int64)
    labels_argmax = tf.argmax(batch_labels, axis=-1, output_type=tf.int64)
    mean_loss_value, mean_loss_update_op = tf.metrics.mean(loss)
    acc_value, acc_update_op = tf.metrics.accuracy(labels=labels_argmax, predictions=predictions_argmax)
    local_metric_vars = tf.contrib.framework.get_variables(scope=scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    metrics_reset_op = tf.variables_initializer(var_list=local_metric_vars, name='metrics_reset_op')

# Run the training.

epochs = 3
steps_per_epoch = 1000

fetch_list = [mean_loss_value,
              acc_value,
              moving_mean,
              moving_variance,
              train_step,
              mean_loss_update_op,
              acc_update_op] + model.updates

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

with sess.as_default():

    for epoch in range(1, epochs+1):

        tr = trange(steps_per_epoch, file=sys.stdout)
        tr.set_description('Epoch {}/{}'.format(epoch, epochs))

        sess.run(metrics_reset_op)

        for train_step in tr:

            ret = sess.run(fetches=fetch_list, feed_dict={tf.keras.backend.learning_phase(): 1})

            tr.set_postfix(ordered_dict={'loss': ret[0],
                                         'accuracy': ret[1],
                                         'bn1 moving mean': ret[2],
                                         'bn1 moving variance': ret[3]})

Запуск этого кода приводит к ошибке, описанной выше:

InvalidArgumentError: You must feed a value for placeholder tensor 'input_1' with dtype float and shape [?,224,224,3]

Очень дерьмовый обходной путь, позволяющий обойти эту проблему, состоит в том, чтобы выбрать следующий пакет с помощью отдельного вызова sess.run() и затем передать извлеченные массивы Numpy во второй вызов sess.run() с помощью feed_dict.Это работает, но, очевидно, частично отрицает цель использования tf.data API:

# Build the training-relevant part of the graph.

labels = tf.placeholder(dtype=tf.float32, shape=(None, num_classes), name='labels')

loss = tf.reduce_mean(tf.keras.backend.categorical_crossentropy(target=labels, output=model.output, from_logits=False))

train_step = tf.train.AdamOptimizer().minimize(loss)

with tf.variable_scope('metrics') as scope:
    predictions_argmax = tf.argmax(model.output, axis=-1, output_type=tf.int64)
    labels_argmax = tf.argmax(labels, axis=-1, output_type=tf.int64)
    mean_loss_value, mean_loss_update_op = tf.metrics.mean(loss)
    acc_value, acc_update_op = tf.metrics.accuracy(labels=labels_argmax, predictions=predictions_argmax)
    local_metric_vars = tf.contrib.framework.get_variables(scope=scope, collection=tf.GraphKeys.LOCAL_VARIABLES)
    metrics_reset_op = tf.variables_initializer(var_list=local_metric_vars, name='metrics_reset_op')

# Run the training. With BatchNorm.

epochs = 3
steps_per_epoch = 1000

fetch_list = [mean_loss_value,
              acc_value,
              moving_mean,
              moving_variance,
              train_step,
              mean_loss_update_op,
              acc_update_op] + model.updates

sess.run(tf.global_variables_initializer())
sess.run(tf.local_variables_initializer())

with sess.as_default():

    for epoch in range(1, epochs+1):

        tr = trange(steps_per_epoch, file=sys.stdout)
        tr.set_description('Epoch {}/{}'.format(epoch, epochs))

        sess.run(metrics_reset_op)

        for train_step in tr:

            b_images, b_labels = sess.run([batch_features, batch_labels])

            ret = sess.run(fetches=fetch_list, feed_dict={tf.keras.backend.learning_phase(): 1,
                                                          model.input: b_images,
                                                          labels: b_labels})

            tr.set_postfix(ordered_dict={'loss': ret[0],
                                         'accuracy': ret[1],
                                         'bn1 moving mean': ret[2],
                                         'bn1 moving variance': ret[3]})

Как уже упоминалось выше, это просто плохой обходной путь.Как я могу заставить это работать должным образом?

Вы можете скачать файлы TFRecord здесь .

1 Ответ

0 голосов
/ 23 февраля 2019

Проблема в этой строке:

model_output = model(batch_features)

Обычно нормально назвать модель на тензоре, но в этом случае это вызывает проблемы.Когда модель была создана, ее входной слой создал тензор-заполнитель, который нужно подавать при вызове model.updates.Вместо того, чтобы вызывать модель для тензора batch_features, вы должны вместо этого установить для входного слоя модели значение batch_features (вместо создания заполнителя) при его создании.То есть вам нужно установить правильный ввод при создании экземпляра модели, после этого будет слишком поздно.Это делается так:

input_tensor = tf.keras.layers.Input(tensor=batch_features)

Теперь работает model.updates отлично работает.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...