Skip to content

Commit

Permalink
Interpret matrix operators
Browse files Browse the repository at this point in the history
  • Loading branch information
mdbrnowski committed Dec 14, 2024
1 parent dfd512a commit 0eef1fc
Show file tree
Hide file tree
Showing 4 changed files with 69 additions and 1 deletion.
9 changes: 8 additions & 1 deletion interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,14 @@ def visitBinaryExpression(self, ctx: MyParser.BinaryExpressionContext):
return a * b
case MyParser.DIVIDE:
return a / b
# todo: MAT_* operations
case MyParser.MAT_PLUS:
return a.mat_add(b)
case MyParser.MAT_MINUS:
return a.mat_sub(b)
case MyParser.MAT_MULTIPLY:
return a.mat_mul(b)
case MyParser.MAT_DIVIDE:
return a.mat_truediv(b)

def visitParenthesesExpression(self, ctx: MyParser.ParenthesesExpressionContext):
return self.visit(ctx.expression())
Expand Down
8 changes: 8 additions & 0 deletions test_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,14 @@ def test_sem_errors(name: str, line_numbers: list[int], additional: str):
[[0, 0], [0, 1]],
],
),
(
"mat_operators",
[
[[2, 2], [2, 2]],
[[4, 4], [4, 4]],
[[3, 3], [3, 3]],
],
),
],
)
def test_interpreter(name: str, output: str):
Expand Down
9 changes: 9 additions & 0 deletions tests/interpreter/mat_operators.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
A = ones(2, 2);
B = ones(2, 2);
A = A .+ B;
print A;
A = A .* A;
print A;
A = A .- B;
print A;
A = A ./ B;
44 changes: 44 additions & 0 deletions utils/values.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,6 +130,50 @@ def __init__(self, value: list):
def __str__(self):
return "[" + ", ".join(str(elem) for elem in self.value) + "]"

def mat_add(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_add(other_elem))
else:
rows.append(elem + other_elem)
return Vector(rows)
raise TypeError()

def mat_sub(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_sub(other_elem))
else:
rows.append(elem - other_elem)
return Vector(rows)
raise TypeError()

def mat_mul(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_mul(other_elem))
else:
rows.append(elem * other_elem)
return Vector(rows)
raise TypeError()

def mat_truediv(self, other):
if isinstance(other, Vector):
rows = []
for elem, other_elem in zip(self.value, other.value):
if isinstance(elem, Vector):
rows.append(elem.mat_truediv(other_elem))
else:
rows.append(elem / other_elem)
return Vector(rows)
raise TypeError()

def transpose(self):
if len(self.dims) != 2:
raise TypeError
Expand Down

0 comments on commit 0eef1fc

Please sign in to comment.