Skip to content

Commit

Permalink
minor improvements
Browse files Browse the repository at this point in the history
  • Loading branch information
montyly committed Sep 15, 2023
1 parent 628f723 commit 4053c9b
Show file tree
Hide file tree
Showing 9 changed files with 252 additions and 262 deletions.
6 changes: 6 additions & 0 deletions slither/core/variables/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,8 @@
from .state_variable import StateVariable
from .variable import Variable
from .local_variable_init_from_tuple import LocalVariableInitFromTuple
from .local_variable import LocalVariable
from .top_level_variable import TopLevelVariable
from .event_variable import EventVariable
from .function_type_variable import FunctionTypeVariable
from .structure_variable import StructureVariable
3 changes: 2 additions & 1 deletion slither/printers/summary/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
"""
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.ck import CKMetrics
from slither.utils.output import Output


class CK(AbstractPrinter):
Expand All @@ -40,7 +41,7 @@ class CK(AbstractPrinter):

WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#ck"

def output(self, _filename):
def output(self, _filename: str) -> Output:
if len(self.contracts) == 0:
return self.generate_output("No contract found")

Expand Down
3 changes: 2 additions & 1 deletion slither/printers/summary/halstead.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
"""
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.halstead import HalsteadMetrics
from slither.utils.output import Output


class Halstead(AbstractPrinter):
Expand All @@ -33,7 +34,7 @@ class Halstead(AbstractPrinter):

WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#halstead"

def output(self, _filename):
def output(self, _filename: str) -> Output:
if len(self.contracts) == 0:
return self.generate_output("No contract found")

Expand Down
3 changes: 2 additions & 1 deletion slither/printers/summary/martin.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
"""
from slither.printers.abstract_printer import AbstractPrinter
from slither.utils.martin import MartinMetrics
from slither.utils.output import Output


class Martin(AbstractPrinter):
Expand All @@ -19,7 +20,7 @@ class Martin(AbstractPrinter):

WIKI = "https://github.com/trailofbits/slither/wiki/Printer-documentation#martin"

def output(self, _filename):
def output(self, _filename: str) -> Output:
if len(self.contracts) == 0:
return self.generate_output("No contract found")

Expand Down
9 changes: 5 additions & 4 deletions slither/utils/ck.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,7 @@ class CKContractMetrics:
dit: int = 0
cbo: int = 0

def __post_init__(self):
def __post_init__(self) -> None:
if not hasattr(self.contract, "functions"):
return
self.count_variables()
Expand All @@ -123,7 +123,7 @@ def __post_init__(self):

# pylint: disable=too-many-locals
# pylint: disable=too-many-branches
def calculate_metrics(self):
def calculate_metrics(self) -> None:
"""Calculate the metrics for a contract"""
rfc = self.public # initialize with public getter count
for func in self.contract.functions:
Expand Down Expand Up @@ -186,7 +186,7 @@ def calculate_metrics(self):
self.ext_calls += len(external_calls)
self.rfc = rfc

def count_variables(self):
def count_variables(self) -> None:
"""Count the number of variables in a contract"""
state_variable_count = 0
constant_count = 0
Expand Down Expand Up @@ -302,7 +302,7 @@ class CKMetrics:
("Core", "core", CORE_KEYS),
)

def __post_init__(self):
def __post_init__(self) -> None:
martin_metrics = MartinMetrics(self.contracts).contract_metrics
dependents = {
inherited.name: {
Expand All @@ -323,6 +323,7 @@ def __post_init__(self):
for contract in self.contracts
}

subtitle = ""
# Update each section
for (title, attr, keys) in self.SECTIONS:
if attr == "core":
Expand Down
202 changes: 202 additions & 0 deletions slither/utils/encoding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
from typing import Union

from slither.core import variables
from slither.core.declarations import (
SolidityVariable,
SolidityVariableComposed,
Structure,
Enum,
Contract,
)
from slither.core import solidity_types
from slither.slithir import operations
from slither.slithir import variables as SlitherIRVariable


# pylint: disable=too-many-branches
def ntype(_type: Union[solidity_types.Type, str]) -> str:
if isinstance(_type, solidity_types.ElementaryType):
_type = str(_type)
elif isinstance(_type, solidity_types.ArrayType):
if isinstance(_type.type, solidity_types.ElementaryType):
_type = str(_type)
else:
_type = "user_defined_array"
elif isinstance(_type, Structure):
_type = str(_type)
elif isinstance(_type, Enum):
_type = str(_type)
elif isinstance(_type, solidity_types.MappingType):
_type = str(_type)
elif isinstance(_type, solidity_types.UserDefinedType):
if isinstance(_type.type, Contract):
_type = f"contract({_type.type.name})"
elif isinstance(_type.type, Structure):
_type = f"struct({_type.type.name})"
elif isinstance(_type.type, Enum):
_type = f"enum({_type.type.name})"
else:
_type = str(_type)

_type = _type.replace(" memory", "")
_type = _type.replace(" storage ref", "")

if "struct" in _type:
return "struct"
if "enum" in _type:
return "enum"
if "tuple" in _type:
return "tuple"
if "contract" in _type:
return "contract"
if "mapping" in _type:
return "mapping"
return _type.replace(" ", "_")


# pylint: disable=too-many-branches
def encode_var_for_compare(var: Union[variables.Variable, SolidityVariable]) -> str:

# variables
if isinstance(var, SlitherIRVariable.Constant):
return f"constant({ntype(var.type)},{var.value})"
if isinstance(var, SolidityVariableComposed):
return f"solidity_variable_composed({var.name})"
if isinstance(var, SolidityVariable):
return f"solidity_variable{var.name}"
if isinstance(var, SlitherIRVariable.TemporaryVariable):
return "temporary_variable"
if isinstance(var, SlitherIRVariable.ReferenceVariable):
return f"reference({ntype(var.type)})"
if isinstance(var, variables.LocalVariable):
return f"local_solc_variable({ntype(var.type)},{var.location})"
if isinstance(var, variables.StateVariable):
if not (var.is_constant or var.is_immutable):
try:
slot, _ = var.contract.compilation_unit.storage_layout_of(var.contract, var)
except KeyError:
slot = var.name
else:
slot = var.name
return f"state_solc_variable({ntype(var.type)},{slot})"
if isinstance(var, variables.LocalVariableInitFromTuple):
return "local_variable_init_tuple"
if isinstance(var, SlitherIRVariable.TupleVariable):
return "tuple_variable"

# default
return ""


# pylint: disable=too-many-branches
def encode_ir_for_upgradeability_compare(ir: operations.Operation) -> str:
# operations
if isinstance(ir, operations.Assignment):
return f"({encode_var_for_compare(ir.lvalue)}):=({encode_var_for_compare(ir.rvalue)})"
if isinstance(ir, operations.Index):
return f"index({ntype(ir.variable_right.type)})"
if isinstance(ir, operations.Member):
return "member" # .format(ntype(ir._type))
if isinstance(ir, operations.Length):
return "length"
if isinstance(ir, operations.Binary):
return f"binary({encode_var_for_compare(ir.variable_left)}{ir.type}{encode_var_for_compare(ir.variable_right)})"
if isinstance(ir, operations.Unary):
return f"unary({str(ir.type)})"
if isinstance(ir, operations.Condition):
return f"condition({encode_var_for_compare(ir.value)})"
if isinstance(ir, operations.NewStructure):
return "new_structure"
if isinstance(ir, operations.NewContract):
return "new_contract"
if isinstance(ir, operations.NewArray):
return f"new_array({ntype(ir.array_type)})"
if isinstance(ir, operations.NewElementaryType):
return f"new_elementary({ntype(ir.type)})"
if isinstance(ir, operations.Delete):
return f"delete({encode_var_for_compare(ir.lvalue)},{encode_var_for_compare(ir.variable)})"
if isinstance(ir, operations.SolidityCall):
return f"solidity_call({ir.function.full_name})"
if isinstance(ir, operations.InternalCall):
return f"internal_call({ntype(ir.type_call)})"
if isinstance(ir, operations.EventCall): # is this useful?
return "event"
if isinstance(ir, operations.LibraryCall):
return "library_call"
if isinstance(ir, operations.InternalDynamicCall):
return "internal_dynamic_call"
if isinstance(ir, operations.HighLevelCall): # TODO: improve
return "high_level_call"
if isinstance(ir, operations.LowLevelCall): # TODO: improve
return "low_level_call"
if isinstance(ir, operations.TypeConversion):
return f"type_conversion({ntype(ir.type)})"
if isinstance(ir, operations.Return): # this can be improved using values
return "return" # .format(ntype(ir.type))
if isinstance(ir, operations.Transfer):
return f"transfer({encode_var_for_compare(ir.call_value)})"
if isinstance(ir, operations.Send):
return f"send({encode_var_for_compare(ir.call_value)})"
if isinstance(ir, operations.Unpack): # TODO: improve
return "unpack"
if isinstance(ir, operations.InitArray): # TODO: improve
return "init_array"

# default
return ""


def encode_ir_for_halstead(ir: operations.Operation) -> str:
# operations
if isinstance(ir, operations.Assignment):
return "assignment"
if isinstance(ir, operations.Index):
return "index"
if isinstance(ir, operations.Member):
return "member" # .format(ntype(ir._type))
if isinstance(ir, operations.Length):
return "length"
if isinstance(ir, operations.Binary):
return f"binary({str(ir.type)})"
if isinstance(ir, operations.Unary):
return f"unary({str(ir.type)})"
if isinstance(ir, operations.Condition):
return f"condition({encode_var_for_compare(ir.value)})"
if isinstance(ir, operations.NewStructure):
return "new_structure"
if isinstance(ir, operations.NewContract):
return "new_contract"
if isinstance(ir, operations.NewArray):
return f"new_array({ntype(ir.array_type)})"
if isinstance(ir, operations.NewElementaryType):
return f"new_elementary({ntype(ir.type)})"
if isinstance(ir, operations.Delete):
return "delete"
if isinstance(ir, operations.SolidityCall):
return f"solidity_call({ir.function.full_name})"
if isinstance(ir, operations.InternalCall):
return f"internal_call({ntype(ir.type_call)})"
if isinstance(ir, operations.EventCall): # is this useful?
return "event"
if isinstance(ir, operations.LibraryCall):
return "library_call"
if isinstance(ir, operations.InternalDynamicCall):
return "internal_dynamic_call"
if isinstance(ir, operations.HighLevelCall): # TODO: improve
return "high_level_call"
if isinstance(ir, operations.LowLevelCall): # TODO: improve
return "low_level_call"
if isinstance(ir, operations.TypeConversion):
return f"type_conversion({ntype(ir.type)})"
if isinstance(ir, operations.Return): # this can be improved using values
return "return" # .format(ntype(ir.type))
if isinstance(ir, operations.Transfer):
return "transfer"
if isinstance(ir, operations.Send):
return "send"
if isinstance(ir, operations.Unpack): # TODO: improve
return "unpack"
if isinstance(ir, operations.InitArray): # TODO: improve
return "init_array"
# default
raise NotImplementedError(f"encode_ir_for_halstead: {ir}")
22 changes: 13 additions & 9 deletions slither/utils/halstead.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,17 @@
"""
import math
from collections import OrderedDict
from dataclasses import dataclass, field
from typing import Tuple, List, Dict
from collections import OrderedDict

from slither.core.declarations import Contract
from slither.slithir.variables.temporary import TemporaryVariable
from slither.utils.encoding import encode_ir_for_halstead
from slither.utils.myprettytable import make_pretty_table, MyPrettyTable
from slither.utils.upgradeability import encode_ir_for_halstead


# pylint: disable=too-many-branches


@dataclass
Expand All @@ -55,7 +59,7 @@ class HalsteadContractMetrics:
T: float = 0
B: float = 0

def __post_init__(self):
def __post_init__(self) -> None:
"""Operators and operands can be passed in as constructor args to avoid computing
them based on the contract. Useful for computing metrics for ALL_CONTRACTS"""

Expand Down Expand Up @@ -85,7 +89,7 @@ def to_dict(self) -> Dict[str, float]:
}
)

def populate_operators_and_operands(self):
def populate_operators_and_operands(self) -> None:
"""Populate the operators and operands lists."""
operators = []
operands = []
Expand All @@ -104,7 +108,7 @@ def populate_operators_and_operands(self):
self.all_operators.extend(operators)
self.all_operands.extend(operands)

def compute_metrics(self, all_operators=None, all_operands=None):
def compute_metrics(self, all_operators=None, all_operands=None) -> None:
"""Compute the Halstead metrics."""
if all_operators is None:
all_operators = self.all_operators
Expand Down Expand Up @@ -183,17 +187,17 @@ class HalsteadMetrics:
("Extended 2/2", "extended2", EXTENDED2_KEYS),
)

def __post_init__(self):
def __post_init__(self) -> None:
# Compute the metrics for each contract and for all contracts.
self.update_contract_metrics()
self.add_all_contracts_metrics()
self.update_reporting_sections()

def update_contract_metrics(self):
def update_contract_metrics(self) -> None:
for contract in self.contracts:
self.contract_metrics[contract.name] = HalsteadContractMetrics(contract=contract)

def add_all_contracts_metrics(self):
def add_all_contracts_metrics(self) -> None:
# If there are more than 1 contract, compute the metrics for all contracts.
if len(self.contracts) <= 1:
return
Expand All @@ -211,7 +215,7 @@ def add_all_contracts_metrics(self):
None, all_operators=all_operators, all_operands=all_operands
)

def update_reporting_sections(self):
def update_reporting_sections(self) -> None:
# Create the table and text for each section.
data = {
contract.name: self.contract_metrics[contract.name].to_dict()
Expand Down
Loading

0 comments on commit 4053c9b

Please sign in to comment.