Skip to content

Commit

Permalink
mypy api
Browse files Browse the repository at this point in the history
  • Loading branch information
AlexanderDokuchaev committed Dec 29, 2024
1 parent a25d85e commit b4d6a00
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 7 deletions.
16 changes: 9 additions & 7 deletions nncf/api/compression.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations

from abc import ABC
from abc import abstractmethod
Expand All @@ -33,7 +34,7 @@ class CompressionLoss(ABC):
"""

@abstractmethod
def calculate(self, *args, **kwargs) -> Any:
def calculate(self, *args: Any, **kwargs: Any) -> Any:
"""
Calculates and returns the compression loss value.
"""
Expand All @@ -53,7 +54,7 @@ def get_state(self) -> Dict[str, Any]:
Returns the compression loss state.
"""

def __call__(self, *args, **kwargs) -> Any:
def __call__(self, *args: Any, **kwargs: Any) -> Any:
"""
Calculates and returns the compression loss value. Same as `.calculate()`.
"""
Expand Down Expand Up @@ -127,7 +128,7 @@ class CompressionStage(IntEnum):
PARTIALLY_COMPRESSED = 1
FULLY_COMPRESSED = 2

def __add__(self, other: "CompressionStage") -> "CompressionStage":
def __add__(self, other: int) -> CompressionStage:
"""
Defines compression stage of a composite compression controller, consist of
two algorithms, where `self` is the compression stage of the first algorithm
Expand Down Expand Up @@ -162,7 +163,7 @@ def __init__(self, target_model: TModel):
self._model = target_model

@property
def model(self) -> TModel:
def model(self) -> TModel: # type: ignore[type-var]
"""
The compressed model object with which this controller is associated.
"""
Expand Down Expand Up @@ -215,6 +216,7 @@ def get_compression_state(self) -> Dict[str, Any]:
:return: Compression state of the model to resume compression from it.
"""

@abstractmethod
def compression_stage(self) -> CompressionStage:
"""
Returns the compression stage. Should be used on saving best checkpoints
Expand Down Expand Up @@ -254,7 +256,7 @@ def prepare_for_export(self) -> None:
"""
self._model = self.strip_model(self._model)

def strip(self, do_copy: bool = True) -> TModel:
def strip(self, do_copy: bool = True) -> TModel: # type: ignore[type-var]
"""
Returns the model object with as much custom NNCF additions as possible removed
while still preserving the functioning of the model object as a compressed model.
Expand All @@ -263,7 +265,7 @@ def strip(self, do_copy: bool = True) -> TModel:
will return the currently associated model object "stripped" in-place.
:return: The stripped model.
"""
return self.strip_model(self.model, do_copy)
return self.strip_model(self.model, do_copy) # type: ignore

@abstractmethod
def export_model(
Expand Down Expand Up @@ -413,7 +415,7 @@ class CompressionLevel(IntEnum):
FULL = 2

@classmethod
def map_legacy_level_to_stage(cls):
def map_legacy_level_to_stage(cls) -> Dict[CompressionLevel, CompressionStage]:
return {
CompressionLevel.NONE: CompressionStage.UNCOMPRESSED,
CompressionLevel.PARTIAL: CompressionStage.PARTIALLY_COMPRESSED,
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,7 @@ strict = true
# https://github.com/hauntsaninja/no_implicit_optional
implicit_optional = true
files = [
"nncf/api",
"nncf/common/sparsity",
"nncf/common/graph",
"nncf/common/accuracy_aware_training/",
Expand Down

0 comments on commit b4d6a00

Please sign in to comment.