Взвешенная сумма входа внутри сети - PullRequest
0 голосов
/ 03 декабря 2018

У меня есть сеть с несколькими входами, и я разделил первые 10 входов и вычислил взвешенную сумму, а затем объединил ее с остальными входными данными:

first = Lambda(lambda z: z[:, 0:11])(d_inputs)
wsum_first = Lambda(calcWSumF)(first )
d_input = concatenate([d_inputs, wsum_first], axis=-1)

с функцией, определенной как:

w_vec = K.constant(np.array([range(10)]*64).reshape(10, 64)) # batch size is 64
def calcWSumF(x):
    y = K.dot(w_vec, x)
    y = K.expand_dims(y, -1)       
    return y

Я хочу использовать постоянный вектор для расчета взвешенной суммы первой части входных данных.Конкатенация не работает, потому что формы не совпадают.Как я могу реализовать это правильно?

1 Ответ

0 голосов
/ 03 декабря 2018

Вы можете написать это намного лучше, используя K.sum и только вектор, содержащий коэффициенты.Кроме того, нет необходимости использовать фиксированный размер пакета (это может быть любое число):

def calcWSumF(x, idx):
    w_vec = K.constant(np.arange(idx))
    y = K.sum(x[:, 0:idx] * w_vec, axis=-1, keepdims=True)
    return y

d_inputs = Input((15,))
wsum_first = Lambda(calcWSumF, arguments={'idx': 10})(d_inputs)
d_input = concatenate([d_inputs, wsum_first], axis=-1)

model = Model(d_inputs, d_input)
model.predict(np.arange(15).reshape(1, 15))

# output:
array([[  0.,   1.,   2.,   3.,   4.,   5.,   6.,   7.,   8.,   9.,  10.,
         11.,  12.,  13.,  14., 285.]], dtype=float32)

# Note: 0*0 + 1*1 + 2*2 + ... + 9*9 = 285

Обратите внимание, что, чтобы сделать его более общим, мы добавили еще один аргумент (idx) кЛямбда-функция, которая определяет, сколько элементов с начала мы бы хотели рассмотреть.

...