Почему пространственная трансформаторная сеть (STN) не работает на изображении - PullRequest
0 голосов
/ 04 октября 2018

Я пытаюсь подключить один массив изображений к STN, используя код из https://github.com/kevinzakka/spatial-transformer-network:

def STNfn(x): 
    import tensorflow as tf
    print(x.shape)
    B,W,H,C = x.shape
    # identity transform
    initial = np.array([[1., 0, 0], [0, 1., 0]])
    initial = initial.astype('float32').flatten()
    # localization network  
    n_fc = 6
    W_fc1 = tf.Variable(tf.zeros([H*W*C, n_fc]), name='W_fc1')
    b_fc1 = tf.Variable(initial_value=initial, name='b_fc1')
    h_fc1 = tf.matmul(tf.zeros([B, H*W*C]), W_fc1) + b_fc1 

    # spatial transformer layer
    from stn import spatial_transformer_network as transformer
    h_trans = transformer(x, h_fc1) 
    return h_trans

fname = 'testimage.jpg'
img = plt.imread(fname)
img = STNfn(np.array([img]))

Однако я получаю следующую ошибку:

TypeError: Input 'y' of 'Mul' Op has type uint8 
that does not match type float32 of argument 'x'.

Я пытался заменить float32 на np.uint8, но это не помогает.

Где проблема и как ее можно решить?

1 Ответ

0 голосов
/ 15 октября 2018

n_fc = 6 должен быть float32 может быть?Не знаком с Python, в Java он равен 6.0f для float, и только 6 - целое число.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...