Почему tape.gradient не возвращает ничего в моей последовательной модели? - PullRequest
0 голосов
/ 14 июля 2020

Мне нужно вычислить градиенты этой модели:

model=Sequential()
model.add(Dense(40, activation='relu',input_dim=12))
model.add(Dense(60, activation='relu'))
model.add(Dense(units=3, activation='softmax'))
opt=tf.keras.optimizers.Adam(lr=0.001)
model.compile(loss="mse", optimizer=opt)

model_q=Sequential()
model_q.add(Dense(40, activation='relu',input_dim=15))
model_q.add(Dense(60, activation='relu'))
model_q.add(Dense(units=1, activation='linear'))
opt=tf.keras.optimizers.Adam(lr=0.001)
model_q.compile(loss="mse", optimizer=opt)

x=np.random.random(12)
x2=model.predict(x.reshape(-1,12))
with tf.GradientTape() as tape:
            value = model_q([tf.convert_to_tensor(np.append(x,x2).reshape(-1,15))])
            loss = -tf.reduce_mean(value)
grad = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grad, model.trainable_variables))

, но grad не возвращает ничего, поэтому opt не может применить градиенты к модели. Почему это происходит? Я знаю, что это довольно странная потеря, но я бы хотел ее вычислить

1 Ответ

1 голос
/ 15 июля 2020

Ваш model не записывается на ленту. Вы должны поместить вычисления в контекст ленты, если хотите получить градиенты.

model=Sequential()
model.add(Dense(40, activation='relu',input_dim=12))
model.add(Dense(60, activation='relu'))
model.add(Dense(units=3, activation='softmax'))
opt=tf.keras.optimizers.Adam(lr=0.001)

model_q=Sequential()
model_q.add(Dense(40, activation='relu',input_dim=15))
model_q.add(Dense(60, activation='relu'))
model_q.add(Dense(units=1, activation='linear'))
opt=tf.keras.optimizers.Adam(lr=0.001)

x=np.random.random(12).reshape(-1,12)
with tf.GradientTape() as tape:
  x2 = model([x])
  value = model_q([tf.concat((x,x2), -1)])
  loss = -tf.reduce_mean(value)
grad = tape.gradient(loss, model.trainable_variables)
opt.apply_gradients(zip(grad, model.trainable_variables))
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...