diff --git a/config/main.py b/config/main.py index eebf6edd37..b2359539f7 100644 --- a/config/main.py +++ b/config/main.py @@ -4436,16 +4436,20 @@ def interface(ctx): def enable(ctx, ifname): config_db = ctx.obj['db'] if not interface_name_is_valid(config_db, ifname) and ifname != 'all': - click.echo("Invalid interface name") - return + ctx.fail("Invalid interface name") intf_dict = config_db.get_table('SFLOW_SESSION') - if intf_dict and ifname in intf_dict: - intf_dict[ifname]['admin_state'] = 'up' - config_db.mod_entry('SFLOW_SESSION', ifname, intf_dict[ifname]) + if ifname == 'all': + port_dict = config_db.get_table('PORT') + for port in port_dict.keys(): + config_db.mod_entry('SFLOW_SESSION', port, {'admin_state': 'up'}) else: - config_db.mod_entry('SFLOW_SESSION', ifname, {'admin_state': 'up'}) + if intf_dict and ifname in intf_dict: + intf_dict[ifname]['admin_state'] = 'up' + config_db.mod_entry('SFLOW_SESSION', ifname, intf_dict[ifname]) + else: + config_db.mod_entry('SFLOW_SESSION', ifname, {'admin_state': 'up'}) # # 'sflow' command ('config sflow interface disable ...') @@ -4456,17 +4460,21 @@ def enable(ctx, ifname): def disable(ctx, ifname): config_db = ctx.obj['db'] if not interface_name_is_valid(config_db, ifname) and ifname != 'all': - click.echo("Invalid interface name") - return + ctx.fail("Invalid interface name") intf_dict = config_db.get_table('SFLOW_SESSION') - if intf_dict and ifname in intf_dict: - intf_dict[ifname]['admin_state'] = 'down' - config_db.mod_entry('SFLOW_SESSION', ifname, intf_dict[ifname]) + if ifname == 'all': + port_dict = config_db.get_table('PORT') + for port in port_dict.keys(): + config_db.mod_entry('SFLOW_SESSION', port, {'admin_state': 'down'}) else: - config_db.mod_entry('SFLOW_SESSION', ifname, - {'admin_state': 'down'}) + if intf_dict and ifname in intf_dict: + intf_dict[ifname]['admin_state'] = 'down' + config_db.mod_entry('SFLOW_SESSION', ifname, intf_dict[ifname]) + else: + config_db.mod_entry('SFLOW_SESSION', ifname, + {'admin_state': 'down'}) # # 'sflow' command ('config sflow interface sample-rate ...') diff --git a/tests/sflow_test.py b/tests/sflow_test.py index 0e15f1e027..a39f2dc20d 100644 --- a/tests/sflow_test.py +++ b/tests/sflow_test.py @@ -261,6 +261,21 @@ def test_config_sflow_intf_sample_rate(self): return + def verify_all_ports(self, config_db, status): + port_table = config_db.get_table('PORT') + sflow_table = config_db.get_table('SFLOW_SESSION') + + for port in port_table.keys(): + if port not in sflow_table.keys(): + return False + + admin_state = 'up' if status == 'enable' else 'down' + + if sflow_table[port].get('admin_state') != admin_state: + return False + + return True + def test_config_disable_all_intf(self): db = Db() runner = CliRunner() @@ -273,8 +288,7 @@ def test_config_disable_all_intf(self): assert result.exit_code == 0 # verify in configDb - sflowSession = db.cfgdb.get_table('SFLOW_SESSION') - assert sflowSession["all"]["admin_state"] == "down" + assert self.verify_all_ports(db.cfgdb, 'disable') == True def test_config_enable_all_intf(self): db = Db() @@ -287,8 +301,7 @@ def test_config_enable_all_intf(self): assert result.exit_code == 0 # verify in configDb - sflowSession = db.cfgdb.get_table('SFLOW_SESSION') - assert sflowSession["all"]["admin_state"] == "up" + assert self.verify_all_ports(db.cfgdb, 'enable') == True @classmethod def teardown_class(cls):