Skip to content

Commit

Permalink
Fix CR comments and lint test failures
Browse files Browse the repository at this point in the history
  • Loading branch information
oxeye-nikolay committed Sep 19, 2021
1 parent 12a4746 commit fedd768
Show file tree
Hide file tree
Showing 7 changed files with 71 additions and 73 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- `opentelemetry-sdk-extension-aws` Add AWS resource detectors to extension package
([#586](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/586))
- `opentelemetry-instrumentation-asgi`, `opentelemetry-instrumentation-aiohttp-client`, `openetelemetry-instrumentation-fastapi`,
`opentelemetry-instrumentation-starlette`, `opentelemetry-instrumentation-urllib`, `opentelemetry-instrumentation-urllib3` Added `request_hook` and `response_hook` callbacks
`opentelemetry-instrumentation-starlette`, `opentelemetry-instrumentation-urllib`, `opentelemetry-instrumentation-urllib3`,
`opentelemetry-instrumentation-pika` Added `request_hook` and `response_hook` callbacks
([#576](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/576))

### Changed
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,6 @@
---
"""


# pylint: disable=unused-argument
from .pika_instrumentor import PikaInstrumentor
from .version import __version__
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from logging import getLogger
from typing import Any, Callable, Collection, Dict, Optional

from pika.adapters import BlockingConnection
from pika.channel import Channel

from opentelemetry import trace
Expand All @@ -24,25 +25,28 @@
from opentelemetry.trace import Tracer, TracerProvider

_LOG = getLogger(__name__)
CTX_KEY = "__otel_task_span"
_CTX_KEY = "__otel_task_span"

FUNCTIONS_TO_UNINSTRUMENT = ["basic_publish"]
_FUNCTIONS_TO_UNINSTRUMENT = ["basic_publish"]


class PikaInstrumentor(BaseInstrumentor): # type: ignore
# pylint: disable=attribute-defined-outside-init
@staticmethod
def _instrument_consumers(
consumers_dict: Dict[str, Callable[..., Any]], tracer: Tracer
) -> Any:
for key, callback in consumers_dict.items():
decorated_callback = utils.decorate_callback(callback, tracer, key)
decorated_callback = utils._decorate_callback(
callback, tracer, key
)
setattr(decorated_callback, "_original_callback", callback)
consumers_dict[key] = decorated_callback

@staticmethod
def _instrument_basic_publish(channel: Channel, tracer: Tracer) -> None:
original_function = getattr(channel, "basic_publish")
decorated_function = utils.decorate_basic_publish(
decorated_function = utils._decorate_basic_publish(
original_function, channel, tracer
)
setattr(decorated_function, "_original_function", original_function)
Expand All @@ -58,7 +62,7 @@ def _instrument_channel_functions(

@staticmethod
def _uninstrument_channel_functions(channel: Channel) -> None:
for function_name in FUNCTIONS_TO_UNINSTRUMENT:
for function_name in _FUNCTIONS_TO_UNINSTRUMENT:
if not hasattr(channel, function_name):
continue
function = getattr(channel, function_name)
Expand All @@ -69,30 +73,19 @@ def _uninstrument_channel_functions(channel: Channel) -> None:
def instrument_channel(
channel: Channel, tracer_provider: Optional[TracerProvider] = None,
) -> None:
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
channel.__setattr__("__opentelemetry_tracer", tracer)
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
tracer = trace.get_tracer(__name__, __version__, tracer_provider)
channel.__setattr__("__opentelemetry_tracer", tracer)
if channel._impl._consumers:
PikaInstrumentor._instrument_consumers(
channel._impl._consumers, tracer
)
PikaInstrumentor._instrument_channel_functions(channel, tracer)

def _instrument(self, **kwargs: Dict[str, Any]) -> None:
channel: Channel = kwargs.get("channel", None)
if not channel or not isinstance(channel, Channel):
return
tracer_provider: TracerProvider = kwargs.get("tracer_provider", None)
PikaInstrumentor.instrument_channel(
channel, tracer_provider=tracer_provider
)

def _uninstrument(self, **kwargs: Dict[str, Any]) -> None:
channel: Channel = kwargs.get("channel", None)
if not channel or not isinstance(channel, Channel):
return
@staticmethod
def uninstrument_channel(channel: Channel) -> None:
if not hasattr(channel, "_impl"):
_LOG.error("Could not find implementation for provided channel!")
return
Expand All @@ -101,5 +94,24 @@ def _uninstrument(self, **kwargs: Dict[str, Any]) -> None:
channel._impl._consumers[key] = callback._original_callback
PikaInstrumentor._uninstrument_channel_functions(channel)

def _decorate_channel_function(
self, tracer_provider: Optional[TracerProvider]
) -> None:
self.original_channel_func = BlockingConnection.channel

def _wrapper(*args, **kwargs):
channel = self.original_channel_func(*args, **kwargs)
self.instrument_channel(channel, tracer_provider=tracer_provider)
return channel

BlockingConnection.channel = _wrapper

def _instrument(self, **kwargs: Dict[str, Any]) -> None:
tracer_provider: TracerProvider = kwargs.get("tracer_provider", None)
self._decorate_channel_function(tracer_provider)

def _uninstrument(self, **kwargs: Dict[str, Any]) -> None:
BlockingConnection.channel = self.original_channel_func

def instrumentation_dependencies(self) -> Collection[str]:
return _instruments
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from opentelemetry.trace.span import Span


class PikaGetter(Getter): # type: ignore
class _PikaGetter(Getter): # type: ignore
def get(self, carrier: CarrierT, key: str) -> Optional[List[str]]:
value = carrier.get(key, None)
if value is None:
Expand All @@ -25,10 +25,10 @@ def keys(self, carrier: CarrierT) -> List[str]:
return []


pika_getter = PikaGetter()
_pika_getter = _PikaGetter()


def decorate_callback(
def _decorate_callback(
callback: Callable[[Channel, Basic.Deliver, BasicProperties, bytes], Any],
tracer: Tracer,
task_name: str,
Expand All @@ -41,7 +41,7 @@ def decorated_callback(
) -> Any:
if not properties:
properties = BasicProperties()
span = get_span(
span = _get_span(
tracer,
channel,
properties,
Expand All @@ -56,7 +56,7 @@ def decorated_callback(
return decorated_callback


def decorate_basic_publish(
def _decorate_basic_publish(
original_function: Callable[[str, str, bytes, BasicProperties, bool], Any],
channel: Channel,
tracer: Tracer,
Expand All @@ -70,7 +70,7 @@ def decorated_function(
) -> Any:
if not properties:
properties = BasicProperties()
span = get_span(
span = _get_span(
tracer,
channel,
properties,
Expand All @@ -92,7 +92,7 @@ def decorated_function(
return decorated_function


def get_span(
def _get_span(
tracer: Tracer,
channel: Channel,
properties: BasicProperties,
Expand All @@ -101,29 +101,29 @@ def get_span(
) -> Optional[Span]:
if properties.headers is None:
properties.headers = {}
ctx = propagate.extract(properties.headers, getter=pika_getter)
ctx = propagate.extract(properties.headers, getter=_pika_getter)
if context.get_value("suppress_instrumentation") or context.get_value(
_SUPPRESS_INSTRUMENTATION_KEY
):
print("Suppressing instrumentation!")
return None
task_name = properties.type if properties.type else task_name
span = tracer.start_span(
context=ctx, name=generate_span_name(task_name, operation)
context=ctx, name=_generate_span_name(task_name, operation)
)
enrich_span(span, channel, properties, task_name, operation)
_enrich_span(span, channel, properties, task_name, operation)
return span


def generate_span_name(
def _generate_span_name(
task_name: str, operation: Optional[MessagingOperationValues]
) -> str:
if not operation:
return f"{task_name} send"
return f"{task_name} {operation.value}"


def enrich_span(
def _enrich_span(
span: Span,
channel: Channel,
properties: BasicProperties,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,12 @@
# limitations under the License.
from unittest import TestCase

from opentelemetry.instrumentation.pika.utils import PikaGetter
from opentelemetry.instrumentation.pika.utils import _PikaGetter


class TestPikaGetter(TestCase):
def setUp(self) -> None:
self.getter = PikaGetter()
self.getter = _PikaGetter()

def test_get_none(self) -> None:
carrier = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.
from unittest import TestCase, mock

from pika.adapters import BaseConnection
from pika.adapters import BaseConnection, BlockingConnection
from pika.channel import Channel

from opentelemetry.instrumentation.pika import PikaInstrumentor
Expand All @@ -27,31 +27,16 @@ def setUp(self) -> None:
self.mock_callback = mock.MagicMock()
self.channel._impl._consumers = {"mock_key": self.mock_callback}

@mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor.instrument_channel"
)
@mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor._uninstrument_channel_functions"
)
def test_instrument_api(
self,
uninstrument_channel_functions: mock.MagicMock,
instrument_channel: mock.MagicMock,
) -> None:
def test_instrument_api(self) -> None:
original_channel = BlockingConnection.channel
instrumentation = PikaInstrumentor()
instrumentation.instrument(channel=self.channel)
instrument_channel.assert_called_once_with(
self.channel, tracer_provider=None
)
self.channel._impl._consumers = {"mock_key": mock.MagicMock()}
self.channel._impl._consumers[
"mock_key"
]._original_callback = self.mock_callback
instrumentation.uninstrument(channel=self.channel)
uninstrument_channel_functions.assert_called_once()
instrumentation.instrument()
self.assertTrue(hasattr(instrumentation, "original_channel_func"))
self.assertEqual(
self.channel._impl._consumers["mock_key"], self.mock_callback
original_channel, instrumentation.original_channel_func
)
instrumentation.uninstrument(channel=self.channel)
self.assertEqual(original_channel, BlockingConnection.channel)

@mock.patch(
"opentelemetry.instrumentation.pika.PikaInstrumentor._instrument_channel_functions"
Expand All @@ -71,7 +56,7 @@ def test_instrument(
instrument_consumers.assert_called_once()
instrument_channel_functions.assert_called_once()

@mock.patch("opentelemetry.instrumentation.pika.utils.decorate_callback")
@mock.patch("opentelemetry.instrumentation.pika.utils._decorate_callback")
def test_instrument_consumers(
self, decorate_callback: mock.MagicMock
) -> None:
Expand All @@ -92,7 +77,7 @@ def test_instrument_consumers(
)

@mock.patch(
"opentelemetry.instrumentation.pika.utils.decorate_basic_publish"
"opentelemetry.instrumentation.pika.utils._decorate_basic_publish"
)
def test_instrument_basic_publish(
self, decorate_basic_publish: mock.MagicMock
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@
class TestUtils(TestCase):
@staticmethod
@mock.patch("opentelemetry.context.get_value")
@mock.patch("opentelemetry.instrumentation.pika.utils.generate_span_name")
@mock.patch("opentelemetry.instrumentation.pika.utils.enrich_span")
@mock.patch("opentelemetry.instrumentation.pika.utils._generate_span_name")
@mock.patch("opentelemetry.instrumentation.pika.utils._enrich_span")
@mock.patch("opentelemetry.propagate.extract")
def test_get_span(
extract: mock.MagicMock,
Expand All @@ -35,7 +35,7 @@ def test_get_span(
properties = mock.MagicMock()
task_name = "test.test"
get_value.return_value = None
span = utils.get_span(tracer, channel, properties, task_name)
span = utils._get_span(tracer, channel, properties, task_name)
extract.assert_called_once()
generate_span_name.assert_called_once()
tracer.start_span.assert_called_once_with(
Expand All @@ -48,8 +48,8 @@ def test_get_span(
), "The returned span was not enriched using enrich_span!"

@mock.patch("opentelemetry.context.get_value")
@mock.patch("opentelemetry.instrumentation.pika.utils.generate_span_name")
@mock.patch("opentelemetry.instrumentation.pika.utils.enrich_span")
@mock.patch("opentelemetry.instrumentation.pika.utils._generate_span_name")
@mock.patch("opentelemetry.instrumentation.pika.utils._enrich_span")
@mock.patch("opentelemetry.propagate.extract")
def test_get_span_suppressed(
self,
Expand All @@ -63,22 +63,22 @@ def test_get_span_suppressed(
properties = mock.MagicMock()
task_name = "test.test"
get_value.return_value = True
span = utils.get_span(tracer, channel, properties, task_name)
span = utils._get_span(tracer, channel, properties, task_name)
self.assertEqual(span, None)
extract.assert_called_once()
generate_span_name.assert_not_called()

def test_generate_span_name_no_operation(self) -> None:
task_name = "test.test"
operation = None
span_name = utils.generate_span_name(task_name, operation)
span_name = utils._generate_span_name(task_name, operation)
self.assertEqual(span_name, f"{task_name} send")

def test_generate_span_name_with_operation(self) -> None:
task_name = "test.test"
operation = mock.MagicMock()
operation.value = "process"
span_name = utils.generate_span_name(task_name, operation)
span_name = utils._generate_span_name(task_name, operation)
self.assertEqual(span_name, f"{task_name} {operation.value}")

@staticmethod
Expand All @@ -87,7 +87,7 @@ def test_enrich_span_basic_values() -> None:
properties = mock.MagicMock()
task_destination = "test.test"
span = mock.MagicMock(spec=Span)
utils.enrich_span(span, channel, properties, task_destination)
utils._enrich_span(span, channel, properties, task_destination)
span.set_attribute.assert_has_calls(
any_order=True,
calls=[
Expand Down Expand Up @@ -121,7 +121,7 @@ def test_enrich_span_with_operation() -> None:
task_destination = "test.test"
operation = mock.MagicMock()
span = mock.MagicMock(spec=Span)
utils.enrich_span(
utils._enrich_span(
span, channel, properties, task_destination, operation
)
span.set_attribute.assert_has_calls(
Expand All @@ -137,7 +137,7 @@ def test_enrich_span_without_operation() -> None:
properties = mock.MagicMock()
task_destination = "test.test"
span = mock.MagicMock(spec=Span)
utils.enrich_span(span, channel, properties, task_destination)
utils._enrich_span(span, channel, properties, task_destination)
span.set_attribute.assert_has_calls(
any_order=True,
calls=[mock.call(SpanAttributes.MESSAGING_TEMP_DESTINATION, True)],
Expand All @@ -151,7 +151,7 @@ def test_enrich_span_unique_connection() -> None:
span = mock.MagicMock(spec=Span)
# We do this to create the behaviour of hasattr(channel.connection, "params") == False
del channel.connection.params
utils.enrich_span(span, channel, properties, task_destination)
utils._enrich_span(span, channel, properties, task_destination)
span.set_attribute.assert_has_calls(
any_order=True,
calls=[
Expand Down

0 comments on commit fedd768

Please sign in to comment.