Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: implement a separate TracedRNumber #161

Merged
merged 35 commits into from
Oct 6, 2024
Merged
Show file tree
Hide file tree
Changes from 33 commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
4a1be10
feat: TracedRScalar
avik-pal Oct 2, 2024
978fdd9
feat: partial progress on getting scalars to work
avik-pal Oct 3, 2024
32af332
refactor: Scalar --> Number
avik-pal Oct 5, 2024
8c213cb
fix: batching
avik-pal Oct 5, 2024
49e124a
fix: promote_rule and introduce union over primitive types
avik-pal Oct 5, 2024
6dc36f1
chore: apply formatting
avik-pal Oct 5, 2024
20de817
feat: type-restrict arrays
avik-pal Oct 5, 2024
e04e3b6
refactor: move scalar ops to a separate file
avik-pal Oct 5, 2024
2ca8c68
feat: support Base.float
avik-pal Oct 5, 2024
2975953
fix: import ordering
avik-pal Oct 5, 2024
8da3d20
feat: handle `broadcast_preserving_zero_d` in a generic fashion
avik-pal Oct 5, 2024
194ee65
refactor: move code a bit
avik-pal Oct 5, 2024
db5565b
test: more test fixes
avik-pal Oct 5, 2024
d18aff8
chore: apply formatting
avik-pal Oct 5, 2024
7fd269d
fix: setindex with scalars
avik-pal Oct 5, 2024
91a4a00
fix: scalar broadcasting case
avik-pal Oct 5, 2024
d82fb52
feat: support BFloat16 from Core (if available)
avik-pal Oct 5, 2024
45158bb
test: more native lux functionality unblocked
avik-pal Oct 5, 2024
4757cf9
refactor: use a union type for traced types
avik-pal Oct 5, 2024
c85d3a1
fix: check for reactant primitives
avik-pal Oct 5, 2024
d9cf498
fix: missing import
avik-pal Oct 5, 2024
0d7ad84
fix: correct semantics for Colon mapreduce
avik-pal Oct 5, 2024
d7337c9
fix: trace_type
avik-pal Oct 5, 2024
6aab7f7
fix: minor fixes
avik-pal Oct 5, 2024
abc6a9e
feat: support logsoftmax
avik-pal Oct 5, 2024
841376d
fix: bool promote rule
avik-pal Oct 5, 2024
eb3d1db
fix: broadcasting of closures
avik-pal Oct 5, 2024
944dca8
refactor: use TracedTypes
avik-pal Oct 6, 2024
500d12f
Merge branch 'main' into ap/scalar
mofeing Oct 6, 2024
3ecafef
Fix type of `preserved_args`
mofeing Oct 6, 2024
c03b5e0
Rename `TracedTypes` to `TracedType`
mofeing Oct 6, 2024
60b614b
small testset rename
mofeing Oct 6, 2024
8a9f06c
fix: special handling for concatenation of numbers
avik-pal Oct 6, 2024
a35d7b7
Reenable tests
mofeing Oct 6, 2024
4a81556
Rename `ReactantPrimitives` to `ReactantPrimitive`
mofeing Oct 6, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 18 additions & 11 deletions ext/ReactantNNlibExt.jl
Original file line number Diff line number Diff line change
@@ -1,31 +1,24 @@
module ReactantNNlibExt

using NNlib
using Reactant: Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR
using Reactant:
Reactant, TracedRArray, AnyTracedRArray, materialize_traced_array, MLIR, TracedRNumber

for (jlop, hloop) in (
(:(NNlib.tanh_fast), :tanh),
(:(NNlib.sigmoid_fast), :logistic),
(:(NNlib.sigmoid), :logistic),
)
@eval function $(jlop)(x::TracedRArray{T,0}) where {T}
return TracedRArray{T,0}(
@eval function $(jlop)(x::TracedRNumber{T}) where {T}
return TracedRNumber{T}(
(),
Reactant.MLIR.IR.result(
Reactant.MLIR.Dialects.stablehlo.$(hloop)(x.mlir_data), 1
),
(),
)
end
end

# Don't confuse our poor scalar arrays, we no like numbers we like 0D arrays
for nnlib_op in setdiff(Tuple(NNlib.ACTIVATIONS), (:tanh_fast, :sigmoid_fast, :sigmoid, :σ))
@eval function NNlib.$(nnlib_op)(x::TracedRArray{T,0}) where {T}
return invoke(NNlib.$(nnlib_op), Tuple{Any}, x)
end
end

# TODO handle non finite cases
function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where {T,N}
max_ = NNlib.fast_maximum(x; dims)
Expand All @@ -39,6 +32,20 @@ function NNlib.softmax!(out::TracedRArray{T,N}, x::AbstractArray; dims=1) where
return out ./= tmp
end

function NNlib.logsoftmax!(out::TracedRArray{T}, x::AbstractArray; dims=1) where {T}
max_ = NNlib.fast_maximum(x; dims)
# if all(isfinite, max_)
@fastmath out .= x .- max_
# else
# _zero, _minf, _inf = T(0), T(-Inf), T(Inf)
# @. out = ifelse(
# isequal(max_, _inf), ifelse(isequal(x, _inf), _zero, _minf), x - max_
# )
# end
@fastmath log_ = log.(sum(exp, out; dims))
return out .-= log_
end

function NNlib.conv(
x::AnyTracedRArray{T,N}, W::AnyTracedRArray{T}, cdims::DenseConvDims
) where {T,N}
Expand Down
10 changes: 6 additions & 4 deletions src/Compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@ import ..Reactant:
XLA,
ConcreteRArray,
TracedRArray,
TracedRNumber,
OrderedIdDict,
make_tracer,
TracedToConcrete,
append_path
append_path,
TracedType

@inline traced_getfield(@nospecialize(obj), field) = Base.getfield(obj, field)

Expand Down Expand Up @@ -286,10 +288,10 @@ function compile_mlir!(mod, f, args; optimize=true)
)
end

preserved_args = Tuple{TracedRArray,Int}[]
preserved_args = Tuple{TracedType,Int}[]
results = [MLIR.IR.operand(ret, i) for i in 1:MLIR.IR.noperands(ret)]
nresults = MLIR.IR.Value[]
linear_results2 = TracedRArray[]
linear_results2 = TracedType[]
for (i, op) in enumerate(results)
if !MLIR.IR.is_block_arg(op)
push!(nresults, op)
Expand Down Expand Up @@ -573,7 +575,7 @@ end
function compile_xla(f, args; client=nothing)
# register MLIR dialects
ctx = MLIR.IR.Context()
Base.append!(Reactant.registry[]; context=ctx)
append!(Reactant.registry[]; context=ctx)
@ccall MLIR.API.mlir_c.RegisterDialects(ctx::MLIR.API.MlirContext)::Cvoid

return MLIR.IR.context!(ctx) do
Expand Down
6 changes: 1 addition & 5 deletions src/ConcreteRArray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -74,10 +74,6 @@ function Base.convert(::Type{T}, x::ConcreteRArray{T,0}) where {T}
return to_float(x)
end

function Base.promote_rule(::Type{<:RArray{T1,0}}, ::Type{T2}) where {T1,T2}
return Base.promote_rule(T1, T2)
end

for jlop in (:(Base.isless), :(Base.:+), :(Base.:-), :(Base.:*), :(Base.:/), :(Base.:^))
@eval begin
function $jlop(x::ConcreteRArray{T,0}, y::ConcreteRArray{U,0}) where {T,U}
Expand Down Expand Up @@ -158,7 +154,7 @@ function Base.getindex(a::ConcreteRArray{T}, args::Vararg{Int,N}) where {T,N}
end

function mysetindex!(a, v, args::Vararg{Int,N}) where {N}
Base.setindex!(a, v, args...)
setindex!(a, v, args...)
return nothing
end

Expand Down
45 changes: 44 additions & 1 deletion src/Reactant.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,45 @@ include("OrderedIdDict.jl")

using Enzyme

abstract type RArray{T,N} <: AbstractArray{T,N} end
@static if isdefined(Core, :BFloat16)
const ReactantPrimitives = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Core.BFloat16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
else
const ReactantPrimitives = Union{
Bool,
Int8,
UInt8,
Int16,
UInt16,
Int32,
UInt32,
Int64,
UInt64,
Float16,
Float32,
Float64,
Complex{Float32},
Complex{Float64},
}
end

abstract type RArray{T<:ReactantPrimitives,N} <: AbstractArray{T,N} end
abstract type RNumber{T<:ReactantPrimitives} <: Number end

function Base.reshape(A::RArray, dims::Tuple{Vararg{Union{Int,Colon}}})
return reshape(A, Base._reshape_uncolon(A, dims))
Expand Down Expand Up @@ -45,8 +83,13 @@ include("mlir/MLIR.jl")
include("XLA.jl")
include("Interpreter.jl")
include("utils.jl")

include("ConcreteRArray.jl")
include("TracedRNumber.jl")
include("TracedRArray.jl")

const TracedType = Union{TracedRArray,TracedRNumber}

include("Tracing.jl")
include("Compiler.jl")

Expand Down
Loading
Loading