GADT to encode function arity

Tentative minimal example

Is there a way to create a GADT to encode a function return type and arity, leaving the function argument type as a parameter? I would like to create a type that encodes

  • Equalhas two arguments and returns bool
  • Hash has a single argument and returns int

And then use that type to create some call functions with fixed argument type. For example, a call_t function such that call_t Equal has type t -> t -> bool, call_t Hash has type t -> int and a call_u function which is the same but with t replaced by u.

Is such a thing expressible with OCaml GADTs ?

Big picture

I’m trying to build a dynamic record type, i.e. a sort of record where the full set of fields is only known at runtime, with new fields being added dynamically via function calls. I would like to define operations on the whole record (compare, equal, hash, clone…) at the field level (so whenever a field is created, it must specify how to implement these operations). Critically, the set of operations is not fixed (some records may have none, some only have equal/compare, some have even more operations…).

The solution I found is to have the user specify a GADT ’a operand when creating the record, along with a merger function combine: ‘a operand -> ‘a -> ‘a -> ‘a, then for each field creation, also specify a function operand: ‘a operand -> field^n -> ‘a. This can then be lifted into a record wide operand: ’a operand -> record^n -> ‘a. The problem is how to encode the ^n… Right now all I can think of is to have one version per desired arity (so one GADT/combine function/operand function for unary operands (hash, clone), another one for binary operands (equal, compare)…).

Is there a way to encode that arity directly in the GADT, avoiding the need for this duplication ?

If you have only a limited number of argument types, you don’t need GADT, normal irregular ADTs suffice:

With GADTs, you can define:

type ('elt, 'fn) op = ..
type 'elt implentation = { f: 'fn. ('elt,'fn) op -> 'fn option }

type (_,_) op += 
| Hash: ('elt, 'elt -> int) op
| Eq: ('elt,'elt -> 'elt -> bool) op

let int_f (type f r) (op: (int,f) op) : f option = match op with
| Hash -> Some Fun.id
| Eq -> Some Int.equal
| _ -> None

let int = { f = int_f }
1 Like

This work if I use the function individually, the problem is that the record function calls the field function, so it needs to cast (record, f) op into (field, f with record replaced by field) op, I don’t think there is a way to type this cast properly in OCaml sadly.

You could use an existential wrapper:

type 'elt wrapper = Wrap: ('elt, 'fn) op -> 'elt wrapper

let cast: type elt_a elt_b fn. (elt_a, fn) op -> elt_b wrapper = function
  | Hash -> Wrap Hash
  | Eq -> Wrap Eq

However, this doesn’t solve my problem because, in order to use this cast function, you must match on the GADT to recover the type information. However, I would like my gadt to be user supplied on record creation as a functor parameter (so the user can create different record with different operation sets). This means I can’t match on the GADT in the implementation:

(** Example of defining the record-wide op from the field operation,
    using a simple tuple to mimic a two field record. 
    Crucially, this can't match on the [op] GADT since it is a functor parameter... *)
let op_f (type f r) (op: (int*int,r,f) op) : f =
    let Wrap op_int = cast op in
    int_f op_int ...

I guess you could also have the user supply a function mapping the GADT to another, known GADT with only encodes arity. But by then we’ve reached the point where the single function solution is more complex than having two separate GADTs for unary and binary operations.

you must match on the GADT to recover the type information

Unfortunately, that is how GADTs work. Once you take a GADT as an abstract type argument to a functor, it no longer matters that it is a GADT – from the point of view of the functor, it is just an abstract type constructor.

it needs to cast (record, f) op into (field, f with record replaced by field) op, I don’t think there is a way to type this cast properly in OCaml sadly.

You can (if you allow yourself pattern-matching, of course). The trick is to duplicate all your GADT parameters (and then a bit of quantification). Complexity and performance are both likely to be worse than using different GADTs for different arities though, so if you only care about unary and binary functions, you should use that solution.

type (_, _, _, _, _) arity =
  | Z : ('r, 'v, 'f, 'v, 'v) arity
  | S : ('r, 'mr, 'f, 'mf, 'v) arity -> ('r, 'r -> 'mr, 'f, 'f -> 'mf, 'v) arity

let rec project : type r mr f mf v. (r -> f) -> (r, mr, f, mf, v) arity -> mf -> mr =
  fun prj arity mf ->
    match arity with
    | Z -> mf
    | S arity -> fun r -> project prj arity (mf (prj r))

type (_, _, _, _, _) op =
  | Equal : ('a, 'a -> 'a -> bool, 'b, 'b -> 'b -> bool, bool) op
  | Hash : ('a, 'a -> int, 'b, 'b -> int, int) op

let arity : type r f mr mf v. (r, mr, f, mf, v) op -> (r, mr, f, mf, v) arity =
  function
  | Equal -> S (S Z)
  | Hash -> S Z

type ('r, 'mr, 'v, 'f) op0 = Op0 : ('r, 'mr, 'f, 'mf, 'v) op -> ('r, 'mr, 'v, 'f) op0

type ('r, 'mr, 'v) op1 = { op0 : 'f. ('r, 'mr, 'v, 'f) op0 }

type 'f field = Field : { call : 'r 'mr 'mf 'v. ('r, 'mr, 'f, 'mf, 'v) op -> 'mf } -> 'f field

let call : type r f mr v. (r -> f) -> f field -> (r, mr, v) op1 -> mr =
  fun get (Field { call }) { op0 = Op0 op } ->
    project get (arity op) (call op)

let rec combine : type r mr f mf v. (v list -> v) -> (r, mr, f, mf, v) arity -> mr list -> mr =
  fun combine_v arity mrs ->
    match arity with
    | Z ->
      combine_v mrs
    | S arity ->
      fun r -> combine combine_v arity (List.map (fun mr -> mr r) mrs)

type 'r rfield = Rfield : ('r -> 'f) * 'f field -> 'r rfield

let combine_call : type r mr v. (v list -> v) -> (r, mr, v) op1 -> r rfield list -> mr =
  fun combine_v ({ op0 = Op0 op } as op1) fields ->
    combine combine_v (arity op)
      (List.map (fun (Rfield (get, field)) -> call get field op1) fields)

let op_equal : 'r. ('r, 'r -> 'r -> bool, bool) op1 = { op0 = Op0 Equal }

let call_int : type r mr mf v. (r, mr, int, mf, v) op -> mf =
  function
  | Equal -> Int.equal
  | Hash -> (Hashtbl.hash : int -> int)

let int_field : int field = Field { call = call_int }

let equal_int_pair : (int * int) -> (int * int) -> bool =
  combine_call (List.for_all (fun b -> b)) op_equal [
    (Rfield (fst, int_field));
    (Rfield (snd, int_field))
  ]
2 Likes

Yes this is what I was trying to achieve! I agree it’ is probably too complex to use in practice, but it’s nice to know its possible.