Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add python bindings for writing Frames #447

Merged
merged 13 commits into from
Jul 25, 2023
Merged
25 changes: 25 additions & 0 deletions python/podio/base_writer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
#!/usr/bin/env python3
"""Python module for defining the basic writer interface that is used by the
backend specific bindings"""


class BaseWriterMixin:
"""Mixin class that defines the base interface of the writers.

The backend specific writers inherit from here and have to initialize the
following members:
- _writer: The actual writer that is able to write frames
"""

def write_frame(self, frame, category, collections=None):
"""Write the given frame under the passed category, optionally limiting the
collections that are written.

Args:
frame (podio.frame.Frame): The Frame to write
category (str): The category name
collections (optional, default=None): The subset of collections to
write. If None, all collections are written
"""
# pylint: disable-next=protected-access
self._writer.writeFrame(frame._frame, category, collections or frame.collections)
114 changes: 105 additions & 9 deletions python/podio/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,38 @@ def _determine_cpp_type(idx_and_type):
SUPPORTED_PARAMETER_TYPES = _determine_supported_parameter_types()


def _get_cpp_vector_types(type_str):
"""Get the possible std::vector<cpp_type> from the passed py_type string."""
# Gather a list of all types that match the type_str (c++ or python)
def _get_cpp_types(type_str):
"""Get all possible c++ types from the passed py_type string."""
types = list(filter(lambda t: type_str in t, SUPPORTED_PARAMETER_TYPES))
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
if not types:
raise ValueError(f'{type_str} cannot be mapped to a valid parameter type')

return types


def _get_cpp_vector_types(type_str):
"""Get the possible std::vector<cpp_type> from the passed py_type string."""
# Gather a list of all types that match the type_str (c++ or python)
types = _get_cpp_types(type_str)
return [f'std::vector<{t}>' for t in map(lambda x: x[0], types)]
tmadlener marked this conversation as resolved.
Show resolved Hide resolved


def _is_collection_base(thing):
"""Check whether the passed thing is a podio::CollectionBase

Args:
thing (any): any object

Returns:
bool: True if thing is a base of podio::CollectionBase, False otherwise
"""
# Make sure to only instantiate the template with things that cppyy
# understands
if "cppyy" in repr(thing):
return cppyy.gbl.std.is_base_of[cppyy.gbl.podio.CollectionBase, type(thing)].value
return False


class Frame:
"""Frame class that serves as a container of collection and meta data."""

Expand All @@ -78,17 +100,16 @@ def __init__(self, data=None):
else:
self._frame = podio.Frame()

self._collections = tuple(str(s) for s in self._frame.getAvailableCollections())
self._param_key_types = self._init_param_keys()
self._param_key_types = self._get_param_keys_types()

@property
def collections(self):
"""Get the available collection (names) from this Frame.
"""Get the currently available collection (names) from this Frame.

Returns:
tuple(str): The names of the available collections from this Frame.
"""
return self._collections
return tuple(str(s) for s in self._frame.getAvailableCollections())

def get(self, name):
"""Get a collection from the Frame by name.
Expand All @@ -107,9 +128,32 @@ def get(self, name):
raise KeyError(f"Collection '{name}' is not available")
return collection

def put(self, collection, name):
"""Put the collection into the frame

The passed collectoin is "moved" into the Frame, i.e. it cannot be used any
longer after a call to this function. This also means that only objects that
were in the collection at the time of calling this function will be
available afterwards.

Args:
collection (podio.CollectionBase): The collection to put into the Frame
name (str): The name of the collection

Returns:
podio.CollectionBase: The reference to the collection that has been put
into the Frame. NOTE: That mutating this collection is not allowed.

Raises:
ValueError: If collection is not actually a podio.CollectionBase
"""
if not _is_collection_base(collection):
raise ValueError("Can only put podio collections into a Frame")
return self._frame.put(cppyy.gbl.std.move(collection), name)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't tried this but after this using collection probably crashes right? It's not trivial to see this if you don't have the C++ background

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just tried it and indeed it crashes, for this PR is fine but there should be something to prevent those...

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it depends on how you want to use the collection. Indexing into it will crash, because it is effectively empty. Adding new things to it might work, but I haven't tried it.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After a bit of playing things seem to work until you try to write a collection a second time, i.e.

event = Frame()
hits = ExampleHitCollection()
# add hits
frame.put(hits, "Hits")

len(hits)  # == 0
hits.create()
len(hits)  # == 1

frame.put(hits, "moreHits")  # seg-fault

I will update the documentation that collections should not be used after calling put.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like here going out of range may be more of a problem than in C++ (this is not related to this PR but just indexing out of range a std::deque):

event = Frame()
hits = ExampleHitCollection()
hits[0] # seg fault
hits[1] # <cppyy.gbl.MutableExampleHit object at 0x55dcd5d30d50>
hits[1].x() # seg fault
hits.create(0, 1, 2. 3, 4)
hits[0] # <cppyy.gbl.MutableExampleHit object at 0x55dcc4324040>
hits[1] # <cppyy.gbl.MutableExampleHit object at 0x55dcd5d572b0>
hits[2] # seg fault

It's inconsistent since the behavior is undefined, could be a bit hard to debug when you go out of range and it may seem to work but then it doesn't later on. I'm not sure if there are options, not checking in C++ I think is a design choice since there is also .at, but in python no one is going to use that probably

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah in the current form there are not that many handrails in python. Could probably be fixed by another thin python shim that introduces the range check. That could then also be used to raise an error in the use-after-move case.

Created #447 to keep track of this.


@property
def parameters(self):
"""Get the available parameter names from this Frame.
"""Get the currently available parameter names from this Frame.

Returns:
tuple (str): The names of the available parameters from this Frame.
Expand Down Expand Up @@ -163,6 +207,58 @@ def _get_param_value(par_type, name):

return _get_param_value(vec_types[0], name)

def put_parameter(self, key, value, as_type=None):
"""Put a parameter into the Frame.

Puts a parameter into the Frame after doing some (incomplete) type checks.
If a list is passed the parameter type is determined from looking at the
first element of the list only. Additionally, since python doesn't
differentiate between floats and doubles, floats will always be stored as
doubles by default, use the as_type argument to change this if necessary.

Args:
key (str): The name of the parameter
value (int, float, str or list of these): The parameter value
as_type (str, optional): Explicitly specify the type that should be used
to put the parameter into the Frame. Python types (e.g. "str") will
be converted to c++ types. This will override any automatic type
deduction that happens otherwise. Note that this will be taken at
pretty much face-value and there are only limited checks for this.

Raises:
ValueError: If a non-supported parameter type is passed
"""
# For lists we determine the c++ vector type and use that to call the
# correct template overload explicitly
if isinstance(value, (list, tuple)):
type_name = as_type or type(value[0]).__name__
vec_types = _get_cpp_vector_types(type_name)
if len(vec_types) == 0:
raise ValueError(f"Cannot put a parameter of type {type_name} into a Frame")

par_type = vec_types[0]
if isinstance(value[0], float):
# Always store floats as doubles from the python side
par_type = par_type.replace("float", "double")

self._frame.putParameter[par_type](key, value)
else:
if as_type is not None:
tmadlener marked this conversation as resolved.
Show resolved Hide resolved
cpp_types = _get_cpp_types(as_type)
if len(cpp_types) == 0:
raise ValueError(f"Cannot put a parameter of type {as_type} into a Frame")
self._frame.putParameter[cpp_types[0]](key, value)

# If we have a single integer, a std::string overload kicks in with higher
# priority than the template for some reason. So we explicitly select the
# correct template here
elif isinstance(value, int):
self._frame.putParameter["int"](key, value)
else:
self._frame.putParameter(key, value)

self._param_key_types = self._get_param_keys_types() # refresh the cache

def get_parameters(self):
"""Get the complete podio::GenericParameters object stored in this Frame.

Expand Down Expand Up @@ -200,7 +296,7 @@ def get_param_info(self, name):

return par_infos

def _init_param_keys(self):
def _get_param_keys_types(self):
"""Initialize the param keys dict for easier lookup of the available parameters.

Returns:
Expand Down
14 changes: 12 additions & 2 deletions python/podio/root_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from ROOT import podio # noqa: E402 # pylint: disable=wrong-import-position

from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position

Writer = podio.ROOTFrameWriter
from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position


class Reader(BaseReaderMixin):
Expand Down Expand Up @@ -49,3 +48,14 @@ def __init__(self, filenames):
self._is_legacy = True

super().__init__()


class Writer(BaseWriterMixin):
"""Writer class for writing podio root files"""
def __init__(self, filename):
"""Create a writer for writing files

Args:
filename (str): The name of the output file
"""
self._writer = podio.ROOTFrameWriter(filename)
14 changes: 12 additions & 2 deletions python/podio/sio_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from ROOT import podio # noqa: 402 # pylint: disable=wrong-import-position

from podio.base_reader import BaseReaderMixin # pylint: disable=wrong-import-position

Writer = podio.SIOFrameWriter
from podio.base_writer import BaseWriterMixin # pylint: disable=wrong-import-position


class Reader(BaseReaderMixin):
Expand Down Expand Up @@ -46,3 +45,14 @@ def __init__(self, filename):
self._is_legacy = True

super().__init__()


class Writer(BaseWriterMixin):
"""Writer class for writing podio root files"""
def __init__(self, filename):
"""Create a writer for writing files

Args:
filename (str): The name of the output file
"""
self._writer = podio.SIOFrameWriter(filename)
59 changes: 59 additions & 0 deletions python/podio/test_Frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
# using root_io as that should always be present regardless of which backends are built
from podio.root_io import Reader

from podio.test_utils import ExampleHitCollection

# The expected collections in each frame
EXPECTED_COLL_NAMES = {
'arrays', 'WithVectorMember', 'info', 'fixedWidthInts', 'mcparticles',
Expand Down Expand Up @@ -34,6 +36,63 @@ def test_frame_invalid_access(self):
with self.assertRaises(KeyError):
_ = frame.get_parameter('NonExistantParameter')

with self.assertRaises(ValueError):
collection = [1, 2, 4]
_ = frame.put(collection, "invalid_collection_type")

def test_frame_put_collection(self):
"""Check that putting a collection works as expected"""
frame = Frame()
self.assertEqual(frame.collections, tuple())

hits = ExampleHitCollection()
hits.create()
hits2 = frame.put(hits, "hits_from_python")
self.assertEqual(frame.collections, tuple(["hits_from_python"]))
# The original collection is gone at this point, and ideally just leaves an
# empty shell
self.assertEqual(len(hits), 0)
# On the other hand the return value of put has the original content
self.assertEqual(len(hits2), 1)

def test_frame_put_parameters(self):
"""Check that putting a parameter works as expected"""
frame = Frame()
self.assertEqual(frame.parameters, tuple())

frame.put_parameter("a_string_param", "a string")
self.assertEqual(frame.parameters, tuple(["a_string_param"]))
self.assertEqual(frame.get_parameter("a_string_param"), "a string")

frame.put_parameter("float_param", 3.14)
self.assertEqual(frame.get_parameter("float_param"), 3.14)

frame.put_parameter("int", 42)
self.assertEqual(frame.get_parameter("int"), 42)

frame.put_parameter("string_vec", ["a", "b", "cd"])
str_vec = frame.get_parameter("string_vec")
self.assertEqual(len(str_vec), 3)
self.assertEqual(str_vec, ["a", "b", "cd"])

frame.put_parameter("more_ints", [1, 2345])
int_vec = frame.get_parameter("more_ints")
self.assertEqual(len(int_vec), 2)
self.assertEqual(int_vec, [1, 2345])

frame.put_parameter("float_vec", [1.23, 4.56, 7.89])
vec = frame.get_parameter("float_vec", as_type="double")
self.assertEqual(len(vec), 3)
self.assertEqual(vec, [1.23, 4.56, 7.89])

frame.put_parameter("real_float_vec", [1.23, 4.56, 7.89], as_type="float")
f_vec = frame.get_parameter("real_float_vec", as_type="float")
self.assertEqual(len(f_vec), 3)
self.assertEqual(vec, [1.23, 4.56, 7.89])

frame.put_parameter("float_as_float", 3.14, as_type="float")
self.assertAlmostEqual(frame.get_parameter("float_as_float"), 3.14, places=5)


class FrameReadTest(unittest.TestCase):
"""Unit tests for the Frame python bindings for Frames read from file.
Expand Down
54 changes: 53 additions & 1 deletion python/podio/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,57 @@
"""Utilities for python unittests"""

import os
import ROOT
ROOT.gSystem.Load("libTestDataModelDict.so") # noqa: E402
from ROOT import ExampleHitCollection, ExampleClusterCollection # noqa: E402 # pylint: disable=wrong-import-position

SKIP_SIO_TESTS = os.environ.get('SKIP_SIO_TESTS', '1') == '1'
from podio.frame import Frame # pylint: disable=wrong-import-position


SKIP_SIO_TESTS = os.environ.get("SKIP_SIO_TESTS", "1") == "1"


def create_hit_collection():
"""Create a simple hit collection with two hits for testing"""
hits = ExampleHitCollection()
hits.create(0xBAD, 0.0, 0.0, 0.0, 23.0)
hits.create(0xCAFFEE, 1.0, 0.0, 0.0, 12.0)

return hits


def create_cluster_collection():
"""Create a simple cluster collection with two clusters"""
clusters = ExampleClusterCollection()
clu0 = clusters.create()
clu0.energy(3.14)
clu1 = clusters.create()
clu1.energy(1.23)

return clusters


def create_frame():
"""Create a frame with an ExampleHit and an ExampleCluster collection"""
frame = Frame()
hits = create_hit_collection()
frame.put(hits, "hits_from_python")
clusters = create_cluster_collection()
frame.put(clusters, "clusters_from_python")

frame.put_parameter("an_int", 42)
frame.put_parameter("some_floats", [1.23, 7.89, 3.14])
frame.put_parameter("greetings", ["from", "python"])
frame.put_parameter("real_float", 3.14, as_type="float")
frame.put_parameter("more_real_floats", [1.23, 4.56, 7.89], as_type="float")

return frame


def write_file(writer_type, filename):
"""Write a file using the given Writer type and put one Frame into it under
the events category
"""
writer = writer_type(filename)
event = create_frame()
writer.write_frame(event, "events")
4 changes: 4 additions & 0 deletions tests/CTestCustom.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ if ((NOT "@FORCE_RUN_ALL_TESTS@" STREQUAL "ON") AND (NOT "@USE_SANITIZER@" STREQ
read-legacy-files-root_v00-13
read_frame_legacy_root
read_frame_root_multiple
write_python_frame_root
read_python_frame_root

write_frame_root
read_frame_root
Expand All @@ -35,6 +37,8 @@ if ((NOT "@FORCE_RUN_ALL_TESTS@" STREQUAL "ON") AND (NOT "@USE_SANITIZER@" STREQ
write_frame_sio
read_frame_sio
read_frame_legacy_sio
write_python_frame_sio
read_python_frame_sio

write_ascii

Expand Down
Loading
Loading