Предложение 1: Моя попытка заключается в следующем, поскольку он просто использует tf.image.extract_image_patches
и tf.extract_volume_patches
, реализация поддерживает только 2d и 3d изображения.
Предложение 2: можно просто отформатировать данные какшаг предварительной обработки (через tf.data.Dataset.map
), однако это также занимает много времени, пока я не уверен, почему (пример https://gist.github.com/pangyuteng/ca5cb07fe383ebe59b521c832f2e2918).
Предложение 3: использовать сверточные блоки для распараллеливания обработки, см.«Гиперколонки для сегментации объекта и мелкозернистой локализации» https://arxiv.org/abs/1411.5752.
-
Код предложения 1:
import time
import matplotlib.pyplot as plt
%matplotlib inline
import numpy as np
import tensorflow as tf
from tensorflow.contrib.data.python.ops import sliding
from skimage import img_as_float, data
from scipy.signal import medfilt
dtype = 2
if dtype==2:
imgs = img_as_float(data.camera())
elif dtype==3:
imgs = np.random.rand(28,28,28)
imgs = img_as_float(data.camera())
### SCIPY median ###
stime = time.time()
scipysmoothed = medfilt(imgs,(9,9))
etime = time.time()
print('scipy smoothed: {:1.4f} seconds'.format(etime-stime))
### TF median ###
method = 'Tensorflow'
imgs = np.expand_dims(imgs,axis=-1)
imgs = np.expand_dims(imgs,axis=0)
print('imgs.shape:{}'.format(imgs.shape))
imgs = tf.cast(imgs,tf.float32)
stime = time.time()
if len(imgs.shape) == 4:
kernel=(1,9,9,1)
stride=(1,1,1,1)
rates=(1,1,1,1)
padding='SAME'
patches=tf.image.extract_image_patches(
imgs,kernel,stride,rates,padding,
)
_,x,y,n = patches.shape
_,sx,sy,_ = kernel
window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
patches = tf.reshape(patches,[x*y,sx,sy])
smoothed = tf.map_fn(lambda x: window_func(patches[x,:,:]), tf.range(x*y), dtype=tf.float32)
smoothed = tf.reshape(smoothed,[x,y])
elif len(imgs.shape) == 5:
kernel=(1,12,12,12,1)
stride=(1,1,1,1,1)
padding='SAME'
patches=tf.extract_volume_patches(
imgs,kernel,stride,padding,
)
_,x,y,z,n = patches.shape
_,sx,sy,sz,_ = kernel
window_func = lambda x: tf.contrib.distributions.percentile(x, 50.0)
patches = tf.reshape(patches,[x*y*z,sx,sy,sz])
smoothed = tf.map_fn(lambda x: window_func(patches[x,:,:]), tf.range(x*y*z), dtype=tf.float32)
smoothed = tf.reshape(smoothed,[x,y,z])
else:
raise NotImplemented()
with tf.Session() as sess:
output = sess.run(smoothed)
etime = time.time()
print('tf smoothed: {:1.4f} seconds'.format(etime-stime))
print(output.shape)
plt.figure(figsize=(20,20))
plt.subplot(131)
imgs = img_as_float(data.camera())
plt.imshow(imgs.squeeze(),cmap='gray',interpolation='none')
plt.title('original')
plt.subplot(132)
plt.imshow(output.squeeze(),cmap='gray',interpolation='none')
plt.title('actual smoothed\nwith {}'.format(method))
plt.subplot(133)
plt.imshow(scipysmoothed,cmap='gray',interpolation='none')
_=plt.title('expected smoothed')