Tensorflow Federated: почему мой итеративный процесс не может тренировать раунды - PullRequest
1 голос
/ 10 февраля 2020

Я пишу код с TFF из моего собственного набора данных, весь код выполняется правильно, кроме этой строки

В train_data я делаю 4 набора данных, загруженных с помощью tf.data.Dataset, они имеют тип "DatasetV1Adapter "

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds.repeat(10).map(map_fn).shuffle(500).batch(20)

federated_train_data = [client_data(n) for n in range(4)]

batch = tf.nest.map_structure(lambda x: x.numpy(), iter(train_data[0]).next())

def model_fn():
  model = tf.keras.models.Sequential([
    .........
  return tff.learning.from_compiled_keras_model(model, batch)   

все это работает правильно, и я получаю тренер и заявляю:

trainer = tff.learning.build_federated_averaging_process(model_fn)

За исключением случаев, когда я хотел бы начать тренировку и пройти с этим кодом:

state, metrics = iterative_process.next(state, federated_train_data) 
print('round  1, metrics={}'.format(metrics))

Я не могу. ошибка приходит! Итак, откуда может быть ошибка? из типа набора данных? или как мне сделать мои данные объединенными?

Ответы [ 2 ]

0 голосов
/ 04 марта 2020

Как проверено в комментариях выше, добавление вызова take(N) для некоторого конечного целого числа N в функции client_data должно решить эту проблему. Проблема в том, что TFF сократит всех элементов в наборе данных, который вы передаете ему . Если у вас есть бесконечный набор данных, это означает, что «продолжайте сокращение всегда». N здесь должно отображать «сколько данных имеет отдельный клиент», и может быть действительно любым по вашему выбору.

0 голосов
/ 02 марта 2020

вот мой код, я использую Tensorflow v2.1.0 и tff 0.12.0

img_height = 200
img_width = 200
num_classes = 2
batch_size = 10

input_shape = (img_height, img_width, 3)

img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen0 = img_gen.flow_from_directory(par1_train_data_dir,(200, 200),'rgb', batch_size=10)
ds_par1 = tf.data.Dataset.from_generator(gen
    output_types=(tf.float32, tf.float32),
    output_shapes=([None,img_height,img_width,3], [None,num_classes])
)
ds_par2 = tf.data.Dataset.from_generator(gen0 
    output_types=(tf.float32, tf.float32),
    output_shapes=([None,img_height,img_width,3], [None,num_classes])
)

dataset_dict={}
dataset_dict['1'] = ds_par1
dataset_dict['2'] = ds_par2

def create_tf_dataset_for_client_fn(client_id):
    return dataset_dict[client_id]

source = tff.simulation.ClientData.from_clients_and_fn(['1','2'],create_tf_dataset_for_client_fn)

def client_data(n):
  ds = source.create_tf_dataset_for_client(source.client_ids[n])
  return ds


train_data = [client_data(n) for n in range(1)]

images, labels = next(img_gen.flow_from_directory(par1_train_data_dir,batch_size=batch_size,target_size=(img_height,img_width)))
sample_batch = (images,labels)

def create_compiled_keras_model():
  .....

def model_fn():
    keras_model = create_compiled_keras_model()
    return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()


state, metrics = iterative_process.next(state, train_data)
print('round 1, metrics={}'.format(round_num, metrics))
...