Skip to content

Commit

Permalink
implement decorator to register class with recipe
Browse files Browse the repository at this point in the history
  • Loading branch information
wochinge committed Aug 23, 2021
1 parent 8070d6f commit 655ecfb
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 5 deletions.
59 changes: 57 additions & 2 deletions rasa/engine/recipes/default_recipe.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

import copy
import enum
import logging
from typing import Dict, Text, Any, Tuple, Type, Optional, List
from enum import Enum
from typing import Dict, Text, Any, Tuple, Type, Optional, List, Callable

import dataclasses

Expand Down Expand Up @@ -39,11 +41,14 @@
from rasa.nlu.tokenizers.tokenizer import Tokenizer
from rasa.nlu.utils.mitie_utils import MitieNLP
from rasa.nlu.utils.spacy_utils import SpacyNLP
from rasa.shared.exceptions import RasaException
from rasa.shared.importers.autoconfig import TrainingType
from rasa.utils.tensorflow.constants import EPOCHS

# from rasa.utils.tensorflow.constants import EPOCHS

logger = logging.getLogger(__name__)


# TODO: Remove once they are implemented
class ProjectProvider(GraphComponent):
pass
Expand Down Expand Up @@ -143,13 +148,63 @@ class TrackerToMessageConverter(GraphComponent):
}


class DefaultV1RecipeRegisterException(RasaException):
"""If you register a class which is not of type `GraphComponent`."""

pass


class DefaultV1Recipe(Recipe):
"""Recipe which converts the normal model config to train and predict graph."""

@enum.unique
class ComponentType(Enum):
"""Enum to categorize and place custom components correctly in the graph."""

MESSAGE_TOKENIZER = 0
TRAINABLE_MESSAGE_FEATURIZER = 1
PRETRAINED_MESSAGE_FEATURIZER = 2
TRAINABLE_INTENT_CLASSIFIER = 3
PRETRAINED_INTENT_CLASSIFIER = 4
TRAINABLE_ENTITY_EXTRACTOR = 5
PRETRAINED_ENTITY_EXTRACTOR = 6
POLICY_WITH_END_TO_END_SUPPORT = 7
MODEL_LOADER = 8

name = "default.v1"
registered_components: Dict[Type[GraphComponent], ComponentType] = {}

def __init__(self) -> None:
"""Creates recipe."""
self._use_core = True
self._use_nlu = True

@classmethod
def register(
cls, component_type: ComponentType
) -> Callable[[Type[GraphComponent]], Type[GraphComponent]]:
def f(registered_class: Type[GraphComponent]) -> Type[GraphComponent]:
if not issubclass(registered_class, GraphComponent):
raise DefaultV1RecipeRegisterException(
f"Failed to register class '{registered_class.__name__}' with "
f"the recipe '{cls.name}'. The class has to be of type "
f"'{GraphComponent.__name__}'."
)
cls.registered_components[registered_class] = component_type
return registered_class

return f

@classmethod
def _type_for(cls, clazz: Type[GraphComponent]) -> Optional[ComponentType]:
if clazz not in cls.registered_components:
raise DefaultV1RecipeRegisterException(
f"It seems that '{clazz.__name__}' has not been registered with the "
f"'{cls.name}' recipe. Please make sure to register your class by "
f"using the decorator 'DefaultV1Recipe.register(...)' with your class."
)
return cls.registered_components.get(clazz)

def schemas_for_config(
self,
config: Dict,
Expand Down
45 changes: 42 additions & 3 deletions tests/engine/recipes/test_default_recipe.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
from typing import Text
from typing import Text, Dict, Any

import pytest

import rasa.shared.utils.io
from rasa.engine.graph import GraphSchema
from rasa.engine.recipes.default_recipe import DefaultV1Recipe
from rasa.engine.graph import GraphSchema, GraphComponent, ExecutionContext
from rasa.engine.recipes.default_recipe import (
DefaultV1Recipe,
DefaultV1RecipeRegisterException,
)
from rasa.engine.recipes.recipe import Recipe
from rasa.engine.storage.resource import Resource
from rasa.engine.storage.storage import ModelStorage
from rasa.nlu.classifiers.mitie_intent_classifier import MitieIntentClassifier
from rasa.nlu.classifiers.sklearn_intent_classifier import SklearnIntentClassifier
from rasa.nlu.extractors.mitie_entity_extractor import MitieEntityExtractor
Expand Down Expand Up @@ -251,3 +256,37 @@ def test_num_threads_interpolation():
# assert predict_schema.nodes[node_name] == node
#
# assert predict_schema == expected_predict_schema


def test_register_component():
@DefaultV1Recipe.register(DefaultV1Recipe.ComponentType.TOKENIZER)
class MyClass(GraphComponent):
@classmethod
def create(
cls,
config: Dict[Text, Any],
model_storage: ModelStorage,
resource: Resource,
execution_context: ExecutionContext,
) -> GraphComponent:
return cls()

assert DefaultV1Recipe._type_for(MyClass) == DefaultV1Recipe.ComponentType.TOKENIZER
assert MyClass()


def test_register_invalid_component():
with pytest.raises(DefaultV1RecipeRegisterException):

@DefaultV1Recipe.register(DefaultV1Recipe.ComponentType.TOKENIZER)
class MyClass:
pass


def test_retrieve_not_registered_class():
class MyClass:
pass

with pytest.raises(DefaultV1RecipeRegisterException):
# noinspection PyTypeChecker
DefaultV1Recipe._type_for(MyClass)

0 comments on commit 655ecfb

Please sign in to comment.