Skip to content

Commit

Permalink
Merge pull request #9 from mbruhns/mps
Browse files Browse the repository at this point in the history
Integrating MPS backend
  • Loading branch information
yihming authored Jan 7, 2024
2 parents 436e53c + cb3065b commit 769a6a0
Show file tree
Hide file tree
Showing 7 changed files with 448 additions and 178 deletions.
154 changes: 148 additions & 6 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,6 +1,148 @@
__pycache__
*.pdf
build
dist
*.egg-info
.eggs
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
*.py,cover
.hypothesis/
.pytest_cache/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# IPython
profile_default/
ipython_config.py

# pyenv
.python-version

# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock

# PEP 582; used by e.g. github.com/David-OConnor/pyflow
__pypackages__/

# Celery stuff
celerybeat-schedule
celerybeat.pid

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/
.dmypy.json
dmypy.json

# Pyre type checker
.pyre/

# Sublime workspace
*.sublime-workspace
.DS_Store

#Custom folders
results/
figures/

*.sublime-workspace
*.sublime-project

# Jupyter notebooks
*.ipynb

.idea/

*.h5ad

2 changes: 1 addition & 1 deletion harmony/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from importlib_metadata import version, PackageNotFoundError

try:
__version__ = version('harmony-pytorch')
__version__ = version("harmony-pytorch")
del version
except PackageNotFoundError:
pass
71 changes: 35 additions & 36 deletions harmony/harmony.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
from .utils import one_hot_tensor, get_batch_codes



def harmonize(
X: np.array,
batch_mat: pd.DataFrame,
Expand Down Expand Up @@ -105,13 +104,16 @@ def harmonize(
>>> X_harmony = harmonize(adata.obsm['X_pca'], adata.obs, ['Channel', 'Lab'])
"""

assert(isinstance(X, np.ndarray))
assert isinstance(X, np.ndarray)

if n_jobs < 0:
import psutil
n_jobs = psutil.cpu_count(logical=False) # get physical cores

n_jobs = psutil.cpu_count(logical=False) # get physical cores
if n_jobs is None:
n_jobs = psutil.cpu_count(logical=True) # if undetermined, use logical cores instead
n_jobs = psutil.cpu_count(
logical=True
) # if undetermined, use logical cores instead
torch.set_num_threads(n_jobs)

device_type = "cpu"
Expand All @@ -120,9 +122,14 @@ def harmonize(
device_type = "cuda"
if verbose:
print("Use GPU mode.")
else:
elif torch.backends.mps.is_available():
device_type = "mps"
if verbose:
print("CUDA is not available on your machine. Use CPU mode instead.")
print("Use Metal (MPS) mode.")
elif verbose:
print(
"Neither CUDA nor MPS is available on your machine. Use CPU mode instead."
)

(stride_0, stride_1) = X.strides
if stride_0 < 0 or stride_1 < 0:
Expand Down Expand Up @@ -156,7 +163,7 @@ def harmonize(
theta = theta.view(1, -1)

assert block_proportion > 0 and block_proportion <= 1
assert correction_method in ["fast", "original"]
assert correction_method in {"fast", "original"}

np.random.seed(random_state)

Expand Down Expand Up @@ -206,13 +213,10 @@ def harmonize(

if is_convergent_harmony(objectives_harmony, tol=tol_harmony):
if verbose:
print("Reach convergence after {} iteration(s).".format(i + 1))
print(f"Reach convergence after {i + 1} iteration(s).")
break

if device_type == "cpu":
return Z_hat.numpy()
else:
return Z_hat.cpu().numpy()
return Z_hat.numpy() if device_type == "cpu" else Z_hat.cpu().numpy()


def initialize_centroids(
Expand All @@ -229,17 +233,19 @@ def initialize_centroids(
):
n_cells = Z_norm.shape[0]

kmeans_params = {'n_clusters': n_clusters,
'init': "k-means++",
'n_init': n_init,
'random_state': random_state,
'max_iter': 25,
}
kmeans_params = {
"n_clusters": n_clusters,
"init": "k-means++",
"n_init": n_init,
"random_state": random_state,
"max_iter": 25,
}

kmeans = KMeans(**kmeans_params)

from threadpoolctl import threadpool_limits
with threadpool_limits(limits = n_jobs):

with threadpool_limits(limits=n_jobs):
if device_type == "cpu":
kmeans.fit(Z_norm)
else:
Expand All @@ -249,9 +255,7 @@ def initialize_centroids(
Y_norm = normalize(Y, p=2, dim=1)

# Initialize R
R = torch.exp(
-2 / sigma * (1 - torch.matmul(Z_norm, Y_norm.t()))
)
R = torch.exp(-2 / sigma * (1 - torch.matmul(Z_norm, Y_norm.t())))
R = normalize(R, p=1, dim=1)

E = torch.matmul(Pr_b, torch.sum(R, dim=0, keepdim=True))
Expand Down Expand Up @@ -282,12 +286,11 @@ def clustering(
device_type,
n_init=10,
):

n_cells = Z_norm.shape[0]

objectives_clustering = []

for i in range(max_iter):
for _ in range(max_iter):
# Compute Cluster Centroids
Y = torch.matmul(R.t(), Z_norm)
Y_norm = normalize(Y, p=2, dim=1)
Expand All @@ -298,12 +301,8 @@ def clustering(
pos = 0
while pos < len(idx_list):
idx_in = idx_list[pos : (pos + block_size)]
R_in = R[
idx_in,
]
Phi_in = Phi[
idx_in,
]
R_in = R[idx_in,]
Phi_in = Phi[idx_in,]

# Compute O and E on left out data.
O -= torch.matmul(Phi_in.t(), R_in)
Expand Down Expand Up @@ -347,14 +346,12 @@ def correction_original(X, R, Phi, ridge_lambda, device_type):
Phi_1 = torch.cat((torch.ones(n_cells, 1, device=device_type), Phi), dim=1)

Z = X.clone()
id_mat = torch.eye(n_batches + 1, n_batches + 1, device = device_type)
id_mat = torch.eye(n_batches + 1, n_batches + 1, device=device_type)
id_mat[0, 0] = 0
Lambda = ridge_lambda * id_mat
for k in range(n_clusters):
Phi_t_diag_R = Phi_1.t() * R[:, k].view(1, -1)
inv_mat = torch.inverse(
torch.matmul(Phi_t_diag_R, Phi_1) + Lambda
)
inv_mat = torch.inverse(torch.matmul(Phi_t_diag_R, Phi_1) + Lambda)
W = torch.matmul(inv_mat, torch.matmul(Phi_t_diag_R, X))
W[0, :] = 0
Z -= torch.matmul(Phi_t_diag_R.t(), W)
Expand All @@ -375,7 +372,7 @@ def correction_fast(X, R, Phi, O, ridge_lambda, device_type):
N_k = torch.sum(O_k)

factor = 1 / (O_k + ridge_lambda)
c = N_k + torch.sum(-factor * O_k ** 2)
c = N_k + torch.sum(-factor * O_k**2)
c_inv = 1 / c

P[0, 1:] = -factor * O_k
Expand All @@ -401,7 +398,9 @@ def compute_objective(
Y_norm, Z_norm, R, theta, sigma, O, E, objective_arr, device_type
):
kmeans_error = torch.sum(R * 2 * (1 - torch.matmul(Z_norm, Y_norm.t())))
entropy_term = sigma * torch.sum(-torch.distributions.Categorical(probs=R).entropy())
entropy_term = sigma * torch.sum(
-torch.distributions.Categorical(probs=R).entropy()
)
diversity_penalty = sigma * torch.sum(
torch.matmul(theta, O * torch.log(torch.div(O + 1, E + 1)))
)
Expand Down
18 changes: 11 additions & 7 deletions harmony/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,24 +2,28 @@


def get_batch_codes(batch_mat, batch_key):
if type(batch_key) is str or len(batch_key) == 1:
if not type(batch_key) is str:
batch_key = batch_key[0]
if type(batch_key) is str:
batch_vec = batch_mat[batch_key]

elif len(batch_key) == 1:
batch_key = batch_key[0]

batch_vec = batch_mat[batch_key]

else:
df = batch_mat[batch_key].astype('str')
batch_vec = df.apply(lambda row: ','.join(row), axis = 1)
df = batch_mat[batch_key].astype("str")
batch_vec = df.apply(lambda row: ",".join(row), axis=1)

return batch_vec.astype("category")


def one_hot_tensor(X, device_type):
ids = torch.as_tensor(X.cat.codes.values.copy(), dtype = torch.long, device = device_type).view(-1, 1)
ids = torch.as_tensor(
X.cat.codes.values.copy(), dtype=torch.long, device=device_type
).view(-1, 1)
n_row = X.size
n_col = X.cat.categories.size
Phi = torch.zeros(n_row, n_col, dtype=torch.float, device = device_type)
Phi = torch.zeros(n_row, n_col, dtype=torch.float, device=device_type)
Phi.scatter_(dim=1, index=ids, value=1.0)

return Phi
5 changes: 2 additions & 3 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,12 @@
long_description = f.read()

requires = [
"torch",
"torch>=1.12",
"numpy",
"pandas",
"psutil",
"threadpoolctl",
"scikit-learn>=0.23",
"importlib_metadata>=0.7; python_version < '3.8'",
"scikit-learn>=0.23"
]

setup(
Expand Down
Loading

0 comments on commit 769a6a0

Please sign in to comment.