I’m trying to rewrite a non-tail recursive function which ‘flattens’ a multi-way tree in continuation passing style. My tree data type and original flatten function are:

type 'a t = Node of 'a * 'a t list
let branch x xs = Node(x,xs)
let leaf x = Node(x,[])
let rec flatten (Node (x, ts)) =
match ts with
| [] -> [ [ x ] ]
| ts ->
List.map (fun xs -> x :: xs) @@
List.concat @@
List.map flatten ts

let ts = flatten t
>> val ts : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 5; 6]]

My CPSd implementation is as follows:

let flatten_cps t =
let rec aux ~k (Node (x, ts)) =
aux_forest ts ~k:(fun sfxs ->
k @@ List.map (fun sfx -> x :: sfx) sfxs
)
and aux_forest ~k = function
| [] -> k [ [] ]
| next :: rest ->
aux next ~k:(fun next' ->
aux_forest rest ~k:(fun rest' ->
k @@ next' @ rest')
)
in
aux ~k:(fun x -> x) t

When I apply this function I do get the flattened paths but I also get all of the prefixes:

let ts2 = flatten_cps t
>> val ts2 : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 2]; [1; 5; 6]; [1; 5]; [1]]

I’ve been staring at this for a while but can’t reason about what I’m doing wrong - can anyone spot where I’m going wrong?

I’m not sure how you built flatten_cps, but the easiest way to build a working implementation is to mechanically transform the original implementation by CPS-converting it bit-by-bit, including the functions that it calls.

First, flatten calls List.map, which looks roughly like this:

let rec map f l =
match l with
| [] -> []
| x :: xs -> f x :: map f xs

Naming every intermediate computation produces something that is closer to CPS:

let rec map f l =
match l with
| [] -> []
| x :: xs -> let y = f x in
let ys = map f xs in
y :: ys

The final steps are (1) to replace every let-binding with a construction of a continuation function, and (2) to apply the continuation k to every returned value:

let rec mapk f l k =
match l with
| [] -> k []
| x :: xs -> f x @@ fun y ->
mapk f xs @@ fun ys ->
k (y :: ys)

A similar process for concat produces the following:

let rec concatk l k =
match l with
| [] -> k []
| l :: r -> concatk r @@ fun r' ->
k (l @ r')

Finally, applying the same mechanical process to your original flatten (and the anonymous function inside) gives:

let rec flattenk (Node (x, ts)) k =
match ts with
| [] -> k [ [ x ] ]
| ts ->
mapk flattenk ts @@ fun x1 ->
concatk x1 @@ fun x2 ->
mapk (fun xs k -> k (x :: xs)) x2 k

and this flatten produces the results you expect:

# flattenk t (fun x -> x);;
- : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 5; 6]]