Skip to content
This repository has been archived by the owner on Nov 4, 2024. It is now read-only.

Commit

Permalink
fix: relax cublaslt types
Browse files Browse the repository at this point in the history
  • Loading branch information
avik-pal committed Oct 10, 2024
1 parent 2d7533c commit 3c9ff76
Show file tree
Hide file tree
Showing 3 changed files with 8 additions and 9 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "LuxLib"
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
authors = ["Avik Pal <[email protected]> and contributors"]
version = "1.3.2"
version = "1.3.3"

[deps]
ArrayInterface = "4fba245c-0d91-5ea0-9b3e-6abc04ee57a9"
Expand Down
3 changes: 1 addition & 2 deletions ext/LuxLibCUDAExt/LuxLibCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
module LuxLibCUDAExt

# This file only wraps functionality part of CUDA like CUBLAS
using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr, AnyCuMatrix, AnyCuVector
using CUDA: CUDA, CUBLAS, StridedCuMatrix, StridedCuVector, CuPtr
using LinearAlgebra: LinearAlgebra, Transpose, Adjoint
using LuxLib: LuxLib, Optional
using LuxLib.Utils: ofeltype_array
Expand Down
12 changes: 6 additions & 6 deletions ext/LuxLibCUDAExt/cublaslt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,16 +170,16 @@ end
len(x) = length(x)
len(::Nothing) = nothing

function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix,
b::Optional{<:AnyCuVector}, ::False) where {F}
function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AbstractMatrix,
x::AbstractMatrix, b::Optional{<:AbstractVector}, ::False) where {F}
z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b),
size(weight, 1), size(x, 2))
LuxLib.cublasLt_fused_dense!(z, act, weight, x, b)
return z, nothing
end

function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuMatrix,
b::Optional{<:AnyCuVector}, ::True) where {F}
function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AbstractMatrix,
x::AbstractMatrix, b::Optional{<:AbstractVector}, ::True) where {F}
z = similar(x, LuxLib.concrete_fba_output_eltype(act, weight, x, b),
size(weight, 1), size(x, 2))
y = similar(z)
Expand All @@ -188,8 +188,8 @@ function LuxLib.Impl.cublasLt_fused_dense(act::F, weight::AnyCuMatrix, x::AnyCuM
end

function LuxLib.Impl.cublasLt_fused_dense!(
z::AbstractMatrix, act::F, weight::AnyCuMatrix, x::AnyCuMatrix,
b::Optional{<:AnyCuVector}, y::Optional{<:AbstractMatrix}=nothing) where {F}
z::AbstractMatrix, act::F, weight::AbstractMatrix, x::AbstractMatrix,
b::Optional{<:AbstractVector}, y::Optional{<:AbstractMatrix}=nothing) where {F}
if hasmethod(cublaslt_matmul_fused!,
(typeof(z), typeof(act), typeof(weight), typeof(x), typeof(b), typeof(y)))
retcode = cublaslt_matmul_fused!(z, act, weight, x, b, y)
Expand Down

0 comments on commit 3c9ff76

Please sign in to comment.