TFF: Как моделировать обучение на случайных выборках пользователей в каждом раунде - PullRequest
1 голос
/ 02 апреля 2020

Я хотел бы смоделировать этот код федеративного обучения для классификации изображений со случайными выборками пользователей в каждом раунде. В этом учебном пособии используются все клиенты для обучения, я уверен, что я хотел бы изменить этот код таким образом, в каждом раунде случайные образцы клиентов выбраны. Так что мы можем изменить в этом коде, чтобы заставить его выбирать клиента случайным образом

import collections
import time

import tensorflow as tf
tf.compat.v1.enable_v2_behavior()

import tensorflow_federated as tff

source, _ = tff.simulation.datasets.emnist.load_data()


def map_fn(example):
  return collections.OrderedDict(
      x=tf.reshape(example['pixels'], [-1, 784]), y=example['label'])


def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).shuffle(500).batch(20).map(map_fn)


train_data = [client_data(n) for n in range(10)]
element_spec = train_data[0].element_spec

def model_fn():
  model = tf.keras.models.Sequential([
      tf.keras.layers.Input(shape=(784,)),
      tf.keras.layers.Dense(units=10, kernel_initializer='zeros'),
      tf.keras.layers.Softmax(),
  ])
  return tff.learning.from_keras_model(
      model,
      input_spec=element_spec,
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])


trainer = tff.learning.build_federated_averaging_process(
    model_fn, client_optimizer_fn=lambda: tf.keras.optimizers.SGD(0.02))

....
NUM_ROUNDS = 11
for round_num in range(2, NUM_ROUNDS):
  state, metrics = trainer.next(state, federated_train_data)
  print('round {:2d}, metrics={}'.format(round_num, metrics))

1 Ответ

1 голос
/ 04 апреля 2020

tff.simulation.ClientData объекты предоставляют атрибут client_ids, который представляет список строк, идентифицирующих пользователей в этом наборе данных.

Таким образом, вы можете напрямую выбирать из этого списка, и используйте метод create_tf_dataset_for_client для того же объекта, чтобы создать набор данных этого пользователя. Предполагая tff.simulation.ClientData объект client_data, псевдокод будет выглядеть следующим образом:

import random
...

for round_num in range(2, NUM_ROUNDS):
  selected_clients = random.sample(client_data.client_ids, USERS_PER_ROUND)
  federated_data = [
      client_data.create_tf_dataset_for_client(n) for n in selected_clients]
  state, metrics = iterative_process.next(state, federated_data)

Большая часть исследовательского кода, включенного в TFF, несколько отделяет задачу выбора клиентов от проведения обучения l oop, поэтому Я не могу указать на хороший пример этого паттерна, но TFF, я думаю, был бы рад принять участие в обновлении учебников, чтобы использовать такой паттерн, чтобы помочь немного продемонстрировать гибкость ClientData API. лучше.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...