diff options
-rw-r--r-- | users/grfn/achilles/src/ast/hir.rs | 58 | ||||
-rw-r--r-- | users/grfn/achilles/src/ast/mod.rs | 41 | ||||
-rw-r--r-- | users/grfn/achilles/src/codegen/llvm.rs | 60 | ||||
-rw-r--r-- | users/grfn/achilles/src/interpreter/mod.rs | 27 | ||||
-rw-r--r-- | users/grfn/achilles/src/interpreter/value.rs | 21 | ||||
-rw-r--r-- | users/grfn/achilles/src/parser/expr.rs | 95 | ||||
-rw-r--r-- | users/grfn/achilles/src/parser/mod.rs | 1 | ||||
-rw-r--r-- | users/grfn/achilles/src/parser/type_.rs | 35 | ||||
-rw-r--r-- | users/grfn/achilles/src/parser/util.rs | 8 | ||||
-rw-r--r-- | users/grfn/achilles/src/passes/hir/mod.rs | 22 | ||||
-rw-r--r-- | users/grfn/achilles/src/passes/hir/strip_positive_units.rs | 8 | ||||
-rw-r--r-- | users/grfn/achilles/src/tc/mod.rs | 91 |
12 files changed, 413 insertions, 54 deletions
diff --git a/users/grfn/achilles/src/ast/hir.rs b/users/grfn/achilles/src/ast/hir.rs index 0d145d620bef..cdfaef567d7a 100644 --- a/users/grfn/achilles/src/ast/hir.rs +++ b/users/grfn/achilles/src/ast/hir.rs @@ -5,9 +5,42 @@ use itertools::Itertools; use super::{BinaryOperator, Ident, Literal, UnaryOperator}; #[derive(Debug, PartialEq, Eq, Clone)] +pub enum Pattern<'a, T> { + Id(Ident<'a>, T), + Tuple(Vec<Pattern<'a, T>>), +} + +impl<'a, T> Pattern<'a, T> { + pub fn to_owned(&self) -> Pattern<'static, T> + where + T: Clone, + { + match self { + Pattern::Id(id, t) => Pattern::Id(id.to_owned(), t.clone()), + Pattern::Tuple(pats) => { + Pattern::Tuple(pats.into_iter().map(Pattern::to_owned).collect()) + } + } + } + + pub fn traverse_type<F, U, E>(self, f: F) -> Result<Pattern<'a, U>, E> + where + F: Fn(T) -> Result<U, E> + Clone, + { + match self { + Pattern::Id(id, t) => Ok(Pattern::Id(id, f(t)?)), + Pattern::Tuple(pats) => Ok(Pattern::Tuple( + pats.into_iter() + .map(|pat| pat.traverse_type(f.clone())) + .collect::<Result<Vec<_>, _>>()?, + )), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding<'a, T> { - pub ident: Ident<'a>, - pub type_: T, + pub pat: Pattern<'a, T>, pub body: Expr<'a, T>, } @@ -17,8 +50,7 @@ impl<'a, T> Binding<'a, T> { T: Clone, { Binding { - ident: self.ident.to_owned(), - type_: self.type_.clone(), + pat: self.pat.to_owned(), body: self.body.to_owned(), } } @@ -30,6 +62,8 @@ pub enum Expr<'a, T> { Literal(Literal<'a>, T), + Tuple(Vec<Expr<'a, T>>, T), + UnaryOp { op: UnaryOperator, rhs: Box<Expr<'a, T>>, @@ -76,6 +110,7 @@ impl<'a, T> Expr<'a, T> { match self { Expr::Ident(_, t) => t, Expr::Literal(_, t) => t, + Expr::Tuple(_, t) => t, Expr::UnaryOp { type_, .. } => type_, Expr::BinaryOp { type_, .. } => type_, Expr::Let { type_, .. } => type_, @@ -115,10 +150,9 @@ impl<'a, T> Expr<'a, T> { } => Ok(Expr::Let { bindings: bindings .into_iter() - .map(|Binding { ident, type_, body }| { + .map(|Binding { pat, body }| { Ok(Binding { - ident, - type_: f(type_)?, + pat: pat.traverse_type(f.clone())?, body: body.traverse_type(f.clone())?, }) }) @@ -168,6 +202,13 @@ impl<'a, T> Expr<'a, T> { .collect::<Result<Vec<_>, E>>()?, type_: f(type_)?, }), + Expr::Tuple(members, t) => Ok(Expr::Tuple( + members + .into_iter() + .map(|t| t.traverse_type(f.clone())) + .try_collect()?, + f(t)?, + )), } } @@ -242,6 +283,9 @@ impl<'a, T> Expr<'a, T> { args: args.iter().map(|e| e.to_owned()).collect(), type_: type_.clone(), }, + Expr::Tuple(members, t) => { + Expr::Tuple(members.into_iter().map(Expr::to_owned).collect(), t.clone()) + } } } } diff --git a/users/grfn/achilles/src/ast/mod.rs b/users/grfn/achilles/src/ast/mod.rs index 7dc2de895709..5438d29d2cf7 100644 --- a/users/grfn/achilles/src/ast/mod.rs +++ b/users/grfn/achilles/src/ast/mod.rs @@ -128,8 +128,23 @@ impl<'a> Literal<'a> { } #[derive(Debug, PartialEq, Eq, Clone)] +pub enum Pattern<'a> { + Id(Ident<'a>), + Tuple(Vec<Pattern<'a>>), +} + +impl<'a> Pattern<'a> { + pub fn to_owned(&self) -> Pattern<'static> { + match self { + Pattern::Id(id) => Pattern::Id(id.to_owned()), + Pattern::Tuple(pats) => Pattern::Tuple(pats.iter().map(Pattern::to_owned).collect()), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding<'a> { - pub ident: Ident<'a>, + pub pat: Pattern<'a>, pub type_: Option<Type<'a>>, pub body: Expr<'a>, } @@ -137,7 +152,7 @@ pub struct Binding<'a> { impl<'a> Binding<'a> { fn to_owned(&self) -> Binding<'static> { Binding { - ident: self.ident.to_owned(), + pat: self.pat.to_owned(), type_: self.type_.as_ref().map(|t| t.to_owned()), body: self.body.to_owned(), } @@ -179,6 +194,8 @@ pub enum Expr<'a> { args: Vec<Expr<'a>>, }, + Tuple(Vec<Expr<'a>>), + Ascription { expr: Box<Expr<'a>>, type_: Type<'a>, @@ -190,6 +207,9 @@ impl<'a> Expr<'a> { match self { Expr::Ident(ref id) => Expr::Ident(id.to_owned()), Expr::Literal(ref lit) => Expr::Literal(lit.to_owned()), + Expr::Tuple(ref members) => { + Expr::Tuple(members.into_iter().map(Expr::to_owned).collect()) + } Expr::UnaryOp { op, rhs } => Expr::UnaryOp { op: *op, rhs: Box::new((**rhs).to_owned()), @@ -312,6 +332,7 @@ pub enum Type<'a> { Bool, CString, Unit, + Tuple(Vec<Type<'a>>), Var(Ident<'a>), Function(FunctionType<'a>), } @@ -326,6 +347,7 @@ impl<'a> Type<'a> { Type::Unit => Type::Unit, Type::Var(v) => Type::Var(v.to_owned()), Type::Function(f) => Type::Function(f.to_owned()), + Type::Tuple(members) => Type::Tuple(members.iter().map(Type::to_owned).collect()), } } @@ -379,9 +401,23 @@ impl<'a> Type<'a> { Type::Float => Type::Float, Type::Bool => Type::Bool, Type::CString => Type::CString, + Type::Tuple(members) => Type::Tuple( + members + .into_iter() + .map(|t| t.traverse_type_vars(f.clone())) + .collect(), + ), Type::Unit => Type::Unit, } } + + pub fn as_tuple(&self) -> Option<&Vec<Type<'a>>> { + if let Self::Tuple(v) = self { + Some(v) + } else { + None + } + } } impl<'a> Display for Type<'a> { @@ -394,6 +430,7 @@ impl<'a> Display for Type<'a> { Type::Unit => f.write_str("()"), Type::Var(v) => v.fmt(f), Type::Function(ft) => ft.fmt(f), + Type::Tuple(ms) => write!(f, "({})", ms.iter().join(", ")), } } } diff --git a/users/grfn/achilles/src/codegen/llvm.rs b/users/grfn/achilles/src/codegen/llvm.rs index 17dec58b5ff7..9a71ac954e00 100644 --- a/users/grfn/achilles/src/codegen/llvm.rs +++ b/users/grfn/achilles/src/codegen/llvm.rs @@ -7,12 +7,13 @@ use inkwell::builder::Builder; pub use inkwell::context::Context; use inkwell::module::Module; use inkwell::support::LLVMString; -use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType}; -use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; +use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType, StructType}; +use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue, StructValue}; use inkwell::{AddressSpace, IntPredicate}; +use itertools::Itertools; use thiserror::Error; -use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::hir::{Binding, Decl, Expr, Pattern}; use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; @@ -82,6 +83,25 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .append_basic_block(*self.function_stack.last().unwrap(), name) } + fn bind_pattern(&mut self, pat: &'ast Pattern<'ast, Type>, val: AnyValueEnum<'ctx>) { + match pat { + Pattern::Id(id, _) => self.env.set(id, val), + Pattern::Tuple(pats) => { + for (i, pat) in pats.iter().enumerate() { + let member = self + .builder + .build_extract_value( + StructValue::try_from(val).unwrap(), + i as _, + "pat_bind", + ) + .unwrap(); + self.bind_pattern(pat, member.into()); + } + } + } + } + pub fn codegen_expr( &mut self, expr: &'ast Expr<'ast, Type>, @@ -164,9 +184,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { } Expr::Let { bindings, body, .. } => { self.env.push(); - for Binding { ident, body, .. } in bindings { + for Binding { pat, body, .. } in bindings { if let Some(val) = self.codegen_expr(body)? { - self.env.set(ident, val); + self.bind_pattern(pat, val); } } let res = self.codegen_expr(body); @@ -244,6 +264,19 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.env.restore(env); Ok(Some(function.into())) } + Expr::Tuple(members, ty) => { + let values = members + .into_iter() + .map(|expr| self.codegen_expr(expr)) + .collect::<Result<Vec<_>>>()? + .into_iter() + .filter_map(|x| x) + .map(|x| x.try_into().unwrap()) + .collect_vec(); + let field_types = ty.as_tuple().unwrap(); + let tuple_type = self.codegen_tuple_type(field_types); + Ok(Some(tuple_type.const_named_struct(&values).into())) + } } } @@ -341,9 +374,20 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { Type::Function(_) => todo!(), Type::Var(_) => unreachable!(), Type::Unit => None, + Type::Tuple(ts) => Some(self.codegen_tuple_type(ts).into()), } } + fn codegen_tuple_type(&self, ts: &'ast [Type]) -> StructType<'ctx> { + self.context.struct_type( + ts.iter() + .filter_map(|t| self.codegen_type(t)) + .collect_vec() + .as_slice(), + false, + ) + } + fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> { // TODO self.context.i64_type() @@ -433,4 +477,10 @@ mod tests { let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap(); assert_eq!(res, 1); } + + #[test] + fn bind_tuple_pattern() { + let res = jit_eval::<i64>("let (x, y) = (1, 2) in x + y").unwrap(); + assert_eq!(res, 3); + } } diff --git a/users/grfn/achilles/src/interpreter/mod.rs b/users/grfn/achilles/src/interpreter/mod.rs index a8ba2dd3acdc..70df7a0724a5 100644 --- a/users/grfn/achilles/src/interpreter/mod.rs +++ b/users/grfn/achilles/src/interpreter/mod.rs @@ -1,9 +1,12 @@ mod error; mod value; +use itertools::Itertools; +use value::Val; + pub use self::error::{Error, Result}; pub use self::value::{Function, Value}; -use crate::ast::hir::{Binding, Expr}; +use crate::ast::hir::{Binding, Expr, Pattern}; use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; @@ -24,6 +27,17 @@ impl<'a> Interpreter<'a> { .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) } + fn bind_pattern(&mut self, pattern: &'a Pattern<'a, Type>, value: Value<'a>) { + match pattern { + Pattern::Id(id, _) => self.env.set(id, value), + Pattern::Tuple(pats) => { + for (pat, val) in pats.iter().zip(value.as_tuple().unwrap().clone()) { + self.bind_pattern(pat, val); + } + } + } + } + pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result<Value<'a>> { let res = match expr { Expr::Ident(id, _) => self.resolve(id), @@ -53,9 +67,9 @@ impl<'a> Interpreter<'a> { } Expr::Let { bindings, body, .. } => { self.env.push(); - for Binding { ident, body, .. } in bindings { + for Binding { pat, body, .. } in bindings { let val = self.eval(body)?; - self.env.set(ident, val); + self.bind_pattern(pat, val); } let res = self.eval(body)?; self.env.pop(); @@ -115,6 +129,13 @@ impl<'a> Interpreter<'a> { body: (**body).to_owned(), })) } + Expr::Tuple(members, _) => Ok(Val::Tuple( + members + .into_iter() + .map(|expr| self.eval(expr)) + .try_collect()?, + ) + .into()), }?; debug_assert_eq!(&res.type_(), expr.type_()); Ok(res) diff --git a/users/grfn/achilles/src/interpreter/value.rs b/users/grfn/achilles/src/interpreter/value.rs index 55ba42f9de58..272d1167a33c 100644 --- a/users/grfn/achilles/src/interpreter/value.rs +++ b/users/grfn/achilles/src/interpreter/value.rs @@ -6,6 +6,7 @@ use std::rc::Rc; use std::result; use derive_more::{Deref, From, TryInto}; +use itertools::Itertools; use super::{Error, Result}; use crate::ast::hir::Expr; @@ -25,6 +26,7 @@ pub enum Val<'a> { Float(f64), Bool(bool), String(Cow<'a, str>), + Tuple(Vec<Value<'a>>), Function(Function<'a>), } @@ -49,6 +51,7 @@ impl<'a> fmt::Debug for Val<'a> { Val::Function(Function { type_, .. }) => { f.debug_struct("Function").field("type_", type_).finish() } + Val::Tuple(members) => f.debug_tuple("Tuple").field(members).finish(), } } } @@ -79,6 +82,7 @@ impl<'a> Display for Val<'a> { Val::Bool(x) => x.fmt(f), Val::String(s) => write!(f, "{:?}", s), Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_), + Val::Tuple(members) => write!(f, "({})", members.iter().join(", ")), } } } @@ -91,6 +95,7 @@ impl<'a> Val<'a> { Val::Bool(_) => Type::Bool, Val::String(_) => Type::CString, Val::Function(Function { type_, .. }) => Type::Function(type_.clone()), + Val::Tuple(members) => Type::Tuple(members.iter().map(|expr| expr.type_()).collect()), } } @@ -114,6 +119,22 @@ impl<'a> Val<'a> { }), } } + + pub fn as_tuple(&self) -> Option<&Vec<Value<'a>>> { + if let Self::Tuple(v) = self { + Some(v) + } else { + None + } + } + + pub fn try_into_tuple(self) -> result::Result<Vec<Value<'a>>, Self> { + if let Self::Tuple(v) = self { + Ok(v) + } else { + Err(self) + } + } } #[derive(Debug, PartialEq, Clone, Deref)] diff --git a/users/grfn/achilles/src/parser/expr.rs b/users/grfn/achilles/src/parser/expr.rs index 8a28d00984c9..f596b18970aa 100644 --- a/users/grfn/achilles/src/parser/expr.rs +++ b/users/grfn/achilles/src/parser/expr.rs @@ -8,7 +8,8 @@ use nom::{ }; use pratt::{Affix, Associativity, PrattParser, Precedence}; -use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, UnaryOperator}; +use super::util::comma; +use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, Pattern, UnaryOperator}; use crate::parser::{arg, ident, type_}; #[derive(Debug)] @@ -192,9 +193,45 @@ named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit)); named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal)); +named!(tuple(&str) -> Expr, do_parse!( + complete!(tag!("(")) + >> multispace0 + >> fst: expr + >> comma + >> rest: separated_list0!( + comma, + expr + ) + >> multispace0 + >> tag!(")") + >> ({ + let mut members = Vec::with_capacity(rest.len() + 1); + members.push(fst); + members.append(&mut rest.clone()); + Expr::Tuple(members) + }) +)); + +named!(tuple_pattern(&str) -> Pattern, do_parse!( + complete!(tag!("(")) + >> multispace0 + >> pats: separated_list0!( + comma, + pattern + ) + >> multispace0 + >> tag!(")") + >> (Pattern::Tuple(pats)) +)); + +named!(pattern(&str) -> Pattern, alt!( + ident => { |id| Pattern::Id(id) } | + tuple_pattern +)); + named!(binding(&str) -> Binding, do_parse!( multispace0 - >> ident: ident + >> pat: pattern >> multispace0 >> type_: opt!(preceded!(tuple!(tag!(":"), multispace0), type_)) >> multispace0 @@ -202,7 +239,7 @@ named!(binding(&str) -> Binding, do_parse!( >> multispace0 >> body: expr >> (Binding { - ident, + pat, type_, body }) @@ -267,6 +304,7 @@ named!(paren_expr(&str) -> Expr, named!(funcref(&str) -> Expr, alt!( ident_expr | + tuple | paren_expr )); @@ -296,6 +334,7 @@ named!(fun_expr(&str) -> Expr, do_parse!( named!(fn_arg(&str) -> Expr, alt!( ident_expr | literal_expr | + tuple | paren_expr )); @@ -314,7 +353,8 @@ named!(simple_expr_unascripted(&str) -> Expr, alt!( if_ | fun_expr | literal_expr | - ident_expr + ident_expr | + tuple )); named!(simple_expr(&str) -> Expr, alt!( @@ -334,7 +374,7 @@ named!(pub expr(&str) -> Expr, alt!( #[cfg(test)] pub(crate) mod tests { use super::*; - use crate::ast::{Arg, Ident, Type}; + use crate::ast::{Arg, Ident, Pattern, Type}; use std::convert::TryFrom; use BinaryOperator::*; use Expr::{BinaryOp, If, Let, UnaryOp}; @@ -450,6 +490,17 @@ pub(crate) mod tests { } #[test] + fn tuple() { + assert_eq!( + test_parse!(expr, "(1, \"seven\")"), + Expr::Tuple(vec![ + Expr::Literal(Literal::Int(1)), + Expr::Literal(Literal::String(Cow::Borrowed("seven"))) + ]) + ) + } + + #[test] fn simple_string_lit() { assert_eq!( test_parse!(expr, "\"foobar\""), @@ -465,12 +516,12 @@ pub(crate) mod tests { Let { bindings: vec![ Binding { - ident: Ident::try_from("x").unwrap(), + pat: Pattern::Id(Ident::try_from("x").unwrap()), type_: None, body: Expr::Literal(Literal::Int(1)) }, Binding { - ident: Ident::try_from("y").unwrap(), + pat: Pattern::Id(Ident::try_from("y").unwrap()), type_: None, body: Expr::BinaryOp { lhs: ident_expr("x"), @@ -553,7 +604,7 @@ pub(crate) mod tests { Expr::Call { fun: Box::new(Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("x").unwrap(), + pat: Pattern::Id(Ident::try_from("x").unwrap()), type_: None, body: Expr::Literal(Literal::Int(1)) }], @@ -571,7 +622,7 @@ pub(crate) mod tests { res, Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("id").unwrap(), + pat: Pattern::Id(Ident::try_from("id").unwrap()), type_: None, body: Expr::Fun(Box::new(Fun { args: vec![Arg::try_from("x").unwrap()], @@ -586,6 +637,28 @@ pub(crate) mod tests { ); } + #[test] + fn tuple_binding() { + let res = test_parse!(expr, "let (x, y) = (1, 2) in x"); + assert_eq!( + res, + Expr::Let { + bindings: vec![Binding { + pat: Pattern::Tuple(vec![ + Pattern::Id(Ident::from_str_unchecked("x")), + Pattern::Id(Ident::from_str_unchecked("y")) + ]), + body: Expr::Tuple(vec![ + Expr::Literal(Literal::Int(1)), + Expr::Literal(Literal::Int(2)) + ]), + type_: None + }], + body: Box::new(Expr::Ident(Ident::from_str_unchecked("x"))) + } + ) + } + mod ascriptions { use super::*; @@ -608,7 +681,7 @@ pub(crate) mod tests { res, Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("const_1").unwrap(), + pat: Pattern::Id(Ident::try_from("const_1").unwrap()), type_: None, body: Expr::Fun(Box::new(Fun { args: vec![Arg::try_from("x").unwrap()], @@ -633,7 +706,7 @@ pub(crate) mod tests { res, Expr::Let { bindings: vec![Binding { - ident: Ident::try_from("x").unwrap(), + pat: Pattern::Id(Ident::try_from("x").unwrap()), type_: Some(Type::Int), body: Expr::Literal(Literal::Int(1)) }], diff --git a/users/grfn/achilles/src/parser/mod.rs b/users/grfn/achilles/src/parser/mod.rs index 3e0081bd391d..e088cbca10a5 100644 --- a/users/grfn/achilles/src/parser/mod.rs +++ b/users/grfn/achilles/src/parser/mod.rs @@ -6,6 +6,7 @@ use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag pub(crate) mod macros; mod expr; mod type_; +mod util; use crate::ast::{Arg, Decl, Fun, Ident}; pub use expr::expr; diff --git a/users/grfn/achilles/src/parser/type_.rs b/users/grfn/achilles/src/parser/type_.rs index 8a1081e2521f..b80f0e0860a1 100644 --- a/users/grfn/achilles/src/parser/type_.rs +++ b/users/grfn/achilles/src/parser/type_.rs @@ -2,17 +2,14 @@ use nom::character::complete::{multispace0, multispace1}; use nom::{alt, delimited, do_parse, map, named, opt, separated_list0, tag, terminated, tuple}; use super::ident; +use super::util::comma; use crate::ast::{FunctionType, Type}; named!(pub function_type(&str) -> FunctionType, do_parse!( tag!("fn") >> multispace1 >> args: map!(opt!(terminated!(separated_list0!( - tuple!( - multispace0, - tag!(","), - multispace0 - ), + comma, type_ ), multispace1)), |args| args.unwrap_or_default()) >> tag!("->") @@ -24,12 +21,32 @@ named!(pub function_type(&str) -> FunctionType, do_parse!( }) )); +named!(tuple_type(&str) -> Type, do_parse!( + tag!("(") + >> multispace0 + >> fst: type_ + >> comma + >> rest: separated_list0!( + comma, + type_ + ) + >> multispace0 + >> tag!(")") + >> ({ + let mut members = Vec::with_capacity(rest.len() + 1); + members.push(fst); + members.append(&mut rest.clone()); + Type::Tuple(members) + }) +)); + named!(pub type_(&str) -> Type, alt!( tag!("int") => { |_| Type::Int } | tag!("float") => { |_| Type::Float } | tag!("bool") => { |_| Type::Bool } | tag!("cstring") => { |_| Type::CString } | tag!("()") => { |_| Type::Unit } | + tuple_type | function_type => { |ft| Type::Function(ft) }| ident => { |id| Type::Var(id) } | delimited!( @@ -112,6 +129,14 @@ mod tests { } #[test] + fn tuple() { + assert_eq!( + test_parse!(type_, "(int, int)"), + Type::Tuple(vec![Type::Int, Type::Int]) + ) + } + + #[test] fn type_vars() { assert_eq!( test_parse!(type_, "fn x, y -> x"), diff --git a/users/grfn/achilles/src/parser/util.rs b/users/grfn/achilles/src/parser/util.rs new file mode 100644 index 000000000000..bb53fb7fff50 --- /dev/null +++ b/users/grfn/achilles/src/parser/util.rs @@ -0,0 +1,8 @@ +use nom::character::complete::multispace0; +use nom::{complete, map, named, tag, tuple}; + +named!(pub(crate) comma(&str) -> (), map!(tuple!( + multispace0, + complete!(tag!(",")), + multispace0 +) ,|_| ())); diff --git a/users/grfn/achilles/src/passes/hir/mod.rs b/users/grfn/achilles/src/passes/hir/mod.rs index 845bfcb7ab6a..872c449eb020 100644 --- a/users/grfn/achilles/src/passes/hir/mod.rs +++ b/users/grfn/achilles/src/passes/hir/mod.rs @@ -1,6 +1,6 @@ use std::collections::HashMap; -use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::hir::{Binding, Decl, Expr, Pattern}; use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; pub(crate) mod monomorphize; @@ -29,9 +29,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { Ok(()) } + fn visit_pattern(&mut self, _pat: &mut Pattern<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> { - self.visit_ident(&mut binding.ident)?; - self.visit_type(&mut binding.type_)?; + self.visit_pattern(&mut binding.pat)?; self.visit_expr(&mut binding.body)?; Ok(()) } @@ -54,6 +57,13 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { Ok(()) } + fn visit_tuple(&mut self, members: &mut Vec<Expr<'ast, T>>) -> Result<(), Self::Error> { + for expr in members { + self.visit_expr(expr)?; + } + Ok(()) + } + fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { Ok(()) } @@ -137,12 +147,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { self.visit_type(type_)?; self.post_visit_call(fun, type_args, args)?; } + Expr::Tuple(tup, type_) => { + self.visit_tuple(tup)?; + self.visit_type(type_)?; + } } Ok(()) } - fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { + fn post_visit_decl(&mut self, _decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { Ok(()) } diff --git a/users/grfn/achilles/src/passes/hir/strip_positive_units.rs b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs index 91b56551c82d..85ee1cce4859 100644 --- a/users/grfn/achilles/src/passes/hir/strip_positive_units.rs +++ b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs @@ -1,7 +1,7 @@ use std::collections::HashMap; use std::mem; -use ast::hir::Binding; +use ast::hir::{Binding, Pattern}; use ast::Literal; use void::{ResultVoidExt, Void}; @@ -42,8 +42,10 @@ impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits { bindings: extracted .into_iter() .map(|expr| Binding { - ident: Ident::from_str_unchecked("___discarded"), - type_: expr.type_().clone(), + pat: Pattern::Id( + Ident::from_str_unchecked("___discarded"), + expr.type_().clone(), + ), body: expr, }) .collect(), diff --git a/users/grfn/achilles/src/tc/mod.rs b/users/grfn/achilles/src/tc/mod.rs index d27c45075e97..5825bab1fbe9 100644 --- a/users/grfn/achilles/src/tc/mod.rs +++ b/users/grfn/achilles/src/tc/mod.rs @@ -8,7 +8,7 @@ use std::fmt::{self, Display}; use std::{mem, result}; use thiserror::Error; -use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal}; +use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal, Pattern}; use crate::common::env::Env; use crate::common::{Namer, NamerOf}; @@ -85,6 +85,7 @@ pub enum Type { Exist(TyVar), Nullary(NullaryType), Prim(PrimType), + Tuple(Vec<Type>), Unit, Fun { args: Vec<Type>, @@ -102,6 +103,9 @@ impl<'a> TryFrom<Type> for ast::Type<'a> { Type::Exist(_) => Err(value), Type::Nullary(_) => todo!(), Type::Prim(p) => Ok(p.into()), + Type::Tuple(members) => Ok(ast::Type::Tuple( + members.into_iter().map(|ty| ty.try_into()).try_collect()?, + )), Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType { args: args .clone() @@ -128,6 +132,7 @@ impl Display for Type { Type::Univ(TyVar(n)) => write!(f, "∀{}", n), Type::Exist(TyVar(n)) => write!(f, "∃{}", n), Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), + Type::Tuple(members) => write!(f, "({})", members.iter().join(", ")), Type::Unit => write!(f, "()"), } } @@ -159,6 +164,31 @@ impl<'ast> Typechecker<'ast> { } } + fn bind_pattern( + &mut self, + pat: Pattern<'ast>, + type_: Type, + ) -> Result<hir::Pattern<'ast, Type>> { + match pat { + Pattern::Id(ident) => { + self.env.set(ident.clone(), type_.clone()); + Ok(hir::Pattern::Id(ident, type_)) + } + Pattern::Tuple(members) => { + let mut tys = Vec::with_capacity(members.len()); + let mut hir_members = Vec::with_capacity(members.len()); + for pat in members { + let ty = self.fresh_ex(); + hir_members.push(self.bind_pattern(pat, ty.clone())?); + tys.push(ty); + } + let tuple_type = Type::Tuple(tys); + self.unify(&tuple_type, &type_)?; + Ok(hir::Pattern::Tuple(hir_members)) + } + } + } + pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result<hir::Expr<'ast, Type>> { match expr { ast::Expr::Ident(ident) => { @@ -178,6 +208,14 @@ impl<'ast> Typechecker<'ast> { }; Ok(hir::Expr::Literal(lit.to_owned(), type_)) } + ast::Expr::Tuple(members) => { + let members = members + .into_iter() + .map(|expr| self.tc_expr(expr)) + .collect::<Result<Vec<_>>>()?; + let type_ = Type::Tuple(members.iter().map(|expr| expr.type_().clone()).collect()); + Ok(hir::Expr::Tuple(members, type_)) + } ast::Expr::UnaryOp { op, rhs } => todo!(), ast::Expr::BinaryOp { lhs, op, rhs } => { let lhs = self.tc_expr(*lhs)?; @@ -209,18 +247,14 @@ impl<'ast> Typechecker<'ast> { let bindings = bindings .into_iter() .map( - |ast::Binding { ident, type_, body }| -> Result<hir::Binding<Type>> { + |ast::Binding { pat, type_, body }| -> Result<hir::Binding<Type>> { let body = self.tc_expr(body)?; if let Some(type_) = type_ { let type_ = self.type_from_ast_type(type_); self.unify(body.type_(), &type_)?; } - self.env.set(ident.clone(), body.type_().clone()); - Ok(hir::Binding { - ident, - type_: body.type_().clone(), - body, - }) + let pat = self.bind_pattern(pat, body.type_().clone())?; + Ok(hir::Binding { pat, body }) }, ) .collect::<Result<Vec<hir::Binding<Type>>>>()?; @@ -382,7 +416,7 @@ impl<'ast> Typechecker<'ast> { fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> { match (ty1, ty2) { (Type::Unit, Type::Unit) => Ok(Type::Unit), - (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { + (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv)? { Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), Some(var @ ast::Type::Var(_)) => { let var = self.type_from_ast_type(var); @@ -419,6 +453,14 @@ impl<'ast> Typechecker<'ast> { } } (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), + (Type::Tuple(t1), Type::Tuple(t2)) if t1.len() == t2.len() => { + let ts = t1 + .iter() + .zip(t2.iter()) + .map(|(t1, t2)| self.unify(t1, t2)) + .try_collect()?; + Ok(Type::Tuple(ts)) + } ( Type::Fun { args: args1, @@ -469,11 +511,17 @@ impl<'ast> Typechecker<'ast> { fn finalize_type(&self, ty: Type) -> Result<ast::Type<'static>> { let ret = match ty { - Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), + Type::Exist(tv) => self.resolve_tv(tv)?.ok_or(Error::AmbiguousType(tv)), Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), Type::Unit => Ok(ast::Type::Unit), Type::Nullary(_) => todo!(), Type::Prim(pr) => Ok(pr.into()), + Type::Tuple(members) => Ok(ast::Type::Tuple( + members + .into_iter() + .map(|ty| self.finalize_type(ty)) + .try_collect()?, + )), Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { args: args .into_iter() @@ -485,12 +533,15 @@ impl<'ast> Typechecker<'ast> { ret } - fn resolve_tv(&self, tv: TyVar) -> Option<ast::Type<'static>> { + fn resolve_tv(&self, tv: TyVar) -> Result<Option<ast::Type<'static>>> { let mut res = &Type::Exist(tv); - loop { + Ok(loop { match res { Type::Exist(tv) => { - res = self.ctx.get(tv)?; + res = match self.ctx.get(tv) { + Some(r) => r, + None => return Ok(None), + }; } Type::Univ(tv) => { let ident = self.name_univ(*tv); @@ -504,8 +555,9 @@ impl<'ast> Typechecker<'ast> { Type::Prim(pr) => break Some((*pr).into()), Type::Unit => break Some(ast::Type::Unit), Type::Fun { args, ret } => todo!(), + Type::Tuple(_) => break Some(self.finalize_type(res.clone())?), } - } + }) } fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { @@ -515,6 +567,12 @@ impl<'ast> Typechecker<'ast> { ast::Type::Float => FLOAT, ast::Type::Bool => BOOL, ast::Type::CString => CSTRING, + ast::Type::Tuple(members) => Type::Tuple( + members + .into_iter() + .map(|ty| self.type_from_ast_type(ty)) + .collect(), + ), ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { args: args .into_iter() @@ -582,6 +640,11 @@ impl<'ast> Typechecker<'ast> { (Type::Unit, _) => false, (Type::Nullary(_), _) => todo!(), (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, + (Type::Tuple(members), ast::Type::Tuple(members2)) => members + .iter() + .zip(members2.iter()) + .all(|(t1, t2)| self.types_match(t1, t2)), + (Type::Tuple(members), _) => false, (Type::Fun { args, ret }, ast::Type::Function(ft)) => { args.len() == ft.args.len() && args |