diff --git a/rasa/core/policies/ensemble.py b/rasa/core/policies/ensemble.py index 332101b51880..21ca0c7c77f0 100644 --- a/rasa/core/policies/ensemble.py +++ b/rasa/core/policies/ensemble.py @@ -29,9 +29,7 @@ ACTION_BACK_NAME, ) from rasa.shared.core.domain import InvalidDomain, Domain -from rasa.shared.core.events import ( - DefinePrevUserUtteredFeaturization, -) +from rasa.shared.core.events import DefinePrevUserUtteredFeaturization from rasa.shared.core.events import ActionExecutionRejected from rasa.core.exceptions import UnsupportedDialogueModelError from rasa.core.featurizers.tracker_featurizers import MaxHistoryTrackerFeaturizer diff --git a/rasa/core/policies/ted_policy.py b/rasa/core/policies/ted_policy.py index 9cf7be7a65f5..e66da9bec607 100644 --- a/rasa/core/policies/ted_policy.py +++ b/rasa/core/policies/ted_policy.py @@ -26,7 +26,11 @@ from rasa.shared.core.generator import TrackerWithCachedStates from rasa.utils import train_utils from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel -from rasa.utils.tensorflow.model_data import RasaModelData, FeatureSignature +from rasa.utils.tensorflow.model_data import ( + RasaModelData, + FeatureSignature, + FeatureArray, +) from rasa.utils.tensorflow.model_data_utils import convert_to_data_format from rasa.utils.tensorflow.constants import ( LABEL, @@ -264,7 +268,9 @@ def _create_label_data( label_ids = np.arange(domain.num_actions) label_data.add_features( - LABEL_KEY, LABEL_SUB_KEY, [np.expand_dims(label_ids, -1)] + LABEL_KEY, + LABEL_SUB_KEY, + [FeatureArray(np.expand_dims(label_ids, -1), number_of_dimensions=2)], ) return label_data, encoded_all_labels @@ -295,7 +301,11 @@ def _create_model_data( label_ids = np.array( [np.expand_dims(seq_label_ids, -1) for seq_label_ids in label_ids] ) - model_data.add_features(LABEL_KEY, LABEL_SUB_KEY, [label_ids]) + model_data.add_features( + LABEL_KEY, + LABEL_SUB_KEY, + [FeatureArray(label_ids, number_of_dimensions=3)], + ) attribute_data, self.zero_state_features = convert_to_data_format( tracker_state_features diff --git a/rasa/nlu/classifiers/diet_classifier.py b/rasa/nlu/classifiers/diet_classifier.py index 83c5ab87fe7c..2414c439755f 100644 --- a/rasa/nlu/classifiers/diet_classifier.py +++ b/rasa/nlu/classifiers/diet_classifier.py @@ -24,7 +24,11 @@ from rasa.utils import train_utils from rasa.utils.tensorflow import layers from rasa.utils.tensorflow.models import RasaModel, TransformerRasaModel -from rasa.utils.tensorflow.model_data import RasaModelData, FeatureSignature +from rasa.utils.tensorflow.model_data import ( + RasaModelData, + FeatureSignature, + FeatureArray, +) from rasa.nlu.constants import TOKENS_NAMES from rasa.shared.nlu.constants import ( TEXT, @@ -325,7 +329,7 @@ def __init__( self.model = model self._label_data: Optional[RasaModelData] = None - self._data_example: Optional[Dict[Text, List[np.ndarray]]] = None + self._data_example: Optional[Dict[Text, List[FeatureArray]]] = None @property def label_key(self) -> Optional[Text]: @@ -492,10 +496,10 @@ def _check_input_dimension_consistency(self, model_data: RasaModelData) -> None: """Checks if features have same dimensionality if hidden layers are shared.""" if self.component_config.get(SHARE_HIDDEN_LAYERS): - num_text_sentence_features = model_data.feature_dimension(TEXT, SENTENCE) - num_label_sentence_features = model_data.feature_dimension(LABEL, SENTENCE) - num_text_sequence_features = model_data.feature_dimension(TEXT, SEQUENCE) - num_label_sequence_features = model_data.feature_dimension(LABEL, SEQUENCE) + num_text_sentence_features = model_data.number_of_units(TEXT, SENTENCE) + num_label_sentence_features = model_data.number_of_units(LABEL, SENTENCE) + num_text_sequence_features = model_data.number_of_units(TEXT, SEQUENCE) + num_label_sequence_features = model_data.number_of_units(LABEL, SEQUENCE) if (0 < num_text_sentence_features != num_label_sentence_features > 0) or ( 0 < num_text_sequence_features != num_label_sequence_features > 0 @@ -507,7 +511,7 @@ def _check_input_dimension_consistency(self, model_data: RasaModelData) -> None: def _extract_labels_precomputed_features( self, label_examples: List[Message], attribute: Text = INTENT - ) -> Tuple[List[np.ndarray], List[np.ndarray]]: + ) -> Tuple[List[FeatureArray], List[FeatureArray]]: """Collects precomputed encodings.""" features = defaultdict(list) @@ -521,23 +525,32 @@ def _extract_labels_precomputed_features( sentence_features = [] for feature_name, feature_value in features.items(): if SEQUENCE in feature_name: - sequence_features.append(np.array(features[feature_name])) + sequence_features.append( + FeatureArray(np.array(feature_value), number_of_dimensions=3) + ) else: - sentence_features.append(np.array(features[feature_name])) + sentence_features.append( + FeatureArray(np.array(feature_value), number_of_dimensions=3) + ) - return (sequence_features, sentence_features) + return sequence_features, sentence_features @staticmethod def _compute_default_label_features( labels_example: List[Message], - ) -> List[np.ndarray]: + ) -> List[FeatureArray]: """Computes one-hot representation for the labels.""" logger.debug("No label features found. Computing default label features.") eye_matrix = np.eye(len(labels_example), dtype=np.float32) # add sequence dimension to one-hot labels - return [np.array([np.expand_dims(a, 0) for a in eye_matrix])] + return [ + FeatureArray( + np.array([np.expand_dims(a, 0) for a in eye_matrix]), + number_of_dimensions=3, + ) + ] def _create_label_data( self, @@ -590,16 +603,23 @@ def _create_label_data( # explicitly add last dimension to label_ids # to track correctly dynamic sequences label_data.add_features( - LABEL_KEY, LABEL_SUB_KEY, [np.expand_dims(label_ids, -1)] + LABEL_KEY, + LABEL_SUB_KEY, + [FeatureArray(np.expand_dims(label_ids, -1), number_of_dimensions=2)], ) label_data.add_lengths(LABEL, SEQUENCE_LENGTH, LABEL, SEQUENCE) return label_data - def _use_default_label_features(self, label_ids: np.ndarray) -> List[np.ndarray]: + def _use_default_label_features(self, label_ids: np.ndarray) -> List[FeatureArray]: all_label_features = self._label_data.get(LABEL, SENTENCE)[0] - return [np.array([all_label_features[label_id] for label_id in label_ids])] + return [ + FeatureArray( + np.array([all_label_features[label_id] for label_id in label_ids]), + number_of_dimensions=2, + ) + ] def _create_model_data( self, @@ -645,7 +665,11 @@ def _create_model_data( for key, attribute_features in features.items(): for sub_key, _features in attribute_features.items(): sub_key = sub_key.replace(f"{SPARSE}_", "").replace(f"{DENSE}_", "") - model_data.add_features(key, sub_key, [np.array(_features)]) + model_data.add_features( + key, + sub_key, + [FeatureArray(np.array(_features), number_of_dimensions=3)], + ) if ( label_attribute @@ -660,7 +684,9 @@ def _create_model_data( # explicitly add last dimension to label_ids # to track correctly dynamic sequences model_data.add_features( - LABEL_KEY, LABEL_SUB_KEY, [np.expand_dims(label_ids, -1)] + LABEL_KEY, + LABEL_SUB_KEY, + [FeatureArray(np.expand_dims(label_ids, -1), number_of_dimensions=2)], ) model_data.add_lengths(TEXT, SEQUENCE_LENGTH, TEXT, SEQUENCE) @@ -1028,7 +1054,7 @@ def _load_model( entity_tag_specs: List[EntityTagSpec], label_data: RasaModelData, meta: Dict[Text, Any], - data_example: Dict[Text, Dict[Text, List[np.ndarray]]], + data_example: Dict[Text, Dict[Text, List[FeatureArray]]], model_dir: Text, ) -> "RasaModel": file_name = meta.get("file") diff --git a/rasa/utils/tensorflow/model_data.py b/rasa/utils/tensorflow/model_data.py index 2a3cd16811ab..8bd3b0ac2721 100644 --- a/rasa/utils/tensorflow/model_data.py +++ b/rasa/utils/tensorflow/model_data.py @@ -24,6 +24,193 @@ logger = logging.getLogger(__name__) +class FeatureArray(np.ndarray): + """Stores any kind of features ready to be used by a RasaModel. + + Next to the input numpy array of features, it also received the number of dimensions of the features. + As our features can have 1 to 4 dimensions we might have different number of numpy arrays stacked. + The number of dimensions helps us to figure out how to handle this particular feature array. + Also, it is automatically determined whether the feature array is sparse or not and the number of units + is determined as well. + + Subclassing np.array: https://numpy.org/doc/stable/user/basics.subclassing.html + """ + + def __new__( + cls, input_array: np.ndarray, number_of_dimensions: int, + ) -> "FeatureArray": + FeatureArray._validate_number_of_dimensions(number_of_dimensions, input_array) + + feature_array = np.asarray(input_array).view(cls) + + if number_of_dimensions <= 2: + feature_array.units = input_array.shape[-1] + feature_array.is_sparse = isinstance(input_array[0], scipy.sparse.spmatrix) + elif number_of_dimensions == 3: + feature_array.units = input_array[0].shape[-1] + feature_array.is_sparse = isinstance(input_array[0], scipy.sparse.spmatrix) + elif number_of_dimensions == 4: + feature_array.units = input_array[0][0].shape[-1] + feature_array.is_sparse = isinstance( + input_array[0][0], scipy.sparse.spmatrix + ) + else: + raise ValueError( + f"Number of dimensions '{number_of_dimensions}' currently not supported." + ) + + feature_array.number_of_dimensions = number_of_dimensions + + return feature_array + + def __init__(self, input_array: Any, number_of_dimensions: int, **kwargs): + # Needed in order to avoid 'Invalid keyword argument number_of_dimensions to function FeatureArray.__init__ ' + super().__init__(**kwargs) + self.number_of_dimensions = number_of_dimensions + + def __array_finalize__(self, obj: Any) -> None: + if obj is None: + return + + self.units = getattr(obj, "units", None) + self.number_of_dimensions = getattr(obj, "number_of_dimensions", None) + self.is_sparse = getattr(obj, "is_sparse", None) + + default_attributes = { + "units": self.units, + "number_of_dimensions": self.number_of_dimensions, + "is_spare": self.is_sparse, + } + self.__dict__.update(default_attributes) + + # pytype: disable=attribute-error + def __array_ufunc__(self, ufunc, method, *inputs, **kwargs): + f = { + "reduce": ufunc.reduce, + "accumulate": ufunc.accumulate, + "reduceat": ufunc.reduceat, + "outer": ufunc.outer, + "at": ufunc.at, + "__call__": ufunc, + } + # convert the inputs to np.ndarray to prevent recursion, call the function, then cast it back as FeatureArray + output = FeatureArray( + f[method](*(i.view(np.ndarray) for i in inputs), **kwargs), + number_of_dimensions=kwargs["number_of_dimensions"], + ) + output.__dict__ = self.__dict__ # carry forward attributes + return output + + def __reduce__(self): + # Needed in order to pickle this object + pickled_state = super(FeatureArray, self).__reduce__() + new_state = pickled_state[2] + ( + self.number_of_dimensions, + self.is_sparse, + self.units, + ) + return pickled_state[0], pickled_state[1], new_state + + def __setstate__(self, state, **kwargs): + # Needed in order to load the object + self.number_of_dimensions = state[-3] + self.is_sparse = state[-2] + self.units = state[-1] + super(FeatureArray, self).__setstate__(state[0:-3], **kwargs) + + # pytype: enable=attribute-error + + @staticmethod + def _validate_number_of_dimensions( + number_of_dimensions: int, input_array: np.ndarray + ) -> None: + """Validates if the given number of dimensions maps the with the dimensions of the input array. + + Args: + number_of_dimensions: number of dimensions + input_array: input array + + Raises: ValueError in case the dimensions do not match + """ + _sub_array = input_array + dim = 0 + # Go number_of_dimensions into the given input_array + for i in range(1, number_of_dimensions + 1): + _sub_array = _sub_array[0] + if isinstance(_sub_array, scipy.sparse.spmatrix): + dim = i + break + + # If the resulting sub_array is sparse, the remaining number of dimensions should be at least 2 + if isinstance(_sub_array, scipy.sparse.spmatrix): + if dim > 2: + raise ValueError( + f"Given number of dimensions '{number_of_dimensions}' does not match dimensiona of given input " + f"array: {input_array}." + ) + # If the resulting sub_array is dense, the sub_array should be a single number + elif not np.issubdtype(type(_sub_array), np.integer) and not isinstance( + _sub_array, (np.float32, np.float64) + ): + raise ValueError( + f"Given number of dimensions '{number_of_dimensions}' does not match dimensiona of given input " + f"array: {input_array}." + ) + + def get_shape_type_info( + self, + ) -> Tuple[ + List[ + Union[ + int, + Tuple[None], + Tuple[None, int], + Tuple[None, None, int], + Tuple[None, None, None, int], + ] + ], + List[int], + ]: + """Returns the shape and type information needed to convert this feature array into tensors. + + Returns: + A list of shape tuples. + A list of type tuples. + """ + if self.is_sparse: + # scipy matrix is converted into indices, data, shape + return ( + [ + (None, self.number_of_dimensions), + (None,), + (self.number_of_dimensions), + ], + [tf.int64, tf.float32, tf.int64], + ) + + if self.number_of_dimensions == 1: + return [(None,)], [tf.float32] + + if self.number_of_dimensions == 2: + return [(None, self.units)], [tf.float32] + + if self.number_of_dimensions == 3: + return [(None, None, self.units)], [tf.float32] + + if self.number_of_dimensions == 4: + return [(None, None, None, self.units)], [tf.float32] + + return [], [] + + +class FeatureSignature(NamedTuple): + """Stores the shape, the type (sparse vs dense), and the number of dimensions of features.""" + + is_sparse: bool + units: Optional[int] + number_of_dimensions: int + + # Mapping of attribute name and feature name to a list of numpy arrays representing # the actual features # For example: @@ -31,14 +218,7 @@ # "numpy array containing dense features for every training example", # "numpy array containing sparse features for every training example" # ]} -Data = Dict[Text, Dict[Text, List[np.ndarray]]] - - -class FeatureSignature(NamedTuple): - """Stores the shape and the type (sparse vs dense) of features.""" - - is_sparse: bool - feature_dimension: Optional[int] +Data = Dict[Text, Dict[Text, List[FeatureArray]]] class RasaModelData: @@ -70,7 +250,9 @@ def __init__( def get( self, key: Text, sub_key: Optional[Text] = None - ) -> Union[Dict[Text, List[np.ndarray]], List[np.ndarray]]: + ) -> Union[ + Dict[Text, List[FeatureArray]], List[FeatureArray], + ]: """Get the data under the given keys. Args: @@ -96,7 +278,7 @@ def items(self) -> ItemsView: """ return self.data.items() - def values(self) -> ValuesView[Dict[Text, List[np.ndarray]]]: + def values(self) -> ValuesView[Dict[Text, List[FeatureArray]]]: """Return the values of the data attribute. Returns: @@ -177,7 +359,7 @@ def number_of_examples(self, data: Optional[Data] = None) -> int: return 0 example_lengths = [ - f.shape[0] + len(f) for attribute_data in data.values() for features in attribute_data.values() for f in features @@ -195,25 +377,25 @@ def number_of_examples(self, data: Optional[Data] = None) -> int: return example_lengths[0] - def feature_dimension(self, key: Text, sub_key: Text) -> int: - """Get the feature dimension of the given key. + def number_of_units(self, key: Text, sub_key: Text) -> int: + """Get the number of units of the given key. Args: key: The key. sub_key: The optional sub-key. Returns: - The feature dimension. + The number of units. """ if key not in self.data or sub_key not in self.data[key]: return 0 - number_of_features = 0 - for data in self.data[key][sub_key]: - if data.size > 0: - number_of_features += data[0].shape[-1] + units = 0 + for features in self.data[key][sub_key]: + if len(features) > 0: + units += features.units - return number_of_features + return units def add_data(self, data: Data, key_prefix: Optional[Text] = None) -> None: """Add incoming data to data. @@ -230,7 +412,7 @@ def add_data(self, data: Data, key_prefix: Optional[Text] = None) -> None: self.add_features(key, sub_key, features) def add_features( - self, key: Text, sub_key: Text, features: Optional[List[np.ndarray]] + self, key: Text, sub_key: Text, features: Optional[List[FeatureArray]], ) -> None: """Add list of features to data under specified key. @@ -245,7 +427,7 @@ def add_features( return for data in features: - if data.size > 0: + if len(data) > 0: self.data[key][sub_key].append(data) if not self.data[key][sub_key]: @@ -273,9 +455,11 @@ def add_lengths( self.data[key][sub_key] = [] for data in self.data[from_key][from_sub_key]: - if data.size > 0: + if len(data) > 0: lengths = np.array([x.shape[0] for x in data]) - self.data[key][sub_key].extend([lengths]) + self.data[key][sub_key].extend( + [FeatureArray(lengths, number_of_dimensions=1)] + ) break def split( @@ -347,7 +531,9 @@ def split( return self._convert_train_test_split(output_values, solo_values) - def get_signature(self) -> Dict[Text, Dict[Text, List[FeatureSignature]]]: + def get_signature( + self, data: Optional[Data] = None + ) -> Dict[Text, Dict[Text, List[FeatureSignature]]]: """Get signature of RasaModelData. Signature stores the shape and whether features are sparse or not for every key. @@ -356,19 +542,18 @@ def get_signature(self) -> Dict[Text, Dict[Text, List[FeatureSignature]]]: A dictionary of key and sub-key to a list of feature signatures (same structure as the data attribute). """ + if not data: + data = self.data return { key: { sub_key: [ - FeatureSignature( - True if isinstance(f[0], scipy.sparse.spmatrix) else False, - f[0].shape[-1] if f[0].shape else None, - ) + FeatureSignature(f.is_sparse, f.units, f.number_of_dimensions) for f in features ] for sub_key, features in attribute_data.items() } - for key, attribute_data in self.data.items() + for key, attribute_data in data.items() } def as_tf_dataset( @@ -440,7 +625,7 @@ def prepare_batch( else: _data = v[:] - if isinstance(_data[0], scipy.sparse.spmatrix): + if _data.is_sparse: batch_data.extend(self._scipy_matrix_to_values(_data)) else: batch_data.append(self._pad_dense_data(_data)) @@ -454,37 +639,15 @@ def _get_shapes_types(self) -> Tuple: Returns: A tuple of shapes and a tuple of types. """ - types = [] shapes = [] - def append_shape(features: np.ndarray) -> None: - if isinstance(features[0], scipy.sparse.spmatrix): - # scipy matrix is converted into indices, data, shape - shapes.append((None, features[0].ndim + 1)) - shapes.append((None,)) - shapes.append((features[0].ndim + 1)) - elif features[0].ndim == 0: - shapes.append((None,)) - elif features[0].ndim == 1: - shapes.append((None, features[0].shape[-1])) - else: - shapes.append((None, None, features[0].shape[-1])) - - def append_type(features: np.ndarray) -> None: - if isinstance(features[0], scipy.sparse.spmatrix): - # scipy matrix is converted into indices, data, shape - types.append(tf.int64) - types.append(tf.float32) - types.append(tf.int64) - else: - types.append(tf.float32) - for attribute_data in self.data.values(): for features in attribute_data.values(): for f in features: - append_shape(f) - append_type(f) + _shapes, _types = f.get_shape_type_info() + shapes.extend(_shapes) + types.extend(_types) return tuple(shapes), tuple(types) @@ -584,7 +747,12 @@ def _balanced_data(self, data: Data, batch_size: int, shuffle: bool) -> Data: for key, attribute_data in new_data.items(): for sub_key, features in attribute_data.items(): for f in features: - final_data[key][sub_key].append(np.concatenate(np.array(f))) + final_data[key][sub_key].append( + FeatureArray( + np.concatenate(np.array(f)), + number_of_dimensions=f[0].number_of_dimensions, + ) + ) return final_data @@ -685,7 +853,7 @@ def _split_by_label_ids( label_data = [] for label_id in unique_label_ids: - matching_ids = label_ids == label_id + matching_ids = np.array(label_ids) == label_id label_data.append( RasaModelData( self.label_key, @@ -785,7 +953,7 @@ def _combine_features( return np.concatenate([feature_1, feature_2]) @staticmethod - def _create_label_ids(label_ids: np.ndarray) -> np.ndarray: + def _create_label_ids(label_ids: FeatureArray) -> np.ndarray: """Convert various size label_ids into single dim array. For multi-label y, map each distinct row to a string representation @@ -795,10 +963,12 @@ def _create_label_ids(label_ids: np.ndarray) -> np.ndarray: Args: label_ids: The label ids. + Raises: + ValueError if dimensionality of label ids is not supported + Returns: The single dim label array. """ - if label_ids.ndim == 1: return label_ids @@ -814,7 +984,7 @@ def _create_label_ids(label_ids: np.ndarray) -> np.ndarray: raise ValueError("Unsupported label_ids dimensions") @staticmethod - def _pad_dense_data(array_of_dense: np.ndarray) -> np.ndarray: + def _pad_dense_data(array_of_dense: FeatureArray) -> np.ndarray: """Pad data of different lengths. Sequential data is padded with zeros. Zeros are added to the end of data. @@ -825,6 +995,8 @@ def _pad_dense_data(array_of_dense: np.ndarray) -> np.ndarray: Returns: The padded array. """ + if array_of_dense.number_of_dimensions == 4: + return RasaModelData._pad_4d_dense_data(array_of_dense) if array_of_dense[0].ndim < 2: # data doesn't contain a sequence @@ -843,7 +1015,39 @@ def _pad_dense_data(array_of_dense: np.ndarray) -> np.ndarray: return data_padded.astype(np.float32) @staticmethod - def _scipy_matrix_to_values(array_of_sparse: np.ndarray) -> List[np.ndarray]: + def _pad_4d_dense_data(array_of_array_of_dense: FeatureArray) -> np.ndarray: + # in case of dialogue data we may have 4 dimensions + # batch size x dialogue history length x sequence length x number of features + data_size = len(array_of_array_of_dense) + max_dialogue_len = max( + len(array_of_dense) for array_of_dense in array_of_array_of_dense + ) + max_seq_len = max( + [ + x.shape[0] + for array_of_dense in array_of_array_of_dense + for x in array_of_dense + ] + ) + + data_padded = np.zeros( + [ + data_size, + max_dialogue_len, + max_seq_len, + array_of_array_of_dense[0][0].shape[-1], + ], + dtype=array_of_array_of_dense[0][0].dtype, + ) + + for i, array_of_dense in enumerate(array_of_array_of_dense): + for j, dense in enumerate(array_of_dense): + data_padded[i, j, : dense.shape[0], :] = dense + + return data_padded.astype(np.float32) + + @staticmethod + def _scipy_matrix_to_values(array_of_sparse: FeatureArray) -> List[np.ndarray]: """Convert a scipy matrix into indices, data, and shape. Args: @@ -852,6 +1056,8 @@ def _scipy_matrix_to_values(array_of_sparse: np.ndarray) -> List[np.ndarray]: Returns: A list of dense numpy arrays representing the sparse data. """ + if array_of_sparse.number_of_dimensions == 4: + return RasaModelData._4d_scipy_matrix_to_values(array_of_sparse) # we need to make sure that the matrices are coo_matrices otherwise the # transformation does not work (e.g. you cannot access x.row, x.col) @@ -878,3 +1084,61 @@ def _scipy_matrix_to_values(array_of_sparse: np.ndarray) -> List[np.ndarray]: data.astype(np.float32), shape.astype(np.int64), ] + + @staticmethod + def _4d_scipy_matrix_to_values(array_of_array_of_sparse: FeatureArray): + # in case of dialogue data we may have 4 dimensions + # batch size x dialogue history length x sequence length x number of features + + # we need to make sure that the matrices are coo_matrices otherwise the + # transformation does not work (e.g. you cannot access x.row, x.col) + if not isinstance(array_of_array_of_sparse[0][0], scipy.sparse.coo_matrix): + array_of_array_of_sparse = [ + [x.tocoo() for x in array_of_sparse] + for array_of_sparse in array_of_array_of_sparse + ] + + max_dialogue_len = max( + [len(array_of_sparse) for array_of_sparse in array_of_array_of_sparse] + ) + max_seq_len = max( + [ + x.shape[0] + for array_of_sparse in array_of_array_of_sparse + for x in array_of_sparse + ] + ) + # get the indices of values + indices = np.hstack( + [ + np.vstack( + [i * np.ones_like(x.row), j * np.ones_like(x.row), x.row, x.col,] + ) + for i, array_of_sparse in enumerate(array_of_array_of_sparse) + for j, x in enumerate(array_of_sparse) + ] + ).T + + data = np.hstack( + [ + x.data + for array_of_sparse in array_of_array_of_sparse + for x in array_of_sparse + ] + ) + + number_of_features = array_of_array_of_sparse[0][0].shape[-1] + shape = np.array( + ( + len(array_of_array_of_sparse), + max_dialogue_len, + max_seq_len, + number_of_features, + ) + ) + + return [ + indices.astype(np.int64), + data.astype(np.float32), + shape.astype(np.int64), + ] diff --git a/rasa/utils/tensorflow/model_data_utils.py b/rasa/utils/tensorflow/model_data_utils.py index 9b8b3094a893..68e621b839d3 100644 --- a/rasa/utils/tensorflow/model_data_utils.py +++ b/rasa/utils/tensorflow/model_data_utils.py @@ -5,7 +5,7 @@ from collections import defaultdict, OrderedDict import scipy.sparse -from rasa.utils.tensorflow.model_data import Data +from rasa.utils.tensorflow.model_data import Data, FeatureArray from rasa.utils.tensorflow.constants import SEQUENCE if typing.TYPE_CHECKING: @@ -158,7 +158,7 @@ def _features_for_attribute( state_to_tracker_features: Dict[Text, List[List[List["Features"]]]], training: bool, zero_state_features: Dict[Text, List["Features"]], -) -> Dict[Text, List[np.ndarray]]: +) -> Dict[Text, List[FeatureArray]]: """Create the features for the given attribute from the tracker features. Args: @@ -186,37 +186,43 @@ def _features_for_attribute( tracker_features, zero_state_features[attribute] ) - sparse_features = defaultdict(list) - dense_features = defaultdict(list) + sparse_features = {} + dense_features = {} - # vstack serves as removing dimension - # TODO check vstack for sequence features + # vstack serves as removing dimension in case we are not dealing with a sequence for key, values in _sparse_features.items(): - sparse_features[key] = [scipy.sparse.vstack(value) for value in values] + if key == SEQUENCE: + sparse_features[key] = FeatureArray( + np.array(values), number_of_dimensions=4 + ) + else: + features = [scipy.sparse.vstack(value) for value in values] + sparse_features[key] = FeatureArray( + np.array(features), number_of_dimensions=3 + ) for key, values in _dense_features.items(): - dense_features[key] = [np.vstack(value) for value in values] + if key == SEQUENCE: + dense_features[key] = FeatureArray(np.array(values), number_of_dimensions=4) + else: + features = [np.vstack(value) for value in values] + dense_features[key] = FeatureArray( + np.array(features), number_of_dimensions=3 + ) - attribute_features = {MASK: [np.array(attribute_masks)]} + attribute_features = { + MASK: [FeatureArray(np.array(attribute_masks), number_of_dimensions=3)] + } feature_types = set() feature_types.update(list(dense_features.keys())) feature_types.update(list(sparse_features.keys())) for feature_type in feature_types: - if feature_type == SEQUENCE: - # TODO we don't take sequence features because that makes us deal - # with 4D sparse tensors - continue - attribute_features[feature_type] = [] if feature_type in sparse_features: - attribute_features[feature_type].append( - np.array(sparse_features[feature_type]) - ) + attribute_features[feature_type].append(sparse_features[feature_type]) if feature_type in dense_features: - attribute_features[feature_type].append( - np.array(dense_features[feature_type]) - ) + attribute_features[feature_type].append(dense_features[feature_type]) return attribute_features diff --git a/rasa/utils/tensorflow/models.py b/rasa/utils/tensorflow/models.py index 18cb971eaa6b..d26e77efe5ef 100644 --- a/rasa/utils/tensorflow/models.py +++ b/rasa/utils/tensorflow/models.py @@ -525,20 +525,15 @@ def batch_to_model_data_format( idx = 0 for key, values in data_signature.items(): for sub_key, signature in values.items(): - for is_sparse, feature_dimension in signature: + for is_sparse, feature_dimension, number_of_dimensions in signature: if is_sparse: # explicitly substitute last dimension in shape with known # static value + shape = [ + batch[idx + 2][i] for i in range(number_of_dimensions - 1) + ] + [feature_dimension] batch_data[key][sub_key].append( - tf.SparseTensor( - batch[idx], - batch[idx + 1], - [ - batch[idx + 2][0], - batch[idx + 2][1], - feature_dimension, - ], - ) + tf.SparseTensor(batch[idx], batch[idx + 1], shape,) ) idx += 3 else: @@ -754,7 +749,7 @@ def _prepare_sparse_dense_layers( ) -> None: sparse = False dense = False - for is_sparse, _ in data_signature: + for is_sparse, _, _ in data_signature: if is_sparse: sparse = True else: diff --git a/tests/utils/tensorflow/test_model_data.py b/tests/utils/tensorflow/test_model_data.py index 26cf4d2d1781..3a373ac2dd73 100644 --- a/tests/utils/tensorflow/test_model_data.py +++ b/tests/utils/tensorflow/test_model_data.py @@ -1,64 +1,178 @@ import copy +from typing import Union, List import pytest import scipy.sparse import numpy as np -from rasa.utils.tensorflow.model_data import RasaModelData +from rasa.utils.tensorflow.model_data import RasaModelData, FeatureArray @pytest.fixture async def model_data() -> RasaModelData: return RasaModelData( - label_key="intent", + label_key="label", label_sub_key="ids", data={ - "text_features": { + "text": { "sentence": [ - np.array( - [ - np.random.rand(5, 14), - np.random.rand(2, 14), - np.random.rand(3, 14), - np.random.rand(1, 14), - np.random.rand(3, 14), - ] + FeatureArray( + np.array( + [ + np.random.rand(5, 14), + np.random.rand(2, 14), + np.random.rand(3, 14), + np.random.rand(1, 14), + np.random.rand(3, 14), + ] + ), + number_of_dimensions=3, ), - np.array( - [ - scipy.sparse.csr_matrix(np.random.randint(5, size=(5, 10))), - scipy.sparse.csr_matrix(np.random.randint(5, size=(2, 10))), - scipy.sparse.csr_matrix(np.random.randint(5, size=(3, 10))), - scipy.sparse.csr_matrix(np.random.randint(5, size=(1, 10))), - scipy.sparse.csr_matrix(np.random.randint(5, size=(3, 10))), - ] + FeatureArray( + np.array( + [ + scipy.sparse.csr_matrix( + np.random.randint(5, size=(5, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(2, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + ] + ), + number_of_dimensions=3, ), ] }, - "intent_features": { + "action_text": { + "sequence": [ + FeatureArray( + np.array( + [ + [ + scipy.sparse.csr_matrix( + np.random.randint(5, size=(5, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(2, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + ], + [ + scipy.sparse.csr_matrix( + np.random.randint(5, size=(5, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(2, 10)) + ), + ], + [ + scipy.sparse.csr_matrix( + np.random.randint(5, size=(5, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + ], + [ + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + ], + [ + scipy.sparse.csr_matrix( + np.random.randint(5, size=(3, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(5, size=(7, 10)) + ), + ], + ] + ), + number_of_dimensions=4, + ), + FeatureArray( + np.array( + [ + [ + np.random.rand(5, 14), + np.random.rand(2, 14), + np.random.rand(3, 14), + np.random.rand(1, 14), + np.random.rand(3, 14), + ], + [np.random.rand(5, 14), np.random.rand(2, 14),], + [ + np.random.rand(5, 14), + np.random.rand(1, 14), + np.random.rand(3, 14), + ], + [np.random.rand(3, 14),], + [ + np.random.rand(3, 14), + np.random.rand(1, 14), + np.random.rand(7, 14), + ], + ] + ), + number_of_dimensions=4, + ), + ] + }, + "dialogue": { "sentence": [ - np.array( - [ - np.random.randint(2, size=(5, 10)), - np.random.randint(2, size=(2, 10)), - np.random.randint(2, size=(3, 10)), - np.random.randint(2, size=(1, 10)), - np.random.randint(2, size=(3, 10)), - ] + FeatureArray( + np.array( + [ + np.random.randint(2, size=(5, 10)), + np.random.randint(2, size=(2, 10)), + np.random.randint(2, size=(3, 10)), + np.random.randint(2, size=(1, 10)), + np.random.randint(2, size=(3, 10)), + ] + ), + number_of_dimensions=3, ) ] }, - "intent": {"ids": [np.array([0, 1, 0, 1, 1])]}, + "label": { + "ids": [FeatureArray(np.array([0, 1, 0, 1, 1]), number_of_dimensions=1)] + }, "entities": { "tag_ids": [ - np.array( - [ - np.array([[0], [1], [1], [0], [2]]), - np.array([[2], [0]]), - np.array([[0], [1], [1]]), - np.array([[0], [1]]), - np.array([[0], [0], [0]]), - ] + FeatureArray( + np.array( + [ + np.array([[0], [1], [1], [0], [2]]), + np.array([[2], [0]]), + np.array([[0], [1], [1]]), + np.array([[0], [1]]), + np.array([[0], [0], [0]]), + ] + ), + number_of_dimensions=3, ) ] }, @@ -86,12 +200,17 @@ def test_shuffle_session_data(model_data: RasaModelData): def test_split_data_by_label(model_data: RasaModelData): split_model_data = model_data._split_by_label_ids( - model_data.data, model_data.get("intent", "ids")[0], np.array([0, 1]) + model_data.data, model_data.get("label", "ids")[0], np.array([0, 1]) ) assert len(split_model_data) == 2 for s in split_model_data: - assert len(set(s.get("intent", "ids")[0])) == 1 + assert len(set(s.get("label", "ids")[0])) == 1 + + for key, attribute_data in split_model_data[0].items(): + for sub_key, features in attribute_data.items(): + assert len(features) == len(model_data.data[key][sub_key]) + assert len(features[0]) == 2 def test_split_data_by_none_label(model_data: RasaModelData): @@ -106,9 +225,9 @@ def test_split_data_by_none_label(model_data: RasaModelData): test_data = split_model_data[1] # train data should have 3 examples - assert len(train_data.get("intent", "ids")[0]) == 3 + assert len(train_data.get("label", "ids")[0]) == 3 # test data should have 2 examples - assert len(test_data.get("intent", "ids")[0]) == 2 + assert len(test_data.get("label", "ids")[0]) == 2 def test_train_val_split(model_data: RasaModelData): @@ -121,17 +240,23 @@ def test_train_val_split(model_data: RasaModelData): assert len(data) == len(train_model_data.get(key, sub_key)) assert len(data) == len(test_model_data.get(key, sub_key)) for i, v in enumerate(data): - assert v[0].dtype == train_model_data.get(key, sub_key)[i][0].dtype + if isinstance(v[0], list): + assert ( + v[0][0].dtype + == train_model_data.get(key, sub_key)[i][0][0].dtype + ) + else: + assert v[0].dtype == train_model_data.get(key, sub_key)[i][0].dtype for values in train_model_data.values(): for data in values.values(): for v in data: - assert v.shape[0] == 3 + assert np.array(v).shape[0] == 3 for values in test_model_data.values(): for data in values.values(): for v in data: - assert v.shape[0] == 2 + assert np.array(v).shape[0] == 2 @pytest.mark.parametrize("size", [0, 1, 5]) @@ -146,7 +271,7 @@ def test_session_data_for_ids(model_data: RasaModelData): for values in filtered_data.values(): for data in values.values(): for v in data: - assert v.shape[0] == 2 + assert np.array(v).shape[0] == 2 key = model_data.keys()[0] sub_key = model_data.keys(key)[0] @@ -174,27 +299,32 @@ def test_get_number_of_examples_raises_value_error(model_data: RasaModelData): def test_gen_batch(model_data: RasaModelData): iterator = model_data._gen_batch(2, shuffle=True, batch_strategy="balanced") - print(model_data.data["entities"]["tag_ids"][0]) + batch = next(iterator) - assert len(batch) == 7 + assert len(batch) == 11 assert len(batch[0]) == 2 batch = next(iterator) - assert len(batch) == 7 + assert len(batch) == 11 assert len(batch[0]) == 2 batch = next(iterator) - assert len(batch) == 7 + assert len(batch) == 11 assert len(batch[0]) == 1 with pytest.raises(StopIteration): next(iterator) +def test_is_in_4d_format(model_data: RasaModelData): + assert model_data.data["action_text"]["sequence"][0].number_of_dimensions == 4 + assert model_data.data["text"]["sentence"][0].number_of_dimensions == 3 + + def test_balance_model_data(model_data: RasaModelData): data = model_data._balanced_data(model_data.data, 2, False) - assert np.all(data["intent"]["ids"][0] == np.array([0, 1, 1, 0, 1])) + assert np.all(np.array(data["label"]["ids"][0]) == np.array([0, 1, 1, 0, 1])) def test_not_balance_model_data(model_data: RasaModelData): @@ -210,6 +340,146 @@ def test_not_balance_model_data(model_data: RasaModelData): def test_get_num_of_features(model_data: RasaModelData): - num_features = model_data.feature_dimension("text_features", "sentence") + num_features = model_data.number_of_units("text", "sentence") assert num_features == 24 + + +@pytest.mark.parametrize( + "incoming_data, expected_shape", + [ + (FeatureArray(np.random.rand(7, 12), number_of_dimensions=2), (7, 12)), + (FeatureArray(np.random.rand(7), number_of_dimensions=1), (7,)), + ( + FeatureArray( + np.array( + [ + np.random.rand(1, 10), + np.random.rand(3, 10), + np.random.rand(7, 10), + np.random.rand(1, 10), + ] + ), + number_of_dimensions=3, + ), + (4, 7, 10), + ), + ( + FeatureArray( + np.array( + [ + np.array( + [ + np.random.rand(1, 10), + np.random.rand(5, 10), + np.random.rand(7, 10), + ] + ), + np.array( + [ + np.random.rand(1, 10), + np.random.rand(3, 10), + np.random.rand(3, 10), + np.random.rand(7, 10), + ] + ), + np.array([np.random.rand(2, 10),]), + ] + ), + number_of_dimensions=4, + ), + (3, 4, 7, 10), + ), + ], +) +def test_pad_dense_data(incoming_data: FeatureArray, expected_shape: np.ndarray): + padded_data = RasaModelData._pad_dense_data(incoming_data) + + assert padded_data.shape == expected_shape + + +@pytest.mark.parametrize( + "incoming_data, expected_shape", + [ + ( + FeatureArray( + np.array([scipy.sparse.csr_matrix(np.random.randint(5, size=(7, 12)))]), + number_of_dimensions=3, + ), + [1, 7, 12], + ), + ( + FeatureArray( + np.array([scipy.sparse.csr_matrix(np.random.randint(5, size=(7,)))]), + number_of_dimensions=2, + ), + [1, 1, 7], + ), + ( + FeatureArray( + np.array( + [ + scipy.sparse.csr_matrix(np.random.randint(10, size=(1, 10))), + scipy.sparse.csr_matrix(np.random.randint(10, size=(3, 10))), + scipy.sparse.csr_matrix(np.random.randint(10, size=(7, 10))), + scipy.sparse.csr_matrix(np.random.randint(10, size=(1, 10))), + ] + ), + number_of_dimensions=3, + ), + (4, 7, 10), + ), + ( + FeatureArray( + np.array( + [ + np.array( + [ + scipy.sparse.csr_matrix( + np.random.randint(10, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(10, size=(5, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(10, size=(7, 10)) + ), + ] + ), + np.array( + [ + scipy.sparse.csr_matrix( + np.random.randint(10, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(10, size=(3, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(10, size=(1, 10)) + ), + scipy.sparse.csr_matrix( + np.random.randint(10, size=(7, 10)) + ), + ] + ), + np.array( + [ + scipy.sparse.csr_matrix( + np.random.randint(10, size=(2, 10)) + ), + ] + ), + ] + ), + number_of_dimensions=4, + ), + (3, 4, 7, 10), + ), + ], +) +def test_scipy_matrix_to_values( + incoming_data: FeatureArray, expected_shape: np.ndarray +): + indices, data, shape = RasaModelData._scipy_matrix_to_values(incoming_data) + + assert np.all(shape == expected_shape)