Skip to content

Commit

Permalink
Testing out many ideas, not all successful
Browse files Browse the repository at this point in the history
  • Loading branch information
blefaudeux committed Oct 20, 2021
1 parent 347745d commit 618697b
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 42 deletions.
Binary file modified docs/plots/FusedLinear_fp16_FW.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified docs/plots/FusedLinear_fp16_FW_BW.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
4 changes: 2 additions & 2 deletions xformers/benchmarks/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,8 +69,8 @@ def pretty_plot(results, title, units: str, filename=None, dash_key=""):

# Make sure that the plot is big enough
f = plt.figure()
f.set_figwidth(6)
f.set_figheight(6)
f.set_figwidth(10)
f.set_figheight(10)

# Display the collections
for k, v in workloads.items():
Expand Down
90 changes: 50 additions & 40 deletions xformers/triton/k_fused_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,9 +43,9 @@ def kernel_fma(
# The stride variables represent how much to increase the ptr by when moving by 1
# element in a particular dimension. E.g. stride_am is how much to increase a_ptr
# by to get the element one row down (A has M rows)
stride_dm, stride_am,
stride_m,
stride_om,
stride_wn, stride_wk,
stride_im,
# Meta-parameters
**META,
):
Expand Down Expand Up @@ -90,16 +90,16 @@ def kernel_fma(
# for rows (resp. col) of C
rm = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
rn = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)

# rk denotes a range of indices for columns
# (resp. rows) of A (resp. B)
rk = tl.arange(0, BLOCK_K)

# the memory addresses of elements in the first block of
# A and W can be computed using numpy-style broadcasting
D += rm[:, None] * stride_dm + rn[None, :]
A += rm[:, None] * stride_am + rk[None, :]
W += rn[None, :] * stride_wn + rk[:, None] * stride_wk
a_ptrs = A + rm[:, None] * stride_m + rk[None, :]
w_ptrs = W + rn[None, :] * stride_wn + rk[:, None] * stride_wk

mask_mk = (rm[:, None] < M) & (rk[None, :] < K)
mask_kn = (rk[:, None] < K) & (rn[None, :] < N)
mask_mn = (rm[:, None] < M) & (rn[None, :] < N)

# initialize and iteratively update accumulator
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
Expand All @@ -111,31 +111,32 @@ def kernel_fma(
for _ in range(K, 0, -BLOCK_K):
# load then increment pointers so that the next blocks of A and B
# are loaded during the next iteration
a = tl.load(A, mask=(rm[:, None] < M) & (rk[None, :] < K))
A += BLOCK_K
a = tl.load(a_ptrs, mask=mask_mk, other=0.0)
a_ptrs += BLOCK_K

w = tl.load(W, mask=(rn[None, :] < N) & (rk[:, None] < K))
W += BLOCK_K * stride_wk
w = tl.load(w_ptrs, mask=mask_kn, other=0.0)
w_ptrs += BLOCK_K * stride_wk

# block level matrix multiplication
acc += tl.dot(a, w)

# optional: save the activation inputs
if META["SAVE_ACT_INPUTS"]:
In += rm[:, None] * stride_im + rn[None, :]
tl.store(In, acc, mask=(rm[:, None] < M) & (rn[None, :] < N))
In += rm[:, None] * stride_om + rn[None, :]
tl.store(In, acc, mask=mask_mn)

# optional: fused activation (while the data is in shared memory)
if META["ACTIVATION"]:
acc = META["ACTIVATION"](acc)

# write back result
tl.store(D, acc, mask=(rm[:, None] < M) & (rn[None, :] < N))
d_ptrs = D + rm[:, None] * stride_om + rn[None, :]
tl.store(d_ptrs, acc, mask=mask_mn)


def _sanitize_inputs(x, weight, bias):
assert (
x.shape[1] == weight.shape[1]
x.shape[-1] == weight.shape[1]
), f"Incompatible dimensions in between inputs and weight, {x.shape} - {weight.shape}"
assert bias is None or bias.is_contiguous()
assert (
Expand Down Expand Up @@ -174,6 +175,8 @@ def grid(META):
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(N, META["BLOCK_N"]),
)

assert outputs.stride(0) == act_inputs.stride(0), f"{outputs.stride()} - {act_inputs.stride()}"

# fmt: off
kernel_fma[grid](
# data ptrs
Expand All @@ -183,10 +186,8 @@ def grid(META):
# shapes
M, N, K,
# strides
outputs.stride(0),
x_.stride(0),
x_.stride(0), outputs.stride(0),
weight.stride(0), weight.stride(1),
act_inputs.stride(0),
# optional fused activation
ACTIVATION=activation,
# optional fused bias
Expand All @@ -200,7 +201,7 @@ def grid(META):
# fmt: on

if x.ndim == 3:
outputs = outputs.reshape(x.shape[0], x.shape[1], N)
outputs = outputs.reshape((x.shape[0], x.shape[1], -1))

return (outputs, act_inputs) if save_inputs else (outputs, None)

Expand Down Expand Up @@ -228,9 +229,9 @@ def kernel_grad_inputs(
# Tensor dimensions
M, N, K,
# strides for all the gradients
stride_gim, stride_gam, stride_gom,
stride_m, stride_im,
# strides for the extra data
stride_aim, stride_wn, stride_wk,
stride_wn, stride_wk,
# Meta-parameters
**META,
):
Expand Down Expand Up @@ -279,30 +280,31 @@ def kernel_grad_inputs(
rn = tl.arange(0, BLOCK_N)

# memory blocks can be computed using numpy-style broadcasting
grad_out_ptrs = GRAD_OUT + rm[:, None] * stride_gom + rn[None, :]
grad_act_ptrs = GRAD_ACT + rm[:, None] * stride_gam + rn[None, :]
act_in_ptrs = ACT_IN + rm[:, None] * stride_aim + rn[None, :]
grad_in_ptrs = GRAD_IN + rm[:, None] * stride_gim + rk[None, :]
grad_out_ptrs = GRAD_OUT + rm[:, None] * stride_m + rn[None, :]
grad_act_ptrs = GRAD_ACT + rm[:, None] * stride_m + rn[None, :]
act_in_ptrs = ACT_IN + rm[:, None] * stride_im + rn[None, :]
w_ptrs = WEIGHT + rn[:, None] * stride_wn + rk[None, :] * stride_wk

# initialize and iteratively update accumulator
grad_in_acc = tl.zeros((BLOCK_M, BLOCK_K), dtype=tl.float32)
act_grad_fn = META["ACTIVATION_GRAD"]

mask_mk = (rm[:, None] < M) & (rk[None, :] < K)
mask_mn = (rm[:, None] < M) & (rn[None, :] < N)
mask_nk = (rn[:, None] < N) & (rk[None, :] < K)

for _ in range(N, 0, -BLOCK_N):
grad_out = tl.load(grad_out_ptrs)
grad_out = tl.load(grad_out_ptrs, mask=mask_mn, other=0.0)
grad_out_ptrs += BLOCK_N

w = tl.load(w_ptrs)
w = tl.load(w_ptrs, mask=mask_nk, other=0.0)
w_ptrs += BLOCK_N * stride_wn

# optional fused activation gradient (while the data is in shared memory)
if META["ACTIVATION_GRAD"]:
if META["ACTIVATION_GRAD_REQ_INPUTS"]:
# This activation requires its inputs
act_input = tl.load(act_in_ptrs)
act_input = tl.load(act_in_ptrs, mask=mask_mn, other=0.0)
act_in_ptrs += BLOCK_N

grad_act = act_grad_fn(act_input)
Expand All @@ -323,6 +325,7 @@ def kernel_grad_inputs(

# write back result
# automatic type promotion/downgrade
grad_in_ptrs = GRAD_IN + rm[:, None] * stride_im + rk[None, :]
tl.store(grad_in_ptrs, grad_in_acc, mask=mask_mk)


Expand All @@ -343,38 +346,40 @@ def fused_matmul_backward(
.. note: The weight buffer is transposed on the fly
"""

grad_out = grad_out.flatten(0, 1) if grad_out.ndim == 3 else grad_out
if grad_out.stride(1) != 1:
grad_out.contiguous()
grad_out_ = grad_out.flatten(0, 1) if grad_out.ndim == 3 else grad_out
if grad_out_.stride(-1) != 1:
grad_out_.contiguous()

assert (
grad_out.shape[1] == weight.shape[0]
grad_out_.shape[-1] == weight.shape[0]
), "Incompatible dimensions in between grad_out and weight"

M, N = grad_out.shape
M, N = grad_out_.shape
N, K = weight.shape

grad_in = torch.empty((M, K), device=grad_out.device, dtype=grad_out.dtype)
grad_act = torch.empty_like(grad_out)
grad_act = torch.empty_like(grad_out_)
activation_inputs = grad_out if activation_inputs is None else activation_inputs

def grid(META):
return (
triton.cdiv(M, META["BLOCK_M"]) * triton.cdiv(K, META["BLOCK_K"]),
)

assert grad_out_.stride(0) == grad_act.stride(0) and grad_out_.stride(1) == grad_act.stride(1)

# Compute the gradient for the inputs + partial reduction for grad bias
# fmt: off
# (M, K) <- (M, N) x (N, K)
kernel_grad_inputs[grid](
# data ptrs
grad_in, grad_act, grad_out,
grad_in, grad_act, grad_out_,
activation_inputs, weight,
# shapes
M, N, K,
# strides
grad_in.stride(0), grad_act.stride(0), grad_out.stride(0),
activation_inputs.stride(0), weight.stride(0), weight.stride(1),
grad_out_.stride(0), grad_in.stride(0),
weight.stride(0), weight.stride(1),
# optional fused activation
ACTIVATION_GRAD=activation_grad,
GROUP_M=8, # L2 data reuse optimization
Expand All @@ -386,8 +391,13 @@ def grid(META):
# fmt: on

grad_bias = torch.sum(grad_act, dim=0) if trainable_bias else None
inputs_ = inputs.flatten(0, 1) if inputs.ndim == 3 else inputs
grad_weight = triton.ops.matmul(grad_act.transpose(0, 1), inputs_) if trainable_weight else None

if trainable_weight:
inputs_ = inputs.flatten(0, 1) if inputs.ndim == 3 else inputs
grad_weight = triton.ops.matmul(grad_act.transpose(0, 1), inputs_)
del inputs_
else:
grad_weight = None

del grad_act

Expand Down

0 comments on commit 618697b

Please sign in to comment.