GPU accelerated numerical ops on OCaml arrays via pyml + jax?

I recently started playing with pyml + jax.

I am wondering if anyone has taken the following approach:

  1. write AST describing numerical op in OCaml DSL
  2. transpile this to pyhon
  3. feed it to pyml, calling Jax.jit
  4. get back CPU / GPU optimized code for the functions
  5. invoke the functions on OCaml arrays


My understanding is: pyml does 3, Jax does 4, pyml w/ numpy support does 5; so the only missing parts are 1 & 2.

For those more versed in OCaml / Python-fu, can you let me know if this could work or if I am missing something glaringly obvious ?

EDIT: ping @lukstafi

I think one alternative way can be to write the numerical code using Owl and then use owl-symbolic to export it to onnx and via pyml execute the onnx model on the appropriate hardware

1 Like

Given JAX’s integration with / reliance on Python, it’s unclear to me what the benefit would be. I’m curious if you feel the need to do the iterative development / debugging in Python, and/or whether pyml provides enough of a bridge to blur the boundary. Do you use the Python debugger to step through JAX computing the tracer objects?

Can GitHub - LaurentMazare/ocaml-onnx: OCaml ONNX runtime powered by onnxruntime run it directly without Pyml?

1 Like

Indeed, this should make it possible to evaluate onnx models without using python and still benefit from the speed of onnxruntime.

When it comes to using jax, I’ve been working on Rust bindings for XLA (the compiler that underpins jax) over the last week , see xla-rs. These are very experimental but should as of a couple hours ago be complete enough to evaluate a GPT-2 model. I plan on polish the Rust version a bit and if there is some interest may well craft an OCaml version too.


For some reason I thought only python ONNX runtime supports training. Is it possible to run training via ONNX using your library as well?

I want the benefits of XLA / engineering work that went into generating optimized CUDA kernels on demand.

I might try to do some type of “interactive” development, but certainly not via python. One example: OCaml server running jax holding all the variable bindings, jsoo/react providing a “REPL” to type expressions over the variables and evaluating it.

OCaml bindings for XLA would be fantastic! I’ll send in bogus bug reports like I did with tch-rs. :slight_smile:

OCaml XLA bindings would indeed be fantastic.

1 Like

Oh nice! I had missed this

(But XLA also underpins TensorFlow, it’s not JAX-specific.)

There is currently no support for the training part in the ocaml bindings (I actually didn’t even know that they existed). From what I can see of the training C api, this could be added though would be a bit of work to get there as there are a bunch of functions to support (list on github).

Good to see that there is some interest in XLA bindings, I’ll have a look once the Rust ones are a bit more stable, hopefully most of the work that wraps the C++ api in a C api can be shared between the Rust and OCaml versions.


I am not sure if we are discussing from the same context, so I’m going to take a step back.

The problem that I care about is:

  1. expressing numerical algorithms in OCaml
  2. executing them on Cuda/GPU

My current choices appear to be:

  1. use OCaml to emit Cuda kernels
  2. target some external format (say onnx)
  3. pyml → ??? → Cuda Kernels

My current understanding of the above is: #1 seems like a lot of work; #2 seems very “batchy” (i.e. not interactive). As for 3 – I have not used tensorflow in recent years – but my impression is:

  • Jax: give me a python function, and I output a cuda kernel
  • Tensorflow: setup a computation graph

In my current understanding, having expertise in python bindings for neither, jax seems much more direct for the problem of “ocaml expr → cuda kernel”.


Quick update on the XLA bindings, I’ve put up some initial version ocaml-xla. It’s pretty sketchy at the moment and more of a proof of concept than an actual library, the main example is some text generation using a pre-trained GPT-2 model.
I plan on polishing this a bit over the next few days, adding some tests, documentation, and some more involved examples (hopefully getting something like llama to work), if this goes well I’ll make a proper release and announceme it in a separate thread. Any feedback is very welcome, feel free to open some issues on things you’re missing though if you start playing with it, please expect the api to change in breaking ways.


@laurent : Thanks for your work! I have a dumb question: why are we pulling binaries from elixir-nx? Is it (1) elixir-nx somehow patches xla, making it easier to generate bindings for or (2) for some reason, no one else provides binaries for a certain stage of the xla compilation process ?


1 Like

One possibility could be to use futhark from ocaml. The bindgen project can create the glue between the two languages (I haven’t tested it yet).

1 Like

I have never used futhark, and can not commit on its merits. However, for this particular problem, in my opinion, the best solution is probably to get as many people as possible to throw their efforts behind @laurent 's GitHub - LaurentMazare/ocaml-xla: XLA (Accelerated Linear Algebra) bindings for OCaml project.

  1. I think this is a very direct way to be an REPL where we can write OCaml symbolic exprs, get it compiled to fast CUDA kernels, and executed.

  2. @laurent has a track record of delivering on these projects; see the amazing demos at: tch-rs/examples at main · LaurentMazare/tch-rs · GitHub and ocaml-torch/examples at main · LaurentMazare/ocaml-torch · GitHub

1 Like

@zeroexcuses It’s only based on (2) and a convenient way to get the xla shared library. No relevant patch gets applied, you can also compile the xla shared library directly from the google repo and copy the headers around and this will work fine (this is what I used to do before learning about the nx effort on the elixir side).
As far as I know there is no official binary distribution that contains only the xla bits, the closest thing is probably the jaxlib python package, if you unzip one of the wheel files there, you would get the shared library but sadly these don’t package the header files so we would have to vendor these. Hopefully the recent push on openxla will result in binary packages being released at some point.


I’m very excited about those XLA bindings (thanks @laurent !) ─ it’s something that I have long been waiting for to support my research work in comp. neuro / ML / AI and I think many people would be interested to have XLA added as an acceleration backend in the great owl library ─ I will definitely look into this (and @tachukao has expressed interest too offline). Owl already supports nested forward/backward modes of automatic differentiation at just the right level of abstraction for research/prototyping purposes, but lacks GPU support. I have to look into ocaml-xla a bit more but my feeling is that it should be reasonably straightforward to extend owl’s Algodiff module so it can operate on ocaml-xla’s Op.t type. Anybody interested in brainstorming/strategising/helping?

1 Like

I’m working on a numerical / deep learning library that is an alternative to Owl, based on a different set of design choices.

  1. It is not functorized, more compact, two-three abstraction layers,
    a. it is centered around two syntaxes, for computations and for differentiable tensor operations,
    b. the aim is to be able to express computations / models very concisely,
    c. tensors are Bigarray.Genarray so easily usable from the outside, except for being tagged/boxed with numerical precision.
  2. It has a strong built-in shape inference logic (it provides “generalized matrix multiplication”; “extended einsum notation”; dynamic indexing integrated with shape inference…)
  3. I’m prioritizing low-level backends, lower-level than XLA. I just finished a C backend (via gccjit), I’ll work on a CUDA backend in the coming months.
  4. It is limited to backpropagation (first order backward mode autodiff).

Sounds interesting! Point 4 is a showstopper for my needs though… is nested forward/backward something that will be outright impossible to incorporate in your framework by design, or is it just that you haven’t had the chance to do it yet? Also, re 1.c, owl tensors are also Bigarray.Genarray. FWIW, I find I am able to express computations and models very concisely in owl; but most of my owl code so far has been for custom, highly specific models and so I haven’t had to care much for a higher-level deep learning interface that would let me write a transformer in 10 lines of code, and in this respect I think I am not very representative of the numerical-ocaml community (though tbh I have no idea :smile:). I’ll definitely keep an eye on your work though, it sounds promising!