Skip to content

Commit

Permalink
chore: re-enable serde tests (#2968)
Browse files Browse the repository at this point in the history
  • Loading branch information
peri044 authored Jul 31, 2024
1 parent 4476792 commit 784fa57
Showing 1 changed file with 53 additions and 58 deletions.
111 changes: 53 additions & 58 deletions tests/py/dynamo/models/test_export_serde.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,21 +47,21 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, inputs=[input])
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()

deser_trt_module = torchtrt.load(trt_ep_path).module()
# Check Pyt and TRT exported program outputs
cos_sim = cosine_similarity(model(input), trt_module(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)
# TODO: Enable this serialization issues are fixed
# # Check Pyt and deserialized TRT exported program outputs
# cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )

# Check Pyt and deserialized TRT exported program outputs
cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
Expand Down Expand Up @@ -99,8 +99,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, inputs=[input])
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()

deser_trt_module = torchtrt.load(trt_ep_path).module()
# Check Pyt and TRT exported program outputs
outputs_pyt = model(input)
outputs_trt = trt_module(input)
Expand All @@ -111,15 +111,14 @@ def forward(self, x):
msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# TODO: Enable this serialization issues are fixed
# # Check Pyt and deserialized TRT exported program outputs
# outputs_trt_deser = deser_trt_module(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
Expand Down Expand Up @@ -156,8 +155,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, inputs=[input])
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()

deser_trt_module = torchtrt.load(trt_ep_path).module()
# Check Pyt and TRT exported program outputs
outputs_pyt = model(input)
outputs_trt = trt_module(input)
Expand All @@ -168,15 +167,14 @@ def forward(self, x):
msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# TODO: Enable this serialization issues are fixed
# # Check Pyt and deserialized TRT exported program outputs
# outputs_trt_deser = deser_trt_module(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
Expand Down Expand Up @@ -216,8 +214,8 @@ def forward(self, x):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, inputs=[input])
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
for idx in range(len(outputs_pyt)):
Expand All @@ -227,14 +225,13 @@ def forward(self, x):
msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# TODO: Enable this serialization issues are fixed
# outputs_trt_deser = deser_trt_module(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
Expand All @@ -258,8 +255,8 @@ def test_resnet18(ir):
exp_program = torchtrt.dynamo.trace(model, **compile_spec)
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)
torchtrt.save(trt_module, trt_ep_path, inputs=[input])
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0])
Expand All @@ -268,13 +265,12 @@ def test_resnet18(ir):
msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# TODO: Enable this serialization issues are fixed
# outputs_trt_deser = deser_trt_module(input)
# cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_trt_deser = deser_trt_module(input)
cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
Expand Down Expand Up @@ -314,8 +310,8 @@ def forward(self, x):
trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec)

torchtrt.save(trt_module, trt_ep_path, inputs=[input])
# TODO: Enable this serialization issues are fixed
# deser_trt_module = torchtrt.load(trt_ep_path).module()

deser_trt_module = torchtrt.load(trt_ep_path).module()
outputs_pyt = model(input)
outputs_trt = trt_module(input)

Expand All @@ -326,14 +322,13 @@ def forward(self, x):
msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)

# TODO: Enable this serialization issues are fixed
# outputs_trt_deser = deser_trt_module(input)
# for idx in range(len(outputs_pyt)):
# cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
# assertions.assertTrue(
# cos_sim > COSINE_THRESHOLD,
# msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
# )
outputs_trt_deser = deser_trt_module(input)
for idx in range(len(outputs_pyt)):
cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx])
assertions.assertTrue(
cos_sim > COSINE_THRESHOLD,
msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}",
)


@pytest.mark.unit
Expand Down

0 comments on commit 784fa57

Please sign in to comment.