Как выполнять линейные алгебраические функции с символами mxnet, чтобы написать пользовательскую функцию потерь (например, онлайн триплет майнинг)?)) - PullRequest
0 голосов
/ 31 декабря 2018

Я ссылался на эту реализацию в тензорном потоке.Требуется форма выходных пакетных вложений, но я не могу получить реальную форму символа mxnet.Есть идеи как?

1 Ответ

0 голосов
/ 04 января 2019

Вы можете использовать infer_shape (), чтобы получить форму символа.Для реализации потери триплета в MXNet вы можете проверить эту тему: https://github.com/apache/incubator-mxnet/issues/6909 Соответственно, вы можете реализовать ее следующим образом:

kernels = [(1, feature_size), (2, feature_size), (3, feature_size)]
for i in range(len(kernels)):
    conv_weight.append(mx.sym.Variable('conv' + str(i) + '_weight'))
    conv_bias.append(mx.sym.Variable('conv' + str(i) + '_bias'))

fa = get_conv(data=anchor,
              kernels=kernels, conv_weight=conv_weight, conv_bias=conv_bias,
              entity_weight=entity_weight, entity_bias=entity_bias,
              feature_name='fa')  # share weight.
fs = get_conv(data=same,
              kernels=kernels, conv_weight=conv_weight, conv_bias=conv_bias,
              entity_weight=entity_weight, entity_bias=entity_bias,
              feature_name='fs')
fd = get_conv(data=diff, 
              kernels=kernels, conv_weight=conv_weight, conv_bias=conv_bias,
              entity_weight=entity_weight, entity_bias=entity_bias,
              feature_name='fd') 

"""
triple-loss
"""
fs = fa - fs
fd = fa - fd
fs = fs * fs
fd = fd * fd
fs = mx.sym.sum(fs, axis=1, keepdims=1)
fd = mx.sym.sum(fd, axis=1, keepdims=1)
loss = fd - fs
loss = one - loss  # a scalar
loss = mx.sym.Activation(data=loss, act_type='relu')  # acts like a norm.
triple_loss = mx.sym.MakeLoss(loss) 
...