about summary refs log tree commit diff
path: root/users/tazjin/rlox/src/bytecode/compiler.rs
diff options
context:
space:
mode:
Diffstat (limited to 'users/tazjin/rlox/src/bytecode/compiler.rs')
-rw-r--r--users/tazjin/rlox/src/bytecode/compiler.rs65
1 files changed, 51 insertions, 14 deletions
diff --git a/users/tazjin/rlox/src/bytecode/compiler.rs b/users/tazjin/rlox/src/bytecode/compiler.rs
index 392fc9b72d..1b87e94a55 100644
--- a/users/tazjin/rlox/src/bytecode/compiler.rs
+++ b/users/tazjin/rlox/src/bytecode/compiler.rs
@@ -1,7 +1,7 @@
 use super::chunk::Chunk;
 use super::errors::{Error, ErrorKind, LoxResult};
 use super::interner::{InternedStr, Interner};
-use super::opcode::{CodeIdx, ConstantIdx, OpCode, StackIdx};
+use super::opcode::{CodeIdx, CodeOffset, ConstantIdx, OpCode, StackIdx};
 use super::value::Value;
 use crate::scanner::{self, Token, TokenKind};
 
@@ -236,9 +236,13 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
 
     fn define_variable(&mut self, var: Option<ConstantIdx>) -> LoxResult<()> {
         if self.locals.scope_depth == 0 {
-            self.emit_op(OpCode::OpDefineGlobal(var.expect("should be global")));
+            self.emit_op(OpCode::OpDefineGlobal(
+                var.expect("should be global"),
+            ));
         } else {
-            self.locals.locals.last_mut()
+            self.locals
+                .locals
+                .last_mut()
                 .expect("fatal: variable not yet added at definition")
                 .depth = Depth::At(self.locals.scope_depth);
         }
@@ -263,6 +267,8 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
     fn statement(&mut self) -> LoxResult<()> {
         if self.match_token(&TokenKind::Print) {
             self.print_statement()
+        } else if self.match_token(&TokenKind::If) {
+            self.if_statement()
         } else if self.match_token(&TokenKind::LeftBrace) {
             self.begin_scope();
             self.block()?;
@@ -289,7 +295,9 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
         self.locals.scope_depth -= 1;
 
         while self.locals.locals.len() > 0
-            && self.locals.locals[self.locals.locals.len() - 1].depth.above(self.locals.scope_depth)
+            && self.locals.locals[self.locals.locals.len() - 1]
+                .depth
+                .above(self.locals.scope_depth)
         {
             self.emit_op(OpCode::OpPop);
             self.locals.locals.remove(self.locals.locals.len() - 1);
@@ -319,6 +327,28 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
         Ok(())
     }
 
+    fn if_statement(&mut self) -> LoxResult<()> {
+        consume!(
+            self,
+            TokenKind::LeftParen,
+            ErrorKind::ExpectedToken("Expected '(' after 'if'")
+        );
+
+        self.expression()?;
+
+        consume!(
+            self,
+            TokenKind::RightParen,
+            ErrorKind::ExpectedToken("Expected ')' after condition")
+        );
+
+        let then_jump = self.emit_op(OpCode::OpJumpPlaceholder(false));
+        self.statement()?;
+        self.patch_jump(then_jump);
+
+        Ok(())
+    }
+
     fn number(&mut self) -> LoxResult<()> {
         if let TokenKind::Number(num) = self.previous().kind {
             self.emit_constant(Value::Number(num), true);
@@ -431,16 +461,12 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
             self.expression()?;
             match local_idx {
                 Some(idx) => self.emit_op(OpCode::OpSetLocal(idx)),
-                None => {
-                    self.emit_op(OpCode::OpSetGlobal(ident.unwrap()))
-                }
+                None => self.emit_op(OpCode::OpSetGlobal(ident.unwrap())),
             };
         } else {
             match local_idx {
                 Some(idx) => self.emit_op(OpCode::OpGetLocal(idx)),
-                None => {
-                    self.emit_op(OpCode::OpGetGlobal(ident.unwrap()))
-                }
+                None => self.emit_op(OpCode::OpGetGlobal(ident.unwrap())),
             };
         }
 
@@ -477,10 +503,7 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
         Ok(())
     }
 
-    fn identifier_str(
-        &mut self,
-        token: &Token,
-    ) -> LoxResult<InternedStr> {
+    fn identifier_str(&mut self, token: &Token) -> LoxResult<InternedStr> {
         let ident = match &token.kind {
             TokenKind::Identifier(ident) => ident.to_string(),
             _ => {
@@ -594,6 +617,20 @@ impl<T: Iterator<Item = Token>> Compiler<T> {
         idx
     }
 
+    fn patch_jump(&mut self, idx: CodeIdx) {
+        let offset = CodeOffset(self.chunk.code.len() - idx.0 - 1);
+
+        if let OpCode::OpJumpPlaceholder(false) = self.chunk.code[idx.0] {
+            self.chunk.code[idx.0] = OpCode::OpJumpIfFalse(offset);
+            return;
+        }
+
+        panic!(
+            "attempted to patch unsupported op: {:?}",
+            self.chunk.code[idx.0]
+        );
+    }
+
     fn previous(&self) -> &Token {
         self.previous
             .as_ref()