about summary refs log tree commit diff
path: root/src/codegen/llvm.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/codegen/llvm.rs')
-rw-r--r--src/codegen/llvm.rs76
1 files changed, 48 insertions, 28 deletions
diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs
index 1f4a457cd8..5b5db90a1a 100644
--- a/src/codegen/llvm.rs
+++ b/src/codegen/llvm.rs
@@ -7,12 +7,13 @@ use inkwell::builder::Builder;
 pub use inkwell::context::Context;
 use inkwell::module::Module;
 use inkwell::support::LLVMString;
-use inkwell::types::FunctionType;
+use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
 use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
 use inkwell::IntPredicate;
 use thiserror::Error;
 
-use crate::ast::{BinaryOperator, Binding, Decl, Expr, Fun, Ident, Literal, UnaryOperator};
+use crate::ast::hir::{Binding, Decl, Expr};
+use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
 use crate::common::env::Env;
 
 #[derive(Debug, PartialEq, Eq, Error)]
@@ -36,7 +37,7 @@ pub struct Codegen<'ctx, 'ast> {
     context: &'ctx Context,
     pub module: Module<'ctx>,
     builder: Builder<'ctx>,
-    env: Env<'ast, AnyValueEnum<'ctx>>,
+    env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>,
     function_stack: Vec<FunctionValue<'ctx>>,
     identifier_counter: u32,
 }
@@ -77,18 +78,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
             .append_basic_block(*self.function_stack.last().unwrap(), name)
     }
 
-    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> {
+    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
         match expr {
-            Expr::Ident(id) => self
+            Expr::Ident(id, _) => self
                 .env
                 .resolve(id)
                 .cloned()
                 .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
-            Expr::Literal(Literal::Int(i)) => {
-                let ty = self.context.i64_type();
-                Ok(AnyValueEnum::IntValue(ty.const_int(*i, false)))
+            Expr::Literal(lit, ty) => {
+                let ty = self.codegen_int_type(ty);
+                match lit {
+                    Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
+                    Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
+                        ty.const_int(if *b { 1 } else { 0 }, false),
+                    )),
+                }
             }
-            Expr::UnaryOp { op, rhs } => {
+            Expr::UnaryOp { op, rhs, .. } => {
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
                     UnaryOperator::Not => unimplemented!(),
@@ -97,7 +103,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     )),
                 }
             }
-            Expr::BinaryOp { lhs, op, rhs } => {
+            Expr::BinaryOp { lhs, op, rhs, .. } => {
                 let lhs = self.codegen_expr(lhs)?;
                 let rhs = self.codegen_expr(rhs)?;
                 match op {
@@ -135,7 +141,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     BinaryOperator::Neq => todo!(),
                 }
             }
-            Expr::Let { bindings, body } => {
+            Expr::Let { bindings, body, .. } => {
                 self.env.push();
                 for Binding { ident, body, .. } in bindings {
                     let val = self.codegen_expr(body)?;
@@ -149,6 +155,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 condition,
                 then,
                 else_,
+                type_,
             } => {
                 let then_block = self.append_basic_block("then");
                 let else_block = self.append_basic_block("else");
@@ -168,15 +175,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.builder.build_unconditional_branch(join_block);
 
                 self.builder.position_at_end(join_block);
-                let phi = self.builder.build_phi(self.context.i64_type(), "join");
+                let phi = self.builder.build_phi(self.codegen_type(type_), "join");
                 phi.add_incoming(&[
                     (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
                     (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
                 ]);
                 Ok(phi.as_basic_value().into())
             }
-            Expr::Call { fun, args } => {
-                if let Expr::Ident(id) = &**fun {
+            Expr::Call { fun, args, .. } => {
+                if let Expr::Ident(id, _) = &**fun {
                     let function = self
                         .module
                         .get_function(id.into())
@@ -197,8 +204,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                     todo!()
                 }
             }
-            Expr::Fun(fun) => {
-                let Fun { args, body } = &**fun;
+            Expr::Fun { args, body, .. } => {
                 let fname = self.fresh_ident("f");
                 let cur_block = self.builder.get_insert_block().unwrap();
                 let env = self.env.save(); // TODO: closures
@@ -207,29 +213,27 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
                 self.env.restore(env);
                 Ok(function.into())
             }
-            Expr::Ascription { expr, .. } => self.codegen_expr(expr),
         }
     }
 
     pub fn codegen_function(
         &mut self,
         name: &str,
-        args: &'ast [Ident<'ast>],
-        body: &'ast Expr<'ast>,
+        args: &'ast [(Ident<'ast>, Type)],
+        body: &'ast Expr<'ast, Type>,
     ) -> Result<FunctionValue<'ctx>> {
-        let i64_type = self.context.i64_type();
         self.new_function(
             name,
-            i64_type.fn_type(
+            self.codegen_type(body.type_()).fn_type(
                 args.iter()
-                    .map(|_| i64_type.into())
+                    .map(|(_, at)| self.codegen_type(at))
                     .collect::<Vec<_>>()
                     .as_slice(),
                 false,
             ),
         );
         self.env.push();
-        for (i, arg) in args.iter().enumerate() {
+        for (i, (arg, _)) in args.iter().enumerate() {
             self.env.set(
                 arg,
                 self.cur_function().get_nth_param(i as u32).unwrap().into(),
@@ -240,11 +244,10 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         Ok(self.finish_function(&res))
     }
 
-    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> {
+    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> {
         match decl {
             Decl::Fun {
-                name,
-                body: Fun { args, body },
+                name, args, body, ..
             } => {
                 self.codegen_function(name.into(), args, body)?;
                 Ok(())
@@ -252,13 +255,28 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
         }
     }
 
-    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> {
+    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
         self.new_function("main", self.context.i64_type().fn_type(&[], false));
         let res = self.codegen_expr(expr)?.try_into().unwrap();
-        self.finish_function(&res);
+        if *expr.type_() != Type::Int {
+            self.builder
+                .build_return(Some(&self.context.i64_type().const_int(0, false)));
+        } else {
+            self.finish_function(&res);
+        }
         Ok(())
     }
 
+    fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
+        // TODO
+        self.context.i64_type().into()
+    }
+
+    fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
+        // TODO
+        self.context.i64_type()
+    }
+
     pub fn print_to_file<P>(&self, path: P) -> Result<()>
     where
         P: AsRef<Path>,
@@ -299,6 +317,8 @@ mod tests {
     fn jit_eval<T>(expr: &str) -> anyhow::Result<T> {
         let expr = crate::parser::expr(expr).unwrap().1;
 
+        let expr = crate::tc::typecheck_expr(expr).unwrap();
+
         let context = Context::create();
         let mut codegen = Codegen::new(&context, "test");
         let execution_engine = codegen