Как мне «горячо закодировать» набор данных Tensorflow? - PullRequest
0 голосов
/ 30 ноября 2018

Сейчас здесь ... Я загрузил набор данных TF следующим образом:

dataset = tf.data.TFRecordDataset(files)
dataset.map(extract_fn)

Набор данных содержит «строковый столбец» с некоторыми значениями, и я хочу «один горячий» кодировать их.Я мог бы сделать это в записи extract_fn по записи, если бы у меня были индексы и глубина (у меня есть только строковое значение на данный момент).Однако есть ли функция TF, которая могла бы сделать это для меня?т.е.

  • Подсчет количества различных значений
  • Сопоставить каждое значение с индексом
  • Создать для этого столбца с горячим кодированием

1 Ответ

0 голосов
/ 30 ноября 2018

Я думаю, что это то, что вы хотите:

import tensorflow as tf

def one_hot_any(a):
    # Save original shape
    s = tf.shape(a)
    # Find unique values
    values, idx = tf.unique(tf.reshape(a, [-1]))
    # One-hot encoding
    n = tf.size(values)
    a_1h_flat = tf.one_hot(idx, n)
    # Reshape to original shape
    a_1h = tf.reshape(a_1h_flat, tf.concat([s, [n]], axis=0))
    return a_1h, values

# Test
x = tf.constant([['a', 'b'], ['a', 'd'], ['c', 'd'], ['b', 'd']])
x_1h, x_vals = one_hot_any(x)
with tf.Session() as sess:
    print(*sess.run([x_1h, x_vals]), sep='\n')

Вывод:

[[[1. 0. 0. 0.]
  [0. 1. 0. 0.]]

 [[1. 0. 0. 0.]
  [0. 0. 1. 0.]]

 [[0. 0. 0. 1.]
  [0. 0. 1. 0.]]

 [[0. 1. 0. 0.]
  [0. 0. 1. 0.]]]
[b'a' b'b' b'd' b'c']

Проблема, однако, заключается в том, что разные входы будут давать несогласованные выходы с разными порядками значений илидаже разную глубину, поэтому я не уверен, что это действительно полезно.

...