Polly is Elucidata’s data-centric MLOps platform which hosts millions of biomolecular datasets, and PollyBERT is our suite of transformer-based language models. The "datasets" we ingest usually come in the form of experimental measurements, accompanying textual metadata, and/or scientific papers describing the experiments' results. The large number of datasets ingested makes it important to ensure an easy search and catalog function for our users. We do this by mapping datasets with metadata annotations. The annotation process, which would be extremely time-consuming if done manually, is aided by PollyBERT.

PollyBERT models are trained to perform information extraction tasks like sequence tagging, sequence classification, entity normalization, relation extraction, and slot-filling. They scan the text and papers and annotate them with human-level accuracy.

If you've been following the news around NLP lately, you'll know that large language models are expensive to train and perform inference on. The rate at which we're ingesting new data into Polly only worsens this problem. To give some perspective, one forward pass through BERT takes 8-50ms on an AWS g4dn.xlarge instance (which comes with an NVIDIA T4 GPU). This means that running 1 BERT model on 1 million inputs would take ~ 9 hours. Since we have 12 different BERT models performing different tasks, it would take 5 days to run them sequentially on one machine. We often need to run these models on several million inputs every week, which can end up costing thousands of dollars every month. Last quarter, we decided to spend some time doing some optimization to help us increase the number of requests our models could handle. We summarized our problem statement:* Increase model inference throughput without lowering model accuracy or increasing costs.*** In this article,** I describe some of the optimizations we tried and how effective they were at helping us achieve our goal. We ended up picking the methods that gave the biggest ROI.

Note that some of these optimizations were specific to our use case, so they might not be universally applicable. Also, we used the BERT architecture for most of our experiments, so you can assume that these results are for BERT unless stated otherwise.

Before we delve into the fancy model compression techniques, let’s talk about what it means to make models faster. When it comes to deep learning models, there are generally 2 variables that people optimize for before deployment.

These are:

1. Size:

This is the space taken up by the model's weights, on disk and in RAM.

Vanilla BERT has a size of 400 MB.

2. Inference Time:

This is the time it takes for one prediction.

For our corpus, vanilla BERT took 30ms (avg) on GPU and 430ms (avg) on CPU.

Model size tends to matter more when deployed on edge devices and mobile phones. In that case, the size can directly impact user experience since it affects download time and space taken up on the device. If your model runs on servers, the model size can affect how quickly you can spin up new instances when the load increases. For us, neither of these was a concern. Our models run server-side, not on edge devices, and the load doesn't change fast enough to warrant a need for quick spin-up times. We ended up not taking model size into account when deciding on which optimizations to go with.

We were mainly concerned with how many inputs our servers could handle per second (throughput), which dictates how many datasets we can annotate. ** Since inference time directly impacts throughput, that’s the quantity we're implicitly optimizing for.**

Different factors can affect BERT's inference time. This includes batch size, sequence length, choice of deep-learning framework, and hardware. We used Pytorch and a batch size of 1 for all our experiments, a g4dn.xlarge instance as our GPU machine and a laptop with Intel i5-7300U as our CPU machine. Furthermore, we fixed the corpus used to test different model variants to control for sequence length, ensuring it resembled production workloads. For our corpus, vanilla BERT took 30ms per input (avg) on GPU and 430ms per input (avg) on CPU.

Model parameters and activations are typically stored as float32. BERT has 110 million float32 parameters, which take up 410MB of memory. The idea behind quantization is to use float16 or int8 to store these parameters, thereby reducing the computation time and the model size.

There are different ways to do this quantization. One of them is called *dynamic quantization*. Here the mapping from float32 to int8 happens *after training*. The model parameters are converted to int8 by rounding them prior to inference. This part is straightforward but not enough. If we leave it there, the errors due to quantization can propagate through the layers and become increasingly big. These errors are reduced by having the *activations* undergo an additional affine transformation (y = Ax + b). This transformation is done in a way that ensures that the model's activations deviate as little as possible from their original (pre-quantization) values. The *A* and *b* for the transformation are learned dynamically during inference (see here for details). Also, not all operations are performed using int8. Some use int16 or int32 to avoid issues like integer overflow.

Here’s a diagram to show what the weights look like before and after quantization. You can clearly see the decrease in floating-point precision.

We got a 60% reduction in model size and a 1.6x speedup on the CPU using dynamic quantization. Unfortunately, if you're using Pytorch (which we do), quantization benefits are not available on GPU.

There are two other well-established ways of quantizing your network, namely static quantization and quantization-aware training, each with its advantages. We stuck to dynamic quantization because it is super easy to implement. Using Pytorch, it can be done in a few lines and doesn’t differ too much from the other two in terms of inference time reduction.

As the name suggests, pruning involves removing (or masking) model parameters that don't contribute much to accuracy but use a lot of computational resources. Pruning can be done by zeroing parameters that are already close to zero (*unstructured pruning*) or in a more systematic way by removing entire layers and reshaping weight matrices and, in some cases, re-training the network afterward (structured pruning).

Pruning is an interesting technique in itself. It can be used both as a way of making your model run faster and as a means to explore what your model is learning.

Structured pruning can be quite effective at speeding up inference, but it is also a fairly elaborate process involving a lot of trial and error. Since we were short on time, we decided to table this for later.

We did try out unstructured pruning. We could zero as many as 30% of our model's parameters without seeing any drop in F1. Unfortunately, this did not give us any inference speedup. The reason for this is that, while unstructured pruning increases the sparsity of your network, most deep-learning frameworks (including Pytorch) are not currently designed to take advantage of this sparsity to speed up computation (see here). They perform the same number of operations regardless of how sparse your network is. Unstructured pruning can help in reducing model size, but only if you store the network weights using some sparse vector format.

KD is a way of training a smaller student model from a larger teacher model. The key idea here is that we teach the student to mimic the teacher's output distribution (soft labels) instead of training it directly for a given task. Empirical results have shown that training the student in this manner is more effective. There are a few different ways BERT can be made "smaller,” but decreasing the number of layers is the most effective.

For our experiments, we used the distilbert architecture. We trained it using a fine-tuned BERT as the teacher. Distilbert gave us a speedup of 1.5x on GPU (and 2.4x on CPU) but had a 5-point lower F1 score. We suspect the lower F1 is because Distilbert was pre-trained on general language text. Our best-performing models are initialized from BioBERT and PubmedBERT, both of which are pre-trained on biomedical text. There doesn't seem to be any distil-*bio*-bert available in the public domain. Perhaps, in the near future, we'll invest some time in training one ourselves.

*Having read about pruning and KD, you might be wondering, if we can reach the same accuracy using a smaller network, why train the bigger network in the first place? Why not just train the smaller network directly?The short answer is, bigger, over-parameterized models converge more easily to an optimal solution. The *

Onnx Runtime does some hardware-specific optimizations to make inference faster without changing the network outputs (without any drop in model accuracy). It optimizes certain operations in the model's computation graph to use hardware-specific accelerators that might be available on a given machine. All this happens in a single "compilation" step before inference. From a user's POV, it works out-of-the-box and helped us achieve a 2x speedup on GPU (and 1.2x on CPU). We're considering automating this step for future deployments, making it part of our deployment pipeline.

MTL turned out to be the most effective at increasing throughput, helping us achieve a massive speedup without compromising accuracy. MTL generally refers to training one network to do multiple complementary tasks. This is done by having a shared base model, the outputs of which feed into distinct task-specific layers. It's commonly used as a way of improving model accuracy. By training a shared model for different but complementary tasks, you're hoping the model learns representations from one task that are useful in others.

MTL's use as an inference speedup technique is not as common, but it was perfect for our use case. We have 12 different PollyBERT models for different information extraction tasks, and they are almost always run on the same batch of inputs. Eight of these are NER models, each responsible for detecting a different entity type (disease, cell line, tissue, etc.). All of them have the same BERT architecture but distinct sets of weights. Running all these models on one input means doing forward passes through 8 BERT models (super slow). We decided to train one single MTL model for these NER tasks to parallelize inference. Our model consisted of a base model initialized from PubmedBERT and 8 token-classification (NER) heads, with each head consisting of a fully-connected 2-layer network followed by softmax. We successfully trained the multi-task model without any drop in accuracy, and while it didn't really decrease the inference time per task, it increased the effective throughput by 5x.

Since our existing deployment was on GPU machines, we ended up going with ONNX runtime and MTL since they both work well on GPU. Moreover, since they are independent, their benefits compounded to give us a **10x increase in overall throughput**.

We realized during our inference optimization sprint that the BioNLP community lacked a distilled BERT model pre-trained on biomedical text. That’s something we’re considering training and releasing publicly.

Also, there are plenty of developments happening in this space. Here are a few we're particularly excited about. We are hoping that they’ll make our model deployment work easier.

1. Huggingface Optimum: Huggingface recently released this extension to their transformers library. It allows users to accelerate inference. The library doesn’t implement any novel optimization techniques. Instead, it is a wrapper around existing optimizations (including quantization and Onnx runtime) and exposes them through a simple and familiar API.

2. Sagemaker Serverless Inference: As the name suggests, this is AWS lambda + Sagemaker. It doesn't support GPUs yet, but we're excited to try it out for workloads that don't need GPU-level performance. We’re hoping it will take care of the headache involved in scaling model inference up and down.

Oops! Something went wrong while submitting the form.

Get the latest insights on Biomolecular data and ML

Oops! Something went wrong while submitting the form.