Skip to content

Commit

Permalink
Add refcounting and debug code
Browse files Browse the repository at this point in the history
  • Loading branch information
icexelloss committed Jun 6, 2023
1 parent 89667fb commit 7ba9cc8
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 9 deletions.
12 changes: 12 additions & 0 deletions python/pyarrow/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -269,6 +269,11 @@ def unary_func_fixture():
def unary_function(ctx, x):
return pc.call_function("add", [x, 1],
memory_pool=ctx.memory_pool)

print("unary_func")
import sys
print(sys.getrefcount(unary_function))

func_name = "y=x+1"
unary_doc = {"summary": "add function",
"description": "test add function"}
Expand All @@ -277,6 +282,8 @@ def unary_function(ctx, x):
unary_doc,
{"array": pa.int64()},
pa.int64())

print(sys.getrefcount(unary_function))
return unary_function, func_name


Expand All @@ -295,6 +302,10 @@ def func(ctx, x):
func_doc = {"summary": "y=avg(x)",
"description": "find mean of x"}

print("unary_agg")
import sys
print(sys.getrefcount(func))

pc.register_aggregate_function(func,
func_name,
func_doc,
Expand All @@ -303,6 +314,7 @@ def func(ctx, x):
},
pa.float64()
)
print(sys.getrefcount(func))
return func, func_name


Expand Down
1 change: 1 addition & 0 deletions python/pyarrow/src/arrow/python/udf.cc
Original file line number Diff line number Diff line change
Expand Up @@ -131,6 +131,7 @@ struct PythonUdfScalarAggregatorImpl : public ScalarUdfAggregator {
std::vector<std::shared_ptr<DataType>> input_types,
std::shared_ptr<DataType> output_type)
: agg_cb(agg_cb), agg_function(agg_function), output_type(output_type) {
Py_INCREF(agg_function->obj());
std::vector<std::shared_ptr<Field>> fields;
for (size_t i = 0; i < input_types.size(); i++) {
fields.push_back(std::move(field("", input_types[i])));
Expand Down
33 changes: 24 additions & 9 deletions python/pyarrow/tests/test_udf.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,7 +292,12 @@ def check_scalar_function(func_fixture,
func = pc.get_function(name)
assert func.name == name

import sys
print("Before call_function: ", sys.getrefcount(function))
result = pc.call_function(name, inputs, length=batch_length)
print("After call_function: ", sys.getrefcount(function))
pc.call_function(name, inputs, length=batch_length)
print("After second call: ", sys.getrefcount(function))
expected_output = function(mock_udf_context(batch_length), *inputs)
assert result == expected_output
# At the moment there is an issue when handling nullary functions.
Expand All @@ -308,11 +313,14 @@ def check_scalar_function(func_fixture,


def test_udf_array_unary(unary_func_fixture):
import sys
print(sys.getrefcount(unary_func_fixture[0]))
check_scalar_function(unary_func_fixture,
[
pa.array([10, 20], pa.int64())
]
)
print(sys.getrefcount(unary_func_fixture[0]))


def test_udf_array_binary(binary_func_fixture):
Expand Down Expand Up @@ -656,11 +664,16 @@ def test_udt_datasource1_exception():
_test_datasource1_udt(datasource1_exception)


# def test_agg_basic(unary_agg_func_fixture):
# arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
# result = pc.call_function("y=avg(x)", [arr])
# expected = pa.scalar(30.0)
# assert result == expected
def test_agg_basic(unary_agg_func_fixture):
import sys
arr = pa.array([10.0, 20.0, 30.0, 40.0, 50.0], pa.float64())
print("Before calling agg: ", sys.getrefcount(unary_agg_func_fixture[0]))
result = pc.call_function("y=avg(x)", [arr])
print("After calling agg: ", sys.getrefcount(unary_agg_func_fixture[0]))
pc.call_function("y=avg(x)", [arr])
print("After second call agg: ", sys.getrefcount(unary_agg_func_fixture[0]))
expected = pa.scalar(30.0)
assert result == expected


# def test_agg_empty(unary_agg_func_fixture):
Expand Down Expand Up @@ -694,7 +707,9 @@ def test_udt_datasource1_exception():


# def test_agg_exception(exception_agg_func_fixture):
# arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64())

# with pytest.raises(RuntimeError, match='Oops'):
# pc.call_function("y=exception_len(x)", [arr])
# arr = pa.array([10, 20, 30, 40, 50, 60], pa.int64())
# import sys
# print("Before calling agg exception: ", sys.getrefcount(exception_agg_func_fixture[0]))
# with pytest.raises(RuntimeError, match='Oops'):
# pc.call_function("y=exception_len(x)", [arr])
# print("After calling agg exception: ", sys.getrefcount(exception_agg_func_fixture[0]))

0 comments on commit 7ba9cc8

Please sign in to comment.