Skip to content

Commit

Permalink
feat: saved queries in query and compile_sql (#45)
Browse files Browse the repository at this point in the history
* feat: `query` and `compile_sql` with saved queries

This commit adds `saved_query` to `query` and `compile_sql`, so it's now
possible to use them with saved queries. I added some method overrides
so that typecheckers can detect invalid usage of them, and I also
created a method which validates at runtime whether the parameters are
valid before submitting them to the servers.

* docs: update changelog

* fixup! feat: `query` and `compile_sql` with saved queries

* fixup! feat: `query` and `compile_sql` with saved queries
  • Loading branch information
serramatutu authored Sep 20, 2024
1 parent d53a6e1 commit 53aa90e
Show file tree
Hide file tree
Showing 13 changed files with 289 additions and 34 deletions.
3 changes: 3 additions & 0 deletions .changes/unreleased/Features-20240920-201139.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Features
body: Allow saved queries in `query` and `compile_sql`
time: 2024-09-20T20:11:39.216931+02:00
3 changes: 3 additions & 0 deletions .changes/unreleased/Under the Hood-20240920-201151.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
kind: Under the Hood
body: Client-side validation of query parameters
time: 2024-09-20T20:11:51.575942+02:00
6 changes: 2 additions & 4 deletions dbtsl/api/adbc/protocol.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
import json
from typing import Any, FrozenSet, Mapping

from dbtsl.api.shared.query_params import (
DimensionValuesQueryParameters,
QueryParameters,
)
from dbtsl.api.shared.query_params import DimensionValuesQueryParameters, QueryParameters, validate_query_parameters


class ADBCProtocol:
Expand Down Expand Up @@ -36,6 +33,7 @@ def append_param_if_exists(p_str: str, p_name: str) -> str:
@classmethod
def get_query_sql(cls, params: QueryParameters) -> str:
"""Get the SQL that will be sent via Arrow Flight to the server based on query parameters."""
validate_query_parameters(params)
serialized_params = cls._serialize_params_dict(params, QueryParameters.__optional_keys__)
return f"SELECT * FROM {{{{ semantic_layer.query({serialized_params}) }}}}"

Expand Down
42 changes: 40 additions & 2 deletions dbtsl/api/graphql/client/asyncio.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from contextlib import AbstractAsyncContextManager
from typing import List, Optional, Self

import pyarrow as pa
from typing_extensions import AsyncIterator, Unpack
from typing_extensions import AsyncIterator, Unpack, overload

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import (
Expand Down Expand Up @@ -44,10 +44,48 @@ class AsyncGraphQLClient:
"""Get a list of all available saved queries."""
...

async def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
@overload
async def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
async def compile_sql(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

@overload
async def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
async def query(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
44 changes: 41 additions & 3 deletions dbtsl/api/graphql/client/sync.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ from contextlib import AbstractContextManager
from typing import Iterator, List, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
from typing_extensions import Self, Unpack, overload

from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.models import (
Expand Down Expand Up @@ -44,10 +44,48 @@ class SyncGraphQLClient:
"""Get a list of all available saved queries."""
...

def compile_sql(self, **params: Unpack[QueryParameters]) -> str:
@overload
def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
def compile_sql(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
@overload
def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
def query(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...
28 changes: 19 additions & 9 deletions dbtsl/api/graphql/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing_extensions import NotRequired, override

from dbtsl.api.graphql.util import render_query
from dbtsl.api.shared.query_params import QueryParameters
from dbtsl.api.shared.query_params import QueryParameters, validate_query_parameters
from dbtsl.models import Dimension, Entity, Measure, Metric
from dbtsl.models.query import QueryId, QueryResult, QueryStatus
from dbtsl.models.saved_query import SavedQuery
Expand Down Expand Up @@ -200,15 +200,17 @@ def get_request_text(self) -> str:
query = """
mutation createQuery(
$environmentId: BigInt!,
$metrics: [MetricInput!]!,
$groupBy: [GroupByInput!]!,
$savedQuery: String,
$metrics: [MetricInput!],
$groupBy: [GroupByInput!],
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$limit: Int,
$readCache: Boolean,
) {
createQuery(
environmentId: $environmentId,
savedQuery: $savedQuery,
metrics: $metrics,
groupBy: $groupBy,
where: $where,
Expand All @@ -224,10 +226,13 @@ def get_request_text(self) -> str:

@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
Expand Down Expand Up @@ -285,15 +290,17 @@ def get_request_text(self) -> str:
query = """
mutation compileSql(
$environmentId: BigInt!,
$metrics: [MetricInput!]!,
$groupBy: [GroupByInput!]!,
$savedQuery: String,
$metrics: [MetricInput!],
$groupBy: [GroupByInput!],
$where: [WhereInput!]!,
$orderBy: [OrderByInput!]!,
$limit: Int,
$readCache: Boolean,
) {
compileSql(
environmentId: $environmentId,
savedQuery: $savedQuery,
metrics: $metrics,
groupBy: $groupBy,
where: $where,
Expand All @@ -309,10 +316,13 @@ def get_request_text(self) -> str:

@override
def get_request_variables(self, environment_id: int, **kwargs: QueryParameters) -> Dict[str, Any]:
# TODO: fix typing
validate_query_parameters(kwargs) # type: ignore
return {
"environmentId": environment_id,
"metrics": [{"name": m} for m in kwargs.get("metrics", [])],
"groupBy": [{"name": g} for g in kwargs.get("group_by", [])],
"savedQuery": kwargs.get("saved_query", None),
"metrics": [{"name": m} for m in kwargs["metrics"]] if "metrics" in kwargs else None,
"groupBy": [{"name": g} for g in kwargs["group_by"]] if "group_by" in kwargs else None,
"where": [{"sql": sql} for sql in kwargs.get("where", [])],
"orderBy": [{"name": o} for o in kwargs.get("order_by", [])],
"limit": kwargs.get("limit", None),
Expand Down
24 changes: 23 additions & 1 deletion dbtsl/api/shared/query_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,12 @@


class QueryParameters(TypedDict, total=False):
"""The parameters of `semantic_layer.query`."""
"""The parameters of `semantic_layer.query`.
metrics/group_by and saved_query are mutually exclusive.
"""

saved_query: str
metrics: List[str]
group_by: List[str]
limit: int
Expand All @@ -12,6 +16,24 @@ class QueryParameters(TypedDict, total=False):
read_cache: bool


def validate_query_parameters(params: QueryParameters) -> None:
"""Validate a dict that should be QueryParameters."""
is_saved_query = "saved_query" in params
is_adhoc_query = "metrics" in params or "group_by" in params
if is_saved_query and is_adhoc_query:
raise ValueError(
"metrics/group_by and saved_query are mutually exclusive, "
"since, by definition, saved queries already include "
"metrics and group_by."
)

if "metrics" in params and len(params["metrics"]) == 0:
raise ValueError("You need to specify at least one metric.")

if "group_by" in params and len(params["group_by"]) == 0:
raise ValueError("You need to specify at least one dimension to group by.")


class DimensionValuesQueryParameters(TypedDict, total=False):
"""The parameters of `semantic_layer.dimension_values`."""

Expand Down
46 changes: 42 additions & 4 deletions dbtsl/client/asyncio.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from contextlib import AbstractAsyncContextManager
from typing import AsyncIterator, List
from typing import AsyncIterator, List, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
from typing_extensions import Self, Unpack, overload

from dbtsl.api.adbc.protocol import QueryParameters
from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery
Expand All @@ -14,12 +14,50 @@ class AsyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
@overload
async def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
async def compile_sql(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
async def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

async def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
@overload
async def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
async def query(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...

async def metrics(self) -> List[Metric]:
Expand Down
46 changes: 42 additions & 4 deletions dbtsl/client/sync.pyi
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from contextlib import AbstractContextManager
from typing import Iterator, List
from typing import Iterator, List, Optional

import pyarrow as pa
from typing_extensions import Self, Unpack
from typing_extensions import Self, Unpack, overload

from dbtsl.api.adbc.protocol import QueryParameters
from dbtsl.models import Dimension, Entity, Measure, Metric, SavedQuery
Expand All @@ -14,12 +14,50 @@ class SyncSemanticLayerClient:
auth_token: str,
host: str,
) -> None: ...
@overload
def compile_sql(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
@overload
def compile_sql(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> str: ...
def compile_sql(self, **query_params: Unpack[QueryParameters]) -> str:
"""Get the compiled SQL that would be sent to the warehouse by a query."""
...

def query(self, **query_params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer for a metric data."""
@overload
def query(
self,
metrics: List[str],
group_by: Optional[List[str]] = None,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
@overload
def query(
self,
saved_query: str,
limit: Optional[int] = None,
order_by: Optional[List[str]] = None,
where: Optional[List[str]] = None,
read_cache: bool = True,
) -> "pa.Table": ...
async def query(self, **params: Unpack[QueryParameters]) -> "pa.Table":
"""Query the Semantic Layer."""
...

def metrics(self) -> List[Metric]:
Expand Down
Loading

0 comments on commit 53aa90e

Please sign in to comment.