From 48098f83c1e3943d1fb76aaecdce3b4f56cf4d4a Mon Sep 17 00:00:00 2001 From: Griffin Smith Date: Sat, 17 Apr 2021 08:28:24 +0200 Subject: feat(grfn/achilles): Implement tuples, and tuple patterns Implement tuple expressions, types, and patterns, all the way through the parser down to the typechecker. In LLVM, these are implemented as anonymous structs, using an `extract` instruction when they're pattern matched on to get out the individual fields. Currently the only limitation here is patterns aren't supported in function argument position, but you can still do something like fn xy = let (x, y) = xy in x + y Change-Id: I357f17e9d4052e741eda8605b6662822f331efde Reviewed-on: https://cl.tvl.fyi/c/depot/+/3027 Reviewed-by: grfn Tested-by: BuildkiteCI --- users/grfn/achilles/src/ast/hir.rs | 58 +++++++++++-- users/grfn/achilles/src/ast/mod.rs | 41 +++++++++- users/grfn/achilles/src/codegen/llvm.rs | 60 ++++++++++++-- users/grfn/achilles/src/interpreter/mod.rs | 27 +++++- users/grfn/achilles/src/interpreter/value.rs | 21 +++++ users/grfn/achilles/src/parser/expr.rs | 95 +++++++++++++++++++--- users/grfn/achilles/src/parser/mod.rs | 1 + users/grfn/achilles/src/parser/type_.rs | 35 ++++++-- users/grfn/achilles/src/parser/util.rs | 8 ++ users/grfn/achilles/src/passes/hir/mod.rs | 22 ++++- .../src/passes/hir/strip_positive_units.rs | 8 +- users/grfn/achilles/src/tc/mod.rs | 91 +++++++++++++++++---- 12 files changed, 413 insertions(+), 54 deletions(-) create mode 100644 users/grfn/achilles/src/parser/util.rs (limited to 'users') diff --git a/users/grfn/achilles/src/ast/hir.rs b/users/grfn/achilles/src/ast/hir.rs index 0d145d620b..cdfaef567d 100644 --- a/users/grfn/achilles/src/ast/hir.rs +++ b/users/grfn/achilles/src/ast/hir.rs @@ -4,10 +4,43 @@ use itertools::Itertools; use super::{BinaryOperator, Ident, Literal, UnaryOperator}; +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Pattern<'a, T> { + Id(Ident<'a>, T), + Tuple(Vec>), +} + +impl<'a, T> Pattern<'a, T> { + pub fn to_owned(&self) -> Pattern<'static, T> + where + T: Clone, + { + match self { + Pattern::Id(id, t) => Pattern::Id(id.to_owned(), t.clone()), + Pattern::Tuple(pats) => { + Pattern::Tuple(pats.into_iter().map(Pattern::to_owned).collect()) + } + } + } + + pub fn traverse_type(self, f: F) -> Result, E> + where + F: Fn(T) -> Result + Clone, + { + match self { + Pattern::Id(id, t) => Ok(Pattern::Id(id, f(t)?)), + Pattern::Tuple(pats) => Ok(Pattern::Tuple( + pats.into_iter() + .map(|pat| pat.traverse_type(f.clone())) + .collect::, _>>()?, + )), + } + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding<'a, T> { - pub ident: Ident<'a>, - pub type_: T, + pub pat: Pattern<'a, T>, pub body: Expr<'a, T>, } @@ -17,8 +50,7 @@ impl<'a, T> Binding<'a, T> { T: Clone, { Binding { - ident: self.ident.to_owned(), - type_: self.type_.clone(), + pat: self.pat.to_owned(), body: self.body.to_owned(), } } @@ -30,6 +62,8 @@ pub enum Expr<'a, T> { Literal(Literal<'a>, T), + Tuple(Vec>, T), + UnaryOp { op: UnaryOperator, rhs: Box>, @@ -76,6 +110,7 @@ impl<'a, T> Expr<'a, T> { match self { Expr::Ident(_, t) => t, Expr::Literal(_, t) => t, + Expr::Tuple(_, t) => t, Expr::UnaryOp { type_, .. } => type_, Expr::BinaryOp { type_, .. } => type_, Expr::Let { type_, .. } => type_, @@ -115,10 +150,9 @@ impl<'a, T> Expr<'a, T> { } => Ok(Expr::Let { bindings: bindings .into_iter() - .map(|Binding { ident, type_, body }| { + .map(|Binding { pat, body }| { Ok(Binding { - ident, - type_: f(type_)?, + pat: pat.traverse_type(f.clone())?, body: body.traverse_type(f.clone())?, }) }) @@ -168,6 +202,13 @@ impl<'a, T> Expr<'a, T> { .collect::, E>>()?, type_: f(type_)?, }), + Expr::Tuple(members, t) => Ok(Expr::Tuple( + members + .into_iter() + .map(|t| t.traverse_type(f.clone())) + .try_collect()?, + f(t)?, + )), } } @@ -242,6 +283,9 @@ impl<'a, T> Expr<'a, T> { args: args.iter().map(|e| e.to_owned()).collect(), type_: type_.clone(), }, + Expr::Tuple(members, t) => { + Expr::Tuple(members.into_iter().map(Expr::to_owned).collect(), t.clone()) + } } } } diff --git a/users/grfn/achilles/src/ast/mod.rs b/users/grfn/achilles/src/ast/mod.rs index 7dc2de8957..5438d29d2c 100644 --- a/users/grfn/achilles/src/ast/mod.rs +++ b/users/grfn/achilles/src/ast/mod.rs @@ -127,9 +127,24 @@ impl<'a> Literal<'a> { } } +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Pattern<'a> { + Id(Ident<'a>), + Tuple(Vec>), +} + +impl<'a> Pattern<'a> { + pub fn to_owned(&self) -> Pattern<'static> { + match self { + Pattern::Id(id) => Pattern::Id(id.to_owned()), + Pattern::Tuple(pats) => Pattern::Tuple(pats.iter().map(Pattern::to_owned).collect()), + } + } +} + #[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding<'a> { - pub ident: Ident<'a>, + pub pat: Pattern<'a>, pub type_: Option>, pub body: Expr<'a>, } @@ -137,7 +152,7 @@ pub struct Binding<'a> { impl<'a> Binding<'a> { fn to_owned(&self) -> Binding<'static> { Binding { - ident: self.ident.to_owned(), + pat: self.pat.to_owned(), type_: self.type_.as_ref().map(|t| t.to_owned()), body: self.body.to_owned(), } @@ -179,6 +194,8 @@ pub enum Expr<'a> { args: Vec>, }, + Tuple(Vec>), + Ascription { expr: Box>, type_: Type<'a>, @@ -190,6 +207,9 @@ impl<'a> Expr<'a> { match self { Expr::Ident(ref id) => Expr::Ident(id.to_owned()), Expr::Literal(ref lit) => Expr::Literal(lit.to_owned()), + Expr::Tuple(ref members) => { + Expr::Tuple(members.into_iter().map(Expr::to_owned).collect()) + } Expr::UnaryOp { op, rhs } => Expr::UnaryOp { op: *op, rhs: Box::new((**rhs).to_owned()), @@ -312,6 +332,7 @@ pub enum Type<'a> { Bool, CString, Unit, + Tuple(Vec>), Var(Ident<'a>), Function(FunctionType<'a>), } @@ -326,6 +347,7 @@ impl<'a> Type<'a> { Type::Unit => Type::Unit, Type::Var(v) => Type::Var(v.to_owned()), Type::Function(f) => Type::Function(f.to_owned()), + Type::Tuple(members) => Type::Tuple(members.iter().map(Type::to_owned).collect()), } } @@ -379,9 +401,23 @@ impl<'a> Type<'a> { Type::Float => Type::Float, Type::Bool => Type::Bool, Type::CString => Type::CString, + Type::Tuple(members) => Type::Tuple( + members + .into_iter() + .map(|t| t.traverse_type_vars(f.clone())) + .collect(), + ), Type::Unit => Type::Unit, } } + + pub fn as_tuple(&self) -> Option<&Vec>> { + if let Self::Tuple(v) = self { + Some(v) + } else { + None + } + } } impl<'a> Display for Type<'a> { @@ -394,6 +430,7 @@ impl<'a> Display for Type<'a> { Type::Unit => f.write_str("()"), Type::Var(v) => v.fmt(f), Type::Function(ft) => ft.fmt(f), + Type::Tuple(ms) => write!(f, "({})", ms.iter().join(", ")), } } } diff --git a/users/grfn/achilles/src/codegen/llvm.rs b/users/grfn/achilles/src/codegen/llvm.rs index 17dec58b5f..9a71ac954e 100644 --- a/users/grfn/achilles/src/codegen/llvm.rs +++ b/users/grfn/achilles/src/codegen/llvm.rs @@ -7,12 +7,13 @@ use inkwell::builder::Builder; pub use inkwell::context::Context; use inkwell::module::Module; use inkwell::support::LLVMString; -use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType}; -use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; +use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType, StructType}; +use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue, StructValue}; use inkwell::{AddressSpace, IntPredicate}; +use itertools::Itertools; use thiserror::Error; -use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::hir::{Binding, Decl, Expr, Pattern}; use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; @@ -82,6 +83,25 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .append_basic_block(*self.function_stack.last().unwrap(), name) } + fn bind_pattern(&mut self, pat: &'ast Pattern<'ast, Type>, val: AnyValueEnum<'ctx>) { + match pat { + Pattern::Id(id, _) => self.env.set(id, val), + Pattern::Tuple(pats) => { + for (i, pat) in pats.iter().enumerate() { + let member = self + .builder + .build_extract_value( + StructValue::try_from(val).unwrap(), + i as _, + "pat_bind", + ) + .unwrap(); + self.bind_pattern(pat, member.into()); + } + } + } + } + pub fn codegen_expr( &mut self, expr: &'ast Expr<'ast, Type>, @@ -164,9 +184,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { } Expr::Let { bindings, body, .. } => { self.env.push(); - for Binding { ident, body, .. } in bindings { + for Binding { pat, body, .. } in bindings { if let Some(val) = self.codegen_expr(body)? { - self.env.set(ident, val); + self.bind_pattern(pat, val); } } let res = self.codegen_expr(body); @@ -244,6 +264,19 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.env.restore(env); Ok(Some(function.into())) } + Expr::Tuple(members, ty) => { + let values = members + .into_iter() + .map(|expr| self.codegen_expr(expr)) + .collect::>>()? + .into_iter() + .filter_map(|x| x) + .map(|x| x.try_into().unwrap()) + .collect_vec(); + let field_types = ty.as_tuple().unwrap(); + let tuple_type = self.codegen_tuple_type(field_types); + Ok(Some(tuple_type.const_named_struct(&values).into())) + } } } @@ -341,9 +374,20 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { Type::Function(_) => todo!(), Type::Var(_) => unreachable!(), Type::Unit => None, + Type::Tuple(ts) => Some(self.codegen_tuple_type(ts).into()), } } + fn codegen_tuple_type(&self, ts: &'ast [Type]) -> StructType<'ctx> { + self.context.struct_type( + ts.iter() + .filter_map(|t| self.codegen_type(t)) + .collect_vec() + .as_slice(), + false, + ) + } + fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> { // TODO self.context.i64_type() @@ -433,4 +477,10 @@ mod tests { let res = jit_eval::("let id = fn x = x in id 1").unwrap(); assert_eq!(res, 1); } + + #[test] + fn bind_tuple_pattern() { + let res = jit_eval::("let (x, y) = (1, 2) in x + y").unwrap(); + assert_eq!(res, 3); + } } diff --git a/users/grfn/achilles/src/interpreter/mod.rs b/users/grfn/achilles/src/interpreter/mod.rs index a8ba2dd3ac..70df7a0724 100644 --- a/users/grfn/achilles/src/interpreter/mod.rs +++ b/users/grfn/achilles/src/interpreter/mod.rs @@ -1,9 +1,12 @@ mod error; mod value; +use itertools::Itertools; +use value::Val; + pub use self::error::{Error, Result}; pub use self::value::{Function, Value}; -use crate::ast::hir::{Binding, Expr}; +use crate::ast::hir::{Binding, Expr, Pattern}; use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; @@ -24,6 +27,17 @@ impl<'a> Interpreter<'a> { .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) } + fn bind_pattern(&mut self, pattern: &'a Pattern<'a, Type>, value: Value<'a>) { + match pattern { + Pattern::Id(id, _) => self.env.set(id, value), + Pattern::Tuple(pats) => { + for (pat, val) in pats.iter().zip(value.as_tuple().unwrap().clone()) { + self.bind_pattern(pat, val); + } + } + } + } + pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result> { let res = match expr { Expr::Ident(id, _) => self.resolve(id), @@ -53,9 +67,9 @@ impl<'a> Interpreter<'a> { } Expr::Let { bindings, body, .. } => { self.env.push(); - for Binding { ident, body, .. } in bindings { + for Binding { pat, body, .. } in bindings { let val = self.eval(body)?; - self.env.set(ident, val); + self.bind_pattern(pat, val); } let res = self.eval(body)?; self.env.pop(); @@ -115,6 +129,13 @@ impl<'a> Interpreter<'a> { body: (**body).to_owned(), })) } + Expr::Tuple(members, _) => Ok(Val::Tuple( + members + .into_iter() + .map(|expr| self.eval(expr)) + .try_collect()?, + ) + .into()), }?; debug_assert_eq!(&res.type_(), expr.type_()); Ok(res) diff --git a/users/grfn/achilles/src/interpreter/value.rs b/users/grfn/achilles/src/interpreter/value.rs index 55ba42f9de..272d1167a3 100644 --- a/users/grfn/achilles/src/interpreter/value.rs +++ b/users/grfn/achilles/src/interpreter/value.rs @@ -6,6 +6,7 @@ use std::rc::Rc; use std::result; use derive_more::{Deref, From, TryInto}; +use itertools::Itertools; use super::{Error, Result}; use crate::ast::hir::Expr; @@ -25,6 +26,7 @@ pub enum Val<'a> { Float(f64), Bool(bool), String(Cow<'a, str>), + Tuple(Vec>), Function(Function<'a>), } @@ -49,6 +51,7 @@ impl<'a> fmt::Debug for Val<'a> { Val::Function(Function { type_, .. }) => { f.debug_struct("Function").field("type_", type_).finish() } + Val::Tuple(members) => f.debug_tuple("Tuple").field(members).finish(), } } } @@ -79,6 +82,7 @@ impl<'a> Display for Val<'a> { Val::Bool(x) => x.fmt(f), Val::String(s) => write!(f, "{:?}", s), Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_), + Val::Tuple(members) => write!(f, "({})", members.iter().join(", ")), } } } @@ -91,6 +95,7 @@ impl<'a> Val<'a> { Val::Bool(_) => Type::Bool, Val::String(_) => Type::CString, Val::Function(Function { type_, .. }) => Type::Function(type_.clone()), + Val::Tuple(members) => Type::Tuple(members.iter().map(|expr| expr.type_()).collect()), } } @@ -114,6 +119,22 @@ impl<'a> Val<'a> { }), } } + + pub fn as_tuple(&self) -> Option<&Vec>> { + if let Self::Tuple(v) = self { + Some(v) + } else { + None + } + } + + pub fn try_into_tuple(self) -> result::Result>, Self> { + if let Self::Tuple(v) = self { + Ok(v) + } else { + Err(self) + } + } } #[derive(Debug, PartialEq, Clone, Deref)] diff --git a/users/grfn/achilles/src/parser/expr.rs b/users/grfn/achilles/src/parser/expr.rs index 8a28d00984..f596b18970 100644 --- a/users/grfn/achilles/src/parser/expr.rs +++ b/users/grfn/achilles/src/parser/expr.rs @@ -8,7 +8,8 @@ use nom::{ }; use pratt::{Affix, Associativity, PrattParser, Precedence}; -use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, UnaryOperator}; +use super::util::comma; +use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, Pattern, UnaryOperator}; use crate::parser::{arg, ident, type_}; #[derive(Debug)] @@ -192,9 +193,45 @@ named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit)); named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal)); +named!(tuple(&str) -> Expr, do_parse!( + complete!(tag!("(")) + >> multispace0 + >> fst: expr + >> comma + >> rest: separated_list0!( + comma, + expr + ) + >> multispace0 + >> tag!(")") + >> ({ + let mut members = Vec::with_capacity(rest.len() + 1); + members.push(fst); + members.append(&mut rest.clone()); + Expr::Tuple(members) + }) +)); + +named!(tuple_pattern(&str) -> Pattern, do_parse!( + complete!(tag!("(")) + >> multispace0 + >> pats: separated_list0!( + comma, + pattern + ) + >> multispace0 + >> tag!(")") + >> (Pattern::Tuple(pats)) +)); + +named!(pattern(&str) -> Pattern, alt!( + ident => { |id| Pattern::Id(id) } | + tuple_pattern +)); + named!(binding(&str) -> Binding, do_parse!( multispace0 - >> ident: ident + >> pat: pattern >> multispace0 >> type_: opt!(preceded!(tuple!(tag!(":"), multispace0), type_)) >> multispace0 @@ -202,7 +239,7 @@ named!(binding(&str) -> Binding, do_parse!( >> multispace0 >> body: expr >> (Binding { - ident, + pat, type_, body }) @@ -267,6 +304,7 @@ named!(paren_expr(&str) -> Expr, named!(funcref(&str) -> Expr, alt!( ident_expr | + tuple | paren_expr )); @@ -296,6 +334,7 @@ named!(fun_expr(&str) -> Expr, do_parse!( named!(fn_arg(&str) -> Expr, alt!( ident_expr | literal_expr | + tuple | paren_expr )); @@ -314,7 +353,8 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!( if_ | fun_expr | literal_expr | - ident_expr + ident_expr | + tuple )); named!(simple_expr(&str) -> Expr, alt!( @@ -334,7 +374,7 @@ named!(pub expr(&str) -> Expr, alt!( #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::ast::{Arg, Ident, Type}; + use crate::ast::{Arg, Ident, Pattern, Type}; use std::convert::TryFrom; use BinaryOperator::*; use Expr::{BinaryOp, If, Let, UnaryOp}; @@ -449,6 +489,17 @@ pub(crate) mod tests { ); } + #[test] + fn tuple() { + assert_eq!( + test_parse!(expr, "(1, \"seven\")"), + Expr::Tuple(vec![ + Expr::Literal(Literal::Int(1)), + Expr::Literal(Literal::String(Cow::Borrowed("seven"))) + ]) + ) + } + #[test] fn simple_string_lit() { assert_eq!( @@ -465,12 +516,12 @@ pub(crate) mod tests { Let { bindings: vec![ Binding { - ident: Ident::try_from("x").unwrap(), + pat: Pattern::Id(Ident::try_from("x").unwrap()), type_: None, body: Expr::Literal(Literal::Int(1)) }, Binding { - ident: Ident::try_from("y").unwrap(), + pat: Pattern::Id(Ident::try_from("y").unwrap()), type_: None, body: Expr::BinaryOp { lhs: ident_expr("x"), @@ -553,7 +604,7 @@ pub(crate) mod tests { Expr::Call { fun: Box::new(Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("x").unwrap(), + pat: Pattern::Id(Ident::try_from("x").unwrap()), type_: None, body: Expr::Literal(Literal::Int(1)) }], @@ -571,7 +622,7 @@ pub(crate) mod tests { res, Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("id").unwrap(), + pat: Pattern::Id(Ident::try_from("id").unwrap()), type_: None, body: Expr::Fun(Box::new(Fun { args: vec![Arg::try_from("x").unwrap()], @@ -586,6 +637,28 @@ pub(crate) mod tests { ); } + #[test] + fn tuple_binding() { + let res = test_parse!(expr, "let (x, y) = (1, 2) in x"); + assert_eq!( + res, + Expr::Let { + bindings: vec![Binding { + pat: Pattern::Tuple(vec![ + Pattern::Id(Ident::from_str_unchecked("x")), + Pattern::Id(Ident::from_str_unchecked("y")) + ]), + body: Expr::Tuple(vec![ + Expr::Literal(Literal::Int(1)), + Expr::Literal(Literal::Int(2)) + ]), + type_: None + }], + body: Box::new(Expr::Ident(Ident::from_str_unchecked("x"))) + } + ) + } + mod ascriptions { use super::*; @@ -608,7 +681,7 @@ pub(crate) mod tests { res, Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("const_1").unwrap(), + pat: Pattern::Id(Ident::try_from("const_1").unwrap()), type_: None, body: Expr::Fun(Box::new(Fun { args: vec![Arg::try_from("x").unwrap()], @@ -633,7 +706,7 @@ pub(crate) mod tests { res, Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("x").unwrap(), + pat: Pattern::Id(Ident::try_from("x").unwrap()), type_: Some(Type::Int), body: Expr::Literal(Literal::Int(1)) }], diff --git a/users/grfn/achilles/src/parser/mod.rs b/users/grfn/achilles/src/parser/mod.rs index 3e0081bd39..e088cbca10 100644 --- a/users/grfn/achilles/src/parser/mod.rs +++ b/users/grfn/achilles/src/parser/mod.rs @@ -6,6 +6,7 @@ use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag pub(crate) mod macros; mod expr; mod type_; +mod util; use crate::ast::{Arg, Decl, Fun, Ident}; pub use expr::expr; diff --git a/users/grfn/achilles/src/parser/type_.rs b/users/grfn/achilles/src/parser/type_.rs index 8a1081e252..b80f0e0860 100644 --- a/users/grfn/achilles/src/parser/type_.rs +++ b/users/grfn/achilles/src/parser/type_.rs @@ -2,17 +2,14 @@ use nom::character::complete::{multispace0, multispace1}; use nom::{alt, delimited, do_parse, map, named, opt, separated_list0, tag, terminated, tuple}; use super::ident; +use super::util::comma; use crate::ast::{FunctionType, Type}; named!(pub function_type(&str) -> FunctionType, do_parse!( tag!("fn") >> multispace1 >> args: map!(opt!(terminated!(separated_list0!( - tuple!( - multispace0, - tag!(","), - multispace0 - ), + comma, type_ ), multispace1)), |args| args.unwrap_or_default()) >> tag!("->") @@ -24,12 +21,32 @@ named!(pub function_type(&str) -> FunctionType, do_parse!( }) )); +named!(tuple_type(&str) -> Type, do_parse!( + tag!("(") + >> multispace0 + >> fst: type_ + >> comma + >> rest: separated_list0!( + comma, + type_ + ) + >> multispace0 + >> tag!(")") + >> ({ + let mut members = Vec::with_capacity(rest.len() + 1); + members.push(fst); + members.append(&mut rest.clone()); + Type::Tuple(members) + }) +)); + named!(pub type_(&str) -> Type, alt!( tag!("int") => { |_| Type::Int } | tag!("float") => { |_| Type::Float } | tag!("bool") => { |_| Type::Bool } | tag!("cstring") => { |_| Type::CString } | tag!("()") => { |_| Type::Unit } | + tuple_type | function_type => { |ft| Type::Function(ft) }| ident => { |id| Type::Var(id) } | delimited!( @@ -111,6 +128,14 @@ mod tests { ) } + #[test] + fn tuple() { + assert_eq!( + test_parse!(type_, "(int, int)"), + Type::Tuple(vec![Type::Int, Type::Int]) + ) + } + #[test] fn type_vars() { assert_eq!( diff --git a/users/grfn/achilles/src/parser/util.rs b/users/grfn/achilles/src/parser/util.rs new file mode 100644 index 0000000000..bb53fb7fff --- /dev/null +++ b/users/grfn/achilles/src/parser/util.rs @@ -0,0 +1,8 @@ +use nom::character::complete::multispace0; +use nom::{complete, map, named, tag, tuple}; + +named!(pub(crate) comma(&str) -> (), map!(tuple!( + multispace0, + complete!(tag!(",")), + multispace0 +) ,|_| ())); diff --git a/users/grfn/achilles/src/passes/hir/mod.rs b/users/grfn/achilles/src/passes/hir/mod.rs index 845bfcb7ab..872c449eb0 100644 --- a/users/grfn/achilles/src/passes/hir/mod.rs +++ b/users/grfn/achilles/src/passes/hir/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::hir::{Binding, Decl, Expr, Pattern}; use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; pub(crate) mod monomorphize; @@ -29,9 +29,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { Ok(()) } + fn visit_pattern(&mut self, _pat: &mut Pattern<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> { - self.visit_ident(&mut binding.ident)?; - self.visit_type(&mut binding.type_)?; + self.visit_pattern(&mut binding.pat)?; self.visit_expr(&mut binding.body)?; Ok(()) } @@ -54,6 +57,13 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { Ok(()) } + fn visit_tuple(&mut self, members: &mut Vec>) -> Result<(), Self::Error> { + for expr in members { + self.visit_expr(expr)?; + } + Ok(()) + } + fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { Ok(()) } @@ -137,12 +147,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { self.visit_type(type_)?; self.post_visit_call(fun, type_args, args)?; } + Expr::Tuple(tup, type_) => { + self.visit_tuple(tup)?; + self.visit_type(type_)?; + } } Ok(()) } - fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { + fn post_visit_decl(&mut self, _decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { Ok(()) } diff --git a/users/grfn/achilles/src/passes/hir/strip_positive_units.rs b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs index 91b56551c8..85ee1cce48 100644 --- a/users/grfn/achilles/src/passes/hir/strip_positive_units.rs +++ b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::mem; -use ast::hir::Binding; +use ast::hir::{Binding, Pattern}; use ast::Literal; use void::{ResultVoidExt, Void}; @@ -42,8 +42,10 @@ impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits { bindings: extracted .into_iter() .map(|expr| Binding { - ident: Ident::from_str_unchecked("___discarded"), - type_: expr.type_().clone(), + pat: Pattern::Id( + Ident::from_str_unchecked("___discarded"), + expr.type_().clone(), + ), body: expr, }) .collect(), diff --git a/users/grfn/achilles/src/tc/mod.rs b/users/grfn/achilles/src/tc/mod.rs index d27c45075e..5825bab1fb 100644 --- a/users/grfn/achilles/src/tc/mod.rs +++ b/users/grfn/achilles/src/tc/mod.rs @@ -8,7 +8,7 @@ use std::fmt::{self, Display}; use std::{mem, result}; use thiserror::Error; -use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal}; +use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal, Pattern}; use crate::common::env::Env; use crate::common::{Namer, NamerOf}; @@ -85,6 +85,7 @@ pub enum Type { Exist(TyVar), Nullary(NullaryType), Prim(PrimType), + Tuple(Vec), Unit, Fun { args: Vec, @@ -102,6 +103,9 @@ impl<'a> TryFrom for ast::Type<'a> { Type::Exist(_) => Err(value), Type::Nullary(_) => todo!(), Type::Prim(p) => Ok(p.into()), + Type::Tuple(members) => Ok(ast::Type::Tuple( + members.into_iter().map(|ty| ty.try_into()).try_collect()?, + )), Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType { args: args .clone() @@ -128,6 +132,7 @@ impl Display for Type { Type::Univ(TyVar(n)) => write!(f, "∀{}", n), Type::Exist(TyVar(n)) => write!(f, "∃{}", n), Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), + Type::Tuple(members) => write!(f, "({})", members.iter().join(", ")), Type::Unit => write!(f, "()"), } } @@ -159,6 +164,31 @@ impl<'ast> Typechecker<'ast> { } } + fn bind_pattern( + &mut self, + pat: Pattern<'ast>, + type_: Type, + ) -> Result> { + match pat { + Pattern::Id(ident) => { + self.env.set(ident.clone(), type_.clone()); + Ok(hir::Pattern::Id(ident, type_)) + } + Pattern::Tuple(members) => { + let mut tys = Vec::with_capacity(members.len()); + let mut hir_members = Vec::with_capacity(members.len()); + for pat in members { + let ty = self.fresh_ex(); + hir_members.push(self.bind_pattern(pat, ty.clone())?); + tys.push(ty); + } + let tuple_type = Type::Tuple(tys); + self.unify(&tuple_type, &type_)?; + Ok(hir::Pattern::Tuple(hir_members)) + } + } + } + pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result> { match expr { ast::Expr::Ident(ident) => { @@ -178,6 +208,14 @@ impl<'ast> Typechecker<'ast> { }; Ok(hir::Expr::Literal(lit.to_owned(), type_)) } + ast::Expr::Tuple(members) => { + let members = members + .into_iter() + .map(|expr| self.tc_expr(expr)) + .collect::>>()?; + let type_ = Type::Tuple(members.iter().map(|expr| expr.type_().clone()).collect()); + Ok(hir::Expr::Tuple(members, type_)) + } ast::Expr::UnaryOp { op, rhs } => todo!(), ast::Expr::BinaryOp { lhs, op, rhs } => { let lhs = self.tc_expr(*lhs)?; @@ -209,18 +247,14 @@ impl<'ast> Typechecker<'ast> { let bindings = bindings .into_iter() .map( - |ast::Binding { ident, type_, body }| -> Result> { + |ast::Binding { pat, type_, body }| -> Result> { let body = self.tc_expr(body)?; if let Some(type_) = type_ { let type_ = self.type_from_ast_type(type_); self.unify(body.type_(), &type_)?; } - self.env.set(ident.clone(), body.type_().clone()); - Ok(hir::Binding { - ident, - type_: body.type_().clone(), - body, - }) + let pat = self.bind_pattern(pat, body.type_().clone())?; + Ok(hir::Binding { pat, body }) }, ) .collect::>>>()?; @@ -382,7 +416,7 @@ impl<'ast> Typechecker<'ast> { fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result { match (ty1, ty2) { (Type::Unit, Type::Unit) => Ok(Type::Unit), - (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { + (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv)? { Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), Some(var @ ast::Type::Var(_)) => { let var = self.type_from_ast_type(var); @@ -419,6 +453,14 @@ impl<'ast> Typechecker<'ast> { } } (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), + (Type::Tuple(t1), Type::Tuple(t2)) if t1.len() == t2.len() => { + let ts = t1 + .iter() + .zip(t2.iter()) + .map(|(t1, t2)| self.unify(t1, t2)) + .try_collect()?; + Ok(Type::Tuple(ts)) + } ( Type::Fun { args: args1, @@ -469,11 +511,17 @@ impl<'ast> Typechecker<'ast> { fn finalize_type(&self, ty: Type) -> Result> { let ret = match ty { - Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), + Type::Exist(tv) => self.resolve_tv(tv)?.ok_or(Error::AmbiguousType(tv)), Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), Type::Unit => Ok(ast::Type::Unit), Type::Nullary(_) => todo!(), Type::Prim(pr) => Ok(pr.into()), + Type::Tuple(members) => Ok(ast::Type::Tuple( + members + .into_iter() + .map(|ty| self.finalize_type(ty)) + .try_collect()?, + )), Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { args: args .into_iter() @@ -485,12 +533,15 @@ impl<'ast> Typechecker<'ast> { ret } - fn resolve_tv(&self, tv: TyVar) -> Option> { + fn resolve_tv(&self, tv: TyVar) -> Result>> { let mut res = &Type::Exist(tv); - loop { + Ok(loop { match res { Type::Exist(tv) => { - res = self.ctx.get(tv)?; + res = match self.ctx.get(tv) { + Some(r) => r, + None => return Ok(None), + }; } Type::Univ(tv) => { let ident = self.name_univ(*tv); @@ -504,8 +555,9 @@ impl<'ast> Typechecker<'ast> { Type::Prim(pr) => break Some((*pr).into()), Type::Unit => break Some(ast::Type::Unit), Type::Fun { args, ret } => todo!(), + Type::Tuple(_) => break Some(self.finalize_type(res.clone())?), } - } + }) } fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { @@ -515,6 +567,12 @@ impl<'ast> Typechecker<'ast> { ast::Type::Float => FLOAT, ast::Type::Bool => BOOL, ast::Type::CString => CSTRING, + ast::Type::Tuple(members) => Type::Tuple( + members + .into_iter() + .map(|ty| self.type_from_ast_type(ty)) + .collect(), + ), ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { args: args .into_iter() @@ -582,6 +640,11 @@ impl<'ast> Typechecker<'ast> { (Type::Unit, _) => false, (Type::Nullary(_), _) => todo!(), (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, + (Type::Tuple(members), ast::Type::Tuple(members2)) => members + .iter() + .zip(members2.iter()) + .all(|(t1, t2)| self.types_match(t1, t2)), + (Type::Tuple(members), _) => false, (Type::Fun { args, ret }, ast::Type::Function(ft)) => { args.len() == ft.args.len() && args -- cgit 1.4.1