A short GADT puzzle

You may know of the newish function List.partition_map with the following type:

List.partition_map : ('a -> ('b, 'c) Either.t) -> 'a list -> 'b list * 'c list

If you tell partition_map how to map an 'a to either a 'b or a 'c, it will partition a list of 'a into a list of 'b and a list of 'c.

puzzle: write a “n-ary” version of partition_map that supports splitting a value into N possible variants of different types, and returns a product of N different lists.

My solution lets you write the following, for example:

(* map a string to either an int, or a bool, or just the same string *)
let classify str =
  match int_of_string str with
  | n -> Ok n
  | exception _ -> Error (
    match bool_of_string str with
    | b -> Ok b
    | exception _ -> Error str
  )

(* split a list three ways based on this *)
let test =
  partition_split (Cons (Cons One)) classify
    ["foo"; "42"; "true"; "bar"; "0"; "false"]

The result of evaluating test is

val test : int list * (bool list * string list) =
  ([42; 0], ([true; false], ["foo"; "bar"]))

(Solutions may use a different type to describe the splitting choices, and to group the partitioned lists. The point is that the choice happens between an arbitrary number of possibilities, possibly of different types, so the result of the splitting function looks like a sum 'b + 'c + 'd + ....)

1 Like

I couldn’t resist giving it a try:

type (_, _) split =
  | One : ('a, 'a list) split
  | Cons : ('a, 'x) split -> (('b, 'a) result, 'b list * 'x) split

let rec nil : type a x. (a, x) split -> x = function
  | One -> []
  | Cons split -> [], nil split

let rec cons : type a x b. (a, x) split -> a -> x -> x = fun split y ys ->
  match split, y, ys with
  | One, _, _ -> y :: ys
  | Cons _, Ok y, (ys, ys') -> (y :: ys, ys')
  | Cons split, Error y, (ys, ys') -> ys, cons split y ys'

let rec rev : type a x. (a, x) split -> x -> x = fun split l ->
  match split, l with
  | One, ys -> List.rev ys
  | Cons split, (ys, ys') -> List.rev ys, rev split ys'

let rec partition_split : type a x b. (a, x) split -> (b -> a) -> b list -> x = fun split f l ->
  let rec loop accu = function
    | [] -> rev split accu
    | x :: xs -> loop (cons split (f x) accu) xs
  in
  loop (nil split) l

Cheers,
Nicolas

Here’s one version using a tree data structure that does not require the user to specify the structure of the result when calling the partition function (instead, it is inferred from the result of the classify function).
There’s definitely a lot of opportunity for improvement:

  • maybe there’s a less verbose way to write the annotations for the polymorphic recursive functions
  • node elements are lists (because our accumulator is a list), meaning that update_at has to use lists as well, and we have to hardcode 'v list into the tree. A more useful version would allow higher kinded types, so that we can separate the structure of the tree from it’s contents.
  • the returned lists are reversed, because I didn’t add a polymorphic map function that we could use to reverse all the lists
  • some way of writing generic printing and iterator functions would be useful
open! Core

module rec Treepath : sig
  type (_,_) t =
    | T : (< v : 't >, 't) t
    | L : ('l, 't) t -> (< l : 'l; r : _ >, 't) t 
    | R : ('r, 't) t -> (< r : 'r; l : _ >, 't) t

  module With_value : sig
    type _ t = V : ('t, 'v) Treepath.t * 'v list -> 't t
  end
end = Treepath

module Treenode = struct
  type _ t  =
    | T : 't list -> < v : 't > t
    | LR : 'l t * 'r t -> < l : 'l; r : 'r > t
    | Default : _ t
end

let rec select_tree :
  'tree 'result. 'tree Treenode.t -> ('tree,'result) Treepath.t -> 'result list =
  fun (type tree result) (t : tree Treenode.t) (path : (tree,result) Treepath.t) : result list ->
  match t, path with
  | T v, T -> v
  | LR (l, _), L rest -> select_tree l rest
  | LR (_,r), R rest -> select_tree r rest
  | Default, _ -> []

let rec update_at :
  'tree 'v . 'tree Treenode.t -> ('tree, 'v) Treepath.t -> f:('v list -> 'v list) -> 'tree Treenode.t =
  fun (type tree v) (t : tree Treenode.t) (path : (tree, v) Treepath.t) ~(f : v list -> v list) : tree Treenode.t ->
  match t, path with
  | T v, T -> T (f v)
  | LR (l, r), L rest -> LR(update_at l rest ~f, r)
  | LR (l, r), R rest -> LR(l, update_at r rest ~f)
  | Default, T -> T (f [])
  | Default, L rest -> LR(update_at Default rest ~f, Default)
  | Default, R rest -> LR(Default, update_at Default rest ~f)

let rec partition_list_tree
  : 'a 'tree
  . 'a list
    -> acc:'tree Treenode.t
    -> f:('a -> 'tree Treepath.With_value.t)
    -> 'tree Treenode.t
  =
  fun (type a tree) (list : a list) ~(acc:tree Treenode.t) ~(f : a -> tree Treepath.With_value.t) : tree Treenode.t ->
  match list with
  | [] -> acc
  | x :: xs ->
    let V (p,v) = f x in
    let acc = update_at acc p ~f:(List.append v) in
    partition_list_tree xs ~acc ~f

let partition_list_tree list ~f = partition_list_tree list ~acc:Default ~f

let classify str =
  match int_of_string str with
  | n -> Treepath.With_value.V (L (L T),[n])
  | exception _ ->
    match bool_of_string str with
    | b -> Treepath.With_value.V (L (R T),[b])
    | exception _ -> Treepath.With_value.V ((R T), [str])

let%expect_test "partition" =
  let test =
    partition_list_tree
      ["foo"; "42"; "true"; "bar"; "0"; "false"]
      ~f:classify
  in
  print_s
    [%message ""
        (select_tree test (L (L T)) : int list)
        (select_tree test (L (R T)) : bool list)
        (select_tree test (R T) : string list)
    ];
  [%expect {|
    (("select_tree test (L (L T))" (0 42))
     ("select_tree test (L (R T))" (false true))
     ("select_tree test (R T)" (bar foo))) |}]


1 Like

I also went down the rabbithole of representing generalized sums as products with the requirement that it would not allow more values than nested eithers:

let test =
  partition
    (fun x -> (* no double [some] allowed! at least one [some] required! *)
           if x mod 2 = 0 then some (x*x) <+> none      <+> none
      else if x mod 3 = 0 then none       <+> some (-x) <+> none
                          else none       <+> none      <+> some (string_of_int x))
    (List.init 10 (fun i -> i))
type z = ZERO
type 'a s = SUCC of 'a

type (_, _, _) t =
  | ENone : ('a list, 'size, 'size) t
  | ESome : 'a -> ('a list, 'size, 'size s) t
  | EPair : ('a, 'size0, 'size1) t * ('b, 'size1, 'size2) t -> ('a * 'b, 'size0, 'size2) t

let none = ENone
let some x = ESome x
let ( <+> ) a b = EPair (a, b)

let rec template : type a b c. (a, b, c) t -> a
= function
  | ENone -> []
  | ESome _ -> []
  | EPair (a, b) -> template a, template b

let rec partition
: type a b. ?tail:b -> (a -> (b, z, z s) t) -> a list -> b
= fun ?tail fn -> function
  | [] -> (match tail with None -> failwith "oups" | Some x -> x)
  | x::xs ->
      let y = fn x in
      let tail = match tail with None -> Some (template y) | Some t -> Some t in
      let ys = partition ?tail fn xs in
      let rec go : type a b c. (a, b, c) t -> a -> a =
        function
          | ENone -> (fun xs -> xs)
          | ESome x -> (fun xs -> x::xs)
          | EPair (x, y) -> (fun (xs, ys) -> go x xs, go y ys)
      in
      go y ys

This is a pretty standard use of peano numbers, but I like the following version which allows for at most one some (rather than exactly one), because it required me to think about “how to convince the typechecker that two identical expressions should have types that don’t unify”:

type 'a left  = LEFT  of 'a
type 'a right = RIGHT of 'a

type (_, _, _) t =
  | ENone : ('a list,  'loc,  [> `none]) t
  | ESome : 'a -> ('a list,  'loc,  [> `some of 'loc]) t
  | EPair : ('a, 'loc left, 'r) t * ('b, 'loc right, 'r) t -> ('a * 'b, 'loc, 'r) t

(* in practice it requires a wrapper to hide the [loc]ation of the optional [ESome] *)
type 'a maybe = Maybe : ('a, _, _) t -> 'a maybe

(* same code as before, modulo type annotations *)

The sum-as-product presentation is fun, but I find it fishy that your function fails on the empty list. I think it is more reasonable to give the typeful arity to the function (or @smuenzel’s approach of returning a queriable datatype). Once the arity is given, it may be possible to simplify the indices of the sum-as-option-product datatype.

:smiley: Yeah I got lazy and didn’t want to define another “shape” GADT for the empty list case, but I missed that the type itself could be used for this purpose by restricting the product to only nones in the ~shape argument:

let rec partition
: type a b. shape:(b, z, z) t -> (a -> (b, z, z s) t) -> a list -> b
= fun ~shape fn -> function
    | [] -> template shape
    | x::xs ->
        let y = fn x in
        let ys = partition ~shape fn xs in
        (* same as before *)

Anyway I wouldn’t recommend this “sum as product” in practice :stuck_out_tongue:

Here’s my (GADT-free) solution:

type ('i, 'o) shape = 'o * ('i -> 'o -> 'o)
let one = [], List.cons
let cons (e, cons) = ([], e), fun x (bs, c) -> match x with
                                                 | Ok b -> (b :: bs, c)
                                                 | Error a -> (bs, cons a c)
let partition_split (e,c) f l = List.fold_right (fun y -> c (f y)) l e

With this code @gasche’s test case:

let test =
  partition_split (cons (cons one)) classify
    ["foo"; "42"; "true"; "bar"; "0"; "false"]

produces the expected result:

val test : int list * (bool list * string list) =
  ([42; 0], ([true; false], ["foo"; "bar"]))
6 Likes

Here’s my solution with GADT that almost matches @gasche code. The only difference is that I choose to use Either.t instead of Result.t..

let classify str =
  match int_of_string str with
  | n -> Either.Left n
  | exception _ ->
      Either.Right
        (match bool_of_string str with
        | b -> Either.Left b
        | exception _ -> Either.Right str)

type (_, _) shape =
  | One : ('a list, 'a) shape
  | Cons : ('a, 'e) shape -> ('b list * 'a, ('b, 'e) Either.t) shape

let partition_split (shape : (_, _) shape) classify l : 'a =
  let rec triage : type a b c. (a, b) shape -> b -> a -> a =
   fun shape x v ->
    match (shape, x, v) with
    | One, _, l -> x :: l
    | Cons _, Either.Left x, (l, r) -> (x :: l, r)
    | Cons shape, Either.Right x, (l, r) -> (l, triage shape x r)
  in
  let rec aux acc = function
    | [] -> acc
    | h :: t -> aux (triage shape (classify h) acc) t
  in
  let rec shape_to_base : type a b. (a, b) shape -> a = function
    | One -> []
    | Cons s -> ([], shape_to_base s)
  in
  let rec rev : type a b. (a, b) shape -> a -> a =
   fun shape l ->
    match (shape, l) with
    | One, l -> List.rev l
    | Cons shape, (l, r) -> (List.rev l, rev shape r)
  in
  aux (shape_to_base shape) l |> rev shape

let test =
  partition_split (Cons (Cons One)) classify
    [ "foo"; "42"; "true"; "bar"; "0"; "false" ]

It was really fun to do! Is there a place where such riddles are collected?

@yallop’s code is beautiful, but I was a bit perplexed by it. It works, but how do you get the idea? (What’s a generic process to derive this sort of solutions?)

I played with the code and made everything more verbose to understand it better. I’m posting the result here in case it could help. The key idea is that the ('i, 'o) shape is a generalized list, a final encoding of lists if you want, or just a “type of list objects” in OO parlance. I made this explict by using a heavy-handed record.

type ('i, 'o) glist = {
  nil: 'o;
  cons: 'i -> 'o -> 'o;
}

let list = {
  nil = [];
  cons = List.cons;
}

let succ (type i b o) (gli : (i, o) glist)
  : ((b, i) result, b list * o) glist
= {
    nil = ([], gli.nil);
    cons = begin fun x (bs, c) -> match x with
      | Ok b -> (b :: bs, c)
      | Error i -> (bs, gli.cons i c)
    end;
  }

Now the code of partition_split is clearer to me:

let partition_split gli f l =
  List.fold_right (fun y -> gli.cons (f y)) l gli.nil

This is just a “map” function over list, except that the input list is a normal list but the result is a generalized list. You could build filter or filter_map in the same way, by definition a ('a option, 'a list) glist structure in the natural way. What other List functions fit this pattern?

@yallop followed my unary notation, but it is clearly clearer to use an explicit product of generalized lists:

let either (type i1 o1 i2 o2) (gli1 : (i1, o1) glist) (gli2 : (i2, o2) glist)
  : ((i1, i2) Either.t, o1 * o2) glist
= {
  nil = (gli1.nil, gli2.nil);
  cons = begin fun x (xs1, xs2) ->
    match x with
    | Either.Left x1 -> (gli1.cons x1 xs1, xs2)
    | Either.Right x2 -> (xs1, gli2.cons x2 xs2)
  end;
}

(Instead of succ (succ list), you then write either list (either list list).)

1 Like

I originally saw this approach in

FUNCTIONAL PEARL: Do we need dependent types?
Daniel Fridlender and Mia Indrika
J. Functional Programming 10 (4): 409–415, July 2000

where it was used to define a generalized zipWith.

More generalized-list constructors:

val either : ('a, 'la) glist -> ('b, 'lb) glist ->
             (('a, 'b) Either.t, 'la * 'lb) glist
(* partition *)

val prod : ('a, 'la) glist -> ('b, 'lb) glist ->
           (('a * 'b), 'la * 'lb) glist
(* split *)

val option : ('a, 'la) glist -> ('a option, 'la) glist
(* filter *)

val list : ('a, 'la) glist -> ('a list, 'la) glist
(* concat *)

val comap : ('a -> 'b) -> ('b, 'lb) glist -> ('a, 'lb) glist
(* partition -> partition_map,
   filter -> filter_map,
   concat -> concat_map... *)

(I’m not sure how to do zip/combine; the functional pearl mentioned uses a different type schema for the combinators that is not obviously a special case of this glist definition.)

@yallop’s construction looks very similar to the one described in this old Fold page on mlton.org.

I don’t think that @yallop’s code gets much of using the framework there.

But the general zip_with and taut problems mentioned in the paper he linked to are rather trivial to encode in this framework.

Given:

module Fold = struct
  let stop (a, f) = f a
  let fold (a, f) g = g (a, f)
  let step0 h (a, f) = fold (h a, f)
end

The zip_with problem is:

let zip_with f z = Fold.fold (Seq.repeat f, List.of_seq) z
let args vs z =
  let app f v = f v in
  Fold.step0 (fun fs -> Seq.map2 app fs (List.to_seq vs)) z

(* Test *)

let labeled_sum x y label = x + y, label
let test =
  zip_with labeled_sum
    (args [0;1;2]) (args [1;2;3]) (args ["a";"b";"c"]) Fold.stop

test is:

val test : (int * string) list = [(1, "a"); (3, "b"); (5, "c")]

and the taut problem is:

let taut z = Fold.fold (Fun.id, Fun.id) z
let arg z = Fold.step0 (fun k f -> k (f true) && k (f false)) z

let taut1 = taut arg Fold.stop
let taut2 = taut arg arg Fold.stop

(* Test *)

let p0 a = a
let p1 a = a || not a
let p2 a b = a && b || not (a && b)
let p3 a b = a && b

let p0_is_taut = taut1 p0
let p1_is_taut = taut1 p1
let p2_is_taut = taut2 p2
let p3_is_taut = taut2 p3

with results:

val p0_is_taut : bool = false
val p1_is_taut : bool = true
val p2_is_taut : bool = true
val p3_is_taut : bool = false