Skip to content
This repository has been archived by the owner on Jan 3, 2024. It is now read-only.

Commit

Permalink
Misc classification typing (#106)
Browse files Browse the repository at this point in the history
* Misc classification typing

* Use numpy/dask array type

* Use types.array in cube_generator
  • Loading branch information
dstansby authored Mar 30, 2023
1 parent f1188df commit 721ccc7
Show file tree
Hide file tree
Showing 5 changed files with 62 additions and 58 deletions.
12 changes: 6 additions & 6 deletions src/cellfinder_core/classify/classify.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
from typing import Callable, Optional
from typing import Callable, List, Optional

import numpy as np
from imlib.general.system import get_num_processes
from tensorflow import keras

from cellfinder_core import logger
from cellfinder_core import logger, types
from cellfinder_core.classify.cube_generator import CubeGeneratorFromFile
from cellfinder_core.classify.tools import get_model
from cellfinder_core.train.train_yml import models


def main(
points,
signal_array,
background_array,
n_free_cpus,
signal_array: types.array,
background_array: types.array,
n_free_cpus: int,
voxel_sizes,
network_voxel_sizes,
batch_size,
Expand All @@ -27,7 +27,7 @@ def main(
max_workers=3,
*,
callback: Optional[Callable[[int], None]] = None,
):
) -> List:
"""
Parameters
----------
Expand Down
36 changes: 19 additions & 17 deletions src/cellfinder_core/classify/cube_generator.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from random import shuffle
from typing import Optional, Tuple

import numpy as np
import tensorflow as tf
Expand All @@ -8,6 +9,7 @@
from skimage.io import imread
from tensorflow.keras.utils import Sequence

from cellfinder_core import types
from cellfinder_core.classify.augment import AugmentationParameters, augment

# TODO: rename, as now using dask arrays -
Expand All @@ -33,26 +35,26 @@ class CubeGeneratorFromFile(Sequence):
def __init__(
self,
points,
signal_array,
background_array,
signal_array: types.array,
background_array: types.array,
voxel_sizes,
network_voxel_sizes,
batch_size=16,
cube_width=50,
cube_height=50,
cube_depth=20,
channels=2, # No other option currently
classes=2,
extract=False,
train=False,
augment=False,
augment_likelihood=0.1,
flip_axis=[0, 1, 2],
rotate_max_axes=[1, 1, 1], # degrees
batch_size: Optional[int] = 16,
cube_width: Optional[int] = 50,
cube_height: Optional[int] = 50,
cube_depth: Optional[int] = 20,
channels: Optional[int] = 2, # No other option currently
classes: Optional[int] = 2,
extract: Optional[bool] = False,
train: Optional[bool] = False,
augment: Optional[bool] = False,
augment_likelihood: Optional[float] = 0.1,
flip_axis: Tuple[int, int, int] = (0, 1, 2),
rotate_max_axes: Tuple[float, float, float] = (1, 1, 1), # degrees
# scale=[0.5, 2], # min, max
translate=[0.05, 0.05, 0.05],
shuffle=False,
interpolation_order=2,
translate: Tuple[float, float, float] = (0.05, 0.05, 0.05),
shuffle: Optional[bool] = False,
interpolation_order: Optional[int] = 2,
):
self.points = points
self.signal_array = signal_array
Expand Down
59 changes: 26 additions & 33 deletions src/cellfinder_core/classify/resnet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
from typing import Dict, List, Literal, Tuple, Union

from tensorflow.keras import Model
from tensorflow.keras.initializers import Initializer
from tensorflow.keras.layers import (
Activation,
Add,
Expand All @@ -15,15 +18,19 @@
#####################################################################
# Define the types of ResNet

resnet_unit_blocks = {
layer_type = Literal[
"18-layer", "34-layer", "50-layer", "101-layer", "152-layer"
]

resnet_unit_blocks: Dict[layer_type, List[int]] = {
"18-layer": [2, 2, 2, 2],
"34-layer": [3, 4, 6, 3],
"50-layer": [3, 4, 6, 3],
"101-layer": [3, 4, 23, 3],
"152-layer": [3, 6, 36, 3],
}

network_residual_bottleneck = {
network_residual_bottleneck: Dict[layer_type, bool] = {
"18-layer": False,
"34-layer": False,
"50-layer": True,
Expand All @@ -34,31 +41,17 @@


def build_model(
shape=(50, 50, 20, 2),
network_depth="18-layer",
shape: Tuple[int, int, int, int] = (50, 50, 20, 2),
network_depth: layer_type = "18-layer",
optimizer=None,
learning_rate=0.0005, # higher rates don't always converge
learning_rate: float = 0.0005, # higher rates don't always converge
loss="categorical_crossentropy",
metrics=["accuracy"],
number_classes=2,
axis=3,
starting_features=64,
number_classes: int = 2,
axis: int = 3,
starting_features: int = 64,
classification_activation="softmax",
):
"""
:param shape:
:param network_depth:
:param optimizer:
:param learning_rate:
:param loss:
:param metrics:
:param number_classes:
:param int axis: Default: 3. Assumed channels are last
:param starting_features: # increases in each set of residual units
:param classification_activation:
:return:
"""
) -> Model:
blocks, bottleneck = get_resnet_blocks_and_bottleneck(network_depth)

inputs = Input(shape)
Expand Down Expand Up @@ -94,7 +87,7 @@ def build_model(
return model


def get_resnet_blocks_and_bottleneck(network_depth):
def get_resnet_blocks_and_bottleneck(network_depth: layer_type):
"""
Parses dicts, and returns how many resnet blocks are in each unit, along
with whether they are bottlneck blocks or not
Expand Down Expand Up @@ -304,14 +297,14 @@ def f(x):

def get_shortcut(
inputs,
resnet_unit_label,
block_id,
features,
stride,
use_bias=False,
kernel_initializer="he_normal",
bn_epsilon=1e-5,
axis=3,
resnet_unit_label: int,
block_id: int,
features: int,
stride: int,
use_bias: bool = False,
kernel_initializer: Union[str, Initializer] = "he_normal",
bn_epsilon: float = 1e-5,
axis: int = 3,
):
"""
Create shortcut. For none-bottleneck residual units, this is just the
Expand Down Expand Up @@ -351,7 +344,7 @@ def get_shortcut(
return inputs


def get_stride(resnet_unit_id, block_id):
def get_stride(resnet_unit_id: int, block_id: int) -> int:
"""
Determines the convolution stride.
Expand Down
7 changes: 5 additions & 2 deletions src/cellfinder_core/classify/tools.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os
from typing import Optional

import numpy as np
import tensorflow as tf

Expand All @@ -6,8 +9,8 @@


def get_model(
existing_model=None,
model_weights=None,
existing_model: Optional[os.PathLike] = None,
model_weights: Optional[os.PathLike] = None,
network_depth=None,
learning_rate=0.0001,
inference=False,
Expand Down
6 changes: 6 additions & 0 deletions src/cellfinder_core/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
from typing import Union

import dask.array as da
import numpy as np

array = Union[da.Array, np.ndarray]

0 comments on commit 721ccc7

Please sign in to comment.