Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add sqlalchemy db.statement sanitization flag #1701

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0

### Added

- `opentelemetry-instrumentation-sqlalchemy` Add optional db.statement query sanitization.
([#1701](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1701))
- Add connection attributes to sqlalchemy connect span
([#1608](https://github.com/open-telemetry/opentelemetry-python-contrib/pull/1608))
- Add support for enabling Redis sanitization from environment variable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def _instrument(self, **kwargs):
``engine``: a SQLAlchemy engine instance
``engines``: a list of SQLAlchemy engine instances
``tracer_provider``: a TracerProvider, defaults to global
``sanitize_query``: bool to enable/disable db.statement query sanitization, defaults to False

Returns:
An instrumented engine if passed in as an argument or list of instrumented engines, None otherwise.
Expand All @@ -151,16 +152,22 @@ def _instrument(self, **kwargs):
)

enable_commenter = kwargs.get("enable_commenter", False)
sanitize_query = kwargs.get("sanitize_query", False)
commenter_options = kwargs.get("commenter_options", {})

_w(
"sqlalchemy",
"create_engine",
_wrap_create_engine(tracer, connections_usage, enable_commenter),
_wrap_create_engine(
tracer, connections_usage, sanitize_query, enable_commenter
),
)
_w(
"sqlalchemy.engine",
"create_engine",
_wrap_create_engine(tracer, connections_usage, enable_commenter),
_wrap_create_engine(
tracer, connections_usage, sanitize_query, enable_commenter
),
)
_w(
"sqlalchemy.engine.base",
Expand All @@ -180,8 +187,9 @@ def _instrument(self, **kwargs):
tracer,
kwargs.get("engine"),
connections_usage,
kwargs.get("enable_commenter", False),
kwargs.get("commenter_options", {}),
sanitize_query,
enable_commenter,
commenter_options,
)
if kwargs.get("engines") is not None and isinstance(
kwargs.get("engines"), Sequence
Expand All @@ -191,8 +199,9 @@ def _instrument(self, **kwargs):
tracer,
engine,
connections_usage,
kwargs.get("enable_commenter", False),
kwargs.get("commenter_options", {}),
sanitize_query,
enable_commenter,
commenter_options,
)
for engine in kwargs.get("engines")
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,87 @@
from opentelemetry.semconv.trace import NetTransportValues, SpanAttributes
from opentelemetry.trace.status import Status, StatusCode

sql_reserved_words = [
"ADD",
"ALL",
"ALTER",
"AND",
"ANY",
"AS",
"ASC",
"BACKUP",
"BETWEEN",
"CASE",
"CHECK",
"COLUMN",
"CONSTRAINT",
"CREATE",
"DATABASE",
"DEFAULT",
"DELETE",
"DESC",
"DISTINCT",
"DROP",
"EXEC",
"EXISTS",
"FOREIGN",
"FROM",
"FULL",
"GROUP",
"BY",
"HAVING",
"IN",
"INDEX",
"INNER",
"INSERT",
"INTO",
"IS",
"JOIN",
"KEY",
"LEFT",
"LIKE",
"LIMIT",
"NOT",
"NULL",
"ON",
"OR",
"ORDER",
"OUTER",
"PRIMARY",
"PROCEDURE",
"RIGHT",
"ROWNUM",
"SELECT",
"SET",
"TABLE",
"TOP",
"TRUNCATE",
"UNION",
"UNIQUE",
"UPDATE",
"VALUES",
"VIEW",
"WHERE",
"=",
]

sql_reserved_dict = {word: True for word in sql_reserved_words}


def _sanitize_query(query):
"""Remove query content, replace with sanitization symbol.
For example `SELECT * FROM table` will sanitize to SELECT ? FROM ?`
"""
sanitized_query = ""
if not query:
return sanitized_query

for word in query.split():
if word.upper() not in sql_reserved_dict:
word = "?"
sanitized_query += word + " "
Comment on lines +104 to +107
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

While this is better than the last change I still find this naive because I can think of some WHERE clause cases this breaks. I wonder if there is a sqlalchemy does this sanitization work more reliably?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The idea behind this is that anything which isn't a reserved word will be sanitized. Can you clarify with an example of where clause case which breaks this?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Imagine a scenario WHERE clause filter containing spaces and one of these reserved words. How does it behave for the following query

SELECT * FROM table WHERE column_name="PRIMARY BANKING STOCK INDEX COLLAPSE". 

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm yes, I see what you're saying. This will result with SELECT ? FROM table WHERE ? ? ? INDEX ?. I still think the SELECT ? ? option is the safest

return sanitized_query.strip()


def _normalize_vendor(vendor):
"""Return a canonical name for a type of database."""
Expand All @@ -42,7 +123,7 @@ def _normalize_vendor(vendor):


def _wrap_create_async_engine(
tracer, connections_usage, enable_commenter=False
tracer, connections_usage, sanitize_query, enable_commenter=False
):
# pylint: disable=unused-argument
def _wrap_create_async_engine_internal(func, module, args, kwargs):
Expand All @@ -51,20 +132,28 @@ def _wrap_create_async_engine_internal(func, module, args, kwargs):
"""
engine = func(*args, **kwargs)
EngineTracer(
tracer, engine.sync_engine, connections_usage, enable_commenter
tracer,
engine.sync_engine,
connections_usage,
sanitize_query,
enable_commenter,
)
return engine

return _wrap_create_async_engine_internal


def _wrap_create_engine(tracer, connections_usage, enable_commenter=False):
def _wrap_create_engine(
tracer, connections_usage, sanitize_query, enable_commenter=False
):
def _wrap_create_engine_internal(func, _module, args, kwargs):
"""Trace the SQLAlchemy engine, creating an `EngineTracer`
object that will listen to SQLAlchemy events.
"""
engine = func(*args, **kwargs)
EngineTracer(tracer, engine, connections_usage, enable_commenter)
EngineTracer(
tracer, engine, connections_usage, sanitize_query, enable_commenter
)
return engine

return _wrap_create_engine_internal
Expand Down Expand Up @@ -95,6 +184,7 @@ def __init__(
tracer,
engine,
connections_usage,
sanitize_query=False,
enable_commenter=False,
commenter_options=None,
):
Expand All @@ -104,6 +194,7 @@ def __init__(
self.vendor = _normalize_vendor(engine.name)
self.enable_commenter = enable_commenter
self.commenter_options = commenter_options if commenter_options else {}
self.sanitize_query = sanitize_query
self._leading_comment_remover = re.compile(r"^/\*.*?\*/")

self._register_event_listener(
Expand Down Expand Up @@ -200,8 +291,12 @@ def _before_cur_exec(
)
with trace.use_span(span, end_on_exit=False):
if span.is_recording():
span.set_attribute(SpanAttributes.DB_STATEMENT, statement)
span.set_attribute(SpanAttributes.DB_SYSTEM, self.vendor)
span.set_attribute(SpanAttributes.DB_STATEMENT, statement)
if self.sanitize_query:
span.set_attribute(
SpanAttributes.DB_STATEMENT, _sanitize_query(statement)
)
for key, value in attrs.items():
span.set_attribute(key, value)
if self.enable_commenter:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright The OpenTelemetry Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from opentelemetry.instrumentation.sqlalchemy.engine import (
_normalize_vendor,
_sanitize_query,
)
from opentelemetry.test.test_base import TestBase


class TestSQLAlchemyEngine(TestBase):
def test_sql_query_sanitization(self):
sanitized = "SELECT ? FROM ? WHERE ? = ?"
select1 = "SELECT * FROM users WHERE name = 'John'"
select2 = "SELECT * FROM users WHERE name = 'John'"
select3 = "SELECT\t*\tFROM\tusers\tWHERE\tname\t=\t'John'"

self.assertEqual(_sanitize_query(select1), sanitized)
self.assertEqual(_sanitize_query(select2), sanitized)
self.assertEqual(_sanitize_query(select3), sanitized)
self.assertEqual(_sanitize_query(""), "")
self.assertEqual(_sanitize_query(None), "")

def test_normalize_vendor(self):
self.assertEqual(_normalize_vendor("mysql"), "mysql")
self.assertEqual(_normalize_vendor("sqlite"), "sqlite")
self.assertEqual(_normalize_vendor("sqlite~12345"), "sqlite")
self.assertEqual(_normalize_vendor("postgres"), "postgresql")
self.assertEqual(_normalize_vendor("postgres 12345"), "postgresql")
self.assertEqual(_normalize_vendor("psycopg2"), "postgresql")
self.assertEqual(_normalize_vendor(""), "db")
self.assertEqual(_normalize_vendor(None), "db")
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,16 @@ def test_trace_integration(self):
tracer_provider=self.tracer_provider,
)
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
cnx.execute("/* leading comment */ SELECT 1 + 1;").fetchall()
cnx.execute(
select_no_comment = "SELECT\t1 + 1;"
select_leading_comment = "/* leading comment */ SELECT 1 + 1;"
select_trailing_comment = "SELECT 1 + 1; /* trailing comment */"
select_leading_and_trailing_comment = (
"/* leading comment */ SELECT 1 + 1; /* trailing comment */"
).fetchall()
cnx.execute("SELECT 1 + 1; /* trailing comment */").fetchall()
)
cnx.execute(select_no_comment).fetchall()
cnx.execute(select_leading_comment).fetchall()
cnx.execute(select_leading_and_trailing_comment).fetchall()
cnx.execute(select_trailing_comment).fetchall()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 5)
Expand All @@ -57,13 +61,28 @@ def test_trace_integration(self):
# second span - the query itself
self.assertEqual(spans[1].name, "SELECT :memory:")
self.assertEqual(spans[1].kind, trace.SpanKind.CLIENT)
self.assertEqual(
spans[1].attributes[SpanAttributes.DB_STATEMENT], select_no_comment
)
# spans for queries with comments
self.assertEqual(spans[2].name, "SELECT :memory:")
self.assertEqual(spans[2].kind, trace.SpanKind.CLIENT)
self.assertEqual(
spans[2].attributes[SpanAttributes.DB_STATEMENT],
select_leading_comment,
)
self.assertEqual(spans[3].name, "SELECT :memory:")
self.assertEqual(spans[3].kind, trace.SpanKind.CLIENT)
self.assertEqual(
spans[3].attributes[SpanAttributes.DB_STATEMENT],
select_leading_and_trailing_comment,
)
self.assertEqual(spans[4].name, "SELECT :memory:")
self.assertEqual(spans[4].kind, trace.SpanKind.CLIENT)
self.assertEqual(
spans[4].attributes[SpanAttributes.DB_STATEMENT],
select_trailing_comment,
)

def test_instrument_two_engines(self):
engine_1 = create_engine("sqlite:///:memory:")
Expand Down Expand Up @@ -150,7 +169,9 @@ def test_not_recording(self):
self.assertFalse(mock_span.set_status.called)

def test_create_engine_wrapper(self):
SQLAlchemyInstrumentor().instrument()
SQLAlchemyInstrumentor().instrument(
sanitize_query=True
) # verify no side effects
from sqlalchemy import create_engine # pylint: disable-all

engine = create_engine("sqlite:///:memory:")
Expand All @@ -164,6 +185,7 @@ def test_create_engine_wrapper(self):
self.assertEqual(
spans[0].attributes[SpanAttributes.DB_NAME], ":memory:"
)

self.assertEqual(
spans[0].attributes[SpanAttributes.DB_SYSTEM], "sqlite"
)
Expand All @@ -176,6 +198,22 @@ def test_create_engine_wrapper(self):
"opentelemetry.instrumentation.sqlalchemy",
)

def test_sanitize_db_statement(self):
SQLAlchemyInstrumentor().instrument(sanitize_query=True)
from sqlalchemy import create_engine # pylint: disable-all

engine = create_engine("sqlite:///:memory:")
cnx = engine.connect()
cnx.execute("SELECT 1 + 1;").fetchall()
spans = self.memory_exporter.get_finished_spans()

self.assertEqual(len(spans), 2)

self.assertEqual(
spans[1].attributes[SpanAttributes.DB_STATEMENT],
"SELECT ? ? ?",
)

def test_custom_tracer_provider(self):
provider = TracerProvider(
resource=Resource.create(
Expand Down