diff --git a/CHANGELOG.md b/CHANGELOG.md index 981d2fd232..178adfc272 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added +- `opentelemetry-instrumentation-sqlalchemy` Add optional db.statement query sanitization. + ([#1701](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1701)) - Add connection attributes to sqlalchemy connect span ([#1608](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1608)) - Add support for enabling Redis sanitization from environment variable diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py index 77db23b417..c03f51d593 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/__init__.py @@ -134,6 +134,7 @@ def _instrument(self, **kwargs): ``engine``: a SQLAlchemy engine instance ``engines``: a list of SQLAlchemy engine instances ``tracer_provider``: a TracerProvider, defaults to global + ``sanitize_query``: bool to enable/disable db.statement query sanitization, defaults to False Returns: An instrumented engine if passed in as an argument or list of instrumented engines, None otherwise. @@ -151,16 +152,22 @@ def _instrument(self, **kwargs): ) enable_commenter = kwargs.get("enable_commenter", False) + sanitize_query = kwargs.get("sanitize_query", False) + commenter_options = kwargs.get("commenter_options", {}) _w( "sqlalchemy", "create_engine", - _wrap_create_engine(tracer, connections_usage, enable_commenter), + _wrap_create_engine( + tracer, connections_usage, sanitize_query, enable_commenter + ), ) _w( "sqlalchemy.engine", "create_engine", - _wrap_create_engine(tracer, connections_usage, enable_commenter), + _wrap_create_engine( + tracer, connections_usage, sanitize_query, enable_commenter + ), ) _w( "sqlalchemy.engine.base", @@ -180,8 +187,9 @@ def _instrument(self, **kwargs): tracer, kwargs.get("engine"), connections_usage, - kwargs.get("enable_commenter", False), - kwargs.get("commenter_options", {}), + sanitize_query, + enable_commenter, + commenter_options, ) if kwargs.get("engines") is not None and isinstance( kwargs.get("engines"), Sequence @@ -191,8 +199,9 @@ def _instrument(self, **kwargs): tracer, engine, connections_usage, - kwargs.get("enable_commenter", False), - kwargs.get("commenter_options", {}), + sanitize_query, + enable_commenter, + commenter_options, ) for engine in kwargs.get("engines") ] diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py index ca691fc052..49b8cc0289 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/src/opentelemetry/instrumentation/sqlalchemy/engine.py @@ -26,6 +26,87 @@ from opentelemetry.semconv.trace import NetTransportValues, SpanAttributes from opentelemetry.trace.status import Status, StatusCode +sql_reserved_words = [ + "ADD", + "ALL", + "ALTER", + "AND", + "ANY", + "AS", + "ASC", + "BACKUP", + "BETWEEN", + "CASE", + "CHECK", + "COLUMN", + "CONSTRAINT", + "CREATE", + "DATABASE", + "DEFAULT", + "DELETE", + "DESC", + "DISTINCT", + "DROP", + "EXEC", + "EXISTS", + "FOREIGN", + "FROM", + "FULL", + "GROUP", + "BY", + "HAVING", + "IN", + "INDEX", + "INNER", + "INSERT", + "INTO", + "IS", + "JOIN", + "KEY", + "LEFT", + "LIKE", + "LIMIT", + "NOT", + "NULL", + "ON", + "OR", + "ORDER", + "OUTER", + "PRIMARY", + "PROCEDURE", + "RIGHT", + "ROWNUM", + "SELECT", + "SET", + "TABLE", + "TOP", + "TRUNCATE", + "UNION", + "UNIQUE", + "UPDATE", + "VALUES", + "VIEW", + "WHERE", + "=", +] + +sql_reserved_dict = {word: True for word in sql_reserved_words} + + +def _sanitize_query(query): + """Remove query content, replace with sanitization symbol. + For example `SELECT * FROM table` will sanitize to SELECT ? FROM ?` + """ + sanitized_query = "" + if not query: + return sanitized_query + + for word in query.split(): + if word.upper() not in sql_reserved_dict: + word = "?" + sanitized_query += word + " " + return sanitized_query.strip() + def _normalize_vendor(vendor): """Return a canonical name for a type of database.""" @@ -42,7 +123,7 @@ def _normalize_vendor(vendor): def _wrap_create_async_engine( - tracer, connections_usage, enable_commenter=False + tracer, connections_usage, sanitize_query, enable_commenter=False ): # pylint: disable=unused-argument def _wrap_create_async_engine_internal(func, module, args, kwargs): @@ -51,20 +132,28 @@ def _wrap_create_async_engine_internal(func, module, args, kwargs): """ engine = func(*args, **kwargs) EngineTracer( - tracer, engine.sync_engine, connections_usage, enable_commenter + tracer, + engine.sync_engine, + connections_usage, + sanitize_query, + enable_commenter, ) return engine return _wrap_create_async_engine_internal -def _wrap_create_engine(tracer, connections_usage, enable_commenter=False): +def _wrap_create_engine( + tracer, connections_usage, sanitize_query, enable_commenter=False +): def _wrap_create_engine_internal(func, _module, args, kwargs): """Trace the SQLAlchemy engine, creating an `EngineTracer` object that will listen to SQLAlchemy events. """ engine = func(*args, **kwargs) - EngineTracer(tracer, engine, connections_usage, enable_commenter) + EngineTracer( + tracer, engine, connections_usage, sanitize_query, enable_commenter + ) return engine return _wrap_create_engine_internal @@ -95,6 +184,7 @@ def __init__( tracer, engine, connections_usage, + sanitize_query=False, enable_commenter=False, commenter_options=None, ): @@ -104,6 +194,7 @@ def __init__( self.vendor = _normalize_vendor(engine.name) self.enable_commenter = enable_commenter self.commenter_options = commenter_options if commenter_options else {} + self.sanitize_query = sanitize_query self._leading_comment_remover = re.compile(r"^/\*.*?\*/") self._register_event_listener( @@ -200,8 +291,12 @@ def _before_cur_exec( ) with trace.use_span(span, end_on_exit=False): if span.is_recording(): - span.set_attribute(SpanAttributes.DB_STATEMENT, statement) span.set_attribute(SpanAttributes.DB_SYSTEM, self.vendor) + span.set_attribute(SpanAttributes.DB_STATEMENT, statement) + if self.sanitize_query: + span.set_attribute( + SpanAttributes.DB_STATEMENT, _sanitize_query(statement) + ) for key, value in attrs.items(): span.set_attribute(key, value) if self.enable_commenter: diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_engine.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_engine.py new file mode 100644 index 0000000000..fdb1a02935 --- /dev/null +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_engine.py @@ -0,0 +1,43 @@ +# Copyright The OpenTelemetry Authors +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from opentelemetry.instrumentation.sqlalchemy.engine import ( + _normalize_vendor, + _sanitize_query, +) +from opentelemetry.test.test_base import TestBase + + +class TestSQLAlchemyEngine(TestBase): + def test_sql_query_sanitization(self): + sanitized = "SELECT ? FROM ? WHERE ? = ?" + select1 = "SELECT * FROM users WHERE name = 'John'" + select2 = "SELECT * FROM users WHERE name = 'John'" + select3 = "SELECT\t*\tFROM\tusers\tWHERE\tname\t=\t'John'" + + self.assertEqual(_sanitize_query(select1), sanitized) + self.assertEqual(_sanitize_query(select2), sanitized) + self.assertEqual(_sanitize_query(select3), sanitized) + self.assertEqual(_sanitize_query(""), "") + self.assertEqual(_sanitize_query(None), "") + + def test_normalize_vendor(self): + self.assertEqual(_normalize_vendor("mysql"), "mysql") + self.assertEqual(_normalize_vendor("sqlite"), "sqlite") + self.assertEqual(_normalize_vendor("sqlite~12345"), "sqlite") + self.assertEqual(_normalize_vendor("postgres"), "postgresql") + self.assertEqual(_normalize_vendor("postgres 12345"), "postgresql") + self.assertEqual(_normalize_vendor("psycopg2"), "postgresql") + self.assertEqual(_normalize_vendor(""), "db") + self.assertEqual(_normalize_vendor(None), "db") diff --git a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py index 981da107db..bef0035149 100644 --- a/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py +++ b/instrumentation/opentelemetry-instrumentation-sqlalchemy/tests/test_sqlalchemy.py @@ -42,12 +42,16 @@ def test_trace_integration(self): tracer_provider=self.tracer_provider, ) cnx = engine.connect() - cnx.execute("SELECT 1 + 1;").fetchall() - cnx.execute("/* leading comment */ SELECT 1 + 1;").fetchall() - cnx.execute( + select_no_comment = "SELECT\t1 + 1;" + select_leading_comment = "/* leading comment */ SELECT 1 + 1;" + select_trailing_comment = "SELECT 1 + 1; /* trailing comment */" + select_leading_and_trailing_comment = ( "/* leading comment */ SELECT 1 + 1; /* trailing comment */" - ).fetchall() - cnx.execute("SELECT 1 + 1; /* trailing comment */").fetchall() + ) + cnx.execute(select_no_comment).fetchall() + cnx.execute(select_leading_comment).fetchall() + cnx.execute(select_leading_and_trailing_comment).fetchall() + cnx.execute(select_trailing_comment).fetchall() spans = self.memory_exporter.get_finished_spans() self.assertEqual(len(spans), 5) @@ -57,13 +61,28 @@ def test_trace_integration(self): # second span - the query itself self.assertEqual(spans[1].name, "SELECT :memory:") self.assertEqual(spans[1].kind, trace.SpanKind.CLIENT) + self.assertEqual( + spans[1].attributes[SpanAttributes.DB_STATEMENT], select_no_comment + ) # spans for queries with comments self.assertEqual(spans[2].name, "SELECT :memory:") self.assertEqual(spans[2].kind, trace.SpanKind.CLIENT) + self.assertEqual( + spans[2].attributes[SpanAttributes.DB_STATEMENT], + select_leading_comment, + ) self.assertEqual(spans[3].name, "SELECT :memory:") self.assertEqual(spans[3].kind, trace.SpanKind.CLIENT) + self.assertEqual( + spans[3].attributes[SpanAttributes.DB_STATEMENT], + select_leading_and_trailing_comment, + ) self.assertEqual(spans[4].name, "SELECT :memory:") self.assertEqual(spans[4].kind, trace.SpanKind.CLIENT) + self.assertEqual( + spans[4].attributes[SpanAttributes.DB_STATEMENT], + select_trailing_comment, + ) def test_instrument_two_engines(self): engine_1 = create_engine("sqlite:///:memory:") @@ -150,7 +169,9 @@ def test_not_recording(self): self.assertFalse(mock_span.set_status.called) def test_create_engine_wrapper(self): - SQLAlchemyInstrumentor().instrument() + SQLAlchemyInstrumentor().instrument( + sanitize_query=True + ) # verify no side effects from sqlalchemy import create_engine # pylint: disable-all engine = create_engine("sqlite:///:memory:") @@ -164,6 +185,7 @@ def test_create_engine_wrapper(self): self.assertEqual( spans[0].attributes[SpanAttributes.DB_NAME], ":memory:" ) + self.assertEqual( spans[0].attributes[SpanAttributes.DB_SYSTEM], "sqlite" ) @@ -176,6 +198,22 @@ def test_create_engine_wrapper(self): "opentelemetry.instrumentation.sqlalchemy", ) + def test_sanitize_db_statement(self): + SQLAlchemyInstrumentor().instrument(sanitize_query=True) + from sqlalchemy import create_engine # pylint: disable-all + + engine = create_engine("sqlite:///:memory:") + cnx = engine.connect() + cnx.execute("SELECT 1 + 1;").fetchall() + spans = self.memory_exporter.get_finished_spans() + + self.assertEqual(len(spans), 2) + + self.assertEqual( + spans[1].attributes[SpanAttributes.DB_STATEMENT], + "SELECT ? ? ?", + ) + def test_custom_tracer_provider(self): provider = TracerProvider( resource=Resource.create(