Skip to content

Commit

Permalink
fix black lint (fairinternal/xformers#1010)
Browse files Browse the repository at this point in the history
Co-authored-by: amyyang <[email protected]>

__original_commit__ = fairinternal/xformers@de829128fff437b8da7965d856424b385bc686e9
  • Loading branch information
amylittleyang authored and xFormers Bot committed Jan 28, 2024
1 parent d15b727 commit 8c531a4
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions xformers/ops/sequence_parallel_fused_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -849,9 +849,7 @@ def my_matmul(
else:
torch.matmul(inputs[0], w.t(), out=go[src_rank])

_is_regular_matmul = all(
[not _is_fp8_dtype(w.dtype) for w in weights]
)
_is_regular_matmul = all([not _is_fp8_dtype(w.dtype) for w in weights])
fused_allgather_and_anything(
[scattered_input],
my_matmul,
Expand Down Expand Up @@ -1041,9 +1039,7 @@ def fused_linear_and_reducescatter(
scattered_outputs = [
gathered_input.new_empty(
sos,
dtype=out_dtype
if out_dtype is not None
else gathered_input.dtype,
dtype=out_dtype if out_dtype is not None else gathered_input.dtype,
)
for sos in scattered_output_shapes
]
Expand All @@ -1068,9 +1064,7 @@ def my_matmul(
else:
torch.matmul(gathered_input[dst_rank], w.t(), out=o)

_is_regular_matmul = all(
[not _is_fp8_dtype(w.dtype) for w in weights]
)
_is_regular_matmul = all([not _is_fp8_dtype(w.dtype) for w in weights])
fused_anything_and_reducescatter(
my_matmul,
scattered_outputs,
Expand Down

0 comments on commit 8c531a4

Please sign in to comment.