Как использовать несколько графических процессоров с сохраненной и восстановленной моделью keras - PullRequest
0 голосов
/ 21 мая 2019

Я тренируюсь, чтобы обучать и распределять свою функциональную модель keras на различные графические процессоры в моей системе. Это работает на первом этапе, после проверки моего использования GPU (команда: watch -n0.5 nvidia-smi). После восстановления функционала распараллеливание для разных ветвей модели все равно не работает в процессе обучения.

import tensorflow as tf
import numpy as np
from tensorflow.keras.layers import Input, Dense, Concatenate
from tensorflow.keras.models import Model, load_model, Sequential
with tf.device('/gpu:0'):
    Input_1=Input(shape=(256,),name="Input_1")
    Dense_1=Dense(128)(Input_1)
with tf.device('/gpu:1'):
    Input_2=Input(shape=(256,),name="Input_2")
    Dense_2=Dense(128)(Input_2)
with tf.device('/gpu:2'):
    Input_3=Input(shape=(256,),name="Input_3")
    Dense_3=Dense(128)(Input_3)
with tf.device('/cpu:0'):
    Concatenate_=Concatenate()([Dense_1,Dense_2,Dense_3])
    output=Dense(1)(Concatenate_)
model=Model(inputs=[Input_1,Input_2,Input_3],outputs=[output])
model=tf.keras.utils.multi_gpu_model(model,gpus=3)
model.compile(optimizer="sgd",loss="mean_squared_error")
model.save("./test.h5")
input_dict={}
for k in ["Input_1","Input_2","Input_3"]:
    input_dict.update({k: np.random.standard_normal((10000,256))})
    output_dict={model.layers[-1].name: np.random.standard_normal((10000,1))}
test_0=model.fit(x=input_dict,y=output_dict,batch_size=128,epochs=20)
del model

model=load_model("./test.h5")
test_1=model.fit(x=input_dict,y=output_dict,batch_size=128,epochs=20)

После восстановления модели в последней части кода модель использует только один графический процессор. Удаление строки "model = tf.keras.utils.multi_gpu_model (model, gpus = 3)" не помогает в конце концов. Мои 3 графических процессора - GEFORCE RTX 2080 Ti, и я использую Anacondaa в Ubuntu. Как заставить модель работать с распараллеливанием на GPU, когда я восстанавливаю ее через «load_model»?

...