Я пытаюсь преобразовать реализацию BERT PyTorch отсюда (https://github.com/codertimo/BERT-pytorch) в ONNX (и, надеюсь, в coreml), но реализация блока Transformer:
class TransformerBlock(nn.Module):
"""
Bidirectional Encoder = Transformer (self-attention)
Transformer = MultiHead_Attention + Feed_Forward with sublayer connection
"""
def __init__(self, hidden, attn_heads, feed_forward_hidden, dropout):
"""
:param hidden: hidden size of transformer
:param attn_heads: head sizes of multi-head attention
:param feed_forward_hidden: feed_forward_hidden, usually 4*hidden_size
:param dropout: dropout rate
"""
super().__init__()
self.attention = MultiHeadedAttention(h=attn_heads, d_model=hidden)
self.feed_forward = PositionwiseFeedForward(d_model=hidden, d_ff=feed_forward_hidden, dropout=dropout)
self.input_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.output_sublayer = SublayerConnection(size=hidden, dropout=dropout)
self.dropout = nn.Dropout(p=dropout)
def forward(self, x, mask):
x = self.input_sublayer(x, lambda _x: self.attention.forward(_x, _x, _x, mask=mask)) // <-- Error!
x = self.output_sublayer(x, self.feed_forward)
return self.dropout(x)
вызывает ошибку:
builtins.ValueError: Auto nesting doesn't know how to process an input object of type bert_pytorch.model.transformer.TransformerBlock.forward.<locals>.<lambda>. Accepted types: Tensors, or lists/tuples of them
Я понимаю, что лямбда вызывает ошибку (или, по крайней мере, это то, что я думаю, что происходит ), но я не уверен, как это исправить - если, действительно, исправление возможно, с текущим ONNX. Есть ли способ переписать это без лямбды, например? И может ли это обойти проблему, не ломая модель? (Я все еще довольно новичок в Python.)
Мое преобразование довольно простое:
export_model = bert
model_name = "bert.onnx"
dummy_input = (torch.randn(1, 40).long().cuda(), torch.randn(1, 40).long().cuda())
torch.onnx.export(export_model,
dummy_input,
model_name,
input_names=['query_sequence'],
output_names=['token_prediction'],
operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
Строки torch.randn(1, 40).long().cuda()
были получены методом проб и ошибок, а использование ONNX_ATEN_FALLBACK
было найдено в поиске Google о преобразовании трансформаторов.
В отношении связанной заметки мне также любопытно, если кто-нибудь знает, является ли полное преобразование BERT-PyTorch> ONNX> CoreML просто невозможным на этом этапе (по крайней мере, пока CoreML 3 не будет готов к работе). Если это произойдет абсолютно , а не , я спасу себя от ударов головой на стол!
Любые мысли приветствуются.