Использование нескольких операторов с перегрузкой операторов дает странную ошибку - PullRequest
0 голосов
/ 13 февраля 2019

У меня есть один класс с именем FloatTensor.Я перегружен операторы для + и * в этом.Вот код.


class FloatTensor {
    public:
    float val; // value of tensor 
    float grad; // value of grad
    Operation *frontOp =NULL, *backOp =NULL;
    FloatTensor* two;
    FloatTensor() {
        // default
    }

    FloatTensor(float val) {
        this->val = val;
    }

    FloatTensor(float val, Operation* op) {
        this->val = val;
        this->backOp = op;
    }

    void backward(float grad) {
        this->grad = grad;
        if(this->backOp != NULL) {
            this->backOp->backward(grad);
        }
    }
    FloatTensor exp() {
        this->frontOp = new ExponentOperation(this);
        return this->frontOp->compute();
    }

    FloatTensor operator * (FloatTensor &two) { 

        this->frontOp = new MultiplyOperation(this, &two);
        return this->frontOp->compute();
    }

    FloatTensor operator + (FloatTensor &two) { 
        this->frontOp = new AddOperation(this, &two);
        return this->frontOp->compute();
    }

    FloatTensor operator / (FloatTensor &two) { 

        this->frontOp = new DivideOperation(this, &two);
        return this->frontOp->compute();
    }

};

В моей основной функции, когда я пытаюсь выполнить простую перегрузку, все работает отлично

int main() {

    // X 
    FloatTensor x1(200); // heap declaration
    FloatTensor x2(300);

    // Weights
    FloatTensor w1(222);
    FloatTensor w2(907);

    FloatTensor temp = (x1*w1);

}

Однако, когда я пытаюсь перегрузить эту формулу с большим количеством операторов, подобных этому

int main() {

    // X 
    FloatTensor x1(200); // heap declaration
    FloatTensor x2(300);

    // Weights
    FloatTensor w1(222);
    FloatTensor w2(907);

    FloatTensor temp = (x1*w1) + (x2*w2);

}

Я получаю эту ошибку:

no operator "+" matches these operands -- operand types are: FloatTensor + FloatTensor

Я был бы очень признателен, если кто-то может объяснить, почему это происходит.Я заметил, что это работает:

x1*w1*x2*x1;
x1*w1 + x2;

Но x1*w1 + x2*w2 нет.

Очень странно ..

1 Ответ

0 голосов
/ 13 февраля 2019

Ваши операторы принимают в качестве аргумента ссылку не-const lvalue.Временные ссылки не привязываются к ссылкам не-const lvalue.Чтобы принять временные, используйте:

FloatTensor operator + (const FloatTensor &two)
...