Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Restrict HDBSCAN metric options to L2 #5415 #5492

Merged
merged 9 commits into from
Aug 1, 2023
6 changes: 1 addition & 5 deletions python/cuml/cluster/hdbscan/hdbscan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -165,12 +165,8 @@ cdef extern from "cuml/cluster/hdbscan.hpp" namespace "ML::HDBSCAN::HELPER":
float cluster_selection_epsilon) except +

_metrics_mapping = {
'l1': DistanceType.L1,
'cityblock': DistanceType.L1,
'manhattan': DistanceType.L1,
'l2': DistanceType.L2SqrtExpanded,
'euclidean': DistanceType.L2SqrtExpanded,
'cosine': DistanceType.CosineExpanded
}


Expand Down Expand Up @@ -838,7 +834,7 @@ class HDBSCAN(UniversalBase, ClusterMixin, CMajorInputTagMixin):
if self.metric in _metrics_mapping:
metric = _metrics_mapping[self.metric]
else:
raise ValueError("'affinity' %s not supported." % self.affinity)
raise ValueError(f"metric '{self.metric}' not supported, only 'l2' and 'euclidean' are currently supported")

cdef uintptr_t core_dists_ptr = self.core_dists.ptr

Expand Down
19 changes: 19 additions & 0 deletions python/cuml/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -475,6 +475,25 @@ def test_hdbscan_core_dists_bug_4054():
assert adjusted_rand_score(cu_labels_, sk_labels_) > 0.99


@pytest.mark.parametrize(
"metric, supported",
[("euclidean", True), ("l1", False), ("l2", True), ("abc", False)],
)
def test_hdbscan_metric_parameter_input(metric, supported):
"""
tests how valid and invalid arguments to the metric
parameter are handled
"""
X, y = make_blobs(n_samples=10000, n_features=15, random_state=12)

clf = HDBSCAN(metric=metric)
if supported:
clf.fit(X)
else:
with pytest.raises(ValueError):
clf.fit(X)


def test_hdbscan_empty_cluster_tree():

raw_tree = np.recarray(
Expand Down