-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathDIBcont.R
99 lines (82 loc) · 3.27 KB
/
DIBcont.R
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
DIBcont <- function(X, ncl, randinit = NULL, s = -1, scale = TRUE,
maxiter = 100, nstart = 100, select_features = FALSE) {
# Validate inputs
if (!is.numeric(ncl) || ncl <= 1 || ncl != round(ncl)) {
stop("Input 'ncl' must be a positive integer greater than 1.")
}
if (!is.logical(scale)) {
stop("'scale' must be a logical value (TRUE or FALSE).")
}
if (!is.numeric(maxiter) || maxiter <= 0 || maxiter != round(maxiter)) {
stop("'maxiter' must be a positive integer.")
}
if (!is.numeric(nstart) || nstart <= 0 || nstart != round(nstart)) {
stop("'nstart' must be a positive integer.")
}
if (!is.logical(select_features)) {
stop("'select_features' must be a logical value (TRUE or FALSE).")
}
if (!is.null(randinit) && (!is.numeric(randinit) || length(randinit) != nrow(X))) {
stop("'randinit' must be a numeric vector with length equal to the number of rows in 'X', or NULL.")
}
# Validate s
if (!is.numeric(s) ||
!(length(s) == 1 || length(s) == ncol(X)) ||
any(s <= 0 & s != -1)) {
stop("'s' must be either a single numeric value (-1 for automatic selection or a positive value) or a numeric vector with positive values matching the number of 'contcols'.")
}
# Helper function to preprocess data
preprocess_cont_data <- function(X) {
X <- data.frame(X)
X <- scale(X) # Standardize continuous variables
return(X)
}
# Helper function to compute bandwidth (s) for continuous data
compute_bandwidth_cont <- function(X, s) {
if (s == -1) {
s_seq <- seq(0.1, 10, by = 1e-1)
for (s_val in s_seq) {
pxy_list_cont <- coord_to_pxy_R(as.data.frame(X), s = s_val,
cat_cols = c(), cont_cols = seq_len(ncol(X)),
lambda = 0)
pyx_cont <- pxy_list_cont$py_x
avg_py_x <- mean(apply(pyx_cont, 2, function(x) max(x) / max(x[-which.max(x)])))
if (avg_py_x < 1.1) {
return(s_val - 1e-1)
}
}
}
return(s)
}
# Preprocessing
if (scale == FALSE)
X <- preprocess_cont_data(X)
# Bandwidth computation
if (length(s) == 1)
if (s == -1)
s <- compute_bandwidth_cont(X, s)
# Compute joint probability density for continuous variables
pxy_list <- coord_to_pxy_R(as.data.frame(X), s = s, cat_cols = c(),
cont_cols = seq_len(ncol(X)), lambda = 0)
py_x <- pxy_list$py_x
px <- pxy_list$px
hy <- pxy_list$hy
# Feature selection using eigengap heuristic (optional)
if (select_features) {
bw <- rep(s, ncol(X))
bws_vec <- eigengap(data = X, contcols = seq_len(ncol(X)), catcols = c(),
bw = bw, ncl = ncl)
} else {
bws_vec <- rep(s, ncol(X))
}
# Run DIB iteration for clustering
best_clust <- DIBmix_iterate(X, ncl = ncl, randinit = randinit, tol = 0,
py_x = py_x, hy = hy, px = px, maxiter = maxiter,
bws_vec = bws_vec, contcols = seq_len(ncol(X)),
catcols = c(), runs = nstart)
# Warning if clustering failed
if (best_clust[[3]] == Inf) {
warning("Initial cluster assignment remained unchanged; use other hyperparameter values for DIBcont to converge.")
}
return(best_clust)
}