Skip to content

Commit

Permalink
More precise import detection + better type handling
Browse files Browse the repository at this point in the history
  • Loading branch information
bart1e committed Feb 15, 2023
1 parent d342d8c commit 2268bb0
Showing 1 changed file with 63 additions and 33 deletions.
96 changes: 63 additions & 33 deletions slither/detectors/statements/unused_imports.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
from logging import Logger
from typing import Dict, Set

from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Contract
Expand Down Expand Up @@ -28,18 +29,18 @@ class UnusedImports(AbstractDetector):
)
WIKI_RECOMMENDATION = "Remove unused imports"

analysed_files = set()
analysed_files: Set[str] = set()

def __init__(
self, compilation_unit: SlitherCompilationUnit, slither: "Slither", logger: Logger
):
super().__init__(compilation_unit, slither, logger)
self.needed_imports = self._new_dict()
self.actual_imports = self._new_dict()
self.absolute_path_to_imp_filename = {}
self.absolute_path_to_imp_filename: Dict[str, str] = {}
self.imports_cycle_detected = False

def _dfs(self, graph: dict, color: dict, x: str):
def _dfs(self, graph: Dict[str, Set[str]], color: Dict[str, int], x: str) -> None:
"""
Simple DFS that checks for a cycle in a given graph.
"""
Expand All @@ -51,11 +52,11 @@ def _dfs(self, graph: dict, color: dict, x: str):
self.imports_cycle_detected = True
color[x] = 2

def _detect_cycles(self, graph: dict):
def _detect_cycles(self, graph: Dict[str, Set[str]]) -> None:
"""
Detects cycle in a directed graph.
"""
vertex_color = {}
vertex_color: Dict[str, int] = {}
for k in graph.keys():
vertex_color[k] = 0
for k in graph.keys():
Expand All @@ -65,24 +66,24 @@ def _detect_cycles(self, graph: dict):
if self.imports_cycle_detected:
return

def _new_dict(self):
def _new_dict(self) -> Dict[str, Set[str]]:
"""
Helper method. Creates a dictionary with all input files as keys and empty sets as values.
"""
dictionary = {}
dictionary: Dict[str, Set[str]] = {}
for file in self.compilation_unit.scopes:
dictionary[file.absolute] = set()
return dictionary

def _import_to_absolute_path(self, imp: str):
def _import_to_absolute_path(self, imp: str) -> str:
"""
Converts an import to the absolute path of that import.
Useful for cases like "import @openzeppelin/...".
"""
return self.compilation_unit.crytic_compile.filename_lookup(imp).absolute

@staticmethod
def _add_import(imps: dict, key: str, value: str):
def _add_import(imps: Dict[str, Set[str]], key: str, value: str) -> None:
"""
Adds import entries to imps dict.
Keys are absolute paths to files (as strings) and each key points to a set of absolute paths (as strings).
Expand All @@ -91,7 +92,7 @@ def _add_import(imps: dict, key: str, value: str):
return
imps[key].add(value)

def _add_custom_type(self, v: Variable):
def _add_custom_type(self, v: Variable) -> None:
"""
Adds user defined types to self.needed_imports.
"""
Expand All @@ -102,7 +103,7 @@ def _add_custom_type(self, v: Variable):
v.type.source_mapping.filename.absolute,
)

def _add_item_by_references(self, item: SourceMapping):
def _add_item_by_references(self, item: SourceMapping) -> None:
"""
Adds all uses of item to self.needed_imports by its references.
"""
Expand All @@ -111,62 +112,62 @@ def _add_item_by_references(self, item: SourceMapping):
self.needed_imports, ref.filename.absolute, item.source_mapping.filename.absolute
)

def _initialise_actual_imports(self):
def _initialise_actual_imports(self) -> None:
"""
After running this function, self.actual_imports contains, for each file, a set of files imported by it.
"""
for imp in self.compilation_unit.import_directives:
import_path = self._import_to_absolute_path(imp.filename)
self._add_import(self.actual_imports, imp.source_mapping.filename.absolute, import_path)

def _initialise_absolute_path_to_imp_filename(self):
def _initialise_absolute_path_to_imp_filename(self) -> None:
for imp in self.compilation_unit.import_directives:
import_path = self._import_to_absolute_path(imp.filename)
self.absolute_path_to_imp_filename[import_path] = imp.filename

def _find_top_level_structs_uses(self):
def _find_top_level_structs_uses(self) -> None:
"""
For each top level struct, analyse all of its uses.
"""
for st in self.compilation_unit.structures_top_level:
self._add_item_by_references(st)

def _find_top_level_custom_types_uses(self):
def _find_top_level_custom_types_uses(self) -> None:
"""
For each top level user defined type, analyse all of its uses.
"""
for _, ct in self.compilation_unit.user_defined_value_types.items():
self._add_item_by_references(ct)

def _find_top_level_enums_uses(self):
def _find_top_level_enums_uses(self) -> None:
"""
For each top level enum, analyse all of its uses.
"""
for en in self.compilation_unit.enums_top_level:
self._add_item_by_references(en)

def _find_top_level_constants_uses(self):
def _find_top_level_constants_uses(self) -> None:
"""
For each top level constant, analyse all of its uses.
"""
for var in self.compilation_unit.variables_top_level:
self._add_item_by_references(var)

def _find_top_level_custom_errors_uses(self):
def _find_top_level_custom_errors_uses(self) -> None:
"""
For each top level custom error, analyse all of its uses.
"""
for err in self.compilation_unit.custom_errors:
self._add_item_by_references(err)

def _find_top_level_functions_uses(self):
def _find_top_level_functions_uses(self) -> None:
"""
For each top level function, analyse all of its uses.
"""
for f in self.compilation_unit.functions_top_level:
self._add_item_by_references(f)

def _find_top_level_items_uses(self):
def _find_top_level_items_uses(self) -> None:
"""
Finds all uses of top level items, excluding contracts, libraries and interfaces.
These include:
Expand All @@ -185,42 +186,42 @@ def _find_top_level_items_uses(self):
self._find_top_level_custom_errors_uses()
self._find_top_level_functions_uses()

def _find_contract_level_structs_uses(self, c: Contract):
def _find_contract_level_structs_uses(self, c: Contract) -> None:
"""
For each contract level struct, analyse all of its uses.
"""
for st in c.structures:
self._add_item_by_references(st)

def _find_contract_level_custom_types_uses(self, c: Contract):
def _find_contract_level_custom_types_uses(self, c: Contract) -> None:
"""
For each contract level user defined type, analyse all of its uses.
"""
for _, ct in c.file_scope.user_defined_types.items():
self._add_item_by_references(ct)

def _find_contract_level_enums_uses(self, c: Contract):
def _find_contract_level_enums_uses(self, c: Contract) -> None:
"""
For each contract level enum, analyse all of its uses.
"""
for en in c.enums:
self._add_item_by_references(en)

def _find_contract_level_variables_uses(self, c: Contract):
def _find_contract_level_variables_uses(self, c: Contract) -> None:
"""
For each contract level variable, analyse all of its uses.
"""
for var in c.variables:
self._add_item_by_references(var)

def _find_contract_level_custom_errors_uses(self, c: Contract):
def _find_contract_level_custom_errors_uses(self, c: Contract) -> None:
"""
For each contract level custom error, analyse all of its uses.
"""
for err in c.custom_errors:
self._add_item_by_references(err)

def _find_contract_level_functions_and_modifiers_uses(self, c: Contract):
def _find_contract_level_functions_and_modifiers_uses(self, c: Contract) -> None:
"""
For each contract level function / modifier, analyse all of its uses.
"""
Expand All @@ -234,14 +235,14 @@ def _find_contract_level_functions_and_modifiers_uses(self, c: Contract):
)
self._add_item_by_references(fm)

def _find_contract_level_custom_events_uses(self, c: Contract):
def _find_contract_level_custom_events_uses(self, c: Contract) -> None:
"""
For each custom event (cannot be top level at the moment), analyse all of its uses.
"""
for ev in c.events:
self._add_item_by_references(ev)

def _find_contract_level_items_uses(self):
def _find_contract_level_items_uses(self) -> None:
"""
Finds all uses of items in contracts, libraries and interfaces.
These include:
Expand Down Expand Up @@ -276,7 +277,9 @@ def _find_contract_level_items_uses(self):
)

@staticmethod
def _add_all_imports_for_file(actual_imports: dict, all_imports: dict, file: str):
def _add_all_imports_for_file(
actual_imports: Dict[str, Set[str]], all_imports: Dict[str, Set[str]], file: str
) -> None:
"""
For a certain file, adds all files from its import graph to all_imports, including files imported indirectly.
For instance, if we have:
Expand All @@ -298,7 +301,7 @@ def _add_all_imports_for_file(actual_imports: dict, all_imports: dict, file: str
for imp in actual_imports[file]:
UnusedImports._add_import(all_imports, file, imp)

def _get_all_imports(self) -> dict:
def _get_all_imports(self) -> Dict[str, Set[str]]:
"""
Returns a dict, that for each file contains all files from its import graph, including files imported
indirectly.
Expand All @@ -324,7 +327,7 @@ def _get_all_imports(self) -> dict:
UnusedImports._add_all_imports_for_file(self.actual_imports, all_imports, k)
return all_imports

def _get_all_imports_in_needed(self, all_imports: dict) -> dict:
def _get_all_imports_in_needed(self, all_imports: Dict[str, Set[str]]) -> Dict[str, Set[str]]:
"""
Returns a dict, that for each file, holds all files imported directly or indirectly by any of imports that this
file needs.
Expand All @@ -348,7 +351,33 @@ def _get_all_imports_in_needed(self, all_imports: dict) -> dict:
all_imports_in_needed[k] |= all_imports[imp]
return all_imports_in_needed

def _get_needed_but_not_imported_directly(self) -> dict:
def _remove_redundant_imports_from_needed(self, all_imports: Dict[str, Set[str]]) -> None:
"""
Removes some imports from the needed_imports dict. The resulting dict, for each key k, contains the smallest
number of items such that the following equality holds:
all_imports[old_needed_imports[k]] = all_imports[new_needed_imports[k]],
where old_needed_imports is needed_imports before the function call and new_needed_imports - after the call.
The reasoning behind this is, that as long as the above equality holds, we can remove some imports from
needed_imports, while still importing the same set of files.
For instance, consider the following example:
C.sol:
import B.sol // B.sol needed
import A.sol // A.sol needed
B.sol:
import A.sol
Before the function call, needed_imports[C.sol] = {A.sol, B.sol}, since both files are used in C.sol.
However, by importing B.sol, C.sol automatically imports A.sol as well, so it doesn't need to import A.sol
directly. So, after the call: needed_imports[C.sol] = {B.sol}.
"""
for _, vs in self.needed_imports.items():
to_remove: Set[str] = set()
for v in vs:
to_remove |= all_imports[v]
for imp in to_remove:
if imp in vs:
vs.remove(imp)

def _get_needed_but_not_imported_directly(self) -> Dict[str, Set[str]]:
"""
Returns a dict, that for each file, contains files that are needed by it, but not imported directly (via
"import" statement). These files can, however, be imported indirectly.
Expand All @@ -365,7 +394,7 @@ def _get_needed_but_not_imported_directly(self) -> dict:
needed_but_not_imported_directly[k] = v - self.actual_imports[k]
return needed_but_not_imported_directly

def _get_imported_but_unneeded(self) -> dict:
def _get_imported_but_unneeded(self) -> Dict[str, Set[str]]:
"""
Returns a dict, that for each file, contains files that are directly imported by it (via "import" statement),
but are not needed (even, if their own imports are needed).
Expand Down Expand Up @@ -397,6 +426,7 @@ def _detect(self):

all_imports = self._get_all_imports()
all_imports_in_needed = self._get_all_imports_in_needed(all_imports)
self._remove_redundant_imports_from_needed(all_imports)

needed_but_not_imported_directly = self._get_needed_but_not_imported_directly()
imported_but_unneeded = self._get_imported_but_unneeded()
Expand Down

0 comments on commit 2268bb0

Please sign in to comment.