From 8d5f3029e531d1163668f34e65b73cb6d639767f Mon Sep 17 00:00:00 2001 From: Griffin Smith Date: Sat, 20 Mar 2021 18:14:23 -0400 Subject: feat(gs/achilles): Implement very basic monomorphization Implement very basic monomorphization, by recording type variable instantiations when typechecking Call nodes and then using those in a new hir Visitor trait to copy the body of any generic decls for each possible set of instantiation of the type variables. Change-Id: Iab54030973e5d66e2b8bcd074b4cb6c001a90123 Reviewed-on: https://cl.tvl.fyi/c/depot/+/2617 Reviewed-by: glittershark Tested-by: BuildkiteCI --- users/glittershark/achilles/Cargo.lock | 7 + users/glittershark/achilles/Cargo.toml | 1 + users/glittershark/achilles/src/ast/hir.rs | 59 ++++++- users/glittershark/achilles/src/ast/mod.rs | 20 +++ users/glittershark/achilles/src/compiler.rs | 4 +- users/glittershark/achilles/src/interpreter/mod.rs | 7 +- users/glittershark/achilles/src/main.rs | 1 + users/glittershark/achilles/src/passes/hir/mod.rs | 179 +++++++++++++++++++++ .../achilles/src/passes/hir/monomorphize.rs | 139 ++++++++++++++++ users/glittershark/achilles/src/passes/mod.rs | 1 + users/glittershark/achilles/src/tc/mod.rs | 20 ++- users/glittershark/achilles/tests/compile.rs | 11 +- 12 files changed, 430 insertions(+), 19 deletions(-) create mode 100644 users/glittershark/achilles/src/passes/hir/mod.rs create mode 100644 users/glittershark/achilles/src/passes/hir/monomorphize.rs create mode 100644 users/glittershark/achilles/src/passes/mod.rs (limited to 'users/glittershark') diff --git a/users/glittershark/achilles/Cargo.lock b/users/glittershark/achilles/Cargo.lock index 0c5779135a5f..a5fa644ec631 100644 --- a/users/glittershark/achilles/Cargo.lock +++ b/users/glittershark/achilles/Cargo.lock @@ -19,6 +19,7 @@ dependencies = [ "proptest", "test-strategy", "thiserror", + "void", ] [[package]] @@ -761,6 +762,12 @@ version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5a972e5669d67ba988ce3dc826706fb0a8b01471c088cb0b6110b805cc36aed" +[[package]] +name = "void" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a02e4885ed3bc0f2de90ea6dd45ebcbb66dacffe03547fadbb0eeae2770887d" + [[package]] name = "wait-timeout" version = "0.2.0" diff --git a/users/glittershark/achilles/Cargo.toml b/users/glittershark/achilles/Cargo.toml index c0ba4d137a9f..ad18aaac3a43 100644 --- a/users/glittershark/achilles/Cargo.toml +++ b/users/glittershark/achilles/Cargo.toml @@ -19,6 +19,7 @@ pratt = "0.3.0" proptest = "1.0.0" test-strategy = "0.1.1" thiserror = "1.0.24" +void = "1.0.2" [dev-dependencies] crate-root = "0.1.3" diff --git a/users/glittershark/achilles/src/ast/hir.rs b/users/glittershark/achilles/src/ast/hir.rs index 691b9607e7e6..8726af509388 100644 --- a/users/glittershark/achilles/src/ast/hir.rs +++ b/users/glittershark/achilles/src/ast/hir.rs @@ -1,3 +1,5 @@ +use std::collections::HashMap; + use itertools::Itertools; use super::{BinaryOperator, Ident, Literal, UnaryOperator}; @@ -55,6 +57,7 @@ pub enum Expr<'a, T> { }, Fun { + type_args: Vec>, args: Vec<(Ident<'a>, T)>, body: Box>, type_: T, @@ -62,6 +65,7 @@ pub enum Expr<'a, T> { Call { fun: Box>, + type_args: HashMap, T>, args: Vec>, type_: T, }, @@ -133,16 +137,31 @@ impl<'a, T> Expr<'a, T> { else_: Box::new(else_.traverse_type(f.clone())?), type_: f(type_)?, }), - Expr::Fun { args, body, type_ } => Ok(Expr::Fun { + Expr::Fun { + args, + type_args, + body, + type_, + } => Ok(Expr::Fun { args: args .into_iter() .map(|(id, t)| Ok((id, f.clone()(t)?))) .collect::, E>>()?, + type_args, body: Box::new(body.traverse_type(f.clone())?), type_: f(type_)?, }), - Expr::Call { fun, args, type_ } => Ok(Expr::Call { + Expr::Call { + fun, + type_args, + args, + type_, + } => Ok(Expr::Call { fun: Box::new(fun.traverse_type(f.clone())?), + type_args: type_args + .into_iter() + .map(|(id, ty)| Ok((id, f.clone()(ty)?))) + .collect::, E>>()?, args: args .into_iter() .map(|e| e.traverse_type(f.clone())) @@ -180,7 +199,7 @@ impl<'a, T> Expr<'a, T> { body, type_, } => Expr::Let { - bindings: bindings.into_iter().map(|b| b.to_owned()).collect(), + bindings: bindings.iter().map(|b| b.to_owned()).collect(), body: Box::new((**body).to_owned()), type_: type_.clone(), }, @@ -195,26 +214,43 @@ impl<'a, T> Expr<'a, T> { else_: Box::new((**else_).to_owned()), type_: type_.clone(), }, - Expr::Fun { args, body, type_ } => Expr::Fun { + Expr::Fun { + args, + type_args, + body, + type_, + } => Expr::Fun { args: args - .into_iter() + .iter() .map(|(id, t)| (id.to_owned(), t.clone())) .collect(), + type_args: type_args.iter().map(|arg| arg.to_owned()).collect(), body: Box::new((**body).to_owned()), type_: type_.clone(), }, - Expr::Call { fun, args, type_ } => Expr::Call { + Expr::Call { + fun, + type_args, + args, + type_, + } => Expr::Call { fun: Box::new((**fun).to_owned()), - args: args.into_iter().map(|e| e.to_owned()).collect(), + type_args: type_args + .iter() + .map(|(id, t)| (id.to_owned(), t.clone())) + .collect(), + args: args.iter().map(|e| e.to_owned()).collect(), type_: type_.clone(), }, } } } +#[derive(Debug, Clone)] pub enum Decl<'a, T> { Fun { name: Ident<'a>, + type_args: Vec>, args: Vec<(Ident<'a>, T)>, body: Box>, type_: T, @@ -235,6 +271,13 @@ impl<'a, T> Decl<'a, T> { } } + pub fn set_name(&mut self, new_name: Ident<'a>) { + match self { + Decl::Fun { name, .. } => *name = new_name, + Decl::Extern { name, .. } => *name = new_name, + } + } + pub fn type_(&self) -> Option<&T> { match self { Decl::Fun { type_, .. } => Some(type_), @@ -249,11 +292,13 @@ impl<'a, T> Decl<'a, T> { match self { Decl::Fun { name, + type_args, args, body, type_, } => Ok(Decl::Fun { name, + type_args, args: args .into_iter() .map(|(id, t)| Ok((id, f(t)?))) diff --git a/users/glittershark/achilles/src/ast/mod.rs b/users/glittershark/achilles/src/ast/mod.rs index 22d16c93645c..53f222a6a11a 100644 --- a/users/glittershark/achilles/src/ast/mod.rs +++ b/users/glittershark/achilles/src/ast/mod.rs @@ -356,6 +356,26 @@ impl<'a> Type<'a> { let mut substs = HashMap::new(); do_alpha_equiv(&mut substs, self, other) } + + pub fn traverse_type_vars<'b, F>(self, mut f: F) -> Type<'b> + where + F: FnMut(Ident<'a>) -> Type<'b> + Clone, + { + match self { + Type::Var(tv) => f(tv), + Type::Function(FunctionType { args, ret }) => Type::Function(FunctionType { + args: args + .into_iter() + .map(|t| t.traverse_type_vars(f.clone())) + .collect(), + ret: Box::new(ret.traverse_type_vars(f)), + }), + Type::Int => Type::Int, + Type::Float => Type::Float, + Type::Bool => Type::Bool, + Type::CString => Type::CString, + } + } } impl<'a> Display for Type<'a> { diff --git a/users/glittershark/achilles/src/compiler.rs b/users/glittershark/achilles/src/compiler.rs index f925b267df57..7001e5a9a384 100644 --- a/users/glittershark/achilles/src/compiler.rs +++ b/users/glittershark/achilles/src/compiler.rs @@ -8,6 +8,7 @@ use test_strategy::Arbitrary; use crate::codegen::{self, Codegen}; use crate::common::Result; +use crate::passes::hir::monomorphize; use crate::{parser, tc}; #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)] @@ -55,7 +56,8 @@ pub struct CompilerOptions { pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> { let src = fs::read_to_string(input)?; let (_, decls) = parser::toplevel(&src)?; // TODO: statements - let decls = tc::typecheck_toplevel(decls)?; + let mut decls = tc::typecheck_toplevel(decls)?; + monomorphize::run_toplevel(&mut decls); let context = codegen::Context::create(); let mut codegen = Codegen::new( diff --git a/users/glittershark/achilles/src/interpreter/mod.rs b/users/glittershark/achilles/src/interpreter/mod.rs index 3bfeeb52e85c..bcd474b3abe1 100644 --- a/users/glittershark/achilles/src/interpreter/mod.rs +++ b/users/glittershark/achilles/src/interpreter/mod.rs @@ -96,7 +96,12 @@ impl<'a> Interpreter<'a> { } Ok(Value::from(*interpreter.eval(body)?.as_type::()?)) } - Expr::Fun { args, body, type_ } => { + Expr::Fun { + type_args: _, + args, + body, + type_, + } => { let type_ = match type_ { Type::Function(ft) => ft.clone(), _ => unreachable!("Function expression without function type"), diff --git a/users/glittershark/achilles/src/main.rs b/users/glittershark/achilles/src/main.rs index d5b00d6b6c46..5ae1b59b3a8e 100644 --- a/users/glittershark/achilles/src/main.rs +++ b/users/glittershark/achilles/src/main.rs @@ -6,6 +6,7 @@ pub(crate) mod commands; pub(crate) mod common; pub mod compiler; pub mod interpreter; +pub(crate) mod passes; #[macro_use] pub mod parser; pub mod tc; diff --git a/users/glittershark/achilles/src/passes/hir/mod.rs b/users/glittershark/achilles/src/passes/hir/mod.rs new file mode 100644 index 000000000000..fb2f64e08591 --- /dev/null +++ b/users/glittershark/achilles/src/passes/hir/mod.rs @@ -0,0 +1,179 @@ +use std::collections::HashMap; + +use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; + +pub(crate) mod monomorphize; + +pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { + type Error; + + fn visit_type(&mut self, _type: &mut T) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_ident(&mut self, _ident: &mut Ident<'ast>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_literal(&mut self, _literal: &mut Literal<'ast>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_unary_operator(&mut self, _op: &mut UnaryOperator) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_binary_operator(&mut self, _op: &mut BinaryOperator) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> { + self.visit_ident(&mut binding.ident)?; + self.visit_type(&mut binding.type_)?; + self.visit_expr(&mut binding.body)?; + Ok(()) + } + + fn post_visit_call( + &mut self, + _fun: &mut Expr<'ast, T>, + _type_args: &mut HashMap, T>, + _args: &mut Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn pre_visit_call( + &mut self, + _fun: &mut Expr<'ast, T>, + _type_args: &mut HashMap, T>, + _args: &mut Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { + match expr { + Expr::Ident(id, t) => { + self.visit_ident(id)?; + self.visit_type(t)?; + } + Expr::Literal(lit, t) => { + self.visit_literal(lit)?; + self.visit_type(t)?; + } + Expr::UnaryOp { op, rhs, type_ } => { + self.visit_unary_operator(op)?; + self.visit_expr(rhs)?; + self.visit_type(type_)?; + } + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => { + self.visit_expr(lhs)?; + self.visit_binary_operator(op)?; + self.visit_expr(rhs)?; + self.visit_type(type_)?; + } + Expr::Let { + bindings, + body, + type_, + } => { + for binding in bindings.iter_mut() { + self.visit_binding(binding)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Expr::If { + condition, + then, + else_, + type_, + } => { + self.visit_expr(condition)?; + self.visit_expr(then)?; + self.visit_expr(else_)?; + self.visit_type(type_)?; + } + Expr::Fun { + args, + body, + type_args, + type_, + } => { + for (ident, t) in args { + self.visit_ident(ident)?; + self.visit_type(t)?; + } + for ta in type_args { + self.visit_ident(ta)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Expr::Call { + fun, + args, + type_args, + type_, + } => { + self.pre_visit_call(fun, type_args, args)?; + self.visit_expr(fun)?; + for arg in args.iter_mut() { + self.visit_expr(arg)?; + } + self.visit_type(type_)?; + self.post_visit_call(fun, type_args, args)?; + } + } + + Ok(()) + } + + fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> { + match decl { + Decl::Fun { + name, + type_args, + args, + body, + type_, + } => { + self.visit_ident(name)?; + for type_arg in type_args { + self.visit_ident(type_arg)?; + } + for (arg, t) in args { + self.visit_ident(arg)?; + self.visit_type(t)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Decl::Extern { + name, + arg_types, + ret_type, + } => { + self.visit_ident(name)?; + for arg_t in arg_types { + self.visit_type(arg_t)?; + } + self.visit_type(ret_type)?; + } + } + + self.post_visit_decl(decl)?; + Ok(()) + } +} diff --git a/users/glittershark/achilles/src/passes/hir/monomorphize.rs b/users/glittershark/achilles/src/passes/hir/monomorphize.rs new file mode 100644 index 000000000000..251a988f4f6f --- /dev/null +++ b/users/glittershark/achilles/src/passes/hir/monomorphize.rs @@ -0,0 +1,139 @@ +use std::cell::RefCell; +use std::collections::{HashMap, HashSet}; +use std::convert::TryInto; +use std::mem; + +use void::{ResultVoidExt, Void}; + +use crate::ast::hir::{Decl, Expr}; +use crate::ast::{self, Ident}; + +use super::Visitor; + +#[derive(Default)] +pub(crate) struct Monomorphize<'a, 'ast> { + decls: HashMap<&'a Ident<'ast>, &'a Decl<'ast, ast::Type<'ast>>>, + extra_decls: Vec>>, + remove_decls: HashSet>, +} + +impl<'a, 'ast> Monomorphize<'a, 'ast> { + pub(crate) fn new() -> Self { + Default::default() + } +} + +impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for Monomorphize<'a, 'ast> { + type Error = Void; + + fn post_visit_call( + &mut self, + fun: &mut Expr<'ast, ast::Type<'ast>>, + type_args: &mut HashMap, ast::Type<'ast>>, + args: &mut Vec>>, + ) -> Result<(), Self::Error> { + let new_fun = match fun { + Expr::Ident(id, _) => { + let decl: Decl<_> = (**self.decls.get(id).unwrap()).clone(); + let name = RefCell::new(id.to_string()); + let type_args = mem::take(type_args); + let mut monomorphized = decl + .traverse_type(|ty| -> Result<_, Void> { + Ok(ty.clone().traverse_type_vars(|v| { + let concrete = type_args.get(&v).unwrap(); + name.borrow_mut().push_str(&concrete.to_string()); + concrete.clone() + })) + }) + .void_unwrap(); + let name: Ident = name.into_inner().try_into().unwrap(); + if name != *id { + self.remove_decls.insert(id.clone()); + monomorphized.set_name(name.clone()); + let type_ = monomorphized.type_().unwrap().clone(); + self.extra_decls.push(monomorphized); + Some(Expr::Ident(name, type_)) + } else { + None + } + } + _ => todo!(), + }; + if let Some(new_fun) = new_fun { + *fun = new_fun; + } + Ok(()) + } + + fn post_visit_decl( + &mut self, + decl: &'a Decl<'ast, ast::Type<'ast>>, + ) -> Result<(), Self::Error> { + self.decls.insert(decl.name(), decl); + Ok(()) + } +} + +pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec>>) { + let mut pass = Monomorphize::new(); + for decl in toplevel.iter_mut() { + pass.visit_decl(decl).void_unwrap(); + } + let remove_decls = mem::take(&mut pass.remove_decls); + let mut extra_decls = mem::take(&mut pass.extra_decls); + toplevel.retain(|decl| !remove_decls.contains(decl.name())); + extra_decls.append(toplevel); + *toplevel = extra_decls; +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::*; + use crate::parser::toplevel; + use crate::tc::typecheck_toplevel; + + #[test] + fn call_id_decl() { + let (_, program) = toplevel( + "ty id : fn a -> a + fn id x = x + + ty main : fn -> int + fn main = id 0", + ) + .unwrap(); + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + let find_decl = |ident: &str| { + program.iter().find(|decl| { + matches!(decl, Decl::Fun {name, ..} if name == &Ident::try_from(ident).unwrap()) + }).unwrap() + }; + + let main = find_decl("main"); + let body = match main { + Decl::Fun { body, .. } => body, + _ => unreachable!(), + }; + + let expected_type = ast::Type::Function(ast::FunctionType { + args: vec![ast::Type::Int], + ret: Box::new(ast::Type::Int), + }); + + match &**body { + Expr::Call { fun, .. } => { + let fun = match &**fun { + Expr::Ident(fun, _) => fun, + _ => unreachable!(), + }; + let called_decl = find_decl(fun.into()); + assert_eq!(called_decl.type_().unwrap(), &expected_type); + } + _ => unreachable!(), + } + } +} diff --git a/users/glittershark/achilles/src/passes/mod.rs b/users/glittershark/achilles/src/passes/mod.rs new file mode 100644 index 000000000000..306869bef1d5 --- /dev/null +++ b/users/glittershark/achilles/src/passes/mod.rs @@ -0,0 +1 @@ +pub(crate) mod hir; diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs index f3cb40a8b010..137561978f00 100644 --- a/users/glittershark/achilles/src/tc/mod.rs +++ b/users/glittershark/achilles/src/tc/mod.rs @@ -266,6 +266,7 @@ impl<'ast> Typechecker<'ast> { args: args.iter().map(|(_, ty)| ty.clone()).collect(), ret: Box::new(body.type_().clone()), }, + type_args: vec![], // TODO fill in once we do let generalization args, body: Box::new(body), }) @@ -289,9 +290,10 @@ impl<'ast> Typechecker<'ast> { Ok(arg) }) .try_collect()?; - self.commit_instantiations(); + let type_args = self.commit_instantiations(); Ok(hir::Expr::Call { fun: Box::new(fun), + type_args, args, type_: ret_ty, }) @@ -325,8 +327,14 @@ impl<'ast> Typechecker<'ast> { self.env.set(name.clone(), type_); self.env.pop(); match body { - hir::Expr::Fun { args, body, type_ } => Ok(Some(hir::Decl::Fun { + hir::Expr::Fun { + type_args, + args, + body, + type_, + } => Ok(Some(hir::Decl::Fun { name, + type_args, args, body, type_, @@ -538,17 +546,21 @@ impl<'ast> Typechecker<'ast> { }) } - fn commit_instantiations(&mut self) { + fn commit_instantiations(&mut self) -> HashMap, Type> { + let mut res = HashMap::new(); let mut ctx = mem::take(&mut self.ctx); for (_, v) in ctx.iter_mut() { if let Type::Univ(tv) = v { - if let Some(concrete) = self.instantiations.resolve(&self.name_univ(*tv)) { + let tv_name = self.name_univ(*tv); + if let Some(concrete) = self.instantiations.resolve(&tv_name) { + res.insert(tv_name, concrete.clone()); *v = concrete.clone(); } } } self.ctx = ctx; self.instantiations.pop(); + res } fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool { diff --git a/users/glittershark/achilles/tests/compile.rs b/users/glittershark/achilles/tests/compile.rs index 1b4da463a980..51ffb239b7e4 100644 --- a/users/glittershark/achilles/tests/compile.rs +++ b/users/glittershark/achilles/tests/compile.rs @@ -14,12 +14,11 @@ const FIXTURES: &[Fixture] = &[ exit_code: 5, expected_output: "", }, - // TODO(grfn): needs monomorphization - // Fixture { - // name: "functions", - // exit_code: 9, - // expected_output: "", - // }, + Fixture { + name: "functions", + exit_code: 9, + expected_output: "", + }, Fixture { name: "externs", exit_code: 0, -- cgit 1.4.1