diff --git a/config/aaa.py b/config/aaa.py index fb2db721ae..d39b00a20f 100644 --- a/config/aaa.py +++ b/config/aaa.py @@ -1,7 +1,17 @@ import click +import ipaddress +import re from swsscommon.swsscommon import ConfigDBConnector import utilities_common.cli as clicommon +RADIUS_MAXSERVERS = 8 +RADIUS_PASSKEY_MAX_LEN = 65 +VALID_CHARS_MSG = "Valid chars are ASCII printable except SPACE, '#', and ','" + +def is_secret(secret): + return bool(re.match('^' + '[^ #,]*' + '$', secret)) + + def add_table_kv(table, entry, key, val): config_db = ConfigDBConnector() config_db.connect() @@ -61,20 +71,69 @@ def fallback(option): authentication.add_command(fallback) +# cmd: aaa authentication debug +@click.command() +@click.argument('option', type=click.Choice(["enable", "disable", "default"])) +def debug(option): + """AAA debug [enable | disable | default]""" + if option == 'default': + del_table_key('AAA', 'authentication', 'debug') + else: + if option == 'enable': + add_table_kv('AAA', 'authentication', 'debug', True) + elif option == 'disable': + add_table_kv('AAA', 'authentication', 'debug', False) +authentication.add_command(debug) + + +# cmd: aaa authentication trace +@click.command() +@click.argument('option', type=click.Choice(["enable", "disable", "default"])) +def trace(option): + """AAA packet trace [enable | disable | default]""" + if option == 'default': + del_table_key('AAA', 'authentication', 'trace') + else: + if option == 'enable': + add_table_kv('AAA', 'authentication', 'trace', True) + elif option == 'disable': + add_table_kv('AAA', 'authentication', 'trace', False) +authentication.add_command(trace) + + @click.command() -@click.argument('auth_protocol', nargs=-1, type=click.Choice(["tacacs+", "local", "default"])) +@click.argument('auth_protocol', nargs=-1, type=click.Choice(["radius", "tacacs+", "local", "default"])) def login(auth_protocol): - """Switch login authentication [ {tacacs+, local} | default ]""" + """Switch login authentication [ {radius, tacacs+, local} | default ]""" if len(auth_protocol) is 0: click.echo('Argument "auth_protocol" is required') return + elif len(auth_protocol) > 2: + click.echo('Not a valid command.') + return if 'default' in auth_protocol: + if len(auth_protocol) !=1: + click.echo('Not a valid command') + return del_table_key('AAA', 'authentication', 'login') else: val = auth_protocol[0] if len(auth_protocol) == 2: - val += ',' + auth_protocol[1] + val2 = auth_protocol[1] + good_ap = False + if val == 'local': + if val2 == 'radius' or val2 == 'tacacs+': + good_ap = True + elif val == 'radius' or val == 'tacacs+': + if val2 == 'local': + good_ap = True + if good_ap == True: + val += ',' + val2 + else: + click.echo('Not a valid command') + return + add_table_kv('AAA', 'authentication', 'login', val) authentication.add_command(login) @@ -189,3 +248,249 @@ def delete(address): config_db.connect() config_db.set_entry('TACPLUS_SERVER', address, None) tacacs.add_command(delete) + + +@click.group() +def radius(): + """RADIUS server configuration""" + pass + + +@click.group() +@click.pass_context +def default(ctx): + """set its default configuration""" + ctx.obj = 'default' +radius.add_command(default) + + +@click.command() +@click.argument('second', metavar='', type=click.IntRange(1, 60), required=False) +@click.pass_context +def timeout(ctx, second): + """Specify RADIUS server global timeout <1 - 60>""" + if ctx.obj == 'default': + del_table_key('RADIUS', 'global', 'timeout') + elif second: + add_table_kv('RADIUS', 'global', 'timeout', second) + else: + click.echo('Not support empty argument') +radius.add_command(timeout) +default.add_command(timeout) + + +@click.command() +@click.argument('retries', metavar='', type=click.IntRange(0, 10), required=False) +@click.pass_context +def retransmit(ctx, retries): + """Specify RADIUS server global retry attempts <0 - 10>""" + if ctx.obj == 'default': + del_table_key('RADIUS', 'global', 'retransmit') + elif retries != None: + add_table_kv('RADIUS', 'global', 'retransmit', retries) + else: + click.echo('Not support empty argument') +radius.add_command(retransmit) +default.add_command(retransmit) + + +@click.command() +@click.argument('type', metavar='', type=click.Choice(["chap", "pap", "mschapv2"]), required=False) +@click.pass_context +def authtype(ctx, type): + """Specify RADIUS server global auth_type [chap | pap | mschapv2]""" + if ctx.obj == 'default': + del_table_key('RADIUS', 'global', 'auth_type') + elif type: + add_table_kv('RADIUS', 'global', 'auth_type', type) + else: + click.echo('Not support empty argument') +radius.add_command(authtype) +default.add_command(authtype) + + +@click.command() +@click.argument('secret', metavar='', required=False) +@click.pass_context +def passkey(ctx, secret): + """Specify RADIUS server global passkey """ + if ctx.obj == 'default': + del_table_key('RADIUS', 'global', 'passkey') + elif secret: + if len(secret) > RADIUS_PASSKEY_MAX_LEN: + click.echo('Maximum of %d chars can be configured' % RADIUS_PASSKEY_MAX_LEN) + return + elif not is_secret(secret): + click.echo(VALID_CHARS_MSG) + return + add_table_kv('RADIUS', 'global', 'passkey', secret) + else: + click.echo('Not support empty argument') +radius.add_command(passkey) +default.add_command(passkey) + +@click.command() +@click.argument('src_ip', metavar='', required=False) +@click.pass_context +def sourceip(ctx, src_ip): + """Specify RADIUS server global source ip """ + if ctx.obj == 'default': + del_table_key('RADIUS', 'global', 'src_ip') + return + elif not src_ip: + click.echo('Not support empty argument') + return + + if not clicommon.is_ipaddress(src_ip): + click.echo('Invalid ip address') + return + + v6_invalid_list = [ipaddress.IPv6Address(unicode('0::0')), ipaddress.IPv6Address(unicode('0::1'))] + net = ipaddress.ip_network(unicode(src_ip), strict=False) + if (net.version == 4): + if src_ip == "0.0.0.0": + click.echo('enter non-zero ip address') + return + ip = ipaddress.IPv4Address(src_ip) + if ip.is_reserved: + click.echo('Reserved ip is not valid') + return + if ip.is_multicast: + click.echo('Multicast ip is not valid') + return + elif (net.version == 6): + ip = ipaddress.IPv6Address(src_ip) + if (ip.is_multicast): + click.echo('Multicast ip is not valid') + return + if (ip in v6_invalid_list): + click.echo('Invalid ip address') + return + add_table_kv('RADIUS', 'global', 'src_ip', src_ip) +radius.add_command(sourceip) +default.add_command(sourceip) + +@click.command() +@click.argument('nas_ip', metavar='', required=False) +@click.pass_context +def nasip(ctx, nas_ip): + """Specify RADIUS server global NAS-IP|IPV6-Address """ + if ctx.obj == 'default': + del_table_key('RADIUS', 'global', 'nas_ip') + return + elif not nas_ip: + click.echo('Not support empty argument') + return + + if not clicommon.is_ipaddress(nas_ip): + click.echo('Invalid ip address') + return + + v6_invalid_list = [ipaddress.IPv6Address(unicode('0::0')), ipaddress.IPv6Address(unicode('0::1'))] + net = ipaddress.ip_network(unicode(nas_ip), strict=False) + if (net.version == 4): + if nas_ip == "0.0.0.0": + click.echo('enter non-zero ip address') + return + ip = ipaddress.IPv4Address(nas_ip) + if ip.is_reserved: + click.echo('Reserved ip is not valid') + return + if ip.is_multicast: + click.echo('Multicast ip is not valid') + return + elif (net.version == 6): + ip = ipaddress.IPv6Address(nas_ip) + if (ip.is_multicast): + click.echo('Multicast ip is not valid') + return + if (ip in v6_invalid_list): + click.echo('Invalid ip address') + return + add_table_kv('RADIUS', 'global', 'nas_ip', nas_ip) +radius.add_command(nasip) +default.add_command(nasip) + +@click.command() +@click.argument('option', type=click.Choice(["enable", "disable", "default"])) +def statistics(option): + """Specify RADIUS server global statistics [enable | disable | default]""" + if option == 'default': + del_table_key('RADIUS', 'global', 'statistics') + else: + if option == 'enable': + add_table_kv('RADIUS', 'global', 'statistics', True) + elif option == 'disable': + add_table_kv('RADIUS', 'global', 'statistics', False) +radius.add_command(statistics) + + +# cmd: radius add --retransmit COUNT --timeout SECOND --key SECRET --type TYPE --auth-port PORT --pri PRIORITY +@click.command() +@click.argument('address', metavar='') +@click.option('-r', '--retransmit', help='Retransmit attempts, default 3', type=click.IntRange(1, 10)) +@click.option('-t', '--timeout', help='Transmission timeout interval, default 5', type=click.IntRange(1, 60)) +@click.option('-k', '--key', help='Shared secret') +@click.option('-a', '--auth_type', help='Authentication type, default pap', type=click.Choice(["chap", "pap", "mschapv2"])) +@click.option('-o', '--auth-port', help='UDP port range is 1 to 65535, default 1812', type=click.IntRange(1, 65535), default=1812) +@click.option('-p', '--pri', help="Priority, default 1", type=click.IntRange(1, 64), default=1) +@click.option('-m', '--use-mgmt-vrf', help="Management vrf, default is no vrf", is_flag=True) +@click.option('-s', '--source-interface', help='Source Interface') +def add(address, retransmit, timeout, key, auth_type, auth_port, pri, use_mgmt_vrf, source_interface): + """Specify a RADIUS server""" + + if key: + if len(key) > RADIUS_PASSKEY_MAX_LEN: + click.echo('--key: Maximum of %d chars can be configured' % RADIUS_PASSKEY_MAX_LEN) + return + elif not is_secret(key): + click.echo('--key: ' + VALID_CHARS_MSG) + return + + config_db = ConfigDBConnector() + config_db.connect() + old_data = config_db.get_table('RADIUS_SERVER') + if address in old_data : + click.echo('server %s already exists' % address) + return + if len(old_data) == RADIUS_MAXSERVERS: + click.echo('Maximum of %d can be configured' % RADIUS_MAXSERVERS) + else: + data = { + 'auth_port': str(auth_port), + 'priority': pri + } + if auth_type is not None: + data['auth_type'] = auth_type + if retransmit is not None: + data['retransmit'] = str(retransmit) + if timeout is not None: + data['timeout'] = str(timeout) + if key is not None: + data['passkey'] = key + if use_mgmt_vrf : + data['vrf'] = "mgmt" + if source_interface : + if (source_interface.startswith("Ethernet") or \ + source_interface.startswith("PortChannel") or \ + source_interface.startswith("Vlan") or \ + source_interface.startswith("Loopback") or \ + source_interface == "eth0"): + data['src_intf'] = source_interface + else: + click.echo('Not supported interface name (valid interface name: Etherent/PortChannel/Vlan/Loopback/eth0)') + config_db.set_entry('RADIUS_SERVER', address, data) +radius.add_command(add) + + +# cmd: radius delete +# 'del' is keyword, replace with 'delete' +@click.command() +@click.argument('address', metavar='') +def delete(address): + """Delete a RADIUS server""" + + config_db = ConfigDBConnector() + config_db.connect() + config_db.set_entry('RADIUS_SERVER', address, None) +radius.add_command(delete) diff --git a/config/main.py b/config/main.py index d27562bd4e..daeb95549c 100644 --- a/config/main.py +++ b/config/main.py @@ -868,6 +868,7 @@ def config(ctx): # Add groups from other modules config.add_command(aaa.aaa) config.add_command(aaa.tacacs) +config.add_command(aaa.radius) config.add_command(chassis_modules.chassis_modules) config.add_command(console.console) config.add_command(feature.feature) diff --git a/show/main.py b/show/main.py index 5fba9d828a..46afa1b5d6 100644 --- a/show/main.py +++ b/show/main.py @@ -1230,10 +1230,10 @@ def services(): break @cli.command() -def aaa(): +@clicommon.pass_db +def aaa(db): """Show AAA configuration""" - config_db = ConfigDBConnector() - config_db.connect() + config_db = db.cfgdb data = config_db.get_table('AAA') output = '' @@ -1281,6 +1281,58 @@ def tacacs(): output += (' %s %s\n' % (key, str(entry[key]))) click.echo(output) +@cli.command() +@clicommon.pass_db +def radius(db): + """Show RADIUS configuration""" + output = '' + config_db = db.cfgdb + data = config_db.get_table('RADIUS') + + radius = { + 'global': { + 'auth_type': 'pap (default)', + 'retransmit': '3 (default)', + 'timeout': '5 (default)', + 'passkey': ' (default)' + } + } + if 'global' in data: + radius['global'].update(data['global']) + for key in radius['global']: + output += ('RADIUS global %s %s\n' % (str(key), str(radius['global'][key]))) + + data = config_db.get_table('RADIUS_SERVER') + if data != {}: + for row in data: + entry = data[row] + output += ('\nRADIUS_SERVER address %s\n' % row) + for key in entry: + output += (' %s %s\n' % (key, str(entry[key]))) + + counters_db = SonicV2Connector(host='127.0.0.1') + counters_db.connect(counters_db.COUNTERS_DB, retry_on=False) + + if radius['global'].get('statistics', False) and (data != {}): + for row in data: + exists = counters_db.exists(counters_db.COUNTERS_DB, + 'RADIUS_SERVER_STATS:{}'.format(row)) + if not exists: + continue + + counter_entry = counters_db.get_all(counters_db.COUNTERS_DB, + 'RADIUS_SERVER_STATS:{}'.format(row)) + output += ('\nStatistics for RADIUS_SERVER address %s\n' % row) + for key in counter_entry: + if counter_entry[key] != "0": + output += (' %s %s\n' % (key, str(counter_entry[key]))) + try: + counters_db.close(counters_db.COUNTERS_DB) + except Exception as e: + pass + + click.echo(output) + # # 'mirror_session' command ("show mirror_session ...") # diff --git a/tests/aaa_test.py b/tests/aaa_test.py new file mode 100644 index 0000000000..d202b41ad7 --- /dev/null +++ b/tests/aaa_test.py @@ -0,0 +1,138 @@ +import imp +import os +import sys + +from click.testing import CliRunner +from utilities_common.db import Db + +import config.main as config +import show.main as show + +test_path = os.path.dirname(os.path.abspath(__file__)) +modules_path = os.path.dirname(test_path) +sys.path.insert(0, test_path) +sys.path.insert(0, modules_path) + +import mock_tables.dbconnector + +show_aaa_default_output="""\ +AAA authentication login local (default) +AAA authentication failthrough False (default) + +""" + +show_aaa_radius_output="""\ +AAA authentication login radius +AAA authentication failthrough False (default) + +""" + +show_aaa_radius_local_output="""\ +AAA authentication login radius,local +AAA authentication failthrough False (default) + +""" + +config_aaa_empty_output="""\ +""" + +config_aaa_not_a_valid_command_output="""\ +Not a valid command +""" + +class TestAaa(object): + @classmethod + def setup_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "1" + print("SETUP") + import config.main + imp.reload(config.main) + import show.main + imp.reload(show.main) + + @classmethod + def teardown_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "0" + print("TEARDOWN") + + def test_show_aaa_default(self): + runner = CliRunner() + result = runner.invoke(show.cli.commands["aaa"], []) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == show_aaa_default_output + + def test_config_aaa_radius(self, get_cmd_module): + (config, show) = get_cmd_module + runner = CliRunner() + db = Db() + db.cfgdb.delete_table("AAA") + + result = runner.invoke(config.config.commands["aaa"],\ + ["authentication", "login", "radius"], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_aaa_empty_output + + db.cfgdb.mod_entry("AAA", "authentication", {'login' : 'radius'}) + + result = runner.invoke(show.cli.commands["aaa"], [], obj=db) + assert result.exit_code == 0 + assert result.output == show_aaa_radius_output + + result = runner.invoke(config.config.commands["aaa"],\ + ["authentication", "login", "default"], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_aaa_empty_output + + db.cfgdb.delete_table("AAA") + + result = runner.invoke(show.cli.commands["aaa"], [], obj=db) + assert result.exit_code == 0 + assert result.output == show_aaa_default_output + + def test_config_aaa_radius_local(self, get_cmd_module): + (config, show) = get_cmd_module + runner = CliRunner() + db = Db() + db.cfgdb.delete_table("AAA") + + result = runner.invoke(config.config.commands["aaa"],\ + ["authentication", "login", "radius", "local"], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_aaa_empty_output + + db.cfgdb.mod_entry("AAA", "authentication", {'login' : 'radius,local'}) + + result = runner.invoke(show.cli.commands["aaa"], [], obj=db) + assert result.exit_code == 0 + assert result.output == show_aaa_radius_local_output + + result = runner.invoke(config.config.commands["aaa"],\ + ["authentication", "login", "default"], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_aaa_empty_output + + db.cfgdb.delete_table("AAA") + + result = runner.invoke(show.cli.commands["aaa"], [], obj=db) + assert result.exit_code == 0 + assert result.output == show_aaa_default_output + + def test_config_aaa_radius_invalid(self): + runner = CliRunner() + result = runner.invoke(config.config.commands["aaa"],\ + ["authentication", "login", "radius", "tacacs+"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_aaa_not_a_valid_command_output + diff --git a/tests/radius_test.py b/tests/radius_test.py new file mode 100644 index 0000000000..117e19bde8 --- /dev/null +++ b/tests/radius_test.py @@ -0,0 +1,194 @@ +import imp +import os +import sys + +from click.testing import CliRunner +from utilities_common.db import Db + +import config.main as config +import show.main as show + +test_path = os.path.dirname(os.path.abspath(__file__)) +modules_path = os.path.dirname(test_path) +sys.path.insert(0, test_path) +sys.path.insert(0, modules_path) + +import mock_tables.dbconnector + +show_radius_default_output="""\ +RADIUS global auth_type pap (default) +RADIUS global retransmit 3 (default) +RADIUS global timeout 5 (default) +RADIUS global passkey (default) + +""" + +show_radius_server_output="""\ +RADIUS global auth_type pap (default) +RADIUS global retransmit 3 (default) +RADIUS global timeout 5 (default) +RADIUS global passkey (default) + +RADIUS_SERVER address 10.10.10.10 + auth_port 1812 + passkey testing123 + priority 1 + retransmit 1 + src_intf eth0 + timeout 3 + +""" + +show_radius_global_output="""\ +RADIUS global auth_type chap +RADIUS global retransmit 3 (default) +RADIUS global timeout 5 (default) +RADIUS global passkey (default) + +""" + +config_radius_empty_output="""\ +""" + +config_radius_server_invalidkey_output="""\ +--key: Valid chars are ASCII printable except SPACE, '#', and ',' +""" + +config_radius_invalidipaddress_output="""\ +Invalid ip address +""" + +class TestRadius(object): + @classmethod + def setup_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "1" + print("SETUP") + import config.main + imp.reload(config.main) + import show.main + imp.reload(show.main) + + @classmethod + def teardown_class(cls): + os.environ['UTILITIES_UNIT_TESTING'] = "0" + print("TEARDOWN") + + def test_show_radius_default(self): + runner = CliRunner() + result = runner.invoke(show.cli.commands["radius"], []) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == show_radius_default_output + + def test_config_radius_server(self, get_cmd_module): + (config, show) = get_cmd_module + runner = CliRunner() + db = Db() + db.cfgdb.delete_table("RADIUS_SERVER") + + result = runner.invoke(config.config.commands["radius"],\ + ["add", "10.10.10.10", "-r", "1", "-t", "3",\ + "-k", "testing123", "-s", "eth0"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_empty_output + + db.cfgdb.mod_entry("RADIUS_SERVER", "10.10.10.10", \ + {'auth_port' : '1812', \ + 'passkey' : 'testing123', \ + 'priority' : '1', \ + 'retransmit': '1', \ + 'src_intf' : 'eth0', \ + 'timeout' : '3', \ + } \ + ) + + result = runner.invoke(show.cli.commands["radius"], [], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == show_radius_server_output + + result = runner.invoke(config.config.commands["radius"],\ + ["delete", "10.10.10.10"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_empty_output + + db.cfgdb.delete_table("RADIUS_SERVER") + + result = runner.invoke(show.cli.commands["radius"], [], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == show_radius_default_output + + def test_config_radius_server_invalidkey(self): + runner = CliRunner() + result = runner.invoke(config.config.commands["radius"],\ + ["add", "10.10.10.10", "-r", "1", "-t", "3",\ + "-k", "comma,invalid", "-s", "eth0"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_server_invalidkey_output + + def test_config_radius_nasip_invalid(self): + runner = CliRunner() + result = runner.invoke(config.config.commands["radius"],\ + ["nasip", "invalid"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_invalidipaddress_output + + def test_config_radius_sourceip_invalid(self): + runner = CliRunner() + result = runner.invoke(config.config.commands["radius"],\ + ["sourceip", "invalid"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_invalidipaddress_output + + def test_config_radius_authtype(self, get_cmd_module): + (config, show) = get_cmd_module + runner = CliRunner() + db = Db() + db.cfgdb.delete_table("RADIUS") + + result = runner.invoke(config.config.commands["radius"],\ + ["authtype", "chap"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_empty_output + + db.cfgdb.mod_entry("RADIUS", "global", \ + {'auth_type' : 'chap'} \ + ) + + result = runner.invoke(show.cli.commands["radius"], [], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == show_radius_global_output + + result = runner.invoke(config.config.commands["radius"],\ + ["default", "authtype"]) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == config_radius_empty_output + + db.cfgdb.delete_table("RADIUS") + + result = runner.invoke(show.cli.commands["radius"], [], obj=db) + print(result.exit_code) + print(result.output) + assert result.exit_code == 0 + assert result.output == show_radius_default_output +