Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add AbstractInterpreter to parameterize compilation pipeline #33955

Merged
merged 2 commits into from
May 10, 2020

Commits on May 10, 2020

  1. Add AbstractInterpreter to parameterize compilation pipeline

    This allows selective overriding of the compilation pipeline through
    multiple dispatch, enabling projects like `XLA.jl` to maintain separate
    inference caches, inference algorithms or heuristic algorithms while
    inferring and lowering code.  In particular, it defines a new type,
    `AbstractInterpreter`, that represents an abstract interpretation
    pipeline.  This `AbstractInterpreter` has a single defined concrete
    subtype, `NativeInterpreter`, that represents the native Julia
    compilation pipeline.  The `NativeInterpreter` contains within it all
    the compiler parameters previously contained within `Params`, split into
    two pieces: `InferenceParams` and `OptimizationParams`, used within type
    inference and optimization, respectively.  The interpreter object is
    then threaded throughout most of the type inference pipeline, and allows
    for straightforward prototyping and replacement of the compiler
    internals.
    
    As a simple example of the kind of workflow this enables, I include here
    a simple testing script showing how to use this to easily get a list
    of the number of times a function is inferred during type inference by
    overriding just two functions within the compiler.  First, I will define
    here some simple methods to make working with inference a bit easier:
    
    ```julia
    using Core.Compiler
    import Core.Compiler: InferenceParams, OptimizationParams, get_world_counter, get_inference_cache
    
    """
        @infer_function interp foo(1, 2) [show_steps=true] [show_ir=false]
    
    Infer a function call using the given interpreter object, return
    the inference object.  Set keyword arguments to modify verbosity:
    
    * Set `show_steps` to `true` to see the `InferenceResult` step by step.
    * Set `show_ir` to `true` to see the final type-inferred Julia IR.
    """
    macro infer_function(interp, func_call, kwarg_exs...)
        if !isa(func_call, Expr) || func_call.head != :call
            error("@infer_function requires a function call")
        end
    
        local func = func_call.args[1]
        local args = func_call.args[2:end]
        kwargs = []
        for ex in kwarg_exs
            if ex isa Expr && ex.head === :(=) && ex.args[1] isa Symbol
                push!(kwargs, first(ex.args) => last(ex.args))
            else
                error("Invalid @infer_function kwarg $(ex)")
            end
        end
        return quote
            infer_function($(esc(interp)), $(esc(func)), typeof.(($(args)...,)); $(esc(kwargs))...)
        end
    end
    
    function infer_function(interp, f, tt; show_steps::Bool=false, show_ir::Bool=false)
        # Find all methods that are applicable to these types
        fms = methods(f, tt)
        if length(fms) != 1
            error("Unable to find single applicable method for $f with types $tt")
        end
    
        # Take the first applicable method
        method = first(fms)
    
        # Build argument tuple
        method_args = Tuple{typeof(f), tt...}
    
        # Grab the appropriate method instance for these types
        mi = Core.Compiler.specialize_method(method, method_args, Core.svec())
    
        # Construct InferenceResult to hold the result,
        result = Core.Compiler.InferenceResult(mi)
        if show_steps
            @info("Initial result, before inference: ", result)
        end
    
        # Create an InferenceState to begin inference, give it a world that is always newest
        world = Core.Compiler.get_world_counter()
        frame = Core.Compiler.InferenceState(result, #=cached=# true, interp)
    
        # Run type inference on this frame.  Because the interpreter is embedded
        # within this InferenceResult, we don't need to pass the interpreter in.
        Core.Compiler.typeinf_local(interp, frame)
        if show_steps
            @info("Ending result, post-inference: ", result)
        end
        if show_ir
            @info("Inferred source: ", result.result.src)
        end
    
        # Give the result back
        return result
    end
    ```
    
    Next, we define a simple function and pass it through:
    ```julia
    function foo(x, y)
        return x + y * x
    end
    
    native_interpreter = Core.Compiler.NativeInterpreter()
    inferred = @infer_function native_interpreter foo(1.0, 2.0) show_steps=true show_ir=true
    ```
    
    This gives a nice output such as the following:
    ```julia-repl
    ┌ Info: Initial result, before inference:
    └   result = foo(::Float64, ::Float64) => Any
    ┌ Info: Ending result, post-inference:
    └   result = foo(::Float64, ::Float64) => Float64
    ┌ Info: Inferred source:
    │   result.result.src =
    │    CodeInfo(
    │        @ REPL[1]:3 within `foo'
    │    1 ─ %1 = (y * x)::Float64
    │    │   %2 = (x + %1)::Float64
    │    └──      return %2
    └    )
    ```
    
    We can then define a custom `AbstractInterpreter` subtype that will
    override two specific pieces of the compilation process; managing the
    runtime inference cache.  While it will transparently pass all information
    through to a bundled `NativeInterpreter`, it has the ability to force cache
    misses in order to re-infer things so that we can easily see how many
    methods (and which) would be inferred to compile a certain method:
    
    ```julia
    struct CountingInterpreter <: Compiler.AbstractInterpreter
        visited_methods::Set{Core.Compiler.MethodInstance}
        methods_inferred::Ref{UInt64}
    
        # Keep around a native interpreter so that we can sub off to "super" functions
        native_interpreter::Core.Compiler.NativeInterpreter
    end
    CountingInterpreter() = CountingInterpreter(
        Set{Core.Compiler.MethodInstance}(),
        Ref(UInt64(0)),
        Core.Compiler.NativeInterpreter(),
    )
    
    InferenceParams(ci::CountingInterpreter) = InferenceParams(ci.native_interpreter)
    OptimizationParams(ci::CountingInterpreter) = OptimizationParams(ci.native_interpreter)
    get_world_counter(ci::CountingInterpreter) = get_world_counter(ci.native_interpreter)
    get_inference_cache(ci::CountingInterpreter) = get_inference_cache(ci.native_interpreter)
    
    function Core.Compiler.inf_for_methodinstance(interp::CountingInterpreter, mi::Core.Compiler.MethodInstance, min_world::UInt, max_world::UInt=min_world)
        # Hit our own cache; if it exists, pass on to the main runtime
        if mi in interp.visited_methods
            return Core.Compiler.inf_for_methodinstance(interp.native_interpreter, mi, min_world, max_world)
        end
    
        # Otherwise, we return `nothing`, forcing a cache miss
        return nothing
    end
    
    function Core.Compiler.cache_result(interp::CountingInterpreter, result::Core.Compiler.InferenceResult, min_valid::UInt, max_valid::UInt)
        push!(interp.visited_methods, result.linfo)
        interp.methods_inferred[] += 1
        return Core.Compiler.cache_result(interp.native_interpreter, result, min_valid, max_valid)
    end
    
    function reset!(interp::CountingInterpreter)
        empty!(interp.visited_methods)
        interp.methods_inferred[] = 0
        return nothing
    end
    ```
    
    Running it on our testing function:
    ```julia
    counting_interpreter = CountingInterpreter()
    inferred = @infer_function counting_interpreter foo(1.0, 2.0)
    @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
    inferred = @infer_function counting_interpreter foo(1, 2) show_ir=true
    @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
    
    inferred = @infer_function counting_interpreter foo(1.0, 2.0)
    @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
    reset!(counting_interpreter)
    
    @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
    inferred = @infer_function counting_interpreter foo(1.0, 2.0)
    @info("Cumulative number of methods inferred: $(counting_interpreter.methods_inferred[])")
    ```
    
    Also gives us a nice result:
    ```
    [ Info: Cumulative number of methods inferred: 2
    ┌ Info: Inferred source:
    │   result.result.src =
    │    CodeInfo(
    │        @ /Users/sabae/src/julia-compilerhack/AbstractInterpreterTest.jl:81 within `foo'
    │    1 ─ %1 = (y * x)::Int64
    │    │   %2 = (x + %1)::Int64
    │    └──      return %2
    └    )
    [ Info: Cumulative number of methods inferred: 4
    [ Info: Cumulative number of methods inferred: 4
    [ Info: Cumulative number of methods inferred: 0
    [ Info: Cumulative number of methods inferred: 2
    ```
    staticfloat authored and Keno committed May 10, 2020
    Configuration menu
    Copy the full SHA
    c5fcd73 View commit details
    Browse the repository at this point in the history
  2. Rename one typeinf_ext method to typeinf_ext_toplevel

    This disambiguates the two methods, allowing us to eliminate the
    redundant `world::UInt` parameter.
    staticfloat authored and Keno committed May 10, 2020
    Configuration menu
    Copy the full SHA
    10b572c View commit details
    Browse the repository at this point in the history