Как я могу сохранить обученную модель TensorFlow Federated в качестве модели .h5? - PullRequest
0 голосов
/ 27 марта 2020

Я хочу сохранить федеративную модель TensorFlow, которая была обучена с помощью алгоритма FedAvg, в качестве модели Keras / .h5. Я не смог найти документы по этому вопросу и хотел бы знать, как это можно сделать. Также, если возможно, я хотел бы иметь доступ как к модели агрегированного сервера, так и к моделям клиентов.

Код, который я использую для обучения федеративной модели, приведен ниже:

def model_fn():
    model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(segment_size,num_input_channels)),
      tf.keras.layers.Flatten(), 
      tf.keras.layers.Dense(units=400, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(units=100, activation='relu'),
      tf.keras.layers.Dropout(dropout_rate),
      tf.keras.layers.Dense(activityCount, activation='softmax'),
    ])
    return tff.learning.from_keras_model(
      model,
      dummy_batch=batch,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(learningRate))

def evaluate(num_rounds=communicationRound):
  state = trainer.initialize()
  roundMetrics = []
  evaluation = tff.learning.build_federated_evaluation(model_fn)

  for round_num in range(num_rounds):
    t1 = time.time()
    state, metrics = trainer.next(state, train_data)
    t2 = time.time()
    test_metrics = evaluation(state.model, train_data)

    roundMetrics.append('round {:2d}, metrics={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    roundMetrics.append("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    roundMetrics.append('round time={}'.format(t2 - t1))
    print('round {:2d}, accuracy={}, loss={}'.format(round_num, metrics.sparse_categorical_accuracy , metrics.loss))
    print("The test accuracy is " + str(test_metrics.sparse_categorical_accuracy))
    print('round time={}'.format(t2 - t1))
  outF = open(filepath+'stats'+architectureType+'.txt', "w")
  for line in roundMetrics:
    outF.write(line)
    outF.write("\n")
  outF.close()

1 Ответ

2 голосов
/ 28 марта 2020

Грубо говоря, мы будем использовать методы save_checkpoint / load_checkpoint. В частности, вы можете создать экземпляр FileCheckpointManager и попросить его сохранить состояние (почти) напрямую.

Состояние в вашем примере является экземпляром tff. python .common_libs.anonymous_tuple.AnonymousTuple (IIR C) ), который не совместим с tf.convert_to_tensor, что необходимо для save_checkpoint и объявлено в его строке документации. Общее решение, часто используемое в исследовательском коде TFF, состоит в том, чтобы ввести класс Python attrs для преобразования из анонимного кортежа, как только возвращается состояние -

Предполагая, что вышеизложенное, следующий эскиз должен работать:

# state assumed an anonymous tuple, previously created
# N some integer 

ckpt_manager = FileCheckpointManager(...)
ckpt_manager.save_checkpoint(ServerState.from_anon_tuple(state), round_num=N)

И для восстановления с этой контрольной точки в любое время вы можете позвонить:

state = iterative_process.initialize()
ckpt_manager = FileCheckpointManager(...)
restored_state = ckpt_manager.load_latest_checkpoint(
    ServerState.from_anon_tuple(state))

Одна вещь, на которую стоит обратить внимание: указатели кода, связанные выше, обычно находятся в tff. python. исследование ..., которое не входит в комплект поставки; поэтому предпочтительный способ получить их - это либо вставить код в ваш собственный проект, либо снять репозиторий и собрать его из исходного кода.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...