From c061ee5d35e688fb995e2b3603374f253e3d6410 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 06:50:10 +0100 Subject: [PATCH 1/7] Variables, assignments & scopes --- interpreter.py | 40 +++++++++++++++++++++++++++------ main.py | 5 +++-- test_main.py | 1 + tests/interpreter/variables.txt | 16 +++++++++++++ utils/memory.py | 38 +++++++++++++++++++++++++++++++ 5 files changed, 91 insertions(+), 9 deletions(-) create mode 100644 tests/interpreter/variables.txt create mode 100644 utils/memory.py diff --git a/interpreter.py b/interpreter.py index d3370a7..b1853ee 100644 --- a/interpreter.py +++ b/interpreter.py @@ -3,18 +3,25 @@ from generated.MyParser import MyParser from generated.MyParserVisitor import MyParserVisitor +from utils.memory import MemoryStack from utils.values import Int, Float, String, Vector class Interpreter(MyParserVisitor): + def __init__(self): + self.memory_stack = MemoryStack() + self.memory_stack.push_memory() + def visitScopeStatement(self, ctx: MyParser.ScopeStatementContext): - return self.visitChildren(ctx) # todo + self.memory_stack.push_memory() + self.visitChildren(ctx) + self.memory_stack.pop_memory() def visitIfThenElse(self, ctx: MyParser.IfThenElseContext): condition = self.visit(ctx.if_()) if condition: return self.visit(ctx.then()) - elif ctx.else_() is not None: + elif ctx.else_(): return self.visit(ctx.else_()) def visitIf(self, ctx: MyParser.IfContext): @@ -50,17 +57,36 @@ def visitComparison(self, ctx: MyParser.ComparisonContext): return a >= b def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext): - return self.visitChildren(ctx) # todo + if ctx.id_(): + self.visitChildren(ctx) + self.memory_stack.put(ctx.id_().getText(), self.visit(ctx.expression())) + else: + pass # todo def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext): - return self.visitChildren(ctx) # todo + if ctx.id_(): + self.visitChildren(ctx) + old_value = self.memory_stack.get(ctx.id_().getText()) + new_value = self.visit(ctx.expression()) + match ctx.getChild(1).symbol.type: + case MyParser.ASSIGN_PLUS: + new_value = old_value + new_value + case MyParser.ASSIGN_MINUS: + new_value = old_value - new_value + case MyParser.ASSIGN_MULTIPLY: + new_value = old_value * new_value + case MyParser.ASSIGN_DIVIDE: + new_value = old_value / new_value + self.memory_stack.put(ctx.id_().getText(), new_value) + else: + pass # todo def visitPrint(self, ctx: MyParser.PrintContext): for i in range(ctx.getChildCount() // 2): print(str(self.visit(ctx.expression(i)))) def visitReturn(self, ctx: MyParser.ReturnContext): - if ctx.expression() is not None: + if ctx.expression(): return_value = self.visit(ctx.expression()) if not isinstance(return_value, Int): raise TypeError @@ -132,7 +158,7 @@ def visitElementReference(self, ctx: MyParser.ElementReferenceContext): return self.visitChildren(ctx) # todo def visitId(self, ctx: MyParser.IdContext): - return self.visitChildren(ctx) # todo + return self.memory_stack.get(ctx.getText()) def visitInt(self, ctx: MyParser.IntContext): return Int(ctx.getText()) @@ -141,4 +167,4 @@ def visitFloat(self, ctx: MyParser.FloatContext): return Float(ctx.getText()) def visitString(self, ctx: MyParser.StringContext): - return String(ctx.getText()) + return String(ctx.getText()[1:-1]) # without quotes diff --git a/main.py b/main.py index 68a5e02..12b5082 100755 --- a/main.py +++ b/main.py @@ -87,8 +87,9 @@ def run(filename: str): tree = parser.program() if parser.getNumberOfSyntaxErrors() == 0: - listener = SemanticListener() - ParseTreeWalker().walk(listener, tree) + # todo: Fix SemanticListener + # listener = SemanticListener() + # ParseTreeWalker().walk(listener, tree) if parser.getNumberOfSyntaxErrors() == 0: visitor = Interpreter() visitor.visit(tree) diff --git a/test_main.py b/test_main.py index 7c13829..0bd432c 100644 --- a/test_main.py +++ b/test_main.py @@ -81,6 +81,7 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str): [[1, 1, 1]], ], ), + ("variables", [2, 1, 3, "OK", 6]), ], ) def test_interpreter(name: str, output: str): diff --git a/tests/interpreter/variables.txt b/tests/interpreter/variables.txt new file mode 100644 index 0000000..d32f2e4 --- /dev/null +++ b/tests/interpreter/variables.txt @@ -0,0 +1,16 @@ +a = 2; +print a; + +a -= 1; +print a; + +a *= 3; +print a; + +if (a == 3) { + b = "OK"; + print b; +} + +b = 2 * a; +print b; diff --git a/utils/memory.py b/utils/memory.py new file mode 100644 index 0000000..11ad583 --- /dev/null +++ b/utils/memory.py @@ -0,0 +1,38 @@ +from .values import Value + + +class Memory: + def __init__(self): + self.variables: dict[str, Value] = {} + + def has_variable(self, name: str) -> bool: + return name in self.variables + + def get(self, name: str) -> Value: + return self.variables[name] + + def put(self, name: str, value: Value): + self.variables[name] = value + + +class MemoryStack: + def __init__(self): + self.stack: list[Memory] = [] + + def get(self, name: str) -> Value: + for memory in self.stack: + if memory.has_variable(name): + return memory.get(name) + + def put(self, name: str, value: Value): + for memory in self.stack: + if memory.has_variable(name): + memory.put(name, value) + return + self.stack[-1].put(name, value) + + def push_memory(self): + self.stack.append(Memory()) + + def pop_memory(self): + self.stack.pop() From 6d7eedad212dfd6cc9f67585acabf67a8fbcf4cb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 07:08:14 +0100 Subject: [PATCH 2/7] Interpret loops --- interpreter.py | 16 +++++++++++++--- test_main.py | 2 ++ tests/interpreter/for.txt | 6 ++++++ tests/interpreter/while.txt | 7 +++++++ 4 files changed, 28 insertions(+), 3 deletions(-) create mode 100644 tests/interpreter/for.txt create mode 100644 tests/interpreter/while.txt diff --git a/interpreter.py b/interpreter.py index b1853ee..5b3577e 100644 --- a/interpreter.py +++ b/interpreter.py @@ -31,13 +31,23 @@ def visitElse(self, ctx: MyParser.ElseContext): return self.visit(ctx.statement()) def visitForLoop(self, ctx: MyParser.ForLoopContext): - return self.visitChildren(ctx) # todo + a, b = self.visit(ctx.range_()) + variable = ctx.id_().getText() + while a <= b: + self.memory_stack.put(variable, a) + a = a + 1 # to increment enumerateor and disregard changes inside the loop + self.visit(ctx.statement()) def visitRange(self, ctx: MyParser.RangeContext): - return self.visitChildren(ctx) # todo + a = self.visit(ctx.expression(0)) + b = self.visit(ctx.expression(1)) + if {type(a), type(b)} != {Int}: + raise TypeError + return (a.value, b.value) def visitWhileLoop(self, ctx: MyParser.WhileLoopContext): - return self.visitChildren(ctx) # todo + while self.visit(ctx.comparison()): + self.visit(ctx.statement()) def visitComparison(self, ctx: MyParser.ComparisonContext): a = self.visit(ctx.expression(0)) diff --git a/test_main.py b/test_main.py index 0bd432c..3c4e2af 100644 --- a/test_main.py +++ b/test_main.py @@ -82,6 +82,8 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str): ], ), ("variables", [2, 1, 3, "OK", 6]), + ("while", [4, 3, 2, 1, 0]), + ("for", [1, 10, 2, 10, 3, 10, 4, 10]), ], ) def test_interpreter(name: str, output: str): diff --git a/tests/interpreter/for.txt b/tests/interpreter/for.txt new file mode 100644 index 0000000..1dc1017 --- /dev/null +++ b/tests/interpreter/for.txt @@ -0,0 +1,6 @@ +n = 4; +for i = 1:n { + print i; + i = 10; + print i; +} \ No newline at end of file diff --git a/tests/interpreter/while.txt b/tests/interpreter/while.txt new file mode 100644 index 0000000..d0ddd5c --- /dev/null +++ b/tests/interpreter/while.txt @@ -0,0 +1,7 @@ +b = -4; +a = 4; +while (a >= b) { + print a; + a -= 1; + b += 1; +} \ No newline at end of file From dc1b821a290c25dd0b1444585f210f762ce24ad4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 07:10:29 +0100 Subject: [PATCH 3/7] Fix for loop handling --- interpreter.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/interpreter.py b/interpreter.py index 5b3577e..42ee356 100644 --- a/interpreter.py +++ b/interpreter.py @@ -34,7 +34,7 @@ def visitForLoop(self, ctx: MyParser.ForLoopContext): a, b = self.visit(ctx.range_()) variable = ctx.id_().getText() while a <= b: - self.memory_stack.put(variable, a) + self.memory_stack.put(variable, Int(a)) a = a + 1 # to increment enumerateor and disregard changes inside the loop self.visit(ctx.statement()) From c1d05476d0b3ce37e04f8e0c859ead2fa93ec522 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 07:22:30 +0100 Subject: [PATCH 4/7] Implement break and continue --- interpreter.py | 26 ++++++++++++++++++++++---- test_main.py | 1 + tests/interpreter/break_continue.txt | 27 +++++++++++++++++++++++++++ 3 files changed, 50 insertions(+), 4 deletions(-) create mode 100644 tests/interpreter/break_continue.txt diff --git a/interpreter.py b/interpreter.py index 42ee356..c3f9170 100644 --- a/interpreter.py +++ b/interpreter.py @@ -7,6 +7,14 @@ from utils.values import Int, Float, String, Vector +class Break(Exception): + pass + + +class Continue(Exception): + pass + + class Interpreter(MyParserVisitor): def __init__(self): self.memory_stack = MemoryStack() @@ -36,7 +44,12 @@ def visitForLoop(self, ctx: MyParser.ForLoopContext): while a <= b: self.memory_stack.put(variable, Int(a)) a = a + 1 # to increment enumerateor and disregard changes inside the loop - self.visit(ctx.statement()) + try: + self.visit(ctx.statement()) + except Continue: + continue + except Break: + break def visitRange(self, ctx: MyParser.RangeContext): a = self.visit(ctx.expression(0)) @@ -47,7 +60,12 @@ def visitRange(self, ctx: MyParser.RangeContext): def visitWhileLoop(self, ctx: MyParser.WhileLoopContext): while self.visit(ctx.comparison()): - self.visit(ctx.statement()) + try: + self.visit(ctx.statement()) + except Continue: + continue + except Break: + break def visitComparison(self, ctx: MyParser.ComparisonContext): a = self.visit(ctx.expression(0)) @@ -153,10 +171,10 @@ def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext) return vector def visitBreak(self, ctx: MyParser.BreakContext): - return self.visitChildren(ctx) # todo + raise Break() def visitContinue(self, ctx: MyParser.ContinueContext): - return self.visitChildren(ctx) # todo + raise Continue() def visitVector(self, ctx: MyParser.VectorContext): elements = [ diff --git a/test_main.py b/test_main.py index 3c4e2af..0f57290 100644 --- a/test_main.py +++ b/test_main.py @@ -84,6 +84,7 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str): ("variables", [2, 1, 3, "OK", 6]), ("while", [4, 3, 2, 1, 0]), ("for", [1, 10, 2, 10, 3, 10, 4, 10]), + ("break_continue", [1, 2, 1, 2, 4] * 2), ], ) def test_interpreter(name: str, output: str): diff --git a/tests/interpreter/break_continue.txt b/tests/interpreter/break_continue.txt new file mode 100644 index 0000000..f55a66b --- /dev/null +++ b/tests/interpreter/break_continue.txt @@ -0,0 +1,27 @@ +for i = 1:4 { + if (i == 3) + break; + print i; +} + +for i = 1:4 { + if (i == 3) + continue; + print i; +} + +i = 0; +while (i < 4) { + i += 1; + if (i == 3) + break; + print i; +} + +i = 0; +while (i < 4) { + i += 1; + if (i == 3) + continue; + print i; +} \ No newline at end of file From c852cf4d1b37f4e14ceb5c2eef14ea2fc12df852 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 10:17:35 +0100 Subject: [PATCH 5/7] Fix for loop --- interpreter.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/interpreter.py b/interpreter.py index c3f9170..802d85e 100644 --- a/interpreter.py +++ b/interpreter.py @@ -41,9 +41,8 @@ def visitElse(self, ctx: MyParser.ElseContext): def visitForLoop(self, ctx: MyParser.ForLoopContext): a, b = self.visit(ctx.range_()) variable = ctx.id_().getText() - while a <= b: - self.memory_stack.put(variable, Int(a)) - a = a + 1 # to increment enumerateor and disregard changes inside the loop + for i in range(a, b + 1): + self.memory_stack.put(variable, Int(i)) try: self.visit(ctx.statement()) except Continue: From dfd512a928f288f1b0fc1ee0c8aa132e3b017a1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 11:08:48 +0100 Subject: [PATCH 6/7] Interpret vector element references --- interpreter.py | 62 +++++++++++++++++++------ test_main.py | 10 ++++ tests/interpreter/element_reference.txt | 9 ++++ utils/values.py | 8 +++- 4 files changed, 73 insertions(+), 16 deletions(-) create mode 100644 tests/interpreter/element_reference.txt diff --git a/interpreter.py b/interpreter.py index 802d85e..bfeaf2d 100644 --- a/interpreter.py +++ b/interpreter.py @@ -4,7 +4,7 @@ from generated.MyParser import MyParser from generated.MyParserVisitor import MyParserVisitor from utils.memory import MemoryStack -from utils.values import Int, Float, String, Vector +from utils.values import Value, Int, Float, String, Vector class Break(Exception): @@ -15,6 +15,13 @@ class Continue(Exception): pass +def not_same_type(a: Value, b: Value): + return type(a) is not type(b) or ( + isinstance(a, Vector) + and (a.dims != b.dims or a.primitive_type != b.primitive_type) + ) + + class Interpreter(MyParserVisitor): def __init__(self): self.memory_stack = MemoryStack() @@ -84,29 +91,44 @@ def visitComparison(self, ctx: MyParser.ComparisonContext): return a >= b def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext): - if ctx.id_(): - self.visitChildren(ctx) + if ctx.id_(): # a = 1 self.memory_stack.put(ctx.id_().getText(), self.visit(ctx.expression())) - else: - pass # todo + else: # a[0] = 1 + ref_value = self.visit(ctx.elementReference()) + new_value = self.visit(ctx.expression()) + if not_same_type(ref_value, new_value): + raise TypeError + ref_value.value = new_value.value def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext): - if ctx.id_(): - self.visitChildren(ctx) - old_value = self.memory_stack.get(ctx.id_().getText()) + if ctx.id_(): # a += 1 + value = self.memory_stack.get(ctx.id_().getText()) new_value = self.visit(ctx.expression()) match ctx.getChild(1).symbol.type: case MyParser.ASSIGN_PLUS: - new_value = old_value + new_value + new_value = value + new_value case MyParser.ASSIGN_MINUS: - new_value = old_value - new_value + new_value = value - new_value case MyParser.ASSIGN_MULTIPLY: - new_value = old_value * new_value + new_value = value * new_value case MyParser.ASSIGN_DIVIDE: - new_value = old_value / new_value + new_value = value / new_value self.memory_stack.put(ctx.id_().getText(), new_value) - else: - pass # todo + else: # a[0] += 1 + ref_value = self.visit(ctx.elementReference()) + new_value = self.visit(ctx.expression()) + if not_same_type(ref_value, new_value): + raise TypeError + match ctx.getChild(1).symbol.type: + case MyParser.ASSIGN_PLUS: + new_value = ref_value + new_value + case MyParser.ASSIGN_MINUS: + new_value = ref_value - new_value + case MyParser.ASSIGN_MULTIPLY: + new_value = ref_value * new_value + case MyParser.ASSIGN_DIVIDE: + new_value = ref_value / new_value + ref_value.value = new_value.value def visitPrint(self, ctx: MyParser.PrintContext): for i in range(ctx.getChildCount() // 2): @@ -182,7 +204,17 @@ def visitVector(self, ctx: MyParser.VectorContext): return Vector(elements) def visitElementReference(self, ctx: MyParser.ElementReferenceContext): - return self.visitChildren(ctx) # todo + indices = [ + self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1) + ] + if {type(idx) for idx in indices} != {Int}: + raise TypeError + result = self.visit(ctx.id_()) + for idx in indices: + if not isinstance(result, Vector): + raise TypeError + result = result.value[idx.value] + return result def visitId(self, ctx: MyParser.IdContext): return self.memory_stack.get(ctx.getText()) diff --git a/test_main.py b/test_main.py index 0f57290..7a808eb 100644 --- a/test_main.py +++ b/test_main.py @@ -85,6 +85,16 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str): ("while", [4, 3, 2, 1, 0]), ("for", [1, 10, 2, 10, 3, 10, 4, 10]), ("break_continue", [1, 2, 1, 2, 4] * 2), + ( + "element_reference", + [ + [1, 0], + 0, + [[1, 2], [0, 1]], + [[0, 2], [0, 1]], + [[0, 0], [0, 1]], + ], + ), ], ) def test_interpreter(name: str, output: str): diff --git a/tests/interpreter/element_reference.txt b/tests/interpreter/element_reference.txt new file mode 100644 index 0000000..3118628 --- /dev/null +++ b/tests/interpreter/element_reference.txt @@ -0,0 +1,9 @@ +A = eye(2); +print A[0]; +print A[0, 1]; +A[0, 1] = 2; +print A; +A[0] = [0, 2]; +print A; +A[0, 1] -= 2; +print A; \ No newline at end of file diff --git a/utils/values.py b/utils/values.py index 2f24537..34191ea 100644 --- a/utils/values.py +++ b/utils/values.py @@ -108,7 +108,11 @@ def __init__(self, value: list): if ( len( { - (type(elem), elem.dims if isinstance(elem, Vector) else None) + ( + (elem.dims, elem.primitive_type) + if isinstance(elem, Vector) + else type(elem) + ) for elem in value } ) @@ -118,8 +122,10 @@ def __init__(self, value: list): if isinstance(value[0], Vector): self.dims = (len(value), *value[0].dims) + self.primitive_type = value[0].primitive_type else: self.dims = (len(value),) + self.primitive_type = type(value[0]) def __str__(self): return "[" + ", ".join(str(elem) for elem in self.value) + "]" From 0eef1fc21884ac97ac43996de0a38467477fe2ee Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Micha=C5=82=20Dobranowski?= Date: Sat, 14 Dec 2024 11:35:45 +0100 Subject: [PATCH 7/7] Interpret matrix operators --- interpreter.py | 9 +++++- test_main.py | 8 ++++++ tests/interpreter/mat_operators.txt | 9 ++++++ utils/values.py | 44 +++++++++++++++++++++++++++++ 4 files changed, 69 insertions(+), 1 deletion(-) create mode 100644 tests/interpreter/mat_operators.txt diff --git a/interpreter.py b/interpreter.py index bfeaf2d..1ae9947 100644 --- a/interpreter.py +++ b/interpreter.py @@ -154,7 +154,14 @@ def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext): return a * b case MyParser.DIVIDE: return a / b - # todo: MAT_* operations + case MyParser.MAT_PLUS: + return a.mat_add(b) + case MyParser.MAT_MINUS: + return a.mat_sub(b) + case MyParser.MAT_MULTIPLY: + return a.mat_mul(b) + case MyParser.MAT_DIVIDE: + return a.mat_truediv(b) def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext): return self.visit(ctx.expression()) diff --git a/test_main.py b/test_main.py index 7a808eb..284dbd2 100644 --- a/test_main.py +++ b/test_main.py @@ -95,6 +95,14 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str): [[0, 0], [0, 1]], ], ), + ( + "mat_operators", + [ + [[2, 2], [2, 2]], + [[4, 4], [4, 4]], + [[3, 3], [3, 3]], + ], + ), ], ) def test_interpreter(name: str, output: str): diff --git a/tests/interpreter/mat_operators.txt b/tests/interpreter/mat_operators.txt new file mode 100644 index 0000000..73f7297 --- /dev/null +++ b/tests/interpreter/mat_operators.txt @@ -0,0 +1,9 @@ +A = ones(2, 2); +B = ones(2, 2); +A = A .+ B; +print A; +A = A .* A; +print A; +A = A .- B; +print A; +A = A ./ B; \ No newline at end of file diff --git a/utils/values.py b/utils/values.py index 34191ea..2a8ac6b 100644 --- a/utils/values.py +++ b/utils/values.py @@ -130,6 +130,50 @@ def __init__(self, value: list): def __str__(self): return "[" + ", ".join(str(elem) for elem in self.value) + "]" + def mat_add(self, other): + if isinstance(other, Vector): + rows = [] + for elem, other_elem in zip(self.value, other.value): + if isinstance(elem, Vector): + rows.append(elem.mat_add(other_elem)) + else: + rows.append(elem + other_elem) + return Vector(rows) + raise TypeError() + + def mat_sub(self, other): + if isinstance(other, Vector): + rows = [] + for elem, other_elem in zip(self.value, other.value): + if isinstance(elem, Vector): + rows.append(elem.mat_sub(other_elem)) + else: + rows.append(elem - other_elem) + return Vector(rows) + raise TypeError() + + def mat_mul(self, other): + if isinstance(other, Vector): + rows = [] + for elem, other_elem in zip(self.value, other.value): + if isinstance(elem, Vector): + rows.append(elem.mat_mul(other_elem)) + else: + rows.append(elem * other_elem) + return Vector(rows) + raise TypeError() + + def mat_truediv(self, other): + if isinstance(other, Vector): + rows = [] + for elem, other_elem in zip(self.value, other.value): + if isinstance(elem, Vector): + rows.append(elem.mat_truediv(other_elem)) + else: + rows.append(elem / other_elem) + return Vector(rows) + raise TypeError() + def transpose(self): if len(self.dims) != 2: raise TypeError