about summary refs log tree commit diff
path: root/users/glittershark/achilles/src/codegen/llvm.rs
use std::convert::{TryFrom, TryInto};
use std::path::Path;
use std::result;

use inkwell::basic_block::BasicBlock;
use inkwell::builder::Builder;
pub use inkwell::context::Context;
use inkwell::module::Module;
use inkwell::support::LLVMString;
use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType};
use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue};
use inkwell::{AddressSpace, IntPredicate};
use thiserror::Error;

use crate::ast::hir::{Binding, Decl, Expr};
use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator};
use crate::common::env::Env;

#[derive(Debug, PartialEq, Eq, Error)]
pub enum Error {
    #[error("Undefined variable {0}")]
    UndefinedVariable(Ident<'static>),

    #[error("LLVM Error: {0}")]
    LLVMError(String),
}

impl From<LLVMString> for Error {
    fn from(s: LLVMString) -> Self {
        Self::LLVMError(s.to_string())
    }
}

pub type Result<T> = result::Result<T, Error>;

pub struct Codegen<'ctx, 'ast> {
    context: &'ctx Context,
    pub module: Module<'ctx>,
    builder: Builder<'ctx>,
    env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>,
    function_stack: Vec<FunctionValue<'ctx>>,
    identifier_counter: u32,
}

impl<'ctx, 'ast> Codegen<'ctx, 'ast> {
    pub fn new(context: &'ctx Context, module_name: &str) -> Self {
        let module = context.create_module(module_name);
        let builder = context.create_builder();
        Self {
            context,
            module,
            builder,
            env: Default::default(),
            function_stack: Default::default(),
            identifier_counter: 0,
        }
    }

    pub fn new_function<'a>(
        &'a mut self,
        name: &str,
        ty: FunctionType<'ctx>,
    ) -> &'a FunctionValue<'ctx> {
        self.function_stack
            .push(self.module.add_function(name, ty, None));
        let basic_block = self.append_basic_block("entry");
        self.builder.position_at_end(basic_block);
        self.function_stack.last().unwrap()
    }

    pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> {
        self.builder.build_return(Some(res));
        self.function_stack.pop().unwrap()
    }

    pub fn append_basic_block(&self, name: &str) -> BasicBlock<'ctx> {
        self.context
            .append_basic_block(*self.function_stack.last().unwrap(), name)
    }

    pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<AnyValueEnum<'ctx>> {
        match expr {
            Expr::Ident(id, _) => self
                .env
                .resolve(id)
                .cloned()
                .ok_or_else(|| Error::UndefinedVariable(id.to_owned())),
            Expr::Literal(lit, ty) => {
                let ty = self.codegen_int_type(ty);
                match lit {
                    Literal::Int(i) => Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))),
                    Literal::Bool(b) => Ok(AnyValueEnum::IntValue(
                        ty.const_int(if *b { 1 } else { 0 }, false),
                    )),
                    Literal::String(s) => Ok(self
                        .builder
                        .build_global_string_ptr(s, "s")
                        .as_pointer_value()
                        .into()),
                }
            }
            Expr::UnaryOp { op, rhs, .. } => {
                let rhs = self.codegen_expr(rhs)?;
                match op {
                    UnaryOperator::Not => unimplemented!(),
                    UnaryOperator::Neg => Ok(AnyValueEnum::IntValue(
                        self.builder.build_int_neg(rhs.into_int_value(), "neg"),
                    )),
                }
            }
            Expr::BinaryOp { lhs, op, rhs, .. } => {
                let lhs = self.codegen_expr(lhs)?;
                let rhs = self.codegen_expr(rhs)?;
                match op {
                    BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add(
                        lhs.into_int_value(),
                        rhs.into_int_value(),
                        "add",
                    ))),
                    BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
                        lhs.into_int_value(),
                        rhs.into_int_value(),
                        "add",
                    ))),
                    BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub(
                        lhs.into_int_value(),
                        rhs.into_int_value(),
                        "add",
                    ))),
                    BinaryOperator::Div => {
                        Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div(
                            lhs.into_int_value(),
                            rhs.into_int_value(),
                            "add",
                        )))
                    }
                    BinaryOperator::Pow => unimplemented!(),
                    BinaryOperator::Equ => {
                        Ok(AnyValueEnum::IntValue(self.builder.build_int_compare(
                            IntPredicate::EQ,
                            lhs.into_int_value(),
                            rhs.into_int_value(),
                            "eq",
                        )))
                    }
                    BinaryOperator::Neq => todo!(),
                }
            }
            Expr::Let { bindings, body, .. } => {
                self.env.push();
                for Binding { ident, body, .. } in bindings {
                    let val = self.codegen_expr(body)?;
                    self.env.set(ident, val);
                }
                let res = self.codegen_expr(body);
                self.env.pop();
                res
            }
            Expr::If {
                condition,
                then,
                else_,
                type_,
            } => {
                let then_block = self.append_basic_block("then");
                let else_block = self.append_basic_block("else");
                let join_block = self.append_basic_block("join");
                let condition = self.codegen_expr(condition)?;
                self.builder.build_conditional_branch(
                    condition.into_int_value(),
                    then_block,
                    else_block,
                );
                self.builder.position_at_end(then_block);
                let then_res = self.codegen_expr(then)?;
                self.builder.build_unconditional_branch(join_block);

                self.builder.position_at_end(else_block);
                let else_res = self.codegen_expr(else_)?;
                self.builder.build_unconditional_branch(join_block);

                self.builder.position_at_end(join_block);
                let phi = self.builder.build_phi(self.codegen_type(type_), "join");
                phi.add_incoming(&[
                    (&BasicValueEnum::try_from(then_res).unwrap(), then_block),
                    (&BasicValueEnum::try_from(else_res).unwrap(), else_block),
                ]);
                Ok(phi.as_basic_value().into())
            }
            Expr::Call { fun, args, .. } => {
                if let Expr::Ident(id, _) = &**fun {
                    let function = self
                        .module
                        .get_function(id.into())
                        .or_else(|| self.env.resolve(id)?.clone().try_into().ok())
                        .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?;
                    let args = args
                        .iter()
                        .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap()))
                        .collect::<Result<Vec<_>>>()?;
                    Ok(self
                        .builder
                        .build_call(function, &args, "call")
                        .try_as_basic_value()
                        .left()
                        .unwrap()
                        .into())
                } else {
                    todo!()
                }
            }
            Expr::Fun { args, body, .. } => {
                let fname = self.fresh_ident("f");
                let cur_block = self.builder.get_insert_block().unwrap();
                let env = self.env.save(); // TODO: closures
                let function = self.codegen_function(&fname, args, body)?;
                self.builder.position_at_end(cur_block);
                self.env.restore(env);
                Ok(function.into())
            }
        }
    }

    pub fn codegen_function(
        &mut self,
        name: &str,
        args: &'ast [(Ident<'ast>, Type)],
        body: &'ast Expr<'ast, Type>,
    ) -> Result<FunctionValue<'ctx>> {
        self.new_function(
            name,
            self.codegen_type(body.type_()).fn_type(
                args.iter()
                    .map(|(_, at)| self.codegen_type(at))
                    .collect::<Vec<_>>()
                    .as_slice(),
                false,
            ),
        );
        self.env.push();
        for (i, (arg, _)) in args.iter().enumerate() {
            self.env.set(
                arg,
                self.cur_function().get_nth_param(i as u32).unwrap().into(),
            );
        }
        let res = self.codegen_expr(body)?.try_into().unwrap();
        self.env.pop();
        Ok(self.finish_function(&res))
    }

    pub fn codegen_extern(
        &mut self,
        name: &str,
        args: &'ast [Type],
        ret: &'ast Type,
    ) -> Result<()> {
        self.module.add_function(
            name,
            self.codegen_type(ret).fn_type(
                &args
                    .iter()
                    .map(|t| self.codegen_type(t))
                    .collect::<Vec<_>>(),
                false,
            ),
            None,
        );
        Ok(())
    }

    pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> {
        match decl {
            Decl::Fun {
                name, args, body, ..
            } => {
                self.codegen_function(name.into(), args, body)?;
                Ok(())
            }
            Decl::Extern {
                name,
                arg_types,
                ret_type,
            } => self.codegen_extern(name.into(), arg_types, ret_type),
        }
    }

    pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> {
        self.new_function("main", self.context.i64_type().fn_type(&[], false));
        let res = self.codegen_expr(expr)?.try_into().unwrap();
        if *expr.type_() != Type::Int {
            self.builder
                .build_return(Some(&self.context.i64_type().const_int(0, false)));
        } else {
            self.finish_function(&res);
        }
        Ok(())
    }

    fn codegen_type(&self, type_: &'ast Type) -> BasicTypeEnum<'ctx> {
        // TODO
        match type_ {
            Type::Int => self.context.i64_type().into(),
            Type::Float => self.context.f64_type().into(),
            Type::Bool => self.context.bool_type().into(),
            Type::CString => self
                .context
                .i8_type()
                .ptr_type(AddressSpace::Generic)
                .into(),
            Type::Function(_) => todo!(),
            Type::Var(_) => unreachable!(),
        }
    }

    fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> {
        // TODO
        self.context.i64_type()
    }

    pub fn print_to_file<P>(&self, path: P) -> Result<()>
    where
        P: AsRef<Path>,
    {
        Ok(self.module.print_to_file(path)?)
    }

    pub fn binary_to_file<P>(&self, path: P) -> Result<()>
    where
        P: AsRef<Path>,
    {
        if self.module.write_bitcode_to_path(path.as_ref()) {
            Ok(())
        } else {
            Err(Error::LLVMError(
                "Error writing bitcode to output path".to_owned(),
            ))
        }
    }

    fn fresh_ident(&mut self, prefix: &str) -> String {
        self.identifier_counter += 1;
        format!("{}{}", prefix, self.identifier_counter)
    }

    fn cur_function(&self) -> &FunctionValue<'ctx> {
        self.function_stack.last().unwrap()
    }
}

#[cfg(test)]
mod tests {
    use inkwell::execution_engine::JitFunction;
    use inkwell::OptimizationLevel;

    use super::*;

    fn jit_eval<T>(expr: &str) -> anyhow::Result<T> {
        let expr = crate::parser::expr(expr).unwrap().1;

        let expr = crate::tc::typecheck_expr(expr).unwrap();

        let context = Context::create();
        let mut codegen = Codegen::new(&context, "test");
        let execution_engine = codegen
            .module
            .create_jit_execution_engine(OptimizationLevel::None)
            .unwrap();

        codegen.codegen_function("test", &[], &expr)?;

        unsafe {
            let fun: JitFunction<unsafe extern "C" fn() -> T> =
                execution_engine.get_function("test")?;
            Ok(fun.call())
        }
    }

    #[test]
    fn add_literals() {
        assert_eq!(jit_eval::<i64>("1 + 2").unwrap(), 3);
    }

    #[test]
    fn variable_shadowing() {
        assert_eq!(
            jit_eval::<i64>("let x = 1 in (let x = 2 in x) + x").unwrap(),
            3
        );
    }

    #[test]
    fn eq() {
        assert_eq!(
            jit_eval::<i64>("let x = 1 in if x == 1 then 2 else 4").unwrap(),
            2
        );
    }

    #[test]
    fn function_call() {
        let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap();
        assert_eq!(res, 1);
    }
}