Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: issues with interface/contract namespace collision #751

Merged
merged 6 commits into from
Sep 13, 2020
Merged
35 changes: 25 additions & 10 deletions brownie/project/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ class Build:

def __init__(self, sources: Sources) -> None:
self._sources = sources
self._build: Dict = {}
self._contracts: Dict = {}
self._interfaces: Dict = {}

def _add(self, build_json: Dict) -> None:
def _add_contract(self, build_json: Dict) -> None:
contract_name = build_json["contractName"]
self._build[contract_name] = build_json
if contract_name in self._contracts and build_json["type"] == "interface":
return
self._contracts[contract_name] = build_json
if "pcMap" not in build_json:
# no pcMap means build artifact is for an interface
return
Expand All @@ -56,6 +59,10 @@ def _add(self, build_json: Dict) -> None:
build_json["pcMap"], build_json["allSourcePaths"], build_json["language"]
)

def _add_interface(self, build_json: Dict) -> None:
contract_name = build_json["contractName"]
self._interfaces[contract_name] = build_json

def _generate_revert_map(self, pcMap: Dict, source_map: Dict, language: str) -> None:
# Adds a contract's dev revert strings to the revert map and it's pcMap
marker = "//" if language == "Solidity" else "#"
Expand Down Expand Up @@ -96,29 +103,37 @@ def _generate_revert_map(self, pcMap: Dict, source_map: Dict, language: str) ->
continue
_revert_map[pc] = False

def _remove(self, contract_name: str) -> None:
del self._build[self._stem(contract_name)]
def _remove_contract(self, contract_name: str) -> None:
key = self._stem(contract_name)
if key in self._contracts:
del self._contracts[key]

def _remove_interface(self, contract_name: str) -> None:
key = self._stem(contract_name)
if key in self._interfaces:
del self._interfaces[key]

def get(self, contract_name: str) -> Dict:
"""Returns build data for the given contract name."""
return self._build[self._stem(contract_name)]
return self._contracts[self._stem(contract_name)]

def items(self, path: Optional[str] = None) -> Union[ItemsView, List]:
"""Provides an list of tuples as (key,value), similar to calling dict.items.
If a path is given, only contracts derived from that source file are returned."""
items = list(self._contracts.items()) + list(self._interfaces.items())
if path is None:
return self._build.items()
return [(k, v) for k, v in self._build.items() if v.get("sourcePath") == path]
return items
return [(k, v) for k, v in items if v.get("sourcePath") == path]

def contains(self, contract_name: str) -> bool:
"""Checks if the contract name exists in the currently loaded build data."""
return self._stem(contract_name) in self._build
return self._stem(contract_name) in list(self._contracts) + list(self._interfaces)

def get_dependents(self, contract_name: str) -> List:
"""Returns a list of contract names that inherit from or link to the given
contract. Used by the compiler when determining which contracts to recompile
based on a changed source file."""
return [k for k, v in self._build.items() if contract_name in v.get("dependencies", [])]
return [k for k, v in self._contracts.items() if contract_name in v.get("dependencies", [])]

def _stem(self, contract_name: str) -> str:
return contract_name.replace(".json", "")
Expand Down
10 changes: 4 additions & 6 deletions brownie/project/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def generate_build_json(
if compiler_data is None:
compiler_data = {}
compiler_data["evm_version"] = input_json["settings"]["evmVersion"]
build_json = {}
build_json: Dict = {}
path_list = list(input_json["sources"])

if input_json["language"] == "Solidity":
Expand All @@ -299,6 +299,8 @@ def generate_build_json(
output_json["contracts"][path_str][contract_name].get("userdoc", {}),
)
output_evm = output_json["contracts"][path_str][contract_name]["evm"]
if contract_name in build_json and not output_evm["deployedBytecode"]["object"]:
continue

if input_json["language"] == "Solidity":
contract_node = next(
Expand Down Expand Up @@ -429,11 +431,7 @@ def get_abi(
to_compile = {k: v for k, v in contract_sources.items() if k in path_list}

set_solc_version(version)
input_json = generate_input_json(
to_compile,
language="Vyper" if version == "vyper" else "Solidity",
remappings=remappings,
)
input_json = generate_input_json(to_compile, language="Solidity", remappings=remappings)
input_json["settings"]["outputSelection"]["*"] = {"*": ["abi"]}

output_json = compile_from_input_json(input_json, silent, allow_paths)
Expand Down
40 changes: 21 additions & 19 deletions brownie/project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _compile(self, contract_sources: Dict, compiler_config: Dict, silent: bool)
path = self._build_path.joinpath(f"contracts/{data['contractName']}.json")
with path.open("w") as fp:
json.dump(data, fp, sort_keys=True, indent=2, default=sorted)
self._build._add(data)
self._build._add_contract(data)

def _create_containers(self) -> None:
# create container objects
Expand Down Expand Up @@ -192,7 +192,7 @@ def load(self) -> None:
if not self._path.joinpath(build_json["sourcePath"]).exists():
path.unlink()
continue
self._build._add(build_json)
self._build._add_contract(build_json)

interface_hashes = {}
interface_list = self._sources.get_interface_list()
Expand All @@ -205,7 +205,7 @@ def load(self) -> None:
if not set(INTERFACE_KEYS).issubset(build_json) or path.stem not in interface_list:
path.unlink()
continue
self._build._add(build_json)
self._build._add_interface(build_json)
interface_hashes[path.stem] = build_json["sha1"]

self._compiler_config = _load_project_compiler_config(self._path)
Expand Down Expand Up @@ -238,21 +238,20 @@ def load(self) -> None:
def _get_changed_contracts(self, compiled_hashes: Dict) -> Dict:
# get list of changed interfaces and contracts
new_hashes = self._sources.get_interface_hashes()
interfaces = [k for k, v in new_hashes.items() if compiled_hashes.get(k, None) != v]
contracts = [i for i in self._sources.get_contract_list() if self._compare_build_json(i)]
# remove outdated build artifacts
for name in [k for k, v in new_hashes.items() if compiled_hashes.get(k, None) != v]:
self._build._remove_interface(name)

# get dependents of changed sources
final = set(contracts + interfaces)
for contract_name in list(final):
final.update(self._build.get_dependents(contract_name))
contracts = set(i for i in self._sources.get_contract_list() if self._compare_build_json(i))
for contract_name in list(contracts):
contracts.update(self._build.get_dependents(contract_name))

# remove outdated build artifacts
for name in [i for i in final if self._build.contains(i)]:
self._build._remove(name)
for name in contracts:
self._build._remove_contract(name)

# get final list of changed source paths
final.difference_update(interfaces)
changed_set: Set = set(self._sources.get_source_path(i) for i in final)
changed_set: Set = set(self._sources.get_source_path(i) for i in contracts)
return {i: self._sources.get(i) for i in changed_set}

def _compare_build_json(self, contract_name: str) -> bool:
Expand Down Expand Up @@ -283,7 +282,7 @@ def _compare_build_json(self, contract_name: str) -> bool:
def _compile_interfaces(self, compiled_hashes: Dict) -> None:
new_hashes = self._sources.get_interface_hashes()
changed_paths = [
self._sources.get_source_path(k)
self._sources.get_source_path(k, True)
for k, v in new_hashes.items()
if compiled_hashes.get(k, None) != v
]
Expand All @@ -302,7 +301,7 @@ def _compile_interfaces(self, compiled_hashes: Dict) -> None:

with self._build_path.joinpath(f"interfaces/{name}.json").open("w") as fp:
json.dump(abi, fp, sort_keys=True, indent=2, default=sorted)
self._build._add(abi)
self._build._add_interface(abi)

def _load_deployments(self) -> None:
if CONFIG.network_type != "live" and not CONFIG.settings["dev_deployment_artifacts"]:
Expand Down Expand Up @@ -510,12 +509,15 @@ def check_for_project(path: Union[Path, str] = ".") -> Optional[Path]:
for folder in [path] + list(path.parents):

structure_config = _load_project_structure_config(folder)
contracts_path = folder.joinpath(structure_config["contracts"])
tests_path = folder.joinpath(structure_config["tests"])
contracts = folder.joinpath(structure_config["contracts"])
interfaces = folder.joinpath(structure_config["interfaces"])
tests = folder.joinpath(structure_config["tests"])

if next((i for i in contracts_path.glob("**/*") if i.suffix in (".vy", ".sol")), None):
if next((i for i in contracts.glob("**/*") if i.suffix in (".vy", ".sol")), None):
return folder
if next((i for i in interfaces.glob("**/*") if i.suffix in (".json", ".vy", ".sol")), None):
return folder
if contracts_path.is_dir() and tests_path.is_dir():
if contracts.is_dir() and tests.is_dir():
return folder

return None
Expand Down
28 changes: 16 additions & 12 deletions brownie/project/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class Sources:
"""Methods for accessing and manipulating a project's contract source files."""

def __init__(self, contract_sources: Dict, interface_sources: Dict) -> None:
self._source: Dict = {}
self._contract_sources: Dict = {}
self._contracts: Dict = {}
self._interface_sources: Dict = {}
self._interfaces: Dict = {}

contracts: Dict = {}
collisions: Dict = {}
for path, source in contract_sources.items():
self._source[path] = source
self._contract_sources[path] = source
if Path(path).suffix != ".sol":
contract_names = [(Path(path).stem, "contract")]
else:
Expand All @@ -42,7 +43,7 @@ def __init__(self, contract_sources: Dict, interface_sources: Dict) -> None:
self._contracts = {k: v[0] for k, v in contracts.items()}

for path, source in interface_sources.items():
self._source[path] = source
self._interface_sources[path] = source

if Path(path).suffix != ".sol":
interface_names = [(Path(path).stem, "interface")]
Expand Down Expand Up @@ -70,19 +71,19 @@ def get(self, key: str) -> str:
key = str(key)

if key in self._contracts:
return self._source[self._contracts[key]]
return self._contract_sources[self._contracts[key]]

if key not in self._source:
if key not in self._contract_sources:
# for sources outside this project (packages, other projects)
with Path(key).open() as fp:
source = fp.read()
self._source[key] = source
self._contract_sources[key] = source

return self._source[key]
return self._contract_sources[key]

def get_path_list(self) -> List:
"""Returns a sorted list of source code file paths for the active project."""
return sorted(self._source.keys())
return sorted(self._contract_sources.keys())

def get_contract_list(self) -> List:
"""Returns a sorted list of contract names for the active project."""
Expand All @@ -94,15 +95,18 @@ def get_interface_list(self) -> List:

def get_interface_hashes(self) -> Dict:
"""Returns a dict of interface hashes in the form of {name: hash}"""
return {k: sha1(self._source[v].encode()).hexdigest() for k, v in self._interfaces.items()}
return {
k: sha1(self._interface_sources[v].encode()).hexdigest()
for k, v in self._interfaces.items()
}

def get_interface_sources(self) -> Dict:
"""Returns a dict of interfaces sources in the form {path: source}"""
return {v: self._source[v] for v in self._interfaces.values()}
return {v: self._interface_sources[v] for v in self._interfaces.values()}

def get_source_path(self, contract_name: str) -> str:
def get_source_path(self, contract_name: str, is_interface: bool = False) -> str:
"""Returns the path to the source file where a contract is located."""
if contract_name in self._contracts:
if contract_name in self._contracts and not is_interface:
return self._contracts[contract_name]
if contract_name in self._interfaces:
return self._interfaces[contract_name]
Expand Down
12 changes: 0 additions & 12 deletions tests/project/main/test_recompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,6 @@ def test_modify_library(mockproject):
]


# modifying an interface should recompile a dependent contract
def test_modify_interface(mockproject):
code = INTERFACE.split("\n")
code[3] = ""
code = "\n".join(code)
with mockproject._path.joinpath("interfaces/IFoo.sol").open("w") as fp:
fp.write(code)

mockproject.load()
assert sorted(mockproject._compile.call_args[0][0]) == ["contracts/Foo.sol"]


# modifying a base contract should recompile a dependent
def test_modify_base(mockproject):
code = BASE_CONTRACT.split("\n")
Expand Down
2 changes: 1 addition & 1 deletion tests/project/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_contract_interface_collisions(solc5source):


def test_get_path_list(sourceobj):
assert sourceobj.get_path_list() == ["interfaces/Baz.vy", "path/to/Foo.sol"]
assert sourceobj.get_path_list() == ["path/to/Foo.sol"]


def test_get_contract_list(sourceobj):
Expand Down