From 8d5f3029e531d1163668f34e65b73cb6d639767f Mon Sep 17 00:00:00 2001 From: Griffin Smith Date: Sat, 20 Mar 2021 18:14:23 -0400 Subject: feat(gs/achilles): Implement very basic monomorphization Implement very basic monomorphization, by recording type variable instantiations when typechecking Call nodes and then using those in a new hir Visitor trait to copy the body of any generic decls for each possible set of instantiation of the type variables. Change-Id: Iab54030973e5d66e2b8bcd074b4cb6c001a90123 Reviewed-on: https://cl.tvl.fyi/c/depot/+/2617 Reviewed-by: glittershark Tested-by: BuildkiteCI --- users/glittershark/achilles/src/passes/hir/mod.rs | 179 +++++++++++++++++++++ .../achilles/src/passes/hir/monomorphize.rs | 139 ++++++++++++++++ 2 files changed, 318 insertions(+) create mode 100644 users/glittershark/achilles/src/passes/hir/mod.rs create mode 100644 users/glittershark/achilles/src/passes/hir/monomorphize.rs (limited to 'users/glittershark/achilles/src/passes/hir') diff --git a/users/glittershark/achilles/src/passes/hir/mod.rs b/users/glittershark/achilles/src/passes/hir/mod.rs new file mode 100644 index 000000000000..fb2f64e08591 --- /dev/null +++ b/users/glittershark/achilles/src/passes/hir/mod.rs @@ -0,0 +1,179 @@ +use std::collections::HashMap; + +use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; + +pub(crate) mod monomorphize; + +pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { + type Error; + + fn visit_type(&mut self, _type: &mut T) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_ident(&mut self, _ident: &mut Ident<'ast>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_literal(&mut self, _literal: &mut Literal<'ast>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_unary_operator(&mut self, _op: &mut UnaryOperator) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_binary_operator(&mut self, _op: &mut BinaryOperator) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> { + self.visit_ident(&mut binding.ident)?; + self.visit_type(&mut binding.type_)?; + self.visit_expr(&mut binding.body)?; + Ok(()) + } + + fn post_visit_call( + &mut self, + _fun: &mut Expr<'ast, T>, + _type_args: &mut HashMap, T>, + _args: &mut Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn pre_visit_call( + &mut self, + _fun: &mut Expr<'ast, T>, + _type_args: &mut HashMap, T>, + _args: &mut Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { + match expr { + Expr::Ident(id, t) => { + self.visit_ident(id)?; + self.visit_type(t)?; + } + Expr::Literal(lit, t) => { + self.visit_literal(lit)?; + self.visit_type(t)?; + } + Expr::UnaryOp { op, rhs, type_ } => { + self.visit_unary_operator(op)?; + self.visit_expr(rhs)?; + self.visit_type(type_)?; + } + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => { + self.visit_expr(lhs)?; + self.visit_binary_operator(op)?; + self.visit_expr(rhs)?; + self.visit_type(type_)?; + } + Expr::Let { + bindings, + body, + type_, + } => { + for binding in bindings.iter_mut() { + self.visit_binding(binding)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Expr::If { + condition, + then, + else_, + type_, + } => { + self.visit_expr(condition)?; + self.visit_expr(then)?; + self.visit_expr(else_)?; + self.visit_type(type_)?; + } + Expr::Fun { + args, + body, + type_args, + type_, + } => { + for (ident, t) in args { + self.visit_ident(ident)?; + self.visit_type(t)?; + } + for ta in type_args { + self.visit_ident(ta)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Expr::Call { + fun, + args, + type_args, + type_, + } => { + self.pre_visit_call(fun, type_args, args)?; + self.visit_expr(fun)?; + for arg in args.iter_mut() { + self.visit_expr(arg)?; + } + self.visit_type(type_)?; + self.post_visit_call(fun, type_args, args)?; + } + } + + Ok(()) + } + + fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> { + match decl { + Decl::Fun { + name, + type_args, + args, + body, + type_, + } => { + self.visit_ident(name)?; + for type_arg in type_args { + self.visit_ident(type_arg)?; + } + for (arg, t) in args { + self.visit_ident(arg)?; + self.visit_type(t)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Decl::Extern { + name, + arg_types, + ret_type, + } => { + self.visit_ident(name)?; + for arg_t in arg_types { + self.visit_type(arg_t)?; + } + self.visit_type(ret_type)?; + } + } + + self.post_visit_decl(decl)?; + Ok(()) + } +} diff --git a/users/glittershark/achilles/src/passes/hir/monomorphize.rs b/users/glittershark/achilles/src/passes/hir/monomorphize.rs new file mode 100644 index 000000000000..251a988f4f6f --- /dev/null +++ b/users/glittershark/achilles/src/passes/hir/monomorphize.rs @@ -0,0 +1,139 @@ +use std::cell::RefCell; +use std::collections::{HashMap, HashSet}; +use std::convert::TryInto; +use std::mem; + +use void::{ResultVoidExt, Void}; + +use crate::ast::hir::{Decl, Expr}; +use crate::ast::{self, Ident}; + +use super::Visitor; + +#[derive(Default)] +pub(crate) struct Monomorphize<'a, 'ast> { + decls: HashMap<&'a Ident<'ast>, &'a Decl<'ast, ast::Type<'ast>>>, + extra_decls: Vec>>, + remove_decls: HashSet>, +} + +impl<'a, 'ast> Monomorphize<'a, 'ast> { + pub(crate) fn new() -> Self { + Default::default() + } +} + +impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for Monomorphize<'a, 'ast> { + type Error = Void; + + fn post_visit_call( + &mut self, + fun: &mut Expr<'ast, ast::Type<'ast>>, + type_args: &mut HashMap, ast::Type<'ast>>, + args: &mut Vec>>, + ) -> Result<(), Self::Error> { + let new_fun = match fun { + Expr::Ident(id, _) => { + let decl: Decl<_> = (**self.decls.get(id).unwrap()).clone(); + let name = RefCell::new(id.to_string()); + let type_args = mem::take(type_args); + let mut monomorphized = decl + .traverse_type(|ty| -> Result<_, Void> { + Ok(ty.clone().traverse_type_vars(|v| { + let concrete = type_args.get(&v).unwrap(); + name.borrow_mut().push_str(&concrete.to_string()); + concrete.clone() + })) + }) + .void_unwrap(); + let name: Ident = name.into_inner().try_into().unwrap(); + if name != *id { + self.remove_decls.insert(id.clone()); + monomorphized.set_name(name.clone()); + let type_ = monomorphized.type_().unwrap().clone(); + self.extra_decls.push(monomorphized); + Some(Expr::Ident(name, type_)) + } else { + None + } + } + _ => todo!(), + }; + if let Some(new_fun) = new_fun { + *fun = new_fun; + } + Ok(()) + } + + fn post_visit_decl( + &mut self, + decl: &'a Decl<'ast, ast::Type<'ast>>, + ) -> Result<(), Self::Error> { + self.decls.insert(decl.name(), decl); + Ok(()) + } +} + +pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec>>) { + let mut pass = Monomorphize::new(); + for decl in toplevel.iter_mut() { + pass.visit_decl(decl).void_unwrap(); + } + let remove_decls = mem::take(&mut pass.remove_decls); + let mut extra_decls = mem::take(&mut pass.extra_decls); + toplevel.retain(|decl| !remove_decls.contains(decl.name())); + extra_decls.append(toplevel); + *toplevel = extra_decls; +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::*; + use crate::parser::toplevel; + use crate::tc::typecheck_toplevel; + + #[test] + fn call_id_decl() { + let (_, program) = toplevel( + "ty id : fn a -> a + fn id x = x + + ty main : fn -> int + fn main = id 0", + ) + .unwrap(); + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + let find_decl = |ident: &str| { + program.iter().find(|decl| { + matches!(decl, Decl::Fun {name, ..} if name == &Ident::try_from(ident).unwrap()) + }).unwrap() + }; + + let main = find_decl("main"); + let body = match main { + Decl::Fun { body, .. } => body, + _ => unreachable!(), + }; + + let expected_type = ast::Type::Function(ast::FunctionType { + args: vec![ast::Type::Int], + ret: Box::new(ast::Type::Int), + }); + + match &**body { + Expr::Call { fun, .. } => { + let fun = match &**fun { + Expr::Ident(fun, _) => fun, + _ => unreachable!(), + }; + let called_decl = find_decl(fun.into()); + assert_eq!(called_decl.type_().unwrap(), &expected_type); + } + _ => unreachable!(), + } + } +} -- cgit 1.4.1