A cute challenge with OCaml's polymorphic variants

While refactoring an OCaml project, I wanted to deduplicate some code, although it was not obvious whether that would be possible. To illustrate the situation, I have extracted the following cute challenge from the project.

Setup

The code in question implements bits of some mathematical structures. Understanding the mathematical side is not strictly necessary to approach the challenge: you could just look at the OCaml code. Nevertheless, I will present the mathematical details to give some context.

Partial orders

Consider the partial order depicted by the following Hasse diagram:

graph BT
   warning("Warning");
   error("Error");
   safe("Safe");

   safe & error --- warning;

In OCaml, we can define the ordered elements using a polymorphic variant type:

type reachable = [`Warning | `Safe | `Error]

Now, consider the following extension of the partial order:

graph BT
   warning("Warning");
   error("Error");
   safe("Safe");
   unreachable("Unreachable");

   unreachable --- safe & error --- warning;

In OCaml, we can define the elements using another polymorphic variant:

type t = [reachable | `Unreachable]

which is just equivalent to

type t = [`Warning | `Safe | `Error | `Unreachable]

Least upper bounds

The pairwise least upper bounds (aka joins) of elements in the latter partial order (on t) are given by the table:

  Warning Safe Error Unreachable
Warning Warning Warning Warning Warning
Safe Warning Safe Warning Safe
Error Warning Warning Error Error
Unreachable Warning Safe Error Unreachable

In OCaml, we can implement this table directly using pattern matching:

let join x y =
  match x, y with
  | `Warning,     `Warning     -> `Warning
  | `Warning,     `Safe        -> `Warning
  | `Warning,     `Error       -> `Warning
  | `Warning,     `Unreachable -> `Warning
  | `Safe,        `Warning     -> `Warning
  | `Safe,        `Safe        -> `Safe
  | `Safe,        `Error       -> `Warning
  | `Safe,        `Unreachable -> `Safe
  | `Error,       `Warning     -> `Warning
  | `Error,       `Safe        -> `Warning
  | `Error,       `Error       -> `Error
  | `Error,       `Unreachable -> `Error
  | `Unreachable, `Warning     -> `Warning
  | `Unreachable, `Safe        -> `Safe
  | `Unreachable, `Error       -> `Error
  | `Unreachable, `Unreachable -> `Unreachable

(This is far from the most compact way to do so.)

The pairwise joins of elements in the former partial order (on reachable) are given by the top left 3×3 subtable above. In OCaml, we can implement it analogously:

let join_reachable x y =
  match x, y with
  | `Warning, `Warning -> `Warning
  | `Warning, `Safe    -> `Warning
  | `Warning, `Error   -> `Warning
  | `Safe,    `Warning -> `Warning
  | `Safe,    `Safe    -> `Safe
  | `Safe,    `Error   -> `Warning
  | `Error,   `Warning -> `Warning
  | `Error,   `Safe    -> `Warning
  | `Error,   `Error   -> `Error

(Again, this is far from the most compact way to do so.)

Challenge

The code duplication between join and join_reachable is not very satisfactory. Instead, we would simply like to define it as follows:

let join_reachable: reachable -> reachable -> reachable = join

However, this does not work as-is, but gives a typing error regarding the return type of the function:

Type [> `Error | `Safe | `Unreachable | `Warning ]
is not compatible with type reachable = [ `Error | `Safe | `Warning ]
The second variant type does not allow tag(s) `Unreachable

Even though we can see that join on two reachable values always returns a result in reachable (and never `Unreachable), the type checker cannot.

The challenge is the following:

Rewrite join such that the desired one-line definition of join_reachable type checks.

(Of course, without altering the behavior of join (on t) or sidestepping the type system with Obj.)

You can easily try this on the OCaml Playground.

Hint: This would work if join had the following type:1

([< `Error | `Safe | `Unreachable | `Warning > `Error `Safe `Warning ] as 'a) -> 'a -> 'a

It is actually not necessary to add any type annotations. To my surprise, the OCaml compiler manages to infer such a type on its own (Garrigue, 2004).

Addendum

This challenge may seem a bit backwards: if t is defined using reachable, then it would make more sense to define join using join_reachable, not the other way around. So suppose that join_reachable is already defined directly using pattern matching.

For defining join, it is then tempting to put all the cases involving `Unreachable first and delegate the rest to join_reachable:

let join x y =
  match x, y with
  | `Warning,     `Unreachable -> `Warning
  | `Safe,        `Unreachable -> `Safe
  | `Error,       `Unreachable -> `Error
  | `Unreachable, `Warning     -> `Warning
  | `Unreachable, `Safe        -> `Safe
  | `Unreachable, `Error       -> `Error
  | `Unreachable, `Unreachable -> `Unreachable
  | x',           y'           -> join_reachable x' y'

However, this gives a typing error regarding the arguments passed to join_reachable:

This expression has type [> `Error | `Safe | `Unreachable | `Warning ]
but an expression was expected of type [< `Error | `Safe | `Warning ]
The second variant type does not allow tag(s) `Unreachable

We may try to help the type checker by adding type annotations to the arguments:

let join (x: t) (y: t): t =
  match x, y with
  (* same `Unreachable cases *)
  | x',           y'           -> join_reachable x' y'

However, this gives a very similar typing error:

This expression has type t but an expression was expected of type
[< `Error | `Safe | `Warning ]
The second variant type does not allow tag(s) `Unreachable

Even though we can see that x' and y' cannot be `Unreachable, the type checker cannot. Although the OCaml compiler checks pattern matching exhaustiveness, which deals with such knowledge, it takes place after type checking and does not feed back to refine the types of x' and y' (Garrigue, 2004).

We can actually help the type checker by explicitly listing the possible values of x' and y', instead of leaving them implicit by exhaustiveness. We can even use some obscure OCaml syntax to include all the variants in reachable automatically:

let join x y =
  match x, y with
  (* same `Unreachable cases *)
  | (#reachable as x'), (#reachable as y') -> join_reachable x' y'

And this type checks, finally!

On performance

Given that the latter more intuitive approach works, why might anyone want the challenging approach from above? Well, there might (or might not) be performance reasons — I have not actually benchmarked them, so take the following with a grain of salt.

The intuitive approach involves a function call (join calling join_reachable), whereas the challenging one does not (there is just one compiled function!). Although this is irrelevant if join_reachable is inlined into join. However, even in that case, there could still be a slight inefficiency: the outer match has to match the arguments against #reachable and then the inner match again matches them against the particular join_reachable cases. Theoretically, this is avoidable if the matches are appropriately fused, but I don’t know if the OCaml compiler would actually achieve this, especially after inlining.

This does not necessarily mean that the challenging approach of defining join_reachable as an alias for join is better performance-wise overall. Although join is entirely self-contained and possibly very efficient, join_reachable may have some overhead. Since there is just one compiled function, calls to join_reachable go into the larger pattern matching which also involves `Unreachable. This may mean unnecessary branching on `Unreachable on the arguments to join_reachable which cannot take this value. It is the cost of reusing a more general function for less general inputs.

Above I presented both functions in both approaches using naïve exhaustive enumeration of all the cases. Besides not being the most compact code-wise, this may not be most efficient either. There are cases which could be combined using wildcard patterns (_) which may reduce the amount of branching in the compiled code. I’m not sure whether OCaml’s pattern matching compiler can arrive at such optimizations itself. This massaging of cases might even eliminate the performance differences between the approaches. But if we actually care about performance, then it is probably the best to not reuse code between the functions anyway, bypassing the whole ordeal.

References

  1. PPL
    Typing deep pattern-matching in presence of polymorphic variants
    Jacques Garrigue
    In JSSST Workshop on Programming and Programming Languages, 2004

  1. I wonder if it is possible to achieve the even more general type

    ([< `Error | `Safe | `Unreachable | `Warning > `Warning ] as 'a) -> 'a -> 'a
    

    If you know, then comment below!