У меня есть ГАН, написанный в керасе.Он имеет 2 сети, в одной из которых используется настраиваемая функция потерь keras и обновления с train_on_batch, это прекрасно работает и использует графический процессор.Вторая сеть обновляется с использованием K.get_updates и K.function, она работает, но, похоже, она тренируется на ЦП, а не на ГП.
Максимальная нагрузка на ГП при обучении в первой сети и затем падает до 0в то время как вторая сеть обучается.
Если я переключаю сеть обратно на обучение с помощью train_on_batch, то она использует графический процессор.Однако мне нужна функциональность get_updates и K.function.
Это моя обучающая функция:
def train_combo(self):
#input = K.placeholder(shape=[None, 100])
var = K.placeholder(shape=[None,1])
loss = K.sqrt(K.square(K.std(self.combo.output)- K.std(var)))
optimizer = Adam(lr=0.001)
updates = optimizer.get_updates(self.combo.trainable_weights,[], loss)
train = K.function([self.combo.input,var], [loss], updates=updates)
return train
Вот как я ее называю:
train = self.train_combo()
loss = train([vectors ,var])
Я ожидаю, что он будет работать на графическом процессоре, и, кажется, он работает на процессоре