Ошибка при попытке использования метода "tff.learning.assign_weights_to_keras_model" - PullRequest
0 голосов
/ 08 апреля 2020

Я хотел бы попробовать этот метод TFF с этим учебником , но я обнаружил ошибку, которую не могу понять. Я использую assign_weight и после этого оцениваю свою модель. Вот мой код:

emnist_train, emnist_test = tff.simulation.datasets.emnist.load_data()
...
def create_compiled_keras_model():
    model = tf.keras.models.Sequential([
    tf.keras.layers.Dense(
          10, activation=tf.nn.softmax, kernel_initializer='zeros', input_shape=(784,))])

    model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
     optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
     metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
    return model

def model_fn():
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, train_data)

NUM_ROUNDS = 11
for round_num in range(1, NUM_ROUNDS):
    state, metrics = iterative_process.next(state, train_data)
    print('round {:2d}, metrics={}'.format(round_num, metrics))

evaluation = tff.learning.build_federated_evaluation(model_fn)
train_metrics = evaluation(state.model, train_data)


keras_model = create_compiled_keras_model()
keras_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(),
     optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
     metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
tff.learning.assign_weights_to_keras_model(keras_model, state.model)
centralized_data= emnist_test.create_tf_dataset_from_all_clients()
loss, accuracy = keras_model.evaluate(centralized_data, verbose =1)
print('loss={}, accuracy={}'.format(loss, accuracy))

Сообщение об ошибке:

loss, accuracy = keras_model.evaluate(centralized_data, verbose =1)
ValueError: No data provided for "dense_input". Need data for each key in: ['dense_input']

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...