about summary refs log tree commit diff
path: root/users/glittershark/achilles/src/ast
diff options
context:
space:
mode:
Diffstat (limited to 'users/glittershark/achilles/src/ast')
-rw-r--r--users/glittershark/achilles/src/ast/hir.rs59
-rw-r--r--users/glittershark/achilles/src/ast/mod.rs20
2 files changed, 72 insertions, 7 deletions
diff --git a/users/glittershark/achilles/src/ast/hir.rs b/users/glittershark/achilles/src/ast/hir.rs
index 691b9607e7..8726af5093 100644
--- a/users/glittershark/achilles/src/ast/hir.rs
+++ b/users/glittershark/achilles/src/ast/hir.rs
@@ -1,3 +1,5 @@
+use std::collections::HashMap;
+
 use itertools::Itertools;
 
 use super::{BinaryOperator, Ident, Literal, UnaryOperator};
@@ -55,6 +57,7 @@ pub enum Expr<'a, T> {
     },
 
     Fun {
+        type_args: Vec<Ident<'a>>,
         args: Vec<(Ident<'a>, T)>,
         body: Box<Expr<'a, T>>,
         type_: T,
@@ -62,6 +65,7 @@ pub enum Expr<'a, T> {
 
     Call {
         fun: Box<Expr<'a, T>>,
+        type_args: HashMap<Ident<'a>, T>,
         args: Vec<Expr<'a, T>>,
         type_: T,
     },
@@ -133,16 +137,31 @@ impl<'a, T> Expr<'a, T> {
                 else_: Box::new(else_.traverse_type(f.clone())?),
                 type_: f(type_)?,
             }),
-            Expr::Fun { args, body, type_ } => Ok(Expr::Fun {
+            Expr::Fun {
+                args,
+                type_args,
+                body,
+                type_,
+            } => Ok(Expr::Fun {
                 args: args
                     .into_iter()
                     .map(|(id, t)| Ok((id, f.clone()(t)?)))
                     .collect::<Result<Vec<_>, E>>()?,
+                type_args,
                 body: Box::new(body.traverse_type(f.clone())?),
                 type_: f(type_)?,
             }),
-            Expr::Call { fun, args, type_ } => Ok(Expr::Call {
+            Expr::Call {
+                fun,
+                type_args,
+                args,
+                type_,
+            } => Ok(Expr::Call {
                 fun: Box::new(fun.traverse_type(f.clone())?),
+                type_args: type_args
+                    .into_iter()
+                    .map(|(id, ty)| Ok((id, f.clone()(ty)?)))
+                    .collect::<Result<HashMap<_, _>, E>>()?,
                 args: args
                     .into_iter()
                     .map(|e| e.traverse_type(f.clone()))
@@ -180,7 +199,7 @@ impl<'a, T> Expr<'a, T> {
                 body,
                 type_,
             } => Expr::Let {
-                bindings: bindings.into_iter().map(|b| b.to_owned()).collect(),
+                bindings: bindings.iter().map(|b| b.to_owned()).collect(),
                 body: Box::new((**body).to_owned()),
                 type_: type_.clone(),
             },
@@ -195,26 +214,43 @@ impl<'a, T> Expr<'a, T> {
                 else_: Box::new((**else_).to_owned()),
                 type_: type_.clone(),
             },
-            Expr::Fun { args, body, type_ } => Expr::Fun {
+            Expr::Fun {
+                args,
+                type_args,
+                body,
+                type_,
+            } => Expr::Fun {
                 args: args
-                    .into_iter()
+                    .iter()
                     .map(|(id, t)| (id.to_owned(), t.clone()))
                     .collect(),
+                type_args: type_args.iter().map(|arg| arg.to_owned()).collect(),
                 body: Box::new((**body).to_owned()),
                 type_: type_.clone(),
             },
-            Expr::Call { fun, args, type_ } => Expr::Call {
+            Expr::Call {
+                fun,
+                type_args,
+                args,
+                type_,
+            } => Expr::Call {
                 fun: Box::new((**fun).to_owned()),
-                args: args.into_iter().map(|e| e.to_owned()).collect(),
+                type_args: type_args
+                    .iter()
+                    .map(|(id, t)| (id.to_owned(), t.clone()))
+                    .collect(),
+                args: args.iter().map(|e| e.to_owned()).collect(),
                 type_: type_.clone(),
             },
         }
     }
 }
 
+#[derive(Debug, Clone)]
 pub enum Decl<'a, T> {
     Fun {
         name: Ident<'a>,
+        type_args: Vec<Ident<'a>>,
         args: Vec<(Ident<'a>, T)>,
         body: Box<Expr<'a, T>>,
         type_: T,
@@ -235,6 +271,13 @@ impl<'a, T> Decl<'a, T> {
         }
     }
 
+    pub fn set_name(&mut self, new_name: Ident<'a>) {
+        match self {
+            Decl::Fun { name, .. } => *name = new_name,
+            Decl::Extern { name, .. } => *name = new_name,
+        }
+    }
+
     pub fn type_(&self) -> Option<&T> {
         match self {
             Decl::Fun { type_, .. } => Some(type_),
@@ -249,11 +292,13 @@ impl<'a, T> Decl<'a, T> {
         match self {
             Decl::Fun {
                 name,
+                type_args,
                 args,
                 body,
                 type_,
             } => Ok(Decl::Fun {
                 name,
+                type_args,
                 args: args
                     .into_iter()
                     .map(|(id, t)| Ok((id, f(t)?)))
diff --git a/users/glittershark/achilles/src/ast/mod.rs b/users/glittershark/achilles/src/ast/mod.rs
index 22d16c9364..53f222a6a1 100644
--- a/users/glittershark/achilles/src/ast/mod.rs
+++ b/users/glittershark/achilles/src/ast/mod.rs
@@ -356,6 +356,26 @@ impl<'a> Type<'a> {
         let mut substs = HashMap::new();
         do_alpha_equiv(&mut substs, self, other)
     }
+
+    pub fn traverse_type_vars<'b, F>(self, mut f: F) -> Type<'b>
+    where
+        F: FnMut(Ident<'a>) -> Type<'b> + Clone,
+    {
+        match self {
+            Type::Var(tv) => f(tv),
+            Type::Function(FunctionType { args, ret }) => Type::Function(FunctionType {
+                args: args
+                    .into_iter()
+                    .map(|t| t.traverse_type_vars(f.clone()))
+                    .collect(),
+                ret: Box::new(ret.traverse_type_vars(f)),
+            }),
+            Type::Int => Type::Int,
+            Type::Float => Type::Float,
+            Type::Bool => Type::Bool,
+            Type::CString => Type::CString,
+        }
+    }
 }
 
 impl<'a> Display for Type<'a> {