Как ограничить переменную тензорного потока значениями int8 или uint8? - PullRequest
0 голосов
/ 12 октября 2019

Вот упрощенная версия моего кода:

import tensorflow as tf
x = tf.zeros((10,10), dtype=tf.dtypes.uint8)
x = tf.Variable(x)

with tf.GradientTape() as t:
    obj = 1- tf.reduce_sum(x) # can be anything

optimizer = tf.optimizers.Adam(0.1)
var_list = [x]

grads = t.gradient(obj, var_list)
optimizer.apply_gradients(zip(grads, var_list))

Я хочу ограничить x значением uint8. Это сделано для того, чтобы убедиться, что когда я вычисляю x и, например, превращаю его в объект Image (Image.fromarray(x.numpy())), вся тензорная информация сохраняется. Я не хочу делать что-то вроде (x.numpy() * 255).astype(np.uint8), так как это приведет к потере информации.

Однако, когда я запускаю приведенный выше код, я получаю

---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-11-02f5e5f2081e> in <module>
     10 
     11 grads = t.gradient(obj, var_list)
---> 12 optimizer.apply_gradients(zip(grads, var_list))

/anaconda3/envs/ml/lib/python3.6/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py in apply_gradients(self, grads_and_vars, name)
    425       ValueError: If none of the variables have gradients.
    426     """
--> 427     grads_and_vars = _filter_grads(grads_and_vars)
    428     var_list = [v for (_, v) in grads_and_vars]
    429 

/anaconda3/envs/ml/lib/python3.6/site-packages/tensorflow_core/python/keras/optimizer_v2/optimizer_v2.py in _filter_grads(grads_and_vars)
   1023   if not filtered:
   1024     raise ValueError("No gradients provided for any variable: %s." %
-> 1025                      ([v.name for _, v in grads_and_vars],))
   1026   if vars_with_empty_grads:
   1027     logging.warning(

ValueError: No gradients provided for any variable: ['patch:0'].

Кажется, чтокогда x равно uint8, я не могу принять его градиент. Как я могу сделать x дифференцируемым и в то же время убедиться, что значения ограничены значениями, разрешенными uint8, чтобы при оценке и некоторой обработке, ожидаемой для тензора uint8, не терялась информация?

...