Как caffe вычисляет градиент при наличии нескольких ветвей? - PullRequest
1 голос
/ 19 марта 2019

Я сейчас читаю Caffe исходный код, и у меня возник вопрос.

Возьмите, например, caffe/relu_layer.cpp.При вычислении градиента из

void ReLULayer<Dtype>::Backward_cpu(const vector<Blob<Dtype>*>& top,
    const vector<bool>& propagate_down,
    const vector<Blob<Dtype>*>& bottom) {
  if (propagate_down[0]) {
    const Dtype* bottom_data = bottom[0]->cpu_data();
    const Dtype* top_diff = top[0]->cpu_diff();
    Dtype* bottom_diff = bottom[0]->mutable_cpu_diff();
    const int count = bottom[0]->count();
    Dtype negative_slope = this->layer_param_.relu_param().negative_slope();
    for (int i = 0; i < count; ++i) {
      bottom_diff[i] = top_diff[i] * ((bottom_data[i] > 0)
          + negative_slope * (bottom_data[i] <= 0));
    }
  }
}

мы видим, что значение окончательно присвоено bottom_diff, что указывает на то, что значение является градиентом соответствующего нижнего двоичного объекта.

Однако, когда несколько слоевВозьмите один BLOB-объект в качестве входных данных, например, наложение нескольких ReLU слоев на один BLOB-объект. Как Caffe обрабатывает вычисление градиента?Первый слой ReLU изменяет bottom_diff, и кажется, что второй слой ReLU просто переопределяет его, вместо добавления двух градиентов.

Я нигде не видел выполнения градиентного суммирования, и япутает.Пожалуйста, сообщите мне, если я пропустил что-то важное, и большое спасибо.

1 Ответ

0 голосов
/ 30 марта 2019

Caffe автоматически вставляет разделенный слой, когда верхний шарик используется в нескольких нижних слоях.Это делается внутри Net<Dtype>::Init(...) путем вызова InsertSplits(...) из caffe/utils/insert_splits.cpp.

Пример:

Исходная сеть в NetParameter объект protobuf (узлами здесь являются слои):

data ---> conv1 -> conv2 -> ...
      \-> somelayer -> ...

Net Layer с в памяти после Net::Init():

data -> split ---> conv1 -> conv2 -> ...
               \-> somelayer -> ...

(Кстати, интересная деталь: .diff в активации Blobs назначенана Backward(), в то время как .diff в параметрах, доступных для обучения, добавляется к Backward().)

...