Skip to content

Commit

Permalink
Refactor and pass correct interpreter to typeinf finish loop (JuliaLa…
Browse files Browse the repository at this point in the history
…ng#50469)

When we have an inference loop with different interpreters,
the current code was trying to cache everything with the top
level interpreter of the loop, yielding some unexpected behavior.
I don't think that it's necessarily super well defined what should
happen here, as it depends on the interpreters, in question, but
I think it's better to try to cache each frame with the interpreter
that created it, since they may have different lattices, etc.
Doing this fixes an error I saw downstream that had just such
a situation.

---------

Co-authored-by: Shuhei Kadowaki <[email protected]>
  • Loading branch information
Keno and aviatesk authored Jul 8, 2023
1 parent e20274f commit d60f9b3
Showing 1 changed file with 13 additions and 18 deletions.
31 changes: 13 additions & 18 deletions base/compiler/typeinfer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -257,33 +257,28 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState)
end
for caller in frames
caller.valid_worlds = valid_worlds
finish(caller, interp)
finish(caller, caller.interp)
end
# collect results for the new expanded frame
results = Tuple{InferenceResult, Vector{Any}, Bool}[
( frames[i].result,
frames[i].stmt_edges[1]::Vector{Any},
frames[i].cached )
for i in 1:length(frames) ]
empty!(frames)
for (caller, _, _) in results
opt = caller.src
if opt isa OptimizationState{typeof(interp)} # implies `may_optimize(interp) === true`
optimize(interp, opt, caller)
for caller in frames
opt = caller.result.src
if opt isa OptimizationState # implies `may_optimize(caller.interp) === true`
optimize(caller.interp, opt, caller.result)
end
end
for (caller, edges, cached) in results
valid_worlds = caller.valid_worlds
for caller in frames
(; result ) = caller
valid_worlds = result.valid_worlds
if last(valid_worlds) >= get_world_counter()
# if we aren't cached, we don't need this edge
# but our caller might, so let's just make it anyways
store_backedges(caller, edges)
store_backedges(result, caller.stmt_edges[1])
end
if cached
cache_result!(interp, caller)
if caller.cached
cache_result!(caller.interp, result)
end
finish!(interp, caller)
finish!(caller.interp, result)
end
empty!(frames)
return true
end

Expand Down

0 comments on commit d60f9b3

Please sign in to comment.