Модель Re sNet в Tensorflow Federated - PullRequest
1 голос
/ 07 января 2020

Я пытался настроить модель в учебнике «Классификация изображений» в Tensorflow Federated. (Первоначально использовалась последовательная модель) Я использую Keras ResNet50, но когда он начал тренироваться, всегда возникает ошибка «Несовместимые формы»

Вот мои коды:

NUM_CLIENTS = 4
NUM_EPOCHS = 10
BATCH_SIZE = 2
SHUFFLE_BUFFER = 5

def create_compiled_keras_model():
  model = tf.keras.applications.resnet.ResNet50(include_top=False, weights='imagenet', 
                                                input_tensor=tf.keras.layers.Input(shape=(100, 
                                                300, 3)), pooling=None)

  model.compile(
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
      optimizer=tf.keras.optimizers.SGD(learning_rate=0.02),
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])
  return model


def model_fn():
  keras_model = create_compiled_keras_model()
  return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

iterative_process = tff.learning.build_federated_averaging_process(model_fn)

Информация об ошибке : введите описание изображения здесь

Мне кажется, что форма несовместима, потому что информация об эпохе и клиентах почему-то отсутствовала. Был бы очень благодарен, если бы кто-то мог дать мне подсказку.

Обновления:

Ошибка подтверждения произошла во время tff.learning.build_federated_averaging_process

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-164-dac26193d9d8> in <module>()
----> 1 iterative_process = tff.learning.build_federated_averaging_process(model_fn)
      2 
      3 # iterative_process = build_federated_averaging_process(model_fn)

13 frames
/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/federated_averaging.py in build_federated_averaging_process(model_fn, server_optimizer_fn, client_weight_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    165   return optimizer_utils.build_model_delta_optimizer_process(
    166       model_fn, client_fed_avg, server_optimizer_fn,
--> 167       stateful_delta_aggregate_fn, stateful_model_broadcast_fn)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/framework/optimizer_utils.py in build_model_delta_optimizer_process(model_fn, model_to_client_delta_fn, server_optimizer_fn, stateful_delta_aggregate_fn, stateful_model_broadcast_fn)
    349   # still need this.
    350   with tf.Graph().as_default():
--> 351     dummy_model_for_metadata = model_utils.enhance(model_fn())
    352 
    353   # ===========================================================================

<ipython-input-159-b2763ace8e5b> in model_fn()
      1 def model_fn():
      2   keras_model = model
----> 3   return tff.learning.from_compiled_keras_model(keras_model, sample_batch)

/usr/local/lib/python3.6/dist-packages/tensorflow_federated/python/learning/keras_utils.py in from_compiled_keras_model(keras_model, dummy_batch)
    211   # Model.test_on_batch() once before asking for metrics.
    212   if isinstance(dummy_tensors, collections.Mapping):
--> 213     keras_model.test_on_batch(**dummy_tensors)
    214   else:
    215     keras_model.test_on_batch(*dummy_tensors)

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in test_on_batch(self, x, y, sample_weight, reset_metrics)
   1007         sample_weight=sample_weight,
   1008         reset_metrics=reset_metrics,
-> 1009         standalone=True)
   1010     outputs = (
   1011         outputs['total_loss'] + outputs['output_losses'] + outputs['metrics'])

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2_utils.py in test_on_batch(model, x, y, sample_weight, reset_metrics, standalone)
    503       y,
    504       sample_weights=sample_weights,
--> 505       output_loss_metrics=model._output_loss_metrics)
    506 
    507   if reset_metrics:

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in __call__(self, *args, **kwds)
    568         xla_context.Exit()
    569     else:
--> 570       result = self._call(*args, **kwds)
    571 
    572     if tracing_count == self._get_tracing_count():

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in _call(self, *args, **kwds)
    606       # In this case we have not created variables on the first call. So we can
    607       # run the first trace but we should fail if variables are created.
--> 608       results = self._stateful_fn(*args, **kwds)
    609       if self._created_variables:
    610         raise ValueError("Creating variables on a non-first call to a function"

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in __call__(self, *args, **kwargs)
   2407     """Calls a graph function specialized to the inputs."""
   2408     with self._lock:
-> 2409       graph_function, args, kwargs = self._maybe_define_function(args, kwargs)
   2410     return graph_function._filtered_call(args, kwargs)  # pylint: disable=protected-access
   2411 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _maybe_define_function(self, args, kwargs)
   2765 
   2766       self._function_cache.missed.add(call_context_key)
-> 2767       graph_function = self._create_graph_function(args, kwargs)
   2768       self._function_cache.primary[cache_key] = graph_function
   2769       return graph_function, args, kwargs

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py in _create_graph_function(self, args, kwargs, override_flat_arg_shapes)
   2655             arg_names=arg_names,
   2656             override_flat_arg_shapes=override_flat_arg_shapes,
-> 2657             capture_by_value=self._capture_by_value),
   2658         self._function_attributes,
   2659         # Tell the ConcreteFunction to clean up its graph once it goes out of

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in func_graph_from_py_func(name, python_func, args, kwargs, signature, func_graph, autograph, autograph_options, add_control_dependencies, arg_names, op_return_value, collections, capture_by_value, override_flat_arg_shapes)
    979         _, original_func = tf_decorator.unwrap(python_func)
    980 
--> 981       func_outputs = python_func(*func_args, **func_kwargs)
    982 
    983       # invariant: `func_outputs` contains only Tensors, CompositeTensors,

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/def_function.py in wrapped_fn(*args, **kwds)
    437         # __wrapped__ allows AutoGraph to swap in a converted function. We give
    438         # the function a weak reference to itself to avoid a reference cycle.
--> 439         return weak_wrapped_fn().__wrapped__(*args, **kwds)
    440     weak_wrapped_fn = weakref.ref(wrapped_fn)
    441 

/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/func_graph.py in wrapper(*args, **kwargs)
    966           except Exception as e:  # pylint:disable=broad-except
    967             if hasattr(e, "ag_error_metadata"):
--> 968               raise e.ag_error_metadata.to_exception(e)
    969             else:
    970               raise

AssertionError: in user code:

    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_eager.py:345 test_on_batch  *
        with backend.eager_learning_phase_scope(0):
    /usr/lib/python3.6/contextlib.py:81 __enter__
        return next(self.gen)
    /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py:425 eager_learning_phase_scope
        assert ops.executing_eagerly_outside_functions()

    AssertionError: 

Ответы [ 2 ]

0 голосов
/ 30 января 2020

У меня та же проблема: если я выполню это состояние строки, metrics = iterative_process.next (state, federated_train_data) print ('round 1, metrics = {}'. Format (metrics))

Я нахожу эта ошибка InvalidArgumentError: 2 root ошибок найдено. (0) Неверный аргумент: по умолчанию MaxPoolingOp поддерживает только NHW C на типе устройства CPU [[{{node StatefulPartitionedCall / StatefulPartitionedCall / sequential / vgg16 / block1_pool / MaxPool}}]] [[subcomputation / StatefulPartitionedCall_1 / ReduceDataset] / StatefulPartitionedCall_1 / ReduceDataset / _140]] (1) Недопустимый аргумент: по умолчанию MaxPoolingOp поддерживает только NHW C на устройстве типа CPU [[{{узел StatefulPartitionedCall / StatefulPartitionedCall / sequential / vgg16 / block1_pool / MaxPool }computation]] StatefulPartitionedCall_1 / ReduceDataset]] 0 успешных операций. 0 полученных ошибок игнорируется.

знаю, что я, работающий в VGG16, знаете, что это за ошибка

0 голосов
/ 08 января 2020

Ах, я полагаю, что эта проблема происходит из-за несоответствующих ожиданий sample_batch TFF передает sample_batch в Keras, что вызывает прямой проход с этим образцом пакета для инициализации различных атрибутов модели keras. sample_batch должен быть либо образцом литеральных данных, которые вы собираетесь передавать модели, как на стороне сервера, либо пакетом поддельных данных, соответствующих форме и типу данных, которые вы будете передавать.

Пример первого можно найти здесь (здесь используется tf.data.Dataset), и есть несколько примеров последнего в тестовом коде, например здесь .

Из того, что я вижу в определении модели, вероятно, x элемент вашей sample_batch должен быть ndarray формы [2, 100, 300, 3] (где 2 для размера партии, но технически это может быть любой ненулевой размерности), а элемент y также должен соответствовать ожидаемой структуре y в данных, которые вы используете.

Надеюсь, это поможет, просто отзовитесь, если возникнут проблемы!

Стоит отметить, что это может быть полезно при размышлении о TFF - TFF строит синтаксическое дерево, представляющее распределенные вычисления, которые вы определяете с помощью build_federated_averaging_process. Эта ошибка действительно возникает во время строительства этого объекта. TFF должен отслеживать вычисления, которые вы передаете, чтобы знать, какую структуру генерировать, и это то, что здесь поднимается. Фактическое обучение модели происходит, когда вы звоните next на возвращенном IterativeProcess.

...