Skip to content

Commit

Permalink
feat(//py): Allow example tensors from torch to set shape
Browse files Browse the repository at this point in the history
Signed-off-by: Naren Dasan <[email protected]>
Signed-off-by: Naren Dasan <[email protected]>
  • Loading branch information
narendasan committed Aug 31, 2021
1 parent 15e6863 commit 01d525d
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 1 deletion.
7 changes: 7 additions & 0 deletions py/trtorch/Input.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,10 @@ def _parse_format(format: Any) -> _types.TensorFormat:
else:
raise TypeError(
"Tensor format needs to be specified with either torch.memory_format or trtorch.TensorFormat")

@classmethod
def _from_tensor(cls, t: torch.Tensor):
if not any([t.is_contiguous(memory_format=torch.contiguous_format), t.is_contiguous(memory_format=torch.channels_last)]):
raise ValueError("Tensor does not have a supported contiguous memory format, supported formats are contiguous or channel_last")
frmt = torch.contiguous_format if t.is_contiguous(memory_format=torch.contiguous_format) else torch.channels_last
return cls(shape=t.shape, dtype=t.dtype, format=frmt)
6 changes: 5 additions & 1 deletion py/trtorch/_compile_spec.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,11 @@ def _parse_compile_spec(compile_spec: Dict[str, Any]) -> trtorch._C.CompileSpec:
info.inputs = _parse_input_ranges(compile_spec["input_shapes"])

if "inputs" in compile_spec:
info.inputs = [i._to_internal() for i in compile_spec["inputs"]]
if not all([isinstance(i, torch.Tensor) or isinstance(i, trtorch.Input) for i in compile_spec["inputs"]]):
raise KeyError("Input specs should be either trtorch.Input or torch.Tensor, found types: {}".format([typeof(i) for i in compile_spec["inputs"]]))

inputs = [trtorch.Input._from_tensor(i) if isinstance(i, torch.Tensor) else i for i in compile_spec["inputs"]]
info.inputs = [i._to_internal() for i in inputs]

if "op_precision" in compile_spec and "enabled_precisions" in compile_spec:
raise KeyError(
Expand Down
24 changes: 24 additions & 0 deletions tests/py/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,30 @@ def test_compile_script(self):
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_from_torch_tensor(self):
compile_spec = {
"inputs": [self.input],
"device": {
"device_type": trtorch.DeviceType.GPU,
"gpu_id": 0,
},
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

def test_device(self):
compile_spec = {
"inputs": [self.input],
"device": trtorch.Device("gpu:0"),
"enabled_precisions": {torch.float}
}

trt_mod = trtorch.compile(self.scripted_model, compile_spec)
same = (trt_mod(self.input) - self.scripted_model(self.input)).abs().max()
self.assertTrue(same < 2e-2)

class TestCompileHalf(ModelTestCase):

Expand Down

0 comments on commit 01d525d

Please sign in to comment.