diff --git a/samples/configs/parameterized.json b/samples/configs/parameterized.json index c3cfa1c2..8e7200ed 100644 --- a/samples/configs/parameterized.json +++ b/samples/configs/parameterized.json @@ -11,7 +11,10 @@ { "name": "Simple Int", "param": "--simple_int", - "type": "int" + "type": "int", + "values_ui_mapping": { + "One": "1" + } }, { "name": "Simple Boolean", @@ -30,15 +33,16 @@ "param": "--simple_list=", "same_arg_param": true, "type": "list", - "default": { - "script": "echo ${Simple Text}" - }, "description": "Parameter Five", "values": [ "val1", "val3", "some long value" - ] + ], + "values_ui_mapping": { + "val1": "Value 1", + "val3": "Value 3" + } }, { "name": "File upload", @@ -183,6 +187,9 @@ "type": "int", "description": "Parameter Nine", "secure": true, + "values_ui_mapping": { + "qwerty": "2232" + }, "ui": { "separator_before": { "type": "line" diff --git a/src/config/script/list_values.py b/src/config/script/list_values.py index e97be453..e4c4d0c6 100644 --- a/src/config/script/list_values.py +++ b/src/config/script/list_values.py @@ -18,9 +18,6 @@ def get_required_parameters(self): def get_values(self, parameter_values): pass - def map_value(self, user_value): - return user_value - class EmptyValuesProvider(ValuesProvider): @@ -88,8 +85,8 @@ def get_required_parameters(self): def get_values(self, parameter_values): for param_name in self._required_parameters: - value = parameter_values.get(param_name) - if is_empty(value): + value_wrapper = parameter_values.get(param_name) + if (value_wrapper is None) or is_empty(value_wrapper.mapped_script_value): return [] parameters = self._parameters_supplier() diff --git a/src/execution/execution_service.py b/src/execution/execution_service.py index 88e53a94..4118a574 100644 --- a/src/execution/execution_service.py +++ b/src/execution/execution_service.py @@ -43,13 +43,10 @@ def get_active_executor(self, execution_id, user): return self._executors.get(execution_id) - def start_script(self, config, values, user: User): + def start_script(self, config, user: User): audit_name = user.get_audit_name() - config.set_all_param_values(values) - normalized_values = dict(config.parameter_values) - - executor = ScriptExecutor(config, normalized_values, self._env_vars) + executor = ScriptExecutor(config, self._env_vars) execution_id = self._id_generator.next_id() audit_command = executor.get_secure_command() diff --git a/src/execution/executor.py b/src/execution/executor.py index 6dfdb334..03f6007f 100644 --- a/src/execution/executor.py +++ b/src/execution/executor.py @@ -7,6 +7,7 @@ from model import model_helper from model.model_helper import read_bool from model.parameter_config import ParameterModel +from model.script_config import ConfigModel from react.observable import ObservableBase from utils import file_utils, process_utils, os_utils, string_utils from utils.env_utils import EnvVariables @@ -44,40 +45,11 @@ def _normalize_working_dir(working_directory): return file_utils.normalize_path(working_directory) -def _wrap_values(user_values, parameters): - result = {} - for parameter in parameters: - name = parameter.name - - if parameter.constant: - value = parameter.default - result[name] = _Value(None, value, value, parameter.value_to_str(value)) - continue - - if name in user_values: - user_value = user_values[name] - - if parameter.no_value: - bool_value = model_helper.read_bool(user_value) - result[name] = _Value(user_value, bool_value, bool_value) - continue - - elif user_value: - mapped_value = parameter.map_to_script(user_value) - script_arg = parameter.to_script_args(mapped_value) - secure_value = parameter.get_secured_value(script_arg) - result[name] = _Value(user_value, mapped_value, script_arg, secure_value) - else: - result[name] = _Value(None, None, None) - - return result - - class ScriptExecutor: - def __init__(self, config, parameter_values, env_vars: EnvVariables): + def __init__(self, config: ConfigModel, env_vars: EnvVariables): self.config = config self._env_vars = env_vars - self._parameter_values = _wrap_values(parameter_values, config.parameters) + self._parameter_values = dict(config.parameter_values) self._working_directory = _normalize_working_dir(config.working_directory) self.script_base_command = process_utils.split_command( @@ -356,25 +328,6 @@ def send_stdin_parameters( lambda closed_value=value: process_wrapper.write_to_input(closed_value))) -class _Value: - def __init__(self, user_value, mapped_script_value, script_arg, secure_value=None): - self.user_value = user_value - self.mapped_script_value = mapped_script_value - self.script_arg = script_arg - self.secure_value = secure_value - - def get_secure_value(self): - if self.secure_value is not None: - return self.secure_value - return self.script_arg - - def __str__(self) -> str: - if self.secure_value is not None: - return str(self.secure_value) - - return str(self.script_arg) - - class _ExpectedTextListener: def __init__(self, expected_text, callback): self.expected_text = expected_text diff --git a/src/execution/logging.py b/src/execution/logging.py index fe25b643..8331a9c9 100644 --- a/src/execution/logging.py +++ b/src/execution/logging.py @@ -130,7 +130,7 @@ def start_logging(self, execution_id, output_stream, all_audit_names, script_config, - parameter_values, + parameter_value_wrappers, start_time_millis=None): script_name = str(script_config.name) @@ -145,7 +145,7 @@ def start_logging(self, execution_id, start_time_millis, script_config.logging_config, script_config.parameters, - parameter_values) + parameter_value_wrappers) log_file_path = os.path.join(self._output_folder, log_filename) log_file_path = file_utils.create_unique_filename(log_file_path) @@ -367,7 +367,7 @@ def create_filename(self, start_time, custom_logging_config: Optional[LoggingConfig], parameter_configs, - parameter_values): + parameter_value_wrappers): audit_name = get_audit_name(all_audit_names) audit_name = file_utils.to_filename(audit_name) @@ -388,7 +388,7 @@ def create_filename(self, } filename = self._resolve_filename_template(custom_logging_config).safe_substitute(mapping) - filename = model_helper.fill_parameter_values(parameter_configs, filename, parameter_values) + filename = model_helper.fill_parameter_values(parameter_configs, filename, parameter_value_wrappers) if not filename.lower().endswith('.log'): filename += '.log' @@ -423,7 +423,7 @@ def started(execution_id, user): all_audit_names = user.audit_names output_stream = execution_service.get_anonymized_output_stream(execution_id) audit_command = execution_service.get_audit_command(execution_id) - parameter_values = execution_service.get_user_parameter_values(execution_id) + parameter_value_wrappers = script_config.parameter_values logging_service.start_logging( execution_id, @@ -433,7 +433,7 @@ def started(execution_id, user): output_stream, all_audit_names, script_config, - parameter_values) + parameter_value_wrappers) def finished(execution_id, user): exit_code = execution_service.get_exit_code(execution_id) diff --git a/src/features/file_download_feature.py b/src/features/file_download_feature.py index bd254fed..108c76a9 100644 --- a/src/features/file_download_feature.py +++ b/src/features/file_download_feature.py @@ -101,7 +101,7 @@ def _get_paths(self, execution_id, predicate): paths = [_extract_path(f) for f in config.output_files if predicate(f)] paths = [p for p in paths if p] - parameter_values = self.execution_service.get_script_parameter_values(execution_id) + parameter_value_wrappers = config.parameter_values all_audit_names = self.execution_service.get_all_audit_names(execution_id) audit_name = audit_utils.get_audit_name(all_audit_names) @@ -110,7 +110,7 @@ def _get_paths(self, execution_id, predicate): return substitute_variable_values( config.parameters, paths, - parameter_values, + parameter_value_wrappers, audit_name, username) @@ -238,10 +238,10 @@ def _add_inline_image(self, original_path, download_path): LOGGER.error('Failed to notify image listener') -def substitute_variable_values(parameter_configs, output_files, values, audit_name, username): +def substitute_variable_values(parameter_configs, output_files, value_wrappers, audit_name, username): output_file_parsed = [] for _, output_file in enumerate(output_files): - substituted_file = fill_parameter_values(parameter_configs, output_file, values) + substituted_file = fill_parameter_values(parameter_configs, output_file, value_wrappers) substituted_file = replace_auth_vars(substituted_file, username, audit_name) output_file_parsed.append(substituted_file) diff --git a/src/model/external_model.py b/src/model/external_model.py index 22e8d7df..14597070 100644 --- a/src/model/external_model.py +++ b/src/model/external_model.py @@ -39,13 +39,13 @@ def parameter_to_external(parameter): 'description': parameter.description, 'withoutValue': parameter.no_value, 'required': parameter.required, - 'default': parameter.default, + 'default': parameter.create_value_wrapper_for_default().user_value, 'type': parameter.type, 'min': parameter.min, 'max': parameter.max, 'max_length': parameter.max_length, 'regex': parameter.regex, - 'values': parameter.values, + 'values': parameter.get_ui_values(), 'secure': parameter.secure, 'fileRecursive': parameter.file_recursive, 'fileType': parameter.file_type, diff --git a/src/model/model_helper.py b/src/model/model_helper.py index 7839b75e..36b14a06 100644 --- a/src/model/model_helper.py +++ b/src/model/model_helper.py @@ -188,7 +188,7 @@ def is_empty(value): return (not value) and (value != 0) and (value is not False) -def fill_parameter_values(parameter_configs, template, values): +def fill_parameter_values(parameter_configs, template, value_wrappers): result = template for parameter_config in parameter_configs: @@ -196,15 +196,12 @@ def fill_parameter_values(parameter_configs, template, values): continue parameter_name = parameter_config.name - value = values.get(parameter_name) + value_wrapper = value_wrappers.get(parameter_name) + value = value_wrapper.mapped_script_value if value_wrapper else None if value is None: value = '' - if not isinstance(value, str): - mapped_value = parameter_config.map_to_script(value) - value = parameter_config.to_script_args(mapped_value) - result = result.replace('${' + parameter_name + '}', str(value)) return result diff --git a/src/model/parameter_config.py b/src/model/parameter_config.py index 6914e227..67f8a232 100644 --- a/src/model/parameter_config.py +++ b/src/model/parameter_config.py @@ -12,6 +12,8 @@ from model.model_helper import resolve_env_vars, replace_auth_vars, is_empty, SECURE_MASK, \ normalize_extension, read_bool_from_config, InvalidValueException, read_str_from_config, read_int_from_config from model.template_property import TemplateProperty +from model.value_mappers import create_ui_value_mapper +from model.value_wrapper import ScriptValueWrapper from react.properties import ObservableDict, observable_fields from utils import file_utils, string_utils from utils.file_utils import FileMatcher @@ -30,7 +32,7 @@ 'no_value', 'description', 'required', - 'default', + '_default', 'type', 'min', 'max', @@ -65,7 +67,7 @@ def __init__(self, parameter_config, username, audit_name, self.stdin_expected_text = parameter_config.get('stdin_expected_text') self._original_config = parameter_config - self._parameter_values = other_param_values + self._parameter_value_wrappers = other_param_values self._setup() @@ -102,7 +104,7 @@ def _setup(self): self._working_dir, self.type, self._parameters_supplier(), - self._parameter_values, + self._parameter_value_wrappers, self._process_invoker) self.file_dir = _resolve_file_dir(config, 'file_dir') self._list_files_dir = _resolve_list_files_dir(self.file_dir, self._working_dir) @@ -122,6 +124,8 @@ def _setup(self): else: self.ui_separator = None + self._ui_value_mapper = create_ui_value_mapper(config) + self._validate_config() values_provider = self._create_values_provider( @@ -134,7 +138,7 @@ def _setup(self): def _validate_config(self): param_log_name = self.str_name() - if self.constant and not self.default: + if self.constant and not self._default: message = 'Constant should have default value specified' raise Exception('Failed to set parameter "' + param_log_name + '" to constant: ' + message) @@ -172,6 +176,12 @@ def validate_parameter_dependencies(self, all_parameters): + '" of type "' + unsupported_type + '" in values.script! ') + def get_ui_values(self): + if self.values is None: + return None + + return [self._ui_value_mapper.map_to_ui_value(v) for v in self.values] + def _read_type(self, config): type = config.get('type', 'text') @@ -196,7 +206,7 @@ def _reload_values(self): self.values = None return - values = values_provider.get_values(self._parameter_values) + values = values_provider.get_values(self._parameter_value_wrappers) self.values = values def _create_values_provider(self, values_config, type, constant): @@ -249,6 +259,27 @@ def normalize_user_value(self, value): return value + def create_value_wrapper(self, user_value) -> ScriptValueWrapper: + if self.constant: + value = self._default + return ScriptValueWrapper(None, value, value, self.value_to_str(value)) + + if user_value is None: + return ScriptValueWrapper(None, None, None) + + if self.no_value: + bool_value = model_helper.read_bool(user_value) + return ScriptValueWrapper(user_value, bool_value, bool_value) + + mapped_value = self.map_to_script(user_value) + script_arg = self.to_script_args(mapped_value) + secure_value = self.get_secured_value(script_arg) + return ScriptValueWrapper(user_value, mapped_value, script_arg, secure_value) + + def create_value_wrapper_for_default(self): + ui_value = self._ui_value_mapper.map_to_ui_value(self._default) + return self.create_value_wrapper(ui_value) + def value_to_str(self, value): if self.secure: return SECURE_MASK @@ -271,15 +302,10 @@ def get_secured_value(self, value): return self.value_to_str(value) def map_to_script(self, user_value): - def map_single_value(user_value): - if self._values_provider: - return self._values_provider.map_value(user_value) - return user_value - - if self.type == PARAM_TYPE_MULTISELECT: - return [map_single_value(v) for v in user_value] + if user_value is None: + return None - elif self._is_recursive_server_file(): + if self._is_recursive_server_file(): if user_value: return os.path.join(self.file_dir, *user_value) else: @@ -290,7 +316,10 @@ def map_single_value(user_value): else: return None - return map_single_value(user_value) + if isinstance(user_value, list): + return [self._ui_value_mapper.map_to_script_value(single_value) for single_value in user_value] + else: + return self._ui_value_mapper.map_to_script_value(user_value) def to_script_args(self, script_value): if self.type == PARAM_TYPE_MULTISELECT: @@ -301,21 +330,23 @@ def to_script_args(self, script_value): return script_value - def validate_value(self, value, *, ignore_required=False): + def validate_value(self, value_wrapper: ScriptValueWrapper, *, ignore_required=False): if self.constant: return None - if is_empty(value): + user_value = value_wrapper.user_value + + if is_empty(user_value): if self.required and not ignore_required: return 'is not specified' return None - value_string = self.value_to_repr(value) + value_string = self.value_to_repr(user_value) if self.no_value: - if isinstance(value, bool): + if isinstance(user_value, bool): return None - if isinstance(value, str) and value.lower() in ['true', 'false']: + if isinstance(user_value, str) and user_value.lower() in ['true', 'false']: return None return 'should be boolean, but has value ' + value_string @@ -323,25 +354,26 @@ def validate_value(self, value, *, ignore_required=False): if self.regex is not None: regex_pattern = self.regex.get('pattern', None) if not is_empty(regex_pattern): - regex_matched = re.fullmatch(regex_pattern, value) + regex_matched = re.fullmatch(regex_pattern, user_value) if not regex_matched: description = self.regex.get('description') or regex_pattern return 'does not match regex pattern: ' + description - if (not is_empty(self.max_length)) and (len(value) > int(self.max_length)): + if (not is_empty(self.max_length)) and (len(user_value) > int(self.max_length)): return 'is longer than allowed char length (' \ - + str(len(value)) + ' > ' + str(self.max_length) + ')' + + str(len(user_value)) + ' > ' + str(self.max_length) + ')' return None if self.type == 'file_upload': - if not os.path.exists(value): - return 'Cannot find file ' + value + if not os.path.exists(user_value): + return 'Cannot find file ' + user_value return None if self.type == 'int': - if not (isinstance(value, int) or (isinstance(value, str) and string_utils.is_integer(value))): + if not (isinstance(user_value, int) or ( + isinstance(user_value, str) and string_utils.is_integer(user_value))): return 'should be integer, but has value ' + value_string - int_value = int(value) + int_value = int(user_value) if (not is_empty(self.max)) and (int_value > int(self.max)): return 'is greater than allowed value (' \ @@ -354,7 +386,7 @@ def validate_value(self, value, *, ignore_required=False): if self.type in ('ip', 'ip4', 'ip6'): try: - address = ip_address(value.strip()) + address = ip_address(user_value.strip()) if self.type == 'ip4': if not isinstance(address, IPv4Address): return value_string + ' is not an IPv4 address' @@ -364,18 +396,18 @@ def validate_value(self, value, *, ignore_required=False): except ValueError: return 'wrong IP address ' + value_string - allowed_values = self.values + allowed_values = self.get_ui_values() if (self.type == 'list') or (self._is_plain_server_file()): - if value not in allowed_values: + if user_value not in allowed_values: return 'has value ' + value_string \ + ', but should be in ' + repr(allowed_values) return None if self.type == PARAM_TYPE_MULTISELECT: - if not isinstance(value, list): - return 'should be a list, but was: ' + value_string + '(' + str(type(value)) + ')' - for value_element in value: + if not isinstance(user_value, list): + return 'should be a list, but was: ' + value_string + '(' + str(type(user_value)) + ')' + for value_element in user_value: if value_element not in allowed_values: element_str = self.value_to_repr(value_element) return 'has value ' + element_str \ @@ -383,7 +415,7 @@ def validate_value(self, value, *, ignore_required=False): return None if self._is_recursive_server_file(): - return self._validate_recursive_path(value, intermediate=False) + return self._validate_recursive_path(user_value, intermediate=False) return None @@ -471,10 +503,10 @@ def _set_default_value( working_dir, type, parameters, - parameter_values, + parameter_value_wrappers, process_invoker: ProcessInvoker): if is_empty(default_config): - self.default = default_config + self._default = default_config return script = False @@ -484,7 +516,7 @@ def _set_default_value( elif isinstance(default_config, str): string_value = default_config else: - self.default = default_config + self._default = default_config return resolved_string_value = resolve_env_vars(string_value, full_match=True) @@ -492,10 +524,10 @@ def _set_default_value( resolved_string_value = replace_auth_vars(string_value, username, audit_name) if not script: - self.default = resolved_string_value + self._default = resolved_string_value return - template_property = TemplateProperty(resolved_string_value, parameters, parameter_values) + template_property = TemplateProperty(resolved_string_value, parameters, parameter_value_wrappers) shell = read_bool_from_config('shell', default_config, default=is_empty(template_property.required_parameters)) def get_script_output(script): @@ -508,13 +540,13 @@ def get_script_output(script): return stripped_output if not template_property.required_parameters: - self.default = get_script_output(resolved_string_value) + self._default = get_script_output(resolved_string_value) else: def update_default(_, new): if new is None: - self.default = None + self._default = None else: - self.default = get_script_output(new) + self._default = get_script_output(new) template_property.subscribe(update_default) update_default(None, template_property.value) @@ -591,6 +623,7 @@ def get_sorted_config(param_config): 'type', 'no_value', 'default', 'constant', 'description', 'secure', 'values', + 'values_ui_mapping', 'min', 'max', 'max_length', diff --git a/src/model/script_config.py b/src/model/script_config.py index 8f854db5..8475654a 100644 --- a/src/model/script_config.py +++ b/src/model/script_config.py @@ -2,7 +2,7 @@ import os from collections import OrderedDict from dataclasses import dataclass, field -from typing import List +from typing import List, Optional from auth.authorization import ANY_USER from config.exceptions import InvalidConfigException @@ -83,7 +83,7 @@ def __init__(self, self._original_config = config_object self._included_config_paths = TemplateProperty(read_list(config_object, 'include'), parameters=self.parameters, - values=self.parameter_values) + value_wrappers=self.parameter_values) self._included_config_prop.bind(self._included_config_paths, self._read_and_merge_included_paths) self._reload_config() @@ -93,7 +93,7 @@ def __init__(self, self._init_parameters(username, audit_name) for parameter in self.parameters: - self.parameter_values[parameter.name] = parameter.default + self.parameter_values[parameter.name] = parameter.create_value_wrapper_for_default() self._reload_parameters({}) @@ -104,13 +104,15 @@ def set_param_value(self, param_name, value): if parameter is None: LOGGER.warning('Parameter ' + param_name + ' does not exist in ' + self.name) return - validation_error = parameter.validate_value(value, ignore_required=True) + normalized_value = parameter.normalize_user_value(value) + value_wrapper = parameter.create_value_wrapper(normalized_value) + validation_error = parameter.validate_value(value_wrapper, ignore_required=True) if validation_error is not None: - self.parameter_values[param_name] = None + self.parameter_values[param_name] = parameter.create_value_wrapper(None) raise InvalidValueException(param_name, validation_error) - self.parameter_values[param_name] = value + self.parameter_values[param_name] = value_wrapper def set_all_param_values(self, param_values, skip_invalid_parameters=False): original_values = dict(self.parameter_values) @@ -138,20 +140,21 @@ def get_sort_key(parameter): continue if parameter.constant: - value = parameter.default + value_wrapper = parameter.create_value_wrapper_for_default() else: value = parameter.normalize_user_value(param_values.get(parameter.name)) + value_wrapper = parameter.create_value_wrapper(value) - validation_error = parameter.validate_value(value) + validation_error = parameter.validate_value(value_wrapper) if validation_error: if skip_invalid_parameters: logging.warning('Parameter ' + parameter.name + ' has invalid value, skipping') - value = parameter.normalize_user_value(None) + value_wrapper = parameter.create_value_wrapper(parameter.normalize_user_value(None)) else: self.parameter_values.set(original_values) raise InvalidValueException(parameter.name, validation_error) - self.parameter_values[parameter.name] = value + self.parameter_values[parameter.name] = value_wrapper processed[parameter.name] = parameter anything_changed = True @@ -245,14 +248,14 @@ def _reload_parameters(self, old_included_config): self.parameters.append(parameter) if parameter.name not in self.parameter_values: - self.parameter_values[parameter.name] = parameter.default + self.parameter_values[parameter.name] = parameter.create_value_wrapper_for_default() continue else: LOGGER.warning('Parameter ' + parameter_name + ' exists in original and included file. ' + 'This is now allowed! Included parameter is ignored') continue - def find_parameter(self, param_name): + def find_parameter(self, param_name) -> Optional[ParameterModel]: for parameter in self.parameters: if parameter.name == param_name: return parameter diff --git a/src/model/template_property.py b/src/model/template_property.py index 9242f5fe..6b151008 100644 --- a/src/model/template_property.py +++ b/src/model/template_property.py @@ -5,10 +5,10 @@ class TemplateProperty: - def __init__(self, template_config, parameters: ObservableList, values: ObservableDict, empty=None) -> None: + def __init__(self, template_config, parameters: ObservableList, value_wrappers: ObservableDict, empty=None) -> None: self._value_property = Property(None) self._template_config = template_config - self._values = values + self._values = value_wrappers self._empty = empty self._parameters = parameters @@ -41,7 +41,7 @@ def __init__(self, template_config, parameters: ObservableList, values: Observab self._reload() if self.required_parameters: - values.subscribe(self._value_changed) + value_wrappers.subscribe(self._value_changed) parameters.subscribe(self) def _value_changed(self, parameter, old, new): @@ -59,8 +59,8 @@ def on_remove(self, parameter): def _reload(self): values_filled = True for param_name in self.required_parameters: - value = self._values.get(param_name) - if is_empty(value): + value_wrapper = self._values.get(param_name) + if value_wrapper is None or is_empty(value_wrapper.mapped_script_value): values_filled = False break diff --git a/src/model/value_mappers.py b/src/model/value_mappers.py new file mode 100644 index 00000000..c6e06e4a --- /dev/null +++ b/src/model/value_mappers.py @@ -0,0 +1,54 @@ +import abc + +from model.model_helper import read_dict + + +class ValueMapper(metaclass=abc.ABCMeta): + @abc.abstractmethod + def map_to_script_value(self, user_value): + pass + + @abc.abstractmethod + def map_to_ui_value(self, script_value): + pass + + +class DictBasedValueMapper(ValueMapper): + + def __init__(self, mappings) -> None: + self._mappings = mappings + + def map_to_script_value(self, user_value): + if user_value is None: + return None + + if not self._mappings: + return user_value + + str_user_value = str(user_value) + + for script_value, mapped_user_value in self._mappings.items(): + if mapped_user_value == str_user_value: + return script_value + + return user_value + + def map_to_ui_value(self, script_value): + if script_value is None: + return None + + if not self._mappings: + return script_value + + str_value = str(script_value) + + if str_value in self._mappings: + return self._mappings[str_value] + + return script_value + + +def create_ui_value_mapper(config) -> ValueMapper: + mappings_config = read_dict(config, 'values_ui_mapping', {}) + + return DictBasedValueMapper(mappings_config) diff --git a/src/model/value_wrapper.py b/src/model/value_wrapper.py new file mode 100644 index 00000000..ae37cc56 --- /dev/null +++ b/src/model/value_wrapper.py @@ -0,0 +1,23 @@ +class ScriptValueWrapper: + def __init__(self, user_value, mapped_script_value, script_arg, secure_value=None): + self.user_value = user_value + self.mapped_script_value = mapped_script_value + self.script_arg = script_arg + self.secure_value = secure_value + + def get_secure_value(self): + if self.secure_value is not None: + return self.secure_value + return self.script_arg + + def __str__(self) -> str: + if self.secure_value is not None: + return str(self.secure_value) + + return str(self.script_arg) + + def __eq__(self, o: object) -> bool: + return isinstance(o, ScriptValueWrapper) and (self.mapped_script_value == o.mapped_script_value) + + def __hash__(self) -> int: + return hash(self.mapped_script_value) diff --git a/src/scheduling/schedule_service.py b/src/scheduling/schedule_service.py index 2a80b4fb..b7a26a52 100644 --- a/src/scheduling/schedule_service.py +++ b/src/scheduling/schedule_service.py @@ -85,7 +85,11 @@ def create_job(self, script_name, parameter_values, incoming_schedule_config, us id = self._id_generator.next_id() - normalized_values = dict(config_model.parameter_values) + normalized_values = {} + for parameter_name, value_wrapper in config_model.parameter_values.items(): + if value_wrapper.user_value is not None: + normalized_values[parameter_name] = value_wrapper.user_value + job = SchedulingJob(id, user, schedule_config, script_name, normalized_values) job_path = self.save_job(job) @@ -131,7 +135,7 @@ def _execute_job(self, job: SchedulingJob, job_path): config = self._config_service.load_config_model(script_name, user, parameter_values) self.validate_script_config(config) - execution_id = self._execution_service.start_script(config, parameter_values, user) + execution_id = self._execution_service.start_script(config, user) LOGGER.info('Started script #' + str(execution_id) + ' for ' + job.get_log_name()) if config.scheduling_auto_cleanup: diff --git a/src/tests/execution_logging_test.py b/src/tests/execution_logging_test.py index b564970e..738b4da9 100644 --- a/src/tests/execution_logging_test.py +++ b/src/tests/execution_logging_test.py @@ -525,7 +525,6 @@ class ExecutionLoggingInitiatorTest(unittest.TestCase): def test_start_logging_on_execution_start(self): execution_id = self.executor_service.start_script( create_config_model('my_script'), - {}, User('userX', create_audit_names(ip='localhost'))) executor = self.executor_service.get_active_executor(execution_id, USER_X) @@ -544,10 +543,10 @@ def test_logging_values(self): script_command='echo', parameters=[param1, param2, param3, param4], logging_config=LoggingConfig('test-${SCRIPT}-${p1}')) + config_model.set_all_param_values({'p1': 'abc', 'p3': True, 'p4': 987}) execution_id = self.executor_service.start_script( config_model, - {'p1': 'abc', 'p3': True, 'p4': 987}, User('userX', create_audit_names(ip='localhost', auth_username='sandy'))) executor = self.executor_service.get_active_executor(execution_id, USER_X) @@ -577,7 +576,6 @@ def test_exit_code(self): execution_id = self.executor_service.start_script( config_model, - {}, User('userX', create_audit_names(ip='localhost'))) executor = self.executor_service.get_active_executor(execution_id, USER_X) diff --git a/src/tests/execution_service_test.py b/src/tests/execution_service_test.py index 887d0da0..cbda159e 100644 --- a/src/tests/execution_service_test.py +++ b/src/tests/execution_service_test.py @@ -170,7 +170,8 @@ def test_get_user_parameter_values(self): 'test_get_user_parameter_values', username=DEFAULT_USER_ID, parameters=parameters.values()) - execution_id = self._start_with_config(execution_service, config_model, parameter_values) + config_model.set_all_param_values(parameter_values) + execution_id = self._start_with_config(execution_service, config_model) self.assertEqual(parameter_values, execution_service.get_user_parameter_values(execution_id)) @@ -187,7 +188,8 @@ def test_get_script_parameter_values(self): 'test_get_user_parameter_values', username=DEFAULT_USER_ID, parameters=parameters.values()) - execution_id = self._start_with_config(execution_service, config_model, parameter_values) + config_model.set_all_param_values(parameter_values) + execution_id = self._start_with_config(execution_service, config_model) self.assertEqual({'x': 1, 'y': '2', 'z': True, 'const': 'abc'}, execution_service.get_script_parameter_values(execution_id)) @@ -237,15 +239,9 @@ def test_finish_listener_by_id(self): def _start(self, execution_service, user_id=DEFAULT_USER_ID): return _start(execution_service, user_id) - def _start_with_config(self, execution_service, config, parameter_values=None, user_id=DEFAULT_USER_ID): - if parameter_values is None: - parameter_values = {} - + def _start_with_config(self, execution_service, config, user_id=DEFAULT_USER_ID): user = User(user_id, DEFAULT_AUDIT_NAMES) - execution_id = execution_service.start_script( - config, - parameter_values, - user) + execution_id = execution_service.start_script(config, user) return execution_id def create_execution_service(self): @@ -416,10 +412,11 @@ def _start_with_config(execution_service, config, parameter_values=None, user_id if parameter_values is None: parameter_values = {} + config.set_all_param_values(parameter_values) + user = User(user_id, DEFAULT_AUDIT_NAMES) execution_id = execution_service.start_script( config, - parameter_values, user) execution_owners[execution_id] = user return execution_id diff --git a/src/tests/executor_test.py b/src/tests/executor_test.py index 9a010707..1245c910 100644 --- a/src/tests/executor_test.py +++ b/src/tests/executor_test.py @@ -22,7 +22,8 @@ def parse_env_variables(output): class TestScriptExecutor(unittest.TestCase): def test_start_without_values(self): - self.create_executor(create_config_model('config_x'), {}) + config = create_config_model('config_x', parameter_values={}) + self.create_executor(config) self.executor.start(123) process_wrapper = self.executor.process_wrapper @@ -35,8 +36,12 @@ def test_start_without_values(self): self.assertEqual(expected_values, process_wrapper.all_env_variables) def test_start_with_one_value(self): - config = create_config_model('config_x', parameters=[create_script_param_config('id')]) - self.create_executor(config, {'id': 918273}) + config = create_config_model( + 'config_x', + parameters=[create_script_param_config('id')], + parameter_values={'id': 918273}) + + self.create_executor(config) self.executor.start(123) process_wrapper = self.executor.process_wrapper @@ -46,12 +51,15 @@ def test_start_with_one_value(self): {'PARAM_ID': '918273', 'EXECUTION_ID': '123'}) def test_start_with_multiple_values(self): - config = create_config_model('config_x', parameters=[ - create_script_param_config('id'), - create_script_param_config('name', env_var='My_Name', param='-n'), - create_script_param_config('verbose', param='--verbose', no_value=True), - ]) - self.create_executor(config, {'id': 918273, 'name': 'UserX', 'verbose': True}) + config = create_config_model( + 'config_x', + parameters=[ + create_script_param_config('id'), + create_script_param_config('name', env_var='My_Name', param='-n'), + create_script_param_config('verbose', param='--verbose', no_value=True), + ], + parameter_values={'id': 918273, 'name': 'UserX', 'verbose': True}) + self.create_executor(config) self.executor.start(123) process_wrapper = self.executor.process_wrapper @@ -73,10 +81,11 @@ def test_env_variables_when_pty(self): create_script_param_config('id'), create_script_param_config('name', env_var='My_Name', param='-n'), create_script_param_config('verbose', param='--verbose', no_value=True), - ]) + ], + parameter_values={'id': '918273', 'name': 'UserX', 'verbose': True}) executor._process_creator = create_process_wrapper - self.create_executor(config, {'id': '918273', 'name': 'UserX', 'verbose': True}) + self.create_executor(config) self.executor.start(123) data = read_until_closed(self.executor.get_raw_output_stream(), 100) @@ -99,10 +108,11 @@ def test_env_variables_when_popen(self): create_script_param_config('id'), create_script_param_config('name', env_var='My_Name', param='-n'), create_script_param_config('verbose', param='--verbose', no_value=True), - ]) + ], + parameter_values={'id': '918273', 'name': 'UserX', 'verbose': True}) executor._process_creator = create_process_wrapper - self.create_executor(config, {'id': '918273', 'name': 'UserX', 'verbose': True}) + self.create_executor(config) self.executor.start(123) data = read_until_closed(self.executor.get_raw_output_stream(), 100) @@ -116,11 +126,14 @@ def test_env_variables_when_popen(self): self.assertEqual('123', variables.get('EXECUTION_ID')) def test_start_with_multiple_values_when_one_not_exist(self): - config = create_config_model('config_x', parameters=[ - create_script_param_config('id'), - create_script_param_config('verbose', param='--verbose', no_value=True), - ]) - self.create_executor(config, {'id': 918273, 'name': 'UserX', 'verbose': True}) + config = create_config_model( + 'config_x', + parameters=[ + create_script_param_config('id'), + create_script_param_config('verbose', param='--verbose', no_value=True), + ], + parameter_values={'id': 918273, 'name': 'UserX', 'verbose': True}) + self.create_executor(config) self.executor.start(123) process_wrapper = self.executor.process_wrapper @@ -137,10 +150,11 @@ def test_pass_as(self): create_script_param_config('p1', pass_as='argument'), create_script_param_config('p2', pass_as='env_variable'), create_script_param_config('p3', pass_as='stdin'), - ]) + ], + parameter_values={'p1': 'abc', 'p2': 'def', 'p3': 'xyz'}) executor._process_creator = create_process_wrapper - self.create_executor(config, {'p1': 'abc', 'p2': 'def', 'p3': 'xyz'}) + self.create_executor(config) self.executor.start(123) data = read_until_closed(self.executor.get_raw_output_stream(), 200) @@ -180,10 +194,11 @@ def test_pass_as_stdin(self): create_script_param_config('p5', pass_as='stdin', no_value=True), create_script_param_config('p6', pass_as='stdin', no_value=True), create_script_param_config('p7', pass_as='stdin', stdin_expected_text='b'), - ]) + ], + parameter_values={'p1': 'xxx', 'p2': 'yyy', 'p3': [1, 3, 7], 'p5': True, 'p6': False, 'p7': 'zzz'}) executor._process_creator = create_process_wrapper - self.create_executor(config, {'p1': 'xxx', 'p2': 'yyy', 'p3': [1, 3, 7], 'p5': True, 'p6': False, 'p7': 'zzz'}) + self.create_executor(config) self.executor.start(123) data = read_until_closed(self.executor.get_raw_output_stream(), 1000) @@ -200,8 +215,29 @@ def test_pass_as_stdin(self): inputs: 'xxx' '1,3,7' 'true' 'false' 'zzz' 'yyy' '''), output) - def create_executor(self, config, parameter_values): - self.executor = ScriptExecutor(config, parameter_values, test_utils.env_variables) + def test_values_ui_mapping(self): + config = create_config_model( + 'config_x', + script_command='echo ', + parameters=[ + create_script_param_config('p1', type='int', values_ui_mapping={'One': '1'}), + create_script_param_config('p2', type='list', + allowed_values=['abc'], + values_ui_mapping={'abc': 'qwerty'}), + ], + parameter_values={'p1': '1', 'p2': 'qwerty'}) + + executor._process_creator = create_process_wrapper + self.create_executor(config) + self.executor.start(123) + + data = read_until_closed(self.executor.get_raw_output_stream(), 1000) + output = ''.join(data) + + self.assertEqual('One abc\n', output) + + def create_executor(self, config): + self.executor = ScriptExecutor(config, test_utils.env_variables) def setUp(self): executor._process_creator = _MockProcessWrapper @@ -307,7 +343,11 @@ def test_parameter_multiselect_when_empty_list(self): self.assertEqual([], args_list) def test_parameter_multiselect_when_single_list(self): - parameter = create_script_param_config('p1', param='-p1', type=PARAM_TYPE_MULTISELECT) + parameter = create_script_param_config( + 'p1', + param='-p1', + type=PARAM_TYPE_MULTISELECT, + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1']}, config) @@ -315,7 +355,11 @@ def test_parameter_multiselect_when_single_list(self): self.assertEqual(['-p1', 'val1'], args_list) def test_parameter_multiselect_when_single_list_as_multiarg(self): - parameter = create_script_param_config('p1', param='-p1', type=PARAM_TYPE_MULTISELECT) + parameter = create_script_param_config( + 'p1', + param='-p1', + type=PARAM_TYPE_MULTISELECT, + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1']}, config) @@ -323,7 +367,10 @@ def test_parameter_multiselect_when_single_list_as_multiarg(self): self.assertEqual(['-p1', 'val1'], args_list) def test_parameter_multiselect_when_multiple_list(self): - parameter = create_script_param_config('p1', type=PARAM_TYPE_MULTISELECT) + parameter = create_script_param_config( + 'p1', + type=PARAM_TYPE_MULTISELECT, + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -331,7 +378,11 @@ def test_parameter_multiselect_when_multiple_list(self): self.assertEqual(['val1,val2,hello world'], args_list) def test_parameter_multiselect_when_multiple_list_and_custom_separator(self): - parameter = create_script_param_config('p1', type=PARAM_TYPE_MULTISELECT, multiselect_separator='; ') + parameter = create_script_param_config( + 'p1', + type=PARAM_TYPE_MULTISELECT, + multiselect_separator='; ', + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -339,9 +390,11 @@ def test_parameter_multiselect_when_multiple_list_and_custom_separator(self): self.assertEqual(['val1; val2; hello world'], args_list) def test_parameter_multiselect_when_multiple_list_as_multiarg(self): - parameter = create_script_param_config('p1', - type=PARAM_TYPE_MULTISELECT, - multiselect_argument_type='argument_per_value') + parameter = create_script_param_config( + 'p1', + type=PARAM_TYPE_MULTISELECT, + multiselect_argument_type='argument_per_value', + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -365,7 +418,12 @@ def test_parameter_int_without_space(self): self.assertEqual(['-p1=10'], args_string) def test_parameter_multiselect_when_multiple_list_without_space(self): - parameter = create_script_param_config('p1', param='--p1=', type=PARAM_TYPE_MULTISELECT, same_arg_param=True) + parameter = create_script_param_config( + 'p1', + param='--p1=', + type=PARAM_TYPE_MULTISELECT, + same_arg_param=True, + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -373,10 +431,13 @@ def test_parameter_multiselect_when_multiple_list_without_space(self): self.assertEqual(['--p1=val1,val2,hello world'], args_list) def test_parameter_multiselect_when_multiple_list_and_argument_per_value_without_space(self): - parameter = create_script_param_config('p1', param='--p1=', - type=PARAM_TYPE_MULTISELECT, - multiselect_argument_type='argument_per_value', - same_arg_param=True) + parameter = create_script_param_config( + 'p1', + param='--p1=', + type=PARAM_TYPE_MULTISELECT, + multiselect_argument_type='argument_per_value', + same_arg_param=True, + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -384,9 +445,12 @@ def test_parameter_multiselect_when_multiple_list_and_argument_per_value_without self.assertEqual(['--p1=val1', 'val2', 'hello world'], args_list) def test_parameter_multiselect_when_multiple_list_as_multiarg_repeat_param(self): - parameter = create_script_param_config('p1', param='-p1', - type=PARAM_TYPE_MULTISELECT, - multiselect_argument_type='repeat_param_value') + parameter = create_script_param_config( + 'p1', + param='-p1', + type=PARAM_TYPE_MULTISELECT, + multiselect_argument_type='repeat_param_value', + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -394,10 +458,13 @@ def test_parameter_multiselect_when_multiple_list_as_multiarg_repeat_param(self) self.assertEqual(['-p1', 'val1', '-p1', 'val2', '-p1', 'hello world'], args_list) def test_parameter_multiselect_when_multiple_list_as_multiarg_repeat_param_without_space(self): - parameter = create_script_param_config('p1', param='--p1=', - type=PARAM_TYPE_MULTISELECT, - multiselect_argument_type='repeat_param_value', - same_arg_param=True) + parameter = create_script_param_config( + 'p1', + param='--p1=', + type=PARAM_TYPE_MULTISELECT, + multiselect_argument_type='repeat_param_value', + same_arg_param=True, + allowed_values=['val1', 'val2', 'hello world']) config = create_config_model('config_x', parameters=[parameter]) args_list = self.build_command_args({'p1': ['val1', 'val2', 'hello world']}, config) @@ -425,7 +492,9 @@ def build_command_args(self, param_values, config): if config.script_command is None: config.script_command = 'ping' - script_executor = ScriptExecutor(config, param_values, test_utils.env_variables) + config.set_all_param_values(param_values) + + script_executor = ScriptExecutor(config, test_utils.env_variables) args_string = executor.build_command_args(script_executor.get_script_parameter_values(), config) return args_string @@ -490,6 +559,25 @@ def test_log_with_secure(self): output = self.get_finish_output() self.assertEqual(output, '******| some text\nand ****** new line with some long long text |******') + def test_log_with_secure_when_ui_mapping(self): + parameter = create_script_param_config( + 'p1', + secure=True, + type='list', + allowed_values=['abc', 'def', 'xyz'], + values_ui_mapping={'abc': 'qwerty'} + ) + config = self._create_config(parameters=[parameter]) + + self.create_and_start_executor(config, {'p1': 'qwerty'}) + + self.write_process_output(' qwerty def abc xyz ') + + self.finish_process() + + output = self.get_finish_output() + self.assertEqual(output, ' qwerty def ****** xyz ') + def test_log_with_secure_ignore_whitespaces(self): parameter = create_script_param_config('p1', secure=True) config = self._create_config(parameters=[parameter]) @@ -519,7 +607,11 @@ def test_log_with_secure_ignore_inside_word(self): self.assertEqual(output, '******\n-******-\nbobcat\ncatty\n1cat\nmy ****** is cute') def test_log_with_secure_when_multiselect(self): - parameter = create_script_param_config('p1', secure=True, type=PARAM_TYPE_MULTISELECT) + parameter = create_script_param_config( + 'p1', + secure=True, + type=PARAM_TYPE_MULTISELECT, + allowed_values=['123', '456', 'password']) config = self._create_config(parameters=[parameter]) self.create_and_start_executor(config, {'p1': ['123', 'password']}) @@ -588,7 +680,9 @@ def create_and_start_executor(self, config, parameter_values=None): if parameter_values is None: parameter_values = {} - self.executor = ScriptExecutor(config, parameter_values, test_utils.env_variables) + config.set_all_param_values(parameter_values) + + self.executor = ScriptExecutor(config, test_utils.env_variables) self.executor.start(123) return self.executor @@ -615,7 +709,12 @@ def test_parameter_secure_value_and_same_unsecure(self): self.assertEqual('ls -p1 ****** -p2 value', secure_command) def test_parameter_secure_multiselect(self): - parameter = create_script_param_config('p1', param='-p1', secure=True, type=PARAM_TYPE_MULTISELECT) + parameter = create_script_param_config( + 'p1', + param='-p1', + secure=True, + type=PARAM_TYPE_MULTISELECT, + allowed_values=['one', 'two', 'three']) secure_command = self.get_secure_command([parameter], {'p1': ['one', 'two', 'three']}) @@ -623,7 +722,12 @@ def test_parameter_secure_multiselect(self): def test_parameter_secure_multiselect_as_multiarg(self): parameter = create_script_param_config( - 'p1', param='-p1', secure=True, type=PARAM_TYPE_MULTISELECT, multiselect_argument_type='argument_per_value') + 'p1', + param='-p1', + secure=True, + type=PARAM_TYPE_MULTISELECT, + multiselect_argument_type='argument_per_value', + allowed_values=['one', 'two', 'three']) secure_command = self.get_secure_command([parameter], {'p1': ['one', 'two', 'three']}) @@ -639,7 +743,10 @@ def test_parameter_no_value(self): def test_parameter_multiselect_and_argument_per_value(self): parameter = create_script_param_config( - 'p1', param='-p1', type=PARAM_TYPE_MULTISELECT, multiselect_argument_type='argument_per_value') + 'p1', param='-p1', + type=PARAM_TYPE_MULTISELECT, + multiselect_argument_type='argument_per_value', + allowed_values=['abc', 'xyz', 'def']) secure_command = self.get_secure_command([parameter], {'p1': ['abc', 'def']}) @@ -647,7 +754,10 @@ def test_parameter_multiselect_and_argument_per_value(self): def test_when_parameter_multiselect_and_comma_separated(self): parameter = create_script_param_config( - 'p1', param='-p1', type=PARAM_TYPE_MULTISELECT) + 'p1', + param='-p1', + type=PARAM_TYPE_MULTISELECT, + allowed_values=['abc', 'def']) secure_command = self.get_secure_command([parameter], {'p1': ['abc', 'def']}) @@ -677,8 +787,8 @@ def test_secure_parameter_int(self): self.assertEqual('ls -p1 ******', secure_command) def get_secure_command(self, parameters, values): - config = create_config_model('config_x', parameters=parameters) - executor = ScriptExecutor(config, values, test_utils.env_variables) + config = create_config_model('config_x', parameters=parameters, parameter_values=values) + executor = ScriptExecutor(config, test_utils.env_variables) return executor.get_secure_command() diff --git a/src/tests/file_download_feature_test.py b/src/tests/file_download_feature_test.py index 63abbe47..d95f1679 100644 --- a/src/tests/file_download_feature_test.py +++ b/src/tests/file_download_feature_test.py @@ -10,7 +10,7 @@ from files.user_file_storage import UserFileStorage from tests import test_utils from tests.test_utils import create_parameter_model, _MockProcessWrapper, _IdGeneratorMock, create_config_model, \ - create_audit_names, create_script_param_config, AnyUserAuthorizer + create_audit_names, create_script_param_config, AnyUserAuthorizer, wrap_values from utils import file_utils, os_utils from utils.file_utils import normalize_path @@ -210,10 +210,12 @@ def test_no_parameters(self): def test_single_replace(self): parameter = create_parameter_model('param1') + parameters = [parameter] + value_wrappers = wrap_values(parameters, {'param1': 'val1'}) files = file_download_feature.substitute_variable_values( - [parameter], + parameters, ['/home/user/${param1}.txt'], - {'param1': 'val1'}, + value_wrappers, '127.0.0.1', 'user_X') @@ -224,10 +226,11 @@ def test_two_replaces(self): parameters.append(create_parameter_model('param1', all_parameters=parameters)) parameters.append(create_parameter_model('param2', all_parameters=parameters)) + value_wrappers = wrap_values(parameters, {'param1': 'val1', 'param2': 'val2'}) files = file_download_feature.substitute_variable_values( parameters, ['/home/${param2}/${param1}.txt'], - {'param1': 'val1', 'param2': 'val2'}, + value_wrappers, '127.0.0.1', 'user_X') @@ -238,10 +241,11 @@ def test_two_replaces_in_two_files(self): parameters.append(create_parameter_model('param1', all_parameters=parameters)) parameters.append(create_parameter_model('param2', all_parameters=parameters)) + value_wrappers = wrap_values(parameters, {'param1': 'val1', 'param2': 'val2'}) files = file_download_feature.substitute_variable_values( parameters, ['/home/${param2}/${param1}.txt', '/tmp/${param2}.txt', '/${param1}'], - {'param1': 'val1', 'param2': 'val2'}, + value_wrappers, '127.0.0.1', 'user_X') @@ -250,10 +254,11 @@ def test_two_replaces_in_two_files(self): def test_no_pattern_match(self): param1 = create_parameter_model('param1') + parameters = [param1] files = file_download_feature.substitute_variable_values( - [param1], + parameters, ['/home/user/${paramX}.txt'], - {'param1': 'val1'}, + wrap_values(parameters, {'param1': 'val1'}), '127.0.0.1', 'user_X') @@ -262,10 +267,11 @@ def test_no_pattern_match(self): def test_skip_secure_replace(self): param1 = create_parameter_model('param1', secure=True) + parameters = [param1] files = file_download_feature.substitute_variable_values( - [param1], + parameters, ['/home/user/${param1}.txt'], - {'param1': 'val1'}, + wrap_values(parameters, {'param1': 'val1'}), '127.0.0.1', 'user_X') @@ -274,10 +280,11 @@ def test_skip_secure_replace(self): def test_skip_flag_replace(self): param1 = create_parameter_model('param1', no_value=True) + parameters = [param1] files = file_download_feature.substitute_variable_values( - [param1], + parameters, ['/home/user/${param1}.txt'], - {'param1': 'val1'}, + wrap_values(parameters, {'param1': 'val1'}), '127.0.0.1', 'user_X') @@ -286,10 +293,11 @@ def test_skip_flag_replace(self): def test_replace_audit_name(self): param1 = create_parameter_model('param1', no_value=True) + parameters = [param1] files = file_download_feature.substitute_variable_values( - [param1], + parameters, ['/home/user/${auth.audit_name}.txt'], - {'param1': 'val1'}, + wrap_values(parameters, {'param1': 'val1'}), '127.0.0.1', 'user_X') @@ -298,10 +306,11 @@ def test_replace_audit_name(self): def test_replace_username(self): param1 = create_parameter_model('param1', no_value=True) + parameters = [param1] files = file_download_feature.substitute_variable_values( - [param1], + parameters, ['/home/user/${auth.username}.txt'], - {'param1': 'val1'}, + wrap_values(parameters, {'param1': 'val1'}), '127.0.0.1', 'user_X') @@ -310,10 +319,11 @@ def test_replace_username(self): def test_replace_username_and_param(self): param1 = create_parameter_model('param1') + parameters = [param1] files = file_download_feature.substitute_variable_values( - [param1], + parameters, ['/home/${auth.username}/${param1}.txt'], - {'param1': 'val1'}, + wrap_values(parameters, {'param1': 'val1'}), '127.0.0.1', 'user_X') @@ -370,10 +380,11 @@ def perform_execution(self, output_files, parameter_values=None, parameters=None parameters = [create_script_param_config(key) for key in parameter_values.keys()] config_model = create_config_model('my_script', output_files=output_files, parameters=parameters) + config_model.set_all_param_values(parameter_values) user = User('userX', create_audit_names(ip='127.0.0.1')) execution_id = self.executor_service.start_script( - config_model, parameter_values, user) + config_model, user) self.executor_service.stop_script(execution_id, user) finish_condition = threading.Event() @@ -613,7 +624,7 @@ def write_output(self, execution_id, output): process_wrapper.write_output(output) def start_execution(self, config): - execution_id = self.executor_service.start_script(config, {}, DEFAULT_USER) + execution_id = self.executor_service.start_script(config, DEFAULT_USER) self.file_download_feature.subscribe_on_inline_images(execution_id, self._add_image) return execution_id diff --git a/src/tests/list_values_test.py b/src/tests/list_values_test.py index 2e44a738..81e02d98 100644 --- a/src/tests/list_values_test.py +++ b/src/tests/list_values_test.py @@ -5,7 +5,7 @@ from config.script.list_values import DependantScriptValuesProvider, FilesProvider, ScriptValuesProvider from tests import test_utils -from tests.test_utils import create_parameter_model +from tests.test_utils import create_parameter_model, wrap_values from utils import file_utils from utils.process_utils import ExecutionException @@ -93,12 +93,16 @@ def test_get_values_when_no_values(self, shell): @parameterized.expand([(True,), (False,)]) def test_get_values_when_single_parameter(self, shell): + parameters_supplier = self.create_parameters_supplier('param1') + values_provider = DependantScriptValuesProvider( "echo '_${param1}_'", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual(['_hello world_'], values_provider.get_values({'param1': 'hello world'})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': 'hello world'}) + self.assertEqual(['_hello world_'], values_provider.get_values(value_wrappers)) @parameterized.expand([(True,), (False,)]) def test_get_values_when_multiple_parameters(self, shell): @@ -106,69 +110,91 @@ def test_get_values_when_multiple_parameters(self, shell): for i in range(0, 5): file_utils.write_file(os.path.join(files_path, 'f' + str(i) + '.txt'), 'test') + parameters_supplier = self.create_parameters_supplier('param1', 'param2') values_provider = DependantScriptValuesProvider( 'ls ' + test_utils.temp_folder + '/${param1}/${param2}', - self.create_parameters_supplier('param1', 'param2'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': 'path1', 'param2': 'path2'}) + self.assertEqual(['f0.txt', 'f1.txt', 'f2.txt', 'f3.txt', 'f4.txt'], - values_provider.get_values({'param1': 'path1', 'param2': 'path2'})) + values_provider.get_values(value_wrappers)) @parameterized.expand([(True,), (False,)]) def test_get_values_when_parameter_repeats(self, shell): + parameters_supplier = self.create_parameters_supplier('param1') values_provider = DependantScriptValuesProvider( "echo '_${param1}_\n' 'test\n' '+${param1}+'", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual(['_123_', ' test', ' +123+'], values_provider.get_values({'param1': '123'})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': '123'}) + self.assertEqual(['_123_', ' test', ' +123+'], values_provider.get_values(value_wrappers)) @parameterized.expand([(True,), (False,)]) def test_get_values_when_numeric_parameter(self, shell): + parameters_supplier = self.create_parameters_supplier('param1') values_provider = DependantScriptValuesProvider( "echo '_${param1}_'", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual(['_123_'], values_provider.get_values({'param1': 123})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': 123}) + self.assertEqual(['_123_'], values_provider.get_values(value_wrappers)) @parameterized.expand([(True,), (False,)]) def test_get_values_when_newline_response(self, shell): + parameters_supplier = self.create_parameters_supplier('param1') values_provider = DependantScriptValuesProvider( "ls '${param1}'", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual([], values_provider.get_values({'param1': test_utils.temp_folder})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': test_utils.temp_folder}) + self.assertEqual([], values_provider.get_values(value_wrappers)) @parameterized.expand([(True, ['1', '2']), (False, ['1 && echo 2'])]) def test_no_code_injection_for_and_operator(self, shell, expected_values): + parameters_supplier = self.create_parameters_supplier('param1') values_provider = DependantScriptValuesProvider( "echo ${param1}", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual(expected_values, values_provider.get_values({'param1': '1 && echo 2'})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': '1 && echo 2'}) + self.assertEqual(expected_values, values_provider.get_values(value_wrappers)) @parameterized.expand([(True, ['y2', 'y3']), (False, [])]) def test_no_code_injection_for_pipe_operator(self, shell, expected_values): test_utils.create_files(['x1', 'y2', 'y3']) + parameters_supplier = self.create_parameters_supplier('param1') values_provider = DependantScriptValuesProvider( "ls ${param1}", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual(expected_values, values_provider.get_values({'param1': test_utils.temp_folder + ' | grep y'})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': test_utils.temp_folder + ' | grep y'}) + self.assertEqual(expected_values, values_provider.get_values(value_wrappers)) @parameterized.expand([(True,), (False,)]) def test_script_fails(self, shell): + parameters_supplier = self.create_parameters_supplier('param1') values_provider = DependantScriptValuesProvider( "echo2 ${param1}", - self.create_parameters_supplier('param1'), + parameters_supplier, shell=shell, process_invoker=test_utils.process_invoker) - self.assertEqual([], values_provider.get_values({'param1': 'abc'})) + + value_wrappers = wrap_values(parameters_supplier(), {'param1': 'abc'}) + self.assertEqual([], values_provider.get_values(value_wrappers)) def setUp(self): test_utils.setup() diff --git a/src/tests/model_helper_test.py b/src/tests/model_helper_test.py index 7df23c85..94a22ce7 100644 --- a/src/tests/model_helper_test.py +++ b/src/tests/model_helper_test.py @@ -7,7 +7,7 @@ from model.model_helper import read_list, read_dict, fill_parameter_values, resolve_env_vars, \ InvalidFileException, read_bool_from_config, InvalidValueException, InvalidValueTypeException, read_str_from_config from tests import test_utils -from tests.test_utils import create_parameter_model, set_os_environ_value +from tests.test_utils import create_parameter_model, set_os_environ_value, wrap_values from utils import file_utils from utils.file_utils import FileMatcher @@ -121,39 +121,62 @@ def test_unsupported_type(self): class TestFillParameterValues(unittest.TestCase): def test_fill_single_parameter(self): - result = fill_parameter_values(self.create_parameters('p1'), 'Hello, ${p1}!', {'p1': 'world'}) + parameters = self.create_parameters('p1') + value_wrappers = wrap_values(parameters, {'p1': 'world'}) + + result = fill_parameter_values(parameters, 'Hello, ${p1}!', value_wrappers) self.assertEqual('Hello, world!', result) def test_fill_single_parameter_multiple_times(self): - result = fill_parameter_values(self.create_parameters('p1'), 'Ho${p1}-${p1}${p1}!', {'p1': 'ho'}) + parameters = self.create_parameters('p1') + value_wrappers = wrap_values(parameters, {'p1': 'ho'}) + + result = fill_parameter_values(parameters, 'Ho${p1}-${p1}${p1}!', value_wrappers) self.assertEqual('Hoho-hoho!', result) def test_fill_multiple_parameters(self): - result = fill_parameter_values(self.create_parameters('p1', 'p2', 'p3'), - 'Some ${p2} text, which is ${p3} by ${p1}.', - {'p1': 'script-server', 'p2': 'small', 'p3': 'generated'}) + parameters = self.create_parameters('p1', 'p2', 'p3') + value_wrappers = wrap_values(parameters, {'p1': 'script-server', 'p2': 'small', 'p3': 'generated'}) + + result = fill_parameter_values( + parameters, + 'Some ${p2} text, which is ${p3} by ${p1}.', + value_wrappers) self.assertEqual('Some small text, which is generated by script-server.', result) def test_fill_multiple_parameters_when_one_without_value(self): - result = fill_parameter_values(self.create_parameters('p1', 'p2'), - '${p1} vs ${p2}', - {'p1': 'ABC'}) + parameters = self.create_parameters('p1', 'p2') + value_wrappers = wrap_values(parameters, {'p1': 'ABC'}) + + result = fill_parameter_values( + parameters, + '${p1} vs ${p2}', + value_wrappers) self.assertEqual('ABC vs ', result) def test_fill_multiple_parameters_when_one_secure(self): parameters = self.create_parameters('p1', 'p2') parameters[1].secure = True - result = fill_parameter_values(parameters, - '${p1} vs ${p2}', - {'p1': 'ABC', 'p2': 'XYZ'}) + value_wrappers = wrap_values(parameters, {'p1': 'ABC', 'p2': 'XYZ'}) + + result = fill_parameter_values( + parameters, + '${p1} vs ${p2}', + value_wrappers) self.assertEqual('ABC vs ${p2}', result) def test_fill_non_string_value(self): - result = fill_parameter_values(self.create_parameters('p1'), 'Value = ${p1}', {'p1': 5}) + parameters = self.create_parameters('p1') + value_wrappers = wrap_values(parameters, {'p1': 5}) + + result = fill_parameter_values(parameters, 'Value = ${p1}', value_wrappers) self.assertEqual('Value = 5', result) def test_fill_when_no_parameter_for_pattern(self): - result = fill_parameter_values(self.create_parameters('p1'), 'Value = ${xyz}', {'p1': '12345'}) + parameters = self.create_parameters('p1') + value_wrappers = wrap_values(parameters, {'p1': '12345'}) + + result = fill_parameter_values(parameters, 'Value = ${xyz}', value_wrappers) self.assertEqual('Value = ${xyz}', result) def test_fill_when_server_file_recursive_and_one_level(self): @@ -162,8 +185,9 @@ def test_fill_when_server_file_recursive_and_one_level(self): type='server_file', file_dir=test_utils.temp_folder, file_recursive=True)] + value_wrappers = wrap_values(parameters, {'p1': ['folder']}) - result = fill_parameter_values(parameters, 'Value = ${p1}', {'p1': ['folder']}) + result = fill_parameter_values(parameters, 'Value = ${p1}', value_wrappers) expected_value = os.path.join(test_utils.temp_folder, 'folder') self.assertEqual('Value = ' + expected_value, result) @@ -173,8 +197,9 @@ def test_fill_when_server_file_recursive_and_multiple_levels(self): type='server_file', file_dir=test_utils.temp_folder, file_recursive=True)] + value_wrappers = wrap_values(parameters, {'p1': ['folder', 'sub', 'log.txt']}) - result = fill_parameter_values(parameters, 'Value = ${p1}', {'p1': ['folder', 'sub', 'log.txt']}) + result = fill_parameter_values(parameters, 'Value = ${p1}', value_wrappers) expected_value = os.path.join(test_utils.temp_folder, 'folder', 'sub', 'log.txt') self.assertEqual('Value = ' + expected_value, result) @@ -185,8 +210,9 @@ def test_fill_when_server_file_plain(self): file_dir=test_utils.temp_folder, file_recursive=True)] - result = fill_parameter_values(parameters, 'Value = ${p1}', {'p1': 'folder'}) - self.assertEqual('Value = folder', result) + value_wrappers = wrap_values(parameters, {'p1': ['folder']}) + result = fill_parameter_values(parameters, 'Value = ${p1}', value_wrappers) + self.assertEqual('Value = tests_temp/folder', result) def create_parameters(self, *names): result = [] diff --git a/src/tests/parameter_config_test.py b/src/tests/parameter_config_test.py index c249d22e..0089346d 100644 --- a/src/tests/parameter_config_test.py +++ b/src/tests/parameter_config_test.py @@ -9,9 +9,11 @@ from model import parameter_config from model.model_helper import InvalidValueException from model.parameter_config import get_sorted_config, ParameterUiSeparator +from model.value_wrapper import ScriptValueWrapper from react.properties import ObservableDict, ObservableList from tests import test_utils -from tests.test_utils import create_parameter_model, create_parameter_model_from_config +from tests.test_utils import create_parameter_model, create_parameter_model_from_config, validate_value, \ + get_default_value, wrap_values from utils.process_utils import ExecutionException from utils.string_utils import is_blank @@ -40,7 +42,7 @@ def test_create_full_parameter(self): 'name': name, 'param': param, 'env_var': 'my_Param', - 'no_value': 'true', + 'no_value': 'false', 'description': description, 'required': required, 'min': min, @@ -63,14 +65,14 @@ def test_create_full_parameter(self): self.assertEqual(name, parameter_model.name) self.assertEqual(param, parameter_model.param) self.assertEqual('my_Param', parameter_model.env_var) - self.assertEqual(True, parameter_model.no_value) + self.assertEqual(False, parameter_model.no_value) self.assertEqual(description, parameter_model.description) self.assertEqual(required, parameter_model.required) self.assertEqual(min, parameter_model.min) self.assertEqual(max, parameter_model.max) self.assertEqual(separator, parameter_model.separator) self.assertEqual('argument_per_value', parameter_model.multiselect_argument_type) - self.assertEqual(default, parameter_model.default) + self.assertEqual(default, parameter_model.create_value_wrapper_for_default().mapped_script_value) self.assertEqual(type, parameter_model.type) self.assertEqual(False, parameter_model.constant) self.assertEqual(3, parameter_model.ui_width_weight) @@ -95,7 +97,7 @@ def test_default_value_from_env(self): parameter_model = _create_parameter_model({ 'name': 'def_param', 'default': '$$my_env_var'}) - self.assertEqual('sky', parameter_model.default) + self.assertEqual('sky', parameter_model.create_value_wrapper_for_default().mapped_script_value) def test_default_value_from_env_when_missing(self): test_utils.set_os_environ_value('my_env_var', 'earth') @@ -108,7 +110,8 @@ def test_default_value_from_env_when_missing(self): def test_default_value_from_auth(self): parameter_model = _create_parameter_model({'name': 'def_param', 'default': 'X${auth.username}X'}) - self.assertEqual('X' + DEF_USERNAME + 'X', parameter_model.default) + default_wrapper = parameter_model.create_value_wrapper_for_default() + self.assertEqual('X' + DEF_USERNAME + 'X', default_wrapper.mapped_script_value) def test_prohibit_constant_without_default(self): self.assertRaisesRegex(Exception, 'Constant should have default value specified', @@ -380,17 +383,17 @@ def test_script_value_when_dynamic_script(self, script, values, expected_value): parameter_values=parameter_values ) - self.assertEqual(None, parameter_model.default) + self.assertEqual(None, get_default_value(parameter_model)) - for (key, value) in values.items(): - parameter_values[key] = value + parameter_values.set(wrap_values(parameters, values)) - self.assertEqual(expected_value, parameter_model.default) + self.assertEqual(expected_value, get_default_value(parameter_model)) def test_script_value_when_dynamic_script_value_change_to_null(self): - parameters = ObservableList([_create_parameter_model({'name': 'param1'})]) + param1 = _create_parameter_model({'name': 'param1'}) + parameters = ObservableList([param1]) - parameter_values = ObservableDict({'param1': 'abc'}) + parameter_values = ObservableDict({'param1': param1.create_value_wrapper('abc')}) parameter_model = self.prepare_parameter_model( {'script': 'echo ${param1}'}, @@ -399,14 +402,15 @@ def test_script_value_when_dynamic_script_value_change_to_null(self): parameter_values=parameter_values ) - self.assertEqual('abc', parameter_model.default) + self.assertEqual('abc', get_default_value(parameter_model)) - parameter_values['param1'] = None + parameter_values['param1'] = param1.create_value_wrapper(None) - self.assertEqual(None, parameter_model.default) + self.assertEqual(None, get_default_value(parameter_model)) def test_script_value_when_dynamic_script_shell_true(self): - parameters = ObservableList([_create_parameter_model({'name': 'param1'})]) + param1 = _create_parameter_model({'name': 'param1'}) + parameters = ObservableList([param1]) parameter_values = ObservableDict() @@ -416,16 +420,17 @@ def test_script_value_when_dynamic_script_shell_true(self): parameter_values=parameter_values ) - parameter_values['param1'] = 'abc\ndef\ncat' + parameter_values['param1'] = param1.create_value_wrapper('abc\ndef\ncat') - self.assertEqual('abc\ncat', parameter_model.default) + self.assertEqual('abc\ncat', get_default_value(parameter_model)) @parameterized.expand([ (False,), (None,), ]) def test_script_value_when_dynamic_script_shell_false(self, shell): - parameters = ObservableList([_create_parameter_model({'name': 'param1'})]) + param1 = _create_parameter_model({'name': 'param1'}) + parameters = ObservableList([param1]) parameter_values = ObservableDict() @@ -435,9 +440,25 @@ def test_script_value_when_dynamic_script_shell_false(self, shell): parameter_values=parameter_values ) - parameter_values['param1'] = 'abc\ndef\ncat' + parameter_values['param1'] = param1.create_value_wrapper('abc\ndef\ncat') + + self.assertEqual('abc\ndef\ncat | grep a', get_default_value(parameter_model)) + + def test_script_value_when_dynamic_script_and_ui_mapping(self): + param1 = _create_parameter_model({'name': 'param1', 'values_ui_mapping': {'abc': 'qwerty'}}) + parameters = ObservableList([param1]) + + parameter_values = ObservableDict() + + parameter_model = self.prepare_parameter_model( + {'script': 'echo "${param1}"'}, + other_parameters=parameters, + parameter_values=parameter_values + ) + + parameter_values['param1'] = param1.create_value_wrapper('qwerty') - self.assertEqual('abc\ndef\ncat | grep a', parameter_model.default) + self.assertEqual('abc', get_default_value(parameter_model)) @staticmethod def resolve_default(value, *, username=None, audit_name=None, working_dir=None, type=None): @@ -448,7 +469,7 @@ def resolve_default(value, *, username=None, audit_name=None, working_dir=None, working_dir=working_dir, type=type) - return model.default + return get_default_value(model) @staticmethod def prepare_parameter_model( @@ -496,261 +517,261 @@ class TestSingleParameterValidation(unittest.TestCase): def test_string_parameter_when_none(self): parameter = create_parameter_model('param') - error = parameter.validate_value(None) + error = validate_value(parameter, None) self.assertIsNone(error) @parameterized.expand([(None,), ('multiline_text',)]) def test_text_parameter_when_value(self, param_type): parameter = create_parameter_model('param', type=param_type) - error = parameter.validate_value('val') + error = validate_value(parameter, 'val') self.assertIsNone(error) def test_required_parameter_when_none(self): parameter = create_parameter_model('param', required=True) - error = parameter.validate_value({}) + error = validate_value(parameter, {}) self.assert_error(error) @parameterized.expand([(None,), ('multiline_text',)]) def test_required_parameter_when_empty(self, param_type): parameter = create_parameter_model('param', type=param_type, required=True) - error = parameter.validate_value('') + error = validate_value(parameter, '') self.assert_error(error) def test_required_parameter_when_value(self): parameter = create_parameter_model('param', required=True) - error = parameter.validate_value('val') + error = validate_value(parameter, 'val') self.assertIsNone(error) def test_required_parameter_when_constant(self): parameter = create_parameter_model('param', required=True, constant=True, default='123') - error = parameter.validate_value(None) + error = validate_value(parameter, None) self.assertIsNone(error) def test_flag_parameter_when_true_bool(self): parameter = create_parameter_model('param', no_value=True) - error = parameter.validate_value(True) + error = validate_value(parameter, True) self.assertIsNone(error) def test_flag_parameter_when_false_bool(self): parameter = create_parameter_model('param', no_value=True) - error = parameter.validate_value(False) + error = validate_value(parameter, False) self.assertIsNone(error) def test_flag_parameter_when_true_string(self): parameter = create_parameter_model('param', no_value=True) - error = parameter.validate_value('true') + error = validate_value(parameter, 'true') self.assertIsNone(error) def test_flag_parameter_when_false_string(self): parameter = create_parameter_model('param', no_value=True) - error = parameter.validate_value('false') + error = validate_value(parameter, 'false') self.assertIsNone(error) def test_flag_parameter_when_some_string(self): parameter = create_parameter_model('param', no_value=True) - error = parameter.validate_value('no') + error = validate_value(parameter, 'no') self.assert_error(error) def test_required_flag_parameter_when_true_boolean(self): parameter = create_parameter_model('param', no_value=True, required=True) - error = parameter.validate_value(True) + error = validate_value(parameter, True) self.assertIsNone(error) def test_required_flag_parameter_when_false_boolean(self): parameter = create_parameter_model('param', no_value=True, required=True) - error = parameter.validate_value(False) + error = validate_value(parameter, False) self.assertIsNone(error) def test_int_parameter_when_negative_int(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value(-100) + error = validate_value(parameter, -100) self.assertIsNone(error) def test_int_parameter_when_large_positive_int(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value(1234567890987654321) + error = validate_value(parameter, 1234567890987654321) self.assertIsNone(error) def test_int_parameter_when_zero_int_string(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value('0') + error = validate_value(parameter, '0') self.assertIsNone(error) def test_int_parameter_when_large_negative_int_string(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value('-1234567890987654321') + error = validate_value(parameter, '-1234567890987654321') self.assertIsNone(error) def test_int_parameter_when_not_int_string(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value('v123') + error = validate_value(parameter, 'v123') self.assert_error(error) def test_int_parameter_when_float(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value(1.2) + error = validate_value(parameter, 1.2) self.assert_error(error) def test_int_parameter_when_float_string(self): parameter = create_parameter_model('param', type='int') - error = parameter.validate_value('1.0') + error = validate_value(parameter, '1.0') self.assert_error(error) def test_int_parameter_when_lower_than_max(self): parameter = create_parameter_model('param', type='int', max=100) - error = parameter.validate_value(9) + error = validate_value(parameter, 9) self.assertIsNone(error) def test_int_parameter_when_equal_to_max(self): parameter = create_parameter_model('param', type='int', max=5) - error = parameter.validate_value(5) + error = validate_value(parameter, 5) self.assertIsNone(error) def test_int_parameter_when_larger_than_max(self): parameter = create_parameter_model('param', type='int', max=0) - error = parameter.validate_value(100) + error = validate_value(parameter, 100) self.assert_error(error) def test_int_parameter_when_lower_than_min(self): parameter = create_parameter_model('param', type='int', min=100) - error = parameter.validate_value(0) + error = validate_value(parameter, 0) self.assert_error(error) def test_int_parameter_when_equal_to_min(self): parameter = create_parameter_model('param', type='int', min=-100) - error = parameter.validate_value(-100) + error = validate_value(parameter, -100) self.assertIsNone(error) def test_int_parameter_when_larger_than_min(self): parameter = create_parameter_model('param', type='int', min=100) - error = parameter.validate_value(0) + error = validate_value(parameter, 0) self.assert_error(error) def test_required_int_parameter_when_zero(self): parameter = create_parameter_model('param', type='int', required=True) - error = parameter.validate_value(0) + error = validate_value(parameter, 0) self.assertIsNone(error) def test_file_upload_parameter_when_valid(self): parameter = create_parameter_model('param', type='file_upload') uploaded_file = test_utils.create_file('test.xml') - error = parameter.validate_value(uploaded_file) + error = validate_value(parameter, uploaded_file) self.assertIsNone(error) def test_file_upload_parameter_when_not_exists(self): parameter = create_parameter_model('param', type='file_upload') uploaded_file = test_utils.create_file('test.xml') - error = parameter.validate_value(uploaded_file + '_') + error = validate_value(parameter, uploaded_file + '_') self.assert_error(error) def test_list_parameter_when_matches(self): parameter = create_parameter_model( 'param', type='list', allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value('val2') + error = validate_value(parameter, 'val2') self.assertIsNone(error) def test_list_parameter_when_not_matches(self): parameter = create_parameter_model( 'param', type='list', allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value('val4') + error = validate_value(parameter, 'val4') self.assert_error(error) def test_editable_list_parameter_when_not_matches(self): parameter = create_parameter_model( 'param', type='editable_list', allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value('val4') + error = validate_value(parameter, 'val4') self.assertIsNone(error) def test_multiselect_when_empty_string(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value('') + error = validate_value(parameter, '') self.assertIsNone(error) def test_multiselect_when_empty_list(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value([]) + error = validate_value(parameter, []) self.assertIsNone(error) def test_multiselect_when_single_matching_element(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value(['val2']) + error = validate_value(parameter, ['val2']) self.assertIsNone(error) def test_multiselect_when_multiple_matching_elements(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value(['val2', 'val1']) + error = validate_value(parameter, ['val2', 'val1']) self.assertIsNone(error) def test_multiselect_when_multiple_elements_one_not_matching(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value(['val2', 'val1', 'X']) + error = validate_value(parameter, ['val2', 'val1', 'X']) self.assert_error(error) def test_multiselect_when_not_list_value(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value('val1') + error = validate_value(parameter, 'val1') self.assert_error(error) def test_multiselect_when_single_not_matching_element(self): parameter = create_parameter_model( 'param', type=PARAM_TYPE_MULTISELECT, allowed_values=['val1', 'val2', 'val3']) - error = parameter.validate_value(['X']) + error = validate_value(parameter, ['X']) self.assert_error(error) def test_list_with_script_when_matches(self): parameter = create_parameter_model('param', type='list', values_script="echo '123\n' 'abc'") - error = parameter.validate_value('123') + error = validate_value(parameter, '123') self.assertIsNone(error) def test_list_with_script_when_matches_and_win_newline(self): parameter = create_parameter_model('param', type='list', values_script="echo '123\r\n' 'abc'") - error = parameter.validate_value('123') + error = validate_value(parameter, '123') self.assertIsNone(error) @parameterized.expand([ @@ -763,7 +784,7 @@ def test_list_with_script_when_matches_and_win_newline(self): def test_regex_validation_when_fail_with_description(self, regex, value, description, expected_description): parameter = create_parameter_model('param', regex={'pattern': regex, 'description': description}) - error = parameter.validate_value(value) + error = validate_value(parameter, value) self.assert_error(error) self.assertEqual(error, "does not match regex pattern: " + expected_description) @@ -777,7 +798,7 @@ def test_regex_validation_when_fail_with_description(self, regex, value, descrip def test_regex_validation_when_success(self, regex, value): parameter = create_parameter_model('param', regex={'pattern': regex}) - error = parameter.validate_value(value) + error = validate_value(parameter, value) self.assertIsNone(error) @parameterized.expand([(False,), (True,), (None,)]) @@ -793,48 +814,48 @@ def test_list_with_dependency_when_matches(self, shell): values_script_shell=shell) parameters.extend([dep_param, parameter]) - values['dep_param'] = 'abc' - error = parameter.validate_value(' _abc_') + values['dep_param'] = ScriptValueWrapper('abc', 'abc', 'abc') + error = validate_value(parameter, ' _abc_') self.assertIsNone(error) def test_any_ip_when_ip4(self): parameter = create_parameter_model('param', type='ip') - error = parameter.validate_value('127.0.0.1') + error = validate_value(parameter, '127.0.0.1') self.assertIsNone(error) def test_any_ip_when_ip6(self): parameter = create_parameter_model('param', type='ip') - error = parameter.validate_value('ABCD::6789') + error = validate_value(parameter, 'ABCD::6789') self.assertIsNone(error) def test_any_ip_when_wrong(self): parameter = create_parameter_model('param', type='ip') - error = parameter.validate_value('127.abcd.1') + error = validate_value(parameter, '127.abcd.1') self.assert_error(error) def test_ip4_when_valid(self): parameter = create_parameter_model('param', type='ip4') - error = parameter.validate_value('192.168.0.13') + error = validate_value(parameter, '192.168.0.13') self.assertIsNone(error) def test_ip4_when_ip6(self): parameter = create_parameter_model('param', type='ip4') - error = parameter.validate_value('ABCD::1234') + error = validate_value(parameter, 'ABCD::1234') self.assert_error(error) def test_ip6_when_valid(self): parameter = create_parameter_model('param', type='ip6') - error = parameter.validate_value('1:2:3:4:5:6:7:8') + error = validate_value(parameter, '1:2:3:4:5:6:7:8') self.assertIsNone(error) def test_ip6_when_ip4(self): parameter = create_parameter_model('param', type='ip6') - error = parameter.validate_value('172.13.0.15') + error = validate_value(parameter, '172.13.0.15') self.assert_error(error) def test_ip6_when_complex_valid(self): parameter = create_parameter_model('param', type='ip6') - error = parameter.validate_value('AbC:0::13:127.0.0.1') + error = validate_value(parameter, 'AbC:0::13:127.0.0.1') self.assertIsNone(error) def test_server_file_when_valid(self): @@ -843,28 +864,28 @@ def test_server_file_when_valid(self): test_utils.create_file(filename) parameter = create_parameter_model('param', type=PARAM_TYPE_SERVER_FILE, file_dir=test_utils.temp_folder) - error = parameter.validate_value(filename) + error = validate_value(parameter, filename) self.assertIsNone(error) def test_server_file_when_wrong(self): test_utils.create_file('file1.txt') parameter = create_parameter_model('param', type=PARAM_TYPE_SERVER_FILE, file_dir=test_utils.temp_folder) - error = parameter.validate_value('my.dat') + error = validate_value(parameter, 'my.dat') self.assert_error(error) @parameterized.expand([(None,), ('multiline_text',)]) def test_text_parameter_when_max_length_ok(self, param_type): parameter = create_parameter_model('param', type=param_type, max_length=10) - error = parameter.validate_value('012345678\n') + error = validate_value(parameter, '012345678\n') self.assertIsNone(error) @parameterized.expand([(None,), ('multiline_text',)]) def test_text_parameter_when_max_length_violated(self, param_type): parameter = create_parameter_model('param', type=param_type, max_length=10) - error = parameter.validate_value('0123456789\n') + error = validate_value(parameter, '0123456789\n') self.assert_error(error) def assert_error(self, error): @@ -1039,6 +1060,25 @@ def test_get_sorted_when_unknown_fields(self): self.assertCountEqual(expected.items(), config.items()) +class TestUiValueMapping(unittest.TestCase): + def test_no_mapping(self): + parameter_model = create_parameter_model(type='list', allowed_values=['abc', 'def', 'xyz']) + + self.assertEqual(['abc', 'def', 'xyz'], parameter_model.get_ui_values()) + + def test_mapping(self): + parameter_model = create_parameter_model( + type='list', + allowed_values=['abc', 'def', 'xyz'], + values_ui_mapping={ + 'abc': 'ABC', + 'def': 'qwerty' + } + ) + + self.assertEqual(['ABC', 'qwerty', 'xyz'], parameter_model.get_ui_values()) + + def _create_parameter_model(config, *, username=DEF_USERNAME, audit_name=DEF_AUDIT_NAME, all_parameters=None): return create_parameter_model_from_config(config, username=username, diff --git a/src/tests/parameter_server_file_test.py b/src/tests/parameter_server_file_test.py index 966013da..c4b1173c 100644 --- a/src/tests/parameter_server_file_test.py +++ b/src/tests/parameter_server_file_test.py @@ -6,7 +6,7 @@ from model.parameter_config import WrongParameterUsageException from model.script_config import InvalidValueException from tests import test_utils -from tests.test_utils import create_script_param_config, create_files +from tests.test_utils import create_script_param_config, create_files, validate_value class ServerFileConfigTest(unittest.TestCase): @@ -119,19 +119,19 @@ def test_validate_success_when_working_dir(self): create_files(['abc', 'def'], os.path.join(working_dir, file_dir)) working_dir_path = os.path.join(test_utils.temp_folder, 'work', 'dir') config = _create_parameter_model(recursive=False, file_dir=file_dir, working_dir=working_dir_path) - self.assertIsNone(config.validate_value('def')) + self.assertIsNone(validate_value(config, 'def')) def test_validate_failure_when_working_dir(self): file_dir = 'inner' create_files(['abc', 'def'], file_dir) working_dir_path = os.path.join(test_utils.temp_folder, 'work', 'dir') config = _create_parameter_model(recursive=False, file_dir=file_dir, working_dir=working_dir_path) - self.assertRegex(config.validate_value('def'), '.+ but should be in \[\]') + self.assertRegex(validate_value(config, 'def'), '.+ but should be in \[\]') def test_validate_failure_when_excluded_file(self): create_files(['abc', 'def']) config = _create_parameter_model(recursive=False, file_dir=test_utils.temp_folder, excluded_files=['abc']) - self.assertRegex(config.validate_value('abc'), '.+ but should be in \[\'def\'\]') + self.assertRegex(validate_value(config, 'abc'), '.+ but should be in \[\'def\'\]') def setUp(self): test_utils.setup() @@ -366,50 +366,50 @@ def test_list_files_when_excluded_recursive_glob(self): def test_validate_missing_value(self): config = _create_parameter_model(recursive=True) - self.assertIsNone(config.validate_value([])) + self.assertIsNone(validate_value(config, [])) def test_validate_existing_top_file(self): create_files(['abc']) config = _create_parameter_model(recursive=True) - self.assertIsNone(config.validate_value(['abc'])) + self.assertIsNone(validate_value(config, ['abc'])) def test_validate_existing_nested_file(self): create_files(['abc.txt'], os.path.join('my', 'nested', 'folder')) config = _create_parameter_model(recursive=True) - self.assertIsNone(config.validate_value(['my', 'nested', 'folder', 'abc.txt'])) + self.assertIsNone(validate_value(config, ['my', 'nested', 'folder', 'abc.txt'])) def test_validate_missing_top_file(self): config = _create_parameter_model(recursive=True) - error = config.validate_value(['abc']) + error = validate_value(config, ['abc']) self.assertRegex(error, '.+ does not exist') def test_validate_missing_nested_file(self): config = _create_parameter_model(recursive=True) - error = config.validate_value(['my', 'nested', 'folder', 'abc.txt']) + error = validate_value(config, ['my', 'nested', 'folder', 'abc.txt']) self.assertRegex(error, '.+ does not exist') def test_validate_fail_when_relative_reference(self): create_files(['abc.txt']) config = _create_parameter_model(recursive=True) - error = config.validate_value(['..', test_utils.temp_folder, 'abc.txt']) + error = validate_value(config, ['..', test_utils.temp_folder, 'abc.txt']) self.assertEqual('Relative path references are not allowed', error) def test_validate_success_when_extensions(self): create_files(['abc.txt', 'admin.log', 'doc.pdf'], 'home') config = _create_parameter_model(recursive=True, extensions='txt') - self.assertIsNone(config.validate_value(['home', 'abc.txt'])) + self.assertIsNone(validate_value(config, ['home', 'abc.txt'])) def test_validate_fail_when_extensions(self): create_files(['abc.txt', 'admin.log', 'doc.pdf'], 'home') config = _create_parameter_model(recursive=True, extensions='log') - error = config.validate_value(['home', 'abc.txt']) + error = validate_value(config, ['home', 'abc.txt']) self.assertRegex(error, '.+ is not allowed') def test_validate_success_when_file_type_file(self): @@ -417,14 +417,14 @@ def test_validate_success_when_file_type_file(self): create_files(['.passwords'], os.path.join('home', 'private')) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_FILE) - self.assertIsNone(config.validate_value(['home', 'abc.txt'])) + self.assertIsNone(validate_value(config, ['home', 'abc.txt'])) def test_validate_fail_when_file_type_file(self): create_files(['abc.txt', 'admin.log'], 'home') create_files(['.passwords'], os.path.join('home', 'private')) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_FILE) - error = config.validate_value(['home', 'private']) + error = validate_value(config, ['home', 'private']) self.assertRegex(error, '.+ is not allowed') def test_validate_success_when_file_type_dir(self): @@ -432,14 +432,14 @@ def test_validate_success_when_file_type_dir(self): create_files(['tasks.list']) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_DIR) - self.assertIsNone(config.validate_value(['private'])) + self.assertIsNone(validate_value(config, ['private'])) def test_validate_fail_when_file_type_dir(self): create_files(['.passwords'], 'private') create_files(['tasks.list']) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_DIR) - error = config.validate_value(['tasks.list']) + error = validate_value(config, ['tasks.list']) self.assertRegex(error, '.+ is not allowed') def test_validate_success_when_file_type_dir_and_extensions(self): @@ -447,14 +447,14 @@ def test_validate_success_when_file_type_dir_and_extensions(self): create_files(['admin.log', 'file.txt', 'print.pdf']) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_DIR, extensions=['txt']) - self.assertIsNone(config.validate_value(['file.txt'])) + self.assertIsNone(validate_value(config, ['file.txt'])) def test_validate_fail_on_dir_when_file_type_dir_and_extensions(self): create_files(['.passwords'], 'private') create_files(['admin.log', 'file.txt', 'print.pdf']) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_DIR, extensions=['txt']) - error = config.validate_value(['private']) + error = validate_value(config, ['private']) self.assertRegex(error, '.+ is not allowed') def test_validate_fail_on_extension_when_file_type_dir_and_extensions(self): @@ -462,7 +462,7 @@ def test_validate_fail_on_extension_when_file_type_dir_and_extensions(self): create_files(['admin.log', 'file.txt', 'print.pdf']) config = _create_parameter_model(recursive=True, file_type=FILE_TYPE_DIR, extensions=['txt']) - error = config.validate_value(['print.pdf']) + error = validate_value(config, ['print.pdf']) self.assertRegex(error, '.+ is not allowed') def test_validate_fail_on_excluded_file(self): @@ -470,7 +470,7 @@ def test_validate_fail_on_excluded_file(self): create_files(['xyz', 'abc'], subfolder) config = _create_parameter_model(recursive=True, excluded_files=[os.path.join(subfolder, 'abc')]) - error = config.validate_value(['work', 'another', 'abc']) + error = validate_value(config, ['work', 'another', 'abc']) self.assertRegex(error, '.+ is excluded') def test_validate_fail_on_list_excluded_subfolder(self): diff --git a/src/tests/scheduling/schedule_service_test.py b/src/tests/scheduling/schedule_service_test.py index 844acb01..90e49df3 100644 --- a/src/tests/scheduling/schedule_service_test.py +++ b/src/tests/scheduling/schedule_service_test.py @@ -66,7 +66,7 @@ def setUp(self) -> None: self.create_config('unschedulable-script', scheduling_enabled=False) self.execution_service = MagicMock() - self.execution_service.start_script.side_effect = lambda config, values, user: time.time_ns() + self.execution_service.start_script.side_effect = lambda config, user: time.time_ns() self.schedule_service = ScheduleService(self.config_service, self.execution_service, test_utils.temp_folder) @@ -75,7 +75,7 @@ def setUp(self) -> None: def create_config(self, name, scheduling_enabled=True, parameters=None, auto_cleanup=False): if parameters is None: parameters = [ - {'name': 'p1'}, + {'name': 'p1', 'values_ui_mapping': {'bingo!': 'mpd'}}, {'name': 'param_2', 'type': 'multiselect', 'values': ['hello', 'world', '1', '2', '3']}, ] @@ -183,6 +183,15 @@ def test_create_job_verify_timer_call_when_repeatable(self): self.assert_schedule_calls([(job_prototype, get_job_path(job_prototype), mocked_now_epoch + 1468703)]) + def test_create_job_when_ui_values_mapping(self): + job_prototype = create_job( + id='1', + parameter_values={'p1': 'mpd', 'param_2': []}, + repeatable=False) + self.call_create_job(job_prototype) + + self.assert_schedule_calls([(job_prototype, get_job_path(job_prototype), mocked_now_epoch + 5)]) + def call_create_job(self, job: SchedulingJob): return self.schedule_service.create_job( job.script_name, @@ -292,13 +301,18 @@ def test_scheduler_runner_when_stopped(self): class TestScheduleServiceExecuteJob(ScheduleServiceTestCase): def test_execute_simple_job(self): - job = create_job(id=1, repeatable=False, start_datetime=mocked_now - timedelta(seconds=1)) + job = create_job( + id=1, + repeatable=False, + start_datetime=mocked_now - timedelta(seconds=1), + parameter_values={'p1': 'mpd', 'param_2': ['hello', '3']}) + job_path = save_job(job) self.schedule_service._execute_job(job, job_path) - self.execution_service.start_script.assert_called_once_with( - ANY, job.parameter_values, job.user) + self.verify_start_script_call({'p1': 'bingo!', 'param_2': ['hello', '3']}, job.user) + self.execution_service.add_finish_listener.assert_not_called() self.assert_schedule_calls([]) @@ -312,8 +326,7 @@ def test_execute_repeatable_job(self): self.schedule_service._execute_job(job, job_path) - self.execution_service.start_script.assert_called_once_with( - ANY, job.parameter_values, job.user) + self.verify_start_script_call(job.parameter_values, job.user) self.execution_service.add_finish_listener.assert_not_called() self.assert_schedule_calls([(job, job_path, mocked_now_epoch + 86399)]) @@ -394,8 +407,7 @@ def add_finish_listener(callback_param, execution_id): self.schedule_service._execute_job(job, job_path) - self.execution_service.start_script.assert_called_once_with( - ANY, job.parameter_values, job.user) + self.verify_start_script_call(job.parameter_values, job.user) self.execution_service.cleanup_execution.assert_not_called() self.assertIsNotNone(finish_callback) @@ -404,6 +416,12 @@ def add_finish_listener(callback_param, execution_id): self.execution_service.cleanup_execution.assert_called_once_with(ANY, job.user) + def verify_start_script_call(self, expected_values, expected_user): + start_args = self.execution_service.start_script.call_args.args + self.assertEqual(expected_user, start_args[1]) + actual_values = {name: value.mapped_script_value for name, value in start_args[0].parameter_values.items()} + self.assertEqual(expected_values, actual_values) + def create_job(id=None, user_id='UserX', diff --git a/src/tests/script_config_test.py b/src/tests/script_config_test.py index d8856988..1acfa8c6 100644 --- a/src/tests/script_config_test.py +++ b/src/tests/script_config_test.py @@ -8,9 +8,10 @@ from config.exceptions import InvalidConfigException from model.script_config import ConfigModel, InvalidValueException, TemplateProperty, ParameterNotFoundException, \ get_sorted_config +from model.value_wrapper import ScriptValueWrapper from react.properties import ObservableDict, ObservableList from tests import test_utils -from tests.test_utils import create_script_param_config, create_parameter_model, create_files +from tests.test_utils import create_script_param_config, create_parameter_model, create_files, wrap_values from utils import file_utils, custom_json from utils.process_utils import ExecutionException @@ -58,14 +59,15 @@ def test_create_with_parameter(self): def test_create_with_parameters_and_default_values(self): parameters = [create_script_param_config('param1', default='123'), create_script_param_config('param2'), - create_script_param_config('param3', default='A')] + create_script_param_config('param3', default='A', values_ui_mapping={'A': 'Value 1'})] config_model = _create_config_model('conf_with_defaults', parameters=parameters) self.assertEqual(3, len(config_model.parameters)) values = config_model.parameter_values - self.assertEqual('123', values.get('param1')) - self.assertIsNone(values.get('param2')) - self.assertEqual('A', values.get('param3')) + self.assertEqual('123', values.get('param1').mapped_script_value) + self.assertIsNone(values.get('param2').mapped_script_value) + self.assertEqual('A', values.get('param3').mapped_script_value) + self.assertEqual('Value 1', values.get('param3').user_value) def test_create_with_parameters_and_custom_values(self): parameters = [create_script_param_config('param1', default='def1'), @@ -77,9 +79,9 @@ def test_create_with_parameters_and_custom_values(self): self.assertEqual(3, len(config_model.parameters)) values = config_model.parameter_values - self.assertEqual('123', values.get('param1')) - self.assertIsNone(values.get('param2')) - self.assertEqual(True, values.get('param3')) + self.assertEqual('123', values.get('param1').mapped_script_value) + self.assertIsNone(values.get('param2').mapped_script_value) + self.assertEqual(True, values.get('param3').mapped_script_value) def test_create_with_missing_dependant_parameter(self): parameters = [create_script_param_config('param1', type='list', values_script='echo ${p2}')] @@ -126,7 +128,8 @@ def test_set_value(self): config_model = _create_config_model('conf_x', parameters=[param1]) config_model.set_param_value('param1', 'abc') - self.assertEqual({'param1': 'abc'}, config_model.parameter_values) + expected_values = wrap_values(config_model.parameters, {'param1': 'abc'}) + self.assertEqual(expected_values, config_model.parameter_values) def test_set_value_for_unknown_parameter(self): param1 = create_script_param_config('param1') @@ -136,6 +139,45 @@ def test_set_value_for_unknown_parameter(self): self.assertNotIn('PAR_2', config_model.parameter_values) + @parameterized.expand([ + ('list', 'ABC', 'abc'), + ('list', 'qwerty', 'def'), + ('list', 'xyz', 'xyz'), + ('PARAM_TYPE_MULTISELECT', ['ABC', 'qwerty', 'xyz'], ['abc', 'def', 'xyz']), + ('PARAM_TYPE_MULTISELECT', ['ABC'], ['abc']), + ('PARAM_TYPE_MULTISELECT', ['xyz'], ['xyz']), + ('PARAM_TYPE_MULTISELECT', [], []), + ('int', 1, 'One'), + ('int', 2, 2), + ]) + def test_set_value_when_mapping(self, param_type, user_value, expected_script_value): + parameters = [ + create_script_param_config('p1', + type=param_type, + allowed_values=['abc', 'def', 'ghi', 'xyz'], + values_ui_mapping={'abc': 'ABC', 'def': 'qwerty', 'One': '1'})] + + config_model = _create_config_model('config', parameters=parameters) + config_model.set_param_value('p1', user_value) + + value = config_model.parameter_values['p1'] + self.assertEqual( + (user_value, expected_script_value), + (value.user_value, value.mapped_script_value)) + + def test_set_value_when_mapping_and_wrong_value(self): + parameters = [ + create_script_param_config('p1', + type='list', + allowed_values=['abc', 'def', 'ghi', 'xyz'], + values_ui_mapping={'abc': 'ABC', 'def': 'qwerty', 'One': '1'})] + + config_model = _create_config_model('config', parameters=parameters) + self.assertRaises(InvalidValueException, config_model.set_param_value, 'p1', 'abc') + + value = config_model.parameter_values['p1'] + self.assertEqual(ScriptValueWrapper(None, None, None), value) + def test_set_all_values_when_dependant_before_required(self): parameters = [ create_script_param_config('dep_p2', type='list', values_script='echo "X${p1}X"'), @@ -146,9 +188,10 @@ def test_set_all_values_when_dependant_before_required(self): values = {'dep_p2': 'XabcX', 'p1': 'abc'} config_model.set_all_param_values(values) - self.assertEqual(values, config_model.parameter_values) + expected_values = wrap_values(config_model.parameters, values) + self.assertEqual(expected_values, config_model.parameter_values) - def test_set_all_values_when_dependants_cylce(self): + def test_set_all_values_when_dependants_cycle(self): parameters = [ create_script_param_config('p1', type='list', values_script='echo "X${p2}X"'), create_script_param_config('p2', type='list', values_script='echo "X${p1}X"')] @@ -168,7 +211,47 @@ def test_set_all_values_with_normalization(self): config_model = _create_config_model('config', parameters=parameters) config_model.set_all_param_values({'p1': '', 'p2': ['def'], 'p3': 'abc'}) - self.assertEqual({'p1': [], 'p2': ['def'], 'p3': ['abc']}, config_model.parameter_values) + expected_values = wrap_values(config_model.parameters, {'p1': [], 'p2': ['def'], 'p3': ['abc']}) + self.assertEqual(expected_values, config_model.parameter_values) + + def test_set_all_values_when_dependant_and_ui_mapping(self): + parameters = [ + create_script_param_config('dep_p2', type='list', values_script='echo "X${p1}X"'), + create_script_param_config('p1', values_ui_mapping={'abc': 'qwerty'})] + + config_model = _create_config_model('main_conf', parameters=parameters) + + values = {'dep_p2': 'XabcX', 'p1': 'qwerty'} + config_model.set_all_param_values(values) + + expected_values = wrap_values(config_model.parameters, values) + self.assertEqual(expected_values, config_model.parameter_values) + + @parameterized.expand([ + ('list', 'ABC', 'abc'), + ('list', 'qwerty', 'def'), + ('list', 'xyz', 'xyz'), + ('PARAM_TYPE_MULTISELECT', ['ABC', 'qwerty', 'xyz'], ['abc', 'def', 'xyz']), + ('PARAM_TYPE_MULTISELECT', ['ABC'], ['abc']), + ('PARAM_TYPE_MULTISELECT', ['xyz'], ['xyz']), + ('PARAM_TYPE_MULTISELECT', [], []), + ('int', 1, 'One'), + ('int', 2, 2), + ]) + def test_set_all_values_when_mapping(self, param_type, user_value, expected_script_value): + parameters = [ + create_script_param_config('p1', + type=param_type, + allowed_values=['abc', 'def', 'ghi', 'xyz'], + values_ui_mapping={'abc': 'ABC', 'def': 'qwerty', 'One': '1'})] + + config_model = _create_config_model('config', parameters=parameters) + config_model.set_all_param_values({'p1': user_value}) + + value = config_model.parameter_values['p1'] + self.assertEqual( + (user_value, expected_script_value), + (value.user_value, value.mapped_script_value)) class ConfigModelListFilesTest(unittest.TestCase): @@ -514,7 +597,8 @@ def test_set_all_values_for_included(self): values = {'p1': included_path, 'included_param1': 'X', 'included_param2': 123} config_model.set_all_param_values(values) - self.assertEqual(values, config_model.parameter_values) + expected_values = wrap_values(config_model.parameters, values) + self.assertEqual(expected_values, config_model.parameter_values) def test_set_all_values_for_dependant_on_constant(self): included_path = test_utils.write_script_config({'parameters': [ @@ -529,14 +613,15 @@ def test_set_all_values_for_dependant_on_constant(self): values = {'included_param1': 't2'} config_model.set_all_param_values(values) - self.assertEqual({'included_param1': 't2', 'p1': 't1\nt2\nt3'}, config_model.parameter_values) + expected_values = wrap_values(config_model.parameters, {'included_param1': 't2', 'p1': 't1\nt2\nt3'}) + self.assertEqual(expected_values, config_model.parameter_values) def test_dynamic_include_add_parameter_with_default(self): (config_model, included_path) = self.prepare_config_model_with_included([ create_script_param_config('included_param', default='abc 123') ], 'p1') - self.assertEqual('abc 123', config_model.parameter_values.get('included_param')) + self.assertEqual('abc 123', config_model.parameter_values.get('included_param').mapped_script_value) def test_dynamic_include_add_parameter_with_default_when_value_exist(self): (config_model, included_path) = self.prepare_config_model_with_included([ @@ -546,10 +631,10 @@ def test_dynamic_include_add_parameter_with_default_when_value_exist(self): config_model.set_param_value('included_param', 'def 456') config_model.set_param_value('p1', 'random value') - self.assertEqual('def 456', config_model.parameter_values.get('included_param')) + self.assertEqual('def 456', config_model.parameter_values.get('included_param').mapped_script_value) config_model.set_param_value('p1', included_path) - self.assertEqual('def 456', config_model.parameter_values.get('included_param')) + self.assertEqual('def 456', config_model.parameter_values.get('included_param').mapped_script_value) def test_dynamic_include_add_2_parameters_with_default_when_one_dependant(self): (config_model, included_path) = self.prepare_config_model_with_included([ @@ -558,8 +643,8 @@ def test_dynamic_include_add_2_parameters_with_default_when_one_dependant(self): values_script='echo x${included_param1}x'), ], 'p1') - self.assertEqual('ABC', config_model.parameter_values.get('included_param1')) - self.assertEqual('xABCx', config_model.parameter_values.get('included_param2')) + self.assertEqual('ABC', config_model.parameter_values.get('included_param1').mapped_script_value) + self.assertEqual('xABCx', config_model.parameter_values.get('included_param2').mapped_script_value) dependant_parameter = config_model.find_parameter('included_param2') self.assertEqual(['xABCx'], dependant_parameter.values) @@ -667,13 +752,15 @@ def test_multiple_required_parameters_when_one_missing_and_skip_invalid(self): valid = self._validate(script_config, values, skip_invalid_parameters=True) self.assertTrue(valid) - self.assertEqual({ - 'param0': '0', - 'param1': '1', - 'param2': '2', - 'param3': None, - 'param4': '4'}, - script_config.parameter_values) + + expected_values = wrap_values( + script_config.parameters, + {'param0': '0', + 'param1': '1', + 'param2': '2', + 'param3': None, + 'param4': '4'}) + self.assertEqual(expected_values, script_config.parameter_values) def test_multiple_parameters_when_all_defined(self): values = {} @@ -868,7 +955,7 @@ def add_parameter(self, config): self.parameters.append(config) def set_value(self, name, value): - self.values[name] = value + self.values[name] = ScriptValueWrapper(value, value, value) class GetSortedConfigTest(unittest.TestCase): diff --git a/src/tests/test_utils.py b/src/tests/test_utils.py index 0cc3f26e..db70bb96 100644 --- a/src/tests/test_utils.py +++ b/src/tests/test_utils.py @@ -14,6 +14,7 @@ from execution.process_base import ProcessWrapper from model.script_config import ConfigModel, ParameterModel from model.server_conf import LoggingConfig +from model.value_wrapper import ScriptValueWrapper from react.observable import read_until_closed from react.properties import ObservableDict, ObservableList from utils import audit_utils @@ -156,83 +157,45 @@ def create_script_param_config( pass_as=None, stdin_expected_text=None, ui_separator_type=None, - ui_separator_title=None): + ui_separator_title=None, + values_ui_mapping=None): + method_params = dict(locals()) conf = {'name': param_name} - if type is not None: - conf['type'] = type + simple_options = { + 'type': 'type', + 'default': 'default', + 'required': 'required', + 'secure': 'secure', + 'param': 'param', + 'env_var': 'env_var', + 'no_value': 'no_value', + 'constant': 'constant', + 'multiselect_separator': 'separator', + 'multiselect_argument_type': 'multiselect_argument_type', + 'min': 'min', + 'max': 'max', + 'file_dir': 'file_dir', + 'file_recursive': 'file_recursive', + 'file_extensions': 'file_extensions', + 'file_type': 'file_type', + 'excluded_files': 'excluded_files', + 'same_arg_param': 'same_arg_param', + 'regex': 'regex', + 'max_length': 'max_length', + 'pass_as': 'pass_as', + 'stdin_expected_text': 'stdin_expected_text', + 'values_ui_mapping': 'values_ui_mapping', + } if values_script is not None: conf['values'] = {'script': values_script} if values_script_shell is not None: conf['values']['shell'] = values_script_shell - if default is not None: - conf['default'] = default - - if required is not None: - conf['required'] = required - - if secure is not None: - conf['secure'] = secure - - if param is not None: - conf['param'] = param - - if env_var is not None: - conf['env_var'] = env_var - - if no_value is not None: - conf['no_value'] = no_value - - if constant is not None: - conf['constant'] = constant - - if multiselect_separator is not None: - conf['separator'] = multiselect_separator - - if multiselect_argument_type is not None: - conf['multiselect_argument_type'] = multiselect_argument_type - - if min is not None: - conf['min'] = min - - if max is not None: - conf['max'] = max - if allowed_values is not None: conf['values'] = list(allowed_values) - if file_dir is not None: - conf['file_dir'] = file_dir - - if file_recursive is not None: - conf['file_recursive'] = file_recursive - - if file_extensions is not None: - conf['file_extensions'] = file_extensions - - if file_type is not None: - conf['file_type'] = file_type - - if excluded_files is not None: - conf['excluded_files'] = excluded_files - - if same_arg_param is not None: - conf['same_arg_param'] = same_arg_param - - if regex is not None: - conf['regex'] = regex - - if max_length is not None: - conf['max_length'] = max_length - - if pass_as is not None: - conf['pass_as'] = pass_as - - if stdin_expected_text is not None: - conf['stdin_expected_text'] = stdin_expected_text - if ui_separator_type or ui_separator_title: separator_conf = {} conf['ui'] = {'separator_before': separator_conf} @@ -241,6 +204,11 @@ def create_script_param_config( if ui_separator_title: separator_conf['title'] = ui_separator_title + for param_name, conf_name in simple_options.items(): + value = method_params[param_name] + if value is not None: + conf[conf_name] = value + return conf @@ -291,7 +259,7 @@ def create_config_model(name, *, model = ConfigModel(result_config, path, username, audit_name, process_invoker) if parameter_values is not None: - model.set_all_param_values(model) + model.set_all_param_values(parameter_values) return model @@ -324,7 +292,8 @@ def create_parameter_model(name=None, pass_as=None, stdin_expected_text=None, ui_separator_type=None, - ui_separator_title=None): + ui_separator_title=None, + values_ui_mapping=None): config = create_script_param_config( name, type=type, @@ -349,7 +318,8 @@ def create_parameter_model(name=None, pass_as=pass_as, stdin_expected_text=stdin_expected_text, ui_separator_type=ui_separator_type, - ui_separator_title=ui_separator_title) + ui_separator_title=ui_separator_title, + values_ui_mapping=values_ui_mapping) if all_parameters is None: all_parameters = [] @@ -528,6 +498,29 @@ def wait_and_read(process_wrapper): return ''.join(read_until_closed(process_wrapper.output_stream)) +def wrap_values(parameters, values): + parameters_dict = {p.name: p for p in parameters} + + result = {} + for name, value in values.items(): + parameter = parameters_dict.get(name) + if parameter: + value_wrapper = parameter.create_value_wrapper(value) + else: + value_wrapper = ScriptValueWrapper(value, value, value) + result[name] = value_wrapper + + return result + + +def validate_value(parameter, value): + return parameter.validate_value(parameter.create_value_wrapper(value)) + + +def get_default_value(parameter_model): + return parameter_model.create_value_wrapper_for_default().mapped_script_value + + class _MockProcessWrapper(ProcessWrapper): def __init__(self, executor, command, working_directory, env_variables): super().__init__(command, working_directory, env_variables) diff --git a/src/tests/web/script_config_socket_test.py b/src/tests/web/script_config_socket_test.py index 65639d08..74a4be46 100644 --- a/src/tests/web/script_config_socket_test.py +++ b/src/tests/web/script_config_socket_test.py @@ -54,7 +54,7 @@ def test_reload_model(self): self.socket.write_message(json.dumps({ 'event': 'reloadModelValues', 'data': {'clientModelId': 'abcd', - 'parameterValues': {'list 1': 'A', 'file 1': 'z', 'list 2': 'z1.txt'}}})) + 'parameterValues': {'list 1': 'Value a', 'file 1': 'z', 'list 2': 'z1.txt'}}})) response = yield self.socket.read_message() self.assertIsNone(self.socket.close_reason) @@ -106,7 +106,7 @@ def test_client_version(self): self.socket.write_message(json.dumps({ 'event': 'reloadModelValues', 'data': {'clientModelId': 'abcd', - 'parameterValues': {'list 1': 'A', 'file 1': 'y', 'list 2': 'y1.txt'}, + 'parameterValues': {'list 1': 'Value a', 'file 1': 'y', 'list 2': 'y1.txt'}, 'clientStateVersion': 7}})) self._assert_parameter_change((yield self.socket.read_message()), @@ -258,13 +258,15 @@ def setUp(self): 'parameters': [ test_utils.create_script_param_config('text 1', required=True), test_utils.create_script_param_config('list 1', type='list', - allowed_values=['A', 'B', 'C']), + allowed_values=['A', 'B', 'C'], + values_ui_mapping={'A': 'Value a', 'C': 'Customer'}, + default='C'), test_utils.create_script_param_config('file 1', type='server_file', file_dir=test1_files_path, ui_separator_type='line', ui_separator_title='Some title'), test_utils.create_script_param_config('list 2', type='list', - values_script='ls ' + test1_files_path + '/${file 1}') + values_script='ls ${file 1}') ]}, 'test_script_1') @@ -301,8 +303,8 @@ def _text1(): def _list1(): - return {'name': 'list 1', 'description': None, 'withoutValue': False, 'required': False, 'default': None, - 'type': 'list', 'min': None, 'max': None, 'max_length': None, 'values': ['A', 'B', 'C'], + return {'name': 'list 1', 'description': None, 'withoutValue': False, 'required': False, 'default': 'Customer', + 'type': 'list', 'min': None, 'max': None, 'max_length': None, 'values': ['Value a', 'B', 'Customer'], 'secure': False, 'fileRecursive': False, 'fileType': None, 'requiredParameters': [], 'regex': None, diff --git a/src/web/server.py b/src/web/server.py index 160b4318..82fd1848 100755 --- a/src/web/server.py +++ b/src/web/server.py @@ -409,10 +409,9 @@ def post(self, user): all_audit_names = user.audit_names LOGGER.info('Calling script %s. User %s', script_name, all_audit_names) - execution_id = self.application.execution_service.start_script( - config_model, - parameter_values, - user) + config_model.set_all_param_values(parameter_values) + + execution_id = self.application.execution_service.start_script(config_model, user) self.write(str(execution_id)) diff --git a/web-src/src/admin/components/scripts-config/ParameterConfigForm.vue b/web-src/src/admin/components/scripts-config/ParameterConfigForm.vue index 27c397a6..0fb93114 100644 --- a/web-src/src/admin/components/scripts-config/ParameterConfigForm.vue +++ b/web-src/src/admin/components/scripts-config/ParameterConfigForm.vue @@ -124,16 +124,23 @@ @error="handleError(uiSeparatorTitleField, $event)"/> + +