Как построить и повторно использовать сети через две входные ветви сети? - PullRequest
0 голосов
/ 03 апреля 2019

Как сделать что-то подобное?

nn = get_networks()
A = nn(X_input)
B = nn(X_other_input)
C = A + B
model = ... 

Так что все тензоры в nn одинаковы, отличаются только ветки ввода-обучения?

В чистом тензорном потоке вы делаете это с

tf.variable_scope('something', reuse=tf.AUTO_REUSE):
       define stuff here

и осторожно присваивая названия слоям.

Но, в основном, вы можете создать nn в первую очередь, потому что вы не можете передать невызванный слой в слой call !

Например:

In [21]: tf.keras.layers.Dense(16)(tf.keras.layers.Dense(8))
...
AttributeError: 'Dense' object has no attribute 'shape'

UPDATE:

Я достиг этого, создав нескомпилированную модель в качестве подсети. Эта «модель» затем может быть передана другим функциям создания сети. Например, если у вас есть функциональное уравнение, которое вы хотите решить, вы можете аппроксимировать функцию сетью, а затем передать сеть функции, которая сама является сетью.

1 Ответ

0 голосов
/ 03 апреля 2019

Зависит от того, как вы хотите использовать его повторно, но идея состоит в том, чтобы сохранить ваши слои после инициализации и использовать их несколько раз позже.

import tensorflow as tf
import tensorflow.keras as keras
import tensorflow.keras.layers as layers
import numpy as np

layers = {}

def net1(input):
    layers["l1"] = keras.layers.Dense(10)
    layers["l2"] = keras.layers.Dense(10)

    return layers["l1"](layers["l2"](keras.layers.Flatten()(input)))

def net2(input):
    return layers["l1"](layers["l2"](keras.layers.Flatten()(input)))

input1 = keras.layers.Input((2, 2))
input2 = keras.layers.Input((2, 2))

model1 = keras.Model(inputs=input1, outputs=net1(input1)) 
model1.compile(loss=keras.losses.mean_squared_error, optimizer=keras.optimizers.Adam())

model2 = keras.Model(inputs=input2, outputs=net2(input2)) 
model2.compile(loss=keras.losses.mean_squared_error, optimizer=keras.optimizers.Adam())

x = np.random.randint(0, 100, (50, 2, 2))
m1 = model1.predict(x)
m2 = model2.predict(x)
print(x[0])
print(m1[0])
print(m2[0])

Выходы идентичны:

[ 10.114908  -13.074531   -8.671929  -59.03201    55.389366    1.3610549
 -38.051434    8.355987    7.5310936 -27.717983 ]
[ 10.114908  -13.074531   -8.671929  -59.03201    55.389366    1.3610549
 -38.051434    8.355987    7.5310936 -27.717983 ]
Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...