about summary refs log tree commit diff
path: root/src
diff options
context:
space:
mode:
authorGriffin Smith <root@gws.fyi>2021-03-14T02·57-0500
committerGriffin Smith <root@gws.fyi>2021-03-14T03·07-0500
commit32a5c0ff0fc58aa6721c1e0ad41950bde2d66744 (patch)
treeef5dcf5234c2a86607ee2f8f30db73bad016e075 /src
parentf8beda81fbe8d04883aee71ff4ea078f897c6de4 (diff)
Add the start of a hindley-milner typechecker
The beginning of a parse-don't-validate-based hindley-milner
typechecker, which returns on success an IR where every AST node
trivially knows its own type, and using those types to determine LLVM
types in codegen.
Diffstat (limited to 'src')
-rw-r--r--src/ast/hir.rs246
-rw-r--r--src/ast/mod.rs3
-rw-r--r--src/codegen/llvm.rs76
-rw-r--r--src/codegen/mod.rs5
-rw-r--r--src/commands/check.rs39
-rw-r--r--src/commands/eval.rs6
-rw-r--r--src/commands/mod.rs2
-rw-r--r--src/common/env.rs24
-rw-r--r--src/common/error.rs11
-rw-r--r--src/compiler.rs4
-rw-r--r--src/interpreter/mod.rs70
-rw-r--r--src/interpreter/value.rs5
-rw-r--r--src/main.rs3
-rw-r--r--src/parser/expr.rs25
-rw-r--r--src/parser/macros.rs1
-rw-r--r--src/parser/mod.rs5
-rw-r--r--src/tc/mod.rs528
17 files changed, 975 insertions, 78 deletions
diff --git a/src/ast/hir.rs b/src/ast/hir.rs
new file mode 100644
index 0000000000..151ddd5298
--- /dev/null
+++ b/src/ast/hir.rs
@@ -0,0 +1,246 @@
+use itertools::Itertools;
+
+use super::{BinaryOperator, Ident, Literal, UnaryOperator};
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct Binding<'a, T> {
+    pub ident: Ident<'a>,
+    pub type_: T,
+    pub body: Expr<'a, T>,
+}
+
+impl<'a, T> Binding<'a, T> {
+    fn to_owned(&self) -> Binding<'static, T>
+    where
+        T: Clone,
+    {
+        Binding {
+            ident: self.ident.to_owned(),
+            type_: self.type_.clone(),
+            body: self.body.to_owned(),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum Expr<'a, T> {
+    Ident(Ident<'a>, T),
+
+    Literal(Literal, T),
+
+    UnaryOp {
+        op: UnaryOperator,
+        rhs: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    BinaryOp {
+        lhs: Box<Expr<'a, T>>,
+        op: BinaryOperator,
+        rhs: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    Let {
+        bindings: Vec<Binding<'a, T>>,
+        body: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    If {
+        condition: Box<Expr<'a, T>>,
+        then: Box<Expr<'a, T>>,
+        else_: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    Fun {
+        args: Vec<(Ident<'a>, T)>,
+        body: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    Call {
+        fun: Box<Expr<'a, T>>,
+        args: Vec<Expr<'a, T>>,
+        type_: T,
+    },
+}
+
+impl<'a, T> Expr<'a, T> {
+    pub fn type_(&self) -> &T {
+        match self {
+            Expr::Ident(_, t) => t,
+            Expr::Literal(_, t) => t,
+            Expr::UnaryOp { type_, .. } => type_,
+            Expr::BinaryOp { type_, .. } => type_,
+            Expr::Let { type_, .. } => type_,
+            Expr::If { type_, .. } => type_,
+            Expr::Fun { type_, .. } => type_,
+            Expr::Call { type_, .. } => type_,
+        }
+    }
+
+    pub fn traverse_type<F, U, E>(self, f: F) -> Result<Expr<'a, U>, E>
+    where
+        F: Fn(T) -> Result<U, E> + Clone,
+    {
+        match self {
+            Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)),
+            Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)),
+            Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp {
+                op,
+                rhs: Box::new(rhs.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::BinaryOp {
+                lhs,
+                op,
+                rhs,
+                type_,
+            } => Ok(Expr::BinaryOp {
+                lhs: Box::new(lhs.traverse_type(f.clone())?),
+                op,
+                rhs: Box::new(rhs.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::Let {
+                bindings,
+                body,
+                type_,
+            } => Ok(Expr::Let {
+                bindings: bindings
+                    .into_iter()
+                    .map(|Binding { ident, type_, body }| {
+                        Ok(Binding {
+                            ident,
+                            type_: f(type_)?,
+                            body: body.traverse_type(f.clone())?,
+                        })
+                    })
+                    .collect::<Result<Vec<_>, E>>()?,
+                body: Box::new(body.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::If {
+                condition,
+                then,
+                else_,
+                type_,
+            } => Ok(Expr::If {
+                condition: Box::new(condition.traverse_type(f.clone())?),
+                then: Box::new(then.traverse_type(f.clone())?),
+                else_: Box::new(else_.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::Fun { args, body, type_ } => Ok(Expr::Fun {
+                args: args
+                    .into_iter()
+                    .map(|(id, t)| Ok((id, f.clone()(t)?)))
+                    .collect::<Result<Vec<_>, E>>()?,
+                body: Box::new(body.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::Call { fun, args, type_ } => Ok(Expr::Call {
+                fun: Box::new(fun.traverse_type(f.clone())?),
+                args: args
+                    .into_iter()
+                    .map(|e| e.traverse_type(f.clone()))
+                    .collect::<Result<Vec<_>, E>>()?,
+                type_: f(type_)?,
+            }),
+        }
+    }
+
+    pub fn to_owned(&self) -> Expr<'static, T>
+    where
+        T: Clone,
+    {
+        match self {
+            Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()),
+            Expr::Literal(lit, t) => Expr::Literal(lit.clone(), t.clone()),
+            Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp {
+                op: *op,
+                rhs: Box::new((**rhs).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::BinaryOp {
+                lhs,
+                op,
+                rhs,
+                type_,
+            } => Expr::BinaryOp {
+                lhs: Box::new((**lhs).to_owned()),
+                op: *op,
+                rhs: Box::new((**rhs).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::Let {
+                bindings,
+                body,
+                type_,
+            } => Expr::Let {
+                bindings: bindings.into_iter().map(|b| b.to_owned()).collect(),
+                body: Box::new((**body).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::If {
+                condition,
+                then,
+                else_,
+                type_,
+            } => Expr::If {
+                condition: Box::new((**condition).to_owned()),
+                then: Box::new((**then).to_owned()),
+                else_: Box::new((**else_).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::Fun { args, body, type_ } => Expr::Fun {
+                args: args
+                    .into_iter()
+                    .map(|(id, t)| (id.to_owned(), t.clone()))
+                    .collect(),
+                body: Box::new((**body).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::Call { fun, args, type_ } => Expr::Call {
+                fun: Box::new((**fun).to_owned()),
+                args: args.into_iter().map(|e| e.to_owned()).collect(),
+                type_: type_.clone(),
+            },
+        }
+    }
+}
+
+pub enum Decl<'a, T> {
+    Fun {
+        name: Ident<'a>,
+        args: Vec<(Ident<'a>, T)>,
+        body: Box<Expr<'a, T>>,
+        type_: T,
+    },
+}
+
+impl<'a, T> Decl<'a, T> {
+    pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E>
+    where
+        F: Fn(T) -> Result<U, E> + Clone,
+    {
+        match self {
+            Decl::Fun {
+                name,
+                args,
+                body,
+                type_,
+            } => Ok(Decl::Fun {
+                name,
+                args: args
+                    .into_iter()
+                    .map(|(id, t)| Ok((id, f(t)?)))
+                    .try_collect()?,
+                body: Box::new(body.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+        }
+    }
+}
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index dc22ac3cdb..cef366d16e 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -1,3 +1,5 @@
+pub(crate) mod hir;
+
 use std::borrow::Cow;
 use std::convert::TryFrom;
 use std::fmt::{self, Display, Formatter};
@@ -107,6 +109,7 @@ pub enum UnaryOperator {
 #[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Literal {
     Int(u64),
+    Bool(bool),
 }
 
 #[derive(Debug, PartialEq, Eq, Clone)]
diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs
index 1f4a457cd8..5b5db90a1a 100644
--- a/src/codegen/llvm.rs
+++ b/src/codegen/llvm.rs
@@ -7,12 +7,13 @@ use inkwell::builder::Builder;
 pub use inkwell::context::Context;
 use inkwell::module::Module;
 use inkwell::support::LLVMString;
-use inkwell::types::FunctionType;
+use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
 use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
 use inkwell::IntPredicate;
 use thiserror::Error;
 
-use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator};
+use crate::ast::hir::{Binding, Decl, Expr};
+use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, PartialEq, Eq, Error)]
@@ -36,7 +37,7 @@ pub struct Codegen<'ctx, 'ast> {
     context: &'ctx Context,
     pub module: Module<'ctx>,
     builder: Builder<'ctx>,
-    env: Env<'ast, AnyValueEnum<'ctx>>,
+    env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>,
     function_stack: Vec<FunctionValue<'ctx>>,
     identifier_counter: u32,
 }
@@ -77,18 +78,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             .append_basic_block(*self.function_stack.last().unwrap(), name)
     }
 
-    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> {
+    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
         match expr {
-            Expr::Ident(id) => self
+            Expr::Ident(id, _) => self
                 .env
                 .resolve(id)
                 .cloned()
                 .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
-            Expr::Literal(Literal::Int(i)) => {
-                let ty = self.context.i64_type();
-                Ok(AnyValueEnum::IntValue(ty.const_int(*i, false)))
+            Expr::Literal(lit, ty) => {
+                let ty = self.codegen_int_type(ty);
+                match lit {
+                    Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
+                    Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
+                        ty.const_int(if *b { 1 } else { 0 }, false),
+                    )),
+                }
             }
-            Expr::UnaryOp { op, rhs } => {
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
                     UnaryOperator::Not => unimplemented!(),
@@ -97,7 +103,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     )),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.codegen_expr(lhs)?;
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
@@ -135,7 +141,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.codegen_expr(body)?;
@@ -149,6 +155,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 condition,
                 then,
                 else_,
+                type_,
             } => {
                 let then_block = self.append_basic_block("then");
                 let else_block = self.append_basic_block("else");
@@ -168,15 +175,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.builder.build_unconditional_branch(join_block);
 
                 self.builder.position_at_end(join_block);
-                let phi = self.builder.build_phi(self.context.i64_type(), "join");
+                let phi = self.builder.build_phi(self.codegen_type(type_), "join");
                 phi.add_incoming(&[
                     (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
                     (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
                 ]);
                 Ok(phi.as_basic_value().into())
             }
-            Expr::Call { fun, args } => {
-                if let Expr::Ident(id) = &**fun {
+            Expr::Call { fun, args, .. } => {
+                if let Expr::Ident(id, _) = &**fun {
                     let function = self
                         .module
                         .get_function(id.into())
@@ -197,8 +204,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     todo!()
                 }
             }
-            Expr::Fun(fun) => {
-                let Fun { args, body } = &**fun;
+            Expr::Fun { args, body, .. } => {
                 let fname = self.fresh_ident("f");
                 let cur_block = self.builder.get_insert_block().unwrap();
                 let env = self.env.save(); // TODO: closures
@@ -207,29 +213,27 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.env.restore(env);
                 Ok(function.into())
             }
-            Expr::Ascription { expr, .. } => self.codegen_expr(expr),
         }
     }
 
     pub fn codegen_function(
         &mut self,
         name: &str,
-        args: &'ast [Ident<'ast>],
-        body: &'ast Expr<'ast>,
+        args: &'ast [(Ident<'ast>, Type)],
+        body: &'ast Expr<'ast, Type>,
     ) -> Result<FunctionValue<'ctx>> {
-        let i64_type = self.context.i64_type();
         self.new_function(
             name,
-            i64_type.fn_type(
+            self.codegen_type(body.type_()).fn_type(
                 args.iter()
-                    .map(|_| i64_type.into())
+                    .map(|(_, at)| self.codegen_type(at))
                     .collect::<Vec<_>>()
                     .as_slice(),
                 false,
             ),
         );
         self.env.push();
-        for (i, arg) in args.iter().enumerate() {
+        for (i, (arg, _)) in args.iter().enumerate() {
             self.env.set(
                 arg,
                 self.cur_function().get_nth_param(i as u32).unwrap().into(),
@@ -240,11 +244,10 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         Ok(self.finish_function(&res))
     }
 
-    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> {
+    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> {
         match decl {
             Decl::Fun {
-                name,
-                body: Fun { args, body },
+                name, args, body, ..
             } => {
                 self.codegen_function(name.into(), args, body)?;
                 Ok(())
@@ -252,13 +255,28 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         }
     }
 
-    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> {
+    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
         self.new_function("main", self.context.i64_type().fn_type(&[], false));
         let res = self.codegen_expr(expr)?.try_into().unwrap();
-        self.finish_function(&res);
+        if *expr.type_() != Type::Int {
+            self.builder
+                .build_return(Some(&self.context.i64_type().const_int(0, false)));
+        } else {
+            self.finish_function(&res);
+        }
         Ok(())
     }
 
+    fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
+        // TODO
+        self.context.i64_type().into()
+    }
+
+    fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
+        // TODO
+        self.context.i64_type()
+    }
+
     pub fn print_to_file<P>(&self, path: P) -> Result<()>
     where
         P: AsRef<Path>,
@@ -299,6 +317,8 @@ mod tests {
     fn jit_eval<T>(expr: &str) -> anyhow::Result<T> {
         let expr = crate::parser::expr(expr).unwrap().1;
 
+        let expr = crate::tc::typecheck_expr(expr).unwrap();
+
         let context = Context::create();
         let mut codegen = Codegen::new(&context, "test");
         let execution_engine = codegen
diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs
index 6f95d90b45..8ef057dba0 100644
--- a/src/codegen/mod.rs
+++ b/src/codegen/mod.rs
@@ -4,10 +4,11 @@ use inkwell::execution_engine::JitFunction;
 use inkwell::OptimizationLevel;
 pub use llvm::*;
 
-use crate::ast::Expr;
+use crate::ast::hir::Expr;
+use crate::ast::Type;
 use crate::common::Result;
 
-pub fn jit_eval<T>(expr: &Expr) -> Result<T> {
+pub fn jit_eval<T>(expr: &Expr<Type>) -> Result<T> {
     let context = Context::create();
     let mut codegen = Codegen::new(&context, "eval");
     let execution_engine = codegen
diff --git a/src/commands/check.rs b/src/commands/check.rs
new file mode 100644
index 0000000000..40de288a28
--- /dev/null
+++ b/src/commands/check.rs
@@ -0,0 +1,39 @@
+use clap::Clap;
+use std::path::PathBuf;
+
+use crate::ast::Type;
+use crate::{parser, tc, Result};
+
+/// Typecheck a file or expression
+#[derive(Clap)]
+pub struct Check {
+    /// File to check
+    path: Option<PathBuf>,
+
+    /// Expression to check
+    #[clap(long, short = 'e')]
+    expr: Option<String>,
+}
+
+fn run_expr(expr: String) -> Result<Type> {
+    let (_, parsed) = parser::expr(&expr)?;
+    let hir_expr = tc::typecheck_expr(parsed)?;
+    Ok(hir_expr.type_().clone())
+}
+
+fn run_path(path: PathBuf) -> Result<Type> {
+    todo!()
+}
+
+impl Check {
+    pub fn run(self) -> Result<()> {
+        let type_ = match (self.path, self.expr) {
+            (None, None) => Err("Must specify either a file or expression to check".into()),
+            (Some(_), Some(_)) => Err("Cannot specify both a file and expression to check".into()),
+            (None, Some(expr)) => run_expr(expr),
+            (Some(path), None) => run_path(path),
+        }?;
+        println!("type: {}", type_);
+        Ok(())
+    }
+}
diff --git a/src/commands/eval.rs b/src/commands/eval.rs
index 112bee6462..61a712c08a 100644
--- a/src/commands/eval.rs
+++ b/src/commands/eval.rs
@@ -3,6 +3,7 @@ use clap::Clap;
 use crate::codegen;
 use crate::interpreter;
 use crate::parser;
+use crate::tc;
 use crate::Result;
 
 /// Evaluate an expression and print its result
@@ -19,10 +20,11 @@ pub struct Eval {
 impl Eval {
     pub fn run(self) -> Result<()> {
         let (_, parsed) = parser::expr(&self.expr)?;
+        let hir = tc::typecheck_expr(parsed)?;
         let result = if self.jit {
-            codegen::jit_eval::<i64>(&parsed)?.into()
+            codegen::jit_eval::<i64>(&hir)?.into()
         } else {
-            interpreter::eval(&parsed)?
+            interpreter::eval(&hir)?
         };
         println!("{}", result);
         Ok(())
diff --git a/src/commands/mod.rs b/src/commands/mod.rs
index 9c0038dabf..fd0a822708 100644
--- a/src/commands/mod.rs
+++ b/src/commands/mod.rs
@@ -1,5 +1,7 @@
+pub mod check;
 pub mod compile;
 pub mod eval;
 
+pub use check::Check;
 pub use compile::Compile;
 pub use eval::Eval;
diff --git a/src/common/env.rs b/src/common/env.rs
index f499323639..59a5e46c46 100644
--- a/src/common/env.rs
+++ b/src/common/env.rs
@@ -1,19 +1,25 @@
+use std::borrow::Borrow;
 use std::collections::HashMap;
+use std::hash::Hash;
 use std::mem;
 
-use crate::ast::Ident;
-
 /// A lexical environment
 #[derive(Debug, PartialEq, Eq)]
-pub struct Env<'ast, V>(Vec<HashMap<&'ast Ident<'ast>, V>>);
+pub struct Env<K: Eq + Hash, V>(Vec<HashMap<K, V>>);
 
-impl<'ast, V> Default for Env<'ast, V> {
+impl<K, V> Default for Env<K, V>
+where
+    K: Eq + Hash,
+{
     fn default() -> Self {
         Self::new()
     }
 }
 
-impl<'ast, V> Env<'ast, V> {
+impl<K, V> Env<K, V>
+where
+    K: Eq + Hash,
+{
     pub fn new() -> Self {
         Self(vec![Default::default()])
     }
@@ -34,11 +40,15 @@ impl<'ast, V> Env<'ast, V> {
         *self = saved;
     }
 
-    pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) {
+    pub fn set(&mut self, k: K, v: V) {
         self.0.last_mut().unwrap().insert(k, v);
     }
 
-    pub fn resolve<'a>(&'a self, k: &'ast Ident<'ast>) -> Option<&'a V> {
+    pub fn resolve<'a, Q>(&'a self, k: &Q) -> Option<&'a V>
+    where
+        K: Borrow<Q>,
+        Q: Hash + Eq + ?Sized,
+    {
         for ctx in self.0.iter().rev() {
             if let Some(res) = ctx.get(k) {
                 return Some(res);
diff --git a/src/common/error.rs b/src/common/error.rs
index f3f3023cea..51575a895e 100644
--- a/src/common/error.rs
+++ b/src/common/error.rs
@@ -2,7 +2,7 @@ use std::{io, result};
 
 use thiserror::Error;
 
-use crate::{codegen, interpreter, parser};
+use crate::{codegen, interpreter, parser, tc};
 
 #[derive(Error, Debug)]
 pub enum Error {
@@ -18,6 +18,9 @@ pub enum Error {
     #[error("Compile error: {0}")]
     CodegenError(#[from] codegen::Error),
 
+    #[error("Type error: {0}")]
+    TypeError(#[from] tc::Error),
+
     #[error("{0}")]
     Message(String),
 }
@@ -28,6 +31,12 @@ impl From<String> for Error {
     }
 }
 
+impl<'a> From<&'a str> for Error {
+    fn from(s: &'a str) -> Self {
+        Self::Message(s.to_owned())
+    }
+}
+
 impl<'a> From<nom::Err<nom::error::Error<&'a str>>> for Error {
     fn from(e: nom::Err<nom::error::Error<&'a str>>) -> Self {
         use nom::error::Error as NomError;
diff --git a/src/compiler.rs b/src/compiler.rs
index 5f8e1ef4fa..f925b267df 100644
--- a/src/compiler.rs
+++ b/src/compiler.rs
@@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
 
 use crate::codegen::{self, Codegen};
 use crate::common::Result;
-use crate::parser;
+use crate::{parser, tc};
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
 pub enum OutputFormat {
@@ -55,6 +55,8 @@ pub struct CompilerOptions {
 pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
     let src = fs::read_to_string(input)?;
     let (_, decls) = parser::toplevel(&src)?; // TODO: statements
+    let decls = tc::typecheck_toplevel(decls)?;
+
     let context = codegen::Context::create();
     let mut codegen = Codegen::new(
         &context,
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
index 00421ee90d..85a8928cbf 100644
--- a/src/interpreter/mod.rs
+++ b/src/interpreter/mod.rs
@@ -3,14 +3,13 @@ mod value;
 
 pub use self::error::{Error, Result};
 pub use self::value::{Function, Value};
-use crate::ast::{
-    BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator,
-};
+use crate::ast::hir::{Binding, Expr};
+use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, Default)]
 pub struct Interpreter<'a> {
-    env: Env<'a, Value<'a>>,
+    env: Env<&'a Ident<'a>, Value<'a>>,
 }
 
 impl<'a> Interpreter<'a> {
@@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> {
             .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
     }
 
-    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> {
-        match expr {
-            Expr::Ident(id) => self.resolve(id),
-            Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
-            Expr::UnaryOp { op, rhs } => {
+    pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
+        let res = match expr {
+            Expr::Ident(id, _) => self.resolve(id),
+            Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
+            Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.eval(rhs)?;
                 match op {
                     UnaryOperator::Neg => -rhs,
                     _ => unimplemented!(),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.eval(lhs)?;
                 let rhs = self.eval(rhs)?;
                 match op {
@@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.eval(body)?;
@@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> {
                 condition,
                 then,
                 else_,
+                ..
             } => {
                 let condition = self.eval(condition)?;
                 if *(condition.as_type::<bool>()?) {
@@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> {
                     self.eval(else_)
                 }
             }
-            Expr::Call { ref fun, args } => {
+            Expr::Call { ref fun, args, .. } => {
                 let fun = self.eval(fun)?;
                 let expected_type = FunctionType {
                     args: args.iter().map(|_| Type::Int).collect(),
@@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> {
                 }
                 Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
             }
-            Expr::Fun(fun) => Ok(Value::from(value::Function {
-                // TODO
-                type_: FunctionType {
-                    args: fun.args.iter().map(|_| Type::Int).collect(),
-                    ret: Box::new(Type::Int),
-                },
-                args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
-                body: fun.body.to_owned(),
-            })),
-            Expr::Ascription { expr, .. } => self.eval(expr),
-        }
+            Expr::Fun { args, body, type_ } => {
+                let type_ = match type_ {
+                    Type::Function(ft) => ft.clone(),
+                    _ => unreachable!("Function expression without function type"),
+                };
+
+                Ok(Value::from(value::Function {
+                    // TODO
+                    type_,
+                    args: args.iter().map(|(arg, _)| arg.to_owned()).collect(),
+                    body: (**body).to_owned(),
+                }))
+            }
+        }?;
+        debug_assert_eq!(&res.type_(), expr.type_());
+        Ok(res)
     }
 }
 
-pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> {
+pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
     let mut interpreter = Interpreter::new();
     interpreter.eval(expr)
 }
@@ -121,17 +127,18 @@ mod tests {
     use super::*;
     use BinaryOperator::*;
 
-    fn int_lit(i: u64) -> Box<Expr<'static>> {
-        Box::new(Expr::Literal(Literal::Int(i)))
+    fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
+        Box::new(Expr::Literal(Literal::Int(i), Type::Int))
     }
 
-    fn parse_eval<T>(src: &str) -> T
+    fn do_eval<T>(src: &str) -> T
     where
         for<'a> &'a T: TryFrom<&'a Val<'a>>,
         T: Clone + TypeOf,
     {
         let expr = crate::parser::expr(src).unwrap().1;
-        let res = eval(&expr).unwrap();
+        let hir = crate::tc::typecheck_expr(expr).unwrap();
+        let res = eval(&hir).unwrap();
         res.as_type::<T>().unwrap().clone()
     }
 
@@ -141,6 +148,7 @@ mod tests {
             lhs: int_lit(1),
             op: Mul,
             rhs: int_lit(2),
+            type_: Type::Int,
         };
         let res = eval(&expr).unwrap();
         assert_eq!(*res.as_type::<i64>().unwrap(), 2);
@@ -148,19 +156,19 @@ mod tests {
 
     #[test]
     fn variable_shadowing() {
-        let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
+        let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
         assert_eq!(res, 3);
     }
 
     #[test]
     fn conditional_with_equals() {
-        let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
+        let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
         assert_eq!(res, 2);
     }
 
     #[test]
     fn function_call() {
-        let res = parse_eval::<i64>("let id = fn x = x in id 1");
+        let res = do_eval::<i64>("let id = fn x = x in id 1");
         assert_eq!(res, 1);
     }
 }
diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs
index 496e9c4230..5e55825160 100644
--- a/src/interpreter/value.rs
+++ b/src/interpreter/value.rs
@@ -6,13 +6,14 @@ use std::rc::Rc;
 use derive_more::{Deref, From, TryInto};
 
 use super::{Error, Result};
-use crate::ast::{Expr, FunctionType, Ident, Type};
+use crate::ast::hir::Expr;
+use crate::ast::{FunctionType, Ident, Type};
 
 #[derive(Debug, Clone)]
 pub struct Function<'a> {
     pub type_: FunctionType,
     pub args: Vec<Ident<'a>>,
-    pub body: Expr<'a>,
+    pub body: Expr<'a, Type>,
 }
 
 #[derive(From, TryInto)]
diff --git a/src/main.rs b/src/main.rs
index b539ebbb3d..d5b00d6b6c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,6 +8,7 @@ pub mod compiler;
 pub mod interpreter;
 #[macro_use]
 pub mod parser;
+pub mod tc;
 
 pub use common::{Error, Result};
 
@@ -21,6 +22,7 @@ struct Opts {
 enum Command {
     Eval(commands::Eval),
     Compile(commands::Compile),
+    Check(commands::Check),
 }
 
 fn main() -> anyhow::Result<()> {
@@ -28,5 +30,6 @@ fn main() -> anyhow::Result<()> {
     match opts.subcommand {
         Command::Eval(eval) => Ok(eval.run()?),
         Command::Compile(compile) => Ok(compile.run()?),
+        Command::Check(check) => Ok(check.run()?),
     }
 }
diff --git a/src/parser/expr.rs b/src/parser/expr.rs
index 2fda3e93fa..73c873b5b3 100644
--- a/src/parser/expr.rs
+++ b/src/parser/expr.rs
@@ -156,7 +156,14 @@ where
 
 named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int));
 
-named!(literal(&str) -> Expr, map!(alt!(int), Expr::Literal));
+named!(bool_(&str) -> Literal, alt!(
+    tag!("true") => { |_| Literal::Bool(true) } |
+    tag!("false") => { |_| Literal::Bool(false) }
+));
+
+named!(literal(&str) -> Literal, alt!(int | bool_));
+
+named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
 
 named!(binding(&str) -> Binding, do_parse!(
     multispace0
@@ -262,7 +269,7 @@ named!(fun_expr(&str) -> Expr, do_parse!(
 
 named!(arg(&str) -> Expr, alt!(
     ident_expr |
-    literal |
+    literal_expr |
     paren_expr
 ));
 
@@ -280,7 +287,7 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!(
     let_ |
     if_ |
     fun_expr |
-    literal |
+    literal_expr |
     ident_expr
 ));
 
@@ -400,6 +407,18 @@ pub(crate) mod tests {
     }
 
     #[test]
+    fn bools() {
+        assert_eq!(
+            test_parse!(expr, "true"),
+            Expr::Literal(Literal::Bool(true))
+        );
+        assert_eq!(
+            test_parse!(expr, "false"),
+            Expr::Literal(Literal::Bool(false))
+        );
+    }
+
+    #[test]
     fn let_complex() {
         let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4");
         assert_eq!(
diff --git a/src/parser/macros.rs b/src/parser/macros.rs
index 60db5133dc..406e5c0e69 100644
--- a/src/parser/macros.rs
+++ b/src/parser/macros.rs
@@ -1,3 +1,4 @@
+#[cfg(test)]
 #[macro_use]
 macro_rules! test_parse {
     ($parser: ident, $src: expr) => {{
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index af7dff6ff2..0251d02df4 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -14,7 +14,10 @@ pub use type_::type_;
 pub type Error = nom::Err<nom::error::Error<String>>;
 
 pub(crate) fn is_reserved(s: &str) -> bool {
-    matches!(s, "if" | "then" | "else" | "let" | "in" | "fn")
+    matches!(
+        s,
+        "if" | "then" | "else" | "let" | "in" | "fn" | "int" | "float" | "bool" | "true" | "false"
+    )
 }
 
 pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E>
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
new file mode 100644
index 0000000000..b5acfac2b4
--- /dev/null
+++ b/src/tc/mod.rs
@@ -0,0 +1,528 @@
+use derive_more::From;
+use itertools::Itertools;
+use std::collections::HashMap;
+use std::convert::{TryFrom, TryInto};
+use std::fmt::{self, Display};
+use std::result;
+use thiserror::Error;
+
+use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
+use crate::common::env::Env;
+
+#[derive(Debug, Error)]
+pub enum Error {
+    #[error("Undefined variable {0}")]
+    UndefinedVariable(Ident<'static>),
+
+    #[error("Mismatched types: expected {expected}, but got {actual}")]
+    TypeMismatch { expected: Type, actual: Type },
+
+    #[error("Mismatched types, expected numeric type, but got {0}")]
+    NonNumeric(Type),
+
+    #[error("Ambiguous type {0}")]
+    AmbiguousType(TyVar),
+}
+
+pub type Result<T> = result::Result<T, Error>;
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub struct TyVar(u64);
+
+impl Display for TyVar {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "t{}", self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct NullaryType(String);
+
+impl Display for NullaryType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str(&self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum PrimType {
+    Int,
+    Float,
+    Bool,
+}
+
+impl From<PrimType> for ast::Type {
+    fn from(pr: PrimType) -> Self {
+        match pr {
+            PrimType::Int => ast::Type::Int,
+            PrimType::Float => ast::Type::Float,
+            PrimType::Bool => ast::Type::Bool,
+        }
+    }
+}
+
+impl Display for PrimType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            PrimType::Int => f.write_str("int"),
+            PrimType::Float => f.write_str("float"),
+            PrimType::Bool => f.write_str("bool"),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, From)]
+pub enum Type {
+    #[from(ignore)]
+    Univ(TyVar),
+    #[from(ignore)]
+    Exist(TyVar),
+    Nullary(NullaryType),
+    Prim(PrimType),
+    Fun {
+        args: Vec<Type>,
+        ret: Box<Type>,
+    },
+}
+
+impl PartialEq<ast::Type> for Type {
+    fn eq(&self, other: &ast::Type) -> bool {
+        match (self, other) {
+            (Type::Univ(_), _) => todo!(),
+            (Type::Exist(_), _) => false,
+            (Type::Nullary(_), _) => todo!(),
+            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
+                *args == ft.args && (**ret).eq(&*ft.ret)
+            }
+            (Type::Fun { .. }, _) => false,
+        }
+    }
+}
+
+impl TryFrom<Type> for ast::Type {
+    type Error = Type;
+
+    fn try_from(value: Type) -> result::Result<Self, Self::Error> {
+        match value {
+            Type::Univ(_) => todo!(),
+            Type::Exist(_) => Err(value),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(p) => Ok(p.into()),
+            Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .clone()
+                    .into_iter()
+                    .map(Self::try_from)
+                    .try_collect()
+                    .map_err(|_| value.clone())?,
+                ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
+            })),
+        }
+    }
+}
+
+const INT: Type = Type::Prim(PrimType::Int);
+const FLOAT: Type = Type::Prim(PrimType::Float);
+const BOOL: Type = Type::Prim(PrimType::Bool);
+
+impl Display for Type {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Type::Nullary(nt) => nt.fmt(f),
+            Type::Prim(p) => p.fmt(f),
+            Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
+            Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
+            Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+        }
+    }
+}
+
+impl From<ast::Type> for Type {
+    fn from(type_: ast::Type) -> Self {
+        match type_ {
+            ast::Type::Int => INT,
+            ast::Type::Float => FLOAT,
+            ast::Type::Bool => BOOL,
+            ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
+                args: args.into_iter().map(Self::from).collect(),
+                ret: Box::new(Self::from(*ret)),
+            },
+        }
+    }
+}
+
+struct Typechecker<'ast> {
+    ty_var_counter: u64,
+    ctx: HashMap<TyVar, Type>,
+    env: Env<Ident<'ast>, Type>,
+}
+
+impl<'ast> Typechecker<'ast> {
+    fn new() -> Self {
+        Self {
+            ty_var_counter: 0,
+            ctx: Default::default(),
+            env: Default::default(),
+        }
+    }
+
+    pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
+        match expr {
+            ast::Expr::Ident(ident) => {
+                let type_ = self
+                    .env
+                    .resolve(&ident)
+                    .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
+                    .clone();
+                Ok(hir::Expr::Ident(ident, type_))
+            }
+            ast::Expr::Literal(lit) => {
+                let type_ = match lit {
+                    Literal::Int(_) => Type::Prim(PrimType::Int),
+                    Literal::Bool(_) => Type::Prim(PrimType::Bool),
+                };
+                Ok(hir::Expr::Literal(lit, type_))
+            }
+            ast::Expr::UnaryOp { op, rhs } => todo!(),
+            ast::Expr::BinaryOp { lhs, op, rhs } => {
+                let lhs = self.tc_expr(*lhs)?;
+                let rhs = self.tc_expr(*rhs)?;
+                let type_ = match op {
+                    BinaryOperator::Equ | BinaryOperator::Neq => {
+                        self.unify(lhs.type_(), rhs.type_())?;
+                        Type::Prim(PrimType::Bool)
+                    }
+                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
+                        let ty = self.unify(lhs.type_(), rhs.type_())?;
+                        // if !matches!(ty, Type::Int | Type::Float) {
+                        //     return Err(Error::NonNumeric(ty));
+                        // }
+                        ty
+                    }
+                    BinaryOperator::Div => todo!(),
+                    BinaryOperator::Pow => todo!(),
+                };
+                Ok(hir::Expr::BinaryOp {
+                    lhs: Box::new(lhs),
+                    op,
+                    rhs: Box::new(rhs),
+                    type_,
+                })
+            }
+            ast::Expr::Let { bindings, body } => {
+                self.env.push();
+                let bindings = bindings
+                    .into_iter()
+                    .map(
+                        |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
+                            let body = self.tc_expr(body)?;
+                            if let Some(type_) = type_ {
+                                self.unify(body.type_(), &type_.into())?;
+                            }
+                            self.env.set(ident.clone(), body.type_().clone());
+                            Ok(hir::Binding {
+                                ident,
+                                type_: body.type_().clone(),
+                                body,
+                            })
+                        },
+                    )
+                    .collect::<Result<Vec<hir::Binding<Type>>>>()?;
+                let body = self.tc_expr(*body)?;
+                self.env.pop();
+                Ok(hir::Expr::Let {
+                    bindings,
+                    type_: body.type_().clone(),
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::If {
+                condition,
+                then,
+                else_,
+            } => {
+                let condition = self.tc_expr(*condition)?;
+                self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
+                let then = self.tc_expr(*then)?;
+                let else_ = self.tc_expr(*else_)?;
+                let type_ = self.unify(then.type_(), else_.type_())?;
+                Ok(hir::Expr::If {
+                    condition: Box::new(condition),
+                    then: Box::new(then),
+                    else_: Box::new(else_),
+                    type_,
+                })
+            }
+            ast::Expr::Fun(f) => {
+                let ast::Fun { args, body } = *f;
+                self.env.push();
+                let args: Vec<_> = args
+                    .into_iter()
+                    .map(|id| {
+                        let ty = self.fresh_ex();
+                        self.env.set(id.clone(), ty.clone());
+                        (id, ty)
+                    })
+                    .collect();
+                let body = self.tc_expr(body)?;
+                self.env.pop();
+                Ok(hir::Expr::Fun {
+                    type_: Type::Fun {
+                        args: args.iter().map(|(_, ty)| ty.clone()).collect(),
+                        ret: Box::new(body.type_().clone()),
+                    },
+                    args,
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::Call { fun, args } => {
+                let ret_ty = self.fresh_ex();
+                let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
+                let ft = Type::Fun {
+                    args: arg_tys.clone(),
+                    ret: Box::new(ret_ty.clone()),
+                };
+                let fun = self.tc_expr(*fun)?;
+                self.unify(&ft, fun.type_())?;
+                let args = args
+                    .into_iter()
+                    .zip(arg_tys)
+                    .map(|(arg, ty)| {
+                        let arg = self.tc_expr(arg)?;
+                        self.unify(&ty, arg.type_())?;
+                        Ok(arg)
+                    })
+                    .try_collect()?;
+                Ok(hir::Expr::Call {
+                    fun: Box::new(fun),
+                    args,
+                    type_: ret_ty,
+                })
+            }
+            ast::Expr::Ascription { expr, type_ } => {
+                let expr = self.tc_expr(*expr)?;
+                self.unify(expr.type_(), &type_.into())?;
+                Ok(expr)
+            }
+        }
+    }
+
+    pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
+        match decl {
+            ast::Decl::Fun { name, body } => {
+                let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
+                let type_ = body.type_().clone();
+                self.env.set(name.clone(), type_);
+                match body {
+                    hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
+                        name,
+                        args,
+                        body,
+                        type_,
+                    }),
+                    _ => unreachable!(),
+                }
+            }
+        }
+    }
+
+    fn fresh_tv(&mut self) -> TyVar {
+        self.ty_var_counter += 1;
+        TyVar(self.ty_var_counter)
+    }
+
+    fn fresh_ex(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn fresh_univ(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
+        // TODO
+        expr
+    }
+
+    fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
+        match (ty1, ty2) {
+            (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
+            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
+                Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
+                Some(existing_ty) => Err(Error::TypeMismatch {
+                    expected: ty.clone(),
+                    actual: existing_ty.into(),
+                }),
+                None => match self.ctx.insert(*tv, ty.clone()) {
+                    Some(existing) => self.unify(&existing, ty),
+                    None => Ok(ty.clone()),
+                },
+            },
+            (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
+            (
+                Type::Fun {
+                    args: args1,
+                    ret: ret1,
+                },
+                Type::Fun {
+                    args: args2,
+                    ret: ret2,
+                },
+            ) => {
+                let args = args1
+                    .iter()
+                    .zip(args2)
+                    .map(|(t1, t2)| self.unify(t1, t2))
+                    .try_collect()?;
+                let ret = self.unify(ret1, ret2)?;
+                Ok(Type::Fun {
+                    args,
+                    ret: Box::new(ret),
+                })
+            }
+            (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
+            _ => Err(Error::TypeMismatch {
+                expected: ty1.clone(),
+                actual: ty2.clone(),
+            }),
+        }
+    }
+
+    fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
+        expr.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
+        decl.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
+        match ty {
+            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
+            Type::Univ(tv) => todo!(),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(pr) => Ok(pr.into()),
+            Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .into_iter()
+                    .map(|ty| self.finalize_type(ty))
+                    .try_collect()?,
+                ret: Box::new(self.finalize_type(*ret)?),
+            })),
+        }
+    }
+
+    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
+        let mut res = &Type::Exist(tv);
+        loop {
+            match res {
+                Type::Exist(tv) => {
+                    res = self.ctx.get(tv)?;
+                }
+                Type::Univ(_) => todo!(),
+                Type::Nullary(_) => todo!(),
+                Type::Prim(pr) => break Some((*pr).into()),
+                Type::Fun { args, ret } => todo!(),
+            }
+        }
+    }
+}
+
+pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
+    let mut typechecker = Typechecker::new();
+    let typechecked = typechecker.tc_expr(expr)?;
+    typechecker.finalize_expr(typechecked)
+}
+
+pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
+    let mut typechecker = Typechecker::new();
+    decls
+        .into_iter()
+        .map(|decl| {
+            let decl = typechecker.tc_decl(decl)?;
+            typechecker.finalize_decl(decl)
+        })
+        .try_collect()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    macro_rules! assert_type {
+        ($expr: expr, $type: expr) => {
+            use crate::parser::{expr, type_};
+            let parsed_expr = test_parse!(expr, $expr);
+            let parsed_type = test_parse!(type_, $type);
+            let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
+            assert_eq!(res.type_(), &parsed_type);
+        };
+    }
+
+    macro_rules! assert_type_error {
+        ($expr: expr) => {
+            use crate::parser::expr;
+            let parsed_expr = test_parse!(expr, $expr);
+            let res = typecheck_expr(parsed_expr);
+            assert!(
+                res.is_err(),
+                "Expected type error, but got type: {}",
+                res.unwrap().type_()
+            );
+        };
+    }
+
+    #[test]
+    fn literal_int() {
+        assert_type!("1", "int");
+    }
+
+    #[test]
+    fn conditional() {
+        assert_type!("if 1 == 2 then 3 else 4", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn add_bools() {
+        assert_type_error!("true + false");
+    }
+
+    #[test]
+    fn call_generic_function() {
+        assert_type!("(fn x = x) 1", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn generic_function() {
+        assert_type!("fn x = x", "fn x, y -> x");
+    }
+
+    #[test]
+    #[ignore]
+    fn let_generalization() {
+        assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
+    }
+
+    #[test]
+    fn concrete_function() {
+        assert_type!("fn x = x + 1", "fn int -> int");
+    }
+
+    #[test]
+    fn call_concrete_function() {
+        assert_type!("(fn x = x + 1) 2", "int");
+    }
+
+    #[test]
+    fn conditional_non_bool() {
+        assert_type_error!("if 3 then true else false");
+    }
+
+    #[test]
+    fn let_int() {
+        assert_type!("let x = 1 in x", "int");
+    }
+}