8. Trees

Trees are everywhere, there isn’t just one type of trees but many datatypes that can be subsumed under the general concept of trees. Indeed the types we have seen in the previous chapters (natural numbers and lists) are special instances of trees.

Trees are also the most dangerously underused concept in programming. Bad programmers tend to try to do everything with strings, unaware how much easier it would be with trees. You don’t have to be a functional programmer to use trees, trees can be just as easily defined in languages like Python - see my computerphile video.

In this chapter we are going to look at two instances of trees: expression trees and sorted trees.

Expression trees are used to encode the syntax of expressions, we are going to define an interpreter and a compiler which compiles expression into the code for a simple stack machine and then show that the compiler is correct wrt the interpreter. This is an extended version of Hutton’s razor an example invented by Prof Hutton.

The 2nd example is tree-sort: an efficient alternatve to insertion-sort. From a list we produce a sorted tree and then we turn this into a sorted list. Actually, tree-sort is just quicksort in disguise.

8.1. Expression trees

I am trying to hammer this in: when you see an expression like x * (y + 2) or (x * y) + 2 then try to see the tree. Really expressions are a 2-dimensional structure but we use a 1-dimensional notation with brackets to write them.

_images/expr1.pdf _images/expr2.pdf

But it isn’t only about seeing the tree we can turn this into a datatype, indeed an inductive datatype.

inductive Expr : Type
| const :   Expr
| var : string  Expr
| plus : Expr  Expr  Expr
| times : Expr  Expr  Expr

def e1 : Expr
  := times (var "x") (plus (var "y") (const 2))

def e2 : Expr
  := plus (times (var "x") (var "y")) (const 2)

An expression is either a constant, a variable, a plus-expression or a times-expression. To construct a constant we need a number, for variables we need a string and for both plus and times we need two expressions which serve as the first and second argument. The last two show that expression tress are recursive.

I am not going to waste any time to prove no-confusion and injectivity, e.g.

 theorem no_conf :  n : ,  l r : Expr, const n  plus l r :=
 begin
   sorry,
 end

 theorem inj_plus_l :  l r l' r' : Expr , plus l r = plus l' r'  l=l' :=
 begin
   sorry,
 end
 

Btw the name of the theorem inj_plus is a bit misleading: it is the tree constructor plus that is injective not the operation +. Actually is + injective?

8.1.1. Evaluating expressions

Instead let’s evaluate expressions! To do this we need an assignment from variable names (i.e. strings) to numbers. For this purpose I introduce a type of environments - I am using functions to represent environments.

def Env : Type
  := string  

def my_env : Env
| "x" := 3
| "y" := 5
| _ := 0

#reduce my_env "y"

The environment my_env assigns to "x" the number 3, to y the number 5 and 0 to all other variables. Really I should have introduced some error handling for undefined variables but for the sake of brevity I am going to ignore this. To look up a variable name we just have to apply the function, e.g. my_env "y"

Ok, we are ready to write the evaluator for expressions which gets an expression and an environment and returns a number. And it uses - oh horror - recursion on trees.

def eval : Expr  Env  
| (const n) env := n
| (var s) env := env s
| (plus l r) env := (eval l env) + (eval r env)
| (times l r) env := (eval l env) * (eval r env)

#reduce eval e1 my_env

#reduce eval e2 my_env

eval looks at the expression: if it is a constant it just returns the numerical values of the constant, it looks up variable in the environment and to evaluate a plus or a times we first recursively evaluate the subexpressions and then add them or multiply them together.

I hope you are able to evaluate the two examples e1 and e2 in your head before checking wether you got it right.

8.1.2. A simple Compiler

To prove something interesting let’s implement a simple compiler. Our machine code is a little stack machine. We first define the instructions:

inductive Instr : Type
| pushC :   Instr
| pushV : string  Instr
| add : Instr
| mult : Instr

open Instr

def Code : Type
  := list Instr

We can push a constant or a variable from the environment onto the stack and we can add or multiply the top two items of the stack which has the effect of removing them and replacing them with their sum or product. The machine code is just a sequence of instructions.

We define a run function that executes a piece of code, returning what is on the top of the stack which is represented as a list of numbers.

def Stack : Type
  := list 

def run : Code  Stack  Env  
| [] [n] env := n
| (pushC n ::c) s env := run c (n :: s) env
| (pushV x ::c) s env := run c (env x :: s) env
| (add :: c) (m :: n :: s) env := run c ((n + m) :: s) env
| (mult :: c) (m :: n :: s) env := run c ((n * m) :: s) env
| _ _ _ := 0

The function run analyzes the first instruction (if there is one) and modifies the stack accordongly and then rns the remaining instructions. Again no error handling, if something is wrong, I return 0. This calls for some serious refactoring! But not now.

As an example let’s run some code that computes the values of e1:

def c1 : Code
:= [pushV "x",pushV "y",pushC 2,add,mult]

#eval run c1 [] my_env

Now I have compiled e1 by hand to c1 but certainly we can do this automatically. Here is the first version of a compiler:

def compile_naive : Expr  Code
| (const n) := [pushC n]
| (var x) := [pushV x]
| (plus l r) := (compile_naive l) ++ (compile_naive r) ++ [add]
| (times l r) := (compile_naive l) ++ (compile_naive r) ++ [mult]

The compiler translates const and var into the corresponding push instructions, and for plus and times it creates code for the subexpression and inserts an add or a mult instruction afterwards.

The naive compiler is inefficient due to the use of ++ which has to traverse the already created code each time. And it is actually a bit harder to verify because we need to exploit the fact that lists are a monoid. However, there is a nice trick to make the compiler more efficient and easier to verify. We add an extra argument to the compiler which is the code that should be inserted after the code for the expression we are just compiling (the continuation). We end up with:

def compile_aux : Expr  Code  Code
| (const n) c := pushC n :: c
| (var x) c := pushV x :: c
| (plus l r) c := compile_aux l (compile_aux r (add :: c))
| (times l r) c := compile_aux l (compile_aux r (mult :: c))

def compile (e :Expr) : Code
  := compile_aux e []

#reduce run (compile e1) [] my_env

#reduce run (compile e2) [] my_env

This version of the compiler is more efficient because it doesn’t need to traverse the code it already produced. Indeed, this is basically the same issue with rev vs fastrev. The other advantage is that we don’t need to use any properties of ++ in the proof because we aren’t using it!

8.1.3. Compiler correctness

We can see looking at the examples that compile and run produces the same results as eval but we woud like to prove this, i.e. the correctness of the compiler.

theorem compile_ok :  e : Expr,  env : Env,
              run (compile e) [] env = eval e env :=
begin
  sorry,
end

However, we won’t be able to prove this directly because here we state a property which only holds for the empty stack but once we compile a complex expression the stack won’t be empty anymore.

This means we have to find a stronger proposition for which the induction goes through and which implies the proposition we actually want to prove. This is known as induction loading.

Clearly we need to prove a statement for compile_aux, namely that running compile_aux for some expression is the same as evaluating the expression and putting the result on the stack. This implies the statement we want to prove for the special case that both the remaining code and the stack are empty.

lemma compile_aux_ok :
       e : Expr,  c : Code,  s : Stack,  env : Env,
      run (compile_aux e c) s env = run c ((eval e env) :: s) env  :=
begin
  assume e,
  induction e,
  sorry,
 

I have already started the proof. We are going to do induction over the expression e. After this we are in the following state:

e : ℕ
⊢ ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux (const e) c) s env = run c (eval (const e) env :: s) env

case Expr.var
e : string
⊢ ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux (var e) c) s env = run c (eval (var e) env :: s) env

case Expr.plus
e_a e_a_1 : Expr,
e_ih_a : ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux e_a c) s env = run c (eval e_a env :: s) env,
e_ih_a_1 : ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux e_a_1 c) s env = run c (eval e_a_1 env :: s) env
⊢ ∀ (c : Code) (s : Stack) (env : Env),
    run (compile_aux (plus e_a e_a_1) c) s env = run c (eval (plus e_a e_a_1) env :: s) env

case Expr.times
e_a e_a_1 : Expr,
e_ih_a : ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux e_a c) s env = run c (eval e_a env :: s) env,
e_ih_a_1 : ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux e_a_1 c) s env = run c (eval e_a_1 env :: s) env
⊢ ∀ (c : Code) (s : Stack) (env : Env),
    run (compile_aux (times e_a e_a_1) c) s env = run c (eval (times e_a e_a_1) env :: s) env

We see that we have four cases, one for each of the constructors and in the recursive cases for plus ad mult I have induction hypothesis which say that my theorem holds for the left and right sub expressions.

I don’t want to use the names generated by lean (like e_ih_a_1) but using with here would be also getting a bit complicated given all the names we are using. Hence I am going to use case for each of the cases which allows my to introduce the variables separately. I also have to use {..} to turn the proof for each case into a block.

Hence our proof is going to look like this:
lemma compile_aux_ok :  e : Expr,  c : Code,  s : Stack,  env : Env,
      run (compile_aux e c) s env = run c ((eval e env) :: s) env
      :=
begin
  assume e,
  induction e,
  sorry,
  case const : n {
  sorry, },
  case var : name {
  sorry, },
  case plus : l r ih_l ih_r {
  sorry,  },
  case times : l r ih_l ih_r {
  sorry,}
end
 

The cases for const and var are easy, we just need to appeal to reflexivity,. The cases for plus and mult are virtually identical (another case for refactoring, in this case for a proof). Let’s have a look at plus:

l r : Expr,
ih_l : ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux l c) s env = run c (eval l env :: s) env,
ih_r : ∀ (c : Code) (s : Stack) (env : Env), run (compile_aux r c) s env = run c (eval r env :: s) env,
c : Code,
s : Stack,
env : Env
⊢ run (compile_aux (plus l r) c) s env = run c (eval (plus l r) env :: s) env

By unfolding the definition of compile_aux (plus l r) c we obtain:

run (compile_aux (plus l r) c) s env
= run (compile_aux l (compile_aux r (add :: c))) s env

And now we can use ih_l to push the value of l on the stack:

... = run (compile_aux r (add :: c)) ((eval l env) :: s) env

So indeed the value of l has landed on the stack, and now we can use ih_r

... = run (add :: c) ((eval r env) :: (eval l env) :: s) env

We are almost done, we can run run one step

... = run c (((eval l env) + (eval r env)) :: s) env

And now working backward from the definition of eval we arrive on the other side of the equation:

...  = run c (eval (plus l r) env :: s) env

Ok here is the proof so far using calc for the reasoning above:

lemma compile_aux_ok :  e : Expr,  c : Code,  s : Stack,  env : Env,
      run (compile_aux e c) s env = run c ((eval e env) :: s) env
      :=
begin
  assume e,
  induction e,
  sorry,
case const : n {
  assume c s env,
  reflexivity,},
  case var : name {
  assume c s env,
  reflexivity,},
  case plus : l r ih_l ih_r {
  assume c s env,
  dsimp [compile_aux],
  calc
    run (compile_aux (plus l r) c) s env
    = run (compile_aux l (compile_aux r (add :: c))) s env : by refl
    ... = run (compile_aux r (add :: c)) ((eval l env) :: s) env
               : by rewrite ih_l
    ... = run (add :: c) ((eval r env) :: (eval l env) :: s) env
               : by rewrite ih_r
    ... = run c (((eval l env) + (eval r env)) :: s) env : by refl
    ...  = run c (eval (plus l r) env :: s) env : by refl,
  },
  case times : l r ih_l ih_r {
  sorry,}
end
 

I leave it to you to fill in the case for times.

You may have noticed that I didn’t introduce all the assumptions in the beginning. Could I have done the proof starting with:

lemma compile_aux_ok :  e : Expr,  c : Code,  s : Stack,  env : Env,
      run (compile_aux e c) s env = run c ((eval e env) :: s) env
      :=
begin
  assume e c s env,
  induction e,
  sorry,
 

It seems I would have avoided the repeated assume c s env, or is there a problem? Try it out.

8.2. Tree sort

Finally we will look at another application of trees: sorting. The algorithm I am describing is tree sort and as I already said it is a variation of quicksort.

The idea is that we turn a list like our favorite example [6,3,8,2,3] into a tree like this one:

_images/tree.pdf

The nodes of the tree are labelled with numbers and the tree is sorted in the sense that at each node all the labels in the left subtree are less or equal to the number at the current node and all the ones in the right subtree are greater or equal. And once we flatten this tree into a list we get the sorted list [2,3,3,6,8].

8.2.1. Implementing tree sort

First of all we need to define the type of binary trees with nodes labelled with natural numbers:

inductive Tree : Type
| leaf : Tree
| node : Tree    Tree  Tree

To build a sorted tree from a list we need to write a function that inserts an element into a sorted tree, preserving sortedness. We define this function by recursion over trees:

def ins :   Tree  Tree
| n leaf := node leaf n leaf
| n (node l m r) :=
              if ble n m
              then node (ins n l) m r
              else node l m (ins n r)
 

Here we query the function ble which we have already seen earlier to decide wether to recursively insert the number into the right or left subtree.

To turn a list into a sorted tree we need to fold ins over the list, mapping the empty list to a leaf.

def list2tree : list   Tree
| [] := leaf
| (n :: ns) := ins n (list2tree ns)

#reduce list2tree [6,3,8,2,3]
 

Now all what is left to do to do is to implement a function that flattens a tree into a list:

def tree2list : Tree  list 
| leaf := []
| (node l m r) := tree2list l ++ m :: tree2list r
 

Putting both together we have constructed a sorting function on lists - treesort:

def sort (ns : list ) : list 
  := tree2list (list2tree ns)

#reduce (sort [6,3,8,2,3])
 

8.2.2. Verifying tree sort

To verify that tree sort returns a sorted list we have to specify what a sorted tree is. To do this we need to be able to say things like all the nodes in a tree are smaller or greater that a number. We can do this even more generally by defining a higher order predicate that applies a given predicate to all nodes of a tree:

inductive AllTree (P :   Prop) : Tree  Prop
| allLeaf : AllTree leaf
| allNode :  l r : Tree,  m : ,
              AllTree l  P m  AllTree r  AllTree (node l m r)
 

That is AllTree P t holds if the predicate P holds for all the numbers in the nodes of t.

Using AllTree we can define SortedTree:

inductive SortedTree : Tree  Prop
| sortedLeaf : SortedTree leaf
| sortedNode :  l r : Tree,  m : ,
               SortedTree l  AllTree (λ x:, x  m) l
              SortedTree r  AllTree (λ x:, m  x) r
              SortedTree (node l m r)
 

We are now ready to state the correctness of sort which is the same as for insertion sort using the predicate Sorted on lists that we have define in the previous chapter:

theorem tree_sort_sorts :  ns : list , Sorted (sort ns) :=
begin
sorry,
end
 

It is not difficult to identify the two lemmas we need to show:

  • list2tree produces a sorted tree (list2tree_lem)
  • tree2list maps a sorted tree into a sorted list (tree2list_lem)

Hence the top-level structure of our proof looks like this:

lemma list2tree_lem : forall l : list , SortedTree (list2tree l) :=
begin
  sorry,
end

lemma tree2list_lem :  t : Tree, SortedTree t  Sorted (tree2list t) :=
begin
  sorry
end

theorem tree_sort_sorts :  ns : list , Sorted (sort ns) :=
begin
  assume ns,
  dsimp [sort],
  apply tree2list_lem,
  apply list2tree_lem,
end
 

Since you have now seen enough proofs I will omit the gory details but only tell you the lemmas (stepping stones). First of all we want to prove list2tree_lem by induction over lists. Hence another lemma pops up:

lemma ins_lem :  t : Tree, n:,SortedTree t  SortedTree (ins n t) :=
begin
  sorry,
end
 

This we need to prove by induction over trees. At some point we need a lemma about the interaction of ins with AllTree, I used the following:

lemma insAllLem :  P :   Prop,  t : Tree,  n : ,
              AllTree P t  P n  AllTree P (ins n t) :=
begin
  sorry,
end
 

Again this just requirea tree induction. To prove the other direction it is helpfull to also introduce a higher order predicate for lists:

inductive AllList (P :   Prop) : list   Prop
| allListNil : AllList []
| allListCons :  n : ,  ns : list , P n  AllList ns  AllList (n :: ns)
 

And then I prove a lemma:

lemma AllTree2list :  P :   Prop,  t : Tree,
              AllTree P t  AllList P (tree2list t) :=
begin
  sorry,
end
   

To complete the proof of tree2list_lem I needed some additional lemmas about Sorted and Le_list, but you may find a different path.

8.2.3. Tree sort and permutation

As before for insertion sort we also need to show that tree sort permutes its input. The proof is actually very similar to the one for insertion sort, we just need to adopt the lemma for ins

lemma ins_inserts :  n : , t : Tree,
              Insert n (tree2list t) (tree2list (ins n t)) :=
begin
  sorry
end

theorem sort_perm :  ns : list  , Perm ns (sort ns) :=
begin
  assume ns,
  induction ns,
  apply perm_nil,
  apply perm_cons,
  apply ns_ih,
  apply ins_inserts,
end

To show ins_inserts I needed two lemmas about Insert and ++:

lemma insert_appl :  n:,  ms nms is : list ,
      Insert n ms nms  Insert n (is ++ ms) (is ++ nms) :=
begin
  sorry,
end

lemma insert_appr :  n:,  ms is nms : list ,
      Insert n ms nms  Insert n (ms ++ is) (nms ++ is) :=
begin
   sorry,
end

Both can be shown by induction over lists but the choice of which list to do induction over is crucial.

8.2.4. Relation to quicksort

I have already mentioned that tree sort is basically quick sort. How can this be you ask because quicksort doesn’t actually uses any trees. Here is quick sort in Lean:

def split :   list   list  × list 
| n [] := ([] , [])
| n (m :: ns) := match split n ns with (l,r) :=
                       if ble n m
                       then (m :: l, r)
                       else (l, m::r) end

def qsort : list   list 
| [] := []
| (n :: ms) := match split n ms with (l,r) :=
                     (qsort l)++(n::(qsort r)) end


#eval (qsort [6,3,8,2,3])
  

The program uses × and match which I haven’t explained but I hope it is obvious. Lean isn’t happy about the recursive definition of qsort because the recursive call isn’t on a sublist of the input. This can be fixed by using well founded recursion but this is beyond the scope of this course. However, Lean is happy running the program using #eval.

We can get form tree sort to quicksort by a process called program fusion. In a nutshell: the function list2tree produces a tree which is consumed by tree2list but we can actually avoid the creation of the intermediate tree by fusing the two together, hence arriving at quick sort.

Can you see how to tree-ify merge sort? Hint: in this case you need to use trees where the leaves contain the data not the nodes.