Skip to content

Commit

Permalink
feat: improve ast output (#2824)
Browse files Browse the repository at this point in the history
Modify AST output format to be partially folded (only performs folding
of builtin constants and builtin functions) and after type annotation
and validation. The purpose is to expose type information on expressions
to downstream tooling.

Partial fix for #2276
  • Loading branch information
tserg authored May 5, 2022
1 parent df2da62 commit 3cbdf35
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 14 deletions.
5 changes: 4 additions & 1 deletion tests/cli/vyper_json/test_output_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,10 @@ def test_keys():
assert sorted(output_json.keys()) == ["compiler", "contracts", "sources"]
assert output_json["compiler"] == f"vyper-{vyper.__version__}"
data = compiler_data["foo.vy"]
assert output_json["sources"]["foo.vy"] == {"id": 0, "ast": data["ast_dict"]["ast"]}
assert output_json["sources"]["foo.vy"] == {
"id": 0,
"ast": data["ast_dict"]["ast"],
}
assert output_json["contracts"]["foo.vy"]["foo"] == {
"abi": data["abi"],
"devdoc": data["devdoc"],
Expand Down
2 changes: 2 additions & 0 deletions tests/parser/ast_utils/test_ast_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,9 @@ def test_basic_ast():
"lineno": 2,
"node_id": 2,
"src": "1:1:0",
"type": "int128",
},
"type": "int128",
"value": None,
}

Expand Down
4 changes: 4 additions & 0 deletions vyper/ast/nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,10 @@ def to_dict(self) -> dict:
ast_dict[key] = [_to_dict(i) for i in value]
else:
ast_dict[key] = _to_dict(value)

if "type" in self._metadata:
ast_dict["type"] = str(self._metadata["type"])

return ast_dict

def get_ancestor(self, node_type: Union["VyperNode", tuple, None] = None) -> "VyperNode":
Expand Down
28 changes: 18 additions & 10 deletions vyper/builtin_functions/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,12 @@
SHA256_PER_WORD_GAS = 12


class _SimpleBuiltinFunction:
class _BuiltinFunction:
def __repr__(self):
return f"builtin function {self._id}"


class _SimpleBuiltinFunction(_BuiltinFunction):
def fetch_call_return(self, node):
validate_call_args(node, len(self._inputs), getattr(self, "_kwargs", []))
for arg, (_, expected) in zip(node.args, self._inputs):
Expand Down Expand Up @@ -168,7 +173,7 @@ def build_IR(self, expr, args, kwargs, context):
)


class Convert:
class Convert(_BuiltinFunction):

_id = "convert"

Expand Down Expand Up @@ -246,7 +251,7 @@ def _build_adhoc_slice_node(sub: IRnode, start: IRnode, length: IRnode, context:
return IRnode.from_list(node, typ=ByteArrayType(length.value), location=MEMORY)


class Slice:
class Slice(_BuiltinFunction):

_id = "slice"
_inputs = [("b", ("Bytes", "bytes32", "String")), ("start", "uint256"), ("length", "uint256")]
Expand Down Expand Up @@ -451,7 +456,7 @@ def build_IR(self, node, context):
return get_bytearray_length(arg)


class Concat:
class Concat(_BuiltinFunction):

_id = "concat"

Expand Down Expand Up @@ -685,7 +690,7 @@ def build_IR(self, expr, args, kwargs, context):
)


class MethodID:
class MethodID(_BuiltinFunction):

_id = "method_id"

Expand Down Expand Up @@ -948,7 +953,7 @@ def build_IR(self, expr, args, kwargs, context):
)


class AsWeiValue:
class AsWeiValue(_BuiltinFunction):

_id = "as_wei_value"
_inputs = [("value", NumericAbstractType()), ("unit", "str_literal")]
Expand Down Expand Up @@ -1248,7 +1253,7 @@ def build_IR(self, expr, args, kwargs, contact):
)


class RawLog:
class RawLog(_BuiltinFunction):

_id = "raw_log"
_inputs = [("topics", "*"), ("data", ("bytes32", "Bytes"))]
Expand Down Expand Up @@ -1649,11 +1654,14 @@ def build_IR(self, expr, args, kwargs, context):
)


class _UnsafeMath:
class _UnsafeMath(_BuiltinFunction):

# TODO add unsafe math for `decimal`s
_inputs = [("a", IntegerAbstractType()), ("b", IntegerAbstractType())]

def __repr__(self):
return f"builtin function unsafe_{self.op}"

def fetch_call_return(self, node):
validate_call_args(node, 2)

Expand Down Expand Up @@ -1711,7 +1719,7 @@ class UnsafeDiv(_UnsafeMath):
op = "div"


class _MinMax:
class _MinMax(_BuiltinFunction):

_inputs = [("a", NumericAbstractType()), ("b", NumericAbstractType())]

Expand Down Expand Up @@ -1852,7 +1860,7 @@ def build_IR(self, expr, args, kwargs, context):
)


class Empty:
class Empty(_BuiltinFunction):

_id = "empty"
_inputs = [("typename", "*")]
Expand Down
7 changes: 6 additions & 1 deletion vyper/cli/vyper_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -302,7 +302,12 @@ def compile_files(
output_formats = combined_json_outputs
show_version = True

translate_map = {"abi_python": "abi", "json": "abi", "ast": "ast_dict", "ir_json": "ir_dict"}
translate_map = {
"abi_python": "abi",
"json": "abi",
"ast": "ast_dict",
"ir_json": "ir_dict",
}
final_formats = [translate_map.get(i, i) for i in output_formats]

compiler_data = vyper.compile_codes(
Expand Down
2 changes: 1 addition & 1 deletion vyper/compiler/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
def build_ast_dict(compiler_data: CompilerData) -> dict:
ast_dict = {
"contract_name": compiler_data.contract_name,
"ast": ast_to_dict(compiler_data.vyper_module),
"ast": ast_to_dict(compiler_data.vyper_module_unfolded),
}
return ast_dict

Expand Down
25 changes: 25 additions & 0 deletions vyper/compiler/phases.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,17 @@ def vyper_module(self) -> vy_ast.Module:

return self._vyper_module

@property
def vyper_module_unfolded(self) -> vy_ast.Module:
# This phase is intended to generate an AST for tooling use, and is not
# used in the compilation process.
if not hasattr(self, "_vyper_module_unfolded"):
self._vyper_module_unfolded = generate_unfolded_ast(
self.vyper_module, self.interface_codes
)

return self._vyper_module_unfolded

@property
def vyper_module_folded(self) -> vy_ast.Module:
if not hasattr(self, "_vyper_module_folded"):
Expand Down Expand Up @@ -184,6 +195,20 @@ def generate_ast(source_code: str, source_id: int, contract_name: str) -> vy_ast
return vy_ast.parse_to_ast(source_code, source_id, contract_name)


def generate_unfolded_ast(
vyper_module: vy_ast.Module,
interface_codes: Optional[InterfaceImports],
) -> vy_ast.Module:

vy_ast.validation.validate_literal_nodes(vyper_module)
vy_ast.folding.replace_builtin_constants(vyper_module)
vy_ast.folding.replace_builtin_functions(vyper_module)
# note: validate_semantics does type inference on the AST
validate_semantics(vyper_module, interface_codes)

return vyper_module


def generate_folded_ast(
vyper_module: vy_ast.Module,
interface_codes: Optional[InterfaceImports],
Expand Down
3 changes: 2 additions & 1 deletion vyper/semantics/types/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,8 @@ def __init__(
self.nonreentrant = nonreentrant

def __repr__(self):
return f"contract function '{self.name}'"
arg_types = ",".join(repr(a) for a in self.arguments.values())
return f"contract function {self.name}({arg_types})"

@classmethod
def from_abi(cls, abi: Dict) -> "ContractFunction":
Expand Down
4 changes: 4 additions & 0 deletions vyper/semantics/types/user/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def __init__(self, name: str, arguments: OrderedDict, indexed: List) -> None:
self.indexed = indexed
self.event_id = int(keccak256(self.signature.encode()).hex(), 16)

def __repr__(self):
arg_types = ",".join(repr(a) for a in self.arguments.values())
return f"event {self.name}({arg_types})"

@property
def signature(self):
return f"{self.name}({','.join(v.canonical_abi_type for v in self.arguments.values())})"
Expand Down

0 comments on commit 3cbdf35

Please sign in to comment.