From 6266c5d32f9ff651fcfc3a4cc0c68e89da56ca65 Mon Sep 17 00:00:00 2001 From: Griffin Smith Date: Sun, 11 Apr 2021 17:53:27 -0400 Subject: refactor(users/glittershark): Rename to grfn Rename my //users directory and all places that refer to glittershark to grfn, including nix references and documentation. This may require some extra attention inside of gerrit's database after it lands to allow me to actually push things. Change-Id: I4728b7ec2c60024392c1c1fa6e0d4a59b3e266fa Reviewed-on: https://cl.tvl.fyi/c/depot/+/2933 Tested-by: BuildkiteCI Reviewed-by: tazjin Reviewed-by: lukegb Reviewed-by: glittershark --- users/grfn/achilles/src/ast/hir.rs | 320 +++++++++ users/grfn/achilles/src/ast/mod.rs | 447 +++++++++++++ users/grfn/achilles/src/codegen/llvm.rs | 436 ++++++++++++ users/grfn/achilles/src/codegen/mod.rs | 25 + users/grfn/achilles/src/commands/check.rs | 39 ++ users/grfn/achilles/src/commands/compile.rs | 31 + users/grfn/achilles/src/commands/eval.rs | 32 + users/grfn/achilles/src/commands/mod.rs | 7 + users/grfn/achilles/src/common/env.rs | 59 ++ users/grfn/achilles/src/common/error.rs | 59 ++ users/grfn/achilles/src/common/mod.rs | 6 + users/grfn/achilles/src/common/namer.rs | 122 ++++ users/grfn/achilles/src/compiler.rs | 89 +++ users/grfn/achilles/src/interpreter/error.rs | 19 + users/grfn/achilles/src/interpreter/mod.rs | 182 +++++ users/grfn/achilles/src/interpreter/value.rs | 203 ++++++ users/grfn/achilles/src/main.rs | 36 + users/grfn/achilles/src/parser/expr.rs | 645 ++++++++++++++++++ users/grfn/achilles/src/parser/macros.rs | 16 + users/grfn/achilles/src/parser/mod.rs | 239 +++++++ users/grfn/achilles/src/parser/type_.rs | 127 ++++ users/grfn/achilles/src/passes/hir/mod.rs | 197 ++++++ users/grfn/achilles/src/passes/hir/monomorphize.rs | 139 ++++ .../src/passes/hir/strip_positive_units.rs | 189 ++++++ users/grfn/achilles/src/passes/mod.rs | 1 + users/grfn/achilles/src/tc/mod.rs | 745 +++++++++++++++++++++ 26 files changed, 4410 insertions(+) create mode 100644 users/grfn/achilles/src/ast/hir.rs create mode 100644 users/grfn/achilles/src/ast/mod.rs create mode 100644 users/grfn/achilles/src/codegen/llvm.rs create mode 100644 users/grfn/achilles/src/codegen/mod.rs create mode 100644 users/grfn/achilles/src/commands/check.rs create mode 100644 users/grfn/achilles/src/commands/compile.rs create mode 100644 users/grfn/achilles/src/commands/eval.rs create mode 100644 users/grfn/achilles/src/commands/mod.rs create mode 100644 users/grfn/achilles/src/common/env.rs create mode 100644 users/grfn/achilles/src/common/error.rs create mode 100644 users/grfn/achilles/src/common/mod.rs create mode 100644 users/grfn/achilles/src/common/namer.rs create mode 100644 users/grfn/achilles/src/compiler.rs create mode 100644 users/grfn/achilles/src/interpreter/error.rs create mode 100644 users/grfn/achilles/src/interpreter/mod.rs create mode 100644 users/grfn/achilles/src/interpreter/value.rs create mode 100644 users/grfn/achilles/src/main.rs create mode 100644 users/grfn/achilles/src/parser/expr.rs create mode 100644 users/grfn/achilles/src/parser/macros.rs create mode 100644 users/grfn/achilles/src/parser/mod.rs create mode 100644 users/grfn/achilles/src/parser/type_.rs create mode 100644 users/grfn/achilles/src/passes/hir/mod.rs create mode 100644 users/grfn/achilles/src/passes/hir/monomorphize.rs create mode 100644 users/grfn/achilles/src/passes/hir/strip_positive_units.rs create mode 100644 users/grfn/achilles/src/passes/mod.rs create mode 100644 users/grfn/achilles/src/tc/mod.rs (limited to 'users/grfn/achilles/src') diff --git a/users/grfn/achilles/src/ast/hir.rs b/users/grfn/achilles/src/ast/hir.rs new file mode 100644 index 000000000000..0d145d620bef --- /dev/null +++ b/users/grfn/achilles/src/ast/hir.rs @@ -0,0 +1,320 @@ +use std::collections::HashMap; + +use itertools::Itertools; + +use super::{BinaryOperator, Ident, Literal, UnaryOperator}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Binding<'a, T> { + pub ident: Ident<'a>, + pub type_: T, + pub body: Expr<'a, T>, +} + +impl<'a, T> Binding<'a, T> { + fn to_owned(&self) -> Binding<'static, T> + where + T: Clone, + { + Binding { + ident: self.ident.to_owned(), + type_: self.type_.clone(), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr<'a, T> { + Ident(Ident<'a>, T), + + Literal(Literal<'a>, T), + + UnaryOp { + op: UnaryOperator, + rhs: Box>, + type_: T, + }, + + BinaryOp { + lhs: Box>, + op: BinaryOperator, + rhs: Box>, + type_: T, + }, + + Let { + bindings: Vec>, + body: Box>, + type_: T, + }, + + If { + condition: Box>, + then: Box>, + else_: Box>, + type_: T, + }, + + Fun { + type_args: Vec>, + args: Vec<(Ident<'a>, T)>, + body: Box>, + type_: T, + }, + + Call { + fun: Box>, + type_args: HashMap, T>, + args: Vec>, + type_: T, + }, +} + +impl<'a, T> Expr<'a, T> { + pub fn type_(&self) -> &T { + match self { + Expr::Ident(_, t) => t, + Expr::Literal(_, t) => t, + Expr::UnaryOp { type_, .. } => type_, + Expr::BinaryOp { type_, .. } => type_, + Expr::Let { type_, .. } => type_, + Expr::If { type_, .. } => type_, + Expr::Fun { type_, .. } => type_, + Expr::Call { type_, .. } => type_, + } + } + + pub fn traverse_type(self, f: F) -> Result, E> + where + F: Fn(T) -> Result + Clone, + { + match self { + Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)), + Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)), + Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp { + op, + rhs: Box::new(rhs.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => Ok(Expr::BinaryOp { + lhs: Box::new(lhs.traverse_type(f.clone())?), + op, + rhs: Box::new(rhs.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Let { + bindings, + body, + type_, + } => Ok(Expr::Let { + bindings: bindings + .into_iter() + .map(|Binding { ident, type_, body }| { + Ok(Binding { + ident, + type_: f(type_)?, + body: body.traverse_type(f.clone())?, + }) + }) + .collect::, E>>()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::If { + condition, + then, + else_, + type_, + } => Ok(Expr::If { + condition: Box::new(condition.traverse_type(f.clone())?), + then: Box::new(then.traverse_type(f.clone())?), + else_: Box::new(else_.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Fun { + args, + type_args, + body, + type_, + } => Ok(Expr::Fun { + args: args + .into_iter() + .map(|(id, t)| Ok((id, f.clone()(t)?))) + .collect::, E>>()?, + type_args, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Call { + fun, + type_args, + args, + type_, + } => Ok(Expr::Call { + fun: Box::new(fun.traverse_type(f.clone())?), + type_args: type_args + .into_iter() + .map(|(id, ty)| Ok((id, f.clone()(ty)?))) + .collect::, E>>()?, + args: args + .into_iter() + .map(|e| e.traverse_type(f.clone())) + .collect::, E>>()?, + type_: f(type_)?, + }), + } + } + + pub fn to_owned(&self) -> Expr<'static, T> + where + T: Clone, + { + match self { + Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()), + Expr::Literal(lit, t) => Expr::Literal(lit.to_owned(), t.clone()), + Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp { + op: *op, + rhs: Box::new((**rhs).to_owned()), + type_: type_.clone(), + }, + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => Expr::BinaryOp { + lhs: Box::new((**lhs).to_owned()), + op: *op, + rhs: Box::new((**rhs).to_owned()), + type_: type_.clone(), + }, + Expr::Let { + bindings, + body, + type_, + } => Expr::Let { + bindings: bindings.iter().map(|b| b.to_owned()).collect(), + body: Box::new((**body).to_owned()), + type_: type_.clone(), + }, + Expr::If { + condition, + then, + else_, + type_, + } => Expr::If { + condition: Box::new((**condition).to_owned()), + then: Box::new((**then).to_owned()), + else_: Box::new((**else_).to_owned()), + type_: type_.clone(), + }, + Expr::Fun { + args, + type_args, + body, + type_, + } => Expr::Fun { + args: args + .iter() + .map(|(id, t)| (id.to_owned(), t.clone())) + .collect(), + type_args: type_args.iter().map(|arg| arg.to_owned()).collect(), + body: Box::new((**body).to_owned()), + type_: type_.clone(), + }, + Expr::Call { + fun, + type_args, + args, + type_, + } => Expr::Call { + fun: Box::new((**fun).to_owned()), + type_args: type_args + .iter() + .map(|(id, t)| (id.to_owned(), t.clone())) + .collect(), + args: args.iter().map(|e| e.to_owned()).collect(), + type_: type_.clone(), + }, + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Decl<'a, T> { + Fun { + name: Ident<'a>, + type_args: Vec>, + args: Vec<(Ident<'a>, T)>, + body: Box>, + type_: T, + }, + + Extern { + name: Ident<'a>, + arg_types: Vec, + ret_type: T, + }, +} + +impl<'a, T> Decl<'a, T> { + pub fn name(&self) -> &Ident<'a> { + match self { + Decl::Fun { name, .. } => name, + Decl::Extern { name, .. } => name, + } + } + + pub fn set_name(&mut self, new_name: Ident<'a>) { + match self { + Decl::Fun { name, .. } => *name = new_name, + Decl::Extern { name, .. } => *name = new_name, + } + } + + pub fn type_(&self) -> Option<&T> { + match self { + Decl::Fun { type_, .. } => Some(type_), + Decl::Extern { .. } => None, + } + } + + pub fn traverse_type(self, f: F) -> Result, E> + where + F: Fn(T) -> Result + Clone, + { + match self { + Decl::Fun { + name, + type_args, + args, + body, + type_, + } => Ok(Decl::Fun { + name, + type_args, + args: args + .into_iter() + .map(|(id, t)| Ok((id, f(t)?))) + .try_collect()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Decl::Extern { + name, + arg_types, + ret_type, + } => Ok(Decl::Extern { + name, + arg_types: arg_types.into_iter().map(f.clone()).try_collect()?, + ret_type: f(ret_type)?, + }), + } + } +} diff --git a/users/grfn/achilles/src/ast/mod.rs b/users/grfn/achilles/src/ast/mod.rs new file mode 100644 index 000000000000..7dc2de895709 --- /dev/null +++ b/users/grfn/achilles/src/ast/mod.rs @@ -0,0 +1,447 @@ +pub(crate) mod hir; + +use std::borrow::Cow; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::fmt::{self, Display, Formatter}; + +use itertools::Itertools; + +#[derive(Debug, PartialEq, Eq)] +pub struct InvalidIdentifier<'a>(Cow<'a, str>); + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct Ident<'a>(pub Cow<'a, str>); + +impl<'a> From<&'a Ident<'a>> for &'a str { + fn from(id: &'a Ident<'a>) -> Self { + id.0.as_ref() + } +} + +impl<'a> Display for Ident<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl<'a> Ident<'a> { + pub fn to_owned(&self) -> Ident<'static> { + Ident(Cow::Owned(self.0.clone().into_owned())) + } + + /// Construct an identifier from a &str without checking that it's a valid identifier + pub fn from_str_unchecked(s: &'a str) -> Self { + debug_assert!(is_valid_identifier(s)); + Self(Cow::Borrowed(s)) + } + + pub fn from_string_unchecked(s: String) -> Self { + debug_assert!(is_valid_identifier(&s)); + Self(Cow::Owned(s)) + } +} + +pub fn is_valid_identifier(s: &S) -> bool +where + S: AsRef + ?Sized, +{ + s.as_ref() + .chars() + .any(|c| !c.is_alphanumeric() || !"_".contains(c)) +} + +impl<'a> TryFrom<&'a str> for Ident<'a> { + type Error = InvalidIdentifier<'a>; + + fn try_from(s: &'a str) -> Result { + if is_valid_identifier(s) { + Ok(Ident(Cow::Borrowed(s))) + } else { + Err(InvalidIdentifier(Cow::Borrowed(s))) + } + } +} + +impl<'a> TryFrom for Ident<'a> { + type Error = InvalidIdentifier<'static>; + + fn try_from(s: String) -> Result { + if is_valid_identifier(&s) { + Ok(Ident(Cow::Owned(s))) + } else { + Err(InvalidIdentifier(Cow::Owned(s))) + } + } +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum BinaryOperator { + /// `+` + Add, + + /// `-` + Sub, + + /// `*` + Mul, + + /// `/` + Div, + + /// `^` + Pow, + + /// `==` + Equ, + + /// `!=` + Neq, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum UnaryOperator { + /// ! + Not, + + /// - + Neg, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Literal<'a> { + Unit, + Int(u64), + Bool(bool), + String(Cow<'a, str>), +} + +impl<'a> Literal<'a> { + pub fn to_owned(&self) -> Literal<'static> { + match self { + Literal::Int(i) => Literal::Int(*i), + Literal::Bool(b) => Literal::Bool(*b), + Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())), + Literal::Unit => Literal::Unit, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Binding<'a> { + pub ident: Ident<'a>, + pub type_: Option>, + pub body: Expr<'a>, +} + +impl<'a> Binding<'a> { + fn to_owned(&self) -> Binding<'static> { + Binding { + ident: self.ident.to_owned(), + type_: self.type_.as_ref().map(|t| t.to_owned()), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr<'a> { + Ident(Ident<'a>), + + Literal(Literal<'a>), + + UnaryOp { + op: UnaryOperator, + rhs: Box>, + }, + + BinaryOp { + lhs: Box>, + op: BinaryOperator, + rhs: Box>, + }, + + Let { + bindings: Vec>, + body: Box>, + }, + + If { + condition: Box>, + then: Box>, + else_: Box>, + }, + + Fun(Box>), + + Call { + fun: Box>, + args: Vec>, + }, + + Ascription { + expr: Box>, + type_: Type<'a>, + }, +} + +impl<'a> Expr<'a> { + pub fn to_owned(&self) -> Expr<'static> { + match self { + Expr::Ident(ref id) => Expr::Ident(id.to_owned()), + Expr::Literal(ref lit) => Expr::Literal(lit.to_owned()), + Expr::UnaryOp { op, rhs } => Expr::UnaryOp { + op: *op, + rhs: Box::new((**rhs).to_owned()), + }, + Expr::BinaryOp { lhs, op, rhs } => Expr::BinaryOp { + lhs: Box::new((**lhs).to_owned()), + op: *op, + rhs: Box::new((**rhs).to_owned()), + }, + Expr::Let { bindings, body } => Expr::Let { + bindings: bindings.iter().map(|binding| binding.to_owned()).collect(), + body: Box::new((**body).to_owned()), + }, + Expr::If { + condition, + then, + else_, + } => Expr::If { + condition: Box::new((**condition).to_owned()), + then: Box::new((**then).to_owned()), + else_: Box::new((**else_).to_owned()), + }, + Expr::Fun(fun) => Expr::Fun(Box::new((**fun).to_owned())), + Expr::Call { fun, args } => Expr::Call { + fun: Box::new((**fun).to_owned()), + args: args.iter().map(|arg| arg.to_owned()).collect(), + }, + Expr::Ascription { expr, type_ } => Expr::Ascription { + expr: Box::new((**expr).to_owned()), + type_: type_.to_owned(), + }, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Arg<'a> { + pub ident: Ident<'a>, + pub type_: Option>, +} + +impl<'a> Arg<'a> { + pub fn to_owned(&self) -> Arg<'static> { + Arg { + ident: self.ident.to_owned(), + type_: self.type_.as_ref().map(Type::to_owned), + } + } +} + +impl<'a> TryFrom<&'a str> for Arg<'a> { + type Error = as TryFrom<&'a str>>::Error; + + fn try_from(value: &'a str) -> Result { + Ok(Arg { + ident: Ident::try_from(value)?, + type_: None, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Fun<'a> { + pub args: Vec>, + pub body: Expr<'a>, +} + +impl<'a> Fun<'a> { + pub fn to_owned(&self) -> Fun<'static> { + Fun { + args: self.args.iter().map(|arg| arg.to_owned()).collect(), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Decl<'a> { + Fun { + name: Ident<'a>, + body: Fun<'a>, + }, + Ascription { + name: Ident<'a>, + type_: Type<'a>, + }, + Extern { + name: Ident<'a>, + type_: FunctionType<'a>, + }, +} + +//// + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct FunctionType<'a> { + pub args: Vec>, + pub ret: Box>, +} + +impl<'a> FunctionType<'a> { + pub fn to_owned(&self) -> FunctionType<'static> { + FunctionType { + args: self.args.iter().map(|a| a.to_owned()).collect(), + ret: Box::new((*self.ret).to_owned()), + } + } +} + +impl<'a> Display for FunctionType<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Type<'a> { + Int, + Float, + Bool, + CString, + Unit, + Var(Ident<'a>), + Function(FunctionType<'a>), +} + +impl<'a> Type<'a> { + pub fn to_owned(&self) -> Type<'static> { + match self { + Type::Int => Type::Int, + Type::Float => Type::Float, + Type::Bool => Type::Bool, + Type::CString => Type::CString, + Type::Unit => Type::Unit, + Type::Var(v) => Type::Var(v.to_owned()), + Type::Function(f) => Type::Function(f.to_owned()), + } + } + + pub fn alpha_equiv(&self, other: &Self) -> bool { + fn do_alpha_equiv<'a>( + substs: &mut HashMap<&'a Ident<'a>, &'a Ident<'a>>, + lhs: &'a Type, + rhs: &'a Type, + ) -> bool { + match (lhs, rhs) { + (Type::Var(v1), Type::Var(v2)) => substs.entry(v1).or_insert(v2) == &v2, + ( + Type::Function(FunctionType { + args: args1, + ret: ret1, + }), + Type::Function(FunctionType { + args: args2, + ret: ret2, + }), + ) => { + args1.len() == args2.len() + && args1 + .iter() + .zip(args2) + .all(|(a1, a2)| do_alpha_equiv(substs, a1, a2)) + && do_alpha_equiv(substs, ret1, ret2) + } + _ => lhs == rhs, + } + } + + let mut substs = HashMap::new(); + do_alpha_equiv(&mut substs, self, other) + } + + pub fn traverse_type_vars<'b, F>(self, mut f: F) -> Type<'b> + where + F: FnMut(Ident<'a>) -> Type<'b> + Clone, + { + match self { + Type::Var(tv) => f(tv), + Type::Function(FunctionType { args, ret }) => Type::Function(FunctionType { + args: args + .into_iter() + .map(|t| t.traverse_type_vars(f.clone())) + .collect(), + ret: Box::new(ret.traverse_type_vars(f)), + }), + Type::Int => Type::Int, + Type::Float => Type::Float, + Type::Bool => Type::Bool, + Type::CString => Type::CString, + Type::Unit => Type::Unit, + } + } +} + +impl<'a> Display for Type<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Int => f.write_str("int"), + Type::Float => f.write_str("float"), + Type::Bool => f.write_str("bool"), + Type::CString => f.write_str("cstring"), + Type::Unit => f.write_str("()"), + Type::Var(v) => v.fmt(f), + Type::Function(ft) => ft.fmt(f), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn type_var(n: &str) -> Type<'static> { + Type::Var(Ident::try_from(n.to_owned()).unwrap()) + } + + mod alpha_equiv { + use super::*; + + #[test] + fn trivial() { + assert!(Type::Int.alpha_equiv(&Type::Int)); + assert!(!Type::Int.alpha_equiv(&Type::Bool)); + } + + #[test] + fn simple_type_var() { + assert!(type_var("a").alpha_equiv(&type_var("b"))); + } + + #[test] + fn function_with_type_vars_equiv() { + assert!(Type::Function(FunctionType { + args: vec![type_var("a")], + ret: Box::new(type_var("b")), + }) + .alpha_equiv(&Type::Function(FunctionType { + args: vec![type_var("b")], + ret: Box::new(type_var("a")), + }))) + } + + #[test] + fn function_with_type_vars_non_equiv() { + assert!(!Type::Function(FunctionType { + args: vec![type_var("a")], + ret: Box::new(type_var("a")), + }) + .alpha_equiv(&Type::Function(FunctionType { + args: vec![type_var("b")], + ret: Box::new(type_var("a")), + }))) + } + } +} diff --git a/users/grfn/achilles/src/codegen/llvm.rs b/users/grfn/achilles/src/codegen/llvm.rs new file mode 100644 index 000000000000..17dec58b5ff7 --- /dev/null +++ b/users/grfn/achilles/src/codegen/llvm.rs @@ -0,0 +1,436 @@ +use std::convert::{TryFrom, TryInto}; +use std::path::Path; +use std::result; + +use inkwell::basic_block::BasicBlock; +use inkwell::builder::Builder; +pub use inkwell::context::Context; +use inkwell::module::Module; +use inkwell::support::LLVMString; +use inkwell::types::{BasicType, BasicTypeEnum, FunctionType, IntType}; +use inkwell::values::{AnyValueEnum, BasicValueEnum, FunctionValue}; +use inkwell::{AddressSpace, IntPredicate}; +use thiserror::Error; + +use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::{BinaryOperator, Ident, Literal, Type, UnaryOperator}; +use crate::common::env::Env; + +#[derive(Debug, PartialEq, Eq, Error)] +pub enum Error { + #[error("Undefined variable {0}")] + UndefinedVariable(Ident<'static>), + + #[error("LLVM Error: {0}")] + LLVMError(String), +} + +impl From for Error { + fn from(s: LLVMString) -> Self { + Self::LLVMError(s.to_string()) + } +} + +pub type Result = result::Result; + +pub struct Codegen<'ctx, 'ast> { + context: &'ctx Context, + pub module: Module<'ctx>, + builder: Builder<'ctx>, + env: Env<&'ast Ident<'ast>, AnyValueEnum<'ctx>>, + function_stack: Vec>, + identifier_counter: u32, +} + +impl<'ctx, 'ast> Codegen<'ctx, 'ast> { + pub fn new(context: &'ctx Context, module_name: &str) -> Self { + let module = context.create_module(module_name); + let builder = context.create_builder(); + Self { + context, + module, + builder, + env: Default::default(), + function_stack: Default::default(), + identifier_counter: 0, + } + } + + pub fn new_function<'a>( + &'a mut self, + name: &str, + ty: FunctionType<'ctx>, + ) -> &'a FunctionValue<'ctx> { + self.function_stack + .push(self.module.add_function(name, ty, None)); + let basic_block = self.append_basic_block("entry"); + self.builder.position_at_end(basic_block); + self.function_stack.last().unwrap() + } + + pub fn finish_function(&mut self, res: Option<&BasicValueEnum<'ctx>>) -> FunctionValue<'ctx> { + self.builder.build_return(match res { + // lol + Some(val) => Some(val), + None => None, + }); + self.function_stack.pop().unwrap() + } + + pub fn append_basic_block(&self, name: &str) -> BasicBlock<'ctx> { + self.context + .append_basic_block(*self.function_stack.last().unwrap(), name) + } + + pub fn codegen_expr( + &mut self, + expr: &'ast Expr<'ast, Type>, + ) -> Result>> { + match expr { + Expr::Ident(id, _) => self + .env + .resolve(id) + .cloned() + .ok_or_else(|| Error::UndefinedVariable(id.to_owned())) + .map(Some), + Expr::Literal(lit, ty) => { + let ty = self.codegen_int_type(ty); + match lit { + Literal::Int(i) => Ok(Some(AnyValueEnum::IntValue(ty.const_int(*i, false)))), + Literal::Bool(b) => Ok(Some(AnyValueEnum::IntValue( + ty.const_int(if *b { 1 } else { 0 }, false), + ))), + Literal::String(s) => Ok(Some( + self.builder + .build_global_string_ptr(s, "s") + .as_pointer_value() + .into(), + )), + Literal::Unit => Ok(None), + } + } + Expr::UnaryOp { op, rhs, .. } => { + let rhs = self.codegen_expr(rhs)?.unwrap(); + match op { + UnaryOperator::Not => unimplemented!(), + UnaryOperator::Neg => Ok(Some(AnyValueEnum::IntValue( + self.builder.build_int_neg(rhs.into_int_value(), "neg"), + ))), + } + } + Expr::BinaryOp { lhs, op, rhs, .. } => { + let lhs = self.codegen_expr(lhs)?.unwrap(); + let rhs = self.codegen_expr(rhs)?.unwrap(); + match op { + BinaryOperator::Add => { + Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_add( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + )))) + } + BinaryOperator::Sub => { + Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + )))) + } + BinaryOperator::Mul => { + Ok(Some(AnyValueEnum::IntValue(self.builder.build_int_sub( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + )))) + } + BinaryOperator::Div => Ok(Some(AnyValueEnum::IntValue( + self.builder.build_int_signed_div( + lhs.into_int_value(), + rhs.into_int_value(), + "add", + ), + ))), + BinaryOperator::Pow => unimplemented!(), + BinaryOperator::Equ => Ok(Some(AnyValueEnum::IntValue( + self.builder.build_int_compare( + IntPredicate::EQ, + lhs.into_int_value(), + rhs.into_int_value(), + "eq", + ), + ))), + BinaryOperator::Neq => todo!(), + } + } + Expr::Let { bindings, body, .. } => { + self.env.push(); + for Binding { ident, body, .. } in bindings { + if let Some(val) = self.codegen_expr(body)? { + self.env.set(ident, val); + } + } + let res = self.codegen_expr(body); + self.env.pop(); + res + } + Expr::If { + condition, + then, + else_, + type_, + } => { + let then_block = self.append_basic_block("then"); + let else_block = self.append_basic_block("else"); + let join_block = self.append_basic_block("join"); + let condition = self.codegen_expr(condition)?.unwrap(); + self.builder.build_conditional_branch( + condition.into_int_value(), + then_block, + else_block, + ); + self.builder.position_at_end(then_block); + let then_res = self.codegen_expr(then)?; + self.builder.build_unconditional_branch(join_block); + + self.builder.position_at_end(else_block); + let else_res = self.codegen_expr(else_)?; + self.builder.build_unconditional_branch(join_block); + + self.builder.position_at_end(join_block); + if let Some(phi_type) = self.codegen_type(type_) { + let phi = self.builder.build_phi(phi_type, "join"); + phi.add_incoming(&[ + ( + &BasicValueEnum::try_from(then_res.unwrap()).unwrap(), + then_block, + ), + ( + &BasicValueEnum::try_from(else_res.unwrap()).unwrap(), + else_block, + ), + ]); + Ok(Some(phi.as_basic_value().into())) + } else { + Ok(None) + } + } + Expr::Call { fun, args, .. } => { + if let Expr::Ident(id, _) = &**fun { + let function = self + .module + .get_function(id.into()) + .or_else(|| self.env.resolve(id)?.clone().try_into().ok()) + .ok_or_else(|| Error::UndefinedVariable(id.to_owned()))?; + let args = args + .iter() + .map(|arg| Ok(self.codegen_expr(arg)?.unwrap().try_into().unwrap())) + .collect::>>()?; + Ok(self + .builder + .build_call(function, &args, "call") + .try_as_basic_value() + .left() + .map(|val| val.into())) + } else { + todo!() + } + } + Expr::Fun { args, body, .. } => { + let fname = self.fresh_ident("f"); + let cur_block = self.builder.get_insert_block().unwrap(); + let env = self.env.save(); // TODO: closures + let function = self.codegen_function(&fname, args, body)?; + self.builder.position_at_end(cur_block); + self.env.restore(env); + Ok(Some(function.into())) + } + } + } + + pub fn codegen_function( + &mut self, + name: &str, + args: &'ast [(Ident<'ast>, Type)], + body: &'ast Expr<'ast, Type>, + ) -> Result> { + let arg_types = args + .iter() + .filter_map(|(_, at)| self.codegen_type(at)) + .collect::>(); + + self.new_function( + name, + match self.codegen_type(body.type_()) { + Some(ret_ty) => ret_ty.fn_type(&arg_types, false), + None => self.context.void_type().fn_type(&arg_types, false), + }, + ); + self.env.push(); + for (i, (arg, _)) in args.iter().enumerate() { + self.env.set( + arg, + self.cur_function().get_nth_param(i as u32).unwrap().into(), + ); + } + let res = self.codegen_expr(body)?; + self.env.pop(); + Ok(self.finish_function(res.map(|av| av.try_into().unwrap()).as_ref())) + } + + pub fn codegen_extern( + &mut self, + name: &str, + args: &'ast [Type], + ret: &'ast Type, + ) -> Result<()> { + let arg_types = args + .iter() + .map(|t| self.codegen_type(t).unwrap()) + .collect::>(); + self.module.add_function( + name, + match self.codegen_type(ret) { + Some(ret_ty) => ret_ty.fn_type(&arg_types, false), + None => self.context.void_type().fn_type(&arg_types, false), + }, + None, + ); + Ok(()) + } + + pub fn codegen_decl(&mut self, decl: &'ast Decl<'ast, Type>) -> Result<()> { + match decl { + Decl::Fun { + name, args, body, .. + } => { + self.codegen_function(name.into(), args, body)?; + Ok(()) + } + Decl::Extern { + name, + arg_types, + ret_type, + } => self.codegen_extern(name.into(), arg_types, ret_type), + } + } + + pub fn codegen_main(&mut self, expr: &'ast Expr<'ast, Type>) -> Result<()> { + self.new_function("main", self.context.i64_type().fn_type(&[], false)); + let res = self.codegen_expr(expr)?; + if *expr.type_() != Type::Int { + self.builder + .build_return(Some(&self.context.i64_type().const_int(0, false))); + } else { + self.finish_function(res.map(|r| r.try_into().unwrap()).as_ref()); + } + Ok(()) + } + + fn codegen_type(&self, type_: &'ast Type) -> Option> { + // TODO + match type_ { + Type::Int => Some(self.context.i64_type().into()), + Type::Float => Some(self.context.f64_type().into()), + Type::Bool => Some(self.context.bool_type().into()), + Type::CString => Some( + self.context + .i8_type() + .ptr_type(AddressSpace::Generic) + .into(), + ), + Type::Function(_) => todo!(), + Type::Var(_) => unreachable!(), + Type::Unit => None, + } + } + + fn codegen_int_type(&self, type_: &'ast Type) -> IntType<'ctx> { + // TODO + self.context.i64_type() + } + + pub fn print_to_file

(&self, path: P) -> Result<()> + where + P: AsRef, + { + Ok(self.module.print_to_file(path)?) + } + + pub fn binary_to_file

(&self, path: P) -> Result<()> + where + P: AsRef, + { + if self.module.write_bitcode_to_path(path.as_ref()) { + Ok(()) + } else { + Err(Error::LLVMError( + "Error writing bitcode to output path".to_owned(), + )) + } + } + + fn fresh_ident(&mut self, prefix: &str) -> String { + self.identifier_counter += 1; + format!("{}{}", prefix, self.identifier_counter) + } + + fn cur_function(&self) -> &FunctionValue<'ctx> { + self.function_stack.last().unwrap() + } +} + +#[cfg(test)] +mod tests { + use inkwell::execution_engine::JitFunction; + use inkwell::OptimizationLevel; + + use super::*; + + fn jit_eval(expr: &str) -> anyhow::Result { + let expr = crate::parser::expr(expr).unwrap().1; + + let expr = crate::tc::typecheck_expr(expr).unwrap(); + + let context = Context::create(); + let mut codegen = Codegen::new(&context, "test"); + let execution_engine = codegen + .module + .create_jit_execution_engine(OptimizationLevel::None) + .unwrap(); + + codegen.codegen_function("test", &[], &expr)?; + + unsafe { + let fun: JitFunction T> = + execution_engine.get_function("test")?; + Ok(fun.call()) + } + } + + #[test] + fn add_literals() { + assert_eq!(jit_eval::("1 + 2").unwrap(), 3); + } + + #[test] + fn variable_shadowing() { + assert_eq!( + jit_eval::("let x = 1 in (let x = 2 in x) + x").unwrap(), + 3 + ); + } + + #[test] + fn eq() { + assert_eq!( + jit_eval::("let x = 1 in if x == 1 then 2 else 4").unwrap(), + 2 + ); + } + + #[test] + fn function_call() { + let res = jit_eval::("let id = fn x = x in id 1").unwrap(); + assert_eq!(res, 1); + } +} diff --git a/users/grfn/achilles/src/codegen/mod.rs b/users/grfn/achilles/src/codegen/mod.rs new file mode 100644 index 000000000000..8ef057dba04f --- /dev/null +++ b/users/grfn/achilles/src/codegen/mod.rs @@ -0,0 +1,25 @@ +pub mod llvm; + +use inkwell::execution_engine::JitFunction; +use inkwell::OptimizationLevel; +pub use llvm::*; + +use crate::ast::hir::Expr; +use crate::ast::Type; +use crate::common::Result; + +pub fn jit_eval(expr: &Expr) -> Result { + let context = Context::create(); + let mut codegen = Codegen::new(&context, "eval"); + let execution_engine = codegen + .module + .create_jit_execution_engine(OptimizationLevel::None) + .map_err(Error::from)?; + codegen.codegen_function("test", &[], &expr)?; + + unsafe { + let fun: JitFunction T> = + execution_engine.get_function("eval").unwrap(); + Ok(fun.call()) + } +} diff --git a/users/grfn/achilles/src/commands/check.rs b/users/grfn/achilles/src/commands/check.rs new file mode 100644 index 000000000000..0bea482c1478 --- /dev/null +++ b/users/grfn/achilles/src/commands/check.rs @@ -0,0 +1,39 @@ +use clap::Clap; +use std::path::PathBuf; + +use crate::ast::Type; +use crate::{parser, tc, Result}; + +/// Typecheck a file or expression +#[derive(Clap)] +pub struct Check { + /// File to check + path: Option, + + /// Expression to check + #[clap(long, short = 'e')] + expr: Option, +} + +fn run_expr(expr: String) -> Result> { + let (_, parsed) = parser::expr(&expr)?; + let hir_expr = tc::typecheck_expr(parsed)?; + Ok(hir_expr.type_().to_owned()) +} + +fn run_path(path: PathBuf) -> Result> { + todo!() +} + +impl Check { + pub fn run(self) -> Result<()> { + let type_ = match (self.path, self.expr) { + (None, None) => Err("Must specify either a file or expression to check".into()), + (Some(_), Some(_)) => Err("Cannot specify both a file and expression to check".into()), + (None, Some(expr)) => run_expr(expr), + (Some(path), None) => run_path(path), + }?; + println!("type: {}", type_); + Ok(()) + } +} diff --git a/users/grfn/achilles/src/commands/compile.rs b/users/grfn/achilles/src/commands/compile.rs new file mode 100644 index 000000000000..be8767575ab5 --- /dev/null +++ b/users/grfn/achilles/src/commands/compile.rs @@ -0,0 +1,31 @@ +use std::path::PathBuf; + +use clap::Clap; + +use crate::common::Result; +use crate::compiler::{self, CompilerOptions}; + +/// Compile a source file +#[derive(Clap)] +pub struct Compile { + /// File to compile + file: PathBuf, + + /// Output file + #[clap(short = 'o')] + out_file: PathBuf, + + #[clap(flatten)] + options: CompilerOptions, +} + +impl Compile { + pub fn run(self) -> Result<()> { + eprintln!( + ">>> {} -> {}", + &self.file.to_string_lossy(), + self.out_file.to_string_lossy() + ); + compiler::compile_file(&self.file, &self.out_file, &self.options) + } +} diff --git a/users/grfn/achilles/src/commands/eval.rs b/users/grfn/achilles/src/commands/eval.rs new file mode 100644 index 000000000000..61a712c08a8e --- /dev/null +++ b/users/grfn/achilles/src/commands/eval.rs @@ -0,0 +1,32 @@ +use clap::Clap; + +use crate::codegen; +use crate::interpreter; +use crate::parser; +use crate::tc; +use crate::Result; + +/// Evaluate an expression and print its result +#[derive(Clap)] +pub struct Eval { + /// JIT-compile with LLVM instead of interpreting + #[clap(long)] + jit: bool, + + /// Expression to evaluate + expr: String, +} + +impl Eval { + pub fn run(self) -> Result<()> { + let (_, parsed) = parser::expr(&self.expr)?; + let hir = tc::typecheck_expr(parsed)?; + let result = if self.jit { + codegen::jit_eval::(&hir)?.into() + } else { + interpreter::eval(&hir)? + }; + println!("{}", result); + Ok(()) + } +} diff --git a/users/grfn/achilles/src/commands/mod.rs b/users/grfn/achilles/src/commands/mod.rs new file mode 100644 index 000000000000..fd0a822708c2 --- /dev/null +++ b/users/grfn/achilles/src/commands/mod.rs @@ -0,0 +1,7 @@ +pub mod check; +pub mod compile; +pub mod eval; + +pub use check::Check; +pub use compile::Compile; +pub use eval::Eval; diff --git a/users/grfn/achilles/src/common/env.rs b/users/grfn/achilles/src/common/env.rs new file mode 100644 index 000000000000..59a5e46c466f --- /dev/null +++ b/users/grfn/achilles/src/common/env.rs @@ -0,0 +1,59 @@ +use std::borrow::Borrow; +use std::collections::HashMap; +use std::hash::Hash; +use std::mem; + +/// A lexical environment +#[derive(Debug, PartialEq, Eq)] +pub struct Env(Vec>); + +impl Default for Env +where + K: Eq + Hash, +{ + fn default() -> Self { + Self::new() + } +} + +impl Env +where + K: Eq + Hash, +{ + pub fn new() -> Self { + Self(vec![Default::default()]) + } + + pub fn push(&mut self) { + self.0.push(Default::default()); + } + + pub fn pop(&mut self) { + self.0.pop(); + } + + pub fn save(&mut self) -> Self { + mem::take(self) + } + + pub fn restore(&mut self, saved: Self) { + *self = saved; + } + + pub fn set(&mut self, k: K, v: V) { + self.0.last_mut().unwrap().insert(k, v); + } + + pub fn resolve<'a, Q>(&'a self, k: &Q) -> Option<&'a V> + where + K: Borrow, + Q: Hash + Eq + ?Sized, + { + for ctx in self.0.iter().rev() { + if let Some(res) = ctx.get(k) { + return Some(res); + } + } + None + } +} diff --git a/users/grfn/achilles/src/common/error.rs b/users/grfn/achilles/src/common/error.rs new file mode 100644 index 000000000000..51575a895e91 --- /dev/null +++ b/users/grfn/achilles/src/common/error.rs @@ -0,0 +1,59 @@ +use std::{io, result}; + +use thiserror::Error; + +use crate::{codegen, interpreter, parser, tc}; + +#[derive(Error, Debug)] +pub enum Error { + #[error(transparent)] + IOError(#[from] io::Error), + + #[error("Error parsing input: {0}")] + ParseError(#[from] parser::Error), + + #[error("Error evaluating expression: {0}")] + EvalError(#[from] interpreter::Error), + + #[error("Compile error: {0}")] + CodegenError(#[from] codegen::Error), + + #[error("Type error: {0}")] + TypeError(#[from] tc::Error), + + #[error("{0}")] + Message(String), +} + +impl From for Error { + fn from(s: String) -> Self { + Self::Message(s) + } +} + +impl<'a> From<&'a str> for Error { + fn from(s: &'a str) -> Self { + Self::Message(s.to_owned()) + } +} + +impl<'a> From>> for Error { + fn from(e: nom::Err>) -> Self { + use nom::error::Error as NomError; + use nom::Err::*; + + Self::ParseError(match e { + Incomplete(i) => Incomplete(i), + Error(NomError { input, code }) => Error(NomError { + input: input.to_owned(), + code, + }), + Failure(NomError { input, code }) => Failure(NomError { + input: input.to_owned(), + code, + }), + }) + } +} + +pub type Result = result::Result; diff --git a/users/grfn/achilles/src/common/mod.rs b/users/grfn/achilles/src/common/mod.rs new file mode 100644 index 000000000000..8368a6dd180f --- /dev/null +++ b/users/grfn/achilles/src/common/mod.rs @@ -0,0 +1,6 @@ +pub(crate) mod env; +pub(crate) mod error; +pub(crate) mod namer; + +pub use error::{Error, Result}; +pub use namer::{Namer, NamerOf}; diff --git a/users/grfn/achilles/src/common/namer.rs b/users/grfn/achilles/src/common/namer.rs new file mode 100644 index 000000000000..016e9f6ed99a --- /dev/null +++ b/users/grfn/achilles/src/common/namer.rs @@ -0,0 +1,122 @@ +use std::fmt::Display; +use std::marker::PhantomData; + +pub struct Namer { + make_name: F, + counter: u64, + _phantom: PhantomData, +} + +impl Namer { + pub fn new(make_name: F) -> Self { + Namer { + make_name, + counter: 0, + _phantom: PhantomData, + } + } +} + +impl Namer String>> { + pub fn with_prefix(prefix: T) -> Self + where + T: Display + 'static, + { + Namer::new(move |i| format!("{}{}", prefix, i)).boxed() + } + + pub fn with_suffix(suffix: T) -> Self + where + T: Display + 'static, + { + Namer::new(move |i| format!("{}{}", i, suffix)).boxed() + } + + pub fn alphabetic() -> Self { + Namer::new(|i| { + if i <= 26 { + std::char::from_u32((i + 96) as u32).unwrap().to_string() + } else { + format!( + "{}{}", + std::char::from_u32(((i % 26) + 96) as u32).unwrap(), + i - 26 + ) + } + }) + .boxed() + } +} + +impl Namer +where + F: Fn(u64) -> T, +{ + pub fn make_name(&mut self) -> T { + self.counter += 1; + (self.make_name)(self.counter) + } + + pub fn boxed(self) -> NamerOf + where + F: 'static, + { + Namer { + make_name: Box::new(self.make_name), + counter: self.counter, + _phantom: self._phantom, + } + } + + pub fn map(self, f: G) -> NamerOf + where + G: Fn(T) -> U + 'static, + T: 'static, + F: 'static, + { + Namer { + counter: self.counter, + make_name: Box::new(move |x| f((self.make_name)(x))), + _phantom: PhantomData, + } + } +} + +pub type NamerOf = Namer T>>; + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prefix() { + let mut namer = Namer::with_prefix("t"); + assert_eq!(namer.make_name(), "t1"); + assert_eq!(namer.make_name(), "t2"); + } + + #[test] + fn suffix() { + let mut namer = Namer::with_suffix("t"); + assert_eq!(namer.make_name(), "1t"); + assert_eq!(namer.make_name(), "2t"); + } + + #[test] + fn alphabetic() { + let mut namer = Namer::alphabetic(); + assert_eq!(namer.make_name(), "a"); + assert_eq!(namer.make_name(), "b"); + (0..25).for_each(|_| { + namer.make_name(); + }); + assert_eq!(namer.make_name(), "b2"); + } + + #[test] + fn custom_callback() { + let mut namer = Namer::new(|n| n + 1); + assert_eq!(namer.make_name(), 2); + assert_eq!(namer.make_name(), 3); + } +} diff --git a/users/grfn/achilles/src/compiler.rs b/users/grfn/achilles/src/compiler.rs new file mode 100644 index 000000000000..45b215473d7f --- /dev/null +++ b/users/grfn/achilles/src/compiler.rs @@ -0,0 +1,89 @@ +use std::fmt::{self, Display}; +use std::path::Path; +use std::str::FromStr; +use std::{fs, result}; + +use clap::Clap; +use test_strategy::Arbitrary; + +use crate::codegen::{self, Codegen}; +use crate::common::Result; +use crate::passes::hir::{monomorphize, strip_positive_units}; +use crate::{parser, tc}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Arbitrary)] +pub enum OutputFormat { + LLVM, + Bitcode, +} + +impl Default for OutputFormat { + fn default() -> Self { + Self::Bitcode + } +} + +impl FromStr for OutputFormat { + type Err = String; + + fn from_str(s: &str) -> result::Result { + match s { + "llvm" => Ok(Self::LLVM), + "binary" => Ok(Self::Bitcode), + _ => Err(format!( + "Invalid output format {}, expected one of {{llvm, binary}}", + s + )), + } + } +} + +impl Display for OutputFormat { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + OutputFormat::LLVM => f.write_str("llvm"), + OutputFormat::Bitcode => f.write_str("binary"), + } + } +} + +#[derive(Clap, Debug, PartialEq, Eq, Default)] +pub struct CompilerOptions { + #[clap(long, short = 'f', default_value)] + format: OutputFormat, +} + +pub fn compile_file(input: &Path, output: &Path, options: &CompilerOptions) -> Result<()> { + let src = fs::read_to_string(input)?; + let (_, decls) = parser::toplevel(&src)?; + let mut decls = tc::typecheck_toplevel(decls)?; + monomorphize::run_toplevel(&mut decls); + strip_positive_units::run_toplevel(&mut decls); + + let context = codegen::Context::create(); + let mut codegen = Codegen::new( + &context, + &input + .file_stem() + .map_or("UNKNOWN".to_owned(), |s| s.to_string_lossy().into_owned()), + ); + for decl in &decls { + codegen.codegen_decl(decl)?; + } + match options.format { + OutputFormat::LLVM => codegen.print_to_file(output)?, + OutputFormat::Bitcode => codegen.binary_to_file(output)?, + } + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + use test_strategy::proptest; + + #[proptest] + fn output_format_display_from_str_round_trip(of: OutputFormat) { + assert_eq!(OutputFormat::from_str(&of.to_string()), Ok(of)); + } +} diff --git a/users/grfn/achilles/src/interpreter/error.rs b/users/grfn/achilles/src/interpreter/error.rs new file mode 100644 index 000000000000..268d6f479a1e --- /dev/null +++ b/users/grfn/achilles/src/interpreter/error.rs @@ -0,0 +1,19 @@ +use std::result; + +use thiserror::Error; + +use crate::ast::{Ident, Type}; + +#[derive(Debug, PartialEq, Eq, Error)] +pub enum Error { + #[error("Undefined variable {0}")] + UndefinedVariable(Ident<'static>), + + #[error("Unexpected type {actual}, expected type {expected}")] + InvalidType { + actual: Type<'static>, + expected: Type<'static>, + }, +} + +pub type Result = result::Result; diff --git a/users/grfn/achilles/src/interpreter/mod.rs b/users/grfn/achilles/src/interpreter/mod.rs new file mode 100644 index 000000000000..a8ba2dd3acdc --- /dev/null +++ b/users/grfn/achilles/src/interpreter/mod.rs @@ -0,0 +1,182 @@ +mod error; +mod value; + +pub use self::error::{Error, Result}; +pub use self::value::{Function, Value}; +use crate::ast::hir::{Binding, Expr}; +use crate::ast::{BinaryOperator, FunctionType, Ident, Literal, Type, UnaryOperator}; +use crate::common::env::Env; + +#[derive(Debug, Default)] +pub struct Interpreter<'a> { + env: Env<&'a Ident<'a>, Value<'a>>, +} + +impl<'a> Interpreter<'a> { + pub fn new() -> Self { + Self::default() + } + + fn resolve(&self, var: &'a Ident<'a>) -> Result> { + self.env + .resolve(var) + .cloned() + .ok_or_else(|| Error::UndefinedVariable(var.to_owned())) + } + + pub fn eval(&mut self, expr: &'a Expr<'a, Type>) -> Result> { + let res = match expr { + Expr::Ident(id, _) => self.resolve(id), + Expr::Literal(Literal::Int(i), _) => Ok((*i).into()), + Expr::Literal(Literal::Bool(b), _) => Ok((*b).into()), + Expr::Literal(Literal::String(s), _) => Ok(s.clone().into()), + Expr::Literal(Literal::Unit, _) => unreachable!(), + Expr::UnaryOp { op, rhs, .. } => { + let rhs = self.eval(rhs)?; + match op { + UnaryOperator::Neg => -rhs, + _ => unimplemented!(), + } + } + Expr::BinaryOp { lhs, op, rhs, .. } => { + let lhs = self.eval(lhs)?; + let rhs = self.eval(rhs)?; + match op { + BinaryOperator::Add => lhs + rhs, + BinaryOperator::Sub => lhs - rhs, + BinaryOperator::Mul => lhs * rhs, + BinaryOperator::Div => lhs / rhs, + BinaryOperator::Pow => todo!(), + BinaryOperator::Equ => Ok(lhs.eq(&rhs).into()), + BinaryOperator::Neq => todo!(), + } + } + Expr::Let { bindings, body, .. } => { + self.env.push(); + for Binding { ident, body, .. } in bindings { + let val = self.eval(body)?; + self.env.set(ident, val); + } + let res = self.eval(body)?; + self.env.pop(); + Ok(res) + } + Expr::If { + condition, + then, + else_, + .. + } => { + let condition = self.eval(condition)?; + if *(condition.as_type::()?) { + self.eval(then) + } else { + self.eval(else_) + } + } + Expr::Call { ref fun, args, .. } => { + let fun = self.eval(fun)?; + let expected_type = FunctionType { + args: args.iter().map(|_| Type::Int).collect(), + ret: Box::new(Type::Int), + }; + + let Function { + args: function_args, + body, + .. + } = fun.as_function(expected_type)?; + let arg_values = function_args.iter().zip( + args.iter() + .map(|v| self.eval(v)) + .collect::>>()?, + ); + let mut interpreter = Interpreter::new(); + for (arg_name, arg_value) in arg_values { + interpreter.env.set(arg_name, arg_value); + } + Ok(Value::from(*interpreter.eval(body)?.as_type::()?)) + } + Expr::Fun { + type_args: _, + args, + body, + type_, + } => { + let type_ = match type_ { + Type::Function(ft) => ft.clone(), + _ => unreachable!("Function expression without function type"), + }; + + Ok(Value::from(value::Function { + // TODO + type_, + args: args.iter().map(|(arg, _)| arg.to_owned()).collect(), + body: (**body).to_owned(), + })) + } + }?; + debug_assert_eq!(&res.type_(), expr.type_()); + Ok(res) + } +} + +pub fn eval<'a>(expr: &'a Expr<'a, Type>) -> Result> { + let mut interpreter = Interpreter::new(); + interpreter.eval(expr) +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::value::{TypeOf, Val}; + use super::*; + use BinaryOperator::*; + + fn int_lit(i: u64) -> Box>> { + Box::new(Expr::Literal(Literal::Int(i), Type::Int)) + } + + fn do_eval(src: &str) -> T + where + for<'a> &'a T: TryFrom<&'a Val<'a>>, + T: Clone + TypeOf, + { + let expr = crate::parser::expr(src).unwrap().1; + let hir = crate::tc::typecheck_expr(expr).unwrap(); + let res = eval(&hir).unwrap(); + res.as_type::().unwrap().clone() + } + + #[test] + fn simple_addition() { + let expr = Expr::BinaryOp { + lhs: int_lit(1), + op: Mul, + rhs: int_lit(2), + type_: Type::Int, + }; + let res = eval(&expr).unwrap(); + assert_eq!(*res.as_type::().unwrap(), 2); + } + + #[test] + fn variable_shadowing() { + let res = do_eval::("let x = 1 in (let x = 2 in x) + x"); + assert_eq!(res, 3); + } + + #[test] + fn conditional_with_equals() { + let res = do_eval::("let x = 1 in if x == 1 then 2 else 4"); + assert_eq!(res, 2); + } + + #[test] + #[ignore] + fn function_call() { + let res = do_eval::("let id = fn x = x in id 1"); + assert_eq!(res, 1); + } +} diff --git a/users/grfn/achilles/src/interpreter/value.rs b/users/grfn/achilles/src/interpreter/value.rs new file mode 100644 index 000000000000..55ba42f9de58 --- /dev/null +++ b/users/grfn/achilles/src/interpreter/value.rs @@ -0,0 +1,203 @@ +use std::borrow::Cow; +use std::convert::TryFrom; +use std::fmt::{self, Display}; +use std::ops::{Add, Div, Mul, Neg, Sub}; +use std::rc::Rc; +use std::result; + +use derive_more::{Deref, From, TryInto}; + +use super::{Error, Result}; +use crate::ast::hir::Expr; +use crate::ast::{FunctionType, Ident, Type}; + +#[derive(Debug, Clone)] +pub struct Function<'a> { + pub type_: FunctionType<'a>, + pub args: Vec>, + pub body: Expr<'a, Type<'a>>, +} + +#[derive(From, TryInto)] +#[try_into(owned, ref)] +pub enum Val<'a> { + Int(i64), + Float(f64), + Bool(bool), + String(Cow<'a, str>), + Function(Function<'a>), +} + +impl<'a> TryFrom> for String { + type Error = (); + + fn try_from(value: Val<'a>) -> result::Result { + match value { + Val::String(s) => Ok(s.into_owned()), + _ => Err(()), + } + } +} + +impl<'a> fmt::Debug for Val<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Val::Int(x) => f.debug_tuple("Int").field(x).finish(), + Val::Float(x) => f.debug_tuple("Float").field(x).finish(), + Val::Bool(x) => f.debug_tuple("Bool").field(x).finish(), + Val::String(s) => f.debug_tuple("String").field(s).finish(), + Val::Function(Function { type_, .. }) => { + f.debug_struct("Function").field("type_", type_).finish() + } + } + } +} + +impl<'a> PartialEq for Val<'a> { + fn eq(&self, other: &Self) -> bool { + match (self, other) { + (Val::Int(x), Val::Int(y)) => x == y, + (Val::Float(x), Val::Float(y)) => x == y, + (Val::Bool(x), Val::Bool(y)) => x == y, + (Val::Function(_), Val::Function(_)) => false, + (_, _) => false, + } + } +} + +impl<'a> From for Val<'a> { + fn from(i: u64) -> Self { + Self::from(i as i64) + } +} + +impl<'a> Display for Val<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Val::Int(x) => x.fmt(f), + Val::Float(x) => x.fmt(f), + Val::Bool(x) => x.fmt(f), + Val::String(s) => write!(f, "{:?}", s), + Val::Function(Function { type_, .. }) => write!(f, "<{}>", type_), + } + } +} + +impl<'a> Val<'a> { + pub fn type_(&self) -> Type { + match self { + Val::Int(_) => Type::Int, + Val::Float(_) => Type::Float, + Val::Bool(_) => Type::Bool, + Val::String(_) => Type::CString, + Val::Function(Function { type_, .. }) => Type::Function(type_.clone()), + } + } + + pub fn as_type<'b, T>(&'b self) -> Result<&'b T> + where + T: TypeOf + 'b + Clone, + &'b T: TryFrom<&'b Self>, + { + <&T>::try_from(self).map_err(|_| Error::InvalidType { + actual: self.type_().to_owned(), + expected: ::type_of(), + }) + } + + pub fn as_function<'b>(&'b self, function_type: FunctionType) -> Result<&'b Function<'a>> { + match self { + Val::Function(f) if f.type_ == function_type => Ok(&f), + _ => Err(Error::InvalidType { + actual: self.type_().to_owned(), + expected: Type::Function(function_type.to_owned()), + }), + } + } +} + +#[derive(Debug, PartialEq, Clone, Deref)] +pub struct Value<'a>(Rc>); + +impl<'a> Display for Value<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + self.0.fmt(f) + } +} + +impl<'a, T> From for Value<'a> +where + Val<'a>: From, +{ + fn from(x: T) -> Self { + Self(Rc::new(x.into())) + } +} + +impl<'a> Neg for Value<'a> { + type Output = Result>; + + fn neg(self) -> Self::Output { + Ok((-self.as_type::()?).into()) + } +} + +impl<'a> Add for Value<'a> { + type Output = Result>; + + fn add(self, rhs: Self) -> Self::Output { + Ok((self.as_type::()? + rhs.as_type::()?).into()) + } +} + +impl<'a> Sub for Value<'a> { + type Output = Result>; + + fn sub(self, rhs: Self) -> Self::Output { + Ok((self.as_type::()? - rhs.as_type::()?).into()) + } +} + +impl<'a> Mul for Value<'a> { + type Output = Result>; + + fn mul(self, rhs: Self) -> Self::Output { + Ok((self.as_type::()? * rhs.as_type::()?).into()) + } +} + +impl<'a> Div for Value<'a> { + type Output = Result>; + + fn div(self, rhs: Self) -> Self::Output { + Ok((self.as_type::()? / rhs.as_type::()?).into()) + } +} + +pub trait TypeOf { + fn type_of() -> Type<'static>; +} + +impl TypeOf for i64 { + fn type_of() -> Type<'static> { + Type::Int + } +} + +impl TypeOf for bool { + fn type_of() -> Type<'static> { + Type::Bool + } +} + +impl TypeOf for f64 { + fn type_of() -> Type<'static> { + Type::Float + } +} + +impl TypeOf for String { + fn type_of() -> Type<'static> { + Type::CString + } +} diff --git a/users/grfn/achilles/src/main.rs b/users/grfn/achilles/src/main.rs new file mode 100644 index 000000000000..5ae1b59b3a8e --- /dev/null +++ b/users/grfn/achilles/src/main.rs @@ -0,0 +1,36 @@ +use clap::Clap; + +pub mod ast; +pub mod codegen; +pub(crate) mod commands; +pub(crate) mod common; +pub mod compiler; +pub mod interpreter; +pub(crate) mod passes; +#[macro_use] +pub mod parser; +pub mod tc; + +pub use common::{Error, Result}; + +#[derive(Clap)] +struct Opts { + #[clap(subcommand)] + subcommand: Command, +} + +#[derive(Clap)] +enum Command { + Eval(commands::Eval), + Compile(commands::Compile), + Check(commands::Check), +} + +fn main() -> anyhow::Result<()> { + let opts = Opts::parse(); + match opts.subcommand { + Command::Eval(eval) => Ok(eval.run()?), + Command::Compile(compile) => Ok(compile.run()?), + Command::Check(check) => Ok(check.run()?), + } +} diff --git a/users/grfn/achilles/src/parser/expr.rs b/users/grfn/achilles/src/parser/expr.rs new file mode 100644 index 000000000000..8a28d00984c9 --- /dev/null +++ b/users/grfn/achilles/src/parser/expr.rs @@ -0,0 +1,645 @@ +use std::borrow::Cow; + +use nom::alt; +use nom::character::complete::{digit1, multispace0, multispace1}; +use nom::{ + call, char, complete, delimited, do_parse, flat_map, many0, map, named, opt, parse_to, + preceded, separated_list0, separated_list1, tag, tuple, +}; +use pratt::{Affix, Associativity, PrattParser, Precedence}; + +use crate::ast::{BinaryOperator, Binding, Expr, Fun, Literal, UnaryOperator}; +use crate::parser::{arg, ident, type_}; + +#[derive(Debug)] +enum TokenTree<'a> { + Prefix(UnaryOperator), + // Postfix(char), + Infix(BinaryOperator), + Primary(Expr<'a>), + Group(Vec>), +} + +named!(prefix(&str) -> TokenTree, map!(alt!( + complete!(char!('-')) => { |_| UnaryOperator::Neg } | + complete!(char!('!')) => { |_| UnaryOperator::Not } +), TokenTree::Prefix)); + +named!(infix(&str) -> TokenTree, map!(alt!( + complete!(tag!("==")) => { |_| BinaryOperator::Equ } | + complete!(tag!("!=")) => { |_| BinaryOperator::Neq } | + complete!(char!('+')) => { |_| BinaryOperator::Add } | + complete!(char!('-')) => { |_| BinaryOperator::Sub } | + complete!(char!('*')) => { |_| BinaryOperator::Mul } | + complete!(char!('/')) => { |_| BinaryOperator::Div } | + complete!(char!('^')) => { |_| BinaryOperator::Pow } +), TokenTree::Infix)); + +named!(primary(&str) -> TokenTree, alt!( + do_parse!( + multispace0 >> + char!('(') >> + multispace0 >> + group: group >> + multispace0 >> + char!(')') >> + multispace0 >> + (TokenTree::Group(group)) + ) | + delimited!(multispace0, simple_expr, multispace0) => { |s| TokenTree::Primary(s) } +)); + +named!( + rest(&str) -> Vec<(TokenTree, Vec, TokenTree)>, + many0!(tuple!( + infix, + delimited!(multispace0, many0!(prefix), multispace0), + primary + // many0!(postfix) + )) +); + +named!(group(&str) -> Vec, do_parse!( + prefix: many0!(prefix) + >> primary: primary + // >> postfix: many0!(postfix) + >> rest: rest + >> ({ + let mut res = prefix; + res.push(primary); + // res.append(&mut postfix); + for (infix, mut prefix, primary/*, mut postfix*/) in rest { + res.push(infix); + res.append(&mut prefix); + res.push(primary); + // res.append(&mut postfix); + } + res + }) +)); + +fn token_tree(i: &str) -> nom::IResult<&str, Vec> { + group(i) +} + +struct ExprParser; + +impl<'a, I> PrattParser for ExprParser +where + I: Iterator>, +{ + type Error = pratt::NoError; + type Input = TokenTree<'a>; + type Output = Expr<'a>; + + fn query(&mut self, input: &Self::Input) -> Result { + use BinaryOperator::*; + use UnaryOperator::*; + + Ok(match input { + TokenTree::Infix(Add) => Affix::Infix(Precedence(6), Associativity::Left), + TokenTree::Infix(Sub) => Affix::Infix(Precedence(6), Associativity::Left), + TokenTree::Infix(Mul) => Affix::Infix(Precedence(7), Associativity::Left), + TokenTree::Infix(Div) => Affix::Infix(Precedence(7), Associativity::Left), + TokenTree::Infix(Pow) => Affix::Infix(Precedence(8), Associativity::Right), + TokenTree::Infix(Equ) => Affix::Infix(Precedence(4), Associativity::Right), + TokenTree::Infix(Neq) => Affix::Infix(Precedence(4), Associativity::Right), + TokenTree::Prefix(Neg) => Affix::Prefix(Precedence(6)), + TokenTree::Prefix(Not) => Affix::Prefix(Precedence(6)), + TokenTree::Primary(_) => Affix::Nilfix, + TokenTree::Group(_) => Affix::Nilfix, + }) + } + + fn primary(&mut self, input: Self::Input) -> Result { + Ok(match input { + TokenTree::Primary(expr) => expr, + TokenTree::Group(group) => self.parse(&mut group.into_iter()).unwrap(), + _ => unreachable!(), + }) + } + + fn infix( + &mut self, + lhs: Self::Output, + op: Self::Input, + rhs: Self::Output, + ) -> Result { + let op = match op { + TokenTree::Infix(op) => op, + _ => unreachable!(), + }; + Ok(Expr::BinaryOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + }) + } + + fn prefix(&mut self, op: Self::Input, rhs: Self::Output) -> Result { + let op = match op { + TokenTree::Prefix(op) => op, + _ => unreachable!(), + }; + + Ok(Expr::UnaryOp { + op, + rhs: Box::new(rhs), + }) + } + + fn postfix( + &mut self, + _lhs: Self::Output, + _op: Self::Input, + ) -> Result { + unreachable!() + } +} + +named!(int(&str) -> Literal, map!(flat_map!(digit1, parse_to!(u64)), Literal::Int)); + +named!(bool_(&str) -> Literal, alt!( + complete!(tag!("true")) => { |_| Literal::Bool(true) } | + complete!(tag!("false")) => { |_| Literal::Bool(false) } +)); + +fn string_internal(i: &str) -> nom::IResult<&str, Cow<'_, str>, nom::error::Error<&str>> { + // TODO(grfn): use String::split_once when that's stable + let (s, rem) = if let Some(pos) = i.find('"') { + (&i[..pos], &i[(pos + 1)..]) + } else { + return Err(nom::Err::Error(nom::error::Error::new( + i, + nom::error::ErrorKind::Tag, + ))); + }; + + Ok((rem, Cow::Borrowed(s))) +} + +named!(string(&str) -> Literal, preceded!( + complete!(char!('"')), + map!( + string_internal, + |s| Literal::String(s) + ) +)); + +named!(unit(&str) -> Literal, map!(complete!(tag!("()")), |_| Literal::Unit)); + +named!(literal(&str) -> Literal, alt!(int | bool_ | string | unit)); + +named!(literal_expr(&str) -> Expr, map!(literal, Expr::Literal)); + +named!(binding(&str) -> Binding, do_parse!( + multispace0 + >> ident: ident + >> multispace0 + >> type_: opt!(preceded!(tuple!(tag!(":"), multispace0), type_)) + >> multispace0 + >> char!('=') + >> multispace0 + >> body: expr + >> (Binding { + ident, + type_, + body + }) +)); + +named!(let_(&str) -> Expr, do_parse!( + tag!("let") + >> multispace0 + >> bindings: separated_list1!(alt!(char!(';') | char!('\n')), binding) + >> multispace0 + >> tag!("in") + >> multispace0 + >> body: expr + >> (Expr::Let { + bindings, + body: Box::new(body) + }) +)); + +named!(if_(&str) -> Expr, do_parse! ( + tag!("if") + >> multispace0 + >> condition: expr + >> multispace0 + >> tag!("then") + >> multispace0 + >> then: expr + >> multispace0 + >> tag!("else") + >> multispace0 + >> else_: expr + >> (Expr::If { + condition: Box::new(condition), + then: Box::new(then), + else_: Box::new(else_) + }) +)); + +named!(ident_expr(&str) -> Expr, map!(ident, Expr::Ident)); + +fn ascripted<'a>( + p: impl Fn(&'a str) -> nom::IResult<&'a str, Expr, nom::error::Error<&'a str>> + 'a, +) -> impl Fn(&'a str) -> nom::IResult<&str, Expr, nom::error::Error<&'a str>> { + move |i| { + do_parse!( + i, + expr: p + >> multispace0 + >> complete!(tag!(":")) + >> multispace0 + >> type_: type_ + >> (Expr::Ascription { + expr: Box::new(expr), + type_ + }) + ) + } +} + +named!(paren_expr(&str) -> Expr, + delimited!(complete!(tag!("(")), expr, complete!(tag!(")")))); + +named!(funcref(&str) -> Expr, alt!( + ident_expr | + paren_expr +)); + +named!(no_arg_call(&str) -> Expr, do_parse!( + fun: funcref + >> complete!(tag!("()")) + >> (Expr::Call { + fun: Box::new(fun), + args: vec![], + }) +)); + +named!(fun_expr(&str) -> Expr, do_parse!( + tag!("fn") + >> multispace1 + >> args: separated_list0!(multispace1, arg) + >> multispace0 + >> char!('=') + >> multispace0 + >> body: expr + >> (Expr::Fun(Box::new(Fun { + args, + body + }))) +)); + +named!(fn_arg(&str) -> Expr, alt!( + ident_expr | + literal_expr | + paren_expr +)); + +named!(call_with_args(&str) -> Expr, do_parse!( + fun: funcref + >> multispace1 + >> args: separated_list1!(multispace1, fn_arg) + >> (Expr::Call { + fun: Box::new(fun), + args + }) +)); + +named!(simple_expr_unascripted(&str) -> Expr, alt!( + let_ | + if_ | + fun_expr | + literal_expr | + ident_expr +)); + +named!(simple_expr(&str) -> Expr, alt!( + call!(ascripted(simple_expr_unascripted)) | + simple_expr_unascripted +)); + +named!(pub expr(&str) -> Expr, alt!( + no_arg_call | + call_with_args | + map!(token_tree, |tt| { + ExprParser.parse(&mut tt.into_iter()).unwrap() + }) | + simple_expr +)); + +#[cfg(test)] +pub(crate) mod tests { + use super::*; + use crate::ast::{Arg, Ident, Type}; + use std::convert::TryFrom; + use BinaryOperator::*; + use Expr::{BinaryOp, If, Let, UnaryOp}; + use UnaryOperator::*; + + pub(crate) fn ident_expr(s: &str) -> Box { + Box::new(Expr::Ident(Ident::try_from(s).unwrap())) + } + + mod operators { + use super::*; + + #[test] + fn mul_plus() { + let (rem, res) = expr("x*y+z").unwrap(); + assert!(rem.is_empty()); + assert_eq!( + res, + BinaryOp { + lhs: Box::new(BinaryOp { + lhs: ident_expr("x"), + op: Mul, + rhs: ident_expr("y") + }), + op: Add, + rhs: ident_expr("z") + } + ) + } + + #[test] + fn mul_plus_ws() { + let (rem, res) = expr("x * y + z").unwrap(); + assert!(rem.is_empty(), "non-empty remainder: \"{}\"", rem); + assert_eq!( + res, + BinaryOp { + lhs: Box::new(BinaryOp { + lhs: ident_expr("x"), + op: Mul, + rhs: ident_expr("y") + }), + op: Add, + rhs: ident_expr("z") + } + ) + } + + #[test] + fn unary() { + let (rem, res) = expr("x * -z").unwrap(); + assert!(rem.is_empty(), "non-empty remainder: \"{}\"", rem); + assert_eq!( + res, + BinaryOp { + lhs: ident_expr("x"), + op: Mul, + rhs: Box::new(UnaryOp { + op: Neg, + rhs: ident_expr("z"), + }) + } + ) + } + + #[test] + fn mul_literal() { + let (rem, res) = expr("x * 3").unwrap(); + assert!(rem.is_empty()); + assert_eq!( + res, + BinaryOp { + lhs: ident_expr("x"), + op: Mul, + rhs: Box::new(Expr::Literal(Literal::Int(3))), + } + ) + } + + #[test] + fn equ() { + let res = test_parse!(expr, "x * 7 == 7"); + assert_eq!( + res, + BinaryOp { + lhs: Box::new(BinaryOp { + lhs: ident_expr("x"), + op: Mul, + rhs: Box::new(Expr::Literal(Literal::Int(7))) + }), + op: Equ, + rhs: Box::new(Expr::Literal(Literal::Int(7))) + } + ) + } + } + + #[test] + fn unit() { + assert_eq!(test_parse!(expr, "()"), Expr::Literal(Literal::Unit)); + } + + #[test] + fn bools() { + assert_eq!( + test_parse!(expr, "true"), + Expr::Literal(Literal::Bool(true)) + ); + assert_eq!( + test_parse!(expr, "false"), + Expr::Literal(Literal::Bool(false)) + ); + } + + #[test] + fn simple_string_lit() { + assert_eq!( + test_parse!(expr, "\"foobar\""), + Expr::Literal(Literal::String(Cow::Borrowed("foobar"))) + ) + } + + #[test] + fn let_complex() { + let res = test_parse!(expr, "let x = 1; y = x * 7 in (x + y) * 4"); + assert_eq!( + res, + Let { + bindings: vec![ + Binding { + ident: Ident::try_from("x").unwrap(), + type_: None, + body: Expr::Literal(Literal::Int(1)) + }, + Binding { + ident: Ident::try_from("y").unwrap(), + type_: None, + body: Expr::BinaryOp { + lhs: ident_expr("x"), + op: Mul, + rhs: Box::new(Expr::Literal(Literal::Int(7))) + } + } + ], + body: Box::new(Expr::BinaryOp { + lhs: Box::new(Expr::BinaryOp { + lhs: ident_expr("x"), + op: Add, + rhs: ident_expr("y"), + }), + op: Mul, + rhs: Box::new(Expr::Literal(Literal::Int(4))), + }) + } + ) + } + + #[test] + fn if_simple() { + let res = test_parse!(expr, "if x == 8 then 9 else 20"); + assert_eq!( + res, + If { + condition: Box::new(BinaryOp { + lhs: ident_expr("x"), + op: Equ, + rhs: Box::new(Expr::Literal(Literal::Int(8))), + }), + then: Box::new(Expr::Literal(Literal::Int(9))), + else_: Box::new(Expr::Literal(Literal::Int(20))) + } + ) + } + + #[test] + fn no_arg_call() { + let res = test_parse!(expr, "f()"); + assert_eq!( + res, + Expr::Call { + fun: ident_expr("f"), + args: vec![] + } + ); + } + + #[test] + fn unit_call() { + let res = test_parse!(expr, "f ()"); + assert_eq!( + res, + Expr::Call { + fun: ident_expr("f"), + args: vec![Expr::Literal(Literal::Unit)] + } + ) + } + + #[test] + fn call_with_args() { + let res = test_parse!(expr, "f x 1"); + assert_eq!( + res, + Expr::Call { + fun: ident_expr("f"), + args: vec![*ident_expr("x"), Expr::Literal(Literal::Int(1))] + } + ) + } + + #[test] + fn call_funcref() { + let res = test_parse!(expr, "(let x = 1 in x) 2"); + assert_eq!( + res, + Expr::Call { + fun: Box::new(Expr::Let { + bindings: vec![Binding { + ident: Ident::try_from("x").unwrap(), + type_: None, + body: Expr::Literal(Literal::Int(1)) + }], + body: ident_expr("x") + }), + args: vec![Expr::Literal(Literal::Int(2))] + } + ) + } + + #[test] + fn anon_function() { + let res = test_parse!(expr, "let id = fn x = x in id 1"); + assert_eq!( + res, + Expr::Let { + bindings: vec![Binding { + ident: Ident::try_from("id").unwrap(), + type_: None, + body: Expr::Fun(Box::new(Fun { + args: vec![Arg::try_from("x").unwrap()], + body: *ident_expr("x") + })) + }], + body: Box::new(Expr::Call { + fun: ident_expr("id"), + args: vec![Expr::Literal(Literal::Int(1))], + }) + } + ); + } + + mod ascriptions { + use super::*; + + #[test] + fn bare_ascription() { + let res = test_parse!(expr, "1: float"); + assert_eq!( + res, + Expr::Ascription { + expr: Box::new(Expr::Literal(Literal::Int(1))), + type_: Type::Float + } + ) + } + + #[test] + fn fn_body_ascription() { + let res = test_parse!(expr, "let const_1 = fn x = 1: int in const_1 2"); + assert_eq!( + res, + Expr::Let { + bindings: vec![Binding { + ident: Ident::try_from("const_1").unwrap(), + type_: None, + body: Expr::Fun(Box::new(Fun { + args: vec![Arg::try_from("x").unwrap()], + body: Expr::Ascription { + expr: Box::new(Expr::Literal(Literal::Int(1))), + type_: Type::Int, + } + })) + }], + body: Box::new(Expr::Call { + fun: ident_expr("const_1"), + args: vec![Expr::Literal(Literal::Int(2))] + }) + } + ) + } + + #[test] + fn let_binding_ascripted() { + let res = test_parse!(expr, "let x: int = 1 in x"); + assert_eq!( + res, + Expr::Let { + bindings: vec![Binding { + ident: Ident::try_from("x").unwrap(), + type_: Some(Type::Int), + body: Expr::Literal(Literal::Int(1)) + }], + body: ident_expr("x") + } + ) + } + } +} diff --git a/users/grfn/achilles/src/parser/macros.rs b/users/grfn/achilles/src/parser/macros.rs new file mode 100644 index 000000000000..406e5c0e699e --- /dev/null +++ b/users/grfn/achilles/src/parser/macros.rs @@ -0,0 +1,16 @@ +#[cfg(test)] +#[macro_use] +macro_rules! test_parse { + ($parser: ident, $src: expr) => {{ + let res = $parser($src); + nom_trace::print_trace!(); + let (rem, res) = res.unwrap(); + assert!( + rem.is_empty(), + "non-empty remainder: \"{}\", parsed: {:?}", + rem, + res + ); + res + }}; +} diff --git a/users/grfn/achilles/src/parser/mod.rs b/users/grfn/achilles/src/parser/mod.rs new file mode 100644 index 000000000000..3e0081bd391d --- /dev/null +++ b/users/grfn/achilles/src/parser/mod.rs @@ -0,0 +1,239 @@ +use nom::character::complete::{multispace0, multispace1}; +use nom::error::{ErrorKind, ParseError}; +use nom::{alt, char, complete, do_parse, eof, many0, named, separated_list0, tag, terminated}; + +#[macro_use] +pub(crate) mod macros; +mod expr; +mod type_; + +use crate::ast::{Arg, Decl, Fun, Ident}; +pub use expr::expr; +use type_::function_type; +pub use type_::type_; + +pub type Error = nom::Err>; + +pub(crate) fn is_reserved(s: &str) -> bool { + matches!( + s, + "if" | "then" + | "else" + | "let" + | "in" + | "fn" + | "ty" + | "int" + | "float" + | "bool" + | "true" + | "false" + | "cstring" + ) +} + +pub(crate) fn ident<'a, E>(i: &'a str) -> nom::IResult<&'a str, Ident, E> +where + E: ParseError<&'a str>, +{ + let mut chars = i.chars(); + if let Some(f) = chars.next() { + if f.is_alphabetic() || f == '_' { + let mut idx = 1; + for c in chars { + if !(c.is_alphanumeric() || c == '_') { + break; + } + idx += 1; + } + let id = &i[..idx]; + if is_reserved(id) { + Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Satisfy))) + } else { + Ok((&i[idx..], Ident::from_str_unchecked(id))) + } + } else { + Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Satisfy))) + } + } else { + Err(nom::Err::Error(E::from_error_kind(i, ErrorKind::Eof))) + } +} + +named!(ascripted_arg(&str) -> Arg, do_parse!( + complete!(char!('(')) >> + multispace0 >> + ident: ident >> + multispace0 >> + complete!(char!(':')) >> + multispace0 >> + type_: type_ >> + multispace0 >> + complete!(char!(')')) >> + (Arg { + ident, + type_: Some(type_) + }) +)); + +named!(arg(&str) -> Arg, alt!( + ident => { |ident| Arg {ident, type_: None}} | + ascripted_arg +)); + +named!(extern_decl(&str) -> Decl, do_parse!( + complete!(tag!("extern")) + >> multispace1 + >> name: ident + >> multispace0 + >> char!(':') + >> multispace0 + >> type_: function_type + >> multispace0 + >> (Decl::Extern { + name, + type_ + }) +)); + +named!(fun_decl(&str) -> Decl, do_parse!( + complete!(tag!("fn")) + >> multispace1 + >> name: ident + >> multispace1 + >> args: separated_list0!(multispace1, arg) + >> multispace0 + >> char!('=') + >> multispace0 + >> body: expr + >> (Decl::Fun { + name, + body: Fun { + args, + body + } + }) +)); + +named!(ascription_decl(&str) -> Decl, do_parse!( + complete!(tag!("ty")) + >> multispace1 + >> name: ident + >> multispace0 + >> complete!(char!(':')) + >> multispace0 + >> type_: type_ + >> multispace0 + >> (Decl::Ascription { + name, + type_ + }) +)); + +named!(pub decl(&str) -> Decl, alt!( + ascription_decl | + fun_decl | + extern_decl +)); + +named!(pub toplevel(&str) -> Vec, do_parse!( + decls: many0!(decl) + >> multispace0 + >> eof!() + >> (decls))); + +#[cfg(test)] +mod tests { + use std::convert::TryInto; + + use crate::ast::{BinaryOperator, Expr, FunctionType, Literal, Type}; + + use super::*; + use expr::tests::ident_expr; + + #[test] + fn fn_decl() { + let res = test_parse!(decl, "fn id x = x"); + assert_eq!( + res, + Decl::Fun { + name: "id".try_into().unwrap(), + body: Fun { + args: vec!["x".try_into().unwrap()], + body: *ident_expr("x"), + } + } + ) + } + + #[test] + fn ascripted_fn_args() { + test_parse!(ascripted_arg, "(x : int)"); + let res = test_parse!(decl, "fn plus1 (x : int) = x + 1"); + assert_eq!( + res, + Decl::Fun { + name: "plus1".try_into().unwrap(), + body: Fun { + args: vec![Arg { + ident: "x".try_into().unwrap(), + type_: Some(Type::Int), + }], + body: Expr::BinaryOp { + lhs: ident_expr("x"), + op: BinaryOperator::Add, + rhs: Box::new(Expr::Literal(Literal::Int(1))), + } + } + } + ); + } + + #[test] + fn multiple_decls() { + let res = test_parse!( + toplevel, + "fn id x = x + fn plus x y = x + y + fn main = plus (id 2) 7" + ); + assert_eq!(res.len(), 3); + let res = test_parse!( + toplevel, + "fn id x = x\nfn plus x y = x + y\nfn main = plus (id 2) 7\n" + ); + assert_eq!(res.len(), 3); + } + + #[test] + fn top_level_ascription() { + let res = test_parse!(toplevel, "ty id : fn a -> a"); + assert_eq!( + res, + vec![Decl::Ascription { + name: "id".try_into().unwrap(), + type_: Type::Function(FunctionType { + args: vec![Type::Var("a".try_into().unwrap())], + ret: Box::new(Type::Var("a".try_into().unwrap())) + }) + }] + ) + } + + #[test] + fn return_unit() { + assert_eq!( + test_parse!(decl, "fn g _ = ()"), + Decl::Fun { + name: "g".try_into().unwrap(), + body: Fun { + args: vec![Arg { + ident: "_".try_into().unwrap(), + type_: None, + }], + body: Expr::Literal(Literal::Unit), + }, + } + ) + } +} diff --git a/users/grfn/achilles/src/parser/type_.rs b/users/grfn/achilles/src/parser/type_.rs new file mode 100644 index 000000000000..8a1081e2521f --- /dev/null +++ b/users/grfn/achilles/src/parser/type_.rs @@ -0,0 +1,127 @@ +use nom::character::complete::{multispace0, multispace1}; +use nom::{alt, delimited, do_parse, map, named, opt, separated_list0, tag, terminated, tuple}; + +use super::ident; +use crate::ast::{FunctionType, Type}; + +named!(pub function_type(&str) -> FunctionType, do_parse!( + tag!("fn") + >> multispace1 + >> args: map!(opt!(terminated!(separated_list0!( + tuple!( + multispace0, + tag!(","), + multispace0 + ), + type_ + ), multispace1)), |args| args.unwrap_or_default()) + >> tag!("->") + >> multispace1 + >> ret: type_ + >> (FunctionType { + args, + ret: Box::new(ret) + }) +)); + +named!(pub type_(&str) -> Type, alt!( + tag!("int") => { |_| Type::Int } | + tag!("float") => { |_| Type::Float } | + tag!("bool") => { |_| Type::Bool } | + tag!("cstring") => { |_| Type::CString } | + tag!("()") => { |_| Type::Unit } | + function_type => { |ft| Type::Function(ft) }| + ident => { |id| Type::Var(id) } | + delimited!( + tuple!(tag!("("), multispace0), + type_, + tuple!(tag!(")"), multispace0) + ) +)); + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::*; + use crate::ast::Ident; + + #[test] + fn simple_types() { + assert_eq!(test_parse!(type_, "int"), Type::Int); + assert_eq!(test_parse!(type_, "float"), Type::Float); + assert_eq!(test_parse!(type_, "bool"), Type::Bool); + assert_eq!(test_parse!(type_, "cstring"), Type::CString); + assert_eq!(test_parse!(type_, "()"), Type::Unit); + } + + #[test] + fn no_arg_fn_type() { + assert_eq!( + test_parse!(type_, "fn -> int"), + Type::Function(FunctionType { + args: vec![], + ret: Box::new(Type::Int) + }) + ); + } + + #[test] + fn fn_type_with_args() { + assert_eq!( + test_parse!(type_, "fn int, bool -> int"), + Type::Function(FunctionType { + args: vec![Type::Int, Type::Bool], + ret: Box::new(Type::Int) + }) + ); + } + + #[test] + fn fn_taking_fn() { + assert_eq!( + test_parse!(type_, "fn fn int, bool -> bool, float -> float"), + Type::Function(FunctionType { + args: vec![ + Type::Function(FunctionType { + args: vec![Type::Int, Type::Bool], + ret: Box::new(Type::Bool) + }), + Type::Float + ], + ret: Box::new(Type::Float) + }) + ) + } + + #[test] + fn parenthesized() { + assert_eq!( + test_parse!(type_, "fn (fn int, bool -> bool), float -> float"), + Type::Function(FunctionType { + args: vec![ + Type::Function(FunctionType { + args: vec![Type::Int, Type::Bool], + ret: Box::new(Type::Bool) + }), + Type::Float + ], + ret: Box::new(Type::Float) + }) + ) + } + + #[test] + fn type_vars() { + assert_eq!( + test_parse!(type_, "fn x, y -> x"), + Type::Function(FunctionType { + args: vec![ + Type::Var(Ident::try_from("x").unwrap()), + Type::Var(Ident::try_from("y").unwrap()), + ], + ret: Box::new(Type::Var(Ident::try_from("x").unwrap())), + }) + ) + } +} diff --git a/users/grfn/achilles/src/passes/hir/mod.rs b/users/grfn/achilles/src/passes/hir/mod.rs new file mode 100644 index 000000000000..845bfcb7ab6a --- /dev/null +++ b/users/grfn/achilles/src/passes/hir/mod.rs @@ -0,0 +1,197 @@ +use std::collections::HashMap; + +use crate::ast::hir::{Binding, Decl, Expr}; +use crate::ast::{BinaryOperator, Ident, Literal, UnaryOperator}; + +pub(crate) mod monomorphize; +pub(crate) mod strip_positive_units; + +pub(crate) trait Visitor<'a, 'ast, T: 'ast>: Sized + 'a { + type Error; + + fn visit_type(&mut self, _type: &mut T) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_ident(&mut self, _ident: &mut Ident<'ast>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_literal(&mut self, _literal: &mut Literal<'ast>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_unary_operator(&mut self, _op: &mut UnaryOperator) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_binary_operator(&mut self, _op: &mut BinaryOperator) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_binding(&mut self, binding: &mut Binding<'ast, T>) -> Result<(), Self::Error> { + self.visit_ident(&mut binding.ident)?; + self.visit_type(&mut binding.type_)?; + self.visit_expr(&mut binding.body)?; + Ok(()) + } + + fn post_visit_call( + &mut self, + _fun: &mut Expr<'ast, T>, + _type_args: &mut HashMap, T>, + _args: &mut Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn pre_visit_call( + &mut self, + _fun: &mut Expr<'ast, T>, + _type_args: &mut HashMap, T>, + _args: &mut Vec>, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn pre_visit_expr(&mut self, _expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_expr(&mut self, expr: &mut Expr<'ast, T>) -> Result<(), Self::Error> { + self.pre_visit_expr(expr)?; + match expr { + Expr::Ident(id, t) => { + self.visit_ident(id)?; + self.visit_type(t)?; + } + Expr::Literal(lit, t) => { + self.visit_literal(lit)?; + self.visit_type(t)?; + } + Expr::UnaryOp { op, rhs, type_ } => { + self.visit_unary_operator(op)?; + self.visit_expr(rhs)?; + self.visit_type(type_)?; + } + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => { + self.visit_expr(lhs)?; + self.visit_binary_operator(op)?; + self.visit_expr(rhs)?; + self.visit_type(type_)?; + } + Expr::Let { + bindings, + body, + type_, + } => { + for binding in bindings.iter_mut() { + self.visit_binding(binding)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Expr::If { + condition, + then, + else_, + type_, + } => { + self.visit_expr(condition)?; + self.visit_expr(then)?; + self.visit_expr(else_)?; + self.visit_type(type_)?; + } + Expr::Fun { + args, + body, + type_args, + type_, + } => { + for (ident, t) in args { + self.visit_ident(ident)?; + self.visit_type(t)?; + } + for ta in type_args { + self.visit_ident(ta)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + } + Expr::Call { + fun, + args, + type_args, + type_, + } => { + self.pre_visit_call(fun, type_args, args)?; + self.visit_expr(fun)?; + for arg in args.iter_mut() { + self.visit_expr(arg)?; + } + self.visit_type(type_)?; + self.post_visit_call(fun, type_args, args)?; + } + } + + Ok(()) + } + + fn post_visit_decl(&mut self, decl: &'a Decl<'ast, T>) -> Result<(), Self::Error> { + Ok(()) + } + + fn post_visit_fun_decl( + &mut self, + _name: &mut Ident<'ast>, + _type_args: &mut Vec, + _args: &mut Vec<(Ident, T)>, + _body: &mut Box>, + _type_: &mut T, + ) -> Result<(), Self::Error> { + Ok(()) + } + + fn visit_decl(&mut self, decl: &'a mut Decl<'ast, T>) -> Result<(), Self::Error> { + match decl { + Decl::Fun { + name, + type_args, + args, + body, + type_, + } => { + self.visit_ident(name)?; + for type_arg in type_args.iter_mut() { + self.visit_ident(type_arg)?; + } + for (arg, t) in args.iter_mut() { + self.visit_ident(arg)?; + self.visit_type(t)?; + } + self.visit_expr(body)?; + self.visit_type(type_)?; + self.post_visit_fun_decl(name, type_args, args, body, type_)?; + } + Decl::Extern { + name, + arg_types, + ret_type, + } => { + self.visit_ident(name)?; + for arg_t in arg_types { + self.visit_type(arg_t)?; + } + self.visit_type(ret_type)?; + } + } + + self.post_visit_decl(decl)?; + Ok(()) + } +} diff --git a/users/grfn/achilles/src/passes/hir/monomorphize.rs b/users/grfn/achilles/src/passes/hir/monomorphize.rs new file mode 100644 index 000000000000..251a988f4f6f --- /dev/null +++ b/users/grfn/achilles/src/passes/hir/monomorphize.rs @@ -0,0 +1,139 @@ +use std::cell::RefCell; +use std::collections::{HashMap, HashSet}; +use std::convert::TryInto; +use std::mem; + +use void::{ResultVoidExt, Void}; + +use crate::ast::hir::{Decl, Expr}; +use crate::ast::{self, Ident}; + +use super::Visitor; + +#[derive(Default)] +pub(crate) struct Monomorphize<'a, 'ast> { + decls: HashMap<&'a Ident<'ast>, &'a Decl<'ast, ast::Type<'ast>>>, + extra_decls: Vec>>, + remove_decls: HashSet>, +} + +impl<'a, 'ast> Monomorphize<'a, 'ast> { + pub(crate) fn new() -> Self { + Default::default() + } +} + +impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for Monomorphize<'a, 'ast> { + type Error = Void; + + fn post_visit_call( + &mut self, + fun: &mut Expr<'ast, ast::Type<'ast>>, + type_args: &mut HashMap, ast::Type<'ast>>, + args: &mut Vec>>, + ) -> Result<(), Self::Error> { + let new_fun = match fun { + Expr::Ident(id, _) => { + let decl: Decl<_> = (**self.decls.get(id).unwrap()).clone(); + let name = RefCell::new(id.to_string()); + let type_args = mem::take(type_args); + let mut monomorphized = decl + .traverse_type(|ty| -> Result<_, Void> { + Ok(ty.clone().traverse_type_vars(|v| { + let concrete = type_args.get(&v).unwrap(); + name.borrow_mut().push_str(&concrete.to_string()); + concrete.clone() + })) + }) + .void_unwrap(); + let name: Ident = name.into_inner().try_into().unwrap(); + if name != *id { + self.remove_decls.insert(id.clone()); + monomorphized.set_name(name.clone()); + let type_ = monomorphized.type_().unwrap().clone(); + self.extra_decls.push(monomorphized); + Some(Expr::Ident(name, type_)) + } else { + None + } + } + _ => todo!(), + }; + if let Some(new_fun) = new_fun { + *fun = new_fun; + } + Ok(()) + } + + fn post_visit_decl( + &mut self, + decl: &'a Decl<'ast, ast::Type<'ast>>, + ) -> Result<(), Self::Error> { + self.decls.insert(decl.name(), decl); + Ok(()) + } +} + +pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec>>) { + let mut pass = Monomorphize::new(); + for decl in toplevel.iter_mut() { + pass.visit_decl(decl).void_unwrap(); + } + let remove_decls = mem::take(&mut pass.remove_decls); + let mut extra_decls = mem::take(&mut pass.extra_decls); + toplevel.retain(|decl| !remove_decls.contains(decl.name())); + extra_decls.append(toplevel); + *toplevel = extra_decls; +} + +#[cfg(test)] +mod tests { + use std::convert::TryFrom; + + use super::*; + use crate::parser::toplevel; + use crate::tc::typecheck_toplevel; + + #[test] + fn call_id_decl() { + let (_, program) = toplevel( + "ty id : fn a -> a + fn id x = x + + ty main : fn -> int + fn main = id 0", + ) + .unwrap(); + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + let find_decl = |ident: &str| { + program.iter().find(|decl| { + matches!(decl, Decl::Fun {name, ..} if name == &Ident::try_from(ident).unwrap()) + }).unwrap() + }; + + let main = find_decl("main"); + let body = match main { + Decl::Fun { body, .. } => body, + _ => unreachable!(), + }; + + let expected_type = ast::Type::Function(ast::FunctionType { + args: vec![ast::Type::Int], + ret: Box::new(ast::Type::Int), + }); + + match &**body { + Expr::Call { fun, .. } => { + let fun = match &**fun { + Expr::Ident(fun, _) => fun, + _ => unreachable!(), + }; + let called_decl = find_decl(fun.into()); + assert_eq!(called_decl.type_().unwrap(), &expected_type); + } + _ => unreachable!(), + } + } +} diff --git a/users/grfn/achilles/src/passes/hir/strip_positive_units.rs b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs new file mode 100644 index 000000000000..91b56551c82d --- /dev/null +++ b/users/grfn/achilles/src/passes/hir/strip_positive_units.rs @@ -0,0 +1,189 @@ +use std::collections::HashMap; +use std::mem; + +use ast::hir::Binding; +use ast::Literal; +use void::{ResultVoidExt, Void}; + +use crate::ast::hir::{Decl, Expr}; +use crate::ast::{self, Ident}; + +use super::Visitor; + +/// Strip all values with a unit type in positive (non-return) position +pub(crate) struct StripPositiveUnits {} + +impl<'a, 'ast> Visitor<'a, 'ast, ast::Type<'ast>> for StripPositiveUnits { + type Error = Void; + + fn pre_visit_expr( + &mut self, + expr: &mut Expr<'ast, ast::Type<'ast>>, + ) -> Result<(), Self::Error> { + let mut extracted = vec![]; + if let Expr::Call { args, .. } = expr { + // TODO(grfn): replace with drain_filter once it's stabilized + let mut i = 0; + while i != args.len() { + if args[i].type_() == &ast::Type::Unit { + let expr = args.remove(i); + if !matches!(expr, Expr::Literal(Literal::Unit, _)) { + extracted.push(expr) + }; + } else { + i += 1 + } + } + } + + if !extracted.is_empty() { + let body = mem::replace(expr, Expr::Literal(Literal::Unit, ast::Type::Unit)); + *expr = Expr::Let { + bindings: extracted + .into_iter() + .map(|expr| Binding { + ident: Ident::from_str_unchecked("___discarded"), + type_: expr.type_().clone(), + body: expr, + }) + .collect(), + type_: body.type_().clone(), + body: Box::new(body), + }; + } + + Ok(()) + } + + fn post_visit_call( + &mut self, + _fun: &mut Expr<'ast, ast::Type<'ast>>, + _type_args: &mut HashMap, ast::Type<'ast>>, + args: &mut Vec>>, + ) -> Result<(), Self::Error> { + args.retain(|arg| arg.type_() != &ast::Type::Unit); + Ok(()) + } + + fn visit_type(&mut self, type_: &mut ast::Type<'ast>) -> Result<(), Self::Error> { + if let ast::Type::Function(ft) = type_ { + ft.args.retain(|a| a != &ast::Type::Unit); + } + Ok(()) + } + + fn post_visit_fun_decl( + &mut self, + _name: &mut Ident<'ast>, + _type_args: &mut Vec, + args: &mut Vec<(Ident, ast::Type<'ast>)>, + _body: &mut Box>>, + _type_: &mut ast::Type<'ast>, + ) -> Result<(), Self::Error> { + args.retain(|(_, ty)| ty != &ast::Type::Unit); + Ok(()) + } +} + +pub(crate) fn run_toplevel<'a>(toplevel: &mut Vec>>) { + let mut pass = StripPositiveUnits {}; + for decl in toplevel.iter_mut() { + pass.visit_decl(decl).void_unwrap(); + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::parser::toplevel; + use crate::tc::typecheck_toplevel; + use pretty_assertions::assert_eq; + + #[test] + fn unit_only_arg() { + let (_, program) = toplevel( + "ty f : fn () -> int + fn f _ = 1 + + ty main : fn -> int + fn main = f ()", + ) + .unwrap(); + + let (_, expected) = toplevel( + "ty f : fn -> int + fn f = 1 + + ty main : fn -> int + fn main = f()", + ) + .unwrap(); + let expected = typecheck_toplevel(expected).unwrap(); + + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + assert_eq!(program, expected); + } + + #[test] + fn unit_and_other_arg() { + let (_, program) = toplevel( + "ty f : fn (), int -> int + fn f _ x = x + + ty main : fn -> int + fn main = f () 1", + ) + .unwrap(); + + let (_, expected) = toplevel( + "ty f : fn int -> int + fn f x = x + + ty main : fn -> int + fn main = f 1", + ) + .unwrap(); + let expected = typecheck_toplevel(expected).unwrap(); + + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + assert_eq!(program, expected); + } + + #[test] + fn unit_expr_and_other_arg() { + let (_, program) = toplevel( + "ty f : fn (), int -> int + fn f _ x = x + + ty g : fn int -> () + fn g _ = () + + ty main : fn -> int + fn main = f (g 2) 1", + ) + .unwrap(); + + let (_, expected) = toplevel( + "ty f : fn int -> int + fn f x = x + + ty g : fn int -> () + fn g _ = () + + ty main : fn -> int + fn main = let ___discarded = g 2 in f 1", + ) + .unwrap(); + assert_eq!(expected.len(), 6); + let expected = typecheck_toplevel(expected).unwrap(); + + let mut program = typecheck_toplevel(program).unwrap(); + run_toplevel(&mut program); + + assert_eq!(program, expected); + } +} diff --git a/users/grfn/achilles/src/passes/mod.rs b/users/grfn/achilles/src/passes/mod.rs new file mode 100644 index 000000000000..306869bef1d5 --- /dev/null +++ b/users/grfn/achilles/src/passes/mod.rs @@ -0,0 +1 @@ +pub(crate) mod hir; diff --git a/users/grfn/achilles/src/tc/mod.rs b/users/grfn/achilles/src/tc/mod.rs new file mode 100644 index 000000000000..d27c45075e97 --- /dev/null +++ b/users/grfn/achilles/src/tc/mod.rs @@ -0,0 +1,745 @@ +use bimap::BiMap; +use derive_more::From; +use itertools::Itertools; +use std::cell::RefCell; +use std::collections::HashMap; +use std::convert::{TryFrom, TryInto}; +use std::fmt::{self, Display}; +use std::{mem, result}; +use thiserror::Error; + +use crate::ast::{self, hir, Arg, BinaryOperator, Ident, Literal}; +use crate::common::env::Env; +use crate::common::{Namer, NamerOf}; + +#[derive(Debug, Error)] +pub enum Error { + #[error("Undefined variable {0}")] + UndefinedVariable(Ident<'static>), + + #[error("Mismatched types: expected {expected}, but got {actual}")] + TypeMismatch { expected: Type, actual: Type }, + + #[error("Mismatched types, expected numeric type, but got {0}")] + NonNumeric(Type), + + #[error("Ambiguous type {0}")] + AmbiguousType(TyVar), +} + +pub type Result = result::Result; + +#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash)] +pub struct TyVar(u64); + +impl Display for TyVar { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "t{}", self.0) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Hash)] +pub struct NullaryType(String); + +impl Display for NullaryType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(&self.0) + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum PrimType { + Int, + Float, + Bool, + CString, +} + +impl<'a> From for ast::Type<'a> { + fn from(pr: PrimType) -> Self { + match pr { + PrimType::Int => ast::Type::Int, + PrimType::Float => ast::Type::Float, + PrimType::Bool => ast::Type::Bool, + PrimType::CString => ast::Type::CString, + } + } +} + +impl Display for PrimType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + PrimType::Int => f.write_str("int"), + PrimType::Float => f.write_str("float"), + PrimType::Bool => f.write_str("bool"), + PrimType::CString => f.write_str("cstring"), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, From)] +pub enum Type { + #[from(ignore)] + Univ(TyVar), + #[from(ignore)] + Exist(TyVar), + Nullary(NullaryType), + Prim(PrimType), + Unit, + Fun { + args: Vec, + ret: Box, + }, +} + +impl<'a> TryFrom for ast::Type<'a> { + type Error = Type; + + fn try_from(value: Type) -> result::Result { + match value { + Type::Unit => Ok(ast::Type::Unit), + Type::Univ(_) => todo!(), + Type::Exist(_) => Err(value), + Type::Nullary(_) => todo!(), + Type::Prim(p) => Ok(p.into()), + Type::Fun { ref args, ref ret } => Ok(ast::Type::Function(ast::FunctionType { + args: args + .clone() + .into_iter() + .map(Self::try_from) + .try_collect() + .map_err(|_| value.clone())?, + ret: Box::new((*ret.clone()).try_into().map_err(|_| value.clone())?), + })), + } + } +} + +const INT: Type = Type::Prim(PrimType::Int); +const FLOAT: Type = Type::Prim(PrimType::Float); +const BOOL: Type = Type::Prim(PrimType::Bool); +const CSTRING: Type = Type::Prim(PrimType::CString); + +impl Display for Type { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Nullary(nt) => nt.fmt(f), + Type::Prim(p) => p.fmt(f), + Type::Univ(TyVar(n)) => write!(f, "∀{}", n), + Type::Exist(TyVar(n)) => write!(f, "∃{}", n), + Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret), + Type::Unit => write!(f, "()"), + } + } +} + +struct Typechecker<'ast> { + ty_var_namer: NamerOf, + ctx: HashMap, + env: Env, Type>, + + /// AST type var -> type + instantiations: Env, Type>, + + /// AST type-var -> universal TyVar + type_vars: RefCell<(BiMap, TyVar>, NamerOf>)>, +} + +impl<'ast> Typechecker<'ast> { + fn new() -> Self { + Self { + ty_var_namer: Namer::new(TyVar).boxed(), + type_vars: RefCell::new(( + Default::default(), + Namer::alphabetic().map(|n| Ident::try_from(n).unwrap()), + )), + ctx: Default::default(), + env: Default::default(), + instantiations: Default::default(), + } + } + + pub(crate) fn tc_expr(&mut self, expr: ast::Expr<'ast>) -> Result> { + match expr { + ast::Expr::Ident(ident) => { + let type_ = self + .env + .resolve(&ident) + .ok_or_else(|| Error::UndefinedVariable(ident.to_owned()))? + .clone(); + Ok(hir::Expr::Ident(ident, type_)) + } + ast::Expr::Literal(lit) => { + let type_ = match lit { + Literal::Int(_) => Type::Prim(PrimType::Int), + Literal::Bool(_) => Type::Prim(PrimType::Bool), + Literal::String(_) => Type::Prim(PrimType::CString), + Literal::Unit => Type::Unit, + }; + Ok(hir::Expr::Literal(lit.to_owned(), type_)) + } + ast::Expr::UnaryOp { op, rhs } => todo!(), + ast::Expr::BinaryOp { lhs, op, rhs } => { + let lhs = self.tc_expr(*lhs)?; + let rhs = self.tc_expr(*rhs)?; + let type_ = match op { + BinaryOperator::Equ | BinaryOperator::Neq => { + self.unify(lhs.type_(), rhs.type_())?; + Type::Prim(PrimType::Bool) + } + BinaryOperator::Add | BinaryOperator::Sub | BinaryOperator::Mul => { + let ty = self.unify(lhs.type_(), rhs.type_())?; + // if !matches!(ty, Type::Int | Type::Float) { + // return Err(Error::NonNumeric(ty)); + // } + ty + } + BinaryOperator::Div => todo!(), + BinaryOperator::Pow => todo!(), + }; + Ok(hir::Expr::BinaryOp { + lhs: Box::new(lhs), + op, + rhs: Box::new(rhs), + type_, + }) + } + ast::Expr::Let { bindings, body } => { + self.env.push(); + let bindings = bindings + .into_iter() + .map( + |ast::Binding { ident, type_, body }| -> Result> { + let body = self.tc_expr(body)?; + if let Some(type_) = type_ { + let type_ = self.type_from_ast_type(type_); + self.unify(body.type_(), &type_)?; + } + self.env.set(ident.clone(), body.type_().clone()); + Ok(hir::Binding { + ident, + type_: body.type_().clone(), + body, + }) + }, + ) + .collect::>>>()?; + let body = self.tc_expr(*body)?; + self.env.pop(); + Ok(hir::Expr::Let { + bindings, + type_: body.type_().clone(), + body: Box::new(body), + }) + } + ast::Expr::If { + condition, + then, + else_, + } => { + let condition = self.tc_expr(*condition)?; + self.unify(&Type::Prim(PrimType::Bool), condition.type_())?; + let then = self.tc_expr(*then)?; + let else_ = self.tc_expr(*else_)?; + let type_ = self.unify(then.type_(), else_.type_())?; + Ok(hir::Expr::If { + condition: Box::new(condition), + then: Box::new(then), + else_: Box::new(else_), + type_, + }) + } + ast::Expr::Fun(f) => { + let ast::Fun { args, body } = *f; + self.env.push(); + let args: Vec<_> = args + .into_iter() + .map(|Arg { ident, type_ }| { + let ty = match type_ { + Some(t) => self.type_from_ast_type(t), + None => self.fresh_ex(), + }; + self.env.set(ident.clone(), ty.clone()); + (ident, ty) + }) + .collect(); + let body = self.tc_expr(body)?; + self.env.pop(); + Ok(hir::Expr::Fun { + type_: Type::Fun { + args: args.iter().map(|(_, ty)| ty.clone()).collect(), + ret: Box::new(body.type_().clone()), + }, + type_args: vec![], // TODO fill in once we do let generalization + args, + body: Box::new(body), + }) + } + ast::Expr::Call { fun, args } => { + let ret_ty = self.fresh_ex(); + let arg_tys = args.iter().map(|_| self.fresh_ex()).collect::>(); + let ft = Type::Fun { + args: arg_tys.clone(), + ret: Box::new(ret_ty.clone()), + }; + let fun = self.tc_expr(*fun)?; + self.instantiations.push(); + self.unify(&ft, fun.type_())?; + let args = args + .into_iter() + .zip(arg_tys) + .map(|(arg, ty)| { + let arg = self.tc_expr(arg)?; + self.unify(&ty, arg.type_())?; + Ok(arg) + }) + .try_collect()?; + let type_args = self.commit_instantiations(); + Ok(hir::Expr::Call { + fun: Box::new(fun), + type_args, + args, + type_: ret_ty, + }) + } + ast::Expr::Ascription { expr, type_ } => { + let expr = self.tc_expr(*expr)?; + let type_ = self.type_from_ast_type(type_); + self.unify(expr.type_(), &type_)?; + Ok(expr) + } + } + } + + pub(crate) fn tc_decl( + &mut self, + decl: ast::Decl<'ast>, + ) -> Result>> { + match decl { + ast::Decl::Fun { name, body } => { + let mut expr = ast::Expr::Fun(Box::new(body)); + if let Some(type_) = self.env.resolve(&name) { + expr = ast::Expr::Ascription { + expr: Box::new(expr), + type_: self.finalize_type(type_.clone())?, + }; + } + + self.env.push(); + let body = self.tc_expr(expr)?; + let type_ = body.type_().clone(); + self.env.set(name.clone(), type_); + self.env.pop(); + match body { + hir::Expr::Fun { + type_args, + args, + body, + type_, + } => Ok(Some(hir::Decl::Fun { + name, + type_args, + args, + body, + type_, + })), + _ => unreachable!(), + } + } + ast::Decl::Ascription { name, type_ } => { + let type_ = self.type_from_ast_type(type_); + self.env.set(name.clone(), type_); + Ok(None) + } + ast::Decl::Extern { name, type_ } => { + let type_ = self.type_from_ast_type(ast::Type::Function(type_)); + self.env.set(name.clone(), type_.clone()); + let (arg_types, ret_type) = match type_ { + Type::Fun { args, ret } => (args, *ret), + _ => unreachable!(), + }; + Ok(Some(hir::Decl::Extern { + name, + arg_types, + ret_type, + })) + } + } + } + + fn fresh_tv(&mut self) -> TyVar { + self.ty_var_namer.make_name() + } + + fn fresh_ex(&mut self) -> Type { + Type::Exist(self.fresh_tv()) + } + + fn fresh_univ(&mut self) -> Type { + Type::Univ(self.fresh_tv()) + } + + fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result { + match (ty1, ty2) { + (Type::Unit, Type::Unit) => Ok(Type::Unit), + (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) { + Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()), + Some(var @ ast::Type::Var(_)) => { + let var = self.type_from_ast_type(var); + self.unify(&var, ty) + } + Some(existing_ty) => match ty { + Type::Exist(_) => { + let rhs = self.type_from_ast_type(existing_ty); + self.unify(ty, &rhs) + } + _ => Err(Error::TypeMismatch { + expected: ty.clone(), + actual: self.type_from_ast_type(existing_ty), + }), + }, + None => match self.ctx.insert(*tv, ty.clone()) { + Some(existing) => self.unify(&existing, ty), + None => Ok(ty.clone()), + }, + }, + (Type::Univ(u1), Type::Univ(u2)) if u1 == u2 => Ok(ty2.clone()), + (Type::Univ(u), ty) | (ty, Type::Univ(u)) => { + let ident = self.name_univ(*u); + match self.instantiations.resolve(&ident) { + Some(existing_ty) if ty == existing_ty => Ok(ty.clone()), + Some(existing_ty) => Err(Error::TypeMismatch { + expected: ty.clone(), + actual: existing_ty.clone(), + }), + None => { + self.instantiations.set(ident, ty.clone()); + Ok(ty.clone()) + } + } + } + (Type::Prim(p1), Type::Prim(p2)) if p1 == p2 => Ok(ty2.clone()), + ( + Type::Fun { + args: args1, + ret: ret1, + }, + Type::Fun { + args: args2, + ret: ret2, + }, + ) => { + let args = args1 + .iter() + .zip(args2) + .map(|(t1, t2)| self.unify(t1, t2)) + .try_collect()?; + let ret = self.unify(ret1, ret2)?; + Ok(Type::Fun { + args, + ret: Box::new(ret), + }) + } + (Type::Nullary(_), _) | (_, Type::Nullary(_)) => todo!(), + _ => Err(Error::TypeMismatch { + expected: ty1.clone(), + actual: ty2.clone(), + }), + } + } + + fn finalize_expr( + &self, + expr: hir::Expr<'ast, Type>, + ) -> Result>> { + expr.traverse_type(|ty| self.finalize_type(ty)) + } + + fn finalize_decl( + &mut self, + decl: hir::Decl<'ast, Type>, + ) -> Result>> { + let res = decl.traverse_type(|ty| self.finalize_type(ty))?; + if let Some(type_) = res.type_() { + let ty = self.type_from_ast_type(type_.clone()); + self.env.set(res.name().clone(), ty); + } + Ok(res) + } + + fn finalize_type(&self, ty: Type) -> Result> { + let ret = match ty { + Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)), + Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))), + Type::Unit => Ok(ast::Type::Unit), + Type::Nullary(_) => todo!(), + Type::Prim(pr) => Ok(pr.into()), + Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType { + args: args + .into_iter() + .map(|ty| self.finalize_type(ty)) + .try_collect()?, + ret: Box::new(self.finalize_type(*ret)?), + })), + }; + ret + } + + fn resolve_tv(&self, tv: TyVar) -> Option> { + let mut res = &Type::Exist(tv); + loop { + match res { + Type::Exist(tv) => { + res = self.ctx.get(tv)?; + } + Type::Univ(tv) => { + let ident = self.name_univ(*tv); + if let Some(r) = self.instantiations.resolve(&ident) { + res = r; + } else { + break Some(ast::Type::Var(ident)); + } + } + Type::Nullary(_) => todo!(), + Type::Prim(pr) => break Some((*pr).into()), + Type::Unit => break Some(ast::Type::Unit), + Type::Fun { args, ret } => todo!(), + } + } + } + + fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type { + match ast_type { + ast::Type::Unit => Type::Unit, + ast::Type::Int => INT, + ast::Type::Float => FLOAT, + ast::Type::Bool => BOOL, + ast::Type::CString => CSTRING, + ast::Type::Function(ast::FunctionType { args, ret }) => Type::Fun { + args: args + .into_iter() + .map(|t| self.type_from_ast_type(t)) + .collect(), + ret: Box::new(self.type_from_ast_type(*ret)), + }, + ast::Type::Var(id) => Type::Univ({ + let opt_tv = { self.type_vars.borrow_mut().0.get_by_left(&id).copied() }; + opt_tv.unwrap_or_else(|| { + let tv = self.fresh_tv(); + self.type_vars + .borrow_mut() + .0 + .insert_no_overwrite(id, tv) + .unwrap(); + tv + }) + }), + } + } + + fn name_univ(&self, tv: TyVar) -> Ident<'static> { + let mut vars = self.type_vars.borrow_mut(); + vars.0 + .get_by_right(&tv) + .map(Ident::to_owned) + .unwrap_or_else(|| { + let name = loop { + let name = vars.1.make_name(); + if !vars.0.contains_left(&name) { + break name; + } + }; + vars.0.insert_no_overwrite(name.clone(), tv).unwrap(); + name + }) + } + + fn commit_instantiations(&mut self) -> HashMap, Type> { + let mut res = HashMap::new(); + let mut ctx = mem::take(&mut self.ctx); + for (_, v) in ctx.iter_mut() { + if let Type::Univ(tv) = v { + let tv_name = self.name_univ(*tv); + if let Some(concrete) = self.instantiations.resolve(&tv_name) { + res.insert(tv_name, concrete.clone()); + *v = concrete.clone(); + } + } + } + self.ctx = ctx; + self.instantiations.pop(); + res + } + + fn types_match(&self, type_: &Type, ast_type: &ast::Type<'ast>) -> bool { + match (type_, ast_type) { + (Type::Univ(u), ast::Type::Var(v)) => { + Some(u) == self.type_vars.borrow().0.get_by_left(v) + } + (Type::Univ(_), _) => false, + (Type::Exist(_), _) => false, + (Type::Unit, ast::Type::Unit) => true, + (Type::Unit, _) => false, + (Type::Nullary(_), _) => todo!(), + (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty, + (Type::Fun { args, ret }, ast::Type::Function(ft)) => { + args.len() == ft.args.len() + && args + .iter() + .zip(&ft.args) + .all(|(a1, a2)| self.types_match(a1, &a2)) + && self.types_match(&*ret, &*ft.ret) + } + (Type::Fun { .. }, _) => false, + } + } +} + +pub fn typecheck_expr(expr: ast::Expr) -> Result> { + let mut typechecker = Typechecker::new(); + let typechecked = typechecker.tc_expr(expr)?; + typechecker.finalize_expr(typechecked) +} + +pub fn typecheck_toplevel(decls: Vec) -> Result>> { + let mut typechecker = Typechecker::new(); + let mut res = Vec::with_capacity(decls.len()); + for decl in decls { + if let Some(hir_decl) = typechecker.tc_decl(decl)? { + let hir_decl = typechecker.finalize_decl(hir_decl)?; + res.push(hir_decl); + } + typechecker.ctx.clear(); + } + Ok(res) +} + +#[cfg(test)] +mod tests { + use super::*; + + macro_rules! assert_type { + ($expr: expr, $type: expr) => { + use crate::parser::{expr, type_}; + let parsed_expr = test_parse!(expr, $expr); + let parsed_type = test_parse!(type_, $type); + let res = typecheck_expr(parsed_expr).unwrap_or_else(|e| panic!("{}", e)); + assert!( + res.type_().alpha_equiv(&parsed_type), + "{} inferred type {}, but expected {}", + $expr, + res.type_(), + $type + ); + }; + + (toplevel($program: expr), $($decl: ident => $type: expr),+ $(,)?) => {{ + use crate::parser::{toplevel, type_}; + let program = test_parse!(toplevel, $program); + let res = typecheck_toplevel(program).unwrap_or_else(|e| panic!("{}", e)); + $( + let parsed_type = test_parse!(type_, $type); + let ident = Ident::try_from(::std::stringify!($decl)).unwrap(); + let decl = res.iter().find(|decl| { + matches!(decl, crate::ast::hir::Decl::Fun { name, .. } if name == &ident) + }).unwrap_or_else(|| panic!("Could not find declaration for {}", ident)); + assert!( + decl.type_().unwrap().alpha_equiv(&parsed_type), + "inferred type {} for {}, but expected {}", + decl.type_().unwrap(), + ident, + $type + ); + )+ + }}; + } + + macro_rules! assert_type_error { + ($expr: expr) => { + use crate::parser::expr; + let parsed_expr = test_parse!(expr, $expr); + let res = typecheck_expr(parsed_expr); + assert!( + res.is_err(), + "Expected type error, but got type: {}", + res.unwrap().type_() + ); + }; + } + + #[test] + fn literal_int() { + assert_type!("1", "int"); + } + + #[test] + fn conditional() { + assert_type!("if 1 == 2 then 3 else 4", "int"); + } + + #[test] + #[ignore] + fn add_bools() { + assert_type_error!("true + false"); + } + + #[test] + fn call_generic_function() { + assert_type!("(fn x = x) 1", "int"); + } + + #[test] + fn call_let_bound_generic() { + assert_type!("let id = fn x = x in id 1", "int"); + } + + #[test] + fn universal_ascripted_let() { + assert_type!("let id: fn a -> a = fn x = x in id 1", "int"); + } + + #[test] + fn call_generic_function_toplevel() { + assert_type!( + toplevel( + "ty id : fn a -> a + fn id x = x + + fn main = id 0" + ), + main => "fn -> int", + id => "fn a -> a", + ); + } + + #[test] + #[ignore] + fn let_generalization() { + assert_type!("let id = fn x = x in if id true then id 1 else 2", "int"); + } + + #[test] + fn concrete_function() { + assert_type!("fn x = x + 1", "fn int -> int"); + } + + #[test] + fn arg_ascriptions() { + assert_type!("fn (x: int) = x", "fn int -> int"); + } + + #[test] + fn call_concrete_function() { + assert_type!("(fn x = x + 1) 2", "int"); + } + + #[test] + fn conditional_non_bool() { + assert_type_error!("if 3 then true else false"); + } + + #[test] + fn let_int() { + assert_type!("let x = 1 in x", "int"); + } +} -- cgit 1.4.1