Controlling mutation with types

A couple of weeks ago I spent some time pair-programming with Penelope Phippen, on her Ruby auto-formatter rubyfmt. This codebase is largely written in Rust and does a lot of work transforming abstract syntax trees to determine how your Ruby code should be formatted. As such, it deals with recursion and stack structures, and this presented an interesting opportunity to get the type system to enforce some business rules for us.

To illustrate the problem we were solving, I’m going to pick an equivalent problem that’s a bit more self-contained and won’t require explaining all the internals of rubyfmt. Imagine we’d like to parse the following string into a data structure of nested lists of digits.

    [3, [1, 4], [1, [5, 9, [2]], 6]]

There are many possible programming approaches to doing this, and one of them is to track the state using a stack. To parse the string, we have an iterator that yields each char from the string, whose state I’m indicating using a caret (^) below. When the parser sees a left bracket ([), it pushes a new list [] onto the stack, so our state on reading the first character is:

          [3, [1, 4], [1, [5, 9, [2]], 6]]
          ^

    stack: []

When we see a digit, we push that digit onto the list that’s at the top of the stack, so that first 3 is added to the list that we just put on the stack.

          [3, [1, 4], [1, [5, 9, [2]], 6]]
           ^

    stack: [3]

Then we see another left bracket, so we push another list [] on the stack. The stack now has [] at the top, and the list [3] underneath.

          [3, [1, 4], [1, [5, 9, [2]], 6]]
              ^

    stack: []
           [3]

We keep following these rules, processing digits and brackets. We push the 1 and the 4 onto the topmost list:

          [3, [1, 4], [1, [5, 9, [2]], 6]]
                  ^

    stack: [1, 4]
           [3]

And then we see a right bracket (]) denoting the end of the current list. We handle this by popping the topmost list off the stack, and then appending it to the list that’s underneath it. That is, we pop the list [1, 4] off the stack and push this onto the list [3], giving us [3, [1, 4]].

          [3, [1, 4], [1, [5, 9, [2]], 6]]
                   ^

    stack: [3, [1, 4]]

Continuing these rules up to the digit 2 gives us the following stack state:

          [3, [1, 4], [1, [5, 9, [2]], 6]]
                                  ^

    stack: [2]
           [5, 9]
           [1]
           [3, [1, 4]]

We then see two right brackets. Processing the first one means we pop [2] off the stack and append it to [5, 9], giving us [5, 9, [2]] on top of the stack.

          [3, [1, 4], [1, [5, 9, [2]], 6]]
                                   ^

    stack: [5, 9, [2]]
           [1]
           [3, [1, 4]]

Handling the second bracket means we pop the list [5, 9, [2]] and append it to [1], giving [1, [5, 9, [2]]].

          [3, [1, 4], [1, [5, 9, [2]], 6]]
                                    ^

    stack: [1, [5, 9, [2]]]
           [3, [1, 4]]

Following through to the bracket after the digit 6 gets us back to a state where we have a single list on the stack, and hitting the final right bracket gives the final result.

          [3, [1, 4], [1, [5, 9, [2]], 6]]
                                         ^

    result: [3, [1, 4], [1, [5, 9, [2]], 6]]

Having seen the approach in the abstract, let’s turn this into some Rust code. I’ll use this Expr enum to model the generated structure; it’s either a List containing a Vec of nested Expr values, or it’s a Digit holding a number.

enum Expr {
    List(Vec<Expr>),
    Digit(u32),
}

The Expr::from_str function will parse an Expr from a string and return a Vec of all the expressions it found. We could return a single Expr but this tidies the design up and is actually closer to the rubyfmt problem we were dealing with. It begins by creating the stack, which at first holds a single empty list; the outer Vec in vec![vec![]] is the stack, and the inner Vec is the first list on the stack. We pass this stack and a Chars iterator into parse_expr to process the string, and when that returns we pop the top list off the stack. If all the brackets in the string are balanced the stack should only contain a single value.

impl Expr {
    fn from_str(string: &str) -> Vec<Expr> {
        let mut stack = vec![vec![]];
        parse_expr(&mut stack, &mut string.chars());
        stack.pop().unwrap()
    }
}

Now we get to the meat of the parser. This is grossly over-simplified just to illustrate the idea – it doesn’t deal with all characters, or handle any syntax errors. It consumes characters from chars; on left brackets it calls push_list, on digits it calls push_digit, and on right brackets it breaks out of the loop and returns. All other characters are ignored.

use std::str::Chars;

fn parse_expr(stack: &mut Vec<Vec<Expr>>, chars: &mut Chars) {
    while let Some(c) = chars.next() {
        match c {
            '[' => push_list(stack, chars),
            '0'..='9' => push_digit(stack, c),
            ']' => break,
            _ => {}
        }
    }
}

push_list is effectively the recursive part of the algorithm, and the root of the design problem we’re concerned with. It starts by pushing a new Vec on the stack, and then recurses into parse_expr with the extended stack. When parse_expr returns – because it hit a right bracket or ran out of characters – we pop the list off the top of the stack and push it onto the list below by calling push_down.

fn push_list(stack: &mut Vec<Vec<Expr>>, chars: &mut Chars) {
    stack.push(vec![]);
    parse_expr(stack, chars);
    let list = stack.pop().unwrap();

    push_down(stack, Expr::List(list));
}

The push_down function deals with pushing an Expr onto the list on the top of the stack. We use Vec::last_mut to get a mutable reference to the top list on the stack, but because a Vec may be empty, this returns Option<&mut Vec<Expr>> and so we need to pattern-match it to get the actual reference out and modify the list it points at.

fn push_down(stack: &mut Vec<Vec<Expr>>, value: Expr) {
    if let Some(list) = stack.last_mut() {
        list.push(value);
    }
}

And to round things off, push_digit uses char::to_digit to convert a character into a number, and pushes it on the top of the stack.

fn push_digit(stack: &mut Vec<Vec<Expr>>, c: char) {
    let value = Expr::Digit(c.to_digit(10).unwrap());

    push_down(stack, value);
}

Now we’ve got a working implementation, but there are some clear problems with it. For one, there’s a couple places where it will panic if the stack structure ends up in an unexpected state. We’re calling stack.pop().unwrap(), expecting that Vec::pop will always give us a value because the stack should not be empty at this point. In push_down we’re being more careful by using a pattern-match rather than Option::unwrap, but there’s still a clear assumption here that the stack is never empty and we don’t need to handle the None case.

This becomes very hard to guarantee as the program grows. For example, look again at push_list:

fn push_list(stack: &mut Vec<Vec<Expr>>, chars: &mut Chars) {
    stack.push(vec![]);
    parse_expr(stack, chars);
    let list = stack.pop().unwrap();

    push_down(stack, Expr::List(list));
}

This assumes that the stack is the same size before and after parse_expr(), so that we always get a value back from stack.pop() and it’s the same list we pushed on the stack before calling another function. In rubyfmt, there are many other types of nested syntax to deal with: arrays, hashes, method and block parameters, method call arguments, all of which can be nested in each other and are combined in different ways in different contexts. rubyfmt deals with this by composing different processing functions using closures. For example we can make push_list take a closure that deals with how to process the list contents, rather than being hard-coded to call parse_expr.

fn push_list<F>(stack: &mut Vec<Vec<Expr>>, chars: &mut Chars, f: F)
where
    F: FnOnce(&mut Vec<Vec<Expr>>, &mut Chars),
{
    stack.push(vec![]);
    f(stack, chars);
    let list = stack.pop().unwrap();

    push_down(stack, Expr::List(list));
}

parse_expr can retain its existing behaviour by passing a call to itself in the closure to push_list:

fn parse_expr(stack: &mut Vec<Vec<Expr>>, chars: &mut Chars) {
    while let Some(c) = chars.next() {
        match c {
            '[' => push_list(stack, chars, |st, ch| parse_expr(st, ch)),
            '0'..='9' => push_digit(stack, c),
            ']' => break,
            _ => {}
        }
    }
}

The problem with push_list is that it’s now impossible to guarantee the call to unwrap is safe here:

    stack.push(vec![]);
    f(stack, chars);
    let list = stack.pop().unwrap();

Looking just at this function, f could make absolutely change with the &mut Vec<Vec<Expr>> it’s given, including removing all its elements, so how can we know that stack.pop() will still return what we expect? To answer this we’d have to manually review the entire program. Can we do better than this? Can we lean on the Rust type system to guarantee statically that the stack is never empty and that we can safely pass in any closure in here and it will only make safe changes to the stack?

The answer to this comes in two parts. First, we can change the type we use to represent stacks so that it’s impossible to construct an empty one. Second, we provide an API for that type that only lets callers use it safely.

So to begin, let’s make a better type that cannot be empty. Stack<T> has a head field which must contain a value (it’s not an Option), and a rest field containing a Vec as we’ve been using it above. Stack::new constructs a stack from an initial head value.

#[derive(Default)]
struct Stack<T> {
    head: T,
    rest: Vec<T>,
}

impl<T> Stack<T> {
    fn new(head: T) -> Stack<T> {
        Stack { head, rest: vec![] }
    }
}

Second, we need to think about what functionality our program actually needs from the stack. We don’t want the closure to push_list to be able to make arbitrary changes, but for the parser to work, it needs to be able to do two things. It needs to modify the top item on the stack (the push_down function), and it needs this interaction where a new item is pushed to the stack, a closure is run, and then the stack is popped and the result returned to the caller. The closure must not be able to make arbitrary changes like pushing and popping the stack, it should only be able to perform this well-defined transaction that maintains the stack size.

In order to mutate the stack, any closure will at least need to be given a &mut Stack<T>. In fact if we only pass this type to the closure (rather than an owned Stack<T>) then we can implement a safe interface for that type, and require an owned type for the low-level push and pop operations. We’ll look at those operations first and then use them to build the safe API.

Stack::push it going to take an owned Stack<T> and a T and return a Stack<T>, representing the result of pushing the new item on the stack. We can implement this by mutating the input and returning it back to the caller.

use std::mem;

impl<T> Stack<T> {
    fn push(mut self, item: T) -> Stack<T> {
        let head = mem::replace(&mut self.head, item);
        self.rest.push(head);
        self
    }
}

This function uses mem::replace to grab the current head value and replace it with the new one, before pushing the old value onto the rest array. Strictly speaking we could have implemented this by taking &mut self, but requiring self puts it off-limits if you only have a &mut Stack<T> and makes it symmetric with the pop function. It’s also nice if you go and implement an enum to handle both empty and non-empty stacks, because then push can return an owned non-empty stack for both variants.

Now for the pop function, which is where things get a bit more interesting. We want this to always return the T that’s in the head field, because this stack is never empty. That means moving the current value out of head, which we can only do if we can put another value in its place. And that means we can only do it if there are still items in the rest slot. So an implementation that takes &mut Stack<T> might look like this:

impl<T> Stack<T> {
    fn pop(&mut self) -> T {
        if let Some(item) = self.rest.pop() {
            mem::replace(&mut self.head, item)
        } else {
            panic!("we emptied the stack somehow");
        }
    }
}

This implementation still has to panic if it can’t put the structure in a valid state. But if we write a consuming implementation, we can do something more interesting: return a T, and an Option of a new stack state, if it’s possible to make a valid one.

impl<T> Stack<T> {
    fn pop(mut self) -> (T, Option<Stack<T>>) {
        let head = self.head;

        if let Some(item) = self.rest.pop() {
            self.head = item;
            (head, Some(self))
        } else {
            (head, None)
        }
    }
}

By taking Stack<T> rather than &mut Stack<T>, we can arrange it so that you can’t pop the stack and keep the stack object you’re currently holding: Stack::pop consumes it so you can’t use it any more. If it’s possible to build a valid stack afterward, then you’ll get one back, otherwise you get nothing. Either way, you get the top stack item, but you can never hold an invalid Stack<T> value.

So how do we build a safe interface out of these consuming functions, if the caller only has a &mut Stack<T>? We want a safe interface for performing an action in between pushing a new item on the stack and popping it off again, where that action cannot do random pushes/pops. So the signature will need to take and return the temporarily pushed value, and take a closure accepting a mutable reference:

    fn with_item<F>(&mut self, item: T, f: F) -> T
    where
        F: FnOnce(&mut Stack<T>)

If this method only receives &mut self, how can it make use of the push and pop functions that take an owned self? Well, we can cheat a little bit by replacing self with Stack::default(), allowing us to move the existing stack out of self and take ownership of it.

        let stack = mem::take(self);

Because of the #[derive(Default)], Stack implements Default if T: Default, allowing us to effectively construct an “empty” stack with T::default() as its head. Since we want a stack of Vec<Expr> and Vec<T>: Default, this is fine.

stack now contains an owned version of self, but since we’ve replaced self with an empty stack, it’s not actually usable now. However, we will make things right before this method returns so that it’s safe again.

The next thing we do is push the new item onto the stack, and call the closure with a &mut to this extended stack. When the closure returns, we call Stack::pop to get the top value and a new stack. head is the value we can return to the caller now that it’s been moved out of the stack.

        let mut pushed_stack = stack.push(item);
        f(&mut pushed_stack);
        let (head, popped_stack) = pushed_stack.pop();

Finally, we need to restore self to hold the real stack, not the “empty” one we temporarily replaced it with. This is only possible if popped_stack is not None, so we still have to panic if something’s gone wrong.

        if let Some(mut stack) = popped_stack {
            mem::swap(self, &mut stack);
        } else {
            panic!("we emptied the stack somehow");
        }

The full Stack::with_item function is shown below.

impl<T> Stack<T> {
    fn with_item<F>(&mut self, item: T, f: F) -> T
    where
        F: FnOnce(&mut Stack<T>),
        T: Default,
    {
        let stack = mem::take(self);

        let mut pushed_stack = stack.push(item);
        f(&mut pushed_stack);
        let (head, popped_stack) = pushed_stack.pop();

        if let Some(mut stack) = popped_stack {
            mem::swap(self, &mut stack);
        } else {
            panic!("we emptied the stack somehow");
        }

        head
    }
}

This function lets us refactor push_list to use this safer abstraction, rather than handling pushing and popping the stack itself. It pushes a vec![] onto the stack, runs its closure inside the Stack::with_item block, and pushes the result down.

fn push_list<F>(stack: &mut Stack<Vec<Expr>>, chars: &mut Chars, f: F)
where
    F: FnOnce(&mut Stack<Vec<Expr>>, &mut Chars),
{
    let list = stack.with_item(vec![], |inner_stack| {
        f(inner_stack, chars);
    });

    push_down(stack, Expr::List(list));
}

The other parsing functions change their first argument from &mut Vec<Vec<Expr>> to &mut Stack<Vec<Expr>>, and Expr::from_str can change to use Stack::new and Stack::pop, safe in the knowledge that it will always return a value.

impl Expr {
    fn from_str(string: &str) -> Vec<Expr> {
        let mut stack = Stack::new(vec![]);
        parse_expr(&mut stack, &mut string.chars());
        stack.pop().0
    }
}

The only other change is that we can drop the pattern-matching guard from push_down. Vec::last_mut returns Option<&mut T> but Stack::last_mut can return &mut T, a mutable reference to head. Instead of being undefined for an empty stack, push_down is now statically guaranteed to work in all cases.

impl<T> Stack<T> {
    fn last_mut(&mut self) -> &mut T {
        &mut self.head
    }
}

fn push_down(stack: &mut Stack<Vec<Expr>>, value: Expr) {
    stack.last_mut().push(value);
}

You may have noticed this is still not completely safe. Any closure in receipt of a &mut Stack<T> can simply replace the stack with a brand new one, just as we do in Stack::with_item. But by making a type that can’t represent invalid states, and putting a purpose-built API over the top, it’s much harder to misuse and there’s only one place the program can still panic, rather than anywhere it was previously using Vec::pop. In particular it cannot panic by calling pop() after the closure returns, because Stack::pop always returns a value.

The ergonomics have also improved because we’ve removed a bunch of places where we previously had to handle Option values. Stack::with_item presents an interface that, so long as closures use that and not any mem::replace trickery which is easier to audit for, the stack will behave correctly. The simplified interface compared to Vec means callers are much less likely to trigger a panic by accident. You can still cause weird behaviour deliberately, but it’s much hard to make unintentional mistakes.

While there are certainly other approaches to this problem, I think it’s a nice demonstration of the power of static types to make sure certain program states can never happen, because there’s no well-typed value for them. Rust’s type system in particular, with its ownership and borrowing rules, also allows for fine-grained control over letting functions mutate things. The ability for a function to take an owned value and so prevent the caller from using it again feels inconvenient at first but it’s incredibly useful for stopping invalid behaviour.