diff --git a/lit_nlp/api/model.py b/lit_nlp/api/model.py index 71adac90..f0533e01 100644 --- a/lit_nlp/api/model.py +++ b/lit_nlp/api/model.py @@ -194,19 +194,6 @@ def get_embedding_table(self) -> tuple[list[str], np.ndarray]: raise NotImplementedError('get_embedding_table() not implemented for ' + self.__class__.__name__) - def fit_transform( - self, inputs: Iterable[types.JsonDict] - ) -> Iterable[types.JsonDict]: - """For internal use by UMAP and other sklearn-based models.""" - raise NotImplementedError( - 'fit_transform() not implemented for ' + self.__class__.__name__) - - def fit_transform_with_metadata( - self, indexed_inputs: Iterable[types.IndexedInput] - ) -> Iterable[types.JsonDict]: - """For internal use by UMAP and other sklearn-based models.""" - return self.fit_transform((ii['data'] for ii in indexed_inputs)) - ## # Concrete implementations of common functions. def predict(self, inputs: Iterable[JsonDict], **kw) -> Iterable[JsonDict]: @@ -301,11 +288,6 @@ def output_spec(self) -> types.Spec: def get_embedding_table(self) -> tuple[list[str], np.ndarray]: return self.wrapped.get_embedding_table() - def fit_transform_with_metadata( - self, indexed_inputs: Iterable[types.IndexedInput] - ): - return self.wrapped.fit_transform_with_metadata(indexed_inputs) - class BatchedRemoteModel(Model): """Generic base class for remotely-hosted models. @@ -363,3 +345,27 @@ def predict_minibatch(self, inputs: list[JsonDict]) -> list[JsonDict]: list of outputs, following model.output_spec() """ return + + +class ProjectorModel(Model, metaclass=abc.ABCMeta): + """LIT Model API for dimensionality reduction.""" + + ## + # Training methods + @abc.abstractmethod + def fit_transform(self, inputs: Iterable[JsonDict]) -> list[JsonDict]: + """For internal use by SciKit Learn-based models.""" + pass + + ## + # LIT model API + def input_spec(self): + # 'x' denotes input features + return {'x': types.Embeddings()} + + def output_spec(self): + # 'z' denotes projected embeddings + return {'z': types.Embeddings()} + + def max_minibatch_size(self, **unused_kw): + return 1000 diff --git a/lit_nlp/components/pca.py b/lit_nlp/components/pca.py index fe13e919..178a36b5 100644 --- a/lit_nlp/components/pca.py +++ b/lit_nlp/components/pca.py @@ -15,12 +15,12 @@ """Implementation of PCA as a dimensionality reduction model.""" from absl import logging -from lit_nlp.components import projection +from lit_nlp.api import model from lit_nlp.lib import utils import numpy as np -class PCAModel(projection.ProjectorModel): +class PCAModel(model.ProjectorModel): """LIT model API implementation for PCA.""" def __init__(self, **pca_kw): @@ -57,7 +57,7 @@ def fit_transform(self, inputs): # LIT model API def predict_minibatch(self, inputs, **unused_kw): if not self._fitted: - return ({"z": [0, 0, 0]} for i in inputs) + return ({"z": [0, 0, 0]} for _ in inputs) x = np.stack([i["x"] for i in inputs]) x = x - self._mean zs = np.dot(x, self._evecs) diff --git a/lit_nlp/components/projection.py b/lit_nlp/components/projection.py index 92a74a7f..d3ff815b 100644 --- a/lit_nlp/components/projection.py +++ b/lit_nlp/components/projection.py @@ -33,8 +33,7 @@ projections). """ -import abc -from collections.abc import Iterable, Hashable, Sequence +from collections.abc import Hashable, Sequence import threading from typing import Optional @@ -51,35 +50,6 @@ Spec = types.Spec -class ProjectorModel(lit_model.Model, metaclass=abc.ABCMeta): - """LIT model API implementation for dimensionality reduction.""" - - ## - # Training methods - @abc.abstractmethod - def fit_transform(self, inputs: Iterable[JsonDict]) -> list[JsonDict]: - pass - - ## - # LIT model API - def input_spec(self): - # 'x' denotes input features - return {"x": types.Embeddings()} - - def output_spec(self): - # 'z' denotes projected embeddings - return {"z": types.Embeddings()} - - @abc.abstractmethod - def predict_minibatch( - self, inputs: Iterable[JsonDict], **unused_kw - ) -> list[JsonDict]: - pass - - def max_minibatch_size(self, **unused_kw): - return 1000 - - class ProjectionInterpreter(lit_components.Interpreter): """Interpreter API implementation for dimensionality reduction model.""" @@ -88,7 +58,7 @@ def __init__( model: lit_model.Model, inputs: Sequence[JsonDict], model_outputs: Optional[list[JsonDict]], - projector: ProjectorModel, + projector: lit_model.ProjectorModel, field_name: str, name: str, ): @@ -166,7 +136,7 @@ class ProjectionManager(lit_components.Interpreter): this is not explicitly enforced. """ - def __init__(self, model_class: type[ProjectorModel]): + def __init__(self, model_class: type[lit_model.ProjectorModel]): self._lock = threading.RLock() self._instances: dict[Hashable, ProjectionInterpreter] = {} # Used to construct new instances, given config['proj_kw'] diff --git a/lit_nlp/components/umap.py b/lit_nlp/components/umap.py index 2e40dc95..4e52f156 100644 --- a/lit_nlp/components/umap.py +++ b/lit_nlp/components/umap.py @@ -15,13 +15,13 @@ """Implementation of UMAP as a dimensionality reduction model.""" from absl import logging -from lit_nlp.components import projection +from lit_nlp.api import model from lit_nlp.lib import utils import numpy as np import umap -class UmapModel(projection.ProjectorModel): +class UmapModel(model.ProjectorModel): """LIT model API implementation for UMAP.""" def __init__(self, **umap_kw): diff --git a/lit_nlp/lib/caching.py b/lit_nlp/lib/caching.py index 05858bd8..d531415e 100644 --- a/lit_nlp/lib/caching.py +++ b/lit_nlp/lib/caching.py @@ -227,9 +227,15 @@ def key_fn(self, d) -> CacheKey: ## # For internal use def fit_transform(self, inputs: Iterable[types.JsonDict]): - """For use with UMAP and other preprocessing transforms.""" + """Cache projections from ProjectorModel dimensionality reducers.""" + wrapped = self.wrapped + if not isinstance(wrapped, lit_model.ProjectorModel): + raise TypeError( + "Attempted to call fit_transform() on a non-ProjectorModel." + ) + inputs_as_list = list(inputs) - outputs = list(self.wrapped.fit_transform(inputs_as_list)) + outputs = list(wrapped.fit_transform(inputs_as_list)) with self._cache.lock: for inp, output in zip(inputs_as_list, outputs): self._cache.put(output, self.key_fn(inp))