TL; DR : используйте tf.nn.convolution()
вместо tf.nn.atrous_conv2d()
и короткое замыкание tf.nn._nn_ops._get_strides_and_dilation_rate()
. (См. Пример кода внизу ответа.)
TS; WM
Реализация tf.nn.atrous_conv2d()
в основном просто вызывает tf.nn.convolution()
с dilation_rate
, установленным на [rate, rate]
. Это текущая позиция в источнике, но я копирую ее здесь (без 130 строк комментариев), потому что она может измениться, так что ссылка может устареть:
def atrous_conv2d(value, filters, rate, padding, name=None):
return convolution(
input=value,
filter=filters,
padding=padding,
dilation_rate=np.broadcast_to(rate, (2,)),
name=name)
Здесь ясно, что np.broadcast_to()
делает невозможным использование тензора для dilation_rate
. Может быть, в tf.nn.convolution()
тогда?
Ну, теоретически, да, до вызова _get_strides_and_dilation_rate()
, который бесполезно имеет такие вещи, как
dilation_rate = np.array(dilation_rate, dtype=np.int32)
После этого классу Convolution
все равно, и _WithSpaceToBatch.__init__()
сразу же преобразует его в тензор:
dilation_rate = ops.convert_to_tensor(
dilation_rate, dtypes.int32, name="dilation_rate")
Так что, если вы уверены, что с вашими параметрами все в порядке, вы можете коротко замкнуть tf.nn._get_strides_and_dilation_rate()
, а затем вызвать tf.nn.convolution()
напрямую с двухкратной степенью расширения, как этот код (проверено):
import tensorflow as tf
square_size = 5
dr = tf.placeholder( shape = ( 2, ), dtype = tf.int32 )
inp = tf.reshape( tf.constant( range( square_size * square_size ), dtype = tf.float32 ), ( 1, square_size, square_size, 1 ) )
fltr = tf.reshape( tf.ones( ( 3, 3 ) ), ( 3, 3, 1, 1 ))
_original = tf.nn._nn_ops._get_strides_and_dilation_rate
tf.nn._nn_ops._get_strides_and_dilation_rate = lambda a, b, c : ( b, c )
state = tf.nn.convolution( inp, fltr, "SAME", strides = [ 1, 1 ], dilation_rate = dr )
tf.nn._nn_ops._get_strides_and_dilation_rate = _original
with tf.Session() as sess:
print( sess.run( tf.squeeze( inp ) ) )
print
print( sess.run( tf.squeeze( state ), feed_dict = { dr : [ 2, 2 ] } ) )
print
print( sess.run( tf.squeeze( state ), feed_dict = { dr : [ 3, 3 ] } ) )
выведет результаты с динамически изменяемой степенью расширения (tf.squeeze
только для разборчивости):
[[0. 1. 2. 3. 4.]
[5. 6. 7. 8. 9.]
[10. 11. 12. 13. 14.]
[15. 16. 17. 18. 19.]
[20. 21. 22. 23. 24.]]
[[24. 28. 42. 28. 32.]
[44. 48. 72. 48. 52.]
[66. 72. 108. 72. 78.]
[44. 48. 72. 48. 52.]
[64. 68. 102. 68. 72.]]
[[36. 40. 19. 36. 40.]
[56. 60. 29. 56. 60.]
[23. 25. 12. 23. 25.]
[36. 40. 19. 36. 40.]
[56. 60. 29. 56. 60.]]