diff --git a/control/cli.py b/control/cli.py index 71ecd702..217c4382 100644 --- a/control/cli.py +++ b/control/cli.py @@ -1054,25 +1054,16 @@ def host_add(self, args): out_func, err_func = self.get_output_functions(args) if args.psk: - if len(args.psk) > len(args.host_nqn): - err_func("There are more PSK values than hosts, will ignore redundant values") - elif len(args.psk) < len(args.host_nqn): - err_func("There are more hosts than PSK values, will assume empty PSK values") + if len(args.host_nqn) > 1: + self.cli.parser.error(f"Can't have more than one host NQN when PSK keys are used") for i in range(len(args.host_nqn)): one_host_nqn = args.host_nqn[i] - one_host_psk = None - if args.psk: - try: - one_host_psk = args.psk[i] - except IndexError: - pass - if one_host_nqn == "*" and one_host_psk: - err_func(f"PSK is only allowed for specific hosts, ignoring PSK value \"{one_host_psk}\"") - one_host_psk = None + if one_host_nqn == "*" and args.psk: + self.cli.parser.error(f"PSK is only allowed for specific hosts") - req = pb2.add_host_req(subsystem_nqn=args.subsystem, host_nqn=one_host_nqn, psk=one_host_psk) + req = pb2.add_host_req(subsystem_nqn=args.subsystem, host_nqn=one_host_nqn, psk=args.psk) try: ret = self.stub.add_host(req) except Exception as ex: @@ -1219,7 +1210,7 @@ def host_list(self, args): ] host_add_args = host_common_args + [ argument("--host-nqn", "-t", help="Host NQN list", nargs="+", required=True), - argument("--psk", help="Hosts PSK key list", nargs="+", required=False), + argument("--psk", help="Hosts PSK key list", required=False), ] host_del_args = host_common_args + [ argument("--host-nqn", "-t", help="Host NQN list", nargs="+", required=True), diff --git a/tests/test_psk.py b/tests/test_psk.py index f5ed5306..9ee2e05e 100644 --- a/tests/test_psk.py +++ b/tests/test_psk.py @@ -25,8 +25,6 @@ hostnqn8 = "nqn.2014-08.org.nvmexpress:uuid:22207d09-d8af-4ed2-84ec-a6d80b0cf7f2" hostnqn9 = "nqn.2014-08.org.nvmexpress:uuid:22207d09-d8af-4ed2-84ec-a6d80b0cf7f3" hostnqn10 = "nqn.2014-08.org.nvmexpress:uuid:22207d09-d8af-4ed2-84ec-a6d80b0cf7f4" -hostnqn11 = "nqn.2014-08.org.nvmexpress:uuid:22207d09-d8af-4ed2-84ec-a6d80b0cf7f5" -hostnqn12 = "nqn.2014-08.org.nvmexpress:uuid:22207d09-d8af-4ed2-84ec-a6d80b0cf7f6" hostpsk = "NVMeTLSkey-1:01:YzrPElk4OYy1uUERriPwiiyEJE/+J5ckYpLB+5NHMsR2iBuT:" hostpsk2 = "NVMeTLSkey-1:02:FTFds4vH4utVcfrOforxbrWIgv+Qq4GQHgMdWwzDdDxE1bAqK2mOoyXxmbJxGeueEVVa/Q==:" @@ -123,18 +121,14 @@ def test_create_not_secure(caplog, gateway): def test_create_secure_list(caplog, gateway): caplog.clear() - cli(["host", "add", "--subsystem", subsystem, "--host-nqn", hostnqn8, hostnqn9, hostnqn10, "--psk", hostpsk5, hostpsk6, hostpsk7, hostpsk]) - assert f"There are more PSK values than hosts, will ignore redundant values" in caplog.text - assert f"Adding host {hostnqn8} to {subsystem}: Successful" in caplog.text - assert f"Adding host {hostnqn9} to {subsystem}: Successful" in caplog.text - assert f"Adding host {hostnqn10} to {subsystem}: Successful" in caplog.text - -def test_create_secure_list_missing_psk(caplog, gateway): - caplog.clear() - cli(["host", "add", "--subsystem", subsystem, "--host-nqn", hostnqn11, hostnqn12, "--psk", hostpsk8]) - assert f"Adding host {hostnqn11} to {subsystem}: Successful" in caplog.text - assert f"Adding host {hostnqn12} to {subsystem}: Successful" in caplog.text - assert f"There are more hosts than PSK values, will assume empty PSK values" in caplog.text + rc = 0 + try: + cli(["host", "add", "--subsystem", subsystem, "--host-nqn", hostnqn8, hostnqn9, hostnqn10, "--psk", hostpsk]) + except SystemExit as sysex: + rc = int(str(sysex)) + pass + assert rc == 2 + assert f"error: Can't have more than one host NQN when PSK keys are used" in caplog.text def test_create_secure_junk_key(caplog, gateway): caplog.clear() @@ -150,16 +144,14 @@ def test_create_secure_no_key(caplog, gateway): rc = int(str(sysex)) pass assert rc == 2 - assert f"error: argument --psk: expected at least one argument" in caplog.text + assert f"error: argument --psk: expected one argument" in caplog.text def test_list_psk_hosts(caplog, gateway): caplog.clear() hosts = cli_test(["host", "list", "--subsystem", subsystem]) found = 0 - assert len(hosts.hosts) == 10 + assert len(hosts.hosts) == 5 for h in hosts.hosts: - assert h.nqn != hostnqn3 - assert h.nqn != hostnqn5 if h.nqn == hostnqn: found += 1 assert h.use_psk @@ -175,29 +167,20 @@ def test_list_psk_hosts(caplog, gateway): elif h.nqn == hostnqn7: found += 1 assert not h.use_psk - elif h.nqn == hostnqn8: - found += 1 - assert h.use_psk - elif h.nqn == hostnqn9: - found += 1 - assert h.use_psk - elif h.nqn == hostnqn10: - found += 1 - assert h.use_psk - elif h.nqn == hostnqn11: - found += 1 - assert h.use_psk - elif h.nqn == hostnqn12: - found += 1 - assert not h.use_psk else: assert False - assert found == 10 + assert found == 5 def test_allow_any_host_with_psk(caplog, gateway): caplog.clear() - cli(["host", "add", "--subsystem", subsystem, "--host-nqn", "*", "--psk", hostpsk]) - assert f"PSK is only allowed for specific hosts, ignoring PSK value \"{hostpsk}\"" in caplog.text + rc = 0 + try: + cli(["host", "add", "--subsystem", subsystem, "--host-nqn", "*", "--psk", hostpsk]) + except SystemExit as sysex: + rc = int(str(sysex)) + pass + assert rc == 2 + assert f"error: PSK is only allowed for specific hosts" in caplog.text def test_list_listeners(caplog, gateway): caplog.clear()