Skip to content

Commit

Permalink
Merge pull request #751 from eth-brownie/fix-build-error
Browse files Browse the repository at this point in the history
fix: issues with interface/contract namespace collision
  • Loading branch information
iamdefinitelyahuman authored Sep 13, 2020
2 parents 5eec625 + 26e8456 commit ba14902
Show file tree
Hide file tree
Showing 6 changed files with 67 additions and 60 deletions.
35 changes: 25 additions & 10 deletions brownie/project/build.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,14 @@ class Build:

def __init__(self, sources: Sources) -> None:
self._sources = sources
self._build: Dict = {}
self._contracts: Dict = {}
self._interfaces: Dict = {}

def _add(self, build_json: Dict) -> None:
def _add_contract(self, build_json: Dict) -> None:
contract_name = build_json["contractName"]
self._build[contract_name] = build_json
if contract_name in self._contracts and build_json["type"] == "interface":
return
self._contracts[contract_name] = build_json
if "pcMap" not in build_json:
# no pcMap means build artifact is for an interface
return
Expand All @@ -56,6 +59,10 @@ def _add(self, build_json: Dict) -> None:
build_json["pcMap"], build_json["allSourcePaths"], build_json["language"]
)

def _add_interface(self, build_json: Dict) -> None:
contract_name = build_json["contractName"]
self._interfaces[contract_name] = build_json

def _generate_revert_map(self, pcMap: Dict, source_map: Dict, language: str) -> None:
# Adds a contract's dev revert strings to the revert map and it's pcMap
marker = "//" if language == "Solidity" else "#"
Expand Down Expand Up @@ -96,29 +103,37 @@ def _generate_revert_map(self, pcMap: Dict, source_map: Dict, language: str) ->
continue
_revert_map[pc] = False

def _remove(self, contract_name: str) -> None:
del self._build[self._stem(contract_name)]
def _remove_contract(self, contract_name: str) -> None:
key = self._stem(contract_name)
if key in self._contracts:
del self._contracts[key]

def _remove_interface(self, contract_name: str) -> None:
key = self._stem(contract_name)
if key in self._interfaces:
del self._interfaces[key]

def get(self, contract_name: str) -> Dict:
"""Returns build data for the given contract name."""
return self._build[self._stem(contract_name)]
return self._contracts[self._stem(contract_name)]

def items(self, path: Optional[str] = None) -> Union[ItemsView, List]:
"""Provides an list of tuples as (key,value), similar to calling dict.items.
If a path is given, only contracts derived from that source file are returned."""
items = list(self._contracts.items()) + list(self._interfaces.items())
if path is None:
return self._build.items()
return [(k, v) for k, v in self._build.items() if v.get("sourcePath") == path]
return items
return [(k, v) for k, v in items if v.get("sourcePath") == path]

def contains(self, contract_name: str) -> bool:
"""Checks if the contract name exists in the currently loaded build data."""
return self._stem(contract_name) in self._build
return self._stem(contract_name) in list(self._contracts) + list(self._interfaces)

def get_dependents(self, contract_name: str) -> List:
"""Returns a list of contract names that inherit from or link to the given
contract. Used by the compiler when determining which contracts to recompile
based on a changed source file."""
return [k for k, v in self._build.items() if contract_name in v.get("dependencies", [])]
return [k for k, v in self._contracts.items() if contract_name in v.get("dependencies", [])]

def _stem(self, contract_name: str) -> str:
return contract_name.replace(".json", "")
Expand Down
10 changes: 4 additions & 6 deletions brownie/project/compiler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -279,7 +279,7 @@ def generate_build_json(
if compiler_data is None:
compiler_data = {}
compiler_data["evm_version"] = input_json["settings"]["evmVersion"]
build_json = {}
build_json: Dict = {}
path_list = list(input_json["sources"])

if input_json["language"] == "Solidity":
Expand All @@ -299,6 +299,8 @@ def generate_build_json(
output_json["contracts"][path_str][contract_name].get("userdoc", {}),
)
output_evm = output_json["contracts"][path_str][contract_name]["evm"]
if contract_name in build_json and not output_evm["deployedBytecode"]["object"]:
continue

if input_json["language"] == "Solidity":
contract_node = next(
Expand Down Expand Up @@ -429,11 +431,7 @@ def get_abi(
to_compile = {k: v for k, v in contract_sources.items() if k in path_list}

set_solc_version(version)
input_json = generate_input_json(
to_compile,
language="Vyper" if version == "vyper" else "Solidity",
remappings=remappings,
)
input_json = generate_input_json(to_compile, language="Solidity", remappings=remappings)
input_json["settings"]["outputSelection"]["*"] = {"*": ["abi"]}

output_json = compile_from_input_json(input_json, silent, allow_paths)
Expand Down
40 changes: 21 additions & 19 deletions brownie/project/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def _compile(self, contract_sources: Dict, compiler_config: Dict, silent: bool)
path = self._build_path.joinpath(f"contracts/{data['contractName']}.json")
with path.open("w") as fp:
json.dump(data, fp, sort_keys=True, indent=2, default=sorted)
self._build._add(data)
self._build._add_contract(data)

def _create_containers(self) -> None:
# create container objects
Expand Down Expand Up @@ -192,7 +192,7 @@ def load(self) -> None:
if not self._path.joinpath(build_json["sourcePath"]).exists():
path.unlink()
continue
self._build._add(build_json)
self._build._add_contract(build_json)

interface_hashes = {}
interface_list = self._sources.get_interface_list()
Expand All @@ -205,7 +205,7 @@ def load(self) -> None:
if not set(INTERFACE_KEYS).issubset(build_json) or path.stem not in interface_list:
path.unlink()
continue
self._build._add(build_json)
self._build._add_interface(build_json)
interface_hashes[path.stem] = build_json["sha1"]

self._compiler_config = _load_project_compiler_config(self._path)
Expand Down Expand Up @@ -238,21 +238,20 @@ def load(self) -> None:
def _get_changed_contracts(self, compiled_hashes: Dict) -> Dict:
# get list of changed interfaces and contracts
new_hashes = self._sources.get_interface_hashes()
interfaces = [k for k, v in new_hashes.items() if compiled_hashes.get(k, None) != v]
contracts = [i for i in self._sources.get_contract_list() if self._compare_build_json(i)]
# remove outdated build artifacts
for name in [k for k, v in new_hashes.items() if compiled_hashes.get(k, None) != v]:
self._build._remove_interface(name)

# get dependents of changed sources
final = set(contracts + interfaces)
for contract_name in list(final):
final.update(self._build.get_dependents(contract_name))
contracts = set(i for i in self._sources.get_contract_list() if self._compare_build_json(i))
for contract_name in list(contracts):
contracts.update(self._build.get_dependents(contract_name))

# remove outdated build artifacts
for name in [i for i in final if self._build.contains(i)]:
self._build._remove(name)
for name in contracts:
self._build._remove_contract(name)

# get final list of changed source paths
final.difference_update(interfaces)
changed_set: Set = set(self._sources.get_source_path(i) for i in final)
changed_set: Set = set(self._sources.get_source_path(i) for i in contracts)
return {i: self._sources.get(i) for i in changed_set}

def _compare_build_json(self, contract_name: str) -> bool:
Expand Down Expand Up @@ -283,7 +282,7 @@ def _compare_build_json(self, contract_name: str) -> bool:
def _compile_interfaces(self, compiled_hashes: Dict) -> None:
new_hashes = self._sources.get_interface_hashes()
changed_paths = [
self._sources.get_source_path(k)
self._sources.get_source_path(k, True)
for k, v in new_hashes.items()
if compiled_hashes.get(k, None) != v
]
Expand All @@ -302,7 +301,7 @@ def _compile_interfaces(self, compiled_hashes: Dict) -> None:

with self._build_path.joinpath(f"interfaces/{name}.json").open("w") as fp:
json.dump(abi, fp, sort_keys=True, indent=2, default=sorted)
self._build._add(abi)
self._build._add_interface(abi)

def _load_deployments(self) -> None:
if CONFIG.network_type != "live" and not CONFIG.settings["dev_deployment_artifacts"]:
Expand Down Expand Up @@ -510,12 +509,15 @@ def check_for_project(path: Union[Path, str] = ".") -> Optional[Path]:
for folder in [path] + list(path.parents):

structure_config = _load_project_structure_config(folder)
contracts_path = folder.joinpath(structure_config["contracts"])
tests_path = folder.joinpath(structure_config["tests"])
contracts = folder.joinpath(structure_config["contracts"])
interfaces = folder.joinpath(structure_config["interfaces"])
tests = folder.joinpath(structure_config["tests"])

if next((i for i in contracts_path.glob("**/*") if i.suffix in (".vy", ".sol")), None):
if next((i for i in contracts.glob("**/*") if i.suffix in (".vy", ".sol")), None):
return folder
if next((i for i in interfaces.glob("**/*") if i.suffix in (".json", ".vy", ".sol")), None):
return folder
if contracts_path.is_dir() and tests_path.is_dir():
if contracts.is_dir() and tests.is_dir():
return folder

return None
Expand Down
28 changes: 16 additions & 12 deletions brownie/project/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,15 @@ class Sources:
"""Methods for accessing and manipulating a project's contract source files."""

def __init__(self, contract_sources: Dict, interface_sources: Dict) -> None:
self._source: Dict = {}
self._contract_sources: Dict = {}
self._contracts: Dict = {}
self._interface_sources: Dict = {}
self._interfaces: Dict = {}

contracts: Dict = {}
collisions: Dict = {}
for path, source in contract_sources.items():
self._source[path] = source
self._contract_sources[path] = source
if Path(path).suffix != ".sol":
contract_names = [(Path(path).stem, "contract")]
else:
Expand All @@ -42,7 +43,7 @@ def __init__(self, contract_sources: Dict, interface_sources: Dict) -> None:
self._contracts = {k: v[0] for k, v in contracts.items()}

for path, source in interface_sources.items():
self._source[path] = source
self._interface_sources[path] = source

if Path(path).suffix != ".sol":
interface_names = [(Path(path).stem, "interface")]
Expand Down Expand Up @@ -70,19 +71,19 @@ def get(self, key: str) -> str:
key = str(key)

if key in self._contracts:
return self._source[self._contracts[key]]
return self._contract_sources[self._contracts[key]]

if key not in self._source:
if key not in self._contract_sources:
# for sources outside this project (packages, other projects)
with Path(key).open() as fp:
source = fp.read()
self._source[key] = source
self._contract_sources[key] = source

return self._source[key]
return self._contract_sources[key]

def get_path_list(self) -> List:
"""Returns a sorted list of source code file paths for the active project."""
return sorted(self._source.keys())
return sorted(self._contract_sources.keys())

def get_contract_list(self) -> List:
"""Returns a sorted list of contract names for the active project."""
Expand All @@ -94,15 +95,18 @@ def get_interface_list(self) -> List:

def get_interface_hashes(self) -> Dict:
"""Returns a dict of interface hashes in the form of {name: hash}"""
return {k: sha1(self._source[v].encode()).hexdigest() for k, v in self._interfaces.items()}
return {
k: sha1(self._interface_sources[v].encode()).hexdigest()
for k, v in self._interfaces.items()
}

def get_interface_sources(self) -> Dict:
"""Returns a dict of interfaces sources in the form {path: source}"""
return {v: self._source[v] for v in self._interfaces.values()}
return {v: self._interface_sources[v] for v in self._interfaces.values()}

def get_source_path(self, contract_name: str) -> str:
def get_source_path(self, contract_name: str, is_interface: bool = False) -> str:
"""Returns the path to the source file where a contract is located."""
if contract_name in self._contracts:
if contract_name in self._contracts and not is_interface:
return self._contracts[contract_name]
if contract_name in self._interfaces:
return self._interfaces[contract_name]
Expand Down
12 changes: 0 additions & 12 deletions tests/project/main/test_recompile.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,6 @@ def test_modify_library(mockproject):
]


# modifying an interface should recompile a dependent contract
def test_modify_interface(mockproject):
code = INTERFACE.split("\n")
code[3] = ""
code = "\n".join(code)
with mockproject._path.joinpath("interfaces/IFoo.sol").open("w") as fp:
fp.write(code)

mockproject.load()
assert sorted(mockproject._compile.call_args[0][0]) == ["contracts/Foo.sol"]


# modifying a base contract should recompile a dependent
def test_modify_base(mockproject):
code = BASE_CONTRACT.split("\n")
Expand Down
2 changes: 1 addition & 1 deletion tests/project/test_sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def test_contract_interface_collisions(solc5source):


def test_get_path_list(sourceobj):
assert sourceobj.get_path_list() == ["interfaces/Baz.vy", "path/to/Foo.sol"]
assert sourceobj.get_path_list() == ["path/to/Foo.sol"]


def test_get_contract_list(sourceobj):
Expand Down

0 comments on commit ba14902

Please sign in to comment.