diff options
author | Griffin Smith <root@gws.fyi> | 2021-03-14T20·43-0400 |
---|---|---|
committer | Griffin Smith <root@gws.fyi> | 2021-03-14T20·43-0400 |
commit | ecb4c0f803e9b408e4fd21c475769eb4dc649d14 (patch) | |
tree | 80390b00a6009cea21fbb68cbf56e6a193b478a2 /src/tc | |
parent | 7960c3270e1a338f4da40d044a6896df96d82c79 (diff) |
Universally quantified type variables
Implement universally quantified type variables, both explicitly given by the user and inferred by the type inference algorithm.
Diffstat (limited to 'src/tc')
-rw-r--r-- | src/tc/mod.rs | 274 |
1 files changed, 202 insertions, 72 deletions
diff --git a/src/tc/mod.rs b/src/tc/mod.rs index 2c40a02bf7c6..4c088c885749 100644 --- a/src/tc/mod.rs +++ b/src/tc/mod.rs @@ -1,13 +1,16 @@ +use bimap::BiMap; use derive_more::From; use itertools::Itertools; +use std::cell::RefCell; use std::collections::HashMap; use std::convert::{TryFrom, TryInto}; use std::fmt::{self, Display}; -use std::result; +use std::{mem, result}; use thiserror::Error; -use crate::ast::{self, hir, BinaryOperator, Ident, Literal}; +use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal}; use crate::common::env::Env; +use crate::common::{Namer, NamerOf}; #[derive(Debug, Error)] pub enum Error { @@ -52,7 +55,7 @@ pub enum PrimType { CString, } -impl From<PrimType> for ast::Type { +impl<'a> From<PrimType> for ast::Type<'a> { fn from(pr: PrimType) -> Self { match pr { PrimType::Int => ast::Type::Int, @@ -88,22 +91,7 @@ pub enum Type { }, } -impl PartialEq<ast::Type> for Type { - fn eq(&self, other: &ast::Type) -> bool { - match (self, other) { - (Type::Univ(_), _) => todo!(), - (Type::Exist(_), _) => false, - (Type::Nullary(_), _) => todo!(), - (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, - (Type::Fun { args, ret }, ast::Type::Function(ft)) => { - *args == ft.args && (**ret).eq(&*ft.ret) - } - (Type::Fun { .. }, _) => false, - } - } -} - -impl TryFrom<Type> for ast::Type { +impl<'a> TryFrom<Type> for ast::Type<'a> { type Error = Type; fn try_from(value: Type) -> result::Result<Self, Self::Error> { @@ -142,33 +130,29 @@ impl Display for Type { } } -impl From<ast::Type> for Type { - fn from(type_: ast::Type) -> Self { - match type_ { - ast::Type::Int => INT, - ast::Type::Float => FLOAT, - ast::Type::Bool => BOOL, - ast::Type::CString => CSTRING, - ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { - args: args.into_iter().map(Self::from).collect(), - ret: Box::new(Self::from(*ret)), - }, - } - } -} - struct Typechecker<'ast> { - ty_var_counter: u64, + ty_var_namer: NamerOf<TyVar>, ctx: HashMap<TyVar, Type>, env: Env<Ident<'ast>, Type>, + + /// AST type var -> type + instantiations: Env<Ident<'ast>, Type>, + + /// AST type-var -> universal TyVar + type_vars: RefCell<(BiMap<Ident<'ast>, TyVar>, NamerOf<Ident<'static>>)>, } impl<'ast> Typechecker<'ast> { fn new() -> Self { Self { - ty_var_counter: 0, + ty_var_namer: Namer::new(TyVar).boxed(), + type_vars: RefCell::new(( + Default::default(), + Namer::alphabetic().map(|n| Ident::try_from(n).unwrap()), + )), ctx: Default::default(), env: Default::default(), + instantiations: Default::default(), } } @@ -224,7 +208,8 @@ impl<'ast> Typechecker<'ast> { |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> { let body = self.tc_expr(body)?; if let Some(type_) = type_ { - self.unify(body.type_(), &type_.into())?; + let type_ = self.type_from_ast_type(type_); + self.unify(body.type_(), &type_)?; } self.env.set(ident.clone(), body.type_().clone()); Ok(hir::Binding { @@ -265,19 +250,22 @@ impl<'ast> Typechecker<'ast> { self.env.push(); let args: Vec<_> = args .into_iter() - .map(|id| { - let ty = self.fresh_ex(); - self.env.set(id.clone(), ty.clone()); - (id, ty) + .map(|Arg { ident, type_ }| { + let ty = match type_ { + Some(t) => self.type_from_ast_type(t), + None => self.fresh_ex(), + }; + self.env.set(ident.clone(), ty.clone()); + (ident, ty) }) .collect(); let body = self.tc_expr(body)?; self.env.pop(); Ok(hir::Expr::Fun { - type_: Type::Fun { - args: args.iter().map(|(_, ty)| ty.clone()).collect(), - ret: Box::new(body.type_().clone()), - }, + type_: self.universalize( + args.iter().map(|(_, ty)| ty.clone()).collect(), + body.type_().clone(), + ), args, body: Box::new(body), }) @@ -290,6 +278,7 @@ impl<'ast> Typechecker<'ast> { ret: Box::new(ret_ty.clone()), }; let fun = self.tc_expr(*fun)?; + self.instantiations.push(); self.unify(&ft, fun.type_())?; let args = args .into_iter() @@ -300,6 +289,7 @@ impl<'ast> Typechecker<'ast> { Ok(arg) }) .try_collect()?; + self.commit_instantiations(); Ok(hir::Expr::Call { fun: Box::new(fun), args, @@ -308,7 +298,8 @@ impl<'ast> Typechecker<'ast> { } ast::Expr::Ascription { expr, type_ } => { let expr = self.tc_expr(*expr)?; - self.unify(expr.type_(), &type_.into())?; + let type_ = self.type_from_ast_type(type_); + self.unify(expr.type_(), &type_)?; Ok(expr) } } @@ -334,8 +325,7 @@ impl<'ast> Typechecker<'ast> { } fn fresh_tv(&mut self) -> TyVar { - self.ty_var_counter += 1; - TyVar(self.ty_var_counter) + self.ty_var_namer.make_name() } fn fresh_ex(&mut self) -> Type { @@ -343,29 +333,69 @@ impl<'ast> Typechecker<'ast> { } fn fresh_univ(&mut self) -> Type { - Type::Exist(self.fresh_tv()) - } + Type::Univ(self.fresh_tv()) + } + + #[allow(clippy::redundant_closure)] // https://github.com/rust-lang/rust-clippy/issues/6903 + fn universalize(&mut self, args: Vec<Type>, ret: Type) -> Type { + let mut vars = HashMap::new(); + let mut universalize_type = move |ty| match ty { + Type::Exist(tv) if self.resolve_tv(tv).is_none() => vars + .entry(tv) + .or_insert_with_key(|tv| { + let ty = self.fresh_univ(); + self.ctx.insert(*tv, ty.clone()); + ty + }) + .clone(), + _ => ty, + }; - fn universalize<'a>(&mut self, expr: hir::Expr<'a, Type>) -> hir::Expr<'a, Type> { - // TODO - expr + Type::Fun { + args: args.into_iter().map(|t| universalize_type(t)).collect(), + ret: Box::new(universalize_type(ret)), + } } fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> { match (ty1, ty2) { - (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { - Some(existing_ty) if *ty == existing_ty => Ok(ty.clone()), - Some(existing_ty) => Err(Error::TypeMismatch { - expected: ty.clone(), - actual: existing_ty.into(), - }), + Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), + Some(var @ ast::Type::Var(_)) => { + let var = self.type_from_ast_type(var); + self.unify(&var, ty) + } + Some(existing_ty) => match ty { + Type::Exist(_) => { + let rhs = self.type_from_ast_type(existing_ty); + self.unify(ty, &rhs) + } + _ => Err(Error::TypeMismatch { + expected: ty.clone(), + actual: self.type_from_ast_type(existing_ty), + }), + }, None => match self.ctx.insert(*tv, ty.clone()) { Some(existing) => self.unify(&existing, ty), None => Ok(ty.clone()), }, }, (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()), + (Type::Univ(u), ty) | (ty, Type::Univ(u)) => { + let ident = self.name_univ(*u); + match self.instantiations.resolve(&ident) { + Some(existing_ty) if ty == existing_ty => Ok(ty.clone()), + Some(existing_ty) => Err(Error::TypeMismatch { + expected: ty.clone(), + actual: existing_ty.clone(), + }), + None => { + self.instantiations.set(ident, ty.clone()); + Ok(ty.clone()) + } + } + } + (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), ( Type::Fun { args: args1, @@ -395,18 +425,24 @@ impl<'ast> Typechecker<'ast> { } } - fn finalize_expr(&self, expr: hir::Expr<'ast, Type>) -> Result<hir::Expr<'ast, ast::Type>> { + fn finalize_expr( + &self, + expr: hir::Expr<'ast, Type>, + ) -> Result<hir::Expr<'ast, ast::Type<'ast>>> { expr.traverse_type(|ty| self.finalize_type(ty)) } - fn finalize_decl(&self, decl: hir::Decl<'ast, Type>) -> Result<hir::Decl<'ast, ast::Type>> { + fn finalize_decl( + &self, + decl: hir::Decl<'ast, Type>, + ) -> Result<hir::Decl<'ast, ast::Type<'ast>>> { decl.traverse_type(|ty| self.finalize_type(ty)) } - fn finalize_type(&self, ty: Type) -> Result<ast::Type> { - match ty { + fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> { + let ret = match ty { Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), - Type::Univ(tv) => todo!(), + Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), Type::Nullary(_) => todo!(), Type::Prim(pr) => Ok(pr.into()), Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { @@ -416,23 +452,105 @@ impl<'ast> Typechecker<'ast> { .try_collect()?, ret: Box::new(self.finalize_type(*ret)?), })), - } + }; + ret } - fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type> { + fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> { let mut res = &Type::Exist(tv); loop { match res { Type::Exist(tv) => { res = self.ctx.get(tv)?; } - Type::Univ(_) => todo!(), + Type::Univ(tv) => { + let ident = self.name_univ(*tv); + if let Some(r) = self.instantiations.resolve(&ident) { + res = r; + } else { + break Some(ast::Type::Var(ident)); + } + } Type::Nullary(_) => todo!(), Type::Prim(pr) => break Some((*pr).into()), Type::Fun { args, ret } => todo!(), } } } + + fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { + match ast_type { + ast::Type::Int => INT, + ast::Type::Float => FLOAT, + ast::Type::Bool => BOOL, + ast::Type::CString => CSTRING, + ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { + args: args + .into_iter() + .map(|t| self.type_from_ast_type(t)) + .collect(), + ret: Box::new(self.type_from_ast_type(*ret)), + }, + ast::Type::Var(id) => Type::Univ({ + let opt_tv = { self.type_vars.borrow_mut().0.get_by_left(&id).copied() }; + opt_tv.unwrap_or_else(|| { + let tv = self.fresh_tv(); + self.type_vars + .borrow_mut() + .0 + .insert_no_overwrite(id, tv) + .unwrap(); + tv + }) + }), + } + } + + fn name_univ(&self, tv: TyVar) -> Ident<'static> { + let mut vars = self.type_vars.borrow_mut(); + vars.0 + .get_by_right(&tv) + .map(Ident::to_owned) + .unwrap_or_else(|| { + let name = vars.1.make_name(); + vars.0.insert_no_overwrite(name.clone(), tv).unwrap(); + name + }) + } + + fn commit_instantiations(&mut self) { + let mut ctx = mem::take(&mut self.ctx); + for (_, v) in ctx.iter_mut() { + if let Type::Univ(tv) = v { + if let Some(concrete) = self.instantiations.resolve(&self.name_univ(*tv)) { + *v = concrete.clone(); + } + } + } + self.ctx = ctx; + self.instantiations.pop(); + } + + fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool { + match (type_, ast_type) { + (Type::Univ(u), ast::Type::Var(v)) => { + Some(u) == self.type_vars.borrow().0.get_by_left(v) + } + (Type::Univ(_), _) => false, + (Type::Exist(_), _) => false, + (Type::Nullary(_), _) => todo!(), + (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, + (Type::Fun { args, ret }, ast::Type::Function(ft)) => { + args.len() == ft.args.len() + && args + .iter() + .zip(&ft.args) + .all(|(a1, a2)| self.types_match(a1, &a2)) + && self.types_match(&*ret, &*ft.ret) + } + (Type::Fun { .. }, _) => false, + } + } } pub fn typecheck_expr(expr: ast::Expr) -> Result<hir::Expr<ast::Type>> { @@ -446,8 +564,10 @@ pub fn typecheck_toplevel(decls: Vec<ast::Decl>) -> Result<Vec<hir::Decl<ast::Ty decls .into_iter() .map(|decl| { - let decl = typechecker.tc_decl(decl)?; - typechecker.finalize_decl(decl) + let hir_decl = typechecker.tc_decl(decl)?; + let res = typechecker.finalize_decl(hir_decl)?; + typechecker.ctx.clear(); + Ok(res) }) .try_collect() } @@ -462,7 +582,13 @@ mod tests { let parsed_expr = test_parse!(expr, $expr); let parsed_type = test_parse!(type_, $type); let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e)); - assert_eq!(res.type_(), &parsed_type); + assert!( + res.type_().alpha_equiv(&parsed_type), + "{} inferred type {}, but expected {}", + $expr, + res.type_(), + $type + ); }; } @@ -501,9 +627,8 @@ mod tests { } #[test] - #[ignore] fn generic_function() { - assert_type!("fn x = x", "fn x, y -> x"); + assert_type!("fn x = x", "fn x -> x"); } #[test] @@ -518,6 +643,11 @@ mod tests { } #[test] + fn arg_ascriptions() { + assert_type!("fn (x: int) = x", "fn int -> int"); + } + + #[test] fn call_concrete_function() { assert_type!("(fn x = x + 1) 2", "int"); } |