Skip to content

Commit

Permalink
add override for all instance of virtual, inherited functions
Browse files Browse the repository at this point in the history
  • Loading branch information
0xalpharush committed Mar 28, 2024
1 parent 759a4fc commit e7edac5
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 25 deletions.
4 changes: 3 additions & 1 deletion slither/core/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,14 +473,16 @@ def is_override(self) -> bool:
@property
def overridden_by(self) -> List["FunctionContract"]:
"""
List["FunctionContract"]: List offunctions in child contracts that override this function
List["FunctionContract"]: List of functions in child contracts that override this function
This may include distinct instances of the same function due to inheritance
"""
return self._overridden_by

@property
def overrides(self) -> List["FunctionContract"]:
"""
List["FunctionContract"]: List of functions in parent contracts that this function overrides
This may include distinct instances of the same function due to inheritance
"""
return self._overrides

Expand Down
24 changes: 11 additions & 13 deletions slither/solc_parsing/declarations/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,23 +242,21 @@ def _analyze_attributes(self) -> None:
self._function.payable = attributes["payable"]

if "baseFunctions" in attributes:
overrides_ids = []
for o_id in attributes["baseFunctions"]:
overrides_ids.append(o_id)
overrides_ids = attributes["baseFunctions"]
if len(overrides_ids) > 0:
found = 0
for c in self.contract_parser.underlying_contract.immediate_inheritance:
for f in c.functions_declared:
if f.id in overrides_ids:
for f_id in overrides_ids:
funcs = self.slither_parser.functions_by_id[f_id]
for f in funcs:
# Do not consider leaf contracts as overrides.
# B is A { function a() override {} } and C is A { function a() override {} } override A.a(), not each other.
if (
f.contract == self._function.contract
or f.contract in self._function.contract.inheritance
):
self._function.overrides.append(f)
f.overridden_by.append(self._function)
found += 1
# Search next parent if already found overridden func in this parent
continue
# Stop searching if we found all the overrides
if len(overrides_ids) == found:
break

# Attaches reference to override specifier e.g. X is referenced by `function a() override(X)`
if "overrides" in attributes and isinstance(attributes["overrides"], dict):
for override in attributes["overrides"].get("overrides", []):
refId = override["referencedDeclaration"]
Expand Down
9 changes: 8 additions & 1 deletion slither/solc_parsing/slither_compilation_unit_solc.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import json
import logging
import os
Expand All @@ -7,7 +8,7 @@

from slither.analyses.data_dependency.data_dependency import compute_dependency
from slither.core.compilation_unit import SlitherCompilationUnit
from slither.core.declarations import Contract
from slither.core.declarations import Contract, Function
from slither.core.declarations.custom_error_top_level import CustomErrorTopLevel
from slither.core.declarations.enum_top_level import EnumTopLevel
from slither.core.declarations.event_top_level import EventTopLevel
Expand Down Expand Up @@ -80,6 +81,7 @@ def __init__(self, compilation_unit: SlitherCompilationUnit) -> None:
self._compilation_unit: SlitherCompilationUnit = compilation_unit

self._contracts_by_id: Dict[int, Contract] = {}
self._functions_by_id: Dict[int, List[Function]] = defaultdict(list)
self._parsed = False
self._analyzed = False
self._is_compact_ast = False
Expand All @@ -104,6 +106,7 @@ def all_functions_and_modifiers_parser(self) -> List[FunctionSolc]:

def add_function_or_modifier_parser(self, f: FunctionSolc) -> None:
self._all_functions_and_modifier_parser.append(f)
self._functions_by_id[f.underlying_function.id].append(f.underlying_function)

@property
def underlying_contract_to_parser(self) -> Dict[Contract, ContractSolc]:
Expand All @@ -117,6 +120,10 @@ def slither_parser(self) -> "SlitherCompilationUnitSolc":
def contracts_by_id(self) -> Dict[int, Contract]:
return self._contracts_by_id

@property
def functions_by_id(self) -> Dict[int, List[Function]]:
return self._functions_by_id

###################################################################################
###################################################################################
# region AST
Expand Down
34 changes: 24 additions & 10 deletions tests/unit/core/test_virtual_overrides.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def test_overrides(solc_binary_path) -> None:
x = test.get_functions_overridden_by(test_virtual_func)
assert len(x) == 0
x = test_virtual_func.overridden_by
assert len(x) == 3, [i.canonical_name for i in x]
assert set([i.canonical_name for i in x]) == set(
assert len(x) == 5
assert set(i.canonical_name for i in x) == set(
["A.myVirtualFunction()", "C.myVirtualFunction()", "X.myVirtualFunction()"]
)

Expand All @@ -25,18 +25,27 @@ def test_overrides(solc_binary_path) -> None:
assert a_virtual_func.is_virtual
assert a_virtual_func.is_override
x = a.get_functions_overridden_by(a_virtual_func)
assert len(x) == 1
assert x[0].canonical_name == "Test.myVirtualFunction()"
assert len(x) == 2
assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"])

b = slither.get_contract_from_name("B")[0]
b_virtual_func = b.get_function_from_full_name("myVirtualFunction()")
assert not b_virtual_func.is_virtual
assert b_virtual_func.is_override
x = b.get_functions_overridden_by(b_virtual_func)
assert len(x) == 1
assert x[0].canonical_name == "A.myVirtualFunction()"
assert len(x) == 2
assert set(i.canonical_name for i in x) == set(["A.myVirtualFunction()"])
assert len(b_virtual_func.overridden_by) == 0

c = slither.get_contract_from_name("C")[0]
c_virtual_func = c.get_function_from_full_name("myVirtualFunction()")
assert not c_virtual_func.is_virtual
assert c_virtual_func.is_override
x = c.get_functions_overridden_by(c_virtual_func)
assert len(x) == 2
# C should not override B as they are distinct leaves in the inheritance tree
assert set(i.canonical_name for i in x) == set(["Test.myVirtualFunction()"])

y = slither.get_contract_from_name("Y")[0]
y_virtual_func = y.get_function_from_full_name("myVirtualFunction()")
assert y_virtual_func.is_virtual
Expand All @@ -50,23 +59,28 @@ def test_overrides(solc_binary_path) -> None:
assert z_virtual_func.is_virtual
assert z_virtual_func.is_override
x = z.get_functions_overridden_by(z_virtual_func)
assert len(x) == 2
assert set([i.canonical_name for i in x]) == set(
assert len(x) == 4
assert set(i.canonical_name for i in x) == set(
["Y.myVirtualFunction()", "X.myVirtualFunction()"]
)

k = slither.get_contract_from_name("K")[0]
k_virtual_func = k.get_function_from_full_name("a()")
assert not k_virtual_func.is_virtual
assert k_virtual_func.is_override
assert len(k_virtual_func.overrides) == 1
assert len(k_virtual_func.overrides) == 3
x = k_virtual_func.overrides
assert set(i.canonical_name for i in x) == set(["I.a()"])

i = slither.get_contract_from_name("I")[0]
i_virtual_func = i.get_function_from_full_name("a()")
assert i_virtual_func.is_virtual
assert not i_virtual_func.is_override
assert len(i_virtual_func.overrides) == 0
assert len(i_virtual_func.overridden_by) == 1
x = i_virtual_func.overridden_by
assert len(x) == 1
assert x[0].canonical_name == "K.a()"


def test_virtual_override_references_and_implementations(solc_binary_path) -> None:
solc_path = solc_binary_path("0.8.15")
Expand Down

0 comments on commit e7edac5

Please sign in to comment.