Skip to content

Commit

Permalink
update import mean_per_channel
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Oct 2, 2023
1 parent 19640ad commit 3ac8523
Showing 1 changed file with 2 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from nncf.common.tensor_statistics.statistic_point import StatisticPointsContainer
from nncf.common.utils.backend import BackendType
from nncf.common.utils.backend import get_backend
from nncf.experimental.common.tensor_statistics import statistical_functions as s_fns
from nncf.experimental.common.tensor_statistics.statistical_functions import mean_per_channel
from nncf.experimental.tensor import Tensor
from nncf.experimental.tensor import functions as fns
from nncf.quantization.algorithms.algorithm import Algorithm
Expand Down Expand Up @@ -318,7 +318,7 @@ def _get_bias_shift(
engine = EngineFactory.create(model)
raw_output = engine.infer(input_blob)
q_outputs = self._backend_entity.process_model_output(raw_output, output_name)
q_outputs = s_fns.mean_per_channel(q_outputs, channel_axis)
q_outputs = mean_per_channel(q_outputs, channel_axis)
bias_shift = fns.stack(output_fp) - q_outputs
return bias_shift

Expand Down

0 comments on commit 3ac8523

Please sign in to comment.