about summary refs log tree commit diff
path: root/src/ast
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast')
-rw-r--r--src/ast/hir.rs6
-rw-r--r--src/ast/mod.rs162
2 files changed, 154 insertions, 14 deletions
diff --git a/src/ast/hir.rs b/src/ast/hir.rs
index 9db6919f6f53..6859174a2dd0 100644
--- a/src/ast/hir.rs
+++ b/src/ast/hir.rs
@@ -222,6 +222,12 @@ pub enum Decl<'a, T> {
 }
 
 impl<'a, T> Decl<'a, T> {
+    pub fn type_(&self) -> &T {
+        match self {
+            Decl::Fun { type_, .. } => type_,
+        }
+    }
+
     pub fn traverse_type<F, U, E>(self, f: F) -> Result<Decl<'a, U>, E>
     where
         F: Fn(T) -> Result<U, E> + Clone,
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 5526c5348350..1884ba69f43c 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -1,6 +1,7 @@
 pub(crate) mod hir;
 
 use std::borrow::Cow;
+use std::collections::HashMap;
 use std::convert::TryFrom;
 use std::fmt::{self, Display, Formatter};
 
@@ -126,7 +127,7 @@ impl<'a> Literal<'a> {
 #[derive(Debug, PartialEq, Eq, Clone)]
 pub struct Binding<'a> {
     pub ident: Ident<'a>,
-    pub type_: Option<Type>,
+    pub type_: Option<Type<'a>>,
     pub body: Expr<'a>,
 }
 
@@ -134,7 +135,7 @@ impl<'a> Binding<'a> {
     fn to_owned(&self) -> Binding<'static> {
         Binding {
             ident: self.ident.to_owned(),
-            type_: self.type_.clone(),
+            type_: self.type_.as_ref().map(|t| t.to_owned()),
             body: self.body.to_owned(),
         }
     }
@@ -177,7 +178,7 @@ pub enum Expr<'a> {
 
     Ascription {
         expr: Box<Expr<'a>>,
-        type_: Type,
+        type_: Type<'a>,
     },
 }
 
@@ -215,20 +216,46 @@ impl<'a> Expr<'a> {
             },
             Expr::Ascription { expr, type_ } => Expr::Ascription {
                 expr: Box::new((**expr).to_owned()),
-                type_: type_.clone(),
+                type_: type_.to_owned(),
             },
         }
     }
 }
 
 #[derive(Debug, PartialEq, Eq, Clone)]
+pub struct Arg<'a> {
+    pub ident: Ident<'a>,
+    pub type_: Option<Type<'a>>,
+}
+
+impl<'a> Arg<'a> {
+    pub fn to_owned(&self) -> Arg<'static> {
+        Arg {
+            ident: self.ident.to_owned(),
+            type_: self.type_.as_ref().map(Type::to_owned),
+        }
+    }
+}
+
+impl<'a> TryFrom<&'a str> for Arg<'a> {
+    type Error = <Ident<'a> as TryFrom<&'a str>>::Error;
+
+    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
+        Ok(Arg {
+            ident: Ident::try_from(value)?,
+            type_: None,
+        })
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub struct Fun<'a> {
-    pub args: Vec<Ident<'a>>,
+    pub args: Vec<Arg<'a>>,
     pub body: Expr<'a>,
 }
 
 impl<'a> Fun<'a> {
-    fn to_owned(&self) -> Fun<'static> {
+    pub fn to_owned(&self) -> Fun<'static> {
         Fun {
             args: self.args.iter().map(|arg| arg.to_owned()).collect(),
             body: self.body.to_owned(),
@@ -236,40 +263,147 @@ impl<'a> Fun<'a> {
     }
 }
 
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Decl<'a> {
     Fun { name: Ident<'a>, body: Fun<'a> },
 }
 
+////
+
 #[derive(Debug, PartialEq, Eq, Clone)]
-pub struct FunctionType {
-    pub args: Vec<Type>,
-    pub ret: Box<Type>,
+pub struct FunctionType<'a> {
+    pub args: Vec<Type<'a>>,
+    pub ret: Box<Type<'a>>,
 }
 
-impl Display for FunctionType {
+impl<'a> FunctionType<'a> {
+    pub fn to_owned(&self) -> FunctionType<'static> {
+        FunctionType {
+            args: self.args.iter().map(|a| a.to_owned()).collect(),
+            ret: Box::new((*self.ret).to_owned()),
+        }
+    }
+}
+
+impl<'a> Display for FunctionType<'a> {
     fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
         write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret)
     }
 }
 
 #[derive(Debug, PartialEq, Eq, Clone)]
-pub enum Type {
+pub enum Type<'a> {
     Int,
     Float,
     Bool,
     CString,
-    Function(FunctionType),
+    Var(Ident<'a>),
+    Function(FunctionType<'a>),
 }
 
-impl Display for Type {
+impl<'a> Type<'a> {
+    pub fn to_owned(&self) -> Type<'static> {
+        match self {
+            Type::Int => Type::Int,
+            Type::Float => Type::Float,
+            Type::Bool => Type::Bool,
+            Type::CString => Type::CString,
+            Type::Var(v) => Type::Var(v.to_owned()),
+            Type::Function(f) => Type::Function(f.to_owned()),
+        }
+    }
+
+    pub fn alpha_equiv(&self, other: &Self) -> bool {
+        fn do_alpha_equiv<'a>(
+            substs: &mut HashMap<&'a Ident<'a>, &'a Ident<'a>>,
+            lhs: &'a Type,
+            rhs: &'a Type,
+        ) -> bool {
+            match (lhs, rhs) {
+                (Type::Var(v1), Type::Var(v2)) => substs.entry(v1).or_insert(v2) == &v2,
+                (
+                    Type::Function(FunctionType {
+                        args: args1,
+                        ret: ret1,
+                    }),
+                    Type::Function(FunctionType {
+                        args: args2,
+                        ret: ret2,
+                    }),
+                ) => {
+                    args1.len() == args2.len()
+                        && args1
+                            .iter()
+                            .zip(args2)
+                            .all(|(a1, a2)| do_alpha_equiv(substs, a1, a2))
+                        && do_alpha_equiv(substs, ret1, ret2)
+                }
+                _ => lhs == rhs,
+            }
+        }
+
+        let mut substs = HashMap::new();
+        do_alpha_equiv(&mut substs, self, other)
+    }
+}
+
+impl<'a> Display for Type<'a> {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {
             Type::Int => f.write_str("int"),
             Type::Float => f.write_str("float"),
             Type::Bool => f.write_str("bool"),
             Type::CString => f.write_str("cstring"),
+            Type::Var(v) => v.fmt(f),
             Type::Function(ft) => ft.fmt(f),
         }
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+
+    fn type_var(n: &str) -> Type<'static> {
+        Type::Var(Ident::try_from(n.to_owned()).unwrap())
+    }
+
+    mod alpha_equiv {
+        use super::*;
+
+        #[test]
+        fn trivial() {
+            assert!(Type::Int.alpha_equiv(&Type::Int));
+            assert!(!Type::Int.alpha_equiv(&Type::Bool));
+        }
+
+        #[test]
+        fn simple_type_var() {
+            assert!(type_var("a").alpha_equiv(&type_var("b")));
+        }
+
+        #[test]
+        fn function_with_type_vars_equiv() {
+            assert!(Type::Function(FunctionType {
+                args: vec![type_var("a")],
+                ret: Box::new(type_var("b")),
+            })
+            .alpha_equiv(&Type::Function(FunctionType {
+                args: vec![type_var("b")],
+                ret: Box::new(type_var("a")),
+            })))
+        }
+
+        #[test]
+        fn function_with_type_vars_non_equiv() {
+            assert!(!Type::Function(FunctionType {
+                args: vec![type_var("a")],
+                ret: Box::new(type_var("a")),
+            })
+            .alpha_equiv(&Type::Function(FunctionType {
+                args: vec![type_var("b")],
+                ret: Box::new(type_var("a")),
+            })))
+        }
+    }
+}