Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP Add custom KL loss layer HLS implementation #606

Merged
merged 20 commits into from
Feb 10, 2023

Conversation

katyagovorkova
Copy link
Contributor

Adds an implementation of the KL loss layer used for CMS Anomaly detection at L1.
Adds an example of usage of KL layer is in contrib/kl_layer.py and the HLS part is in hls4ml/templates/vivado/nnet_utils/nnet_distance.h.
The original implementation of the KL layer is available on the AE_L1_paper branch, this PR updates the implementation for the new layer API.

Type of change

  • Documentation update
  • New feature (non-breaking change which adds functionality)

Tests

The test creates a dummy Keras model which includes the KL loss layer, converts it to an hls4ml model
and synthesises it.

Test Configuration:
To run the test do python contrib/kl_layer.py

Checklist

  • I have read the guidelines for contributing.
  • I have commented my code, particularly in hard-to-understand areas.
  • I have made corresponding changes to the documentation.
  • My changes generate no new warnings.
  • I have added tests that prove my fix is effective or that my feature works.

@jmitrevs jmitrevs requested a review from gabhijith August 4, 2022 22:36
@jmitrevs jmitrevs added the please test Trigger testing by creating local PR branch label Jan 18, 2023
@jmitrevs
Copy link
Contributor

jmitrevs commented Jan 18, 2023

I think maybe all of this should go into contrib, since it's not used directly but using the extensions API. Potentially make a kl_loss directory in there with these two files + a readme to explain how to use it. I think that would be useful.

@jmitrevs readme updated!
remove trailing whitespace
@jmitrevs
Copy link
Contributor

When I run python kl_layer.py, I get:

Interpreting Model
Traceback (most recent call last):
  File "/Users/jmitrevs/work/hls4ml/contrib/kl_layer/kl_layer.py", line 189, in <module>
    test_extensions(test_root_path)
  File "/Users/jmitrevs/work/hls4ml/contrib/kl_layer/kl_layer.py", line 168, in test_extensions
    hmodel = hls4ml.converters.convert_from_keras_model(
  File "/Users/jmitrevs/work/hls4ml/hls4ml/converters/__init__.py", line 241, in convert_from_keras_model
    return keras_to_hls(config)
  File "/Users/jmitrevs/work/hls4ml/hls4ml/converters/keras_to_hls.py", line 384, in keras_to_hls
    layer_list, input_layers, output_layers = parse_keras_model(model_arch, reader)
  File "/Users/jmitrevs/work/hls4ml/hls4ml/converters/keras_to_hls.py", line 298, in parse_keras_model
    raise Exception('ERROR: Unsupported layer type: {}'.format(keras_layer['class_name']))
Exception: ERROR: Unsupported layer type: TFOpLambda

Does it succeed for you?

@vloncar
Copy link
Contributor

vloncar commented Jan 23, 2023

@jmitrevs You're probably using TF 2.8 or newer where the information about the custom layer is not embedded in the model when saving it to disk, but rather its computation graph is embedded. So when loading the model back you get these lambda ops. You can try saving and then loading the model back and printing its summary() to see if this is the cause. I found no solution for this and was hoping TF reverts to the old functionality (because this one is of dubious utility) in newer release but I didn't check if the latest TF did so.

@jmitrevs
Copy link
Contributor

The model that the test seems to parse is:

 Layer (type)                   Output Shape         Param #     Connected to                     
==================================================================================================
 input_1 (InputLayer)           [(None, 19, 3, 1)]   0           []                               
                                                                                                  
 dense_1 (Dense)                (None, 19, 3, 10)    20          ['input_1[0][0]']                
                                                                                                  
 dense (Dense)                  (None, 19, 3, 10)    20          ['input_1[0][0]']                
                                                                                                  
 tf.__operators__.add (TFOpLamb  (None, 19, 3, 10)   0           ['dense_1[0][0]']                
 da)                                                                                              
                                                                                                  
 tf.math.square (TFOpLambda)    (None, 19, 3, 10)    0           ['dense[0][0]']                  
                                                                                                  
 tf.math.subtract (TFOpLambda)  (None, 19, 3, 10)    0           ['tf.__operators__.add[0][0]',   
                                                                  'tf.math.square[0][0]']         
                                                                                                  
 tf.math.exp (TFOpLambda)       (None, 19, 3, 10)    0           ['dense_1[0][0]']                
                                                                                                  
 tf.math.subtract_1 (TFOpLambda  (None, 19, 3, 10)   0           ['tf.math.subtract[0][0]',       
 )                                                                'tf.math.exp[0][0]']            
                                                                                                  
 tf.math.reduce_mean (TFOpLambd  (None, 19, 3, 1)    0           ['tf.math.subtract_1[0][0]']     
 a)                                                                                               
                                                                                                  
 tf.math.multiply (TFOpLambda)  (None, 19, 3, 1)     0           ['tf.math.reduce_mean[0][0]']    
                                                                                                  
==================================================================================================

@jmitrevs
Copy link
Contributor

So I think it is as @vloncar says. Let's see what options we have to proceed.

@vloncar
Copy link
Contributor

vloncar commented Jan 23, 2023

Maybe add get_config() to the Keras implementation and try like that? Also try decorating the function with tf.keras.utils.register_keras_serializable

@jmitrevs
Copy link
Contributor

Is there anything we can get from the test_extensions.py, which still works?

@jmitrevs
Copy link
Contributor

jmitrevs commented Jan 23, 2023

My quick attempts with get_config and tf.keras.utils.register_keras_serializable didn't really do anything, though don't really know what I am doing, so I could have easily done it wrong.

@vloncar
Copy link
Contributor

vloncar commented Jan 23, 2023

Upon digging a bit more, turns out this implementation is problematic, not the TF version itself. Apparently we shouldn't use tensorflow.python.keras.* anymore and go with tensorflow.keras.*. I don't even remember why we used the former. With that change the issue is gone. But now the issue is how to import the required base class _Merge which has changed locations sometime between 2.7 and the current version (too lazy to investigate exactly when). So I propose we do something like:

try:
    from keras.layers.merge import _Merge as Merge
except Exception:
    from keras.layers.merging.base_merge import _Merge

contrib/kl_layer/README.md Outdated Show resolved Hide resolved
contrib/kl_layer/kl_layer.py Outdated Show resolved Hide resolved
contrib/kl_layer/kl_layer.py Outdated Show resolved Hide resolved
contrib/kl_layer/kl_layer.py Outdated Show resolved Hide resolved

// Internal info
static const unsigned table_size = 1024;
static constexpr float exp_range = 1024;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this really need to be float? It would likely have bad QoR if it is not a power-of-two integer.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what should it be instead of float?

@katyagovorkova katyagovorkova changed the title Add custom KL loss layer HLS implementation WIP Add custom KL loss layer HLS implementation Jan 25, 2023
@jmitrevs
Copy link
Contributor

I think you need these changes:

(fastml39) mac-137349:hls4ml jmitrevs$ git diff
diff --git a/contrib/kl_layer/kl_layer.py b/contrib/kl_layer/kl_layer.py
index 198fb012..318c2f46 100644
--- a/contrib/kl_layer/kl_layer.py
+++ b/contrib/kl_layer/kl_layer.py
@@ -14,7 +14,7 @@ import tensorflow as tf
 try:
     from keras.layers.merge import _Merge as Merge
 except Exception:
-    from keras.layers.merging.base_merge import _Merge
+    from keras.layers.merging.base_merge import _Merge as Merge
     
 from tensorflow.python.keras.utils import tf_utils
 from tensorflow.python.ops import math_ops
@@ -110,7 +110,7 @@ class HKLLossFunctionTemplate(hls4ml.backends.template.FunctionCallTemplate):
 
 
 # Parser for converter
-def parse_klloss_layer(keras_layer, input_names, input_shapes, data_reader, config):
+def parse_klloss_layer(keras_layer, input_names, input_shapes, data_reader):
     assert 'KLLoss' in keras_layer['class_name']
 
     layer = parse_default_keras_layer(keras_layer, input_names)

However, I still get a KeyError: 'accum_t' at the format in:

    def format(self, node):
        params = self._default_config_params(node)
        params['n_in'] = node.get_input_variable(node.inputs[0]).shape[0]
        params['n_out'] = 1
        return self.template.format(**params)

around line 90 of kl_layer.py.

@katyagovorkova
Copy link
Contributor Author

I think you need these changes:

Fixed, thanks!

However, I still get a KeyError: 'accum_t' at the format in:

    def format(self, node):
        params = self._default_config_params(node)
        params['n_in'] = node.get_input_variable(node.inputs[0]).shape[0]
        params['n_out'] = 1
        return self.template.format(**params)

around line 90 of kl_layer.py.

Ah I see, that's most likely because I have removed the Distance class as Vladimir suggested. But seems like it required more changes than just removing the class.. Also I can not test it locally since I have different error when running: Exception: Optimization pass vivado:clone_output already registered

@vloncar
Copy link
Contributor

vloncar commented Feb 10, 2023

I fixed the outstanding issues. Unfortunately, pre-commit broke for me so I couldn't run it, hence the noise until I manually made it compliant. So annoying...

Anyway, it is ready now.

Copy link
Contributor

@vloncar vloncar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, thanks Katya!

@vloncar vloncar merged commit 85b9531 into fastmachinelearning:main Feb 10, 2023
calad0i pushed a commit to calad0i/hls4ml that referenced this pull request Jul 1, 2023
)

* add kl layer

* separate hls part; clean up and add docs

* creeate KL layer folder in contrib and move the files there

* pass pre-commit check

* README and fix pre-commit issue

* update readme

* fix formatting

* add readme

* Update README.md

@jmitrevs readme updated!

* Update README.md

remove trailing whitespace

* Update kl_layer.py

* Rename nnet_distance.h to kl_layer.h

* Update README.md

* Update kl_layer.py

* Update kl_layer.h

* fix pre-commit

* Fix KLLoss layer example

---------

Co-authored-by: Jovan Mitrevski <[email protected]>
Co-authored-by: Vladimir Loncar <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
please test Trigger testing by creating local PR branch
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants