Я пытаюсь подключить один массив изображений к STN, используя код из https://github.com/kevinzakka/spatial-transformer-network:
def STNfn(x):
import tensorflow as tf
print(x.shape)
B,W,H,C = x.shape
# identity transform
initial = np.array([[1., 0, 0], [0, 1., 0]])
initial = initial.astype('float32').flatten()
# localization network
n_fc = 6
W_fc1 = tf.Variable(tf.zeros([H*W*C, n_fc]), name='W_fc1')
b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
h_fc1 = tf.matmul(tf.zeros([B, H*W*C]), W_fc1) + b_fc1
# spatial transformer layer
from stn import spatial_transformer_network as transformer
h_trans = transformer(x, h_fc1)
return h_trans
fname = 'testimage.jpg'
img = plt.imread(fname)
img = STNfn(np.array([img]))
Однако я получаю следующую ошибку:
TypeError: Input 'y' of 'Mul' Op has type uint8
that does not match type float32 of argument 'x'.
Я пытался заменить float32 на np.uint8
, но это не помогает.
Где проблема и как ее можно решить?