Intro

We have packaged our ML model, added tests, and handed over to the devops team to set up the continuous integration and continuous deployment pipelines. Because of the packaging and tests we added, the continuous integration step should be fairly standard, assuming your devops team is used to dealing with Python packages.

However, somehow the forward progress got stuck again, and we are back in a meeting with Jane, the devops lead:

   Frank (the ML person):  Hey Jane, what's holding up the deployment of
                           that ML project of mine?

Jane (the DevOps person):  Oh yeah, that. So the tests are passing, but your
                           project is dependent on PyTorch, which isn't a
                           sanctioned library yet, so we did a workaround
                           to get the tests to pass. But we can't deploy
                           it until we sort out whether we are going to
                           add support for PyTorch.

                   Frank:  (surprised) By PyTorch is what the ML community
                           is using for model development and training these
                           days. I wish someone had told me that this needs
                           approval first!

                    Jane:  (shrugs) I don't know what to tell you. We currently
                           support <insert other ML libraries here>, and to
                           take on a new one, it needs to be reviewed first.

                   Frank:  (frustrated) How long would that take? And what if
                           PyTorch is not "approved" for whatever reason? My
                           model would never get launched?

                    Jane:  (cagey) Let's take this offline. Let us review the
                           situation and get back to you.

                   Frank:  (speechless)...

Is Frank’s ML model destined to the scrapyard because of external factors preventing the model from being deployed? What can be done to salvage the situation?

The Lowdown

Whether or not the reasoning is sound or justified, there will be scenarios where we as ML folks are limited by the libraries, tools, and languages we get to work with. This may be technical, political, or even monetary reasons, but the reality is that such limits can greatly slow down, if not jeopardize, a ML project from being deployed. It’s best to anticipate such potential pitfalls, and even better, know the tools to navigate around them to get your ML projects into production.

In this post, we will go through a tactical solution in working around the PyTorch dependency so Frank’s model can be deployed. The larger picture is beyond the scope here, but it’d worth noting that it’s better to find out, from the get-go, what the production environments (e.g., cloud, hardware, operating system, mobile devices, IoT, etc.) your team’s ML projects would be deployed into. More often than not, each environment will have constraints and idiosyncrasies that likely dictate what ML technologies would work within them. This is a big topic on its own, but it’s good to have awareness and start planning as early as possible.

Let’s get started.

Step 0: The Plan

The good news is, there are multiple routes to get around Frank’s issue. The bad news is, the process is still somewhat messy. The situation is steadily improving, and it appears that a common format is emerging as a viable path for converting ML models between ML libraries. That format is the Open Neural Network eXchange (ONNX), which PyTorch has steadily improving its support for. This would be the format that we will convert Frank’s ML model into, such that the ONNX version can be used for inferencing directly, or converted into another format that Jane’s team supports. This would achieve the goal of removing the PyTorch dependency and use another ML library for inferencing as dictated by the external environments.

In a way, ONNX provides the flexibility to train models in one library, and deploy for inferencing in another (or multiple targets). In theory at least. The actual process can get complicated quickly, so word of advice: If there’s a chance that model conversion is needed in your deployment workflow, best to test out this conversion process early on, so you’re not caught with a model that fails to get converted and potentially scuttle the entire project. The usual culprit is if a model uses a custom layer or activation function that ONNX doesn’t support, you’d need to resolve them yourselves or find alternates for them, both of which are not ideal.

For this tutorial we will walk through a complication-free example, so we can see what a successful conversion process look like. So the plan in pseudocode would be:

import franks-ml-model

instantiate franks-ml-model
export franks-ml-model into ONNX format

load the ONNX model
run inference with ONNX model using the same test image
ensure ONNX's output is the same as PyTorch's

Step 1: Export to ONNX

Let’s start with test converting the based model, EfficientNet in our example, to ONNX. We can then add Frank’s wrapper around it after. This is a better sequence, since the ONNX export process can fail. To see if this is the case for EfficientNet, the conversion process is quite simple via short script:

import torch
from efficientnet_pytorch import EfficientNet

if __name__ == '__main__':
    name='efficientnet-b0'
    dummy_input = torch.randn(1, 3, 224, 224)
    model = EfficientNet.from_pretrained(name)
    torch.onnx.export(model, dummy_input, "/tmp/%s.onnx" % name )

PyTorch made the process very simple: create a dummy input of the shape the network expects, batch of 1, channel of 3, and 224x244 tensor in this example model. We then load the model, and call torch.onnx.export with the dummy input so PyTorch can capture the dynamic graph, and convert and save into ONNX format.

Let’s see what happens when we run this script:

$ python export_model_onnx.py
Loaded pretrained weights for efficientnet-b0
Traceback (most recent call last):
  File "export_model_onnx.py", line 18, in <module>
    torch.onnx.export(model, dummy_input, "/tmp/%s.onnx" % name )
  File ".../torch/onnx/__init__.py", line 148, in export
    strip_doc_string, dynamic_axes, keep_initializers_as_inputs)
  File ".../torch/onnx/utils.py", line 66, in export
    dynamic_axes=dynamic_axes, keep_initializers_as_inputs=keep_initializers_as_inputs)
  File ".../torch/onnx/utils.py", line 428, in _export
    operator_export_type, strip_doc_string, val_keep_init_as_ip)
RuntimeError: ONNX export failed:
  Couldn't export Python operator SwishImplementation

Oops, something in the model is not export-able to ONNX. This is why it’s good to run the conversion process to check if the model uses operators that ONNX doesn’t support natively. In this case, the culprit is EfficientNet’s MemoryEfficientSwish operator.

Whether custom operators can be supported by ONNX is beyond the scope of this post. In this case, the author, Luke, added support for ONNX export by having two implementations of the swish activation function, so we dodged a bullet here. We’d just need to add this line to our export script after loading the model:

    model = EfficientNet.from_pretrained(name)
    model.set_swish(memory_efficient=False)

    torch.onnx.export(...)

and you should not get any errors this time and get the .onnx file generated:

$ ls -lh /tmp/efficientnet-b0.onnx
-rw-r--r--  1 gerald  wheel    20M Mar 15 07:52 /tmp/efficientnet-b0.onnx

Step 2: Removing PyTorch dependencies

With this ONNX model converted, we can run inference in a variety of run time environments of our choosing, including ONNX, Tensorflow, TensorRT, MXNet, etc., independently of PyTorch. To keep the steps simple, we will use ONNX runtime to validate the model was converted correctly, by running the same dog image through and compare that we get the same class probabilities. We will base our test script using this example ONNX code.

The inference script has similarly three blocks:

  1. Load model
  2. Load image and preprocess
  3. Inference and post-process results

The first block is quite simple:

import numpy as np
import onnxruntime
import time
from PIL import Image

fn_onnx = '/tmp/efficientnet-b0.onnx'
session = onnxruntime.InferenceSession(fn_onnx, None)

input_name = session.get_inputs()[0].name

The second block is more complicated. We want to resize and normalize the image using the same process the model was trained with, but without depending on PyTorch routines to do the transforms. That means we’d need to reimplement them, using PIL and numpy in this tutorial. It’s possible to use other image libraries and packages like OpenCV also, but that’d introduce more variables since PyTorch uses PIL under the hood for image (for now at least). We simplified the example ONNX code a bit for the preprocess block to be more similar to what Frank did in his PyTorch version:

def preprocess(image):
    # convert the input data into the float32 input
    img_data = np.array(image).astype('float32')

    # normalize
    mean_vec = np.array([0.485, 0.456, 0.406])
    stddev_vec = np.array([0.229, 0.224, 0.225])

    norm_img_data = (img_data/255. - mean_vec) / stddev_vec
    norm_img_data = norm_img_data.astype('float32')

    # change channel order
    norm_img_data = norm_img_data.transpose(2, 0, 1)
    # add batch dimension
    norm_img_data = np.expand_dims(norm_img_data, axis=0)

    return norm_img_data

The last block is to do inferencing and post-process the outputs, again using numpy instead of PyTorch:

def softmax(x):
    x = x.reshape(-1)
    e_x = np.exp(x - np.max(x))
    return e_x / e_x.sum(axis=0)

def postprocess(result):
    return softmax(np.array(result)).tolist()

Step 3: Inferencing using the ONNX model

We are now ready to do inferencing:

# load dog image from our test
image = Image.open('../mlillustrated_franks_ml_model/tests/test_files/dog.jpg')
# resize to 224x224 that EfficientNet expects
image = image.resize( (224, 224) )
# preproces image
input_data = preprocess(image)

# do inference via ONNX runtime session
raw_result = session.run([], {input_name: input_data})
# post-procss predictions
res = postprocess(raw_result)

# sort and print the top 5 classes
sort_idx = np.flip(np.squeeze(np.argsort(res)))
print('============ Top 5 classes are: ============================')
for idx in sort_idx[:5]:
  print( idx, res[idx] )

When we run this, the output looks like

============ Top 5 classes are: ===========================
207 0.457427978515625
213 0.32097241282463074
219 0.03438207134604454
220 0.014274176210165024
211 0.01017140131443739
===========================================================

Hmm… On one hand, it’s encouraging that the top class is 207, as we expect, but why is the probability lower? It’d have failed our test for it being over 0.5. Is this a bug in the conversion process, or in ONNX runtime, or something else? Or should Frank give up and go back to Jane and say that PyTorch is a hard requirement since he doesn’t want to deal with unknown behaviors with other runtimes?

Step 4: Some digging

Being curious engineers, and knowing that ONNX has been around long enough that it shouldn’t be that unstable, let’s investigate on what might be going on here. There are three areas we can validate, by feeding the same input into both models and comparing the outputs per block:

  1. Is the preprocessing block the same?
  2. Is the inferencing block the same?
  3. Is the post-processing block the same?

Let’s start with the preprocessing block. We used the same dog image for the input, so we can focus on verifying the outputs between the PyTorch and our numpy reimplementation.

We can first verify the shapes of the pre-processed image. In the PyTorch version, we print the shape of the tensor after passing in the dog image through self.tfms:

def load_and_transform_image( self, fn_image ):
    img_tensor = self.tfms(Image.open( fn_image )).unsqueeze(0)
    print(img_tensor.shape)
    ...

To our surprise, the image is not 224x224, but torch.Size([1, 3, 224, 336]). Digging into this further, the documentation for transforms.Resize states that:

“If size is an int, smaller edge of the image will be matched to this number. i.e, if height > width, then image will be rescaled to (size * height / width, size)”

Let’s change the image resize and see if the probability goes up. We’d need to re-export the ONNX with the updated dummy input, and change our image.resize step, i.e.,

    # in export_model_onnx.py
    dummy_input = torch.randn(1, 3, 224, 336)
# in the onnx_infer.py script
image = image.resize( (336, 224) ) # instead of (224,224)

Good: With this update, the top probability now goes up to over 0.5, accounting for the majority of the differences between the original PyTorch and ONNX versions:

207 0.5564291477203369

Step 5: Further digging

We could have stopped here, but while we were reading the transforms.Resize() documentation, it also mentions that the default interpolation mode is actually PIL.Image.BILINEAR, whereas the sampling default for PIL.resize() is PIL.Image.NEAREST. We can update our implementation like so:

image = image.resize( (336, 224), resample=Image.BILINEAR)

Nice: the top probability goes up to 0.56, quite close to the original PyTorch output:

207 0.5610848665237427

At this point the probabilities are sufficiently close that we can attribute the differences to numerical instability. You can also spend some time validating the post-processing code, but we will leave that out here.

We went through this process to show that models can be quite sensitive to how the inputs are preprocessed. Ideally we want to have exactly the same implementation for both training images and inferencing images, but when that’s not possible, it’s worthwhile spending the effort to validate that the preprocessed inputs are as identical to the original as possible. That’s also why we chose PIL instead of other image libraries, since we’d have had to do further validation.

Step 6: Make the input shapes flexible

It may appear to you that we changed the shape to get this image to work correctly, it’s not a good solution if the model is to accept arbitrary shaped images, such as portrait or square images. The better solution would be to resize the images the same way PyTorch transform does, feed them to the model the same way it’s been trained on, so we preserve the same inferencing probabilities and accuracy.

Fortunately, ONNX supports dynamic input shapes, by passing the right parameters at export time. We’d then re-implement PyTorch transform’s image resizing logic, so we can fully support the behavior of the original PyTorch model.

The first step is doing some digging of the torch.onnx.export parameters, by passing in which input dimensions to be made dynamic, in our case the 2nd and 3rd (height and width) axes:

input_names = ["input.1"]
dynamic_axes = {"input.1": {2: "h", 3: "w"}}
...

torch.onnx.export(
    model,
    dummy_input,
    "/tmp/%s.onnx" % name,
    dynamic_axes=dynamic_axes,
    input_names=input_names
)

The second step is to re-implement (really copy and paste) the logic for the resize dimensions from torchvision.functional.transform.resize. It’s not rocket science, but in such cases, it’s better to just take the code as-is to avoid any discrepancies in the implementations:

def calc_resize_size( size, img_size ):
    w, h = img_size
    if (w <= h and w == size) or (h <= w and h == size):
        ow, oh = w, h
    if w < h:
        ow = size
        oh = int(size * h / w)
    else:
        oh = size
        ow = int(size * w / h)
    return (ow, oh)


resize_size = calc_resize_size( 224, image.size )
image = image.resize(resize_size, resample=Image.BILINEAR)
...

If you test this with the dynanic_axes ONNX model, you will get the same >0.56 probability for the test image. This means that the ONNX model new supports the same behaviors as the original PyTorch, completing our validation process that, indeed, ONNX can be used in place of PyTorch for inferencing.

Conclusion

We have completed the first of the two part process in converting the PyTorch model to ONNX. The process can be smooth, especially if the original model uses operators supported by ONNX. The gotchas may creep up around the preprocessing and post-processing phases, and we saw an example of the image resizing logic in this example. It’s a good thing that the source code PyTorch and torchvision are very well organized and easy to read, making the process of finding out how to re-implement relatively trouble-free.

What about bleeding-edge models with custom operators or layers? That really depends on the operator support in ONNX, as well as potentially the support in the runtime you’d ultimately use, if it’s not ONNX Runtime. This process is not for the faint of heart, so it’s best to assume that if ONNX support is missing, the conversion process would not succeed.

Frankly, if you run into such situations, unless the gains with exotic operators are that significant, my take is to instead look for comparable models using standard operators that might be slightly less accurate. Get that model ready for deployment, and you can then swap out with the more accurate model later, perhaps by reimplementation or when ONNX adds support for the new operators.

So the take home message here is, if ONNX conversion is part of your production pipeline, test out this step as early as possible, so you find potential issues from the start. There are other possibilities for working around such limitations, such as via teacher-student training or knowledge distillation, but they have their own limitations as well. So best to verify early on that whatever model you choose doesn’t have conversion issues.

At this point, we are not ready to go back the Jane and tell her of our success yet. In the next post, we will create a new Python package for Frank’s ONNX model. We will also address the issue of distributing the model files themselves (e.g., the file /tmp/efficientnet-b0.onnx), which up to this point have been kindly hosted (and paid for) by Luke of EfficientNet_Pytorch. We can’t do that with our ONNX model files, so we will explore ways to package and distribute model files, in addition to the code that we went through earlier. And of course, add tests. 👍

References