Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(connectors): BI-5975 Connection export #736

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def secret_string_field(
required: bool = True,
allow_none: bool = False,
default: Optional[str] = None,
bi_extra: FieldExtra = FieldExtra(editable=True), # noqa: B008
bi_extra: FieldExtra = FieldExtra(editable=True, export_fake=True), # noqa: B008
) -> ma_fields.String:
return ma_fields.String(
attribute=attribute,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ class ClassicSQLConnectionSchema(ConnectionSchema):
host = DBHostField(attribute="data.host", required=True, bi_extra=FieldExtra(editable=True))
port = ma_fields.Integer(attribute="data.port", required=True, bi_extra=FieldExtra(editable=True))
username = ma_fields.String(attribute="data.username", required=True, bi_extra=FieldExtra(editable=True))
password = secret_string_field(attribute="data.password", bi_extra=FieldExtra(editable=True))
password = secret_string_field(attribute="data.password")
db_name = ma_fields.String(
attribute="data.db_name", allow_none=True, bi_extra=FieldExtra(editable=True), validate=db_name_no_query_params
)
Expand Down
5 changes: 5 additions & 0 deletions lib/dl_api_connector/dl_api_connector/api_schema/extras.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ class EditMode(OperationsMode):
test = enum.auto()


class ExportMode(OperationsMode):
export = enum.auto()


class SchemaKWArgs(TypedDict):
only: Optional[Sequence[str]]
partial: Union[Sequence[str], bool]
Expand All @@ -38,3 +42,4 @@ class FieldExtra:
partial_in: Sequence[OperationsMode] = ()
exclude_in: Sequence[OperationsMode] = ()
editable: Union[bool, Sequence[OperationsMode]] = ()
export_fake: Optional[bool] = False
18 changes: 18 additions & 0 deletions lib/dl_api_connector/dl_api_connector/api_schema/top_level.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import copy
from copy import deepcopy
import itertools
import logging
import os
Expand All @@ -19,6 +20,7 @@
import marshmallow
from marshmallow import (
missing,
post_dump,
post_load,
pre_load,
)
Expand All @@ -28,6 +30,7 @@
from dl_api_connector.api_schema.extras import (
CreateMode,
EditMode,
ExportMode,
FieldExtra,
OperationsMode,
SchemaKWArgs,
Expand Down Expand Up @@ -98,6 +101,13 @@ def all_fields_with_extra_info(cls) -> Iterable[tuple[str, ma_fields.Field, Fiel
if extra is not None:
yield field_name, field, extra

@classmethod
def fieldnames_with_extra_export_fake_info(cls) -> Iterable[str]:
for field_name, field in cls.all_fields_dict().items():
extra = cls.get_field_extra(field)
if extra is not None and extra.export_fake is True:
yield field_name

def _refine_init_kwargs(self, kw_args: SchemaKWArgs, operations_mode: Optional[OperationsMode]) -> SchemaKWArgs:
if operations_mode is None:
return kw_args
Expand Down Expand Up @@ -232,6 +242,14 @@ def pre_load(self, data: dict[str, Any], **_: Any) -> dict[str, Any]:
)
return self.handle_unknown_fields(data)

@post_dump(pass_many=False)
def post_dump(self, data: dict[str, Any], **_: Any) -> dict[str, Any]:
if isinstance(self.operations_mode, ExportMode):
data = deepcopy(data)
for secret_field in self.fieldnames_with_extra_export_fake_info():
data[secret_field] = "******"
return data


_US_ENTRY_TV = TypeVar("_US_ENTRY_TV", bound=USEntry)

Expand Down
16 changes: 16 additions & 0 deletions lib/dl_api_lib/dl_api_lib/app/control_api/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Generic,
Optional,
TypeVar,
final,
)

import attr
Expand Down Expand Up @@ -49,6 +50,12 @@
from dl_core.connection_models import ConnectOptions
from dl_core.us_connection_base import ConnectionBase

from dl_api_lib.app.control_api.resources.connections import (
BIResource,
ConnectionExportItem,
)
from dl_api_lib.app.control_api.resources.connections import ns as connections_namespace


@attr.s(frozen=True)
class EnvSetupResult:
Expand All @@ -62,6 +69,14 @@ class EnvSetupResult:
class ControlApiAppFactory(SRFactoryBuilder, Generic[TControlApiAppSettings], abc.ABC):
_settings: TControlApiAppSettings = attr.ib()

def get_connection_export_resource(self) -> type[BIResource]:
return ConnectionExportItem

@final
def register_additional_handlers(self) -> None:
connection_export_resource = self.get_connection_export_resource()
connections_namespace.add_resource(connection_export_resource, "/export/<connection_id>")

@abc.abstractmethod
def set_up_environment(
self,
Expand Down Expand Up @@ -159,6 +174,7 @@ def create_app(
ma = Marshmallow()
ma.init_app(app)

app.before_first_request(self.register_additional_handlers)
init_apis(app)

return app
22 changes: 22 additions & 0 deletions lib/dl_api_lib/dl_api_lib/app/control_api/resources/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from dl_api_connector.api_schema.extras import (
CreateMode,
EditMode,
ExportMode,
)
from dl_api_lib import exc
from dl_api_lib.api_decorators import schematic_request
Expand Down Expand Up @@ -205,6 +206,27 @@ def put(self, connection_id): # type: ignore # TODO: fix
us_manager.save(conn)


class ConnectionExportItem(BIResource):
@put_to_request_context(endpoint_code="ConnectionGet")
@schematic_request(
ns=ns,
responses={
# 200: ('Success', GetConnectionResponseSchema()),
},
)
def get(self, connection_id: str) -> dict:
conn = self.get_us_manager().get_by_id(connection_id, expected_type=ConnectionBase)
need_permission_on_entry(conn, USPermissionKind.read)
assert isinstance(conn, ConnectionBase)

if not conn.allow_export:
raise exc.UnsupportedForEntityType(f"Connector {conn.conn_type.name} does not support export")

result = GenericConnectionSchema(context=self.get_schema_ctx(ExportMode.export)).dump(conn)
result.update(options=ConnectionOptionsSchema().dump(conn.get_options()))
return result


def _dump_source_templates(tpls) -> dict: # type: ignore # TODO: fix
if tpls is None:
return None # type: ignore # TODO: fix
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@

from dl_api_client.dsmaker.api.http_sync_base import SyncHttpClientBase
from dl_api_lib_testing.connection_base import ConnectionTestBase
from dl_core.us_connection_base import ConnectionBase
from dl_core.us_manager.us_manager_sync import SyncUSManager
from dl_testing.regulated_test import RegulatedTestCase


Expand All @@ -23,6 +25,30 @@ def test_create_connection(
)
assert resp.status_code == 200, resp.json

def test_export_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: Optional[dict[str, str]],
sync_us_manager: SyncUSManager,
) -> None:
conn = sync_us_manager.get_by_id(saved_connection_id, expected_type=ConnectionBase)
assert isinstance(conn, ConnectionBase)

resp = control_api_sync_client.get(
url=f"/api/v1/connections/export/{saved_connection_id}",
headers=bi_headers,
)

if not conn.allow_export:
assert resp.status_code == 400
return

assert resp.status_code == 200, resp.json
if hasattr(conn.data, "password"):
password = resp.json.get("password", None)
assert password == "******"

def test_test_connection(
self,
control_api_sync_client: SyncHttpClientBase,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ class CHYTConnectionSchema(ConnectionMetaMixin, RawSQLLevelMixin, DataExportForb

host = DBHostField(attribute="data.host", required=True, bi_extra=FieldExtra(editable=True))
port = ma.fields.Integer(attribute="data.port", required=True, bi_extra=FieldExtra(editable=True))
token = secret_string_field(attribute="data.token", required=True, bi_extra=FieldExtra(editable=True))
token = secret_string_field(attribute="data.token", required=True)
alias = alias_string_field(attribute="data.alias")
secure = ma.fields.Boolean(attribute="data.secure", bi_extra=FieldExtra(editable=True))
cache_ttl_sec = cache_ttl_field(attribute="data.cache_ttl_sec")
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ class ClickHouseConnectionSchema(
attribute="data.password",
required=False,
allow_none=True,
bi_extra=FieldExtra(editable=True),
)

secure = core_ma_fields.OnOffField(attribute="data.secure", bi_extra=FieldExtra(editable=True))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ class ConnectionClickhouse(ConnectionClickhouseBase):
allowed_source_types = frozenset((SOURCE_TYPE_CH_TABLE, SOURCE_TYPE_CH_SUBSELECT))
allow_dashsql: ClassVar[bool] = True
allow_cache: ClassVar[bool] = True
allow_export: ClassVar[bool] = True
is_always_user_source: ClassVar[bool] = False # TODO: should be `True`, but need some cleanup for that.

def get_data_source_template_templates(self, localizer: Localizer) -> list[DataSourceTemplate]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ class ConnectionMSSQL(ClassicConnectionSQL):
allowed_source_types = frozenset((SOURCE_TYPE_MSSQL_TABLE, SOURCE_TYPE_MSSQL_SUBSELECT))
allow_dashsql: ClassVar[bool] = True
allow_cache: ClassVar[bool] = True
allow_export: ClassVar[bool] = True
is_always_user_source: ClassVar[bool] = True

@attr.s(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ConnectionMySQL(ClassicConnectionSQL):
allowed_source_types = frozenset((SOURCE_TYPE_MYSQL_TABLE, SOURCE_TYPE_MYSQL_SUBSELECT))
allow_dashsql: ClassVar[bool] = True
allow_cache: ClassVar[bool] = True
allow_export: ClassVar[bool] = True
is_always_user_source: ClassVar[bool] = True

@attr.s(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ class ConnectionSQLOracle(ClassicConnectionSQL):
allowed_source_types = frozenset((SOURCE_TYPE_ORACLE_TABLE, SOURCE_TYPE_ORACLE_SUBSELECT))
allow_dashsql: ClassVar[bool] = True
allow_cache: ClassVar[bool] = True
allow_export: ClassVar[bool] = True
is_always_user_source: ClassVar[bool] = True

@attr.s(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
class ConnectionPostgreSQLBase(ClassicConnectionSQL):
has_schema = True
default_schema_name = "public"
allow_export = True

@attr.s(kw_only=True)
class DataModel(ClassicConnectionSQL.DataModel):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,6 @@ class PromQLConnectionSchema(ConnectionMetaMixin, ClassicSQLConnectionSchema):
attribute="data.password",
required=False,
allow_none=True,
bi_extra=FieldExtra(editable=True),
)
path = DBPathField(
attribute="data.path",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class PromQLConnection(ClassicConnectionSQL):
allow_cache: ClassVar[bool] = True
is_always_user_source: ClassVar[bool] = True
allow_dashsql: ClassVar[bool] = True
allow_export: ClassVar[bool] = True
source_type = SOURCE_TYPE_PROMQL

@attr.s(kw_only=True)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@ class SnowFlakeConnectionSchema(ConnectionSchema, RawSQLLevelMixin):
client_secret = secret_string_field(
attribute="data.client_secret",
required=True,
bi_extra=FieldExtra(editable=True),
)
schema = ma_fields.String(
attribute="data.schema",
Expand All @@ -55,7 +54,6 @@ class SnowFlakeConnectionSchema(ConnectionSchema, RawSQLLevelMixin):
refresh_token = secret_string_field(
attribute="data.refresh_token",
required=False,
bi_extra=FieldExtra(editable=True),
)
refresh_token_expire_time = ma_fields.DateTime(
attribute="data.refresh_token_expire_time",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ class YDBConnection(ClassicConnectionSQL):
allow_cache: ClassVar[bool] = True
is_always_user_source: ClassVar[bool] = True
allow_dashsql: ClassVar[bool] = True
allow_export: ClassVar[bool] = True

source_type = SOURCE_TYPE_YDB_TABLE

Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,35 @@
from typing import Optional

from dl_api_client.dsmaker.api.http_sync_base import SyncHttpClientBase
from dl_api_lib_testing.connector.connection_suite import DefaultConnectorConnectionTestSuite
from dl_core.us_connection_base import ConnectionBase
from dl_core.us_manager.us_manager_sync import SyncUSManager

from dl_connector_ydb_tests.db.api.base import YDBConnectionTestBase


class TestYDBConnection(YDBConnectionTestBase, DefaultConnectorConnectionTestSuite):
pass
# a separate test since password=self.data.token
def test_export_connection(
self,
control_api_sync_client: SyncHttpClientBase,
saved_connection_id: str,
bi_headers: Optional[dict[str, str]],
sync_us_manager: SyncUSManager,
) -> None:
conn = sync_us_manager.get_by_id(saved_connection_id, expected_type=ConnectionBase)
assert isinstance(conn, ConnectionBase)

resp = control_api_sync_client.get(
url=f"/api/v1/connections/export/{saved_connection_id}",
headers=bi_headers,
)

if not conn.allow_export:
assert resp.status_code == 400
return

assert resp.status_code == 200, resp.json
if hasattr(conn.data, "token"):
token = resp.json.get("token", None)
assert token == "******"
1 change: 1 addition & 0 deletions lib/dl_core/dl_core/us_connection_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ class ConnectionBase(USEntry, metaclass=abc.ABCMeta):
allowed_source_types: ClassVar[Optional[frozenset[DataSourceType]]] = None
allow_dashsql: ClassVar[bool] = False
allow_cache: ClassVar[bool] = False
allow_export: ClassVar[bool] = False
is_always_internal_source: ClassVar[bool] = False
is_always_user_source: ClassVar[bool] = False

Expand Down
Loading