diff options
Diffstat (limited to 'src/ast')
-rw-r--r-- | src/ast/hir.rs | 6 | ||||
-rw-r--r-- | src/ast/mod.rs | 162 |
2 files changed, 154 insertions, 14 deletions
diff --git a/src/ast/hir.rs b/src/ast/hir.rs index 9db6919f6f53..6859174a2dd0 100644 --- a/src/ast/hir.rs +++ b/src/ast/hir.rs @@ -222,6 +222,12 @@ pub enum Decl<'a, T> { } impl<'a, T> Decl<'a, T> { + pub fn type_(&self) -> &T { + match self { + Decl::Fun { type_, .. } => type_, + } + } + pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E> where F: Fn(T) -> Result<U, E> + Clone, diff --git a/src/ast/mod.rs b/src/ast/mod.rs index 5526c5348350..1884ba69f43c 100644 --- a/src/ast/mod.rs +++ b/src/ast/mod.rs @@ -1,6 +1,7 @@ pub(crate) mod hir; use std::borrow::Cow; +use std::collections::HashMap; use std::convert::TryFrom; use std::fmt::{self, Display, Formatter}; @@ -126,7 +127,7 @@ impl<'a> Literal<'a> { #[derive(Debug, PartialEq, Eq, Clone)] pub struct Binding<'a> { pub ident: Ident<'a>, - pub type_: Option<Type>, + pub type_: Option<Type<'a>>, pub body: Expr<'a>, } @@ -134,7 +135,7 @@ impl<'a> Binding<'a> { fn to_owned(&self) -> Binding<'static> { Binding { ident: self.ident.to_owned(), - type_: self.type_.clone(), + type_: self.type_.as_ref().map(|t| t.to_owned()), body: self.body.to_owned(), } } @@ -177,7 +178,7 @@ pub enum Expr<'a> { Ascription { expr: Box<Expr<'a>>, - type_: Type, + type_: Type<'a>, }, } @@ -215,20 +216,46 @@ impl<'a> Expr<'a> { }, Expr::Ascription { expr, type_ } => Expr::Ascription { expr: Box::new((**expr).to_owned()), - type_: type_.clone(), + type_: type_.to_owned(), }, } } } #[derive(Debug, PartialEq, Eq, Clone)] +pub struct Arg<'a> { + pub ident: Ident<'a>, + pub type_: Option<Type<'a>>, +} + +impl<'a> Arg<'a> { + pub fn to_owned(&self) -> Arg<'static> { + Arg { + ident: self.ident.to_owned(), + type_: self.type_.as_ref().map(Type::to_owned), + } + } +} + +impl<'a> TryFrom<&'a str> for Arg<'a> { + type Error = <Ident<'a> as TryFrom<&'a str>>::Error; + + fn try_from(value: &'a str) -> Result<Self, Self::Error> { + Ok(Arg { + ident: Ident::try_from(value)?, + type_: None, + }) + } +} + +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Fun<'a> { - pub args: Vec<Ident<'a>>, + pub args: Vec<Arg<'a>>, pub body: Expr<'a>, } impl<'a> Fun<'a> { - fn to_owned(&self) -> Fun<'static> { + pub fn to_owned(&self) -> Fun<'static> { Fun { args: self.args.iter().map(|arg| arg.to_owned()).collect(), body: self.body.to_owned(), @@ -236,40 +263,147 @@ impl<'a> Fun<'a> { } } -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Clone)] pub enum Decl<'a> { Fun { name: Ident<'a>, body: Fun<'a> }, } +//// + #[derive(Debug, PartialEq, Eq, Clone)] -pub struct FunctionType { - pub args: Vec<Type>, - pub ret: Box<Type>, +pub struct FunctionType<'a> { + pub args: Vec<Type<'a>>, + pub ret: Box<Type<'a>>, } -impl Display for FunctionType { +impl<'a> FunctionType<'a> { + pub fn to_owned(&self) -> FunctionType<'static> { + FunctionType { + args: self.args.iter().map(|a| a.to_owned()).collect(), + ret: Box::new((*self.ret).to_owned()), + } + } +} + +impl<'a> Display for FunctionType<'a> { fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret) } } #[derive(Debug, PartialEq, Eq, Clone)] -pub enum Type { +pub enum Type<'a> { Int, Float, Bool, CString, - Function(FunctionType), + Var(Ident<'a>), + Function(FunctionType<'a>), } -impl Display for Type { +impl<'a> Type<'a> { + pub fn to_owned(&self) -> Type<'static> { + match self { + Type::Int => Type::Int, + Type::Float => Type::Float, + Type::Bool => Type::Bool, + Type::CString => Type::CString, + Type::Var(v) => Type::Var(v.to_owned()), + Type::Function(f) => Type::Function(f.to_owned()), + } + } + + pub fn alpha_equiv(&self, other: &Self) -> bool { + fn do_alpha_equiv<'a>( + substs: &mut HashMap<&'a Ident<'a>, &'a Ident<'a>>, + lhs: &'a Type, + rhs: &'a Type, + ) -> bool { + match (lhs, rhs) { + (Type::Var(v1), Type::Var(v2)) => substs.entry(v1).or_insert(v2) == &v2, + ( + Type::Function(FunctionType { + args: args1, + ret: ret1, + }), + Type::Function(FunctionType { + args: args2, + ret: ret2, + }), + ) => { + args1.len() == args2.len() + && args1 + .iter() + .zip(args2) + .all(|(a1, a2)| do_alpha_equiv(substs, a1, a2)) + && do_alpha_equiv(substs, ret1, ret2) + } + _ => lhs == rhs, + } + } + + let mut substs = HashMap::new(); + do_alpha_equiv(&mut substs, self, other) + } +} + +impl<'a> Display for Type<'a> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Type::Int => f.write_str("int"), Type::Float => f.write_str("float"), Type::Bool => f.write_str("bool"), Type::CString => f.write_str("cstring"), + Type::Var(v) => v.fmt(f), Type::Function(ft) => ft.fmt(f), } } } + +#[cfg(test)] +mod tests { + use super::*; + + fn type_var(n: &str) -> Type<'static> { + Type::Var(Ident::try_from(n.to_owned()).unwrap()) + } + + mod alpha_equiv { + use super::*; + + #[test] + fn trivial() { + assert!(Type::Int.alpha_equiv(&Type::Int)); + assert!(!Type::Int.alpha_equiv(&Type::Bool)); + } + + #[test] + fn simple_type_var() { + assert!(type_var("a").alpha_equiv(&type_var("b"))); + } + + #[test] + fn function_with_type_vars_equiv() { + assert!(Type::Function(FunctionType { + args: vec![type_var("a")], + ret: Box::new(type_var("b")), + }) + .alpha_equiv(&Type::Function(FunctionType { + args: vec![type_var("b")], + ret: Box::new(type_var("a")), + }))) + } + + #[test] + fn function_with_type_vars_non_equiv() { + assert!(!Type::Function(FunctionType { + args: vec![type_var("a")], + ret: Box::new(type_var("a")), + }) + .alpha_equiv(&Type::Function(FunctionType { + args: vec![type_var("b")], + ret: Box::new(type_var("a")), + }))) + } + } +} |