diff --git a/mreg_cli/api/__init__.py b/mreg_cli/api/__init__.py index 985bfdd3..efb5df77 100644 --- a/mreg_cli/api/__init__.py +++ b/mreg_cli/api/__init__.py @@ -6,17 +6,21 @@ guarantees about the data it is working with. """ +import re from ipaddress import ip_address from typing import Dict, Union from mreg_cli.api.endpoints import Endpoint -from mreg_cli.api.models import Host, HostList, MACAddressField +from mreg_cli.api.models import Host, HostList, IPAddressT, MACAddressField +from mreg_cli.config import MregCliConfig from mreg_cli.log import cli_warning +from mreg_cli.outputmanager import OutputManager from mreg_cli.utilities.api import get, get_item_by_key_value, get_list, post -from mreg_cli.utilities.host import clean_hostname -def get_host(identifier: str, ok404: bool = False) -> Union[None, Host]: +def get_host( + identifier: str, ok404: bool = False, inform_as_cname: bool = False +) -> Union[None, Host]: """Get a host by the given identifier. - If the identifier is numeric, it will be treated as an ID. @@ -28,7 +32,7 @@ def get_host(identifier: str, ok404: bool = False) -> Union[None, Host]: To check if a returned host is a cname, one can do the following: ```python - hostname = "example.com" + hostname = "host.example.com" host = get_host(hostname, ok404=True) if host is None: print("Host not found.") @@ -38,7 +42,12 @@ def get_host(identifier: str, ok404: bool = False) -> Union[None, Host]: print(f"{host.name} is a host.") ``` + Note that get_host will perform a case-insensitive search for a fully qualified version + of the hostname, so the comparison above may fail. + :param identifier: The identifier to search for. + :param ok404: If True, don't raise a CliWarning if the host is not found. + :param inform_as_cname: If True, inform the user if the host is a CNAME. :raises CliWarning: If we don't find the host and `ok404` is False. @@ -62,7 +71,7 @@ def get_host(identifier: str, ok404: bool = False) -> Union[None, Host]: hostname = clean_hostname(identifier) - data = get(Endpoint.Hosts.with_id(hostname), ok404=ok404) + data = get(Endpoint.Hosts.with_id(hostname), ok404=True) if data: data = data.json() @@ -73,7 +82,13 @@ def get_host(identifier: str, ok404: bool = False) -> Union[None, Host]: if data is not None: data = get_item_by_key_value(Endpoint.Hosts, "id", data["host"], ok404=ok404) + if data and inform_as_cname: + OutputManager().add_line(f"{hostname} is a CNAME for {data['name']}") + if data is None: + if not ok404: + cli_warning(f"Host {identifier} not found.") + return None return Host(**data) @@ -102,3 +117,46 @@ def add_host(data: Dict[str, Union[str, None]]) -> bool: return True return False + + +def get_network_by_ip(ip: IPAddressT) -> Union[None, Dict[str, Union[str, int]]]: + """Return a network associated with given IP.""" + return get(Endpoint.NetworksByIP.with_id(str(ip))).json() + + +def clean_hostname(name: Union[str, bytes]) -> str: + """Ensure hostname is fully qualified, lowercase, and has valid characters. + + :param name: The hostname to clean. + + :raises CliWarning: If the hostname is invalid. + + :returns: The cleaned hostname. + """ + # bytes? + if not isinstance(name, (str, bytes)): + cli_warning("Invalid input for hostname: {}".format(name)) + + if isinstance(name, bytes): + name = name.decode() + + name = name.lower() + + # invalid characters? + if re.search(r"^(\*\.)?([a-z0-9_][a-z0-9\-]*\.?)+$", name) is None: + cli_warning("Invalid input for hostname: {}".format(name)) + + # Assume user is happy with domain, but strip the dot. + if name.endswith("."): + return name[:-1] + + # If a dot in name, assume long name. + if "." in name: + return name + + config = MregCliConfig() + default_domain = config.get("domain") + # Append domain name if in config and it does not end with it + if default_domain and not name.endswith(default_domain): + return "{}.{}".format(name, default_domain) + return name diff --git a/mreg_cli/api/models.py b/mreg_cli/api/models.py index b1020d93..a9247fb5 100644 --- a/mreg_cli/api/models.py +++ b/mreg_cli/api/models.py @@ -3,6 +3,7 @@ import ipaddress import re import sys +from datetime import datetime from typing import Any, Dict, List, Optional, Union from pydantic import BaseModel, root_validator, validator @@ -11,7 +12,7 @@ from mreg_cli.api.endpoints import Endpoint from mreg_cli.log import cli_warning from mreg_cli.outputmanager import OutputManager -from mreg_cli.utilities.api import delete, get, get_list, patch +from mreg_cli.utilities.api import delete, get, get_item_by_key_value, get_list, get_list_in, patch IPAddressT = Union[ipaddress.IPv4Address, ipaddress.IPv6Address] @@ -46,7 +47,68 @@ class Config: frozen = True -class Network(FrozenModel): +class FrozenModelWithTimestamps(FrozenModel): + """Model with created_at and updated_at fields.""" + + created_at: datetime + updated_at: datetime + + def output_timestamps(self, padding: int = 14) -> None: + """Output the created and updated timestamps to the console.""" + output_manager = OutputManager() + output_manager.add_line(f"{'Created:':<{padding}}{self.created_at:%c}") + output_manager.add_line(f"{'Updated:':<{padding}}{self.updated_at:%c}") + + +class WithHost(BaseModel): + """Model for an object that has a host element.""" + + host: int + + def resolve_host(self) -> Union["Host", None]: + """Resolve the host ID to a Host object. + + Notes + ----- + - This method will call the API to resolve the host ID to a Host object. + - This assumes that there is a host attribute in the object. + + """ + data = get_item_by_key_value(Endpoint.Hosts, "id", str(self.host)) + + if not data: + return None + + return Host(**data) + + +class NameServer(FrozenModelWithTimestamps): + """Model for representing a nameserver within a DNS zone.""" + + id: int # noqa: A003 + name: str + ttl: Optional[int] = None + + +class Zone(FrozenModelWithTimestamps): + """Model representing a DNS zone with various attributes and related nameservers.""" + + id: int # noqa: A003 + nameservers: List[NameServer] + updated: bool + primary_ns: str + email: str + serialno: int + serialno_updated_at: datetime + refresh: int + retry: int + expire: int + soa_ttl: int + default_ttl: int + name: str + + +class Network(FrozenModelWithTimestamps): """Model for a network.""" id: int # noqa: A003 @@ -121,13 +183,12 @@ def __hash__(self): return hash(self.address) -class IPAddress(FrozenModel): +class IPAddress(FrozenModelWithTimestamps, WithHost): """Represents an IP address with associated details.""" id: int # noqa: A003 macaddress: Optional[MACAddressField] = None ipaddress: IPAddressField - host: int @validator("macaddress", pre=True, allow_reuse=True) def create_valid_macadress_or_none(cls, v: str): @@ -188,40 +249,146 @@ def associate_mac(self, mac: Union[MACAddressField, str], force: bool = False) - patch(Endpoint.Ipaddresses.with_id(self.id), macaddress=mac.address) return self.model_copy(update={"macaddress": mac}) + def output(self, len_ip: int, len_names: int, names: bool = False): + """Output the IP address to the console.""" + ip = self.ipaddress.__str__() + mac = self.macaddress if self.macaddress else "" + name = str(self.host) if names else "" + OutputManager().add_line(f"{name:<{len_names}}{ip:<{len_ip}}{mac}") + + @classmethod + def output_multiple(cls, ips: List["IPAddress"], padding: int = 14, names: bool = False): + """Output IP addresses to the console.""" + output_manager = OutputManager() + len_ip = max(padding, max([len(str(ip.ipaddress)) for ip in ips], default=0) + 2) + + # This seems completely broken, we need to look up all the hosts and get their names. + # This again requires a fetch_hosts() call that takes a series of identifiers using + # id__in. + len_names = ( + padding + if not names + else max(padding, max([len(str(ip.host)) for ip in ips], default=0) + 2) + ) + + # Separate and output A and AAAA records + for record_type, records in ( + ("A_Records", [ip for ip in ips if ip.is_ipv4()]), + ("AAAA_Records", [ip for ip in ips if ip.is_ipv6()]), + ): + if records: + output_manager.add_line(f"{record_type:<{len_names}}IP{' ' * (len_ip - 2)}MAC") + for record in records: + record.output(len_ip=len_ip, len_names=len_names, names=names) + def __hash__(self): """Return a hash of the IP address.""" return hash((self.id, self.ipaddress.address, self.macaddress)) -class CNAME(FrozenModel): +class HInfo(FrozenModelWithTimestamps, WithHost): + """Represents a HINFO record.""" + + cpu: str + os: str + + def output(self, padding: int = 14): + """Output the HINFO record to the console.""" + OutputManager().add_line( + "{1:<{0}}cpu={2} os={3}".format(padding, "Hinfo:", self.cpu, self.os) + ) + + +class CNAME(FrozenModelWithTimestamps, WithHost): """Represents a CNAME record.""" name: HostT ttl: Optional[int] = None zone: int - host: int + def output(self, padding: int = 14) -> None: + """Output the CNAME record to the console. + + :param padding: Number of spaces for left-padding the output. + """ + actual_host = self.resolve_host() + host = actual_host.name if actual_host else "" + + OutputManager().add_line(f"{'Cname:':<{padding}}{self.name} -> {host}") + + @classmethod + def output_multiple(cls, cnames: List["CNAME"], padding: int = 14) -> None: + """Output multiple CNAME records to the console. + + :param cnames: List of CNAME records to output. + :param padding: Number of spaces for left-padding the output. + """ + if not cnames: + return + + for cname in cnames: + cname.output(padding=padding) -class TXT(FrozenModel): + +class TXT(FrozenModelWithTimestamps): """Represents a TXT record.""" txt: str host: int + def output(self, padding: int = 14) -> None: + """Output the TXT record to the console. + + :param padding: Number of spaces for left-padding the output. + """ + OutputManager().add_line(f"{'TXT:':<{padding}}{self.txt}") + + @classmethod + def output_multiple(cls, txts: List["TXT"], padding: int = 14) -> None: + """Output multiple TXT records to the console. + + :param txts: List of TXT records to output. + :param padding: Number of spaces for left-padding the output. + """ + if not txts: + return + + for txt in txts: + txt.output(padding=padding) -class MX(FrozenModel): + +class MX(FrozenModelWithTimestamps, WithHost): """Represents a MX record.""" mx: str priority: int host: int + def output(self, padding: int = 14) -> None: + """Output the MX record to the console. + + :param padding: Number of spaces for left-padding the output. + """ + len_pri = len("Priority") + OutputManager().add_line( + "{1:<{0}}{2:>{3}} {4}".format(padding, "", self.priority, len_pri, self.mx) + ) + + @classmethod + def output_multiple(cls, mxs: List["MX"], padding: int = 14) -> None: + """Output MX records to the console.""" + if not mxs: + return + + OutputManager().add_line("{1:<{0}}{2} {3}".format(padding, "MX:", "Priority", "Server")) + for mx in sorted(mxs, key=lambda i: i.priority): + mx.output(padding=padding) -class NAPTR(FrozenModel): + +class NAPTR(FrozenModelWithTimestamps, WithHost): """Represents a NAPTR record.""" id: int # noqa: A003 - host: int preference: int order: int flag: Optional[str] @@ -229,8 +396,50 @@ class NAPTR(FrozenModel): regex: Optional[str] replacement: str + def output(self, padding: int = 14) -> None: + """Output the NAPTR record to the console. -class Srv(FrozenModel): + :param padding: Number of spaces for left-padding the output. + """ + row_format = f"{{:<{padding}}}" * len(NAPTR.headers()) + OutputManager().add_line( + row_format.format( + "", + self.preference, + self.order, + self.flag, + self.service, + self.regex or '""', + self.replacement, + ) + ) + + @classmethod + def headers(cls) -> List[str]: + """Return the headers for the NAPTR record.""" + return [ + "NAPTRs:", + "Preference", + "Order", + "Flag", + "Service", + "Regex", + "Replacement", + ] + + @classmethod + def output_multiple(cls, naptrs: List["NAPTR"], padding: int = 14) -> None: + """Output multiple NAPTR records to the console.""" + headers = cls.headers() + row_format = f"{{:<{padding}}}" * len(headers) + manager = OutputManager() + if naptrs: + manager.add_line(row_format.format(*headers)) + for naptr in naptrs: + naptr.output(padding=padding) + + +class Srv(FrozenModelWithTimestamps, WithHost): """Represents a SRV record.""" id: int # noqa: A003 @@ -240,18 +449,102 @@ class Srv(FrozenModel): port: int ttl: Optional[int] zone: int - host: int + def output(self, padding: int = 14, host_id_name_map: Optional[Dict[int, str]] = None) -> None: + """Output the SRV record to the console. + + The output will include the record name, priority, weight, port, + and the associated host name. Optionally uses a mapping of host IDs + to host names to avoid repeated lookups. + + :param padding: Number of spaces for left-padding the output. + :param host_names: Optional dictionary mapping host IDs to host names. + """ + host_name = "" + if host_id_name_map and self.host in host_id_name_map: + host_name = host_id_name_map[self.host] + elif not host_id_name_map or self.host not in host_id_name_map: + host = self.resolve_host() + if host: + host_name = host.name + + # Format the output string to include padding and center alignment + # for priority, weight, and port. + output_manager = OutputManager() + format_str = "SRV: {:<{padding}} {:^6} {:^6} {:^6} {}" + output_manager.add_line( + format_str.format( + self.name, + str(self.priority), + str(self.weight), + str(self.port), + host_name, + padding=padding, + ) + ) + + @classmethod + def output_multiple(cls, srvs: List["Srv"], padding: int = 14) -> None: + """Output multiple SRV records. + + This method adjusts the padding dynamically based on the longest record name. + + :param srvs: List of Srv records to output. + :param padding: Minimum number of spaces for left-padding the output. + """ + if not srvs: + return + + host_ids = {srv.host for srv in srvs} + + host_data = get_list_in(Endpoint.Hosts, "id", list(host_ids)) + hosts = [Host(**host) for host in host_data] + + host_id_name_map = {host.id: host.name for host in hosts} + + host_id_name_map.update( + {host_id: host_id_name_map.get(host_id, "") for host_id in host_ids} + ) + + padding = max((len(srv.name) for srv in srvs), default=padding) + + # Output each SRV record with the optimized host name lookup + for srv in srvs: + srv.output(padding=padding, host_id_name_map=host_id_name_map) -class PTR_override(FrozenModel): + +class PTR_override(FrozenModelWithTimestamps, WithHost): """Represents a PTR override record.""" id: int # noqa: A003 host: int ipaddress: str # For now, should be an IP address + def output(self, padding: int = 14): + """Output the PTR override record to the console. + + :param padding: Number of spaces for left-padding the output. + """ + host = self.resolve_host() + hostname = host.name if host else "" + + OutputManager().add_line(f"{'PTR override:':<{padding}}{self.ipaddress} -> {hostname}") + + @classmethod + def output_multiple(cls, ptrs: List["PTR_override"], padding: int = 14): + """Output multiple PTR override records to the console. + + :param ptrs: List of PTR override records to output. + :param padding: Number of spaces for left-padding the output. + """ + if not ptrs: + return + + for ptr in ptrs: + ptr.output(padding=padding) -class Host(FrozenModel): + +class Host(FrozenModelWithTimestamps): """Model for an individual host. This is the endpoint at /api/v1/hosts/. @@ -264,7 +557,7 @@ class Host(FrozenModel): mxs: List[MX] = [] txts: List[TXT] = [] ptr_overrides: List[PTR_override] = [] - hinfo: Optional[str] = None + hinfo: Optional[HInfo] = None loc: Optional[str] = None bacnetid: Optional[str] = None contact: str @@ -389,7 +682,7 @@ def vlans(self) -> Dict[int, List[IPAddress]]: return ret_dict - # This wouold be greatly improved by having a proper error returned to avoid the need for + # This would be greatly improved by having a proper error returned to avoid the need for # manually calling networks() or vlans() to determine the issue. One option is to use # a custom exception, or to return a tuple of (bool, str) where the str is the error message. def all_ips_on_same_vlan(self) -> bool: @@ -424,40 +717,58 @@ def naptrs(self) -> List[NAPTR]: def srvs(self) -> List[Srv]: """Return a list of SRV records.""" - # We should access by ID, but the current tests use host__name, so to reduce - # the number of changes, we'll use name for now. - # srvs = get_list(Endpoint.Srvs, params={"host": self.id}) - srvs = get_list(Endpoint.Srvs, params={"host__name": self.name}) + srvs = get_list(Endpoint.Srvs, params={"host": self.id}) return [Srv(**srv) for srv in srvs] def output_host_info(self, names: bool = False): """Output host information to the console with padding.""" + padding = 14 + output_manager = OutputManager() - output_manager.add_line(f"Name: {self.name}") - output_manager.add_line(f"Contact: {self.contact}") + output_manager.add_line(f"{'Name:':<{padding}}{self.name}") + output_manager.add_line(f"{'Contact:':<{padding}}{self.contact}") - # Calculate padding - len_ip = max( - 14, max([len(ip.ipaddress.__str__()) for ip in self.ipaddresses], default=0) + 1 - ) - len_names = ( - 14 - if not names - else max(14, max([len(str(ip.host)) for ip in self.ipaddresses], default=0) + 1) - ) + if self.comment: + output_manager.add_line(f"{'Comment:':<{padding}}{self.comment}") - # Separate and output A and AAAA records - for record_type, records in ( - ("A_Records", self.ipv4_addresses()), - ("AAAA_Records", self.ipv6_addresses()), - ): - if records: - output_manager.add_line(f"{record_type:<{len_names}}IP{' ' * (len_ip - 2)}MAC") - for record in records: - ip = record.ipaddress.__str__() - mac = record.macaddress if record.macaddress else "" - name = str(record.host) if names else "" - output_manager.add_line(f"{name:<{len_names}}{ip:<{len_ip}}{mac}") + IPAddress.output_multiple(self.ipaddresses, padding=padding, names=names) + PTR_override.output_multiple(self.ptr_overrides, padding=padding) + + output_manager.add_line("{1:<{0}}{2}".format(padding, "TTL:", self.ttl or "(Default)")) + + MX.output_multiple(self.mxs, padding=padding) + + if self.hinfo: + self.hinfo.output(padding=padding) + + if self.loc: + output_manager.add_line(f"{'Loc:':<{padding}}{self.loc}") + + CNAME.output_multiple(self.cnames, padding=padding) + TXT.output_multiple(self.txts, padding=padding) + Srv.output_multiple(self.srvs(), padding=padding) + NAPTR.output_multiple(self.naptrs(), padding=padding) + + # output_hinfo(info["hinfo"]) + + # if info["loc"]: + # output_loc(info["loc"]) + # for cname in info["cnames"]: + # output_cname(cname["name"], info["name"]) + # for txt in info["txts"]: + # output_txt(txt["txt"]) + # output_srv(host_id=info["id"]) + # output_naptr(info) + # output_sshfp(info) + # if "bacnetid" in info: + # output_bacnetid(info.get("bacnetid")) + + # policies = get_list("/api/v1/hostpolicy/roles/", params={"hosts__name": info["name"]}) + # output_policies([p["name"] for p in policies]) + + # cli_info("printed host info for {}".format(info["name"])) + + self.output_timestamps() def __hash__(self): """Return a hash of the host.""" diff --git a/mreg_cli/commands/host_submodules/core.py b/mreg_cli/commands/host_submodules/core.py index d19e2fad..e598cd62 100644 --- a/mreg_cli/commands/host_submodules/core.py +++ b/mreg_cli/commands/host_submodules/core.py @@ -22,14 +22,13 @@ from mreg_cli.utilities.api import get, get_list, patch from mreg_cli.utilities.history import format_history_items, get_history_items from mreg_cli.utilities.host import ( - clean_hostname, cname_exists, get_host_by_name, get_requested_ip, host_info_by_name, ) from mreg_cli.utilities.output import output_host_info, output_ip_info -from mreg_cli.utilities.shared import convert_wildcard_to_regex, format_mac +from mreg_cli.utilities.shared import clean_hostname, convert_wildcard_to_regex, format_mac from mreg_cli.utilities.validators import is_valid_email, is_valid_ip, is_valid_mac from mreg_cli.utilities.zone import zone_check_for_hostname @@ -77,7 +76,7 @@ def add(args: argparse.Namespace) -> None: import mreg_cli.api as api ip = None - name = clean_hostname(args.name) + name = args.name host = api.get_host(name, ok404=True) if host: @@ -160,7 +159,7 @@ def remove(args: argparse.Namespace) -> None: """ import mreg_cli.api as api - hostname = clean_hostname(args.name) + hostname = args.name host = api.get_host(hostname) if host is None: @@ -293,7 +292,7 @@ def host_info_pydantic(args: argparse.Namespace) -> None: """Print information about host.""" import mreg_cli.api as api - host = api.get_host(args.hosts[0]) + host = api.get_host(args.hosts[0], inform_as_cname=True) if host is None: cli_warning(f"Host {args.hosts[0]} not found.") @@ -333,6 +332,8 @@ def host_info(args: argparse.Namespace) -> None: else: cli_warning(f"Found no host with macaddress: {mac}") else: + from mreg_cli.api import clean_hostname + info = host_info_by_name(name_or_ip) name = clean_hostname(name_or_ip) if any(cname["name"] == name for cname in info["cnames"]): diff --git a/mreg_cli/commands/host_submodules/rr.py b/mreg_cli/commands/host_submodules/rr.py index 6fe7a23c..70ca8924 100644 --- a/mreg_cli/commands/host_submodules/rr.py +++ b/mreg_cli/commands/host_submodules/rr.py @@ -45,7 +45,7 @@ from mreg_cli.outputmanager import OutputManager from mreg_cli.types import Flag from mreg_cli.utilities.api import delete, get_list, patch, post -from mreg_cli.utilities.host import clean_hostname, get_info_by_name, host_info_by_name +from mreg_cli.utilities.host import get_info_by_name, host_info_by_name from mreg_cli.utilities.network import get_network_by_ip, get_network_reserved_ips, ip_in_mreg_net from mreg_cli.utilities.output import ( output_hinfo, @@ -57,6 +57,7 @@ output_ttl, output_txt, ) +from mreg_cli.utilities.shared import clean_hostname from mreg_cli.utilities.validators import is_valid_ip, is_valid_ttl from mreg_cli.utilities.zone import zone_check_for_hostname @@ -585,8 +586,13 @@ def ptr_add(args: argparse.Namespace) -> None: if info["zone"] is None and not args.force: cli_warning("{} isn't in a zone controlled by MREG, must force".format(info["name"])) - network = get_network_by_ip(args.ip) - reserved_addresses = get_network_reserved_ips(network["network"]) + import ipaddress + + network = get_network_by_ip(ipaddress.ip_address(args.ip)) + if network is None: + cli_warning("No network found for {}".format(args.ip)) + + reserved_addresses = get_network_reserved_ips(str(network["network"])) if args.ip in reserved_addresses and not args.force: cli_warning("Address is reserved. Requires force") diff --git a/mreg_cli/utilities/api.py b/mreg_cli/utilities/api.py index e5b804df..50c85add 100644 --- a/mreg_cli/utilities/api.py +++ b/mreg_cli/utilities/api.py @@ -290,6 +290,28 @@ def get_list( return ret +def get_list_in( + path: str, + search_field: str, + search_values: List[int], + ok404: bool = False, +) -> List[Dict[str, Any]]: + """Get a list of items by a key value pair. + + :param path: The path to the API endpoint. + :param search_field: The field to search for. + :param search_values: The values to search for. + :param ok404: Whether to allow 404 responses. + + :returns: A list of dictionaries. + """ + return get_list( + path, + params={f"{search_field}__in": ",".join(str(x) for x in search_values)}, + ok404=ok404, + ) + + def get_item_by_key_value( path: str, search_field: str, @@ -370,10 +392,8 @@ def _check_expect_one_result( return {} if len(ret) != 1: raise CliError(f"Expected exactly one result, got {len(ret)}.") - if "results" not in ret[0]: - raise CliError("Expected 'results' in response, got none.") - return ret[0]["results"] + return ret[0] return ret diff --git a/mreg_cli/utilities/host.py b/mreg_cli/utilities/host.py index 84323c2e..947e5515 100644 --- a/mreg_cli/utilities/host.py +++ b/mreg_cli/utilities/host.py @@ -2,11 +2,9 @@ import argparse import ipaddress -import re import urllib.parse from typing import Any, Dict, Optional, Tuple, Union -from mreg_cli.config import MregCliConfig from mreg_cli.exceptions import CliWarning, HostNotFoundWarning from mreg_cli.log import cli_error, cli_info, cli_warning from mreg_cli.types import IP_Version @@ -18,7 +16,7 @@ get_network_reserved_ips, ips_are_in_same_vlan, ) -from mreg_cli.utilities.shared import format_mac +from mreg_cli.utilities.shared import clean_hostname, format_mac from mreg_cli.utilities.validators import ( is_valid_ip, is_valid_ipv4, @@ -28,37 +26,6 @@ ) -def clean_hostname(name: Union[str, bytes]) -> str: - """Convert from short to long hostname, if no domain found.""" - # bytes? - if not isinstance(name, (str, bytes)): - cli_warning("Invalid input for hostname: {}".format(name)) - - if isinstance(name, bytes): - name = name.decode() - - name = name.lower() - - # invalid characters? - if re.search(r"^(\*\.)?([a-z0-9_][a-z0-9\-]*\.?)+$", name) is None: - cli_warning("Invalid input for hostname: {}".format(name)) - - # Assume user is happy with domain, but strip the dot. - if name.endswith("."): - return name[:-1] - - # If a dot in name, assume long name. - if "." in name: - return name - - config = MregCliConfig() - default_domain = config.get("domain") - # Append domain name if in config and it does not end with it - if default_domain and not name.endswith(default_domain): - return "{}.{}".format(name, default_domain) - return name - - def get_unique_ip_by_name_or_ip(arg: str) -> Dict[str, Any]: """Get A/AAAA record by either ip address or host name. @@ -275,7 +242,7 @@ def get_requested_ip(ip: str, force: bool, ipversion: Union[IP_Version, None] = if hosts and not force: hostnames = ",".join([i["name"] for i in hosts]) cli_warning(f"{ip} already in use by: {hostnames}. Must force") - network = get_network_by_ip(ip) + network = get_network_by_ip(ipaddress.ip_address(ip)) if not network: if force: return ip @@ -294,7 +261,7 @@ def get_requested_ip(ip: str, force: bool, ipversion: Union[IP_Version, None] = if network["frozen"] and not force: cli_warning("network {} is frozen, must force".format(network["network"])) # Chat the address given isn't reserved - reserved_addresses = get_network_reserved_ips(network["network"]) + reserved_addresses = get_network_reserved_ips(str(network["network"])) if ip in reserved_addresses and not force: cli_warning("Address is reserved. Requires force") if network_object.num_addresses > 2: diff --git a/mreg_cli/utilities/network.py b/mreg_cli/utilities/network.py index 8ae47b9d..ec424fca 100644 --- a/mreg_cli/utilities/network.py +++ b/mreg_cli/utilities/network.py @@ -9,10 +9,9 @@ import urllib.parse from typing import Any, Dict, Iterable, List +from mreg_cli.api import get_network_by_ip from mreg_cli.log import cli_warning from mreg_cli.types import IP_networkT -from mreg_cli.utilities.api import get -from mreg_cli.utilities.validators import is_valid_ip, is_valid_network def get_network_first_unused_ip(network: Dict[str, Any]) -> str: @@ -30,7 +29,8 @@ def get_network_first_unused_ip(network: Dict[str, Any]) -> str: def ip_in_mreg_net(ip: str) -> bool: """Return true if the ip is in a MREG controlled network.""" - net = get_network_by_ip(ip) + ipt = ipaddress.ip_address(ip) + net = get_network_by_ip(ipt) return bool(net) @@ -44,7 +44,7 @@ def ips_are_in_same_vlan(ips: List[str]) -> bool: # IPs must be in a network, and that network must have a vlan for this to work. last_vlan = "" for ip in ips: - network = get_network_by_ip(ip) + network = get_network_by_ip(ipaddress.ip_address(ip)) if not network: return False @@ -59,26 +59,13 @@ def ips_are_in_same_vlan(ips: List[str]) -> bool: return True -def get_network_by_ip(ip: str) -> Dict[str, Any]: - """Return a network associated with given IP.""" - if is_valid_ip(ip): - path = f"/api/v1/networks/ip/{urllib.parse.quote(ip)}" - net = get(path, ok404=True) - if net: - return net.json() - else: - return {} - else: - cli_warning("Not a valid ip address") - - def get_network(ip: str) -> Dict[str, Any]: """Return a network associated with given range or IP.""" if is_valid_network(ip): path = f"/api/v1/networks/{urllib.parse.quote(ip)}" return get(path).json() elif is_valid_ip(ip): - net = get_network_by_ip(ip) + net = get_network_by_ip(ipaddress.ip_address(ip)) if net: return net cli_warning("ip address exists but is not an address in any existing network") diff --git a/mreg_cli/utilities/shared.py b/mreg_cli/utilities/shared.py index a81dac35..7e752bde 100644 --- a/mreg_cli/utilities/shared.py +++ b/mreg_cli/utilities/shared.py @@ -1,11 +1,55 @@ """Shared utilities for the mreg_cli package.""" import re -from typing import Any, Tuple +from typing import Any, Tuple, Union from mreg_cli.log import cli_warning +# Temporary, to avoid circular imports and to allow old code to remain without +# breaking. This should be removed once the all the old code is refactored. +def clean_hostname(name: Union[str, bytes]) -> str: + """Ensure hostname is fully qualified, lowercase, and has valid characters. + + :param name: The hostname to clean. + + :raises CliWarning: If the hostname is invalid. + + :returns: The cleaned hostname. + """ + import re + + from mreg_cli.config import MregCliConfig + + # bytes? + if not isinstance(name, (str, bytes)): + cli_warning("Invalid input for hostname: {}".format(name)) + + if isinstance(name, bytes): + name = name.decode() + + name = name.lower() + + # invalid characters? + if re.search(r"^(\*\.)?([a-z0-9_][a-z0-9\-]*\.?)+$", name) is None: + cli_warning("Invalid input for hostname: {}".format(name)) + + # Assume user is happy with domain, but strip the dot. + if name.endswith("."): + return name[:-1] + + # If a dot in name, assume long name. + if "." in name: + return name + + config = MregCliConfig() + default_domain = config.get("domain") + # Append domain name if in config and it does not end with it + if default_domain and not name.endswith(default_domain): + return "{}.{}".format(name, default_domain) + return name + + def string_to_int(value: Any, error_tag: str) -> int: """Convert a string to an integer.""" try: