about summary refs log tree commit diff
path: root/users/grfn/achilles/src/tc/mod.rs
diff options
context:
space:
mode:
authorGriffin Smith <grfn@gws.fyi>2021-04-17T06·28+0200
committergrfn <grfn@gws.fyi>2021-04-17T06·33+0000
commit48098f83c1e3943d1fb76aaecdce3b4f56cf4d4a (patch)
treec7a09a380bf66b4b20a46ab60890e2d8c7a2de67 /users/grfn/achilles/src/tc/mod.rs
parente1c45be3f58f5ae139439994b8d72f7d58c5c895 (diff)
feat(grfn/achilles): Implement tuples, and tuple patterns r/2522
Implement tuple expressions, types, and patterns, all the way through
the parser down to the typechecker. In LLVM, these are implemented as
anonymous structs, using an `extract` instruction when they're pattern
matched on to get out the individual fields.

Currently the only limitation here is patterns aren't supported in
function argument position, but you can still do something like

    fn xy = let (x, y) = xy in x + y

Change-Id: I357f17e9d4052e741eda8605b6662822f331efde
Reviewed-on: https://cl.tvl.fyi/c/depot/+/3027
Reviewed-by: grfn <grfn@gws.fyi>
Tested-by: BuildkiteCI
Diffstat (limited to '')
-rw-r--r--users/grfn/achilles/src/tc/mod.rs91
1 files changed, 77 insertions, 14 deletions
diff --git a/users/grfn/achilles/src/tc/mod.rs b/users/grfn/achilles/src/tc/mod.rs
index d27c45075e..5825bab1fb 100644
--- a/users/grfn/achilles/src/tc/mod.rs
+++ b/users/grfn/achilles/src/tc/mod.rs
@@ -8,7 +8,7 @@ use std::fmt::{self, Display};
 use std::{mem, result};
 use thiserror::Error;
 
-use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal};
+use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal, Pattern};
 use crate::common::env::Env;
 use crate::common::{Namer, NamerOf};
 
@@ -85,6 +85,7 @@ pub enum Type {
     Exist(TyVar),
     Nullary(NullaryType),
     Prim(PrimType),
+    Tuple(Vec<Type>),
     Unit,
     Fun {
         args: Vec<Type>,
@@ -102,6 +103,9 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
             Type::Exist(_) => Err(value),
             Type::Nullary(_) => todo!(),
             Type::Prim(p) => Ok(p.into()),
+            Type::Tuple(members) => Ok(ast::Type::Tuple(
+                members.into_iter().map(|ty| ty.try_into()).try_collect()?,
+            )),
             Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType {
                 args: args
                     .clone()
@@ -128,6 +132,7 @@ impl Display for Type {
             Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
             Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
             Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+            Type::Tuple(members) => write!(f, "({})", members.iter().join(", ")),
             Type::Unit => write!(f, "()"),
         }
     }
@@ -159,6 +164,31 @@ impl<'ast> Typechecker<'ast> {
         }
     }
 
+    fn bind_pattern(
+        &mut self,
+        pat: Pattern<'ast>,
+        type_: Type,
+    ) -> Result<hir::Pattern<'ast, Type>> {
+        match pat {
+            Pattern::Id(ident) => {
+                self.env.set(ident.clone(), type_.clone());
+                Ok(hir::Pattern::Id(ident, type_))
+            }
+            Pattern::Tuple(members) => {
+                let mut tys = Vec::with_capacity(members.len());
+                let mut hir_members = Vec::with_capacity(members.len());
+                for pat in members {
+                    let ty = self.fresh_ex();
+                    hir_members.push(self.bind_pattern(pat, ty.clone())?);
+                    tys.push(ty);
+                }
+                let tuple_type = Type::Tuple(tys);
+                self.unify(&tuple_type, &type_)?;
+                Ok(hir::Pattern::Tuple(hir_members))
+            }
+        }
+    }
+
     pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> {
         match expr {
             ast::Expr::Ident(ident) => {
@@ -178,6 +208,14 @@ impl<'ast> Typechecker<'ast> {
                 };
                 Ok(hir::Expr::Literal(lit.to_owned(), type_))
             }
+            ast::Expr::Tuple(members) => {
+                let members = members
+                    .into_iter()
+                    .map(|expr| self.tc_expr(expr))
+                    .collect::<Result<Vec<_>>>()?;
+                let type_ = Type::Tuple(members.iter().map(|expr| expr.type_().clone()).collect());
+                Ok(hir::Expr::Tuple(members, type_))
+            }
             ast::Expr::UnaryOp { op, rhs } => todo!(),
             ast::Expr::BinaryOp { lhs, op, rhs } => {
                 let lhs = self.tc_expr(*lhs)?;
@@ -209,18 +247,14 @@ impl<'ast> Typechecker<'ast> {
                 let bindings = bindings
                     .into_iter()
                     .map(
-                        |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> {
+                        |ast::Binding { pat, type_, body }| -> Result<hir::Binding<Type>> {
                             let body = self.tc_expr(body)?;
                             if let Some(type_) = type_ {
                                 let type_ = self.type_from_ast_type(type_);
                                 self.unify(body.type_(), &type_)?;
                             }
-                            self.env.set(ident.clone(), body.type_().clone());
-                            Ok(hir::Binding {
-                                ident,
-                                type_: body.type_().clone(),
-                                body,
-                            })
+                            let pat = self.bind_pattern(pat, body.type_().clone())?;
+                            Ok(hir::Binding { pat, body })
                         },
                     )
                     .collect::<Result<Vec<hir::Binding<Type>>>>()?;
@@ -382,7 +416,7 @@ impl<'ast> Typechecker<'ast> {
     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
         match (ty1, ty2) {
             (Type::Unit, Type::Unit) => Ok(Type::Unit),
-            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
+            (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv)? {
                 Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
                 Some(var @ ast::Type::Var(_)) => {
                     let var = self.type_from_ast_type(var);
@@ -419,6 +453,14 @@ impl<'ast> Typechecker<'ast> {
                 }
             }
             (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()),
+            (Type::Tuple(t1), Type::Tuple(t2)) if t1.len() == t2.len() => {
+                let ts = t1
+                    .iter()
+                    .zip(t2.iter())
+                    .map(|(t1, t2)| self.unify(t1, t2))
+                    .try_collect()?;
+                Ok(Type::Tuple(ts))
+            }
             (
                 Type::Fun {
                     args: args1,
@@ -469,11 +511,17 @@ impl<'ast> Typechecker<'ast> {
 
     fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> {
         let ret = match ty {
-            Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
+            Type::Exist(tv) => self.resolve_tv(tv)?.ok_or(Error::AmbiguousType(tv)),
             Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
             Type::Unit => Ok(ast::Type::Unit),
             Type::Nullary(_) => todo!(),
             Type::Prim(pr) => Ok(pr.into()),
+            Type::Tuple(members) => Ok(ast::Type::Tuple(
+                members
+                    .into_iter()
+                    .map(|ty| self.finalize_type(ty))
+                    .try_collect()?,
+            )),
             Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
                 args: args
                     .into_iter()
@@ -485,12 +533,15 @@ impl<'ast> Typechecker<'ast> {
         ret
     }
 
-    fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> {
+    fn resolve_tv(&self, tv: TyVar) -> Result<Option<ast::Type<'static>>> {
         let mut res = &Type::Exist(tv);
-        loop {
+        Ok(loop {
             match res {
                 Type::Exist(tv) => {
-                    res = self.ctx.get(tv)?;
+                    res = match self.ctx.get(tv) {
+                        Some(r) => r,
+                        None => return Ok(None),
+                    };
                 }
                 Type::Univ(tv) => {
                     let ident = self.name_univ(*tv);
@@ -504,8 +555,9 @@ impl<'ast> Typechecker<'ast> {
                 Type::Prim(pr) => break Some((*pr).into()),
                 Type::Unit => break Some(ast::Type::Unit),
                 Type::Fun { args, ret } => todo!(),
+                Type::Tuple(_) => break Some(self.finalize_type(res.clone())?),
             }
-        }
+        })
     }
 
     fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
@@ -515,6 +567,12 @@ impl<'ast> Typechecker<'ast> {
             ast::Type::Float => FLOAT,
             ast::Type::Bool => BOOL,
             ast::Type::CString => CSTRING,
+            ast::Type::Tuple(members) => Type::Tuple(
+                members
+                    .into_iter()
+                    .map(|ty| self.type_from_ast_type(ty))
+                    .collect(),
+            ),
             ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun {
                 args: args
                     .into_iter()
@@ -582,6 +640,11 @@ impl<'ast> Typechecker<'ast> {
             (Type::Unit, _) => false,
             (Type::Nullary(_), _) => todo!(),
             (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
+            (Type::Tuple(members), ast::Type::Tuple(members2)) => members
+                .iter()
+                .zip(members2.iter())
+                .all(|(t1, t2)| self.types_match(t1, t2)),
+            (Type::Tuple(members), _) => false,
             (Type::Fun { args, ret }, ast::Type::Function(ft)) => {
                 args.len() == ft.args.len()
                     && args