tf.keras.backend.clip не дает правильных результатов - PullRequest
0 голосов
/ 16 июня 2020

tf.keras.backend.clip не обрезает тензоры

Когда я использую tf.keras.backend.clip внутри этой функции

def grads_ds(model_ds, ds_inputs,y_true,cw):
    print(y_true)
    with tf.GradientTape() as ds_tape:
        y_pred = model_ds(ds_inputs)
        print(y_pred.numpy())
        logits_1 = -1*y_true*K.log(y_pred)*cw[:,0]
        logits_0 = -1*(1-y_true)*K.log(1-y_pred)*cw[:,1]
        loss = logits_1 + logits_0
        loss_value_ds = K.sum(loss)

    ds_grads = ds_tape.gradient(loss_value_ds,model_ds.trainable_variables,unconnected_gradients=tf.UnconnectedGradients.NONE)
    for g in ds_grads:
        g = tf.keras.backend.clip(g,min_grad,max_grad)
    return loss_value_ds, ds_grads

Значение градиентов остается прежним (не обрезано).

Когда я использую tf.keras.backend.clip внутри пользовательского обучения l oop, точно так же

for g in ds_grads:
    g = tf.keras.backend.clip(g,min_grad,max_grad)

это не работает. Градиент, применяемый к переменным, не обрезается.

Однако, если я напечатаю g в пределах l oop, тогда будет показано обрезанное значение.

Не могу понять, где находится проблема есть.

1 Ответ

1 голос
/ 17 июня 2020

Это потому, что g в вашем примере является ссылкой на значение в списке. Когда вы назначаете ему, вы просто меняете значение, на которое он указывает (ie вы не изменяете текущее значение, на которое он указывает). Рассмотрим этот пример, я хочу установить все значения в lst на 5. Угадайте, что происходит, когда вы запускаете этот образец кода?

lst = [1,2,3,4]
for ele in lst:
    ele = 5
print(lst)

Ничего! Вы получите тот же список обратно. Однако внутри l oop вы увидите, что ele теперь 5, как вы уже выяснили в вашем случае. Это был тот случай, когда значения в списке неизменяемы (тензоры неизменны).

Однако вы можете изменять изменяемые объекты на месте:

lst = [[2], [2], [2]]
for ele in lst:
    ele.append(3)
print(lst)

Приведенный выше код сделает каждый элемент [2, 3] как и ожидалось.

Один из способов решения вашей проблемы:

lst = [1,2,3,4]
for itr in range(len(lst)):
    lst[itr] = 5
print(lst)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...