Skip to content

Commit

Permalink
feat(frontend): Added kaiser_bessel_derived_window to tensorflow fron…
Browse files Browse the repository at this point in the history
…tend (ivy-llc#22517)

Co-authored-by: Saeed Ashraf <[email protected]>
  • Loading branch information
2 people authored and druvdub committed Oct 14, 2023
1 parent af59d4a commit 422a308
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 0 deletions.
17 changes: 17 additions & 0 deletions ivy/functional/frontends/tensorflow/signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,15 @@ def idct(input, type=2, n=None, axis=-1, norm=None, name=None):
return ivy.dct(input, type=inverse_type, n=n, axis=axis, norm=norm)


# kaiser_bessel_derived_window
@handle_tf_dtype
@to_ivy_arrays_and_back
def kaiser_bessel_derived_window(
window_length, beta=12.0, dtype=ivy.float32, name=None
):
return ivy.kaiser_bessel_derived_window(window_length, beta=beta, dtype=dtype)


@with_supported_dtypes(
{"2.13.0 and below": ("float32", "float64", "float16", "bfloat16")},
"tensorflow",
Expand All @@ -36,3 +45,11 @@ def kaiser_window(window_length, beta=12.0, dtype=ivy.float32, name=None):
@to_ivy_arrays_and_back
def vorbis_window(window_length, dtype=ivy.float32, name=None):
return ivy.vorbis_window(window_length, dtype=dtype, out=None)


kaiser_bessel_derived_window.supported_dtypes = (
"float32",
"float64",
"float16",
"bfloat16",
)
37 changes: 37 additions & 0 deletions ivy_tests/test_ivy/test_frontends/test_tensorflow/test_signal.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,43 @@ def test_tensorflow_idct(
)


# kaiser_bessel_derived_window
@handle_frontend_test(
fn_tree="tensorflow.signal.kaiser_bessel_derived_window",
dtype_and_x=helpers.dtype_and_values(
available_dtypes=helpers.get_dtypes("valid"),
max_num_dims=0,
min_value=1,
max_value=10,
),
beta=st.floats(min_value=1, max_value=5),
# dtype=helpers.get_dtypes("float", full=False),
test_with_out=st.just(False),
)
def test_tensorflow_kaiser_bessel_derived_window(
*,
dtype_and_x,
beta,
test_flags,
backend_fw,
fn_tree,
on_device,
frontend, # dtype
):
input_dtype, x = dtype_and_x
helpers.test_frontend_function(
input_dtypes=input_dtype,
test_flags=test_flags,
backend_to_test=backend_fw,
frontend=frontend,
fn_tree=fn_tree,
on_device=on_device,
window_length=int(x[0]),
beta=beta,
# dtype=dtype[0],
)


# kaiser_window
@handle_frontend_test(
fn_tree="tensorflow.signal.kaiser_window",
Expand Down

0 comments on commit 422a308

Please sign in to comment.