Skip to content

Commit

Permalink
Add more types hints
Browse files Browse the repository at this point in the history
  • Loading branch information
montyly committed Feb 16, 2023
1 parent 20a7951 commit 371e3cb
Show file tree
Hide file tree
Showing 49 changed files with 422 additions and 256 deletions.
63 changes: 17 additions & 46 deletions slither/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def process_single(
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
) -> Tuple[Slither, List[Dict], List[Output], int]:
"""
The core high-level code for running Slither static analysis.
Expand All @@ -89,7 +89,7 @@ def process_all(
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[List[Slither], List[Dict], List[Dict], int]:
) -> Tuple[List[Slither], List[Dict], List[Output], int]:
compilations = compile_all(target, **vars(args))
slither_instances = []
results_detectors = []
Expand Down Expand Up @@ -144,23 +144,6 @@ def _process(
return slither, results_detectors, results_printers, analyzed_contracts_count


# TODO: delete me?
def process_from_asts(
filenames: List[str],
args: argparse.Namespace,
detector_classes: List[Type[AbstractDetector]],
printer_classes: List[Type[AbstractPrinter]],
) -> Tuple[Slither, List[Dict], List[Dict], int]:
all_contracts: List[str] = []

for filename in filenames:
with open(filename, encoding="utf8") as file_open:
contract_loaded = json.load(file_open)
all_contracts.append(contract_loaded["ast"])

return process_single(all_contracts, args, detector_classes, printer_classes)


# endregion
###################################################################################
###################################################################################
Expand Down Expand Up @@ -608,9 +591,6 @@ def parse_args(
default=False,
)

# if the json is splitted in different files
parser.add_argument("--splitted", help=argparse.SUPPRESS, action="store_true", default=False)

# Disable the throw/catch on partial analyses
parser.add_argument(
"--disallow-partial", help=argparse.SUPPRESS, action="store_true", default=False
Expand All @@ -626,7 +606,7 @@ def parse_args(
args.filter_paths = parse_filter_paths(args)

# Verify our json-type output is valid
args.json_types = set(args.json_types.split(","))
args.json_types = set(args.json_types.split(",")) # type:ignore
for json_type in args.json_types:
if json_type not in JSON_OUTPUT_TYPES:
raise Exception(f'Error: "{json_type}" is not a valid JSON result output type.')
Expand Down Expand Up @@ -697,14 +677,14 @@ def __call__(


class FormatterCryticCompile(logging.Formatter):
def format(self, record):
def format(self, record: logging.LogRecord) -> str:
# for i, msg in enumerate(record.msg):
if record.msg.startswith("Compilation warnings/errors on "):
txt = record.args[1]
txt = txt.split("\n")
txt = record.args[1] # type:ignore
txt = txt.split("\n") # type:ignore
txt = [red(x) if "Error" in x else x for x in txt]
txt = "\n".join(txt)
record.args = (record.args[0], txt)
record.args = (record.args[0], txt) # type:ignore
return super().format(record)


Expand Down Expand Up @@ -747,7 +727,7 @@ def main_impl(
set_colorization_enabled(False if args.disable_color else sys.stdout.isatty())

# Define some variables for potential JSON output
json_results = {}
json_results: Dict[str, Any] = {}
output_error = None
outputting_json = args.json is not None
outputting_json_stdout = args.json == "-"
Expand Down Expand Up @@ -796,7 +776,7 @@ def main_impl(
crytic_compile_error.setLevel(logging.INFO)

results_detectors: List[Dict] = []
results_printers: List[Dict] = []
results_printers: List[Output] = []
try:
filename = args.filename

Expand All @@ -809,26 +789,17 @@ def main_impl(
number_contracts = 0

slither_instances = []
if args.splitted:
for filename in filenames:
(
slither_instance,
results_detectors,
results_printers,
number_contracts,
) = process_from_asts(filenames, args, detector_classes, printer_classes)
results_detectors_tmp,
results_printers_tmp,
number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance)
else:
for filename in filenames:
(
slither_instance,
results_detectors_tmp,
results_printers_tmp,
number_contracts_tmp,
) = process_single(filename, args, detector_classes, printer_classes)
number_contracts += number_contracts_tmp
results_detectors += results_detectors_tmp
results_printers += results_printers_tmp
slither_instances.append(slither_instance)

# Rely on CryticCompile to discern the underlying type of compilations.
else:
Expand Down
88 changes: 70 additions & 18 deletions slither/analyses/data_dependency/data_dependency.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from collections import defaultdict
from typing import Union, Set, Dict, TYPE_CHECKING

from slither.core.cfg.node import Node
from slither.core.declarations import (
Contract,
Enum,
Expand All @@ -12,6 +13,7 @@
SolidityVariable,
SolidityVariableComposed,
Structure,
FunctionContract,
)
from slither.core.declarations.solidity_import_placeholder import SolidityImportPlaceHolder
from slither.core.variables.top_level_variable import TopLevelVariable
Expand Down Expand Up @@ -40,25 +42,37 @@


Variable_types = Union[Variable, SolidityVariable]
# TODO refactor the data deps to be better suited for top level function object
# Right now we allow to pass a node to ease the API, but we need something
# better
# The deps propagation for top level elements is also not working as expected
Context_types_API = Union[Contract, Function, Node]
Context_types = Union[Contract, Function]


def is_dependent(
variable: Variable_types,
source: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable (Variable)
source (Variable)
context (Contract|Function)
context (Contract|Function|Node).
only_unprotected (bool): True only unprotected function are considered
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func

if isinstance(variable, Constant):
return False
if variable == source:
Expand All @@ -76,10 +90,13 @@ def is_dependent(
def is_dependent_ssa(
variable: Variable_types,
source: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable (Variable)
taint (Variable)
Expand All @@ -88,7 +105,10 @@ def is_dependent_ssa(
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
context_dict = context.context
if isinstance(variable, Constant):
return False
Expand All @@ -112,19 +132,25 @@ def is_dependent_ssa(

def is_tainted(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable
context (Contract|Function)
only_unprotected (bool): True only unprotected function are considered
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant):
return False
Expand All @@ -139,19 +165,25 @@ def is_tainted(

def is_tainted_ssa(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
ignore_generic_taint: bool = False,
):
) -> bool:
"""
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
Args:
variable
context (Contract|Function)
only_unprotected (bool): True only unprotected function are considered
Returns:
bool
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if isinstance(variable, Constant):
return False
Expand All @@ -166,35 +198,45 @@ def is_tainted_ssa(

def get_dependencies(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> Set[Variable]:
"""
Return the variables for which `variable` depends on.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: set(Variable)
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED].get(variable, set())
return context.context[KEY_NON_SSA].get(variable, set())


def get_all_dependencies(
context: Context_types, only_unprotected: bool = False
context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]:
"""
Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable))
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_NON_SSA_UNPROTECTED]
Expand All @@ -203,35 +245,45 @@ def get_all_dependencies(

def get_dependencies_ssa(
variable: Variable_types,
context: Context_types,
context: Context_types_API,
only_unprotected: bool = False,
) -> Set[Variable]:
"""
Return the variables for which `variable` depends on (SSA version).
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param variable: The target (must be SSA variable)
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: set(Variable)
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED].get(variable, set())
return context.context[KEY_SSA].get(variable, set())


def get_all_dependencies_ssa(
context: Context_types, only_unprotected: bool = False
context: Context_types_API, only_unprotected: bool = False
) -> Dict[Variable, Set[Variable]]:
"""
Return the dictionary of dependencies.
If Node is provided as context, the context will be the broader context, either the contract or the function,
depending on if the node is in a top level function or not
:param context: Either a function (interprocedural) or a contract (inter transactional)
:param only_unprotected: True if consider only protected functions
:return: Dict(Variable, set(Variable))
"""
assert isinstance(context, (Contract, Function))
assert isinstance(context, (Contract, Function, Node))
if isinstance(context, Node):
func = context.function
context = func.contract if isinstance(func, FunctionContract) else func
assert isinstance(only_unprotected, bool)
if only_unprotected:
return context.context[KEY_SSA_UNPROTECTED]
Expand Down
Loading

0 comments on commit 371e3cb

Please sign in to comment.