Skip to content

Commit

Permalink
Merge branch 'main' into feat/zir-score-samples
Browse files Browse the repository at this point in the history
  • Loading branch information
FBruzzesi authored Jul 3, 2024
2 parents 3fa5eee + 439c9c0 commit 6bf2a6c
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 46 deletions.
53 changes: 37 additions & 16 deletions sklego/linear_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,8 +426,9 @@ class _FairClassifier(BaseEstimator, LinearClassifierMixin):
A list of column names or column indexes in the input data that represent sensitive attributes.
C : float, default=1.0
Inverse of regularization strength; must be a positive float. Smaller values specify stronger regularization.
penalty : Literal["l1", "none"], default="l1"
The type of penalty to apply to the model. "l1" applies L1 regularization, while "none" disables regularization.
penalty : Literal["l1", "l2", "none", None], default="l1"
The type of penalty to apply to the model. "l1" applies L1 regularization, "l2" applies L2 regularization,
while None (or "none") disables regularization.
fit_intercept : bool, default=True
Whether or not to fit an intercept term. If True, an intercept term is added to the model.
max_iter : int, default=100
Expand All @@ -451,6 +452,8 @@ class _FairClassifier(BaseEstimator, LinearClassifierMixin):
`_FairClassifier` should not be used directly; it serves as a base class for fair classification models.
"""

_ALLOWED_PENALTIES = ("l1", "l2", "none", None)

def __init__(
self,
sensitive_cols=None,
Expand Down Expand Up @@ -487,22 +490,38 @@ def fit(self, X, y):
Raises
------
ValueError
If `penalty` is not one of "l1" or "none".
If `penalty` is not one of "l1", "l2", "none" or None.
"""
if self.penalty not in ["l1", "none"]:
raise ValueError(f"penalty should be either 'l1' or 'none', got {self.penalty}")
if self.penalty not in self._ALLOWED_PENALTIES:
raise ValueError(f"penalty should be one of {self._ALLOWED_PENALTIES}, got {self.penalty}")

if self.penalty == "none":
warn(
"Please use `penalty=None` instead of `penalty='none'`, 'none' will be deprecated in future versions",
DeprecationWarning,
)

self.sensitive_col_idx_ = self.sensitive_cols
X = nw.from_native(X, eager_only=True, strict=False)

if isinstance(X, nw.DataFrame):
self.sensitive_col_idx_ = [i for i, name in enumerate(X.columns) if name in self.sensitive_cols]

X, y = check_X_y(X, y, accept_large_sparse=False)
sensitive = X[:, self.sensitive_col_idx_]

if not self.train_sensitive_cols:
X = np.delete(X, self.sensitive_col_idx_, axis=1)

X = self._add_intercept(X)
if self.fit_intercept:
X = np.c_[np.ones(len(X)), X]

if X.shape[1] == 0:
msg = "Cannot fit the model, at least 1 feature(s) is required."
raise ValueError(msg)

column_or_1d(y)

label_encoder = LabelEncoder().fit(y)
y = label_encoder.transform(y)
self.classes_ = label_encoder.classes_
Expand All @@ -529,8 +548,12 @@ def _solve(self, sensitive, X, y):
cp.multiply(y, y_hat)
- cp.log_sum_exp(cp.hstack([np.zeros((n_obs, 1)), cp.reshape(y_hat, (n_obs, 1))]), axis=1)
)

if self.penalty == "l1":
log_likelihood -= cp.sum((1 / self.C) * cp.norm(theta[1:]))
log_likelihood -= cp.norm(theta[int(self.fit_intercept) :], 1) / self.C

elif self.penalty == "l2":
log_likelihood -= cp.norm(theta[int(self.fit_intercept) :], 2) / self.C

constraints = self.constraints(y_hat, y, sensitive, n_obs)

Expand All @@ -540,7 +563,7 @@ def _solve(self, sensitive, X, y):
kwargs = {"max_iters": self.max_iter}
else:
if self.max_iter:
logging.warning("solver does not support `max_iters` and it `self.max_iter` will be ignored")
logging.warning("solver does not support `max_iters` and the argument will be ignored")
kwargs = {}

problem.solve(**kwargs)
Expand Down Expand Up @@ -583,10 +606,6 @@ def decision_function(self, X):
X = np.delete(X, self.sensitive_col_idx_, axis=1)
return super().decision_function(X)

def _add_intercept(self, X):
if self.fit_intercept:
return np.c_[np.ones(len(X)), X]

def _more_tags(self):
return {"poor_score": True}

Expand Down Expand Up @@ -618,8 +637,9 @@ class DemographicParityClassifier(BaseEstimator, LinearClassifierMixin):
C : float, default=1.0
Inverse of regularization strength; must be a positive float. Like in support vector machines, smaller values
specify stronger regularization.
penalty : Literal["l1", "none"], default="l1"
Used to specify the norm used in the penalization.
penalty : Literal["l1", "l2", "none", None], default="l1"
The type of penalty to apply to the model. "l1" applies L1 regularization, "l2" applies L2 regularization,
while None (or "none") disables regularization.
fit_intercept : bool, default=True
Whether or not a constant term (a.k.a. bias or intercept) should be added to the decision function.
max_iter : int, default=100
Expand Down Expand Up @@ -712,8 +732,9 @@ class EqualOpportunityClassifier(BaseEstimator, LinearClassifierMixin):
C : float, default=1.0
Inverse of regularization strength; must be a positive float. Like in support vector machines, smaller values
specify stronger regularization.
penalty : Literal["l1", "none"], default="l1"
Used to specify the norm used in the penalization.
penalty : Literal["l1", "l2", "none", None], default="l1"
The type of penalty to apply to the model. "l1" applies L1 regularization, "l2" applies L2 regularization,
while None (or "none") disables regularization.
fit_intercept : bool, default=True
Whether or not a constant term (a.k.a. bias or intercept) should be added to the decision function.
max_iter : int, default=100
Expand Down
15 changes: 1 addition & 14 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,26 +62,13 @@ def random_xy_dataset_multitarget(request):
return X, y


@pytest.fixture
def sensitive_classification_dataset():
df = pd.DataFrame(
{
"x1": [1, 0, 1, 0, 1, 0, 1, 1],
"x2": [0, 0, 0, 0, 0, 1, 1, 1],
"y": [1, 1, 1, 0, 1, 0, 0, 0],
}
)

return df[["x1", "x2"]], df["y"]


@pytest.fixture(params=[pd.DataFrame, pl.DataFrame])
def funct(request):
return request.param


@pytest.fixture
def sensitive_classification_dataset_equalopportunity(funct):
def sensitive_classification_dataset(funct):
df = funct(
{
"x1": [1, 0, 1, 0, 1, 0, 1, 1],
Expand Down
20 changes: 13 additions & 7 deletions tests/test_estimators/test_demographic_parity.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
train_sensitive_cols=train_sensitive_cols,
)
for train_sensitive_cols in [True, False]
for penalty in ["none", "l1"]
for penalty in ["l1", "l2", None]
]
)
def test_sklearn_compatible_estimator(estimator, check):
Expand Down Expand Up @@ -100,21 +100,27 @@ def test_same_logistic_multiclass(random_xy_dataset_multiclf):
_test_same(random_xy_dataset_multiclf)


def test_regularization(sensitive_classification_dataset_equalopportunity):
@pytest.mark.parametrize("penalty", ["l1", "l2"])
def test_regularization(sensitive_classification_dataset, penalty):
"""Tests whether increasing regularization decreases the norm of the coefficient vector"""
X, y = sensitive_classification_dataset_equalopportunity
X, y = sensitive_classification_dataset

prev_theta_norm = np.inf
for C in [1, 0.5, 0.1, 0.05]:
fair = DemographicParityClassifier(covariance_threshold=None, sensitive_cols=["x1"], C=C).fit(X, y)
theta_norm = np.sum(np.abs(fair.estimators_[0].coef_))
fair = DemographicParityClassifier(
covariance_threshold=None,
sensitive_cols=["x1"],
C=C,
penalty=penalty,
).fit(X, y)
theta_norm = np.linalg.norm(fair.estimators_[0].coef_, ord=int(penalty[-1]))
assert theta_norm < prev_theta_norm
prev_theta_norm = theta_norm


def test_fairness(sensitive_classification_dataset_equalopportunity):
def test_fairness(sensitive_classification_dataset):
"""tests whether fairness (measured by p percent score) increases as we decrease the covariance threshold"""
X, y = sensitive_classification_dataset_equalopportunity
X, y = sensitive_classification_dataset
scorer = p_percent_score("x1")

prev_fairness = -np.inf
Expand Down
19 changes: 12 additions & 7 deletions tests/test_estimators/test_equal_opportunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
train_sensitive_cols=train_sensitive_cols,
)
for train_sensitive_cols in [True, False]
for penalty in ["none", "l1"]
for penalty in ["l1", "l2", None]
]
)
def test_sklearn_compatible_estimator(estimator, check):
Expand Down Expand Up @@ -97,23 +97,28 @@ def test_same_logistic_multiclass(random_xy_dataset_multiclf):
_test_same(random_xy_dataset_multiclf)


def test_regularization(sensitive_classification_dataset_equalopportunity):
@pytest.mark.parametrize("penalty", ["l1", "l2"])
def test_regularization(sensitive_classification_dataset, penalty):
"""Tests whether increasing regularization decreases the norm of the coefficient vector"""
X, y = sensitive_classification_dataset_equalopportunity
X, y = sensitive_classification_dataset

prev_theta_norm = np.inf
for C in [1, 0.5, 0.1, 0.05]:
fair = EqualOpportunityClassifier(
covariance_threshold=None, sensitive_cols=["x1"], C=C, positive_target=True
covariance_threshold=None,
sensitive_cols=["x1"],
C=C,
positive_target=True,
penalty=penalty,
).fit(X, y)
theta_norm = np.sum(np.abs(fair.estimators_[0].coef_))
theta_norm = np.linalg.norm(fair.estimators_[0].coef_, ord=int(penalty[-1]))
assert theta_norm < prev_theta_norm
prev_theta_norm = theta_norm


def test_fairness(sensitive_classification_dataset_equalopportunity):
def test_fairness(sensitive_classification_dataset):
"""tests whether fairness (measured by p percent score) increases as we decrease the covariance threshold"""
X, y = sensitive_classification_dataset_equalopportunity
X, y = sensitive_classification_dataset
scorer = equal_opportunity_score("x1")

prev_fairness = -np.inf
Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics/test_equal_opportunity.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ def test_p_percent_pandas_multiclass():

def test_p_percent_numpy(sensitive_classification_dataset):
X, y = sensitive_classification_dataset
X = X.values
X, y = X.to_numpy(), y.to_numpy()
mod = LogisticRegression().fit(X, y)
assert equal_opportunity_score(1)(mod, X, y) == 0

Expand Down
2 changes: 1 addition & 1 deletion tests/test_metrics/test_p_percent.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def test_p_percent_pandas_multiclass(sensitive_multiclass_classification_dataset

def test_p_percent_numpy(sensitive_classification_dataset):
X, y = sensitive_classification_dataset
X = X.values
X, y = X.to_numpy(), y.to_numpy()
mod = LogisticRegression().fit(X, y)
assert p_percent_score(1)(mod, X) == 0

Expand Down

0 comments on commit 6bf2a6c

Please sign in to comment.