Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

mlir: Func call reverse diff #2127

Merged
merged 7 commits into from
Oct 23, 2024
Merged

mlir: Func call reverse diff #2127

merged 7 commits into from
Oct 23, 2024

Conversation

Pangoraw
Copy link
Contributor

@Pangoraw Pangoraw commented Oct 21, 2024

prompted by EnzymeAD/Reactant.jl#171 (comment).

We can now run enzyme-batch -> enzyme-wrap without inlining.

// ./bazel-bin/enzymexlamlir-opt test/lit_tests/batchtests/autodiff.mlir --enzyme-batch --enzyme-wrap="infn=main retTys=enzyme_active argTys=enzyme_active mode=ReverseModeCombined" --canonicalize --remove-unnecessary-enzyme-ops --arith-raise --cse --canonicalize --arith-raise --enzyme-simplify-math --mlir-print-ir-before=enzyme-batch
// -----// IR Dump Before BatchPass (enzyme-batch) //----- //
module {
  func.func @f(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    return %0 : tensor<f32>
  }
  func.func @main(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %0 = enzyme.batch @f(%arg0) {batch_shape = array<i64: 10>} : (tensor<10xf32>) -> tensor<10xf32>
    return %0 : tensor<10xf32>
  }
}


module {
  func.func @f(%arg0: tensor<f32>) -> tensor<f32> {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<f32>
    return %0 : tensor<f32>
  }
  func.func private @batched_f(%arg0: tensor<10xf32>) -> tensor<10xf32> {
    %0 = stablehlo.multiply %arg0, %arg0 : tensor<10xf32>
    return %0 : tensor<10xf32>
  }
  func.func @main(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
    %cst = arith.constant dense<0.000000e+00> : tensor<10xf32>
    %0 = call @batched_f(%arg0) : (tensor<10xf32>) -> tensor<10xf32>
    %1 = stablehlo.add %arg1, %cst : tensor<10xf32>
    %2 = call @diffebatched_f(%arg0, %1) : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32>
    %3 = stablehlo.add %2, %cst : tensor<10xf32>
    return %3 : tensor<10xf32>
  }
  func.func private @diffebatched_f(%arg0: tensor<10xf32>, %arg1: tensor<10xf32>) -> tensor<10xf32> {
    %cst = arith.constant dense<0.000000e+00> : tensor<10xf32>
    %0 = stablehlo.add %arg1, %cst : tensor<10xf32>
    %1 = stablehlo.multiply %0, %arg0 : tensor<10xf32>
    %2 = stablehlo.add %1, %cst : tensor<10xf32>
    %3 = stablehlo.add %2, %1 : tensor<10xf32>
    return %3 : tensor<10xf32>
  }
}


std::vector<bool> volatile_args(narg, true);
std::vector<bool> returnShadow(narg, false);
std::vector<bool> returnPrimal(nret, false);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should probably default to always returning the primal, no?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't really understand how the primal would be used here

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah we would only need for fused forward and reverse passes or handling of mutation, we can defer that for later here

bool freeMemory = true;
size_t width = 1;

auto revFn = gutils->Logic.CreateReverseDiff(
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Long term We probably need to do the same augmented forward pass and separate reverse pass that enzyme llvm does atm .

if we assume everything in the function is read only and only return active not duplicated results this is fine

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah right, the caching at the primal call site is not enough if the arguments are mutable. I could not make an example input with memrefs (something about invertPointerM). So I added an error if any of the result or arg of the initial call is mutable.

Copy link
Member

@wsmoses wsmoses left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM but some misc fix ups and also the assumptions of the current implementation should throw errors if not met (we have a isReadOnly for example)

@@ -29,9 +29,139 @@ namespace {
#include "Implementations/FuncDerivatives.inc"
} // namespace

static std::optional<func::FuncOp> getContainingFunction(Operation *orig) {
Operation *parent;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can this be a function interface instead?


auto parent = getContainingFunction(orig);
if (parent.has_value() &&
callOp.getCallee() == parent.value().getNameAttr()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I see this use, this is super fragile. For example it won’t work if you have a calls b calls a.

I think this is okay to merge as is but adding proper recursion support shouldn’t be bad. We have it already for forward enzymemlir which basically just requires having a cache in enzymelogic asking if we’ve already made this derivative function

Copy link
Contributor Author

@Pangoraw Pangoraw Oct 23, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah any cycle in the call graph will not be caught here.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah enzymelogic should do this for forwards (and also llvm) already

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will look at implementing something like ForwardCachedFunctions for reverse in a follow-up

@wsmoses wsmoses merged commit 11e9b9d into EnzymeAD:main Oct 23, 2024
12 of 27 checks passed
jedbrown added a commit that referenced this pull request Nov 1, 2024
* main: (49 commits)
  Fix iv of constant (#2141)
  Update benchmarks (#2035)
  Implement tgamma derivative (#2140)
  tgamma error improvement (#2139)
  Improve cache index error message (#2138)
  Fixes warnings and adds missing header guards (#2124)
  mlir: cache and reuse reverse funcs (#2133)
  mlir: implement forward mode for func.call (#2134)
  mlir: Func call reverse diff (#2127)
  Update build_tarballs.jl
  Fix combined temp cache for reverse (#2131)
  Improve runtime activity err message (#2132)
  Fix undef value storage (#2129)
  Adapt to const tblgen (#2128)
  Add gcloaded TT (#2125)
  Fix blas decl updater indexing (#2123)
  Add header files to ClangEnzyme target (#2062)
  Improve unknown function error messages (#2120)
  Fix handle sync (#2122)
  Support more Julia 1.11 functions (#2121)
  ...
jedbrown added a commit that referenced this pull request Nov 1, 2024
* main: (49 commits)
  Fix iv of constant (#2141)
  Update benchmarks (#2035)
  Implement tgamma derivative (#2140)
  tgamma error improvement (#2139)
  Improve cache index error message (#2138)
  Fixes warnings and adds missing header guards (#2124)
  mlir: cache and reuse reverse funcs (#2133)
  mlir: implement forward mode for func.call (#2134)
  mlir: Func call reverse diff (#2127)
  Update build_tarballs.jl
  Fix combined temp cache for reverse (#2131)
  Improve runtime activity err message (#2132)
  Fix undef value storage (#2129)
  Adapt to const tblgen (#2128)
  Add gcloaded TT (#2125)
  Fix blas decl updater indexing (#2123)
  Add header files to ClangEnzyme target (#2062)
  Improve unknown function error messages (#2120)
  Fix handle sync (#2122)
  Support more Julia 1.11 functions (#2121)
  ...
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants