Flattening a multiway tree in continuation passing style

Hi,

I’m trying to rewrite a non-tail recursive function which ‘flattens’ a multi-way tree in continuation passing style. My tree data type and original flatten function are:

type 'a t = Node of 'a * 'a t list

let branch x xs = Node(x,xs)

let leaf x = Node(x,[])

let rec flatten (Node (x, ts)) =
  match ts with
  | [] -> [ [ x ] ]
  | ts -> 
    List.map (fun xs -> x :: xs) @@ 
    List.concat @@ 
    List.map flatten ts

Given the following tree:

let t = 
  branch 1 
    [branch 2 
        [ leaf 3
        ; leaf 4
        ]
    ; branch 5 
        [ leaf 6 ]
    ]

I get the correct ‘flattened’ representation:

let ts = flatten t 
>> val ts : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 5; 6]]

My CPSd implementation is as follows:

let flatten_cps t = 
  let rec aux ~k (Node (x, ts)) =
    aux_forest ts ~k:(fun sfxs -> 
    k @@ List.map (fun sfx -> x :: sfx) sfxs
    )
  and aux_forest ~k = function
    | [] -> k [ [] ]
    | next :: rest ->
      aux next ~k:(fun next' -> 
      aux_forest rest ~k:(fun rest' -> 
      k @@ next' @ rest')
      )
  in
  aux ~k:(fun x -> x) t

When I apply this function I do get the flattened paths but I also get all of the prefixes:

let ts2 = flatten_cps t
>> val ts2 : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 2]; [1; 5; 6]; [1; 5]; [1]]

I’ve been staring at this for a while but can’t reason about what I’m doing wrong - can anyone spot where I’m going wrong?

Michael

I haven’t analyzed your code completely, but it seems that the branch

    | [] -> k [ [] ]

in aux_forest is what’s doing the damage. This check should be in aux.

I’m not sure how you built flatten_cps, but the easiest way to build a working implementation is to mechanically transform the original implementation by CPS-converting it bit-by-bit, including the functions that it calls.

First, flatten calls List.map, which looks roughly like this:

let rec map f l =
  match l with
   | [] -> []
   | x :: xs -> f x :: map f xs

Naming every intermediate computation produces something that is closer to CPS:

let rec map f l =
  match l with
   | [] -> []
   | x :: xs -> let y = f x in
                    let ys = map f xs in
                      y :: ys

The final steps are (1) to replace every let-binding with a construction of a continuation function, and (2) to apply the continuation k to every returned value:

let rec mapk f l k =
  match l with
  | [] -> k []
  | x :: xs -> f x @@ fun y ->
               mapk f xs @@ fun ys ->
               k (y :: ys)

A similar process for concat produces the following:

let rec concatk l k =
  match l with
  | [] -> k []
  | l :: r -> concatk r @@ fun r' ->
              k (l @ r')

Finally, applying the same mechanical process to your original flatten (and the anonymous function inside) gives:

let rec flattenk (Node (x, ts)) k =
  match ts with
  | [] -> k [ [ x ] ]
  | ts -> 
     mapk flattenk ts @@ fun x1 ->
     concatk x1 @@ fun x2 ->
     mapk (fun xs k -> k (x :: xs)) x2 k

and this flatten produces the results you expect:

# flattenk t (fun x -> x);;
- : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 5; 6]]
3 Likes

Thanks Jeremy! Am I right in thinking the intermediate step, where all intermediate computations are named, is ‘administrative normal form’?

Michael