Иногда я получаю странные результаты с TensorFlow после слоя Conv2D. Следующая программа иллюстрирует проблему:
import numpy as np
import tensorflow as tf
# Kernel
k = np.empty((3, 3, 1, 1), dtype=np.float32)
k[:, :, 0, 0] = np.array([
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
], dtype=np.float32)
# Input
x = np.empty((1, 6, 6, 1), dtype=np.float32)
x[0, :, :, 0] = np.array([
[1, 0, 0, 1, 0, 0],
[0, 1, 0, 0, 1, 0],
[0, 0, 1, 0, 0, 1],
[0, 0, 0, 0, 0, 0],
[1, 1, 1, 1, 1, 1],
[0, 0, 0, 0, 0, 0],
], dtype=np.float32)
# Create model
model = tf.keras.models.Sequential([
tf.keras.layers.Conv2D(1, 3, use_bias=False, input_shape=(6, 6, 1)),
])
model.set_weights([k])
# Evaluate model
y = model(x).numpy()
print('kernel')
print(model.get_weights()[0][:, :, 0, 0])
print('input')
print(x[0, :, :, 0])
print('output')
print(y[0, :, :, 0])
Правильный вывод создается в большинстве случаев:
kernel
[[1. 0. 0.]
[0. 1. 0.]
[0. 0. 1.]]
input
[[1. 0. 0. 1. 0. 0.]
[0. 1. 0. 0. 1. 0.]
[0. 0. 1. 0. 0. 1.]
[0. 0. 0. 0. 0. 0.]
[1. 1. 1. 1. 1. 1.]
[0. 0. 0. 0. 0. 0.]]
output
[[3. 0. 0. 3.]
[0. 2. 0. 0.]
[1. 1. 2. 1.]
[1. 1. 1. 1.]]
, но после нескольких прогонов, только когда вы думаете, что все хорошо, один получает следующий вывод:
output
[[ 2.9999976e+00 2.3841858e-07 -2.3841858e-07 3.0000002e+00]
[-4.7683716e-07 2.0000000e+00 -1.1920929e-07 -7.4505806e-08]
[ 9.9999952e-01 1.0000002e+00 2.0000000e+00 1.0000001e+00]
[ 9.9999988e-01 9.9999994e-01 9.9999994e-01 9.9999994e-01]]
Я бы принял небольшую ошибку округления, если бы TF использовал 16 или 8-битные числа с плавающей запятой внутри, но смотрите, например, третью запись в верхнем ряду. Сложение с плавающей точкой и умножение на ноль не приводит к ошибкам округления, но число равно 2.4E-7 от нуля. Это не имеет смысла. Это не частое явление, и иногда может потребоваться от 10 до 20 прогонов, чтобы увидеть.
Я использую пакет TensorFlow 2.1.0, загруженный из PyPI. Компьютер работает под управлением Ubuntu 18.4 Linux и оснащен графическим процессором Nvidia Titan RTX.
Это ожидаемый результат? Я не смог найти ничего в документации, относящейся к этому. Или, скорее, об этой ошибке я должен сообщить разработчикам?