Skip to content

Commit

Permalink
Introduce BaseModel that can be included into other models.
Browse files Browse the repository at this point in the history
  • Loading branch information
bojan-karlas committed May 22, 2024
1 parent 8668577 commit 0488016
Showing 1 changed file with 15 additions and 11 deletions.
26 changes: 15 additions & 11 deletions experiments/datascope/experiments/pipelines/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -397,7 +397,11 @@ def __init__(self, **kwargs: Any) -> None:
pass


class LogisticRegressionModel(Model, id="logreg", longname="Logistic Regression"):
class BaseModel(Model, abstract=True, id="base", longname="Base Model"):
pass


class LogisticRegressionModel(BaseModel, id="logreg", longname="Logistic Regression"):

def __init__(self, solver: str = "liblinear", max_iter: int = 5000, **kwargs) -> None:
self._solver = solver
Expand All @@ -416,7 +420,7 @@ def construct(self: "LogisticRegressionModel", dataset: Dataset) -> BaseEstimato
return LogisticRegression(solver=self.solver, max_iter=self.max_iter, random_state=666)


class RandomForestModel(Model, id="randf", longname="Random Forest"):
class RandomForestModel(BaseModel, id="randf", longname="Random Forest"):
def __init__(self, num_estimators: int = 50, **kwargs) -> None:
self._num_estimators = num_estimators

Expand All @@ -429,7 +433,7 @@ def construct(self: "RandomForestModel", dataset: Dataset) -> BaseEstimator:
return RandomForestClassifier(n_estimators=self.num_estimators, random_state=666)


class KNearestNeighborsModel(Model, id="knn", longname="K-Nearest Neighbors"):
class KNearestNeighborsModel(BaseModel, id="knn", longname="K-Nearest Neighbors"):
def __init__(self, num_neighbors: int = 1, **kwargs) -> None:
self._num_neighbors = num_neighbors

Expand Down Expand Up @@ -472,7 +476,7 @@ def __init__(self, **kwargs) -> None:
super().__init__(num_neighbors=100)


class SupportVectorMachineModel(Model, id="svm", longname="Support Vector Machine"):
class SupportVectorMachineModel(BaseModel, id="svm", longname="Support Vector Machine"):
def __init__(self, kernel: str = "rbf", **kwargs) -> None:
self._kernel = kernel

Expand All @@ -485,22 +489,22 @@ def construct(self: "SupportVectorMachineModel", dataset: Dataset) -> BaseEstima
return SVC(kernel=self.kernel, random_state=666)


class LinearSupportVectorMachineModel(Model, id="linsvm", longname="Linear Support Vector Machine"):
class LinearSupportVectorMachineModel(BaseModel, id="linsvm", longname="Linear Support Vector Machine"):
def construct(self: "LinearSupportVectorMachineModel", dataset: Dataset) -> BaseEstimator:
return LinearSVC(loss="hinge", random_state=666)


class GaussianProcessModel(Model, id="gp", longname="Gaussian Process"):
class GaussianProcessModel(BaseModel, id="gp", longname="Gaussian Process"):
def construct(self: "GaussianProcessModel", dataset: Dataset) -> BaseEstimator:
return GaussianProcessClassifier(random_state=666)


class NaiveBayesModel(Model, id="nb", longname="Naive Bayes Classifier"):
class NaiveBayesModel(BaseModel, id="nb", longname="Naive Bayes Classifier"):
def construct(self: "NaiveBayesModel", dataset: Dataset) -> BaseEstimator:
return MultinomialNB()


class MultilevelPerceptronModel(Model, id="mlp", longname="Multilevel Perceptron"):
class MultilevelPerceptronModel(BaseModel, id="mlp", longname="Multilevel Perceptron"):
def __init__(
self,
solver: str = "sgd",
Expand Down Expand Up @@ -567,7 +571,7 @@ def construct(self: "MultilevelPerceptronModel", dataset: Dataset) -> BaseEstima
)


class XGBoostModel(Model, id="xgb", longname="XGBoost"):
class XGBoostModel(BaseModel, id="xgb", longname="XGBoost"):
def __init__(self, num_estimators: int = 100, max_depth: int = 6, subsample: float = 1.0, **kwargs) -> None:
self._num_estimators = num_estimators
self._max_depth = max_depth
Expand Down Expand Up @@ -598,7 +602,7 @@ def construct(self: "XGBoostModel", dataset: Dataset) -> BaseEstimator:
)


class Resnet18Model(Model, id="resnet-18", longname="ResNet-18"):
class Resnet18Model(BaseModel, id="resnet-18", longname="ResNet-18"):
def __init__(
self,
num_epochs: int = 10,
Expand Down Expand Up @@ -634,6 +638,6 @@ def construct(self: "Resnet18Model", dataset: Dataset) -> BaseEstimator:
)


class MatchingNetworkModel(Model, id="matchingnet", longname="Matching Network"):
class MatchingNetworkModel(BaseModel, id="matchingnet", longname="Matching Network"):
def construct(self: "MatchingNetworkModel", dataset: Dataset) -> BaseEstimator:
return MatchingNetworkClassifier()

0 comments on commit 0488016

Please sign in to comment.