Skip to content

Commit d9a50bb

Browse files
committed
refactor: refactor base contract creation methods
1 parent 31bc80a commit d9a50bb

File tree

7 files changed

+36
-51
lines changed

7 files changed

+36
-51
lines changed

src/ape/api/accounts.py

+9-6
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,17 @@ def deploy(self, contract: "ContractContainer", *args, **kwargs) -> "ContractIns
173173
if not receipt.contract_address:
174174
raise AccountsError(f"'{receipt.txn_hash}' did not create a contract.")
175175

176-
address = click.style(receipt.contract_address, bold=True)
176+
address = receipt.contract_address
177+
styled_address = click.style(receipt.contract_address, bold=True)
177178
contract_name = contract.contract_type.name or "<Unnamed Contract>"
178-
logger.success(f"Contract '{contract_name}' deployed to: {address}")
179-
180-
contract_instance = self.get_contract_instance(
181-
address=receipt.contract_address, # type: ignore
182-
contract_type=contract.contract_type,
179+
logger.success(f"Contract '{contract_name}' deployed to: {styled_address}")
180+
contract_instance = self.chain_manager.contracts.instance_at(
181+
address, contract.contract_type
183182
)
183+
184+
if not isinstance(contract_instance, ContractInstance):
185+
raise ValueError("Failed to deploy contract.")
186+
184187
self.chain_manager.contracts[contract_instance.address] = contract_instance.contract_type
185188
return contract_instance
186189

src/ape/api/projects.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,8 @@ def __getattr__(self, contract_name: str) -> "ContractContainer":
235235
def get(self, contract_name: str) -> Optional["ContractContainer"]:
236236
manifest = self.extract_manifest()
237237
if hasattr(manifest, contract_name):
238-
return self.create_contract_container(contract_type=getattr(manifest, contract_name))
238+
contract_type = getattr(manifest, contract_name)
239+
return self.chain_manager.contracts.get_container(contract_type)
239240

240241
return None
241242

src/ape/contracts/base.py

+6-4
Original file line numberDiff line numberDiff line change
@@ -674,10 +674,12 @@ def at(self, address: AddressType) -> ContractInstance:
674674
:class:`~ape.contracts.ContractInstance`
675675
"""
676676

677-
return self.get_contract_instance(
678-
address=address,
679-
contract_type=self.contract_type,
680-
)
677+
contract = self.chain_manager.contracts.instance_at(address, self.contract_type)
678+
679+
if not isinstance(contract, ContractInstance):
680+
raise ValueError(f"Address '{address}' is not a contract.")
681+
682+
return contract
681683

682684
def __call__(self, *args, **kwargs) -> TransactionAPI:
683685
args = self.conversion_manager.convert(args, tuple)

src/ape/managers/chain.py

+15-1
Original file line numberDiff line numberDiff line change
@@ -576,6 +576,19 @@ def get(
576576

577577
return contract_type
578578

579+
def get_container(self, contract_type: ContractType) -> ContractContainer:
580+
"""
581+
Get a contract container for the given contract type.
582+
583+
Args:
584+
contract_type (ContractType): The contract type to wrap.
585+
586+
Returns:
587+
ContractContainer: A container object you can deploy.
588+
"""
589+
590+
return ContractContainer(contract_type)
591+
579592
def instance_at(
580593
self, address: Union[str, "AddressType"], contract_type: Optional[ContractType] = None
581594
) -> BaseAddress:
@@ -618,8 +631,9 @@ def instance_at(
618631
f"Expected type '{ContractType.__name__}' for argument 'contract_type'."
619632
)
620633

621-
return self.get_contract_instance(address, contract_type)
634+
return ContractInstance(address, contract_type)
622635

636+
logger.warning(f"Failed to find contract type at address '{address}'.")
623637
return Address(address)
624638

625639
def get_deployments(self, contract_container: ContractContainer) -> List[ContractInstance]:

src/ape/managers/project/manager.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -440,9 +440,7 @@ def _load_dependencies(self) -> Dict[str, Dict[str, DependencyAPI]]:
440440

441441
def _get_contract(self, name: str) -> Optional[ContractContainer]:
442442
if name in self.contracts:
443-
return self.create_contract_container(
444-
contract_type=self.contracts[name],
445-
)
443+
return self.chain_manager.contracts.get_container(self.contracts[name])
446444

447445
return None
448446

src/ape/utils/basemodel.py

-34
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,13 @@
11
from abc import ABC
22
from typing import TYPE_CHECKING, ClassVar, Dict, List, cast
33

4-
from ethpm_types import ContractType
54
from pydantic import BaseModel
65

76
from ape.exceptions import ProviderNotConnectedError
8-
from ape.types import AddressType
97
from ape.utils.misc import cached_property, singledispatchmethod
108

119
if TYPE_CHECKING:
1210
from ape.api.providers import ProviderAPI
13-
from ape.contracts.base import ContractContainer, ContractInstance
1411
from ape.managers.accounts import AccountManager
1512
from ape.managers.chain import ChainManager
1613
from ape.managers.compilers import CompilerManager
@@ -75,37 +72,6 @@ def provider(self) -> "ProviderAPI":
7572
raise ProviderNotConnectedError()
7673
return self.network_manager.active_provider
7774

78-
def create_contract_container(self, contract_type: ContractType) -> "ContractContainer":
79-
"""
80-
Helper method for creating a ``ContractContainer``.
81-
82-
Args:
83-
contract_type (``ContractType``): Type of contract for the container
84-
85-
Returns:
86-
:class:`~ape.contracts.ContractContainer`
87-
"""
88-
from ape.contracts.base import ContractContainer
89-
90-
return ContractContainer(contract_type=contract_type)
91-
92-
def get_contract_instance(
93-
self, address: "AddressType", contract_type: "ContractType"
94-
) -> "ContractInstance":
95-
"""
96-
Helper method for creating a ``ContractInstance``.
97-
98-
Args:
99-
address (``AddressType``): Address of contract
100-
contract_type (``ContractType``): Type of contract
101-
102-
Returns:
103-
:class:`~ape.contracts.ContractInstance`
104-
"""
105-
from ape.contracts.base import ContractInstance
106-
107-
return ContractInstance(address=address, contract_type=contract_type)
108-
10975

11076
class BaseInterface(ManagerAccessMixin, ABC):
11177
"""

src/ape/utils/trace.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -132,10 +132,11 @@ def _dim_default_gas(call_sig: str) -> str:
132132
method = None
133133
contract_name = contract_type.name
134134
if "symbol" in contract_type.view_methods:
135-
contract = self._receipt.get_contract_instance(address, contract_type)
135+
# Type ignores below because we know it's a ContractInstance at this point.
136+
contract = self._receipt.chain_manager.contracts.instance_at(address, contract_type)
136137

137138
try:
138-
contract_name = contract.symbol() or contract_name
139+
contract_name = contract.symbol() or contract_name # type: ignore
139140
except ContractError:
140141
contract_name = contract_type.name
141142

0 commit comments

Comments
 (0)