Skip to content

Commit

Permalink
Add support for nan_to_num and atan2 op
Browse files Browse the repository at this point in the history
  • Loading branch information
kamalrajkannan78 committed Jan 7, 2025
1 parent e405246 commit e66c611
Showing 1 changed file with 43 additions and 0 deletions.
43 changes: 43 additions & 0 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -4620,6 +4620,47 @@ def scaled_dot_product_attention(self, inputs, input_types):
attn_weight = _op.reshape(attn_weight, newshape=[-4, batch_size, -1, -2])

return attn_weight


def nan_to_num(self, inputs, input_types):

data = inputs[0]
nan_value = inputs[1]
posinf = inputs[2]
neginf = inputs[3]

dtype = input_types[0]

assert dtype == "float32", f"Expected dtype to be float32, but got {dtype}. Support for {dtype} is not added yet."

dtype_max = np.finfo(dtype).max
dtype_min = np.finfo(dtype).min

nan_tensor = tvm.relay.const(nan_value if nan_value is not None else 0.0, dtype)
posinf_tensor = tvm.relay.const(posinf if posinf is not None else dtype_max, dtype)
neginf_tensor = tvm.relay.const(neginf if neginf is not None else dtype_min, dtype)

result = tvm.relay.where(tvm.relay.isnan(data), nan_tensor, data)
result = tvm.relay.where(tvm.relay.equal(data, tvm.relay.const(np.inf, dtype)), posinf_tensor, result)
result = tvm.relay.where(tvm.relay.equal(data, tvm.relay.const(-np.inf, dtype)), neginf_tensor, result)

return result

def atan2(self, inputs, input_types):

data_1 = inputs[1]
data_2 = inputs[0]

ratio = tvm.relay.divide(data_2, data_1)
atan_res = tvm.relay.atan(ratio)

pi = tvm.relay.const(np.pi, "float32")
zero = tvm.relay.const(0.0, "float32")

correction = tvm.relay.where(tvm.relay.less(data_1, zero), tvm.relay.where(tvm.relay.greater_equal(data_2, zero), pi, -pi), zero)

result = tvm.relay.add(atan_res, correction)
return result

# Operator mappings
def create_convert_map(self):
Expand Down Expand Up @@ -4920,6 +4961,8 @@ def create_convert_map(self):
"aten::linalg_vector_norm": self.linalg_vector_norm,
"aten::scaled_dot_product_attention": self.scaled_dot_product_attention,
"aten::lift_fresh": self.identity,
"aten::nan_to_num":self.nan_to_num,
"aten::atan2":self.atan2,
}

def update_convert_map(self, custom_map):
Expand Down

0 comments on commit e66c611

Please sign in to comment.