(керас) введите тензор и индекс, получите тензор, который я хочу - PullRequest
0 голосов
/ 27 сентября 2019

Мой вход подобен массиву (3,3,2) и массиву (3,3):

img = np.array([[[1,1],[2,2],[3,3]],
                [[4,4],[5,5],[6,6]],
                [[7,7],[8,8],[9,9]]])

idx = np.array([[1,0,0],
                [0,0,1],
                [1,1,0]])

Мой идеальный вывод должен быть:

[[1 1]
 [6 6]
 [7 7]
 [8 8]]

Iхочу сделать это с помощью пользовательского слоя:

  1. сделать слой:
def extract_layer(data, idx):

    idx = tf.where(idx)
    data = tf.gather_nd(data,idx)
    data = tf.reshape(data,[-1,2])

    return data
сделать в модель:
input_data = kl.Input(shape=(3,3,2))
input_idxs = kl.Input(shape=(3,3))
extraction = kl.Lambda(lambda x:extract_layer(*x),name='extraction')([input_data,input_idxs])

Я могу построить модель, и я могу увидеть сводку керас модели, на выходе будет

model = Model(inputs=([input_data,input_idxs]), outputs=extraction)
model.summary()

...
input_1 (InputLayer)            (None, 3, 3, 2) 
input_2 (InputLayer)            (None, 3, 3) 
extraction (Lambda)             (None, 2)
Total params: 0
...

, но когдая начинаю предсказывать как:

'i have already made the two inputs into (1,3,3,2) and (1,3,3) shape'
result = model.predict(x=([img,idx]))

он получает ошибку:

'ValueError: could not broadcast input array from shape (4,2) into shape (1,2)'

я думаю, что тензор формы (4,2) - это значение, которое я хочу, но я не знаюпочему keras передает его (1,2)

, есть кто-нибудь, кто может мне помочь ??

большое спасибо!

1 Ответ

0 голосов
/ 27 сентября 2019

В вашей функции extract_layer() data - тензор двух димсов.Но model.predict должен возвращать результаты с дополнительным затемнением партии.Просто разверните dim, когда return data in extract_layer() может исправить эту ошибку.

def extract_layer(data, idx):

    idx = tf.where(idx)
    data = tf.gather_nd(data,idx)
    data = tf.reshape(data,[-1,2])

    return tf.expand_dims(data, axis=0)

Примечание : так как результаты, возвращаемые tf.gather_nd, могут иметь различную длину, я думаюразмер партии будет только 1. Пожалуйста, исправьте меня, если я ошибаюсь.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...