Manual CPS is faster than CPS monad. Is it expected?

I was experimenting with monads, and discovered that CPS monad performs slower than manual CPS transformation. It seems the case for OCaml, Racket and Chicken Scheme. In OCaml the difference can be 2-3x. In GHC they perform exactly the same

Do we have a reasonable explanation? Is it a missing optimization or an artifact of missing CPS-based inner representation? Maybe my code is wrong or I am measuring wrong thing?

Manual CPS
let fib_cps =
  let rec helper n k =
    if n <= 1 then k 1
    else helper (n - 1) (fun p1 -> helper (n - 2) (fun p2 -> k (p1 + p2)))
  in
  fun n -> helper n Fun.id
CPS monad
module Cont =
struct
  type ('a, 'b) cont = 'a -> 'b
  type 'a t = { cont : 'b. ('a, 'b) cont -> 'b } [@@unboxed]

  let return (x : 'a) = { cont = (fun k -> k x) } [@@inline always]

  let ( >>= ) (x : 'a t) (f : 'a -> 'b t) : 'b t =
    { cont = (fun k -> x.cont (fun v -> (f v).cont k)) }
    [@@inline always]

  let error = failwith
  let run_cont f { cont } = cont f [@@inline always]

  module Syntax = struct
    let ( let* ) = ( >>= )
  end
end

let fib_cps_m n =
  let open Cont in
  let open Cont.Syntax in
  let rec helper n =
    if n <= 1 then return 1
    else
      let* l = helper (n - 1) in
      let* r = helper (n - 2) in
      return (l + r)
  in
  run_cont Fun.id (helper n)

Update: it is 4.14.1+flambda

-dlambda output
   fib_cps/298 =
     (letrec
       (helper/299
          (function n/300[int] k/301
            (if (<= n/300 1) (apply k/301 1)
              (apply helper/299 (- n/300 1)
                (function p1/302[int]
                  (apply helper/299 (- n/300 2)
                    (function p2/303[int] (apply k/301 (+ p1/302 p2/303)))))))))
       (function n/304[int] : int
         (apply helper/299 n/304 (function prim/410 stub prim/410))))


   fib_cps_m/311 =
     (function n/313[int] : int
       (letrec
         (helper/314
            (function n/315[int]
              (if (<= n/315 1) (apply (field 0 Cont/297) 1)
                (apply (field 0 (field 4 Cont/297))
                  (apply helper/314 (- n/315 1))
                  (function l/316[int]
                    (apply (field 0 (field 4 Cont/297))
                      (apply helper/314 (- n/315 2))
                      (function r/317[int]
                        (apply (field 0 Cont/297) (+ l/316 r/317)))))))))
         (apply (field 3 Cont/297) (function prim/411 stub prim/411)
           (apply helper/314 n/313))))

Update: flambda output in the repo

1 Like

Since the code is the same up to inlining I would suspect missed inlining opportunities. Have you tried compiling with an flambda-enabled compiler?

It is already flambda.

The -dlambda output will not be very useful as it’s an IR generated prior to flambda optimization passes. Maybe -dflambda-verbose?

Since the code is the same up to inlining

Actually, I was wrong. In the manually translated CPS, helper takes its continuation argument before testing n, whereas in the monadic version it first tests n, and each branch of the conditional returns a closure. I suspect this is the reason for the difference.

I tried to implement your idea, and it looks like new bad implementation is in between manual and monadic one.

let fib_cps2_bad =
  let rec helper n =
    if n <= 1 then fun k -> k 1
    else fun k ->
      helper (n - 1) (fun p1 -> helper (n - 2) (fun p2 -> k (p1 + p2)))
  in
  fun n -> helper n Fun.id
                        Rate fib monadic cps fib Bad manual cps   fib manual cps
   fib monadic cps  647082/s              --               -48%             -69%
fib Bad manual cps 1241157/s             92%                 --             -40%
    fib manual cps 2084202/s            222%                68%               --

I think the monadic implementation is more like

let fib_cps3 =
  let rec helper n =
    if n <= 1 then fun k -> k 1
    else let h = helper (n - 1) in fun k ->
      h (fun p1 -> helper (n - 2) (fun p2 -> k (p1 + p2)))
  in
  fun n -> helper n Fun.id

ie helper (n - 1) is computed before getting the continuation

I have a feeling that to make these kind of CPS optimization we need to be aware of function’s purity/impurity. Typed algebraic effects to the rescue?

                          Rate fib monadic cps fib cps 3 SkySkimmer fib Bad manual cps fib manual cps
     fib monadic cps  649477/s              --                 -56%               -61%           -72%
fib cps 3 SkySkimmer 1471670/s            127%                   --               -11%           -36%
  fib Bad manual cps 1650165/s            154%                  12%                 --           -28%
      fib manual cps 2294104/s            253%                  56%                39%             --