Предоставляет ли TensorFlow наполовину нормальный инициализатор? - PullRequest
0 голосов
/ 21 апреля 2020

TensorFlow обеспечивает random_normal_initializer. Однако мне нужен инициализатор, который выдает числа от 0 до N, но при этом увеличивается плотность вокруг определенного значения x (выбранного пользователем) в диапазоне [0, N] (где N может быть 1), поэтому я не могу использовать унифицированный инициализатор (потому что он одинаково помещает массу на все значения).

Я думаю, что инициализатор, который производит HalfNormal, будет в порядке.

TF уже предоставляет это или мне нужно реализовать собственный инициализатор?

Если мне нужно реализовать собственный инициализатор, каков типичный способ сделать это? Я предполагаю, что могу унаследовать от класса инициализатора , но я не знаю, является ли это обычным способом ведения дел.

Эта проблема также была поднята в проблеме TF трекер *. * 1025

1 Ответ

0 голосов
/ 22 апреля 2020

Согласно этой странице документации , мы можем реализовать пользовательский инициализатор, определив функцию, которая возвращает начальное значение. Затем мы передаем этот объект функции (т.е. вы не вызываете функцию) инициализатору.

Вот пример (в TensorFlow 2.1), который делает то, что я хочу.

import tensorflow as tf


def random_half_normal(shape, **kwargs):
    return tf.abs(tf.keras.backend.random_normal(shape, **kwargs))


class MyLayer(tf.keras.layers.Layer):
    def build(self, input_shape):
        self.my_var = self.add_weight(initializer=random_half_normal, 
                                      trainable=False)

    def call(self, inputs):
        tf.print("\nself.my_var =", self.my_var)
        return inputs


def get_model():
    inp = tf.keras.layers.Input(shape=(1,))
    out = MyLayer(8)(inp)
    model = tf.keras.Model(inputs=inp, outputs=out)
    model.summary()
    return model


def train():
    model = get_model()
    model.compile(optimizer="adam", loss="mae")
    x_train = [2, 3, 4, 1, 2, 6]
    y_train = [1, 0, 1, 0, 1, 1]
    model.fit(x_train, y_train)


if __name__ == '__main__':
    train()
...