Есть ли способ найти пакетное средство в керасе? - PullRequest
0 голосов
/ 01 апреля 2019

Я пытаюсь реализовать прототип сети в керасе.

Я хотел написать Prototype Layer для расчета прототипов, а затем Distance Layer для вычисления расстояний между объектами запроса и прототипами. В слое поддержки я попытался использовать объекты и метки в качестве входных данных, затем применить маску к тензорным объектам и найти среднее значение по партии. Это, однако, не представляется возможным и не является правильным способом реализации архитектуры в целом.

Вот мой код для слоя расстояния:


def euclidean_distance(x):
    a, b = x
    N, D = K.shape(a)[0], K.shape(a)[1]
    M = K.shape(b)[0]
    a = K.tile(K.expand_dims(a, axis=1), (1, M, 1))
    b = K.tile(K.expand_dims(b, axis=0), (N, 1, 1))
    return K.mean(K.square(a - b), axis=2)


class EuclidianLayer(Lambda):

    def __init__(self):
        super(EuclidianLayer, self).__init__(euclidean_distance)

    def build(self, input_shape):
        super(EuclidianLayer, self).build(input_shape)

    def call(self, inputs, mask=None):
        return Lambda.call(self, inputs)

и для поддерживающего слоя:


class SupportLayer(Layer):

    def __init__(self,  n_way, **kwargs):
        self.n_way = n_way
        super(SupportLayer, self).__init__(**kwargs)

    def build(self, input_shape):
        assert isinstance(input_shape, list)
        self.embedding_length = input_shape[0][1]
        super(SupportLayer, self).build(input_shape)  # Be sure to call this at the end

    def call(self, x):
        assert isinstance(x, list)
        X, y = x
        supports = []
        for i in range(self.n_way):
            y_mask = K.cast(K.equal(y, i), 'float32')
            X_masked = multiply([X, K.tile(y_mask, [1, self.embedding_length])])
            supports.append(K.mean(X_masked, axis=0))
        return K.stack(supports)

    def compute_output_shape(self, input_shape):
        return self.n_way, self.embedding_length

Правильно ли я все делаю, и возможно ли вообще реализовать такую ​​сеть в керасе?

Добро пожаловать на сайт PullRequest, где вы можете задавать вопросы и получать ответы от других членов сообщества.
...