AddSymbolicGradients терпит неудачу на рекурсивной реализации - PullRequest
0 голосов
/ 31 декабря 2018

Я реализую простую RNN с API-интерфейсами Tensorflow C ++.

В настоящее время использую 1.13 из github.

Вызов AddSymbolicGradients (последняя строка) завершается неудачнов Ошибка сегментации .GDB говорит мне, что ошибка происходит внутри SymbolicGradientBuilder :: Initialize

auto input_slices = Split(scope, 1, x, window_size);

auto initial_state = Fill(scope, {batch_size, state_size}, 0);

vector<Output> states;
states.reserve(window_size+1);
states.push_back(initial_state);

for (int i=0; i!=window_size; i++) {
    auto concat = Concat(scope, InputList(initializer_list<Input>{input_slices[i], states[i]}), 1);
    auto new_state = Tanh(scope, Add(scope, MatMul(scope, concat, w_rnn), b_rnn));
    states.push_back(new_state);
}

// dense output
auto out = Tanh(scope, Add(scope, MatMul(scope, states[window_size], w_dense), b_dense));

// loss function
auto loss = ReduceMean(scope, Square(scope, Sub(scope, out, y)), {0, 1});

vector<Output> grad_outputs;
TF_CHECK_OK(AddSymbolicGradients(scope, {loss}, {w_rnn, w_dense, b_rnn, b_dense}, &grad_outputs));

Есть идеи?Спасибо

РЕДАКТИРОВАТЬ: GDB, где полный вывод

(gdb) where
#0  0x00005555559b45aa in tensorflow::(anonymous namespace)::SymbolicGradientBuilder::Initialize() ()
#1  0x00005555559b6f1c in tensorflow::AddSymbolicGradients(tensorflow::Scope const&, std::vector<tensorflow::Output, std::allocator<tensorflow::O
utput> > const&, std::vector<tensorflow::Output, std::allocator<tensorflow::Output> > const&, std::vector<tensorflow::Output, std::allocator<tens
orflow::Output> > const&, std::vector<tensorflow::Output, std::allocator<tensorflow::Output> >*) ()
#2  0x00005555559b9bb2 in tensorflow::AddSymbolicGradients(tensorflow::Scope const&, std::vector<tensorflow::Output, std::allocator<tensorflow::O
utput> > const&, std::vector<tensorflow::Output, std::allocator<tensorflow::Output> > const&, std::vector<tensorflow::Output, std::allocator<tens
orflow::Output> >*) ()
#3  0x000055555597ce6c in Rnn::train(Dataset, int, int, int) ()
#4  0x00005555558410d6 in main ()
...