This chapter contains the following sections:
Before reading this guide, it is strongly advised that the reader get familiar with PyTorch and the official PyTorch documentation on ONNX conversion first: guide 1, guide 2.
To convert a PyTorch model to an ONNX model, you need both the PyTorch model and the source code that generates the PyTorch model. Then you can load the model in Python using PyTorch, define dummy input values for all input variables of the model, and run the ONNX exporter to get an ONNX model.
ONNX support is built into PyTorch as a first class citizen. You don't need to look for third party converters like you would do with tensorflow. However, even with built-in ONNX conversion capability, some models are still difficult to export. In general, there are three possible road blockers:
For unsupported operators, you can either wait for them to be added to PyTorch, or you can do it yourself. For many cases, this is easier than you think. For example, in the following example, we need operator bitwise-or
, but it's not supported in PyTorch 1.4.0. A simple Google search reveals that support for this operator is already in the master branch of PyTorch, it just didn't make it to the latest official release (1.4.0). Simply adding the following code to the file /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py
(this path is different on different operating systems/python installs):
@parse_args('v') def bitwise_not(g, inp): if inp.type().scalarType() != 'Bool': return _unimplemented("bitwise_not", "non-bool tensor") return g.op("Not", inp)
will add support for this operator.
For PyTorch internal bugs, you can either fix it yourself or wait for the PyTorch team to fix it. Fortunately, this case is very rare.
For control flow, we will explain in detail in the following example.
We will look at this example: Text Summarization with Bert. We will convert this particular PyTorch model to ONNX format, completely from scratch.
Intuitively speaking, the PyTorch to ONNX converter is a tracer. It takes a loaded model, and a dummy input for the model. It then runs the model based on the provided input data, recording what happens internally in the model. It then reconstruct an ONNX model that does exactly the same thing, and save the ONNX model to disk. For many types of models, this method works just fine. However, whenever a model contains control flow, like for
loops or if
statements, the tracer method will fail, simply because the tracer is never aware of the existence of the control flow statements, it faithfully records the flow based on the supplied input. For example, if the model contains a for
loop that loops for max_step
number of times, in a tracer based exporter, the for
loop will simply be expanded max_step
times, whichever value max_step
happens to be in the supplied input to the exporter (let's say the value is a
). When we run the exported model with a different value of max_step
(let's say now the value is b
), the model will ignore that, and simply run the loop for a
times, rendering the result useless in most times.
To solve this issue, PyTorch has another method completely different to the tracer based method to export models with control flow. It's called a script based method. Intuitively what happens is that the model source code is 'compiled', and analyzed. The PyTorch 'compiler' will correctly capture any control flow, and correctly export the model to ONNX format. This sounds like a proper solution to the problem, however currently the script based method has significant limitation on language feature support of the model source code, meaning that there are certain Python language features (for example lambda) you cannot use when defining your model. Unless the model is coded with the mission 'exporting to ONNX' in mind, it is generally very difficult to rewrite the model source code to comply with the requirements of a script based method.
MarkLogic is a document database, we naturally want to work with models that handle text. Unfortunately, almost all models that handle text contains control flow (with a small number of exceptions), because most models construct the output in a recursive/iterative way (for example, for each word in the input document, generate the next output word). This makes exporting these PyTorch models to ONNX more challenging.
Fortunately, with a good understanding of the model, the exporting mechanism and some coding, and ever growing ONNX operator support, we can convert lots of text-handling models to ONNX.
Text summarization is an important task in Natural Language Processing (NLP). The objective is to take a long article and return a short summarization. There are plenty of research results on this topic. We pick the most recent one Text Summarization with Pretrained Encoders to demonstrate the conversion process from a model produced by PyTorch (with no intention to be converted later) to ONNX. It's worth noticing that this model is based on BERT which is a highly sophisticated pretrained language model trained on massive text corpus on massive amount of computation power by Google, to be used as a bootstrap model for other NLP related tasks. Success of converting this model to ONNX will demonstrate that the ONNX format is quite capable, and with ONNX support in MarkLogic, many of your pretrained model can work properly in the MarkLogic database. With these in mind, let's start with preparing the environment.
Install Python3 (if you don't have it), it most certainly comes with pip. Notice that for macOS users and some Linux users, you need to make sure you're using the correct Python, since your operating system comes with one pre-installed. For this particular task, we need at least Python 3.6.
Clone this git repo for the paper, then install the prerequisites by executing
pip3 install --user -r requirements.txt
Although "torch==1.1.0" is specified, we still want to try the latest PyTorch (1.4.0 as of this writing) first, due to possibly better ONNX operator coverage, and overall improved functionality. If the newest version of PyTorch failed, we then revert to the version specified in the requirements. You can install the latest PyTorch here.
Now follow the instruction described by the git repo, to download pretrained models, and training/testing datasets. We will be using CNN/DM BertExtAbs, the Abstractive Summarization model based on Bert, trained with CNN/DM dataset. For datasets, we use the prepared data.
After downloading and decompressing those files, move the model file to models
directory, and move the datasets to bert_data
directory. After those steps, in addition to the cloned source code, your models
directory should contain a file model_step_148000.pt
, and your bert_data
directory should contain lots of files with name similar to cnndm.test.0.bert.pt
.
We are now ready to edit the source code to add a function to export the model to ONNX format.
At this point, we need to read through the source code that generates the model first. Since our goal is to convert this model to ONNX format, load it into MarkLogic and perform summarization on a piece of article, we need to understand how that is done in PyTorch first. Understanding the model is always the most important and most difficult part of the conversion. For this particular model, in order to summarize a raw piece of text, notice that the author suggests using -mode test_text -text_src $text_file -test_from $ckpt_file -mode abs
. Following the code path we understand that the function test_text_abs
in file train_abstractive.py
is our main guy. The function mostly does the following things:
Let's start by trying to export the loaded model without any post-processing first, just to be sure that all operators are supported. We modify the train.py
file to add a new mode called onnx_export
, and then create a new file onnx_export.py
under src
. Put the following code in onnx_export.py
:
import torch from models import data_loader, model_builder from models.data_loader import load_dataset from models.model_builder import AbsSummarizer model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size', 'enc_layers', 'enc_hidden_size', 'enc_ff_size', 'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder', 'ff_actv', 'use_interval'] def onnx_export(args): device = "cpu" checkpoint = torch.load( args.test_from, map_location=lambda storage, loc: storage) opt = vars(checkpoint['opt']) for k in opt.keys(): if (k in model_flags): setattr(args, k, opt[k]) model = AbsSummarizer(args, device, checkpoint) model.eval() test_iter = data_loader.Dataloader( args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) for input_data in test_iter: dummy_input = ( input_data.src.index_select(0, torch.tensor([0])), input_data.tgt.index_select(0, torch.tensor([0])), input_data.segs.index_select(0, torch.tensor([0])), input_data.clss.index_select(0, torch.tensor([0])), input_data.mask_src.index_select(0, torch.tensor([0])), input_data.mask_tgt.index_select(0, torch.tensor([0])), input_data.mask_cls.index_select(0, torch.tensor([0])) ) torch.onnx.export( model, dummy_input, "AbsSummarizer.onnx", opset_version=11) break
The gist of the above code is to load the model just like when doing summarization from raw text, and using the first batch of input data as dummy input, export the model to ONNX format. The construction of dummy_input
is dictated by the AbsSummarizer
class's forward
function. All PyTorch model has a forward
function, the signature of which determines the input and output of the model. We then extract the required input data from the first batch, feed it to the ONNX exporter and try to export the model as ONNX model.
python3 train.py -mode onnx_export -task abs -test_from ../models/model_step_148000.pt -bert_data_path ../bert_data/cnndm
under directory src
. Unsurprisingly, we are greeted with an error:
RuntimeError: Subtraction, the `-` operator, with a bool tensor is not supported. If you are trying to invert a mask, use the `~` or `logical_not()` operator instead.
This is an easy fix. Just do as the error message suggests and fix the code, and try again.
We're now greeted with another error message:
RuntimeError: Only tuples, lists and Variables supported as JIT inputs /outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type NoneType
Looking at the definition of AbsSummarizer
class in model_builder.py
, you will notice that the model returns two output, one of which is None
. That's our culprit! Simply deleting the None
, and let's try again.
This time it's successful! The command finishes without error, and there is a new file AbsSummarizer.onnx
which is 843 MB in our src
directory. However, notice that we do have a couple of Warnings:
PreSumm/src/models/encoder.py:42: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! emb = emb + self.pe[:, :emb.size(1)]
PreSumm/src/models/decoder.py:64: TracerWarning: Converting a tensor to a Python index might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs! :tgt_pad_mask.size(1)], 0)
Warnings like these are pretty self-explanatory: A variable is being treated as constant. So when you run the exported model with a different set of inputs, the result will not change, it'll still be the result based on the input we used during exporting, just like the case with control flows, rendering the exported model completely useless!
To get around this issue, use torch.index_select
instead of converting torch.tensor
to Python index. Do notice that different fixes are required for different scenarios, index_select
is just one of the fixes which works in this case. So this code in question:
emb = emb + self.pe[:, :emb.size(1)]
emb = emb+self.pe.index_select(1, torch.arange(emb.size(1)))
Do the same with the other warning, we can now export the base AbsSummarizer model to ONNX format warning free.
Now that we know the base model, without post processing, can be exported successfully. However, notice that in the definition of the base model, it only does a single round of computation, generating one 'word' of the output summarization. In order to generate the full summarization, we need to imitate the predictor.translate
function call, to construct a real working ONNX summarization model.
Now we need to look at the translate
and _fast_translate_batch
functions in predictor.py
. Unsurprisingly, in _fast_translate_batch
function which does the real work of generating the summarization, we see a for
loop:
for step in range(max_length):
Here max_length
is the maximum length (in terms of word) of the summarization, and step
is the length of current work-in-progress summarization. Recall that to export control flow we can use the script based exporter, but since this piece of code contains many advanced Python features that are not supported by the script based exporter, this option becomes unpractical (but still possible, you can always rewrite the code from scratch).
From here on there is no official way to proceed. In this particular case we choose to export two models, one representing initialization and the first loop, the other representing the loop body. We take the control flow outside of the model, to be handled by application code(in other words, in XQuery or javascript in MarkLogic). In this case, the original application (pseudo)code transforms from a single ort.run:
// pseudocode, it doesn't run! let session = ort.session("text_summarization_all_in_one.onnx") let input = article("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") return ort.run(session, input)
To a slightly more complicated one with a for
loop:
// pseudocode, it doesn't run! let init_loop = ort.session(init_loop.onnx) let loop_body = ort.session(loop_body.onnx) let init_loop_input = article("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.") let init_loop_output = ort.run(init_loop, init_loop_input) let loop_body_input = init_loop_output let loop_body_output for step in range(max_step): loop_body_output = ort.run(loop_body, loop_body_input) loop_body_input = loop_body_output return loop_body_output
To do this, we need to analyze what's happening inside _fast_translate_batch
function, and define our own two models. It takes quite a while and does needs a good understanding of the model building and evaluation process, involving many more error and warning messages, whose details will be omitted here. Eventually we end up with the following two new model definitions in model_builder.py
(this is far from an optimal definition; the objective here is to make as few modifications to the original code as possible to make it work):
class InitLoopModel(nn.Module): def __init__(self, args, device, checkpoint): super(InitLoopModel, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator( self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight self.load_state_dict(checkpoint['model'], strict=False) self.to(device) def forward(self, src, segs, step): min_length = self.args.min_length beam_size = self.args.beam_size mask_src = ~(src == 0) batch_size = src.size(0) src_features = self.bert(src, segs, mask_src) device = src_features.device dec_states = self.decoder.init_decoder_state( src, src_features, with_cache=False) dec_states.src = tile(dec_states.src, beam_size, 0) src_features = tile(src_features, beam_size, dim=0) beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=device) alive_seq = torch.full([batch_size * beam_size, 1], 1, dtype=torch.long, device=device) const_topk_log_probs = torch.tensor( [0.0] + [float("-inf")] * (beam_size - 1), device=device) topk_log_probs = (const_topk_log_probs.repeat(batch_size)) decoder_input = alive_seq[:, -1].view(1, -1) decoder_input = decoder_input.transpose(0, 1) dec_out, dec_states = self.decoder( decoder_input, src_features, dec_states, step=step) log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) endprob = torch.tensor([-1e20]).repeat(log_probs.size(0)) new_log_probs = torch.cat([log_probs.index_select(-1, torch.arange(2)), endprob.view(-1).unsqueeze( 1), log_probs.index_select(-1, torch.arange(start=3, end=log_probs.size(1)))], -1) + topk_log_probs.view( -1).unsqueeze(1) alpha = self.args.alpha length_penalty = ((5.0 + (1)) / 6.0) ** alpha curr_scores = new_log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) batch_index = (topk_beam_index + beam_offset.index_select(0, torch.arange(topk_beam_index.size(0))).unsqueeze(1)) select_indices = batch_index.view(-1) alive_seq = torch.cat([alive_seq.index_select( 0, select_indices), topk_ids.view(-1, 1)], -1) src_features = src_features.index_select(0, select_indices) dec_states.src = dec_states.src.index_select(0, select_indices) return src_features, dec_states.src, dec_states.previous_input, dec_states.previous_layer_inputs, alive_seq, topk_log_probs class LoopBodyModel(nn.Module): def __init__(self, args, device, checkpoint): super(LoopBodyModel, self).__init__() self.args = args self.device = device self.bert = Bert(args.large, args.temp_dir, args.finetune_bert) self.vocab_size = self.bert.model.config.vocab_size tgt_embeddings = nn.Embedding( self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0) self.decoder = TransformerDecoder(self.args.dec_layers, self.args.dec_hidden_size, heads=self.args.dec_heads, d_ff=self.args.dec_ff_size, dropout=self.args.dec_dropout, embeddings=tgt_embeddings) self.generator = get_generator( self.vocab_size, self.args.dec_hidden_size, device) self.generator[0].weight = self.decoder.embeddings.weight self.load_state_dict(checkpoint['model'], strict=False) self.to(device) def forward(self, step, min_length, src_features, decoder_state_src, decoder beam_size = self.args.beam_size batch_size = src_features.size(0).div(beam_size) beam_offset = torch.arange( 0, batch_size * beam_size, step=beam_size, dtype=torch.long, device=self.device) decoder_input = alive_seq[:, -1].view(1, -1) decoder_input = decoder_input.transpose(0, 1) dec_states = TransformerDecoderState(decoder_state_src) dec_states.previous_input = decoder_state_previous_input dec_states.previous_layer_inputs = decoder_state_previous_layer_inputs dec_out, dec_states = self.decoder( decoder_input, src_features, dec_states, step=step) log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0)) vocab_size = log_probs.size(-1) small = torch.tensor([-1e20]) tooshort = small*torch.lt(step, min_length).float() longenough = log_probs[:, 2]*((~step.lt(min_length)).float()) endprob = tooshort + longenough new_log_probs = torch.cat([log_probs.index_select(-1, torch.arange(2) , endprob.view(-1).unsqueeze(1), log_probs.index_select(-1, torch.arange(start=3, end=log_probs.size(1)))], -1) + topk_log_probs.view(-1).unsqueeze(1) alpha = self.args.alpha length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha curr_scores = new_log_probs / length_penalty curr_scores = curr_scores.reshape(-1, beam_size * vocab_size) topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1) topk_log_probs = topk_scores * length_penalty topk_beam_index = topk_ids.div(vocab_size) topk_ids = topk_ids.fmod(vocab_size) batch_index = topk_beam_index + \ beam_offset.index_select(0, torch.arange( topk_beam_index.size(0))).unsqueeze(1) select_indices = batch_index.view(-1) alive_seq = torch.cat([alive_seq.index_select( 0, select_indices), topk_ids.view(-1, 1)], -1) src_features = src_features.index_select(0, select_indices) dec_states.src = dec_states.src.index_select(0, select_indices) results = alive_seq.index_select( 0, select_indices.index_select(0, torch.tensor(0))) return src_features, dec_states.src, dec_states.previous_input, dec_states.previous_layer_inputs, alive_seq, topk_log_probs, results, endprob
And our export code in onnx_export.py
becomes:
init_model = InitLoopModel(args, device, checkpoint) init_model.eval() test_iter = data_loader.Dataloader(args, load_dataset(args, 'test', shuffle=False), args.test_batch_size, device, shuffle=False, is_test=True) loop_body_model = LoopBodyModel(args, device, checkpoint) loop_body_model.eval() for batch in test_iter: init_inputs = (batch.src.index_select(0, torch.tensor([0])), batch.segs.index_select(0, torch.tensor([0])), torch.tensor([0])) torch.onnx.export(init_model, init_inputs, "init_loop.onnx", verbose=False, input_names=["src", "segs", "step"], output_names=["src_features", "decoder_states_src", "decoder_states_previous_input", "decoder_states_previous_layer_inputs", "alive_seq", "topk_log_probs"], pset_version=11, dynamic_axes={"src": {0: "batch"}, "segs": {0: "batch"}, "src_features": {0: "batchXbeam"}, "decoder_states_src": {0: "batchXbeam"}, "decoder_states_previous_input": {0: "batchXbeam"}, "decoder_states_previous_layer_inputs": {0: "batch", 1: "batchXbeam"}, "alive_seq": {0: "batchXbeam"}, "topk_log_probs": {0: "batch"}}) src_features, decoder_state_src, decoder_state_previous_input, decoder_state_previous_layer_inputs, alive_seq, topk_log_probs = init_model.forward( batch.src.index_select(0, torch.tensor([0])), batch.segs.index_select(0, torch.tensor([0])), torch.tensor([0])) loop_inputs = (torch.tensor(1), torch.tensor(20), src_features, decoder_state_src, decoder_state_previous_input, decoder_state_previous_layer_inputs, alive_seq, topk_log_probs) torch.onnx.export(loop_body_model, loop_inputs, "loop_body.onnx", verbose=False, input_names=["step", "min_length", "src_features_in", "decoder_states_src_in", "decoder_states_previous_input_in", "decoder_states_previous_layer_inputs_in", "alive_seq_in", "topk_log_probs_in"], output_names=["src_features_out", "decoder_states_src_out", "decoder_states_previous_input_out", "decoder_states_previous_layer_inputs_out", "alive_seq_out", "topk_log_probs_out", "results", "endprob"], opset_version=11, dynamic_axes={"src_features_in": {0: "batchXbeam"}, "decoder_states_src_in": {0: "batchXbeam"}, "decoder_states_previous_input_in": {0: "batchXbeam"}, "decoder_states_previous_layer_inputs_in": {0: "batch", 1: "batchXbeam", 2: "prev_step"}, "alive_seq_in": {0: "batchXbeam", 1: "prev_step"}, "topk_log_probs_in": {0: "batch"}, "src_features_out": {0: "batchXbeam"}, "decoder_states_src_out": {0: "batchXbeam"}, "decoder_states_previous_input_out": {0: "batchXbeam", 1: "step"}, "decoder_states_previous_layer_inputs_out": {0: "batch", 1: "batchXbeam", 2: "step"}, "alive_seq_out": {0: "batchXbeam", 1: "step"}, "topk_log_probs_out": {0: "batch"}, "results": {0: "batch", 2: "step"}}) break
After exporting the two models, for them to work properly in MarkLogic, we need to also transform the preprocessing and postprocessing code to XQuery
or Javascript
. This is much easier than exporting the model, and a final working example looks like this (again, not the most optimal code, the objective is to faithfully translate the original python code to javascript):
'use strict'; function whitespace_tokenize(s) { return s.split(" ") } function wordpiece_tokenize(s, vocab) { let output = [] let wstokens = whitespace_tokenize(s) for (let i = 0; i < wstokens.length; i++) { let token = wstokens[i] if (token.length > 100) {output.push("[UNK]") continue } let is_bad = false let start = 0 let sub_tokens = [] while (start < token.length) { let end = token.length let cur_substr = null while (start < end) { let substr = token.substr(start, end - start) if (start > 0) substr = "##" + substr if (vocab.hasOwnProperty(substr)) { cur_substr = substr break } end -= 1 } if (cur_substr == null) { is_bad = true break } sub_tokens.push(cur_substr) start = end } if (is_bad) { output.push("[UNK]") } else { for (let j = 0; j < sub_tokens.length; j++) { output.push(sub_tokens[j]) } } } return output } function tokenize(s, vocab) { s = s.trim().toLowerCase() let pretokens = s.split(" ") let tokens = ["[CLS]"] for (let i = 0; i < pretokens.length; i++) { let t = pretokens[i] let subtokens = wordpiece_tokenize(t, vocab) for (let j = 0; j < subtokens.length; j++) { let token = subtokens[j] tokens.push(token) if (tokens.length >= 511) { break; } } if (tokens.length >= 511) { break; } } tokens.push("[SEP]") return tokens } function preprocess(s, vocab) { var tokens = tokenize(s, vocab) var src = [] var segs = [] for (var i = 0; i < 512; i++) { if (i < tokens.length) { src.push(vocab[tokens[i]]) segs.push(0) } else { src.push(0) segs.push(1) } } return [src, segs] } function getSummarization(result, reverse_vocab) { let s = "" for (let i = 0; i < result.length; i++) { s += reverse_vocab[result[i]] if (i != result.length - 1) { s += " " } } return s } function postprocess(s) { s = s.replace(/ ##/g, "") s = s.replace(/\[unused0\]/g, "") s = s.replace(/\[unused1\]/g, "") s = s.replace(/\[unused2\]/g, "") s = s.replace(/\[unused3\]/g, "") s = s.replace(/\[PAD\]/g, "") s = s.replace(/ +/g, " ") s = s.trim() return s } let vocab = cts.doc("vocab.json").toObject() let reverse_vocab = cts.doc("reverse_vocab.json").toObject() let article = "(CNN) An Iranian chess referee says she is frightened to return home after she was criticized online for not wearing the appropriate headscarf during an international tournament. Currently the chief adjudicator at the Women's World Chess Championship held in Russia and China, Shohreh Bayat says she fears arrest after a photograph of her was taken during the event and was then circulated online in Iran. \"They are very sensitive about the hijab when we are representing Iran in international events and even sometimes they send a person with the team to control our hijab,\" Bayat told CNN Sport in a phone interview Tuesday. The headscarf, or the hijab, has been a mandatory part of women's dress in Iran since the 1979 Islamic revolution but, in recent years, some women have mounted opposition and staged protests about headwear rules. Bayat said she had been wearing a headscarf at the tournament but that certain camera angles had made it look like she was not. \"If I come back to Iran, I think there are a few possibilities. It is highly possible that they arrest me [...] or it is possible that they invalidate my passport,\" added Bayat. \"I think they want to make an example of me.\" The photographs were taken at the first stage of the chess championship in Shanghai, China, but Bayat has since flown to Vladivostok, Russia, for the second leg between Ju Wenjun and Aleksandra Goryachkina. She was left \"panicked and shocked\" when she became aware of the reaction in Iran after checking her phone in the hotel room. The 32- year-old said she felt helpless as websites reportedly condemned her for what some described as protesting the country's compulsory law. Subsequently, Bayat has decided to no longer wear the headscarf. \"I'm not wearing it anymore because what is the point? I was just tolerating it, I don't believe in the hijab,\" she added. \"People must be free to choose to wear what they want, and I was only wearing the hijab because I live in Iran and I had to wear it. I had no other choice.\" Bayat says she sought help from the country's chess federation. She says the federation told her to post an apology on her social media channels. She agreed under the condition that the federation would guarantee her safety but she said they refused. \"My husband is in Iran, my parents are in Iran, all my family members are in Iran. I don't have anyone else outside of Iran. I don't know what to say, this is a very hard situation,\" she said. CNN contacted the Iranian Chess Federation on Tuesday but has yet to receive a response." let processed = preprocess(article, vocab) let src = processed[0] let segs = processed[1] let initLoop = ort.session(cts.doc("init_loop.onnx")) let loopBody = ort.session(cts.doc("loop_body.onnx")) let srcName = "src" let segsName = "segs" let stepName = "step" let batchSize = 1 let inputs = {} for (let i = 0; i < ort.sessionInputCount(initLoop); i++) { let name = ort.sessionInputName(initLoop, i) if (name == srcName) { let shape = ort.sessionInputType(initLoop, i)["shape"] shape[0] = batchSize inputs[name] = ort.value(src, shape, ort.sessionInputType(initLoop, i)["tensorType"]) } else if (name == segsName) { let shape = ort.sessionInputType(initLoop, i)["shape"] shape[0] = batchSize inputs[name] = ort.value(segs, shape, ort.sessionInputType(initLoop, i)["tensorType"]) } else if (name == stepName) { inputs[name] = ort.value([0], [1], "INT64") } } let initOutputs = ort.run(initLoop, inputs) let names = [] for (let i = 0; i < ort.sessionOutputCount(initLoop); i++) { names.push(ort.sessionOutputName(initLoop, i)) } let loopBodyInputs = {} for (let i = 0; i < names.length; i++) { loopBodyInputs[names[i] + "_in"] = initOutputs[names[i]] } let step = 0 let maxStep = 50 let loopBodyOutputs let result let minLengthVal = ort.value([20], [1], "INT64") while (step < maxStep) { let stepVal = ort.value([step], [1], "INT64") loopBodyInputs["step"] = stepVal loopBodyInputs["min_length"] = minLengthVal loopBodyOutputs = ort.run(loopBody, loopBodyInputs) for (let i = 0; i < names.length; i++) { loopBodyInputs[names[i] + "_in"] = loopBodyOutputs[names[i] + "_out"] } step++ let resultVal = loopBodyOutputs["results"] result = ort.valueGetArray(resultVal) if (result[result.length - 1] == vocab["[unused2]"]) { break; } } let summarization = postprocess(getSummarization(result, reverse_vocab)) summarization
And the summarization looks like this:
shohreh bayat says she fears arrest after a photograph of her was circulated online
Above is just one example of trying to convert a state-of-the-art PyTorch NLP model to ONNX. It is true that the conversion is not a one-click solution; it actually requires a rather good understanding of PyTorch and the model itself and some non-trivial problem-solving through debugging/coding. However, this should be expected given the complex nature of the model. BERT is a very significant step forward for NLP, and very widely used. It is actually used in Google search today. Also this model is not authored with conversion to ONNX in mind, making the job more difficult. Given the deep integration of PyTorch and ONNX, if the author of a model writes code with ONNX in mind, the conversion process would be much smoother.
Again, the code in this example is far from optimal or even idiomatic. This is just one way to make it work, as a proof of concept. With a better understanding of PyTorch and the model, there would definitely be much better solutions.
A summary of the above working code is available as a git patch to the original source code.