diff options
-rw-r--r-- | Cargo.lock | 10 | ||||
-rw-r--r-- | Cargo.toml | 1 | ||||
-rw-r--r-- | src/ast/mod.rs | 94 | ||||
-rw-r--r-- | src/codegen/llvm.rs | 181 | ||||
-rw-r--r-- | src/codegen/mod.rs | 4 | ||||
-rw-r--r-- | src/commands/compile.rs | 3 | ||||
-rw-r--r-- | src/common/env.rs | 9 | ||||
-rw-r--r-- | src/interpreter/mod.rs | 56 | ||||
-rw-r--r-- | src/interpreter/value.rs | 101 | ||||
-rw-r--r-- | src/parser/mod.rs | 167 |
10 files changed, 501 insertions, 125 deletions
diff --git a/Cargo.lock b/Cargo.lock index 485e9f4cdfcb..d8eaedeca181 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -8,6 +8,7 @@ dependencies = [ "clap", "derive_more", "inkwell", + "itertools", "llvm-sys", "nom", "nom-trace", @@ -246,6 +247,15 @@ dependencies = [ ] [[package]] +name = "itertools" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37d572918e350e82412fe766d24b15e6682fb2ed2bbe018280caa810397cb319" +dependencies = [ + "either", +] + +[[package]] name = "lazy_static" version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" diff --git a/Cargo.toml b/Cargo.toml index eda15e554661..c9796a821586 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,6 +9,7 @@ anyhow = "1.0.38" clap = "3.0.0-beta.2" derive_more = "0.99.11" inkwell = { git = "https://github.com/TheDan64/inkwell", branch = "master", features = ["llvm11-0"] } +itertools = "0.10.0" llvm-sys = "110.0.1" nom = "6.1.2" nom-trace = { git = "https://github.com/glittershark/nom-trace", branch = "nom-6" } diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 2dcf955fe67c..7f3543271db8 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -2,10 +2,12 @@ use std::borrow::Cow; use std::convert::TryFrom; use std::fmt::{self, Display, Formatter}; +use itertools::Itertools; + #[derive(Debug, PartialEq, Eq)] pub struct InvalidIdentifier<'a>(Cow<'a, str>); -#[derive(Debug, PartialEq, Eq, Hash)] +#[derive(Debug, PartialEq, Eq, Hash, Clone)] pub struct Ident<'a>(pub Cow<'a, str>); impl<'a> From<&'a Ident<'a>> for &'a str { @@ -69,7 +71,7 @@ impl<'a> TryFrom<String> for Ident<'a> { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum BinaryOperator { /// `+` Add, @@ -93,7 +95,7 @@ pub enum BinaryOperator { Neq, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] pub enum UnaryOperator { /// ! Not, @@ -102,12 +104,12 @@ pub enum UnaryOperator { Neg, } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Literal { Int(u64), } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Expr<'a> { Ident(Ident<'a>), @@ -134,33 +136,101 @@ pub enum Expr<'a> { then: Box<Expr<'a>>, else_: Box<Expr<'a>>, }, + + Fun(Box<Fun<'a>>), + + Call { + fun: Box<Expr<'a>>, + args: Vec<Expr<'a>>, + }, } -#[derive(Debug, PartialEq, Eq)] +impl<'a> Expr<'a> { + pub fn to_owned(&self) -> Expr<'static> { + match self { + Expr::Ident(ref id) => Expr::Ident(id.to_owned()), + Expr::Literal(ref lit) => Expr::Literal(lit.clone()), + Expr::UnaryOp { op, rhs } => Expr::UnaryOp { + op: *op, + rhs: Box::new((**rhs).to_owned()), + }, + Expr::BinaryOp { lhs, op, rhs } => Expr::BinaryOp { + lhs: Box::new((**lhs).to_owned()), + op: *op, + rhs: Box::new((**rhs).to_owned()), + }, + Expr::Let { bindings, body } => Expr::Let { + bindings: bindings + .iter() + .map(|(id, expr)| (id.to_owned(), expr.to_owned())) + .collect(), + body: Box::new((**body).to_owned()), + }, + Expr::If { + condition, + then, + else_, + } => Expr::If { + condition: Box::new((**condition).to_owned()), + then: Box::new((**then).to_owned()), + else_: Box::new((**else_).to_owned()), + }, + Expr::Fun(fun) => Expr::Fun(Box::new((**fun).to_owned())), + Expr::Call { fun, args } => Expr::Call { + fun: Box::new((**fun).to_owned()), + args: args.iter().map(|arg| arg.to_owned()).collect(), + }, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Fun<'a> { - pub name: Ident<'a>, pub args: Vec<Ident<'a>>, pub body: Expr<'a>, } +impl<'a> Fun<'a> { + fn to_owned(&self) -> Fun<'static> { + Fun { + args: self.args.iter().map(|arg| arg.to_owned()).collect(), + body: self.body.to_owned(), + } + } +} + #[derive(Debug, PartialEq, Eq)] pub enum Decl<'a> { - Fun(Fun<'a>), + Fun { name: Ident<'a>, body: Fun<'a> }, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct FunctionType { + pub args: Vec<Type>, + pub ret: Box<Type>, +} + +impl Display for FunctionType { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret) + } } -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Type { Int, Float, Bool, + Function(FunctionType), } impl Display for Type { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { - Self::Int => f.write_str("int"), - Self::Float => f.write_str("float"), - Self::Bool => f.write_str("bool"), + Type::Int => f.write_str("int"), + Type::Float => f.write_str("float"), + Type::Bool => f.write_str("bool"), + Type::Function(ft) => ft.fmt(f), } } } diff --git a/src/codegen/llvm.rs b/src/codegen/llvm.rs index ff92c3373176..1d1c742a9434 100644 --- a/src/codegen/llvm.rs +++ b/src/codegen/llvm.rs @@ -1,3 +1,4 @@ +use std::convert::{TryFrom, TryInto}; use std::path::Path; use std::result; @@ -7,7 +8,7 @@ pub use inkwell::context::Context; use inkwell::module::Module; use inkwell::support::LLVMString; use inkwell::types::FunctionType; -use inkwell::values::{BasicValueEnum, FunctionValue}; +use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; use inkwell::IntPredicate; use thiserror::Error; @@ -35,8 +36,9 @@ pub struct Codegen<'ctx, 'ast> { context: &'ctx Context, pub module: Module<'ctx>, builder: Builder<'ctx>, - env: Env<'ast, BasicValueEnum<'ctx>>, - function: Option<FunctionValue<'ctx>>, + env: Env<'ast, AnyValueEnum<'ctx>>, + function_stack: Vec<FunctionValue<'ctx>>, + identifier_counter: u32, } impl<'ctx, 'ast> Codegen<'ctx, 'ast> { @@ -48,7 +50,8 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { module, builder, env: Default::default(), - function: None, + function_stack: Default::default(), + identifier_counter: 0, } } @@ -57,22 +60,24 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { name: &str, ty: FunctionType<'ctx>, ) -> &'a FunctionValue<'ctx> { - self.function = Some(self.module.add_function(name, ty, None)); + self.function_stack + .push(self.module.add_function(name, ty, None)); let basic_block = self.append_basic_block("entry"); self.builder.position_at_end(basic_block); - self.function.as_ref().unwrap() + self.function_stack.last().unwrap() } - pub fn finish_function(&self, res: &BasicValueEnum<'ctx>) { + pub fn finish_function(&mut self, res: &BasicValueEnum<'ctx>) -> FunctionValue<'ctx> { self.builder.build_return(Some(res)); + self.function_stack.pop().unwrap() } pub fn append_basic_block(&self, name: &str) -> BasicBlock<'ctx> { self.context - .append_basic_block(self.function.unwrap(), name) + .append_basic_block(*self.function_stack.last().unwrap(), name) } - pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<BasicValueEnum<'ctx>> { + pub fn codegen_expr(&mut self, expr: &'ast Expr<'ast>) -> Result<AnyValueEnum<'ctx>> { match expr { Expr::Ident(id) => self .env @@ -81,13 +86,13 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { .ok_or_else(|| Error::UndefinedVariable(id.to_owned())), Expr::Literal(Literal::Int(i)) => { let ty = self.context.i64_type(); - Ok(BasicValueEnum::IntValue(ty.const_int(*i, false))) + Ok(AnyValueEnum::IntValue(ty.const_int(*i, false))) } Expr::UnaryOp { op, rhs } => { let rhs = self.codegen_expr(rhs)?; match op { UnaryOperator::Not => unimplemented!(), - UnaryOperator::Neg => Ok(BasicValueEnum::IntValue( + UnaryOperator::Neg => Ok(AnyValueEnum::IntValue( self.builder.build_int_neg(rhs.into_int_value(), "neg"), )), } @@ -96,29 +101,23 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { let lhs = self.codegen_expr(lhs)?; let rhs = self.codegen_expr(rhs)?; match op { - BinaryOperator::Add => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_add( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))) - } - BinaryOperator::Sub => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_sub( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))) - } - BinaryOperator::Mul => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_sub( - lhs.into_int_value(), - rhs.into_int_value(), - "add", - ))) - } + BinaryOperator::Add => Ok(AnyValueEnum::IntValue(self.builder.build_int_add( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ))), + BinaryOperator::Sub => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ))), + BinaryOperator::Mul => Ok(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ))), BinaryOperator::Div => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_signed_div( + Ok(AnyValueEnum::IntValue(self.builder.build_int_signed_div( lhs.into_int_value(), rhs.into_int_value(), "add", @@ -126,7 +125,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { } BinaryOperator::Pow => unimplemented!(), BinaryOperator::Equ => { - Ok(BasicValueEnum::IntValue(self.builder.build_int_compare( + Ok(AnyValueEnum::IntValue(self.builder.build_int_compare( IntPredicate::EQ, lhs.into_int_value(), rhs.into_int_value(), @@ -170,34 +169,83 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { self.builder.position_at_end(join_block); let phi = self.builder.build_phi(self.context.i64_type(), "join"); - phi.add_incoming(&[(&then_res, then_block), (&else_res, else_block)]); - Ok(phi.as_basic_value()) + phi.add_incoming(&[ + (&BasicValueEnum::try_from(then_res).unwrap(), then_block), + (&BasicValueEnum::try_from(else_res).unwrap(), else_block), + ]); + Ok(phi.as_basic_value().into()) + } + Expr::Call { fun, args } => { + if let Expr::Ident(id) = &**fun { + let function = self + .module + .get_function(id.into()) + .or_else(|| self.env.resolve(id)?.clone().try_into().ok()) + .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?; + let args = args + .iter() + .map(|arg| Ok(self.codegen_expr(arg)?.try_into().unwrap())) + .collect::<Result<Vec<_>>>()?; + Ok(self + .builder + .build_call(function, &args, "call") + .try_as_basic_value() + .left() + .unwrap() + .into()) + } else { + todo!() + } + } + Expr::Fun(fun) => { + let Fun { args, body } = &**fun; + let fname = self.fresh_ident("f"); + let cur_block = self.builder.get_insert_block().unwrap(); + let env = self.env.save(); // TODO: closures + let function = self.codegen_function(&fname, args, body)?; + self.builder.position_at_end(cur_block); + self.env.restore(env); + Ok(function.into()) } } } + pub fn codegen_function( + &mut self, + name: &str, + args: &'ast [Ident<'ast>], + body: &'ast Expr<'ast>, + ) -> Result<FunctionValue<'ctx>> { + let i64_type = self.context.i64_type(); + self.new_function( + name, + i64_type.fn_type( + args.iter() + .map(|_| i64_type.into()) + .collect::<Vec<_>>() + .as_slice(), + false, + ), + ); + self.env.push(); + for (i, arg) in args.iter().enumerate() { + self.env.set( + arg, + self.cur_function().get_nth_param(i as u32).unwrap().into(), + ); + } + let res = self.codegen_expr(body)?.try_into().unwrap(); + self.env.pop(); + Ok(self.finish_function(&res)) + } + pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast>) -> Result<()> { match decl { - Decl::Fun(Fun { name, args, body }) => { - let i64_type = self.context.i64_type(); - self.new_function( - name.into(), - i64_type.fn_type( - args.iter() - .map(|_| i64_type.into()) - .collect::<Vec<_>>() - .as_slice(), - false, - ), - ); - self.env.push(); - for (i, arg) in args.iter().enumerate() { - self.env - .set(arg, self.function.unwrap().get_nth_param(i as u32).unwrap()); - } - let res = self.codegen_expr(body)?; - self.env.pop(); - self.finish_function(&res); + Decl::Fun { + name, + body: Fun { args, body }, + } => { + self.codegen_function(name.into(), args, body)?; Ok(()) } } @@ -205,7 +253,7 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { pub fn codegen_main(&mut self, expr: &'ast Expr<'ast>) -> Result<()> { self.new_function("main", self.context.i64_type().fn_type(&[], false)); - let res = self.codegen_expr(expr)?; + let res = self.codegen_expr(expr)?.try_into().unwrap(); self.finish_function(&res); Ok(()) } @@ -229,6 +277,15 @@ impl<'ctx, 'ast> Codegen<'ctx, 'ast> { )) } } + + fn fresh_ident(&mut self, prefix: &str) -> String { + self.identifier_counter += 1; + format!("{}{}", prefix, self.identifier_counter) + } + + fn cur_function(&self) -> &FunctionValue<'ctx> { + self.function_stack.last().unwrap() + } } #[cfg(test)] @@ -248,9 +305,7 @@ mod tests { .create_jit_execution_engine(OptimizationLevel::None) .unwrap(); - codegen.new_function("test", context.i64_type().fn_type(&[], false)); - let res = codegen.codegen_expr(&expr)?; - codegen.finish_function(&res); + codegen.codegen_function("test", &[], &expr)?; unsafe { let fun: JitFunction<unsafe extern "C" fn() -> T> = @@ -279,4 +334,10 @@ mod tests { 2 ); } + + #[test] + fn function_call() { + let res = jit_eval::<i64>("let id = fn x = x in id 1").unwrap(); + assert_eq!(res, 1); + } } diff --git a/src/codegen/mod.rs b/src/codegen/mod.rs index 4620b6b48e84..6f95d90b45a1 100644 --- a/src/codegen/mod.rs +++ b/src/codegen/mod.rs @@ -14,9 +14,7 @@ pub fn jit_eval<T>(expr: &Expr) -> Result<T> { .module .create_jit_execution_engine(OptimizationLevel::None) .map_err(Error::from)?; - codegen.new_function("eval", context.i64_type().fn_type(&[], false)); - let res = codegen.codegen_expr(&expr)?; - codegen.finish_function(&res); + codegen.codegen_function("test", &[], &expr)?; unsafe { let fun: JitFunction<unsafe extern "C" fn() -> T> = diff --git a/src/commands/compile.rs b/src/commands/compile.rs index e16b8c87a659..be8767575ab5 100644 --- a/src/commands/compile.rs +++ b/src/commands/compile.rs @@ -5,10 +5,13 @@ use clap::Clap; use crate::common::Result; use crate::compiler::{self, CompilerOptions}; +/// Compile a source file #[derive(Clap)] pub struct Compile { + /// File to compile file: PathBuf, + /// Output file #[clap(short = 'o')] out_file: PathBuf, diff --git a/src/common/env.rs b/src/common/env.rs index 8b5cde49e9e4..f499323639e3 100644 --- a/src/common/env.rs +++ b/src/common/env.rs @@ -1,4 +1,5 @@ use std::collections::HashMap; +use std::mem; use crate::ast::Ident; @@ -25,6 +26,14 @@ impl<'ast, V> Env<'ast, V> { self.0.pop(); } + pub fn save(&mut self) -> Self { + mem::take(self) + } + + pub fn restore(&mut self, saved: Self) { + *self = saved; + } + pub fn set(&mut self, k: &'ast Ident<'ast>, v: V) { self.0.last_mut().unwrap().insert(k, v); } diff --git a/src/interpreter/mod.rs b/src/interpreter/mod.rs index adff3568c2c3..fc1556d1c292 100644 --- a/src/interpreter/mod.rs +++ b/src/interpreter/mod.rs @@ -2,13 +2,13 @@ mod error; mod value; pub use self::error::{Error, Result}; -pub use self::value::Value; -use crate::ast::{BinaryOperator, Expr, Ident, Literal, UnaryOperator}; +pub use self::value::{Function, Value}; +use crate::ast::{BinaryOperator, Expr, FunctionType, Ident, Literal, Type, UnaryOperator}; use crate::common::env::Env; #[derive(Debug, Default)] pub struct Interpreter<'a> { - env: Env<'a, Value>, + env: Env<'a, Value<'a>>, } impl<'a> Interpreter<'a> { @@ -16,14 +16,14 @@ impl<'a> Interpreter<'a> { Self::default() } - fn resolve(&self, var: &'a Ident<'a>) -> Result<Value> { + fn resolve(&self, var: &'a Ident<'a>) -> Result<Value<'a>> { self.env .resolve(var) .cloned() .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) } - pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value> { + pub fn eval(&mut self, expr: &'a Expr<'a>) -> Result<Value<'a>> { match expr { Expr::Ident(id) => self.resolve(id), Expr::Literal(Literal::Int(i)) => Ok((*i).into()), @@ -63,12 +63,44 @@ impl<'a> Interpreter<'a> { else_, } => { let condition = self.eval(condition)?; - if *(condition.into_type::<bool>()?) { + if *(condition.as_type::<bool>()?) { self.eval(then) } else { self.eval(else_) } } + Expr::Call { ref fun, args } => { + let fun = self.eval(fun)?; + let expected_type = FunctionType { + args: args.iter().map(|_| Type::Int).collect(), + ret: Box::new(Type::Int), + }; + + let Function { + args: function_args, + body, + .. + } = fun.as_function(expected_type)?; + let arg_values = function_args.iter().zip( + args.iter() + .map(|v| self.eval(v)) + .collect::<Result<Vec<_>>>()?, + ); + let mut interpreter = Interpreter::new(); + for (arg_name, arg_value) in arg_values { + interpreter.env.set(arg_name, arg_value); + } + Ok(Value::from(*interpreter.eval(body)?.as_type::<i64>()?)) + } + Expr::Fun(fun) => Ok(Value::from(value::Function { + // TODO + type_: FunctionType { + args: fun.args.iter().map(|_| Type::Int).collect(), + ret: Box::new(Type::Int), + }, + args: fun.args.iter().map(|arg| arg.to_owned()).collect(), + body: fun.body.to_owned(), + })), } } } @@ -92,12 +124,12 @@ mod tests { fn parse_eval<T>(src: &str) -> T where - for<'a> &'a T: TryFrom<&'a Val>, + for<'a> &'a T: TryFrom<&'a Val<'a>>, T: Clone + TypeOf, { let expr = crate::parser::expr(src).unwrap().1; let res = eval(&expr).unwrap(); - res.into_type::<T>().unwrap().clone() + res.as_type::<T>().unwrap().clone() } #[test] @@ -108,7 +140,7 @@ mod tests { rhs: int_lit(2), }; let res = eval(&expr).unwrap(); - assert_eq!(*res.into_type::<i64>().unwrap(), 2); + assert_eq!(*res.as_type::<i64>().unwrap(), 2); } #[test] @@ -122,4 +154,10 @@ mod tests { let res = parse_eval::<i64>("let x = 1 in if x == 1 then 2 else 4"); assert_eq!(res, 2); } + + #[test] + fn function_call() { + let res = parse_eval::<i64>("let id = fn x = x in id 1"); + assert_eq!(res, 1); + } } diff --git a/src/interpreter/value.rs b/src/interpreter/value.rs index 69e4d4ffeb96..496e9c4230de 100644 --- a/src/interpreter/value.rs +++ b/src/interpreter/value.rs @@ -6,108 +6,153 @@ use std::rc::Rc; use derive_more::{Deref, From, TryInto}; use super::{Error, Result}; -use crate::ast::Type; +use crate::ast::{Expr, FunctionType, Ident, Type}; -#[derive(Debug, PartialEq, From, TryInto)] +#[derive(Debug, Clone)] +pub struct Function<'a> { + pub type_: FunctionType, + pub args: Vec<Ident<'a>>, + pub body: Expr<'a>, +} + +#[derive(From, TryInto)] #[try_into(owned, ref)] -pub enum Val { +pub enum Val<'a> { Int(i64), Float(f64), Bool(bool), + Function(Function<'a>), +} + +impl<'a> fmt::Debug for Val<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Val::Int(x) => f.debug_tuple("Int").field(x).finish(), + Val::Float(x) => f.debug_tuple("Float").field(x).finish(), + Val::Bool(x) => f.debug_tuple("Bool").field(x).finish(), + Val::Function(Function { type_, .. }) => { + f.debug_struct("Function").field("type_", type_).finish() + } + } + } } -impl From<u64> for Val { +impl<'a> PartialEq for Val<'a> { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Val::Int(x), Val::Int(y)) => x == y, + (Val::Float(x), Val::Float(y)) => x == y, + (Val::Bool(x), Val::Bool(y)) => x == y, + (Val::Function(_), Val::Function(_)) => false, + (_, _) => false, + } + } +} + +impl<'a> From<u64> for Val<'a> { fn from(i: u64) -> Self { Self::from(i as i64) } } -impl Display for Val { +impl<'a> Display for Val<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Val::Int(x) => x.fmt(f), Val::Float(x) => x.fmt(f), Val::Bool(x) => x.fmt(f), + Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_), } } } -impl Val { +impl<'a> Val<'a> { pub fn type_(&self) -> Type { match self { Val::Int(_) => Type::Int, Val::Float(_) => Type::Float, Val::Bool(_) => Type::Bool, + Val::Function(Function { type_, .. }) => Type::Function(type_.clone()), } } - pub fn into_type<'a, T>(&'a self) -> Result<&'a T> + pub fn as_type<'b, T>(&'b self) -> Result<&'b T> where - T: TypeOf + 'a + Clone, - &'a T: TryFrom<&'a Self>, + T: TypeOf + 'b + Clone, + &'b T: TryFrom<&'b Self>, { <&T>::try_from(self).map_err(|_| Error::InvalidType { actual: self.type_(), expected: <T as TypeOf>::type_of(), }) } + + pub fn as_function<'b>(&'b self, function_type: FunctionType) -> Result<&'b Function<'a>> { + match self { + Val::Function(f) if f.type_ == function_type => Ok(&f), + _ => Err(Error::InvalidType { + actual: self.type_(), + expected: Type::Function(function_type), + }), + } + } } #[derive(Debug, PartialEq, Clone, Deref)] -pub struct Value(Rc<Val>); +pub struct Value<'a>(Rc<Val<'a>>); -impl Display for Value { +impl<'a> Display for Value<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { self.0.fmt(f) } } -impl<T> From<T> for Value +impl<'a, T> From<T> for Value<'a> where - Val: From<T>, + Val<'a>: From<T>, { fn from(x: T) -> Self { Self(Rc::new(x.into())) } } -impl Neg for Value { - type Output = Result<Value>; +impl<'a> Neg for Value<'a> { + type Output = Result<Value<'a>>; fn neg(self) -> Self::Output { - Ok((-self.into_type::<i64>()?).into()) + Ok((-self.as_type::<i64>()?).into()) } } -impl Add for Value { - type Output = Result<Value>; +impl<'a> Add for Value<'a> { + type Output = Result<Value<'a>>; fn add(self, rhs: Self) -> Self::Output { - Ok((self.into_type::<i64>()? + rhs.into_type::<i64>()?).into()) + Ok((self.as_type::<i64>()? + rhs.as_type::<i64>()?).into()) } } -impl Sub for Value { - type Output = Result<Value>; +impl<'a> Sub for Value<'a> { + type Output = Result<Value<'a>>; fn sub(self, rhs: Self) -> Self::Output { - Ok((self.into_type::<i64>()? - rhs.into_type::<i64>()?).into()) + Ok((self.as_type::<i64>()? - rhs.as_type::<i64>()?).into()) } } -impl Mul for Value { - type Output = Result<Value>; +impl<'a> Mul for Value<'a> { + type Output = Result<Value<'a>>; fn mul(self, rhs: Self) -> Self::Output { - Ok((self.into_type::<i64>()? * rhs.into_type::<i64>()?).into()) + Ok((self.as_type::<i64>()? * rhs.as_type::<i64>()?).into()) } } -impl Div for Value { - type Output = Result<Value>; +impl<'a> Div for Value<'a> { + type Output = Result<Value<'a>>; fn div(self, rhs: Self) -> Self::Output { - Ok((self.into_type::<f64>()? / rhs.into_type::<f64>()?).into()) + Ok((self.as_type::<f64>()? / rhs.as_type::<f64>()?).into()) } } diff --git a/src/parser/mod.rs b/src/parser/mod.rs index 811450da0d61..be432b8adf56 100644 --- a/src/parser/mod.rs +++ b/src/parser/mod.rs @@ -156,6 +156,10 @@ where } } +fn is_reserved(s: &str) -> bool { + matches!(s, "if" | "then" | "else" | "let" | "in" | "fn") +} + fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E> where E: ParseError<&'a str>, @@ -170,7 +174,12 @@ where } idx += 1; } - Ok((&i[idx..], Ident::from_str_unchecked(&i[..idx]))) + let id = &i[..idx]; + if is_reserved(id) { + Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Satisfy))) + } else { + Ok((&i[idx..], Ident::from_str_unchecked(id))) + } } else { Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Satisfy))) } @@ -228,14 +237,65 @@ named!(if_(&str) -> Expr, do_parse! ( named!(ident_expr(&str) -> Expr, map!(ident, Expr::Ident)); +named!(paren_expr(&str) -> Expr, + delimited!(complete!(tag!("(")), expr, complete!(tag!(")")))); + +named!(funcref(&str) -> Expr, alt!( + ident_expr | + paren_expr +)); + +named!(no_arg_call(&str) -> Expr, do_parse!( + fun: funcref + >> multispace0 + >> complete!(tag!("()")) + >> (Expr::Call { + fun: Box::new(fun), + args: vec![], + }) +)); + +named!(fun_expr(&str) -> Expr, do_parse!( + tag!("fn") + >> multispace1 + >> args: separated_list0!(multispace1, ident) + >> multispace0 + >> char!('=') + >> multispace0 + >> body: expr + >> (Expr::Fun(Box::new(Fun { + args, + body + }))) +)); + +named!(arg(&str) -> Expr, alt!( + ident_expr | + literal | + paren_expr +)); + +named!(call_with_args(&str) -> Expr, do_parse!( + fun: funcref + >> multispace1 + >> args: separated_list1!(multispace1, arg) + >> (Expr::Call { + fun: Box::new(fun), + args + }) +)); + named!(simple_expr(&str) -> Expr, alt!( let_ | if_ | + fun_expr | literal | ident_expr )); named!(pub expr(&str) -> Expr, alt!( + no_arg_call | + call_with_args | map!(token_tree, |tt| { ExprParser.parse(&mut tt.into_iter()).unwrap() }) | @@ -243,8 +303,8 @@ named!(pub expr(&str) -> Expr, alt!( ////// -named!(fun(&str) -> Fun, do_parse!( - tag!("fn") +named!(fun_decl(&str) -> Decl, do_parse!( + complete!(tag!("fn")) >> multispace0 >> name: ident >> multispace1 @@ -253,21 +313,24 @@ named!(fun(&str) -> Fun, do_parse!( >> char!('=') >> multispace0 >> body: expr - >> (Fun { + >> (Decl::Fun { name, - args, - body + body: Fun { + args, + body + } }) )); named!(pub decl(&str) -> Decl, alt!( - fun => { |f| Decl::Fun(f) } + fun_decl )); -named!(pub toplevel(&str) -> Vec<Decl>, separated_list0!(multispace1, decl)); +named!(pub toplevel(&str) -> Vec<Decl>, many0!(decl)); #[cfg(test)] mod tests { + use nom_trace::print_trace; use std::convert::{TryFrom, TryInto}; use super::*; @@ -281,7 +344,9 @@ mod tests { macro_rules! test_parse { ($parser: ident, $src: expr) => {{ - let (rem, res) = $parser($src).unwrap(); + let res = $parser($src); + print_trace!(); + let (rem, res) = res.unwrap(); assert!( rem.is_empty(), "non-empty remainder: \"{}\", parsed: {:?}", @@ -435,11 +500,87 @@ mod tests { let res = test_parse!(decl, "fn id x = x"); assert_eq!( res, - Decl::Fun(Fun { + Decl::Fun { name: "id".try_into().unwrap(), - args: vec!["x".try_into().unwrap()], - body: *ident_expr("x"), - }) + body: Fun { + args: vec!["x".try_into().unwrap()], + body: *ident_expr("x"), + } + } ) } + + #[test] + fn no_arg_call() { + let res = test_parse!(expr, "f()"); + assert_eq!( + res, + Expr::Call { + fun: ident_expr("f"), + args: vec![] + } + ); + } + + #[test] + fn call_with_args() { + let res = test_parse!(expr, "f x 1"); + assert_eq!( + res, + Expr::Call { + fun: ident_expr("f"), + args: vec![*ident_expr("x"), Expr::Literal(Literal::Int(1))] + } + ) + } + + #[test] + fn call_funcref() { + let res = test_parse!(expr, "(let x = 1 in x) 2"); + assert_eq!( + res, + Expr::Call { + fun: Box::new(Expr::Let { + bindings: vec![( + Ident::try_from("x").unwrap(), + Expr::Literal(Literal::Int(1)) + )], + body: ident_expr("x") + }), + args: vec![Expr::Literal(Literal::Int(2))] + } + ) + } + + #[test] + fn anon_function() { + let res = test_parse!(expr, "let id = fn x = x in id 1"); + assert_eq!( + res, + Expr::Let { + bindings: vec![( + Ident::try_from("id").unwrap(), + Expr::Fun(Box::new(Fun { + args: vec![Ident::try_from("x").unwrap()], + body: *ident_expr("x") + })) + )], + body: Box::new(Expr::Call { + fun: ident_expr("id"), + args: vec![Expr::Literal(Literal::Int(1))], + }) + } + ); + } + + #[test] + fn multiple_decls() { + let res = test_parse!( + toplevel, + "fn id x = x + fn plus x y = x + y + fn main = plus (id 2) 7" + ); + assert_eq!(res.len(), 3); + } } |