diff --git a/.gitignore b/.gitignore index 6ed2d28..bb5985c 100644 --- a/.gitignore +++ b/.gitignore @@ -9,6 +9,7 @@ __pycache__/ /build /wheels /wheelhouse +emlangkit.egg-info # PyCharm .idea/ diff --git a/emlangkit/language.py b/emlangkit/language.py index 7037a30..b8ed742 100644 --- a/emlangkit/language.py +++ b/emlangkit/language.py @@ -4,6 +4,7 @@ import numpy as np import emlangkit.metrics as metrics +import emlangkit.utils as utils class Language: @@ -40,6 +41,7 @@ def __init__( observations: Optional[np.ndarray] = None, prev_horizon: int = 8, seed: int = 42, + has_threshold: float = 0.8, ): if not isinstance(messages, np.ndarray): raise ValueError("Language only accepts numpy arrays!") @@ -70,6 +72,22 @@ def __init__( self.__mpn_value = None self.prev_horizon = prev_horizon + # HAS placeholders + self.has_threshold = has_threshold + self.__alpha = None + self.__freq = None + self.__branching_entropy = None + self.__conditional_entropy = None + self.__boundaries = None + self.__segments = None + self.__segment_ids = None + self.__hashed_segments = None + self.__random_boundaries = None + self.__random_segments = None + self.__random_segment_ids = None + self.__random_hashed_segments = None + self.__has_stats = None + def topsim(self) -> tuple[float, float]: """ Calculate the topographic similarity score for the language. @@ -83,6 +101,11 @@ def topsim(self) -> tuple[float, float]: Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ if self.observations is None: raise ValueError( @@ -109,6 +132,11 @@ def posdis(self): Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ if self.observations is None: raise ValueError( @@ -134,6 +162,11 @@ def bosdis(self): Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ if self.observations is None: raise ValueError( @@ -159,6 +192,11 @@ def language_entropy(self): Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ # This may have been calculated previously if self.__langauge_entropy_value is None: @@ -179,6 +217,11 @@ def observation_entropy(self): Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ if self.observations is None: raise ValueError( @@ -205,16 +248,20 @@ def mutual_information(self): Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ if self.observations is None: raise ValueError("Observations are needed to calculate mutual information!") - if self.__observation_entropy_value is None: - self.observation_entropy() - if self.__langauge_entropy_value is None: - self.language_entropy() - if self.__mutual_information_value is None: + if self.__observation_entropy_value is None: + self.observation_entropy() + if self.__langauge_entropy_value is None: + self.language_entropy() self.__mutual_information_value = metrics.compute_mutual_information( self.messages, self.observations, @@ -238,6 +285,11 @@ def mpn(self): Raises ------ ValueError: If observations are not set. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. """ if self.observations is None: raise ValueError("Observations are needed to calculate M_previous^n.") @@ -248,3 +300,331 @@ def mpn(self): ) return self.__mpn_value + + # Harris' Articulation Scheme metrics + def branching_entropy(self): + """ + Calculate the branching entropy for a given language. + + Returns + ------- + float: The calculated branching entropy value. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + """ + if self.__branching_entropy is None: + if self.__freq is None: + self.__alpha, self.__freq = metrics.has_init(self.messages) + self.__branching_entropy = metrics.compute_branching_entropy( + self.__alpha, self.__freq + ) + + return self.__branching_entropy + + def conditional_entropy(self): + """ + Calculate the conditional entropy for a given language. + + Returns + ------- + float + The calculated conditional entropy value. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + """ + # No need to even check for __freq as branching entropy already requires that + if self.__conditional_entropy is None: + if self.__branching_entropy is None: + self.branching_entropy() + self.__conditional_entropy = metrics.compute_conditional_entropy( + self.__branching_entropy, self.__freq + ) + + return self.__conditional_entropy + + def boundaries(self, return_count: bool = False, return_mean: bool = False): + """ + Calculate the HAS boundaries for a given language. + + Parameters + ---------- + return_count : bool, optional + If True, the method will return the boundaries and the count of each boundary. + Default is False. + + return_mean : bool, optional + If True, the method will return the boundaries, the count of each boundary, + and the mean count. Default is False. + + Returns + ------- + boundaries : list of lists + A list of boundary lists for each message in the language. + + Optional Returns: + If `return_count` is True, the method will also return `nb`, which is a list + containing the count of each boundary. + + If `return_mean` is True, the method will also return `nb` and `mean`. `nb` is + a list containing the count of each boundary, and `mean` is the mean count. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + """ + if self.__boundaries is None: + if self.__branching_entropy is None: + self.branching_entropy() + self.__boundaries = metrics.compute_boundaries( + self.messages, self.__branching_entropy, threshold=self.has_threshold + ) + + if return_count: + nb = [len(b) for b in self.__boundaries] + return self.__boundaries, nb + + if return_mean: + nb = [len(b) for b in self.__boundaries] + mean = np.mean(nb) + return self.__boundaries, nb, mean + + return self.__boundaries + + def random_boundaries( + self, + return_count: bool = False, + return_mean: bool = False, + recompute: bool = False, + ): + """ + Calculate the random HAS boundaries for a given language. + + Parameters + ---------- + return_count : bool, optional + If True, returns the random boundaries along with the number of boundary items for each boundary. + Default is False. + return_mean : bool, optional + If True, returns the random boundaries along with the number of boundary items for each boundary, + as well as the mean number of boundary items across all boundaries. + Default is False. + recompute : bool, optional + If True, forces the recomputation of the random boundaries. + Default is False. + + Returns + ------- + boundaries : list of lists + A list of random boundary lists for each message in the language. + + Optional Returns: + If `return_count` is True, the method will also return `nb`, which is a list + containing the count of each boundary. + + If `return_mean` is True, the method will also return `nb` and `mean`. `nb` is + a list containing the count of each boundary, and `mean` is the mean count. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + """ + if self.__random_boundaries is None and not recompute: + if self.__boundaries is None: + self.boundaries() + self.__random_boundaries = metrics.compute_random_boundaries( + self.messages, self.__boundaries, self.__rng + ) + + if return_count: + nb = [len(b) for b in self.__random_boundaries] + return self.__random_boundaries, nb + + if return_mean: + nb = [len(b) for b in self.__random_boundaries] + mean = np.mean(nb) + return self.__random_boundaries, nb, mean + + return self.__random_boundaries + + def segments(self, return_ids: bool = False, return_hashed_segments: bool = False): + """ + Calculate the HAS segments for a given language. + + Parameters + ---------- + return_ids : bool, optional + If True, returns the segments along with their corresponding segment ids. + Default is False. + + return_hashed_segments : bool, optional + If True, returns the segments along with their hashed versions. + Default is False. + + Returns + ------- + numpy.ndarray + Array of segments. + + Optional Returns: + If `return_ids` is True, the method will also return segment_ids. + If `return_hashed_segments` is True, the method will also return the hashed segments. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + + """ + if self.__segments is None: + if self.__boundaries is None: + self.boundaries() + ( + self.__segments, + self.__segment_ids, + self.__hashed_segments, + ) = metrics.compute_segments(self.messages, self.__boundaries) + + if return_ids: + return self.__segments, self.__segment_ids + + if return_hashed_segments: + return self.__segments, self.__hashed_segments + + if return_ids and return_hashed_segments: + return self.__segments, self.__segment_ids, self.__hashed_segments + + return self.__segments + + def random_segments( + self, + return_ids: bool = False, + return_hashed_segments: bool = False, + recompute: bool = False, + ): + """ + Calculate the random HAS segments for a given language. + + Parameters + ---------- + return_ids : bool, optional + Specifies whether to return segment IDs along with the segments. Default is False. + return_hashed_segments : bool, optional + Specifies whether to return hashed segments along with the segments. Default is False. + recompute : bool, optional + Specifies whether to recompute the random segments. Default is False. + + Returns + ------- + numpy.ndarray + Array of segments. + + Optional Returns: + If `return_ids` is True, the method will also return segment_ids. + If `return_hashed_segments` is True, the method will also return the hashed segments. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + """ + if self.__random_segments is None and not recompute: + if self.__random_boundaries is None and not recompute: + self.random_boundaries() + ( + self.__random_segments, + self.__random_segment_ids, + self.__random_hashed_segments, + ) = metrics.compute_segments(self.messages, self.__random_boundaries) + + if return_ids: + return self.__random_segments, self.__random_segment_ids + + if return_hashed_segments: + return self.__random_segments, self.__random_hashed_segments + + if return_ids and return_hashed_segments: + return ( + self.__random_segments, + self.__random_segment_ids, + self.__random_hashed_segments, + ) + + return self.__random_segments + + def has_stats(self, compute_topsim: bool = False) -> dict: + """ + Calculate the HAS statistics for a given language. + + Parameters + ---------- + compute_topsim : bool, optional + Flag indicating whether to compute topographic similarity. Default is False. + + Returns + ------- + dict + A dictionary containing various statistics related to the language. + + Raises + ------ + ValueError + If observations are None and compute_topsim is True. + + Notes + ----- + The result is cached and will only be computed once. + Subsequent calls to this method will return the cached value. + + """ + if self.__has_stats is None: + if self.observations is None and compute_topsim: + raise ValueError( + "Observations are needed to calculate topographic similarity." + ) + + zla, freq = metrics.zla(self.__segments) + random_zla, random_freq = metrics.zla(self.__random_segments) + + # Pad the segments for topsim computation + # We use 0 as it is not used in the has table + # and has no effect on the distance measurement + if compute_topsim: + padded_hashed_segments = utils.pad_jagged(self.__hashed_segments) + padded_random_hashed_segments = utils.pad_jagged( + self.__random_hashed_segments + ) + + self.__has_stats = { + "vocab_size": len(self.__segment_ids), + "zla": zla, + "zipf": freq, + # We use hamming here, as the segments could contain multiple characters + # So editdistance would give us a worse estimate + "topographic_similarity": metrics.compute_topographic_similarity( + padded_hashed_segments, + self.observations, + message_dist_metric="hamming", + ) + if compute_topsim + else None, + "random_vocab_size": len(self.__random_segment_ids), + "random_zla": random_zla, + "random_zipf": random_freq, + "random_topographic_similarity": metrics.compute_topographic_similarity( + padded_random_hashed_segments, + self.observations, + message_dist_metric="hamming", + ) + if compute_topsim + else None, + } + + return self.__has_stats diff --git a/emlangkit/metrics/__init__.py b/emlangkit/metrics/__init__.py index 07f022d..8ce0666 100644 --- a/emlangkit/metrics/__init__.py +++ b/emlangkit/metrics/__init__.py @@ -1,10 +1,19 @@ """Root __init__ of the metrics.""" from emlangkit.metrics.bosdis import compute_bosdis from emlangkit.metrics.entropy import compute_entropy +from emlangkit.metrics.has import ( + compute_boundaries, + compute_branching_entropy, + compute_conditional_entropy, + compute_random_boundaries, + compute_segments, + has_init, +) from emlangkit.metrics.mpn import compute_mpn from emlangkit.metrics.mutual_information import compute_mutual_information from emlangkit.metrics.posdis import compute_posdis from emlangkit.metrics.topsim import compute_topographic_similarity +from emlangkit.metrics.zla import zla __all__ = [ # Metrics @@ -14,4 +23,11 @@ "compute_posdis", "compute_topographic_similarity", "compute_mpn", + "has_init", + "compute_segments", + "compute_boundaries", + "compute_random_boundaries", + "compute_branching_entropy", + "compute_conditional_entropy", + "zla", ] diff --git a/emlangkit/metrics/has.py b/emlangkit/metrics/has.py new file mode 100644 index 0000000..ee9f219 --- /dev/null +++ b/emlangkit/metrics/has.py @@ -0,0 +1,235 @@ +""" +The Harris' Articulation Scheme based segmentation. + +Adapted from https://openreview.net/forum?id=b4t9_XASt6G +""" + +import itertools +from collections import Counter +from typing import List, Tuple + +import numpy as np + + +def has_init(messages: np.ndarray) -> Tuple[set, Counter]: + """ + Compute initial values used by the other HAS functions. + + Parameters + ---------- + messages : numpy.ndarray + The array of messages. + + Returns + ------- + alpha : set + The set of unique characters present in the messages. + freq : Counter + A Counter containing all sequences and their corresponding frequencies. + + """ + # Create the alphabet + alpha = set(np.unique(messages)) + # Count all subsequences + freq = Counter( + tuple(s[i:j]) + for s in messages + for i in range(len(s)) + for j in range(i + 1, len(s) + 1) + ) + # The frequency of empty sequence is defined as follows. + # This is just for the convenience. + freq[tuple()] = sum(len(s) for s in messages) + + return alpha, freq + + +def compute_branching_entropy(alpha, freq): + """ + Calculate the branching entropy for a given alphabet, with given frequencies of each item. + + Parameters + ---------- + alpha : set + The set of unique characters present in the messages. + freq : Counter + A dictionary containing sequences as keys and their corresponding frequencies as values. + + Returns + ------- + branching_entropy : dict + Dictionary mapping contexts to their corresponding branching entropy. + """ + branching_entropy = dict() + for context, context_freq in freq.items(): + succ_freq_list = [freq[context + (a,)] for a in alpha] + branching_entropy[context] = ( + -1 + * sum( + succ_freq * (np.log2(succ_freq) - np.log2(context_freq)) + for succ_freq in succ_freq_list + if succ_freq > 0 + ) + / context_freq + ) + return branching_entropy + + +def compute_conditional_entropy(branching_entropy, freq) -> dict: + """ + Compute conditional entropy of a given alphabet, given the branching entropy and the character frequencies. + + Parameters + ---------- + branching_entropy : dict + A dictionary containing sequences as keys and their corresponding branching entropy values as values. + freq : dict + A dictionary containing sequences as keys and their corresponding frequencies as values. + + + Returns + ------- + dict + A dictionary containing the conditional entropy for each sequence length. + The keys are sequence lengths and the values are the corresponding conditional entropy values. + """ + conditional_entropy = dict() + length_to_total_freq = dict() + for seq, ent in branching_entropy.items(): + seq_len = len(seq) + if seq_len not in conditional_entropy: + conditional_entropy[seq_len] = 0 + if seq_len not in length_to_total_freq: + length_to_total_freq[seq_len] = 0 + conditional_entropy[seq_len] += freq[seq] * ent + length_to_total_freq[seq_len] += freq[seq] + for length, total_freq in length_to_total_freq.items(): + conditional_entropy[length] /= total_freq + return conditional_entropy + + +def compute_boundaries( + messages: np.ndarray, branching_entropy: dict, threshold: float +) -> List[set]: + """ + Compute the boundaries of a language, given its pre-computed branching entropy and a threshold value. + + Parameters + ---------- + messages : numpy.ndarray + A numpy array containing the input messages. + branching_entropy : dict + The branching entropy for each context in the messages. + threshold : float + The threshold value used for determining the boundaries. + + Returns + ------- + boundaries : List[set] + A list of sets, where each set represents the boundary positions in each message. + + Notes + ----- + This method computes the boundaries in a list of messages based on the branching entropy and a threshold value. + The boundaries are determined by comparing the branching entropy of each context with the previous context. + If the difference is greater than the threshold, a boundary is added at the position. + The algorithm starts with a width of 2, assuming that the branching entropy has already been computed. + + """ + boundaries = [] + for d in messages: + boundaries.append(set()) + start: int = 0 + width: int = 2 + """ + We begin with width=2, while the algorithm in the paper begins with width=1. + It is because this code block assumes that self.branching_entropy is already computed. + """ + while start < len(d): + context = tuple(d[start : start + width]) + if branching_entropy[context] - branching_entropy[context[:-1]] > threshold: + boundaries[-1].add(start + width) + if start + width + 1 < len(d): + width += 1 + else: + start += 1 + width = 2 + return boundaries + + +def compute_segments( + messages: np.ndarray, boundaries: List[set] +) -> Tuple[list, dict, list]: + """ + Compute language segments given the pre-computed boundaries. + + Parameters + ---------- + messages : numpy.ndarray + An array containing the messages to be segmented. + + boundaries : List[set] + A list representing the boundaries for segmentation. Each element of + this iterable represents the positions where the messages will be split. + + Returns + ------- + segments : list + A list of tuples containing the segmented messages. Each tuple represents + a segment of the message. + + segment_ids : dict + A dictionary mapping each unique segment to its corresponding ID. The ID + is calculated based on the order of occurrence in the segments list. + + hashed_segments : list + A list of tuples containing the hashed versions of the segmented messages. + Each element in the tuple represents the ID of a segment. + + """ + segs = [] + for data, boundaries in zip(messages, boundaries): + segs.append([]) + bot = 0 + for top in sorted(boundaries | {len(data)}): + word = tuple(data[bot:top]) + bot = top + segs[-1].append(word) + segments = [tuple(x) for x in segs] + segment_ids = { + s: i + 1 + for i, s in enumerate( + {tuple(x) for x in itertools.chain.from_iterable(segments)} + ) + } + hashed_segments = [tuple(segment_ids[x] for x in s) for s in segments] + return segments, segment_ids, hashed_segments + + +def compute_random_boundaries( + messages: np.ndarray, boundaries, rng: np.random.Generator +) -> List[set]: + """ + Compute random boundaries for a language, given pre-computed boundaries and a random number generator instance. + + Parameters + ---------- + messages : np.ndarray + The input array of messages. + + boundaries : list + The input list of boundaries. + + rng : np.random.Generator + The random number generator object. + + Returns + ------- + random_boundaries : List[set] + The list of randomly computed boundaries. + """ + random_boundaries = [ + set(rng.choice(np.arange(1, len(data), dtype=np.int32), size=len(boundaries))) + for data, boundaries in zip(messages, boundaries) + ] + return random_boundaries diff --git a/emlangkit/metrics/topsim.py b/emlangkit/metrics/topsim.py index a906332..7c9b139 100644 --- a/emlangkit/metrics/topsim.py +++ b/emlangkit/metrics/topsim.py @@ -8,7 +8,10 @@ def compute_topographic_similarity( - messages: np.ndarray, observations: np.ndarray + messages: np.ndarray, + observations: np.ndarray, + observations_dist_metric: str = "hamming", + message_dist_metric: str = "editdistance", ) -> Tuple[float, float]: """ Calculate the topographic similarity between the given messages and observations. @@ -19,17 +22,31 @@ def compute_topographic_similarity( Messages to calculate the topographic similarity for. observations : np.ndarray Observations to calculate the topographic similarity for. + observations_dist_metric: Literal["editdistance", "cosine", "hamming", "jaccard", "euclidean"] + Metric to use to calculate the distances between observations. + message_dist_metric: Literal["editdistance", "cosine", "hamming", "jaccard", "euclidean"] + Metric to use to calculate the distances between messages. Returns ------- topsim_value : np.ndarray Topographic similarity score. """ - observations_dist = distance.pdist(observations, "hamming") + if message_dist_metric == "editdistance": + + def msg_metric(x, y): + return editdistance.eval(x, y) / ((len(x) + len(y)) / 2) + + else: + msg_metric = message_dist_metric + + # noinspection PyTypeChecker + observations_dist = distance.pdist(observations, observations_dist_metric) # Even though they are ints treat as text messages_dist = distance.pdist( messages, - lambda x, y: editdistance.eval(x, y) / ((len(x) + len(y)) / 2), + msg_metric, ) + # noinspection PyTypeChecker topsim, pvalue = spearmanr(observations_dist, messages_dist, nan_policy="raise") return topsim, pvalue diff --git a/emlangkit/metrics/zla.py b/emlangkit/metrics/zla.py new file mode 100644 index 0000000..b4da0d3 --- /dev/null +++ b/emlangkit/metrics/zla.py @@ -0,0 +1,41 @@ +""" +Zipf's Law statistics. + +Adapted from https://openreview.net/forum?id=b4t9_XASt6G +""" + +import itertools +from collections import Counter, defaultdict +from typing import Tuple + +import numpy as np + + +def zla(words: np.ndarray) -> Tuple[list, list]: + """ + Compute Zipf's Law of Abbreviation (ZLA) statistics. + + Returns the mean word lengths and their frequencies, and just the raw frequencies. + + Parameters + ---------- + words : numpy.ndarray + A numpy array of words. + + Returns + ------- + tuple : (list, list) + The first element contains the mean length of words that have + the same frequency of occurrence in the given words array. + + The second element of the tuple contains a list of frequencies, where + each frequency represents the number of occurrences of a word in the given words array. + """ + frequencies = [] + freq_to_lens = defaultdict(list) + for word, freq in Counter(itertools.chain.from_iterable(words)).most_common(): + frequencies.append(freq) + freq_to_lens[freq].append(len(words)) + zla_stats = [np.mean(freq_to_lens[freq]) for freq in frequencies] + + return zla_stats, frequencies diff --git a/emlangkit/utils/__init__.py b/emlangkit/utils/__init__.py new file mode 100644 index 0000000..d76371d --- /dev/null +++ b/emlangkit/utils/__init__.py @@ -0,0 +1,4 @@ +"""Root __init__ of the utils.""" +from emlangkit.utils.array_ops import pad_jagged + +__all__ = ["pad_jagged"] diff --git a/emlangkit/utils/array_ops.py b/emlangkit/utils/array_ops.py new file mode 100644 index 0000000..a5dcb25 --- /dev/null +++ b/emlangkit/utils/array_ops.py @@ -0,0 +1,28 @@ +"""Utilities for array operations.""" +# Adapted from https://stackoverflow.com/questions/37676539/numpy-padding-matrix-of-different-row-size + +import numpy as np + + +def pad_jagged(array: np.ndarray, fill: int = 0) -> np.ndarray: + """ + Append the minimal required amount of a given integer at the end of each array, such that it looses its jagedness. + + Parameters + ---------- + array : np.ndarray + Input array to be padded. + fill : int + Integer to pad the array with. + + Returns + ------- + padded : np.ndarray + Padded array. + + """ + maxlen = max(len(r) for r in array) + padded = np.full((len(array), maxlen), fill_value=fill) + for enu, row in enumerate(array): + padded[enu, : len(row)] += row + return padded diff --git a/pyproject.toml b/pyproject.toml index 7ff418a..9cb07aa 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "emlangkit" -version = "0.1.0" +version = "0.2.0" description = " Emergent Language Analysis Toolkit" authors = [ { name = "Olaf Lipinski", email = "o.lipinski@soton.ac.uk" }, diff --git a/tests/test_language.py b/tests/test_language.py index b65a962..aea6851 100644 --- a/tests/test_language.py +++ b/tests/test_language.py @@ -10,17 +10,26 @@ def test_instantiations(): + # Check error for not numpy array with pytest.raises(ValueError, match=r".* numpy .*"): + # noinspection PyTypeChecker Language(messages=[]) + # Check error for empty messages with pytest.raises(ValueError, match=r".* messages .*"): Language(messages=np.array([])) + # Check error for not numpy array + with pytest.raises(ValueError, match=r".* numpy .*"): + # noinspection PyTypeChecker + Language(messages=np.array([1, 1, 1]), observations=[]) + + # Check error for empty observations, when provided with pytest.raises(ValueError, match=r".* observations .*"): Language(messages=np.array([1, 1, 1]), observations=np.array([])) -def test_metrics(): +def test_language_metrics(): test_msgs = np.array( [ [0, 0, 0], @@ -68,3 +77,19 @@ def test_metrics(): lang.bosdis() lang.observation_entropy() lang.language_entropy() + + # MPN + lang.mpn() + + # HAS + lang.branching_entropy() + lang.conditional_entropy() + lang.boundaries(return_count=True, return_mean=True) + lang.random_boundaries(return_count=True, return_mean=True) + lang.segments(return_ids=True, return_hashed_segments=True) + lang.random_segments(return_ids=True, return_hashed_segments=True) + lang.has_stats(compute_topsim=True) + + # Test recomputing random stats + lang.random_boundaries(recompute=True) + lang.random_segments(recompute=True) diff --git a/tests/test_metrics.py b/tests/test_metrics.py index c1ca26a..b77d255 100644 --- a/tests/test_metrics.py +++ b/tests/test_metrics.py @@ -725,3 +725,69 @@ def test_mpn(): 100, 2, ) + + with_stats = metrics.compute_mpn( + messages=np.array( + [ + [0, 1], + [0, 1], + [0, 1], + [0, 1], + [4, 1], + [2, 1], + [4, 1], + [3, 1], + [1, 1], + ] + ), + observations=np.array( + [ + [1], + [2], + [3], + [4], + [1], + [4], + [3], + [1], + [1], + ] + ), + prev_horizon=8, + return_stats=True, + ) + + assert len(with_stats) == 2 + + +def test_has(): + """Tests to see if HAS metrics are calculated correctly.""" + rng = np.random.default_rng(seed=42) + + messages = np.array( + [ + [0, 1, 1], + [0, 1, 2], + [0, 1, 3], + [0, 1, 1], + [4, 1, 1], + [2, 1, 3], + [4, 1, 2], + [3, 1, 4], + [1, 1, 5], + ] + ) + + alpha, freq = metrics.has_init(messages) + + be = metrics.compute_branching_entropy(alpha, freq) + + metrics.compute_conditional_entropy(be, freq) + + boundaries = metrics.compute_boundaries(messages, be, 0.5) + + metrics.compute_segments(messages, boundaries) + + random_boundaries = metrics.compute_random_boundaries(messages, boundaries, rng) + + metrics.compute_segments(messages, random_boundaries)