about summary refs log tree commit diff
path: root/users/grfn/achilles/src/tc/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'users/grfn/achilles/src/tc/mod.rs')
-rw-r--r--users/grfn/achilles/src/tc/mod.rs91
1 files changed, 77 insertions, 14 deletions
diff --git a/users/grfn/achilles/src/tc/mod.rs b/users/grfn/achilles/src/tc/mod.rs
index d27c45075e97..5825bab1fbe9 100644
--- a/users/grfn/achilles/src/tc/mod.rs
+++ b/users/grfn/achilles/src/tc/mod.rs
@@ -8,7 +8,7 @@ use std::fmt::{self, Display};
 use std::{mem, result};
 use thiserror::Error;
 
-use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal};
+use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal, Pattern};
 use crate::common::env::Env;
 use crate::common::{Namer, NamerOf};
 
@@ -85,6 +85,7 @@ pub enum Type {
     Exist(TyVar),
     Nullary(NullaryType),
     Prim(PrimType),
+    Tuple(Vec<Type>),
     Unit,
     Fun {
         args: Vec<Type>,
@@ -102,6 +103,9 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
             Type::Exist(_) => Err(value),
             Type::Nullary(_) => todo!(),
             Type::Prim(p) => Ok(p.into()),
+            Type::Tuple(members) => Ok(ast::Type::Tuple(
+                members.into_iter().map(|ty| ty.try_into()).try_collect()?,
+            )),
             Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
                 args: args
                     .clone()
@@ -128,6 +132,7 @@ impl Display for Type {
             Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
             Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
             Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+            Type::Tuple(members) => write!(f, "({})", members.iter().join(", ")),
             Type::Unit => write!(f, "()"),
         }
     }
@@ -159,6 +164,31 @@ impl<'ast> Typechecker<'ast> {
         }
     }
 
+    fn bind_pattern(
+        &mut self,
+        pat: Pattern<'ast>,
+        type_: Type,
+    ) -> Result<hir::Pattern<'ast, Type>> {
+        match pat {
+            Pattern::Id(ident) => {
+                self.env.set(ident.clone(), type_.clone());
+                Ok(hir::Pattern::Id(ident, type_))
+            }
+            Pattern::Tuple(members) => {
+                let mut tys = Vec::with_capacity(members.len());
+                let mut hir_members = Vec::with_capacity(members.len());
+                for pat in members {
+                    let ty = self.fresh_ex();
+                    hir_members.push(self.bind_pattern(pat, ty.clone())?);
+                    tys.push(ty);
+                }
+                let tuple_type = Type::Tuple(tys);
+                self.unify(&tuple_type, &type_)?;
+                Ok(hir::Pattern::Tuple(hir_members))
+            }
+        }
+    }
+
     pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
         match expr {
             ast::Expr::Ident(ident) => {
@@ -178,6 +208,14 @@ impl<'ast> Typechecker<'ast> {
                 };
                 Ok(hir::Expr::Literal(lit.to_owned(), type_))
             }
+            ast::Expr::Tuple(members) => {
+                let members = members
+                    .into_iter()
+                    .map(|expr| self.tc_expr(expr))
+                    .collect::<Result<Vec<_>>>()?;
+                let type_ = Type::Tuple(members.iter().map(|expr| expr.type_().clone()).collect());
+                Ok(hir::Expr::Tuple(members, type_))
+            }
             ast::Expr::UnaryOp { op, rhs } => todo!(),
             ast::Expr::BinaryOp { lhs, op, rhs } => {
                 let lhs = self.tc_expr(*lhs)?;
@@ -209,18 +247,14 @@ impl<'ast> Typechecker<'ast> {
                 let bindings = bindings
                     .into_iter()
                     .map(
-                        |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
+                        |ast::Binding { pat, type_, body }| -> Result<hir::Binding<Type>> {
                             let body = self.tc_expr(body)?;
                             if let Some(type_) = type_ {
                                 let type_ = self.type_from_ast_type(type_);
                                 self.unify(body.type_(), &type_)?;
                             }
-                            self.env.set(ident.clone(), body.type_().clone());
-                            Ok(hir::Binding {
-                                ident,
-                                type_: body.type_().clone(),
-                                body,
-                            })
+                            let pat = self.bind_pattern(pat, body.type_().clone())?;
+                            Ok(hir::Binding { pat, body })
                         },
                     )
                     .collect::<Result<Vec<hir::Binding<Type>>>>()?;
@@ -382,7 +416,7 @@ impl<'ast> Typechecker<'ast> {
     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
         match (ty1, ty2) {
             (Type::Unit, Type::Unit) => Ok(Type::Unit),
-            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
+            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv)? {
                 Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
                 Some(var @ ast::Type::Var(_)) => {
                     let var = self.type_from_ast_type(var);
@@ -419,6 +453,14 @@ impl<'ast> Typechecker<'ast> {
                 }
             }
             (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
+            (Type::Tuple(t1), Type::Tuple(t2)) if t1.len() == t2.len() => {
+                let ts = t1
+                    .iter()
+                    .zip(t2.iter())
+                    .map(|(t1, t2)| self.unify(t1, t2))
+                    .try_collect()?;
+                Ok(Type::Tuple(ts))
+            }
             (
                 Type::Fun {
                     args: args1,
@@ -469,11 +511,17 @@ impl<'ast> Typechecker<'ast> {
 
     fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
         let ret = match ty {
-            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
+            Type::Exist(tv) => self.resolve_tv(tv)?.ok_or(Error::AmbiguousType(tv)),
             Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
             Type::Unit => Ok(ast::Type::Unit),
             Type::Nullary(_) => todo!(),
             Type::Prim(pr) => Ok(pr.into()),
+            Type::Tuple(members) => Ok(ast::Type::Tuple(
+                members
+                    .into_iter()
+                    .map(|ty| self.finalize_type(ty))
+                    .try_collect()?,
+            )),
             Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
                 args: args
                     .into_iter()
@@ -485,12 +533,15 @@ impl<'ast> Typechecker<'ast> {
         ret
     }
 
-    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> {
+    fn resolve_tv(&self, tv: TyVar) -> Result<Option<ast::Type<'static>>> {
         let mut res = &Type::Exist(tv);
-        loop {
+        Ok(loop {
             match res {
                 Type::Exist(tv) => {
-                    res = self.ctx.get(tv)?;
+                    res = match self.ctx.get(tv) {
+                        Some(r) => r,
+                        None => return Ok(None),
+                    };
                 }
                 Type::Univ(tv) => {
                     let ident = self.name_univ(*tv);
@@ -504,8 +555,9 @@ impl<'ast> Typechecker<'ast> {
                 Type::Prim(pr) => break Some((*pr).into()),
                 Type::Unit => break Some(ast::Type::Unit),
                 Type::Fun { args, ret } => todo!(),
+                Type::Tuple(_) => break Some(self.finalize_type(res.clone())?),
             }
-        }
+        })
     }
 
     fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
@@ -515,6 +567,12 @@ impl<'ast> Typechecker<'ast> {
             ast::Type::Float => FLOAT,
             ast::Type::Bool => BOOL,
             ast::Type::CString => CSTRING,
+            ast::Type::Tuple(members) => Type::Tuple(
+                members
+                    .into_iter()
+                    .map(|ty| self.type_from_ast_type(ty))
+                    .collect(),
+            ),
             ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
                 args: args
                     .into_iter()
@@ -582,6 +640,11 @@ impl<'ast> Typechecker<'ast> {
             (Type::Unit, _) => false,
             (Type::Nullary(_), _) => todo!(),
             (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Tuple(members), ast::Type::Tuple(members2)) => members
+                .iter()
+                .zip(members2.iter())
+                .all(|(t1, t2)| self.types_match(t1, t2)),
+            (Type::Tuple(members), _) => false,
             (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
                 args.len() == ft.args.len()
                     && args