-
Notifications
You must be signed in to change notification settings - Fork 448
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
chore: remove PyTorch 2.5.0 checks #1877
Changes from all commits
b91b1f1
da5178f
e2cfa3d
4408138
9970128
4027c2e
37d5a01
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -84,7 +84,7 @@ def test_packed_block_causal_mask_sdpa(self, seq_lens): | |
|
||
@pytest.mark.skipif( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ok looks like we need to keep this check, in case the hardware that runs the gpu tests on GitHub CI does not support flex attention |
||
not _SUPPORTS_FLEX_ATTENTION, | ||
reason="Please install a nightly build of torch (>=2.5.0) to run this test.", | ||
reason="Hardware does not support Flex Attention.", | ||
) | ||
@gpu_test(gpu_count=1) | ||
def test_packed_block_causal_mask_flex(self): | ||
|
@@ -100,7 +100,7 @@ def test_packed_block_causal_mask_flex(self): | |
class TestSDPAOrFlexAttention: | ||
@pytest.mark.skipif( | ||
not _SUPPORTS_FLEX_ATTENTION, | ||
reason="Please install a nightly build of torch (>=2.5.0) to run this test.", | ||
reason="Hardware does not support Flex Attention.", | ||
) | ||
@mock.patch("torchtune.modules.attention_utils.compile_friendly_flex_attention") | ||
@mock.patch( | ||
|
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -5,7 +5,7 @@ | |||||
# LICENSE file in the root directory of this source tree. | ||||||
|
||||||
import contextlib | ||||||
from typing import Optional, Union | ||||||
from typing import Union | ||||||
from warnings import warn | ||||||
|
||||||
import psutil | ||||||
|
@@ -38,9 +38,9 @@ class OffloadActivations(saved_tensors_hooks): | |||||
memory on the CPU. Pinned memory allows the Tensor to be moved back onto GPU more quickly | ||||||
but is a limited resource. Default: True. | ||||||
|
||||||
use_streams (Optional[bool]): Whether or not to use streams for performance optimization where | ||||||
use_streams (bool): Whether or not to use streams for performance optimization where | ||||||
the communications get overlapped with the computation. Requires a torch build | ||||||
after torch-2.5.0.dev20240907. Default: True if a later torch build is found, else False. | ||||||
after torch-2.5.0.]. Default: True. | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bumping this |
||||||
|
||||||
max_fwd_stash_size (int): The maximum size of the forward stash, or the maximum number of | ||||||
consecutive activations to keep alive during the forward pass. This number must be at | ||||||
|
@@ -67,15 +67,12 @@ class OffloadActivations(saved_tensors_hooks): | |||||
def __init__( | ||||||
self, | ||||||
use_pin_memory: bool = True, | ||||||
use_streams: Optional[bool] = None, | ||||||
use_streams: bool = True, | ||||||
max_fwd_stash_size: int = 5, | ||||||
min_offload_size: int = 1024, | ||||||
) -> None: | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. let's remove the check below and make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. are you referring to this if use_streams is None:
# Default to True if an acceptable torch is installed (later nightly/version or from source)
self.use_streams = torch.__version__ >= "2.5.0.dev20240907"
else:
self.use_streams = use_streams or should it be changed like this, where if use_streams is False:
# Default to True if an acceptable torch is installed (later nightly/version or from source)
self.use_streams = torch.__version__ >= "2.5.0.dev20240907"
else:
self.use_streams = use_streams There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yep! it would just be:
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have noticed that # for streaming
if self.use_streams:
if torch.__version__ < "2.5.0.dev20240907":
raise RuntimeError(
"OffloadActivations with use_streams=True requires PyTorch 2.5.0.dev20240907 or later."
) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, good catch. Let's remove that as well. I believe these may have been added after I put the issue up, or maybe I just missed it There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. sounds good, will update all |
||||||
if use_streams is None: | ||||||
# Default to True if an acceptable torch is installed (later nightly/version or from source) | ||||||
self.use_streams = torch.__version__ >= "2.5.0.dev20240907" | ||||||
else: | ||||||
self.use_streams = use_streams | ||||||
|
||||||
self.use_streams: bool = use_streams | ||||||
|
||||||
self.min_tensor_size_bytes = ( | ||||||
min_offload_size # we don't want to bother with small tensors | ||||||
|
@@ -98,10 +95,6 @@ def __init__( | |||||
|
||||||
# for streaming | ||||||
if self.use_streams: | ||||||
if torch.__version__ < "2.5.0.dev20240907": | ||||||
raise RuntimeError( | ||||||
"OffloadActivations with use_streams=True requires PyTorch 2.5.0.dev20240907 or later." | ||||||
) | ||||||
self.s1 = torch.cuda.Stream() # comms stream | ||||||
self.fwd_stash = {} # tensor_id => (activation, ev1) | ||||||
if max_fwd_stash_size < 1: | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
bumping this
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we need a merge here? The docs in main read:
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I believe we can just revert the changes here to leave the file as-as since it's been updated in another PR @JP-sDEV