Skip to content
This repository has been archived by the owner on Dec 15, 2023. It is now read-only.

Adapt to breaking changes from cairo-lang 0.7.1 #38

Merged
merged 2 commits into from
Feb 18, 2022
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions starknet_devnet/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ async def add_transaction():
"""Endpoint for accepting DEPLOY and INVOKE_FUNCTION transactions."""

transaction = validate_transaction(request.data)

tx_type = transaction.tx_type.name

if tx_type == TransactionType.DEPLOY.name:
Expand All @@ -53,15 +52,13 @@ async def add_transaction():
**result_dict
})

def validate_transaction(data: bytes):
def validate_transaction(data: bytes, loader: Transaction=Transaction):
"""Ensure `data` is a valid Starknet transaction. Returns the parsed `Transaction`."""

try:
transaction = Transaction.loads(data)
transaction = loader.loads(data)
except (TypeError, ValidationError) as err:
msg = f"Invalid tx: {err}\nBe sure to use the correct compilation (json) artifact. Devnet-compatible cairo-lang version: {CAIRO_LANG_VERSION}"
abort(Response(msg, 400))

return transaction

@app.route("/feeder_gateway/get_contract_addresses", methods=["GET"])
Expand Down
51 changes: 35 additions & 16 deletions starknet_devnet/starknet_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,11 @@
from copy import deepcopy
from typing import Dict

from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction, InternalTransaction
from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.business_logic.state import CarriedState
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Transaction
from starkware.starknet.services.api.gateway.contract_address import calculate_contract_address
from starkware.starknet.services.api.gateway.transaction import InvokeFunction, Deploy, Transaction
from starkware.starknet.testing.starknet import Starknet
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
from starkware.starkware_utils.error_handling import StarkException
Expand Down Expand Up @@ -105,43 +106,52 @@ def set_origin(self, origin: Origin):
"""Set the origin chain."""
self.__origin = origin

async def deploy(self, transaction: Transaction):
async def deploy(self, deploy_transaction: Deploy):
"""
Deploys the contract specified with `transaction`.
Returns (contract_address, transaction_hash).
"""

state = await self.__get_state()
deploy_transaction: InternalDeploy = InternalDeploy.from_external(transaction, state.general_config)
contract_definition = deploy_transaction.contract_definition
tx_hash = deploy_transaction.calculate_hash(state.general_config)
contract_address = calculate_contract_address(
caller_address=0,
constructor_calldata=deploy_transaction.constructor_calldata,
salt=deploy_transaction.contract_address_salt,
contract_definition=deploy_transaction.contract_definition
)

starknet = await self.get_starknet()

if deploy_transaction.contract_address not in self.__address2contract_wrapper:
if contract_address not in self.__address2contract_wrapper:
try:
contract = await starknet.deploy(
contract_def=deploy_transaction.contract_definition,
contract_def=contract_definition,
constructor_calldata=deploy_transaction.constructor_calldata,
contract_address_salt=deploy_transaction.contract_address_salt
)
execution_info = contract.deploy_execution_info
error_message = None
status = TxStatus.ACCEPTED_ON_L2

self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, deploy_transaction.contract_definition)
self.__address2contract_wrapper[contract.contract_address] = ContractWrapper(contract, contract_definition)
await self.__update_state()
except StarkException as err:
error_message = err.message
status = TxStatus.REJECTED
execution_info = DummyExecutionInfo()

await self.__store_transaction(
internal_tx=deploy_transaction,
transaction=deploy_transaction,
contract_address=contract_address,
tx_hash=tx_hash,
status=status,
execution_info=execution_info,
error_message=error_message
)

return deploy_transaction.contract_address, deploy_transaction.hash_value
return contract_address, tx_hash

async def invoke(self, transaction: InvokeFunction):
"""Perform invoke according to specifications in `transaction`."""
Expand All @@ -166,7 +176,9 @@ async def invoke(self, transaction: InvokeFunction):
adapted_result = []

await self.__store_transaction(
internal_tx=invoke_transaction,
transaction=invoke_transaction,
contract_address=transaction.contract_address,
tx_hash=invoke_transaction.hash_value,
status=status,
execution_info=execution_info,
error_message=error_message
Expand Down Expand Up @@ -318,16 +330,23 @@ def get_block_by_number(self, block_number: int):

return self.__origin.get_block_by_number(block_number)

async def __store_transaction(self, internal_tx: InternalTransaction, status: TxStatus,
# pylint: disable=too-many-arguments
async def __store_transaction(self, transaction: Transaction, contract_address: int, tx_hash: int, status: TxStatus,
execution_info: StarknetTransactionExecutionInfo, error_message: str=None
):
"""Stores the provided data as a deploy transaction in `self.transactions`."""
if internal_tx.tx_type == TransactionType.DEPLOY:
tx_wrapper = DeployTransactionWrapper(internal_tx, status, execution_info)
elif internal_tx.tx_type == TransactionType.INVOKE_FUNCTION:
tx_wrapper = InvokeTransactionWrapper(internal_tx, status, execution_info)
if transaction.tx_type == TransactionType.DEPLOY:
tx_wrapper = DeployTransactionWrapper(
transaction=transaction,
contract_address=contract_address,
tx_hash=tx_hash,
status=status,
execution_info=execution_info
)
elif transaction.tx_type == TransactionType.INVOKE_FUNCTION:
tx_wrapper = InvokeTransactionWrapper(transaction, status, execution_info)
else:
raise StarknetDevnetException(message=f"Illegal tx_type: {internal_tx.tx_type}")
raise StarknetDevnetException(message=f"Illegal tx_type: {transaction.tx_type}")

if status == TxStatus.REJECTED:
assert error_message, "error_message must be present if tx rejected"
Expand Down
14 changes: 8 additions & 6 deletions starknet_devnet/transaction_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from dataclasses import dataclass
from typing import List

from starkware.starknet.business_logic.internal_transaction import InternalDeploy, InternalInvokeFunction
from starkware.starknet.business_logic.internal_transaction import InternalInvokeFunction
from starkware.starknet.services.api.gateway.transaction import Deploy
from starkware.starknet.definitions.error_codes import StarknetErrorCode
from starkware.starknet.definitions.transaction_type import TransactionType
from starkware.starknet.testing.objects import StarknetTransactionExecutionInfo
Expand Down Expand Up @@ -88,16 +89,17 @@ def set_failure_reason(self, error_message: str):
class DeployTransactionWrapper(TransactionWrapper):
"""Wrapper of Deploy Transaction."""

def __init__(self, internal_tx: InternalDeploy, status: TxStatus, execution_info: StarknetTransactionExecutionInfo):
# pylint: disable=too-many-arguments
def __init__(self, transaction: Deploy, contract_address: int, tx_hash: int, status: TxStatus, execution_info: StarknetTransactionExecutionInfo):
super().__init__(
status,
execution_info,
DeployTransactionDetails(
TransactionType.DEPLOY.name,
contract_address=fixed_length_hex(internal_tx.contract_address),
transaction_hash=fixed_length_hex(internal_tx.hash_value),
constructor_calldata=[str(arg) for arg in internal_tx.constructor_calldata],
contract_address_salt=hex(internal_tx.contract_address_salt)
contract_address=fixed_length_hex(contract_address),
transaction_hash=fixed_length_hex(tx_hash),
constructor_calldata=[str(arg) for arg in transaction.constructor_calldata],
contract_address_salt=hex(transaction.contract_address_salt)
)
)

Expand Down