Как применить операцию "Собери", как NumPy в Caffe2? - PullRequest
0 голосов
/ 17 декабря 2018

Я новичок в Caffe2 и хочу написать такую ​​операцию:

  • Numpy way

пример кода

  • pytoch way

пример кода

Мой вопрос заключается в том, как составить операторы Caffe2 для создания тех же операторов, что и выше?Я попробовал несколько композиций, но все равно не смог найти подходящую.Если кто-нибудь знает композицию, пожалуйста, помогите, я буду очень признателен за это.

1 Ответ

0 голосов
/ 06 мая 2019

В Caffe2 есть оператор Gather.Основная проблема с этим оператором заключается в том, что вы не можете установить ось (она всегда равна 0).Итак, если мы запустим этот код:

model = ModelHelper(name="test")

s = np.arange(20).reshape(4, 5)
y = np.asarray([0, 1, 2])

workspace.FeedBlob('s', s.astype(np.float32))
workspace.FeedBlob('y', y.astype(np.int32))

model.net.Gather(['s', 'y'], ['out'])

workspace.RunNetOnce(model.net)

out = workspace.FetchBlob('out')
print(out)

Мы получим:

[[  0.   1.   2.   3.   4.]
 [  5.   6.   7.   8.   9.]
 [ 10.  11.  12.  13.  14.]]

Одним из решений может быть преобразование s в одномерный массив и преобразование у таким же образом.Прежде всего, мы должны реализовать оператор для преобразования y .В этом случае мы будем использовать непрямую функцию ravel_multi_index:

class RavelMultiIndexOp(object):
    def forward(self, inputs, outputs):
        blob_out = outputs[0]

        index = np.ravel_multi_index(inputs[0].data, inputs[1].shape)

        blob_out.reshape(index.shape)
        blob_out.data[...] = index

Теперь мы можем переопределить наш оригинальный код:

model = ModelHelper(name="test")

s = np.arange(20).reshape(4, 5)
y = np.asarray([[0, 1, 2],[0, 1, 2]])

workspace.FeedBlob('s', s.astype(np.float32))
workspace.FeedBlob('y', y.astype(np.int32))

model.net.Python(RavelMultiIndexOp().forward)(
    ['y', 's'], ['y'], name='RavelMultiIndex'
)
model.net.Reshape('s', ['s_reshaped', 's_old'], shape=(-1, 1))

model.net.Gather(['s_reshaped', 'y'], ['out'])

workspace.RunNetOnce(model.net)

out = workspace.FetchBlob('out')
print(out)

Вывод:

[[  0.]
 [  6.]
 [ 12.]]

Вы можете изменить его на (1, -1).

...