Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Commit

Permalink
Adapt to breaking changes from cairo-lang 0.7.1 (#38)
Browse files Browse the repository at this point in the history
  • Loading branch information
FabijanC authored Feb 18, 2022
1 parent 5cde962 commit 50b60cf
Show file tree
Hide file tree
Showing 7 changed files with 132 additions and 50 deletions.
13 changes: 12 additions & 1 deletion starknet_devnet/origin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ def get_transaction(self, transaction_hash: str):
"""Returns the transaction object."""
raise NotImplementedError

def get_transaction_receipt(self, transaction_hash: str):
"""Returns the transaction receipt object."""
raise NotImplementedError

def get_block_by_hash(self, block_hash: str):
"""Returns the block identified with either its hash."""
raise NotImplementedError
Expand Down Expand Up @@ -51,8 +55,15 @@ def get_transaction_status(self, transaction_hash: str):

def get_transaction(self, transaction_hash: str):
return {
"status": TxStatus.NOT_RECEIVED.name
}

def get_transaction_receipt(self, transaction_hash: str):
return {
"l2_to_l1_messages": [],
"status": TxStatus.NOT_RECEIVED.name,
"transaction_hash": transaction_hash
"transaction_hash": transaction_hash,
"events": []
}

def get_block_by_hash(self, block_hash: str):
Expand Down
10 changes: 3 additions & 7 deletions starknet_devnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ async def add_transaction():
"""Endpoint for accepting DEPLOY and INVOKE_FUNCTION transactions."""

transaction = validate_transaction(request.data)

tx_type = transaction.tx_type.name

if tx_type == TransactionType.DEPLOY.name:
Expand All @@ -48,20 +47,18 @@ async def add_transaction():

return jsonify({
"code": StarkErrorCode.TRANSACTION_RECEIVED.name,
"transaction_hash": fixed_length_hex(transaction_hash),
"transaction_hash": hex(transaction_hash),
"address": fixed_length_hex(contract_address),
**result_dict
})

def validate_transaction(data: bytes):
def validate_transaction(data: bytes, loader: Transaction=Transaction):
"""Ensure `data` is a valid Starknet transaction. Returns the parsed `Transaction`."""

try:
transaction = Transaction.loads(data)
transaction = loader.loads(data)
except (TypeError, ValidationError) as err:
msg = f"Invalid tx: {err}\nBe sure to use the correct compilation (json) artifact. Devnet-compatible cairo-lang version: {CAIRO_LANG_VERSION}"
abort(Response(msg, 400))

return transaction

@app.route("/feeder_gateway/get_contract_addresses", methods=["GET"])
Expand Down Expand Up @@ -105,7 +102,6 @@ def _check_block_hash(request_args: MultiDict):
@app.route("/feeder_gateway/get_block", methods=["GET"])
async def get_block():
"""Endpoint for retrieving a block identified by its hash or number."""

block_hash = request.args.get("blockHash")
block_number = request.args.get("blockNumber", type=custom_int)

Expand Down
61 changes: 38 additions & 23 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from copy import deepcopy
from typing import Dict

from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction, InternalTransaction
from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.business_logic.state import CarriedState
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Transaction
from starkware.starknet.services.api.gateway.contract_address import calculate_contract_address
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Deploy, Transaction
from starkware.starknet.testing.starknet import Starknet
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
from starkware.starkware_utils.error_handling import StarkException
Expand Down Expand Up @@ -105,43 +106,52 @@ def set_origin(self, origin: Origin):
"""Set the origin chain."""
self.__origin = origin

async def deploy(self, transaction: Transaction):
async def deploy(self, deploy_transaction: Deploy):
"""
Deploys the contract specified with `transaction`.
Returns (contract_address, transaction_hash).
"""

state = await self.__get_state()
deploy_transaction: InternalDeploy = InternalDeploy.from_external(transaction, state.general_config)
contract_definition = deploy_transaction.contract_definition
tx_hash = deploy_transaction.calculate_hash(state.general_config)
contract_address = calculate_contract_address(
caller_address=0,
constructor_calldata=deploy_transaction.constructor_calldata,
salt=deploy_transaction.contract_address_salt,
contract_definition=deploy_transaction.contract_definition
)

starknet = await self.get_starknet()

if deploy_transaction.contract_address not in self.__address2contract_wrapper:
if contract_address not in self.__address2contract_wrapper:
try:
contract = await starknet.deploy(
contract_def=deploy_transaction.contract_definition,
contract_def=contract_definition,
constructor_calldata=deploy_transaction.constructor_calldata,
contract_address_salt=deploy_transaction.contract_address_salt
)
execution_info = contract.deploy_execution_info
error_message = None
status = TxStatus.ACCEPTED_ON_L2

self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, deploy_transaction.contract_definition)
self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, contract_definition)
await self.__update_state()
except StarkException as err:
error_message = err.message
status = TxStatus.REJECTED
execution_info = DummyExecutionInfo()

await self.__store_transaction(
internal_tx=deploy_transaction,
transaction=deploy_transaction,
contract_address=contract_address,
tx_hash=tx_hash,
status=status,
execution_info=execution_info,
error_message=error_message
)

return deploy_transaction.contract_address, deploy_transaction.hash_value
return contract_address, tx_hash

async def invoke(self, transaction: InvokeFunction):
"""Perform invoke according to specifications in `transaction`."""
Expand All @@ -166,7 +176,9 @@ async def invoke(self, transaction: InvokeFunction):
adapted_result = []

await self.__store_transaction(
internal_tx=invoke_transaction,
transaction=invoke_transaction,
contract_address=transaction.contract_address,
tx_hash=invoke_transaction.hash_value,
status=status,
execution_info=execution_info,
error_message=error_message
Expand Down Expand Up @@ -227,15 +239,11 @@ def get_transaction(self, transaction_hash: str):
def get_transaction_receipt(self, transaction_hash: str):
"""Returns the transaction receipt of the transaction identified by `transaction_hash`."""

tx_hash_int = int(transaction_hash,16)
tx_hash_int = int(transaction_hash, 16)
if tx_hash_int in self.__transaction_wrappers:
return self.__transaction_wrappers[tx_hash_int].receipt

return {
"l2_to_l1_messages": [],
"status": TxStatus.NOT_RECEIVED.name,
"transaction_hash": transaction_hash
}
return self.__origin.get_transaction_receipt(transaction_hash)

def get_number_of_blocks(self) -> int:
"""Returns the number of blocks stored so far."""
Expand Down Expand Up @@ -277,7 +285,7 @@ async def __generate_block(self, tx_wrapper: TransactionWrapper):
"state_root": state_root.hex(),
"status": TxStatus.ACCEPTED_ON_L2.name,
"timestamp": timestamp,
"transaction_receipts": [tx_wrapper.receipt],
"transaction_receipts": [tx_wrapper.get_receipt_block_variant()],
"transactions": [tx_wrapper.transaction["transaction"]],
}

Expand Down Expand Up @@ -318,16 +326,23 @@ def get_block_by_number(self, block_number: int):

return self.__origin.get_block_by_number(block_number)

async def __store_transaction(self, internal_tx: InternalTransaction, status: TxStatus,
# pylint: disable=too-many-arguments
async def __store_transaction(self, transaction: Transaction, contract_address: int, tx_hash: int, status: TxStatus,
execution_info: StarknetTransactionExecutionInfo, error_message: str=None
):
"""Stores the provided data as a deploy transaction in `self.transactions`."""
if internal_tx.tx_type == TransactionType.DEPLOY:
tx_wrapper = DeployTransactionWrapper(internal_tx, status, execution_info)
elif internal_tx.tx_type == TransactionType.INVOKE_FUNCTION:
tx_wrapper = InvokeTransactionWrapper(internal_tx, status, execution_info)
if transaction.tx_type == TransactionType.DEPLOY:
tx_wrapper = DeployTransactionWrapper(
transaction=transaction,
contract_address=contract_address,
tx_hash=tx_hash,
status=status,
execution_info=execution_info
)
elif transaction.tx_type == TransactionType.INVOKE_FUNCTION:
tx_wrapper = InvokeTransactionWrapper(transaction, status, execution_info)
else:
raise StarknetDevnetException(message=f"Illegal tx_type: {internal_tx.tx_type}")
raise StarknetDevnetException(message=f"Illegal tx_type: {transaction.tx_type}")

if status == TxStatus.REJECTED:
assert error_message, "error_message must be present if tx rejected"
Expand Down
45 changes: 35 additions & 10 deletions starknet_devnet/transaction_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
"""

from abc import ABC, abstractmethod
from copy import deepcopy
from dataclasses import dataclass
from typing import List

from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction
from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.services.api.gateway.transaction import Deploy
from starkware.starknet.definitions.error_codes import StarknetErrorCode
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
Expand Down Expand Up @@ -38,6 +40,20 @@ class InvokeTransactionDetails(TransactionDetails):
calldata: List[str]
signature: List[str]
entry_point_selector: str
entry_point_type: str

def get_events(execution_info: StarknetTransactionExecutionInfo):
"""Extract events if any; stringify content."""
if not hasattr(execution_info, "raw_events"):
return []
events = []
for event in execution_info.raw_events:
events.append({
"from_address": hex(event.from_address),
"data": [str(d) for d in event.data],
"keys": [str(key) for key in event.keys]
})
return events

class TransactionWrapper(ABC):
"""Transaction Wrapper base class."""
Expand All @@ -48,9 +64,7 @@ def __init__(
):
self.transaction_hash = tx_details.transaction_hash

events = []
if hasattr(execution_info, "raw_events"):
events = execution_info.raw_events
events = get_events(execution_info)

self.transaction = {
"status": status.name,
Expand All @@ -72,6 +86,15 @@ def set_block_data(self, block_hash: str, block_number: int):
self.transaction["block_hash"] = self.receipt["block_hash"] = block_hash
self.transaction["block_number"] = self.receipt["block_number"] = block_number

def get_receipt_block_variant(self):
"""
Receipt is a part of get_block response, but somewhat modified.
This method returns the modified version.
"""
receipt = deepcopy(self.receipt)
del receipt["status"]
return receipt

def set_failure_reason(self, error_message: str):
"""Sets the failure reason to transaction and receipt dicts."""
assert error_message
Expand All @@ -88,16 +111,17 @@ def set_failure_reason(self, error_message: str):
class DeployTransactionWrapper(TransactionWrapper):
"""Wrapper of Deploy Transaction."""

def __init__(self, internal_tx: InternalDeploy, status: TxStatus, execution_info: StarknetTransactionExecutionInfo):
# pylint: disable=too-many-arguments
def __init__(self, transaction: Deploy, contract_address: int, tx_hash: int, status: TxStatus, execution_info: StarknetTransactionExecutionInfo):
super().__init__(
status,
execution_info,
DeployTransactionDetails(
TransactionType.DEPLOY.name,
contract_address=fixed_length_hex(internal_tx.contract_address),
transaction_hash=fixed_length_hex(internal_tx.hash_value),
constructor_calldata=[str(arg) for arg in internal_tx.constructor_calldata],
contract_address_salt=hex(internal_tx.contract_address_salt)
contract_address=fixed_length_hex(contract_address),
transaction_hash=fixed_length_hex(tx_hash),
constructor_calldata=[str(arg) for arg in transaction.constructor_calldata],
contract_address_salt=hex(transaction.contract_address_salt)
)
)

Expand All @@ -114,7 +138,8 @@ def __init__(self, internal_tx: InternalInvokeFunction, status: TxStatus, execut
contract_address=fixed_length_hex(internal_tx.contract_address),
transaction_hash=fixed_length_hex(internal_tx.hash_value),
calldata=[str(arg) for arg in internal_tx.calldata],
entry_point_selector=str(internal_tx.entry_point_selector),
entry_point_selector=fixed_length_hex(internal_tx.entry_point_selector),
entry_point_type=internal_tx.entry_point_type.name,
signature=[str(sig_part) for sig_part in internal_tx.signature]
)
)
8 changes: 4 additions & 4 deletions test/expected/invoke_receipt_event.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
"events": [
{
"data": [
0,
10
"0",
"10"
],
"from_address": 3588522635734560402301506863768658247518049453423086523878030385873504815898,
"from_address": "0x6eceb2feb8e5c474e78f114e589cf089329de072898053d2685f64fd38c39e6",
"keys": [
1744303484486821561902174603220722448499782664094942993128426674277214273437
"1744303484486821561902174603220722448499782664094942993128426674277214273437"
]
}
]
Expand Down
12 changes: 9 additions & 3 deletions test/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@

from .util import (
assert_negative_block_input,
assert_transaction_not_received,
assert_transaction_receipt_not_received,
run_devnet_in_background,
assert_block, assert_contract_code, assert_equal, assert_failing_deploy, assert_receipt, assert_salty_deploy,
assert_storage, assert_transaction, assert_tx_status, assert_events,
Expand All @@ -17,14 +19,17 @@
EVENTS_ABI_PATH = f"{ARTIFACTS_PATH}/events.cairo/events_abi.json"
FAILING_CONTRACT_PATH = f"{ARTIFACTS_PATH}/always_fail.cairo/always_fail.json"

EXPECTED_SALTY_DEPLOY_ADDRESS="0x07ef082652cf5e336e98971981f2ef9a32d5673c822898c344b213f51449cb1a"
EXPECTED_SALTY_DEPLOY_ADDRESS = "0x06eceb2feb8e5c474e78f114e589cf089329de072898053d2685f64fd38c39e6"
EXPECTED_SALTY_DEPLOY_HASH = "0x44f6722fc95c1512a183759502ab4d5b9bae15fb26f3f957db486f79abdd829"
NONEXISTENT_TX_HASH = "0x1"

run_devnet_in_background(sleep_seconds=1)
deploy_info = deploy(CONTRACT_PATH, ["0"])
print("Deployment:", deploy_info)

assert_tx_status(deploy_info["tx_hash"], "ACCEPTED_ON_L2")
assert_transaction(deploy_info["tx_hash"], "ACCEPTED_ON_L2")
assert_transaction_not_received(NONEXISTENT_TX_HASH)

# check storage after deployment
BALANCE_KEY = "916907772491729262376534102982219947830828984996257231353398618781993312401"
Expand All @@ -34,6 +39,7 @@
assert_negative_block_input()
assert_block(0, deploy_info["tx_hash"])
assert_receipt(deploy_info["tx_hash"], "test/expected/deploy_receipt.json")
assert_transaction_receipt_not_received(NONEXISTENT_TX_HASH)

# check code
assert_contract_code(deploy_info["address"])
Expand Down Expand Up @@ -72,7 +78,7 @@
salt="0x99",
inputs=None,
expected_address=EXPECTED_SALTY_DEPLOY_ADDRESS,
expected_tx_hash="0x03e3c1a20f6b175b812bb14df175fb8a9e352ea7b38d7b942489968a8a4a9dd0"
expected_tx_hash=EXPECTED_SALTY_DEPLOY_HASH
)

salty_invoke_tx_hash = invoke(
Expand All @@ -82,6 +88,6 @@
inputs=["10"]
)

assert_events(salty_invoke_tx_hash,"test/expected/invoke_receipt_event.json")
assert_events(salty_invoke_tx_hash, "test/expected/invoke_receipt_event.json")

assert_failing_deploy(contract_path=FAILING_CONTRACT_PATH)
Loading

0 comments on commit 50b60cf

Please sign in to comment.