Skip to content

Commit

Permalink
Merge pull request #10 from mdbrnowski/interpretation
Browse files Browse the repository at this point in the history
Complete interpretation
  • Loading branch information
mdbrnowski authored Dec 14, 2024
2 parents ef43a2f + 0eef1fc commit ff83d67
Show file tree
Hide file tree
Showing 11 changed files with 295 additions and 18 deletions.
122 changes: 107 additions & 15 deletions interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,40 @@

from generated.MyParser import MyParser
from generated.MyParserVisitor import MyParserVisitor
from utils.values import Int, Float, String, Vector
from utils.memory import MemoryStack
from utils.values import Value, Int, Float, String, Vector


class Break(Exception):
pass


class Continue(Exception):
pass


def not_same_type(a: Value, b: Value):
return type(a) is not type(b) or (
isinstance(a, Vector)
and (a.dims != b.dims or a.primitive_type != b.primitive_type)
)


class Interpreter(MyParserVisitor):
def __init__(self):
self.memory_stack = MemoryStack()
self.memory_stack.push_memory()

def visitScopeStatement(self, ctx: MyParser.ScopeStatementContext):
return self.visitChildren(ctx) # todo
self.memory_stack.push_memory()
self.visitChildren(ctx)
self.memory_stack.pop_memory()

def visitIfThenElse(self, ctx: MyParser.IfThenElseContext):
condition = self.visit(ctx.if_())
if condition:
return self.visit(ctx.then())
elif ctx.else_() is not None:
elif ctx.else_():
return self.visit(ctx.else_())

def visitIf(self, ctx: MyParser.IfContext):
Expand All @@ -24,13 +46,32 @@ def visitElse(self, ctx: MyParser.ElseContext):
return self.visit(ctx.statement())

def visitForLoop(self, ctx: MyParser.ForLoopContext):
return self.visitChildren(ctx) # todo
a, b = self.visit(ctx.range_())
variable = ctx.id_().getText()
for i in range(a, b + 1):
self.memory_stack.put(variable, Int(i))
try:
self.visit(ctx.statement())
except Continue:
continue
except Break:
break

def visitRange(self, ctx: MyParser.RangeContext):
return self.visitChildren(ctx) # todo
a = self.visit(ctx.expression(0))
b = self.visit(ctx.expression(1))
if {type(a), type(b)} != {Int}:
raise TypeError
return (a.value, b.value)

def visitWhileLoop(self, ctx: MyParser.WhileLoopContext):
return self.visitChildren(ctx) # todo
while self.visit(ctx.comparison()):
try:
self.visit(ctx.statement())
except Continue:
continue
except Break:
break

def visitComparison(self, ctx: MyParser.ComparisonContext):
a = self.visit(ctx.expression(0))
Expand All @@ -50,17 +91,51 @@ def visitComparison(self, ctx: MyParser.ComparisonContext):
return a >= b

def visitSimpleAssignment(self, ctx: MyParser.SimpleAssignmentContext):
return self.visitChildren(ctx) # todo
if ctx.id_(): # a = 1
self.memory_stack.put(ctx.id_().getText(), self.visit(ctx.expression()))
else: # a[0] = 1
ref_value = self.visit(ctx.elementReference())
new_value = self.visit(ctx.expression())
if not_same_type(ref_value, new_value):
raise TypeError
ref_value.value = new_value.value

def visitCompoundAssignment(self, ctx: MyParser.CompoundAssignmentContext):
return self.visitChildren(ctx) # todo
if ctx.id_(): # a += 1
value = self.memory_stack.get(ctx.id_().getText())
new_value = self.visit(ctx.expression())
match ctx.getChild(1).symbol.type:
case MyParser.ASSIGN_PLUS:
new_value = value + new_value
case MyParser.ASSIGN_MINUS:
new_value = value - new_value
case MyParser.ASSIGN_MULTIPLY:
new_value = value * new_value
case MyParser.ASSIGN_DIVIDE:
new_value = value / new_value
self.memory_stack.put(ctx.id_().getText(), new_value)
else: # a[0] += 1
ref_value = self.visit(ctx.elementReference())
new_value = self.visit(ctx.expression())
if not_same_type(ref_value, new_value):
raise TypeError
match ctx.getChild(1).symbol.type:
case MyParser.ASSIGN_PLUS:
new_value = ref_value + new_value
case MyParser.ASSIGN_MINUS:
new_value = ref_value - new_value
case MyParser.ASSIGN_MULTIPLY:
new_value = ref_value * new_value
case MyParser.ASSIGN_DIVIDE:
new_value = ref_value / new_value
ref_value.value = new_value.value

def visitPrint(self, ctx: MyParser.PrintContext):
for i in range(ctx.getChildCount() // 2):
print(str(self.visit(ctx.expression(i))))

def visitReturn(self, ctx: MyParser.ReturnContext):
if ctx.expression() is not None:
if ctx.expression():
return_value = self.visit(ctx.expression())
if not isinstance(return_value, Int):
raise TypeError
Expand All @@ -79,7 +154,14 @@ def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
return a * b
case MyParser.DIVIDE:
return a / b
# todo: MAT_* operations
case MyParser.MAT_PLUS:
return a.mat_add(b)
case MyParser.MAT_MINUS:
return a.mat_sub(b)
case MyParser.MAT_MULTIPLY:
return a.mat_mul(b)
case MyParser.MAT_DIVIDE:
return a.mat_truediv(b)

def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
return self.visit(ctx.expression())
Expand Down Expand Up @@ -117,10 +199,10 @@ def visitSpecialMatrixFunction(self, ctx: MyParser.SpecialMatrixFunctionContext)
return vector

def visitBreak(self, ctx: MyParser.BreakContext):
return self.visitChildren(ctx) # todo
raise Break()

def visitContinue(self, ctx: MyParser.ContinueContext):
return self.visitChildren(ctx) # todo
raise Continue()

def visitVector(self, ctx: MyParser.VectorContext):
elements = [
Expand All @@ -129,10 +211,20 @@ def visitVector(self, ctx: MyParser.VectorContext):
return Vector(elements)

def visitElementReference(self, ctx: MyParser.ElementReferenceContext):
return self.visitChildren(ctx) # todo
indices = [
self.visit(ctx.expression(i)) for i in range(ctx.getChildCount() // 2 - 1)
]
if {type(idx) for idx in indices} != {Int}:
raise TypeError
result = self.visit(ctx.id_())
for idx in indices:
if not isinstance(result, Vector):
raise TypeError
result = result.value[idx.value]
return result

def visitId(self, ctx: MyParser.IdContext):
return self.visitChildren(ctx) # todo
return self.memory_stack.get(ctx.getText())

def visitInt(self, ctx: MyParser.IntContext):
return Int(ctx.getText())
Expand All @@ -141,4 +233,4 @@ def visitFloat(self, ctx: MyParser.FloatContext):
return Float(ctx.getText())

def visitString(self, ctx: MyParser.StringContext):
return String(ctx.getText())
return String(ctx.getText()[1:-1]) # without quotes
5 changes: 3 additions & 2 deletions main.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,8 +87,9 @@ def run(filename: str):

tree = parser.program()
if parser.getNumberOfSyntaxErrors() == 0:
listener = SemanticListener()
ParseTreeWalker().walk(listener, tree)
# todo: Fix SemanticListener
# listener = SemanticListener()
# ParseTreeWalker().walk(listener, tree)
if parser.getNumberOfSyntaxErrors() == 0:
visitor = Interpreter()
visitor.visit(tree)
Expand Down
22 changes: 22 additions & 0 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,28 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
[[1, 1, 1]],
],
),
("variables", [2, 1, 3, "OK", 6]),
("while", [4, 3, 2, 1, 0]),
("for", [1, 10, 2, 10, 3, 10, 4, 10]),
("break_continue", [1, 2, 1, 2, 4] * 2),
(
"element_reference",
[
[1, 0],
0,
[[1, 2], [0, 1]],
[[0, 2], [0, 1]],
[[0, 0], [0, 1]],
],
),
(
"mat_operators",
[
[[2, 2], [2, 2]],
[[4, 4], [4, 4]],
[[3, 3], [3, 3]],
],
),
],
)
def test_interpreter(name: str, output: str):
Expand Down
27 changes: 27 additions & 0 deletions tests/interpreter/break_continue.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
for i = 1:4 {
if (i == 3)
break;
print i;
}

for i = 1:4 {
if (i == 3)
continue;
print i;
}

i = 0;
while (i < 4) {
i += 1;
if (i == 3)
break;
print i;
}

i = 0;
while (i < 4) {
i += 1;
if (i == 3)
continue;
print i;
}
9 changes: 9 additions & 0 deletions tests/interpreter/element_reference.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
A = eye(2);
print A[0];
print A[0, 1];
A[0, 1] = 2;
print A;
A[0] = [0, 2];
print A;
A[0, 1] -= 2;
print A;
6 changes: 6 additions & 0 deletions tests/interpreter/for.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
n = 4;
for i = 1:n {
print i;
i = 10;
print i;
}
9 changes: 9 additions & 0 deletions tests/interpreter/mat_operators.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
A = ones(2, 2);
B = ones(2, 2);
A = A .+ B;
print A;
A = A .* A;
print A;
A = A .- B;
print A;
A = A ./ B;
16 changes: 16 additions & 0 deletions tests/interpreter/variables.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
a = 2;
print a;

a -= 1;
print a;

a *= 3;
print a;

if (a == 3) {
b = "OK";
print b;
}

b = 2 * a;
print b;
7 changes: 7 additions & 0 deletions tests/interpreter/while.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
b = -4;
a = 4;
while (a >= b) {
print a;
a -= 1;
b += 1;
}
38 changes: 38 additions & 0 deletions utils/memory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from .values import Value


class Memory:
def __init__(self):
self.variables: dict[str, Value] = {}

def has_variable(self, name: str) -> bool:
return name in self.variables

def get(self, name: str) -> Value:
return self.variables[name]

def put(self, name: str, value: Value):
self.variables[name] = value


class MemoryStack:
def __init__(self):
self.stack: list[Memory] = []

def get(self, name: str) -> Value:
for memory in self.stack:
if memory.has_variable(name):
return memory.get(name)

def put(self, name: str, value: Value):
for memory in self.stack:
if memory.has_variable(name):
memory.put(name, value)
return
self.stack[-1].put(name, value)

def push_memory(self):
self.stack.append(Memory())

def pop_memory(self):
self.stack.pop()
Loading

0 comments on commit ff83d67

Please sign in to comment.