You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
n00b to this very cool project, looking to enforce a broadcast-ability pattern where a dimension in one tensor either matches or can be broadcast to (i.e. equals 1) a dimension in another tensor.
@typeguard.typechecked
def mwe(
x: torchtyping.TensorType[
...,
"foo",
"bar", # How do we make this "match bar from arg_b or equal 1"?
],
y: torchtyping.TensorType[
"bar",
]) -> torch typing.TensorType[...,"foo","bar"]:
return x * y
Am I missing an existing way to do this in torchtyping out of the box? Would this need an extension?
The text was updated successfully, but these errors were encountered:
Yep, this is possible: it can be done with Union[TensorType[..., "foo", 1], TensorType[..., "foo", "bar"]].
One caveat -- switching the order of the elements of the Union will cause a spurious failure. (The 1 case has to go before the "bar" case.) That's a bug, really, but probably a thorny one to fix.
Incidentally broadcasting is a common enough operation that I'd be willing to accept a PR making this neater than the Union solution. Essentially all that's needed is some syntax like TensorType["foo": OrOne] which TensorType.__class_getitem__ expands out into a Union of the form given above.
This should be pretty simple so it'd be a good first issue for anyone looking to contribute.
n00b to this very cool project, looking to enforce a broadcast-ability pattern where a dimension in one tensor either matches or can be broadcast to (i.e. equals 1) a dimension in another tensor.
Am I missing an existing way to do this in
torchtyping
out of the box? Would this need an extension?The text was updated successfully, but these errors were encountered: