What is the use of Continuation Passing Style (CPS)?

I wrote a small (classical) example:

(* computing the length of a list, not tail-recursive *)
let rec list_length = function
  | [] -> 0
  | _::s -> 1 + list_length s (* not a tail call *)

(* tail-recursive version adding an accumulator *)
let list_length_tail l =
  let rec aux acc = function
  | [] -> acc
  | _::s -> aux (acc + 1) s
  in
  aux 0 l

type 'a binary_tree =
  | Empty
  | Node of 'a * ('a binary_tree) * ('a binary_tree)

(* computing the height of a tree, not tail-recursive *)
let rec tree_height = function
  | Empty -> 0
  | Node (_, l, r) -> 1 + max (tree_height l) (tree_height r)

(* impossible to make it tail-recursive adding an accumulator, try it... *)

(* tail-recursive version using CPS *)
let tree_height_tail t =
  let rec aux t k = match t with
    | Empty -> k 0
    | Node (_, l, r) ->
        aux l (fun lh ->
        aux r (fun rh ->
        k (1 + max lh rh)))
  in
  aux t (fun x -> x)

let _ =
  let l = [1; 2; 3; 4] in
  Format.printf "size of the list is: %d@." (list_length l);
  Format.printf "size of the list is: %d@." (list_length_tail l);
  let t = Node (1, Empty, Node(2, Node (3, Empty, Empty), Empty)) in
  Format.printf "height of the tree is: %d@." (tree_height t);
  Format.printf "height of the tree is: %d@." (tree_height_tail t)

Then, if you wonder why tail-recursive function is sometimes needed, well, it’s because if your list or tree is too big, your function will fail with a stack overflow.

EDIT: also note that getting from the non-cps version to the cps one is almost only syntactic :slight_smile:

let rec tree_height t = match t with
  | Empty -> 0
  | Node (_, l, r) -> 1 + max (tree_height l) (tree_height r)

(* add intermediate values: *)
let rec tree_height t = match t with
  | Empty -> 0
  | Node (_, l, r) ->
      let lh = tree_height l in
      let rh = tree_height r in
      1 + max lh rh

(* add a continuation to the args and before returning any value: *)
(* this is not valid OCaml *)
let rec tree_height t k = match t with
  | Empty -> k 0
  | Node (_, l, r) ->
      let lh = tree_height l in
      let rh = tree_height r in
      k (1 + max lh rh)

(* replace all intermediate `let x = f y` by `tree_height y (fun x ->` : *)
let rec tree_height t k = match t with
  | Empty -> k 0
  | Node (_, l, r) ->
      tree_height l (fun lh ->
      tree_height r (fun rh ->
      k (1 + max lh rh)))
5 Likes