How to get this serialization/deserialization code to type-check?

Suppose I have a module Ndarray and I wan to serialize/deserialize its contents using a specific method like below:

module Ndarray = struct
  type dtype =
    | Int64
    | Float32
    | Int32

  type 'a t =
    {shape : int array
    ;stride : int array
    ;dtype : dtype
    ;data : 'a array}

  let iter f x = Array.iter f x.data

  let dtype_size t = match t.dtype with
    | Int32 -> 4
    | Int64 -> 8
    | Float32 -> 4

  let data_type x = x.dtype
end

let add_float32 buf v = Int32.bits_of_float v |> Buffer.add_int32_le buf

let to_bytes x =
  let buf = Buffer.create @@ (Ndarray.dtype_size x) * (List.fold_left Int.mul 1 x.shape) in
  match Ndarray.data_type x with
  | Int64 -> Ndarray.iter (Buffer.add_int64_le buf) x; Buffer.contents buf
  | Float32 -> Ndarray.iter (Buffer.add_int32_le buf) x; Buffer.contents buf
  | Int32 -> Ndarray.iter (add_float32 buf) x; Buffer.contents buf

type metadata = {dtype : Ndarray.dtype; shape : int array}

let reconstruct dtype shape = (* some function to reconstruct Ndarray.t using dtype and shape. *)

let of_bytes buf meta =
  match meta.dtype with
  | Int32 ->
    reconstruct dtype meta.shape @@ fun i -> String.get_int32_le buf (4 * i)
  | Int64 ->
    reconstruct dtype meta.shape @@ fun i -> String.get_int64_le buf (8 * i)
  | Float32 ->
    reconstruct dtype meta.shape @@ fun i -> String.get_int32_le buf (4 * i) |> Int32.float_of_bits

I cannot seem to get the to_bytes and of_bytes functions to type-check due to the pattern matching. I get a This expression has type int64 Ndarray.t but an expression was expected of type int32 Ndarray.t Type int64 is not compatible with type int32 error. If I try to add type annotations using
type a. a Ndarray.t -> string = fun x -> I get his expression has type a Ndarray.t but an expression was expected of type int64 Ndarray.t Type a is not compatible with type int64. Same thing with the of_bytes function.

Is there something I can do to make the functions type-check? Or am I supposed to implement them differently? If so, how? Or am I stuff having to have specialized functions like to_int32_bytes (that won’t help since I have more dtype constructors in a real world example).

Any help is appreciated.

This is a textbook case for GADTs. You can find some documentation in the manual (see this section in particular). The basic idea is that you want to link the type of the elements of the array with the dtype field, but your current type makes them completely independent.
Here is an alternate definition of dtype, that you should be able to adapt to fit your needs:

type 'elt_type dtype =
| Int64 : Int64.t dtype
| Float32 : Float.t dtype
| Int32 : Int32.t dtype

Don’t hesitate to ask further questions if there are still things that you do not understand (but don’t forget to look at the section of the manual linked above first).

3 Likes

I have read your link on GADTs but I am not convinced they really help in the situation presented.
Correct me if I’m wrong, but if one replaces the OP’s original non-GADT dtype with your parametrized dtype, it becomes impossible to define an equivalent of the dtype_size function in the OP’s code :

─( 17:15:30 )─< command 12 >──────────────────────────────────────────────────────────────────────────────{ counter: 0 }─
utop #   type 'elt_type dtype =
  | Int64 :   int64 dtype
  | Float32 :   float dtype
  | Int32 :   int32 dtype ;;
type 'elt_type dtype =
    Int64 : int64 dtype
  | Float32 : float dtype
  | Int32 : int32 dtype
─( 17:15:37 )─< command 13 >──────────────────────────────────────────────────────────────────────────────{ counter: 0 }─
utop # 
  let dtype_size dt = match dt with
    | Int32  -> 4
    | Int64  -> 8
    | Float32  -> 4 ;; 
Error: This pattern matches values of type int64 dtype
       but a pattern was expected which matches values of type int32 dtype
       Type int64 is not compatible with type int32

It’s one of the common issues with GADTs: inference does not always work as well as for regular datatypes.
You should be able to type-check the function with a few additional annotations:

let dtype_size : type a. a dtype -> int = function
  | Int32  -> 4
  | Int64  -> 8
  | Float32  -> 4
1 Like

This works nicely, thanks a lot!