вот мой код, я использую Tensorflow v2.1.0 и tff 0.12.0
img_height = 200
img_width = 200
num_classes = 2
batch_size = 10
input_shape = (img_height, img_width, 3)
img_gen = tf.keras.preprocessing.image.ImageDataGenerator()
gen0 = img_gen.flow_from_directory(par1_train_data_dir,(200, 200),'rgb', batch_size=10)
ds_par1 = tf.data.Dataset.from_generator(gen
output_types=(tf.float32, tf.float32),
output_shapes=([None,img_height,img_width,3], [None,num_classes])
)
ds_par2 = tf.data.Dataset.from_generator(gen0
output_types=(tf.float32, tf.float32),
output_shapes=([None,img_height,img_width,3], [None,num_classes])
)
dataset_dict={}
dataset_dict['1'] = ds_par1
dataset_dict['2'] = ds_par2
def create_tf_dataset_for_client_fn(client_id):
return dataset_dict[client_id]
source = tff.simulation.ClientData.from_clients_and_fn(['1','2'],create_tf_dataset_for_client_fn)
def client_data(n):
ds = source.create_tf_dataset_for_client(source.client_ids[n])
return ds
train_data = [client_data(n) for n in range(1)]
images, labels = next(img_gen.flow_from_directory(par1_train_data_dir,batch_size=batch_size,target_size=(img_height,img_width)))
sample_batch = (images,labels)
def create_compiled_keras_model():
.....
def model_fn():
keras_model = create_compiled_keras_model()
return tff.learning.from_compiled_keras_model(keras_model, sample_batch)
iterative_process = tff.learning.build_federated_averaging_process(model_fn)
state = iterative_process.initialize()
state, metrics = iterative_process.next(state, train_data)
print('round 1, metrics={}'.format(round_num, metrics))