Я строю простую модель в TensorFlow (v 2.1) и сталкиваюсь со странным поведением с tf.gather
- Возможно, я не понимаю, что он делает.
Я рассматривая модель, которая может иметь несколько перехватов (т.е. y = a[i] + X@b
). Я определяю новый слой, как показано ниже,
class GroupedInterceptLinearCoeffs_gather(tf.keras.layers.Layer):
"""
"""
def __init__(self, ngroup=1, **kwargs):
super(GroupedInterceptLinearCoeffs_gather, self).__init__()
self.ngroup = ngroup
def build(self, input_shape):
self.a = self.add_weight(
shape=(self.ngroup,), dtype="float32",
initializer="random_normal", trainable=True
)
self.b = self.add_weight(
shape=(input_shape[1][-1],), dtype="float32",
initializer="random_normal", trainable=True
)
@tf.function()
def call(self, inputs):
out = tf.gather(self.a, inputs[0], axis=0, batch_dims=0) + tf.linalg.matvec(inputs[1], self.b)
return out
, а затем проверяю, что он выполняет то, что я ожидаю с
import numpy as np
import tensorflow as tf
nobs = 100
alpha = 1.0 # To keep things simple, we'll only have one intercept here
beta = np.array([0.0, 0.5, 0.25])
L = np.array([[1.0, 0.0, 0.0], [0.25, 1.1, 0.0], [0.2, 0.2, 1.25]])
X = np.random.randn(nobs, 3) @ L
y = alpha + X@beta
. Проверка того, может ли модель воспроизводить мои данные, показывает, что существует (эффективно) 0 ошибка
gilc_g = GroupedInterceptLinearCoeffs_gather(ngroup=1)
gilc_g([np.zeros((X.shape[0],), dtype=np.int32), X.astype(np.float32)])
gilc_g.set_weights([np.array([alpha], dtype=np.float32), beta.astype(np.float32)])
np.max(
np.abs(
gilc_g(
[np.zeros((X.shape[0],), dtype=np.int32), X.astype(np.float32)]
).numpy() - (alpha + X@beta)
)
)
, но когда я пытаюсь подогнать модель к ней, она быстро перестает делать успехи.
class OLS_gather(tf.keras.Model):
"""
"""
def __init__(self, ngroups=1, name="ols", **kwargs):
super(OLS_gather, self).__init__(name=name, **kwargs)
self.lm = GroupedInterceptLinearCoeffs_gather(ngroups)
def call(self, inputs):
print(inputs[0].shape)
print(inputs[1].shape)
out = self.lm(inputs)
return out
olsmodel_g = OLS_gather(ngroups=1)
olsmodel_g.compile(optimizer, loss=tf.keras.losses.MeanSquaredError())
olsmodel_g.fit([np.zeros((X.shape[0],), dtype=np.int32), X.astype(np.float32)], y.astype(np.float32), epochs=50)
Проверка весов b
показывает, что это не перемещая веса в правильном направлении, но похожая модель (без сбора) быстро сходится (см. этот gist для всего кода). Я неправильно использую tf.gather
? Если да, есть ли другой способ «переиндексировать» массив, подобный этому, для создания дубликатов в определенном порядке?
(Кроме того, я знаю, что мне не обязательно создавать свои собственные слои / модели, но Мой пример немного сложнее, и мне нужна пользовательская функция потерь et c ...)