Skip to content

Commit

Permalink
Add trunc to z3 validator (pytorch#140886)
Browse files Browse the repository at this point in the history
Fixes vision_maskrcnn benchmark when validation is turned on

Pull Request resolved: pytorch#140886
Approved by: https://github.com/ezyang
ghstack dependencies: pytorch#140830, pytorch#140832, pytorch#140828
  • Loading branch information
bobrenjc93 authored and youssef62 committed Nov 23, 2024
1 parent 0f70dde commit d492213
Showing 1 changed file with 4 additions and 0 deletions.
4 changes: 4 additions & 0 deletions torch/fx/experimental/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,9 @@ def ceil(self, number: z3.ArithRef) -> z3.ArithRef:
self.floor(number) < number, self.floor(number + 1), number
) # type: ignore[return-value]

def trunc(self, number: z3.ArithRef) -> z3.ArithRef:
return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value]

def max(self, a: z3.ArithRef, b: z3.ArithRef) -> z3.ArithRef:
return z3.If(a > b, a, b) # type: ignore[return-value]

Expand Down Expand Up @@ -291,6 +294,7 @@ def wrapper(*args):
# Math module.
math.ceil: lift(ops.ceil),
math.floor: lift(ops.floor),
math.trunc: lift(ops.trunc),
# Torch module.
torch.sym_float: lift(ops.to_real),
torch.sym_max: lift(ops.max),
Expand Down

0 comments on commit d492213

Please sign in to comment.