Почему get_tensor_by_name не может правильно получить веса слоев, определенные tf.keras.layers - PullRequest
1 голос
/ 06 мая 2019

Я пытаюсь получить веса слоев, определенных как tf.keras.layers, используя get_tensor_by_name в tensorflow.Код представлен следующим образом:

# encoding: utf-8
import tensorflow as tf

x = tf.placeholder(tf.float32, (None,3))
h = tf.keras.layers.dense(3)(x)
y = tf.keras.layers.dense(1)(h)

for tn in tf.trainable_variables():
    print(tn.name)

sess = tf.Session()
sess.run(tf.global_variables_initializer())
w = tf.get_default_graph().get_tensor_by_name("dense/kernel:0")
print(sess.run(w))

Название веса: dense/kernel:0.Однако вывод sess.run(w) является странным

[( 10,) ( 44,) ( 47,) (106,) (111,) ( 98,) ( 58,) (108,) (111,) ( 99,)
 ( 97,) (108,) (104,) (111,) (115,) (116,) ( 47,) (114,) (101,) 
 ... ]

, который не является массивом с плавающей точкой.На самом деле, если я использую tf.layers.dense для определения сети, все идет хорошо.Поэтому мой вопрос заключается в том, как я могу получить веса слоев, определенные tf.keras.layers, правильно используя тензорное имя.

1 Ответ

0 голосов
/ 06 мая 2019

Вы можете использовать get_weights() для слоев, чтобы получить значения веса отдельных слоев. Вот пример кода для вашего случая:

import tensorflow as tf

input_x = tf.placeholder(tf.float32, [None, 3], name='x')    
dense1 = tf.keras.Dense(3, activation='relu')
l1 = dense1(input_x)
dense2 = tf.keras.Dense(1)
y = dense2(l1)

weights = dense1.get_weights()

С помощью Keras API это можно сделать еще проще:

def mymodel():
    i = Input(shape=(3, ))
    x = Dense(3, activation='relu')(i)
    o = Dense(1)(x)

    model = Model(input=i, output=o)
    return model


model = mymodel()

names = [weight.name for layer in model.layers for weight in layer.weights]
weights = model.get_weights()

for name, weight in zip(names, weights):
    print(name, weight.shape)

Этот пример получает весовые матрицы для каждого слоя вашей модели.

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...