diff --git a/CHANGELOG.md b/CHANGELOG.md index 5f4c305ae74..55dd541b19d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -16,6 +16,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ([#2153](https://github.com/open-telemetry/opentelemetry-python/pull/2153)) - Add metrics API ([#1887](https://github.com/open-telemetry/opentelemetry-python/pull/1887)) +- Make batch processor fork aware and reinit when needed + ([#2242](https://github.com/open-telemetry/opentelemetry-python/pull/2242)) ## [1.6.2-0.25b2](https://github.com/open-telemetry/opentelemetry-python/releases/tag/v1.6.2-0.25b2) - 2021-10-19 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/_logs/export/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/_logs/export/__init__.py index f65c967534b..c705c2b2497 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/_logs/export/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/_logs/export/__init__.py @@ -16,6 +16,7 @@ import collections import enum import logging +import os import sys import threading from os import linesep @@ -154,6 +155,17 @@ def __init__( None ] * self._max_export_batch_size # type: List[Optional[LogData]] self._worker_thread.start() + # Only available in *nix since py37. + if hasattr(os, "register_at_fork"): + os.register_at_fork( + after_in_child=self._at_fork_reinit + ) # pylint: disable=protected-access + + def _at_fork_reinit(self): + self._condition = threading.Condition(threading.Lock()) + self._queue.clear() + self._worker_thread = threading.Thread(target=self.worker, daemon=True) + self._worker_thread.start() def worker(self): timeout = self._schedule_delay_millis / 1e3 diff --git a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py index 4f0cc817c9f..d40bb4968c0 100644 --- a/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py +++ b/opentelemetry-sdk/src/opentelemetry/sdk/trace/export/__init__.py @@ -14,6 +14,7 @@ import collections import logging +import os import sys import threading import typing @@ -197,6 +198,11 @@ def __init__( None ] * self.max_export_batch_size # type: typing.List[typing.Optional[Span]] self.worker_thread.start() + # Only available in *nix since py37. + if hasattr(os, "register_at_fork"): + os.register_at_fork( + after_in_child=self._at_fork_reinit + ) # pylint: disable=protected-access def on_start( self, span: Span, parent_context: typing.Optional[Context] = None @@ -220,6 +226,17 @@ def on_end(self, span: ReadableSpan) -> None: with self.condition: self.condition.notify() + def _at_fork_reinit(self): + self.condition = threading.Condition(threading.Lock()) + self.queue.clear() + + # worker_thread is local to a process, only the thread that issued fork continues + # to exist. A new worker thread must be started in child process. + self.worker_thread = threading.Thread( + name="OtelBatchSpanProcessor", target=self.worker, daemon=True + ) + self.worker_thread.start() + def worker(self): timeout = self.schedule_delay_millis / 1e3 flush_request = None # type: typing.Optional[_FlushRequest] diff --git a/opentelemetry-sdk/tests/logs/test_export.py b/opentelemetry-sdk/tests/logs/test_export.py index 964b44f7694..45b83358f93 100644 --- a/opentelemetry-sdk/tests/logs/test_export.py +++ b/opentelemetry-sdk/tests/logs/test_export.py @@ -14,7 +14,9 @@ # pylint: disable=protected-access import logging +import multiprocessing import os +import sys import time import unittest from concurrent.futures import ThreadPoolExecutor @@ -38,6 +40,7 @@ from opentelemetry.sdk._logs.severity import SeverityNumber from opentelemetry.sdk.resources import Resource as SDKResource from opentelemetry.sdk.util.instrumentation import InstrumentationInfo +from opentelemetry.test.concurrency_test import ConcurrencyTestBase from opentelemetry.trace import TraceFlags from opentelemetry.trace.span import INVALID_SPAN_CONTEXT @@ -158,7 +161,7 @@ def test_simple_log_processor_shutdown(self): self.assertEqual(len(finished_logs), 0) -class TestBatchLogProcessor(unittest.TestCase): +class TestBatchLogProcessor(ConcurrencyTestBase): def test_emit_call_log_record(self): exporter = InMemoryLogExporter() log_processor = Mock(wraps=BatchLogProcessor(exporter)) @@ -269,6 +272,54 @@ def bulk_log_and_flush(num_logs): finished_logs = exporter.get_finished_logs() self.assertEqual(len(finished_logs), 2415) + @unittest.skipUnless( + hasattr(os, "fork") and sys.version_info >= (3, 7), + "needs *nix and minor version 7 or later", + ) + def test_batch_log_processor_fork(self): + # pylint: disable=invalid-name + exporter = InMemoryLogExporter() + log_processor = BatchLogProcessor( + exporter, + max_export_batch_size=64, + schedule_delay_millis=10, + ) + provider = LogEmitterProvider() + provider.add_log_processor(log_processor) + + emitter = provider.get_log_emitter(__name__) + logger = logging.getLogger("test-fork") + logger.addHandler(OTLPHandler(log_emitter=emitter)) + + logger.critical("yolo") + time.sleep(0.5) # give some time for the exporter to upload + + self.assertTrue(log_processor.force_flush()) + self.assertEqual(len(exporter.get_finished_logs()), 1) + exporter.clear() + + multiprocessing.set_start_method("fork") + + def child(conn): + def _target(): + logger.critical("Critical message child") + + self.run_with_many_threads(_target, 100) + + time.sleep(0.5) + + logs = exporter.get_finished_logs() + conn.send(len(logs) == 100) + conn.close() + + parent_conn, child_conn = multiprocessing.Pipe() + p = multiprocessing.Process(target=child, args=(child_conn,)) + p.start() + self.assertTrue(parent_conn.recv()) + p.join() + + log_processor.shutdown() + class TestConsoleExporter(unittest.TestCase): def test_export(self): # pylint: disable=no-self-use diff --git a/opentelemetry-sdk/tests/trace/export/test_export.py b/opentelemetry-sdk/tests/trace/export/test_export.py index 2e4672af268..00ccfe44d38 100644 --- a/opentelemetry-sdk/tests/trace/export/test_export.py +++ b/opentelemetry-sdk/tests/trace/export/test_export.py @@ -12,7 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing import os +import sys import threading import time import unittest @@ -30,6 +32,10 @@ OTEL_BSP_SCHEDULE_DELAY, ) from opentelemetry.sdk.trace import export +from opentelemetry.sdk.trace.export.in_memory_span_exporter import ( + InMemorySpanExporter, +) +from opentelemetry.test.concurrency_test import ConcurrencyTestBase class MySpanExporter(export.SpanExporter): @@ -157,7 +163,7 @@ def _create_start_and_end_span(name, span_processor): span.end() -class TestBatchSpanProcessor(unittest.TestCase): +class TestBatchSpanProcessor(ConcurrencyTestBase): @mock.patch.dict( "os.environ", { @@ -356,6 +362,60 @@ def test_batch_span_processor_not_sampled(self): self.assertEqual(len(spans_names_list), 0) span_processor.shutdown() + def _check_fork_trace(self, exporter, expected): + time.sleep(0.5) # give some time for the exporter to upload spans + spans = exporter.get_finished_spans() + for span in spans: + self.assertIn(span.name, expected) + + @unittest.skipUnless( + hasattr(os, "fork") and sys.version_info >= (3, 7), + "needs *nix and minor version 7 or later", + ) + def test_batch_span_processor_fork(self): + # pylint: disable=invalid-name + tracer_provider = trace.TracerProvider() + tracer = tracer_provider.get_tracer(__name__) + + exporter = InMemorySpanExporter() + span_processor = export.BatchSpanProcessor( + exporter, + max_queue_size=256, + max_export_batch_size=64, + schedule_delay_millis=10, + ) + tracer_provider.add_span_processor(span_processor) + with tracer.start_as_current_span("foo"): + pass + time.sleep(0.5) # give some time for the exporter to upload spans + + self.assertTrue(span_processor.force_flush()) + self.assertEqual(len(exporter.get_finished_spans()), 1) + exporter.clear() + + def child(conn): + def _target(): + with tracer.start_as_current_span("span") as s: + s.set_attribute("i", "1") + with tracer.start_as_current_span("temp"): + pass + + self.run_with_many_threads(_target, 100) + + time.sleep(0.5) + + spans = exporter.get_finished_spans() + conn.send(len(spans) == 200) + conn.close() + + parent_conn, child_conn = multiprocessing.Pipe() + p = multiprocessing.Process(target=child, args=(child_conn,)) + p.start() + self.assertTrue(parent_conn.recv()) + p.join() + + span_processor.shutdown() + def test_batch_span_processor_scheduled_delay(self): """Test that spans are exported each schedule_delay_millis""" spans_names_list = []