From 7ba9cc846e629557412207563c210030249ffe86 Mon Sep 17 00:00:00 2001 From: Li Jin Date: Tue, 6 Jun 2023 06:35:54 -0400 Subject: [PATCH] Add refcounting and debug code --- python/pyarrow/conftest.py | 12 ++++++++++ python/pyarrow/src/arrow/python/udf.cc | 1 + python/pyarrow/tests/test_udf.py | 33 +++++++++++++++++++------- 3 files changed, 37 insertions(+), 9 deletions(-) diff --git a/python/pyarrow/conftest.py b/python/pyarrow/conftest.py index f32cbf01efcd6..1914d892e508c 100644 --- a/python/pyarrow/conftest.py +++ b/python/pyarrow/conftest.py @@ -269,6 +269,11 @@ def unary_func_fixture(): def unary_function(ctx, x): return pc.call_function("add", [x, 1], memory_pool=ctx.memory_pool) + + print("unary_func") + import sys + print(sys.getrefcount(unary_function)) + func_name = "y=x+1" unary_doc = {"summary": "add function", "description": "test add function"} @@ -277,6 +282,8 @@ def unary_function(ctx, x): unary_doc, {"array": pa.int64()}, pa.int64()) + + print(sys.getrefcount(unary_function)) return unary_function, func_name @@ -295,6 +302,10 @@ def func(ctx, x): func_doc = {"summary": "y=avg(x)", "description": "find mean of x"} + print("unary_agg") + import sys + print(sys.getrefcount(func)) + pc.register_aggregate_function(func, func_name, func_doc, @@ -303,6 +314,7 @@ def func(ctx, x): }, pa.float64() ) + print(sys.getrefcount(func)) return func, func_name diff --git a/python/pyarrow/src/arrow/python/udf.cc b/python/pyarrow/src/arrow/python/udf.cc index a856f0e398b07..e7ef154fabb9f 100644 --- a/python/pyarrow/src/arrow/python/udf.cc +++ b/python/pyarrow/src/arrow/python/udf.cc @@ -131,6 +131,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator { std::vector> input_types, std::shared_ptr output_type) : agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) { + Py_INCREF(agg_function->obj()); std::vector> fields; for (size_t i = 0; i < input_types.size(); i++) { fields.push_back(std::move(field("", input_types[i]))); diff --git a/python/pyarrow/tests/test_udf.py b/python/pyarrow/tests/test_udf.py index aee2e4e27fbcf..055af2f651e66 100644 --- a/python/pyarrow/tests/test_udf.py +++ b/python/pyarrow/tests/test_udf.py @@ -292,7 +292,12 @@ def check_scalar_function(func_fixture, func = pc.get_function(name) assert func.name == name + import sys + print("Before call_function: ", sys.getrefcount(function)) result = pc.call_function(name, inputs, length=batch_length) + print("After call_function: ", sys.getrefcount(function)) + pc.call_function(name, inputs, length=batch_length) + print("After second call: ", sys.getrefcount(function)) expected_output = function(mock_udf_context(batch_length), *inputs) assert result == expected_output # At the moment there is an issue when handling nullary functions. @@ -308,11 +313,14 @@ def check_scalar_function(func_fixture, def test_udf_array_unary(unary_func_fixture): + import sys + print(sys.getrefcount(unary_func_fixture[0])) check_scalar_function(unary_func_fixture, [ pa.array([10, 20], pa.int64()) ] ) + print(sys.getrefcount(unary_func_fixture[0])) def test_udf_array_binary(binary_func_fixture): @@ -656,11 +664,16 @@ def test_udt_datasource1_exception(): _test_datasource1_udt(datasource1_exception) -# def test_agg_basic(unary_agg_func_fixture): -# arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) -# result = pc.call_function("y=avg(x)", [arr]) -# expected = pa.scalar(30.0) -# assert result == expected +def test_agg_basic(unary_agg_func_fixture): + import sys + arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64()) + print("Before calling agg: ", sys.getrefcount(unary_agg_func_fixture[0])) + result = pc.call_function("y=avg(x)", [arr]) + print("After calling agg: ", sys.getrefcount(unary_agg_func_fixture[0])) + pc.call_function("y=avg(x)", [arr]) + print("After second call agg: ", sys.getrefcount(unary_agg_func_fixture[0])) + expected = pa.scalar(30.0) + assert result == expected # def test_agg_empty(unary_agg_func_fixture): @@ -694,7 +707,9 @@ def test_udt_datasource1_exception(): # def test_agg_exception(exception_agg_func_fixture): -# arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) - -# with pytest.raises(RuntimeError, match='Oops'): -# pc.call_function("y=exception_len(x)", [arr]) +# arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64()) +# import sys +# print("Before calling agg exception: ", sys.getrefcount(exception_agg_func_fixture[0])) +# with pytest.raises(RuntimeError, match='Oops'): +# pc.call_function("y=exception_len(x)", [arr]) +# print("After calling agg exception: ", sys.getrefcount(exception_agg_func_fixture[0]))