From d0268a9aa7462b49881af25ba6bdb17fb3f366d3 Mon Sep 17 00:00:00 2001 From: Maximilian Ludvigsson Date: Fri, 17 Jan 2020 09:30:37 +0800 Subject: [PATCH 1/4] add `--num-threads` argument to train cli --- rasa/cli/arguments/train.py | 15 ++++++++++++ rasa/cli/train.py | 17 ++++++++++--- rasa/train.py | 48 +++++++++++++++++++++++++++---------- 3 files changed, 65 insertions(+), 15 deletions(-) diff --git a/rasa/cli/arguments/train.py b/rasa/cli/arguments/train.py index c3b4741b555e..6ac8fecd2ea2 100644 --- a/rasa/cli/arguments/train.py +++ b/rasa/cli/arguments/train.py @@ -21,6 +21,8 @@ def set_train_arguments(parser: argparse.ArgumentParser): add_debug_plots_param(parser) add_dump_stories_param(parser) + add_num_threads_param(parser) + add_model_name_param(parser) add_persist_nlu_data_param(parser) add_force_param(parser) @@ -50,6 +52,8 @@ def set_train_nlu_arguments(parser: argparse.ArgumentParser): add_nlu_data_param(parser, help_text="File or folder containing your NLU data.") + add_num_threads_param(parser) + add_model_name_param(parser) add_persist_nlu_data_param(parser) @@ -133,6 +137,17 @@ def add_debug_plots_param( ) +def add_num_threads_param( + parser: Union[argparse.ArgumentParser, argparse._ActionsContainer] +): + parser.add_argument( + "--num-threads", + type=int, + default=1, + help="Maximum amount of threads to use when training.", + ) + + def add_model_name_param(parser: argparse.ArgumentParser): parser.add_argument( "--fixed-model-name", diff --git a/rasa/cli/train.py b/rasa/cli/train.py index 0da1b5ebfd9a..e618d95ee923 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -73,7 +73,8 @@ def train(args: argparse.Namespace) -> Optional[Text]: force_training=args.force, fixed_model_name=args.fixed_model_name, persist_nlu_training_data=args.persist_nlu_data, - additional_arguments=extract_additional_arguments(args), + core_additional_arguments=extract_core_additional_arguments(args), + nlu_additional_arguments=extract_core_additional_arguments(args), ) @@ -92,7 +93,7 @@ def train_core( story_file = get_validated_path( args.stories, "stories", DEFAULT_DATA_PATH, none_is_valid=True ) - additional_arguments = extract_additional_arguments(args) + additional_arguments = extract_core_additional_arguments(args) # Policies might be a list for the compare training. Do normal training # if only list item was passed. @@ -138,10 +139,11 @@ def train_nlu( train_path=train_path, fixed_model_name=args.fixed_model_name, persist_nlu_training_data=args.persist_nlu_data, + additional_arguments=extract_nlu_additional_arguments(args), ) -def extract_additional_arguments(args: argparse.Namespace) -> Dict: +def extract_core_additional_arguments(args: argparse.Namespace) -> Dict: arguments = {} if "augmentation" in args: @@ -154,6 +156,15 @@ def extract_additional_arguments(args: argparse.Namespace) -> Dict: return arguments +def extract_nlu_additional_arguments(args: argparse.Namespace) -> Dict: + arguments = {} + + if "num_threads" in args: + arguments["num_threads"] = args.num_threads + + return arguments + + def _get_valid_config( config: Optional[Text], mandatory_keys: List[Text], diff --git a/rasa/train.py b/rasa/train.py index b18a3018a10d..7ccc69aa0c05 100644 --- a/rasa/train.py +++ b/rasa/train.py @@ -28,7 +28,8 @@ def train( force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, - additional_arguments: Optional[Dict] = None, + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, loop: Optional[asyncio.AbstractEventLoop] = None, ) -> Optional[Text]: if loop is None: @@ -47,7 +48,8 @@ def train( force_training=force_training, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, - additional_arguments=additional_arguments, + core_additional_arguments=core_additional_arguments, + nlu_additional_arguments=nlu_additional_arguments, ) ) @@ -60,7 +62,8 @@ async def train_async( force_training: bool = False, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, - additional_arguments: Optional[Dict] = None, + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Trains a Rasa model (Core and NLU). @@ -73,7 +76,9 @@ async def train_async( fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. - additional_arguments: Additional training parameters. + core_additional_arguments: Additional training parameters for core training. + nlu_additional_arguments: Additional training parameters forwarded to training + method of each NLU component. Returns: Path of the trained model archive. @@ -98,7 +103,8 @@ async def train_async( force_training, fixed_model_name, persist_nlu_training_data, - additional_arguments, + core_additional_arguments=core_additional_arguments, + nlu_additional_arguments=nlu_additional_arguments, ) @@ -122,7 +128,8 @@ async def _train_async_internal( force_training: bool, fixed_model_name: Optional[Text], persist_nlu_training_data: bool, - additional_arguments: Optional[Dict], + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Trains a Rasa model (Core and NLU). Use only from `train_async`. @@ -131,10 +138,12 @@ async def _train_async_internal( train_path: Directory in which to train the model. output_path: Output path. force_training: If `True` retrain model even if data has not changed. + fixed_model_name: Name of model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. - fixed_model_name: Name of model to be stored. - additional_arguments: Additional training parameters. + core_additional_arguments: Additional training parameters for core training. + nlu_additional_arguments: Additional training parameters forwarded to training + method of each NLU component. Returns: Path of the trained model archive. @@ -158,6 +167,7 @@ async def _train_async_internal( output=output_path, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, + additional_arguments=nlu_additional_arguments, ) if nlu_data.is_empty(): @@ -166,7 +176,7 @@ async def _train_async_internal( file_importer, output=output_path, fixed_model_name=fixed_model_name, - additional_arguments=additional_arguments, + additional_arguments=core_additional_arguments, ) new_fingerprint = await model.model_fingerprint(file_importer) @@ -185,7 +195,8 @@ async def _train_async_internal( fingerprint_comparison_result=fingerprint_comparison, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, - additional_arguments=additional_arguments, + core_additional_arguments=core_additional_arguments, + nlu_additional_arguments=nlu_additional_arguments, ) return model.package_model( @@ -209,7 +220,8 @@ async def _do_training( fingerprint_comparison_result: Optional[FingerprintComparisonResult] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, - additional_arguments: Optional[Dict] = None, + core_additional_arguments: Optional[Dict] = None, + nlu_additional_arguments: Optional[Dict] = None, ): if not fingerprint_comparison_result: fingerprint_comparison_result = FingerprintComparisonResult() @@ -220,7 +232,7 @@ async def _do_training( output=output_path, train_path=train_path, fixed_model_name=fixed_model_name, - additional_arguments=additional_arguments, + additional_arguments=core_additional_arguments, ) elif fingerprint_comparison_result.should_retrain_nlg(): print_color( @@ -243,6 +255,7 @@ async def _do_training( train_path=train_path, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, + additional_arguments=nlu_additional_arguments, ) else: print_color( @@ -383,6 +396,7 @@ def train_nlu( train_path: Optional[Text] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, + additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Trains an NLU model. @@ -395,6 +409,8 @@ def train_nlu( fixed_model_name: Name of the model to be stored. persist_nlu_training_data: `True` if the NLU training data should be persisted with the model. + additional_arguments: Additional training parameters which will be passed to + the `train` method of each component. Returns: @@ -412,6 +428,7 @@ def train_nlu( train_path, fixed_model_name, persist_nlu_training_data, + additional_arguments, ) ) @@ -423,6 +440,7 @@ async def _train_nlu_async( train_path: Optional[Text] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, + additional_arguments: Optional[Dict] = None, ): # training NLU only hence the training files still have to be selected file_importer = TrainingDataImporter.load_nlu_importer_from_config( @@ -443,6 +461,7 @@ async def _train_nlu_async( train_path=train_path, fixed_model_name=fixed_model_name, persist_nlu_training_data=persist_nlu_training_data, + additional_arguments=additional_arguments, ) @@ -452,11 +471,15 @@ async def _train_nlu_with_validated_data( train_path: Optional[Text] = None, fixed_model_name: Optional[Text] = None, persist_nlu_training_data: bool = False, + additional_arguments: Optional[Dict] = None, ) -> Optional[Text]: """Train NLU with validated training and config data.""" import rasa.nlu.train + if additional_arguments is None: + additional_arguments = {} + with ExitStack() as stack: if train_path: # If the train path was provided, do nothing on exit. @@ -472,6 +495,7 @@ async def _train_nlu_with_validated_data( _train_path, fixed_model_name="nlu", persist_nlu_training_data=persist_nlu_training_data, + **additional_arguments, ) print_color("NLU model training completed.", color=bcolors.OKBLUE) From befe4cd7316f381011adf524b5a650d50aa02ed0 Mon Sep 17 00:00:00 2001 From: Tom Bocklisch Date: Sat, 16 May 2020 00:18:38 +0200 Subject: [PATCH 2/4] Update rasa/cli/train.py --- rasa/cli/train.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/rasa/cli/train.py b/rasa/cli/train.py index e618d95ee923..386258ea7482 100644 --- a/rasa/cli/train.py +++ b/rasa/cli/train.py @@ -74,7 +74,7 @@ def train(args: argparse.Namespace) -> Optional[Text]: fixed_model_name=args.fixed_model_name, persist_nlu_training_data=args.persist_nlu_data, core_additional_arguments=extract_core_additional_arguments(args), - nlu_additional_arguments=extract_core_additional_arguments(args), + nlu_additional_arguments=extract_nlu_additional_arguments(args), ) From 26142cc8d9d3f593b81a4cc05ba3ccb02ed5c956 Mon Sep 17 00:00:00 2001 From: Tom Bocklisch Date: Sat, 16 May 2020 00:22:13 +0200 Subject: [PATCH 3/4] Create 5086.feature.rst --- changelog/5086.feature.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 changelog/5086.feature.rst diff --git a/changelog/5086.feature.rst b/changelog/5086.feature.rst new file mode 100644 index 000000000000..b0b9e47630ed --- /dev/null +++ b/changelog/5086.feature.rst @@ -0,0 +1 @@ +Added a ``--num-threads`` CLI argument that can be passed to ``rasa train`` and will be used to train NLU components. From 987b7b1ea5c4325d5a2f3959efaa4a371382eb64 Mon Sep 17 00:00:00 2001 From: Tom Bocklisch Date: Mon, 18 May 2020 11:23:31 +0200 Subject: [PATCH 4/4] fixed tests --- tests/cli/test_rasa_train.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/tests/cli/test_rasa_train.py b/tests/cli/test_rasa_train.py index d0147a575ea3..20d747c8125d 100644 --- a/tests/cli/test_rasa_train.py +++ b/tests/cli/test_rasa_train.py @@ -326,6 +326,7 @@ def test_train_help(run): help_text = """usage: rasa train [-h] [-v] [-vv] [--quiet] [--data DATA [DATA ...]] [-c CONFIG] [-d DOMAIN] [--out OUT] [--augmentation AUGMENTATION] [--debug-plots] + [--num-threads NUM_THREADS] [--fixed-model-name FIXED_MODEL_NAME] [--persist-nlu-data] [--force] {core,nlu} ...""" @@ -340,7 +341,8 @@ def test_train_nlu_help(run: Callable[..., RunResult]): output = run("train", "nlu", "--help") help_text = """usage: rasa train nlu [-h] [-v] [-vv] [--quiet] [-c CONFIG] [--out OUT] - [-u NLU] [--fixed-model-name FIXED_MODEL_NAME] + [-u NLU] [--num-threads NUM_THREADS] + [--fixed-model-name FIXED_MODEL_NAME] [--persist-nlu-data]""" lines = help_text.split("\n")