tenensflow.keras оценщик train_and_evaluate выдает RuntimeError при оценке - PullRequest
0 голосов
/ 25 апреля 2019

Моя версия тензорного потока: 1.12.0

Я пытаюсь переобучить предварительно обученную сеть Keras ResNet-50 в моем настраиваемом наборе данных (сохраненном как tfrecords).Ниже приведен мой код:

import tensorflow as tf
from tensorflow import keras
import numpy as np

from tensorflow.keras.applications import resnet50

from tensorflow.keras import models, layers, optimizers

def dummy_model(features, num_classes):

    conv1 = layers.Conv2D(32, kernel_size=4, activation='relu')(features)
    pool1 = layers.MaxPooling2D(pool_size=(2, 2))(conv1)
    conv2 = layers.Conv2D(16, kernel_size=4, activation='relu')(pool1)
    pool2 = layers.MaxPooling2D(pool_size=(2, 2))(conv2)
    flat = layers.Flatten()(pool2)
    hidden_layer_1 = layers.Dense(10, activation='relu')(flat)
    logits = layers.Dense(num_classes)(hidden_layer_1) 

    return logits



def resnet50_network(features, num_classes):

    base = resnet50.ResNet50(weights='imagenet', input_tensor=features,include_top=False)
    base_features = base.output    
    flat = layers.Flatten()(base_features)    
    hidden_layer_1 = layers.Dense(1024, activation='relu')(flat)    
    logits = layers.Dense(num_classes)(hidden_layer_1)

    return logits



def model_fn(features, labels, mode, params):

    num_classes = params['num_classes']
    #tf.keras.backend.set_learning_phase(mode==tf.estimator.ModeKeys.TRAIN)
    #logits = dummy_model(features, num_classes)
    logits = resnet50_network(features, num_classes)

    loss_tensor = tf.nn.softmax_cross_entropy_with_logits_v2(
        logits=logits,
        labels=labels)

    loss = tf.reduce_mean(loss_tensor, name='loss')

    if mode==tf.estimator.ModeKeys.EVAL:
        prec, prec_update_op = tf.metrics.precision(labels=tf.argmax(labels), predictions=logits, name='precision_op')
        recall, recall_update_op = tf.metrics.recall(labels=tf.argmax(labels), predictions=logits, name='recall_op')

        metrics = { \
        'recall':(recall, recall_update_op), 
        'precision':(prec, prec_update_op) }

        return tf.estimator.EstimatorSpec(mode, loss=loss, eval_metric_ops=metrics)

    assert mode==tf.estimator.ModeKeys.TRAIN

    optimizer = tf.train.AdamOptimizer(
                 learning_rate=params['learning_rate'],
                 name='Adam')

    global_step = tf.train.get_global_step()

    update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
    with tf.control_dependencies(update_ops):
        train_op = optimizer.minimize(loss, global_step = global_step)

    return tf.estimator.EstimatorSpec(mode, loss=loss, train_op=train_op)

configuration = tf.estimator.RunConfig(
      model_dir = save_directory,
      keep_checkpoint_max=10,
      session_config = sess_config,
      save_checkpoints_steps=1000,
      log_step_count_steps=100)  

classifier = tf.estimator.Estimator(
      model_fn=model_fn, 
      params={ \
      'learning_rate':1e-3,
      'num_classes':20,
      'decay':0.995}, 
      config=configuration
      )


train_spec = tf.estimator.TrainSpec(
      input_fn=lambda:imgs_input_fn(filenames=train_outpath), 
      max_steps=10000) 

eval_spec = tf.estimator.EvalSpec(
      input_fn=lambda:imgs_input_fn(filenames=test_outpath), 
      steps=100)

tf.estimator.train_and_evaluate(classifier, train_spec, eval_spec)

Приведенный выше код работает нормально, когда я использую logits = dummy_model(features, num_classes) вместо logits = resnet50_network(features, num_classes).Когда я запускаю с более поздним (resnet50_network), он выдает

RuntimeError: График завершен и не может быть изменен.

Предполагается более поздняя функция ie resnet50_networkчтобы получить предварительно обученный resnet-50 на imagenet, добавьте плотный слой и верните логин вызывающей стороне.

Первый, т. е. dummy_model создает небольшой CNN для обучения с нуля.

выше только выдает ошибку, если я пытаюсь использовать предварительно обученную модель.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...