From d60f9b3b47ee585cc1d8a836bb0d7acab81a9b6e Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sat, 8 Jul 2023 02:52:22 -0400 Subject: [PATCH] Refactor and pass correct interpreter to typeinf finish loop (#50469) When we have an inference loop with different interpreters, the current code was trying to cache everything with the top level interpreter of the loop, yielding some unexpected behavior. I don't think that it's necessarily super well defined what should happen here, as it depends on the interpreters, in question, but I think it's better to try to cache each frame with the interpreter that created it, since they may have different lattices, etc. Doing this fixes an error I saw downstream that had just such a situation. --------- Co-authored-by: Shuhei Kadowaki <40514306+aviatesk@users.noreply.github.com> --- base/compiler/typeinfer.jl | 31 +++++++++++++------------------ 1 file changed, 13 insertions(+), 18 deletions(-) diff --git a/base/compiler/typeinfer.jl b/base/compiler/typeinfer.jl index 77e1fd02de8d0..dfeaba18321d1 100644 --- a/base/compiler/typeinfer.jl +++ b/base/compiler/typeinfer.jl @@ -257,33 +257,28 @@ function _typeinf(interp::AbstractInterpreter, frame::InferenceState) end for caller in frames caller.valid_worlds = valid_worlds - finish(caller, interp) + finish(caller, caller.interp) end - # collect results for the new expanded frame - results = Tuple{InferenceResult, Vector{Any}, Bool}[ - ( frames[i].result, - frames[i].stmt_edges[1]::Vector{Any}, - frames[i].cached ) - for i in 1:length(frames) ] - empty!(frames) - for (caller, _, _) in results - opt = caller.src - if opt isa OptimizationState{typeof(interp)} # implies `may_optimize(interp) === true` - optimize(interp, opt, caller) + for caller in frames + opt = caller.result.src + if opt isa OptimizationState # implies `may_optimize(caller.interp) === true` + optimize(caller.interp, opt, caller.result) end end - for (caller, edges, cached) in results - valid_worlds = caller.valid_worlds + for caller in frames + (; result ) = caller + valid_worlds = result.valid_worlds if last(valid_worlds) >= get_world_counter() # if we aren't cached, we don't need this edge # but our caller might, so let's just make it anyways - store_backedges(caller, edges) + store_backedges(result, caller.stmt_edges[1]) end - if cached - cache_result!(interp, caller) + if caller.cached + cache_result!(caller.interp, result) end - finish!(interp, caller) + finish!(caller.interp, result) end + empty!(frames) return true end