diff options
Diffstat (limited to 'src/codegen/llvm.rs')
-rw-r--r-- | src/codegen/llvm.rs | 181 |
1 files changed, 121 insertions, 60 deletions
diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs index ff92c3373176..1d1c742a9434 100644 --- a/src/codegen/llvm.rs +++ b/src/codegen/llvm.rs @@ -1,3 +1,4 @@ +use std::convert::{TryFrom, TryInto}; use std::path::Path; use std::result; @@ -7,7 +8,7 @@ pub use inkwell::context::Context; use inkwell::module::Module; use inkwell::support::LLVMString; use inkwell::types::FunctionType; -use inkwell::values::{BasicValueEnum, FunctionValue}; +use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; use inkwell::IntPredicate; use thiserror::Error; @@ -35,8 +36,9 @@ pub struct Codegen<'ctx, 'ast> { context: &'ctx Context, pub module: Module<'ctx>, builder: Builder<'ctx>, - env: Env<'ast, BasicValueEnum<'ctx>>, - function: Option<FunctionValue<'ctx>>, + env: Env<'ast, AnyValueEnum<'ctx>>, + function_stack: Vec<FunctionValue<'ctx>>, + identifier_counter: u32, } impl<'ctx, 'ast> Codegen<'ctx, 'ast> { @@ -48,7 +50,8 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { module, builder, env: Default::default(), - function: None, + function_stack: Default::default(), + identifier_counter: 0, } } @@ -57,22 +60,24 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { name: &str, ty: FunctionType<'ctx>, ) -> &'a FunctionValue<'ctx> { - self.function = Some(self.module.add_function(name, ty, None)); + self.function_stack + .push(self.module.add_function(name, ty, None)); let basic_block = self.append_basic_block("entry"); self.builder.position_at_end(basic_block); - self.function.as_ref().unwrap() + self.function_stack.last().unwrap() } - pub fn finish_function(&self, res: &BasicValueEnum<'ctx>) { + pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> { self.builder.build_return(Some(res)); + self.function_stack.pop().unwrap() } pub fn append_basic_block(&self, name: &str) -> BasicBlock<'ctx> { self.context - .append_basic_block(self.function.unwrap(), name) + .append_basic_block(*self.function_stack.last().unwrap(), name) } - pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<BasicValueEnum<'ctx>> { + pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> { match expr { Expr::Ident(id) => self .env @@ -81,13 +86,13 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .ok_or_else(|| Error::UndefinedVariable(id.to_owned())), Expr::Literal(Literal::Int(i)) => { let ty = self.context.i64_type(); - Ok(BasicValueEnum::IntValue(ty.const_int(*i, false))) + Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))) } Expr::UnaryOp { op, rhs } => { let rhs = self.codegen_expr(rhs)?; match op { UnaryOperator::Not => unimplemented!(), - UnaryOperator::Neg => Ok(BasicValueEnum::IntValue( + UnaryOperator::Neg => Ok(AnyValueEnum::IntValue( self.builder.build_int_neg(rhs.into_int_value(), "neg"), )), } @@ -96,29 +101,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { let lhs = self.codegen_expr(lhs)?; let rhs = self.codegen_expr(rhs)?; match op { - BinaryOperator::Add => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_add( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))) - } - BinaryOperator::Sub => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_sub( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))) - } - BinaryOperator::Mul => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_sub( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))) - } + BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ))), + BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ))), + BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ))), BinaryOperator::Div => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_signed_div( + Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div( lhs.into_int_value(), rhs.into_int_value(), "add", @@ -126,7 +125,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { } BinaryOperator::Pow => unimplemented!(), BinaryOperator::Equ => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_compare( + Ok(AnyValueEnum::IntValue(self.builder.build_int_compare( IntPredicate::EQ, lhs.into_int_value(), rhs.into_int_value(), @@ -170,34 +169,83 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.builder.position_at_end(join_block); let phi = self.builder.build_phi(self.context.i64_type(), "join"); - phi.add_incoming(&[(&then_res, then_block), (&else_res, else_block)]); - Ok(phi.as_basic_value()) + phi.add_incoming(&[ + (&BasicValueEnum::try_from(then_res).unwrap(), then_block), + (&BasicValueEnum::try_from(else_res).unwrap(), else_block), + ]); + Ok(phi.as_basic_value().into()) + } + Expr::Call { fun, args } => { + if let Expr::Ident(id) = &**fun { + let function = self + .module + .get_function(id.into()) + .or_else(|| self.env.resolve(id)?.clone().try_into().ok()) + .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?; + let args = args + .iter() + .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap())) + .collect::<Result<Vec<_>>>()?; + Ok(self + .builder + .build_call(function, &args, "call") + .try_as_basic_value() + .left() + .unwrap() + .into()) + } else { + todo!() + } + } + Expr::Fun(fun) => { + let Fun { args, body } = &**fun; + let fname = self.fresh_ident("f"); + let cur_block = self.builder.get_insert_block().unwrap(); + let env = self.env.save(); // TODO: closures + let function = self.codegen_function(&fname, args, body)?; + self.builder.position_at_end(cur_block); + self.env.restore(env); + Ok(function.into()) } } } + pub fn codegen_function( + &mut self, + name: &str, + args: &'ast [Ident<'ast>], + body: &'ast Expr<'ast>, + ) -> Result<FunctionValue<'ctx>> { + let i64_type = self.context.i64_type(); + self.new_function( + name, + i64_type.fn_type( + args.iter() + .map(|_| i64_type.into()) + .collect::<Vec<_>>() + .as_slice(), + false, + ), + ); + self.env.push(); + for (i, arg) in args.iter().enumerate() { + self.env.set( + arg, + self.cur_function().get_nth_param(i as u32).unwrap().into(), + ); + } + let res = self.codegen_expr(body)?.try_into().unwrap(); + self.env.pop(); + Ok(self.finish_function(&res)) + } + pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> { match decl { - Decl::Fun(Fun { name, args, body }) => { - let i64_type = self.context.i64_type(); - self.new_function( - name.into(), - i64_type.fn_type( - args.iter() - .map(|_| i64_type.into()) - .collect::<Vec<_>>() - .as_slice(), - false, - ), - ); - self.env.push(); - for (i, arg) in args.iter().enumerate() { - self.env - .set(arg, self.function.unwrap().get_nth_param(i as u32).unwrap()); - } - let res = self.codegen_expr(body)?; - self.env.pop(); - self.finish_function(&res); + Decl::Fun { + name, + body: Fun { args, body }, + } => { + self.codegen_function(name.into(), args, body)?; Ok(()) } } @@ -205,7 +253,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> { self.new_function("main", self.context.i64_type().fn_type(&[], false)); - let res = self.codegen_expr(expr)?; + let res = self.codegen_expr(expr)?.try_into().unwrap(); self.finish_function(&res); Ok(()) } @@ -229,6 +277,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { )) } } + + fn fresh_ident(&mut self, prefix: &str) -> String { + self.identifier_counter += 1; + format!("{}{}", prefix, self.identifier_counter) + } + + fn cur_function(&self) -> &FunctionValue<'ctx> { + self.function_stack.last().unwrap() + } } #[cfg(test)] @@ -248,9 +305,7 @@ mod tests { .create_jit_execution_engine(OptimizationLevel::None) .unwrap(); - codegen.new_function("test", context.i64_type().fn_type(&[], false)); - let res = codegen.codegen_expr(&expr)?; - codegen.finish_function(&res); + codegen.codegen_function("test", &[], &expr)?; unsafe { let fun: JitFunction<unsafe extern "C" fn() -> T> = @@ -279,4 +334,10 @@ mod tests { 2 ); } + + #[test] + fn function_call() { + let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap(); + assert_eq!(res, 1); + } } |