From 9e94e6ef71eb1c51658e389b8ce0c38fbab2bb67 Mon Sep 17 00:00:00 2001 From: Sergei Lebedev Date: Mon, 5 Feb 2024 13:58:01 -0800 Subject: [PATCH] Fixed a typo in min/max Triton lowering rules PiperOrigin-RevId: 604424404 --- jaxlib/triton/compat.py | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/jaxlib/triton/compat.py b/jaxlib/triton/compat.py index 3c2125042636..0e5d948263ad 100644 --- a/jaxlib/triton/compat.py +++ b/jaxlib/triton/compat.py @@ -672,13 +672,13 @@ def max(x: tensor, y: tensor) -> tensor: assert x.shape == y.shape if x.dtype.is_floating(): # TODO(slebedev): Triton promotes bfloat16 to float32 and back here. - return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.maxnumf(x.handle, y.handle), x.type) if not x.dtype.is_int(): raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}") elif x.dtype.is_int_signed(): - return tensor(arith_dialect.maxsi(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.maxsi(x.handle, y.handle), x.type) else: - return tensor(arith_dialect.maxui(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.maxui(x.handle, y.handle), x.type) @staticmethod def min(x: tensor, y: tensor) -> tensor: @@ -686,13 +686,13 @@ def min(x: tensor, y: tensor) -> tensor: assert x.shape == y.shape if x.dtype.is_floating(): # TODO(slebedev): Triton promotes bfloat16 to float32 and back here. - return tensor(arith_dialect.minnumf(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.minnumf(x.handle, y.handle), x.type) if not x.dtype.is_int(): raise NotImplementedError(f"unsupported dtypes: {x.dtype} and {y.dtype}") elif x.dtype.is_int_signed(): - return tensor(arith_dialect.minsi(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.minsi(x.handle, y.handle), x.type) else: - return tensor(arith_dialect.minui(x.handle, y.handle), x.dtype) + return tensor(arith_dialect.minui(x.handle, y.handle), x.type) sin = libdevice_extern_elementwise({ (float32,): ("__nv_sinf", float32),