Skip to content

Commit

Permalink
fix: catch all exceptions in forked process
Browse files Browse the repository at this point in the history
If base exception is raised the communication queue will block
in the calling process and will cause timeouts. This is prevalent
in unit test for protobuf serialization in Github Actions, caused
by the protobuf compiler work directory existence.
The fix is to use pathlib.Path with exists=True and parents=True
for removing the FileExistsError and also changing the error handling
logic to pass back also BaseExceptions through the multiprocess queue.
  • Loading branch information
jjaakola-aiven authored and eliax1996 committed Jul 24, 2024
1 parent 83e329a commit 8f0a350
Showing 1 changed file with 61 additions and 50 deletions.
111 changes: 61 additions & 50 deletions karapace/protobuf/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from karapace.protobuf.schema import ProtobufSchema
from karapace.protobuf.type_element import TypeElement
from multiprocessing import Process, Queue
from pathlib import Path
from typing import Dict, Final, Generator, Iterable, Protocol
from typing_extensions import Self, TypeAlias

import hashlib
import importlib
import importlib.util
import os
import subprocess
import sys

Expand Down Expand Up @@ -96,29 +96,27 @@ def get_protobuf_class_instance(
class_name: str,
cfg: Config,
) -> _ProtobufModel:
directory = cfg["protobuf_runtime_directory"]
directory = Path(cfg["protobuf_runtime_directory"])
deps_list = crawl_dependencies(schema)
root_class_name = ""
for value in deps_list.values():
root_class_name = root_class_name + value["unique_class_name"]
root_class_name = root_class_name + str(schema)
proto_name = calculate_class_name(root_class_name)

proto_path = f"{proto_name}.proto"
work_dir = f"{directory}/{proto_name}"
if not os.path.isdir(directory):
os.mkdir(directory)
if not os.path.isdir(work_dir):
os.mkdir(work_dir)
class_path = f"{directory}/{proto_name}/{proto_name}_pb2.py"
if not os.path.exists(class_path):
main_proto_filename = f"{proto_name}.proto"
work_dir = directory / Path(proto_name)
work_dir.mkdir(exist_ok=True, parents=True)
class_path = work_dir / Path(f"{proto_name}_pb2.py")

if not class_path.exists():
with open(f"{directory}/{proto_name}/{proto_name}.proto", mode="w", encoding="utf8") as proto_text:
proto_text.write(replace_imports(str(schema), deps_list))

protoc_arguments = [
"protoc",
"--python_out=./",
proto_path,
main_proto_filename,
]
for value in deps_list.values():
proto_file_name = value["unique_class_name"] + ".proto"
Expand All @@ -127,16 +125,18 @@ def get_protobuf_class_instance(
with open(dependency_path, mode="w", encoding="utf8") as proto_text:
proto_text.write(replace_imports(value["schema"], deps_list))

if not os.path.isfile(class_path):
if not class_path.is_file():
subprocess.run(
protoc_arguments,
check=True,
cwd=work_dir,
)

# todo: This will leave residues on sys.path in case of exceptions. If really must
# mutate sys.path, we should at least wrap in try-finally.
sys.path.append(f"./runtime/{proto_name}")
runtime_proto_path = f"./runtime/{proto_name}"
if runtime_proto_path not in sys.path:
# todo: This will leave residues on sys.path in case of exceptions. If really must
# mutate sys.path, we should at least wrap in try-finally.
sys.path.append(runtime_proto_path)
spec = importlib.util.spec_from_file_location(f"{proto_name}_pb2", class_path)
# This is reasonable to assert because we just created this file.
assert spec is not None
Expand Down Expand Up @@ -168,25 +168,25 @@ def read_data(
return class_instance


_ReaderQueue: TypeAlias = "Queue[dict[object, object] | Exception]"
_ReaderQueue: TypeAlias = "Queue[dict[object, object] | BaseException]"


def reader_process(
queue: _ReaderQueue,
reader_queue: _ReaderQueue,
config: Config,
writer_schema: ProtobufSchema,
reader_schema: ProtobufSchema,
bio: BytesIO,
) -> None:
try:
queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True))
# todo: This lint ignore does not look reasonable. If it is, reasoning should be
# documented.
except Exception as e: # pylint: disable=broad-except
queue.put(e)
reader_queue.put(protobuf_to_dict(read_data(config, writer_schema, reader_schema, bio), True))
# Reading happens in the forked process, catch is broad so exception will get communicated
# back to calling process.
except BaseException as base_exception: # pylint: disable=broad-except
reader_queue.put(base_exception)


def reader_mp(
def read_in_forked_multiprocess_process(
config: Config,
writer_schema: ProtobufSchema,
reader_schema: ProtobufSchema,
Expand All @@ -200,14 +200,18 @@ def reader_mp(
# To avoid problem with enum values for basic SerDe support we
# will isolate work with call protobuf libraries in child process.
if __name__ == "karapace.protobuf.io":
queue: _ReaderQueue = Queue()
p = Process(target=reader_process, args=(queue, config, writer_schema, reader_schema, bio))
reader_queue: _ReaderQueue = Queue()
p = Process(target=reader_process, args=(reader_queue, config, writer_schema, reader_schema, bio))
p.start()
result = queue.get()
p.join()
TEN_SECONDS_WAIT = 10
try:
result = reader_queue.get(True, TEN_SECONDS_WAIT)
finally:
p.join()
reader_queue.close()
if isinstance(result, Dict):
return result
if isinstance(result, Exception):
if isinstance(result, BaseException):
raise result
raise IllegalArgumentException()
return {"Error": "This never must be returned"}
Expand All @@ -233,34 +237,37 @@ def __init__(
def read(self, bio: BytesIO) -> dict:
if self._reader_schema is None:
self._reader_schema = self._writer_schema
return reader_mp(self.config, self._writer_schema, self._reader_schema, bio)
return read_in_forked_multiprocess_process(self.config, self._writer_schema, self._reader_schema, bio)


_WriterQueue: TypeAlias = "Queue[bytes | Exception]"
_WriterQueue: TypeAlias = "Queue[bytes | str | BaseException]"


def writer_process(
queue: _WriterQueue,
writer_queue: _WriterQueue,
config: Config,
writer_schema: ProtobufSchema,
message_name: str,
datum: dict,
) -> None:
class_instance = get_protobuf_class_instance(writer_schema, message_name, config)
try:
class_instance = get_protobuf_class_instance(writer_schema, message_name, config)
dict_to_protobuf(class_instance, datum)
# todo: This does not look like a reasonable place to catch any exception,
# especially since we're effectively silencing them.
except Exception:
# pylint: disable=raise-missing-from
e = ProtobufTypeException(writer_schema, datum)
queue.put(e)
raise e
queue.put(class_instance.SerializeToString())


# todo: What is mp? Expand the abbreviation or add an explaining comment.
def writer_mp(
result = class_instance.SerializeToString()
writer_queue.put(result)
# Writing happens in the forked process, catch is broad so exception will get communicated
# back to calling process.
except Exception as bare_exception: # pylint: disable=broad-exception-caught
try:
raise ProtobufTypeException(writer_schema, datum) from bare_exception
except ProtobufTypeException as protobuf_exception:
writer_queue.put(protobuf_exception)
raise protobuf_exception
except BaseException as base_exception: # pylint: disable=broad-exception-caught
writer_queue.put(base_exception)


def write_in_forked_multiprocess_process(
config: Config,
writer_schema: ProtobufSchema,
message_name: str,
Expand All @@ -274,14 +281,18 @@ def writer_mp(
# To avoid problem with enum values for basic SerDe support we
# will isolate work with call protobuf libraries in child process.
if __name__ == "karapace.protobuf.io":
queue: _WriterQueue = Queue()
p = Process(target=writer_process, args=(queue, config, writer_schema, message_name, datum))
writer_queue: _WriterQueue = Queue(1)
p = Process(target=writer_process, args=(writer_queue, config, writer_schema, message_name, datum))
p.start()
result = queue.get()
p.join()
TEN_SECONDS_WAIT = 10
try:
result = writer_queue.get(True, TEN_SECONDS_WAIT) # Block for ten seconds
finally:
p.join()
writer_queue.close()
if isinstance(result, bytes):
return result
if isinstance(result, Exception):
if isinstance(result, BaseException):
raise result
raise IllegalArgumentException()
raise NotImplementedError("Error: Reached unreachable code")
Expand Down Expand Up @@ -309,4 +320,4 @@ def write_index(self, writer: BytesIO) -> None:
write_indexes(writer, [self._message_index])

def write(self, datum: dict[object, object], writer: BytesIO) -> None:
writer.write(writer_mp(self.config, self._writer_schema, self._message_name, datum))
writer.write(write_in_forked_multiprocess_process(self.config, self._writer_schema, self._message_name, datum))

0 comments on commit 8f0a350

Please sign in to comment.