diff --git a/contextualized/dags/losses.py b/contextualized/dags/losses.py index f149815..e2837da 100644 --- a/contextualized/dags/losses.py +++ b/contextualized/dags/losses.py @@ -4,7 +4,7 @@ def dag_loss_dagma_indiv(w, s=1): - M = s * torch.eye(w.shape[-1]) - w * w + M = s * torch.eye(w.shape[-1]).to(w.device) - w * w return w.shape[-1] * np.log(s) - torch.slogdet(M)[1] @@ -18,7 +18,7 @@ def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs): def dag_loss_poly_indiv(w): d = w.shape[-1] - return torch.trace((torch.eye(d) + (1 / d) * torch.matmul(w, w))^d) - d + return torch.trace((torch.eye(d).to(w.device) + (1 / d) * torch.matmul(w, w)) ** d) - d def dag_loss_poly(W, **kwargs): diff --git a/contextualized/dags/tests.py b/contextualized/dags/tests.py index ad62257..2f6d8fd 100644 --- a/contextualized/dags/tests.py +++ b/contextualized/dags/tests.py @@ -9,7 +9,7 @@ from pytorch_lightning.callbacks import LearningRateFinder -from contextualized.dags.lightning_modules import NOTMAD +from contextualized.dags.lightning_modules import NOTMAD, DEFAULT_SS_PARAMS, DEFAULT_ARCH_PARAMS from contextualized.dags import graph_utils from contextualized.dags.trainers import GraphTrainer from contextualized.dags.losses import mse_loss as mse @@ -37,26 +37,21 @@ def _train(self, model_args, n_epochs): model = NOTMAD( self.C.shape[-1], self.X.shape[-1], - archetype_params={ + archetype_loss_params={ "l1": 0.0, "dag": model_args.get( "dag", - { - "loss_type": "NOTEARS", - "params": { - "alpha": 1e-1, - "rho": 1e-2, - "h_old": 0.0, - "tol": 0.25, - "use_dynamic_alpha_rho": True, - }, - }, + DEFAULT_ARCH_PARAMS["dag"], ), "init_mat": INIT_MAT, "num_factors": model_args.get("num_factors", 0), "factor_mat_l1": 0.0, "num_archetypes": model_args.get("num_archetypes", k), }, + sample_specific_loss_params= { + "l1": 0.0, + "dag": DEFAULT_SS_PARAMS["dag"], + } ) dataloader = model.dataloader(self.C, self.X, batch_size=1, num_workers=0) trainer = GraphTrainer( @@ -181,26 +176,21 @@ def _train(self, model_args, n_epochs): model = NOTMAD( self.C.shape[-1], self.X.shape[-1], - archetype_params={ + archetype_loss_params={ "l1": 0.0, "dag": model_args.get( "dag", - { - "loss_type": "NOTEARS", - "params": { - "alpha": 1e-1, - "rho": 1e-2, - "h_old": 0.0, - "tol": 0.25, - "use_dynamic_alpha_rho": True, - }, - }, + DEFAULT_ARCH_PARAMS["dag"], ), "init_mat": INIT_MAT, "num_factors": model_args.get("num_factors", 0), "factor_mat_l1": 0.0, "num_archetypes": model_args.get("num_archetypes", k), }, + sample_specific_loss_params= { + "l1": 0.0, + "dag": DEFAULT_SS_PARAMS["dag"], + } ) train_dataloader = model.dataloader( self.C_train, self.X_train, batch_size=1, num_workers=0 diff --git a/contextualized/easy/wrappers/SKLearnWrapper.py b/contextualized/easy/wrappers/SKLearnWrapper.py index 96fc5fc..607ab0e 100644 --- a/contextualized/easy/wrappers/SKLearnWrapper.py +++ b/contextualized/easy/wrappers/SKLearnWrapper.py @@ -335,7 +335,7 @@ def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): else: print("X_val not provided, not using the provided C_val.") if "val_split" in kwargs: - if 0 < kwargs["val_split"] < 1: + if 0 <= kwargs["val_split"] < 1: val_split = kwargs["val_split"] else: print( @@ -346,15 +346,23 @@ def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs): else: val_split = self.default_val_split if Y is None: - C_train, C_val, X_train, X_val = train_test_split( - C, X, test_size=val_split, shuffle=True - ) + if val_split > 0: + C_train, C_val, X_train, X_val = train_test_split( + C, X, test_size=val_split, shuffle=True + ) + else: + C_train, X_train = C, X + C_val, X_val = C, X train_data = [C_train, X_train] val_data = [C_val, X_val] else: - C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split( - C, X, Y, test_size=val_split, shuffle=True - ) + if val_split > 0: + C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split( + C, X, Y, test_size=val_split, shuffle=True + ) + else: + C_train, X_train, Y_train = C, X, Y + C_val, X_val, Y_val = C, X, Y train_data = [C_train, X_train, Y_train] val_data = [C_val, X_val, Y_val] return train_data, val_data diff --git a/pyproject.toml b/pyproject.toml index ee912bf..c01aaff 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,7 +26,7 @@ keywords = [ ] dependencies = [ 'lightning>=2.0.0', - 'torch>=2.0.0,<2.2.0', + 'torch>=2.0.0', 'torchvision>=0.8.0', 'numpy>=1.19.0', 'pandas>=2.0.0',