about summary refs log tree commit diff
path: root/users/glittershark/achilles/src
diff options
context:
space:
mode:
Diffstat (limited to 'users/glittershark/achilles/src')
-rw-r--r--users/glittershark/achilles/src/ast/hir.rs2
-rw-r--r--users/glittershark/achilles/src/ast/mod.rs7
-rw-r--r--users/glittershark/achilles/src/codegen/llvm.rs187
-rw-r--r--users/glittershark/achilles/src/compiler.rs5
-rw-r--r--users/glittershark/achilles/src/interpreter/mod.rs1
-rw-r--r--users/glittershark/achilles/src/parser/expr.rs22
-rw-r--r--users/glittershark/achilles/src/parser/mod.rs27
-rw-r--r--users/glittershark/achilles/src/parser/type_.rs2
-rw-r--r--users/glittershark/achilles/src/passes/hir/mod.rs22
-rw-r--r--users/glittershark/achilles/src/passes/hir/strip_positive_units.rs189
-rw-r--r--users/glittershark/achilles/src/tc/mod.rs10
11 files changed, 386 insertions, 88 deletions
diff --git a/users/glittershark/achilles/src/ast/hir.rs b/users/glittershark/achilles/src/ast/hir.rs
index 8726af509388..0d145d620bef 100644
--- a/users/glittershark/achilles/src/ast/hir.rs
+++ b/users/glittershark/achilles/src/ast/hir.rs
@@ -246,7 +246,7 @@ impl<'a, T> Expr<'a, T> {
     }
 }
 
-#[derive(Debug, Clone)]
+#[derive(Debug, Clone, PartialEq, Eq)]
 pub enum Decl<'a, T> {
     Fun {
         name: Ident<'a>,
diff --git a/users/glittershark/achilles/src/ast/mod.rs b/users/glittershark/achilles/src/ast/mod.rs
index 53f222a6a11a..7dc2de895709 100644
--- a/users/glittershark/achilles/src/ast/mod.rs
+++ b/users/glittershark/achilles/src/ast/mod.rs
@@ -30,6 +30,7 @@ impl<'a> Ident<'a> {
         Ident(Cow::Owned(self.0.clone().into_owned()))
     }
 
+    /// Construct an identifier from a &str without checking that it's a valid identifier
     pub fn from_str_unchecked(s: &'a str) -> Self {
         debug_assert!(is_valid_identifier(s));
         Self(Cow::Borrowed(s))
@@ -109,6 +110,7 @@ pub enum UnaryOperator {
 
 #[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Literal<'a> {
+    Unit,
     Int(u64),
     Bool(bool),
     String(Cow<'a, str>),
@@ -120,6 +122,7 @@ impl<'a> Literal<'a> {
             Literal::Int(i) => Literal::Int(*i),
             Literal::Bool(b) => Literal::Bool(*b),
             Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())),
+            Literal::Unit => Literal::Unit,
         }
     }
 }
@@ -308,6 +311,7 @@ pub enum Type<'a> {
     Float,
     Bool,
     CString,
+    Unit,
     Var(Ident<'a>),
     Function(FunctionType<'a>),
 }
@@ -319,6 +323,7 @@ impl<'a> Type<'a> {
             Type::Float => Type::Float,
             Type::Bool => Type::Bool,
             Type::CString => Type::CString,
+            Type::Unit => Type::Unit,
             Type::Var(v) => Type::Var(v.to_owned()),
             Type::Function(f) => Type::Function(f.to_owned()),
         }
@@ -374,6 +379,7 @@ impl<'a> Type<'a> {
             Type::Float => Type::Float,
             Type::Bool => Type::Bool,
             Type::CString => Type::CString,
+            Type::Unit => Type::Unit,
         }
     }
 }
@@ -385,6 +391,7 @@ impl<'a> Display for Type<'a> {
             Type::Float => f.write_str("float"),
             Type::Bool => f.write_str("bool"),
             Type::CString => f.write_str("cstring"),
+            Type::Unit => f.write_str("()"),
             Type::Var(v) => v.fmt(f),
             Type::Function(ft) => ft.fmt(f),
         }
diff --git a/users/glittershark/achilles/src/codegen/llvm.rs b/users/glittershark/achilles/src/codegen/llvm.rs
index f49e084a8174..17dec58b5ff7 100644
--- a/users/glittershark/achilles/src/codegen/llvm.rs
+++ b/users/glittershark/achilles/src/codegen/llvm.rs
@@ -68,8 +68,12 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         self.function_stack.last().unwrap()
     }
 
-    pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> {
-        self.builder.build_return(Some(res));
+    pub fn finish_function(&mut self, res: Option<&BasicValueEnum<'ctx>>) -> FunctionValue<'ctx> {
+        self.builder.build_return(match res {
+            // lol
+            Some(val) => Some(val),
+            None => None,
+        });
         self.function_stack.pop().unwrap()
     }
 
@@ -78,79 +82,92 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             .append_basic_block(*self.function_stack.last().unwrap(), name)
     }
 
-    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
+    pub fn codegen_expr(
+        &mut self,
+        expr: &'ast Expr<'ast, Type>,
+    ) -> Result<Option<AnyValueEnum<'ctx>>> {
         match expr {
             Expr::Ident(id, _) => self
                 .env
                 .resolve(id)
                 .cloned()
-                .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
+                .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))
+                .map(Some),
             Expr::Literal(lit, ty) => {
                 let ty = self.codegen_int_type(ty);
                 match lit {
-                    Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
-                    Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
+                    Literal::Int(i) => Ok(Some(AnyValueEnum::IntValue(ty.const_int(*i, false)))),
+                    Literal::Bool(b) => Ok(Some(AnyValueEnum::IntValue(
                         ty.const_int(if *b { 1 } else { 0 }, false),
+                    ))),
+                    Literal::String(s) => Ok(Some(
+                        self.builder
+                            .build_global_string_ptr(s, "s")
+                            .as_pointer_value()
+                            .into(),
                     )),
-                    Literal::String(s) => Ok(self
-                        .builder
-                        .build_global_string_ptr(s, "s")
-                        .as_pointer_value()
-                        .into()),
+                    Literal::Unit => Ok(None),
                 }
             }
             Expr::UnaryOp { op, rhs, .. } => {
-                let rhs = self.codegen_expr(rhs)?;
+                let rhs = self.codegen_expr(rhs)?.unwrap();
                 match op {
                     UnaryOperator::Not => unimplemented!(),
-                    UnaryOperator::Neg => Ok(AnyValueEnum::IntValue(
+                    UnaryOperator::Neg => Ok(Some(AnyValueEnum::IntValue(
                         self.builder.build_int_neg(rhs.into_int_value(), "neg"),
-                    )),
+                    ))),
                 }
             }
             Expr::BinaryOp { lhs, op, rhs, .. } => {
-                let lhs = self.codegen_expr(lhs)?;
-                let rhs = self.codegen_expr(rhs)?;
+                let lhs = self.codegen_expr(lhs)?.unwrap();
+                let rhs = self.codegen_expr(rhs)?.unwrap();
                 match op {
-                    BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add(
-                        lhs.into_int_value(),
-                        rhs.into_int_value(),
-                        "add",
-                    ))),
-                    BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
-                        lhs.into_int_value(),
-                        rhs.into_int_value(),
-                        "add",
-                    ))),
-                    BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
-                        lhs.into_int_value(),
-                        rhs.into_int_value(),
-                        "add",
-                    ))),
-                    BinaryOperator::Div => {
-                        Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div(
+                    BinaryOperator::Add => {
+                        Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_add(
                             lhs.into_int_value(),
                             rhs.into_int_value(),
                             "add",
-                        )))
+                        ))))
                     }
+                    BinaryOperator::Sub => {
+                        Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub(
+                            lhs.into_int_value(),
+                            rhs.into_int_value(),
+                            "add",
+                        ))))
+                    }
+                    BinaryOperator::Mul => {
+                        Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub(
+                            lhs.into_int_value(),
+                            rhs.into_int_value(),
+                            "add",
+                        ))))
+                    }
+                    BinaryOperator::Div => Ok(Some(AnyValueEnum::IntValue(
+                        self.builder.build_int_signed_div(
+                            lhs.into_int_value(),
+                            rhs.into_int_value(),
+                            "add",
+                        ),
+                    ))),
                     BinaryOperator::Pow => unimplemented!(),
-                    BinaryOperator::Equ => {
-                        Ok(AnyValueEnum::IntValue(self.builder.build_int_compare(
+                    BinaryOperator::Equ => Ok(Some(AnyValueEnum::IntValue(
+                        self.builder.build_int_compare(
                             IntPredicate::EQ,
                             lhs.into_int_value(),
                             rhs.into_int_value(),
                             "eq",
-                        )))
-                    }
+                        ),
+                    ))),
                     BinaryOperator::Neq => todo!(),
                 }
             }
             Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
-                    let val = self.codegen_expr(body)?;
-                    self.env.set(ident, val);
+                    if let Some(val) = self.codegen_expr(body)? {
+                        self.env.set(ident, val);
+                    }
                 }
                 let res = self.codegen_expr(body);
                 self.env.pop();
@@ -165,7 +182,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 let then_block = self.append_basic_block("then");
                 let else_block = self.append_basic_block("else");
                 let join_block = self.append_basic_block("join");
-                let condition = self.codegen_expr(condition)?;
+                let condition = self.codegen_expr(condition)?.unwrap();
                 self.builder.build_conditional_branch(
                     condition.into_int_value(),
                     then_block,
@@ -180,12 +197,22 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.builder.build_unconditional_branch(join_block);
 
                 self.builder.position_at_end(join_block);
-                let phi = self.builder.build_phi(self.codegen_type(type_), "join");
-                phi.add_incoming(&[
-                    (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
-                    (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
-                ]);
-                Ok(phi.as_basic_value().into())
+                if let Some(phi_type) = self.codegen_type(type_) {
+                    let phi = self.builder.build_phi(phi_type, "join");
+                    phi.add_incoming(&[
+                        (
+                            &BasicValueEnum::try_from(then_res.unwrap()).unwrap(),
+                            then_block,
+                        ),
+                        (
+                            &BasicValueEnum::try_from(else_res.unwrap()).unwrap(),
+                            else_block,
+                        ),
+                    ]);
+                    Ok(Some(phi.as_basic_value().into()))
+                } else {
+                    Ok(None)
+                }
             }
             Expr::Call { fun, args, .. } => {
                 if let Expr::Ident(id, _) = &**fun {
@@ -196,15 +223,14 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                         .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?;
                     let args = args
                         .iter()
-                        .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap()))
+                        .map(|arg| Ok(self.codegen_expr(arg)?.unwrap().try_into().unwrap()))
                         .collect::<Result<Vec<_>>>()?;
                     Ok(self
                         .builder
                         .build_call(function, &args, "call")
                         .try_as_basic_value()
                         .left()
-                        .unwrap()
-                        .into())
+                        .map(|val| val.into()))
                 } else {
                     todo!()
                 }
@@ -216,7 +242,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 let function = self.codegen_function(&fname, args, body)?;
                 self.builder.position_at_end(cur_block);
                 self.env.restore(env);
-                Ok(function.into())
+                Ok(Some(function.into()))
             }
         }
     }
@@ -227,15 +253,17 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         args: &'ast [(Ident<'ast>, Type)],
         body: &'ast Expr<'ast, Type>,
     ) -> Result<FunctionValue<'ctx>> {
+        let arg_types = args
+            .iter()
+            .filter_map(|(_, at)| self.codegen_type(at))
+            .collect::<Vec<_>>();
+
         self.new_function(
             name,
-            self.codegen_type(body.type_()).fn_type(
-                args.iter()
-                    .map(|(_, at)| self.codegen_type(at))
-                    .collect::<Vec<_>>()
-                    .as_slice(),
-                false,
-            ),
+            match self.codegen_type(body.type_()) {
+                Some(ret_ty) => ret_ty.fn_type(&arg_types, false),
+                None => self.context.void_type().fn_type(&arg_types, false),
+            },
         );
         self.env.push();
         for (i, (arg, _)) in args.iter().enumerate() {
@@ -244,9 +272,9 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.cur_function().get_nth_param(i as u32).unwrap().into(),
             );
         }
-        let res = self.codegen_expr(body)?.try_into().unwrap();
+        let res = self.codegen_expr(body)?;
         self.env.pop();
-        Ok(self.finish_function(&res))
+        Ok(self.finish_function(res.map(|av| av.try_into().unwrap()).as_ref()))
     }
 
     pub fn codegen_extern(
@@ -255,15 +283,16 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         args: &'ast [Type],
         ret: &'ast Type,
     ) -> Result<()> {
+        let arg_types = args
+            .iter()
+            .map(|t| self.codegen_type(t).unwrap())
+            .collect::<Vec<_>>();
         self.module.add_function(
             name,
-            self.codegen_type(ret).fn_type(
-                &args
-                    .iter()
-                    .map(|t| self.codegen_type(t))
-                    .collect::<Vec<_>>(),
-                false,
-            ),
+            match self.codegen_type(ret) {
+                Some(ret_ty) => ret_ty.fn_type(&arg_types, false),
+                None => self.context.void_type().fn_type(&arg_types, false),
+            },
             None,
         );
         Ok(())
@@ -287,29 +316,31 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 
     pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
         self.new_function("main", self.context.i64_type().fn_type(&[], false));
-        let res = self.codegen_expr(expr)?.try_into().unwrap();
+        let res = self.codegen_expr(expr)?;
         if *expr.type_() != Type::Int {
             self.builder
                 .build_return(Some(&self.context.i64_type().const_int(0, false)));
         } else {
-            self.finish_function(&res);
+            self.finish_function(res.map(|r| r.try_into().unwrap()).as_ref());
         }
         Ok(())
     }
 
-    fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
+    fn codegen_type(&self, type_: &'ast Type) -> Option<BasicTypeEnum<'ctx>> {
         // TODO
         match type_ {
-            Type::Int => self.context.i64_type().into(),
-            Type::Float => self.context.f64_type().into(),
-            Type::Bool => self.context.bool_type().into(),
-            Type::CString => self
-                .context
-                .i8_type()
-                .ptr_type(AddressSpace::Generic)
-                .into(),
+            Type::Int => Some(self.context.i64_type().into()),
+            Type::Float => Some(self.context.f64_type().into()),
+            Type::Bool => Some(self.context.bool_type().into()),
+            Type::CString => Some(
+                self.context
+                    .i8_type()
+                    .ptr_type(AddressSpace::Generic)
+                    .into(),
+            ),
             Type::Function(_) => todo!(),
             Type::Var(_) => unreachable!(),
+            Type::Unit => None,
         }
     }
 
diff --git a/users/glittershark/achilles/src/compiler.rs b/users/glittershark/achilles/src/compiler.rs
index 7001e5a9a384..45b215473d7f 100644
--- a/users/glittershark/achilles/src/compiler.rs
+++ b/users/glittershark/achilles/src/compiler.rs
@@ -8,7 +8,7 @@ use test_strategy::Arbitrary;
 
 use crate::codegen::{self, Codegen};
 use crate::common::Result;
-use crate::passes::hir::monomorphize;
+use crate::passes::hir::{monomorphize, strip_positive_units};
 use crate::{parser, tc};
 
 #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)]
@@ -55,9 +55,10 @@ pub struct CompilerOptions {
 
 pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> {
     let src = fs::read_to_string(input)?;
-    let (_, decls) = parser::toplevel(&src)?; // TODO: statements
+    let (_, decls) = parser::toplevel(&src)?;
     let mut decls = tc::typecheck_toplevel(decls)?;
     monomorphize::run_toplevel(&mut decls);
+    strip_positive_units::run_toplevel(&mut decls);
 
     let context = codegen::Context::create();
     let mut codegen = Codegen::new(
diff --git a/users/glittershark/achilles/src/interpreter/mod.rs b/users/glittershark/achilles/src/interpreter/mod.rs
index bcd474b3abe1..a8ba2dd3acdc 100644
--- a/users/glittershark/achilles/src/interpreter/mod.rs
+++ b/users/glittershark/achilles/src/interpreter/mod.rs
@@ -30,6 +30,7 @@ impl<'a> Interpreter<'a> {
             Expr::Literal(Literal::Int(i), _) => Ok((*i).into()),
             Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()),
             Expr::Literal(Literal::String(s), _) => Ok(s.clone().into()),
+            Expr::Literal(Literal::Unit, _) => unreachable!(),
             Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.eval(rhs)?;
                 match op {
diff --git a/users/glittershark/achilles/src/parser/expr.rs b/users/glittershark/achilles/src/parser/expr.rs
index 99c8018fd00c..8a28d00984c9 100644
--- a/users/glittershark/achilles/src/parser/expr.rs
+++ b/users/glittershark/achilles/src/parser/expr.rs
@@ -186,7 +186,9 @@ named!(string(&str) -> Literal, preceded!(
     )
 ));
 
-named!(literal(&str) -> Literal, alt!(int | bool_ | string));
+named!(unit(&str) -> Literal, map!(complete!(tag!("()")), |_| Literal::Unit));
+
+named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit));
 
 named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal));
 
@@ -270,7 +272,6 @@ named!(funcref(&str) -> Expr, alt!(
 
 named!(no_arg_call(&str) -> Expr, do_parse!(
     fun: funcref
-        >> multispace0
         >> complete!(tag!("()"))
         >> (Expr::Call {
             fun: Box::new(fun),
@@ -432,6 +433,11 @@ pub(crate) mod tests {
     }
 
     #[test]
+    fn unit() {
+        assert_eq!(test_parse!(expr, "()"), Expr::Literal(Literal::Unit));
+    }
+
+    #[test]
     fn bools() {
         assert_eq!(
             test_parse!(expr, "true"),
@@ -516,6 +522,18 @@ pub(crate) mod tests {
     }
 
     #[test]
+    fn unit_call() {
+        let res = test_parse!(expr, "f ()");
+        assert_eq!(
+            res,
+            Expr::Call {
+                fun: ident_expr("f"),
+                args: vec![Expr::Literal(Literal::Unit)]
+            }
+        )
+    }
+
+    #[test]
     fn call_with_args() {
         let res = test_parse!(expr, "f x 1");
         assert_eq!(
diff --git a/users/glittershark/achilles/src/parser/mod.rs b/users/glittershark/achilles/src/parser/mod.rs
index 652b083fdae5..3e0081bd391d 100644
--- a/users/glittershark/achilles/src/parser/mod.rs
+++ b/users/glittershark/achilles/src/parser/mod.rs
@@ -1,9 +1,9 @@
 use nom::character::complete::{multispace0, multispace1};
 use nom::error::{ErrorKind, ParseError};
-use nom::{alt, char, complete, do_parse, many0, named, separated_list0, tag, terminated};
+use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag, terminated};
 
 #[macro_use]
-mod macros;
+pub(crate) mod macros;
 mod expr;
 mod type_;
 
@@ -136,7 +136,11 @@ named!(pub decl(&str) -> Decl, alt!(
     extern_decl
 ));
 
-named!(pub toplevel(&str) -> Vec<Decl>, terminated!(many0!(decl), multispace0));
+named!(pub toplevel(&str) -> Vec<Decl>, do_parse!(
+    decls: many0!(decl)
+        >> multispace0
+        >> eof!()
+        >> (decls)));
 
 #[cfg(test)]
 mod tests {
@@ -215,4 +219,21 @@ mod tests {
             }]
         )
     }
+
+    #[test]
+    fn return_unit() {
+        assert_eq!(
+            test_parse!(decl, "fn g _ = ()"),
+            Decl::Fun {
+                name: "g".try_into().unwrap(),
+                body: Fun {
+                    args: vec![Arg {
+                        ident: "_".try_into().unwrap(),
+                        type_: None,
+                    }],
+                    body: Expr::Literal(Literal::Unit),
+                },
+            }
+        )
+    }
 }
diff --git a/users/glittershark/achilles/src/parser/type_.rs b/users/glittershark/achilles/src/parser/type_.rs
index 1e6e380bb823..8a1081e2521f 100644
--- a/users/glittershark/achilles/src/parser/type_.rs
+++ b/users/glittershark/achilles/src/parser/type_.rs
@@ -29,6 +29,7 @@ named!(pub type_(&str) -> Type, alt!(
     tag!("float") => { |_| Type::Float } |
     tag!("bool") => { |_| Type::Bool } |
     tag!("cstring") => { |_| Type::CString } |
+    tag!("()") => { |_| Type::Unit } |
     function_type => { |ft| Type::Function(ft) }|
     ident => { |id| Type::Var(id) } |
     delimited!(
@@ -51,6 +52,7 @@ mod tests {
         assert_eq!(test_parse!(type_, "float"), Type::Float);
         assert_eq!(test_parse!(type_, "bool"), Type::Bool);
         assert_eq!(test_parse!(type_, "cstring"), Type::CString);
+        assert_eq!(test_parse!(type_, "()"), Type::Unit);
     }
 
     #[test]
diff --git a/users/glittershark/achilles/src/passes/hir/mod.rs b/users/glittershark/achilles/src/passes/hir/mod.rs
index fb2f64e08591..845bfcb7ab6a 100644
--- a/users/glittershark/achilles/src/passes/hir/mod.rs
+++ b/users/glittershark/achilles/src/passes/hir/mod.rs
@@ -4,6 +4,7 @@ use crate::ast::hir::{Binding, Decl, Expr};
 use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator};
 
 pub(crate) mod monomorphize;
+pub(crate) mod strip_positive_units;
 
 pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
     type Error;
@@ -53,7 +54,12 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
         Ok(())
     }
 
+    fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
     fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> {
+        self.pre_visit_expr(expr)?;
         match expr {
             Expr::Ident(id, t) => {
                 self.visit_ident(id)?;
@@ -140,6 +146,17 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
         Ok(())
     }
 
+    fn post_visit_fun_decl(
+        &mut self,
+        _name: &mut Ident<'ast>,
+        _type_args: &mut Vec<Ident>,
+        _args: &mut Vec<(Ident, T)>,
+        _body: &mut Box<Expr<T>>,
+        _type_: &mut T,
+    ) -> Result<(), Self::Error> {
+        Ok(())
+    }
+
     fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> {
         match decl {
             Decl::Fun {
@@ -150,15 +167,16 @@ pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a {
                 type_,
             } => {
                 self.visit_ident(name)?;
-                for type_arg in type_args {
+                for type_arg in type_args.iter_mut() {
                     self.visit_ident(type_arg)?;
                 }
-                for (arg, t) in args {
+                for (arg, t) in args.iter_mut() {
                     self.visit_ident(arg)?;
                     self.visit_type(t)?;
                 }
                 self.visit_expr(body)?;
                 self.visit_type(type_)?;
+                self.post_visit_fun_decl(name, type_args, args, body, type_)?;
             }
             Decl::Extern {
                 name,
diff --git a/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs b/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs
new file mode 100644
index 000000000000..91b56551c82d
--- /dev/null
+++ b/users/glittershark/achilles/src/passes/hir/strip_positive_units.rs
@@ -0,0 +1,189 @@
+use std::collections::HashMap;
+use std::mem;
+
+use ast::hir::Binding;
+use ast::Literal;
+use void::{ResultVoidExt, Void};
+
+use crate::ast::hir::{Decl, Expr};
+use crate::ast::{self, Ident};
+
+use super::Visitor;
+
+/// Strip all values with a unit type in positive (non-return) position
+pub(crate) struct StripPositiveUnits {}
+
+impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits {
+    type Error = Void;
+
+    fn pre_visit_expr(
+        &mut self,
+        expr: &mut Expr<'ast, ast::Type<'ast>>,
+    ) -> Result<(), Self::Error> {
+        let mut extracted = vec![];
+        if let Expr::Call { args, .. } = expr {
+            // TODO(grfn): replace with drain_filter once it's stabilized
+            let mut i = 0;
+            while i != args.len() {
+                if args[i].type_() == &ast::Type::Unit {
+                    let expr = args.remove(i);
+                    if !matches!(expr, Expr::Literal(Literal::Unit, _)) {
+                        extracted.push(expr)
+                    };
+                } else {
+                    i += 1
+                }
+            }
+        }
+
+        if !extracted.is_empty() {
+            let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit));
+            *expr = Expr::Let {
+                bindings: extracted
+                    .into_iter()
+                    .map(|expr| Binding {
+                        ident: Ident::from_str_unchecked("___discarded"),
+                        type_: expr.type_().clone(),
+                        body: expr,
+                    })
+                    .collect(),
+                type_: body.type_().clone(),
+                body: Box::new(body),
+            };
+        }
+
+        Ok(())
+    }
+
+    fn post_visit_call(
+        &mut self,
+        _fun: &mut Expr<'ast, ast::Type<'ast>>,
+        _type_args: &mut HashMap<Ident<'ast>, ast::Type<'ast>>,
+        args: &mut Vec<Expr<'ast, ast::Type<'ast>>>,
+    ) -> Result<(), Self::Error> {
+        args.retain(|arg| arg.type_() != &ast::Type::Unit);
+        Ok(())
+    }
+
+    fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> {
+        if let ast::Type::Function(ft) = type_ {
+            ft.args.retain(|a| a != &ast::Type::Unit);
+        }
+        Ok(())
+    }
+
+    fn post_visit_fun_decl(
+        &mut self,
+        _name: &mut Ident<'ast>,
+        _type_args: &mut Vec<Ident>,
+        args: &mut Vec<(Ident, ast::Type<'ast>)>,
+        _body: &mut Box<Expr<ast::Type<'ast>>>,
+        _type_: &mut ast::Type<'ast>,
+    ) -> Result<(), Self::Error> {
+        args.retain(|(_, ty)| ty != &ast::Type::Unit);
+        Ok(())
+    }
+}
+
+pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec<Decl<'a, ast::Type<'a>>>) {
+    let mut pass = StripPositiveUnits {};
+    for decl in toplevel.iter_mut() {
+        pass.visit_decl(decl).void_unwrap();
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use crate::parser::toplevel;
+    use crate::tc::typecheck_toplevel;
+    use pretty_assertions::assert_eq;
+
+    #[test]
+    fn unit_only_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn () -> int
+             fn f _ = 1
+
+             ty main : fn -> int
+             fn main = f ()",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn -> int
+             fn f = 1
+
+             ty main : fn -> int
+             fn main = f()",
+        )
+        .unwrap();
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+
+    #[test]
+    fn unit_and_other_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn (), int -> int
+             fn f _ x = x
+
+             ty main : fn -> int
+             fn main = f () 1",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn int -> int
+             fn f x = x
+
+             ty main : fn -> int
+             fn main = f 1",
+        )
+        .unwrap();
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+
+    #[test]
+    fn unit_expr_and_other_arg() {
+        let (_, program) = toplevel(
+            "ty f : fn (), int -> int
+             fn f _ x = x
+
+             ty g : fn int -> ()
+             fn g _ = ()
+
+             ty main : fn -> int
+             fn main = f (g 2) 1",
+        )
+        .unwrap();
+
+        let (_, expected) = toplevel(
+            "ty f : fn int -> int
+             fn f x = x
+
+             ty g : fn int -> ()
+             fn g _ = ()
+
+             ty main : fn -> int
+             fn main = let ___discarded = g 2 in f 1",
+        )
+        .unwrap();
+        assert_eq!(expected.len(), 6);
+        let expected = typecheck_toplevel(expected).unwrap();
+
+        let mut program = typecheck_toplevel(program).unwrap();
+        run_toplevel(&mut program);
+
+        assert_eq!(program, expected);
+    }
+}
diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs
index 137561978f00..d27c45075e97 100644
--- a/users/glittershark/achilles/src/tc/mod.rs
+++ b/users/glittershark/achilles/src/tc/mod.rs
@@ -85,6 +85,7 @@ pub enum Type {
     Exist(TyVar),
     Nullary(NullaryType),
     Prim(PrimType),
+    Unit,
     Fun {
         args: Vec<Type>,
         ret: Box<Type>,
@@ -96,6 +97,7 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
 
     fn try_from(value: Type) -> result::Result<Self, Self::Error> {
         match value {
+            Type::Unit => Ok(ast::Type::Unit),
             Type::Univ(_) => todo!(),
             Type::Exist(_) => Err(value),
             Type::Nullary(_) => todo!(),
@@ -126,6 +128,7 @@ impl Display for Type {
             Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
             Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
             Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+            Type::Unit => write!(f, "()"),
         }
     }
 }
@@ -171,6 +174,7 @@ impl<'ast> Typechecker<'ast> {
                     Literal::Int(_) => Type::Prim(PrimType::Int),
                     Literal::Bool(_) => Type::Prim(PrimType::Bool),
                     Literal::String(_) => Type::Prim(PrimType::CString),
+                    Literal::Unit => Type::Unit,
                 };
                 Ok(hir::Expr::Literal(lit.to_owned(), type_))
             }
@@ -377,6 +381,7 @@ impl<'ast> Typechecker<'ast> {
 
     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
         match (ty1, ty2) {
+            (Type::Unit, Type::Unit) => Ok(Type::Unit),
             (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
                 Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
                 Some(var @ ast::Type::Var(_)) => {
@@ -466,6 +471,7 @@ impl<'ast> Typechecker<'ast> {
         let ret = match ty {
             Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
             Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
+            Type::Unit => Ok(ast::Type::Unit),
             Type::Nullary(_) => todo!(),
             Type::Prim(pr) => Ok(pr.into()),
             Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
@@ -496,6 +502,7 @@ impl<'ast> Typechecker<'ast> {
                 }
                 Type::Nullary(_) => todo!(),
                 Type::Prim(pr) => break Some((*pr).into()),
+                Type::Unit => break Some(ast::Type::Unit),
                 Type::Fun { args, ret } => todo!(),
             }
         }
@@ -503,6 +510,7 @@ impl<'ast> Typechecker<'ast> {
 
     fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
         match ast_type {
+            ast::Type::Unit => Type::Unit,
             ast::Type::Int => INT,
             ast::Type::Float => FLOAT,
             ast::Type::Bool => BOOL,
@@ -570,6 +578,8 @@ impl<'ast> Typechecker<'ast> {
             }
             (Type::Univ(_), _) => false,
             (Type::Exist(_), _) => false,
+            (Type::Unit, ast::Type::Unit) => true,
+            (Type::Unit, _) => false,
             (Type::Nullary(_), _) => todo!(),
             (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
             (Type::Fun { args, ret }, ast::Type::Function(ft)) => {