Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ext: Expect tracer provider instead of tracer in integrations #602

Merged
merged 14 commits into from
Apr 23, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def run():
Expand All @@ -73,7 +72,7 @@ def run():
# of the code.
with grpc.insecure_channel("localhost:50051") as channel:

channel = intercept_channel(channel, client_interceptor(tracer))
channel = intercept_channel(channel, client_interceptor())

stub = helloworld_pb2_grpc.GreeterStub(channel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


class Greeter(helloworld_pb2_grpc.GreeterServicer):
Expand All @@ -75,7 +74,7 @@ def SayHello(self, request, context):
def serve():

server = grpc.server(futures.ThreadPoolExecutor())
server = intercept_server(server, server_interceptor(tracer))
server = intercept_server(server, server_interceptor())

helloworld_pb2_grpc.add_GreeterServicer_to_server(Greeter(), server)
server.add_insecure_port("[::]:50051")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def make_route_note(message, latitude, longitude):
Expand Down Expand Up @@ -154,7 +153,7 @@ def run():
# used in circumstances in which the with statement does not fit the needs
# of the code.
with grpc.insecure_channel("localhost:50051") as channel:
channel = intercept_channel(channel, client_interceptor(tracer))
channel = intercept_channel(channel, client_interceptor())

stub = route_guide_pb2_grpc.RouteGuideStub(channel)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@
trace.get_tracer_provider().add_span_processor(
SimpleExportSpanProcessor(ConsoleSpanExporter())
)
tracer = trace.get_tracer(__name__)


def get_feature(feature_db, point):
Expand Down Expand Up @@ -164,7 +163,7 @@ def RouteChat(self, request_iterator, context):

def serve():
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
server = intercept_server(server, server_interceptor(tracer))
server = intercept_server(server, server_interceptor())

route_guide_pb2_grpc.add_RouteGuideServicer_to_server(
RouteGuideServicer(), server
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,15 +24,16 @@
import mysql.connector
import pyodbc

from opentelemetry import trace
from opentelemetry.ext.dbapi import trace_integration
from opentelemetry.trace import TracerProvider

trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)

# Ex: mysql.connector
trace_integration(tracer, mysql.connector, "connect", "mysql", "sql")
trace_integration(mysql.connector, "connect", "mysql", "sql")
# Ex: pyodbc
trace_integration(tracer, pyodbc, "Connection", "odbc", "sql")
trace_integration(pyodbc, "Connection", "odbc", "sql")

API
---
Expand All @@ -44,13 +45,44 @@

import wrapt

from opentelemetry.trace import SpanKind, Tracer
from opentelemetry.ext.dbapi.version import __version__
from opentelemetry.trace import SpanKind, Tracer, TracerProvider, get_tracer
from opentelemetry.trace.status import Status, StatusCanonicalCode

logger = logging.getLogger(__name__)


def trace_integration(
connect_module: typing.Callable[..., any],
connect_method_name: str,
database_component: str,
database_type: str = "",
connection_attributes: typing.Dict = None,
tracer_provider: typing.Optional[TracerProvider] = None,
):
"""Integrate with DB API library.
https://www.python.org/dev/peps/pep-0249/

Args:
connect_module: Module name where connect method is available.
connect_method_name: The connect method name.
database_component: Database driver name or database name "JDBI", "jdbc", "odbc", "postgreSQL".
database_type: The Database type. For any SQL database, "sql".
connection_attributes: Attribute names for database, port, host and user in Connection object.
tracer_provider: The :class:`TracerProvider` to use. If ommited the current configured one is used.
"""
tracer = get_tracer(__name__, __version__, tracer_provider)
wrap_connect(
tracer,
connect_module,
connect_method_name,
database_component,
database_type,
connection_attributes,
)


def wrap_connect(
tracer: Tracer,
connect_module: typing.Callable[..., any],
connect_method_name: str,
Expand All @@ -71,7 +103,7 @@ def trace_integration(
"""

# pylint: disable=unused-argument
def wrap_connect(
def wrap_connect_(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe a slight style thing, but it's a little harder for me to spot a trailing underscore rather than a leading underscore.

wrapped: typing.Callable[..., any],
instance: typing.Any,
args: typing.Tuple[any, any],
Expand All @@ -87,7 +119,7 @@ def wrap_connect(

try:
wrapt.wrap_function_wrapper(
connect_module, connect_method_name, wrap_connect
connect_module, connect_method_name, wrap_connect_
)
except Exception as ex: # pylint: disable=broad-except
logger.warning("Failed to integrate with DB API. %s", str(ex))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls):
cls._connection = None
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
cls._connection = mysql.connector.connect(
user=MYSQL_USER,
password=MYSQL_PASSWORD,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def setUpClass(cls):
cls._connection = None
cls._cursor = None
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
cls._connection = psycopg2.connect(
dbname=POSTGRES_DB_NAME,
user=POSTGRES_USER,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class TestFunctionalPymongo(TestBase):
def setUpClass(cls):
super().setUpClass()
cls._tracer = cls.tracer_provider.get_tracer(__name__)
trace_integration(cls._tracer)
trace_integration(cls.tracer_provider)
client = MongoClient(
MONGODB_HOST, MONGODB_PORT, serverSelectionTimeoutMS=2000
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
# pylint:disable=no-name-in-module
# pylint:disable=relative-beyond-top-level

from opentelemetry import trace
from opentelemetry.ext.grpc.version import __version__

def client_interceptor(tracer):

def client_interceptor(tracer_provider=None):
"""Create a gRPC client channel interceptor.

Args:
Expand All @@ -29,10 +32,12 @@ def client_interceptor(tracer):
"""
from . import _client

tracer = trace.get_tracer(__name__, __version__, tracer_provider)

return _client.OpenTelemetryClientInterceptor(tracer)


def server_interceptor(tracer):
def server_interceptor(tracer_provider=None):
"""Create a gRPC server interceptor.

Args:
Expand All @@ -43,4 +48,6 @@ def server_interceptor(tracer):
"""
from . import _server

tracer = trace.get_tracer(__name__, __version__, tracer_provider)

return _server.OpenTelemetryServerInterceptor(tracer)
26 changes: 12 additions & 14 deletions ext/opentelemetry-ext-grpc/tests/test_server_interceptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import grpc

import opentelemetry.ext.grpc
from opentelemetry import trace
from opentelemetry.ext.grpc import server_interceptor
from opentelemetry.ext.grpc.grpcext import intercept_server
Expand Down Expand Up @@ -48,15 +49,11 @@ def service(self, handler_call_details):


class TestOpenTelemetryServerInterceptor(TestBase):
def setUp(self):
super().setUp()
self.tracer = self.tracer_provider.get_tracer(__name__)

def test_create_span(self):
"""Check that the interceptor wraps calls with spans server-side."""

# Intercept gRPC calls...
interceptor = server_interceptor(self.tracer)
interceptor = server_interceptor()

# No-op RPC handler
def handler(request, context):
Expand Down Expand Up @@ -87,18 +84,21 @@ def handler(request, context):
self.assertEqual(span.name, "")
self.assertIs(span.kind, trace.SpanKind.SERVER)

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(span, opentelemetry.ext.grpc)

def test_span_lifetime(self):
"""Check that the span is active for the duration of the call."""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
interceptor = server_interceptor(tracer)
interceptor = server_interceptor()
tracer = self.tracer_provider.get_tracer(__name__)

# To capture the current span at the time the handler is called
active_span_in_handler = None

def handler(request, context):
nonlocal active_span_in_handler
# The current span is shared among all the tracers.
active_span_in_handler = tracer.get_current_span()
return b""

Expand Down Expand Up @@ -128,10 +128,9 @@ def handler(request, context):
def test_sequential_server_spans(self):
"""Check that sequential RPCs get separate server spans."""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
tracer = self.tracer_provider.get_tracer(__name__)

interceptor = server_interceptor(tracer)
interceptor = server_interceptor()

# Capture the currently active span in each thread
active_spans_in_handler = []
Expand Down Expand Up @@ -176,10 +175,9 @@ def test_concurrent_server_spans(self):
context.
"""

tracer_provider = trace_sdk.TracerProvider()
tracer = tracer_provider.get_tracer(__name__)
tracer = self.tracer_provider.get_tracer(__name__)

interceptor = server_interceptor(tracer)
interceptor = server_interceptor()

# Capture the currently active span in each thread
active_spans_in_handler = []
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,8 @@
from opentelemetry.ext.mysql import trace_integration

trace.set_tracer_provider(TracerProvider())
tracer = trace.get_tracer(__name__)

trace_integration(tracer)
trace_integration()
cnx = mysql.connector.connect(database='MySQL_Database')
cursor = cnx.cursor()
cursor.execute("INSERT INTO test (testField) VALUES (123)"
Expand All @@ -42,23 +41,29 @@
---
"""

import typing

import mysql.connector

from opentelemetry.ext.dbapi import trace_integration as db_integration
from opentelemetry.trace import Tracer
from opentelemetry.ext.dbapi import wrap_connect
from opentelemetry.ext.mysql.version import __version__
from opentelemetry.trace import TracerProvider, get_tracer


def trace_integration(tracer: Tracer):
def trace_integration(tracer_provider: typing.Optional[TracerProvider] = None):
"""Integrate with MySQL Connector/Python library.
https://dev.mysql.com/doc/connector-python/en/
"""

tracer = get_tracer(__name__, __version__, tracer_provider)

connection_attributes = {
"database": "database",
"port": "server_port",
"host": "server_host",
"user": "user",
}
db_integration(
wrap_connect(
tracer,
mysql.connector,
"connect",
Expand Down
33 changes: 28 additions & 5 deletions ext/opentelemetry-ext-mysql/tests/test_mysql_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,43 @@

import mysql.connector

from opentelemetry.ext.mysql import trace_integration
import opentelemetry.ext.mysql
from opentelemetry.sdk import resources
from opentelemetry.test.test_base import TestBase


class TestMysqlIntegration(TestBase):
def test_trace_integration(self):
tracer = self.tracer_provider.get_tracer(__name__)
with mock.patch("mysql.connector.connect") as mock_connect:
mock_connect.get.side_effect = mysql.connector.MySQLConnection()
opentelemetry.ext.mysql.trace_integration()
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
query = "SELECT * FROM test"
cursor.execute(query)

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.check_span_instrumentation_info(span, opentelemetry.ext.mysql)

def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
result = self.create_tracer_provider(resource=resource)
tracer_provider, exporter = result

with mock.patch("mysql.connector.connect") as mock_connect:
mock_connect.get.side_effect = mysql.connector.MySQLConnection()
trace_integration(tracer)
opentelemetry.ext.mysql.trace_integration(tracer_provider)
cnx = mysql.connector.connect(database="test")
cursor = cnx.cursor()
query = "SELECT * FROM test"
cursor.execute(query)
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

span_list = exporter.get_finished_spans()
self.assertEqual(len(span_list), 1)
span = span_list[0]

self.assertIs(span.resource, resource)
Loading