Как добавить мои параметры (вес, смещение) к аргументам в символе? - PullRequest
0 голосов
/ 13 ноября 2018

Я пытаюсь изменить вес свертки, как это. Для этого я делаю, инициализирую свои параметры (вес, уклон), сверяю входное изображение, используя их. Но это показывает ошибку, потому что мои параметры не являются аргументами в символе.

Как добавить мои параметры к аргументам в символе? Если бы вы дали мне знать, я был бы очень признателен.

1 Ответ

0 голосов
/ 21 ноября 2018

Если вы хотите передать аргументы пользовательскому оператору, вы должны сделать это с помощью метода init.

С https://github.com/apache/incubator-mxnet/issues/5580 вот фрагмент, иллюстрирующий, что вам нужно:

class Softmax(mx.operator.CustomOp):

    def __init__(self, xxx, yyy):  # arguments xxx, and yyy
        self.xxx = xxx
        self.yyy = yyy
    def forward(self, is_train, req, in_data, out_data, aux):
        x = in_data[0].asnumpy()
        y = np.exp(x - x.max(axis=1).reshape((x.shape[0], 1)))
        y /= y.sum(axis=1).reshape((x.shape[0], 1))
        print self.xxx, self.yyy
        self.assign(out_data[0], req[0], mx.nd.array(y))

    def backward(self, req, out_grad, in_data, out_data, in_grad, aux):
        l = in_data[1].asnumpy().ravel().astype(np.int)
        y = out_data[0].asnumpy()
        y[np.arange(l.shape[0]), l] -= 1.0
        self.assign(in_grad[0], req[0], mx.nd.array(y))

@mx.operator.register("softmax")
class SoftmaxProp(mx.operator.CustomOpProp):
    def __init__(self, xxx, yyy):
        super(SoftmaxProp, self).__init__(need_top_grad=False)

        # add parameter
        self.xxx = xxx
        self.yyy = yyy
    def list_arguments(self):
        return ['data', 'label', 'xxx', 'yyy']

    def list_outputs(self):
        return ['output']

    def infer_shape(self, in_shape):
        data_shape = in_shape[0]
        label_shape = (in_shape[0][0],)
        output_shape = in_shape[0]
        return [data_shape, label_shape], [output_shape], []

    def create_operator(self, ctx, shapes, dtypes):
        return Softmax(xxx=self.xxx, yyy=self.yyy)

Для получения полной информации посмотрите https://mxnet.incubator.apache.org/faq/new_op.html.

Vishaal

...