Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add QueryParams type #273

Merged
merged 2 commits into from
Jun 24, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 7 additions & 7 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
PatchError,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import JsonMapping
from mreg_cli.types import JsonMapping, QueryParams
from mreg_cli.utilities.api import (
delete,
get,
Expand Down Expand Up @@ -308,15 +308,15 @@ def get_list_by_field(

:returns: A list of objects if found, an empty list otherwise.
"""
params = {field: value}
params: QueryParams = {field: value}
if ordering:
params["ordering"] = ordering

return get_typed(cls.endpoint(), list[cls], params=params, limit=limit)

@classmethod
def get_by_query(
cls, query: dict[str, str], ordering: str | None = None, limit: int | None = 500
cls, query: QueryParams, ordering: str | None = None, limit: int | None = 500
) -> list[Self]:
"""Get a list of objects by a query.

Expand All @@ -329,12 +329,12 @@ def get_by_query(
if ordering:
query["ordering"] = ordering

return get_typed(cls.endpoint().with_query(query), list[cls], limit=limit)
return get_typed(cls.endpoint(), list[cls], query, limit=limit)

@classmethod
def get_by_query_unique_or_raise(
cls,
query: dict[str, str],
query: QueryParams,
exc_type: type[Exception] = EntityNotFound,
exc_message: str | None = None,
) -> Self:
Expand All @@ -358,7 +358,7 @@ def get_by_query_unique_or_raise(
@classmethod
def get_by_query_unique_and_raise(
cls,
query: dict[str, str],
query: QueryParams,
exc_type: type[Exception] = EntityAlreadyExists,
exc_message: str | None = None,
) -> None:
Expand All @@ -380,7 +380,7 @@ def get_by_query_unique_and_raise(
return None

@classmethod
def get_by_query_unique(cls, data: dict[str, str]) -> Self | None:
def get_by_query_unique(cls, data: QueryParams) -> Self | None:
"""Get an object with the given data.

:param data: The data to search for.
Expand Down
9 changes: 0 additions & 9 deletions mreg_cli/api/endpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,12 +126,3 @@ def with_params(self, *params: str | int) -> str:
)
encoded_params = (quote(str(param)) for param in params)
return self.value.format(*encoded_params)

def with_query(self, query: dict[str, str]) -> str:
"""Construct and return an endpoint URL with a query string.

:param query: A dictionary of query parameters.
:returns: A fully constructed endpoint URL with a query string.
"""
query_string = "&".join(f"{quote(key)}={quote(value)}" for key, value in query.items())
return f"{self.value}?{query_string}"
3 changes: 2 additions & 1 deletion mreg_cli/api/history.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from mreg_cli.api.endpoints import Endpoint
from mreg_cli.exceptions import EntityNotFound, InternalError
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import QueryParams
from mreg_cli.utilities.api import get_typed


Expand Down Expand Up @@ -125,7 +126,7 @@ def output_multiple(cls, basename: str, items: list[HistoryItem]) -> None:
@classmethod
def get(cls, name: str, resource: HistoryResource) -> list[Self]:
"""Get history items for a resource."""
params: dict[str, str] = {"resource": resource.resource(), "name": name}
params: QueryParams = {"resource": resource.resource(), "name": name}
ret = get_typed(Endpoint.History, list[cls], params=params)
if len(ret) == 0:
raise EntityNotFound(f"No history found for {name}")
Expand Down
8 changes: 4 additions & 4 deletions mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
ValidationError,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import IP_AddressT, IP_NetworkT, IP_Version
from mreg_cli.types import IP_AddressT, IP_NetworkT, IP_Version, QueryParams
from mreg_cli.utilities.api import (
delete,
get,
Expand Down Expand Up @@ -717,7 +717,7 @@ def update_soa(
:param expire: The expire interval for the zone.
:param soa_ttl: The TTL for the zone.
"""
params: dict[str, str | int | None] = {
params: QueryParams = {
"primary_ns": primary_ns,
"email": email,
"serialno": serialno,
Expand Down Expand Up @@ -2691,7 +2691,7 @@ def add_ip(self, ip: IP_AddressT, mac: MACAddressField | None = None) -> Host:

:returns: A new Host object fetched from the API with the updated IP address.
"""
params: dict[str, str | None] = {"ipaddress": str(ip), "host": str(self.id)}
params: QueryParams = {"ipaddress": str(ip), "host": str(self.id)}
if mac:
params["macaddress"] = mac.address

Expand Down Expand Up @@ -3109,7 +3109,7 @@ def endpoint(cls) -> Endpoint:
return Endpoint.Hosts

@classmethod
def get(cls, params: dict[str, Any] | None = None) -> HostList:
def get(cls, params: QueryParams | None = None) -> HostList:
"""Get a list of hosts.

:param params: Optional parameters to pass to the API.
Expand Down
6 changes: 3 additions & 3 deletions mreg_cli/commands/host_submodules/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
PatchError,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import Flag
from mreg_cli.types import Flag, JsonMapping, QueryParams
from mreg_cli.utilities.shared import convert_wildcard_to_regex


Expand Down Expand Up @@ -118,7 +118,7 @@ def add(args: argparse.Namespace) -> None:
except ValueError as e:
raise InputFailure(f"invalid MAC address: {macaddress}") from e

data: dict[str, str | None] = {
data: JsonMapping = {
"name": hname.hostname,
"contact": args.contact or None,
"comment": args.comment or None,
Expand Down Expand Up @@ -377,7 +377,7 @@ def _add_param(param: str, value: str) -> None:
if not any([args.name, args.comment, args.contact]):
raise InputFailure("Need at least one search critera")

params: dict[str, str | int] = {
params: QueryParams = {
"ordering": "name",
}

Expand Down
28 changes: 10 additions & 18 deletions mreg_cli/commands/host_submodules/rr.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@
PatchError,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import Flag
from mreg_cli.types import Flag, QueryParams


@command_registry.register_command(
Expand Down Expand Up @@ -340,7 +340,7 @@ def naptr_add(args: argparse.Namespace) -> None:
:param args: argparse.Namespace (name, preference, order, flag, service, regex, replacement)
"""
host = Host.get_by_any_means_or_raise(args.name)
arg_data: dict[str, str | int] = {
params: QueryParams = {
"preference": args.preference,
"order": args.order,
"flag": args.flag,
Expand All @@ -349,15 +349,10 @@ def naptr_add(args: argparse.Namespace) -> None:
"replacement": args.replacement,
"host": host.id,
}

# Query parameters must be strings...
search_data: dict[str, str] = {k: str(v) for k, v in arg_data.items()}

existing_naptr = NAPTR.get_by_query_unique(search_data)
existing_naptr = NAPTR.get_by_query_unique(params)
if existing_naptr:
raise EntityAlreadyExists(f"{host} already has that NAPTR defined.")

NAPTR.create(params=arg_data)
NAPTR.create(params=params)
OutputManager().add_ok(f"Added NAPTR record to {host.name.hostname}.")


Expand Down Expand Up @@ -622,7 +617,7 @@ def srv_add(args: argparse.Namespace) -> None:
if not hzone:
raise EntityNotFound(f"{host} isn't in a zone controlled by MREG")

data: dict[str, str] = {
data: QueryParams = {
"name": sname.hostname,
"priority": args.priority,
"weight": args.weight,
Expand All @@ -634,10 +629,7 @@ def srv_add(args: argparse.Namespace) -> None:
existing_srv = Srv.get_by_query_unique(data)
if existing_srv:
raise EntityAlreadyExists(f"{sname} already has that SRV defined.")

arg_data: dict[str, str | None] = {k: v for k, v in data.items()}

Srv.create(arg_data)
Srv.create(data)
OutputManager().add_ok(f"Added SRV record {sname} with target {host}.")


Expand Down Expand Up @@ -679,7 +671,7 @@ def srv_remove(args: argparse.Namespace) -> None:
host = Host.get_by_any_means_or_raise(args.host)
sname = HostT(hostname=args.name)

data: dict[str, str] = {
data: QueryParams = {
"name": sname.hostname,
"host": str(host.id),
"priority": args.priority,
Expand Down Expand Up @@ -739,18 +731,18 @@ def sshfp_add(args: argparse.Namespace) -> None:
"""
host = Host.get_by_any_means_or_raise(args.name)

data: dict[str, str] = {
data: QueryParams = {
"algorithm": args.algorithm,
"hash_type": args.hash_type,
"fingerprint": args.fingerprint,
"host": str(host.id),
"host": host.id,
}

existing_sshfp = SSHFP.get_by_query_unique(data)
if existing_sshfp:
raise EntityAlreadyExists(f"{host} already has that SSHFP defined.")

arg_data: dict[str, str | None] = {k: v for k, v in data.items()}
arg_data = {k: v for k, v in data.items()}
SSHFP.create(arg_data)
OutputManager().add_ok(f"Added SSHFP record for {host.name.hostname}.")

Expand Down
4 changes: 2 additions & 2 deletions mreg_cli/commands/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
NetworkOverlap,
)
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import Flag
from mreg_cli.types import Flag, QueryParams
from mreg_cli.utilities.shared import convert_wildcard_to_regex, string_to_int
from mreg_cli.utilities.validators import is_valid_category_tag, is_valid_location_tag

Expand Down Expand Up @@ -191,7 +191,7 @@ def find(args: argparse.Namespace) -> None:
addr = IPAddressField(address=ip_arg)
networks = [Network.get_by_ip_or_raise(addr.address)]
else:
params: dict[str, str] = {}
params: QueryParams = {}
param_names = [
"network",
"description",
Expand Down
4 changes: 2 additions & 2 deletions mreg_cli/commands/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from mreg_cli.commands.registry import CommandRegistry
from mreg_cli.exceptions import DeleteError, EntityNotFound
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import Flag
from mreg_cli.types import Flag, QueryParams
from mreg_cli.utilities.shared import convert_wildcard_to_regex

command_registry = CommandRegistry()
Expand Down Expand Up @@ -46,7 +46,7 @@ def network_list(args: argparse.Namespace) -> None:
"""
permission_list: list[Permission] = []

params: dict[str, str] = {}
params: QueryParams = {}
if args.group is not None:
param, value = convert_wildcard_to_regex("group", args.group)
params[param] = value
Expand Down
2 changes: 2 additions & 0 deletions mreg_cli/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
Any,
Literal,
Mapping,
MutableMapping,
NamedTuple,
Sequence,
TypeAlias,
Expand Down Expand Up @@ -90,6 +91,7 @@ def json_custom_error_validator(
],
)
JsonMapping = Mapping[str, Json]
QueryParams = MutableMapping[str, str | int | None]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, MutableMapping. Neat!



class Flag:
Expand Down
Loading
Loading