diff --git a/pyroute2/ndb/objects/__init__.py b/pyroute2/ndb/objects/__init__.py index 8afa026dc..4d9c8e62e 100644 --- a/pyroute2/ndb/objects/__init__.py +++ b/pyroute2/ndb/objects/__init__.py @@ -302,6 +302,9 @@ def __init__( self.knorm = self.schema.compiled[self.table]['norm_idx'] self.spec = self.schema.compiled[self.table]['all_names'] self.names = self.schema.compiled[self.table]['norm_names'] + self.lookup_fallbacks = self.schema.compiled[self.table][ + 'lookup_fallbacks' + ] self.names_count = [self.names.count(x) for x in self.names] self.last_save = None if self.event_map is None: diff --git a/pyroute2/ndb/objects/interface.py b/pyroute2/ndb/objects/interface.py index b01af5581..a995a19ff 100644 --- a/pyroute2/ndb/objects/interface.py +++ b/pyroute2/ndb/objects/interface.py @@ -654,6 +654,21 @@ def do_add_vlan(self, mode, spec): self._apply_script.append((do_add_vlan, {'spec': spec})) return self + @check_auth('obj:modify') + def ensure_vlan(self, spec): + def do_ensure_vlan(self, mode, spec): + try: + method = getattr(self.vlan.create(spec), mode) + return [method()] + except KeyError: + return [] + except Exception as e_s: + e_s.trace = traceback.format_stack() + return [e_s] + + self._apply_script.append((do_ensure_vlan, {'spec': spec})) + return self + @check_auth('obj:modify') def del_vlan(self, spec): def do_del_vlan(self, mode, spec): @@ -682,6 +697,23 @@ def do_add_neighbour(self, mode, spec): self._apply_script.append((do_add_neighbour, {'spec': spec})) return self + @check_auth('obj:modify') + def ensure_neighbour(self, spec=None, **kwarg): + spec = spec or kwarg + + def do_ensure_neighbour(self, mode, spec): + try: + method = getattr(self.neighbours.create(spec), mode) + return [method()] + except KeyError: + return [] + except Exception as e_s: + e_s.trace = traceback.format_stack() + return [e_s] + + self._apply_script.append((do_ensure_neighbour, {'spec': spec})) + return self + @check_auth('obj:modify') def del_neighbour(self, spec=None, **kwarg): spec = spec or dict(kwarg) @@ -727,6 +759,23 @@ def do_add_ip(self, mode, spec): self._apply_script.append((do_add_ip, {'spec': spec})) return self + @check_auth('obj:modify') + def ensure_ip(self, spec=None, **kwarg): + spec = spec or kwarg + + def do_ensure_ip(self, mode, spec): + try: + method = getattr(self.ipaddr.create(spec), mode) + return [method()] + except KeyError: + return [] + except Exception as e_s: + e_s.trace = traceback.format_stack() + return [e_s] + + self._apply_script.append((do_ensure_ip, {'spec': spec})) + return self + @check_auth('obj:modify') def del_ip(self, spec=None, **kwarg): spec = spec or kwarg diff --git a/pyroute2/ndb/schema.py b/pyroute2/ndb/schema.py index 418b9391a..97bf83ec6 100644 --- a/pyroute2/ndb/schema.py +++ b/pyroute2/ndb/schema.py @@ -461,6 +461,7 @@ def compile_spec(self, table, schema_names, schema_idx): 'fset': ','.join(f_set), 'knames': ','.join(f_idx), 'fidx': ' AND '.join(f_idx_match), + 'lookup_fallbacks': iclass.lookup_fallbacks, } @publish diff --git a/pyroute2/ndb/view.py b/pyroute2/ndb/view.py index eb152dfe6..09da32ffb 100644 --- a/pyroute2/ndb/view.py +++ b/pyroute2/ndb/view.py @@ -189,6 +189,17 @@ def create(self, *argspec, **kwspec): spec['create'] = True return self[spec] + @cli.change_pointer + @check_auth('obj:modify') + def ensure(self, *argspec, **kwspec): + try: + obj = self.locate(**kwspec) + except KeyError: + obj = self.create(**kwspec) + for key, value in kwspec.items(): + obj[key] = value + return obj + @cli.change_pointer @check_auth('obj:modify') def add(self, *argspec, **kwspec): @@ -266,11 +277,17 @@ def locate(self, spec=None, table=None, **kwarg): iclass = self.classes[table] spec = iclass.new_spec(spec) kspec = self.ndb.schema.compiled[table]['norm_idx'] + lookup_fallbacks = self.ndb.schema.compiled[table]['lookup_fallbacks'] request = {} for name in kspec: name = iclass.nla2name(name) if name in spec: request[name] = spec[name] + elif name in lookup_fallbacks: + fallback = lookup_fallbacks[name] + if fallback in spec: + request[fallback] = spec[fallback] + if not request: raise KeyError('got an empty key') return self[request] diff --git a/pyroute2/netlink/__init__.py b/pyroute2/netlink/__init__.py index 56e587670..4ea5bd659 100644 --- a/pyroute2/netlink/__init__.py +++ b/pyroute2/netlink/__init__.py @@ -834,6 +834,7 @@ class nlmsg_base(dict): sql_constraints = {} sql_extra_fields = () sql_extend = () + lookup_fallbacks = {} nla_flags = 0 # NLA flags value_map = {} is_nla = False diff --git a/pyroute2/netlink/rtnl/ifinfmsg/__init__.py b/pyroute2/netlink/rtnl/ifinfmsg/__init__.py index fb2cfd698..e4f9e5744 100644 --- a/pyroute2/netlink/rtnl/ifinfmsg/__init__.py +++ b/pyroute2/netlink/rtnl/ifinfmsg/__init__.py @@ -450,6 +450,7 @@ class ifinfbase(object): # sql_constraints = {'index': 'NOT NULL'} sql_extra_fields = (('state', 'TEXT'),) + lookup_fallbacks = {'index': 'ifname'} fields = ( ('family', 'B'), diff --git a/tests/test_linux/test_ndb/test_ensure.py b/tests/test_linux/test_ndb/test_ensure.py new file mode 100644 index 000000000..621bab521 --- /dev/null +++ b/tests/test_linux/test_ndb/test_ensure.py @@ -0,0 +1,140 @@ +import pytest +from pr2test.context_manager import make_test_matrix +from pr2test.marks import require_root +from pr2test.tools import address_exists, interface_exists + +pytestmark = [require_root()] +test_matrix = make_test_matrix( + targets=['local', 'netns'], dbs=['sqlite3/:memory:', 'postgres/pr2test'] +) + + +@pytest.mark.parametrize( + 'host_link_attr,create,ensure', + ( + ( + None, + ( + {'kind': 'dummy', 'state': 'down'}, + {'kind': 'dummy', 'state': 'up'}, + None, + ), + ( + {'kind': 'dummy', 'state': 'up'}, + {'kind': 'dummy', 'state': 'up'}, + {'kind': 'dummy', 'state': 'up'}, + ), + ), + ( + 'link', + ( + { + 'kind': 'vlan', + 'link': None, + 'vlan_id': 1010, + 'state': 'down', + }, + {'kind': 'vlan', 'link': None, 'vlan_id': 1011, 'state': 'up'}, + None, + ), + ( + {'kind': 'vlan', 'link': None, 'vlan_id': 1010, 'state': 'up'}, + {'kind': 'vlan', 'link': None, 'vlan_id': 1011, 'state': 'up'}, + {'kind': 'vlan', 'link': None, 'vlan_id': 1012, 'state': 'up'}, + ), + ), + ( + 'vxlan_link', + ( + { + 'kind': 'vxlan', + 'vxlan_link': None, + 'vxlan_id': 2020, + 'state': 'down', + }, + { + 'kind': 'vxlan', + 'vxlan_link': None, + 'vxlan_id': 2021, + 'state': 'up', + }, + None, + ), + ( + { + 'kind': 'vxlan', + 'vxlan_link': None, + 'vxlan_id': 2020, + 'state': 'up', + }, + { + 'kind': 'vxlan', + 'vxlan_link': None, + 'vxlan_id': 2021, + 'state': 'up', + }, + { + 'kind': 'vxlan', + 'vxlan_link': None, + 'vxlan_id': 2022, + 'state': 'up', + }, + ), + ), + ), + ids=('dummy', 'vlan', 'vxlan'), +) +@pytest.mark.parametrize('context', test_matrix, indirect=True) +def test_ensure_interface_simple(context, host_link_attr, create, ensure): + # if we need a host interface + if host_link_attr is not None: + host_ifname = context.new_ifname + host_nic = context.ndb.interfaces.create( + ifname=host_ifname, kind='dummy', state='up' + ) + host_nic.commit() + for spec in create + ensure: + if spec is not None: + spec[host_link_attr] = host_nic['index'] + + # patch interface specs + for spec_create, spec_ensure in zip(create, ensure): + ifname = context.new_ifname + if spec_create is not None: + spec_create['ifname'] = ifname + if spec_ensure is not None: + spec_ensure['ifname'] = ifname + + # create interfaces + for spec in create: + if spec is not None: + context.ndb.interfaces.create(**spec).commit() + + # ensure interfaces + for spec in ensure: + if spec is not None: + context.ndb.interfaces.ensure(**spec).commit() + assert interface_exists(context.netns, **spec) + + +@pytest.mark.parametrize('context', test_matrix, indirect=True) +def test_ensure_ensure_ip(context): + ifname = context.new_ifname + ipaddr1 = context.new_ipaddr + ipaddr2 = context.new_ipaddr + + nic = context.ndb.interfaces.create( + ifname=ifname, kind='dummy', state='down' + ) + nic.add_ip(address=ipaddr1, prefixlen=24) + nic.commit() + + ( + context.ndb.interfaces.ensure(ifname=ifname, kind='dummy', state='up') + .ensure_ip(address=ipaddr1, prefixlen=24) + .ensure_ip(address=ipaddr2, prefixlen=24) + .commit() + ) + assert interface_exists(context.netns, ifname=ifname) + assert address_exists(context.netns, ifname=ifname, address=ipaddr1) + assert address_exists(context.netns, ifname=ifname, address=ipaddr2)