Как избежать использования большой памяти в пользовательской функции потерь в keras - PullRequest
1 голос
/ 06 мая 2020

Я определил пользовательскую функцию потерь в keras. В этой пользовательской функции потерь я извлекаю несмежные значения из y_pred следующим образом:

sel_row = tf.constant([[2],[5],[8]])
row_tmp = y_pred
selected = tf.transpose(tf.gather_nd(tf.transpose(row_tmp), sel_row))

С помощью этого я просто выбираю столбец из тензора. Теперь, если я сделаю то же самое, но для смежных столбцов, т.е. row_tmp[:, 2:5], у меня нет проблем, но с не непрерывными столбцами я получу:

/tensorflow/lib/python3.7/site-packages/tensorflow_core/python/framework/indexed_slices.py:424: 
UserWarning: Converting sparse IndexedSlices to a dense Tensor of unknown shape. 
This may consume a large amount of memory.
  "Converting sparse IndexedSlices to a dense Tensor of unknown shape. "

Все работает, но было бы неплохо иметь лучший способ не использовать слишком много памяти.

Я попытался заменить tf.constant на tf.Variable, но возникла эта ошибка:

ValueError: tf.function-decorated function tried to create variables on non-first call.

Есть какие-нибудь советы?

1 Ответ

1 голос
/ 06 мая 2020

Вы можете просто сделать:

selected = tf.gather(row_tmp, tf.squeeze(sel_row, axis=1), axis=1)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...