Тензорное вычитание
Чтобы вычесть среднее значение изображения из пакета данных изображения, вы можете просто использовать оператор минус (который является синтаксическим сахаром для tf.subtract
):
In [28]: x = tf.zeros((2, 3, 3))
In [29]: x
Out[29]:
<tf.Tensor: id=38, shape=(2, 3, 3), dtype=float32, numpy=
array([[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]],
[[0., 0., 0.],
[0., 0., 0.],
[0., 0., 0.]]], dtype=float32)>
In [30]: mean = tf.eye(3)
In [31]: mean
Out[31]:
<tf.Tensor: id=42, shape=(3, 3), dtype=float32, numpy=
array([[1., 0., 0.],
[0., 1., 0.],
[0., 0., 1.]], dtype=float32)>
In [32]: x - mean
Out[32]:
<tf.Tensor: id=44, shape=(2, 3, 3), dtype=float32, numpy=
array([[[-1., 0., 0.],
[ 0., -1., 0.],
[ 0., 0., -1.]],
[[-1., 0., 0.],
[ 0., -1., 0.],
[ 0., 0., -1.]]], dtype=float32)>
Чтение изображения в тензор
Чтобы получить PNG-изображение в качестве тензора TensorFlow, просто оберните массив numpy с помощью tf.constant
:
import cv2
mean_img = cv2.imread('/path/to/the/image')
mean_img_tensor = tf.constant(mean_img)
Обратите внимание, что OpenCV по умолчанию считывает изображение в цветовое пространство BGR.Вы можете преобразовать его в RGB:
mean_img = cv2.cvtColor(mean_img, cv2.COLOR_BGR2RGB))
Или использовать библиотеку изображений Python:
from PIL import Image
import numpy as np
mean_img = Image.open('/path/to/image')
mean_img_tensor = tf.constant(np.array(mean_img))
Собрать все вместе
Поскольку вы используете TF Dataset API, я считаю, что map_and_batch
должно быть лучшим решением для повышения производительности:
def datasetLoader(dataSetPath, batchSize, mean_image_path):
dataset = tf.data.TFRecordDataset(dataSetPath)
mean_img = cv2.cvtColor(cv2.imread(mean_image_path), cv2.COLOR_BGR2RGB)
mean = tf.constant(mean_img)
dataset = dataset.map(_ds_parser, num_parallel_calls=8)
# This dataset will go on forever
dataset = dataset.repeat()
def preprocess(X, Y):
# Bring the date back in shape
X = tf.reshape(X, [-1, 66, 198, 3])
Y = tf.reshape(Y,[-1,1])
X = X - mean
return X, Y
# Set the batchsize
dataset = dataset.apply(tf.contrib.data.map_and_batch(map_func=preprocess, batch_size=batchSize, num_parallel_calls=8))
return dataset.make_one_shot_iterator().get_next()