Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

test: create base class for mockserver tests #1255

Merged
merged 5 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
139 changes: 139 additions & 0 deletions tests/mockserver_tests/mock_server_test_base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2024 Google LLC All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
start_mock_server,
SpannerServicer,
)
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import Client, TypeCode, FixedSizePool
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc


def add_result(sql: str, result: result_set.ResultSet):
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


def add_update_count(
sql: str, count: int, dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
add_result(sql, result)


def add_select1_result():
add_single_result("select 1", "c", TypeCode.INT64, [("1",)])


def add_single_result(
sql: str, column_name: str, type_code: spanner_type.TypeCode, row
):
result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name=column_name,
type=spanner_type.Type(dict(code=type_code)),
)
)
]
)
)
)
),
)
)
result.rows.extend(row)
MockServerTestBase.spanner_service.mock_spanner.add_result(sql, result)


class MockServerTestBase(unittest.TestCase):
server: grpc.Server = None
spanner_service: SpannerServicer = None
database_admin_service: DatabaseAdminServicer = None
port: int = None

def __init__(self, *args, **kwargs):
super(MockServerTestBase, self).__init__(*args, **kwargs)
self._client = None
self._instance = None
self._database = None

@classmethod
def setup_class(cls):
(
MockServerTestBase.server,
MockServerTestBase.spanner_service,
MockServerTestBase.database_admin_service,
MockServerTestBase.port,
) = start_mock_server()

@classmethod
def teardown_class(cls):
if MockServerTestBase.server is not None:
MockServerTestBase.server.stop(grace=None)
MockServerTestBase.server = None

def setup_method(self, *args, **kwargs):
self._client = None
self._instance = None
self._database = None

def teardown_method(self, *args, **kwargs):
MockServerTestBase.spanner_service.clear_requests()
MockServerTestBase.database_admin_service.clear_requests()

@property
def client(self) -> Client:
if self._client is None:
self._client = Client(
project="p",
credentials=AnonymousCredentials(),
client_options=ClientOptions(
api_endpoint="localhost:" + str(MockServerTestBase.port),
),
)
return self._client

@property
def instance(self) -> Instance:
if self._instance is None:
self._instance = self.client.instance("test-instance")
return self._instance

@property
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database", pool=FixedSizePool(size=10)
)
return self._database
121 changes: 8 additions & 113 deletions tests/mockserver_tests/test_basics.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,131 +12,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest

from google.cloud.spanner_admin_database_v1.types import spanner_database_admin
from google.cloud.spanner_dbapi import Connection
from google.cloud.spanner_dbapi.parsed_statement import AutocommitDmlMode
from google.cloud.spanner_v1.testing.mock_database_admin import DatabaseAdminServicer
from google.cloud.spanner_v1.testing.mock_spanner import (
start_mock_server,
SpannerServicer,
)
import google.cloud.spanner_v1.types.type as spanner_type
import google.cloud.spanner_v1.types.result_set as result_set
from google.api_core.client_options import ClientOptions
from google.auth.credentials import AnonymousCredentials
from google.cloud.spanner_v1 import (
Client,
FixedSizePool,
BatchCreateSessionsRequest,
ExecuteSqlRequest,
BeginTransactionRequest,
TransactionOptions,
)
from google.cloud.spanner_v1.database import Database
from google.cloud.spanner_v1.instance import Instance
import grpc


class TestBasics(unittest.TestCase):
server: grpc.Server = None
spanner_service: SpannerServicer = None
database_admin_service: DatabaseAdminServicer = None
port: int = None

def __init__(self, *args, **kwargs):
super(TestBasics, self).__init__(*args, **kwargs)
self._client = None
self._instance = None
self._database = None

@classmethod
def setUpClass(cls):
(
TestBasics.server,
TestBasics.spanner_service,
TestBasics.database_admin_service,
TestBasics.port,
) = start_mock_server()

@classmethod
def tearDownClass(cls):
if TestBasics.server is not None:
TestBasics.server.stop(grace=None)
TestBasics.server = None

def teardown_method(self, *args, **kwargs):
TestBasics.spanner_service.clear_requests()
TestBasics.database_admin_service.clear_requests()

def _add_select1_result(self):
result = result_set.ResultSet(
dict(
metadata=result_set.ResultSetMetadata(
dict(
row_type=spanner_type.StructType(
dict(
fields=[
spanner_type.StructType.Field(
dict(
name="c",
type=spanner_type.Type(
dict(code=spanner_type.TypeCode.INT64)
),
)
)
]
)
)
)
),
)
)
result.rows.extend(["1"])
TestBasics.spanner_service.mock_spanner.add_result("select 1", result)

def add_update_count(
self,
sql: str,
count: int,
dml_mode: AutocommitDmlMode = AutocommitDmlMode.TRANSACTIONAL,
):
if dml_mode == AutocommitDmlMode.PARTITIONED_NON_ATOMIC:
stats = dict(row_count_lower_bound=count)
else:
stats = dict(row_count_exact=count)
result = result_set.ResultSet(dict(stats=result_set.ResultSetStats(stats)))
TestBasics.spanner_service.mock_spanner.add_result(sql, result)

@property
def client(self) -> Client:
if self._client is None:
self._client = Client(
project="test-project",
credentials=AnonymousCredentials(),
client_options=ClientOptions(
api_endpoint="localhost:" + str(TestBasics.port),
),
)
return self._client

@property
def instance(self) -> Instance:
if self._instance is None:
self._instance = self.client.instance("test-instance")
return self._instance
from tests.mockserver_tests.mock_server_test_base import (
MockServerTestBase,
add_select1_result,
add_update_count,
)

@property
def database(self) -> Database:
if self._database is None:
self._database = self.instance.database(
"test-database", pool=FixedSizePool(size=10)
)
return self._database

class TestBasics(MockServerTestBase):
def test_select1(self):
self._add_select1_result()
add_select1_result()
with self.database.snapshot() as snapshot:
results = snapshot.execute_sql("select 1")
result_list = []
Expand Down Expand Up @@ -171,7 +66,7 @@ def test_create_table(self):
# been re-factored to use a base class for the boiler plate code.
def test_dbapi_partitioned_dml(self):
sql = "UPDATE singers SET foo='bar' WHERE active = true"
self.add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
add_update_count(sql, 100, AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
connection = Connection(self.instance, self.database)
connection.autocommit = True
connection.set_autocommit_dml_mode(AutocommitDmlMode.PARTITIONED_NON_ATOMIC)
Expand Down
Loading