Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Permission commands migrated. #228

Merged
merged 6 commits into from
May 21, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 59 additions & 6 deletions mreg_cli/api/abstracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -303,47 +303,100 @@ def get_by_field_and_raise(

@classmethod
def get_list_by_field(
cls, field: str, value: str | int, ordering: str | None = None
cls, field: str, value: str | int, ordering: str | None = None, limit: int = 500
) -> list[Self]:
"""Get a list of objects by a field.

:param field: The field to search by.
:param value: The value to search for.
:param ordering: The ordering to use when fetching the list.
:param limit: The maximum number of hits to allow (default 500)

:returns: A list of objects if found, an empty list otherwise.
"""
params = {field: value}
if ordering:
params["ordering"] = ordering

data = get_list(cls.endpoint(), params=params)
data = get_list(cls.endpoint(), params=params, max_hits_to_allow=limit)
return [cls(**item) for item in data]

@classmethod
def get_by_query(cls, query: dict[str, str], ordering: str | None = None) -> list[Self]:
def get_by_query(
cls, query: dict[str, str], ordering: str | None = None, limit: int = 500
) -> list[Self]:
"""Get a list of objects by a query.

:param query: The query to search by.
:param ordering: The ordering to use when fetching the list.
:param limit: The maximum number of hits to allow (default 500)

:returns: A list of objects if found, an empty list otherwise.
"""
if ordering:
query["ordering"] = ordering

data = get_list(cls.endpoint().with_query(query))
data = get_list(cls.endpoint().with_query(query), max_hits_to_allow=limit)
return [cls(**item) for item in data]

@classmethod
def get_by_query_unique(cls, data: dict[str, str]) -> Self:
def get_by_query_unique_or_raise(
cls,
query: dict[str, str],
exc_type: type[Exception] = EntityNotFound,
exc_message: str | None = None,
) -> Self:
"""Get an object by a query and raise if not found.

Used for cases where the object must exist for the operation to continue.

:param query: The query to search by.
:param exc_type: The exception type to raise.
:param exc_message: The exception message. Overrides the default message.

:returns: The object if found.
"""
obj = cls.get_by_query_unique(query)
if not obj:
if not exc_message:
exc_message = f"{cls.__name__} with query {query} not found."
raise exc_type(exc_message)
return obj

@classmethod
def get_by_query_unique_and_raise(
cls,
query: dict[str, str],
exc_type: type[Exception] = EntityAlreadyExists,
exc_message: str | None = None,
) -> None:
"""Get an object by a query and raise if found.

Used for cases where the object must NOT exist for the operation to continue.

:param query: The query to search by.
:param exc_type: The exception type to raise.
:param exc_message: The exception message. Overrides the default message.

:raises Exception: If the object is found.
"""
obj = cls.get_by_query_unique(query)
if obj:
if not exc_message:
exc_message = f"{cls.__name__} with query {query} already exists."
raise exc_type(exc_message)
return None

@classmethod
def get_by_query_unique(cls, data: dict[str, str]) -> Self | None:
"""Get an object with the given data.

:param data: The data to search for.
:returns: The object if found, None otherwise.
"""
obj_dict = get_list_unique(cls.endpoint(), params=data)
if not obj_dict:
raise EntityNotFound(f"{cls.__name__} record for {data} not found.")
return None
return cls(**obj_dict)

def refetch(self) -> Self:
Expand Down
39 changes: 38 additions & 1 deletion mreg_cli/api/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,10 +342,19 @@ class Permission(FrozenModelWithTimestamps, APIMixin):

id: int # noqa: A003
group: str
range: str # noqa: A003
range: IP_NetworkT # noqa: A003
regex: str
labels: list[int]

@field_validator("range", mode="before")
@classmethod
def validate_ip_or_network(cls, value: str) -> IP_NetworkT:
"""Validate and convert the input to a network."""
try:
return ipaddress.ip_network(value)
except ValueError as e:
raise InputFailure(f"Invalid input for network: {value}") from e

@classmethod
def endpoint(cls) -> Endpoint:
"""Return the endpoint for the class."""
Expand All @@ -364,6 +373,34 @@ def output_multiple(cls, permissions: list[Permission], indent: int = 4) -> None
indent=indent,
)

def add_label(self, label_name: str) -> Self:
"""Add a label to the permission.

:param label_name: The name of the label to add.
:returns: The updated Permission object.
"""
label = Label.get_by_name_or_raise(label_name)
if label.id in self.labels:
raise EntityAlreadyExists(f"The permission already has the label {label_name!r}")

label_ids = self.labels.copy()
label_ids.append(label.id)
return self.patch({"labels": label_ids})

def remove_label(self, label_name: str) -> Self:
"""Remove a label from the permission.

:param label_name: The name of the label to remove.
:returns: The updated Permission object.
"""
label = Label.get_by_name_or_raise(label_name)
if label.id not in self.labels:
raise EntityNotFound(f"The permission does not have the label {label_name!r}")

label_ids = self.labels.copy()
label_ids.remove(label.id)
return self.patch({"labels": label_ids})


class Zone(FrozenModelWithTimestamps, WithTTL):
"""Model representing a DNS zone with various attributes and related nameservers."""
Expand Down
154 changes: 53 additions & 101 deletions mreg_cli/commands/permission.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,16 @@
from __future__ import annotations

import argparse
import ipaddress
from typing import Any

from mreg_cli.api.models import Label, NetworkOrIP, Permission
from mreg_cli.commands.base import BaseCommand
from mreg_cli.commands.registry import CommandRegistry
from mreg_cli.log import cli_info, cli_warning
from mreg_cli.exceptions import DeleteError, EntityNotFound
from mreg_cli.log import cli_info
from mreg_cli.outputmanager import OutputManager
from mreg_cli.types import Flag
from mreg_cli.utilities.api import delete, get, get_list, patch, post
from mreg_cli.utilities.network import network_is_supernet
from mreg_cli.utilities.shared import convert_wildcard_to_regex
from mreg_cli.utilities.validators import is_valid_network

command_registry = CommandRegistry()

Expand Down Expand Up @@ -47,45 +45,52 @@ def network_list(args: argparse.Namespace) -> None:

:param args: argparse.Namespace (group, range)
"""
params = {
"ordering": "range,group",
}
permission_list: list[Permission] = []

params: dict[str, str] = {}
if args.group is not None:
param, value = convert_wildcard_to_regex("group", args.group)
params[param] = value
permissions = get_list("/api/v1/permissions/netgroupregex/", params=params)

data = []
# Well, this is effin' awful. We have to fetch all permissions, but the API wants to limit
# the number of results. We should probably fix this in the API.
permissions = Permission.get_by_query(query=params, ordering="range,group", limit=10000)

if args.range is not None:
argnetwork = ipaddress.ip_network(args.range)
for i in permissions:
permnet = ipaddress.ip_network(i["range"])
argnetwork = NetworkOrIP(ip_or_network=args.range).as_network()

for permission in permissions:
permnet = permission.range
if permnet.version != argnetwork.version:
continue # no warning if the networks are not comparable
if network_is_supernet(argnetwork, permnet): # type: ignore # guaranteed to be the same version
data.append(i)
if argnetwork.supernet_of(permnet): # type: ignore # guaranteed to be the same version
permission_list.append(permission)
else:
data = permissions
permission_list = permissions

if not permission_list:
raise EntityNotFound("No permissions found")

output: list[dict[str, str]] = []
labelnames: dict[int, str] = {}

if not data:
cli_info("No permissions found", True)
return
for label in Label.get_all():
labelnames[label.id] = label.name

# Add label names to the result
labelnames = {}
info = get_list("/api/v1/labels/")
if info:
for i in info:
labelnames[i["id"]] = i["name"]
for row in data:
labels = []
for j in row["labels"]:
labels.append(labelnames[j])
row["labels"] = ", ".join(labels)
for permission in permission_list:
perm_data: dict[str, str] = {}
row_labels: list[str] = [labelnames[label] for label in permission.labels]
perm_data["labels"] = ", ".join(row_labels)
perm_data["range"] = str(permission.range)
perm_data["group"] = permission.group
perm_data["regex"] = permission.regex
output.append(perm_data)

headers = ("Range", "Group", "Regex", "Labels")
keys = ("range", "group", "regex", "labels")
OutputManager().add_formatted_table(headers, keys, data)
OutputManager().add_formatted_table(
("Range", "Group", "Regex", "Labels"),
("range", "group", "regex", "labels"),
output,
)


@command_registry.register_command(
Expand All @@ -103,16 +108,16 @@ def network_add(args: argparse.Namespace) -> None:

:param args: argparse.Namespace (range, group, regex)
"""
if not is_valid_network(args.range):
cli_warning(f"Invalid range: {args.range}")
NetworkOrIP(ip_or_network=args.range).as_network()

data = {
query = {
"range": args.range,
"group": args.group,
"regex": args.regex,
}
path = "/api/v1/permissions/netgroupregex/"
post(path, **data)

Permission.get_by_query_unique_and_raise(query)
Permission.create(params=query)
cli_info(f"Added permission to {args.range}", True)


Expand All @@ -131,22 +136,17 @@ def network_remove(args: argparse.Namespace) -> None:

:param args: argparse.Namespace (range, group, regex)
"""
params = {
query = {
"group": args.group,
"range": args.range,
"regex": args.regex,
}
permissions = get_list("/api/v1/permissions/netgroupregex/", params=params)

if not permissions:
cli_warning("No matching permission found", True)
return

assert len(permissions) == 1, "Should only match one permission"
identifier = permissions[0]["id"]
path = f"/api/v1/permissions/netgroupregex/{identifier}"
delete(path)
cli_info(f"Removed permission for {args.range}", True)
permission = Permission.get_by_query_unique_or_raise(query)
if permission.delete():
cli_info(f"Removed permission for {args.range}", True)
else:
raise DeleteError(f"Failed to remove permission for {args.range}")


@command_registry.register_command(
Expand All @@ -171,32 +171,8 @@ def add_label_to_permission(args: argparse.Namespace) -> None:
"range": args.range,
"regex": args.regex,
}
permissions = get_list("/api/v1/permissions/netgroupregex/", params=query)

if not permissions:
cli_warning("No matching permission found", True)
return

assert len(permissions) == 1, "Should only match one permission"
identifier = permissions[0]["id"]
path = f"/api/v1/permissions/netgroupregex/{identifier}"

# find the label
labelpath = f"/api/v1/labels/name/{args.label}"
res = get(labelpath, ok404=True)
if not res:
cli_warning(f"Could not find a label with name {args.label!r}")
label = res.json()

# check if the permission object already has the label
perm = get(path).json()
if label["id"] in perm["labels"]:
cli_warning(f"The permission already has the label {args.label!r}")

# patch the permission
ar = perm["labels"]
ar.append(label["id"])
patch(path, labels=ar)
permission = Permission.get_by_query_unique_or_raise(query)
permission.add_label(args.label)
cli_info(f"Added the label {args.label!r} to the permission.", print_msg=True)


Expand All @@ -222,30 +198,6 @@ def remove_label_from_permission(args: argparse.Namespace) -> None:
"range": args.range,
"regex": args.regex,
}
permissions = get_list("/api/v1/permissions/netgroupregex/", params=query)

if not permissions:
cli_warning("No matching permission found", True)
return

assert len(permissions) == 1, "Should only match one permission"
identifier = permissions[0]["id"]
path = f"/api/v1/permissions/netgroupregex/{identifier}"

# find the label
labelpath = f"/api/v1/labels/name/{args.label}"
res = get(labelpath, ok404=True)
if not res:
cli_warning(f"Could not find a label with name {args.label!r}")
label = res.json()

# check if the permission object has the label
perm = get(path).json()
if label["id"] not in perm["labels"]:
cli_warning(f"The permission doesn't have the label {args.label!r}")

# patch the permission
ar = perm["labels"]
ar.remove(label["id"])
patch(path, params={"labels": ar}, use_json=True)
permission = Permission.get_by_query_unique_or_raise(query)
permission.remove_label(args.label)
terjekv marked this conversation as resolved.
Show resolved Hide resolved
cli_info(f"Removed the label {args.label!r} from the permission.", print_msg=True)
2 changes: 1 addition & 1 deletion mreg_cli/outputmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,7 +428,7 @@ def add_formatted_table(
# Add headers and rows to the output
self.add_line(raw_format.format(*headers))
for d in output_data:
self.add_line(raw_format.format(*[d[key] for key in keys]))
self.add_line(raw_format.format(*[str(d[key]) for key in keys]))

def filtered_output(self) -> list[str]:
"""Return the lines of output.
Expand Down
Loading