Skip to content

Commit

Permalink
Initial shot at structural equal and bug fixes in cpp_generator.py (#15)
Browse files Browse the repository at this point in the history
* Get initial version of structural equality working

* Fix typo in objectgen

* Cpp generator bug

* Respond to comments
  • Loading branch information
Lily Orth-Smith authored Jun 17, 2021
1 parent 29d1b7f commit 2300abe
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 17 deletions.
23 changes: 12 additions & 11 deletions objectgen/objectgen/cpp_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,21 +183,22 @@ def generate_equal_and_hash(self, header_buf, object_def):
header_buf.write(f"{4 * ' '}bool SEqualReduce(const {object_def.payload_name()}* other, SEqualReducer equal) const {{\n")
header_buf.write(f"{8 * ' '}return")

if len(object_def.fields):
has_bindings = any([f.is_binding for f in object_def.fields])

check_equal_fields = [f for f in object_def.fields if f.use_in_sequal_reduce]
if len(check_equal_fields):
has_bindings = any([f.is_binding for f in check_equal_fields])
if has_bindings:
raise Exception("add MarkNodeGraph")

for i, field in enumerate(object_def.fields):
if has_bindings and field.is_binding:
equal_method = "DefEqual"
else:
equal_method = "equal"
for i, field in enumerate(check_equal_fields):
if field.use_in_sequal_reduce: # Whether this field should be included in the structural equality
if has_bindings and field.is_binding:
equal_method = "DefEqual"
else:
equal_method = "equal"

header_buf.write(f" #{equal_method}({field.field_name}, other->{field.field_name})")
if i != len(object_def.fields) - 1:
header_buf.write(" && ")
header_buf.write(f" {equal_method}({field.field_name}, other->{field.field_name})")
if i != len(check_equal_fields) - 1:
header_buf.write(" && ")
else:
header_buf.write(" true")

Expand Down
1 change: 1 addition & 0 deletions objectgen/objectgen/object_def.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ class ObjectField:
field_name: str
field_type: Type
is_binding: bool = False
use_in_sequal_reduce: bool = True

@attr.s
class ObjectMethod:
Expand Down
8 changes: 4 additions & 4 deletions objectgen/relax.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,14 +26,14 @@
ObjectDefinition(
name="Type",
fields=[
ObjectField("span", "Span")
ObjectField("span", "Span", use_in_sequal_reduce=False)
],
final = False,
),
ObjectDefinition(
name="Expr",
fields=[
ObjectField("span", "Span")
ObjectField("span", "Span", use_in_sequal_reduce=False)
],
final = False,
),
Expand Down Expand Up @@ -87,7 +87,7 @@
name="Function",
inherits_from="Expr",
fields=[
ObjectField("name", "Optional<runtime::String>"),
ObjectField("name", "Optional<runtime::String>", use_in_sequal_reduce=False),
ObjectField("params", "runtime::Array<Var>"),
ObjectField("body", "Expr"),
ObjectField("ret_type", "Type"),
Expand Down Expand Up @@ -189,7 +189,7 @@
ObjectDefinition(
name="Instruction",
fields=[
ObjectField("span", "Span")
ObjectField("span", "Span", use_in_sequal_reduce=False)
],
final = False,
),
Expand Down
6 changes: 4 additions & 2 deletions python/tvm/relax/frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,11 +51,12 @@ def __init__(self, definition_scope, diag_ctx):
self.module = {}
super().__init__()

def span_to_span(self, span):
def span_to_span(self, span: synr.Span) -> tvm.ir.Span:
src_name = self.diag_ctx.str_to_source_name[span.filename]
tvm_span = tvm.ir.Span(src_name, span.start_line, span.end_line, span.start_column, span.end_column)
return tvm_span


def decl_var(self, name, ty, span=None):
identifier = Id(name)
var = expr.Var(identifier, ty, span)
Expand Down Expand Up @@ -226,7 +227,7 @@ def add_source(self, name: str, source: str) -> None:
src_name = self.tvm_diag_ctx.module.source_map.add(name, source)
self.str_to_source_name[name] = src_name

def emit(self, level: str, message: str, span: Span) -> None:
def emit(self, level: str, message: str, span: tvm.ir.Span) -> None:
"""Called when an error has occured."""

if level == "error":
Expand All @@ -239,6 +240,7 @@ def emit(self, level: str, message: str, span: Span) -> None:
level = "error"

assert span, "Span must not be null"
assert isinstance(span, tvm.ir.span), "Expected tvm.ir.span, but got " + str(type(span))

diag = diagnostics.Diagnostic(level, span, message)

Expand Down
2 changes: 2 additions & 0 deletions python/tvm/script/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def buffer_slice_to_region(

def tvm_span_from_synr(span: synr.ast.Span) -> Span:
"""Convert a synr span to a TVM span"""
assert isinstance(span, synr.ast.Span), "Expected span to be synr.ast.Span, but got " + str(type(span))
return Span(
SourceName(span.filename),
span.start_line,
Expand All @@ -111,6 +112,7 @@ def tvm_span_from_synr(span: synr.ast.Span) -> Span:

def synr_span_from_tvm(span: Span) -> synr.ast.Span:
"""Convert a TVM span to a synr span"""
assert isinstance(span, synr.ast.Span), "Expected span to be tvm.ir.Span, but got " + str(type(span))
return synr.ast.Span(
span.source_name.name,
span.line,
Expand Down
70 changes: 70 additions & 0 deletions tests/python/relax/test_relax_roundtrip.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
"""Roundtripping tests for Relay Next (Relax)"""
from __future__ import annotations
from os import X_OK
import tvm
from tvm.relay.base import Id
import tvm.relax.op.operators
from tvm.relax import expr, r2


from typing import TypeVar, Generic, Union
from io import StringIO
import numpy

def assert_structural_equal(lhs, rhs, map_free_vars=False):
lhs = tvm.runtime.convert(lhs)
rhs = tvm.runtime.convert(rhs)
# These are packed funcs here
tvm.runtime._ffi_node_api.StructuralEqual(lhs, rhs, True, map_free_vars)

@r2
def foo(x: Tensor) -> Tensor:
return x

foo1 = foo

@r2
def same_as_foo(x: Tensor) -> Tensor:
return x

@r2
def not_foo(x: Tensor, y: Tensor) -> Tensor:
return x

@r2
def foo(y: Tensor) -> Tensor:
return y

foo2 = foo


# test literally the same object
def test_same():
rlx_program = foo
assert_structural_equal(rlx_program.module['foo'], rlx_program.module['foo'])


# test two fns with the same name but different objects, different variable names
# problem with span
def test_same_name():
assert_structural_equal(foo1.module['foo'], foo2.module['foo'], True)


# test two functions that are the same with different names
def test_same_as_foo():
rlx_program1 = foo
rlx_program2 = same_as_foo
assert_structural_equal(rlx_program1.module['foo'], rlx_program2.module['same_as_foo'], True)

def test_not_foo():
rlx_program1 = foo
rlx_program2 = not_foo
assert_structural_equal(rlx_program1.module['foo'], rlx_program2.module['not_foo'], True)

# Tests that should succeed
test_same()
test_same_name()
test_same_as_foo()

# Tests that should fail
# test_not_foo()

0 comments on commit 2300abe

Please sign in to comment.