diff options
Diffstat (limited to 'users/glittershark/achilles/src/tc')
-rw-r--r-- | users/glittershark/achilles/src/tc/mod.rs | 745 |
1 files changed, 0 insertions, 745 deletions
diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs deleted file mode 100644 index d27c45075e97..000000000000 --- a/users/glittershark/achilles/src/tc/mod.rs +++ /dev/null @@ -1,745 +0,0 @@ -use bimap::BiMap; -use derive_more::From; -use itertools::Itertools; -use std::cell::RefCell; -use std::collections::HashMap; -use std::convert::{TryFrom, TryInto}; -use std::fmt::{self, Display}; -use std::{mem, result}; -use thiserror::Error; - -use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal}; -use crate::common::env::Env; -use crate::common::{Namer, NamerOf}; - -#[derive(Debug, Error)] -pub enum Error { - #[error("Undefined variable {0}")] - UndefinedVariable(Ident<'static>), - - #[error("Mismatched types: expected {expected}, but got {actual}")] - TypeMismatch { expected: Type, actual: Type }, - - #[error("Mismatched types, expected numeric type, but got {0}")] - NonNumeric(Type), - - #[error("Ambiguous type {0}")] - AmbiguousType(TyVar), -} - -pub type Result<T> = result::Result<T, Error>; - -#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] -pub struct TyVar(u64); - -impl Display for TyVar { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "t{}", self.0) - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Hash)] -pub struct NullaryType(String); - -impl Display for NullaryType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&self.0) - } -} - -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum PrimType { - Int, - Float, - Bool, - CString, -} - -impl<'a> From<PrimType> for ast::Type<'a> { - fn from(pr: PrimType) -> Self { - match pr { - PrimType::Int => ast::Type::Int, - PrimType::Float => ast::Type::Float, - PrimType::Bool => ast::Type::Bool, - PrimType::CString => ast::Type::CString, - } - } -} - -impl Display for PrimType { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - PrimType::Int => f.write_str("int"), - PrimType::Float => f.write_str("float"), - PrimType::Bool => f.write_str("bool"), - PrimType::CString => f.write_str("cstring"), - } - } -} - -#[derive(Debug, PartialEq, Eq, Clone, From)] -pub enum Type { - #[from(ignore)] - Univ(TyVar), - #[from(ignore)] - Exist(TyVar), - Nullary(NullaryType), - Prim(PrimType), - Unit, - Fun { - args: Vec<Type>, - ret: Box<Type>, - }, -} - -impl<'a> TryFrom<Type> for ast::Type<'a> { - type Error = Type; - - fn try_from(value: Type) -> result::Result<Self, Self::Error> { - match value { - Type::Unit => Ok(ast::Type::Unit), - Type::Univ(_) => todo!(), - Type::Exist(_) => Err(value), - Type::Nullary(_) => todo!(), - Type::Prim(p) => Ok(p.into()), - Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType { - args: args - .clone() - .into_iter() - .map(Self::try_from) - .try_collect() - .map_err(|_| value.clone())?, - ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?), - })), - } - } -} - -const INT: Type = Type::Prim(PrimType::Int); -const FLOAT: Type = Type::Prim(PrimType::Float); -const BOOL: Type = Type::Prim(PrimType::Bool); -const CSTRING: Type = Type::Prim(PrimType::CString); - -impl Display for Type { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Type::Nullary(nt) => nt.fmt(f), - Type::Prim(p) => p.fmt(f), - Type::Univ(TyVar(n)) => write!(f, "∀{}", n), - Type::Exist(TyVar(n)) => write!(f, "∃{}", n), - Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), - Type::Unit => write!(f, "()"), - } - } -} - -struct Typechecker<'ast> { - ty_var_namer: NamerOf<TyVar>, - ctx: HashMap<TyVar, Type>, - env: Env<Ident<'ast>, Type>, - - /// AST type var -> type - instantiations: Env<Ident<'ast>, Type>, - - /// AST type-var -> universal TyVar - type_vars: RefCell<(BiMap<Ident<'ast>, TyVar>, NamerOf<Ident<'static>>)>, -} - -impl<'ast> Typechecker<'ast> { - fn new() -> Self { - Self { - ty_var_namer: Namer::new(TyVar).boxed(), - type_vars: RefCell::new(( - Default::default(), - Namer::alphabetic().map(|n| Ident::try_from(n).unwrap()), - )), - ctx: Default::default(), - env: Default::default(), - instantiations: Default::default(), - } - } - - pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> { - match expr { - ast::Expr::Ident(ident) => { - let type_ = self - .env - .resolve(&ident) - .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))? - .clone(); - Ok(hir::Expr::Ident(ident, type_)) - } - ast::Expr::Literal(lit) => { - let type_ = match lit { - Literal::Int(_) => Type::Prim(PrimType::Int), - Literal::Bool(_) => Type::Prim(PrimType::Bool), - Literal::String(_) => Type::Prim(PrimType::CString), - Literal::Unit => Type::Unit, - }; - Ok(hir::Expr::Literal(lit.to_owned(), type_)) - } - ast::Expr::UnaryOp { op, rhs } => todo!(), - ast::Expr::BinaryOp { lhs, op, rhs } => { - let lhs = self.tc_expr(*lhs)?; - let rhs = self.tc_expr(*rhs)?; - let type_ = match op { - BinaryOperator::Equ | BinaryOperator::Neq => { - self.unify(lhs.type_(), rhs.type_())?; - Type::Prim(PrimType::Bool) - } - BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => { - let ty = self.unify(lhs.type_(), rhs.type_())?; - // if !matches!(ty, Type::Int | Type::Float) { - // return Err(Error::NonNumeric(ty)); - // } - ty - } - BinaryOperator::Div => todo!(), - BinaryOperator::Pow => todo!(), - }; - Ok(hir::Expr::BinaryOp { - lhs: Box::new(lhs), - op, - rhs: Box::new(rhs), - type_, - }) - } - ast::Expr::Let { bindings, body } => { - self.env.push(); - let bindings = bindings - .into_iter() - .map( - |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> { - let body = self.tc_expr(body)?; - if let Some(type_) = type_ { - let type_ = self.type_from_ast_type(type_); - self.unify(body.type_(), &type_)?; - } - self.env.set(ident.clone(), body.type_().clone()); - Ok(hir::Binding { - ident, - type_: body.type_().clone(), - body, - }) - }, - ) - .collect::<Result<Vec<hir::Binding<Type>>>>()?; - let body = self.tc_expr(*body)?; - self.env.pop(); - Ok(hir::Expr::Let { - bindings, - type_: body.type_().clone(), - body: Box::new(body), - }) - } - ast::Expr::If { - condition, - then, - else_, - } => { - let condition = self.tc_expr(*condition)?; - self.unify(&Type::Prim(PrimType::Bool), condition.type_())?; - let then = self.tc_expr(*then)?; - let else_ = self.tc_expr(*else_)?; - let type_ = self.unify(then.type_(), else_.type_())?; - Ok(hir::Expr::If { - condition: Box::new(condition), - then: Box::new(then), - else_: Box::new(else_), - type_, - }) - } - ast::Expr::Fun(f) => { - let ast::Fun { args, body } = *f; - self.env.push(); - let args: Vec<_> = args - .into_iter() - .map(|Arg { ident, type_ }| { - let ty = match type_ { - Some(t) => self.type_from_ast_type(t), - None => self.fresh_ex(), - }; - self.env.set(ident.clone(), ty.clone()); - (ident, ty) - }) - .collect(); - let body = self.tc_expr(body)?; - self.env.pop(); - Ok(hir::Expr::Fun { - type_: Type::Fun { - args: args.iter().map(|(_, ty)| ty.clone()).collect(), - ret: Box::new(body.type_().clone()), - }, - type_args: vec![], // TODO fill in once we do let generalization - args, - body: Box::new(body), - }) - } - ast::Expr::Call { fun, args } => { - let ret_ty = self.fresh_ex(); - let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::<Vec<_>>(); - let ft = Type::Fun { - args: arg_tys.clone(), - ret: Box::new(ret_ty.clone()), - }; - let fun = self.tc_expr(*fun)?; - self.instantiations.push(); - self.unify(&ft, fun.type_())?; - let args = args - .into_iter() - .zip(arg_tys) - .map(|(arg, ty)| { - let arg = self.tc_expr(arg)?; - self.unify(&ty, arg.type_())?; - Ok(arg) - }) - .try_collect()?; - let type_args = self.commit_instantiations(); - Ok(hir::Expr::Call { - fun: Box::new(fun), - type_args, - args, - type_: ret_ty, - }) - } - ast::Expr::Ascription { expr, type_ } => { - let expr = self.tc_expr(*expr)?; - let type_ = self.type_from_ast_type(type_); - self.unify(expr.type_(), &type_)?; - Ok(expr) - } - } - } - - pub(crate) fn tc_decl( - &mut self, - decl: ast::Decl<'ast>, - ) -> Result<Option<hir::Decl<'ast, Type>>> { - match decl { - ast::Decl::Fun { name, body } => { - let mut expr = ast::Expr::Fun(Box::new(body)); - if let Some(type_) = self.env.resolve(&name) { - expr = ast::Expr::Ascription { - expr: Box::new(expr), - type_: self.finalize_type(type_.clone())?, - }; - } - - self.env.push(); - let body = self.tc_expr(expr)?; - let type_ = body.type_().clone(); - self.env.set(name.clone(), type_); - self.env.pop(); - match body { - hir::Expr::Fun { - type_args, - args, - body, - type_, - } => Ok(Some(hir::Decl::Fun { - name, - type_args, - args, - body, - type_, - })), - _ => unreachable!(), - } - } - ast::Decl::Ascription { name, type_ } => { - let type_ = self.type_from_ast_type(type_); - self.env.set(name.clone(), type_); - Ok(None) - } - ast::Decl::Extern { name, type_ } => { - let type_ = self.type_from_ast_type(ast::Type::Function(type_)); - self.env.set(name.clone(), type_.clone()); - let (arg_types, ret_type) = match type_ { - Type::Fun { args, ret } => (args, *ret), - _ => unreachable!(), - }; - Ok(Some(hir::Decl::Extern { - name, - arg_types, - ret_type, - })) - } - } - } - - fn fresh_tv(&mut self) -> TyVar { - self.ty_var_namer.make_name() - } - - fn fresh_ex(&mut self) -> Type { - Type::Exist(self.fresh_tv()) - } - - fn fresh_univ(&mut self) -> Type { - Type::Univ(self.fresh_tv()) - } - - fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> { - match (ty1, ty2) { - (Type::Unit, Type::Unit) => Ok(Type::Unit), - (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { - Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), - Some(var @ ast::Type::Var(_)) => { - let var = self.type_from_ast_type(var); - self.unify(&var, ty) - } - Some(existing_ty) => match ty { - Type::Exist(_) => { - let rhs = self.type_from_ast_type(existing_ty); - self.unify(ty, &rhs) - } - _ => Err(Error::TypeMismatch { - expected: ty.clone(), - actual: self.type_from_ast_type(existing_ty), - }), - }, - None => match self.ctx.insert(*tv, ty.clone()) { - Some(existing) => self.unify(&existing, ty), - None => Ok(ty.clone()), - }, - }, - (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()), - (Type::Univ(u), ty) | (ty, Type::Univ(u)) => { - let ident = self.name_univ(*u); - match self.instantiations.resolve(&ident) { - Some(existing_ty) if ty == existing_ty => Ok(ty.clone()), - Some(existing_ty) => Err(Error::TypeMismatch { - expected: ty.clone(), - actual: existing_ty.clone(), - }), - None => { - self.instantiations.set(ident, ty.clone()); - Ok(ty.clone()) - } - } - } - (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), - ( - Type::Fun { - args: args1, - ret: ret1, - }, - Type::Fun { - args: args2, - ret: ret2, - }, - ) => { - let args = args1 - .iter() - .zip(args2) - .map(|(t1, t2)| self.unify(t1, t2)) - .try_collect()?; - let ret = self.unify(ret1, ret2)?; - Ok(Type::Fun { - args, - ret: Box::new(ret), - }) - } - (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(), - _ => Err(Error::TypeMismatch { - expected: ty1.clone(), - actual: ty2.clone(), - }), - } - } - - fn finalize_expr( - &self, - expr: hir::Expr<'ast, Type>, - ) -> Result<hir::Expr<'ast, ast::Type<'ast>>> { - expr.traverse_type(|ty| self.finalize_type(ty)) - } - - fn finalize_decl( - &mut self, - decl: hir::Decl<'ast, Type>, - ) -> Result<hir::Decl<'ast, ast::Type<'ast>>> { - let res = decl.traverse_type(|ty| self.finalize_type(ty))?; - if let Some(type_) = res.type_() { - let ty = self.type_from_ast_type(type_.clone()); - self.env.set(res.name().clone(), ty); - } - Ok(res) - } - - fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> { - let ret = match ty { - Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), - Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), - Type::Unit => Ok(ast::Type::Unit), - Type::Nullary(_) => todo!(), - Type::Prim(pr) => Ok(pr.into()), - Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { - args: args - .into_iter() - .map(|ty| self.finalize_type(ty)) - .try_collect()?, - ret: Box::new(self.finalize_type(*ret)?), - })), - }; - ret - } - - fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> { - let mut res = &Type::Exist(tv); - loop { - match res { - Type::Exist(tv) => { - res = self.ctx.get(tv)?; - } - Type::Univ(tv) => { - let ident = self.name_univ(*tv); - if let Some(r) = self.instantiations.resolve(&ident) { - res = r; - } else { - break Some(ast::Type::Var(ident)); - } - } - Type::Nullary(_) => todo!(), - Type::Prim(pr) => break Some((*pr).into()), - Type::Unit => break Some(ast::Type::Unit), - Type::Fun { args, ret } => todo!(), - } - } - } - - fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { - match ast_type { - ast::Type::Unit => Type::Unit, - ast::Type::Int => INT, - ast::Type::Float => FLOAT, - ast::Type::Bool => BOOL, - ast::Type::CString => CSTRING, - ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { - args: args - .into_iter() - .map(|t| self.type_from_ast_type(t)) - .collect(), - ret: Box::new(self.type_from_ast_type(*ret)), - }, - ast::Type::Var(id) => Type::Univ({ - let opt_tv = { self.type_vars.borrow_mut().0.get_by_left(&id).copied() }; - opt_tv.unwrap_or_else(|| { - let tv = self.fresh_tv(); - self.type_vars - .borrow_mut() - .0 - .insert_no_overwrite(id, tv) - .unwrap(); - tv - }) - }), - } - } - - fn name_univ(&self, tv: TyVar) -> Ident<'static> { - let mut vars = self.type_vars.borrow_mut(); - vars.0 - .get_by_right(&tv) - .map(Ident::to_owned) - .unwrap_or_else(|| { - let name = loop { - let name = vars.1.make_name(); - if !vars.0.contains_left(&name) { - break name; - } - }; - vars.0.insert_no_overwrite(name.clone(), tv).unwrap(); - name - }) - } - - fn commit_instantiations(&mut self) -> HashMap<Ident<'ast>, Type> { - let mut res = HashMap::new(); - let mut ctx = mem::take(&mut self.ctx); - for (_, v) in ctx.iter_mut() { - if let Type::Univ(tv) = v { - let tv_name = self.name_univ(*tv); - if let Some(concrete) = self.instantiations.resolve(&tv_name) { - res.insert(tv_name, concrete.clone()); - *v = concrete.clone(); - } - } - } - self.ctx = ctx; - self.instantiations.pop(); - res - } - - fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool { - match (type_, ast_type) { - (Type::Univ(u), ast::Type::Var(v)) => { - Some(u) == self.type_vars.borrow().0.get_by_left(v) - } - (Type::Univ(_), _) => false, - (Type::Exist(_), _) => false, - (Type::Unit, ast::Type::Unit) => true, - (Type::Unit, _) => false, - (Type::Nullary(_), _) => todo!(), - (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, - (Type::Fun { args, ret }, ast::Type::Function(ft)) => { - args.len() == ft.args.len() - && args - .iter() - .zip(&ft.args) - .all(|(a1, a2)| self.types_match(a1, &a2)) - && self.types_match(&*ret, &*ft.ret) - } - (Type::Fun { .. }, _) => false, - } - } -} - -pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> { - let mut typechecker = Typechecker::new(); - let typechecked = typechecker.tc_expr(expr)?; - typechecker.finalize_expr(typechecked) -} - -pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Type>>> { - let mut typechecker = Typechecker::new(); - let mut res = Vec::with_capacity(decls.len()); - for decl in decls { - if let Some(hir_decl) = typechecker.tc_decl(decl)? { - let hir_decl = typechecker.finalize_decl(hir_decl)?; - res.push(hir_decl); - } - typechecker.ctx.clear(); - } - Ok(res) -} - -#[cfg(test)] -mod tests { - use super::*; - - macro_rules! assert_type { - ($expr: expr, $type: expr) => { - use crate::parser::{expr, type_}; - let parsed_expr = test_parse!(expr, $expr); - let parsed_type = test_parse!(type_, $type); - let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e)); - assert!( - res.type_().alpha_equiv(&parsed_type), - "{} inferred type {}, but expected {}", - $expr, - res.type_(), - $type - ); - }; - - (toplevel($program: expr), $($decl: ident => $type: expr),+ $(,)?) => {{ - use crate::parser::{toplevel, type_}; - let program = test_parse!(toplevel, $program); - let res = typecheck_toplevel(program).unwrap_or_else(|e| panic!("{}", e)); - $( - let parsed_type = test_parse!(type_, $type); - let ident = Ident::try_from(::std::stringify!($decl)).unwrap(); - let decl = res.iter().find(|decl| { - matches!(decl, crate::ast::hir::Decl::Fun { name, .. } if name == &ident) - }).unwrap_or_else(|| panic!("Could not find declaration for {}", ident)); - assert!( - decl.type_().unwrap().alpha_equiv(&parsed_type), - "inferred type {} for {}, but expected {}", - decl.type_().unwrap(), - ident, - $type - ); - )+ - }}; - } - - macro_rules! assert_type_error { - ($expr: expr) => { - use crate::parser::expr; - let parsed_expr = test_parse!(expr, $expr); - let res = typecheck_expr(parsed_expr); - assert!( - res.is_err(), - "Expected type error, but got type: {}", - res.unwrap().type_() - ); - }; - } - - #[test] - fn literal_int() { - assert_type!("1", "int"); - } - - #[test] - fn conditional() { - assert_type!("if 1 == 2 then 3 else 4", "int"); - } - - #[test] - #[ignore] - fn add_bools() { - assert_type_error!("true + false"); - } - - #[test] - fn call_generic_function() { - assert_type!("(fn x = x) 1", "int"); - } - - #[test] - fn call_let_bound_generic() { - assert_type!("let id = fn x = x in id 1", "int"); - } - - #[test] - fn universal_ascripted_let() { - assert_type!("let id: fn a -> a = fn x = x in id 1", "int"); - } - - #[test] - fn call_generic_function_toplevel() { - assert_type!( - toplevel( - "ty id : fn a -> a - fn id x = x - - fn main = id 0" - ), - main => "fn -> int", - id => "fn a -> a", - ); - } - - #[test] - #[ignore] - fn let_generalization() { - assert_type!("let id = fn x = x in if id true then id 1 else 2", "int"); - } - - #[test] - fn concrete_function() { - assert_type!("fn x = x + 1", "fn int -> int"); - } - - #[test] - fn arg_ascriptions() { - assert_type!("fn (x: int) = x", "fn int -> int"); - } - - #[test] - fn call_concrete_function() { - assert_type!("(fn x = x + 1) 2", "int"); - } - - #[test] - fn conditional_non_bool() { - assert_type_error!("if 3 then true else false"); - } - - #[test] - fn let_int() { - assert_type!("let x = 1 in x", "int"); - } -} |