Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[impellerc] sort uniforms on metal backend #39366

Merged
merged 5 commits into from
Feb 3, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions ci/licenses_golden/licenses_flutter
Original file line number Diff line number Diff line change
Expand Up @@ -1226,6 +1226,8 @@ ORIGIN: ../../../flutter/impeller/compiler/switches.cc + ../../../flutter/LICENS
ORIGIN: ../../../flutter/impeller/compiler/switches.h + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/compiler/types.cc + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/compiler/types.h + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/compiler/uniform_sorter.cc + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/compiler/uniform_sorter.h + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/compiler/utilities.cc + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/compiler/utilities.h + ../../../flutter/LICENSE
ORIGIN: ../../../flutter/impeller/display_list/display_list_dispatcher.cc + ../../../flutter/LICENSE
Expand Down Expand Up @@ -3705,6 +3707,8 @@ FILE: ../../../flutter/impeller/compiler/switches.cc
FILE: ../../../flutter/impeller/compiler/switches.h
FILE: ../../../flutter/impeller/compiler/types.cc
FILE: ../../../flutter/impeller/compiler/types.h
FILE: ../../../flutter/impeller/compiler/uniform_sorter.cc
FILE: ../../../flutter/impeller/compiler/uniform_sorter.h
FILE: ../../../flutter/impeller/compiler/utilities.cc
FILE: ../../../flutter/impeller/compiler/utilities.h
FILE: ../../../flutter/impeller/display_list/display_list_dispatcher.cc
Expand Down
2 changes: 2 additions & 0 deletions impeller/compiler/BUILD.gn
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ impeller_component("compiler_lib") {
"switches.h",
"types.cc",
"types.h",
"uniform_sorter.cc",
"uniform_sorter.h",
]

public_deps = [
Expand Down
46 changes: 46 additions & 0 deletions impeller/compiler/compiler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "impeller/compiler/includer.h"
#include "impeller/compiler/logger.h"
#include "impeller/compiler/types.h"
#include "impeller/compiler/uniform_sorter.h"

namespace impeller {
namespace compiler {
Expand All @@ -36,6 +37,51 @@ static CompilerBackend CreateMSLCompiler(const spirv_cross::ParsedIR& ir,
spirv_cross::CompilerMSL::Options::make_msl_version(1, 2);
sl_compiler->set_msl_options(sl_options);

// Sort the float and sampler uniforms according to their declared/decorated
// order. For user authored fragment shaders, the API for setting uniform
// values uses the index of the uniform in the declared order. By default, the
// metal backend of spirv-cross will order uniforms according to usage. To fix
// this, we use the sorted order and the add_msl_resource_binding API to force
// the ordering to match the declared order. Note that while this code runs
// for all compiled shaders, it will only affect fragment shaders due to the
// specified stage.
auto floats =
SortUniforms(&ir, sl_compiler.get(), spirv_cross::SPIRType::Float);
auto images =
SortUniforms(&ir, sl_compiler.get(), spirv_cross::SPIRType::SampledImage);

uint32_t buffer_offset = 0;
uint32_t sampler_offset = 0;
for (auto& float_id : floats) {
sl_compiler->add_msl_resource_binding(
{.stage = spv::ExecutionModel::ExecutionModelFragment,
.basetype = spirv_cross::SPIRType::BaseType::Float,
.desc_set = sl_compiler->get_decoration(float_id,
spv::DecorationDescriptorSet),
.binding =
sl_compiler->get_decoration(float_id, spv::DecorationBinding),
.count = 1u,
.msl_buffer = buffer_offset});
buffer_offset++;
}
for (auto& image_id : images) {
sl_compiler->add_msl_resource_binding({
.stage = spv::ExecutionModel::ExecutionModelFragment,
.basetype = spirv_cross::SPIRType::BaseType::SampledImage,
.desc_set =
sl_compiler->get_decoration(image_id, spv::DecorationDescriptorSet),
.binding =
sl_compiler->get_decoration(image_id, spv::DecorationBinding),
.count = 1u,
// A sampled image is both an image and a sampler, so both
// offsets need to be set or depending on the partiular shader
// the bindings may be incorrect.
.msl_texture = sampler_offset,
.msl_sampler = sampler_offset,
});
sampler_offset++;
}

return CompilerBackend(sl_compiler);
}

Expand Down
4 changes: 2 additions & 2 deletions impeller/compiler/compiler_backend.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@ struct CompilerBackend {

const spirv_cross::Compiler* operator->() const;

spirv_cross::Compiler* GetCompiler();

operator bool() const;

enum class ExtendedResourceIndex {
Expand All @@ -55,8 +57,6 @@ struct CompilerBackend {

const spirv_cross::Compiler* GetCompiler() const;

spirv_cross::Compiler* GetCompiler();

private:
Type type_ = Type::kMSL;
Compiler compiler_;
Expand Down
37 changes: 20 additions & 17 deletions impeller/compiler/reflector.cc
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include "impeller/base/strings.h"
#include "impeller/base/validation.h"
#include "impeller/compiler/code_gen_template.h"
#include "impeller/compiler/uniform_sorter.h"
#include "impeller/compiler/utilities.h"
#include "impeller/geometry/matrix.h"
#include "impeller/geometry/scalar.h"
Expand Down Expand Up @@ -359,23 +360,25 @@ std::shared_ptr<RuntimeStageData> Reflector::GenerateRuntimeStageData() const {
if (sksl_data_) {
data->SetSkSLData(sksl_data_);
}
ir_->for_each_typed_id<spirv_cross::SPIRVariable>(
[&](uint32_t, const spirv_cross::SPIRVariable& var) {
if (var.storage != spv::StorageClassUniformConstant) {
return;
}
const auto spir_type = compiler_->get_type(var.basetype);
UniformDescription uniform_description;
uniform_description.name = compiler_->get_name(var.self);
uniform_description.location = compiler_->get_decoration(
var.self, spv::Decoration::DecorationLocation);
uniform_description.type = spir_type.basetype;
uniform_description.rows = spir_type.vecsize;
uniform_description.columns = spir_type.columns;
uniform_description.bit_width = spir_type.width;
uniform_description.array_elements = GetArrayElements(spir_type);
data->AddUniformDescription(std::move(uniform_description));
});

// Sort the IR so that the uniforms are in declaration order.
std::vector<spirv_cross::ID> uniforms =
SortUniforms(ir_.get(), compiler_.GetCompiler());

for (auto& sorted_id : uniforms) {
auto var = ir_->ids[sorted_id].get<spirv_cross::SPIRVariable>();
const auto spir_type = compiler_->get_type(var.basetype);
UniformDescription uniform_description;
uniform_description.name = compiler_->get_name(var.self);
uniform_description.location = compiler_->get_decoration(
var.self, spv::Decoration::DecorationLocation);
uniform_description.type = spir_type.basetype;
uniform_description.rows = spir_type.vecsize;
uniform_description.columns = spir_type.columns;
uniform_description.bit_width = spir_type.width;
uniform_description.array_elements = GetArrayElements(spir_type);
data->AddUniformDescription(std::move(uniform_description));
}
return data;
}

Expand Down
46 changes: 7 additions & 39 deletions impeller/compiler/spirv_sksl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
// found in the LICENSE file.

#include "impeller/compiler/spirv_sksl.h"
#include "impeller/compiler/uniform_sorter.h"

using namespace spv;
using namespace SPIRV_CROSS_NAMESPACE;
Expand Down Expand Up @@ -219,47 +220,14 @@ bool CompilerSkSL::emit_uniform_resources() {
bool emitted = false;

// Output Uniform Constants (values, samplers, images, etc).
std::vector<ID> regular_uniforms;
std::vector<ID> shader_uniforms;
for (auto& id : ir.ids) {
if (id.get_type() == TypeVariable) {
auto& var = id.get<SPIRVariable>();
auto& type = get<SPIRType>(var.basetype);
if (var.storage != StorageClassFunction && !is_hidden_variable(var) &&
type.pointer &&
(type.storage == StorageClassUniformConstant ||
type.storage == StorageClassAtomicCounter)) {
// Separate out the uniforms that will be of SkSL 'shader' type since
// we need to make sure they are emitted only after the other uniforms.
if (type.basetype == SPIRType::SampledImage) {
shader_uniforms.push_back(var.self);
} else {
regular_uniforms.push_back(var.self);
}
emitted = true;
}
}
std::vector<ID> regular_uniforms =
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The revert was required because this change was filtering out non float uniforms, which we would never emit, and therefore never error check. We could move the error checking earlier in the pipeline, now that I'm more familiar with the API that is probably a better place to do it.

That said for now, the minimal change is to sort all unfiorms except sampled images instead of all float uniforms.

SortUniforms(&ir, this, SPIRType::SampledImage, /*include=*/false);
std::vector<ID> shader_uniforms =
SortUniforms(&ir, this, SPIRType::SampledImage);
if (regular_uniforms.size() > 0 || shader_uniforms.size() > 0) {
emitted = true;
}

// Sort uniforms by location.
auto compare_locations = [this](ID id1, ID id2) {
auto& flags1 = get_decoration_bitset(id1);
auto& flags2 = get_decoration_bitset(id2);
// Put the uniforms with no location after the ones that have a location.
if (!flags1.get(DecorationLocation)) {
return false;
}
if (!flags2.get(DecorationLocation)) {
return true;
}
// Sort in increasing order of location.
return get_decoration(id1, DecorationLocation) <
get_decoration(id2, DecorationLocation);
};
std::sort(regular_uniforms.begin(), regular_uniforms.end(),
compare_locations);
std::sort(shader_uniforms.begin(), shader_uniforms.end(), compare_locations);

for (const auto& id : regular_uniforms) {
auto& var = get<SPIRVariable>(id);
emit_uniform(var);
Expand Down
48 changes: 48 additions & 0 deletions impeller/compiler/uniform_sorter.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
// Copyright 2013 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#include "impeller/compiler/uniform_sorter.h"

namespace impeller {

std::vector<spirv_cross::ID> SortUniforms(
const spirv_cross::ParsedIR* ir,
const spirv_cross::Compiler* compiler,
std::optional<spirv_cross::SPIRType::BaseType> type_filter,
bool include) {
// Sort the IR so that the uniforms are in declaration order.
std::vector<spirv_cross::ID> uniforms;
ir->for_each_typed_id<spirv_cross::SPIRVariable>(
[&](uint32_t, const spirv_cross::SPIRVariable& var) {
if (var.storage != spv::StorageClassUniformConstant) {
return;
}
const auto type = compiler->get_type(var.basetype);
if (!type_filter.has_value() ||
(include && type_filter.value() == type.basetype) ||
(!include && type_filter.value() != type.basetype)) {
uniforms.push_back(var.self);
}
});

auto compare_locations = [&ir](spirv_cross::ID id1, spirv_cross::ID id2) {
auto& flags1 = ir->get_decoration_bitset(id1);
auto& flags2 = ir->get_decoration_bitset(id2);
// Put the uniforms with no location after the ones that have a location.
if (!flags1.get(spv::Decoration::DecorationLocation)) {
return false;
}
if (!flags2.get(spv::Decoration::DecorationLocation)) {
return true;
}
// Sort in increasing order of location.
return ir->get_decoration(id1, spv::Decoration::DecorationLocation) <
ir->get_decoration(id2, spv::Decoration::DecorationLocation);
};
std::sort(uniforms.begin(), uniforms.end(), compare_locations);

return uniforms;
}

} // namespace impeller
27 changes: 27 additions & 0 deletions impeller/compiler/uniform_sorter.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
// Copyright 2013 The Flutter Authors. All rights reserved.
// Use of this source code is governed by a BSD-style license that can be
// found in the LICENSE file.

#pragma once

#include <optional>

#include "impeller/compiler/compiler_backend.h"

#include "spirv_msl.hpp"
#include "spirv_parser.hpp"

namespace impeller {

/// @brief Sorts uniform declarations in an IR according to decoration order.
///
/// The [type_filter] may be optionally supplied to limit which types are
/// returned The [include] value can be set to false change this filter to
/// exclude instead of include.
std::vector<spirv_cross::ID> SortUniforms(
const spirv_cross::ParsedIR* ir,
const spirv_cross::Compiler* compiler,
std::optional<spirv_cross::SPIRType::BaseType> type_filter = std::nullopt,
bool include = true);

} // namespace impeller
2 changes: 1 addition & 1 deletion impeller/entity/contents/runtime_effect_contents.cc
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ bool RuntimeEffectContents::Render(const ContentContext& renderer,

ShaderUniformSlot uniform_slot;
uniform_slot.name = uniform.name.c_str();
uniform_slot.ext_res_0 = buffer_index;
uniform_slot.ext_res_0 = uniform.location;
cmd.BindResource(ShaderStage::kFragment, uniform_slot, metadata,
buffer_view);
buffer_index++;
Expand Down