Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rasa train nlu accepts domain #6913

Merged
merged 15 commits into from
Oct 7, 2020
Merged
19 changes: 19 additions & 0 deletions data/test_config/config_defaults.yml
Original file line number Diff line number Diff line change
@@ -1,4 +1,23 @@
language: en
pipeline: []
# # No configuration for the NLU pipeline was provided. The following default pipeline was used to train your model.
# # If you'd like to customize it, uncomment and adjust the pipeline.
# # See https://rasa.com/docs/rasa/tuning-your-model for more information.
# - name: WhitespaceTokenizer
# - name: RegexFeaturizer
# - name: LexicalSyntacticFeaturizer
# - name: CountVectorsFeaturizer
# - name: CountVectorsFeaturizer
# analyzer: char_wb
# min_ngram: 1
# max_ngram: 4
# - name: DIETClassifier
# epochs: 100
# - name: EntitySynonymMapper
# - name: ResponseSelector
# epochs: 100
# - name: FallbackClassifier
# threshold: 0.3
# ambiguity_threshold: 0.1

data:
6 changes: 6 additions & 0 deletions data/test_config/config_response_selector_minimal.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
language: en
pipeline:
- name: WhitespaceTokenizer
- name: CountVectorsFeaturizer
- name: ResponseSelector
epochs: 1
degiz marked this conversation as resolved.
Show resolved Hide resolved
10 changes: 10 additions & 0 deletions data/test_nlu_no_responses/domain_with_only_responses.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
responses:
utter_chitchat/ask_name:
- image: "https://i.imgur.com/zTvA58i.jpeg"
text: hello, my name is retrieval bot.
- text: Oh yeah, I am called the retrieval bot.

utter_chitchat/ask_weather:
- text: Oh, it does look sunny right now in Berlin.
image: "https://i.imgur.com/vwv7aHN.png"
degiz marked this conversation as resolved.
Show resolved Hide resolved
- text: I am not sure of the whole week but I can see the sun is out today.
16 changes: 16 additions & 0 deletions data/test_nlu_no_responses/nlu_no_responses.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
version: "2.0"

nlu:
- intent: chitchat/ask_name
examples: |
- What is your name?
- May I know your name?
- What do people call you?
- Do you have a name for yourself?

- intent: chitchat/ask_weather
examples: |
- What's the weather like today?
- Does it look sunny outside today?
- Oh, do you mind checking the weather for me please?
- I like sunny days in Berlin.
5 changes: 3 additions & 2 deletions rasa/cli/arguments/default_arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,13 +49,14 @@ def add_nlu_data_param(


def add_domain_param(
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer]
parser: Union[argparse.ArgumentParser, argparse._ActionsContainer],
default: Optional[Text] = DEFAULT_DOMAIN_PATH,
) -> None:
parser.add_argument(
"-d",
"--domain",
type=str,
default=DEFAULT_DOMAIN_PATH,
default=default,
help="Domain specification. This can be a single YAML file, or a directory "
"that contains several files with domain specifications in it. The content "
"of these files will be read and merged together.",
Expand Down
1 change: 1 addition & 0 deletions rasa/cli/arguments/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def set_train_core_arguments(parser: argparse.ArgumentParser):

def set_train_nlu_arguments(parser: argparse.ArgumentParser):
add_config_param(parser)
add_domain_param(parser, default=None)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why making the default None? get_validated_path does a lot of magic and would automatically set it to None in case the default path does not exist.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So that it's None in rasa train nlu --help, because we don't want to change the default behaviour.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why don't we wanna join the default behavior? And it doesn't change the default behavior anyway, right?

add_out_param(parser, help_text="Directory where your models should be stored.")

add_nlu_data_param(parser, help_text="File or folder containing your NLU data.")
Expand Down
6 changes: 6 additions & 0 deletions rasa/cli/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def train_nlu(
args.nlu, "nlu", DEFAULT_DATA_PATH, none_is_valid=True
)

if args.domain:
args.domain = rasa.cli.utils.get_validated_path(
args.domain, "domain", DEFAULT_DOMAIN_PATH, none_is_valid=True
)

return train_nlu(
config=config,
nlu_data=nlu_data,
Expand All @@ -151,6 +156,7 @@ def train_nlu(
fixed_model_name=args.fixed_model_name,
persist_nlu_training_data=args.persist_nlu_data,
additional_arguments=extract_nlu_additional_arguments(args),
domain=args.domain,
)
wochinge marked this conversation as resolved.
Show resolved Hide resolved


Expand Down
6 changes: 5 additions & 1 deletion rasa/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,6 +434,7 @@ def train_nlu(
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
domain: Optional[Union[Domain, Text]] = None,
) -> Optional[Text]:
"""Trains an NLU model.

Expand All @@ -448,6 +449,7 @@ def train_nlu(
with the model.
additional_arguments: Additional training parameters which will be passed to
the `train` method of each component.
domain: Path to the optional domain file/Domain object.


Returns:
Expand All @@ -465,6 +467,7 @@ def train_nlu(
fixed_model_name,
persist_nlu_training_data,
additional_arguments,
domain=domain,
)
)

Expand All @@ -477,6 +480,7 @@ async def _train_nlu_async(
fixed_model_name: Optional[Text] = None,
persist_nlu_training_data: bool = False,
additional_arguments: Optional[Dict] = None,
domain: Optional[Union[Domain, Text]] = None,
):
degiz marked this conversation as resolved.
Show resolved Hide resolved
if not nlu_data:
print_error(
Expand All @@ -487,7 +491,7 @@ async def _train_nlu_async(

# training NLU only hence the training files still have to be selected
file_importer = TrainingDataImporter.load_nlu_importer_from_config(
config, training_data_paths=[nlu_data]
config, domain, training_data_paths=[nlu_data]
)

training_data = await file_importer.get_nlu_data()
Expand Down
48 changes: 48 additions & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@
import rasa.core
import rasa.shared.importers.autoconfig as autoconfig
from rasa.core.interpreter import RasaNLUInterpreter
from rasa.shared.core.domain import Domain
from rasa.shared.importers.importer import TrainingDataImporter

from rasa.train import train_core, train_nlu, train
from tests.conftest import DEFAULT_CONFIG_PATH, DEFAULT_NLU_DATA
Expand Down Expand Up @@ -164,6 +166,52 @@ def test_train_nlu_wrong_format_error_message(
assert "Please verify the data format" in captured.out


def test_train_nlu_with_responses_no_domain_warns(
tmp_path: Path, monkeypatch: MonkeyPatch,
degiz marked this conversation as resolved.
Show resolved Hide resolved
):
(tmp_path / "training").mkdir()
degiz marked this conversation as resolved.
Show resolved Hide resolved
(tmp_path / "models").mkdir()
degiz marked this conversation as resolved.
Show resolved Hide resolved

data_path = "data/test_nlu_no_responses/nlu_no_responses.yml"

with pytest.warns(UserWarning) as records:
train_nlu(
"data/test_config/config_response_selector_minimal.yml",
data_path,
output=str(tmp_path / "models"),
)

assert any(
"You either need to add a response phrase or correct the intent"
in record.message.args[0]
for record in records
)


def test_train_nlu_with_responses_and_domain_no_warns(
tmp_path: Path, monkeypatch: MonkeyPatch,
degiz marked this conversation as resolved.
Show resolved Hide resolved
):
(tmp_path / "training").mkdir()
(tmp_path / "models").mkdir()

data_path = "data/test_nlu_no_responses/nlu_no_responses.yml"
domain_path = "data/test_nlu_no_responses/domain_with_only_responses.yml"

with pytest.warns(None) as records:
train_nlu(
"data/test_config/config_response_selector_minimal.yml",
data_path,
output=str(tmp_path / "models"),
domain=domain_path,
)

assert not any(
"You either need to add a response phrase or correct the intent"
in record.message.args[0]
for record in records
)


def test_train_nlu_no_nlu_file_error_message(
capsys: CaptureFixture,
tmp_path: Text,
Expand Down