Can a state monad be optimized out with flambda

I am trying to get flambda to optimize code that uses what is essentially a simple state monad, with let* and let+ operators, to the code that can be written by breaking the abstraction and directly writing the code in state-passing style. I have not been able to find attributes or compilation flags that enable the abstraction-preserving version to compile to close to as good code as the abstraction-breaking version, and wonder if anyone has ideas.

Here is a concrete example:

module State = struct
  type s
  type 'a m = s -> 'a * s

  let return a s = (a, s)

  let bind m k s =
    let a, s = m s in
    k a s

  let map f m = bind m (fun a -> return (f a))

  module Import = struct
    let ( let* ) = bind
    let ( let+ ) m f = map f m
  end

  let read s = (s, s)
  let run s m = m s

  let mfold ps a ~f s =
    let f (a, s) p = f p a s in
    List.fold_left f (a, s) ps
end

open State.Import

type t =
  | True
  | Not of t
  | And of t list
  | Or of t list
  | If of {cnd: t; pos: t; neg: t}

let iter_dnf1 ~meet ~top fml ~f =
  let rec add_conjunct fml (cjn, splits) =
    match fml with
    | True | Not _ ->
        let+ cjn = meet fml cjn in
        (cjn, splits)
    | And ps -> State.mfold ~f:add_conjunct ps (cjn, splits)
    | Or ps -> State.return (cjn, ps :: splits)
    | If {cnd; pos; neg} ->
        add_conjunct (Or [And [cnd; pos]; And [Not cnd; neg]]) (cjn, splits)
  in
  let rec add_disjunct (cjn, splits) fml =
    let* cjn, splits = add_conjunct fml (cjn, splits) in
    let+ s = State.read in
    match splits with
    | ps :: splits ->
        List.iter
          (fun fml -> fst (State.run s (add_disjunct (cjn, splits) fml)))
          ps
    | [] -> f (cjn, s)
  in
  add_disjunct (top, []) fml

let dnf1 ~meet ~top fml =
  let+ s = State.read in
  fun f -> fst (State.run s (iter_dnf1 ~meet ~top fml ~f))

let iter_dnf2 ~meet ~top fml ~f s =
  let rec add_conjunct fml (cjn, splits) s =
    match fml with
    | True | Not _ ->
        let cjn, s = meet fml cjn s in
        ((cjn, splits), s)
    | And ps -> State.mfold ~f:add_conjunct ps (cjn, splits) s
    | Or ps -> ((cjn, ps :: splits), s)
    | If {cnd; pos; neg} ->
        add_conjunct
          (Or [And [cnd; pos]; And [Not cnd; neg]])
          (cjn, splits) s
  in
  let rec add_disjunct (cjn, splits) fml s =
    let (cjn, splits), s = add_conjunct fml (cjn, splits) s in
    match splits with
    | ps :: splits ->
        (List.iter (fun fml -> fst (add_disjunct (cjn, splits) fml s)) ps, s)
    | [] -> (f (cjn, s), s)
  in
  add_disjunct (top, []) fml s

let dnf2 ~meet ~top fml s =
  ((fun f -> fst (iter_dnf2 ~meet ~top fml ~f s)), s)

The iter_dnf1 and dnf1 functions preserve the abstraction while the iter_dnf2 and dnf2 functions break it, but are otherwise equivalent. The code generated for the latter version is rather better, with the former in particular allocating many closures. See this example at Compiler Explorer. That is for 4.12, and the relative situation is similar with 4.14.

It seems difficult for flambda to eliminate the closures for the anonymous functions that result from the elaboration of the let* and let+ operators, but maybe something else is going on.

Another transformation that seems difficult is needed in the add_conjunct function. The bodies of some cases need to be eta-expanded with the state argument, and then fun needs to be hoisted over match, such as transforming:

match e with A -> fun s -> a | B -> fun s -> b | C -> fun s -> c

to

fun s -> match e with A -> a | B -> b | C -> c

Any suggestions would be much appreciated!

1 Like

I’m afraid this is going to be complicated.
Let’s start with the issue with add_conjunct:

match e with A -> fun s -> a | B -> fun s -> b | C -> fun s -> c

is not equivalent to

fun s -> match e with A -> a | B -> b | C -> c

In the general case, if e contains side effects then the second version may duplicate those effects. Even if e is pure, the second version still pushes the evaluation of e and the pattern-matching under the abstraction, which from a performance point of view is not great.

For the rest of the code, the main issue seems to be that you’d like flambda (or the compiler in general) to be able to eta-expand functions that immediately return a function. This might be something we could implement in flambda, but we haven’t done it yet, and if we do it will have to be opt-in (otherwise it could lead to unpleasant surprises, like code duplication).

It might be possible to write a ppx that translates the first version into the second version automatically though.

3 Likes

Thanks @vlaviron, it is helpful to know what to potentially expect from the optimizer.

For the transformation hoisting fun over match, would there be performance concerns in the case that the matched expression is an immediate value? In this case it is even an identifier, that is a formal parameter so even considerations like live ranges ought to be ok IIUC.

Is there an easy explanation or example of how eta-expanding functions that immediately return functions can cause code duplication? I ask since when manually adjusting code my experience has been that eta-expansion (of expressions that just build a closure by partial application) is never bad and sometimes good for performance, and your comment indicates that generalizing like this isn’t right.

Yes. Take this example:

let f x =
  match x with
  | true -> fun y -> x + y
  | false -> fun y -> x - y

let g x n =
  let f' = f x in
  for i = 1 to n do
    ignore (f' n)
  done

Here, the code does 1 match on x and n arithmetic operations. If you lift y out of the match, then no matter what your inliner does you still have to do n matches on x.
If x was a known constant instead, then flambda would be able to simplify the body and you would end up with the same code in both cases, but in general pushing even pure code inside a function is not such a good idea.

It’s almost always a good idea to eta-expand, but there are a few corner cases.
For example:

let f1 _ x = body
let f2 _ = (); (fun x -> body)

Assuming that body doesn’t have any free variables except x, then partial applications of f1 will create a new wrapper each time, while applications of f2 will always return the same constant closure. It’s not a big deal, as the only difference is in the size of the executable (the partial application wrappers will also end up simplified to constant closures, hopefully), but it’s the kind of things we tend to be cautious about. Also, if body is too big to be inlined, then the extra wrapper will keep the useless argument alive for longer than necessary.

But I think there is a strong argument for automatic eta-expansion in flambda for simple cases, which is that the compiler passes before flambda already do it as much as possible. For example, in the following code f1, f2 and f3 will be simplified to fun x y -> x + y, but not f4. If we did the same optimisations in flambda, we could also handle f4 in the same way.

let f1 x =
  fun y -> x + y

let f2 x =
  let g y = x + y in
  g

let f3 x =
  let g y = x + y in
  (fun x -> x) g

let id x = x

let f4 x =
  let g y = x + y in
  id g

Maybe I’ll find an intern that can be tricked into working on that.

1 Like

You can play with [@@inline] and [@@specialise] annotations and see if it brings any improvements, e.g.,

module State = struct
  type s
  type 'a m = s -> 'a * s

  let return a s = (a, s)
  [@@inline]
  

  let bind m k s =
    let a, s = m s in
    k a s
  [@@specialise]
  [@@inline]

  let map f m = bind m (fun a -> return (f a))
  [@@specialise]
  [@@inline]

  module Import = struct
    let ( let* ) = bind
    let ( let+ ) m f = map f m
    [@@specialise]
    [@@inline]
    
  end

  let read s = (s, s) [@@inline]
  let run s m = m s

  let mfold ps a ~f s =
    let f (a, s) p = f p a s in
    List.fold_left f (a, s) ps
  [@@specialise]
end

This might help with this microbenchmark. However, my experience with using bap as the macrobencmark, and this project uses state monads a lot, shows that representing a state monad as an unboxed function offers worse performance than representing it is a boxed value, e.g.,

let ('a,'s) state = {run : 's -> 'a * 's}

Adding the [@@unboxed] annotation will effectively turn this representation into the same that you’re using, and our benchmarks show that the unboxed representation is about 10 to 20% slower. This is a rather known issue, see the corresponding angstrom discussion (I saw a deeper explanation somewhere, but can’t find it anymore).

Even more, our experiments show that the following state monad is more efficient,

type 'a state = {
    run : 'r. ('a -> state -> 'r) -> state -> 'r
  }

It has two benefits. It doesn’t require an extra allocation for the state/value pair and combines very nicely with other monads, such as the Error monad and the Continuation monad. So the you can have a monad stacked without the overhead of a monad transformer, e.g., the following representation is suitable for implementing state, continuation, and error monad interfaces (with an extra bonus that the error monad doesn’t require boxing a value in a Result.t style structure),

  type 'a fmonad = {
    run : reject:(error -> unit) -> accept:('a -> state -> unit) -> state -> unit
  }

You can refer to the code of the linked PR for the implementation details and more discussion. Or ask here, if anything is unclear.

With that all said, I don’t think there’s that much of a need to avoid short-living allocations with all costs. With a properly tuned GC, such allocations will not be promoted and are have very small impact on the performance. And as with all optimizations, it is always better to benchmark before optimizing.

6 Likes

More eta-expansion in flambda makes sense to me. Just as a data point: I find it surprising that f4 does not get compiled the same as f1, f2, and f3. In particular, my expectation would be that id gets inlined into f4 and then continues exactly as f3. (I also don’t understand how this case requires eta-expansion, but I’m not familiar with the internals of the transformations flambda performs.)

Thanks @ivg, these are interesting alternatives. I experimented a bit with my use case, and the boxed representation does indeed perform better, which is counterintuitive to me. I see how the CPSed representation combines better with other monads, but in my use case it is used alone and the boxed direct style representation is faster than the boxed CPS representation. But even with optimization settings beyond -O3 and judicious use of annotations, there is still a significant cost to any of these monad abstractions. To be fair, in the actual application, the code involved is as hot as it gets.

And yes, the short-lived allocations are not themselves what I am worried about, but they indicate that the compiler has not seen how to eliminate the intermediate functions, and thereby enabled further simplifications and optimizations. I started this investigation while experimenting with adding a monad abstraction, then saw a significant end-to-end performance hit, and proceeded to go hunting with the profiler.

1 Like