Skip to content

Commit

Permalink
Add support for Keras >=2.2.0
Browse files Browse the repository at this point in the history
  • Loading branch information
cangermueller committed Dec 3, 2018
1 parent b50cc27 commit 8490679
Show file tree
Hide file tree
Showing 8 changed files with 42 additions and 25 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
__pycache__/*
.cache/*
.ipynb_checkpoints
.pytest_cache

# Project files
.ropeproject
Expand Down
5 changes: 5 additions & 0 deletions README.rst
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ Table of contents
News
====

* **181201**: DeepCpG 1.0.7 released!
* **180224**: DeepCpG 1.0.6 released!
* **171112**: Keras 2 is now the main Keras version (release 1.0.5).
* **170412**: New `notebook <./examples/notebooks/stats/index.ipynb>`_ on predicting inter-cell statistics!
Expand Down Expand Up @@ -213,6 +214,10 @@ Content
Changelog
=========

1.0.7
-----
* Add support for Keras >=2.2.0.

1.0.6
-----
* Add support for Keras 2.1.4 and Tensorflow 1.5.0
Expand Down
2 changes: 1 addition & 1 deletion deepcpg/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.6'
__version__ = '1.0.7'
13 changes: 12 additions & 1 deletion deepcpg/models/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from os import path as pt

import keras
from keras import backend as K
from keras import models as km
from keras import layers as kl
Expand Down Expand Up @@ -410,6 +411,16 @@ def copy_weights(src_model, dst_model, must_exist=True):
return copied


def is_input_layer(layer):
"""Test if `layer` is an input layer."""
return isinstance(layer, keras.engine.input_layer.InputLayer)


def is_output_layer(layer, model):
"""Test if `layer` is an output layer."""
return layer.name in model.output_names


class Model(object):
"""Abstract model call.
Expand Down Expand Up @@ -445,7 +456,7 @@ def _build(self, input, output):
model = km.Model(input, output, name=self.name)
if self.scope:
for layer in model.layers:
if layer not in model.input_layers:
if not is_input_layer(layer):
layer.name = '%s/%s' % (self.scope, layer.name)
return model

Expand Down
4 changes: 2 additions & 2 deletions docs/source/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,9 @@
# built documents.
#
# The short X.Y version.
version = '1.0.6'
version = '1.0.7'
# The full version, including alpha/beta/rc tags.
release = '1.0.6'
release = '1.0.7'

# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
Expand Down
6 changes: 3 additions & 3 deletions scripts/dcpg_snp.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,9 +144,9 @@ def main(self, name, opts):

# Get DNA layer.
dna_layer = None
for i, name in enumerate(model.input_names):
if name == 'dna':
dna_layer = model.input_layers[i]
for layer in model.layers:
if layer.name == 'dna':
dna_layer = layer
break
if not dna_layer:
raise ValueError('The provided model is not a DNA model!')
Expand Down
34 changes: 17 additions & 17 deletions scripts/dcpg_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,7 @@
from deepcpg import data as dat
from deepcpg import metrics as met
from deepcpg import models as mod
from deepcpg.models.utils import is_input_layer, is_output_layer
from deepcpg.data import hdf, OUTPUT_SEP
from deepcpg.utils import format_table, make_dir, EPS

Expand All @@ -86,7 +87,7 @@


def remove_outputs(model):
while model.layers[-1] in model.output_layers:
while is_output_layer(model.layers[-1], model):
model.layers.pop()
model.outputs = [model.layers[-1].output]
model.layers[-1].outbound_nodes = []
Expand All @@ -97,7 +98,7 @@ def rename_layers(model, scope=None):
if not scope:
scope = model.scope
for layer in model.layers:
if layer in model.input_layers or layer.name.startswith(scope):
if is_input_layer(layer) or layer.name.startswith(scope):
continue
layer.name = '%s/%s' % (scope, layer.name)

Expand Down Expand Up @@ -595,7 +596,7 @@ def build_model(self):
remove_outputs(stem)

outputs = mod.add_output_layers(stem.outputs[0], output_names)
model = Model(stem.inputs, outputs, stem.name)
model = Model(inputs=stem.inputs, outputs=outputs, name=stem.name)
return model

def set_trainability(self, model):
Expand All @@ -622,17 +623,18 @@ def set_trainability(self, model):
table['layer'] = []
table['trainable'] = []
for layer in model.layers:
if layer not in model.input_layers + model.output_layers:
if not hasattr(layer, 'trainable'):
continue
for regex in not_trainable:
if re.match(regex, layer.name):
layer.trainable = False
for regex in trainable:
if re.match(regex, layer.name):
layer.trainable = True
table['layer'].append(layer.name)
table['trainable'].append(layer.trainable)
if is_input_layer(layer) or is_output_layer(layer, model):
continue
if not hasattr(layer, 'trainable'):
continue
for regex in not_trainable:
if re.match(regex, layer.name):
layer.trainable = False
for regex in trainable:
if re.match(regex, layer.name):
layer.trainable = True
table['layer'].append(layer.name)
table['trainable'].append(layer.trainable)
print('Layer trainability:')
print(format_table(table))
print()
Expand Down Expand Up @@ -713,9 +715,7 @@ def main(self, name, opts):
mod.save_model(model, os.path.join(opts.out_dir, 'model.json'))

log.info('Computing output statistics ...')
output_names = []
for output_layer in model.output_layers:
output_names.append(output_layer.name)
output_names = model.output_names

output_stats = OrderedDict()

Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ def read(fname):


setup(name='deepcpg',
version='1.0.6',
version='1.0.7',
description='Deep learning for predicting CpG methylation',
long_description=read('README.rst'),
author='Christof Angermueller',
Expand Down

0 comments on commit 8490679

Please sign in to comment.