Generalizing types for monadic map

I need to extend Stdlib.Map or BatMap with monad-friendly functions and hoping someone can suggest a way to generalize the polymorphic types in the bind and return functions in the following:

type ('k, 'v) map =
  | Empty
  | Node of ('k, 'v) map * 'k * 'v * ('k, 'v) map * int

let map_m : 
     return: ('a -> 'a_m)
  -> bind: ('a_m -> ('a -> 'a_m) -> 'a_m)
  -> ('b -> 'c)
  -> 'd
  -> 'e
  = fun ~return ~bind f map ->
    let ( let* ) = bind in
    let rec loop map =
      match map with
      | Empty -> return Empty
      | Node (l, k, v, r, h) ->
        let* l = loop l in
        let* v = f v in
        let* r = loop r in
        return (Node (l, k, v, r, h))
    loop map

This fails with:

# #mod_use "";;
File "", line 20, characters 28-29:
20 |         return (Node (l, k, v, r, h))
Error: This expression has type ('a, 'b) map
       but an expression was expected of type 'b
       The type variable 'b occurs inside ('a, 'b) map

I’ve was hoping the Polymorphic recursion section would provide an answer but so far I’ve had no luck adding explicit polymorphic annotations. (Note: have deliberately relaxed the other types in the signature.)

There is an polymorphism issue, but it does not happen during recursion. For a given monad 'a m the type of return and bind are:

val bind: 'a 'b. 'a m -> ('a -> 'b m) -> 'b m
val return: 'a. 'a -> 'a m

The function map_m uses bind with two distinct types, it thus requires either two functions binds:

let map_m return (let**) (let*) f map =
  let rec loop map =
    match map with
    | Empty -> return Empty
    | Node (l, k, v, r, h) ->
      let** l = loop l in
      let* v = f v in
      let** r = loop r in
      return (Node (l, k, v, r, h))
  loop map

or higher-rank polymorphism. Higher-rank polymorphism in OCaml requires to use either a record or an object with explicit polymorphic annotation. For instance, with the record option:

type monad = {
  return:'a. 'a -> 'a m;
  bind: 'a 'b. 'a m -> ('a -> 'b m) -> 'b m
let map_m { return; bind=(let*) } f map =
  let rec loop map =
    match map with
    | Empty -> return Empty
    | Node (l, k, v, r, h) ->
      let* l = loop l in
      let* v = f v in
      let* r = loop r in
      return (Node (l, k, v, r, h))
  in loop map

works as expected.
Now, to have a function that works for any monads without duplicating the bind function, we would need to have a function that is polymorphic over the type constructor m. Since m is a type constructor, in other words its kind is * -> ... -> *, we need higher-kinded polymorphism.

Higher-kinded polymorphism is the domain of functors:

module type monad = sig
  type 'a t
  val (let*) : 'a t -> ('a -> 'b t) -> 'b t
  val return: 'a -> 'a t
module Map_m(M:monad) = struct
  open M
  let map_m f map =
    let rec loop map =
      match map with
      | Empty -> return Empty
      | Node (l, k, v, r, h) ->
        let* l = loop l in
        let* v = f v in
        let* r = loop r in
        return (Node (l, k, v, r, h))
    in loop map

Superb, comprehensive answer. Thanks! I was aware of the functorial solution but, subjectively, if felt more heavy-weight in context. Looks like the best way to go.