Несовместимые формы с использованием `sample_weight` в исполнении Graph - PullRequest
1 голос
/ 27 мая 2020

Прекрасно работает в Eager. Код + ошибка ниже; вставленные выходные данные взяты из экземпляра Colab с tf.__version__ == 2.3.0-dev20200526 (tf-nightly), также воспроизведены в 2.2.0 и на Windows 10. В графике TF 1.14.0 ошибок нет.

Есть подсказки?


Попытка отладки : Размещение ниже здесь (при локальной установке)

print(sample_weight)
try: print("sample_weight =", K.eval(sample_weight))
except: pass
print(loss, '\n')

дает:

# EAGER
Tensor("ExpandDims:0", shape=(32, 1), dtype=float32)
Tensor("mean_squared_error/weighted_loss/value:0", shape=(), dtype=float32)

Tensor("ExpandDims:0", shape=(32, 1), dtype=float32)
Tensor("mean_squared_error/weighted_loss/value:0", shape=(), dtype=float32)

# GRAPH
1.0
sample_weight = 1.0
Tensor("loss/conv2d_loss/weighted_loss/Mul:0", shape=(32, 28, 28), dtype=float32)

Tensor("conv2d_sample_weights:0", shape=(None,), dtype=float32)
sample_weight = [1.]
Tensor("loss_1/conv2d_loss/weighted_loss/Mul:0", shape=(32, 28, 28), dtype=float32)

График обрабатывает тензор sample_weight иначе. И Eager, и Graph отлично работают с 2D-формой вывода и loss='categorical_crossentropy', хотя, возможно, неправильно; код внизу здесь .


Воспроизводимый код + Ошибка :

import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.models import Model
tf.compat.v1.disable_eager_execution()

batch_shape = (32, 28, 28, 1)

ipt = Input(batch_shape=batch_shape)
out = Conv2D(filters=1, kernel_size=(1, 1))(ipt)
model = Model(ipt, out)
model.compile('adam', 'mse')

x = y = np.random.randn(*batch_shape)
sw = np.ones(len(x))

model.train_on_batch(x, y, sw)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
   1470         ret = tf_session.TF_SessionRunCallable(self._session._session,
   1471                                                self._handle, args,
-> 1472                                                run_metadata_ptr)
   1473         if run_metadata:
   1474           proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: Incompatible shapes: [32] vs. [32,28,28]
     [[{{node training/Adam/gradients/gradients/loss_1/conv2d_loss/weighted_loss/Mul_grad/Mul}}]]
...