Given JAX’s integration with / reliance on Python, it’s unclear to me what the benefit would be. I’m curious if you feel the need to do the iterative development / debugging in Python, and/or whether pyml provides enough of a bridge to blur the boundary. Do you use the Python debugger to step through JAX computing the tracer objects?
Indeed, this should make it possible to evaluate onnx models without using python and still benefit from the speed of onnxruntime.
When it comes to using jax, I’ve been working on Rust bindings for XLA (the compiler that underpins jax) over the last week , see xla-rs. These are very experimental but should as of a couple hours ago be complete enough to evaluate a GPT-2 model. I plan on polish the Rust version a bit and if there is some interest may well craft an OCaml version too.
I want the benefits of XLA / engineering work that went into generating optimized CUDA kernels on demand.
I might try to do some type of “interactive” development, but certainly not via python. One example: OCaml server running jax holding all the variable bindings, jsoo/react providing a “REPL” to type expressions over the variables and evaluating it.
OCaml bindings for XLA would be fantastic! I’ll send in bogus bug reports like I did with tch-rs.
There is currently no support for the training part in the ocaml bindings (I actually didn’t even know that they existed). From what I can see of the training C api, this could be added though would be a bit of work to get there as there are a bunch of functions to support (list on github).
Good to see that there is some interest in XLA bindings, I’ll have a look once the Rust ones are a bit more stable, hopefully most of the work that wraps the C++ api in a C api can be shared between the Rust and OCaml versions.
Quick update on the XLA bindings, I’ve put up some initial version ocaml-xla. It’s pretty sketchy at the moment and more of a proof of concept than an actual library, the main example is some text generation using a pre-trained GPT-2 model.
I plan on polishing this a bit over the next few days, adding some tests, documentation, and some more involved examples (hopefully getting something like llama to work), if this goes well I’ll make a proper release and announceme it in a separate thread. Any feedback is very welcome, feel free to open some issues on things you’re missing though if you start playing with it, please expect the api to change in breaking ways.
@laurent : Thanks for your work! I have a dumb question: why are we pulling binaries from elixir-nx? Is it (1) elixir-nx somehow patches xla, making it easier to generate bindings for or (2) for some reason, no one else provides binaries for a certain stage of the xla compilation process ?
@zeroexcuses It’s only based on (2) and a convenient way to get the xla shared library. No relevant patch gets applied, you can also compile the xla shared library directly from the google repo and copy the headers around and this will work fine (this is what I used to do before learning about the nx effort on the elixir side).
As far as I know there is no official binary distribution that contains only the xla bits, the closest thing is probably the jaxlib python package, if you unzip one of the wheel files there, you would get the shared library but sadly these don’t package the header files so we would have to vendor these. Hopefully the recent push on openxla will result in binary packages being released at some point.
I’m very excited about those XLA bindings (thanks @laurent !) ─ it’s something that I have long been waiting for to support my research work in comp. neuro / ML / AI and I think many people would be interested to have XLA added as an acceleration backend in the great owl library ─ I will definitely look into this (and @tachukao has expressed interest too offline). Owl already supports nested forward/backward modes of automatic differentiation at just the right level of abstraction for research/prototyping purposes, but lacks GPU support. I have to look into ocaml-xla a bit more but my feeling is that it should be reasonably straightforward to extend owl’s Algodiff module so it can operate on ocaml-xla’s Op.t type. Anybody interested in brainstorming/strategising/helping?
I’m working on a numerical / deep learning library that is an alternative to Owl, based on a different set of design choices.
It is not functorized, more compact, two-three abstraction layers,
a. it is centered around two syntaxes, for computations and for differentiable tensor operations,
b. the aim is to be able to express computations / models very concisely,
c. tensors are Bigarray.Genarray so easily usable from the outside, except for being tagged/boxed with numerical precision.
It has a strong built-in shape inference logic (it provides “generalized matrix multiplication”; “extended einsum notation”; dynamic indexing integrated with shape inference…)
I’m prioritizing low-level backends, lower-level than XLA. I just finished a C backend (via gccjit), I’ll work on a CUDA backend in the coming months.
It is limited to backpropagation (first order backward mode autodiff).
Sounds interesting! Point 4 is a showstopper for my needs though… is nested forward/backward something that will be outright impossible to incorporate in your framework by design, or is it just that you haven’t had the chance to do it yet? Also, re 1.c, owl tensors are also Bigarray.Genarray. FWIW, I find I am able to express computations and models very concisely in owl; but most of my owl code so far has been for custom, highly specific models and so I haven’t had to care much for a higher-level deep learning interface that would let me write a transformer in 10 lines of code, and in this respect I think I am not very representative of the numerical-ocaml community (though tbh I have no idea ). I’ll definitely keep an eye on your work though, it sounds promising!