Skip to content

Commit

Permalink
Add int4 quantize kernel
Browse files Browse the repository at this point in the history
add

int4_1

int4_2

FLAGS_logging_pir_py_code (PaddlePaddle#63981)

* FLAGS_logging_pir_py_code

* FLAGS_logging_pir_py_code_dir

---------

Co-authored-by: jiahy0825 <[email protected]>

[Cleanup] Remove Flake8 config in `.editorconfig` (PaddlePaddle#64027)

【PIR Dist Op Reg No.19】 reg pull_box_sparse (PaddlePaddle#62982)

* fix

* fix

* fix

* fix

* fix

* fix

* add test

* add

* fix

* fix

* add out

* fix

* codestyle

* fix

* fix backward

* merge

[Dy2St][PIR] Hold backward program in GradNode (PaddlePaddle#63694)

Co-authored-by: xiongkun <[email protected]>
Co-authored-by: Nyakku Shigure <[email protected]>

split test.cmake: add new test_cases.cmake (PaddlePaddle#64007)

[PIR] Support sparse_slice and sparse_sum in pt (PaddlePaddle#64009)

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

* support sparse_slice and sparse_sum in pt

option for WITH_CPP_TEST (PaddlePaddle#63896)

* option for WITH_CPP_TEST

* fix

* Fix

* Fix

[PIR] Fix `attributes_num` of `SliceArrayOp` (PaddlePaddle#64013)

[Dy2St] Use `full_graph=True` outside dy2st uts (part1) (PaddlePaddle#64058)

[Dy2St] Use `full_graph=True` outside dy2st uts (part2) (PaddlePaddle#64059)

fix typo (PaddlePaddle#64060)

Co-authored-by: jiahy0825 <[email protected]>

update (PaddlePaddle#64042)

Replace paddle/fluid/platform/device/gpu/gpu_dnn.h (PaddlePaddle#63819)

* Fix

* Fix

* Fix

Clean lookup_table_v2_op.h lookup_table_v2_op.cu (PaddlePaddle#64020)

* Fix

* ci

refine GetTensorListFromArgs (PaddlePaddle#64045)

Revert "【Hackathon 6th Fundable Projects 3 No.60】Remove fluid operator chunk_…" (PaddlePaddle#64050)

This reverts commit 88b1a6e.

[Prim][PIR] support floor_divide op forward in prim pir (PaddlePaddle#64023)

* floor-div-dev

* update test

[CINN] Reconstruct shape_analysis (PaddlePaddle#63790)

* reconstruct shape_analysis

* fix input value shape infer

* fix merge bugs

* fix concat and gather op InferSymbolicShape

* fix merge bug

* fix value_to_shape_or_data hash error and add some checks

* fix set shape for null value

* fix group op lazy infer

* add IsStaticShape check

* fix merge bug

* support static dim check and set for VectorType

* change auto to detail type

[XPU] fix bugs in processing of attention_mask and fix_seed_offset on XPU (PaddlePaddle#64003)

* [XPU] fix segmentfault caused by setting fix_seed_offset on XPU

* cast attention_mask to float32 when necessary

fix merge bug (PaddlePaddle#64069)

【Fix PIR Unittest No.125、147、481】Fix some 0D uts in PIR mode (part1) (PaddlePaddle#64064)

[Prim][VJP]support autogen to remove unused composite in .yaml (PaddlePaddle#64054)

* support autogen to remove unused composite in .yaml

* fix bug

[PIR] Fix typo `set_pit_tests_properties` -> `set_pir_tests_properties` (PaddlePaddle#64063)

[Dy2St] Use `full_graph=True` outside dy2st uts (part3) (PaddlePaddle#64066)

[PIR save/load] Open more tests for paddle.save and paddle.load (PaddlePaddle#64044)

* open more tests for paddle.save and paddle.load

* fix

API Improvement for paddle.nn.functional.group_norm and paddle.nn.GroupNorm (PaddlePaddle#63881)

* update group_norm

* update trt plugin

* update trt plugin

* fix trt plugin

* fix trt plugin

* fix test

* fix test

* fix ci windows inference

* update kernel function names and add v2 test

* fix

* fix fp16 test

Revert "【Hackathon 6th Fundable Projects 3 No.81】Remove fluid operators ctc_a…" (PaddlePaddle#64049)

This reverts commit 2134ead.

Clean paddle/fluid/operators/fused/attention_layer_norm.h (PaddlePaddle#64051)

* Fix

* Fix

 Replace operators::math to phi::math in fluid/operators (PaddlePaddle#63854)

[CINN]Clean usless loop_reorder_aligment tactic (PaddlePaddle#63998)

* [CINN]Clean usless loop_reorder_aligment tactic

* fix source

【Hackathon 6th Fundable Projects 3 No.396】fluid operator yolo_box_head (PaddlePaddle#63783)

* Fix

* Fix

* Fix

* Fix

* Fix

【Hackathon 6th Fundable Projects 3 No.240】fluid operator moe (PaddlePaddle#63929)

【Hackathon 6th Fundable Projects 3 No.82】fluid operator cudnn_lstm (PaddlePaddle#63936)

* Fix

* Fix

* Fix

* Fix

[CINN] Remove useless log (PaddlePaddle#64052)

[pir_save_load] add pir for test_jit_save_load.py (PaddlePaddle#63958)

* add jit load.train

* modify backward program lost

* modify

* combine eval and train

* modify 8 case of jit.save.load

* modify jit_save_load case

* rename jit_save_load

* change name all

* modify timeout

* modify new case

* modify TestJitSaveLoadMultiMethods

* modify cpu tensor no holder bug

Flashattention support qkvpacked and varlen (PaddlePaddle#63289)

* Flashattention support qkvpacked and varlen

* fix codestyle

* fix codestyle

* FlashAttention kvReduceGQA Performance Optimization

* Fix problem with windows

* code clean

* update third_party/flashattn

* update errormsg and docs

* update api

* update doc

* update doctest

* update doc, test=document_fix

* update doc, test=document_fix

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <[email protected]>

* Update python/paddle/nn/functional/flash_attention.py

Co-authored-by: zachary sun <[email protected]>

* update doc

---------

Co-authored-by: zachary sun <[email protected]>

【PIR Dist Op Reg No.20】 reg global_gather (PaddlePaddle#63867)

* reg global_gather

* reg global_gather

* reg_global_gather

* fix

* fix

* fix

* fix conflict

* fix conflict

* Update ops_api_gen.py

* Update ops_api_gen.py

Fix backward program kwargs error when process inplace value (PaddlePaddle#63939)

【Hackathon 6th No.35】support kwargs for recompute when use_reentrant == True fix (PaddlePaddle#63880)

* support kwargs for recompute when use_reentrant == True

* recover third party

merge main

lint

delete printf

change flash attn version
  • Loading branch information
Your Name authored and yinfan98 committed May 7, 2024
1 parent eb7c5d1 commit 0becc4a
Show file tree
Hide file tree
Showing 264 changed files with 7,793 additions and 5,088 deletions.
5 changes: 1 addition & 4 deletions .editorconfig
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,12 @@ insert_final_newline = true
[*.{c,cc,cxx,cpp,cu,cuh,h,hpp,hxx,kps}]
indent_size = 2

[*.{py,java,r}]
[*.{py,pyi,java,r,toml}]
indent_size = 4

[Dockerfile.*]
indent_size = 4

[.flake8]
indent_size = 4

[*.go]
indent_style = tab
indent_size = 4
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ option(WITH_PIP_CUDA_LIBRARIES
"Paddle uses the CUDA library provided by NVIDIA" OFF)
option(WITH_NIGHTLY_BUILD
"Compile nightly paddle whl package of the develop branch" OFF)
option(WITH_CPP_TEST "Compile PaddlePaddle skip cpp test" ON)
find_package(Git REQUIRED)

# config GIT_URL with github mirrors to speed up dependent repos clone
Expand Down
24 changes: 12 additions & 12 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -116,14 +116,14 @@ void GroupOp::Print(pir::IrPrinter& printer) {
}

bool GroupOp::InferSymbolicShape(
::pir::ShapeConstraintIRAnalysis* shape_analysis) {
::pir::InferSymExprForBlock(*block(), shape_analysis);
::pir::InferSymbolicShapeContext* infer_context) {
::pir::InferSymExprForBlock(*block(), infer_context);

for (uint32_t rst_idx = 0; rst_idx < num_results(); rst_idx++) {
auto inner_yield_value = block()->back().operand_source(rst_idx);
const auto& shape =
shape_analysis->GetShapeOrDataForValue(inner_yield_value);
shape_analysis->SetShapeOrDataForValue(result(rst_idx), shape);
infer_context->GetShapeOrDataForValue(inner_yield_value);
infer_context->SetShapeOrDataForValue(result(rst_idx), shape);
}

if (VLOG_IS_ON(4)) {
Expand Down Expand Up @@ -204,16 +204,16 @@ void YieldStoreOp::Build(pir::Builder& builder,
void YieldStoreOp::VerifySig() {}

bool YieldStoreOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
shape_analysis->SetShapeOrDataForValue(
result(0), shape_analysis->GetShapeOrDataForValue(operand_source(0)));
pir::InferSymbolicShapeContext* infer_context) {
infer_context->SetShapeOrDataForValue(
result(0), infer_context->GetShapeOrDataForValue(operand_source(0)));
return true;
}

bool ConcatOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
pir::InferSymbolicShapeContext* infer_context) {
VLOG(4) << "Infer symbolic shape for cinn_op.concat";
return ConcatOpInferSymbolicShape(this->operation(), shape_analysis);
return ConcatOpInferSymbolicShape(this->operation(), infer_context);
}

void ConcatOp::Build(pir::Builder& builder, // NOLINT
Expand Down Expand Up @@ -476,7 +476,7 @@ GenerateShapeOp::ConvertAttributeToSymbolBindings(
}

bool GenerateShapeOp::InferSymbolicShape(
pir::ShapeConstraintIRAnalysis* shape_analysis) {
pir::InferSymbolicShapeContext* infer_context) {
const auto attr_dim_exprs = [&] {
std::vector<symbol::DimExpr> dim_exprs{};
pir::Attribute dim_expr_attr = this->attributes().at("output_dim_exprs");
Expand Down Expand Up @@ -505,7 +505,7 @@ bool GenerateShapeOp::InferSymbolicShape(
}();
auto DimExprs4InputDim =
[&](int input_idx) -> const symbol::ShapeOrDataDimExprs& {
return shape_analysis->GetShapeOrDataForValue(
return infer_context->GetShapeOrDataForValue(
this->operand_source(input_idx));
};
auto DimExprs4SymbolName =
Expand All @@ -527,7 +527,7 @@ bool GenerateShapeOp::InferSymbolicShape(
symbol::ShapeOrDataDimExprs shape_or_data_dim_exprs{
symbol::TensorShapeOrDataDimExprs(shape, substituted_dim_exprs)};

shape_analysis->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);
infer_context->SetShapeOrDataForValue(this->out(), shape_or_data_dim_exprs);

return true;
}
Expand Down
8 changes: 4 additions & 4 deletions paddle/cinn/hlir/dialect/operator/ir/manual_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class IR_API GroupOp
pir::Block *block() const;
std::vector<pir::Operation *> GetOperators() const;

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);

void VerifySig();
void Print(pir::IrPrinter &printer); // NOLINT
Expand Down Expand Up @@ -102,7 +102,7 @@ class IR_API YieldStoreOp

void VerifySig();

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API ConcatOp
Expand All @@ -123,7 +123,7 @@ class IR_API ConcatOp

void VerifySig() const {}

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);
};

class IR_API SplitOp : public pir::Op<SplitOp> {
Expand Down Expand Up @@ -177,7 +177,7 @@ class IR_API GenerateShapeOp

pir::Value out() { return result(0); }

bool InferSymbolicShape(pir::ShapeConstraintIRAnalysis *shape_analysis);
bool InferSymbolicShape(pir::InferSymbolicShapeContext *infer_context);

static pir::Attribute ConvertSymbolBindingsToAttribute(
pir::Builder &builder, const SymbolBindings &symbol_bindings); // NOLINT
Expand Down
8 changes: 2 additions & 6 deletions paddle/cinn/hlir/dialect/operator/transforms/add_cinn_pass.cc
Original file line number Diff line number Diff line change
Expand Up @@ -220,14 +220,10 @@ void ApplyCinnPass(::pir::Program* program,
ApplyPdToCinnPass(program, CreatePassManager);
ApplyCinnPreprocessPass(program, CreatePassManager);
ApplyBuildGroupOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code group-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code group-ops end]===";
PirToPyCodeConverter().SaveIfFlagEnabled("group_op_programs", *program);
ApplyGroupOpPass(program, CreatePassManager);
ApplyDivideGroupOpToFusionOpPass(program, CreatePassManager);
LOG(INFO) << "====[pir-to-py-code fusion-ops begin]===" << std::endl
<< PirToPyCodeConverter().Convert(*program);
LOG(INFO) << "====[pir-to-py-code fusion-ops end]===";
PirToPyCodeConverter().SaveIfFlagEnabled("fusion_op_programs", *program);
LOG(INFO) << "FusionOp count before lowering : *****[ "
<< GetOpCount<cinn::dialect::FusionOp>(program->module_op())
<< " ]*****";
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,11 +45,9 @@ class AddYieldStoreInFusionOpPattern
auto orignal_base = op->operand_source(i);
op->operand(i).set_source(store_op.result(0));

if (shape_analysis.HasShapeOrDataForValue(orignal_base)) {
shape_analysis.SetShapeOrDataForValue(
store_op.result(0),
shape_analysis.GetShapeOrDataForValue(orignal_base));
}
shape_analysis.SetShapeOrDataForValue(
store_op.result(0),
shape_analysis.GetShapeOrDataForValue(orignal_base));
}

return true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,10 @@ class BlockDimExprsAsserter {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
} else {
bool infer_result = interface.InferSymbolicShape(shape_analysis.get());
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
// is done.
bool infer_result = interface.InferSymbolicShape(
shape_analysis->GetInferSymbolicShapeContext());
PADDLE_ENFORCE_EQ(infer_result,
true,
::common::errors::PreconditionNotMet(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -182,11 +182,9 @@ ::pir::GroupOpsVec CloneOps(
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());

for (size_t i = 0; i < op->num_results(); ++i) {
if (shape_analysis.HasShapeOrDataForValue(op->result(i))) {
shape_analysis.SetShapeOrDataForValue(
new_op->result(i),
shape_analysis.GetShapeOrDataForValue(op->result(i)));
}
shape_analysis.SetShapeOrDataForValue(
new_op->result(i),
shape_analysis.GetShapeOrDataForValue(op->result(i)));
}

vec_new_op_list.push_back(new_op);
Expand Down Expand Up @@ -357,11 +355,9 @@ class CinnGroupClusterPattern
// update ir mapping
for (size_t i = 0; i < output_values.size(); ++i) {
ir_mapping.Add(output_values[i], new_group_op->result(i));
if (shape_analysis.HasShapeOrDataForValue(output_values[i])) {
shape_analysis.SetShapeOrDataForValue(
new_group_op->result(i),
shape_analysis.GetShapeOrDataForValue(output_values[i]));
}
shape_analysis.SetShapeOrDataForValue(
new_group_op->result(i),
shape_analysis.GetShapeOrDataForValue(output_values[i]));
}
for (size_t i = 0; i < output_values.size(); ++i) {
auto find_it = all_output_values.find(output_values[i]);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,18 +33,16 @@ bool ReplaceOpWithReshapeOp(pir::Operation* op,
std::vector<int> shape = phi::vectorize<int>(
output.type().dyn_cast<pir::DenseTensorType>().dims());

if (shape_analysis->HasShapeOrDataForValue(op->result(0))) {
const auto& shape_info =
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
int temp_dim = -1;

for (size_t i = 0; i < shape_info.size(); ++i) {
if (shape_info[i].isa<int64_t>()) {
shape[i] = shape_info[i].Get<int64_t>();
} else {
shape[i] = temp_dim;
temp_dim = 1;
}
const auto& shape_info =
shape_analysis->GetShapeOrDataForValue(op->result(0)).shape();
int temp_dim = -1;

for (size_t i = 0; i < shape_info.size(); ++i) {
if (shape_info[i].isa<int64_t>()) {
shape[i] = shape_info[i].Get<int64_t>();
} else {
shape[i] = temp_dim;
temp_dim = 1;
}
}
return shape;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,11 @@ bool RemoveOp(pir::Operation* op, pir::PatternRewriter* rewriter) {
if (has_dynamic_shape) {
auto& shape_analysis =
pir::ShapeAnalysisManager::Instance().Get(op->GetParentProgram());
if (shape_analysis.HasShapeOrDataForValue(input) &&
shape_analysis.HasShapeOrDataForValue(output)) {
auto input_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
auto output_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
return input_sym_shape == output_sym_shape;
}
return false;
auto input_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(input));
auto output_sym_shape =
GetExprVecFromShape(shape_analysis.GetShapeOrDataForValue(output));
return input_sym_shape == output_sym_shape;
}
return GetDims(input) == GetDims(output);
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,10 @@ void InferSymbolicShapeForSubgraph(
auto infer_symbolic_shape_interface =
op->dyn_cast<paddle::dialect::InferSymbolicShapeInterface>();
if (infer_symbolic_shape_interface) {
infer_symbolic_shape_interface.InferSymbolicShape(shape_analysis);
// TODO(Hongqing-work): delete this after the shape analysis reconstruct
// is done.
infer_symbolic_shape_interface.InferSymbolicShape(
shape_analysis->GetInferSymbolicShapeContext());
} else {
PADDLE_THROW(phi::errors::Unimplemented(
op->name() + " DOES NOT have InferSymbolicShapeInterface!"));
Expand Down Expand Up @@ -348,7 +351,6 @@ bool ReplaceShapeOpsToGenerateShape(
auto ShapeOrDataDimExprs4Value =
[&shape_analysis](
pir::Value value) -> const symbol::ShapeOrDataDimExprs& {
CHECK(shape_analysis->HasShapeOrDataForValue(value));
return shape_analysis->GetShapeOrDataForValue(value);
};
std::optional<pir::Value> opt_generated_shape =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -104,9 +104,6 @@ class DynamicToStaticConverter {
}

bool Convert() {
if (!IsSymbolFullyInfered()) {
return false;
}
bool updated = false;
VisitEachValue(fusion_op_, [&](pir::Value value) {
updated |= UpdateValueShape(value);
Expand All @@ -116,16 +113,6 @@ class DynamicToStaticConverter {
}

private:
bool IsSymbolFullyInfered() {
bool is_infered = true;
VisitEachValue(fusion_op_, [&](pir::Value value) {
if (!shape_analysis_->HasShapeOrDataForValue(value)) {
is_infered = false;
}
});
return is_infered;
}

DimExpr4SymbolName InitDimExpr4SymbolName() {
const auto* map = GetGlobalDynamicToStaticDimMap();
CHECK(map->has_value());
Expand Down Expand Up @@ -178,7 +165,6 @@ class DynamicToStaticConverter {

bool UpdateValueShape(pir::Value value) {
bool update = false;
CHECK(shape_analysis_->HasShapeOrDataForValue(value));
const auto& origin_shape = GetOriginValueShape(value);
const auto& target_shape = GetTargetValueShape(value);
PADDLE_ENFORCE_EQ(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ struct StaticDimToDynamicConverter {
&pir::ShapeAnalysisManager::Instance().Get(
this->fusion_op->GetParentProgram());
ForEachValue([&](pir::Value value) {
CHECK(shape_analysis->HasShapeOrDataForValue(value));
const auto& origin_shape = GetOriginValueShape(value);
const auto& target_shape = GetTargetValueShape(
shape_analysis->GetShapeOrDataForValue(value).shape());
Expand Down Expand Up @@ -369,26 +368,8 @@ struct StaticDimToDynamicConverter {
pir::Value value,
int64_t constant,
const std::string& symbol) {
if (shape_analysis->HasShapeOrDataForValue(value)) {
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
} else {
auto& dims = value.type().dyn_cast<::pir::DenseTensorType>().dims();
const auto& int_dims = ::common::vectorize<int>(dims);
std::vector<symbol::DimExpr> old{};
for (int dim : int_dims) {
old.emplace_back(static_cast<std::int64_t>(dim));
}
const auto& opt_exprs =
ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
if (opt_exprs.has_value()) {
return opt_exprs.value();
} else {
return symbol::ShapeOrDataDimExprs{
symbol::TensorShapeOrDataDimExprs(old)};
}
}
PADDLE_THROW(phi::errors::Fatal("Dead code"));
const auto& old = shape_analysis->GetShapeOrDataForValue(value).shape();
return ConvertShapeOrDataDimExprs(Converter, old, constant, symbol);
}

template <typename ConverterT>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,20 +101,14 @@ void SimplifyDimExpr(pir::Operation* module_op) {

VisitEachOp(module_op, [&](pir::Operation& op) {
VisitEachValue(op, [&](pir::Value value) {
if (!shape_analysis->HasShapeOrDataForValue(value)) {
VLOG(4) << "SimplifyDimExpr: shape_analysis can't find ShapeOrData for "
"value of the op:"
<< op.name();
} else {
const symbol::ShapeOrDataDimExprs& shape_or_data =
shape_analysis->GetShapeOrDataForValue(value);
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
SimplifyShapeOrData(shape_or_data);
VLOG(8) << op.name()
<< " simplified_shape_or_data: " << simplified_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
}
const symbol::ShapeOrDataDimExprs& shape_or_data =
shape_analysis->GetShapeOrDataForValue(value);
VLOG(8) << op.name() << " origin_shape_or_data: " << shape_or_data;
symbol::ShapeOrDataDimExprs simplified_shape_or_data =
SimplifyShapeOrData(shape_or_data);
VLOG(8) << op.name()
<< " simplified_shape_or_data: " << simplified_shape_or_data;
shape_analysis->SetShapeOrDataForValue(value, simplified_shape_or_data);
});
if (op.num_results() > 0) {
pir::shape::SetShapeAttrForOp(
Expand Down
Loading

0 comments on commit 0becc4a

Please sign in to comment.