Я пытался создать серию каналов, созданных вручную, которые являются частью графика, который я хотел бы разместить на входном изображении / тензоре до его прохождения через остальную сеть.
input_tensor = KL.Input(shape=input_image, name="input")
handcrafted_channels = handcrafted(input_tensor)
x = KL.concatenate([input_tensor, handcrafted_channels], axis=-1)
x = KL.ZeroPadding2D((3, 3))(x)
x = KL.MaxPooling2D(pool_size=(1, 1), strides=(1,1), padding="same")(x)
x = KL.Conv2D(64, (7, 7), strides=(2, 2), name='conv1', use_bias=True, input_shape=x.shape, data_format="channels_last")(x)
... continue with normal resnet
def handcrafted(self, input_tensor):
_red, _green, _blue = tf.split(input_tensor, 3, axis = 3)
# This could be any sort of equation, but for example a really simple set
handcrafted_channel_a = KL.add([_red, _green])
handcrafted_channel_b = KL.subtract([_green, _blue])
handcrafted_channels = KL.concatenate([handcrafted_channel_a, handcrafted_channel_b], axis=-1)
return handcrafted_channels
Когда я запускаю это с функцией потерь sparse_categorical_crossentropy и оптимизатором SGD (learning_rate = 0,01, импульс = 0,9, clipnorm = 5,0), я получаю Nan для потерь. Чтобы убедиться, что с остальной сетью все в порядке, я могу успешно запустить обучение, если удаляю handcrafted_channels и у меня нет первого объединения.
Тренировочный прогон с включенным вручную
Тренировочный прогон без включенного вручную
И сообщение об ошибке, которое выдается в конце 1-й эпохи:
---------------------------------------------------------------------------
InvalidArgumentError Traceback (most recent call last)
<ipython-input-6-77a19ea1ade6> in <module>
48 reduce_lr,
---> 49 early_stopping,
50 ])
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py in fit(self, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
817 max_queue_size=max_queue_size,
818 workers=workers,
--> 819 use_multiprocessing=use_multiprocessing)
820
821 def evaluate(self,
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py in fit(self, model, x, y, batch_size, epochs, verbose, callbacks, validation_split, validation_data, shuffle, class_weight, sample_weight, initial_epoch, steps_per_epoch, validation_steps, validation_freq, max_queue_size, workers, use_multiprocessing, **kwargs)
395 total_epochs=1)
396 cbks.make_logs(model, epoch_logs, eval_result, ModeKeys.TEST,
--> 397 prefix='val_')
398
399 return model.history
/usr/lib/python3.6/contextlib.py in __exit__(self, type, value, traceback)
86 if type is None:
87 try:
---> 88 next(self.gen)
89 except StopIteration:
90 return False
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_v2.py in on_epoch(self, epoch, mode)
769 if mode == ModeKeys.TRAIN:
770 # Epochs only apply to `fit`.
--> 771 self.callbacks.on_epoch_end(epoch, epoch_logs)
772 self.progbar.on_epoch_end(epoch, epoch_logs)
773
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
300 logs = logs or {}
301 for callback in self.callbacks:
--> 302 callback.on_epoch_end(epoch, logs)
303
304 def on_train_batch_begin(self, batch, logs=None):
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in on_epoch_end(self, epoch, logs)
1711
1712 if self.histogram_freq and epoch % self.histogram_freq == 0:
-> 1713 self._log_weights(epoch)
1714
1715 if self.embeddings_freq and epoch % self.embeddings_freq == 0:
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/callbacks.py in _log_weights(self, epoch)
1802 with ops.init_scope():
1803 weight = K.get_value(weight)
-> 1804 summary_ops_v2.histogram(weight_name, weight, step=epoch)
1805 if self.write_images:
1806 self._log_weight_as_image(weight, weight_name, epoch)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in histogram(name, tensor, family, step)
821 name=scope)
822
--> 823 return summary_writer_function(name, tensor, function, family=family)
824
825
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in summary_writer_function(name, tensor, function, family)
750 with ops.device("cpu:0"):
751 op = smart_cond.smart_cond(
--> 752 should_record_summaries(), record, _nothing, name="")
753 if not context.executing_eagerly():
754 ops.add_to_collection(ops.GraphKeys._SUMMARY_COLLECTION, op) # pylint: disable=protected-access
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/smart_cond.py in smart_cond(pred, true_fn, false_fn, name)
52 if pred_value is not None:
53 if pred_value:
---> 54 return true_fn()
55 else:
56 return false_fn()
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in record()
743 with ops.name_scope(name_scope), summary_op_util.summary_scope(
744 name, family, values=[tensor]) as (tag, scope):
--> 745 with ops.control_dependencies([function(tag, scope)]):
746 return constant_op.constant(True)
747
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/summary_ops_v2.py in function(tag, scope)
819 tag,
820 array_ops.identity(tensor),
--> 821 name=scope)
822
823 return summary_writer_function(name, tensor, function, family=family)
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_summary_ops.py in write_histogram_summary(writer, step, tag, values, name)
467 try:
468 return write_histogram_summary_eager_fallback(
--> 469 writer, step, tag, values, name=name, ctx=_ctx)
470 except _core._SymbolicException:
471 pass # Add nodes to the TensorFlow graph.
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/ops/gen_summary_ops.py in write_histogram_summary_eager_fallback(writer, step, tag, values, name, ctx)
488 _attrs = ("T", _attr_T)
489 _result = _execute.execute(b"WriteHistogramSummary", 0, inputs=_inputs_flat,
--> 490 attrs=_attrs, ctx=ctx, name=name)
491 _result = None
492 return _result
/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py in quick_execute(op_name, num_outputs, inputs, attrs, ctx, name)
65 else:
66 message = e.message
---> 67 six.raise_from(core._status_to_exception(e.code, message), None)
68 except TypeError as e:
69 keras_symbolic_tensors = [
/usr/local/lib/python3.6/dist-packages/six.py in raise_from(value, from_value)
Если у кого-либо есть Идеи о том, где искать, чтобы попытаться решить эту проблему, будут высоко оценены.