diff --git a/integration_tests/CMakeLists.txt b/integration_tests/CMakeLists.txt index 1177bb5266..ede6d6b69c 100644 --- a/integration_tests/CMakeLists.txt +++ b/integration_tests/CMakeLists.txt @@ -626,6 +626,7 @@ RUN(NAME test_builtin_sum LABELS cpython llvm c) RUN(NAME test_math1 LABELS cpython llvm c) RUN(NAME test_math_02 LABELS cpython llvm NOFAST) RUN(NAME test_math_03 LABELS llvm) #1595: TODO: Test using CPython (3.11 recommended) +RUN(NAME test_math_04 LABELS llvm) #TODO: Test using CPython (as above) RUN(NAME test_pass_compare LABELS cpython llvm c) RUN(NAME test_c_interop_01 LABELS cpython llvm c) RUN(NAME test_c_interop_02 LABELS cpython llvm c diff --git a/integration_tests/test_math_04.py b/integration_tests/test_math_04.py new file mode 100644 index 0000000000..cf4f6d5974 --- /dev/null +++ b/integration_tests/test_math_04.py @@ -0,0 +1,36 @@ +from math import (cbrt, sqrt) +from lpython import f32, f64, i32, i64 + +eps: f64 +eps = 1e-12 + +def test_cbrt(): + eps: f64 = 1e-12 + a : i32 = 64 + b : i64 = i64(64) + c : f32 = f32(64.0) + d : f64 = f64(64.0) + assert abs(cbrt(124.0) - 4.986630952238646) < eps + assert abs(cbrt(39.0) - 3.3912114430141664) < eps + assert abs(cbrt(39) - 3.3912114430141664) < eps + assert abs(cbrt(a) - 4.0) < eps + assert abs(cbrt(b) - 4.0) < eps + assert abs(cbrt(c) - 4.0) < eps + assert abs(cbrt(d) - 4.0) < eps + +def test_sqrt(): + eps: f64 = 1e-12 + a : i32 = 64 + b : i64 = i64(64) + c : f32 = f32(64.0) + d : f64 = f64(64.0) + assert abs(sqrt(a) - 8.0) < eps + assert abs(sqrt(b) - 8.0) < eps + assert abs(sqrt(c) - 8.0) < eps + assert abs(sqrt(d) - 8.0) < eps + +def check(): + test_cbrt() + test_sqrt() + +check() \ No newline at end of file diff --git a/src/runtime/math.py b/src/runtime/math.py index 8655201892..df8a6e2139 100644 --- a/src/runtime/math.py +++ b/src/runtime/math.py @@ -539,18 +539,74 @@ def trunc(x: f32) -> i32: else: return ceil(x) +@overload +def sqrt(x: f32) -> f64: + """ + Returns square root of a number x + """ + y : f64 + y = f64(x) + return y**(1/2) + +@overload def sqrt(x: f64) -> f64: """ Returns square root of a number x """ return x**(1/2) +@overload +def sqrt(x: i32) -> f64: + """ + Returns square root of a number x + """ + y : f64 + y = float(x) + return y**(1/2) + +@overload +def sqrt(x: i64) -> f64: + """ + Returns square root of a number x + """ + y : f64 + y = float(x) + return y**(1/2) + +@overload +def cbrt(x: f32) -> f64: + """ + Returns cube root of a number x + """ + y : f64 + y = f64(x) + return y**(1/3) + +@overload def cbrt(x: f64) -> f64: """ Returns cube root of a number x """ return x**(1/3) +@overload +def cbrt(x: i32) -> f64: + """ + Returns cube root of a number x + """ + y : f64 + y = float(x) + return y**(1/3) + +@overload +def cbrt(x: i64) -> f64: + """ + Returns cube root of a number x + """ + y : f64 + y = float(x) + return y**(1/3) + @ccall def _lfortran_dsin(x: f64) -> f64: pass