diff options
Diffstat (limited to 'users/glittershark/achilles/src')
-rw-r--r-- | users/glittershark/achilles/src/ast/hir.rs | 7 | ||||
-rw-r--r-- | users/glittershark/achilles/src/tc/mod.rs | 90 |
2 files changed, 67 insertions, 30 deletions
diff --git a/users/glittershark/achilles/src/ast/hir.rs b/users/glittershark/achilles/src/ast/hir.rs index 0212b3dbcdbb..691b9607e7e6 100644 --- a/users/glittershark/achilles/src/ast/hir.rs +++ b/users/glittershark/achilles/src/ast/hir.rs @@ -228,6 +228,13 @@ pub enum Decl<'a, T> { } impl<'a, T> Decl<'a, T> { + pub fn name(&self) -> &Ident<'a> { + match self { + Decl::Fun { name, .. } => name, + Decl::Extern { name, .. } => name, + } + } + pub fn type_(&self) -> Option<&T> { match self { Decl::Fun { type_, .. } => Some(type_), diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs index 4bca52733bc8..f3cb40a8b010 100644 --- a/users/glittershark/achilles/src/tc/mod.rs +++ b/users/glittershark/achilles/src/tc/mod.rs @@ -262,10 +262,10 @@ impl<'ast> Typechecker<'ast> { let body = self.tc_expr(body)?; self.env.pop(); Ok(hir::Expr::Fun { - type_: self.universalize( - args.iter().map(|(_, ty)| ty.clone()).collect(), - body.type_().clone(), - ), + type_: Type::Fun { + args: args.iter().map(|(_, ty)| ty.clone()).collect(), + ret: Box::new(body.type_().clone()), + }, args, body: Box::new(body), }) @@ -319,9 +319,11 @@ impl<'ast> Typechecker<'ast> { }; } + self.env.push(); let body = self.tc_expr(expr)?; let type_ = body.type_().clone(); self.env.set(name.clone(), type_); + self.env.pop(); match body { hir::Expr::Fun { args, body, type_ } => Ok(Some(hir::Decl::Fun { name, @@ -365,27 +367,6 @@ impl<'ast> Typechecker<'ast> { Type::Univ(self.fresh_tv()) } - #[allow(clippy::redundant_closure)] // https://github.com/rust-lang/rust-clippy/issues/6903 - fn universalize(&mut self, args: Vec<Type>, ret: Type) -> Type { - let mut vars = HashMap::new(); - let mut universalize_type = move |ty| match ty { - Type::Exist(tv) if self.resolve_tv(tv).is_none() => vars - .entry(tv) - .or_insert_with(|| { - let ty = self.fresh_univ(); - self.ctx.insert(tv, ty.clone()); - ty - }) - .clone(), - _ => ty, - }; - - Type::Fun { - args: args.into_iter().map(|t| universalize_type(t)).collect(), - ret: Box::new(universalize_type(ret)), - } - } - fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> { match (ty1, ty2) { (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { @@ -462,10 +443,15 @@ impl<'ast> Typechecker<'ast> { } fn finalize_decl( - &self, + &mut self, decl: hir::Decl<'ast, Type>, ) -> Result<hir::Decl<'ast, ast::Type<'ast>>> { - decl.traverse_type(|ty| self.finalize_type(ty)) + let res = decl.traverse_type(|ty| self.finalize_type(ty))?; + if let Some(type_) = res.type_() { + let ty = self.type_from_ast_type(type_.clone()); + self.env.set(res.name().clone(), ty); + } + Ok(res) } fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> { @@ -541,7 +527,12 @@ impl<'ast> Typechecker<'ast> { .get_by_right(&tv) .map(Ident::to_owned) .unwrap_or_else(|| { - let name = vars.1.make_name(); + let name = loop { + let name = vars.1.make_name(); + if !vars.0.contains_left(&name) { + break name; + } + }; vars.0.insert_no_overwrite(name.clone(), tv).unwrap(); name }) @@ -619,6 +610,26 @@ mod tests { $type ); }; + + (toplevel($program: expr), $($decl: ident => $type: expr),+ $(,)?) => {{ + use crate::parser::{toplevel, type_}; + let program = test_parse!(toplevel, $program); + let res = typecheck_toplevel(program).unwrap_or_else(|e| panic!("{}", e)); + $( + let parsed_type = test_parse!(type_, $type); + let ident = Ident::try_from(::std::stringify!($decl)).unwrap(); + let decl = res.iter().find(|decl| { + matches!(decl, crate::ast::hir::Decl::Fun { name, .. } if name == &ident) + }).unwrap_or_else(|| panic!("Could not find declaration for {}", ident)); + assert!( + decl.type_().unwrap().alpha_equiv(&parsed_type), + "inferred type {} for {}, but expected {}", + decl.type_().unwrap(), + ident, + $type + ); + )+ + }}; } macro_rules! assert_type_error { @@ -656,8 +667,27 @@ mod tests { } #[test] - fn generic_function() { - assert_type!("fn x = x", "fn x -> x"); + fn call_let_bound_generic() { + assert_type!("let id = fn x = x in id 1", "int"); + } + + #[test] + fn universal_ascripted_let() { + assert_type!("let id: fn a -> a = fn x = x in id 1", "int"); + } + + #[test] + fn call_generic_function_toplevel() { + assert_type!( + toplevel( + "ty id : fn a -> a + fn id x = x + + fn main = id 0" + ), + main => "fn -> int", + id => "fn a -> a", + ); } #[test] |