Telling the type checker that two GADTs are equal

I’ve been working on a side-project in OCaml that binds to the
libmnl library to interact
with netlink sockets to add/remove/update network links. I have been
taking my time and experimenting to learn more about OCaml rather than
just try to stick to what I’m comfortable with.

Netlink messages have several trailing “attributes” of the form
len[2] tag[2] payload[len], of various shapes and sizes. For
example:

IFLA_ADDRESS /* 6-byte ethernet address */
IFLA_IFNAME  /* string, e.g. "eth0" */
IFLA_MTU     /* 32-bit unsigned integer (host endian) */

My initial thought was to model these with variants, so I would
be able to use them in pattern matching:

type nlattr =
 | IFLA_IFNAME of string
 | IFLA_MTU of int32
 | IFLA_ADDRESS of bytes

but I didn’t like that the payload was always attached to the tag.
I couldn’t do something like

let ifname = nlattr_find_opt IFLA_IFNAME nlmsg in

I could make the constructors take 0 arguments and group them by
their underlying type. This is what the Unix module does with the
socket_*_option types:

type nlattr_str =
 | IFLA_IFNAME
 | IFLA_IFALIAS

type nlattr_u32 =
 | IFLA_MTU
 | IFLA_LINK_NSID

val nlattr_get_u32 : nlmsg -> nlattr_u32 -> int32 option
val nlattr_get_str : nlmsg -> nlattr_str -> string option
(* and so on *)

But there are dozens of varying payloads, and some of them are only used
for one or two attributes. I tried modelling this using a GADT instead:

type _ nlattr =
 | IFLA_IFNAME: string nlattr
 | IFLA_MTU: int32 nlattr
 | IFLA_ADDRESS: bytes nlattr

After some contortions, I was able to implement an iterator with the
following signature:

type 'a nlattr_fn = { f: 't. 'a -> 't nlattr_type -> 't -> 'a } [@@unboxed]
val nlattr_walk: 'a nlattr_fn -> 'a -> nlmsghdr -> 'a

This function got pretty hairy because attributes can be nested.
Anyway, I wanted to implement the nlattr_find_opt function above,
which gets the payload for one attribute with the given tag, using
the more general nlattr_walk. Here was my initial attempt:

let nlattr_find_opt (type a) nlh (attr: a nlattr_type) =
 let f (type b) (init: a option) (k: b nlattr_type) (v: b) =
   if k = attr
   then Some v else init
 in
 nlattr_walk {f} None nlh

This failed to compile with the error:

727 |     if k = attr
                ^^^^
Error: This expression has type a nlattr_type
      but an expression was expected of type b nlattr_type
      Type a is not compatible with type b
(exit status 1)

After some research, I eventually was able to satisfy the type checker
with a witness:

type (_, _) eq = Equal : ('a, 'a) eq
let nlattr_eq: type a b. a nlattr_type -> b nlattr_type -> (a, b) eq option =
 fun a b ->
   match a, b with
   | IFLA_ADDRESS, IFLA_ADDRESS -> Some Equal
   (* ... snip ... *)
   | _ -> None

let nlattr_find_opt: type a. Nlmsghdr.t ptr -> a nlattr_type -> a option =
 fun nlh attr ->
 let f: type b. a option -> b nlattr_type -> b -> a option =
   fun init k v ->
     match nlattr_eq attr k with
     | Some Eq -> Some v
     | None -> init
 in
 nlattr_walk {f} None nlh

I understand that I need to prove to the type checker that the two
types are the same, but I haven’t quite wrapped my head around how the
nlattr_eq function proves that. I did some experimentation to try and
figure out what was allowed and what wasn’t. For example, if I change
this case:

match a, b with
...
| IFLA_IFALIAS, IFLA_IFALIAS -> Some Eq (* string, string *)

to something nonsensical like

| IFLA_IFALIAS, IFLA_NUM_TX_QUEUES -> Some Eq (* string, int32 *)

I get the error

Error: This expression has type (a, a) eq
      but an expression was expected of type (a, b) eq
      Type a = string is not compatible with type b = int32
      Type string is not compatible with type b = int32

but if I change it to this:

| IFLA_IFALIAS, IFLA_IFNAME -> Some Eq (* string, string *)

the program compiles, but as expected those two tags are treated as
equivalent in nlattr_find_opt. I can now see how I’d be able to use
this technique to simplify my program; for example I wouldn’t have
to enumerate every variant in my printing function; I can handle all
variants with the same payload at once.

I started this post before I got a working solution, and the process
of drafting the post helped me figure things out myself (funny how that
works!). But I’m posting anyway in case someone else is struggling with
GADTs, and also because I still have the lingering sensation that I’ve
performed some magical conjuring that I don’t fully understand.

Taking a simplified example

type (_, _) eq = Eq : ('a, 'a) eq

let nlattr_eq: type a b. a nlattr_type -> b nlattr_type -> (a, b) eq option =
 fun a b ->
 match a, b with
 | IFA_ADDRESS, IFA_ADDRESS -> Some Eq

The thing that still confuses me is the constructor Eq. Even though
I’m not providing any arguments, I somehow constructed a value of type
(a, b) eq. How did the type checker know that’ what I was doing?
In the pattern match

match a, b with

Is (a, b) not a tuple, but something else? Or is all the “magic” just
occurring in the function’s signature:

val nlattr_eq: type a b. a nlattr_type -> b nlattr_type -> (a, b) eq option

So by emitting Eq when a ≠ b, I’m violating the type signature? That
makes more sense, but if that is the case, is there some more compact
way to show the type checker that two variants are the same, closer to
my initial attempt of just comparing them with =? Something that does
not involve enumerating every possible variant?

I think there’s confusion with the overloaded (a, b) syntax. In the case of a type parameter, this is not a tuple term per se but a cartesian product of types. (The equivalent term would be spelled a * b; I imagine ML uses this different syntax to make the distinction more obvious, but I don’t know the history.)

In any case, the “magic” is entirely in the (_, _) eq type, whose only constructor requires the two parameters to be unified. This requirement is forced on the return type of nlattr_eq, as you see above. If you’re looking for stronger type safety, one common pattern involving GADTs would be to mint new (niladic) types for the different GADT constructors.

So, for example, instead of the first nlattr constructor being a string nlatter, it would be something like an ifname nlattr. (Note that if you never actually instantiate the ifname type, you can often get away with just declaring the type and not giving any constructors.)

I have to admit that I didn’t quite follow everything you were doing above and don’t have the context of what you’re “actually” trying to achieve, so it’s not obvious to me whether this GADT representation is actually helpful–it does let you enforce at the type level when two variants are the same, but I’m not sure what that buys you in this case. In your specific case, what do you gain by factoring the tag type separately from the payload itself? Are there cases where you might operate on the same underlying payload type in the “same” way even though the tags are mismatched? (I’m asking because depending on the answer, there may be different ways to design your types.)

I’m not sure what “factoring” means in this context. If you’re asking why I didn’t want to make the payload part of the type constructor, the main reason was because I wanted to be able to use the same name (IFLA_IFNAME, etc) for both get and set functions, but it doesn’t make sense for the get function to include a payload.

Yes, not in the example I posted (of fetching a specific tag), but certainly for pretty-printing, I could use the same printer for all variants that have the same underlying type. Also, I could use it for calculating how much storage I would to encode a given message.

First, a, b in

match a, b with

is exactly the same thing than (a,b), the parentheses are not needed when the context is unambiguous, for instance

let number_names = [ 0, "zero"; 1, "one"; 2, "two" ]

Then taking a simpler type as example

type _ typ = 
  | Int: int typ
  | Int8: int typ
  | Float: float typ

when you write

let eq: type a b. a typ -> b typ -> (a,b) Type.eq option = fun x y ->
match x, y with
| Int, Int -> Some Type.Equal
| Int8, Int8 -> Some Type.Equal
| Float, Float -> Some Type.Equal
| _ -> None

what happens is first the compiler adds an annotation to the expression side. In other words, the above can be read as

let eq: type a b. a typ -> b typ -> (a,b) Type.eq option = fun x y ->
match x, y with
| Int, Int -> (Some Type.Equal: (a,b) Type.eq option)
| Int8, Int8 -> (Some Type.Equal: (a,b) Type.eq option)
| Float, Float -> (Some Type.Equal: (a,b) Type eq option)
| _ -> None

Then in the context of each branch the annotation is valid. For instance in the Int8, Int8 branch

match x, y ->
...
| Int8, Int8 -> Some (Type.Equal: (a,b) Type.eq)

matching Int8:int typ for x: a typ introduces the local equation a=int.
Symmetrically, pattern matching Int8 for y gives us the equation b=int.

Then, when we write Type.Equal it has type ('x,'x) Type.eq for some type 'x by the definition of the constructor Type.Equal.

Finally, to check the type annotation, we have to unify ('x, 'x) Type.eq with the type constraint (a,b) Type.eq with the two equations a=int and b=int available.

Starting with the first parameter, gives us 'x = a, which reduces the unification to
(a,a) Type.eq : (a,b) Type.eq. For the second parameter, for the type annotation to be valid, we need to check that a = b. At this point, we need to use our two equations in the local context a = int = b to finally check that indeed Type.Equal: (a,b) Type.eq is a well-typed in the local context of the branch.

Thanks @octachron ! Your step-by-step explanation really helped dispel the “magic” :smiley:

This annotation in particular was really helpful. Are there any compiler flags or tools that could show output like this?

I fear that there is no such compiler flag (and it would be quite hard to implement in a non-noisy way).

You can use merlin to query the type at any point but I don’t know if it possible to query just the expected type.

By “factor” I just meant to remove the “tag” from your original sum and put it into its own type (with no constructor args). You also noted that you do in fact want this separation for reasons not illustrated in the example, so that makes sense.

I think the response by @octachron answered your actual question, but I’ll add a bit more about what I meant by saying that the “magic” is in the eq constructor. This example will not do what you want because it adds a secondary constructor to the eq type which does not require the parameters to be unified; it’s not practical in any way but perhaps gives more insight into what is happening:

type (_, _) eq = Equal : ('a, 'a) eq | NotNecessarilyEqual : ('a, 'b) eq

type _ t =
  | IfName : string t
  | IfMtu : int32 t
  | IfAddr : bytes t

let eq : type a b. a t -> b t -> (a, b) eq option = fun a b -> match (a, b) with
  | IfName, IfName -> Some Equal
  | IfMtu, IfMtu -> Some Equal
  | IfAddr, IfMtu -> Some NotNecessarilyEqual (* <- This lets you construct unequal types *)
  | IfAddr, IfAddr -> Some NotNecessarilyEqual (* But it also allows you to "loosen" constraints on the parameters you want to tie together. *)
  | _ -> None

When you remove the spurious NotNecessarilyEqual constructor, it forces unification since the Equal constructor requires it.

Adding type parameters doesn’t really help in this context because the compiler already gives a very good, focused error message when you try to construct a bad value:

type (_, _) eq = Equal : ('a, 'a) eq

type _ t =
  | IfName : string t
  | IfMtu : int32 t
  | IfAddr : bytes t

let eq : type a b. a t -> b t -> (a, b) eq option = fun a b -> match (a, b) with
  | IfName, IfName -> Some Equal
  | IfMtu, IfMtu -> Some Equal
  | IfAddr, IfAddr -> Some Equal
  | IfAddr, IfMtu -> Some Equal
  | _ -> None

The error message clearly calls out that it wants an (a, a) eq here (though you have to refer to the Equal constructor itself to realize where the constraint is coming from):

28 |   | IfAddr, IfMtu -> Some Equal
                               ^^^^^
Error: This expression has type (a, a) eq
       but an expression was expected of type (a, b) eq
       Type a is not compatible with type b

In any case, you can always hack around less-than-ideal type errors by applying spurious type annotations at the sites where you’re unsure of the inferred type, e.g.:

type (_, _) eq = Equal : ('a, 'a) eq

type _ t =
  | IfName : string t
  | IfMtu : int32 t
  | IfAddr : bytes t

let eq : type a b. a t -> b t -> (a, b) eq option = fun a b -> match (a, b) with
  | IfName, IfName -> Some Equal
  | IfMtu, IfMtu -> Some Equal
  | IfAddr, IfAddr -> Some Equal
  | IfAddr, IfMtu -> Some (Equal : int)
  | _ -> None

Which gives this (nearly identical) error message:

28 |   | IfAddr, IfMtu -> Some (Equal : int)
                                ^^^^^
Error: This expression has type ('a, 'a) eq
       but an expression was expected of type int

Note that I applied the annotation to the innermost (unknown) value. I find that this tends to give me more focused error messages, though I’m not sure exactly why (likely something to do with restricting the type inference surface). This hack is more likely to be useful where you don’t already get a good error message in any case.