diff --git a/tests/py/dynamo/models/test_export_serde.py b/tests/py/dynamo/models/test_export_serde.py index 58e0115886..e139b7109e 100644 --- a/tests/py/dynamo/models/test_export_serde.py +++ b/tests/py/dynamo/models/test_export_serde.py @@ -47,21 +47,21 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) torchtrt.save(trt_module, trt_ep_path, inputs=[input]) - # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + + deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs cos_sim = cosine_similarity(model(input), trt_module(input)[0]) assertions.assertTrue( cos_sim > COSINE_THRESHOLD, msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # TODO: Enable this serialization issues are fixed - # # Check Pyt and deserialized TRT exported program outputs - # cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0]) - # assertions.assertTrue( - # cos_sim > COSINE_THRESHOLD, - # msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - # ) + + # Check Pyt and deserialized TRT exported program outputs + cos_sim = cosine_similarity(model(input), deser_trt_module(input)[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_model_full_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) @pytest.mark.unit @@ -99,8 +99,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) torchtrt.save(trt_module, trt_ep_path, inputs=[input]) - # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + + deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -111,15 +111,14 @@ def forward(self, x): msg=f"test_base_full_compile_multiple_outputs TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # TODO: Enable this serialization issues are fixed # # Check Pyt and deserialized TRT exported program outputs - # outputs_trt_deser = deser_trt_module(input) - # for idx in range(len(outputs_pyt)): - # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) - # assertions.assertTrue( - # cos_sim > COSINE_THRESHOLD, - # msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - # ) + outputs_trt_deser = deser_trt_module(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_base_full_compile_multiple_outputs deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) @pytest.mark.unit @@ -156,8 +155,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) torchtrt.save(trt_module, trt_ep_path, inputs=[input]) - # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + + deser_trt_module = torchtrt.load(trt_ep_path).module() # Check Pyt and TRT exported program outputs outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -168,15 +167,14 @@ def forward(self, x): msg=f"test_no_compile TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # TODO: Enable this serialization issues are fixed # # Check Pyt and deserialized TRT exported program outputs - # outputs_trt_deser = deser_trt_module(input) - # for idx in range(len(outputs_pyt)): - # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) - # assertions.assertTrue( - # cos_sim > COSINE_THRESHOLD, - # msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - # ) + outputs_trt_deser = deser_trt_module(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_no_compile deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) @pytest.mark.unit @@ -216,8 +214,8 @@ def forward(self, x): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) torchtrt.save(trt_module, trt_ep_path, inputs=[input]) - # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + + deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) outputs_trt = trt_module(input) for idx in range(len(outputs_pyt)): @@ -227,14 +225,13 @@ def forward(self, x): msg=f"test_hybrid_relu_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # TODO: Enable this serialization issues are fixed - # outputs_trt_deser = deser_trt_module(input) - # for idx in range(len(outputs_pyt)): - # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) - # assertions.assertTrue( - # cos_sim > COSINE_THRESHOLD, - # msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - # ) + outputs_trt_deser = deser_trt_module(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_hybrid_relu_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) @pytest.mark.unit @@ -258,8 +255,8 @@ def test_resnet18(ir): exp_program = torchtrt.dynamo.trace(model, **compile_spec) trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) torchtrt.save(trt_module, trt_ep_path, inputs=[input]) - # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + + deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) outputs_trt = trt_module(input) cos_sim = cosine_similarity(outputs_pyt, outputs_trt[0]) @@ -268,13 +265,12 @@ def test_resnet18(ir): msg=f"test_resnet18 TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # TODO: Enable this serialization issues are fixed - # outputs_trt_deser = deser_trt_module(input) - # cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) - # assertions.assertTrue( - # cos_sim > COSINE_THRESHOLD, - # msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - # ) + outputs_trt_deser = deser_trt_module(input) + cos_sim = cosine_similarity(outputs_pyt, outputs_trt_deser[0]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_resnet18 deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) @pytest.mark.unit @@ -314,8 +310,8 @@ def forward(self, x): trt_module = torchtrt.dynamo.compile(exp_program, **compile_spec) torchtrt.save(trt_module, trt_ep_path, inputs=[input]) - # TODO: Enable this serialization issues are fixed - # deser_trt_module = torchtrt.load(trt_ep_path).module() + + deser_trt_module = torchtrt.load(trt_ep_path).module() outputs_pyt = model(input) outputs_trt = trt_module(input) @@ -326,14 +322,13 @@ def forward(self, x): msg=f"test_hybrid_conv_fallback TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", ) - # TODO: Enable this serialization issues are fixed - # outputs_trt_deser = deser_trt_module(input) - # for idx in range(len(outputs_pyt)): - # cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) - # assertions.assertTrue( - # cos_sim > COSINE_THRESHOLD, - # msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", - # ) + outputs_trt_deser = deser_trt_module(input) + for idx in range(len(outputs_pyt)): + cos_sim = cosine_similarity(outputs_pyt[idx], outputs_trt_deser[idx]) + assertions.assertTrue( + cos_sim > COSINE_THRESHOLD, + msg=f"test_hybrid_conv_fallback deserialized TRT outputs don't match with the original model. Cosine sim score: {cos_sim} Threshold: {COSINE_THRESHOLD}", + ) @pytest.mark.unit