diff options
Diffstat (limited to 'users/grfn/achilles/src/ast')
-rw-r--r-- | users/grfn/achilles/src/ast/hir.rs | 364 | ||||
-rw-r--r-- | users/grfn/achilles/src/ast/mod.rs | 484 |
2 files changed, 848 insertions, 0 deletions
diff --git a/users/grfn/achilles/src/ast/hir.rs b/users/grfn/achilles/src/ast/hir.rs new file mode 100644 index 000000000000..cdfaef567d7a --- /dev/null +++ b/users/grfn/achilles/src/ast/hir.rs @@ -0,0 +1,364 @@ +use std::collections::HashMap; + +use itertools::Itertools; + +use super::{BinaryOperator, Ident, Literal, UnaryOperator}; + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Pattern<'a, T> { + Id(Ident<'a>, T), + Tuple(Vec<Pattern<'a, T>>), +} + +impl<'a, T> Pattern<'a, T> { + pub fn to_owned(&self) -> Pattern<'static, T> + where + T: Clone, + { + match self { + Pattern::Id(id, t) => Pattern::Id(id.to_owned(), t.clone()), + Pattern::Tuple(pats) => { + Pattern::Tuple(pats.into_iter().map(Pattern::to_owned).collect()) + } + } + } + + pub fn traverse_type<F, U, E>(self, f: F) -> Result<Pattern<'a, U>, E> + where + F: Fn(T) -> Result<U, E> + Clone, + { + match self { + Pattern::Id(id, t) => Ok(Pattern::Id(id, f(t)?)), + Pattern::Tuple(pats) => Ok(Pattern::Tuple( + pats.into_iter() + .map(|pat| pat.traverse_type(f.clone())) + .collect::<Result<Vec<_>, _>>()?, + )), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Binding<'a, T> { + pub pat: Pattern<'a, T>, + pub body: Expr<'a, T>, +} + +impl<'a, T> Binding<'a, T> { + fn to_owned(&self) -> Binding<'static, T> + where + T: Clone, + { + Binding { + pat: self.pat.to_owned(), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr<'a, T> { + Ident(Ident<'a>, T), + + Literal(Literal<'a>, T), + + Tuple(Vec<Expr<'a, T>>, T), + + UnaryOp { + op: UnaryOperator, + rhs: Box<Expr<'a, T>>, + type_: T, + }, + + BinaryOp { + lhs: Box<Expr<'a, T>>, + op: BinaryOperator, + rhs: Box<Expr<'a, T>>, + type_: T, + }, + + Let { + bindings: Vec<Binding<'a, T>>, + body: Box<Expr<'a, T>>, + type_: T, + }, + + If { + condition: Box<Expr<'a, T>>, + then: Box<Expr<'a, T>>, + else_: Box<Expr<'a, T>>, + type_: T, + }, + + Fun { + type_args: Vec<Ident<'a>>, + args: Vec<(Ident<'a>, T)>, + body: Box<Expr<'a, T>>, + type_: T, + }, + + Call { + fun: Box<Expr<'a, T>>, + type_args: HashMap<Ident<'a>, T>, + args: Vec<Expr<'a, T>>, + type_: T, + }, +} + +impl<'a, T> Expr<'a, T> { + pub fn type_(&self) -> &T { + match self { + Expr::Ident(_, t) => t, + Expr::Literal(_, t) => t, + Expr::Tuple(_, t) => t, + Expr::UnaryOp { type_, .. } => type_, + Expr::BinaryOp { type_, .. } => type_, + Expr::Let { type_, .. } => type_, + Expr::If { type_, .. } => type_, + Expr::Fun { type_, .. } => type_, + Expr::Call { type_, .. } => type_, + } + } + + pub fn traverse_type<F, U, E>(self, f: F) -> Result<Expr<'a, U>, E> + where + F: Fn(T) -> Result<U, E> + Clone, + { + match self { + Expr::Ident(id, t) => Ok(Expr::Ident(id, f(t)?)), + Expr::Literal(lit, t) => Ok(Expr::Literal(lit, f(t)?)), + Expr::UnaryOp { op, rhs, type_ } => Ok(Expr::UnaryOp { + op, + rhs: Box::new(rhs.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => Ok(Expr::BinaryOp { + lhs: Box::new(lhs.traverse_type(f.clone())?), + op, + rhs: Box::new(rhs.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Let { + bindings, + body, + type_, + } => Ok(Expr::Let { + bindings: bindings + .into_iter() + .map(|Binding { pat, body }| { + Ok(Binding { + pat: pat.traverse_type(f.clone())?, + body: body.traverse_type(f.clone())?, + }) + }) + .collect::<Result<Vec<_>, E>>()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::If { + condition, + then, + else_, + type_, + } => Ok(Expr::If { + condition: Box::new(condition.traverse_type(f.clone())?), + then: Box::new(then.traverse_type(f.clone())?), + else_: Box::new(else_.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Fun { + args, + type_args, + body, + type_, + } => Ok(Expr::Fun { + args: args + .into_iter() + .map(|(id, t)| Ok((id, f.clone()(t)?))) + .collect::<Result<Vec<_>, E>>()?, + type_args, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Expr::Call { + fun, + type_args, + args, + type_, + } => Ok(Expr::Call { + fun: Box::new(fun.traverse_type(f.clone())?), + type_args: type_args + .into_iter() + .map(|(id, ty)| Ok((id, f.clone()(ty)?))) + .collect::<Result<HashMap<_, _>, E>>()?, + args: args + .into_iter() + .map(|e| e.traverse_type(f.clone())) + .collect::<Result<Vec<_>, E>>()?, + type_: f(type_)?, + }), + Expr::Tuple(members, t) => Ok(Expr::Tuple( + members + .into_iter() + .map(|t| t.traverse_type(f.clone())) + .try_collect()?, + f(t)?, + )), + } + } + + pub fn to_owned(&self) -> Expr<'static, T> + where + T: Clone, + { + match self { + Expr::Ident(id, t) => Expr::Ident(id.to_owned(), t.clone()), + Expr::Literal(lit, t) => Expr::Literal(lit.to_owned(), t.clone()), + Expr::UnaryOp { op, rhs, type_ } => Expr::UnaryOp { + op: *op, + rhs: Box::new((**rhs).to_owned()), + type_: type_.clone(), + }, + Expr::BinaryOp { + lhs, + op, + rhs, + type_, + } => Expr::BinaryOp { + lhs: Box::new((**lhs).to_owned()), + op: *op, + rhs: Box::new((**rhs).to_owned()), + type_: type_.clone(), + }, + Expr::Let { + bindings, + body, + type_, + } => Expr::Let { + bindings: bindings.iter().map(|b| b.to_owned()).collect(), + body: Box::new((**body).to_owned()), + type_: type_.clone(), + }, + Expr::If { + condition, + then, + else_, + type_, + } => Expr::If { + condition: Box::new((**condition).to_owned()), + then: Box::new((**then).to_owned()), + else_: Box::new((**else_).to_owned()), + type_: type_.clone(), + }, + Expr::Fun { + args, + type_args, + body, + type_, + } => Expr::Fun { + args: args + .iter() + .map(|(id, t)| (id.to_owned(), t.clone())) + .collect(), + type_args: type_args.iter().map(|arg| arg.to_owned()).collect(), + body: Box::new((**body).to_owned()), + type_: type_.clone(), + }, + Expr::Call { + fun, + type_args, + args, + type_, + } => Expr::Call { + fun: Box::new((**fun).to_owned()), + type_args: type_args + .iter() + .map(|(id, t)| (id.to_owned(), t.clone())) + .collect(), + args: args.iter().map(|e| e.to_owned()).collect(), + type_: type_.clone(), + }, + Expr::Tuple(members, t) => { + Expr::Tuple(members.into_iter().map(Expr::to_owned).collect(), t.clone()) + } + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum Decl<'a, T> { + Fun { + name: Ident<'a>, + type_args: Vec<Ident<'a>>, + args: Vec<(Ident<'a>, T)>, + body: Box<Expr<'a, T>>, + type_: T, + }, + + Extern { + name: Ident<'a>, + arg_types: Vec<T>, + ret_type: T, + }, +} + +impl<'a, T> Decl<'a, T> { + pub fn name(&self) -> &Ident<'a> { + match self { + Decl::Fun { name, .. } => name, + Decl::Extern { name, .. } => name, + } + } + + pub fn set_name(&mut self, new_name: Ident<'a>) { + match self { + Decl::Fun { name, .. } => *name = new_name, + Decl::Extern { name, .. } => *name = new_name, + } + } + + pub fn type_(&self) -> Option<&T> { + match self { + Decl::Fun { type_, .. } => Some(type_), + Decl::Extern { .. } => None, + } + } + + pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E> + where + F: Fn(T) -> Result<U, E> + Clone, + { + match self { + Decl::Fun { + name, + type_args, + args, + body, + type_, + } => Ok(Decl::Fun { + name, + type_args, + args: args + .into_iter() + .map(|(id, t)| Ok((id, f(t)?))) + .try_collect()?, + body: Box::new(body.traverse_type(f.clone())?), + type_: f(type_)?, + }), + Decl::Extern { + name, + arg_types, + ret_type, + } => Ok(Decl::Extern { + name, + arg_types: arg_types.into_iter().map(f.clone()).try_collect()?, + ret_type: f(ret_type)?, + }), + } + } +} diff --git a/users/grfn/achilles/src/ast/mod.rs b/users/grfn/achilles/src/ast/mod.rs new file mode 100644 index 000000000000..5438d29d2cf7 --- /dev/null +++ b/users/grfn/achilles/src/ast/mod.rs @@ -0,0 +1,484 @@ +pub(crate) mod hir; + +use std::borrow::Cow; +use std::collections::HashMap; +use std::convert::TryFrom; +use std::fmt::{self, Display, Formatter}; + +use itertools::Itertools; + +#[derive(Debug, PartialEq, Eq)] +pub struct InvalidIdentifier<'a>(Cow<'a, str>); + +#[derive(Debug, PartialEq, Eq, Hash, Clone)] +pub struct Ident<'a>(pub Cow<'a, str>); + +impl<'a> From<&'a Ident<'a>> for &'a str { + fn from(id: &'a Ident<'a>) -> Self { + id.0.as_ref() + } +} + +impl<'a> Display for Ident<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "{}", self.0) + } +} + +impl<'a> Ident<'a> { + pub fn to_owned(&self) -> Ident<'static> { + Ident(Cow::Owned(self.0.clone().into_owned())) + } + + /// Construct an identifier from a &str without checking that it's a valid identifier + pub fn from_str_unchecked(s: &'a str) -> Self { + debug_assert!(is_valid_identifier(s)); + Self(Cow::Borrowed(s)) + } + + pub fn from_string_unchecked(s: String) -> Self { + debug_assert!(is_valid_identifier(&s)); + Self(Cow::Owned(s)) + } +} + +pub fn is_valid_identifier<S>(s: &S) -> bool +where + S: AsRef<str> + ?Sized, +{ + s.as_ref() + .chars() + .any(|c| !c.is_alphanumeric() || !"_".contains(c)) +} + +impl<'a> TryFrom<&'a str> for Ident<'a> { + type Error = InvalidIdentifier<'a>; + + fn try_from(s: &'a str) -> Result<Self, Self::Error> { + if is_valid_identifier(s) { + Ok(Ident(Cow::Borrowed(s))) + } else { + Err(InvalidIdentifier(Cow::Borrowed(s))) + } + } +} + +impl<'a> TryFrom<String> for Ident<'a> { + type Error = InvalidIdentifier<'static>; + + fn try_from(s: String) -> Result<Self, Self::Error> { + if is_valid_identifier(&s) { + Ok(Ident(Cow::Owned(s))) + } else { + Err(InvalidIdentifier(Cow::Owned(s))) + } + } +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum BinaryOperator { + /// `+` + Add, + + /// `-` + Sub, + + /// `*` + Mul, + + /// `/` + Div, + + /// `^` + Pow, + + /// `==` + Equ, + + /// `!=` + Neq, +} + +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +pub enum UnaryOperator { + /// ! + Not, + + /// - + Neg, +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Literal<'a> { + Unit, + Int(u64), + Bool(bool), + String(Cow<'a, str>), +} + +impl<'a> Literal<'a> { + pub fn to_owned(&self) -> Literal<'static> { + match self { + Literal::Int(i) => Literal::Int(*i), + Literal::Bool(b) => Literal::Bool(*b), + Literal::String(s) => Literal::String(Cow::Owned(s.clone().into_owned())), + Literal::Unit => Literal::Unit, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Pattern<'a> { + Id(Ident<'a>), + Tuple(Vec<Pattern<'a>>), +} + +impl<'a> Pattern<'a> { + pub fn to_owned(&self) -> Pattern<'static> { + match self { + Pattern::Id(id) => Pattern::Id(id.to_owned()), + Pattern::Tuple(pats) => Pattern::Tuple(pats.iter().map(Pattern::to_owned).collect()), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Binding<'a> { + pub pat: Pattern<'a>, + pub type_: Option<Type<'a>>, + pub body: Expr<'a>, +} + +impl<'a> Binding<'a> { + fn to_owned(&self) -> Binding<'static> { + Binding { + pat: self.pat.to_owned(), + type_: self.type_.as_ref().map(|t| t.to_owned()), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Expr<'a> { + Ident(Ident<'a>), + + Literal(Literal<'a>), + + UnaryOp { + op: UnaryOperator, + rhs: Box<Expr<'a>>, + }, + + BinaryOp { + lhs: Box<Expr<'a>>, + op: BinaryOperator, + rhs: Box<Expr<'a>>, + }, + + Let { + bindings: Vec<Binding<'a>>, + body: Box<Expr<'a>>, + }, + + If { + condition: Box<Expr<'a>>, + then: Box<Expr<'a>>, + else_: Box<Expr<'a>>, + }, + + Fun(Box<Fun<'a>>), + + Call { + fun: Box<Expr<'a>>, + args: Vec<Expr<'a>>, + }, + + Tuple(Vec<Expr<'a>>), + + Ascription { + expr: Box<Expr<'a>>, + type_: Type<'a>, + }, +} + +impl<'a> Expr<'a> { + pub fn to_owned(&self) -> Expr<'static> { + match self { + Expr::Ident(ref id) => Expr::Ident(id.to_owned()), + Expr::Literal(ref lit) => Expr::Literal(lit.to_owned()), + Expr::Tuple(ref members) => { + Expr::Tuple(members.into_iter().map(Expr::to_owned).collect()) + } + Expr::UnaryOp { op, rhs } => Expr::UnaryOp { + op: *op, + rhs: Box::new((**rhs).to_owned()), + }, + Expr::BinaryOp { lhs, op, rhs } => Expr::BinaryOp { + lhs: Box::new((**lhs).to_owned()), + op: *op, + rhs: Box::new((**rhs).to_owned()), + }, + Expr::Let { bindings, body } => Expr::Let { + bindings: bindings.iter().map(|binding| binding.to_owned()).collect(), + body: Box::new((**body).to_owned()), + }, + Expr::If { + condition, + then, + else_, + } => Expr::If { + condition: Box::new((**condition).to_owned()), + then: Box::new((**then).to_owned()), + else_: Box::new((**else_).to_owned()), + }, + Expr::Fun(fun) => Expr::Fun(Box::new((**fun).to_owned())), + Expr::Call { fun, args } => Expr::Call { + fun: Box::new((**fun).to_owned()), + args: args.iter().map(|arg| arg.to_owned()).collect(), + }, + Expr::Ascription { expr, type_ } => Expr::Ascription { + expr: Box::new((**expr).to_owned()), + type_: type_.to_owned(), + }, + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Arg<'a> { + pub ident: Ident<'a>, + pub type_: Option<Type<'a>>, +} + +impl<'a> Arg<'a> { + pub fn to_owned(&self) -> Arg<'static> { + Arg { + ident: self.ident.to_owned(), + type_: self.type_.as_ref().map(Type::to_owned), + } + } +} + +impl<'a> TryFrom<&'a str> for Arg<'a> { + type Error = <Ident<'a> as TryFrom<&'a str>>::Error; + + fn try_from(value: &'a str) -> Result<Self, Self::Error> { + Ok(Arg { + ident: Ident::try_from(value)?, + type_: None, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct Fun<'a> { + pub args: Vec<Arg<'a>>, + pub body: Expr<'a>, +} + +impl<'a> Fun<'a> { + pub fn to_owned(&self) -> Fun<'static> { + Fun { + args: self.args.iter().map(|arg| arg.to_owned()).collect(), + body: self.body.to_owned(), + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Decl<'a> { + Fun { + name: Ident<'a>, + body: Fun<'a>, + }, + Ascription { + name: Ident<'a>, + type_: Type<'a>, + }, + Extern { + name: Ident<'a>, + type_: FunctionType<'a>, + }, +} + +//// + +#[derive(Debug, PartialEq, Eq, Clone)] +pub struct FunctionType<'a> { + pub args: Vec<Type<'a>>, + pub ret: Box<Type<'a>>, +} + +impl<'a> FunctionType<'a> { + pub fn to_owned(&self) -> FunctionType<'static> { + FunctionType { + args: self.args.iter().map(|a| a.to_owned()).collect(), + ret: Box::new((*self.ret).to_owned()), + } + } +} + +impl<'a> Display for FunctionType<'a> { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { + write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] +pub enum Type<'a> { + Int, + Float, + Bool, + CString, + Unit, + Tuple(Vec<Type<'a>>), + Var(Ident<'a>), + Function(FunctionType<'a>), +} + +impl<'a> Type<'a> { + pub fn to_owned(&self) -> Type<'static> { + match self { + Type::Int => Type::Int, + Type::Float => Type::Float, + Type::Bool => Type::Bool, + Type::CString => Type::CString, + Type::Unit => Type::Unit, + Type::Var(v) => Type::Var(v.to_owned()), + Type::Function(f) => Type::Function(f.to_owned()), + Type::Tuple(members) => Type::Tuple(members.iter().map(Type::to_owned).collect()), + } + } + + pub fn alpha_equiv(&self, other: &Self) -> bool { + fn do_alpha_equiv<'a>( + substs: &mut HashMap<&'a Ident<'a>, &'a Ident<'a>>, + lhs: &'a Type, + rhs: &'a Type, + ) -> bool { + match (lhs, rhs) { + (Type::Var(v1), Type::Var(v2)) => substs.entry(v1).or_insert(v2) == &v2, + ( + Type::Function(FunctionType { + args: args1, + ret: ret1, + }), + Type::Function(FunctionType { + args: args2, + ret: ret2, + }), + ) => { + args1.len() == args2.len() + && args1 + .iter() + .zip(args2) + .all(|(a1, a2)| do_alpha_equiv(substs, a1, a2)) + && do_alpha_equiv(substs, ret1, ret2) + } + _ => lhs == rhs, + } + } + + let mut substs = HashMap::new(); + do_alpha_equiv(&mut substs, self, other) + } + + pub fn traverse_type_vars<'b, F>(self, mut f: F) -> Type<'b> + where + F: FnMut(Ident<'a>) -> Type<'b> + Clone, + { + match self { + Type::Var(tv) => f(tv), + Type::Function(FunctionType { args, ret }) => Type::Function(FunctionType { + args: args + .into_iter() + .map(|t| t.traverse_type_vars(f.clone())) + .collect(), + ret: Box::new(ret.traverse_type_vars(f)), + }), + Type::Int => Type::Int, + Type::Float => Type::Float, + Type::Bool => Type::Bool, + Type::CString => Type::CString, + Type::Tuple(members) => Type::Tuple( + members + .into_iter() + .map(|t| t.traverse_type_vars(f.clone())) + .collect(), + ), + Type::Unit => Type::Unit, + } + } + + pub fn as_tuple(&self) -> Option<&Vec<Type<'a>>> { + if let Self::Tuple(v) = self { + Some(v) + } else { + None + } + } +} + +impl<'a> Display for Type<'a> { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Type::Int => f.write_str("int"), + Type::Float => f.write_str("float"), + Type::Bool => f.write_str("bool"), + Type::CString => f.write_str("cstring"), + Type::Unit => f.write_str("()"), + Type::Var(v) => v.fmt(f), + Type::Function(ft) => ft.fmt(f), + Type::Tuple(ms) => write!(f, "({})", ms.iter().join(", ")), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + fn type_var(n: &str) -> Type<'static> { + Type::Var(Ident::try_from(n.to_owned()).unwrap()) + } + + mod alpha_equiv { + use super::*; + + #[test] + fn trivial() { + assert!(Type::Int.alpha_equiv(&Type::Int)); + assert!(!Type::Int.alpha_equiv(&Type::Bool)); + } + + #[test] + fn simple_type_var() { + assert!(type_var("a").alpha_equiv(&type_var("b"))); + } + + #[test] + fn function_with_type_vars_equiv() { + assert!(Type::Function(FunctionType { + args: vec![type_var("a")], + ret: Box::new(type_var("b")), + }) + .alpha_equiv(&Type::Function(FunctionType { + args: vec![type_var("b")], + ret: Box::new(type_var("a")), + }))) + } + + #[test] + fn function_with_type_vars_non_equiv() { + assert!(!Type::Function(FunctionType { + args: vec![type_var("a")], + ret: Box::new(type_var("a")), + }) + .alpha_equiv(&Type::Function(FunctionType { + args: vec![type_var("b")], + ret: Box::new(type_var("a")), + }))) + } + } +} |