Loading TOC...
Matches for cat:guide/app-dev have been highlighted. remove
Application Developer's Guide (PDF)

MarkLogic Server 11.0 Product Documentation
Application Developer's Guide
— Chapter 21

Convert PyTorch Model to ONNX Model

This chapter contains the following sections:

Before reading this guide, it is strongly advised that the reader get familiar with PyTorch and the official PyTorch documentation on ONNX conversion first: guide 1, guide 2.

General Steps

To convert a PyTorch model to an ONNX model, you need both the PyTorch model and the source code that generates the PyTorch model. Then you can load the model in Python using PyTorch, define dummy input values for all input variables of the model, and run the ONNX exporter to get an ONNX model.

Case Study: Text Summarization with Bert

ONNX support is built into PyTorch as a first class citizen. You don't need to look for third party converters like you would do with tensorflow. However, even with built-in ONNX conversion capability, some models are still difficult to export. In general, there are three possible road blockers:

  • unsupported operators
  • control flow
  • PyTorch internal bugs

For unsupported operators, you can either wait for them to be added to PyTorch, or you can do it yourself. For many cases, this is easier than you think. For example, in the following example, we need operator bitwise-or, but it's not supported in PyTorch 1.4.0. A simple Google search reveals that support for this operator is already in the master branch of PyTorch, it just didn't make it to the latest official release (1.4.0). Simply adding the following code to the file /Library/Frameworks/Python.framework/Versions/3.7/lib/python3.7/site-packages/torch/onnx/symbolic_opset9.py (this path is different on different operating systems/python installs):

@parse_args('v')
def bitwise_not(g, inp):
  if inp.type().scalarType() != 'Bool':
     return _unimplemented("bitwise_not", "non-bool tensor")
  return g.op("Not", inp)

will add support for this operator.

For PyTorch internal bugs, you can either fix it yourself or wait for the PyTorch team to fix it. Fortunately, this case is very rare.

For control flow, we will explain in detail in the following example.

We will look at this example: Text Summarization with Bert. We will convert this particular PyTorch model to ONNX format, completely from scratch.

How does the Converter Work?

Intuitively speaking, the PyTorch to ONNX converter is a tracer. It takes a loaded model, and a dummy input for the model. It then runs the model based on the provided input data, recording what happens internally in the model. It then reconstruct an ONNX model that does exactly the same thing, and save the ONNX model to disk. For many types of models, this method works just fine. However, whenever a model contains control flow, like for loops or if statements, the tracer method will fail, simply because the tracer is never aware of the existence of the control flow statements, it faithfully records the flow based on the supplied input. For example, if the model contains a for loop that loops for max_step number of times, in a tracer based exporter, the for loop will simply be expanded max_step times, whichever value max_step happens to be in the supplied input to the exporter (let's say the value is a). When we run the exported model with a different value of max_step (let's say now the value is b), the model will ignore that, and simply run the loop for a times, rendering the result useless in most times.

To solve this issue, PyTorch has another method completely different to the tracer based method to export models with control flow. It's called a script based method. Intuitively what happens is that the model source code is 'compiled', and analyzed. The PyTorch 'compiler' will correctly capture any control flow, and correctly export the model to ONNX format. This sounds like a proper solution to the problem, however currently the script based method has significant limitation on language feature support of the model source code, meaning that there are certain Python language features (for example lambda) you cannot use when defining your model. Unless the model is coded with the mission 'exporting to ONNX' in mind, it is generally very difficult to rewrite the model source code to comply with the requirements of a script based method.

MarkLogic is a document database, we naturally want to work with models that handle text. Unfortunately, almost all models that handle text contains control flow (with a small number of exceptions), because most models construct the output in a recursive/iterative way (for example, for each word in the input document, generate the next output word). This makes exporting these PyTorch models to ONNX more challenging.

Fortunately, with a good understanding of the model, the exporting mechanism and some coding, and ever growing ONNX operator support, we can convert lots of text-handling models to ONNX.

Let's now look at the example.

Prepare the Environment

Text summarization is an important task in Natural Language Processing (NLP). The objective is to take a long article and return a short summarization. There are plenty of research results on this topic. We pick the most recent one Text Summarization with Pretrained Encoders to demonstrate the conversion process from a model produced by PyTorch (with no intention to be converted later) to ONNX. It's worth noticing that this model is based on BERT which is a highly sophisticated pretrained language model trained on massive text corpus on massive amount of computation power by Google, to be used as a bootstrap model for other NLP related tasks. Success of converting this model to ONNX will demonstrate that the ONNX format is quite capable, and with ONNX support in MarkLogic, many of your pretrained model can work properly in the MarkLogic database. With these in mind, let's start with preparing the environment.

Install Python3 (if you don't have it), it most certainly comes with pip. Notice that for macOS users and some Linux users, you need to make sure you're using the correct Python, since your operating system comes with one pre-installed. For this particular task, we need at least Python 3.6.

Clone this git repo for the paper, then install the prerequisites by executing

pip3 install --user -r requirements.txt

Although "torch==1.1.0" is specified, we still want to try the latest PyTorch (1.4.0 as of this writing) first, due to possibly better ONNX operator coverage, and overall improved functionality. If the newest version of PyTorch failed, we then revert to the version specified in the requirements. You can install the latest PyTorch here.

Now follow the instruction described by the git repo, to download pretrained models, and training/testing datasets. We will be using CNN/DM BertExtAbs, the Abstractive Summarization model based on Bert, trained with CNN/DM dataset. For datasets, we use the prepared data.

After downloading and decompressing those files, move the model file to models directory, and move the datasets to bert_data directory. After those steps, in addition to the cloned source code, your models directory should contain a file model_step_148000.pt, and your bert_data directory should contain lots of files with name similar to cnndm.test.0.bert.pt.

We are now ready to edit the source code to add a function to export the model to ONNX format.

Export the Model to ONNX

At this point, we need to read through the source code that generates the model first. Since our goal is to convert this model to ONNX format, load it into MarkLogic and perform summarization on a piece of article, we need to understand how that is done in PyTorch first. Understanding the model is always the most important and most difficult part of the conversion. For this particular model, in order to summarize a raw piece of text, notice that the author suggests using -mode test_text -text_src $text_file -test_from $ckpt_file -mode abs. Following the code path we understand that the function test_text_abs in file train_abstractive.py is our main guy. The function mostly does the following things:

  • construct and load the pretrained model
  • load and preprocess input text data
  • run the model based on the input
  • post-process the output to generate the summarization.

Let's start by trying to export the loaded model without any post-processing first, just to be sure that all operators are supported. We modify the train.py file to add a new mode called onnx_export, and then create a new file onnx_export.py under src. Put the following code in onnx_export.py:

import torch
from models import data_loader, model_builder
from models.data_loader import load_dataset
from models.model_builder import AbsSummarizer

model_flags = ['hidden_size', 'ff_size', 'heads', 'emb_size',
  'enc_layers', 'enc_hidden_size', 'enc_ff_size',
  'dec_layers', 'dec_hidden_size', 'dec_ff_size', 'encoder',
  'ff_actv', 'use_interval']

def onnx_export(args):
  device = "cpu"
  checkpoint = torch.load(
    args.test_from, map_location=lambda storage, loc: storage)
  opt = vars(checkpoint['opt'])
  for k in opt.keys():
    if (k in model_flags):
      setattr(args, k, opt[k])

  model = AbsSummarizer(args, device, checkpoint)
  model.eval()

  test_iter = data_loader.Dataloader(
    args,
    load_dataset(args, 'test', shuffle=False),
    args.test_batch_size,
    device,
    shuffle=False,
    is_test=True)
  for input_data in test_iter:
    dummy_input = (
      input_data.src.index_select(0, torch.tensor([0])),
      input_data.tgt.index_select(0, torch.tensor([0])),
      input_data.segs.index_select(0, torch.tensor([0])),
      input_data.clss.index_select(0, torch.tensor([0])),
      input_data.mask_src.index_select(0, torch.tensor([0])),
      input_data.mask_tgt.index_select(0, torch.tensor([0])),
      input_data.mask_cls.index_select(0, torch.tensor([0]))
    )
    torch.onnx.export(
      model,
      dummy_input,
      "AbsSummarizer.onnx",
      opset_version=11)
    break

The gist of the above code is to load the model just like when doing summarization from raw text, and using the first batch of input data as dummy input, export the model to ONNX format. The construction of dummy_input is dictated by the AbsSummarizer class's forward function. All PyTorch model has a forward function, the signature of which determines the input and output of the model. We then extract the required input data from the first batch, feed it to the ONNX exporter and try to export the model as ONNX model.

Run

python3 train.py -mode onnx_export -task abs -test_from 
../models/model_step_148000.pt -bert_data_path ../bert_data/cnndm

under directory src. Unsurprisingly, we are greeted with an error:

RuntimeError: Subtraction, the `-` operator, with a bool tensor is not
supported. If you are trying to invert a mask, use the `~` or
`logical_not()` operator instead.

This is an easy fix. Just do as the error message suggests and fix the code, and try again.

We're now greeted with another error message:

RuntimeError: Only tuples, lists and Variables supported as JIT inputs
/outputs. Dictionaries and strings are also accepted but their usage is
not recommended. But got unsupported type NoneType

Looking at the definition of AbsSummarizer class in model_builder.py, you will notice that the model returns two output, one of which is None. That's our culprit! Simply deleting the None, and let's try again.

This time it's successful! The command finishes without error, and there is a new file AbsSummarizer.onnx which is 843 MB in our src directory. However, notice that we do have a couple of Warnings:

PreSumm/src/models/encoder.py:42: TracerWarning: Converting a tensor to
a Python index might cause the trace to be incorrect. We can't record
the data flow of Python values, so this value will be treated as a
constant in the future. This means that the trace might not generalize
to other inputs!
    emb = emb + self.pe[:, :emb.size(1)]

PreSumm/src/models/decoder.py:64: TracerWarning: Converting a tensor to
a Python index might cause the trace to be incorrect. We can't record
the data flow of Python values, so this value will be treated as a
constant in the future. This means that the trace might not generalize
to other inputs!
    :tgt_pad_mask.size(1)], 0)

Warnings like these are pretty self-explanatory: A variable is being treated as constant. So when you run the exported model with a different set of inputs, the result will not change, it'll still be the result based on the input we used during exporting, just like the case with control flows, rendering the exported model completely useless!

To get around this issue, use torch.index_select instead of converting torch.tensor to Python index. Do notice that different fixes are required for different scenarios, index_select is just one of the fixes which works in this case. So this code in question:

emb = emb + self.pe[:, :emb.size(1)]

becomes

emb = emb+self.pe.index_select(1, torch.arange(emb.size(1)))

Do the same with the other warning, we can now export the base AbsSummarizer model to ONNX format warning free.

Now that we know the base model, without post processing, can be exported successfully. However, notice that in the definition of the base model, it only does a single round of computation, generating one 'word' of the output summarization. In order to generate the full summarization, we need to imitate the predictor.translate function call, to construct a real working ONNX summarization model.

Now we need to look at the translate and _fast_translate_batch functions in predictor.py. Unsurprisingly, in _fast_translate_batch function which does the real work of generating the summarization, we see a for loop:

for step in range(max_length):

Here max_length is the maximum length (in terms of word) of the summarization, and step is the length of current work-in-progress summarization. Recall that to export control flow we can use the script based exporter, but since this piece of code contains many advanced Python features that are not supported by the script based exporter, this option becomes unpractical (but still possible, you can always rewrite the code from scratch).

From here on there is no official way to proceed. In this particular case we choose to export two models, one representing initialization and the first loop, the other representing the loop body. We take the control flow outside of the model, to be handled by application code(in other words, in XQuery or javascript in MarkLogic). In this case, the original application (pseudo)code transforms from a single ort.run:

// pseudocode, it doesn't run!
let session = ort.session("text_summarization_all_in_one.onnx")
let input = article("Lorem ipsum dolor sit amet, consectetur adipiscing
  elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
return ort.run(session, input)

To a slightly more complicated one with a for loop:

// pseudocode, it doesn't run!
let init_loop = ort.session(init_loop.onnx)
let loop_body = ort.session(loop_body.onnx)
let init_loop_input = article("Lorem ipsum dolor sit amet, consectetur adipiscing elit, sed do eiusmod tempor incididunt ut labore et dolore magna aliqua.")
let init_loop_output = ort.run(init_loop, init_loop_input)
let loop_body_input = init_loop_output
let loop_body_output
for step in range(max_step):
  loop_body_output = ort.run(loop_body, loop_body_input)
  loop_body_input = loop_body_output
return loop_body_output

To do this, we need to analyze what's happening inside _fast_translate_batch function, and define our own two models. It takes quite a while and does needs a good understanding of the model building and evaluation process, involving many more error and warning messages, whose details will be omitted here. Eventually we end up with the following two new model definitions in model_builder.py (this is far from an optimal definition; the objective here is to make as few modifications to the original code as possible to make it work):

class InitLoopModel(nn.Module):
  def __init__(self, args, device, checkpoint):
    super(InitLoopModel, self).__init__()
    self.args = args
    self.device = device
    self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
    self.vocab_size = self.bert.model.config.vocab_size
    tgt_embeddings = nn.Embedding(
      self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
    self.decoder = TransformerDecoder(self.args.dec_layers,
                                      self.args.dec_hidden_size,
                                      heads=self.args.dec_heads,
                                      d_ff=self.args.dec_ff_size,
                                      dropout=self.args.dec_dropout,
                                      embeddings=tgt_embeddings)

    self.generator = get_generator(
      self.vocab_size, self.args.dec_hidden_size, device)
    self.generator[0].weight = self.decoder.embeddings.weight
    self.load_state_dict(checkpoint['model'], strict=False)
    self.to(device)

  def forward(self, src, segs, step):
    min_length = self.args.min_length
    beam_size = self.args.beam_size
    mask_src = ~(src == 0)
    batch_size = src.size(0)
    src_features = self.bert(src, segs, mask_src)
    device = src_features.device
    dec_states = self.decoder.init_decoder_state(
      src, src_features, with_cache=False)
    dec_states.src = tile(dec_states.src, beam_size, 0)
    src_features = tile(src_features, beam_size, dim=0)
    beam_offset = torch.arange(
      0, batch_size * beam_size, step=beam_size, dtype=torch.long,
      device=device)
    alive_seq = torch.full([batch_size * beam_size, 1],
                            1, dtype=torch.long, device=device)
    const_topk_log_probs = torch.tensor(
      [0.0] + [float("-inf")] * (beam_size - 1), device=device)
    topk_log_probs = (const_topk_log_probs.repeat(batch_size))
    decoder_input = alive_seq[:, -1].view(1, -1)
    decoder_input = decoder_input.transpose(0, 1)
    dec_out, dec_states = self.decoder(
      decoder_input, src_features, dec_states, step=step)
    log_probs = self.generator.forward(dec_out.transpose(0, 1).squeeze(0))
    vocab_size = log_probs.size(-1)
    endprob = torch.tensor([-1e20]).repeat(log_probs.size(0))
    new_log_probs = torch.cat([log_probs.index_select(-1,
      torch.arange(2)), endprob.view(-1).unsqueeze(
      1), log_probs.index_select(-1, torch.arange(start=3,
      end=log_probs.size(1)))], -1) + topk_log_probs.view(
      -1).unsqueeze(1)
    alpha = self.args.alpha
    length_penalty = ((5.0 + (1)) / 6.0) ** alpha
    curr_scores = new_log_probs / length_penalty
    curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
    topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
    topk_log_probs = topk_scores * length_penalty
    topk_beam_index = topk_ids.div(vocab_size)
    topk_ids = topk_ids.fmod(vocab_size)
    batch_index = (topk_beam_index + beam_offset.index_select(0,
              torch.arange(topk_beam_index.size(0))).unsqueeze(1))
    select_indices = batch_index.view(-1)
    alive_seq = torch.cat([alive_seq.index_select(
      0, select_indices), topk_ids.view(-1, 1)], -1)
    src_features = src_features.index_select(0, select_indices)
    dec_states.src = dec_states.src.index_select(0, select_indices)
    return src_features, dec_states.src, dec_states.previous_input,
      dec_states.previous_layer_inputs, alive_seq, topk_log_probs

class LoopBodyModel(nn.Module):
  def __init__(self, args, device, checkpoint):
    super(LoopBodyModel, self).__init__()
    self.args = args
    self.device = device
    self.bert = Bert(args.large, args.temp_dir, args.finetune_bert)
    self.vocab_size = self.bert.model.config.vocab_size
    tgt_embeddings = nn.Embedding(
      self.vocab_size, self.bert.model.config.hidden_size, padding_idx=0)
    self.decoder = TransformerDecoder(self.args.dec_layers,
                                      self.args.dec_hidden_size,
                                      heads=self.args.dec_heads,
                                      d_ff=self.args.dec_ff_size,
                                      dropout=self.args.dec_dropout,
                                      embeddings=tgt_embeddings)
    self.generator = get_generator(
      self.vocab_size, self.args.dec_hidden_size, device)
    self.generator[0].weight = self.decoder.embeddings.weight
    self.load_state_dict(checkpoint['model'], strict=False)
    self.to(device)

  def forward(self, step, min_length, src_features, decoder_state_src,
              decoder
    beam_size = self.args.beam_size
    batch_size = src_features.size(0).div(beam_size)
    beam_offset = torch.arange(
      0, batch_size * beam_size, step=beam_size, dtype=torch.long,
      device=self.device)
    decoder_input = alive_seq[:, -1].view(1, -1)
    decoder_input = decoder_input.transpose(0, 1)
    dec_states = TransformerDecoderState(decoder_state_src)
    dec_states.previous_input = decoder_state_previous_input
    dec_states.previous_layer_inputs =
        decoder_state_previous_layer_inputs
    dec_out, dec_states = self.decoder(
      decoder_input, src_features, dec_states, step=step)
    log_probs = self.generator.forward(dec_out.transpose(0,
                                                        1).squeeze(0))
    vocab_size = log_probs.size(-1)
    small = torch.tensor([-1e20])
    tooshort = small*torch.lt(step, min_length).float()
    longenough = log_probs[:, 2]*((~step.lt(min_length)).float())
    endprob = tooshort + longenough
    new_log_probs = torch.cat([log_probs.index_select(-1,
                                                      torch.arange(2)
                              , endprob.view(-1).unsqueeze(1),
                              log_probs.index_select(-1,
                                   torch.arange(start=3,
                                   end=log_probs.size(1)))], -1) +
                              topk_log_probs.view(-1).unsqueeze(1)
    alpha = self.args.alpha
    length_penalty = ((5.0 + (step + 1)) / 6.0) ** alpha
    curr_scores = new_log_probs / length_penalty
    curr_scores = curr_scores.reshape(-1, beam_size * vocab_size)
    topk_scores, topk_ids = curr_scores.topk(beam_size, dim=-1)
    topk_log_probs = topk_scores * length_penalty
    topk_beam_index = topk_ids.div(vocab_size)
    topk_ids = topk_ids.fmod(vocab_size)
    batch_index = topk_beam_index + \
      beam_offset.index_select(0, torch.arange(
        topk_beam_index.size(0))).unsqueeze(1)
    select_indices = batch_index.view(-1)
    alive_seq = torch.cat([alive_seq.index_select(
      0, select_indices), topk_ids.view(-1, 1)], -1)
    src_features = src_features.index_select(0, select_indices)
    dec_states.src = dec_states.src.index_select(0, select_indices)
    results = alive_seq.index_select(
      0, select_indices.index_select(0, torch.tensor(0)))
    return src_features, dec_states.src, dec_states.previous_input,
      dec_states.previous_layer_inputs, alive_seq, topk_log_probs,
      results, endprob

And our export code in onnx_export.py becomes:

init_model = InitLoopModel(args, device, checkpoint)
init_model.eval()
test_iter = data_loader.Dataloader(args, load_dataset(args, 'test',
                                                      shuffle=False),
                                   args.test_batch_size, device,
                                    shuffle=False, is_test=True)
loop_body_model = LoopBodyModel(args, device, checkpoint)
loop_body_model.eval()
for batch in test_iter:
  init_inputs = (batch.src.index_select(0, torch.tensor([0])),
              batch.segs.index_select(0, torch.tensor([0])),
              torch.tensor([0]))
  torch.onnx.export(init_model, init_inputs, "init_loop.onnx",
                    verbose=False,
                    input_names=["src", "segs", "step"],
                    output_names=["src_features",
                                  "decoder_states_src",
                                  "decoder_states_previous_input",
                               "decoder_states_previous_layer_inputs",
                                  "alive_seq", "topk_log_probs"],
                    pset_version=11,
                    dynamic_axes={"src": {0: "batch"}, "segs": {0:
                      "batch"}, "src_features": {0: "batchXbeam"},
                      "decoder_states_src": {0: "batchXbeam"},
                      "decoder_states_previous_input": {0:
                      "batchXbeam"},
                      "decoder_states_previous_layer_inputs": {0:
                      "batch", 1: "batchXbeam"}, "alive_seq": {0:
                      "batchXbeam"}, "topk_log_probs": {0: "batch"}})

  src_features, decoder_state_src, decoder_state_previous_input,
    decoder_state_previous_layer_inputs, alive_seq, topk_log_probs =
    init_model.forward(
      batch.src.index_select(0, torch.tensor([0])),
      batch.segs.index_select(0, torch.tensor([0])),
      torch.tensor([0]))
  loop_inputs = (torch.tensor(1), torch.tensor(20), src_features,
    decoder_state_src,
    decoder_state_previous_input, decoder_state_previous_layer_inputs,
    alive_seq, topk_log_probs)
  torch.onnx.export(loop_body_model, loop_inputs, "loop_body.onnx",
    verbose=False,
    input_names=["step", "min_length", "src_features_in",
      "decoder_states_src_in", "decoder_states_previous_input_in",
      "decoder_states_previous_layer_inputs_in", "alive_seq_in",
      "topk_log_probs_in"],
    output_names=["src_features_out",
      "decoder_states_src_out", "decoder_states_previous_input_out",
      "decoder_states_previous_layer_inputs_out", "alive_seq_out",
      "topk_log_probs_out", "results", "endprob"],
    opset_version=11,
    dynamic_axes={"src_features_in": {0: "batchXbeam"},
      "decoder_states_src_in": {0: "batchXbeam"},
      "decoder_states_previous_input_in": {0: "batchXbeam"},
      "decoder_states_previous_layer_inputs_in": {0: "batch", 1:
      "batchXbeam", 2: "prev_step"}, "alive_seq_in": {0: "batchXbeam",
      1: "prev_step"}, "topk_log_probs_in": {0: "batch"},
      "src_features_out": {0: "batchXbeam"}, "decoder_states_src_out":
      {0: "batchXbeam"}, "decoder_states_previous_input_out": {0:
      "batchXbeam", 1: "step"},
      "decoder_states_previous_layer_inputs_out": {0: "batch", 1:
      "batchXbeam", 2: "step"}, "alive_seq_out": {0: "batchXbeam", 1:
      "step"}, "topk_log_probs_out": {0: "batch"}, "results": {0:
      "batch", 2: "step"}})
  break

Running the Model in MarkLogic using Javascript

After exporting the two models, for them to work properly in MarkLogic, we need to also transform the preprocessing and postprocessing code to XQuery or Javascript. This is much easier than exporting the model, and a final working example looks like this (again, not the most optimal code, the objective is to faithfully translate the original python code to javascript):

'use strict';
function whitespace_tokenize(s) {
  return s.split(" ")
}

function wordpiece_tokenize(s, vocab) {
  let output = []
  let wstokens = whitespace_tokenize(s)
  for (let i = 0; i < wstokens.length; i++) {
    let token = wstokens[i]
    if (token.length > 100) 
      {output.push("[UNK]")
      continue
    }
    let is_bad = false
    let start = 0
    let sub_tokens = []
    while (start < token.length) {
      let end = token.length
      let cur_substr = null
      while (start < end) {
        let substr = token.substr(start, end - start)
        if (start > 0)
          substr = "##" + substr
        if (vocab.hasOwnProperty(substr)) {
          cur_substr = substr
          break
        }
        end -= 1
      }
      if (cur_substr == null) {
        is_bad = true
        break
      }
      sub_tokens.push(cur_substr)
      start = end
    }
    if (is_bad) {
      output.push("[UNK]")
    }
    else {
      for (let j = 0; j < sub_tokens.length; j++) {
        output.push(sub_tokens[j])
      }
    }
  }
  return output
}

function tokenize(s, vocab) {
  s = s.trim().toLowerCase()
  let pretokens = s.split(" ")
  let tokens = ["[CLS]"]
  for (let i = 0; i < pretokens.length; i++) {
    let t = pretokens[i]
    let subtokens = wordpiece_tokenize(t, vocab)
    for (let j = 0; j < subtokens.length; j++) {
      let token = subtokens[j]
      tokens.push(token)
      if (tokens.length >= 511) {
        break;
      }
    }
    if (tokens.length >= 511) {
      break;
    }
  }
  tokens.push("[SEP]")
  return tokens
}

function preprocess(s, vocab) {
  var tokens = tokenize(s, vocab)
  var src = []
  var segs = []
  for (var i = 0; i < 512; i++) {
    if (i < tokens.length) {
      src.push(vocab[tokens[i]])
      segs.push(0)
    } else {
      src.push(0)
      segs.push(1)
    }
  }
  return [src, segs]
}

function getSummarization(result, reverse_vocab) {
  let s = ""
  for (let i = 0; i < result.length; i++) {
    s += reverse_vocab[result[i]]
    if (i != result.length - 1) {
      s += " "
    }
  }
  return s
}

function postprocess(s) {
  s = s.replace(/ ##/g, "")
  s = s.replace(/\[unused0\]/g, "")
  s = s.replace(/\[unused1\]/g, "")
  s = s.replace(/\[unused2\]/g, "")
  s = s.replace(/\[unused3\]/g, "")
  s = s.replace(/\[PAD\]/g, "")
  s = s.replace(/ +/g, " ")
  s = s.trim()
  return s
}

let vocab = cts.doc("vocab.json").toObject()
let reverse_vocab = cts.doc("reverse_vocab.json").toObject()
let article = "(CNN) An Iranian chess referee says she is frightened to
  return home after she was criticized online for not wearing the
  appropriate headscarf during an international tournament. Currently
  the chief adjudicator at the Women's World Chess Championship held in
  Russia and China, Shohreh Bayat says she fears arrest after a
  photograph of her was taken during the event and was then circulated
  online in Iran. \"They are very sensitive about the hijab when we are
  representing Iran in international events and even sometimes they
  send a person with the team to control our hijab,\" Bayat told CNN
  Sport in a phone interview Tuesday. The headscarf, or the hijab, has
  been a mandatory part of women's dress in Iran since the 1979 Islamic
  revolution but, in recent years, some women have mounted opposition
  and staged protests about headwear rules. Bayat said she had been
  wearing a headscarf at the tournament but that certain camera angles
  had made it look like she was not. \"If I come back to Iran, I think
  there are a few possibilities. It is highly possible that they arrest
  me [...] or it is possible that they invalidate my passport,\" added
  Bayat. \"I think they want to make an example of me.\" The
  photographs were taken at the first stage of the chess championship
  in Shanghai, China, but Bayat has since flown to Vladivostok, Russia,
  for the second leg between Ju Wenjun and Aleksandra Goryachkina. She
  was left \"panicked and shocked\" when she became aware of the
  reaction in Iran after checking her phone in the hotel room. The 32-
  year-old said she felt helpless as websites reportedly condemned her
  for what some described as protesting the country's compulsory law.
  Subsequently, Bayat has decided to no longer wear the headscarf.
  \"I'm not wearing it anymore because what is the point? I was just
  tolerating it, I don't believe in the hijab,\" she added. \"People
  must be free to choose to wear what they want, and I was only wearing
  the hijab because I live in Iran and I had to wear it. I had no other
  choice.\" Bayat says she sought help from the country's chess
  federation. She says the federation told her to post an apology on
  her social media channels. She agreed under the condition that the
  federation would guarantee her safety but she said they refused. \"My
  husband is in Iran, my parents are in Iran, all my family members are
  in Iran. I don't have anyone else outside of Iran. I don't know what
  to say, this is a very hard situation,\" she said. CNN contacted the
  Iranian Chess Federation on Tuesday but has yet to receive a
  response."
let processed = preprocess(article, vocab)
let src = processed[0]
let segs = processed[1]

let initLoop = ort.session(cts.doc("init_loop.onnx"))
let loopBody = ort.session(cts.doc("loop_body.onnx"))

let srcName = "src"
let segsName = "segs"
let stepName = "step"
let batchSize = 1
let inputs = {}

for (let i = 0; i < ort.sessionInputCount(initLoop); i++) {
  let name = ort.sessionInputName(initLoop, i)
  if (name == srcName) {
    let shape = ort.sessionInputType(initLoop, i)["shape"]
    shape[0] = batchSize
    inputs[name] = ort.value(src, shape, ort.sessionInputType(initLoop, i)["tensorType"])
  } else if (name == segsName) {
    let shape = ort.sessionInputType(initLoop, i)["shape"]
    shape[0] = batchSize
    inputs[name] = ort.value(segs, shape, ort.sessionInputType(initLoop, i)["tensorType"])
  } else if (name == stepName) {
    inputs[name] = ort.value([0], [1], "INT64")
  }
}
let initOutputs = ort.run(initLoop, inputs)
let names = []
for (let i = 0; i < ort.sessionOutputCount(initLoop); i++) {
  names.push(ort.sessionOutputName(initLoop, i))
}

let loopBodyInputs = {}
for (let i = 0; i < names.length; i++) {
  loopBodyInputs[names[i] + "_in"] = initOutputs[names[i]]
}
let step = 0
let maxStep = 50
let loopBodyOutputs
let result
let minLengthVal = ort.value([20], [1], "INT64")
while (step < maxStep) {
  let stepVal = ort.value([step], [1], "INT64")
  loopBodyInputs["step"] = stepVal
  loopBodyInputs["min_length"] = minLengthVal
  loopBodyOutputs = ort.run(loopBody, loopBodyInputs)
  for (let i = 0; i < names.length; i++) {
    loopBodyInputs[names[i] + "_in"] = loopBodyOutputs[names[i] + "_out"]
  }
  step++
  let resultVal = loopBodyOutputs["results"]
  result = ort.valueGetArray(resultVal)
  if (result[result.length - 1] == vocab["[unused2]"]) {
    break;
  }
}
let summarization = postprocess(getSummarization(result, reverse_vocab))
summarization

And the summarization looks like this:

shohreh bayat says she fears arrest after a photograph of her was circulated online

Conclusion

Above is just one example of trying to convert a state-of-the-art PyTorch NLP model to ONNX. It is true that the conversion is not a one-click solution; it actually requires a rather good understanding of PyTorch and the model itself and some non-trivial problem-solving through debugging/coding. However, this should be expected given the complex nature of the model. BERT is a very significant step forward for NLP, and very widely used. It is actually used in Google search today. Also this model is not authored with conversion to ONNX in mind, making the job more difficult. Given the deep integration of PyTorch and ONNX, if the author of a model writes code with ONNX in mind, the conversion process would be much smoother.

Again, the code in this example is far from optimal or even idiomatic. This is just one way to make it work, as a proof of concept. With a better understanding of PyTorch and the model, there would definitely be much better solutions.

A summary of the above working code is available as a git patch to the original source code.

« Previous chapter
Next chapter »