Hi,
I’m trying to rewrite a non-tail recursive function which ‘flattens’ a multi-way tree in continuation passing style. My tree data type and original flatten function are:
type 'a t = Node of 'a * 'a t list
let branch x xs = Node(x,xs)
let leaf x = Node(x,[])
let rec flatten (Node (x, ts)) =
match ts with
| [] -> [ [ x ] ]
| ts ->
List.map (fun xs -> x :: xs) @@
List.concat @@
List.map flatten ts
Given the following tree:
let t =
branch 1
[branch 2
[ leaf 3
; leaf 4
]
; branch 5
[ leaf 6 ]
]
I get the correct ‘flattened’ representation:
let ts = flatten t
>> val ts : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 5; 6]]
My CPSd implementation is as follows:
let flatten_cps t =
let rec aux ~k (Node (x, ts)) =
aux_forest ts ~k:(fun sfxs ->
k @@ List.map (fun sfx -> x :: sfx) sfxs
)
and aux_forest ~k = function
| [] -> k [ [] ]
| next :: rest ->
aux next ~k:(fun next' ->
aux_forest rest ~k:(fun rest' ->
k @@ next' @ rest')
)
in
aux ~k:(fun x -> x) t
When I apply this function I do get the flattened paths but I also get all of the prefixes:
let ts2 = flatten_cps t
>> val ts2 : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 2]; [1; 5; 6]; [1; 5]; [1]]
I’ve been staring at this for a while but can’t reason about what I’m doing wrong - can anyone spot where I’m going wrong?
Michael
I haven’t analyzed your code completely, but it seems that the branch
| [] -> k [ [] ]
in aux_forest is what’s doing the damage. This check should be in aux.
I’m not sure how you built flatten_cps
, but the easiest way to build a working implementation is to mechanically transform the original implementation by CPS-converting it bit-by-bit, including the functions that it calls.
First, flatten
calls List.map
, which looks roughly like this:
let rec map f l =
match l with
| [] -> []
| x :: xs -> f x :: map f xs
Naming every intermediate computation produces something that is closer to CPS:
let rec map f l =
match l with
| [] -> []
| x :: xs -> let y = f x in
let ys = map f xs in
y :: ys
The final steps are (1) to replace every let
-binding with a construction of a continuation function, and (2) to apply the continuation k
to every returned value:
let rec mapk f l k =
match l with
| [] -> k []
| x :: xs -> f x @@ fun y ->
mapk f xs @@ fun ys ->
k (y :: ys)
A similar process for concat
produces the following:
let rec concatk l k =
match l with
| [] -> k []
| l :: r -> concatk r @@ fun r' ->
k (l @ r')
Finally, applying the same mechanical process to your original flatten
(and the anonymous function inside) gives:
let rec flattenk (Node (x, ts)) k =
match ts with
| [] -> k [ [ x ] ]
| ts ->
mapk flattenk ts @@ fun x1 ->
concatk x1 @@ fun x2 ->
mapk (fun xs k -> k (x :: xs)) x2 k
and this flatten
produces the results you expect:
# flattenk t (fun x -> x);;
- : int list list = [[1; 2; 3]; [1; 2; 4]; [1; 5; 6]]
3 Likes
Thanks Jeremy! Am I right in thinking the intermediate step, where all intermediate computations are named, is ‘administrative normal form’?
Michael