about summary refs log tree commit diff
path: root/users/grfn/achilles/src/passes/hir
diff options
context:
space:
mode:
Diffstat (limited to 'users/grfn/achilles/src/passes/hir')
-rw-r--r--users/grfn/achilles/src/passes/hir/mod.rs211
-rw-r--r--users/grfn/achilles/src/passes/hir/monomorphize.rs139
-rw-r--r--users/grfn/achilles/src/passes/hir/strip_positive_units.rs191
3 files changed, 541 insertions, 0 deletions
diff --git a/users/grfn/achilles/src/passes/hir/mod.rs b/users/grfn/achilles/src/passes/hir/mod.rs
new file mode 100644
index 000000000000..872c449eb020
--- /dev/null
+++ b/users/grfn/achilles/src/passes/hir/mod.rs
@@ -0,0 +1,211 @@
+use std::collections::HashMap;
+
+use crate::ast::hir::{Binding, Decl, Expr, Pattern};
+use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
+
+pub(crate) mod monomorphize;
+pub(crate) mod strip_positive_units;
+
+pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
+    type Error;
+
+    fn visit_type(&mut self, _type: &mut T) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_ident(&mut self, _ident: &mut Ident<'ast>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_literal(&mut self, _literal: &mut Literal<'ast>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_unary_operator(&mut self, _op: &mut UnaryOperator) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_binary_operator(&mut self, _op: &mut BinaryOperator) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_pattern(&mut self, _pat: &mut Pattern<'ast, T>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> {
+        self.visit_pattern(&mut binding.pat)?;
+        self.visit_expr(&mut binding.body)?;
+        Ok(())
+    }
+
+    fn post_visit_call(
+        &mut self,
+        _fun: &mut Expr<'ast, T>,
+        _type_args: &mut HashMap<Ident<'ast>, T>,
+        _args: &mut Vec<Expr<'ast, T>>,
+    ) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn pre_visit_call(
+        &mut self,
+        _fun: &mut Expr<'ast, T>,
+        _type_args: &mut HashMap<Ident<'ast>, T>,
+        _args: &mut Vec<Expr<'ast, T>>,
+    ) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_tuple(&mut self, members: &mut Vec<Expr<'ast, T>>) -> Result<(), Self::Error> {
+        for expr in members {
+            self.visit_expr(expr)?;
+        }
+        Ok(())
+    }
+
+    fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
+        self.pre_visit_expr(expr)?;
+        match expr {
+            Expr::Ident(id, t) => {
+                self.visit_ident(id)?;
+                self.visit_type(t)?;
+            }
+            Expr::Literal(lit, t) => {
+                self.visit_literal(lit)?;
+                self.visit_type(t)?;
+            }
+            Expr::UnaryOp { op, rhs, type_ } => {
+                self.visit_unary_operator(op)?;
+                self.visit_expr(rhs)?;
+                self.visit_type(type_)?;
+            }
+            Expr::BinaryOp {
+                lhs,
+                op,
+                rhs,
+                type_,
+            } => {
+                self.visit_expr(lhs)?;
+                self.visit_binary_operator(op)?;
+                self.visit_expr(rhs)?;
+                self.visit_type(type_)?;
+            }
+            Expr::Let {
+                bindings,
+                body,
+                type_,
+            } => {
+                for binding in bindings.iter_mut() {
+                    self.visit_binding(binding)?;
+                }
+                self.visit_expr(body)?;
+                self.visit_type(type_)?;
+            }
+            Expr::If {
+                condition,
+                then,
+                else_,
+                type_,
+            } => {
+                self.visit_expr(condition)?;
+                self.visit_expr(then)?;
+                self.visit_expr(else_)?;
+                self.visit_type(type_)?;
+            }
+            Expr::Fun {
+                args,
+                body,
+                type_args,
+                type_,
+            } => {
+                for (ident, t) in args {
+                    self.visit_ident(ident)?;
+                    self.visit_type(t)?;
+                }
+                for ta in type_args {
+                    self.visit_ident(ta)?;
+                }
+                self.visit_expr(body)?;
+                self.visit_type(type_)?;
+            }
+            Expr::Call {
+                fun,
+                args,
+                type_args,
+                type_,
+            } => {
+                self.pre_visit_call(fun, type_args, args)?;
+                self.visit_expr(fun)?;
+                for arg in args.iter_mut() {
+                    self.visit_expr(arg)?;
+                }
+                self.visit_type(type_)?;
+                self.post_visit_call(fun, type_args, args)?;
+            }
+            Expr::Tuple(tup, type_) => {
+                self.visit_tuple(tup)?;
+                self.visit_type(type_)?;
+            }
+        }
+
+        Ok(())
+    }
+
+    fn post_visit_decl(&mut self, _decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn post_visit_fun_decl(
+        &mut self,
+        _name: &mut Ident<'ast>,
+        _type_args: &mut Vec<Ident>,
+        _args: &mut Vec<(Ident, T)>,
+        _body: &mut Box<Expr<T>>,
+        _type_: &mut T,
+    ) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
+    fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> {
+        match decl {
+            Decl::Fun {
+                name,
+                type_args,
+                args,
+                body,
+                type_,
+            } => {
+                self.visit_ident(name)?;
+                for type_arg in type_args.iter_mut() {
+                    self.visit_ident(type_arg)?;
+                }
+                for (arg, t) in args.iter_mut() {
+                    self.visit_ident(arg)?;
+                    self.visit_type(t)?;
+                }
+                self.visit_expr(body)?;
+                self.visit_type(type_)?;
+                self.post_visit_fun_decl(name, type_args, args, body, type_)?;
+            }
+            Decl::Extern {
+                name,
+                arg_types,
+                ret_type,
+            } => {
+                self.visit_ident(name)?;
+                for arg_t in arg_types {
+                    self.visit_type(arg_t)?;
+                }
+                self.visit_type(ret_type)?;
+            }
+        }
+
+        self.post_visit_decl(decl)?;
+        Ok(())
+    }
+}
diff --git a/users/grfn/achilles/src/passes/hir/monomorphize.rs b/users/grfn/achilles/src/passes/hir/monomorphize.rs
new file mode 100644
index 000000000000..251a988f4f6f
--- /dev/null
+++ b/users/grfn/achilles/src/passes/hir/monomorphize.rs
@@ -0,0 +1,139 @@
+use std::cell::RefCell;
+use std::collections::{HashMap, HashSet};
+use std::convert::TryInto;
+use std::mem;
+
+use void::{ResultVoidExt, Void};
+
+use crate::ast::hir::{Decl, Expr};
+use crate::ast::{self, Ident};
+
+use super::Visitor;
+
+#[derive(Default)]
+pub(crate) struct Monomorphize<'a, 'ast> {
+    decls: HashMap<&'a Ident<'ast>, &'a Decl<'ast, ast::Type<'ast>>>,
+    extra_decls: Vec<Decl<'ast, ast::Type<'ast>>>,
+    remove_decls: HashSet<Ident<'ast>>,
+}
+
+impl<'a, 'ast> Monomorphize<'a, 'ast> {
+    pub(crate) fn new() -> Self {
+        Default::default()
+    }
+}
+
+impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for Monomorphize<'a, 'ast> {
+    type Error = Void;
+
+    fn post_visit_call(
+        &mut self,
+        fun: &mut Expr<'ast, ast::Type<'ast>>,
+        type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
+        args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
+    ) -> Result<(), Self::Error> {
+        let new_fun = match fun {
+            Expr::Ident(id, _) => {
+                let decl: Decl<_> = (**self.decls.get(id).unwrap()).clone();
+                let name = RefCell::new(id.to_string());
+                let type_args = mem::take(type_args);
+                let mut monomorphized = decl
+                    .traverse_type(|ty| -> Result<_, Void> {
+                        Ok(ty.clone().traverse_type_vars(|v| {
+                            let concrete = type_args.get(&v).unwrap();
+                            name.borrow_mut().push_str(&concrete.to_string());
+                            concrete.clone()
+                        }))
+                    })
+                    .void_unwrap();
+                let name: Ident = name.into_inner().try_into().unwrap();
+                if name != *id {
+                    self.remove_decls.insert(id.clone());
+                    monomorphized.set_name(name.clone());
+                    let type_ = monomorphized.type_().unwrap().clone();
+                    self.extra_decls.push(monomorphized);
+                    Some(Expr::Ident(name, type_))
+                } else {
+                    None
+                }
+            }
+            _ => todo!(),
+        };
+        if let Some(new_fun) = new_fun {
+            *fun = new_fun;
+        }
+        Ok(())
+    }
+
+    fn post_visit_decl(
+        &mut self,
+        decl: &'a Decl<'ast, ast::Type<'ast>>,
+    ) -> Result<(), Self::Error> {
+        self.decls.insert(decl.name(), decl);
+        Ok(())
+    }
+}
+
+pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
+    let mut pass = Monomorphize::new();
+    for decl in toplevel.iter_mut() {
+        pass.visit_decl(decl).void_unwrap();
+    }
+    let remove_decls = mem::take(&mut pass.remove_decls);
+    let mut extra_decls = mem::take(&mut pass.extra_decls);
+    toplevel.retain(|decl| !remove_decls.contains(decl.name()));
+    extra_decls.append(toplevel);
+    *toplevel = extra_decls;
+}
+
+#[cfg(test)]
+mod tests {
+    use std::convert::TryFrom;
+
+    use super::*;
+    use crate::parser::toplevel;
+    use crate::tc::typecheck_toplevel;
+
+    #[test]
+    fn call_id_decl() {
+        let (_, program) = toplevel(
+            "ty id : fn a -> a
+             fn id x = x
+
+             ty main : fn -> int
+             fn main = id 0",
+        )
+        .unwrap();
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        let find_decl = |ident: &str| {
+            program.iter().find(|decl| {
+                matches!(decl, Decl::Fun {name, ..} if name == &Ident::try_from(ident).unwrap())
+            }).unwrap()
+        };
+
+        let main = find_decl("main");
+        let body = match main {
+            Decl::Fun { body, .. } => body,
+            _ => unreachable!(),
+        };
+
+        let expected_type = ast::Type::Function(ast::FunctionType {
+            args: vec![ast::Type::Int],
+            ret: Box::new(ast::Type::Int),
+        });
+
+        match &**body {
+            Expr::Call { fun, .. } => {
+                let fun = match &**fun {
+                    Expr::Ident(fun, _) => fun,
+                    _ => unreachable!(),
+                };
+                let called_decl = find_decl(fun.into());
+                assert_eq!(called_decl.type_().unwrap(), &expected_type);
+            }
+            _ => unreachable!(),
+        }
+    }
+}
diff --git a/users/grfn/achilles/src/passes/hir/strip_positive_units.rs b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs
new file mode 100644
index 000000000000..85ee1cce4859
--- /dev/null
+++ b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs
@@ -0,0 +1,191 @@
+use std::collections::HashMap;
+use std::mem;
+
+use ast::hir::{Binding, Pattern};
+use ast::Literal;
+use void::{ResultVoidExt, Void};
+
+use crate::ast::hir::{Decl, Expr};
+use crate::ast::{self, Ident};
+
+use super::Visitor;
+
+/// Strip all values with a unit type in positive (non-return) position
+pub(crate) struct StripPositiveUnits {}
+
+impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits {
+    type Error = Void;
+
+    fn pre_visit_expr(
+        &mut self,
+        expr: &mut Expr<'ast, ast::Type<'ast>>,
+    ) -> Result<(), Self::Error> {
+        let mut extracted = vec![];
+        if let Expr::Call { args, .. } = expr {
+            // TODO(grfn): replace with drain_filter once it's stabilized
+            let mut i = 0;
+            while i != args.len() {
+                if args[i].type_() == &ast::Type::Unit {
+                    let expr = args.remove(i);
+                    if !matches!(expr, Expr::Literal(Literal::Unit, _)) {
+                        extracted.push(expr)
+                    };
+                } else {
+                    i += 1
+                }
+            }
+        }
+
+        if !extracted.is_empty() {
+            let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit));
+            *expr = Expr::Let {
+                bindings: extracted
+                    .into_iter()
+                    .map(|expr| Binding {
+                        pat: Pattern::Id(
+                            Ident::from_str_unchecked("___discarded"),
+                            expr.type_().clone(),
+                        ),
+                        body: expr,
+                    })
+                    .collect(),
+                type_: body.type_().clone(),
+                body: Box::new(body),
+            };
+        }
+
+        Ok(())
+    }
+
+    fn post_visit_call(
+        &mut self,
+        _fun: &mut Expr<'ast, ast::Type<'ast>>,
+        _type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
+        args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
+    ) -> Result<(), Self::Error> {
+        args.retain(|arg| arg.type_() != &ast::Type::Unit);
+        Ok(())
+    }
+
+    fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> {
+        if let ast::Type::Function(ft) = type_ {
+            ft.args.retain(|a| a != &ast::Type::Unit);
+        }
+        Ok(())
+    }
+
+    fn post_visit_fun_decl(
+        &mut self,
+        _name: &mut Ident<'ast>,
+        _type_args: &mut Vec<Ident>,
+        args: &mut Vec<(Ident, ast::Type<'ast>)>,
+        _body: &mut Box<Expr<ast::Type<'ast>>>,
+        _type_: &mut ast::Type<'ast>,
+    ) -> Result<(), Self::Error> {
+        args.retain(|(_, ty)| ty != &ast::Type::Unit);
+        Ok(())
+    }
+}
+
+pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
+    let mut pass = StripPositiveUnits {};
+    for decl in toplevel.iter_mut() {
+        pass.visit_decl(decl).void_unwrap();
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::parser::toplevel;
+    use crate::tc::typecheck_toplevel;
+    use pretty_assertions::assert_eq;
+
+    #[test]
+    fn unit_only_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn () -> int
+             fn f _ = 1
+
+             ty main : fn -> int
+             fn main = f ()",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn -> int
+             fn f = 1
+
+             ty main : fn -> int
+             fn main = f()",
+        )
+        .unwrap();
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+
+    #[test]
+    fn unit_and_other_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn (), int -> int
+             fn f _ x = x
+
+             ty main : fn -> int
+             fn main = f () 1",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn int -> int
+             fn f x = x
+
+             ty main : fn -> int
+             fn main = f 1",
+        )
+        .unwrap();
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+
+    #[test]
+    fn unit_expr_and_other_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn (), int -> int
+             fn f _ x = x
+
+             ty g : fn int -> ()
+             fn g _ = ()
+
+             ty main : fn -> int
+             fn main = f (g 2) 1",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn int -> int
+             fn f x = x
+
+             ty g : fn int -> ()
+             fn g _ = ()
+
+             ty main : fn -> int
+             fn main = let ___discarded = g 2 in f 1",
+        )
+        .unwrap();
+        assert_eq!(expected.len(), 6);
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+}