Прекрасно работает в Eager. Код + ошибка ниже; вставленные выходные данные взяты из экземпляра Colab с tf.__version__ == 2.3.0-dev20200526
(tf-nightly), также воспроизведены в 2.2.0 и на Windows 10. В графике TF 1.14.0 ошибок нет.
Есть подсказки?
Попытка отладки : Размещение ниже здесь (при локальной установке)
print(sample_weight)
try: print("sample_weight =", K.eval(sample_weight))
except: pass
print(loss, '\n')
дает:
# EAGER
Tensor("ExpandDims:0", shape=(32, 1), dtype=float32)
Tensor("mean_squared_error/weighted_loss/value:0", shape=(), dtype=float32)
Tensor("ExpandDims:0", shape=(32, 1), dtype=float32)
Tensor("mean_squared_error/weighted_loss/value:0", shape=(), dtype=float32)
# GRAPH
1.0
sample_weight = 1.0
Tensor("loss/conv2d_loss/weighted_loss/Mul:0", shape=(32, 28, 28), dtype=float32)
Tensor("conv2d_sample_weights:0", shape=(None,), dtype=float32)
sample_weight = [1.]
Tensor("loss_1/conv2d_loss/weighted_loss/Mul:0", shape=(32, 28, 28), dtype=float32)
График обрабатывает тензор sample_weight
иначе. И Eager, и Graph отлично работают с 2D-формой вывода и loss='categorical_crossentropy'
, хотя, возможно, неправильно; код внизу здесь .
Воспроизводимый код + Ошибка :
import numpy as np
import tensorflow as tf
from tensorflow.keras.layers import Input, Conv2D
from tensorflow.keras.models import Model
tf.compat.v1.disable_eager_execution()
batch_shape = (32, 28, 28, 1)
ipt = Input(batch_shape=batch_shape)
out = Conv2D(filters=1, kernel_size=(1, 1))(ipt)
model = Model(ipt, out)
model.compile('adam', 'mse')
x = y = np.random.randn(*batch_shape)
sw = np.ones(len(x))
model.train_on_batch(x, y, sw)
/usr/local/lib/python3.6/dist-packages/tensorflow/python/client/session.py in __call__(self, *args, **kwargs)
1470 ret = tf_session.TF_SessionRunCallable(self._session._session,
1471 self._handle, args,
-> 1472 run_metadata_ptr)
1473 if run_metadata:
1474 proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)
InvalidArgumentError: Incompatible shapes: [32] vs. [32,28,28]
[[{{node training/Adam/gradients/gradients/loss_1/conv2d_loss/weighted_loss/Mul_grad/Mul}}]]