about summary refs log tree commit diff
path: root/src/ast/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/ast/mod.rs')
-rw-r--r--src/ast/mod.rs94
1 files changed, 82 insertions, 12 deletions
diff --git a/src/ast/mod.rs b/src/ast/mod.rs
index 2dcf955fe67c..7f3543271db8 100644
--- a/src/ast/mod.rs
+++ b/src/ast/mod.rs
@@ -2,10 +2,12 @@ use std::borrow::Cow;
 use std::convert::TryFrom;
 use std::fmt::{self, Display, Formatter};
 
+use itertools::Itertools;
+
 #[derive(Debug, PartialEq, Eq)]
 pub struct InvalidIdentifier<'a>(Cow<'a, str>);
 
-#[derive(Debug, PartialEq, Eq, Hash)]
+#[derive(Debug, PartialEq, Eq, Hash, Clone)]
 pub struct Ident<'a>(pub Cow<'a, str>);
 
 impl<'a> From<&'a Ident<'a>> for &'a str {
@@ -69,7 +71,7 @@ impl<'a> TryFrom<String> for Ident<'a> {
     }
 }
 
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Copy, Clone)]
 pub enum BinaryOperator {
     /// `+`
     Add,
@@ -93,7 +95,7 @@ pub enum BinaryOperator {
     Neq,
 }
 
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Copy, Clone)]
 pub enum UnaryOperator {
     /// !
     Not,
@@ -102,12 +104,12 @@ pub enum UnaryOperator {
     Neg,
 }
 
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Literal {
     Int(u64),
 }
 
-#[derive(Debug, PartialEq, Eq)]
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Expr<'a> {
     Ident(Ident<'a>),
 
@@ -134,33 +136,101 @@ pub enum Expr<'a> {
         then: Box<Expr<'a>>,
         else_: Box<Expr<'a>>,
     },
+
+    Fun(Box<Fun<'a>>),
+
+    Call {
+        fun: Box<Expr<'a>>,
+        args: Vec<Expr<'a>>,
+    },
 }
 
-#[derive(Debug, PartialEq, Eq)]
+impl<'a> Expr<'a> {
+    pub fn to_owned(&self) -> Expr<'static> {
+        match self {
+            Expr::Ident(ref id) => Expr::Ident(id.to_owned()),
+            Expr::Literal(ref lit) => Expr::Literal(lit.clone()),
+            Expr::UnaryOp { op, rhs } => Expr::UnaryOp {
+                op: *op,
+                rhs: Box::new((**rhs).to_owned()),
+            },
+            Expr::BinaryOp { lhs, op, rhs } => Expr::BinaryOp {
+                lhs: Box::new((**lhs).to_owned()),
+                op: *op,
+                rhs: Box::new((**rhs).to_owned()),
+            },
+            Expr::Let { bindings, body } => Expr::Let {
+                bindings: bindings
+                    .iter()
+                    .map(|(id, expr)| (id.to_owned(), expr.to_owned()))
+                    .collect(),
+                body: Box::new((**body).to_owned()),
+            },
+            Expr::If {
+                condition,
+                then,
+                else_,
+            } => Expr::If {
+                condition: Box::new((**condition).to_owned()),
+                then: Box::new((**then).to_owned()),
+                else_: Box::new((**else_).to_owned()),
+            },
+            Expr::Fun(fun) => Expr::Fun(Box::new((**fun).to_owned())),
+            Expr::Call { fun, args } => Expr::Call {
+                fun: Box::new((**fun).to_owned()),
+                args: args.iter().map(|arg| arg.to_owned()).collect(),
+            },
+        }
+    }
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub struct Fun<'a> {
-    pub name: Ident<'a>,
     pub args: Vec<Ident<'a>>,
     pub body: Expr<'a>,
 }
 
+impl<'a> Fun<'a> {
+    fn to_owned(&self) -> Fun<'static> {
+        Fun {
+            args: self.args.iter().map(|arg| arg.to_owned()).collect(),
+            body: self.body.to_owned(),
+        }
+    }
+}
+
 #[derive(Debug, PartialEq, Eq)]
 pub enum Decl<'a> {
-    Fun(Fun<'a>),
+    Fun { name: Ident<'a>, body: Fun<'a> },
+}
+
+#[derive(Debug, PartialEq, Eq, Clone)]
+pub struct FunctionType {
+    pub args: Vec<Type>,
+    pub ret: Box<Type>,
+}
+
+impl Display for FunctionType {
+    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
+        write!(f, "fn {} -> {}", self.args.iter().join(", "), self.ret)
+    }
 }
 
-#[derive(Debug, PartialEq, Eq, Clone, Copy)]
+#[derive(Debug, PartialEq, Eq, Clone)]
 pub enum Type {
     Int,
     Float,
     Bool,
+    Function(FunctionType),
 }
 
 impl Display for Type {
     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
         match self {
-            Self::Int => f.write_str("int"),
-            Self::Float => f.write_str("float"),
-            Self::Bool => f.write_str("bool"),
+            Type::Int => f.write_str("int"),
+            Type::Float => f.write_str("float"),
+            Type::Bool => f.write_str("bool"),
+            Type::Function(ft) => ft.fmt(f),
         }
     }
 }