Skip to content

Commit

Permalink
psycopg: fix running async tests
Browse files Browse the repository at this point in the history
Tests were not run because the test case did not await on async tests.
Split tests in two testcases: one for sync tests and one for async tests.
  • Loading branch information
xrmx committed May 27, 2024
1 parent fc04f5d commit 61f94fe
Showing 1 changed file with 115 additions and 104 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import types
from unittest import mock
from unittest import IsolatedAsyncioTestCase, mock

import psycopg

Expand Down Expand Up @@ -114,6 +113,10 @@ def cursor(self):
return cur
return MockAsyncCursor()

def execute(self, query, params=None, *, prepare=None, binary=False):
cur = self.cursor()
return cur.execute(query, params, prepare=prepare)

def get_dsn_parameters(self): # pylint: disable=no-self-use
return {"dbname": "test"}

Expand All @@ -124,7 +127,8 @@ async def __aexit__(self, *args):
return mock.MagicMock(spec=types.MethodType)


class TestPostgresqlIntegration(TestBase):
class PostgresqlIntegrationTestMixin:
# pylint: disable=invalid-name
def setUp(self):
super().setUp()
self.cursor_mock = mock.patch(
Expand All @@ -148,6 +152,7 @@ def setUp(self):
self.connection_sync_mock.start()
self.connection_async_mock.start()

# pylint: disable=invalid-name
def tearDown(self):
super().tearDown()
self.memory_exporter.clear()
Expand All @@ -159,6 +164,8 @@ def tearDown(self):
with self.disable_logging():
PsycopgInstrumentor().uninstrument()


class TestPostgresqlIntegration(PostgresqlIntegrationTestMixin, TestBase):
# pylint: disable=unused-argument
def test_instrumentor(self):
PsycopgInstrumentor().instrument()
Expand Down Expand Up @@ -221,60 +228,6 @@ def test_instrumentor_with_connection_class(self):
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

async def test_wrap_async_connection_class_with_cursor(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect(database="test")
async with acnx as cnx:
async with cnx.cursor() as cursor:
await cursor.execute("SELECT * FROM test")

asyncio.run(test_async_connection())
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

asyncio.run(test_async_connection())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
async def test_instrumentor_with_async_connection_class(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect(database="test")
async with acnx as cnx:
await cnx.execute("SELECT * FROM test")

asyncio.run(test_async_connection())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()
asyncio.run(test_async_connection())

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

def test_span_name(self):
PsycopgInstrumentor().instrument()

Expand All @@ -301,33 +254,6 @@ def test_span_name(self):
self.assertEqual(spans_list[4].name, "query")
self.assertEqual(spans_list[5].name, "query")

async def test_span_name_async(self):
PsycopgInstrumentor().instrument()

cnx = psycopg.AsyncConnection.connect(database="test")
async with cnx.cursor() as cursor:
await cursor.execute("Test query", ("param1Value", False))
await cursor.execute(
"""multi
line
query"""
)
await cursor.execute("tab\tseparated query")
await cursor.execute("/* leading comment */ query")
await cursor.execute(
"/* leading comment */ query /* trailing comment */"
)
await cursor.execute("query /* trailing comment */")

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 6)
self.assertEqual(spans_list[0].name, "Test")
self.assertEqual(spans_list[1].name, "multi")
self.assertEqual(spans_list[2].name, "tab")
self.assertEqual(spans_list[3].name, "query")
self.assertEqual(spans_list[4].name, "query")
self.assertEqual(spans_list[5].name, "query")

# pylint: disable=unused-argument
def test_not_recording(self):
mock_tracer = mock.Mock()
Expand All @@ -348,26 +274,6 @@ def test_not_recording(self):

PsycopgInstrumentor().uninstrument()

# pylint: disable=unused-argument
async def test_not_recording_async(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
PsycopgInstrumentor().instrument()
with mock.patch("opentelemetry.trace.get_tracer") as tracer:
tracer.return_value = mock_tracer
cnx = psycopg.AsyncConnection.connect(database="test")
async with cnx.cursor() as cursor:
query = "SELECT * FROM test"
cursor.execute(query)
self.assertFalse(mock_span.is_recording())
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

PsycopgInstrumentor().uninstrument()

# pylint: disable=unused-argument
def test_custom_tracer_provider(self):
resource = resources.Resource.create({})
Expand Down Expand Up @@ -477,3 +383,108 @@ def test_sqlcommenter_disabled(self, event_mocked):
cursor.execute(query)
kwargs = event_mocked.call_args[1]
self.assertEqual(kwargs["enable_commenter"], False)


class TestPostgresqlIntegrationAsync(
PostgresqlIntegrationTestMixin, TestBase, IsolatedAsyncioTestCase
):
async def test_wrap_async_connection_class_with_cursor(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
async with cnx.cursor() as cursor:
await cursor.execute("SELECT * FROM test")

await test_async_connection()
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()

await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

# pylint: disable=unused-argument
async def test_instrumentor_with_async_connection_class(self):
PsycopgInstrumentor().instrument()

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
await cnx.execute("SELECT * FROM test")

await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# Check version and name in span's instrumentation info
self.assertEqualSpanInstrumentationInfo(
span, opentelemetry.instrumentation.psycopg
)

# check that no spans are generated after uninstrument
PsycopgInstrumentor().uninstrument()
await test_async_connection()

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)

async def test_span_name_async(self):
PsycopgInstrumentor().instrument()

cnx = await psycopg.AsyncConnection.connect("test")
async with cnx.cursor() as cursor:
await cursor.execute("Test query", ("param1Value", False))
await cursor.execute(
"""multi
line
query"""
)
await cursor.execute("tab\tseparated query")
await cursor.execute("/* leading comment */ query")
await cursor.execute(
"/* leading comment */ query /* trailing comment */"
)
await cursor.execute("query /* trailing comment */")

spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 6)
self.assertEqual(spans_list[0].name, "Test")
self.assertEqual(spans_list[1].name, "multi")
self.assertEqual(spans_list[2].name, "tab")
self.assertEqual(spans_list[3].name, "query")
self.assertEqual(spans_list[4].name, "query")
self.assertEqual(spans_list[5].name, "query")

# pylint: disable=unused-argument
async def test_not_recording_async(self):
mock_tracer = mock.Mock()
mock_span = mock.Mock()
mock_span.is_recording.return_value = False
mock_tracer.start_span.return_value = mock_span
PsycopgInstrumentor().instrument()
with mock.patch("opentelemetry.trace.get_tracer") as tracer:
tracer.return_value = mock_tracer
cnx = await psycopg.AsyncConnection.connect("test")
async with cnx.cursor() as cursor:
query = "SELECT * FROM test"
await cursor.execute(query)
self.assertFalse(mock_span.is_recording())
self.assertTrue(mock_span.is_recording.called)
self.assertFalse(mock_span.set_attribute.called)
self.assertFalse(mock_span.set_status.called)

PsycopgInstrumentor().uninstrument()

0 comments on commit 61f94fe

Please sign in to comment.