Я пытаюсь обновить tf 1.15 с помощью keras до tenorflow 2, т. Е. Tf.keras ...
tf 1.15 с использованием keras работает нормально.
При вызове модели .fit (), я получаю ошибку значения (показано ниже).
#Train the model
import tensorflow as tf
model.fit(data, [labels, data], batch_size=1, epochs=1, verbose=1)
Входные и выходные данные: данные - это dtype ('float32'), метка - dtype ('uint8')
В конечном счете код завершается ошибкой при операции умножения numpy: TypeError: У ввода 'y' из 'Mul' Op есть тип float32, который не соответствует типу uint8 аргумента 'x'.
У меня есть попытался изменить массив меток np на tf.float32, приведя метки к tf.float32. Я также пробовал более простые функции потерь.
Любое направление будет оценено. Спасибо, Джей.
model.fit() output:
Train on 4 samples
1/4 [======>.......................] - ETA: 3s
ValueError Traceback (most recent call last)
~/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
469 as_ref=input_arg.is_ref,
--> 470 preferred_dtype=default_dtype)
471 except TypeError as err:
~/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/framework/ops.py in convert_to_tensor(value, dtype, name, as_ref, preferred_dtype, dtype_hint, ctx, accepted_result_types)
1316 "Tensor conversion requested dtype %s for Tensor with dtype %s: %r" %
-> 1317 (dtype.name, value.dtype.name, value))
1318 return value
ValueError: Tensor conversion requested dtype uint8 for Tensor with dtype float32: <tf.Tensor 'model/Dec_GT_Output/Sigmoid:0' shape=(1, 3, 80, 96, 64) dtype=float32>
During handling of the above exception, another exception occurred:
TypeError Traceback (most recent call last)
in
2 # .astype("float32").values
3
----> 4 model.fit(data, [labels, data], batch_size=1, epochs=1, verbose=1)
...
~/anaconda3/envs/tf2/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py in _apply_op_helper(op_type_name, name, **keywords)
504 "%s type %s of argument '%s'." %
505 (prefix, dtypes.as_dtype(attrs[input_arg.type_attr]).name,
--> 506 inferred_from[input_arg.type_attr]))
507
508 types = [values.dtype]
TypeError: Input 'y' of 'Mul' Op has type float32 that does not match type uint8 of argument 'x'.