diff --git a/lib/jnpr/junos/device.py b/lib/jnpr/junos/device.py index 641806803..6042cd5ef 100644 --- a/lib/jnpr/junos/device.py +++ b/lib/jnpr/junos/device.py @@ -1216,6 +1216,10 @@ def __init__(self, *vargs, **kvargs): *OPTIONAL* To disable public key authentication. default is ``None``. + :param bool hostkey_verify: + *OPTIONAL* To enable ssh_known hostkey verify + default is ``False``. + """ # ---------------------------------------- @@ -1234,6 +1238,7 @@ def __init__(self, *vargs, **kvargs): self._huge_tree = kvargs.get("huge_tree", False) self._conn_open_timeout = kvargs.get("conn_open_timeout", 30) self._look_for_keys = kvargs.get("look_for_keys", None) + self._hostkey_verify = kvargs.get("hostkey_verify", False) if self._fact_style != "new": warnings.warn( "fact-style %s will be removed in a future " @@ -1367,6 +1372,14 @@ def open(self, *vargs, **kvargs): else: look_for_keys = self._look_for_keys + # option to enable ssh_known hosts key verification + # using hostkey_verify=True + # Default is disabled with hostkey_verify=False + if self._hostkey_verify is None: + hostkey_verify = False + else: + hostkey_verify = self._hostkey_verify + # open connection using ncclient transport self._conn = netconf_ssh.connect( host=self._hostname, @@ -1374,7 +1387,7 @@ def open(self, *vargs, **kvargs): sock_fd=self._sock_fd, username=self._auth_user, password=self._auth_password, - hostkey_verify=False, + hostkey_verify=hostkey_verify, key_filename=self._ssh_private_key_file, allow_agent=allow_agent, look_for_keys=look_for_keys, diff --git a/tests/unit/facts/test_domain.py b/tests/unit/facts/test_domain.py index 201d0fca0..99958db83 100644 --- a/tests/unit/facts/test_domain.py +++ b/tests/unit/facts/test_domain.py @@ -25,6 +25,7 @@ def setUp(self, mock_connect): @patch("jnpr.junos.Device.execute") def test_domain_fact_from_config(self, mock_execute): + self.dev.facts._cache["hostname"] = "r0" mock_execute.side_effect = self._mock_manager_domain_config self.assertEqual(self.dev.facts["domain"], "juniper.net") self.assertEqual(self.dev.facts["fqdn"], "r0.juniper.net") diff --git a/tests/unit/test_device.py b/tests/unit/test_device.py index 1f8a97ebc..42d16b559 100644 --- a/tests/unit/test_device.py +++ b/tests/unit/test_device.py @@ -496,6 +496,39 @@ def test_device_open_with_look_for_keys_True(self, mock_connect, mock_execute): ) self.dev2.open() self.assertEqual(self.dev2.connected, True) + @patch("ncclient.manager.connect") + @patch("jnpr.junos.Device.execute") + def test_device_open_with_hostkey_verify_True(self, mock_connect, mock_execute): + with patch("jnpr.junos.utils.fs.FS.cat") as mock_cat: + mock_cat.return_value = """ + + domain jls.net + + """ + mock_connect.side_effect = self._mock_manager + mock_execute.side_effect = self._mock_manager + self.dev2 = Device( + host="2.2.2.2", user="test", password="password123", hostkey_verify=True + ) + self.dev2.open() + self.assertEqual(self.dev2.connected, True) + + @patch("ncclient.manager.connect") + @patch("jnpr.junos.Device.execute") + def test_device_open_with_hostkey_verify_False(self, mock_connect, mock_execute): + with patch("jnpr.junos.utils.fs.FS.cat") as mock_cat: + mock_cat.return_value = """ + + domain jls.net + + """ + mock_connect.side_effect = self._mock_manager + mock_execute.side_effect = self._mock_manager + self.dev2 = Device( + host="2.2.2.2", user="test", password="password123", hostkey_verify=False + ) + self.dev2.open() + self.assertEqual(self.dev2.connected, True) @patch("ncclient.manager.connect") @patch("jnpr.junos.Device.execute") diff --git a/tests/unit/test_factcache.py b/tests/unit/test_factcache.py index 7f63fed45..cfe90e9f9 100644 --- a/tests/unit/test_factcache.py +++ b/tests/unit/test_factcache.py @@ -32,11 +32,13 @@ def test_factcache_fact_loop(self): # Change the callback for the model # fact to be the same as the personality fact # in order to induce a fact loop. + tmp = self.dev.facts._callbacks["model"] self.dev.facts._callbacks["model"] = self.dev.facts._callbacks["personality"] # Now, trying to fetch the personality # fact should cause a FactLoopError with self.assertRaises(FactLoopError): personality = self.dev.facts["personality"] + self.dev.facts._callbacks["model"] = tmp # To clear FactLoopError def test_factcache_return_unexpected_fact(self): # Create a callback for the foo fact.