diff --git a/python/tvm/relay/frontend/pytorch.py b/python/tvm/relay/frontend/pytorch.py index 48104e570..37f17057b 100644 --- a/python/tvm/relay/frontend/pytorch.py +++ b/python/tvm/relay/frontend/pytorch.py @@ -4620,6 +4620,83 @@ def scaled_dot_product_attention(self, inputs, input_types): attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2]) return attn_weight + + + def nan_to_num(self, inputs, input_types): + + """ + Mimics the behavior of torch.nan_to_num (https://pytorch.org/docs/stable/generated/torch.nan_to_num.html). + """ + + # Extract input tensor and replacement values + data = inputs[0] + nan_value = inputs[1] + posinf = inputs[2] + neginf = inputs[3] + + # Ensure the data type is one of the supported floating-point types + dtype = input_types[0] + assert dtype in ["float16", "float32"], f"Unsupported dtype: {dtype}. Supported types are ['float16', 'float32']." + + # Define the maximum and minimum representable values for the data type + dtype_max = np.finfo(dtype).max + dtype_min = np.finfo(dtype).min + + # Create constants for NaN, positive infinity, and negative infinity replacements + nan_tensor = tvm.relay.const(nan_value if nan_value is not None else 0.0, dtype) + posinf_tensor = tvm.relay.const(posinf if posinf is not None else dtype_max, dtype) + neginf_tensor = tvm.relay.const(neginf if neginf is not None else dtype_min, dtype) + + # Replace NaN values with the specified or default value + data= tvm.relay.where(_op.isnan(data) , nan_tensor, data) + + # Replace positive infinity with the specified or greatest finite value representable by input’s dtype + data = tvm.relay.where(tvm.relay.greater(data, posinf_tensor),posinf_tensor, data) + + # Replace negative infinity with the specified or least finite value representable by input’s dtype + result = tvm.relay.where(tvm.relay.less(data, neginf_tensor), neginf_tensor, data) + + return result + + + def atan2(self, inputs, input_types): + + """ + Mimics the behavior of torch.atan2 (https://pytorch.org/docs/stable/generated/torch.atan2.html). + """ + + data_1 = inputs[1] # x (denominator) + data_2 = inputs[0] # y (numerator) + + # Compute the ratio y/x. This is the tangent of the angle. + ratio = tvm.relay.divide(data_2, data_1) + + # Compute the arctangent of the ratio, which gives the angle in the range [-π/2, π/2]. + atan_res = tvm.relay.atan(ratio) + + # Define constants for π and 0 for use in correction logic. + pi = tvm.relay.const(np.pi, "float32") # π constant + zero = tvm.relay.const(0.0, "float32") # Zero constant + + # Compute the correction term to adjust the angle to the correct quadrant. + # If x < 0: + # - If y >= 0, add π to the angle (to move from 1st to 2nd quadrant). + # - If y < 0, subtract π from the angle (to move from 4th to 3rd quadrant). + # If x >= 0, no correction is needed. + correction = tvm.relay.where( + tvm.relay.less(data_1, zero), # Check if x < 0 + tvm.relay.where( + tvm.relay.greater_equal(data_2, zero), # Check if y >= 0 + pi, # Add π if x < 0 and y >= 0 + -pi # Subtract π if x < 0 and y < 0 + ), + zero # No correction if x >= 0 + ) + + # Add the correction term to the arctangent result. + result = tvm.relay.add(atan_res, correction) + + return result # Operator mappings def create_convert_map(self): @@ -4893,6 +4970,7 @@ def create_convert_map(self): "aten::mv": self.mv, "aten::grid_sampler": self.grid_sampler, "aten::__ior__": self.make_elemwise("bitwise_or"), + "aten::bitwise_or_": self.make_elemwise("bitwise_or"), "aten::__iand__": self.make_elemwise("bitwise_and"), "aten::__ixor__": self.make_elemwise("bitwise_xor"), "aten::__lshift__": self.make_elemwise("left_shift"), @@ -4920,6 +4998,8 @@ def create_convert_map(self): "aten::linalg_vector_norm": self.linalg_vector_norm, "aten::scaled_dot_product_attention": self.scaled_dot_product_attention, "aten::lift_fresh": self.identity, + "aten::nan_to_num": self.nan_to_num, + "aten::atan2": self.atan2, } def update_convert_map(self, custom_map): diff --git a/python/tvm/relay/op/contrib/forge/forge_passes.py b/python/tvm/relay/op/contrib/forge/forge_passes.py index bb42ebe21..635e7226c 100644 --- a/python/tvm/relay/op/contrib/forge/forge_passes.py +++ b/python/tvm/relay/op/contrib/forge/forge_passes.py @@ -3257,10 +3257,13 @@ def callback(self, pre, post, node_map): data = pre_node_map[self.data][0] - cond = tvm.relay.equal(data, tvm.relay.const(np.nan, dtype="float32")) - where = tvm.relay.where(cond, tvm.relay.const(True), tvm.relay.const(False)) + # NaN (Not a Number) is the only value in floating-point arithmetic that is not equal to itself. + # So, comparing data with itself will return True if data is NaN, and False otherwise. + # This condition is used to identify NaN values in the data tensor. - return where + cond = tvm.relay.not_equal(data, data) + + return cond class RemoveRedundantBinaryStacks(DFPatternCallback):