Skip to content
This repository has been archived by the owner on Dec 1, 2023. It is now read-only.

support mainnet chain_id #275

Merged
merged 3 commits into from
Nov 2, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions src/nile/core/account.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,10 @@ def __init__(self, signer, network, predeployed_info=None):
"""Get or deploy an Account contract for the given private key."""
try:
if predeployed_info is None:
self.signer = Signer(normalize_number(os.environ[signer]))
self.signer = Signer(normalize_number(os.environ[signer]), network)
self.alias = signer
else:
self.signer = Signer(signer)
self.signer = Signer(signer, network)
self.alias = predeployed_info["alias"]

self.network = network
Expand Down
17 changes: 12 additions & 5 deletions src/nile/signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,15 @@
class Signer:
"""Utility for signing transactions for an Account on Starknet."""

def __init__(self, private_key):
def __init__(self, private_key, network="testnet"):
"""Construct a Signer object. Takes a private key."""
self.private_key = private_key
self.public_key = private_to_stark_key(private_key)
self.chain_id = (
StarknetChainId.MAINNET.value
if network == "mainnet"
else StarknetChainId.TESTNET.value
)

def sign(self, message_hash):
"""Sign a message hash."""
Expand All @@ -34,6 +39,7 @@ def sign_declare(self, sender, contract_class, nonce, max_fee):
contract_class=contract_class,
max_fee=max_fee,
nonce=nonce,
chain_id=self.chain_id,
)

return self.sign(message_hash=transaction_hash)
Expand All @@ -60,6 +66,7 @@ def sign_transaction(
nonce=nonce,
max_fee=max_fee,
version=version,
chain_id=self.chain_id,
)

sig_r, sig_s = self.sign(message_hash=transaction_hash)
Expand All @@ -86,19 +93,19 @@ def from_call_to_call_array(calls):
return (call_array, calldata)


def get_declare_hash(sender, contract_class, max_fee, nonce):
def get_declare_hash(sender, contract_class, max_fee, nonce, chain_id):
"""Compute the hash of a declare transaction."""
return calculate_declare_transaction_hash(
contract_class=contract_class,
chain_id=StarknetChainId.TESTNET.value,
chain_id=chain_id,
sender_address=sender,
max_fee=max_fee,
version=TRANSACTION_VERSION,
nonce=nonce,
)


def get_transaction_hash(prefix, account, calldata, nonce, max_fee, version):
def get_transaction_hash(prefix, account, calldata, nonce, max_fee, version, chain_id):
"""Compute the hash of a transaction."""
return calculate_transaction_hash_common(
tx_hash_prefix=prefix,
Expand All @@ -107,6 +114,6 @@ def get_transaction_hash(prefix, account, calldata, nonce, max_fee, version):
entry_point_selector=0,
calldata=calldata,
max_fee=max_fee,
chain_id=StarknetChainId.TESTNET.value,
chain_id=chain_id,
additional_data=[nonce],
)
16 changes: 15 additions & 1 deletion tests/test_signer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,16 @@

import pytest
from starkware.starknet.business_logic.transaction.objects import InternalTransaction
from starkware.starknet.definitions.general_config import StarknetChainId
from starkware.starknet.services.api.contract_class import ContractClass
from starkware.starknet.services.api.gateway.transaction import InvokeFunction
from starkware.starknet.testing.starknet import Starknet

from nile.common import TRANSACTION_VERSION
from nile.signer import Signer, from_call_to_call_array

SIGNER = Signer(12345678987654321)
PRIVATE_KEY = 12345678987654321
SIGNER = Signer(PRIVATE_KEY)


def get_account_definition():
Expand Down Expand Up @@ -106,6 +108,18 @@ async def test_execute():
assert execution_info.result == (3,)


@pytest.mark.asyncio
async def test_chain_id():
mainnet = Signer(PRIVATE_KEY, "mainnet")
assert mainnet.chain_id == StarknetChainId.MAINNET.value

testnet = Signer(PRIVATE_KEY, "testnet")
assert testnet.chain_id == StarknetChainId.TESTNET.value

no_network = Signer(PRIVATE_KEY)
assert no_network.chain_id == StarknetChainId.TESTNET.value


def get_raw_invoke(sender, calls):
"""Construct and return StarkNet's internal raw_invocation."""
call_array, calldata = from_call_to_call_array(calls)
Expand Down