Я хочу преобразовать тензор в рваный тензор на моем графике, используя Keras. Тем не менее, функция RaggedTensor.from_row_lengths
не работает в моем графике.
Версия Tensorflow: tf-nightly 2.1.0.dev20191203
Вот пример кода:
import tensorflow as tf
import numpy as np
input_sequence = np.reshape(
np.array([1, 2, 3, 4, 5, 6, 7, 8], dtype=np.int32),
(2, 4))
labels = np.reshape(
np.array([1.0, 0.0, ], dtype=np.float32),
(2, 1))
dataset = tf.data.Dataset.from_tensor_slices((input_sequence, labels)).batch(1)
sequence_in = tf.keras.layers.Input(shape=(None,), dtype=tf.int32)
# Failing line, the rest works without the line below
ragged_in = tf.RaggedTensor.from_row_lengths(sequence_in, [2, 2])
embedded_tensor = tf.keras.layers.Embedding(9, 4)(sequence_in)
flat_tensor = tf.reshape(embedded_tensor, [-1, 16])
prediction = tf.keras.layers.Dense(2)(flat_tensor)
model = tf.keras.Model(inputs=sequence_in, outputs=prediction)
model.compile(
tf.keras.optimizers.Adam(),
loss=tf.keras.losses.CategoricalCrossentropy(),
metrics=['acc'])
model.fit(dataset, steps_per_epoch=1)
Кажется, ошибка связана с проверкой, примененной для проверки форм Tensor:
Traceback (most recent call last):
File "myscript.py", line 18, in <module>
ragged_in = tf.RaggedTensor.from_row_lengths(sequence_in, [4, 1])
File "python3.6/site-packages/tensorflow_core/python/ops/ragged/ragged_tensor.py", line 510, in from_row_lengths
check_ops.assert_equal(nvals1, nvals2, message=msg)
File "python3.6/site-packages/tensorflow_core/python/ops/check_ops.py", line 506, in assert_equal
if not condition:
File "python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 765, in __bool__
self._disallow_bool_casting()
File "python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 534, in _disallow_bool_casting
self._disallow_in_graph_mode("using a `tf.Tensor` as a Python `bool`")
File "python3.6/site-packages/tensorflow_core/python/framework/ops.py", line 523, in _disallow_in_graph_mode
" this function with @tf.function.".format(task))
tensorflow.python.framework.errors_impl.OperatorNotAllowedInGraphError: using a `tf.Tensor` as a Python `bool` is not allowed in Graph execution. Use Eager execution or decorate this function with @tf.function
Я могу проигнорировать ошибку, используя validate=False
, но затем произойдет ошибка на следующем уровне:
ragged_in = tf.RaggedTensor.from_row_lengths(sequence_in, [2, 2], validate=False)
embedded_ragged = tf.keras.layers.Embedding(9, 4)(ragged_in)
Интересно, связано ли это с тем, что размер пакета, и Tensor 'sequence_in' не исправлены. Поэтому я также попытался преобразовать только первое наблюдение в Ragged Tensor, но такая же ошибка сохраняется.
ragged_in = tf.RaggedTensor.from_row_lengths(sequence_in[0], [2, 2])