Как вы генерируете представление ONNX нейронной сети, прошедшей предварительную подготовку? - PullRequest
0 голосов
/ 16 января 2019

Я пытаюсь создать файл ONNX для примера pytorch-pretrained-bert run_classifier.py.

В этом случае я запускаю его со следующими параметрами согласно основному README.md:

export GLUE_DIR=/tmp/glue_data

python run_classifier.py \
  --task_name MRPC \
  --do_train \
  --do_eval \
  --do_lower_case \
  --data_dir $GLUE_DIR/MRPC/ \
  --bert_model bert-base-uncased \
  --max_seq_length 128 \
  --train_batch_size 32 \
  --learning_rate 2e-5 \
  --num_train_epochs 3.0 \
  --output_dir /tmp/mrpc_output/

Следующий код изменен / добавлен в строке 552:

    # Save a trained model
    model_to_save = model.module if hasattr(model, 'module') else model  # Only save the model it-self
    output_model_file = os.path.join(args.output_dir, "pytorch_model.bin")
    if args.do_train:
        torch.save(model_to_save.state_dict(), output_model_file)

    # Save ONNX
    msl = args.max_seq_length
    dummy_input = torch.randn(1, msl, msl, msl, num_labels, device="cpu")
    output_onnx_file = os.path.join(args.output_dir, "classifier.onnx")
    torch.onnx.export(model, dummy_input, output_onnx_file)

Предполагается, что dummy_input соответствует входным данным модели с предварительной подготовкой. Я думаю, что sample_batch_size 1 подходит для моих нужд.

Некоторые предлагают, чтобы аргументы соответствовали аргументам метода forward () моделей. В этом случае:

class BertForSequenceClassification(PreTrainedBertModel):
    def forward(self, input_ids, token_type_ids=None, attention_mask=None, labels=None):

В этом случае, я считаю, ранги:

input_ids: 1 x 128 <the max_seq_length specified in the args>
token_type_ids: 1 x max_seq_length
attention_mask: 1 x max_seq_length
labels: 1 x 2 <the number of labels for MRPC>

поэтому эффективный вызов:

    dummy_input = torch.randn(1, 128, 128, 128, 2, device="cpu")

К сожалению, это приводит к ошибке:

Exception has occurred: RuntimeError
The expanded size of the tensor (2) must match the existing size (128) at non-singleton dimension 4.  Target sizes: [1, 128, 128, 128, 2].  Tensor sizes: [1, 128]

Кажется вероятным, что это что-то довольно простое. Предложения приветствуются!

...