Skip to content

Commit

Permalink
add dagma and poly loss multi-device training fix, correct tests to u…
Browse files Browse the repository at this point in the history
…se default arch params, allow no-val-split training with easy modules
  • Loading branch information
cnellington committed Nov 3, 2024
1 parent 4527298 commit 630f43b
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 32 deletions.
4 changes: 2 additions & 2 deletions contextualized/dags/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def dag_loss_dagma_indiv(w, s=1):
M = s * torch.eye(w.shape[-1]) - w * w
M = s * torch.eye(w.shape[-1]).to(w.device) - w * w
return w.shape[-1] * np.log(s) - torch.slogdet(M)[1]


Expand All @@ -18,7 +18,7 @@ def dag_loss_dagma(W, s=1, alpha=0.0, **kwargs):

def dag_loss_poly_indiv(w):
d = w.shape[-1]
return torch.trace((torch.eye(d) + (1 / d) * torch.matmul(w, w))^d) - d
return torch.trace((torch.eye(d).to(w.device) + (1 / d) * torch.matmul(w, w))^d) - d


def dag_loss_poly(W, **kwargs):
Expand Down
30 changes: 7 additions & 23 deletions contextualized/dags/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytorch_lightning.callbacks import LearningRateFinder


from contextualized.dags.lightning_modules import NOTMAD
from contextualized.dags.lightning_modules import NOTMAD, DEFAULT_SS_PARAMS, DEFAULT_ARCH_PARAMS
from contextualized.dags import graph_utils
from contextualized.dags.trainers import GraphTrainer
from contextualized.dags.losses import mse_loss as mse
Expand Down Expand Up @@ -37,26 +37,18 @@ def _train(self, model_args, n_epochs):
model = NOTMAD(
self.C.shape[-1],
self.X.shape[-1],
archetype_params={
archetype_loss_params={
"l1": 0.0,
"dag": model_args.get(
"dag",
{
"loss_type": "NOTEARS",
"params": {
"alpha": 1e-1,
"rho": 1e-2,
"h_old": 0.0,
"tol": 0.25,
"use_dynamic_alpha_rho": True,
},
},
DEFAULT_ARCH_PARAMS["dag"],
),
"init_mat": INIT_MAT,
"num_factors": model_args.get("num_factors", 0),
"factor_mat_l1": 0.0,
"num_archetypes": model_args.get("num_archetypes", k),
},
# Todo: add sample-specific params
)
dataloader = model.dataloader(self.C, self.X, batch_size=1, num_workers=0)
trainer = GraphTrainer(
Expand Down Expand Up @@ -181,26 +173,18 @@ def _train(self, model_args, n_epochs):
model = NOTMAD(
self.C.shape[-1],
self.X.shape[-1],
archetype_params={
archetype_loss_params={
"l1": 0.0,
"dag": model_args.get(
"dag",
{
"loss_type": "NOTEARS",
"params": {
"alpha": 1e-1,
"rho": 1e-2,
"h_old": 0.0,
"tol": 0.25,
"use_dynamic_alpha_rho": True,
},
},
DEFAULT_ARCH_PARAMS["dag"],
),
"init_mat": INIT_MAT,
"num_factors": model_args.get("num_factors", 0),
"factor_mat_l1": 0.0,
"num_archetypes": model_args.get("num_archetypes", k),
},
# TODO: Add sample-specific params
)
train_dataloader = model.dataloader(
self.C_train, self.X_train, batch_size=1, num_workers=0
Expand Down
22 changes: 15 additions & 7 deletions contextualized/easy/wrappers/SKLearnWrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -335,7 +335,7 @@ def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs):
else:
print("X_val not provided, not using the provided C_val.")
if "val_split" in kwargs:
if 0 < kwargs["val_split"] < 1:
if 0 <= kwargs["val_split"] < 1:
val_split = kwargs["val_split"]
else:
print(
Expand All @@ -346,15 +346,23 @@ def _split_train_data(self, C, X, Y=None, Y_required=False, **kwargs):
else:
val_split = self.default_val_split
if Y is None:
C_train, C_val, X_train, X_val = train_test_split(
C, X, test_size=val_split, shuffle=True
)
if val_split > 0:
C_train, C_val, X_train, X_val = train_test_split(
C, X, test_size=val_split, shuffle=True
)
else:
C_train, X_train = C, X
C_val, X_val = C, X
train_data = [C_train, X_train]
val_data = [C_val, X_val]
else:
C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split(
C, X, Y, test_size=val_split, shuffle=True
)
if val_split > 0:
C_train, C_val, X_train, X_val, Y_train, Y_val = train_test_split(
C, X, Y, test_size=val_split, shuffle=True
)
else:
C_train, X_train, Y_train = C, X, Y
C_val, X_val, Y_val = C, X, Y
train_data = [C_train, X_train, Y_train]
val_data = [C_val, X_val, Y_val]
return train_data, val_data
Expand Down

0 comments on commit 630f43b

Please sign in to comment.