Skip to content

Commit

Permalink
Separate reset for KV state and LoRA state in LLMPipeline (openvinoto…
Browse files Browse the repository at this point in the history
…olkit#1058)

Fixing a bug when LoRA state is experienced reset each time when
generate is invoked that brought unnecessary overhead in each generate
call even if LoRA tensors/alphas are not changed.
  • Loading branch information
slyalin authored Oct 24, 2024
1 parent 275729c commit 6a4ba7f
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 40 deletions.
5 changes: 3 additions & 2 deletions src/cpp/include/openvino/genai/lora_adapter.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -190,8 +190,9 @@ class OPENVINO_GENAI_EXPORTS AdapterController {
// Apply adapters configured in the current config set last time, or set and use new config given as optional `config` argument
void apply(ov::InferRequest& request, const std::optional<AdapterConfig>& config = std::nullopt);

// the next call of apply will set all adapter tensors regardless of config change, use this method if full state.reset is called for the controlled model
void force_full_apply(bool full_apply = true);
// Returns true if a given name is one of the state names created by this adapter controller for dynamic LoRA
// Helps to distinguish LoRA states from other states (e.g. KV cache state) in the model for a partial state reset.
bool has_state_name(const std::string& name);

operator bool() const {
return bool(m_pimpl);
Expand Down
23 changes: 15 additions & 8 deletions src/cpp/src/llm_pipeline.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,6 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
m_adapter_controller = AdapterController(model, m_generation_config.adapters, device); // TODO: Make the prefix name configurable
utils::slice_matmul_statefull_model(model);
m_model_runner = core.compile_model(model, device, compile_plugin_config).create_infer_request();
m_adapter_controller->apply(m_model_runner, m_generation_config.adapters);
} else {
auto [core_plugin_config, compile_plugin_config] = ov::genai::utils::split_core_complile_config(plugin_config);
core.set_property(core_plugin_config);
Expand Down Expand Up @@ -179,6 +178,18 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
return decoded_results;
}

void reset_kv_state() {
if(m_adapter_controller) {
for(auto& state: m_model_runner.query_state()) {
if(!m_adapter_controller->has_state_name(state.get_name())) {
state.reset();
}
}
} else {
m_model_runner.reset_state();
}
}

EncodedResults generate(
const EncodedInputs& inputs,
OptionalGenerationConfig generation_config,
Expand Down Expand Up @@ -273,11 +284,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
}

if (!is_chat_conversation) {
// FIXME: Reset only KV cache part of state, there is also can be LoRA applied in the states and full reset will need to reapply LoRA even if the LoRA config is not changed
m_model_runner.reset_state();
if(m_adapter_controller) {
m_adapter_controller->force_full_apply(); // FIXME: Reset only KV cache part to avoid this call
}
reset_kv_state();
m_selected_beam = std::nullopt;
} else {
m_is_cache_empty = false;
Expand All @@ -297,7 +304,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
is_chat_conversation = true;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
m_model_runner.reset_state();
reset_kv_state();
m_is_cache_empty = true;
m_history = {};
m_templated_chat_history = "";
Expand All @@ -315,7 +322,7 @@ class StatefulLLMPipeline final : public LLMPipelineImplBase {
is_chat_conversation = false;
m_selected_beam = std::nullopt;
if (!m_is_cache_empty) {
m_model_runner.reset_state();
reset_kv_state();
m_is_cache_empty = true;
m_history.clear();
m_templated_chat_history.clear();
Expand Down
58 changes: 28 additions & 30 deletions src/cpp/src/lora_adapter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,17 +367,8 @@ struct LoRAParametersByWeightGetter {
};


// TODO: There is possible simplification if a new feature is implemented in OpenVINO:
// move name from LoRAVarIDs to to LoRAIndices when the order of tensors in the model state in OV infer request will
// be the same as the order of variables, remove LoRAVarsIDs in this case.

struct LoRAIndices : public LoRAParts<size_t> {
std::string name;
};

struct LoRAVarIDs : public LoRAParts<ov::op::util::VariableInfo> {
std::string name; // layer name where LoRA with given variables is attached
};
using LoRAIndices = LoRAParts<size_t>;
using LoRAVarIDs = LoRAParts<ov::op::util::VariableInfo>;


// Deduce expected LoRA input and output static dimensions based on a given node where LoRA is applied
Expand All @@ -398,15 +389,18 @@ void deduce_input_output_dims(NodePtr node, ov::Dimension& input_dim, ov::Dimens
}


using LoRAVarMap = std::map<std::string, LoRAVarIDs>;


// Creates ReadValue and Assign nodes to inject LoRA tensors as variables for a given node but
// doesn't connect them to the model returning as LoRANode instance.
struct LoRAWeightStateGetter {
LoRAParametersGetter params_getter;
std::shared_ptr<ov::Model> model;
std::vector<LoRAVarIDs>& variable_ids;
LoRAVarMap& variable_ids;
// TODO: Use variable indices instead of variable_id for faster search for a state tensor

LoRAWeightStateGetter (const LoRAParametersGetter& params_getter, std::shared_ptr<ov::Model> model, std::vector<LoRAVarIDs>& variable_ids) :
LoRAWeightStateGetter (const LoRAParametersGetter& params_getter, std::shared_ptr<ov::Model> model, LoRAVarMap& variable_ids) :
params_getter(params_getter), model(model), variable_ids(variable_ids) {}

std::optional<LoRANode> operator() (NodePtr node) const {
Expand All @@ -420,7 +414,6 @@ struct LoRAWeightStateGetter {
std::string variable_id_prefix = "lora_state_" + std::to_string(model->get_sinks().size()) + name;
LoRANode result;
LoRAVarIDs var_ids;
var_ids.name = name;

// FIXME: No guarantees on ordering of state in InferRequest makes impossible using indices of variables later, forced to use variable_id instead
//indices.A = model->get_variables().size();
Expand All @@ -446,7 +439,7 @@ struct LoRAWeightStateGetter {
variable_id_prefix + ".B"
};
result.B = add_variable(var_ids.B);
variable_ids.emplace_back(var_ids);
variable_ids.emplace(name, var_ids);
return result;
} else {
return std::nullopt;
Expand Down Expand Up @@ -815,7 +808,8 @@ bool operator< (const Adapter& a, const Adapter& b) {


struct AdapterControllerImpl {
std::vector<LoRAVarIDs> variable_ids;
LoRAVarMap variable_ids;
std::unordered_set<std::string> variable_names;
AdapterConfig current_config;
bool need_full_apply = true;
InferRequestSignatureCache lora_state_evaluators;
Expand Down Expand Up @@ -890,6 +884,13 @@ struct AdapterControllerImpl {

pm.run_passes(model);
model->validate_nodes_and_infer_types(); // FIXME: For debugging purposes only

// Collect all variable names to quickly detect which state tensor belongs to this adapter controller later
for(const auto& var: variable_ids) {
variable_names.insert(var.second.A.variable_id);
variable_names.insert(var.second.B.variable_id);
variable_names.insert(var.second.alpha.variable_id);
}
}

static std::shared_ptr<Adapter::Impl> get_adapter_impl(const Adapter& adapter) {
Expand Down Expand Up @@ -945,15 +946,14 @@ struct AdapterControllerImpl {
} else if(diff) {
if(diff.adapter) {
set_new_adapter_tensors(infer_request);
} else {
OPENVINO_ASSERT(diff.alpha);
} else if(diff.alpha) {
set_new_adapter_alphas(infer_request);
}
}
}

void force_full_apply(bool full_apply) {
need_full_apply = full_apply;
bool has_state_name(const std::string& name) {
return variable_names.count(name);
}

void set_new_adapter_alphas (ov::InferRequest& infer_request) {
Expand Down Expand Up @@ -988,12 +988,10 @@ struct AdapterControllerImpl {
for(const auto& lora_var_ids : variable_ids) {
// FIXME: Remove this mapping when the order of state will be the same as the order of variables
LoRAIndices lora_indices;
lora_indices.alpha = state_name_to_index.at(lora_var_ids.alpha.variable_id);
lora_indices.A = state_name_to_index.at(lora_var_ids.A.variable_id);
lora_indices.B = state_name_to_index.at(lora_var_ids.B.variable_id);
lora_indices.name = lora_var_ids.name; // TODO: Redundant?

set_lora_tensors(state, lora_var_ids, lora_indices, weight_getters);
lora_indices.alpha = state_name_to_index.at(lora_var_ids.second.alpha.variable_id);
lora_indices.A = state_name_to_index.at(lora_var_ids.second.A.variable_id);
lora_indices.B = state_name_to_index.at(lora_var_ids.second.B.variable_id);
set_lora_tensors(state, lora_var_ids.first, lora_var_ids.second, lora_indices, weight_getters);
}
}

Expand Down Expand Up @@ -1191,13 +1189,13 @@ struct AdapterControllerImpl {
return shape;
}

void set_lora_tensors(std::vector<VariableState>& state, const LoRAVarIDs& lora_var_ids, const LoRAIndices& lora_indices, const std::vector<LoRAWeightGetter>& weight_getters) {
void set_lora_tensors(std::vector<VariableState>& state, const std::string& name, const LoRAVarIDs& lora_var_ids, const LoRAIndices& lora_indices, const std::vector<LoRAWeightGetter>& weight_getters) {
LoRAParts<ov::Tensor> lora_state_tensors{
ov::Tensor(lora_var_ids.alpha.data_type, dynamic_to_static(lora_var_ids.alpha.data_shape)),
ov::Tensor(lora_var_ids.A.data_type, dynamic_to_static(lora_var_ids.A.data_shape)),
ov::Tensor(lora_var_ids.B.data_type, dynamic_to_static(lora_var_ids.B.data_shape))
};
auto new_tensors = prepare_lora_tensors(lora_indices.name, weight_getters, lora_state_tensors);
auto new_tensors = prepare_lora_tensors(name, weight_getters, lora_state_tensors);

state[lora_indices.alpha].set_state(new_tensors.alpha);
state[lora_indices.A].set_state(new_tensors.A);
Expand Down Expand Up @@ -1269,8 +1267,8 @@ void AdapterController::apply(ov::InferRequest& request, const std::optional<Ada
}


void AdapterController::force_full_apply(bool full_apply) {
return m_pimpl->force_full_apply(full_apply);
bool AdapterController::has_state_name(const std::string& name) {
return m_pimpl->has_state_name(name);
}


Expand Down

0 comments on commit 6a4ba7f

Please sign in to comment.