Skip to content

Commit

Permalink
fix: command line params parser error (#165)
Browse files Browse the repository at this point in the history
  • Loading branch information
zhijianma authored Jan 3, 2024
1 parent b8dadc7 commit 95ca8b0
Showing 1 changed file with 54 additions and 52 deletions.
106 changes: 54 additions & 52 deletions data_juicer/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,58 +244,7 @@ def init_configs(args=None):

try:
cfg = parser.parse_args(args=args)
option_in_commands = [
''.join(arg.split('--')[1].split('.')[0]) for arg in parser.args
if '--' in arg and 'config' not in arg
]

full_option_in_commands = list(
set([
''.join(arg.split('--')[1].split('=')[0])
for arg in parser.args if '--' in arg and 'config' not in arg
]))

if cfg.process is None:
cfg.process = []

# check and update every op params in `cfg.process`
# e.g.
# `python demo.py --config demo.yaml
# --language_id_score_filter.lang en`
for i, op_in_process in enumerate(cfg.process):
op_in_process_name = list(op_in_process.keys())[0]

temp_cfg = cfg
if op_in_process_name not in option_in_commands:

# update op params to temp cfg if set
if op_in_process[op_in_process_name]:
temp_cfg = parser.merge_config(
dict_to_namespace(op_in_process), cfg)
else:

# args in the command line override the ones in `cfg.process`
for full_option_in_command in full_option_in_commands:

key = full_option_in_command.split('.')[1]
if op_in_process[
op_in_process_name] and key in op_in_process[
op_in_process_name].keys():
op_in_process[op_in_process_name].pop(key)

if op_in_process[op_in_process_name]:
temp_cfg = parser.merge_config(
dict_to_namespace(op_in_process), temp_cfg)

# update op params of cfg.process
internal_op_para = temp_cfg.get(op_in_process_name)

cfg.process[i] = {
op_in_process_name:
None if internal_op_para is None else
namespace_to_dict(internal_op_para)
}

cfg = update_op_process(cfg, parser)
cfg = init_setup_from_cfg(cfg)

# copy the config file into the work directory
Expand Down Expand Up @@ -498,6 +447,59 @@ def sort_op_by_types_and_names(op_name_classes):
return ops_sorted_by_types


def update_op_process(cfg, parser):
op_keys = list(OPERATORS.modules.keys())
args = [
arg.split('--')[1] for arg in parser.args
if arg.startswith('--') and arg.split('--')[1].split('.')[0] in op_keys
]
option_in_commands = list(set([''.join(arg.split('.')[0])
for arg in args]))
full_option_in_commands = list(
set([''.join(arg.split('=')[0]) for arg in args]))

if cfg.process is None:
cfg.process = []

# check and update every op params in `cfg.process`
# e.g.
# `python demo.py --config demo.yaml
# --language_id_score_filter.lang en`
for i, op_in_process in enumerate(cfg.process):
op_in_process_name = list(op_in_process.keys())[0]

temp_cfg = cfg
if op_in_process_name not in option_in_commands:

# update op params to temp cfg if set
if op_in_process[op_in_process_name]:
temp_cfg = parser.merge_config(
dict_to_namespace(op_in_process), cfg)
else:

# args in the command line override the ones in `cfg.process`
for full_option_in_command in full_option_in_commands:

key = full_option_in_command.split('.')[1]
if op_in_process[op_in_process_name] and key in op_in_process[
op_in_process_name].keys():
op_in_process[op_in_process_name].pop(key)

if op_in_process[op_in_process_name]:
temp_cfg = parser.merge_config(
dict_to_namespace(op_in_process), temp_cfg)

# update op params of cfg.process
internal_op_para = temp_cfg.get(op_in_process_name)

cfg.process[i] = {
op_in_process_name:
None if internal_op_para is None else
namespace_to_dict(internal_op_para)
}
return cfg


def config_backup(cfg):
cfg_path = cfg.config[0].absolute
work_dir = cfg.work_dir
Expand Down

0 comments on commit 95ca8b0

Please sign in to comment.