From 8e13b1303a0d152c2f3b68f2421163e94fdf226c Mon Sep 17 00:00:00 2001 From: Griffin Smith Date: Sun, 28 Mar 2021 13:28:49 -0400 Subject: feat(achilles): Implement a Unit type Add support for a zero-sized Unit type. This requires some special at the codegen level because LLVM (unsurprisingly) only allows Void types in function return position - to make that a little easier to handle there's a new pass that strips any unit-only expressions and pulls unit-only function arguments up to new `let` bindings, so we never have to actually pass around unit values. Change-Id: I0fc18a516821f2d69172c42a6a5d246b23471e38 Reviewed-on: https://cl.tvl.fyi/c/depot/+/2695 Reviewed-by: glittershark Tested-by: BuildkiteCI --- users/glittershark/achilles/src/ast/hir.rs | 2 +- users/glittershark/achilles/src/ast/mod.rs | 7 + users/glittershark/achilles/src/codegen/llvm.rs | 187 +++++++++++--------- users/glittershark/achilles/src/compiler.rs | 5 +- users/glittershark/achilles/src/interpreter/mod.rs | 1 + users/glittershark/achilles/src/parser/expr.rs | 22 ++- users/glittershark/achilles/src/parser/mod.rs | 27 ++- users/glittershark/achilles/src/parser/type_.rs | 2 + users/glittershark/achilles/src/passes/hir/mod.rs | 22 ++- .../src/passes/hir/strip_positive_units.rs | 189 +++++++++++++++++++++ users/glittershark/achilles/src/tc/mod.rs | 10 ++ 11 files changed, 386 insertions(+), 88 deletions(-) create mode 100644 users/glittershark/achilles/src/passes/hir/strip_positive_units.rs (limited to 'users/glittershark/achilles/src') diff --git a/users/glittershark/achilles/src/ast/hir.rs b/users/glittershark/achilles/src/ast/hir.rs index 8726af509388..0d145d620bef 100644 --- a/users/glittershark/achilles/src/ast/hir.rs +++ b/users/glittershark/achilles/src/ast/hir.rs @@ -246,7 +246,7 @@ impl<'a, T> Expr<'a, T> { } } -#[derive(Debug, Clone)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum Decl<'a, T> { Fun { name: Ident<'a>, diff --git a/users/glittershark/achilles/src/ast/mod.rs b/users/glittershark/achilles/src/ast/mod.rs index 53f222a6a11a..7dc2de895709 100644 --- a/users/glittershark/achilles/src/ast/mod.rs +++ b/users/glittershark/achilles/src/ast/mod.rs @@ -30,6 +30,7 @@ impl<'a> Ident<'a> { Ident(Cow::Owned(self.0.clone().into_owned())) } + /// Construct an identifier from a &str without checking that it's a valid identifier pub fn from_str_unchecked(s: &'a str) -> Self { debug_assert!(is_valid_identifier(s)); Self(Cow::Borrowed(s)) @@ -109,6 +110,7 @@ pub enum UnaryOperator { #[derive(Debug, PartialEq, Eq, Clone)] pub enum Literal<'a> { + Unit, Int(u64), Bool(bool), String(Cow<'a, str>), @@ -120,6 +122,7 @@ impl<'a> Literal<'a> { Literal::Int(i) => Literal::Int(*i), Literal::Bool(b) => Literal::Bool(*b), Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())), + Literal::Unit => Literal::Unit, } } } @@ -308,6 +311,7 @@ pub enum Type<'a> { Float, Bool, CString, + Unit, Var(Ident<'a>), Function(FunctionType<'a>), } @@ -319,6 +323,7 @@ impl<'a> Type<'a> { Type::Float => Type::Float, Type::Bool => Type::Bool, Type::CString => Type::CString, + Type::Unit => Type::Unit, Type::Var(v) => Type::Var(v.to_owned()), Type::Function(f) => Type::Function(f.to_owned()), } @@ -374,6 +379,7 @@ impl<'a> Type<'a> { Type::Float => Type::Float, Type::Bool => Type::Bool, Type::CString => Type::CString, + Type::Unit => Type::Unit, } } } @@ -385,6 +391,7 @@ impl<'a> Display for Type<'a> { Type::Float => f.write_str("float"), Type::Bool => f.write_str("bool"), Type::CString => f.write_str("cstring"), + Type::Unit => f.write_str("()"), Type::Var(v) => v.fmt(f), Type::Function(ft) => ft.fmt(f), } diff --git a/users/glittershark/achilles/src/codegen/llvm.rs b/users/glittershark/achilles/src/codegen/llvm.rs index f49e084a8174..17dec58b5ff7 100644 --- a/users/glittershark/achilles/src/codegen/llvm.rs +++ b/users/glittershark/achilles/src/codegen/llvm.rs @@ -68,8 +68,12 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.function_stack.last().unwrap() } - pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> { - self.builder.build_return(Some(res)); + pub fn finish_function(&mut self, res: Option<&BasicValueEnum<'ctx>>) -> FunctionValue<'ctx> { + self.builder.build_return(match res { + // lol + Some(val) => Some(val), + None => None, + }); self.function_stack.pop().unwrap() } @@ -78,79 +82,92 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .append_basic_block(*self.function_stack.last().unwrap(), name) } - pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result> { + pub fn codegen_expr( + &mut self, + expr: &'ast Expr<'ast, Type>, + ) -> Result>> { match expr { Expr::Ident(id, _) => self .env .resolve(id) .cloned() - .ok_or_else(|| Error::UndefinedVariable(id.to_owned())), + .ok_or_else(|| Error::UndefinedVariable(id.to_owned())) + .map(Some), Expr::Literal(lit, ty) => { let ty = self.codegen_int_type(ty); match lit { - Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))), - Literal::Bool(b) => Ok(AnyValueEnum::IntValue( + Literal::Int(i) => Ok(Some(AnyValueEnum::IntValue(ty.const_int(*i, false)))), + Literal::Bool(b) => Ok(Some(AnyValueEnum::IntValue( ty.const_int(if *b { 1 } else { 0 }, false), + ))), + Literal::String(s) => Ok(Some( + self.builder + .build_global_string_ptr(s, "s") + .as_pointer_value() + .into(), )), - Literal::String(s) => Ok(self - .builder - .build_global_string_ptr(s, "s") - .as_pointer_value() - .into()), + Literal::Unit => Ok(None), } } Expr::UnaryOp { op, rhs, .. } => { - let rhs = self.codegen_expr(rhs)?; + let rhs = self.codegen_expr(rhs)?.unwrap(); match op { UnaryOperator::Not => unimplemented!(), - UnaryOperator::Neg => Ok(AnyValueEnum::IntValue( + UnaryOperator::Neg => Ok(Some(AnyValueEnum::IntValue( self.builder.build_int_neg(rhs.into_int_value(), "neg"), - )), + ))), } } Expr::BinaryOp { lhs, op, rhs, .. } => { - let lhs = self.codegen_expr(lhs)?; - let rhs = self.codegen_expr(rhs)?; + let lhs = self.codegen_expr(lhs)?.unwrap(); + let rhs = self.codegen_expr(rhs)?.unwrap(); match op { - BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))), - BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))), - BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))), - BinaryOperator::Div => { - Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div( + BinaryOperator::Add => { + Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_add( lhs.into_int_value(), rhs.into_int_value(), "add", - ))) + )))) } + BinaryOperator::Sub => { + Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + )))) + } + BinaryOperator::Mul => { + Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + )))) + } + BinaryOperator::Div => Ok(Some(AnyValueEnum::IntValue( + self.builder.build_int_signed_div( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ), + ))), BinaryOperator::Pow => unimplemented!(), - BinaryOperator::Equ => { - Ok(AnyValueEnum::IntValue(self.builder.build_int_compare( + BinaryOperator::Equ => Ok(Some(AnyValueEnum::IntValue( + self.builder.build_int_compare( IntPredicate::EQ, lhs.into_int_value(), rhs.into_int_value(), "eq", - ))) - } + ), + ))), BinaryOperator::Neq => todo!(), } } Expr::Let { bindings, body, .. } => { self.env.push(); for Binding { ident, body, .. } in bindings { - let val = self.codegen_expr(body)?; - self.env.set(ident, val); + if let Some(val) = self.codegen_expr(body)? { + self.env.set(ident, val); + } } let res = self.codegen_expr(body); self.env.pop(); @@ -165,7 +182,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { let then_block = self.append_basic_block("then"); let else_block = self.append_basic_block("else"); let join_block = self.append_basic_block("join"); - let condition = self.codegen_expr(condition)?; + let condition = self.codegen_expr(condition)?.unwrap(); self.builder.build_conditional_branch( condition.into_int_value(), then_block, @@ -180,12 +197,22 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.builder.build_unconditional_branch(join_block); self.builder.position_at_end(join_block); - let phi = self.builder.build_phi(self.codegen_type(type_), "join"); - phi.add_incoming(&[ - (&BasicValueEnum::try_from(then_res).unwrap(), then_block), - (&BasicValueEnum::try_from(else_res).unwrap(), else_block), - ]); - Ok(phi.as_basic_value().into()) + if let Some(phi_type) = self.codegen_type(type_) { + let phi = self.builder.build_phi(phi_type, "join"); + phi.add_incoming(&[ + ( + &BasicValueEnum::try_from(then_res.unwrap()).unwrap(), + then_block, + ), + ( + &BasicValueEnum::try_from(else_res.unwrap()).unwrap(), + else_block, + ), + ]); + Ok(Some(phi.as_basic_value().into())) + } else { + Ok(None) + } } Expr::Call { fun, args, .. } => { if let Expr::Ident(id, _) = &**fun { @@ -196,15 +223,14 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?; let args = args .iter() - .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap())) + .map(|arg| Ok(self.codegen_expr(arg)?.unwrap().try_into().unwrap())) .collect::>>()?; Ok(self .builder .build_call(function, &args, "call") .try_as_basic_value() .left() - .unwrap() - .into()) + .map(|val| val.into())) } else { todo!() } @@ -216,7 +242,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { let function = self.codegen_function(&fname, args, body)?; self.builder.position_at_end(cur_block); self.env.restore(env); - Ok(function.into()) + Ok(Some(function.into())) } } } @@ -227,15 +253,17 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { args: &'ast [(Ident<'ast>, Type)], body: &'ast Expr<'ast, Type>, ) -> Result> { + let arg_types = args + .iter() + .filter_map(|(_, at)| self.codegen_type(at)) + .collect::>(); + self.new_function( name, - self.codegen_type(body.type_()).fn_type( - args.iter() - .map(|(_, at)| self.codegen_type(at)) - .collect::>() - .as_slice(), - false, - ), + match self.codegen_type(body.type_()) { + Some(ret_ty) => ret_ty.fn_type(&arg_types, false), + None => self.context.void_type().fn_type(&arg_types, false), + }, ); self.env.push(); for (i, (arg, _)) in args.iter().enumerate() { @@ -244,9 +272,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.cur_function().get_nth_param(i as u32).unwrap().into(), ); } - let res = self.codegen_expr(body)?.try_into().unwrap(); + let res = self.codegen_expr(body)?; self.env.pop(); - Ok(self.finish_function(&res)) + Ok(self.finish_function(res.map(|av| av.try_into().unwrap()).as_ref())) } pub fn codegen_extern( @@ -255,15 +283,16 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { args: &'ast [Type], ret: &'ast Type, ) -> Result<()> { + let arg_types = args + .iter() + .map(|t| self.codegen_type(t).unwrap()) + .collect::>(); self.module.add_function( name, - self.codegen_type(ret).fn_type( - &args - .iter() - .map(|t| self.codegen_type(t)) - .collect::>(), - false, - ), + match self.codegen_type(ret) { + Some(ret_ty) => ret_ty.fn_type(&arg_types, false), + None => self.context.void_type().fn_type(&arg_types, false), + }, None, ); Ok(()) @@ -287,29 +316,31 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> { self.new_function("main", self.context.i64_type().fn_type(&[], false)); - let res = self.codegen_expr(expr)?.try_into().unwrap(); + let res = self.codegen_expr(expr)?; if *expr.type_() != Type::Int { self.builder .build_return(Some(&self.context.i64_type().const_int(0, false))); } else { - self.finish_function(&res); + self.finish_function(res.map(|r| r.try_into().unwrap()).as_ref()); } Ok(()) } - fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> { + fn codegen_type(&self, type_: &'ast Type) -> Option> { // TODO match type_ { - Type::Int => self.context.i64_type().into(), - Type::Float => self.context.f64_type().into(), - Type::Bool => self.context.bool_type().into(), - Type::CString => self - .context - .i8_type() - .ptr_type(AddressSpace::Generic) - .into(), + Type::Int => Some(self.context.i64_type().into()), + Type::Float => Some(self.context.f64_type().into()), + Type::Bool => Some(self.context.bool_type().into()), + Type::CString => Some( + self.context + .i8_type() + .ptr_type(AddressSpace::Generic) + .into(), + ), Type::Function(_) => todo!(), Type::Var(_) => unreachable!(), + Type::Unit => None, } } diff --git a/users/glittershark/achilles/src/compiler.rs b/users/glittershark/achilles/src/compiler.rs index 7001e5a9a384..45b215473d7f 100644 --- a/users/glittershark/achilles/src/compiler.rs +++ b/users/glittershark/achilles/src/compiler.rs @@ -8,7 +8,7 @@ use test_strategy::Arbitrary; use crate::codegen::{self, Codegen}; use crate::common::Result; -use crate::passes::hir::monomorphize; +use crate::passes::hir::{monomorphize, strip_positive_units}; use crate::{parser, tc}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)] @@ -55,9 +55,10 @@ pub struct CompilerOptions { pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> { let src = fs::read_to_string(input)?; - let (_, decls) = parser::toplevel(&src)?; // TODO: statements + let (_, decls) = parser::toplevel(&src)?; let mut decls = tc::typecheck_toplevel(decls)?; monomorphize::run_toplevel(&mut decls); + strip_positive_units::run_toplevel(&mut decls); let context = codegen::Context::create(); let mut codegen = Codegen::new( diff --git a/users/glittershark/achilles/src/interpreter/mod.rs b/users/glittershark/achilles/src/interpreter/mod.rs index bcd474b3abe1..a8ba2dd3acdc 100644 --- a/users/glittershark/achilles/src/interpreter/mod.rs +++ b/users/glittershark/achilles/src/interpreter/mod.rs @@ -30,6 +30,7 @@ impl<'a> Interpreter<'a> { Expr::Literal(Literal::Int(i), _) => Ok((*i).into()), Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()), Expr::Literal(Literal::String(s), _) => Ok(s.clone().into()), + Expr::Literal(Literal::Unit, _) => unreachable!(), Expr::UnaryOp { op, rhs, .. } => { let rhs = self.eval(rhs)?; match op { diff --git a/users/glittershark/achilles/src/parser/expr.rs b/users/glittershark/achilles/src/parser/expr.rs index 99c8018fd00c..8a28d00984c9 100644 --- a/users/glittershark/achilles/src/parser/expr.rs +++ b/users/glittershark/achilles/src/parser/expr.rs @@ -186,7 +186,9 @@ named!(string(&str) -> Literal, preceded!( ) )); -named!(literal(&str) -> Literal, alt!(int | bool_ | string)); +named!(unit(&str) -> Literal, map!(complete!(tag!("()")), |_| Literal::Unit)); + +named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit)); named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal)); @@ -270,7 +272,6 @@ named!(funcref(&str) -> Expr, alt!( named!(no_arg_call(&str) -> Expr, do_parse!( fun: funcref - >> multispace0 >> complete!(tag!("()")) >> (Expr::Call { fun: Box::new(fun), @@ -431,6 +432,11 @@ pub(crate) mod tests { } } + #[test] + fn unit() { + assert_eq!(test_parse!(expr, "()"), Expr::Literal(Literal::Unit)); + } + #[test] fn bools() { assert_eq!( @@ -515,6 +521,18 @@ pub(crate) mod tests { ); } + #[test] + fn unit_call() { + let res = test_parse!(expr, "f ()"); + assert_eq!( + res, + Expr::Call { + fun: ident_expr("f"), + args: vec![Expr::Literal(Literal::Unit)] + } + ) + } + #[test] fn call_with_args() { let res = test_parse!(expr, "f x 1"); diff --git a/users/glittershark/achilles/src/parser/mod.rs b/users/glittershark/achilles/src/parser/mod.rs index 652b083fdae5..3e0081bd391d 100644 --- a/users/glittershark/achilles/src/parser/mod.rs +++ b/users/glittershark/achilles/src/parser/mod.rs @@ -1,9 +1,9 @@ use nom::character::complete::{multispace0, multispace1}; use nom::error::{ErrorKind, ParseError}; -use nom::{alt, char, complete, do_parse, many0, named, separated_list0, tag, terminated}; +use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag, terminated}; #[macro_use] -mod macros; +pub(crate) mod macros; mod expr; mod type_; @@ -136,7 +136,11 @@ named!(pub decl(&str) -> Decl, alt!( extern_decl )); -named!(pub toplevel(&str) -> Vec, terminated!(many0!(decl), multispace0)); +named!(pub toplevel(&str) -> Vec, do_parse!( + decls: many0!(decl) + >> multispace0 + >> eof!() + >> (decls))); #[cfg(test)] mod tests { @@ -215,4 +219,21 @@ mod tests { }] ) } + + #[test] + fn return_unit() { + assert_eq!( + test_parse!(decl, "fn g _ = ()"), + Decl::Fun { + name: "g".try_into().unwrap(), + body: Fun { + args: vec![Arg { + ident: "_".try_into().unwrap(), + type_: None, + }], + body: Expr::Literal(Literal::Unit), + }, + } + ) + } } diff --git a/users/glittershark/achilles/src/parser/type_.rs b/users/glittershark/achilles/src/parser/type_.rs index 1e6e380bb823..8a1081e2521f 100644 --- a/users/glittershark/achilles/src/parser/type_.rs +++ b/users/glittershark/achilles/src/parser/type_.rs @@ -29,6 +29,7 @@ named!(pub type_(&str) -> Type, alt!( tag!("float") => { |_| Type::Float } | tag!("bool") => { |_| Type::Bool } | tag!("cstring") => { |_| Type::CString } | + tag!("()") => { |_| Type::Unit } | function_type => { |ft| Type::Function(ft) }| ident => { |id| Type::Var(id) } | delimited!( @@ -51,6 +52,7 @@ mod tests { assert_eq!(test_parse!(type_, "float"), Type::Float); assert_eq!(test_parse!(type_, "bool"), Type::Bool); assert_eq!(test_parse!(type_, "cstring"), Type::CString); + assert_eq!(test_parse!(type_, "()"), Type::Unit); } #[test] diff --git a/users/glittershark/achilles/src/passes/hir/mod.rs b/users/glittershark/achilles/src/passes/hir/mod.rs index fb2f64e08591..845bfcb7ab6a 100644 --- a/users/glittershark/achilles/src/passes/hir/mod.rs +++ b/users/glittershark/achilles/src/passes/hir/mod.rs @@ -4,6 +4,7 @@ use crate::ast::hir::{Binding, Decl, Expr}; use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; pub(crate) mod monomorphize; +pub(crate) mod strip_positive_units; pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { type Error; @@ -53,7 +54,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { Ok(()) } + fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { + self.pre_visit_expr(expr)?; match expr { Expr::Ident(id, t) => { self.visit_ident(id)?; @@ -140,6 +146,17 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { Ok(()) } + fn post_visit_fun_decl( + &mut self, + _name: &mut Ident<'ast>, + _type_args: &mut Vec, + _args: &mut Vec<(Ident, T)>, + _body: &mut Box>, + _type_: &mut T, + ) -> Result<(), Self::Error> { + Ok(()) + } + fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> { match decl { Decl::Fun { @@ -150,15 +167,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { type_, } => { self.visit_ident(name)?; - for type_arg in type_args { + for type_arg in type_args.iter_mut() { self.visit_ident(type_arg)?; } - for (arg, t) in args { + for (arg, t) in args.iter_mut() { self.visit_ident(arg)?; self.visit_type(t)?; } self.visit_expr(body)?; self.visit_type(type_)?; + self.post_visit_fun_decl(name, type_args, args, body, type_)?; } Decl::Extern { name, diff --git a/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs b/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs new file mode 100644 index 000000000000..91b56551c82d --- /dev/null +++ b/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; +use std::mem; + +use ast::hir::Binding; +use ast::Literal; +use void::{ResultVoidExt, Void}; + +use crate::ast::hir::{Decl, Expr}; +use crate::ast::{self, Ident}; + +use super::Visitor; + +/// Strip all values with a unit type in positive (non-return) position +pub(crate) struct StripPositiveUnits {} + +impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits { + type Error = Void; + + fn pre_visit_expr( + &mut self, + expr: &mut Expr<'ast, ast::Type<'ast>>, + ) -> Result<(), Self::Error> { + let mut extracted = vec![]; + if let Expr::Call { args, .. } = expr { + // TODO(grfn): replace with drain_filter once it's stabilized + let mut i = 0; + while i != args.len() { + if args[i].type_() == &ast::Type::Unit { + let expr = args.remove(i); + if !matches!(expr, Expr::Literal(Literal::Unit, _)) { + extracted.push(expr) + }; + } else { + i += 1 + } + } + } + + if !extracted.is_empty() { + let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit)); + *expr = Expr::Let { + bindings: extracted + .into_iter() + .map(|expr| Binding { + ident: Ident::from_str_unchecked("___discarded"), + type_: expr.type_().clone(), + body: expr, + }) + .collect(), + type_: body.type_().clone(), + body: Box::new(body), + }; + } + + Ok(()) + } + + fn post_visit_call( + &mut self, + _fun: &mut Expr<'ast, ast::Type<'ast>>, + _type_args: &mut HashMap, ast::Type<'ast>>, + args: &mut Vec>>, + ) -> Result<(), Self::Error> { + args.retain(|arg| arg.type_() != &ast::Type::Unit); + Ok(()) + } + + fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> { + if let ast::Type::Function(ft) = type_ { + ft.args.retain(|a| a != &ast::Type::Unit); + } + Ok(()) + } + + fn post_visit_fun_decl( + &mut self, + _name: &mut Ident<'ast>, + _type_args: &mut Vec, + args: &mut Vec<(Ident, ast::Type<'ast>)>, + _body: &mut Box>>, + _type_: &mut ast::Type<'ast>, + ) -> Result<(), Self::Error> { + args.retain(|(_, ty)| ty != &ast::Type::Unit); + Ok(()) + } +} + +pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec>>) { + let mut pass = StripPositiveUnits {}; + for decl in toplevel.iter_mut() { + pass.visit_decl(decl).void_unwrap(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::toplevel; + use crate::tc::typecheck_toplevel; + use pretty_assertions::assert_eq; + + #[test] + fn unit_only_arg() { + let (_, program) = toplevel( + "ty f : fn () -> int + fn f _ = 1 + + ty main : fn -> int + fn main = f ()", + ) + .unwrap(); + + let (_, expected) = toplevel( + "ty f : fn -> int + fn f = 1 + + ty main : fn -> int + fn main = f()", + ) + .unwrap(); + let expected = typecheck_toplevel(expected).unwrap(); + + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + assert_eq!(program, expected); + } + + #[test] + fn unit_and_other_arg() { + let (_, program) = toplevel( + "ty f : fn (), int -> int + fn f _ x = x + + ty main : fn -> int + fn main = f () 1", + ) + .unwrap(); + + let (_, expected) = toplevel( + "ty f : fn int -> int + fn f x = x + + ty main : fn -> int + fn main = f 1", + ) + .unwrap(); + let expected = typecheck_toplevel(expected).unwrap(); + + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + assert_eq!(program, expected); + } + + #[test] + fn unit_expr_and_other_arg() { + let (_, program) = toplevel( + "ty f : fn (), int -> int + fn f _ x = x + + ty g : fn int -> () + fn g _ = () + + ty main : fn -> int + fn main = f (g 2) 1", + ) + .unwrap(); + + let (_, expected) = toplevel( + "ty f : fn int -> int + fn f x = x + + ty g : fn int -> () + fn g _ = () + + ty main : fn -> int + fn main = let ___discarded = g 2 in f 1", + ) + .unwrap(); + assert_eq!(expected.len(), 6); + let expected = typecheck_toplevel(expected).unwrap(); + + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + assert_eq!(program, expected); + } +} diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs index 137561978f00..d27c45075e97 100644 --- a/users/glittershark/achilles/src/tc/mod.rs +++ b/users/glittershark/achilles/src/tc/mod.rs @@ -85,6 +85,7 @@ pub enum Type { Exist(TyVar), Nullary(NullaryType), Prim(PrimType), + Unit, Fun { args: Vec, ret: Box, @@ -96,6 +97,7 @@ impl<'a> TryFrom for ast::Type<'a> { fn try_from(value: Type) -> result::Result { match value { + Type::Unit => Ok(ast::Type::Unit), Type::Univ(_) => todo!(), Type::Exist(_) => Err(value), Type::Nullary(_) => todo!(), @@ -126,6 +128,7 @@ impl Display for Type { Type::Univ(TyVar(n)) => write!(f, "∀{}", n), Type::Exist(TyVar(n)) => write!(f, "∃{}", n), Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), + Type::Unit => write!(f, "()"), } } } @@ -171,6 +174,7 @@ impl<'ast> Typechecker<'ast> { Literal::Int(_) => Type::Prim(PrimType::Int), Literal::Bool(_) => Type::Prim(PrimType::Bool), Literal::String(_) => Type::Prim(PrimType::CString), + Literal::Unit => Type::Unit, }; Ok(hir::Expr::Literal(lit.to_owned(), type_)) } @@ -377,6 +381,7 @@ impl<'ast> Typechecker<'ast> { fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result { match (ty1, ty2) { + (Type::Unit, Type::Unit) => Ok(Type::Unit), (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), Some(var @ ast::Type::Var(_)) => { @@ -466,6 +471,7 @@ impl<'ast> Typechecker<'ast> { let ret = match ty { Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), + Type::Unit => Ok(ast::Type::Unit), Type::Nullary(_) => todo!(), Type::Prim(pr) => Ok(pr.into()), Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { @@ -496,6 +502,7 @@ impl<'ast> Typechecker<'ast> { } Type::Nullary(_) => todo!(), Type::Prim(pr) => break Some((*pr).into()), + Type::Unit => break Some(ast::Type::Unit), Type::Fun { args, ret } => todo!(), } } @@ -503,6 +510,7 @@ impl<'ast> Typechecker<'ast> { fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { match ast_type { + ast::Type::Unit => Type::Unit, ast::Type::Int => INT, ast::Type::Float => FLOAT, ast::Type::Bool => BOOL, @@ -570,6 +578,8 @@ impl<'ast> Typechecker<'ast> { } (Type::Univ(_), _) => false, (Type::Exist(_), _) => false, + (Type::Unit, ast::Type::Unit) => true, + (Type::Unit, _) => false, (Type::Nullary(_), _) => todo!(), (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, (Type::Fun { args, ret }, ast::Type::Function(ft)) => { -- cgit 1.4.1