Как переставить оси тензора в кератах? - PullRequest
0 голосов
/ 21 января 2019

Мой код похож на 100

 user_item_matrix = K.constant(user_item_matrix)
    # Input variables
    user_input = Input(shape=(1,), dtype='int32', name='user_input')
    item_input = Input(shape=(1,), dtype='int32', name='item_input')
    # Embedding layer
    user_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=0))(user_input)
    item_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=1))(item_input)

где user_item_matrix - матрица 6040 * 3706. Предполагается, что формами user_rating и item_rating являются (?, 3706) и (?, 6040). Однако реальная ситуация такова:

user_rating:  (?, 1, 3706)
item_rating:  (6040, ?, 1)

Я не понимаю, почему 6040 произошло на оси 0, где это должно быть? (размер партии). Я пытаюсь использовать Permute и Reshape для решения этой проблемы, но все еще не работает. Есть ли хорошее решение для решения такой проблемы? Спасибо.

1 Ответ

0 голосов
/ 21 января 2019

Вы можете видеть документ о tf.gather():

Создает выходной тензор с формой params.shape [: axis] + indices.shape + params.форма [ось + 1:]

Ваша форма параметров (6040,3706) и форма индексов (?,1).

Таким образом, форма вывода params.shape[:0] + indices.shape + params.shape[1:] = () + (?,1) + (3706,) при установке axis=0.

И форма вывода params.shape[:1] + indices.shape + params.shape[2:] = (6040,) + (?,1) + () при установке axis=1.

Вы можете использовать tf.transpose() для перестановки осей.

import tensorflow as tf
import keras.backend as K
from keras.layers import Input,Lambda
import numpy as np

user_item_matrix = K.constant(np.zeros(shape=(6040,3706)))
# Input variables
user_input = Input(shape=(1,), dtype='int32', name='user_input')
item_input = Input(shape=(1,), dtype='int32', name='item_input')
# Embedding layer
user_rating = Lambda(lambda x: tf.gather(user_item_matrix, tf.to_int32(x), axis=0))(K.squeeze(user_input,axis=1))
item_rating = Lambda(lambda x: tf.transpose(tf.gather(user_item_matrix, tf.to_int32(x), axis=1),(1,0)))(K.squeeze(item_input,axis=1))
print(user_rating.shape)
print(item_rating.shape)

# print
(?, 3706)
(?, 6040)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...