Я хочу сохранить федеративную модель TensorFlow, которая была обучена с помощью алгоритма FedAvg, в качестве модели Keras / .h5. Я не смог найти документы по этому вопросу и хотел бы знать, как это можно сделать. Также, если возможно, я хотел бы иметь доступ как к модели агрегированного сервера, так и к моделям клиентов.
Код, который я использую для обучения федеративной модели, приведен ниже:
def model_fn():
model = tf.keras.models.Sequential([
tf.keras.layers.Input(shape=(segment_size,num_input_channels)),
tf.keras.layers.Flatten(),
tf.keras.layers.Dense(units=400, activation='relu'),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Dense(units=100, activation='relu'),
tf.keras.layers.Dropout(dropout_rate),
tf.keras.layers.Dense(activityCount, activation='softmax'),
])
return tff.learning.from_keras_model(
model,
dummy_batch=batch,
loss=tf.keras.losses.SparseCategoricalCrossentropy(),
metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learningRate))
def evaluate(num_rounds=communicationRound):
state = trainer.initialize()
roundMetrics = []
evaluation = tff.learning.build_federated_evaluation(model_fn)
for round_num in range(num_rounds):
t1 = time.time()
state, metrics = trainer.next(state, train_data)
t2 = time.time()
test_metrics = evaluation(state.model, train_data)
roundMetrics.append('round {:2d}, metrics={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
roundMetrics.append("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
roundMetrics.append('round time={}'.format(t2 - t1))
print('round {:2d}, accuracy={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
print("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
print('round time={}'.format(t2 - t1))
outF = open(filepath+'stats'+architectureType+'.txt', "w")
for line in roundMetrics:
outF.write(line)
outF.write("\n")
outF.close()