Преобразование из pytorch в onnx для определения графа тензорного потока в tflite - TOCO не удалось - проверка типа не удалась - PullRequest
0 голосов
/ 16 ноября 2018

У меня есть простая сеть из двух сверточных слоев и одного полностью связного, который определен в pytorch по существу следующим образом:

def __init__([...])
    [...]
    self.conv1 = nn.Conv1d(1, channels_conv1, width_conv1)
    self.conv2 = nn.Conv1d(channels_conv1, channels_conv2, width_conv2)
    self.fc1 = nn.Linear(hidden_layer_size, 2)

def forward(self, x):
    x = functional.max_pool1d(functional.relu(self.conv1(x)), 2, stride=2)
    x = functional.max_pool1d(functional.relu(self.conv2(x)), 2, stride=2)
    x = x.view(-1, self.num_flat_features(x))
    x = functional.softmax(self.fc1(x))
    return x

Я хочу преобразовать его в tflite.Итак, сначала он конвертируется в onnx

torch.onnx.export(model, input, "net.onnx",
                  export_params=True,
                  input_names=['input'],
                  output_names=['output'],
                  verbose=true)

Затем я конвертирую результат в определение графа тензорного потока с помощью onnx-tf.Результирующий net.pb в порядке, так как он выдает тот же результат, что и оригинал с prepare(onnx.load('net.onnx')).run(...).

Однако у меня есть две проблемы: небольшая проблема в том, что график net.pb не содержит выводаузел больше, и я должен ловить рыбу для выходного узла.Во-вторых, когда я пытаюсь выполнить окончательное преобразование с помощью

tflite_convert --output_file=net.tflite --graph_def_file=net.pb --input_arrays=input --output_arrays=Softmax

, я получаю ошибку TOCO при проверке типа:

tensorflow.lite.python.convert.ConverterError: TOCO failed. See console for info.
2018-11-16 16:11:37.592030: I tensorflow/lite/toco/import_tensorflow.cc:1280] Converting unsupported operation: Where
2018-11-16 16:11:37.601384: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] Before Removing unused ops: 61 operators, 113 arrays (0 quantized)
2018-11-16 16:11:37.602005: I tensorflow/lite/toco/graph_transformations/graph_transformations.cc:39] Before general graph transformations: 61 operators, 113 arrays (0 quantized)
2018-11-16 16:11:37.602311: F tensorflow/lite/toco/graph_transformations/resolve_constant_gather.cc:105] Check failed: coords_array.data_type == ArrayDataType::kInt32 Only int32 indices are supported
Aborted (core dumped)

Я попытался покопаться в сети, но не могуКажется, я не нашел неприятного объекта, и я не нашел проблем, явно связанных с этой проблемой.Любой указатель на точку, где этот процесс мог бы сорваться, был бы отличным!

tf-nightly==1.13.0.dev20181116
onnx==1.3.0
torch-nightly==1.0.0.dev201811

и master (commit b5fef1b) из onnx-tenorflow

...