Skip to content

Commit

Permalink
Add a nondet inv op (#6)
Browse files Browse the repository at this point in the history
  • Loading branch information
tzerrell authored Aug 20, 2024
1 parent 39fb52b commit 8a29839
Show file tree
Hide file tree
Showing 12 changed files with 159 additions and 3 deletions.
32 changes: 32 additions & 0 deletions zirgen/Dialect/BigInt/IR/Eval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,32 @@ BytePoly nondetRem(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
return fromAPInt(rem, coeffs);
}

BytePoly nondetInvMod(const BytePoly& lhs, const BytePoly& rhs, size_t coeffs) {
// Uses the formula n^(p-2) * n = 1 (mod p) to invert `lhs` (mod `rhs`)
// (via the square and multiply technique)
auto lhsInt = toAPInt(lhs);
auto rhsInt = toAPInt(rhs);
size_t maxSize = rhsInt.getBitWidth();
APInt inv(2 * maxSize,
1); // Initialize inverse to zero, twice the width of `prime` to allow multiplication
APInt sqr(lhsInt); // Will be repeatedly squared
APInt position(2 * maxSize, 1); // Bit at `idx` will be 1, other bits will be 0
sqr = sqr.zext(2 * maxSize);
rhsInt = rhsInt.zext(2 * maxSize);
APInt exp = rhsInt - 2;
for (size_t idx = 0; idx < maxSize; idx++) {
if (exp.intersects(position)) {
// multiply in the current power of n (i.e., n^(2^idx))
inv = (inv * sqr).urem(rhsInt);
}
position <<= 1; // increment the bit position to test in `exp`
sqr = (sqr * sqr).urem(rhsInt); // square `sqr` to increment to `n^(2^(idx+1))`
}
inv = inv.trunc(maxSize); // We don't need the extra space used as multiply buffer
LLVM_DEBUG({ dbgs() << "inv (mod " << rhsInt << "): " << inv << "\n"; });
return fromAPInt(inv, coeffs);
}

void printEval(const std::string& message, BytePoly poly) {
risc0::FpExt tot(0);
risc0::FpExt mul(1);
Expand Down Expand Up @@ -190,6 +216,12 @@ EvalOutput eval(func::FuncOp inFunc, ArrayRef<APInt> witnessValues) {
polys[op.getOut()] = poly;
ret.privateWitness.push_back(poly);
})
.Case<NondetInvModOp>([&](auto op) {
uint32_t coeffs = op.getOut().getType().getCoeffs();
auto poly = nondetInvMod(polys[op.getLhs()], polys[op.getRhs()], coeffs);
polys[op.getOut()] = poly;
ret.privateWitness.push_back(poly);
})
.Case<EqualZeroOp>([&](auto op) {
auto poly = polys[op.getIn()];
if (toAPInt(poly) != 0) {
Expand Down
34 changes: 34 additions & 0 deletions zirgen/Dialect/BigInt/IR/Ops.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,34 @@ LogicalResult NondetQuotOp::inferReturnTypes(MLIRContext* ctx,
return success();
}

LogicalResult NondetInvModOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
/*maxNeg=*/0,
/*minBits=*/0));
return success();
}

LogicalResult ModularInvOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
SmallVectorImpl<Type>& out) {
auto rhsType = adaptor.getRhs().getType().cast<BigIntType>();
size_t coeffsWidth = ceilDiv(rhsType.getMaxBits(), kBitsPerCoeff);
out.push_back(BigIntType::get(ctx,
/*coeffs=*/coeffsWidth,
/*maxPos=*/(1 << kBitsPerCoeff) - 1,
/*maxNeg=*/0,
/*minBits=*/0));
return success();
}

LogicalResult ReduceOp::inferReturnTypes(MLIRContext* ctx,
std::optional<Location> loc,
Adaptor adaptor,
Expand Down Expand Up @@ -187,6 +215,12 @@ void NondetQuotOp::emitExpr(codegen::CodegenEmitter& cg) {
{getLhs(), getRhs(), toConstantValue(cg, getContext(), getType().getCoeffs())});
}

void NondetInvModOp::emitExpr(codegen::CodegenEmitter& cg) {
cg.emitFuncCall(cg.getStringAttr("nondet_inv"),
/*contextArgs=*/{"ctx"},
{getLhs(), getRhs(), toConstantValue(cg, getContext(), getType().getCoeffs())});
}

void ConstOp::emitExpr(codegen::CodegenEmitter& cg) {
auto bytePoly = fromAPInt(getValue(), getType().getCoeffs());
SmallVector<codegen::EmitPart> macroArgs;
Expand Down
2 changes: 2 additions & 0 deletions zirgen/Dialect/BigInt/IR/Ops.td
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def SubOp : BinaryOp<"sub", [Pure, ]> {}
def MulOp : BinaryOp<"mul", [Pure, Commutative]> {}
def NondetRemOp : BinaryOp<"nondet_rem", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def NondetQuotOp : BinaryOp<"nondet_quot", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def NondetInvModOp : BinaryOp<"nondet_invmod", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {}
def ModularInvOp : BinaryOp<"inv", []> {}
def ReduceOp : BinaryOp<"reduce", []> {}

def EqualZeroOp : BigIntOp<"eqz", [DeclareOpInterfaceMethods<CodegenExprOpInterface>]> {
Expand Down
1 change: 1 addition & 0 deletions zirgen/Dialect/BigInt/Transforms/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ gentbl_cc_library(
cc_library(
name = "Transforms",
srcs = [
"LowerModularInv.cpp",
"LowerReduce.cpp",
"LowerZll.cpp",
"PassDetail.h",
Expand Down
55 changes: 55 additions & 0 deletions zirgen/Dialect/BigInt/Transforms/LowerModularInv.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
// Copyright (c) 2024 RISC Zero, Inc.
//
// All rights reserved.

#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#include "zirgen/Dialect/BigInt/IR/BigInt.h"
#include "zirgen/Dialect/BigInt/Transforms/PassDetail.h"
#include "zirgen/Dialect/BigInt/Transforms/Passes.h"

using namespace mlir;

namespace zirgen::BigInt {

namespace {

struct ReplaceModularInv : public OpRewritePattern<ModularInvOp> {
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(ModularInvOp op, PatternRewriter& rewriter) const override {
// Construct the constant 1
mlir::Type oneType = rewriter.getIntegerType(1); // a `1` is bitwidth 1
auto oneAttr = rewriter.getIntegerAttr(oneType, 1); // value 1
auto one = rewriter.create<ConstOp>(op.getLoc(), oneAttr);

auto inv = rewriter.create<NondetInvModOp>(op.getLoc(), op.getLhs(), op.getRhs());
auto remult = rewriter.create<MulOp>(op.getLoc(), op.getLhs(), inv);
auto reduced = rewriter.create<ReduceOp>(op.getLoc(), remult, op.getRhs());
auto diff = rewriter.create<SubOp>(op.getLoc(), reduced, one);
rewriter.create<EqualZeroOp>(op.getLoc(), diff);
rewriter.replaceOp(op, inv);
return success();
}
};

struct LowerModularInvPass : public LowerModularInvBase<LowerModularInvPass> {
void runOnOperation() override {
auto ctx = &getContext();
RewritePatternSet patterns(ctx);
patterns.insert<ReplaceModularInv>(ctx);
if (applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)).failed()) {
return signalPassFailure();
}
}
};

} // End namespace

std::unique_ptr<OperationPass<mlir::ModuleOp>> createLowerModularInvPass() {
return std::make_unique<LowerModularInvPass>();
}

} // namespace zirgen::BigInt
4 changes: 2 additions & 2 deletions zirgen/Dialect/BigInt/Transforms/LowerZll.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ void lower(func::FuncOp inFunc) {
}
})
.Case<ConstOp>([&](auto op) { countConst += op.getOut().getType().getNormalWitnessSize(); })
.Case<NondetRemOp, NondetQuotOp>(
.Case<NondetRemOp, NondetQuotOp, NondetInvModOp>(
[&](auto op) { countPrivate += op.getOut().getType().getNormalWitnessSize(); })
.Case<EqualZeroOp>(
[&](auto op) { countPrivate += op.getIn().getType().getCarryWitnessSize(); });
Expand Down Expand Up @@ -158,7 +158,7 @@ void lower(func::FuncOp inFunc) {
valMap[op.getOut()] =
builder.create<Zll::MulOp>(loc, valMap[op.getLhs()], valMap[op.getRhs()]);
})
.Case<NondetRemOp, NondetQuotOp>([&](auto op) {
.Case<NondetRemOp, NondetQuotOp, NondetInvModOp>([&](auto op) {
valMap[op.getOut()] =
extractPoly(cbPrivate.getEvaluations(), curPrivate, op.getOut().getType());
})
Expand Down
1 change: 1 addition & 0 deletions zirgen/Dialect/BigInt/Transforms/Passes.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
namespace zirgen::BigInt {

// Pass constructors
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createLowerModularInvPass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createLowerReducePass();
std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>> createLowerZllPass();

Expand Down
5 changes: 5 additions & 0 deletions zirgen/Dialect/BigInt/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,11 @@
include "mlir/Pass/PassBase.td"
include "mlir/Rewrite/PassUtil.td"

def LowerModularInv : Pass<"lower-modular-inv", "mlir::ModuleOp"> {
let summary = "Remove BigInt::ModularInvOp by lowering it to other ops";
let constructor = "zirgen::BigInt::createLowerModularInvPass()";
}

def LowerReduce : Pass<"lower-reduce", "mlir::ModuleOp"> {
let summary = "Remove BigInt::ReduceOp by lowering it to other ops";
let constructor = "zirgen::BigInt::createLowerReducePass()";
Expand Down
1 change: 1 addition & 0 deletions zirgen/circuit/bigint/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ ZKRS = [
"mul_test_128",
"reduce_test_8",
"reduce_test_128",
"nondet_inv_test_8",
]

build_circuit(
Expand Down
6 changes: 6 additions & 0 deletions zirgen/circuit/bigint/gen_bigint.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,12 @@ int main(int argc, char* argv[]) {
BigInt::setIterationCount(funcOp, rsa.iters);
}
// TODO: More bitwidth coverage?
for (size_t numBits : {8}) {
module.addFunc<0>("nondet_inv_test_" + std::to_string(numBits), {}, [&]() {
auto& builder = Module::getCurModule()->getBuilder();
zirgen::BigInt::makeNondetInvTest(builder, builder.getUnknownLoc(), numBits);
});
}
for (size_t numBits : {8}) {
module.addFunc<0>("const_add_test_" + std::to_string(numBits), {}, [&]() {
auto& builder = Module::getCurModule()->getBuilder();
Expand Down
19 changes: 19 additions & 0 deletions zirgen/circuit/bigint/op_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,4 +128,23 @@ void makeReduceTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) {
builder.create<BigInt::EqualZeroOp>(loc, diff);
}

void makeNondetInvTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits) {
auto inp = builder.create<BigInt::DefOp>(loc, bits, 0, true);
auto prime = builder.create<BigInt::DefOp>(loc, bits, 1, true, bits - 1);
auto expected = builder.create<BigInt::DefOp>(loc, bits, 2, true);

// Construct constants
mlir::Type oneType = builder.getIntegerType(1); // a `1` is bitwidth 1
auto oneAttr = builder.getIntegerAttr(oneType, 1); // value 1
auto one = builder.create<BigInt::ConstOp>(loc, oneAttr);

auto inv = builder.create<BigInt::NondetInvModOp>(loc, inp, prime);
auto prod = builder.create<BigInt::MulOp>(loc, inp, inv);
auto reduced = builder.create<BigInt::ReduceOp>(loc, prod, prime);
auto expect_zero = builder.create<BigInt::SubOp>(loc, reduced, one);
builder.create<BigInt::EqualZeroOp>(loc, expect_zero);
auto result_match = builder.create<BigInt::SubOp>(loc, inv, expected);
builder.create<BigInt::EqualZeroOp>(loc, result_match);
}

} // namespace zirgen::BigInt
2 changes: 1 addition & 1 deletion zirgen/circuit/bigint/op_tests.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ using namespace mlir;

namespace zirgen::BigInt {

void makeIsOddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
void makeConstAddTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
void makeConstAddAltTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
void makeConstMulTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
Expand All @@ -21,5 +20,6 @@ void makeConstTwoByteTest(mlir::OpBuilder builder, mlir::Location loc, size_t bi
void makeSubTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
void makeMulTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
void makeReduceTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);
void makeNondetInvTest(mlir::OpBuilder builder, mlir::Location loc, size_t bits);

} // namespace zirgen::BigInt

0 comments on commit 8a29839

Please sign in to comment.