This commit is contained in:
lz_db
2025-11-16 12:31:03 +08:00
commit 0fab423a18
1451 changed files with 743213 additions and 0 deletions

View File

@@ -0,0 +1,86 @@
from typing import List, Optional, Union
from ..abi.v2 import shape as ShapeV2
from ..abi.v0 import AbiParser as AbiParserV0
from ..abi.v1 import AbiParser as AbiParserV1
from ..abi.v2 import AbiParser as AbiParserV2
from ..serialization import (
FunctionSerializationAdapter,
serializer_for_function,
)
from ..serialization.factory import (
serializer_for_constructor_v2,
serializer_for_function_v1,
)
def translate_constructor_args(
abi: List, constructor_args: Optional[Union[List, dict]], *, cairo_version: int = 1
) -> List[int]:
serializer = (
_get_constructor_serializer_v1(abi)
if cairo_version == 1
else _get_constructor_serializer_v0(abi)
)
if serializer is None or len(serializer.inputs_serializer.serializers) == 0:
return []
if not constructor_args:
raise ValueError(
"Provided contract has a constructor and no arguments were provided."
)
args, kwargs = (
([], constructor_args)
if isinstance(constructor_args, dict)
else (constructor_args, {})
)
return serializer.serialize(*args, **kwargs)
def _get_constructor_serializer_v1(abi: List) -> Optional[FunctionSerializationAdapter]:
if _is_abi_v2(abi):
parsed = AbiParserV2(abi).parse()
constructor = parsed.constructor
if constructor is None or not constructor.inputs:
return None
return serializer_for_constructor_v2(constructor)
parsed = AbiParserV1(abi).parse()
constructor = parsed.functions.get("constructor", None)
# Constructor might not accept any arguments
if constructor is None or not constructor.inputs:
return None
return serializer_for_function_v1(constructor)
def _is_abi_v2(abi: List) -> bool:
for entry in abi:
if entry["type"] in [
ShapeV2.CONSTRUCTOR_ENTRY,
ShapeV2.L1_HANDLER_ENTRY,
ShapeV2.INTERFACE_ENTRY,
ShapeV2.IMPL_ENTRY,
]:
return True
if entry["type"] == ShapeV2.EVENT_ENTRY:
if "inputs" in entry:
return False
if "kind" in entry:
return True
return False
def _get_constructor_serializer_v0(abi: List) -> Optional[FunctionSerializationAdapter]:
parsed = AbiParserV0(abi).parse()
# Constructor might not accept any arguments
if not parsed.constructor or not parsed.constructor.inputs:
return None
return serializer_for_function(parsed.constructor)

View File

@@ -0,0 +1,13 @@
from typing import Iterable, TypeVar, Union
T = TypeVar("T")
# pyright: reportGeneralTypeIssues=false
def ensure_iterable(value: Union[T, Iterable[T]]) -> Iterable[T]:
try:
iter(value)
# Now we now it is iterable
return value
except TypeError:
return [value]

View File

@@ -0,0 +1,13 @@
import os
from ..marshmallow import EXCLUDE, RAISE
from ..marshmallow import Schema as MarshmallowSchema
MARSHMALLOW_UKNOWN_EXCLUDE = os.environ.get("STARKNET_PY_MARSHMALLOW_UKNOWN_EXCLUDE")
class Schema(MarshmallowSchema):
class Meta:
unknown = (
EXCLUDE if (MARSHMALLOW_UKNOWN_EXCLUDE or "").lower() == "true" else RAISE
)

View File

@@ -0,0 +1,182 @@
from dataclasses import dataclass
from typing import Dict, List, Union, cast
from ...marshmallow import Schema, fields, post_load
from ..cairo.felt import encode_shortstring
from ..hash.selector import get_selector_from_name
from ..hash.utils import compute_hash_on_elements
from ..models.typed_data import StarkNetDomainDict, TypedDataDict
@dataclass(frozen=True)
class Parameter:
"""
Dataclass representing a Parameter object
"""
name: str
type: str
@dataclass(frozen=True)
class TypedData:
"""
Dataclass representing a TypedData object
"""
types: Dict[str, List[Parameter]]
primary_type: str
domain: StarkNetDomainDict
message: dict
@staticmethod
def from_dict(data: TypedDataDict) -> "TypedData":
"""
Create TypedData dataclass from dictionary.
:param data: TypedData dictionary.
:return: TypedData dataclass instance.
"""
return cast(TypedData, TypedDataSchema().load(data))
def _is_struct(self, type_name: str) -> bool:
return type_name in self.types
def _encode_value(self, type_name: str, value: Union[int, str, dict, list]) -> int:
if is_pointer(type_name) and isinstance(value, list):
type_name = strip_pointer(type_name)
if self._is_struct(type_name):
return compute_hash_on_elements(
[self.struct_hash(type_name, data) for data in value]
)
return compute_hash_on_elements([int(get_hex(val), 16) for val in value])
if self._is_struct(type_name) and isinstance(value, dict):
return self.struct_hash(type_name, value)
value = cast(Union[int, str], value)
return int(get_hex(value), 16)
def _encode_data(self, type_name: str, data: dict) -> List[int]:
values = []
for param in self.types[type_name]:
encoded_value = self._encode_value(param.type, data[param.name])
values.append(encoded_value)
return values
def _get_dependencies(self, type_name: str) -> List[str]:
if type_name not in self.types:
# type_name is a primitive type, has no dependencies
return []
dependencies = set()
def collect_deps(type_name: str) -> None:
for param in self.types[type_name]:
fixed_type = strip_pointer(param.type)
if fixed_type in self.types and fixed_type not in dependencies:
dependencies.add(fixed_type)
# recursive call
collect_deps(fixed_type)
# collect dependencies into a set
collect_deps(type_name)
return [type_name, *list(dependencies)]
def _encode_type(self, type_name: str) -> str:
primary, *dependencies = self._get_dependencies(type_name)
types = [primary, *sorted(dependencies)]
def make_dependency_str(dependency):
lst = [f"{t.name}:{t.type}" for t in self.types[dependency]]
return f"{dependency}({','.join(lst)})"
return "".join([make_dependency_str(x) for x in types])
def type_hash(self, type_name: str) -> int:
"""
Calculate the hash of a type name.
:param type_name: Name of the type.
:return: Hash of the type name.
"""
return get_selector_from_name(self._encode_type(type_name))
def struct_hash(self, type_name: str, data: dict) -> int:
"""
Calculate the hash of a struct.
:param type_name: Name of the type.
:param data: Data defining the struct.
:return: Hash of the struct.
"""
return compute_hash_on_elements(
[self.type_hash(type_name), *self._encode_data(type_name, data)]
)
def message_hash(self, account_address: int) -> int:
"""
Calculate the hash of the message.
:param account_address: Address of an account.
:return: Hash of the message.
"""
message = [
encode_shortstring("StarkNet Message"),
self.struct_hash("StarkNetDomain", cast(dict, self.domain)),
account_address,
self.struct_hash(self.primary_type, self.message),
]
return compute_hash_on_elements(message)
def get_hex(value: Union[int, str]) -> str:
if isinstance(value, int):
return hex(value)
if value[:2] == "0x":
return value
if value.isnumeric():
return hex(int(value))
return hex(encode_shortstring(value))
def is_pointer(value: str) -> bool:
return len(value) > 0 and value[-1] == "*"
def strip_pointer(value: str) -> str:
if is_pointer(value):
return value[:-1]
return value
# pylint: disable=unused-argument
# pylint: disable=no-self-use
class ParameterSchema(Schema):
name = fields.String(data_key="name", required=True)
type = fields.String(data_key="type", required=True)
@post_load
def make_dataclass(self, data, **kwargs) -> Parameter:
return Parameter(**data)
class TypedDataSchema(Schema):
types = fields.Dict(
data_key="types",
keys=fields.Str(),
values=fields.List(fields.Nested(ParameterSchema())),
)
primary_type = fields.String(data_key="primaryType", required=True)
domain = fields.Dict(data_key="domain", required=True)
message = fields.Dict(data_key="message", required=True)
@post_load
def make_dataclass(self, data, **kwargs) -> TypedData:
return TypedData(**data)