Best approach for implementing open recursion over extensible types

I have a datatype, designed to be user-extensible, with particular constructors defined in different sub-modules:

type 'a expr = ..

type 'a expr +=
  | Int : int -> int expr
  | Add : 'a expr * 'a expr -> 'a expr
  | Sub : 'a expr * 'a expr -> 'a expr

module Bool = struct

  type 'a expr +=
    | Bool: bool -> bool expr
    | And : bool expr * bool expr -> bool expr
    | Or : bool expr * bool expr -> bool expr

end

module Arith = struct

  type 'a expr +=
    | Lt: int expr * int expr -> bool expr
    | UNKNOWN

end

I’d like to define various functions that recurse over this data-structure, but should also incorporate some kind of “open” recursion, so that users can add new constructors to the type and add suitable behaviours.

As an example, let’s consider a function to pretty print expressions: pp_expr: Format.formatter -> 'a expr -> unit.

What I’ve come up with is the following, using mutable references to tie the knot:

let pp_fns : printer list ref = ref ([]: _ list)
let pp_expr : 'a . Format.formatter -> 'a expr -> unit = fun fmt expr ->
  match List.find_opt (fun f -> f.pp_expr fmt expr) !pp_fns with
  | Some _ -> ()
  | None -> failwith "could not resolve pretty printer" 
let add_printer fmt = pp_fns := fmt :: !pp_fns

An example usage:

let () = add_printer {pp_expr=fun (type a) fmt (expr: a expr) : bool ->
    match expr with
    | Bool v ->
      Format.fprintf fmt "%b" v;
      true
    | And (l, r) ->
      Format.fprintf fmt "(%a) && (%a)" pp_expr l pp_expr r;
      true
    | Or (l, r) ->
      Format.fprintf fmt "(%a) || (%a)" pp_expr l pp_expr r;
      true
    | _ -> false
  }

This works well enough:

let () =
  let x = Int 10 in
  Format.printf "x is %a\n%!" pp_expr x;
  let y = Add (x, x) in
  Format.printf "y is %a\n%!" pp_expr y;
  let z = Arith.Lt (x, y) in
  Format.printf "z is %a\n%!" pp_expr z;
  let a = Bool.And (z, z) in
  Format.printf "a is %a\n%!" pp_expr a;
  let b = Bool.Or (a, Arith.UNKNOWN) in
  Format.printf "b is %a\n%!" pp_expr b

(* ==>
y is 10 + 10
z is 10 <= 10 + 10
a is (10 <= 10 + 10) && (10 <= 10 + 10)
b is ((10 <= 10 + 10) && (10 <= 10 + 10)) || (Fatal error: exception Failure("could not resolve pretty printer")
*)

Is this the best way of implementing this? What other implementations would you suggest?

2 Likes

Of course, my initial guess was to try the standard approach to open recursion, using a struct:

type iterator = { expr: 'a. iterator -> 'a expr -> unit; }

but one thing I couldn’t work out was how to compose them – I’d like to define the iterators separatly per sub-module, e.g., so that the code for handling all the constructors in the Arith module occurs within the Arith module itself.

Funny you should suggest that. I’ve just come up with something somewhat similar;

class type printer = object
  method pp_expr : 'a . Format.formatter -> 'a expr -> unit 
end

Defining printers is quite straightforward:

let printer (p: printer) : printer = object (self)
  method pp_expr : 'a . Format.formatter -> 'a expr -> unit = fun (type a) fmt (expr: a expr) : unit ->
    match expr with
    | Int n -> Format.fprintf fmt "%d" n
    | Add (l,r) -> Format.fprintf fmt "%a + %a" self#pp_expr l self#pp_expr r
    | Sub (l, r) -> Format.fprintf fmt "%a - %a" self#pp_expr l self#pp_expr r
    | _ -> p#pp_expr fmt expr
end

Finally, we need a compose function to tie things together:

let build_printer (printers: (printer -> printer) list) =
  let base_printer = object
    val mutable printer : oprinter option = None
    method set_printer p = printer <- Some p
    method pp_expr : 'a . Format.formatter -> 'a expr -> unit =
      fun fmt expr ->
      match printer with
      | None -> assert false
      | Some p -> p#pp_expr fmt expr
  end in
  let printer =
    List.fold_left (fun printer make_printer ->
      make_printer printer)
      (base_printer :> printer)
      printers in
  base_printer#set_printer printer;
  (base_printer :> printer)

Finally usage:

let () =
  let p = build_printer [Arith.printer; Bool.printer; printer] in
  let x = Int 10 in
  Format.printf "x is %a\n%!" p#pp_expr x;
  let y = Add (x, x) in
  Format.printf "y is %a\n%!" p#pp_expr y;
  let z = Arith.Lt (x, y) in
  Format.printf "z is %a\n%!" p#pp_expr z;
  let a = Bool.And (z, z) in
  Format.printf "a is %a\n%!" p#pp_expr a;
  let b = Bool.Or (a, Arith.UNKNOWN) in
  Format.printf "b is %a\n%!" p#pp_expr b

I don’t think I’m really using anything specific object-properties in the code, so it could probably be replaced with a struct.

The nice thing about this implementation is that the code for the printers can be written in an idiomatic way and don’t need to think about the failure case (no need to wrap things in optional parameters).

The problem with this design is that if no case can handle the constructor, then it loops.

Apologies for self-promotion and repetitiveness, anyway my slides on the expression problem. However I wrote them before GADTs were introduced.

Nice slides!

I think I recall Jacques Garrigue also has a few papers/slides on solving the expression problem in OCaml: https://www.cs.ox.ac.uk/ralf.hinze/WG2.8/22/slides/jacques.pdf, and https://www.math.nagoya-u.ac.jp/~garrigue/papers/variant-reuse.pdf were the ones I remember reading most recently.

These are pretty nice solutions – I especially like using polymorphic variants for extensible syntax, but the problem is that I need GADTs + extensible types, which seems to be a distinct element of the problem space that isn’t quite so easily tackled with the approaches suggested above.

Curious to hear what you think.

1 Like

I don’t know if what you’ve chosen is the best way of doing this but, pretty much, the standard PPX rewriters (e.g. show) generate code that follows this pattern (at least, that’s my memory).

For what it’s worth, your approach seems pretty similar to how adding printers for custom exceptions works, i.e. if I define some new exception e then I can “register” a printer for e using Printexc.register_printer. Exceptions use extensible variants internally, so maybe something in the Printexc module would be helpful?