Skip to content

Commit

Permalink
add working control method
Browse files Browse the repository at this point in the history
  • Loading branch information
KaiWaldrant committed Jul 10, 2024
1 parent 862d48d commit 2c81f68
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,21 @@ __merge__: ../../api/comp_control_method.yaml

# A unique identifier for your component (required).
# Can contain only lowercase letters or underscores.
name: logistic_regression
name: true_labels

# Metadata for your component
info:
# A relatively short label, used when rendering visualisations (required)
label: Logistic Regression
label: True Labels
# A one sentence summary of how this method works (required). Used when
# rendering summary tables.
summary: "Logistic Regression with 100-dimensional PCA coordinates estimates parameters for multivariate classification by minimizing cross entropy loss over cell type classes."
summary: "a positive control, solution labels are copied 1 to 1 to the predicted data."
# A multi-line description of how this component works (required). Used
# when rendering reference documentation.
description: |
Logistic Regression estimates parameters of a logistic function for
multivariate classification tasks. Here, we use 100-dimensional whitened PCA
coordinates as independent variables, and the model minimises the cross
entropy loss over all cell type classes.
reference: "hosmer2013applied"
repository_url: https://github.com/scikit-learn/scikit-learn
documentation_url: "https://scikit-learn.org/stable/modules/generated/sklearn.linear_model.LogisticRegression.html"
# Which normalization method this component prefers to use (required).
preferred_normalization: log_cp10k
A positive control, where the solution labels are copied 1 to 1 to the predicted data.
# Which normalisation method this component prefers to use (required).
preferred_normalization: counts

# Component-specific parameters (optional)
# arguments:
Expand All @@ -53,14 +46,14 @@ engines:
image: ghcr.io/openproblems-bio/base_images/python:1.1.0
# Add custom dependencies here (optional). For more information, see
# https://viash.io/reference/config/engines/docker/#setup .
setup:
- type: python
packages: scikit-learn
# setup:
# - type: python
# packages: scib==1.1.5

runners:
# This platform allows running the component natively
- type: executable
# Allows turning the component into a Nextflow module / pipeline.
- type: nextflow
directives:
label: [midtime,midmem,lowcpu]
label: [midtime,lowmem,lowcpu]
Original file line number Diff line number Diff line change
@@ -1,43 +1,43 @@
import anndata as ad
import sklearn.linear_model

## VIASH START
# Note: this section is auto-generated by viash at runtime. To edit it, make changes
# in config.vsh.yaml and then run `viash config inject config.vsh.yaml`.
par = {
'input_train': 'resources_test/task_template/pancreas/train.h5ad',
'input_test': 'resources_test/task_template/pancreas/test.h5ad',
'input_solution': 'resources_test/task_template/pancreas/solution.h5ad',
'output': 'output.h5ad'
}
meta = {
'name': 'logistic_regression'
'name': 'true_labels'
}
## VIASH END

print('Reading input files', flush=True)
input_train = ad.read_h5ad(par['input_train'])
input_test = ad.read_h5ad(par['input_test'])
input_solution = ad.read_h5ad(par['input_solution'])

print('Preprocess data', flush=True)
# ... preprocessing ...

print('Train model', flush=True)
# ... train model ...
classifier = sklearn.linear_model.LogisticRegression()
classifier.fit(input_train.obsm["X_pca"], input_train.obs["label"].astype(str))

print('Generate predictions', flush=True)
# ... generate predictions ...
obs = classifier.predict(input_test.obsm["X_pca"])
obs_label_pred = input_solution.obs["label"]

print("Write output AnnData to file", flush=True)
output = ad.AnnData(
uns={
'dataset_id': input_train.uns['dataset_id'],
'normalization_id': input_train.uns['normalization_id'],
'method_id': meta['name']
},
obs={
'label_pred': obs
'label_pred': obs_label_pred
}
)
output.write_h5ad(par['output'], compression='gzip')

0 comments on commit 2c81f68

Please sign in to comment.