about summary refs log tree commit diff
path: root/users/glittershark/achilles/src/passes/hir
diff options
context:
space:
mode:
Diffstat (limited to 'users/glittershark/achilles/src/passes/hir')
-rw-r--r--users/glittershark/achilles/src/passes/hir/mod.rs22
-rw-r--r--users/glittershark/achilles/src/passes/hir/strip_positive_units.rs189
2 files changed, 209 insertions, 2 deletions
diff --git a/users/glittershark/achilles/src/passes/hir/mod.rs b/users/glittershark/achilles/src/passes/hir/mod.rs
index fb2f64e08591..845bfcb7ab6a 100644
--- a/users/glittershark/achilles/src/passes/hir/mod.rs
+++ b/users/glittershark/achilles/src/passes/hir/mod.rs
@@ -4,6 +4,7 @@ use crate::ast::hir::{Binding, Decl, Expr};
 use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
 
 pub(crate) mod monomorphize;
+pub(crate) mod strip_positive_units;
 
 pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
     type Error;
@@ -53,7 +54,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
         Ok(())
     }
 
+    fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
     fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
+        self.pre_visit_expr(expr)?;
         match expr {
             Expr::Ident(id, t) => {
                 self.visit_ident(id)?;
@@ -140,6 +146,17 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
         Ok(())
     }
 
+    fn post_visit_fun_decl(
+        &mut self,
+        _name: &mut Ident<'ast>,
+        _type_args: &mut Vec<Ident>,
+        _args: &mut Vec<(Ident, T)>,
+        _body: &mut Box<Expr<T>>,
+        _type_: &mut T,
+    ) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
     fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> {
         match decl {
             Decl::Fun {
@@ -150,15 +167,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
                 type_,
             } => {
                 self.visit_ident(name)?;
-                for type_arg in type_args {
+                for type_arg in type_args.iter_mut() {
                     self.visit_ident(type_arg)?;
                 }
-                for (arg, t) in args {
+                for (arg, t) in args.iter_mut() {
                     self.visit_ident(arg)?;
                     self.visit_type(t)?;
                 }
                 self.visit_expr(body)?;
                 self.visit_type(type_)?;
+                self.post_visit_fun_decl(name, type_args, args, body, type_)?;
             }
             Decl::Extern {
                 name,
diff --git a/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs b/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs
new file mode 100644
index 000000000000..91b56551c82d
--- /dev/null
+++ b/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs
@@ -0,0 +1,189 @@
+use std::collections::HashMap;
+use std::mem;
+
+use ast::hir::Binding;
+use ast::Literal;
+use void::{ResultVoidExt, Void};
+
+use crate::ast::hir::{Decl, Expr};
+use crate::ast::{self, Ident};
+
+use super::Visitor;
+
+/// Strip all values with a unit type in positive (non-return) position
+pub(crate) struct StripPositiveUnits {}
+
+impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits {
+    type Error = Void;
+
+    fn pre_visit_expr(
+        &mut self,
+        expr: &mut Expr<'ast, ast::Type<'ast>>,
+    ) -> Result<(), Self::Error> {
+        let mut extracted = vec![];
+        if let Expr::Call { args, .. } = expr {
+            // TODO(grfn): replace with drain_filter once it's stabilized
+            let mut i = 0;
+            while i != args.len() {
+                if args[i].type_() == &ast::Type::Unit {
+                    let expr = args.remove(i);
+                    if !matches!(expr, Expr::Literal(Literal::Unit, _)) {
+                        extracted.push(expr)
+                    };
+                } else {
+                    i += 1
+                }
+            }
+        }
+
+        if !extracted.is_empty() {
+            let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit));
+            *expr = Expr::Let {
+                bindings: extracted
+                    .into_iter()
+                    .map(|expr| Binding {
+                        ident: Ident::from_str_unchecked("___discarded"),
+                        type_: expr.type_().clone(),
+                        body: expr,
+                    })
+                    .collect(),
+                type_: body.type_().clone(),
+                body: Box::new(body),
+            };
+        }
+
+        Ok(())
+    }
+
+    fn post_visit_call(
+        &mut self,
+        _fun: &mut Expr<'ast, ast::Type<'ast>>,
+        _type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
+        args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
+    ) -> Result<(), Self::Error> {
+        args.retain(|arg| arg.type_() != &ast::Type::Unit);
+        Ok(())
+    }
+
+    fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> {
+        if let ast::Type::Function(ft) = type_ {
+            ft.args.retain(|a| a != &ast::Type::Unit);
+        }
+        Ok(())
+    }
+
+    fn post_visit_fun_decl(
+        &mut self,
+        _name: &mut Ident<'ast>,
+        _type_args: &mut Vec<Ident>,
+        args: &mut Vec<(Ident, ast::Type<'ast>)>,
+        _body: &mut Box<Expr<ast::Type<'ast>>>,
+        _type_: &mut ast::Type<'ast>,
+    ) -> Result<(), Self::Error> {
+        args.retain(|(_, ty)| ty != &ast::Type::Unit);
+        Ok(())
+    }
+}
+
+pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
+    let mut pass = StripPositiveUnits {};
+    for decl in toplevel.iter_mut() {
+        pass.visit_decl(decl).void_unwrap();
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::parser::toplevel;
+    use crate::tc::typecheck_toplevel;
+    use pretty_assertions::assert_eq;
+
+    #[test]
+    fn unit_only_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn () -> int
+             fn f _ = 1
+
+             ty main : fn -> int
+             fn main = f ()",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn -> int
+             fn f = 1
+
+             ty main : fn -> int
+             fn main = f()",
+        )
+        .unwrap();
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+
+    #[test]
+    fn unit_and_other_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn (), int -> int
+             fn f _ x = x
+
+             ty main : fn -> int
+             fn main = f () 1",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn int -> int
+             fn f x = x
+
+             ty main : fn -> int
+             fn main = f 1",
+        )
+        .unwrap();
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+
+    #[test]
+    fn unit_expr_and_other_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn (), int -> int
+             fn f _ x = x
+
+             ty g : fn int -> ()
+             fn g _ = ()
+
+             ty main : fn -> int
+             fn main = f (g 2) 1",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn int -> int
+             fn f x = x
+
+             ty g : fn int -> ()
+             fn g _ = ()
+
+             ty main : fn -> int
+             fn main = let ___discarded = g 2 in f 1",
+        )
+        .unwrap();
+        assert_eq!(expected.len(), 6);
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+}