TensorFlow: Как преобразовать DeferredTensor в Tensor во время активного выполнения (для выполнения нормализации группы)? - PullRequest
1 голос
/ 14 июня 2019

В TensorFlow с 1.10 по 1.12 (используя активное выполнение) у меня есть следующий фрагмент кода:

tensor = tf.keras.layers.Conv2D(128, (3, 3), padding='same')(tensor)
tensor = tf.contrib.layers.group_norm(tensor)

Однако вызов tf.contrib.layers.group_norm(tensor) выдает мне следующую ошибку:

ValueError: Attempt to convert a value (<DeferredTensor 'None' shape=(?, 16, 16, 128) dtype=float32>) with an unsupported type (<class 'tensorflow.python.keras.engine.base_layer.DeferredTensor'>) to a Tensor.

Можно ли преобразовать DeferredTensor в Tensor или EagerTensor?Могу ли я выполнить групповую нормализацию другим способом?

1 Ответ

1 голос
/ 14 июня 2019

Вам нужно с самого начала включить активное выполнение. Похоже, у вас есть смесь нетерпеливых и отложенных, вызывающих проблему, поэтому я подозреваю, что один из ваших тензорных операций был создан до вызова tf.enable_eager_execution(), а не что-то конкретное в отношении вызова tf.contrib.layers.group_norm(tensor).

Например, с v 1.12.2 я могу сделать:

import tensorflow as tf

tf.enable_eager_execution()

tensor = tf.random.normal((1, 3,3, 3))
tensor = tf.keras.layers.Conv2D(128, (3, 3), padding='same')(tensor)
tensor = tf.contrib.layers.group_norm(tensor)

print(tf.__version__)
print(tensor)

и получите ожидаемый результат:

1.12.2
tf.Tensor(
[[[[-0.46789345  0.05536499 -0.06537625 ... -0.612622   -0.80556583
     0.39658052]
   [-1.086592   -0.65128946  1.1523774  ...  0.70371515 -0.19514994
     0.6261743 ]
   [-0.68818045  1.0391753   0.61246586 ...  0.49158555 -0.23147273
    -0.40839535]]

  [[-0.27729145 -0.7241349  -0.45006287 ... -1.6836562  -2.0581594
    -0.09571741]
   [-2.7078617   1.6280639   0.29760775 ...  0.48920113 -2.148665
    -0.17309377]
   [ 2.41167    -0.29042014 -0.7241919  ... -0.0780689   1.451448
     2.812067  ]]

  [[ 0.04337802  1.5531337   0.838807   ... -0.164665   -0.28958386
    -1.6659214 ]
   [ 0.38814372 -0.1571713   0.16725369 ...  0.93523234 -0.2039619
     0.6319514 ]
   [ 0.09182647  0.19946824 -0.8600142  ... -0.5493502  -0.68655336
     0.45441204]]]], shape=(1, 3, 3, 128), dtype=float32)
...