Почему GradientTape возвращает None, когда я использую numpy math - PullRequest
1 голос
/ 26 января 2020

Почему GradientTape возвращает None, когда я использую numpy math

Я пытаюсь понять вычисление GradientTape tenorflow для функции потерь RL. Когда я вызываю функцию с помощью np.math, GradientTape возвращает None. Если я использую tf.math в функции, она работает нормально. Я смотрел на tf-агентов, таких как ppo и sa c, и они делают именно (?) То, что я пытаюсь сделать (я пробовал в последних 50 других версиях). Что не так в коде ниже? Чего мне не хватает?

окно 10, python 3.6.8, тензор потока 2.0.0 ref: https://github.com/chagmgang/tf2.0_reinforcement_learning/blob/master/policy/ppo.py

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

def my_loss1(x):
    y=tf.sin(x)
    y=tf.abs(y)
    return y

def my_loss2(x):
    y=np.sin(x)
    y=np.abs(y)
    return y

def main(ver):    
    x = np.linspace(0,10,25)
    dsin_dx=np.cos(x)    
    xx = tf.constant(x)
    with tf.GradientTape() as tape:
        tape.watch(xx)
        if ver==0:
            # my_loss1 with tf math 
            loss1=my_loss1(xx)
        if ver==1:
            #my loss with numpy math
            loss1=my_loss2(np.array(xx))    
            loss1 = tf.convert_to_tensor(loss1, dtype=tf.float64)
        print(loss1)
        loss=tf.reduce_sum(loss1)
        print('loss=',loss)
    grads = tape.gradient(loss, xx)

    fig, ax = plt.subplots(2)
    ax[0].plot(x,loss1,'r')
    print('grads', grads)

    if not grads is None:
        ax[1].plot(x, grads)
        ax[1].plot(x,dsin_dx)
    plt.show()

if __name__ == '__main__':
    main(ver=0)  # This works ok
    main(ver=1)  # This returns grads = None 

1 Ответ

0 голосов
/ 28 января 2020

Проблема в том, что лента градиента записывает только тензоры. Numpy переменные не записываются, поэтому градиент не может быть рассчитан в случае ver = 1. Loss1 в ver1 выглядит идентично loss1 в ver = 0, но зависимость от xx разбита на numpy.
Мой исх. имеет эту ошибку, когда вычисление get_gaes () и вычисление градов неверны.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...