К сожалению, tf.contrib.lookup.HashTable
работает только с одномерными тензорами. Вот реализация с tf.SparseTensor
s, которая, конечно, работает, только если ваши ключи являются целочисленными (int32 или int64) тензорами.
Для значений я храню два столбца в двух отдельных тензорах, но если у вас много столбцов, вы можете просто сохранить их в большом тензоре и сохранить индексы в виде значений в одном tf.SparseTensor
.
Этот код (проверено):
import tensorflow as tf
lookup = tf.placeholder( shape = ( 2, ), dtype = tf.int64 )
default_value = tf.constant( [ 1, 1 ], dtype = tf.int64 )
input_tensor = tf.constant( [ 1, 1 ], dtype=tf.int64)
keys = tf.constant( [ [ 1, 2 ], [ 3, 4 ], [ 5, 6 ] ], dtype=tf.int64 )
values = tf.constant( [ [ 4, 1 ], [ 5, 1 ], [ 6, 1 ] ], dtype=tf.int64 )
val0 = values[ :, 0 ]
val1 = values[ :, 1 ]
st0 = tf.SparseTensor( keys, val0, dense_shape = ( 7, 7 ) )
st1 = tf.SparseTensor( keys, val1, dense_shape = ( 7, 7 ) )
x0 = tf.sparse_slice( st0, lookup, [ 1, 1 ] )
y0 = tf.reshape( tf.sparse_tensor_to_dense( x0, default_value = default_value[ 0 ] ), () )
x1 = tf.sparse_slice( st1, lookup, [ 1, 1 ] )
y1 = tf.reshape( tf.sparse_tensor_to_dense( x1, default_value = default_value[ 1 ] ), () )
y = tf.stack( [ y0, y1 ], axis = 0 )
with tf.Session() as sess:
print( sess.run( y, feed_dict = { lookup : [ 1, 2 ] } ) )
print( sess.run( y, feed_dict = { lookup : [ 1, 1 ] } ) )
выведет:
[4 1]
[1 1]
по желанию (ищет значение [4, 1] для клавиши [1, 2] и значение по умолчанию значение [1, 1] для [1, 1] , которое указывает на несуществующую запись.)