diff options
Diffstat (limited to 'users/grfn/achilles/src/tc/mod.rs')
-rw-r--r-- | users/grfn/achilles/src/tc/mod.rs | 91 |
1 files changed, 77 insertions, 14 deletions
diff --git a/users/grfn/achilles/src/tc/mod.rs b/users/grfn/achilles/src/tc/mod.rs index d27c45075e97..5825bab1fbe9 100644 --- a/users/grfn/achilles/src/tc/mod.rs +++ b/users/grfn/achilles/src/tc/mod.rs @@ -8,7 +8,7 @@ use std::fmt::{self, Display}; use std::{mem, result}; use thiserror::Error; -use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal}; +use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal, Pattern}; use crate::common::env::Env; use crate::common::{Namer, NamerOf}; @@ -85,6 +85,7 @@ pub enum Type { Exist(TyVar), Nullary(NullaryType), Prim(PrimType), + Tuple(Vec<Type>), Unit, Fun { args: Vec<Type>, @@ -102,6 +103,9 @@ impl<'a> TryFrom<Type> for ast::Type<'a> { Type::Exist(_) => Err(value), Type::Nullary(_) => todo!(), Type::Prim(p) => Ok(p.into()), + Type::Tuple(members) => Ok(ast::Type::Tuple( + members.into_iter().map(|ty| ty.try_into()).try_collect()?, + )), Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType { args: args .clone() @@ -128,6 +132,7 @@ impl Display for Type { Type::Univ(TyVar(n)) => write!(f, "∀{}", n), Type::Exist(TyVar(n)) => write!(f, "∃{}", n), Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), + Type::Tuple(members) => write!(f, "({})", members.iter().join(", ")), Type::Unit => write!(f, "()"), } } @@ -159,6 +164,31 @@ impl<'ast> Typechecker<'ast> { } } + fn bind_pattern( + &mut self, + pat: Pattern<'ast>, + type_: Type, + ) -> Result<hir::Pattern<'ast, Type>> { + match pat { + Pattern::Id(ident) => { + self.env.set(ident.clone(), type_.clone()); + Ok(hir::Pattern::Id(ident, type_)) + } + Pattern::Tuple(members) => { + let mut tys = Vec::with_capacity(members.len()); + let mut hir_members = Vec::with_capacity(members.len()); + for pat in members { + let ty = self.fresh_ex(); + hir_members.push(self.bind_pattern(pat, ty.clone())?); + tys.push(ty); + } + let tuple_type = Type::Tuple(tys); + self.unify(&tuple_type, &type_)?; + Ok(hir::Pattern::Tuple(hir_members)) + } + } + } + pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> { match expr { ast::Expr::Ident(ident) => { @@ -178,6 +208,14 @@ impl<'ast> Typechecker<'ast> { }; Ok(hir::Expr::Literal(lit.to_owned(), type_)) } + ast::Expr::Tuple(members) => { + let members = members + .into_iter() + .map(|expr| self.tc_expr(expr)) + .collect::<Result<Vec<_>>>()?; + let type_ = Type::Tuple(members.iter().map(|expr| expr.type_().clone()).collect()); + Ok(hir::Expr::Tuple(members, type_)) + } ast::Expr::UnaryOp { op, rhs } => todo!(), ast::Expr::BinaryOp { lhs, op, rhs } => { let lhs = self.tc_expr(*lhs)?; @@ -209,18 +247,14 @@ impl<'ast> Typechecker<'ast> { let bindings = bindings .into_iter() .map( - |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> { + |ast::Binding { pat, type_, body }| -> Result<hir::Binding<Type>> { let body = self.tc_expr(body)?; if let Some(type_) = type_ { let type_ = self.type_from_ast_type(type_); self.unify(body.type_(), &type_)?; } - self.env.set(ident.clone(), body.type_().clone()); - Ok(hir::Binding { - ident, - type_: body.type_().clone(), - body, - }) + let pat = self.bind_pattern(pat, body.type_().clone())?; + Ok(hir::Binding { pat, body }) }, ) .collect::<Result<Vec<hir::Binding<Type>>>>()?; @@ -382,7 +416,7 @@ impl<'ast> Typechecker<'ast> { fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> { match (ty1, ty2) { (Type::Unit, Type::Unit) => Ok(Type::Unit), - (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { + (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv)? { Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), Some(var @ ast::Type::Var(_)) => { let var = self.type_from_ast_type(var); @@ -419,6 +453,14 @@ impl<'ast> Typechecker<'ast> { } } (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), + (Type::Tuple(t1), Type::Tuple(t2)) if t1.len() == t2.len() => { + let ts = t1 + .iter() + .zip(t2.iter()) + .map(|(t1, t2)| self.unify(t1, t2)) + .try_collect()?; + Ok(Type::Tuple(ts)) + } ( Type::Fun { args: args1, @@ -469,11 +511,17 @@ impl<'ast> Typechecker<'ast> { fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> { let ret = match ty { - Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), + Type::Exist(tv) => self.resolve_tv(tv)?.ok_or(Error::AmbiguousType(tv)), Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), Type::Unit => Ok(ast::Type::Unit), Type::Nullary(_) => todo!(), Type::Prim(pr) => Ok(pr.into()), + Type::Tuple(members) => Ok(ast::Type::Tuple( + members + .into_iter() + .map(|ty| self.finalize_type(ty)) + .try_collect()?, + )), Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { args: args .into_iter() @@ -485,12 +533,15 @@ impl<'ast> Typechecker<'ast> { ret } - fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> { + fn resolve_tv(&self, tv: TyVar) -> Result<Option<ast::Type<'static>>> { let mut res = &Type::Exist(tv); - loop { + Ok(loop { match res { Type::Exist(tv) => { - res = self.ctx.get(tv)?; + res = match self.ctx.get(tv) { + Some(r) => r, + None => return Ok(None), + }; } Type::Univ(tv) => { let ident = self.name_univ(*tv); @@ -504,8 +555,9 @@ impl<'ast> Typechecker<'ast> { Type::Prim(pr) => break Some((*pr).into()), Type::Unit => break Some(ast::Type::Unit), Type::Fun { args, ret } => todo!(), + Type::Tuple(_) => break Some(self.finalize_type(res.clone())?), } - } + }) } fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { @@ -515,6 +567,12 @@ impl<'ast> Typechecker<'ast> { ast::Type::Float => FLOAT, ast::Type::Bool => BOOL, ast::Type::CString => CSTRING, + ast::Type::Tuple(members) => Type::Tuple( + members + .into_iter() + .map(|ty| self.type_from_ast_type(ty)) + .collect(), + ), ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { args: args .into_iter() @@ -582,6 +640,11 @@ impl<'ast> Typechecker<'ast> { (Type::Unit, _) => false, (Type::Nullary(_), _) => todo!(), (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, + (Type::Tuple(members), ast::Type::Tuple(members2)) => members + .iter() + .zip(members2.iter()) + .all(|(t1, t2)| self.types_match(t1, t2)), + (Type::Tuple(members), _) => false, (Type::Fun { args, ret }, ast::Type::Function(ft)) => { args.len() == ft.args.len() && args |