Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Stream aware outputs #5684

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
Open

Stream aware outputs #5684

wants to merge 17 commits into from

Conversation

mzient
Copy link
Contributor

@mzient mzient commented Oct 21, 2024

Category:

New feature (non-breaking change which adds functionality)
Refactoring (Redesign of existing code that doesn't affect functionality)

Description:

This PR adds support for returning the pipeline outputs as DLPack without copying.

Additional information:

Affected modules and functionalities:

This PR adds the following features:

  • stream parameters for Pipeline Run/Outputs
  • __dlpack__ interface for Tensors
    It removes the _expose_dlpack_capsule as it was incomplete.

Key points relevant for the review:

Tests:

  • Existing tests apply
  • New tests added
    • Python tests
    • GTests
    • Benchmark
    • Other
  • N/A

Checklist

Documentation

  • Existing documentation applies
  • Documentation updated
    • Docstring
    • Doxygen
    • RST
    • Jupyter
    • Other
  • N/A

DALI team only

Requirements

  • Implements new requirements
  • Affects existing requirements
  • N/A

REQ IDs: N/A

JIRA TASK: DALI-4075

mzient and others added 14 commits October 17, 2024 17:19
* Add output order handling to exec2
* Add CUDA stream to Outputs and SharedOutputs in Python bindings for
  Pipeline.
* Refactor stream pointer handling in Python

Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michał Zientkiewicz <[email protected]>
Signed-off-by: Michał Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michał Zientkiewicz <[email protected]>
Signed-off-by: Michał Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michał Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

TODO: Remove this class entirely.

dali/test/python/dlpack/test_torch.py Fixed Show fixed Hide fixed
dali/test/python/dlpack/test_torch.py Fixed Show fixed Hide fixed
dali/test/python/dlpack/test_torch.py Fixed Show fixed Hide fixed
dali/test/python/dlpack/test_torch.py Fixed Show fixed Hide fixed
dali/test/python/dlpack/test_torch.py Fixed Show fixed Hide fixed
dali/test/python/dlpack/test_torch_perf.py Fixed Show fixed Hide fixed
dali/test/python/dlpack/test_torch_perf.py Fixed Show fixed Hide fixed
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [19571767]: BUILD STARTED

Signed-off-by: Michal Zientkiewicz <[email protected]>
Signed-off-by: Michal Zientkiewicz <[email protected]>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [19572115]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [19572115]: BUILD FAILED

Signed-off-by: Michal Zientkiewicz <[email protected]>
@dali-automaton
Copy link
Collaborator

CI MESSAGE: [19600864]: BUILD STARTED

@dali-automaton
Copy link
Collaborator

CI MESSAGE: [19600864]: BUILD FAILED

Comment on lines +32 to +33
# convert the tensors in the batch to DLPack
batch = [torch.from_dlpack(t) for t in out]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: This comment is slightly confusing: "convert to DLPack", but the function is from_dlpack. Maybe rephrase it a bit to indicate that we're not really converting to DLPack, but DALI->DLPack->Torch (without a copy)

for t in batch:
means[flat_idx] = torch.mean(t)
flat_idx += 1
# those are meant to overwrite the results if synchronization fails
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

maybe check that we're actually sharing the memory:

batch_a = [torch.from_dlpack(t) for t in out]
batch_b = [torch.from_dlpack(t) for t in out]
# now change batch_b and make sure that batch_a is changed as well

assert jax_array.device() == jax.devices()[0]
assert jax_array.device == jax.devices()[0]
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A breaking change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Slipped in... That's a breaking change... in JAX 0.4.31 :\

#include "dali/core/static_switch.h"

namespace dali {

class DLTensorGraveyard {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add some docs why we need this and how it works?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will do.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants