У меня есть работающая нейронная сеть (встроенная в Tensorflow 2.0 с API Keras), которую я обучал с точностью float32
(точность по умолчанию). Теперь я хочу тренироваться с точностью float64. Я включаю его с tensorflow.keras.backend.set_floatx('float64)
перед началом выполнения нейронной сети. Обучение начинается, но в последней партии первой эпохи я получаю следующую ошибку:
File "Z:\Z_MASTER\DL_Reconstruction\train_stage_1.py", line 49, in train_vae
validation_split=1/19, callbacks=callbacks) # CHANGE val split
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\engine\training.py", line 728, in fit
use_multiprocessing=use_multiprocessing)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\engine\training_arrays.py", line 674, in fit
steps_name='steps_per_epoch')
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\engine\training_arrays.py", line 449, in model_iteration
callbacks.on_epoch_end(epoch, epoch_logs)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\callbacks.py", line 298, in on_epoch_end
callback.on_epoch_end(epoch, logs)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\callbacks.py", line 1614, in on_epoch_end
self._log_weights(epoch)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\callbacks.py", line 1696, in _log_weights
self._log_weight_as_image(weight, weight_name, epoch)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\keras\callbacks.py", line 1721, in _log_weight_as_image
summary_ops_v2.image(weight_name, w_img, step=epoch)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\ops\summary_ops_v2.py", line 820, in image
return summary_writer_function(name, tensor, function, family=family)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\ops\summary_ops_v2.py", line 730, in summary_writer_function
should_record_summaries(), record, _nothing, name="")
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\framework\smart_cond.py", line 54, in smart_cond
return true_fn()
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\ops\summary_ops_v2.py", line 723, in record
with ops.control_dependencies([function(tag, scope)]):
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\ops\summary_ops_v2.py", line 818, in function
name=scope)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\ops\gen_summary_ops.py", line 654, in write_image_summary
name=name, ctx=_ctx)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\ops\gen_summary_ops.py", line 698, in write_image_summary_eager_fallback
attrs=_attrs, ctx=_ctx, name=name)
File "Z:\Z_MASTER\Envs\p37_new_clone\lib\site-packages\tensorflow_core\python\eager\execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "<string>", line 3, in raise_from
tensorflow.python.framework.errors_impl.InvalidArgumentError: Value for attr 'T' of double is not in the list of allowed values: uint8, float, half
; NodeDef: {{node WriteImageSummary}}; Op<name=WriteImageSummary; signature=writer:resource, step:int64, tag:string, tensor:T, bad_color:uint8 -> ; attr=max_images:int,default=3,min=1; attr=T:type,default=DT_FLOAT,allowed=[DT_UINT8, DT_FLOAT, DT_HALF]; is_stateful=true> [Op:WriteImageSummary] name: enc_0_conv/kernel_0/
Process finished with exit code 1
Короче говоря, я думаю, что последняя строка сообщения об ошибке была бы наиболее полезной в обнаружение ошибки:
tensorflow.python.framework.errors_impl.InvalidArgumentError: Value for attr 'T' of double is not in the list of allowed values: uint8, float, half
Я попытался решить эту проблему, изменив параметр dtype
слоев на float64
следующим образом (только фрагменты):
conv = Conv2D(..., dtype='float64')(input)
...
output = ReLU(dtype='float64')(input)
...
lat_var = Lambda(... dtype='float64')([z_mean, z_log_var])
...
Код вылетает в этой строке:
history = model.fit(x=images, y=images, epochs=200, batch_size=32,
validation_split=1/19, callbacks=callbacks)
с images
массивом numpy типа float64
, который был достигнут с помощью images = images.astype('float64')
.
Кто-нибудь знает, как Я могу тренироваться с float64
точностью?