TensorFlow, PyTorch, and JAX: Choosing a deep learning framework

TensorFlow, PyTorch, and JAX: Choosing a deep learning framework

Three widely used frameworks are leading the way in deep learning research and production today. One is celebrated for ease of use, one for features and maturity, and one for immense scalability. Which one should you use?

Credit: Dreamstime

Deep learning is changing our lives in small and large ways every day. Whether it’s Siri or Alexa following our voice commands, the real-time translation apps on our phones, or the computer vision technology enabling smart tractors, warehouse robots, and self-driving cars, every month seems to bring new advances. 

And almost all of these deep learning applications are written in one of three frameworks: TensorFlow, PyTorch, and JAX.

Which of these deep learning frameworks should you use? In this article, we’ll take a high-level comparative look at TensorFlow, PyTorch, and JAX. We’ll aim to give you some idea of the types of applications that play to their strengths, as well as consider factors like community support and ease-of-use.

Should you use TensorFlow?

“Nobody ever got fired for buying IBM” was the rallying cry of computing in the 1970s and 1980s, and the same could be said about using TensorFlow in the 2010s for deep learning. But as we all know, IBM fell by the wayside as we came into the 1990s. Is TensorFlow still competitive in this new decade, seven years after its initial release in 2015?

Well, certainly. It’s not like TensorFlow has stood still for all that time. TensorFlow 1.x was all about building static graphs in a very un-Python manner, but with the TensorFlow 2.x line, you can also build models using the “eager” mode for immediate evaluation of operations, making things feel a lot more like PyTorch. 

At the high level, TensorFlow gives you Keras for easier development, and at the low-level, it gives you the XLA (Accelerated Linear Algebra) optimising compiler for speed. 

XLA works wonders for increasing performance on GPUs, and it’s the primary method of tapping the power of Google’s TPUs (Tensor Processing Units), which deliver unparalleled performance for training models at massive scales.

Then there are all the things that TensorFlow has been doing well for years. Do you need to serve models in a well-defined and repeatable manner on a mature platform? TensorFlow Serving is there for you. Do you need to retarget your model deployments for the web, or for low-power compute such as smartphones, or for resource-constrained devices like IoT things? 

TensorFlow.js and TensorFlow Lite are both very mature at this point. And obviously, considering Google still runs 100 per cent of its production deployments using TensorFlow, you can be confident that TensorFlow can handle your scale.

But… well, there has been a certain lack of energy around the project that is a little hard to ignore these days. The upgrade from TensorFlow 1.x to TensorFlow 2.x was, in a word, brutal. 

Some companies looked at the effort required to update their code to work properly on the new major version, and decided instead to port their code to PyTorch. TensorFlow also lost steam in the research community, which started preferring the flexibility PyTorch offered a few years ago, resulting in a decline in the use of TensorFlow in research papers.

The Keras affair has not helped either. Keras became an integrated part of TensorFlow releases two years ago, but was recently pulled back out into a separate library with its own release schedule once again. Sure, splitting out Keras is not something that affects a developer’s day-to-day life, but such a high-profile reversal in a minor revision of the framework doesn’t inspire confidence.

Having said all that, TensorFlow is a dependable framework and is host to an extensive ecosystem for deep learning. You can build applications and models on TensorFlow that work at all scales, and you will be in plenty of good company if you do so. But TensorFlow might not be your first choice these days.

Should you use PyTorch?

No longer the upstart nipping at TensorFlow’s heels, PyTorch is a major force in the deep learning world today, perhaps primarily for research, but also in production applications more and more. 

And with eager mode having become the default method of developing in TensorFlow as well as PyTorch, the more Pythonic approach offered by PyTorch’s automatic differentiation (autograd) seems to have won the war against static graphs.

Unlike TensorFlow, PyTorch hasn’t experienced any major ruptures in the core code since the deprecation of the Variable API in version 0.4. Previously, Variable was required to use autograd with tensors; now everything is a tensor.

But that’s not to say there haven’t been a few missteps here and there. For instance, if you’ve been using PyTorch to train across multiple GPUs, you likely have run into the differences between DataParallel and the newer DistributedDataParallel. You should pretty much always use DistributedDataParallel, but DataParallel isn’t actually deprecated.

Although PyTorch has been lagging behind TensorFlow and JAX in XLA/TPU support, the situation has improved greatly as of 2022. PyTorch now has support for accessing TPU VMs as well as the older style of TPU Node support, along with easy command-line deployment for running your code on CPUs, GPUs, or TPUs with no code changes. 

And if you don’t want to deal with some of the boilerplate code that PyTorch often makes you write, you can turn to higher-level additions like PyTorch Lightning, which allows you to concentrate on your actual work rather than rewriting training loops. On the minus side, while work continues on PyTorch Mobile, it’s still far less mature than TensorFlow Lite.

In terms of production, PyTorch now has integrations with framework-agnostic platforms such as Kubeflow, while the TorchServe project can handle deployment details such as scaling, metrics, and batch inference, giving you all the MLOps goodness in a small package that is maintained by the PyTorch developers themselves. 

Does PyTorch scale? Meta has been running PyTorch in production for years, so anybody that tells you that PyTorch can’t handle workloads at scale is lying to you. Still, there is a case to be made that PyTorch might not be quite as friendly as JAX for the very, very large training runs that require banks upon banks of GPUs or TPUs.

Finally, there’s the elephant in the room. PyTorch’s popularity in the past few years is almost certainly tied to the success of Hugging Face’s Transformers library. Yes, Transformers now supports TensorFlow and JAX too, but it started as a PyTorch project and remains closely wedded to the framework. 

With the rise of the Transformer architecture, the flexibility of PyTorch for research, and the ability to pull in so many new models within mere days or hours of publication via Hugging Face’s model hub, it’s easy to see why PyTorch is catching on everywhere these days.

Should you use JAX?

If you’re not keen on TensorFlow, then Google might have something else for you. Sort of, anyway. JAX is a deep learning framework that is built, maintained, and used by Google, but it isn’t officially a Google product. 

However, if you look at the papers and releases from Google/DeepMind over the past year or so, you can’t help but notice that a lot of Google’s research has moved over to JAX. So JAX is not an “official” Google product, but it’s what Google researchers are using to push the boundaries.

What is JAX, exactly? An easy way to think about JAX is this: Imagine a GPU/TPU-accelerated version of NumPy that can, with a wave of a wand, magically vectorise a Python function and handle all the derivative calculations on said functions. 

Finally, it has a JIT (Just-In-Time) component that takes your code and optimises it for the XLA compiler, resulting in significant performance improvements over TensorFlow and PyTorch. I’ve seen the execution of some code increase in speed by four or five times simply by reimplementing it in JAX without any real optimisation work taking place.

Given that JAX works at the NumPy level, JAX code is written at a much lower level than TensorFlow/Keras, and, yes, even PyTorch. Happily, there’s a small but growing ecosystem of surrounding projects that add extra bits. You want neural network libraries? 

There’s Flax from Google, and Haiku from DeepMind (also Google). There’s Optax for all your optimiser needs, and PIX for image processing, and much more besides. Once you’re working with something like Flax, building neural networks becomes relatively easy to get to grips with. 

Just be aware that there are still a few rough edges. Veterans talk a lot about how JAX handles random numbers differently from a lot of other frameworks, for example.

Should you convert everything into JAX and ride that cutting edge? Well, maybe, if you’re deep into research involving large-scale models that require enormous resources to train. The advances that JAX makes in areas like deterministic training, and other situations that require thousands of TPU pods, are probably worth the switch all by themselves.

TensorFlow vs. PyTorch vs. JAX

What’s the takeaway, then? Which deep learning framework should you use? Sadly, I don’t think there is a definitive answer. It all depends on the type of problem you’re working on, the scale you plan on deploying your models to handle, and even the compute platforms you’re targeting.

However, I don’t think it’s controversial to say that if you’re working in the text and image domains, and you’re doing small- or medium-scale research with a view to deploying these models in production, then PyTorch is probably your best bet right now. It just hits the sweet spot in that space these days.

If, however, you need to wring out every bit of performance from low-compute devices, then I’d direct you to TensorFlow with its rock-solid TensorFlow Lite package. And at the other end of the scale, if you’re working on training models that are in the tens or hundreds of billions of parameters or more, and you’re mainly training them for research purposes, then maybe it’s time for you to give JAX a whirl.

Tags deep learningTensorFlowPyTorchJAX

Show Comments