Jimmy Miller

The Easiest Way to Build a Type Checker

Type checkers are a piece of software that feel incredibly simple, yet incredibly complex. Seeing Hindley-Milner written in a logic programming language is almost magical, but it never helped me understand how it was implemented. Nor does actually trying to read anything about Algorithm W or any academic paper explaining a type system. But thanks to David Christiansen, I have discovered a setup for type checking that is so conceptually simple it demystified the whole thing for me. It goes by the name Bidirectional Type Checking.

Bidirectional Type Checking

The two directions in this type checker are inferring types and checking types. Unlike Hindley-Milner, we do need some type annotations, but these are typically at function definitions. So code like the sillyExample below is completely valid and fully type checks despite lacking annotations. How far can we take this? I'm not a type theory person. Reading papers in type theory takes me a while, and my comprehension is always lacking, but this paper seems like a good starting point for answering that question.

function sillyExample(x: number): number {
  let a = 10;
  let b = 20;
  let e = a;
  let f = b;
  let q = a + e;
  let g = "hello";
  let h = "world";
  let i = 100 + q;
  return x;
}

So, how do we actually create a bidirectional type checker? I think the easiest way to understand it is to see a full working implementation. So that's what I have below for a very simple language. To understand it, start by looking at the types to figure out what the language supports, then look at each of the infer cases. But don't worry, if it doesn't make sense, I will explain in more detail below.

export type Type =
  | { kind: "number" }
  | { kind: "string" }
  | { kind: "function"; arg: Type; returnType: Type };

export type Expr =
  | { kind: "number"; value: number }
  | { kind: "string"; value: string }
  | { kind: "varLookup"; name: string }
  | { kind: "function"; param: string; body: Expr }
  | { kind: "call"; fn: Expr; arg: Expr }
  | { kind: "let"; name: string; value: Expr; type?: Type }
  | { kind: "block"; statements: Expr[]; return: Expr };

export type Context = Map<string, Type>;

export function infer(ctx: Context, expr: Expr): Type {
  switch (expr.kind) {
    case "number":
      return { kind: "number" };

    case "string":
      return { kind: "string" };

    case "varLookup":
      const type = ctx.get(expr.name);
      if (!type) {
        throw new Error(`Unbound variable: ${expr.name}`);
      }
      return type;

    case "call":
      const fnType = infer(ctx, expr.fn);
      if (fnType.kind !== "function") {
        throw new Error("Cannot call non-function");
      }
      check(ctx, expr.arg, fnType.arg);
      return fnType.returnType;

    case "function":
      throw new Error("Cannot infer type for function without annotation");

    case "let":
      const valueType = infer(ctx, expr.value);
      if (expr.type) {
        if (!typesEqual(valueType, expr.type)) {
          let expected = JSON.stringify(expr.type);
          let actual = JSON.stringify(valueType);
          throw new Error(`expected ${expected}, got ${actual}`);
        }
      }
      ctx.set(expr.name, valueType);
      return valueType;

    case "block":
      let blockCtx = new Map(ctx);
      for (const stmt of expr.statements) {
        infer(blockCtx, stmt);
      }
      return infer(blockCtx, expr.return);
  }
}

export function check(ctx: Context, expr: Expr, expected: Type): void {
  switch (expr.kind) {
    case "function":
      if (expected.kind !== "function") {
        throw new Error("Function must have function type");
      }
      const newCtx = new Map(ctx);
      newCtx.set(expr.param, expected.arg);
      check(newCtx, expr.body, expected.returnType);
      break;

    case "block":
      let blockCtx = new Map(ctx);
      for (const stmt of expr.statements) {
        infer(blockCtx, stmt);
      }
      check(blockCtx, expr.return, expected);
      break;

    default:
      const actual = infer(ctx, expr);
      if (!typesEqual(actual, expected)) {
        throw new Error(`Type mismatch: expected ${expected}, got ${actual}`);
      }
  }
}

export function typesEqual(a: Type, b: Type): boolean {
  if (a.kind !== b.kind) return false;
  if (a.kind === "function" && b.kind === "function") {
    return typesEqual(a.arg, b.arg) && typesEqual(a.returnType, b.returnType);
  }
  return true;
}

Here we have, in ~100 lines, a fully functional type checker for a small language. Is it without flaw? Is it feature complete? Not at all. In a real type checker, you might not want to know only if something typechecks, but you might want to decorate the various parts with their type; we don't do that here. We don't do a lot of things. But I've found that this tiny bit of code is enough to start extending to much larger, more complicated code examples.

Explanation

If you aren't super familiar with the implementation of programming languages, some of this code might strike you as a bit odd, so let me very quickly walk through the implementation. First, we have our data structures for representing our code:

export type Type =
  | { kind: 'number' }
  | { kind: 'string' }
  | { kind: 'function', arg: Type, returnType: Type }

export type Expr =
  | { kind: 'number', value: number }
  | { kind: 'string', value: string }
  | { kind: 'varLookup', name: string }
  | { kind: 'function', param: string, body: Expr }
  | { kind: 'call', fn: Expr, arg: Expr }
  | { kind: 'let', name: string, value: Expr, type?: Type }
  | { kind: 'block', statements: Expr[], return: Expr }

Using this data structure, we can write code in a way that is much easier to work with than the actual string that we use to represent code. This kind of structure is called an "abstract syntax tree". For example

// double(5)
{
  kind: 'call',
  fn: { kind: 'varLookup', name: 'double' },
  arg: { kind: 'number', value: 5 }
}

This structure makes it easy to walk through our program and check things bit by bit.

Context

export type Context = Map<string, Type>

This simple line of code is the key to how all variables, all functions, etc, work. When we enter a function or a block, we make a new Map that will let us hold the local variables and their types. We pass this map around, and now we know the types of things that came before it. If we wanted to let you define functions out of order, we'd simply need to do two passes over the tree. The first to gather up the top-level functions, and the next to type-check the whole program. (This code gets more complicated with nested function definitions, but we'll ignore that here.)

Inference

Each little bit of infer may seem a bit trivial. So, to explain it, let's add a new feature, addition.

// add this into our Expr type
| { kind: 'add', left: Expr, right: Expr }

Now we have something just a bit more complicated, so how would we write our inference for this? Well, we are going to do the simple case; we are only allowed to add numbers together. Given that our code would look something like this:

case 'add':
  const leftType = check(ctx, expr.left, {kind: "number"})
  const rightType = check(ctx, expr.right, {kind: "number"})
  return {kind: "number"};

This may seem a bit magical. How does check make this just work? Imagine that we have the following expression:

// 2 + 3 + 4
 {
    kind: 'add',
    left: {
      kind: 'add',
      left: { kind: 'number', value: 2 },
      right: { kind: 'number', value: 3 }
 },
    right: { kind: 'number', value: 4 }
 }

There is no special handling in check for add so we end up at

default:
  const actual = infer(ctx, expr)
  if (!typesEqual(actual, expected)) {
    throw new Error(`Type mismatch: expected ${expected}, got ${actual}`)
 }

If you trace out the recursion (once you get used to recursion, you don't actually need to do this, but I've found it helps people who aren't used to it), we get something like

 infer(2 + 3 + 4)
    check(2 + 3, number)
      infer(2 + 3)
        check(2, number)
          infer(2)number
        check(3, number)
          infer(3)number
    check(4, number)
      infer(4)number

So now for our first left, we will recurse back to infer, then to check, and finally bottom out in some simple thing we know how to infer. This is the beauty of our bidirectional checker. We can interleave these infer and check calls at will!

How would we change our add to work with strings? Or coerce between number and string? I leave that as an exercise to the reader. It only takes just a little bit more code.

Making it Feel Real

I know for a lot of people this might all seem a bit abstract. So here is a very quick, simple proof of concept that uses this same strategy above for a subset of TypeScript syntax (it does not try to recreate the TypeScript semantics for types).

If you play with this, I'm sure you will find bugs. You will find features that aren't supported. But you will also see the beginnings of a reasonable type checker. (It does a bit more than the one above, because otherwise the demos would be lame. Mainly multiple arguments and adding binary operators.)

But the real takeaway here, I hope, is just how straightforward type checking can be. If you see some literal, you can infer its type. If you have a variable, you can look up its type. If you have a type annotation, you can infer the type of the value and check it against that annotation. I have found that following this formula makes it quite easy to add more and more features.