about summary refs log tree commit diff
path: root/src/codegen/llvm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/codegen/llvm.rs')
-rw-r--r--src/codegen/llvm.rs181
1 files changed, 121 insertions, 60 deletions
diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs
index ff92c33731..1d1c742a94 100644
--- a/src/codegen/llvm.rs
+++ b/src/codegen/llvm.rs
@@ -1,3 +1,4 @@
+use std::convert::{TryFrom, TryInto};
 use std::path::Path;
 use std::result;
 
@@ -7,7 +8,7 @@ pub use inkwell::context::Context;
 use inkwell::module::Module;
 use inkwell::support::LLVMString;
 use inkwell::types::FunctionType;
-use inkwell::values::{BasicValueEnum, FunctionValue};
+use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
 use inkwell::IntPredicate;
 use thiserror::Error;
 
@@ -35,8 +36,9 @@ pub struct Codegen<'ctx, 'ast> {
     context: &'ctx Context,
     pub module: Module<'ctx>,
     builder: Builder<'ctx>,
-    env: Env<'ast, BasicValueEnum<'ctx>>,
-    function: Option<FunctionValue<'ctx>>,
+    env: Env<'ast, AnyValueEnum<'ctx>>,
+    function_stack: Vec<FunctionValue<'ctx>>,
+    identifier_counter: u32,
 }
 
 impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
@@ -48,7 +50,8 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             module,
             builder,
             env: Default::default(),
-            function: None,
+            function_stack: Default::default(),
+            identifier_counter: 0,
         }
     }
 
@@ -57,22 +60,24 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         name: &str,
         ty: FunctionType<'ctx>,
     ) -> &'a FunctionValue<'ctx> {
-        self.function = Some(self.module.add_function(name, ty, None));
+        self.function_stack
+            .push(self.module.add_function(name, ty, None));
         let basic_block = self.append_basic_block("entry");
         self.builder.position_at_end(basic_block);
-        self.function.as_ref().unwrap()
+        self.function_stack.last().unwrap()
     }
 
-    pub fn finish_function(&self, res: &BasicValueEnum<'ctx>) {
+    pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> {
         self.builder.build_return(Some(res));
+        self.function_stack.pop().unwrap()
     }
 
     pub fn append_basic_block(&self, name: &str) -> BasicBlock<'ctx> {
         self.context
-            .append_basic_block(self.function.unwrap(), name)
+            .append_basic_block(*self.function_stack.last().unwrap(), name)
     }
 
-    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<BasicValueEnum<'ctx>> {
+    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> {
         match expr {
             Expr::Ident(id) => self
                 .env
@@ -81,13 +86,13 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
             Expr::Literal(Literal::Int(i)) => {
                 let ty = self.context.i64_type();
-                Ok(BasicValueEnum::IntValue(ty.const_int(*i, false)))
+                Ok(AnyValueEnum::IntValue(ty.const_int(*i, false)))
             }
             Expr::UnaryOp { op, rhs } => {
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
                     UnaryOperator::Not => unimplemented!(),
-                    UnaryOperator::Neg => Ok(BasicValueEnum::IntValue(
+                    UnaryOperator::Neg => Ok(AnyValueEnum::IntValue(
                         self.builder.build_int_neg(rhs.into_int_value(), "neg"),
                     )),
                 }
@@ -96,29 +101,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 let lhs = self.codegen_expr(lhs)?;
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
-                    BinaryOperator::Add => {
-                        Ok(BasicValueEnum::IntValue(self.builder.build_int_add(
-                            lhs.into_int_value(),
-                            rhs.into_int_value(),
-                            "add",
-                        )))
-                    }
-                    BinaryOperator::Sub => {
-                        Ok(BasicValueEnum::IntValue(self.builder.build_int_sub(
-                            lhs.into_int_value(),
-                            rhs.into_int_value(),
-                            "add",
-                        )))
-                    }
-                    BinaryOperator::Mul => {
-                        Ok(BasicValueEnum::IntValue(self.builder.build_int_sub(
-                            lhs.into_int_value(),
-                            rhs.into_int_value(),
-                            "add",
-                        )))
-                    }
+                    BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add(
+                        lhs.into_int_value(),
+                        rhs.into_int_value(),
+                        "add",
+                    ))),
+                    BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
+                        lhs.into_int_value(),
+                        rhs.into_int_value(),
+                        "add",
+                    ))),
+                    BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
+                        lhs.into_int_value(),
+                        rhs.into_int_value(),
+                        "add",
+                    ))),
                     BinaryOperator::Div => {
-                        Ok(BasicValueEnum::IntValue(self.builder.build_int_signed_div(
+                        Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div(
                             lhs.into_int_value(),
                             rhs.into_int_value(),
                             "add",
@@ -126,7 +125,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     }
                     BinaryOperator::Pow => unimplemented!(),
                     BinaryOperator::Equ => {
-                        Ok(BasicValueEnum::IntValue(self.builder.build_int_compare(
+                        Ok(AnyValueEnum::IntValue(self.builder.build_int_compare(
                             IntPredicate::EQ,
                             lhs.into_int_value(),
                             rhs.into_int_value(),
@@ -170,34 +169,83 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 
                 self.builder.position_at_end(join_block);
                 let phi = self.builder.build_phi(self.context.i64_type(), "join");
-                phi.add_incoming(&[(&then_res, then_block), (&else_res, else_block)]);
-                Ok(phi.as_basic_value())
+                phi.add_incoming(&[
+                    (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
+                    (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
+                ]);
+                Ok(phi.as_basic_value().into())
+            }
+            Expr::Call { fun, args } => {
+                if let Expr::Ident(id) = &**fun {
+                    let function = self
+                        .module
+                        .get_function(id.into())
+                        .or_else(|| self.env.resolve(id)?.clone().try_into().ok())
+                        .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?;
+                    let args = args
+                        .iter()
+                        .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap()))
+                        .collect::<Result<Vec<_>>>()?;
+                    Ok(self
+                        .builder
+                        .build_call(function, &args, "call")
+                        .try_as_basic_value()
+                        .left()
+                        .unwrap()
+                        .into())
+                } else {
+                    todo!()
+                }
+            }
+            Expr::Fun(fun) => {
+                let Fun { args, body } = &**fun;
+                let fname = self.fresh_ident("f");
+                let cur_block = self.builder.get_insert_block().unwrap();
+                let env = self.env.save(); // TODO: closures
+                let function = self.codegen_function(&fname, args, body)?;
+                self.builder.position_at_end(cur_block);
+                self.env.restore(env);
+                Ok(function.into())
             }
         }
     }
 
+    pub fn codegen_function(
+        &mut self,
+        name: &str,
+        args: &'ast [Ident<'ast>],
+        body: &'ast Expr<'ast>,
+    ) -> Result<FunctionValue<'ctx>> {
+        let i64_type = self.context.i64_type();
+        self.new_function(
+            name,
+            i64_type.fn_type(
+                args.iter()
+                    .map(|_| i64_type.into())
+                    .collect::<Vec<_>>()
+                    .as_slice(),
+                false,
+            ),
+        );
+        self.env.push();
+        for (i, arg) in args.iter().enumerate() {
+            self.env.set(
+                arg,
+                self.cur_function().get_nth_param(i as u32).unwrap().into(),
+            );
+        }
+        let res = self.codegen_expr(body)?.try_into().unwrap();
+        self.env.pop();
+        Ok(self.finish_function(&res))
+    }
+
     pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> {
         match decl {
-            Decl::Fun(Fun { name, args, body }) => {
-                let i64_type = self.context.i64_type();
-                self.new_function(
-                    name.into(),
-                    i64_type.fn_type(
-                        args.iter()
-                            .map(|_| i64_type.into())
-                            .collect::<Vec<_>>()
-                            .as_slice(),
-                        false,
-                    ),
-                );
-                self.env.push();
-                for (i, arg) in args.iter().enumerate() {
-                    self.env
-                        .set(arg, self.function.unwrap().get_nth_param(i as u32).unwrap());
-                }
-                let res = self.codegen_expr(body)?;
-                self.env.pop();
-                self.finish_function(&res);
+            Decl::Fun {
+                name,
+                body: Fun { args, body },
+            } => {
+                self.codegen_function(name.into(), args, body)?;
                 Ok(())
             }
         }
@@ -205,7 +253,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
 
     pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> {
         self.new_function("main", self.context.i64_type().fn_type(&[], false));
-        let res = self.codegen_expr(expr)?;
+        let res = self.codegen_expr(expr)?.try_into().unwrap();
         self.finish_function(&res);
         Ok(())
     }
@@ -229,6 +277,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             ))
         }
     }
+
+    fn fresh_ident(&mut self, prefix: &str) -> String {
+        self.identifier_counter += 1;
+        format!("{}{}", prefix, self.identifier_counter)
+    }
+
+    fn cur_function(&self) -> &FunctionValue<'ctx> {
+        self.function_stack.last().unwrap()
+    }
 }
 
 #[cfg(test)]
@@ -248,9 +305,7 @@ mod tests {
             .create_jit_execution_engine(OptimizationLevel::None)
             .unwrap();
 
-        codegen.new_function("test", context.i64_type().fn_type(&[], false));
-        let res = codegen.codegen_expr(&expr)?;
-        codegen.finish_function(&res);
+        codegen.codegen_function("test", &[], &expr)?;
 
         unsafe {
             let fun: JitFunction<unsafe extern "C" fn() -> T> =
@@ -279,4 +334,10 @@ mod tests {
             2
         );
     }
+
+    #[test]
+    fn function_call() {
+        let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap();
+        assert_eq!(res, 1);
+    }
 }