diff --git a/hyperglass/execution/drivers/_construct.py b/hyperglass/execution/drivers/_construct.py index 81015edd..c9961a23 100644 --- a/hyperglass/execution/drivers/_construct.py +++ b/hyperglass/execution/drivers/_construct.py @@ -12,11 +12,7 @@ # Project from hyperglass.log import log -from hyperglass.constants import ( # APPEND_NEWLINE, - TRANSPORT_REST, - TARGET_FORMAT_SPACE, - TARGET_JUNIPER_ASPATH, -) +from hyperglass.constants import TRANSPORT_REST, TARGET_FORMAT_SPACE from hyperglass.configuration import commands @@ -24,12 +20,7 @@ class Construct: """Construct SSH commands/REST API parameters from validated query data.""" def __init__(self, device, query_data): - """Initialize command construction. - - Arguments: - device {object} -- Device object - query_data {object} -- Validated query object - """ + """Initialize command construction.""" log.debug( "Constructing {} query for '{}'", query_data.query_type, @@ -61,8 +52,8 @@ def __init__(self, device, query_data): ) if v is not None and self.query_data.query_target.version == v.version ] - elif self.query_data.query_type in ("bgp_aspath", "bgp_community"): + with Formatter(self.device.nos, self.query_data.query_type) as formatter: # For AS Path/Community queries, AFIs are just enabled VRF -> AFI # definitions, no IP version checking is performed (since there is no IP). self.afis = [ @@ -73,39 +64,10 @@ def __init__(self, device, query_data): ) if v is not None ] - - # For devices that follow Juniper's AS_PATH regex standards, - # filter out Cisco-style special characters. - - if ( - self.device.nos in TARGET_JUNIPER_ASPATH - and self.query_data.query_type in ("bgp_aspath",) - ): - query = str(self.query_data.query_target) - asns = re.findall(r"\d+", query) - was_modified = False - if bool(re.match(r"^\_", query)): - # Replace `_65000` with `.* 65000` - asns.insert(0, r".*") - was_modified = True - if bool(re.match(r".*(\_)$", query)): - # Replace `65000_` with `65000 .*` - asns.append(r".*") - was_modified = True - if was_modified: - self.target = " ".join(asns) - else: - self.target = query + self.target = formatter(self.query_data.query_target) def json(self, afi): - """Return JSON version of validated query for REST devices. - - Arguments: - afi {object} -- AFI object - - Returns: - {str} -- JSON query string - """ + """Return JSON version of validated query for REST devices.""" log.debug("Building JSON query for {q}", q=repr(self.query_data)) return _json.dumps( { @@ -118,14 +80,7 @@ def json(self, afi): ) def scrape(self, afi): - """Return formatted command for 'Scrape' endpoints (SSH). - - Arguments: - afi {object} -- AFI object - - Returns: - {str} -- Command string - """ + """Return formatted command for 'Scrape' endpoints (SSH).""" if self.device.structured_output: cmd_paths = ( self.device.nos, @@ -144,11 +99,7 @@ def scrape(self, afi): ) def queries(self): - """Return queries for each enabled AFI. - - Returns: - {list} -- List of queries to run - """ + """Return queries for each enabled AFI.""" query = [] for afi in self.afis: @@ -159,3 +110,92 @@ def queries(self): log.debug("Constructed query: {}", query) return query + + +class Formatter: + """Modify query target based on the device's NOS requirements and the query type.""" + + def __init__(self, nos: str, query_type: str) -> None: + """Initialize target formatting.""" + self.nos = nos + self.query_type = query_type + + def __enter__(self): + """Get the relevant formatter.""" + return self._get_formatter() + + def __exit__(self, exc_type, exc_value, exc_traceback): + """Handle context exit.""" + if exc_type is not None: + log.error(exc_traceback) + pass + + def _get_formatter(self): + if self.nos in ("juniper", "juniper_junos"): + if self.query_type == "bgp_aspath": + return self._juniper_bgp_aspath + if self.nos in ("bird", "bird_ssh"): + if self.query_type == "bgp_aspath": + return self._bird_bgp_aspath + elif self.query_type == "bgp_community": + return self._bird_bgp_community + return self._default + + def _default(self, target: str) -> str: + """Don't format targets by default.""" + return target + + def _juniper_bgp_aspath(self, target: str) -> str: + """Convert from Cisco AS_PATH format to Juniper format.""" + query = str(target) + asns = re.findall(r"\d+", query) + was_modified = False + + if bool(re.match(r"^\_", query)): + # Replace `_65000` with `.* 65000` + asns.insert(0, r".*") + was_modified = True + + if bool(re.match(r".*(\_)$", query)): + # Replace `65000_` with `65000 .*` + asns.append(r".*") + was_modified = True + + if was_modified: + modified = " ".join(asns) + log.debug("Modified target '{}' to '{}'", target, modified) + return modified + + return query + + def _bird_bgp_aspath(self, target: str) -> str: + """Convert from Cisco AS_PATH format to BIRD format.""" + + # Extract ASNs from query target string + asns = re.findall(r"\d+", target) + was_modified = False + + if bool(re.match(r"^\_", target)): + # Replace `_65000` with `.* 65000` + asns.insert(0, "*") + was_modified = True + + if bool(re.match(r".*(\_)$", target)): + # Replace `65000_` with `65000 .*` + asns.append("*") + was_modified = True + + asns.insert(0, "[=") + asns.append("=]") + + result = " ".join(asns) + + if was_modified: + log.debug("Modified target '{}' to '{}'", target, result) + + return result + + def _bird_bgp_community(self, target: str) -> str: + """Convert from standard community format to BIRD format.""" + parts = target.split(":") + return f'({",".join(parts)})'