about summary refs log tree commit diff
diff options
context:
space:
mode:
authorGriffin Smith <root@gws.fyi>2021-03-14T02·57-0500
committerGriffin Smith <root@gws.fyi>2021-03-14T03·07-0500
commit32a5c0ff0fc58aa6721c1e0ad41950bde2d66744 (patch)
treeef5dcf5234c2a86607ee2f8f30db73bad016e075
parentf8beda81fbe8d04883aee71ff4ea078f897c6de4 (diff)
Add the start of a hindley-milner typechecker
The beginning of a parse-don't-validate-based hindley-milner
typechecker, which returns on success an IR where every AST node
trivially knows its own type, and using those types to determine LLVM
types in codegen.
-rw-r--r--Cargo.lock1
-rw-r--r--Cargo.toml1
-rw-r--r--ach/.gitignore3
-rw-r--r--src/ast/hir.rs246
-rw-r--r--src/ast/mod.rs3
-rw-r--r--src/codegen/llvm.rs76
-rw-r--r--src/codegen/mod.rs5
-rw-r--r--src/commands/check.rs39
-rw-r--r--src/commands/eval.rs6
-rw-r--r--src/commands/mod.rs2
-rw-r--r--src/common/env.rs24
-rw-r--r--src/common/error.rs11
-rw-r--r--src/compiler.rs4
-rw-r--r--src/interpreter/mod.rs70
-rw-r--r--src/interpreter/value.rs5
-rw-r--r--src/main.rs3
-rw-r--r--src/parser/expr.rs25
-rw-r--r--src/parser/macros.rs1
-rw-r--r--src/parser/mod.rs5
-rw-r--r--src/tc/mod.rs528
20 files changed, 980 insertions, 78 deletions
diff --git a/Cargo.lock b/Cargo.lock
index d8eaedeca1..8ec5ad6cf9 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -9,6 +9,7 @@ dependencies = [
  "derive_more",
  "inkwell",
  "itertools",
+ "lazy_static",
  "llvm-sys",
  "nom",
  "nom-trace",
diff --git a/Cargo.toml b/Cargo.toml
index c9796a8215..2ac7d25409 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -10,6 +10,7 @@ clap = "3.0.0-beta.2"
 derive_more = "0.99.11"
 inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] }
 itertools = "0.10.0"
+lazy_static = "1.4.0"
 llvm-sys = "110.0.1"
 nom = "6.1.2"
 nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" }
diff --git a/ach/.gitignore b/ach/.gitignore
index e8423ae351..683a53a01f 100644
--- a/ach/.gitignore
+++ b/ach/.gitignore
@@ -1,2 +1,5 @@
 *.ll
 *.o
+
+functions
+simple
diff --git a/src/ast/hir.rs b/src/ast/hir.rs
new file mode 100644
index 0000000000..151ddd5298
--- /dev/null
+++ b/src/ast/hir.rs
@@ -0,0 +1,246 @@
+use itertools::Itertools;
+
+use super::{BinaryOperator, Ident, Literal, UnaryOperator};
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct Binding<'a, T> {
+    pub ident: Ident<'a>,
+    pub type_: T,
+    pub body: Expr<'a, T>,
+}
+
+impl<'a, T> Binding<'a, T> {
+    fn to_owned(&self) -> Binding<'static, T>
+    where
+        T: Clone,
+    {
+        Binding {
+            ident: self.ident.to_owned(),
+            type_: self.type_.clone(),
+            body: self.body.to_owned(),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub enum Expr<'a, T> {
+    Ident(Ident<'a>, T),
+
+    Literal(Literal, T),
+
+    UnaryOp {
+        op: UnaryOperator,
+        rhs: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    BinaryOp {
+        lhs: Box<Expr<'a, T>>,
+        op: BinaryOperator,
+        rhs: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    Let {
+        bindings: Vec<Binding<'a, T>>,
+        body: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    If {
+        condition: Box<Expr<'a, T>>,
+        then: Box<Expr<'a, T>>,
+        else_: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    Fun {
+        args: Vec<(Ident<'a>, T)>,
+        body: Box<Expr<'a, T>>,
+        type_: T,
+    },
+
+    Call {
+        fun: Box<Expr<'a, T>>,
+        args: Vec<Expr<'a, T>>,
+        type_: T,
+    },
+}
+
+impl<'a, T> Expr<'a, T> {
+    pub fn type_(&self) -> &T {
+        match self {
+            Expr::Ident(_, t) => t,
+            Expr::Literal(_, t) => t,
+            Expr::UnaryOp { type_, .. } => type_,
+            Expr::BinaryOp { type_, .. } => type_,
+            Expr::Let { type_, .. } => type_,
+            Expr::If { type_, .. } => type_,
+            Expr::Fun { type_, .. } => type_,
+            Expr::Call { type_, .. } => type_,
+        }
+    }
+
+    pub fn traverse_type<F, U, E>(self, f: F) -> Result<Expr<'a, U>, E>
+    where
+        F: Fn(T) -> Result<U, E> + Clone,
+    {
+        match self {
+            Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)),
+            Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)),
+            Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp {
+                op,
+                rhs: Box::new(rhs.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::BinaryOp {
+                lhs,
+                op,
+                rhs,
+                type_,
+            } => Ok(Expr::BinaryOp {
+                lhs: Box::new(lhs.traverse_type(f.clone())?),
+                op,
+                rhs: Box::new(rhs.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::Let {
+                bindings,
+                body,
+                type_,
+            } => Ok(Expr::Let {
+                bindings: bindings
+                    .into_iter()
+                    .map(|Binding { ident, type_, body }| {
+                        Ok(Binding {
+                            ident,
+                            type_: f(type_)?,
+                            body: body.traverse_type(f.clone())?,
+                        })
+                    })
+                    .collect::<Result<Vec<_>, E>>()?,
+                body: Box::new(body.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::If {
+                condition,
+                then,
+                else_,
+                type_,
+            } => Ok(Expr::If {
+                condition: Box::new(condition.traverse_type(f.clone())?),
+                then: Box::new(then.traverse_type(f.clone())?),
+                else_: Box::new(else_.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::Fun { args, body, type_ } => Ok(Expr::Fun {
+                args: args
+                    .into_iter()
+                    .map(|(id, t)| Ok((id, f.clone()(t)?)))
+                    .collect::<Result<Vec<_>, E>>()?,
+                body: Box::new(body.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+            Expr::Call { fun, args, type_ } => Ok(Expr::Call {
+                fun: Box::new(fun.traverse_type(f.clone())?),
+                args: args
+                    .into_iter()
+                    .map(|e| e.traverse_type(f.clone()))
+                    .collect::<Result<Vec<_>, E>>()?,
+                type_: f(type_)?,
+            }),
+        }
+    }
+
+    pub fn to_owned(&self) -> Expr<'static, T>
+    where
+        T: Clone,
+    {
+        match self {
+            Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()),
+            Expr::Literal(lit, t) => Expr::Literal(lit.clone(), t.clone()),
+            Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp {
+                op: *op,
+                rhs: Box::new((**rhs).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::BinaryOp {
+                lhs,
+                op,
+                rhs,
+                type_,
+            } => Expr::BinaryOp {
+                lhs: Box::new((**lhs).to_owned()),
+                op: *op,
+                rhs: Box::new((**rhs).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::Let {
+                bindings,
+                body,
+                type_,
+            } => Expr::Let {
+                bindings: bindings.into_iter().map(|b| b.to_owned()).collect(),
+                body: Box::new((**body).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::If {
+                condition,
+                then,
+                else_,
+                type_,
+            } => Expr::If {
+                condition: Box::new((**condition).to_owned()),
+                then: Box::new((**then).to_owned()),
+                else_: Box::new((**else_).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::Fun { args, body, type_ } => Expr::Fun {
+                args: args
+                    .into_iter()
+                    .map(|(id, t)| (id.to_owned(), t.clone()))
+                    .collect(),
+                body: Box::new((**body).to_owned()),
+                type_: type_.clone(),
+            },
+            Expr::Call { fun, args, type_ } => Expr::Call {
+                fun: Box::new((**fun).to_owned()),
+                args: args.into_iter().map(|e| e.to_owned()).collect(),
+                type_: type_.clone(),
+            },
+        }
+    }
+}
+
+pub enum Decl<'a, T> {
+    Fun {
+        name: Ident<'a>,
+        args: Vec<(Ident<'a>, T)>,
+        body: Box<Expr<'a, T>>,
+        type_: T,
+    },
+}
+
+impl<'a, T> Decl<'a, T> {
+    pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E>
+    where
+        F: Fn(T) -> Result<U, E> + Clone,
+    {
+        match self {
+            Decl::Fun {
+                name,
+                args,
+                body,
+                type_,
+            } => Ok(Decl::Fun {
+                name,
+                args: args
+                    .into_iter()
+                    .map(|(id, t)| Ok((id, f(t)?)))
+                    .try_collect()?,
+                body: Box::new(body.traverse_type(f.clone())?),
+                type_: f(type_)?,
+            }),
+        }
+    }
+}
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index dc22ac3cdb..cef366d16e 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -1,3 +1,5 @@
+pub(crate) mod hir;
+
 use std::borrow::Cow;
 use std::convert::TryFrom;
 use std::fmt::{self, Display, Formatter};
@@ -107,6 +109,7 @@ pub enum UnaryOperator {
 #[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Literal {
     Int(u64),
+    Bool(bool),
 }
 
 #[derive(Debug, PartialEq, Eq, Clone)]
diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs
index 1f4a457cd8..5b5db90a1a 100644
--- a/src/codegen/llvm.rs
+++ b/src/codegen/llvm.rs
@@ -7,12 +7,13 @@ use inkwell::builder::Builder;
 pub use inkwell::context::Context;
 use inkwell::module::Module;
 use inkwell::support::LLVMString;
-use inkwell::types::FunctionType;
+use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
 use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
 use inkwell::IntPredicate;
 use thiserror::Error;
 
-use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator};
+use crate::ast::hir::{Binding, Decl, Expr};
+use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, PartialEq, Eq, Error)]
@@ -36,7 +37,7 @@ pub struct Codegen<'ctx, 'ast> {
     context: &'ctx Context,
     pub module: Module<'ctx>,
     builder: Builder<'ctx>,
-    env: Env<'ast, AnyValueEnum<'ctx>>,
+    env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>,
     function_stack: Vec<FunctionValue<'ctx>>,
     identifier_counter: u32,
 }
@@ -77,18 +78,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             .append_basic_block(*self.function_stack.last().unwrap(), name)
     }
 
-    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> {
+    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
         match expr {
-            Expr::Ident(id) => self
+            Expr::Ident(id, _) => self
                 .env
                 .resolve(id)
                 .cloned()
                 .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
-            Expr::Literal(Literal::Int(i)) => {
-                let ty = self.context.i64_type();
-                Ok(AnyValueEnum::IntValue(ty.const_int(*i, false)))
+            Expr::Literal(lit, ty) => {
+                let ty = self.codegen_int_type(ty);
+                match lit {
+                    Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
+                    Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
+                        ty.const_int(if *b { 1 } else { 0 }, false),
+                    )),
+                }
             }
-            Expr::UnaryOp { op, rhs } => {
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
                     UnaryOperator::Not => unimplemented!(),
@@ -97,7 +103,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     )),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.codegen_expr(lhs)?;
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
@@ -135,7 +141,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.codegen_expr(body)?;
@@ -149,6 +155,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 condition,
                 then,
                 else_,
+                type_,
             } => {
                 let then_block = self.append_basic_block("then");
                 let else_block = self.append_basic_block("else");
@@ -168,15 +175,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.builder.build_unconditional_branch(join_block);
 
                 self.builder.position_at_end(join_block);
-                let phi = self.builder.build_phi(self.context.i64_type(), "join");
+                let phi = self.builder.build_phi(self.codegen_type(type_), "join");
                 phi.add_incoming(&[
                     (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
                     (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
                 ]);
                 Ok(phi.as_basic_value().into())
             }
-            Expr::Call { fun, args } => {
-                if let Expr::Ident(id) = &**fun {
+            Expr::Call { fun, args, .. } => {
+                if let Expr::Ident(id, _) = &**fun {
                     let function = self
                         .module
                         .get_function(id.into())
@@ -197,8 +204,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     todo!()
                 }
             }
-            Expr::Fun(fun) => {
-                let Fun { args, body } = &**fun;
+            Expr::Fun { args, body, .. } => {
                 let fname = self.fresh_ident("f");
                 let cur_block = self.builder.get_insert_block().unwrap();
                 let env = self.env.save(); // TODO: closures
@@ -207,29 +213,27 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.env.restore(env);
                 Ok(function.into())
             }
-            Expr::Ascription { expr, .. } => self.codegen_expr(expr),
         }
     }
 
     pub fn codegen_function(
         &mut self,
         name: &str,
-        args: &'ast [Ident<'ast>],
-        body: &'ast Expr<'ast>,
+        args: &'ast [(Ident<'ast>, Type)],
+        body: &'ast Expr<'ast, Type>,
     ) -> Result<FunctionValue<'ctx>> {
-        let i64_type = self.context.i64_type();
         self.new_function(
             name,
-            i64_type.fn_type(
+            self.codegen_type(body.type_()).fn_type(
                 args.iter()
-                    .map(|_| i64_type.into())
+                    .map(|(_, at)| self.codegen_type(at))
                     .collect::<Vec<_>>()
                     .as_slice(),
                 false,
             ),
         );
         self.env.push();
-        for (i, arg) in args.iter().enumerate() {
+        for (i, (arg, _)) in args.iter().enumerate() {
             self.env.set(
                 arg,
                 self.cur_function().get_nth_param(i as u32).unwrap().into(),
@@ -240,11 +244,10 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         Ok(self.finish_function(&res))
     }
 
-    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> {
+    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> {
         match decl {
             Decl::Fun {
-                name,
-                body: Fun { args, body },
+                name, args, body, ..
             } => {
                 self.codegen_function(name.into(), args, body)?;
                 Ok(())
@@ -252,13 +255,28 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         }
     }
 
-    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> {
+    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
         self.new_function("main", self.context.i64_type().fn_type(&[], false));
         let res = self.codegen_expr(expr)?.try_into().unwrap();
-        self.finish_function(&res);
+        if *expr.type_() != Type::Int {
+            self.builder
+                .build_return(Some(&self.context.i64_type().const_int(0, false)));
+        } else {
+            self.finish_function(&res);
+        }
         Ok(())
     }
 
+    fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
+        // TODO
+        self.context.i64_type().into()
+    }
+
+    fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
+        // TODO
+        self.context.i64_type()
+    }
+
     pub fn print_to_file<P>(&self, path: P) -> Result<()>
     where
         P: AsRef<Path>,
@@ -299,6 +317,8 @@ mod tests {
     fn jit_eval<T>(expr: &str) -> anyhow::Result<T> {
         let expr = crate::parser::expr(expr).unwrap().1;
 
+        let expr = crate::tc::typecheck_expr(expr).unwrap();
+
         let context = Context::create();
         let mut codegen = Codegen::new(&context, "test");
         let execution_engine = codegen
diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs
index 6f95d90b45..8ef057dba0 100644
--- a/src/codegen/mod.rs
+++ b/src/codegen/mod.rs
@@ -4,10 +4,11 @@ use inkwell::execution_engine::JitFunction;
 use inkwell::OptimizationLevel;
 pub use llvm::*;
 
-use crate::ast::Expr;
+use crate::ast::hir::Expr;
+use crate::ast::Type;
 use crate::common::Result;
 
-pub fn jit_eval<T>(expr: &Expr) -> Result<T> {
+pub fn jit_eval<T>(expr: &Expr<Type>) -> Result<T> {
     let context = Context::create();
     let mut codegen = Codegen::new(&context, "eval");
     let execution_engine = codegen
diff --git a/src/commands/check.rs b/src/commands/check.rs
new file mode 100644
index 0000000000..40de288a28
--- /dev/null
+++ b/src/commands/check.rs
@@ -0,0 +1,39 @@
+use clap::Clap;
+use std::path::PathBuf;
+
+use crate::ast::Type;
+use crate::{parser, tc, Result};
+
+/// Typecheck a file or expression
+#[derive(Clap)]
+pub struct Check {
+    /// File to check
+    path: Option<PathBuf>,
+
+    /// Expression to check
+    #[clap(long, short = 'e')]
+    expr: Option<String>,
+}
+
+fn run_expr(expr: String) -> Result<Type> {
+    let (_, parsed) = parser::expr(&expr)?;
+    let hir_expr = tc::typecheck_expr(parsed)?;
+    Ok(hir_expr.type_().clone())
+}
+
+fn run_path(path: PathBuf) -> Result<Type> {
+    todo!()
+}
+
+impl Check {
+    pub fn run(self) -> Result<()> {
+        let type_ = match (self.path, self.expr) {
+            (None, None) => Err("Must specify either a file or expression to check".into()),
+            (Some(_), Some(_)) => Err("Cannot specify both a file and expression to check".into()),
+            (None, Some(expr)) => run_expr(expr),
+            (Some(path), None) => run_path(path),
+        }?;
+        println!("type: {}", type_);
+        Ok(())
+    }
+}
diff --git a/src/commands/eval.rs b/src/commands/eval.rs
index 112bee6462..61a712c08a 100644
--- a/src/commands/eval.rs
+++ b/src/commands/eval.rs
@@ -3,6 +3,7 @@ use clap::Clap;
 use crate::codegen;
 use crate::interpreter;
 use crate::parser;
+use crate::tc;
 use crate::Result;
 
 /// Evaluate an expression and print its result
@@ -19,10 +20,11 @@ pub struct Eval {
 impl Eval {
     pub fn run(self) -> Result<()> {
         let (_, parsed) = parser::expr(&self.expr)?;
+        let hir = tc::typecheck_expr(parsed)?;
         let result = if self.jit {
-            codegen::jit_eval::<i64>(&parsed)?.into()
+            codegen::jit_eval::<i64>(&hir)?.into()
         } else {
-            interpreter::eval(&parsed)?
+            interpreter::eval(&hir)?
         };
         println!("{}", result);
         Ok(())
diff --git a/src/commands/mod.rs b/src/commands/mod.rs
index 9c0038dabf..fd0a822708 100644
--- a/src/commands/mod.rs
+++ b/src/commands/mod.rs
@@ -1,5 +1,7 @@
+pub mod check;
 pub mod compile;
 pub mod eval;
 
+pub use check::Check;
 pub use compile::Compile;
 pub use eval::Eval;
diff --git a/src/common/env.rs b/src/common/env.rs
index f499323639..59a5e46c46 100644
--- a/src/common/env.rs
+++ b/src/common/env.rs
@@ -1,19 +1,25 @@
+use std::borrow::Borrow;
 use std::collections::HashMap;
+use std::hash::Hash;
 use std::mem;
 
-use crate::ast::Ident;
-
 /// A lexical environment
 #[derive(Debug, PartialEq, Eq)]
-pub struct Env<'ast, V>(Vec<HashMap<&'ast Ident<'ast>, V>>);
+pub struct Env<K: Eq + Hash, V>(Vec<HashMap<K, V>>);
 
-impl<'ast, V> Default for Env<'ast, V> {
+impl<K, V> Default for Env<K, V>
+where
+    K: Eq + Hash,
+{
     fn default() -> Self {
         Self::new()
     }
 }
 
-impl<'ast, V> Env<'ast, V> {
+impl<K, V> Env<K, V>
+where
+    K: Eq + Hash,
+{
     pub fn new() -> Self {
         Self(vec![Default::default()])
     }
@@ -34,11 +40,15 @@ impl<'ast, V> Env<'ast, V> {
         *self = saved;
     }
 
-    pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) {
+    pub fn set(&mut self, k: K, v: V) {
         self.0.last_mut().unwrap().insert(k, v);
     }
 
-    pub fn resolve<'a>(&'a self, k: &'ast Ident<'ast>) -> Option<&'a V> {
+    pub fn resolve<'a, Q>(&'a self, k: &Q) -> Option<&'a V>
+    where
+        K: Borrow<Q>,
+        Q: Hash + Eq + ?Sized,
+    {
         for ctx in self.0.iter().rev() {
             if let Some(res) = ctx.get(k) {
                 return Some(res);
diff --git a/src/common/error.rs b/src/common/error.rs
index f3f3023cea..51575a895e 100644
--- a/src/common/error.rs
+++ b/src/common/error.rs
@@ -2,7 +2,7 @@ use std::{io, result};
 
 use thiserror::Error;
 
-use crate::{codegen, interpreter, parser};
+use crate::{codegen, interpreter, parser, tc};
 
 #[derive(Error, Debug)]
 pub enum Error {
@@ -18,6 +18,9 @@ pub enum Error {
     #[error("Compile error: {0}")]
     CodegenError(#[from] codegen::Error),
 
+    #[error("Type error: {0}")]
+    TypeError(#[from] tc::Error),
+
     #[error("{0}")]
     Message(String),
 }
@@ -28,6 +31,12 @@ impl From<String> for Error {
     }
 }
 
+impl<'a> From<&'a str> for Error {
+    fn from(s: &'a str) -> Self {
+        Self::Message(s.to_owned())
+    }
+}
+
 impl<'a> From<nom::Err<nom::error::Error<&'a str>>> for Error {
     fn from(e: nom::Err<nom::error::Error<&'a str>>) -> Self {
         use nom::error::Error as NomError;
diff --git a/src/compiler.rs b/src/compiler.rs
index 5f8e1ef4fa..f925b267df 100644
--- a/src/compiler.rs
+++ b/src/compiler.rs
@@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
 
 use crate::codegen::{self, Codegen};
 use crate::common::Result;
-use crate::parser;
+use crate::{parser, tc};
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
 pub enum OutputFormat {
@@ -55,6 +55,8 @@ pub struct CompilerOptions {
 pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
     let src = fs::read_to_string(input)?;
     let (_, decls) = parser::toplevel(&src)?; // TODO: statements
+    let decls = tc::typecheck_toplevel(decls)?;
+
     let context = codegen::Context::create();
     let mut codegen = Codegen::new(
         &context,
diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs
index 00421ee90d..85a8928cbf 100644
--- a/src/interpreter/mod.rs
+++ b/src/interpreter/mod.rs
@@ -3,14 +3,13 @@ mod value;
 
 pub use self::error::{Error, Result};
 pub use self::value::{Function, Value};
-use crate::ast::{
-    BinaryOperator, Binding, Expr, FunctionType, Ident, Literal, Type, UnaryOperator,
-};
+use crate::ast::hir::{Binding, Expr};
+use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, Default)]
 pub struct Interpreter<'a> {
-    env: Env<'a, Value<'a>>,
+    env: Env<&'a Ident<'a>, Value<'a>>,
 }
 
 impl<'a> Interpreter<'a> {
@@ -25,18 +24,19 @@ impl<'a> Interpreter<'a> {
             .ok_or_else(|| Error::UndefinedVariable(var.to_owned()))
     }
 
-    pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> {
-        match expr {
-            Expr::Ident(id) => self.resolve(id),
-            Expr::Literal(Literal::Int(i)) => Ok((*i).into()),
-            Expr::UnaryOp { op, rhs } => {
+    pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> {
+        let res = match expr {
+            Expr::Ident(id, _) => self.resolve(id),
+            Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
+            Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.eval(rhs)?;
                 match op {
                     UnaryOperator::Neg => -rhs,
                     _ => unimplemented!(),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.eval(lhs)?;
                 let rhs = self.eval(rhs)?;
                 match op {
@@ -49,7 +49,7 @@ impl<'a> Interpreter<'a> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.eval(body)?;
@@ -63,6 +63,7 @@ impl<'a> Interpreter<'a> {
                 condition,
                 then,
                 else_,
+                ..
             } => {
                 let condition = self.eval(condition)?;
                 if *(condition.as_type::<bool>()?) {
@@ -71,7 +72,7 @@ impl<'a> Interpreter<'a> {
                     self.eval(else_)
                 }
             }
-            Expr::Call { ref fun, args } => {
+            Expr::Call { ref fun, args, .. } => {
                 let fun = self.eval(fun)?;
                 let expected_type = FunctionType {
                     args: args.iter().map(|_| Type::Int).collect(),
@@ -94,21 +95,26 @@ impl<'a> Interpreter<'a> {
                 }
                 Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?))
             }
-            Expr::Fun(fun) => Ok(Value::from(value::Function {
-                // TODO
-                type_: FunctionType {
-                    args: fun.args.iter().map(|_| Type::Int).collect(),
-                    ret: Box::new(Type::Int),
-                },
-                args: fun.args.iter().map(|arg| arg.to_owned()).collect(),
-                body: fun.body.to_owned(),
-            })),
-            Expr::Ascription { expr, .. } => self.eval(expr),
-        }
+            Expr::Fun { args, body, type_ } => {
+                let type_ = match type_ {
+                    Type::Function(ft) => ft.clone(),
+                    _ => unreachable!("Function expression without function type"),
+                };
+
+                Ok(Value::from(value::Function {
+                    // TODO
+                    type_,
+                    args: args.iter().map(|(arg, _)| arg.to_owned()).collect(),
+                    body: (**body).to_owned(),
+                }))
+            }
+        }?;
+        debug_assert_eq!(&res.type_(), expr.type_());
+        Ok(res)
     }
 }
 
-pub fn eval<'a>(expr: &'a Expr<'a>) -> Result<Value> {
+pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result<Value> {
     let mut interpreter = Interpreter::new();
     interpreter.eval(expr)
 }
@@ -121,17 +127,18 @@ mod tests {
     use super::*;
     use BinaryOperator::*;
 
-    fn int_lit(i: u64) -> Box<Expr<'static>> {
-        Box::new(Expr::Literal(Literal::Int(i)))
+    fn int_lit(i: u64) -> Box<Expr<'static, Type>> {
+        Box::new(Expr::Literal(Literal::Int(i), Type::Int))
     }
 
-    fn parse_eval<T>(src: &str) -> T
+    fn do_eval<T>(src: &str) -> T
     where
         for<'a> &'a T: TryFrom<&'a Val<'a>>,
         T: Clone + TypeOf,
     {
         let expr = crate::parser::expr(src).unwrap().1;
-        let res = eval(&expr).unwrap();
+        let hir = crate::tc::typecheck_expr(expr).unwrap();
+        let res = eval(&hir).unwrap();
         res.as_type::<T>().unwrap().clone()
     }
 
@@ -141,6 +148,7 @@ mod tests {
             lhs: int_lit(1),
             op: Mul,
             rhs: int_lit(2),
+            type_: Type::Int,
         };
         let res = eval(&expr).unwrap();
         assert_eq!(*res.as_type::<i64>().unwrap(), 2);
@@ -148,19 +156,19 @@ mod tests {
 
     #[test]
     fn variable_shadowing() {
-        let res = parse_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
+        let res = do_eval::<i64>("let x = 1 in (let x = 2 in x) + x");
         assert_eq!(res, 3);
     }
 
     #[test]
     fn conditional_with_equals() {
-        let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
+        let res = do_eval::<i64>("let x = 1 in if x == 1 then 2 else 4");
         assert_eq!(res, 2);
     }
 
     #[test]
     fn function_call() {
-        let res = parse_eval::<i64>("let id = fn x = x in id 1");
+        let res = do_eval::<i64>("let id = fn x = x in id 1");
         assert_eq!(res, 1);
     }
 }
diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs
index 496e9c4230..5e55825160 100644
--- a/src/interpreter/value.rs
+++ b/src/interpreter/value.rs
@@ -6,13 +6,14 @@ use std::rc::Rc;
 use derive_more::{Deref, From, TryInto};
 
 use super::{Error, Result};
-use crate::ast::{Expr, FunctionType, Ident, Type};
+use crate::ast::hir::Expr;
+use crate::ast::{FunctionType, Ident, Type};
 
 #[derive(Debug, Clone)]
 pub struct Function<'a> {
     pub type_: FunctionType,
     pub args: Vec<Ident<'a>>,
-    pub body: Expr<'a>,
+    pub body: Expr<'a, Type>,
 }
 
 #[derive(From, TryInto)]
diff --git a/src/main.rs b/src/main.rs
index b539ebbb3d..d5b00d6b6c 100644
--- a/src/main.rs
+++ b/src/main.rs
@@ -8,6 +8,7 @@ pub mod compiler;
 pub mod interpreter;
 #[macro_use]
 pub mod parser;
+pub mod tc;
 
 pub use common::{Error, Result};
 
@@ -21,6 +22,7 @@ struct Opts {
 enum Command {
     Eval(commands::Eval),
     Compile(commands::Compile),
+    Check(commands::Check),
 }
 
 fn main() -> anyhow::Result<()> {
@@ -28,5 +30,6 @@ fn main() -> anyhow::Result<()> {
     match opts.subcommand {
         Command::Eval(eval) => Ok(eval.run()?),
         Command::Compile(compile) => Ok(compile.run()?),
+        Command::Check(check) => Ok(check.run()?),
     }
 }
diff --git a/src/parser/expr.rs b/src/parser/expr.rs
index 2fda3e93fa..73c873b5b3 100644
--- a/src/parser/expr.rs
+++ b/src/parser/expr.rs
@@ -156,7 +156,14 @@ where
 
 named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int));
 
-named!(literal(&str) -> Expr, map!(alt!(int), Expr::Literal));
+named!(bool_(&str) -> Literal, alt!(
+    tag!("true") => { |_| Literal::Bool(true) } |
+    tag!("false") => { |_| Literal::Bool(false) }
+));
+
+named!(literal(&str) -> Literal, alt!(int | bool_));
+
+named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
 
 named!(binding(&str) -> Binding, do_parse!(
     multispace0
@@ -262,7 +269,7 @@ named!(fun_expr(&str) -> Expr, do_parse!(
 
 named!(arg(&str) -> Expr, alt!(
     ident_expr |
-    literal |
+    literal_expr |
     paren_expr
 ));
 
@@ -280,7 +287,7 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!(
     let_ |
     if_ |
     fun_expr |
-    literal |
+    literal_expr |
     ident_expr
 ));
 
@@ -400,6 +407,18 @@ pub(crate) mod tests {
     }
 
     #[test]
+    fn bools() {
+        assert_eq!(
+            test_parse!(expr, "true"),
+            Expr::Literal(Literal::Bool(true))
+        );
+        assert_eq!(
+            test_parse!(expr, "false"),
+            Expr::Literal(Literal::Bool(false))
+        );
+    }
+
+    #[test]
     fn let_complex() {
         let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4");
         assert_eq!(
diff --git a/src/parser/macros.rs b/src/parser/macros.rs
index 60db5133dc..406e5c0e69 100644
--- a/src/parser/macros.rs
+++ b/src/parser/macros.rs
@@ -1,3 +1,4 @@
+#[cfg(test)]
 #[macro_use]
 macro_rules! test_parse {
     ($parser: ident, $src: expr) => {{
diff --git a/src/parser/mod.rs b/src/parser/mod.rs
index af7dff6ff2..0251d02df4 100644
--- a/src/parser/mod.rs
+++ b/src/parser/mod.rs
@@ -14,7 +14,10 @@ pub use type_::type_;
 pub type Error = nom::Err<nom::error::Error<String>>;
 
 pub(crate) fn is_reserved(s: &str) -> bool {
-    matches!(s, "if" | "then" | "else" | "let" | "in" | "fn")
+    matches!(
+        s,
+        "if" | "then" | "else" | "let" | "in" | "fn" | "int" | "float" | "bool" | "true" | "false"
+    )
 }
 
 pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E>
diff --git a/src/tc/mod.rs b/src/tc/mod.rs
new file mode 100644
index 0000000000..b5acfac2b4
--- /dev/null
+++ b/src/tc/mod.rs
@@ -0,0 +1,528 @@
+use derive_more::From;
+use itertools::Itertools;
+use std::collections::HashMap;
+use std::convert::{TryFrom, TryInto};
+use std::fmt::{self, Display};
+use std::result;
+use thiserror::Error;
+
+use crate::ast::{self, hir, BinaryOperator, Ident, Literal};
+use crate::common::env::Env;
+
+#[derive(Debug, Error)]
+pub enum Error {
+    #[error("Undefined variable {0}")]
+    UndefinedVariable(Ident<'static>),
+
+    #[error("Mismatched types: expected {expected}, but got {actual}")]
+    TypeMismatch { expected: Type, actual: Type },
+
+    #[error("Mismatched types, expected numeric type, but got {0}")]
+    NonNumeric(Type),
+
+    #[error("Ambiguous type {0}")]
+    AmbiguousType(TyVar),
+}
+
+pub type Result<T> = result::Result<T, Error>;
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)]
+pub struct TyVar(u64);
+
+impl Display for TyVar {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        write!(f, "t{}", self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Hash)]
+pub struct NullaryType(String);
+
+impl Display for NullaryType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        f.write_str(&self.0)
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+pub enum PrimType {
+    Int,
+    Float,
+    Bool,
+}
+
+impl From<PrimType> for ast::Type {
+    fn from(pr: PrimType) -> Self {
+        match pr {
+            PrimType::Int => ast::Type::Int,
+            PrimType::Float => ast::Type::Float,
+            PrimType::Bool => ast::Type::Bool,
+        }
+    }
+}
+
+impl Display for PrimType {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            PrimType::Int => f.write_str("int"),
+            PrimType::Float => f.write_str("float"),
+            PrimType::Bool => f.write_str("bool"),
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone, From)]
+pub enum Type {
+    #[from(ignore)]
+    Univ(TyVar),
+    #[from(ignore)]
+    Exist(TyVar),
+    Nullary(NullaryType),
+    Prim(PrimType),
+    Fun {
+        args: Vec<Type>,
+        ret: Box<Type>,
+    },
+}
+
+impl PartialEq<ast::Type> for Type {
+    fn eq(&self, other: &ast::Type) -> bool {
+        match (self, other) {
+            (Type::Univ(_), _) => todo!(),
+            (Type::Exist(_), _) => false,
+            (Type::Nullary(_), _) => todo!(),
+            (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
+                *args == ft.args && (**ret).eq(&*ft.ret)
+            }
+            (Type::Fun { .. }, _) => false,
+        }
+    }
+}
+
+impl TryFrom<Type> for ast::Type {
+    type Error = Type;
+
+    fn try_from(value: Type) -> result::Result<Self, Self::Error> {
+        match value {
+            Type::Univ(_) => todo!(),
+            Type::Exist(_) => Err(value),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(p) => Ok(p.into()),
+            Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .clone()
+                    .into_iter()
+                    .map(Self::try_from)
+                    .try_collect()
+                    .map_err(|_| value.clone())?,
+                ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?),
+            })),
+        }
+    }
+}
+
+const INT: Type = Type::Prim(PrimType::Int);
+const FLOAT: Type = Type::Prim(PrimType::Float);
+const BOOL: Type = Type::Prim(PrimType::Bool);
+
+impl Display for Type {
+    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+        match self {
+            Type::Nullary(nt) => nt.fmt(f),
+            Type::Prim(p) => p.fmt(f),
+            Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
+            Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
+            Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+        }
+    }
+}
+
+impl From<ast::Type> for Type {
+    fn from(type_: ast::Type) -> Self {
+        match type_ {
+            ast::Type::Int => INT,
+            ast::Type::Float => FLOAT,
+            ast::Type::Bool => BOOL,
+            ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
+                args: args.into_iter().map(Self::from).collect(),
+                ret: Box::new(Self::from(*ret)),
+            },
+        }
+    }
+}
+
+struct Typechecker<'ast> {
+    ty_var_counter: u64,
+    ctx: HashMap<TyVar, Type>,
+    env: Env<Ident<'ast>, Type>,
+}
+
+impl<'ast> Typechecker<'ast> {
+    fn new() -> Self {
+        Self {
+            ty_var_counter: 0,
+            ctx: Default::default(),
+            env: Default::default(),
+        }
+    }
+
+    pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
+        match expr {
+            ast::Expr::Ident(ident) => {
+                let type_ = self
+                    .env
+                    .resolve(&ident)
+                    .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))?
+                    .clone();
+                Ok(hir::Expr::Ident(ident, type_))
+            }
+            ast::Expr::Literal(lit) => {
+                let type_ = match lit {
+                    Literal::Int(_) => Type::Prim(PrimType::Int),
+                    Literal::Bool(_) => Type::Prim(PrimType::Bool),
+                };
+                Ok(hir::Expr::Literal(lit, type_))
+            }
+            ast::Expr::UnaryOp { op, rhs } => todo!(),
+            ast::Expr::BinaryOp { lhs, op, rhs } => {
+                let lhs = self.tc_expr(*lhs)?;
+                let rhs = self.tc_expr(*rhs)?;
+                let type_ = match op {
+                    BinaryOperator::Equ | BinaryOperator::Neq => {
+                        self.unify(lhs.type_(), rhs.type_())?;
+                        Type::Prim(PrimType::Bool)
+                    }
+                    BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => {
+                        let ty = self.unify(lhs.type_(), rhs.type_())?;
+                        // if !matches!(ty, Type::Int | Type::Float) {
+                        //     return Err(Error::NonNumeric(ty));
+                        // }
+                        ty
+                    }
+                    BinaryOperator::Div => todo!(),
+                    BinaryOperator::Pow => todo!(),
+                };
+                Ok(hir::Expr::BinaryOp {
+                    lhs: Box::new(lhs),
+                    op,
+                    rhs: Box::new(rhs),
+                    type_,
+                })
+            }
+            ast::Expr::Let { bindings, body } => {
+                self.env.push();
+                let bindings = bindings
+                    .into_iter()
+                    .map(
+                        |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
+                            let body = self.tc_expr(body)?;
+                            if let Some(type_) = type_ {
+                                self.unify(body.type_(), &type_.into())?;
+                            }
+                            self.env.set(ident.clone(), body.type_().clone());
+                            Ok(hir::Binding {
+                                ident,
+                                type_: body.type_().clone(),
+                                body,
+                            })
+                        },
+                    )
+                    .collect::<Result<Vec<hir::Binding<Type>>>>()?;
+                let body = self.tc_expr(*body)?;
+                self.env.pop();
+                Ok(hir::Expr::Let {
+                    bindings,
+                    type_: body.type_().clone(),
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::If {
+                condition,
+                then,
+                else_,
+            } => {
+                let condition = self.tc_expr(*condition)?;
+                self.unify(&Type::Prim(PrimType::Bool), condition.type_())?;
+                let then = self.tc_expr(*then)?;
+                let else_ = self.tc_expr(*else_)?;
+                let type_ = self.unify(then.type_(), else_.type_())?;
+                Ok(hir::Expr::If {
+                    condition: Box::new(condition),
+                    then: Box::new(then),
+                    else_: Box::new(else_),
+                    type_,
+                })
+            }
+            ast::Expr::Fun(f) => {
+                let ast::Fun { args, body } = *f;
+                self.env.push();
+                let args: Vec<_> = args
+                    .into_iter()
+                    .map(|id| {
+                        let ty = self.fresh_ex();
+                        self.env.set(id.clone(), ty.clone());
+                        (id, ty)
+                    })
+                    .collect();
+                let body = self.tc_expr(body)?;
+                self.env.pop();
+                Ok(hir::Expr::Fun {
+                    type_: Type::Fun {
+                        args: args.iter().map(|(_, ty)| ty.clone()).collect(),
+                        ret: Box::new(body.type_().clone()),
+                    },
+                    args,
+                    body: Box::new(body),
+                })
+            }
+            ast::Expr::Call { fun, args } => {
+                let ret_ty = self.fresh_ex();
+                let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>();
+                let ft = Type::Fun {
+                    args: arg_tys.clone(),
+                    ret: Box::new(ret_ty.clone()),
+                };
+                let fun = self.tc_expr(*fun)?;
+                self.unify(&ft, fun.type_())?;
+                let args = args
+                    .into_iter()
+                    .zip(arg_tys)
+                    .map(|(arg, ty)| {
+                        let arg = self.tc_expr(arg)?;
+                        self.unify(&ty, arg.type_())?;
+                        Ok(arg)
+                    })
+                    .try_collect()?;
+                Ok(hir::Expr::Call {
+                    fun: Box::new(fun),
+                    args,
+                    type_: ret_ty,
+                })
+            }
+            ast::Expr::Ascription { expr, type_ } => {
+                let expr = self.tc_expr(*expr)?;
+                self.unify(expr.type_(), &type_.into())?;
+                Ok(expr)
+            }
+        }
+    }
+
+    pub(crate) fn tc_decl(&mut self, decl: ast::Decl<'ast>) -> Result<hir::Decl<'ast, Type>> {
+        match decl {
+            ast::Decl::Fun { name, body } => {
+                let body = self.tc_expr(ast::Expr::Fun(Box::new(body)))?;
+                let type_ = body.type_().clone();
+                self.env.set(name.clone(), type_);
+                match body {
+                    hir::Expr::Fun { args, body, type_ } => Ok(hir::Decl::Fun {
+                        name,
+                        args,
+                        body,
+                        type_,
+                    }),
+                    _ => unreachable!(),
+                }
+            }
+        }
+    }
+
+    fn fresh_tv(&mut self) -> TyVar {
+        self.ty_var_counter += 1;
+        TyVar(self.ty_var_counter)
+    }
+
+    fn fresh_ex(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn fresh_univ(&mut self) -> Type {
+        Type::Exist(self.fresh_tv())
+    }
+
+    fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> {
+        // TODO
+        expr
+    }
+
+    fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
+        match (ty1, ty2) {
+            (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
+            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
+                Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()),
+                Some(existing_ty) => Err(Error::TypeMismatch {
+                    expected: ty.clone(),
+                    actual: existing_ty.into(),
+                }),
+                None => match self.ctx.insert(*tv, ty.clone()) {
+                    Some(existing) => self.unify(&existing, ty),
+                    None => Ok(ty.clone()),
+                },
+            },
+            (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()),
+            (
+                Type::Fun {
+                    args: args1,
+                    ret: ret1,
+                },
+                Type::Fun {
+                    args: args2,
+                    ret: ret2,
+                },
+            ) => {
+                let args = args1
+                    .iter()
+                    .zip(args2)
+                    .map(|(t1, t2)| self.unify(t1, t2))
+                    .try_collect()?;
+                let ret = self.unify(ret1, ret2)?;
+                Ok(Type::Fun {
+                    args,
+                    ret: Box::new(ret),
+                })
+            }
+            (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(),
+            _ => Err(Error::TypeMismatch {
+                expected: ty1.clone(),
+                actual: ty2.clone(),
+            }),
+        }
+    }
+
+    fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> {
+        expr.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> {
+        decl.traverse_type(|ty| self.finalize_type(ty))
+    }
+
+    fn finalize_type(&self, ty: Type) -> Result<ast::Type> {
+        match ty {
+            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
+            Type::Univ(tv) => todo!(),
+            Type::Nullary(_) => todo!(),
+            Type::Prim(pr) => Ok(pr.into()),
+            Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
+                args: args
+                    .into_iter()
+                    .map(|ty| self.finalize_type(ty))
+                    .try_collect()?,
+                ret: Box::new(self.finalize_type(*ret)?),
+            })),
+        }
+    }
+
+    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> {
+        let mut res = &Type::Exist(tv);
+        loop {
+            match res {
+                Type::Exist(tv) => {
+                    res = self.ctx.get(tv)?;
+                }
+                Type::Univ(_) => todo!(),
+                Type::Nullary(_) => todo!(),
+                Type::Prim(pr) => break Some((*pr).into()),
+                Type::Fun { args, ret } => todo!(),
+            }
+        }
+    }
+}
+
+pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> {
+    let mut typechecker = Typechecker::new();
+    let typechecked = typechecker.tc_expr(expr)?;
+    typechecker.finalize_expr(typechecked)
+}
+
+pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> {
+    let mut typechecker = Typechecker::new();
+    decls
+        .into_iter()
+        .map(|decl| {
+            let decl = typechecker.tc_decl(decl)?;
+            typechecker.finalize_decl(decl)
+        })
+        .try_collect()
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    macro_rules! assert_type {
+        ($expr: expr, $type: expr) => {
+            use crate::parser::{expr, type_};
+            let parsed_expr = test_parse!(expr, $expr);
+            let parsed_type = test_parse!(type_, $type);
+            let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e));
+            assert_eq!(res.type_(), &parsed_type);
+        };
+    }
+
+    macro_rules! assert_type_error {
+        ($expr: expr) => {
+            use crate::parser::expr;
+            let parsed_expr = test_parse!(expr, $expr);
+            let res = typecheck_expr(parsed_expr);
+            assert!(
+                res.is_err(),
+                "Expected type error, but got type: {}",
+                res.unwrap().type_()
+            );
+        };
+    }
+
+    #[test]
+    fn literal_int() {
+        assert_type!("1", "int");
+    }
+
+    #[test]
+    fn conditional() {
+        assert_type!("if 1 == 2 then 3 else 4", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn add_bools() {
+        assert_type_error!("true + false");
+    }
+
+    #[test]
+    fn call_generic_function() {
+        assert_type!("(fn x = x) 1", "int");
+    }
+
+    #[test]
+    #[ignore]
+    fn generic_function() {
+        assert_type!("fn x = x", "fn x, y -> x");
+    }
+
+    #[test]
+    #[ignore]
+    fn let_generalization() {
+        assert_type!("let id = fn x = x in if id true then id 1 else 2", "int");
+    }
+
+    #[test]
+    fn concrete_function() {
+        assert_type!("fn x = x + 1", "fn int -> int");
+    }
+
+    #[test]
+    fn call_concrete_function() {
+        assert_type!("(fn x = x + 1) 2", "int");
+    }
+
+    #[test]
+    fn conditional_non_bool() {
+        assert_type_error!("if 3 then true else false");
+    }
+
+    #[test]
+    fn let_int() {
+        assert_type!("let x = 1 in x", "int");
+    }
+}