-
Notifications
You must be signed in to change notification settings - Fork 30
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Clean up model input signature handling for tuples #152
Conversation
Documentation preview |
Click to view CI ResultsGitHub pull request #152 of commit 41bd6d7e424ed099d0922a5c35eb7bc98a2ef63d, no merge conflicts. Running as SYSTEM Setting status of 41bd6d7e424ed099d0922a5c35eb7bc98a2ef63d to PENDING with url https://10.20.13.93:8080/job/merlin_systems/162/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_systems using credential fce1c729-5d7c-48e8-90cb-b0c314b1076e > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/systems # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/systems > git --version # timeout=10 using GIT_ASKPASS to set credentials login for merlin-systems user + githubtoken > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/systems +refs/pull/152/*:refs/remotes/origin/pr/152/* # timeout=10 > git rev-parse 41bd6d7e424ed099d0922a5c35eb7bc98a2ef63d^{commit} # timeout=10 Checking out Revision 41bd6d7e424ed099d0922a5c35eb7bc98a2ef63d (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f 41bd6d7e424ed099d0922a5c35eb7bc98a2ef63d # timeout=10 Commit message: "Clean up model input signature handling for tuples" > git rev-list --no-walk ca721c01210ebf35bfc39cb058192a3c08fff8ac # timeout=10 [merlin_systems] $ /bin/bash /tmp/jenkins15849634275113900591.sh PYTHONPATH=:/usr/local/lib/python3.8/dist-packages/:/usr/local/hugectr/lib:/var/jenkins_home/workspace/merlin_systems/systems ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.2, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_systems/systems, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-3.0.0 collected 50 items |
Click to view CI ResultsGitHub pull request #152 of commit a1c662b2c7770cbeba84910f9c75c5abce939a70, no merge conflicts. Running as SYSTEM Setting status of a1c662b2c7770cbeba84910f9c75c5abce939a70 to PENDING with url https://10.20.13.93:8080/job/merlin_systems/165/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_systems using credential fce1c729-5d7c-48e8-90cb-b0c314b1076e > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/systems # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/systems > git --version # timeout=10 using GIT_ASKPASS to set credentials login for merlin-systems user + githubtoken > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/systems +refs/pull/152/*:refs/remotes/origin/pr/152/* # timeout=10 > git rev-parse a1c662b2c7770cbeba84910f9c75c5abce939a70^{commit} # timeout=10 Checking out Revision a1c662b2c7770cbeba84910f9c75c5abce939a70 (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f a1c662b2c7770cbeba84910f9c75c5abce939a70 # timeout=10 Commit message: "Fix TF output dimensions for fixed size embedding outputs" > git rev-list --no-walk 48ddd15ba3d24196c2987b2ae8cadb2e0856f11c # timeout=10 [merlin_systems] $ /bin/bash /tmp/jenkins8635281584176530108.sh PYTHONPATH=:/usr/local/lib/python3.8/dist-packages/:/usr/local/hugectr/lib:/var/jenkins_home/workspace/merlin_systems/systems ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.2, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_systems/systems, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-3.0.0 collected 50 items |
|
||
def _add_model_param(params, paramclass, col_schema, dims=None): | ||
dims = dims if dims is not None else [-1, 1] | ||
if col_schema.is_list and col_schema.is_ragged: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could we update test_tf_op_exports_own_config
to give it a model with a tupled input as well so this new logic gets covered?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added a new test to cover the schema inference from the model file
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
lgtm aside from the one useful regression test.
I assume this _add_model_param
function is living in operator.py
instead of tensorflow.py
because we expect to use it for other model frameworks as well (besides FIL), is that right?
for _, col_schema in self.input_schema.column_schemas.items(): | ||
_add_model_param(config.input, model_config.ModelInput, col_schema) | ||
|
||
for _, col_schema in self.output_schema.column_schemas.items(): | ||
_add_model_param(config.output, model_config.ModelOutput, col_schema) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This logic would likely also need to be used in PipelineableInferenceOperator.export
systems/merlin/systems/dag/ops/operator.py
Lines 259 to 274 in 572dbf1
for col_name, col_dict in _schema_to_dict(input_schema).items(): | |
config.input.append( | |
model_config.ModelInput( | |
name=col_name, data_type=_convert_dtype(col_dict["dtype"]), dims=[-1, -1] | |
) | |
) | |
for col_name, col_dict in _schema_to_dict(output_schema).items(): | |
# this assumes the list columns are 1D tensors both for cats and conts | |
config.output.append( | |
model_config.ModelOutput( | |
name=col_name, | |
data_type=_convert_dtype(col_dict["dtype"]), | |
dims=[-1, -1], | |
) | |
) |
Click to view CI ResultsGitHub pull request #152 of commit 19430bb0a1e1c3b03a075f6bf3b9721d5abb99c7, no merge conflicts. Running as SYSTEM Setting status of 19430bb0a1e1c3b03a075f6bf3b9721d5abb99c7 to PENDING with url https://10.20.13.93:8080/job/merlin_systems/170/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_systems using credential fce1c729-5d7c-48e8-90cb-b0c314b1076e > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/systems # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/systems > git --version # timeout=10 using GIT_ASKPASS to set credentials login for merlin-systems user + githubtoken > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/systems +refs/pull/152/*:refs/remotes/origin/pr/152/* # timeout=10 > git rev-parse 19430bb0a1e1c3b03a075f6bf3b9721d5abb99c7^{commit} # timeout=10 Checking out Revision 19430bb0a1e1c3b03a075f6bf3b9721d5abb99c7 (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f 19430bb0a1e1c3b03a075f6bf3b9721d5abb99c7 # timeout=10 Commit message: "Add a test for TF tuples of tensor specs" > git rev-list --no-walk 1f34a8c694bb69a73def361db2a91b7dda9bceef # timeout=10 [merlin_systems] $ /bin/bash /tmp/jenkins11668797913824841673.sh PYTHONPATH=:/usr/local/lib/python3.8/dist-packages/:/usr/local/hugectr/lib:/var/jenkins_home/workspace/merlin_systems/systems ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.2, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_systems/systems, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-3.0.0 collected 51 items |
Click to view CI ResultsGitHub pull request #152 of commit 7adeae1aad75a1f00a0ca875243b34b1fa7b8df0, no merge conflicts. Running as SYSTEM Setting status of 7adeae1aad75a1f00a0ca875243b34b1fa7b8df0 to PENDING with url https://10.20.13.93:8080/job/merlin_systems/171/console and message: 'Pending' Using context: Jenkins Building on master in workspace /var/jenkins_home/workspace/merlin_systems using credential fce1c729-5d7c-48e8-90cb-b0c314b1076e > git rev-parse --is-inside-work-tree # timeout=10 Fetching changes from the remote Git repository > git config remote.origin.url https://github.com/NVIDIA-Merlin/systems # timeout=10 Fetching upstream changes from https://github.com/NVIDIA-Merlin/systems > git --version # timeout=10 using GIT_ASKPASS to set credentials login for merlin-systems user + githubtoken > git fetch --tags --force --progress -- https://github.com/NVIDIA-Merlin/systems +refs/pull/152/*:refs/remotes/origin/pr/152/* # timeout=10 > git rev-parse 7adeae1aad75a1f00a0ca875243b34b1fa7b8df0^{commit} # timeout=10 Checking out Revision 7adeae1aad75a1f00a0ca875243b34b1fa7b8df0 (detached) > git config core.sparsecheckout # timeout=10 > git checkout -f 7adeae1aad75a1f00a0ca875243b34b1fa7b8df0 # timeout=10 Commit message: "Merge branch 'main' into fix/tf-input-shapes" > git rev-list --no-walk 19430bb0a1e1c3b03a075f6bf3b9721d5abb99c7 # timeout=10 [merlin_systems] $ /bin/bash /tmp/jenkins5669849440324707847.sh PYTHONPATH=:/usr/local/lib/python3.8/dist-packages/:/usr/local/hugectr/lib:/var/jenkins_home/workspace/merlin_systems/systems ============================= test session starts ============================== platform linux -- Python 3.8.10, pytest-7.1.2, pluggy-1.0.0 rootdir: /var/jenkins_home/workspace/merlin_systems/systems, configfile: pyproject.toml plugins: anyio-3.6.1, xdist-2.5.0, forked-1.4.0, cov-3.0.0 collected 52 items |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
🚀
Partial fix for NVIDIA-Merlin/Merlin#361 (comment)