about summary refs log tree commit diff
path: root/users/glittershark/achilles/src/tc/mod.rs
diff options
context:
space:
mode:
Diffstat (limited to 'users/glittershark/achilles/src/tc/mod.rs')
-rw-r--r--users/glittershark/achilles/src/tc/mod.rs10
1 files changed, 10 insertions, 0 deletions
diff --git a/users/glittershark/achilles/src/tc/mod.rs b/users/glittershark/achilles/src/tc/mod.rs
index 137561978f00..d27c45075e97 100644
--- a/users/glittershark/achilles/src/tc/mod.rs
+++ b/users/glittershark/achilles/src/tc/mod.rs
@@ -85,6 +85,7 @@ pub enum Type {
     Exist(TyVar),
     Nullary(NullaryType),
     Prim(PrimType),
+    Unit,
     Fun {
         args: Vec<Type>,
         ret: Box<Type>,
@@ -96,6 +97,7 @@ impl<'a> TryFrom<Type> for ast::Type<'a> {
 
     fn try_from(value: Type) -> result::Result<Self, Self::Error> {
         match value {
+            Type::Unit => Ok(ast::Type::Unit),
             Type::Univ(_) => todo!(),
             Type::Exist(_) => Err(value),
             Type::Nullary(_) => todo!(),
@@ -126,6 +128,7 @@ impl Display for Type {
             Type::Univ(TyVar(n)) => write!(f, "∀{}", n),
             Type::Exist(TyVar(n)) => write!(f, "∃{}", n),
             Type::Fun { args, ret } => write!(f, "fn {} -> {}", args.iter().join(", "), ret),
+            Type::Unit => write!(f, "()"),
         }
     }
 }
@@ -171,6 +174,7 @@ impl<'ast> Typechecker<'ast> {
                     Literal::Int(_) => Type::Prim(PrimType::Int),
                     Literal::Bool(_) => Type::Prim(PrimType::Bool),
                     Literal::String(_) => Type::Prim(PrimType::CString),
+                    Literal::Unit => Type::Unit,
                 };
                 Ok(hir::Expr::Literal(lit.to_owned(), type_))
             }
@@ -377,6 +381,7 @@ impl<'ast> Typechecker<'ast> {
 
     fn unify(&mut self, ty1: &Type, ty2: &Type) -> Result<Type> {
         match (ty1, ty2) {
+            (Type::Unit, Type::Unit) => Ok(Type::Unit),
             (Type::Exist(tv), ty) | (ty, Type::Exist(tv)) => match self.resolve_tv(*tv) {
                 Some(existing_ty) if self.types_match(ty, &existing_ty) => Ok(ty.clone()),
                 Some(var @ ast::Type::Var(_)) => {
@@ -466,6 +471,7 @@ impl<'ast> Typechecker<'ast> {
         let ret = match ty {
             Type::Exist(tv) => self.resolve_tv(tv).ok_or(Error::AmbiguousType(tv)),
             Type::Univ(tv) => Ok(ast::Type::Var(self.name_univ(tv))),
+            Type::Unit => Ok(ast::Type::Unit),
             Type::Nullary(_) => todo!(),
             Type::Prim(pr) => Ok(pr.into()),
             Type::Fun { args, ret } => Ok(ast::Type::Function(ast::FunctionType {
@@ -496,6 +502,7 @@ impl<'ast> Typechecker<'ast> {
                 }
                 Type::Nullary(_) => todo!(),
                 Type::Prim(pr) => break Some((*pr).into()),
+                Type::Unit => break Some(ast::Type::Unit),
                 Type::Fun { args, ret } => todo!(),
             }
         }
@@ -503,6 +510,7 @@ impl<'ast> Typechecker<'ast> {
 
     fn type_from_ast_type(&mut self, ast_type: ast::Type<'ast>) -> Type {
         match ast_type {
+            ast::Type::Unit => Type::Unit,
             ast::Type::Int => INT,
             ast::Type::Float => FLOAT,
             ast::Type::Bool => BOOL,
@@ -570,6 +578,8 @@ impl<'ast> Typechecker<'ast> {
             }
             (Type::Univ(_), _) => false,
             (Type::Exist(_), _) => false,
+            (Type::Unit, ast::Type::Unit) => true,
+            (Type::Unit, _) => false,
             (Type::Nullary(_), _) => todo!(),
             (Type::Prim(pr), ty) => ast::Type::from(*pr) == *ty,
             (Type::Fun { args, ret }, ast::Type::Function(ft)) => {