Skip to content

Commit

Permalink
Update on "Implement aten::index | feat(torchlib) (#862)"
Browse files Browse the repository at this point in the history
  • Loading branch information
justinchuby committed Jul 17, 2023
1 parent 4e54582 commit ce12f0e
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 8 deletions.
2 changes: 0 additions & 2 deletions onnxscript/function_libs/torch_lib/ops/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -3064,8 +3064,6 @@ def aten_index(self: TensorType, indices: Sequence[Optional[INT64]]) -> TensorTy

self_rank = len(self.shape)
index_ranks = [len(index.shape) for index in indices if index is not None]
print("index_ranks: ", index_ranks)
print("indices: ", indices)
advanced_indexing_rank = max(index_ranks)

# reordered_positions is the permutation of the index positions where
Expand Down
16 changes: 10 additions & 6 deletions onnxscript/tests/function_libs/torch_lib/ops_test_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,8 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,
input.value = subarg
sequence_input.append(input)
ort_inputs[input_name] = subarg
else:
sequence_input.append(subarg)
onnxscript_args.append(sequence_input)
else:
onnxscript_args.append(arg)
Expand Down Expand Up @@ -511,12 +513,14 @@ def _capture_graph_and_evaluate_torch_script_evaluator(function: Callable, args,

onnx_model = onnxscript_graph.to_model_proto(TEST_OPSET_VERSION)
# Make sure the model is valid
try:
onnx.checker.check_model(onnx_model, full_check=True)
except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
raise AssertionError(
f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}"
) from e
# try:
# onnx.checker.check_model(onnx_model, full_check=True)
# except (onnx.checker.ValidationError, onnx.shape_inference.InferenceError) as e:
# raise AssertionError(
# f"ONNX model is invalid. Model:\n{onnxscript.proto2text(onnx_model)}"
# ) from e

print(onnxscript.proto2text(onnx_model))

try:
if os.environ.get("CATCH_ORT_SEGFAULT") == "1":
Expand Down

0 comments on commit ce12f0e

Please sign in to comment.