Как переместить модель tenorflow.keras в графический процессор - PullRequest
0 голосов
/ 06 января 2020

Допустим, у меня есть модель keras, подобная этой:

with tf.device("/CPU"):
    model = tf.keras.Sequential([
    # Adds a densely-connected layer with 64 units to the model:
    tf.keras.layers.Dense(64, activation='relu', input_shape=(32,)),
    # Add another:
    tf.keras.layers.Dense(64, activation='relu'),
    # Add a softmax layer with 10 output units:
    tf.keras.layers.Dense(10, activation='softmax')])

Я хотел бы переместить эту модель на графический процессор.

Я пытался сделать это:

with tf.device("/GPU:0"):
    gpu_model = tf.keras.models.clone_model(model)

Но проблема в том, что имена переменных меняются. Например:

Имя веса первого слоя model: Получено из model.layers[0].weights[0].name

'density / kernel: 0'

Но имя веса первого слоя gpu_model: Получено из gpu_model.layers[0].weights[0].name

'density_3 / kernel: 0'

Как я могу сделать это преобразование графического процессора, пока также сохраняя имена переменных?

Я не хочу сохранять модель на диск и загружать снова

1 Ответ

0 голосов
/ 06 января 2020

Я отвечаю на свой вопрос. Если у кого-то есть лучшее решение. Пожалуйста, опубликуйте это

Это обходной путь, который я нашел:

  1. Создайте state_dict как PyTorch
  2. Получите архитектуру модели как JSON
  3. Очистите сеанс Keras и удалите экземпляр модели
  4. Создайте новую модель из JSON в tf.device контексте
  5. Загрузите предыдущие веса из state_dict
state_dict = {}
for layer in model.layers:
    for weight in layer.weights:
        state_dict[weight.name] = weight.numpy()

model_json_config = model.to_json()
tf.keras.backend.clear_session() # this is crucial to get previous names again
del model

with tf.device("/GPU:0"):
    new_model = tf.keras.models.model_from_json(model_json_config)

for layer in new_model.layers:
    current_layer_weights = []
    for weight in layer.weights:
        current_layer_weights.append(state_dict[weight.name])
    layer.set_weights(current_layer_weights)
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...