diff --git a/src/main/java/com/compiler/StmtParser.java b/src/main/java/com/compiler/StmtParser.java index 5af42fd..334964b 100644 --- a/src/main/java/com/compiler/StmtParser.java +++ b/src/main/java/com/compiler/StmtParser.java @@ -42,7 +42,7 @@ public ASTStmtListNode parseStmtlist() throws Exception { Token curToken = m_lexer.lookAhead(); final List stmtList = new ArrayList<>(); - while (curToken.m_type != com.compiler.TokenIntf.Type.RBRACE) { + while (curToken.m_type != com.compiler.TokenIntf.Type.RBRACE && curToken.m_type != Type.CASE) { // RBrace and Case in Follow set of Statementlist stmtList.add(parseStmt()); curToken = m_lexer.lookAhead(); } @@ -83,6 +83,10 @@ public ASTStmtNode parseStmt() throws Exception { return parseNumericIfStmt(); } + if (type == Type.SWITCH) { + return parseSwitchStmt(); + } + if (type == Type.IF) { return parseIfElseStmt(); } @@ -326,4 +330,39 @@ ASTStmtNode parseExecuteNTimesStmt() throws Exception { return new ASTExecuteNTimesNode(count, stmtlistNode); } + + private ASTStmtNode parseSwitchStmt() throws Exception { + // switch_statement -> SWITCH LPAREN expression RPAREN LBRACE case_list RBRACE + m_lexer.expect(Type.SWITCH); + m_lexer.expect(Type.LPAREN); + ASTExprNode expression = m_exprParser.getQuestionMarkExpr(); + m_lexer.expect(Type.RPAREN); + m_lexer.expect(Type.LBRACE); + ASTCaseListNode caseList = parseCaseList(); + m_lexer.expect(Type.RBRACE); + return new ASTSwitchStmtNode(expression, caseList); + } + + private ASTCaseListNode parseCaseList() throws Exception { + // case_list -> case_item case_list | epsilon + Token curToken = m_lexer.lookAhead(); + final List caseList = new ArrayList<>(); + + while (curToken.m_type == Type.CASE) { + caseList.add(parseCaseStmt()); + curToken = m_lexer.lookAhead(); + } + return new ASTCaseListNode(caseList); + } + + private ASTCaseNode parseCaseStmt() throws Exception { + // case_item -> CASE LITERAL COLON statement_list + m_lexer.expect(Type.CASE); + Token curToken = m_lexer.lookAhead(); + m_lexer.expect(Type.INTEGER); + ASTIntegerLiteralNode value = new ASTIntegerLiteralNode(curToken.m_value); //Integer + m_lexer.expect(Type.DOUBLECOLON); // should be renamed to COLON as double colon is "::" + ASTStmtListNode stmtList = parseStmtlist(); + return new ASTCaseNode(value, stmtList); + } } \ No newline at end of file diff --git a/src/main/java/com/compiler/ast/ASTAssignStmtNode.java b/src/main/java/com/compiler/ast/ASTAssignStmtNode.java index d7440e5..e2a0828 100644 --- a/src/main/java/com/compiler/ast/ASTAssignStmtNode.java +++ b/src/main/java/com/compiler/ast/ASTAssignStmtNode.java @@ -31,6 +31,12 @@ public void codegen(com.compiler.CompileEnvIntf env) { } @Override - public void print(OutputStreamWriter outStream, String indent) throws Exception {} + public void print(OutputStreamWriter outStream, String indent) throws Exception { + outStream.write(indent); + outStream.write("ASTAssignStmtNode "); + outStream.write(identifier.m_name); + outStream.write("\n"); + expr.print(outStream, indent + " "); + } } diff --git a/src/main/java/com/compiler/ast/ASTCaseListNode.java b/src/main/java/com/compiler/ast/ASTCaseListNode.java new file mode 100644 index 0000000..64c1b24 --- /dev/null +++ b/src/main/java/com/compiler/ast/ASTCaseListNode.java @@ -0,0 +1,33 @@ +package com.compiler.ast; + +import com.compiler.CompileEnvIntf; + +import java.io.OutputStreamWriter; +import java.util.List; + +public class ASTCaseListNode extends ASTStmtNode { + + List m_caseList; + public ASTCaseListNode(List caseList) { + m_caseList = caseList; + } + + @Override + public void execute(OutputStreamWriter out) { + m_caseList.forEach(caseItem->caseItem.execute(out)); + } + + @Override + public void print(OutputStreamWriter outStream, String indent) throws Exception { + outStream.write(indent); + outStream.write("ASTCaseListNode\n"); + m_caseList.forEach(caseItem -> { + try { + caseItem.print(outStream, indent + " "); + outStream.write("\n"); + } catch (final Exception e) { + throw new RuntimeException(e); + } + }); + } +} diff --git a/src/main/java/com/compiler/ast/ASTCaseNode.java b/src/main/java/com/compiler/ast/ASTCaseNode.java new file mode 100644 index 0000000..428109f --- /dev/null +++ b/src/main/java/com/compiler/ast/ASTCaseNode.java @@ -0,0 +1,31 @@ +package com.compiler.ast; + +import com.compiler.CompileEnvIntf; + +import java.io.OutputStreamWriter; + +public class ASTCaseNode extends ASTStmtNode { + + ASTIntegerLiteralNode m_value; + ASTStmtListNode m_stmtList; + + int expressionValue; // Value gets written on Execution of Switch Block + + public ASTCaseNode(ASTIntegerLiteralNode value, ASTStmtListNode stmtList) { + m_stmtList = stmtList; + m_value = value; + + } + + @Override + public void execute(OutputStreamWriter out) { + if(m_value.eval()==expressionValue) m_stmtList.execute(out); + } + + @Override + public void print(OutputStreamWriter outStream, String indent) throws Exception { + outStream.write(indent); + outStream.write("ASTCaseNode\n"); + m_stmtList.print(outStream, indent + " "); + } +} diff --git a/src/main/java/com/compiler/ast/ASTSwitchStmtNode.java b/src/main/java/com/compiler/ast/ASTSwitchStmtNode.java new file mode 100644 index 0000000..a1d224d --- /dev/null +++ b/src/main/java/com/compiler/ast/ASTSwitchStmtNode.java @@ -0,0 +1,93 @@ +package com.compiler.ast; + +import com.compiler.*; +import com.compiler.instr.*; + +import java.io.OutputStreamWriter; +import java.util.ArrayList; +import java.util.List; + +public class ASTSwitchStmtNode extends ASTStmtNode { + + ASTExprNode m_expression; + ASTCaseListNode m_caseList; + + int evaluatedExpression; + + public ASTSwitchStmtNode(ASTExprNode expression, ASTCaseListNode caseList) { + m_expression = expression; + m_caseList = caseList; + } + + @Override + public void execute(OutputStreamWriter out) { + evaluatedExpression = m_expression.eval(); + m_caseList.m_caseList.forEach(caseItem-> caseItem.expressionValue = evaluatedExpression); + m_caseList.execute(out); + } + + @Override + public void codegen(CompileEnvIntf env) { + InstrBlock headBlock = env.createBlock(env.createUniqueSymbol("Switch_Head").m_name); + env.addInstr(new InstrJump(headBlock)); + env.setCurrentBlock(headBlock); + InstrBlock exitBlock = env.createBlock(env.createUniqueSymbol("Switch_Exit").m_name); + + if(m_caseList.m_caseList.isEmpty()){ + env.addInstr(new InstrJump(exitBlock)); + env.setCurrentBlock(exitBlock); + return; + } + + InstrIntf switchExpression = m_expression.codegen(env); + Symbol switchExpressionSymbol = env.createUniqueSymbol("switch"); + InstrIntf assignSwitchExpression = new InstrAssign(switchExpressionSymbol, switchExpression); + env.addInstr(assignSwitchExpression); + + //Body Block Loop + List bodyBlocks = new ArrayList<>(); + for (int i = 0; i < m_caseList.m_caseList.size(); i++) { + InstrBlock bodyBlock = env.createBlock(env.createUniqueSymbol("CaseExecute" + i).m_name); + bodyBlocks.add(bodyBlock); + env.setCurrentBlock(bodyBlock); + m_caseList.m_caseList.get(i).m_stmtList.codegen(env); + env.addInstr(new InstrJump(exitBlock)); + } + + //Check Block Loop + List checkBlocks = new ArrayList<>(); + List equalInstrs = new ArrayList<>(); + for (int i = 0; i < m_caseList.m_caseList.size(); i++) { + InstrBlock caseBlock = env.createBlock(env.createUniqueSymbol("CaseCheck" + i).m_name); + checkBlocks.add(caseBlock); + env.setCurrentBlock(caseBlock); + InstrIntf switchValue = new InstrVariableExpr(switchExpressionSymbol); + env.addInstr(switchValue); + InstrIntf caseLiteral = m_caseList.m_caseList.get(i).m_value.codegen(env); + InstrIntf equals = new InstrCompare(TokenIntf.Type.EQUAL, switchValue, caseLiteral); + equalInstrs.add(equals); + env.addInstr(equals); + } + + for (int i = 0; i < checkBlocks.size() - 1; i++) { + env.setCurrentBlock(checkBlocks.get(i)); + env.addInstr(new InstrCondJump(equalInstrs.get(i), bodyBlocks.get(i), checkBlocks.get(i+1))); + } + + env.setCurrentBlock(checkBlocks.get(checkBlocks.size()-1)); + env.addInstr(new InstrCondJump(equalInstrs.get(checkBlocks.size()-1), bodyBlocks.get(checkBlocks.size()-1), exitBlock)); + + env.setCurrentBlock(headBlock); + env.addInstr(new InstrJump(checkBlocks.get(0))); + + env.setCurrentBlock(exitBlock); + } + + @Override + public void print(OutputStreamWriter outStream, String indent) throws Exception { + outStream.write(indent); + outStream.write("ASTSwitchStmtNode\n"); + m_expression.print(outStream, indent + " "); + m_caseList.print(outStream, indent + " "); + } +} diff --git a/src/test/java/com/compiler/InterpreterSwitchStmtTest.java b/src/test/java/com/compiler/InterpreterSwitchStmtTest.java new file mode 100644 index 0000000..961f02f --- /dev/null +++ b/src/test/java/com/compiler/InterpreterSwitchStmtTest.java @@ -0,0 +1,81 @@ +package com.compiler; + +import org.junit.Test; + +public class InterpreterSwitchStmtTest extends InterpreterTestBase{ + @Test + public void testSwitchProgram01() throws Exception { + String program = """ + { + DECLARE in; + DECLARE out; + in = 2; + out = 0; + SWITCH(in) { + CASE 1: + out = 2; + CASE 2: + out = 3; + } + PRINT out; + } + """; + testInterpreter(program, "3\n"); + } + + + @Test + public void testSwitchProgram02() throws Exception { + String program = """ + { + DECLARE in; + DECLARE out; + in = 3; + out = 0; + SWITCH(in) { + CASE 1: + out = 2; + CASE 2: + out = 3; + CASE 4: + out = 3; + CASE 3: + out = 5; + CASE 5: + out = 2; + } + PRINT out; + } + """; + testInterpreter(program, "5\n"); + } + + + @Test + public void testSwitchProgram03() throws Exception { + String program = """ + { + DECLARE in; + DECLARE out; + in = 3; + out = 0; + SWITCH(in + 1) { + CASE 1: + out = 2; + CASE 2: + out = 3; + CASE 4: + out = 3; + in = 4; + CASE 3: + out = 5; + CASE 5: + out = 2; + } + PRINT in; + PRINT out; + } + """; + testInterpreter(program, "4\n3\n"); + } +} diff --git a/src/test/java/com/compiler/StmtSwitchStmtParserTest.java b/src/test/java/com/compiler/StmtSwitchStmtParserTest.java new file mode 100644 index 0000000..2e5ab52 --- /dev/null +++ b/src/test/java/com/compiler/StmtSwitchStmtParserTest.java @@ -0,0 +1,82 @@ +package com.compiler; + +import org.junit.Test; + +public class StmtSwitchStmtParserTest extends StmtParserTestBase{ + + @Test + public void testSwitchProgram01() throws Exception { + String program = """ + { + DECLARE in; + DECLARE out; + in = 2; + out = 0; + SWITCH(in) { + CASE 1: + out = 2; + CASE 2: + out = 3; + } + PRINT out; + } + """; + testParser(program, "3\n"); + } + + + @Test + public void testSwitchProgram02() throws Exception { + String program = """ + { + DECLARE in; + DECLARE out; + in = 3; + out = 0; + SWITCH(in) { + CASE 1: + out = 2; + CASE 2: + out = 3; + CASE 4: + out = 3; + CASE 3: + out = 5; + CASE 5: + out = 2; + } + PRINT out; + } + """; + testParser(program, "5\n"); + } + + + @Test + public void testSwitchProgram03() throws Exception { + String program = """ + { + DECLARE in; + DECLARE out; + in = 3; + out = 0; + SWITCH(in + 1) { + CASE 1: + out = 2; + CASE 2: + out = 3; + CASE 4: + out = 3; + in = 4; + CASE 3: + out = 5; + CASE 5: + out = 2; + } + PRINT in; + PRINT out; + } + """; + testParser(program, "4\n3\n"); + } +}