Skip to content

Commit

Permalink
Complete sparse matrix support (per issue #12 and issue #15)
Browse files Browse the repository at this point in the history
  • Loading branch information
lmcinnes committed Nov 26, 2017
1 parent ef4b1d7 commit da04e17
Show file tree
Hide file tree
Showing 3 changed files with 441 additions and 265 deletions.
173 changes: 165 additions & 8 deletions umap/sparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,26 +6,27 @@
import numpy as np
import numba

from .umap_ import (tau_rand_int,
norm,
make_heap,
heap_push,
rejection_sample,
build_candidates)
from umap.utils import (tau_rand_int,
tau_rand,
norm,
make_heap,
heap_push,
rejection_sample,
build_candidates)


# Just reproduce a simpler version of numpy unique (not numba supported yet)
@numba.njit()
def arr_unique(arr):
aux = np.sort(arr)
flag = np.concatenate(([True], aux[1:] != aux[:-1]))
flag = np.concatenate((np.ones(1, dtype=np.bool_), aux[1:] != aux[:-1]))
return aux[flag]


# Just reproduce a simpler version of numpy union1d (not numba supported yet)
@numba.njit()
def arr_union(ar1, ar2):
return arr_unique(np.concatenate(ar1, ar2))
return arr_unique(np.concatenate((ar1, ar2)))


# Just reproduce a simpler version of numpy intersect1d (not numba supported
Expand Down Expand Up @@ -536,6 +537,115 @@ def sparse_minkowski(ind1, data1, ind2, data2, p):
return result ** (1.0 / p)


@numba.njit()
def sparse_hamming(ind1, data1, ind2, data2, n_features):
num_not_equal = sparse_diff(ind1, data1, ind2, data2).shape[0]
return float(num_not_equal) / n_features


@numba.njit()
def sparse_canberra(ind1, data1, ind2, data2):
abs_data1 = np.abs(data1)
abs_data2 = np.abs(data2)
denom_inds, denom_data = sparse_sum(ind1, abs_data1, ind2, abs_data2)
denom_data = 1.0 / denom_data
numer_inds, numer_data = sparse_diff(ind1, data1, ind2, data2)
numer_data = np.abs(numer_data)

val_inds, val_data = sparse_mul(numer_inds, numer_data,
denom_inds, denom_data)

return np.sum(val_data)

@numba.njit()
def sparse_bray_curtis(ind1, data1, ind2, data2):
abs_data1 = np.abs(data1)
abs_data2 = np.abs(data2)
denom_inds, denom_data = sparse_sum(ind1, abs_data1, ind2, abs_data2)

if denom_data.shape[0] == 0:
return 0.0

denominator = np.sum(denom_data)

numer_inds, numer_data = sparse_diff(ind1, data1, ind2, data2)
numer_data = np.abs(numer_data)

numerator = np.sum(numer_data)

return float(numerator) / denominator


@numba.njit()
def sparse_jaccard(ind1, data1, ind2, data2):
num_non_zero = arr_union(ind1, ind2).shape[0]
num_equal = arr_intersect(ind1, ind2).shape[0]

return float(num_non_zero - num_equal) / num_non_zero


@numba.njit()
def sparse_matching(ind1, data1, ind2, data2, n_features):
num_true_true = arr_intersect(ind1, ind2).shape[0]
num_non_zero = arr_union(ind1, ind2).shape[0]
num_not_equal = num_non_zero - num_true_true

return float(num_not_equal) / n_features


@numba.njit()
def sparse_dice(ind1, data1, ind2, data2):
num_true_true = arr_intersect(ind1, ind2).shape[0]
num_non_zero = arr_union(ind1, ind2).shape[0]
num_not_equal = num_non_zero - num_true_true

return num_not_equal / (2.0 * num_true_true + num_not_equal)


@numba.njit()
def sparse_kulsinski(ind1, data1, ind2, data2, n_features):
num_true_true = arr_intersect(ind1, ind2).shape[0]
num_non_zero = arr_union(ind1, ind2).shape[0]
num_not_equal = num_non_zero - num_true_true

return float(num_not_equal - num_true_true + n_features) / \
(num_not_equal + n_features)


@numba.njit()
def sparse_rogers_tanimoto(ind1, data1, ind2, data2, n_features):
num_true_true = arr_intersect(ind1, ind2).shape[0]
num_non_zero = arr_union(ind1, ind2).shape[0]
num_not_equal = num_non_zero - num_true_true

return (2.0 * num_not_equal) / (n_features + num_not_equal)


@numba.njit()
def sparse_russelrao(ind1, data1, ind2, data2, n_features):
num_true_true = arr_intersect(ind1, ind2).shape[0]

return float(n_features - num_true_true) / (n_features)


@numba.njit()
def sparse_sokal_michener(ind1, data1, ind2, data2, n_features):
num_true_true = arr_intersect(ind1, ind2).shape[0]
num_non_zero = arr_union(ind1, ind2).shape[0]
num_not_equal = num_non_zero - num_true_true

return (2.0 * num_not_equal) / (n_features + num_not_equal)


@numba.njit()
def sparse_sokal_sneath(ind1, data1, ind2, data2):
num_true_true = arr_intersect(ind1, ind2).shape[0]
num_non_zero = arr_union(ind1, ind2).shape[0]
num_not_equal = num_non_zero - num_true_true

return num_not_equal / (0.5 * num_true_true + num_not_equal)


@numba.njit()
def sparse_cosine(ind1, data1, ind2, data2):
aux_inds, aux_data = sparse_mul(ind1, data1, ind2, data2)
Expand All @@ -548,6 +658,32 @@ def sparse_cosine(ind1, data1, ind2, data2):

return 1.0 - (result / np.sqrt(norm1 * norm2))


@numba.njit()
def sparse_correlation(ind1, data1, ind2, data2, n_features):

mu1 = float(np.sum(data1)) / n_features
mu2 = float(np.sum(data2)) / n_features

shifted_data1 = data1 - mu1
shifted_data2 = data2 - mu2

norm1 = norm(shifted_data1)
norm2 = norm(shifted_data2)

dot_prod_inds, dot_prod_data = sparse_mul(ind1, shifted_data1,
ind2, shifted_data2)

if dot_prod_data.shape[0] == 0:
return 1.0

dot_product = np.sum(dot_prod_data)

if dot_product == 0.0:
return 1.0
else:
return (1.0 - dot_product) / np.sqrt(norm1 * norm2)

sparse_named_distances = {
'euclidean' : sparse_euclidean,
'manhattan' : sparse_manhattan,
Expand All @@ -558,5 +694,26 @@ def sparse_cosine(ind1, data1, ind2, data2):
'linfty' : sparse_chebyshev,
'linfinity' : sparse_chebyshev,
'minkowski' : sparse_minkowski,
'hamming' : sparse_hamming,
'canberra' : sparse_canberra,
'bray_curtis' : sparse_bray_curtis,
'jaccard' : sparse_jaccard,
'matching' : sparse_matching,
'kulsinski' : sparse_kulsinski,
'rogers_tanimoto' : sparse_rogers_tanimoto,
'russelrao' : sparse_russelrao,
'sokal_michener' : sparse_sokal_michener,
'sokal_sneath' : sparse_sokal_sneath,
'cosine' : sparse_cosine,
'correlation' : sparse_correlation,
}

sparse_need_n_features = (
'hamming',
'matching',
'kulsinski',
'rogers_tanimoto',
'russelrao',
'sokal_michener',
'correlation'
)
Loading

0 comments on commit da04e17

Please sign in to comment.