The state monad and parameterised state

Functional Programming with OCaml

The state monad and parameterised state

Module 8 · Lecture 3

KC Sivaramakrishnan
IIT Madras

The previous lectures built monads for failure (option, result) and non-determinism (the list monad). Those are two instances of a much larger pattern. A monad is a reusable way to simulate an effect in otherwise pure code: fix the carrier type 'a t, and the one return / bind / let* interface starts behaving like that effect. State, concurrency, probabilistic choice, logging, even reversible computation can each be packaged as a monad. This lecture builds the canonical one, the state monad, which threads a piece of ambient state through a chain of steps: each step receives the current state and returns a new one, with no mutation. We write it once as a functor parameterised by the state type, behind an interface STATE_MONAD that extends the MONAD of the previous lecture with two operations, get and set. At the end we let the state's type itself change from step to step, the parameterised state monad, which is the first bridge to GADTs.

One pattern, many effects

This lecture

Threading state, not mutating it

A stateful computation producing an 'a is a function from the current state to a pair of (value, new state): state -> 'a * state. This is the carrier 'a t of the state monad, and it plays exactly the role that 'a option played in the option monad and 'a list in the list monad: each monad is "a value of shape 'a t plus return and bind", and here the shape happens to be a function rather than a data structure. There is no mutable cell anywhere; each step takes the state in and hands the next state out, and bind is what lines the output of one step up with the input of the next.

State without mutation

We saw in the previous lecture that every monad implements one MONAD interface. The state monad needs two operations beyond return and bind, so its interface extends MONAD:

STATE_MONAD: the interface

module type MONAD = sig type 'a t val return : 'a -> 'a t val bind : 'a t -> ('a -> 'b t) -> 'b t end module type STATE_MONAD = sig type state include MONAD (* type 'a t, return, bind *) val get : state t (* read the state *) val set : state -> unit t (* overwrite the state *) val run : 'a t -> state -> 'a * state end

The state is always some fixed type, but which type depends on the program: an int counter here, a record elsewhere. That is exactly what functors are for. We write the implementation once, parameterised by a module that supplies the state type:

State(S): one functor, any state type

module State (S : sig type t end) : STATE_MONAD with type state = S.t = struct type state = S.t type 'a t = state -> 'a * state let return x = fun s -> (x, s) let bind m f = fun s -> let (a, s') = m s in (f a) s' let get = fun s -> (s, s) let set s' = fun _ -> ((), s') let run m s = m s end

The result signature STATE_MONAD with type state = S.t keeps state visible (callers know it is int, say) while leaving 'a t abstract, just as Map.Make exposed key but hid the tree. That abstraction is the point: outside the functor the representation state -> 'a * state is invisible, and the only handles on the state are get and set.

A worked example: gensym

A gensym hands out a fresh symbol each time it is called: x_1, x_2, x_3. The state parks "the next number to use". We instantiate State at int, open the result so get, set, return, and run are in scope, and bind let* to bind:

gensym

module IntState = State (struct type t = int end) open IntState let ( let* ) = bind let gensym prefix : string t = let* n = get in let* () = set (n + 1) in return (prefix ^ "_" ^ string_of_int n) let program = let* a = gensym "x" in let* b = gensym "x" in let* c = gensym "y" in return (a, b, c) let _ = run program 1 (* = (("x_1", "x_2", "y_3"), 4) *)

Read the current counter, increment it, return a fresh name.

State starts at 1, ends at 4 (three calls used 1, 2, 3; 4 is the next available). The key thing: the user-facing code never mentions the counter, and it could not, since 'a t is abstract. No ref, no incr; the state is implicit in the let* plumbing, threaded by bind. And gensym is itself the domain helper that hides get and set: callers write gensym "x", never get or set directly. Wrapping the raw state operations in a meaningfully named function (gensym, next_token, read_byte) is the usual way to use the monad.

State monad versus ref

The ref version of gensym is shorter:

What you buy versus ref

let counter = ref 1 let gensym_ref prefix = let n = !counter in counter := n + 1; prefix ^ "_" ^ string_of_int n

Two lines, easy to write. But:

The "tests cannot reset" point is worth seeing concretely. Take two tests, one that increments once and one that increments twice, each meant to start from 0. In the state monad each test is just a value we run from a fresh 0, so they cannot interfere:

Test isolation: run from a fresh state

let incr = let* n = get in set (n + 1) let test1 = incr (* 1 increment *) let test2 = let* () = incr in incr (* 2 increments *) let _ = run test1 0 (* = ((), 1) *) let _ = run test2 0 (* = ((), 2) *)

The shared-ref version has no such reset. Both tests read and write the same global cell (the hazard gensym_ref already has), so the second test's result depends on whether the first one ran:

...but a shared ref leaks between tests

let counter = ref 0 let incr_ref () = counter := !counter + 1; !counter let test1_ref () = incr_ref () (* wants 1 *) let test2_ref () = let _ = incr_ref () in incr_ref () (* wants 2 *) let _ = test1_ref () (* = 1; counter is now 1 *) let _ = test2_ref () (* = 3, not 2: counter was never reset to 0 *)

Pick by scale: a one-off counter, use ref; a whole module of state-threading computations, the monad pays off (state in the type, local reasoning per step); parallel code, the monad is safer but more painful. Ask whether you want the state to be a value (visible in types, threaded by the monad) or a side effect (invisible, mutated in place). Both are legitimate.

The same State functor, instantiated at a different state type, becomes a different tool: a PRNG (state is the seed), a parser (state is the unread input), a type checker (state is the environment plus a fresh-variable counter). The interface, return, bind, let*, get, set, run, never changes; only state varies, which is exactly why it is a functor.

When the state's type should change

The State functor fixes one state type for the whole computation: IntState threads an int from start to finish, and 'a t is int -> 'a * int throughout. But sometimes the state's type should change between steps. Imagine a small stack machine: push 5 turns a stack of shape 's into one of shape int * 's; add turns int * (int * 's) into int * 's. The state's type is the running shape of the stack. The fix is to give the carrier two state indices, a 'pre and a 'post, and bundle the result behind its own interface, just as we did for STATE_MONAD:

PSTATE_MONAD: a parameterised monad

module type PSTATE_MONAD = sig type ('pre, 'post, 'a) t val return : 'a -> ('s, 's, 'a) t val bind : ('p, 'q, 'a) t -> ('a -> ('q, 'r, 'b) t) -> ('p, 'r, 'b) t val get : ('s, 's, 's) t val set : 'post -> ('pre, 'post, unit) t val run : ('pre, 'post, 'a) t -> 'pre -> 'a * 'post end

Read bind's indices the way you would a relay race: the first step carries the state from 'p to 'q, hands the baton to the continuation, which carries it from 'q to 'r, and the composite runs from 'p to 'r. The middle type 'q has to match exactly, and that handover check is what will let the compiler chain preconditions step by step.

The implementation is a plain module, not a functor: there is no single state type to parameterise over, since the type changes per operation. Otherwise it mirrors State exactly, abstract carrier and all: get and set are again the only state-aware operations.

PState: the implementation

module PState : PSTATE_MONAD = struct type ('pre, 'post, 'a) t = 'pre -> 'a * 'post let return x = fun s -> (x, s) let bind m f = fun s -> let (a, s') = m s in (f a) s' let get = fun s -> (s, s) let set s' = fun _ -> ((), s') let run m s = m s end open PState let ( let* ) = bind

A well-typed stack machine

Now the payoff use case. We want a stack machine where the type tracks the shape of the stack, so that add can run only when there really are two ints on top, and a malformed program is rejected by the compiler instead of crashing at runtime. We get this directly from the PState monad: take the state to be the stack itself, and each instruction becomes a PState.t whose 'pre and 'post indices record how it reshapes the stack.

Goal: ill-typed programs don't compile

A stack is a nested pair with unit at the bottom, so its type records the whole shape. Each instruction is built from get and set, never by poking the representation:

Stack instructions, typed by shape

let push (x : 'a) : ('s, 'a * 's, unit) PState.t = let* s = get in set (x, s) let add : (int * (int * 's), int * 's, unit) PState.t = let* (x, (y, s)) = get in set (x + y, s)

push x always succeeds: any stack accepts a value on top. add is fussier: its type int * (int * 's) demands at least two ints on top.

One piece of fine print about that 's. push is a function, but add is a bind application, so the value restriction leaves its 's only weakly polymorphic (the toplevel reports '_s). Every program in this lecture uses add on stacks with the same tail type, so nothing notices; if one session reused it at two different stack-tail types, OCaml would complain at the second use. The fix is the usual one, making it a function: let add () = ..., applied as add () (the same applies to any instruction defined this way).

Run two pushes and an add and the types line up:

A well-typed program

let prog = let* () = push 4 in let* () = push 5 in add let _ = run prog () (* = ((), (9, ())) *)

Push a bool instead, and add can no longer apply. The mismatch is a compile error, not a runtime one:

An ill-typed program is rejected at compile time

let bad_prog = let* () = push 4 in let* () = push true in (* stack: bool * (int * unit) *) add (* add wants int * (int * 's) *)
Error: This expression has type
         (bool * (int * unit), 'a, 'b) PState.t
       but an expression was expected of type
         (int * (int * 'c), 'd, 'e) PState.t
       Type bool is not compatible with type int

This is the payoff. "A stack machine that needs two ints on top to add" is a constraint that lives in the type of the operation, and the compiler enforces it before any code runs.

This is not a toy. It is essentially how WebAssembly bytecode is typed. A Wasm module is validated before it runs by tracking the types on the operand stack, instruction by instruction: every instruction is specified with a stack type, and i32.add has exactly the type our add does, [i32 i32] -> [i32], popping two i32s and pushing one. The validator walks the function body keeping the operand-stack type up to date and rejects the module if any instruction does not find the shape it needs, exactly as the OCaml compiler rejects bad_prog. Our 'pre and 'post indices are the operand-stack type before and after a step. The same shape turns up in session types (a client cannot send before it connects) and in typed builders.

Not a toy: this is how WebAssembly is typed

Bridge to GADTs

The stack machine is one step short of a GADT: the state type is a witness for the shape of the stack at this point. The next lecture makes the pattern first-class, with constructors that carry such witnesses inline and pattern matching that refines them.

A quick check

After run program 1 in the gensym example the result is (("x_1", "x_2", "y_3"), 4). Why does the state end at 4 and not 3?

Why: each call reads the current n, sets the state to n + 1, and produces a name using n. After producing y_3 the state was set to 4. The final state is the "next available", not the "last used".

Why does the ill-typed let* () = push true in add fail at compile time rather than runtime?

Why: parameterised state encodes each operation's precondition in its type. add says "I take a state shaped int * (int * 's)", so the compiler rejects any preceding chain that does not produce such a state. No runtime check; the error is caught at compile time.

Activity

Extend the stack machine with dup, which duplicates the top of the stack: it takes an 'a * 's state and produces an 'a * ('a * 's) state. Then write a program that pushes 7, dups, and adds (top becomes 14). Do not use a ref; thread the state with let*.

Show reference solution

Activity solution: dup

PState, push, add, run, and let* are from earlier in this lecture. dup is the one new instruction:

let dup : ('a * 's, 'a * ('a * 's), unit) PState.t = let* (x, s) = get in set (x, (x, s))
  • get reads the top x; set (x, (x, s)) puts two copies back.
  • Input stack 'a * 's, output 'a * ('a * 's): one element in, two out.

Activity solution: the program

let prog = let* () = push 7 in let* () = dup in add let _ = run prog () (* = ((), (14, ())) *)
  • push 7 then dup leaves two 7s on top: int * (int * unit).
  • Both copies are ints, so add applies and gives 14.

A code quiz on the plain state monad:

Write incr_state : unit state that increments the state by 1 and produces (). Use get and set. (This is the bare state monad, before the functor packaging.)

type 'a state = int -> 'a * int let return x : 'a state = fun s -> (x, s) let bind (m : 'a state) (f : 'a -> 'b state) : 'b state = fun s -> let (a, s') = m s in (f a) s' let ( let* ) = bind let get : int state = fun s -> (s, s) let set new_s : unit state = fun _ -> ((), new_s) let run (m : 'a state) (s : int) : 'a * int = m s let incr_state : unit state = fun _ -> failwith "not implemented"
Show reference solution

Reference solution:

let incr_state =
  let* n = get in
  set (n + 1)

Read the state into n, then set it to n + 1. The value of set is (), which is what incr_state should produce.

What is next

What is next

Lecture 4: GADTs, the second half of the module.

The next lecture starts the GADT half. The parameterised-state pattern reappears there in rigorous form: GADT constructors carry type witnesses, pattern matching refines them, and the compiler tracks state-like information through expressions naturally.

Reading

Sources

This lecture's prose, worked examples, and quizzes are original to this course. The state-monad and parameterised-state framing draw on the author's CS3100 monads notebook, used here as a private structural reference; the surface code, comments, and explanations are written from scratch. Cornell CS3110 and Real World OCaml are CC BY-NC-ND-licensed and have not been derivatively reused. See LICENSES.md at the repository root for the full source posture.