Skip to content

Commit

Permalink
Use type hinting generics [1] everywhere
Browse files Browse the repository at this point in the history
And use postponed evaluation of annotations [2] to avoid syntax errors
on Python 3.7 and 3.8. See [3] for why it works.

[1] https://peps.python.org/pep-0585/
[2] https://peps.python.org/pep-0563/
[3] netromdk/vermin#66 (comment)
  • Loading branch information
Chih-Hsuan Yen committed Sep 2, 2022
1 parent 5c27908 commit 1726f6d
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 4 deletions.
2 changes: 2 additions & 0 deletions dnn-models/layer_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import logging
import math
from typing import Any
Expand Down
2 changes: 2 additions & 0 deletions dnn-models/model_utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import dataclasses
import logging
import math
Expand Down
10 changes: 6 additions & 4 deletions dnn-models/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from __future__ import annotations

import enum
import itertools
import logging
Expand All @@ -6,7 +8,7 @@
import sys
import tarfile
import zipfile
from typing import Callable, Dict, Iterable, List, NamedTuple, Optional
from typing import Callable, Iterable, NamedTuple, Optional
from urllib.request import urlretrieve

import filelock
Expand Down Expand Up @@ -98,13 +100,13 @@ def find_tensor_value_info(onnx_model: onnx.ModelProto, name: str) -> onnx.Value
return value_info
raise ValueError(f'No value_info found for {name}')

def find_node_by_output(nodes: List[onnx.NodeProto], output_name: str) -> onnx.NodeProto:
def find_node_by_output(nodes: list[onnx.NodeProto], output_name: str) -> onnx.NodeProto:
for node in nodes:
for output in node.output:
if output == output_name:
return node

def find_node_by_input(nodes: List[onnx.NodeProto], input_name: str) -> onnx.NodeProto:
def find_node_by_input(nodes: list[onnx.NodeProto], input_name: str) -> onnx.NodeProto:
for node in nodes:
for input_ in node.input:
if input_ == input_name:
Expand Down Expand Up @@ -404,7 +406,7 @@ def run_model(model, model_data, limit, verbose=True, save_file=None):
print(f'correct={correct} total={total} rate={accuracy}')
return accuracy

def remap_inputs(model: onnx.ModelProto, input_mapping: Dict[str, str]):
def remap_inputs(model: onnx.ModelProto, input_mapping: dict[str, str]):
new_inputs = list(input_mapping.values())
for new_input in new_inputs:
model.graph.input.append(onnx.ValueInfoProto(name=new_input))
Expand Down

0 comments on commit 1726f6d

Please sign in to comment.