Keras не использует GPU для K.get_updates и K.function - PullRequest
0 голосов
/ 04 мая 2019

У меня есть ГАН, написанный в керасе.Он имеет 2 сети, в одной из которых используется настраиваемая функция потерь keras и обновления с train_on_batch, это прекрасно работает и использует графический процессор.Вторая сеть обновляется с использованием K.get_updates и K.function, она работает, но, похоже, она тренируется на ЦП, а не на ГП.

Максимальная нагрузка на ГП при обучении в первой сети и затем падает до 0в то время как вторая сеть обучается.

Если я переключаю сеть обратно на обучение с помощью train_on_batch, то она использует графический процессор.Однако мне нужна функциональность get_updates и K.function.

Это моя обучающая функция:

def train_combo(self):
    #input = K.placeholder(shape=[None, 100])
    var = K.placeholder(shape=[None,1])

    loss = K.sqrt(K.square(K.std(self.combo.output)- K.std(var)))
    optimizer = Adam(lr=0.001)
    updates = optimizer.get_updates(self.combo.trainable_weights,[], loss)

    train = K.function([self.combo.input,var], [loss], updates=updates)
    return train

Вот как я ее называю:

train = self.train_combo()
loss = train([vectors ,var])

Я ожидаю, что он будет работать на графическом процессоре, и, кажется, он работает на процессоре

...