This commit is contained in:
lz_db
2025-11-16 12:31:03 +08:00
commit 0fab423a18
1451 changed files with 743213 additions and 0 deletions

View File

@@ -0,0 +1,2 @@
from .model import Abi
from .parser import AbiParser, AbiParsingError

View File

@@ -0,0 +1,44 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional, OrderedDict
from ...cairo.data_types import CairoType, StructType
@dataclass
class Abi:
"""
Dataclass representing class abi. Contains parsed functions, events and structures.
"""
@dataclass
class Function:
"""
Dataclass representing function's abi.
"""
name: str
inputs: OrderedDict[str, CairoType]
outputs: OrderedDict[str, CairoType]
@dataclass
class Event:
"""
Dataclass representing event's abi.
"""
name: str
data: OrderedDict[str, CairoType]
defined_structures: Dict[
str, StructType
] #: Abi of structures defined by the class.
functions: Dict[str, Function] #: Functions defined by the class.
constructor: Optional[
Function
] #: Contract's constructor. It is None if class doesn't define one.
l1_handler: Optional[
Function
] #: Handler of L1 messages. It is None if class doesn't define one.
events: Dict[str, Event] #: Events defined by the class

View File

@@ -0,0 +1,216 @@
from __future__ import annotations
import dataclasses
import json
from collections import OrderedDict, defaultdict
from typing import DefaultDict, Dict, List, Optional, cast
from ....marshmallow import EXCLUDE
from .model import Abi
from .schemas import ContractAbiEntrySchema
from .shape import (
CONSTRUCTOR_ENTRY,
EVENT_ENTRY,
FUNCTION_ENTRY,
L1_HANDLER_ENTRY,
STRUCT_ENTRY,
EventDict,
FunctionDict,
StructMemberDict,
TypedMemberDict,
)
from ...cairo.data_types import CairoType, StructType
from ...cairo.type_parser import TypeParser
class AbiParsingError(ValueError):
"""
Error raised when something wrong goes during abi parsing.
"""
class AbiParser:
"""
Utility class for parsing abi into a dataclass.
"""
# Entries from ABI grouped by entry type
_grouped: DefaultDict[str, List[Dict]]
# lazy init property
_type_parser: Optional[TypeParser] = None
def __init__(self, abi_list: List[Dict]):
"""
Abi parser constructor. Ensures that abi satisfies the abi schema.
:param abi_list: Contract's ABI as a list of dictionaries.
"""
abi = [
ContractAbiEntrySchema().load(entry, unknown=EXCLUDE) for entry in abi_list
]
grouped = defaultdict(list)
for entry in abi:
assert isinstance(entry, dict)
grouped[entry["type"]].append(entry)
self._grouped = grouped
def parse(self) -> Abi:
"""
Parse abi provided to constructor and return it as a dataclass. Ensures that there are no cycles in the abi.
:raises: AbiParsingError: on any parsing error.
:return: Abi dataclass.
"""
structures = self._parse_structures()
functions_dict = cast(
Dict[str, FunctionDict],
AbiParser._group_by_entry_name(
self._grouped[FUNCTION_ENTRY], "defined functions"
),
)
events_dict = cast(
Dict[str, EventDict],
AbiParser._group_by_entry_name(
self._grouped[EVENT_ENTRY], "defined events"
),
)
constructors = cast(List[FunctionDict], self._grouped[CONSTRUCTOR_ENTRY])
l1_handlers = cast(List[FunctionDict], self._grouped[L1_HANDLER_ENTRY])
if len(l1_handlers) > 1:
raise AbiParsingError("L1 handler in ABI must be defined at most once.")
if len(constructors) > 1:
raise AbiParsingError("Constructor in ABI must be defined at most once.")
return Abi(
defined_structures=structures,
constructor=(
self._parse_function(constructors[0]) if constructors else None
),
l1_handler=(self._parse_function(l1_handlers[0]) if l1_handlers else None),
functions={
name: self._parse_function(entry)
for name, entry in functions_dict.items()
},
events={
name: self._parse_event(entry) for name, entry in events_dict.items()
},
)
@property
def type_parser(self) -> TypeParser:
if self._type_parser:
return self._type_parser
raise RuntimeError("Tried to get type_parser before it was set.")
def _parse_structures(self) -> Dict[str, StructType]:
structs_dict = AbiParser._group_by_entry_name(
self._grouped[STRUCT_ENTRY], "defined structures"
)
# Contains sorted members of the struct
struct_members: Dict[str, List[StructMemberDict]] = {}
structs: Dict[str, StructType] = {}
# Example problem (with a simplified json structure):
# [{name: User, fields: {id: Uint256}}, {name: "Uint256", ...}]
# User refers to Uint256 even though it is not known yet (will be parsed next).
# This is why it is important to create the structure types first. This way other types can already refer to
# them when parsing types, even thought their fields are not filled yet.
# At the end we will mutate those structures to contain the right fields. An alternative would be to use
# topological sorting with an additional "unresolved type", so this flow is much easier.
for name, struct in structs_dict.items():
structs[name] = StructType(name, OrderedDict())
without_offset = [
member for member in struct["members"] if member.get("offset") is None
]
with_offset = [
member for member in struct["members"] if member not in without_offset
]
struct_members[name] = sorted(
with_offset, key=lambda member: member["offset"] # pyright: ignore
)
for member in without_offset:
member["offset"] = (
struct_members[name][-1].get("offset", 0) + 1
if struct_members[name]
else 0
)
struct_members[name].append(member)
# Now parse the types of members and save them.
self._type_parser = TypeParser(structs)
for name, struct in structs.items():
members = self._parse_members(
cast(List[TypedMemberDict], struct_members[name]),
f"members of structure '{name}'",
)
struct.types.update(members)
# All types have their members assigned now
self._check_for_cycles(structs)
return structs
@staticmethod
def _check_for_cycles(structs: Dict[str, StructType]):
# We want to avoid creating our own cycle checker as it would make it more complex. json module has a built-in
# checker for cycles.
try:
_to_json(structs)
except ValueError as err:
raise AbiParsingError(err) from ValueError
def _parse_function(self, function: FunctionDict) -> Abi.Function:
return Abi.Function(
name=function["name"],
inputs=self._parse_members(function["inputs"], function["name"]),
outputs=self._parse_members(function["outputs"], function["name"]),
)
def _parse_event(self, event: EventDict) -> Abi.Event:
return Abi.Event(
name=event["name"],
data=self._parse_members(event["data"], event["name"]),
)
def _parse_members(
self, params: List[TypedMemberDict], entity_name: str
) -> OrderedDict[str, CairoType]:
# Without cast, it complains that 'Type "TypedMemberDict" cannot be assigned to type "T@_group_by_name"'
members = AbiParser._group_by_entry_name(cast(List[Dict], params), entity_name)
return OrderedDict(
(name, self.type_parser.parse_inline_type(param["type"]))
for name, param in members.items()
)
@staticmethod
def _group_by_entry_name(
dicts: List[Dict], entity_name: str
) -> OrderedDict[str, Dict]:
grouped = OrderedDict()
for entry in dicts:
name = entry["name"]
if name in grouped:
raise AbiParsingError(
f"Name '{name}' was used more than once in {entity_name}."
)
grouped[name] = entry
return grouped
def _to_json(value):
class DataclassSupportingEncoder(json.JSONEncoder):
def default(self, o):
# Dataclasses are not supported by json. Additionally, dataclasses.asdict() works recursively and doesn't
# check for cycles, so we need to flatten dataclasses (by ONE LEVEL) ourselves.
if dataclasses.is_dataclass(o):
return tuple(getattr(o, field.name) for field in dataclasses.fields(o))
return super().default(o)
return json.dumps(value, cls=DataclassSupportingEncoder)

View File

@@ -0,0 +1,72 @@
from ....marshmallow import Schema, fields
from ....marshmallow_oneofschema import OneOfSchema
from .shape import (
CONSTRUCTOR_ENTRY,
EVENT_ENTRY,
FUNCTION_ENTRY,
L1_HANDLER_ENTRY,
STRUCT_ENTRY,
)
class TypedParameterSchema(Schema):
name = fields.String(data_key="name", required=True)
type = fields.String(data_key="type", required=True)
class StructMemberSchema(TypedParameterSchema):
offset = fields.Integer(data_key="offset", required=False)
class FunctionBaseSchema(Schema):
name = fields.String(data_key="name", required=True)
inputs = fields.List(
fields.Nested(TypedParameterSchema()), data_key="inputs", required=True
)
outputs = fields.List(
fields.Nested(TypedParameterSchema()), data_key="outputs", required=True
)
class FunctionAbiEntrySchema(FunctionBaseSchema):
type = fields.Constant(FUNCTION_ENTRY, data_key="type", required=True)
class ConstructorAbiEntrySchema(FunctionBaseSchema):
type = fields.Constant(CONSTRUCTOR_ENTRY, data_key="type", required=True)
class L1HandlerAbiEntrySchema(FunctionBaseSchema):
type = fields.Constant(L1_HANDLER_ENTRY, data_key="type", required=True)
class EventAbiEntrySchema(Schema):
type = fields.Constant(EVENT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
keys = fields.List(
fields.Nested(TypedParameterSchema()), data_key="keys", required=True
)
data = fields.List(
fields.Nested(TypedParameterSchema()), data_key="data", required=True
)
class StructAbiEntrySchema(Schema):
type = fields.Constant(STRUCT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
size = fields.Integer(data_key="size", required=True)
members = fields.List(
fields.Nested(StructMemberSchema()), data_key="members", required=True
)
class ContractAbiEntrySchema(OneOfSchema):
type_field_remove = False
type_schemas = {
FUNCTION_ENTRY: FunctionAbiEntrySchema,
L1_HANDLER_ENTRY: L1HandlerAbiEntrySchema,
CONSTRUCTOR_ENTRY: ConstructorAbiEntrySchema,
EVENT_ENTRY: EventAbiEntrySchema,
STRUCT_ENTRY: StructAbiEntrySchema,
}

View File

@@ -0,0 +1,63 @@
# TODO (#1260): update pylint to 3.1.0 and remove pylint disable
# pylint: disable=too-many-ancestors
import sys
from typing import List, Literal, Union
if sys.version_info < (3, 11):
from typing_extensions import NotRequired, TypedDict
else:
from typing import NotRequired, TypedDict
STRUCT_ENTRY = "struct"
FUNCTION_ENTRY = "function"
CONSTRUCTOR_ENTRY = "constructor"
L1_HANDLER_ENTRY = "l1_handler"
EVENT_ENTRY = "event"
class TypedMemberDict(TypedDict):
name: str
type: str
class StructMemberDict(TypedMemberDict):
offset: NotRequired[int]
class StructDict(TypedDict):
type: Literal["struct"]
name: str
size: int
members: List[StructMemberDict]
class FunctionBaseDict(TypedDict):
name: str
inputs: List[TypedMemberDict]
outputs: List[TypedMemberDict]
stateMutability: NotRequired[Literal["view"]]
class FunctionDict(FunctionBaseDict):
type: Literal["function"]
class ConstructorDict(FunctionBaseDict):
type: Literal["constructor"]
class L1HandlerDict(FunctionBaseDict):
type: Literal["l1_handler"]
class EventDict(TypedDict):
name: str
type: Literal["event"]
data: List[TypedMemberDict]
keys: List[TypedMemberDict]
AbiDictEntry = Union[
StructDict, FunctionDict, ConstructorDict, L1HandlerDict, EventDict
]
AbiDictList = List[AbiDictEntry]

View File

@@ -0,0 +1,2 @@
from .model import Abi
from .parser import AbiParser, AbiParsingError

View File

@@ -0,0 +1,14 @@
{
"abi": [
{
"type": "struct",
"name": "core::starknet::eth_address::EthAddress",
"members": [
{
"name": "address",
"type": "core::felt252"
}
]
}
]
}

View File

@@ -0,0 +1,39 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, OrderedDict
from ...cairo.data_types import CairoType, EnumType, StructType
@dataclass
class Abi:
"""
Dataclass representing class abi. Contains parsed functions, enums, events and structures.
"""
@dataclass
class Function:
"""
Dataclass representing function's abi.
"""
name: str
inputs: OrderedDict[str, CairoType]
outputs: List[CairoType]
@dataclass
class Event:
"""
Dataclass representing event's abi.
"""
name: str
inputs: OrderedDict[str, CairoType]
defined_structures: Dict[
str, StructType
] #: Abi of structures defined by the class.
defined_enums: Dict[str, EnumType] #: Abi of enums defined by the class.
functions: Dict[str, Function] #: Functions defined by the class.
events: Dict[str, Event] #: Events defined by the class

View File

@@ -0,0 +1,220 @@
from __future__ import annotations
import dataclasses
import json
import os
from collections import OrderedDict, defaultdict
from pathlib import Path
from typing import DefaultDict, Dict, List, Optional, Tuple, Union, cast
from ....marshmallow import EXCLUDE
from .model import Abi
from .schemas import ContractAbiEntrySchema
from .shape import (
ENUM_ENTRY,
EVENT_ENTRY,
FUNCTION_ENTRY,
STRUCT_ENTRY,
EventDict,
FunctionDict,
TypedParameterDict,
)
from ...cairo.data_types import CairoType, EnumType, StructType
from ...cairo.v1.type_parser import TypeParser
class AbiParsingError(ValueError):
"""
Error raised when something wrong goes during abi parsing.
"""
class AbiParser:
"""
Utility class for parsing abi into a dataclass.
"""
# Entries from ABI grouped by entry type
_grouped: DefaultDict[str, List[Dict]]
# lazy init property
_type_parser: Optional[TypeParser] = None
def __init__(self, abi_list: List[Dict]):
"""
Abi parser constructor. Ensures that abi satisfies the abi schema.
:param abi_list: Contract's ABI as a list of dictionaries.
"""
# prepend abi with core structures
core_structures = (
Path(os.path.dirname(__file__)) / "core_structures.json"
).read_text("utf-8")
abi_list = json.loads(core_structures)["abi"] + abi_list
abi = [
ContractAbiEntrySchema().load(entry, unknown=EXCLUDE) for entry in abi_list
]
grouped = defaultdict(list)
for entry in abi:
assert isinstance(entry, dict)
grouped[entry["type"]].append(entry)
self._grouped = grouped
def parse(self) -> Abi:
"""
Parse abi provided to constructor and return it as a dataclass. Ensures that there are no cycles in the abi.
:raises: AbiParsingError: on any parsing error.
:return: Abi dataclass.
"""
structures, enums = self._parse_structures_and_enums()
functions_dict = cast(
Dict[str, FunctionDict],
AbiParser._group_by_entry_name(
self._grouped[FUNCTION_ENTRY], "defined functions"
),
)
events_dict = cast(
Dict[str, EventDict],
AbiParser._group_by_entry_name(
self._grouped[EVENT_ENTRY], "defined events"
),
)
return Abi(
defined_structures=structures,
defined_enums=enums,
functions={
name: self._parse_function(entry)
for name, entry in functions_dict.items()
},
events={
name: self._parse_event(entry) for name, entry in events_dict.items()
},
)
@property
def type_parser(self) -> TypeParser:
if self._type_parser:
return self._type_parser
raise RuntimeError("Tried to get type_parser before it was set.")
def _parse_structures_and_enums(
self,
) -> Tuple[Dict[str, StructType], Dict[str, EnumType]]:
structs_dict = AbiParser._group_by_entry_name(
self._grouped[STRUCT_ENTRY], "defined structures"
)
enums_dict = AbiParser._group_by_entry_name(
self._grouped[ENUM_ENTRY], "defined enums"
)
# Contains sorted members of the struct
struct_members: Dict[str, List[TypedParameterDict]] = {}
structs: Dict[str, StructType] = {}
# Contains sorted members of the enum
enum_members: Dict[str, List[TypedParameterDict]] = {}
enums: Dict[str, EnumType] = {}
# Example problem (with a simplified json structure):
# [{name: User, fields: {id: Uint256}}, {name: "Uint256", ...}]
# User refers to Uint256 even though it is not known yet (will be parsed next).
# This is why it is important to create the structure types first. This way other types can already refer to
# them when parsing types, even thought their fields are not filled yet.
# At the end we will mutate those structures to contain the right fields. An alternative would be to use
# topological sorting with an additional "unresolved type", so this flow is much easier.
for name, struct in structs_dict.items():
structs[name] = StructType(name, OrderedDict())
struct_members[name] = struct["members"]
for name, enum in enums_dict.items():
enums[name] = EnumType(name, OrderedDict())
enum_members[name] = enum["variants"]
# Now parse the types of members and save them.
defined_structs_enums: Dict[str, Union[StructType, EnumType]] = dict(structs)
defined_structs_enums.update(enums)
self._type_parser = TypeParser(defined_structs_enums)
for name, struct in structs.items():
members = self._parse_members(
cast(List[TypedParameterDict], struct_members[name]),
f"members of structure '{name}'",
)
struct.types.update(members)
for name, enum in enums.items():
members = self._parse_members(
cast(List[TypedParameterDict], enum_members[name]),
f"members of enum '{name}'",
)
enum.variants.update(members)
# All types have their members assigned now
self._check_for_cycles(defined_structs_enums)
return structs, enums
@staticmethod
def _check_for_cycles(structs: Dict[str, Union[StructType, EnumType]]):
# We want to avoid creating our own cycle checker as it would make it more complex. json module has a built-in
# checker for cycles.
try:
_to_json(structs)
except ValueError as err:
raise AbiParsingError(err) from ValueError
def _parse_function(self, function: FunctionDict) -> Abi.Function:
return Abi.Function(
name=function["name"],
inputs=self._parse_members(function["inputs"], function["name"]),
outputs=list(
self.type_parser.parse_inline_type(param["type"])
for param in function["outputs"]
),
)
def _parse_event(self, event: EventDict) -> Abi.Event:
return Abi.Event(
name=event["name"],
inputs=self._parse_members(event["inputs"], event["name"]),
)
def _parse_members(
self, params: List[TypedParameterDict], entity_name: str
) -> OrderedDict[str, CairoType]:
# Without cast, it complains that 'Type "TypedParameterDict" cannot be assigned to type "T@_group_by_name"'
members = AbiParser._group_by_entry_name(cast(List[Dict], params), entity_name)
return OrderedDict(
(name, self.type_parser.parse_inline_type(param["type"]))
for name, param in members.items()
)
@staticmethod
def _group_by_entry_name(
dicts: List[Dict], entity_name: str
) -> OrderedDict[str, Dict]:
grouped = OrderedDict()
for entry in dicts:
name = entry["name"]
if name in grouped:
raise AbiParsingError(
f"Name '{name}' was used more than once in {entity_name}."
)
grouped[name] = entry
return grouped
def _to_json(value):
class DataclassSupportingEncoder(json.JSONEncoder):
def default(self, o):
# Dataclasses are not supported by json. Additionally, dataclasses.asdict() works recursively and doesn't
# check for cycles, so we need to flatten dataclasses (by ONE LEVEL) ourselves.
if dataclasses.is_dataclass(o):
return tuple(getattr(o, field.name) for field in dataclasses.fields(o))
return super().default(o)
return json.dumps(value, cls=DataclassSupportingEncoder)

View File

@@ -0,0 +1,179 @@
from typing import Any, List, Optional
from ....lark import *
from ....lark import Token, Transformer
from ...cairo.data_types import (
ArrayType,
BoolType,
CairoType,
FeltType,
OptionType,
TupleType,
TypeIdentifier,
UintType,
UnitType,
)
ABI_EBNF = """
IDENTIFIER: /[a-zA-Z_][a-zA-Z_0-9]*/
type: type_unit
| type_bool
| type_felt
| type_uint
| type_contract_address
| type_class_hash
| type_storage_address
| type_option
| type_array
| type_span
| tuple
| type_identifier
type_unit: "()"
type_felt: "core::felt252"
type_bool: "core::bool"
type_uint: "core::integer::u" INT
type_contract_address: "core::starknet::contract_address::ContractAddress"
type_class_hash: "core::starknet::class_hash::ClassHash"
type_storage_address: "core::starknet::storage_access::StorageAddress"
type_option: "core::option::Option::<" (type | type_identifier) ">"
type_array: "core::array::Array::<" (type | type_identifier) ">"
type_span: "core::array::Span::<" (type | type_identifier) ">"
tuple: "(" type? ("," type?)* ")"
type_identifier: (IDENTIFIER | "::")+ ("<" (type | ",")+ ">")?
%import common.INT
%import common.WS
%ignore WS
"""
class ParserTransformer(Transformer):
"""
Transforms the lark tree into CairoTypes.
"""
def __init__(self, type_identifiers: Optional[dict] = None) -> None:
if type_identifiers is None:
type_identifiers = {}
self.type_identifiers = type_identifiers
super(Transformer, self).__init__()
# pylint: disable=no-self-use
def __default__(self, data: str, children, meta):
raise TypeError(f"Unable to parse tree node of type {data}.")
def type(self, value: List[Optional[CairoType]]) -> Optional[CairoType]:
"""
Tokens are read bottom-up, so here all of them are parsed and should be just returned.
`Optional` is added in case of the unit type.
"""
assert len(value) == 1
return value[0]
def type_felt(self, _value: List[Any]) -> FeltType:
"""
Felt does not contain any additional arguments, so `_value` is just an empty list.
"""
return FeltType()
def type_bool(self, _value: List[Any]) -> BoolType:
"""
Bool does not contain any additional arguments, so `_value` is just an empty list.
"""
return BoolType()
def type_uint(self, value: List[Token]) -> UintType:
"""
Uint type contains information about its size. It is present in the value[0].
"""
return UintType(int(value[0]))
def type_unit(self, _value: List[Any]) -> UnitType:
"""
`()` type.
"""
return UnitType()
def type_option(self, value: List[CairoType]) -> OptionType:
"""
Option includes an information about which type it eventually represents.
`Optional` is added in case of the unit type.
"""
return OptionType(value[0])
def type_array(self, value: List[CairoType]) -> ArrayType:
"""
Array contains values of type under `value[0]`.
"""
return ArrayType(value[0])
def type_span(self, value: List[CairoType]) -> ArrayType:
"""
Span contains values of type under `value[0]`.
"""
return ArrayType(value[0])
def type_identifier(self, tokens: List[Token]) -> TypeIdentifier:
"""
Structs and enums are defined as follows: (IDENTIFIER | "::")+ [some not important info]
where IDENTIFIER is a string.
Tokens would contain strings and types (if it is present).
We are interested only in the strings because a structure (or enum) name can be built from them.
"""
name = "::".join(token for token in tokens if isinstance(token, str))
if name in self.type_identifiers:
return self.type_identifiers[name]
return TypeIdentifier(name)
def type_contract_address(self, _value: List[Any]) -> FeltType:
"""
ContractAddress is represented by the felt252.
"""
return FeltType()
def type_class_hash(self, _value: List[Any]) -> FeltType:
"""
ClassHash is represented by the felt252.
"""
return FeltType()
def type_storage_address(self, _value: List[Any]) -> FeltType:
"""
StorageAddress is represented by the felt252.
"""
return FeltType()
def tuple(self, types: List[CairoType]) -> TupleType:
"""
Tuple contains values defined in the `types` argument.
"""
return TupleType(types)
def parse(
code: str,
type_identifiers,
) -> CairoType:
"""
Parse the given string and return a CairoType.
"""
grammar_parser = lark.Lark(
grammar=ABI_EBNF,
start="type",
parser="earley",
)
parsed = grammar_parser.parse(code)
parser_transformer = ParserTransformer(type_identifiers)
cairo_type = parser_transformer.transform(parsed)
return cairo_type

View File

@@ -0,0 +1,66 @@
from ....marshmallow import Schema, fields
from ....marshmallow_oneofschema import OneOfSchema
from .shape import (
ENUM_ENTRY,
EVENT_ENTRY,
FUNCTION_ENTRY,
STRUCT_ENTRY,
)
class TypeSchema(Schema):
type = fields.String(data_key="type", required=True)
class TypedParameterSchema(TypeSchema):
name = fields.String(data_key="name", required=True)
class FunctionBaseSchema(Schema):
name = fields.String(data_key="name", required=True)
inputs = fields.List(
fields.Nested(TypedParameterSchema()), data_key="inputs", required=True
)
outputs = fields.List(
fields.Nested(TypeSchema()), data_key="outputs", required=True
)
state_mutability = fields.String(data_key="state_mutability", default=None)
class FunctionAbiEntrySchema(FunctionBaseSchema):
type = fields.Constant(FUNCTION_ENTRY, data_key="type", required=True)
class EventAbiEntrySchema(Schema):
type = fields.Constant(EVENT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
inputs = fields.List(
fields.Nested(TypedParameterSchema()), data_key="inputs", required=True
)
class StructAbiEntrySchema(Schema):
type = fields.Constant(STRUCT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
members = fields.List(
fields.Nested(TypedParameterSchema()), data_key="members", required=True
)
class EnumAbiEntrySchema(Schema):
type = fields.Constant(ENUM_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
variants = fields.List(
fields.Nested(TypedParameterSchema(), data_key="variants", required=True)
)
class ContractAbiEntrySchema(OneOfSchema):
type_field_remove = False
type_schemas = {
FUNCTION_ENTRY: FunctionAbiEntrySchema,
EVENT_ENTRY: EventAbiEntrySchema,
STRUCT_ENTRY: StructAbiEntrySchema,
ENUM_ENTRY: EnumAbiEntrySchema,
}

View File

@@ -0,0 +1,47 @@
from typing import List, Literal, Optional, TypedDict, Union
ENUM_ENTRY = "enum"
STRUCT_ENTRY = "struct"
FUNCTION_ENTRY = "function"
EVENT_ENTRY = "event"
class TypeDict(TypedDict):
type: str
class TypedParameterDict(TypeDict):
name: str
class StructDict(TypedDict):
type: Literal["struct"]
name: str
members: List[TypedParameterDict]
class FunctionBaseDict(TypedDict):
name: str
inputs: List[TypedParameterDict]
outputs: List[TypeDict]
state_mutability: Optional[Literal["external", "view"]]
class FunctionDict(FunctionBaseDict):
type: Literal["function"]
class EventDict(TypedDict):
name: str
type: Literal["event"]
inputs: List[TypedParameterDict]
class EnumDict(TypedDict):
type: Literal["enum"]
name: str
variants: List[TypedParameterDict]
AbiDictEntry = Union[StructDict, FunctionDict, EventDict, EnumDict]
AbiDictList = List[AbiDictEntry]

View File

@@ -0,0 +1,2 @@
from .model import Abi
from .parser import AbiParser, AbiParsingError

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, List, Optional, OrderedDict, Union
from ...cairo.data_types import CairoType, EnumType, EventType, StructType
@dataclass
class Abi:
"""
Dataclass representing class abi. Contains parsed functions, enums, events and structures.
"""
# pylint: disable=too-many-instance-attributes
@dataclass
class Function:
"""
Dataclass representing function's abi.
"""
name: str
inputs: OrderedDict[str, CairoType]
outputs: List[CairoType]
@dataclass
class Constructor:
"""
Dataclass representing constructor's abi.
"""
name: str
inputs: OrderedDict[str, CairoType]
@dataclass
class EventStruct:
"""
Dataclass representing struct event's abi.
"""
name: str
members: OrderedDict[str, CairoType]
@dataclass
class EventEnum:
"""
Dataclass representing enum event's abi.
"""
name: str
variants: OrderedDict[str, CairoType]
Event = Union[EventStruct, EventEnum]
@dataclass
class Interface:
"""
Dataclass representing an interface.
"""
name: str
items: OrderedDict[
str, Abi.Function
] # Only functions can be defined in the interface
@dataclass
class Impl:
"""
Dataclass representing an impl.
"""
name: str
interface_name: str
defined_structures: Dict[
str, StructType
] #: Abi of structures defined by the class.
defined_enums: Dict[str, EnumType] #: Abi of enums defined by the class.
functions: Dict[str, Function] #: Functions defined by the class.
events: Dict[str, EventType] #: Events defined by the class
constructor: Optional[
Constructor
] #: Contract's constructor. It is None if class doesn't define one.
l1_handler: Optional[
Dict[str, Function]
] #: Handlers of L1 messages. It is None if class doesn't define one.
interfaces: Dict[str, Interface]
implementations: Dict[str, Impl]

View File

@@ -0,0 +1,293 @@
from __future__ import annotations
import dataclasses
import json
from collections import OrderedDict, defaultdict
from typing import DefaultDict, Dict, List, Optional, Tuple, TypeVar, Union, cast
from ....marshmallow import EXCLUDE
from .model import Abi
from .schemas import ContractAbiEntrySchema
from .shape import (
CONSTRUCTOR_ENTRY,
ENUM_ENTRY,
EVENT_ENTRY,
FUNCTION_ENTRY,
IMPL_ENTRY,
INTERFACE_ENTRY,
L1_HANDLER_ENTRY,
STRUCT_ENTRY,
ConstructorDict,
EventDict,
EventEnumVariantDict,
EventStructMemberDict,
FunctionDict,
ImplDict,
InterfaceDict,
TypedParameterDict,
)
from ...cairo.data_types import CairoType, EnumType, EventType, StructType
from ...cairo.v2.type_parser import TypeParser
class AbiParsingError(ValueError):
"""
Error raised when something wrong goes during abi parsing.
"""
class AbiParser:
"""
Utility class for parsing abi into a dataclass.
"""
# Entries from ABI grouped by entry type
_grouped: DefaultDict[str, List[Dict]]
# lazy init property
_type_parser: Optional[TypeParser] = None
def __init__(self, abi_list: List[Dict]):
"""
Abi parser constructor. Ensures that abi satisfies the abi schema.
:param abi_list: Contract's ABI as a list of dictionaries.
"""
abi = [
ContractAbiEntrySchema().load(entry, unknown=EXCLUDE) for entry in abi_list
]
grouped = defaultdict(list)
for entry in abi:
assert isinstance(entry, dict)
grouped[entry["type"]].append(entry)
self._grouped = grouped
def parse(self) -> Abi:
"""
Parse abi provided to constructor and return it as a dataclass. Ensures that there are no cycles in the abi.
:raises: AbiParsingError: on any parsing error.
:return: Abi dataclass.
"""
structures, enums = self._parse_structures_and_enums()
events_dict = cast(
Dict[str, EventDict],
AbiParser._group_by_entry_name(
self._grouped[EVENT_ENTRY], "defined events"
),
)
events: Dict[str, EventType] = {}
for name, event in events_dict.items():
events[name] = self._parse_event(event)
assert self._type_parser is not None
self._type_parser.add_defined_type(events[name])
functions_dict = cast(
Dict[str, FunctionDict],
AbiParser._group_by_entry_name(
self._grouped[FUNCTION_ENTRY], "defined functions"
),
)
interfaces_dict = cast(
Dict[str, InterfaceDict],
AbiParser._group_by_entry_name(
self._grouped[INTERFACE_ENTRY], "defined interfaces"
),
)
impls_dict = cast(
Dict[str, ImplDict],
AbiParser._group_by_entry_name(self._grouped[IMPL_ENTRY], "defined impls"),
)
l1_handlers_dict = cast(
Dict[str, FunctionDict],
AbiParser._group_by_entry_name(
self._grouped[L1_HANDLER_ENTRY], "defined L1 handlers"
),
)
constructors = self._grouped[CONSTRUCTOR_ENTRY]
if len(constructors) > 1:
raise AbiParsingError("Constructor in ABI must be defined at most once.")
return Abi(
defined_structures=structures,
defined_enums=enums,
constructor=(
self._parse_constructor(cast(ConstructorDict, constructors[0]))
if constructors
else None
),
l1_handler={
name: self._parse_function(entry)
for name, entry in l1_handlers_dict.items()
},
functions={
name: self._parse_function(entry)
for name, entry in functions_dict.items()
},
events=events,
interfaces={
name: self._parse_interface(entry)
for name, entry in interfaces_dict.items()
},
implementations={
name: self._parse_impl(entry) for name, entry in impls_dict.items()
},
)
@property
def type_parser(self) -> TypeParser:
if self._type_parser:
return self._type_parser
raise RuntimeError("Tried to get type_parser before it was set.")
def _parse_structures_and_enums(
self,
) -> Tuple[Dict[str, StructType], Dict[str, EnumType]]:
structs_dict = AbiParser._group_by_entry_name(
self._grouped[STRUCT_ENTRY], "defined structures"
)
enums_dict = AbiParser._group_by_entry_name(
self._grouped[ENUM_ENTRY], "defined enums"
)
# Contains sorted members of the struct
struct_members: Dict[str, List[TypedParameterDict]] = {}
structs: Dict[str, StructType] = {}
# Contains sorted members of the enum
enum_members: Dict[str, List[TypedParameterDict]] = {}
enums: Dict[str, EnumType] = {}
# Example problem (with a simplified json structure):
# [{name: User, fields: {id: Uint256}}, {name: "Uint256", ...}]
# User refers to Uint256 even though it is not known yet (will be parsed next).
# This is why it is important to create the structure types first. This way other types can already refer to
# them when parsing types, even thought their fields are not filled yet.
# At the end we will mutate those structures to contain the right fields. An alternative would be to use
# topological sorting with an additional "unresolved type", so this flow is much easier.
for name, struct in structs_dict.items():
structs[name] = StructType(name, OrderedDict())
struct_members[name] = struct["members"]
for name, enum in enums_dict.items():
enums[name] = EnumType(name, OrderedDict())
enum_members[name] = enum["variants"]
# Now parse the types of members and save them.
defined_structs_enums: Dict[str, Union[StructType, EnumType]] = dict(structs)
defined_structs_enums.update(enums)
self._type_parser = TypeParser(defined_structs_enums) # pyright: ignore
for name, struct in structs.items():
members = self._parse_members(
cast(List[TypedParameterDict], struct_members[name]),
f"members of structure '{name}'",
)
struct.types.update(members)
for name, enum in enums.items():
members = self._parse_members(
cast(List[TypedParameterDict], enum_members[name]),
f"members of enum '{name}'",
)
enum.variants.update(members)
# All types have their members assigned now
self._check_for_cycles(defined_structs_enums)
return structs, enums
@staticmethod
def _check_for_cycles(structs: Dict[str, Union[StructType, EnumType]]):
# We want to avoid creating our own cycle checker as it would make it more complex. json module has a built-in
# checker for cycles.
try:
_to_json(structs)
except ValueError as err:
raise AbiParsingError(err) from ValueError
def _parse_function(self, function: FunctionDict) -> Abi.Function:
return Abi.Function(
name=function["name"],
inputs=self._parse_members(function["inputs"], function["name"]),
outputs=list(
self.type_parser.parse_inline_type(param["type"])
for param in function["outputs"]
),
)
def _parse_constructor(self, constructor: ConstructorDict) -> Abi.Constructor:
return Abi.Constructor(
name=constructor["name"],
inputs=self._parse_members(constructor["inputs"], constructor["name"]),
)
def _parse_event(self, event: EventDict) -> EventType:
members_ = event.get("members", event.get("variants"))
assert isinstance(members_, list)
return EventType(
name=event["name"],
types=self._parse_members(
cast(List[TypedParameterDict], members_), event["name"]
),
)
TypedParam = TypeVar(
"TypedParam", TypedParameterDict, EventStructMemberDict, EventEnumVariantDict
)
def _parse_members(
self, params: List[TypedParam], entity_name: str
) -> OrderedDict[str, CairoType]:
# Without cast, it complains that 'Type "TypedParameterDict" cannot be assigned to type "T@_group_by_name"'
members = AbiParser._group_by_entry_name(cast(List[Dict], params), entity_name)
return OrderedDict(
(name, self.type_parser.parse_inline_type(param["type"]))
for name, param in members.items()
)
def _parse_interface(self, interface: InterfaceDict) -> Abi.Interface:
return Abi.Interface(
name=interface["name"],
items=OrderedDict(
(entry["name"], self._parse_function(entry))
for entry in interface["items"]
),
)
@staticmethod
def _parse_impl(impl: ImplDict) -> Abi.Impl:
return Abi.Impl(
name=impl["name"],
interface_name=impl["interface_name"],
)
@staticmethod
def _group_by_entry_name(
dicts: List[Dict], entity_name: str
) -> OrderedDict[str, Dict]:
grouped = OrderedDict()
for entry in dicts:
name = entry["name"]
if name in grouped:
raise AbiParsingError(
f"Name '{name}' was used more than once in {entity_name}."
)
grouped[name] = entry
return grouped
def _to_json(value):
class DataclassSupportingEncoder(json.JSONEncoder):
def default(self, o):
# Dataclasses are not supported by json. Additionally, dataclasses.asdict() works recursively and doesn't
# check for cycles, so we need to flatten dataclasses (by ONE LEVEL) ourselves.
if dataclasses.is_dataclass(o):
return tuple(getattr(o, field.name) for field in dataclasses.fields(o))
return super().default(o)
return json.dumps(value, cls=DataclassSupportingEncoder)

View File

@@ -0,0 +1,192 @@
from typing import Any, List, Optional
from ....lark import *
from ....lark import Token, Transformer
from ...cairo.data_types import (
ArrayType,
BoolType,
CairoType,
FeltType,
OptionType,
TupleType,
TypeIdentifier,
UintType,
UnitType,
)
ABI_EBNF = """
IDENTIFIER: /[a-zA-Z_][a-zA-Z_0-9]*/
type: "@"? actual_type
actual_type: type_unit
| type_bool
| type_felt
| type_bytes
| type_uint
| type_contract_address
| type_class_hash
| type_storage_address
| type_option
| type_array
| type_span
| tuple
| type_identifier
type_unit: "()"
type_felt: "core::felt252"
type_bytes: "core::bytes_31::bytes31"
type_bool: "core::bool"
type_uint: "core::integer::u" INT
type_contract_address: "core::starknet::contract_address::ContractAddress"
type_class_hash: "core::starknet::class_hash::ClassHash"
type_storage_address: "core::starknet::storage_access::StorageAddress"
type_option: "core::option::Option::<" (type | type_identifier) ">"
type_array: "core::array::Array::<" (type | type_identifier) ">"
type_span: "core::array::Span::<" (type | type_identifier) ">"
tuple: "(" type? ("," type?)* ")"
type_identifier: (IDENTIFIER | "::")+ ("<" (type | ",")+ ">")?
%import common.INT
%import common.WS
%ignore WS
"""
class ParserTransformer(Transformer):
"""
Transforms the lark tree into CairoTypes.
"""
def __init__(self, type_identifiers: Optional[dict] = None) -> None:
if type_identifiers is None:
type_identifiers = {}
self.type_identifiers = type_identifiers
super(Transformer, self).__init__()
# pylint: disable=no-self-use
def __default__(self, data: str, children, meta):
raise TypeError(f"Unable to parse tree node of type {data}.")
def type(self, value: List[Optional[CairoType]]) -> Optional[CairoType]:
"""
Tokens are read bottom-up, so here all of them are parsed and should be just returned.
`Optional` is added in case of the unit type.
"""
assert len(value) == 1
return value[0]
def actual_type(self, value) -> Optional[CairoType]:
return value[0]
def type_felt(self, _value: List[Any]) -> FeltType:
"""
Felt does not contain any additional arguments, so `_value` is just an empty list.
"""
return FeltType()
def type_bytes(self, _value: List[Any]) -> FeltType:
"""
Felt does not contain any additional arguments, so `_value` is just an empty list.
"""
return FeltType()
def type_bool(self, _value: List[Any]) -> BoolType:
"""
Bool does not contain any additional arguments, so `_value` is just an empty list.
"""
return BoolType()
def type_uint(self, value: List[Token]) -> UintType:
"""
Uint type contains information about its size. It is present in the value[0].
"""
return UintType(int(value[0]))
def type_unit(self, _value: List[Any]) -> UnitType:
"""
`()` type.
"""
return UnitType()
def type_option(self, value: List[CairoType]) -> OptionType:
"""
Option includes an information about which type it eventually represents.
`Optional` is added in case of the unit type.
"""
return OptionType(value[0])
def type_array(self, value: List[CairoType]) -> ArrayType:
"""
Array contains values of type under `value[0]`.
"""
return ArrayType(value[0])
def type_span(self, value: List[CairoType]) -> ArrayType:
"""
Span contains values of type under `value[0]`.
"""
return ArrayType(value[0])
def type_identifier(self, tokens: List[Token]) -> TypeIdentifier:
"""
Structs and enums are defined as follows: (IDENTIFIER | "::")+ [some not important info]
where IDENTIFIER is a string.
Tokens would contain strings and types (if it is present).
We are interested only in the strings because a structure (or enum) name can be built from them.
"""
name = "::".join(token for token in tokens if isinstance(token, str))
if name in self.type_identifiers:
return self.type_identifiers[name]
return TypeIdentifier(name)
def type_contract_address(self, _value: List[Any]) -> FeltType:
"""
ContractAddress is represented by the felt252.
"""
return FeltType()
def type_class_hash(self, _value: List[Any]) -> FeltType:
"""
ClassHash is represented by the felt252.
"""
return FeltType()
def type_storage_address(self, _value: List[Any]) -> FeltType:
"""
StorageAddress is represented by the felt252.
"""
return FeltType()
def tuple(self, types: List[CairoType]) -> TupleType:
"""
Tuple contains values defined in the `types` argument.
"""
return TupleType(types)
def parse(
code: str,
type_identifiers,
) -> CairoType:
"""
Parse the given string and return a CairoType.
"""
grammar_parser = lark.Lark(
grammar=ABI_EBNF,
start="type",
parser="earley",
)
parsed_lark_tree = grammar_parser.parse(code)
parser_transformer = ParserTransformer(type_identifiers)
cairo_type = parser_transformer.transform(parsed_lark_tree)
return cairo_type

View File

@@ -0,0 +1,132 @@
from ....marshmallow import Schema, fields
from ....marshmallow_oneofschema import OneOfSchema
from .shape import (
CONSTRUCTOR_ENTRY,
DATA_KIND,
ENUM_ENTRY,
EVENT_ENTRY,
FUNCTION_ENTRY,
IMPL_ENTRY,
INTERFACE_ENTRY,
L1_HANDLER_ENTRY,
NESTED_KIND,
STRUCT_ENTRY,
)
class TypeSchema(Schema):
type = fields.String(data_key="type", required=True)
class TypedParameterSchema(TypeSchema):
name = fields.String(data_key="name", required=True)
class FunctionBaseSchema(Schema):
name = fields.String(data_key="name", required=True)
inputs = fields.List(
fields.Nested(TypedParameterSchema()), data_key="inputs", required=True
)
outputs = fields.List(
fields.Nested(TypeSchema()), data_key="outputs", required=True
)
state_mutability = fields.String(data_key="state_mutability", default=None)
class FunctionAbiEntrySchema(FunctionBaseSchema):
type = fields.Constant(FUNCTION_ENTRY, data_key="type", required=True)
class ConstructorAbiEntrySchema(Schema):
type = fields.Constant(CONSTRUCTOR_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
inputs = fields.List(
fields.Nested(TypedParameterSchema()), data_key="inputs", required=True
)
class L1HandlerAbiEntrySchema(FunctionBaseSchema):
type = fields.Constant(L1_HANDLER_ENTRY, data_key="type", required=True)
class EventStructMemberSchema(TypedParameterSchema):
kind = fields.Constant(DATA_KIND, data_key="kind", required=True)
class EventStructAbiEntrySchema(Schema):
type = fields.Constant(EVENT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
kind = fields.Constant(STRUCT_ENTRY, data_key="kind", required=True)
members = fields.List(
fields.Nested(EventStructMemberSchema()), data_key="members", required=True
)
class EventEnumVariantSchema(TypedParameterSchema):
kind = fields.Constant(NESTED_KIND, data_key="kind", required=True)
class EventEnumAbiEntrySchema(Schema):
type = fields.Constant(EVENT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
kind = fields.Constant(ENUM_ENTRY, data_key="kind", required=True)
variants = fields.List(
fields.Nested(EventEnumVariantSchema()), data_key="variants", required=True
)
class EventAbiEntrySchema(OneOfSchema):
type_field = "kind"
type_field_remove = False
type_schemas = {
STRUCT_ENTRY: EventStructAbiEntrySchema,
ENUM_ENTRY: EventEnumAbiEntrySchema,
}
class StructAbiEntrySchema(Schema):
type = fields.Constant(STRUCT_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
members = fields.List(
fields.Nested(TypedParameterSchema()), data_key="members", required=True
)
class EnumAbiEntrySchema(Schema):
type = fields.Constant(ENUM_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
variants = fields.List(
fields.Nested(TypedParameterSchema(), data_key="variants", required=True)
)
class ImplAbiEntrySchema(Schema):
type = fields.Constant(IMPL_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
interface_name = fields.String(data_key="interface_name", required=True)
class InterfaceAbiEntrySchema(Schema):
type = fields.Constant(INTERFACE_ENTRY, data_key="type", required=True)
name = fields.String(data_key="name", required=True)
items = fields.List(
fields.Nested(
FunctionAbiEntrySchema(), data_key="items", required=True
) # for now only functions can be defined here
)
class ContractAbiEntrySchema(OneOfSchema):
type_field_remove = False
type_schemas = {
FUNCTION_ENTRY: FunctionAbiEntrySchema,
EVENT_ENTRY: EventAbiEntrySchema,
STRUCT_ENTRY: StructAbiEntrySchema,
ENUM_ENTRY: EnumAbiEntrySchema,
CONSTRUCTOR_ENTRY: ConstructorAbiEntrySchema,
L1_HANDLER_ENTRY: L1HandlerAbiEntrySchema,
IMPL_ENTRY: ImplAbiEntrySchema,
INTERFACE_ENTRY: InterfaceAbiEntrySchema,
}

View File

@@ -0,0 +1,107 @@
from __future__ import annotations
from typing import List, Literal, Optional, TypedDict, Union
STRUCT_ENTRY = "struct"
EVENT_ENTRY = "event"
FUNCTION_ENTRY = "function"
ENUM_ENTRY = "enum"
CONSTRUCTOR_ENTRY = "constructor"
L1_HANDLER_ENTRY = "l1_handler"
IMPL_ENTRY = "impl"
INTERFACE_ENTRY = "interface"
DATA_KIND = "data"
NESTED_KIND = "nested"
class TypeDict(TypedDict):
type: str
class TypedParameterDict(TypeDict):
name: str
class StructDict(TypedDict):
type: Literal["struct"]
name: str
members: List[TypedParameterDict]
class FunctionBaseDict(TypedDict):
name: str
inputs: List[TypedParameterDict]
outputs: List[TypeDict]
state_mutability: Optional[Literal["external", "view"]]
class FunctionDict(FunctionBaseDict):
type: Literal["function"]
class ConstructorDict(TypedDict):
type: Literal["constructor"]
name: str
inputs: List[TypedParameterDict]
class L1HandlerDict(FunctionBaseDict):
type: Literal["l1_handler"]
class EventBaseDict(TypedDict):
type: Literal["event"]
name: str
class EventStructMemberDict(TypedParameterDict):
kind: Literal["data"]
class EventStructDict(EventBaseDict):
kind: Literal["struct"]
members: List[EventStructMemberDict]
class EventEnumVariantDict(TypedParameterDict):
kind: Literal["nested"]
class EventEnumDict(EventBaseDict):
kind: Literal["enum"]
variants: List[EventEnumVariantDict]
EventDict = Union[EventStructDict, EventEnumDict]
class EnumDict(TypedDict):
type: Literal["enum"]
name: str
variants: List[TypedParameterDict]
class ImplDict(TypedDict):
type: Literal["impl"]
name: str
interface_name: str
class InterfaceDict(TypedDict):
type: Literal["interface"]
name: str
items: List[FunctionDict] # for now only functions can be defined here
AbiDictEntry = Union[
StructDict,
FunctionDict,
EventDict,
EnumDict,
ConstructorDict,
L1HandlerDict,
ImplDict,
InterfaceDict,
]
AbiDictList = List[AbiDictEntry]

View File

@@ -0,0 +1,123 @@
from __future__ import annotations
from abc import ABC
from collections import OrderedDict
from dataclasses import dataclass
from typing import List
class CairoType(ABC):
"""
Base type for all Cairo type representations. All types extend it.
"""
@dataclass
class FeltType(CairoType):
"""
Type representation of Cairo field element.
"""
@dataclass
class BoolType(CairoType):
"""
Type representation of Cairo boolean.
"""
@dataclass
class TupleType(CairoType):
"""
Type representation of Cairo tuples without named fields.
"""
types: List[CairoType] #: types of every tuple element.
@dataclass
class NamedTupleType(CairoType):
"""
Type representation of Cairo tuples with named fields.
"""
types: OrderedDict[str, CairoType] #: types of every tuple member.
@dataclass
class ArrayType(CairoType):
"""
Type representation of Cairo arrays.
"""
inner_type: CairoType #: type of element inside array.
@dataclass
class StructType(CairoType):
"""
Type representation of Cairo structures.
"""
name: str #: Structure name
# We need ordered dict, because it is important in serialization
types: OrderedDict[str, CairoType] #: types of every structure member.
@dataclass
class EnumType(CairoType):
"""
Type representation of Cairo enums.
"""
name: str
variants: OrderedDict[str, CairoType]
@dataclass
class OptionType(CairoType):
"""
Type representation of Cairo options.
"""
type: CairoType
@dataclass
class UintType(CairoType):
"""
Type representation of Cairo unsigned integers.
"""
bits: int
def check_range(self, value: int):
"""
Utility method checking if the `value` is in range.
"""
@dataclass
class TypeIdentifier(CairoType):
"""
Type representation of Cairo identifiers.
"""
name: str
@dataclass
class UnitType(CairoType):
"""
Type representation of Cairo unit `()`.
"""
@dataclass
class EventType(CairoType):
"""
Type representation of Cairo Event.
"""
name: str
types: OrderedDict[str, CairoType]

View File

@@ -0,0 +1,77 @@
import dataclasses
from typing import List, Optional
class CairoType:
"""
Base class for cairo types.
"""
@dataclasses.dataclass
class TypeFelt(CairoType):
pass
@dataclasses.dataclass
class TypeCodeoffset(CairoType):
pass
@dataclasses.dataclass
class TypePointer(CairoType):
pointee: CairoType
@dataclasses.dataclass
class TypeIdentifier(CairoType):
"""
Represents a name of an unresolved type.
This type can be resolved to TypeStruct or TypeDefinition.
"""
name: str
@dataclasses.dataclass
class TypeStruct(CairoType):
scope: str
@dataclasses.dataclass
class TypeFunction(CairoType):
"""
Represents a type of a function.
"""
scope: str
@dataclasses.dataclass
class TypeTuple(CairoType):
"""
Represents a type of a named or unnamed tuple.
For example, "(felt, felt*)" or "(a: felt, b: felt*)".
"""
@dataclasses.dataclass
class Item(CairoType):
"""
Represents a possibly named type item of a TypeTuple.
For example: "felt" or "a: felt".
"""
name: Optional[str]
typ: CairoType
members: List["TypeTuple.Item"]
has_trailing_comma: bool = dataclasses.field(hash=False, compare=False)
@property
def is_named(self) -> bool:
return all(member.name is not None for member in self.members)
@dataclasses.dataclass
class ExprIdentifier(CairoType):
name: str

View File

@@ -0,0 +1,46 @@
from ....lark import Lark
from .cairo_types import CairoType
from .parser_transformer import ParserTransformer
CAIRO_EBNF = """
%import common.WS_INLINE
%ignore WS_INLINE
IDENTIFIER: /[a-zA-Z_][a-zA-Z_0-9]*/
_DBL_STAR: "**"
COMMA: ","
?type: non_identifier_type
| identifier -> type_struct
comma_separated{item}: item? (COMMA item)* COMMA?
named_type: identifier (":" type)? | non_identifier_type
non_identifier_type: "felt" -> type_felt
| "codeoffset" -> type_codeoffset
| type "*" -> type_pointer
| type _DBL_STAR -> type_pointer2
| "(" comma_separated{named_type} ")" -> type_tuple
identifier: IDENTIFIER ("." IDENTIFIER)*
"""
def parse(code: str) -> CairoType:
"""
Parses the given string and returns a CairoType.
"""
grammar = CAIRO_EBNF
grammar_parser = Lark(
grammar=grammar,
start=["type"],
parser="lalr",
)
parsed = grammar_parser.parse(code)
transformed = ParserTransformer().transform(parsed)
return transformed

View File

@@ -0,0 +1,138 @@
import dataclasses
from typing import Optional, Tuple
from ....lark import Token, Transformer, v_args
from .cairo_types import (
CairoType,
ExprIdentifier,
TypeCodeoffset,
TypeFelt,
TypeIdentifier,
TypePointer,
TypeStruct,
TypeTuple,
)
@dataclasses.dataclass
class ParserContext:
"""
Represents information that affects the parsing process.
"""
# If True, treat type identifiers as resolved.
resolved_types: bool = False
class ParserError(Exception):
"""
Base exception for parsing process.
"""
@dataclasses.dataclass
class CommaSeparated:
"""
Represents a list of comma separated values, such as expressions or types.
"""
args: list
has_trailing_comma: bool
class ParserTransformer(Transformer):
"""
Transforms the lark tree into an AST based on the classes defined in cairo_types.py.
"""
# pylint: disable=unused-argument, no-self-use
def __init__(self):
super().__init__()
self.parser_context = ParserContext()
def __default__(self, data: str, children, meta):
raise TypeError(f"Unable to parse tree node of type {data}")
def comma_separated(self, value) -> CommaSeparated:
saw_comma = None
args: list = []
for v in value:
if isinstance(v, Token) and v.type == "COMMA":
if saw_comma is not False:
raise ParserError("Unexpected comma.")
saw_comma = True
else:
if saw_comma is False:
raise ParserError("Expected a comma before this expression.")
args.append(v)
# Reset state.
saw_comma = False
if saw_comma is None:
saw_comma = False
return CommaSeparated(args=args, has_trailing_comma=saw_comma)
# Types.
@v_args(meta=True)
def named_type(self, meta, value) -> TypeTuple.Item:
name: Optional[str]
if len(value) == 1:
# Unnamed type.
(typ,) = value
name = None
if isinstance(typ, ExprIdentifier):
typ = self.type_struct([typ])
elif len(value) == 2:
# Named type.
identifier, typ = value
assert isinstance(identifier, ExprIdentifier)
assert isinstance(typ, CairoType)
if "." in identifier.name:
raise ParserError("Unexpected . in name.")
name = identifier.name
else:
raise NotImplementedError(f"Unexpected number of values. {value}")
return TypeTuple.Item(name=name, typ=typ)
@v_args(meta=True)
def type_felt(self, meta, value):
return TypeFelt()
@v_args(meta=True)
def type_codeoffset(self, meta, value):
return TypeCodeoffset()
def type_struct(self, value):
assert len(value) == 1 and isinstance(value[0], ExprIdentifier)
if self.parser_context.resolved_types:
# If parser_context.resolved_types is True, assume that the type is a struct.
return TypeStruct(scope=value[0].name)
return TypeIdentifier(name=value[0].name)
@v_args(meta=True)
def type_pointer(self, meta, value):
return TypePointer(pointee=value[0])
@v_args(meta=True)
def type_pointer2(self, meta, value):
return TypePointer(pointee=TypePointer(pointee=value[0]))
@v_args(meta=True)
def type_tuple(self, meta, value: Tuple[CommaSeparated]):
(lst,) = value
return TypeTuple(members=lst.args, has_trailing_comma=lst.has_trailing_comma)
@v_args(meta=True)
def identifier(self, meta, value):
return ExprIdentifier(name=".".join(x.value for x in value))
@v_args(meta=True)
def identifier_def(self, meta, value):
return ExprIdentifier(name=value[0].value)

View File

@@ -0,0 +1,64 @@
from typing import List
from ..constants import FIELD_PRIME
CairoData = List[int]
MAX_UINT256 = (1 << 256) - 1
MIN_UINT256 = 0
def uint256_range_check(value: int):
if not MIN_UINT256 <= value <= MAX_UINT256:
raise ValueError(
f"Uint256 is expected to be in range [0;2**256), got: {value}."
)
MIN_FELT = -FIELD_PRIME // 2
MAX_FELT = FIELD_PRIME // 2
def is_in_felt_range(value: int) -> bool:
return 0 <= value < FIELD_PRIME
def cairo_vm_range_check(value: int):
if not is_in_felt_range(value):
raise ValueError(
f"Felt is expected to be in range [0; {FIELD_PRIME}), got: {value}."
)
def encode_shortstring(text: str) -> int:
"""
A function which encodes short string value (at most 31 characters) into cairo felt (MSB as first character)
:param text: A short string value in python
:return: Short string value encoded into felt
"""
if len(text) > 31:
raise ValueError(
f"Shortstring cannot be longer than 31 characters, got: {len(text)}."
)
try:
text_bytes = text.encode("ascii")
except UnicodeEncodeError as u_err:
raise ValueError(f"Expected an ascii string. Found: {repr(text)}.") from u_err
value = int.from_bytes(text_bytes, "big")
cairo_vm_range_check(value)
return value
def decode_shortstring(value: int) -> str:
"""
A function which decodes a felt value to short string (at most 31 characters)
:param value: A felt value
:return: Decoded string which is corresponds to that felt
"""
cairo_vm_range_check(value)
return "".join([chr(i) for i in value.to_bytes(31, byteorder="big")]).lstrip("\x00")

View File

@@ -0,0 +1,121 @@
from __future__ import annotations
from collections import OrderedDict
from typing import Dict, cast
from .deprecated_parse import cairo_types as cairo_lang_types
from .data_types import (
ArrayType,
CairoType,
FeltType,
NamedTupleType,
StructType,
TupleType,
)
from .deprecated_parse.parser import parse
class UnknownCairoTypeError(ValueError):
"""
Error thrown when TypeParser finds type that was not declared prior to parsing.
"""
type_name: str
def __init__(self, type_name: str):
super().__init__(f"Type '{type_name}' is not defined")
self.type_name = type_name
class TypeParser:
"""
Low level utility class for parsing Cairo types that can be used in external methods.
"""
defined_types: Dict[str, StructType]
def __init__(self, defined_types: Dict[str, StructType]):
"""
TypeParser constructor.
:param defined_types: dictionary containing all defined types. For now, they can only be structures.
"""
self.defined_types = defined_types
for name, struct in defined_types.items():
if name != struct.name:
raise ValueError(
f"Keys must match name of type, '{name}' != '{struct.name}'."
)
def parse_inline_type(self, type_string: str) -> CairoType:
"""
Inline type is one that can be used inline, for instance as return type. For instance
(a: Uint256, b: felt*, c: (felt, felt)). Structure can only be referenced in inline type, can't be defined
this way.
:param type_string: type to parse.
"""
parsed = parse(type_string)
return self._transform_cairo_lang_type(parsed)
def _transform_cairo_lang_type(
self, cairo_type: cairo_lang_types.CairoType
) -> CairoType:
"""
For now, we use parse function from cairo-lang package. It will be replaced in the future, but we need to hide
it from the users.
This function takes types returned by cairo-lang package and maps them to our type classes.
:param cairo_type: type returned from parse_type function.
:return: CairoType defined by our package.
"""
if isinstance(cairo_type, cairo_lang_types.TypeFelt):
return FeltType()
if isinstance(cairo_type, cairo_lang_types.TypePointer):
return ArrayType(self._transform_cairo_lang_type(cairo_type.pointee))
if isinstance(cairo_type, cairo_lang_types.TypeIdentifier):
return self._get_struct(str(cairo_type.name))
if isinstance(cairo_type, cairo_lang_types.TypeTuple):
# Cairo returns is_named when there are no members
if cairo_type.is_named and len(cairo_type.members) != 0:
assert all(member.name is not None for member in cairo_type.members)
return NamedTupleType(
OrderedDict(
(
cast(
str, member.name
), # without that pyright is complaining
self._transform_cairo_lang_type(member.typ),
)
for member in cairo_type.members
)
)
return TupleType(
[
self._transform_cairo_lang_type(member.typ)
for member in cairo_type.members
]
)
# Contracts don't support codeoffset as input/output type, user can only use it if it was defined in types
if isinstance(cairo_type, cairo_lang_types.TypeCodeoffset):
return self._get_struct("codeoffset")
# Other options are: TypeFunction, TypeStruct
# Neither of them are possible. In particular TypeStruct is not possible because we parse structs without
# info about other structs, so they will be just TypeIdentifier (structure that was not parsed).
# This is an error of our logic, so we throw a RuntimeError.
raise RuntimeError(
f"Received unknown type '{cairo_type}' from parser."
) # pragma: no cover
def _get_struct(self, name: str):
if name not in self.defined_types:
raise UnknownCairoTypeError(name)
return self.defined_types[name]

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from typing import Dict, Union
from ...abi.v1.parser_transformer import parse
from ..data_types import CairoType, EnumType, StructType, TypeIdentifier
class UnknownCairoTypeError(ValueError):
"""
Error thrown when TypeParser finds type that was not declared prior to parsing.
"""
type_name: str
def __init__(self, type_name: str):
super().__init__(
# pylint: disable=line-too-long
f"Type '{type_name}' is not defined. Please report this issue at https://github.com/software-mansion/starknet.py/issues"
)
self.type_name = type_name
class TypeParser:
"""
Low level utility class for parsing Cairo types that can be used in external methods.
"""
defined_types: Dict[str, Union[StructType, EnumType]]
def __init__(self, defined_types: Dict[str, Union[StructType, EnumType]]):
"""
TypeParser constructor.
:param defined_types: dictionary containing all defined types. For now, they can only be structures.
"""
self.defined_types = defined_types
for name, defined_type in defined_types.items():
if name != defined_type.name:
raise ValueError(
f"Keys must match name of type, '{name}' != '{defined_type.name}'."
)
def parse_inline_type(self, type_string: str) -> CairoType:
"""
Inline type is one that can be used inline, for instance as return type. For instance
(core::felt252, (), (core::felt252,)). Structure can only be referenced in inline type, can't be defined
this way.
:param type_string: type to parse.
"""
parsed = parse(type_string, self.defined_types)
if isinstance(parsed, TypeIdentifier):
for defined_name in self.defined_types.keys():
if parsed.name == defined_name.split("<")[0].strip(":"):
return self.defined_types[defined_name]
raise UnknownCairoTypeError(parsed.name)
return parsed

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from typing import Dict, Union
from ...abi.v2.parser_transformer import parse
from ..data_types import (
CairoType,
EnumType,
EventType,
StructType,
TypeIdentifier,
)
class UnknownCairoTypeError(ValueError):
"""
Error thrown when TypeParser finds type that was not declared prior to parsing.
"""
type_name: str
def __init__(self, type_name: str):
super().__init__(
# pylint: disable=line-too-long
f"Type '{type_name}' is not defined. Please report this issue at https://github.com/software-mansion/starknet.py/issues"
)
self.type_name = type_name
class TypeParser:
"""
Low level utility class for parsing Cairo types that can be used in external methods.
"""
defined_types: Dict[str, Union[StructType, EnumType, EventType]]
def __init__(
self, defined_types: Dict[str, Union[StructType, EnumType, EventType]]
):
"""
TypeParser constructor.
:param defined_types: dictionary containing all defined types. For now, they can only be structures.
"""
self.defined_types = defined_types
for name, defined_type in defined_types.items():
if name != defined_type.name:
raise ValueError(
f"Keys must match name of type, '{name}' != '{defined_type.name}'."
)
def update_defined_types(
self, defined_types: Dict[str, Union[StructType, EnumType, EventType]]
) -> None:
self.defined_types.update(defined_types)
def add_defined_type(
self, defined_type: Union[StructType, EnumType, EventType]
) -> None:
self.defined_types.update({defined_type.name: defined_type})
def parse_inline_type(self, type_string: str) -> CairoType:
"""
Inline type is one that can be used inline, for instance as return type. For instance
(core::felt252, (), (core::felt252,)). Structure can only be referenced in inline type, can't be defined
this way.
:param type_string: type to parse.
"""
parsed = parse(type_string, self.defined_types)
if isinstance(parsed, TypeIdentifier):
for defined_name in self.defined_types.keys():
if parsed.name == defined_name.split("<")[0].strip(":"):
return self.defined_types[defined_name]
raise UnknownCairoTypeError(parsed.name)
return parsed

View File

@@ -0,0 +1,7 @@
# utils to use starknet library in ccxt
from .constants import EC_ORDER
from ..starkware.crypto.signature import grind_key
def get_private_key_from_eth_signature(eth_signature_hex: str) -> int:
r = eth_signature_hex[2 : 64 + 2] if eth_signature_hex[0:2] == '0x' else eth_signature_hex[0 : 64]
return grind_key(int(r, 16), EC_ORDER)

View File

@@ -0,0 +1,15 @@
from typing import Literal, Union
def int_from_hex(number: Union[str, int]) -> int:
return number if isinstance(number, int) else int(number, 16)
def int_from_bytes(
value: bytes,
byte_order: Literal["big", "little"] = "big",
signed: bool = False,
) -> int:
"""
Converts the given bytes object (parsed according to the given byte order) to an integer.
"""
return int.from_bytes(value, byteorder=byte_order, signed=signed)

View File

@@ -0,0 +1,39 @@
from pathlib import Path
# Address came from starkware-libs/starknet-addresses repository: https://github.com/starkware-libs/starknet-addresses
FEE_CONTRACT_ADDRESS = (
"0x049d36570d4e46f48e99674bd3fcc84644ddd6b96f7c741b1562b82f9e004dc7"
)
DEFAULT_DEPLOYER_ADDRESS = (
"0x041a78e741e5aF2fEc34B695679bC6891742439f7AFB8484Ecd7766661aD02BF"
)
API_VERSION = 0
RPC_CONTRACT_NOT_FOUND_ERROR = 20
RPC_INVALID_MESSAGE_SELECTOR_ERROR = 21
RPC_CLASS_HASH_NOT_FOUND_ERROR = 28
RPC_CONTRACT_ERROR = 40
DEFAULT_ENTRY_POINT_NAME = "__default__"
DEFAULT_L1_ENTRY_POINT_NAME = "__l1_default__"
DEFAULT_ENTRY_POINT_SELECTOR = 0
DEFAULT_DECLARE_SENDER_ADDRESS = 1
# MAX_STORAGE_ITEM_SIZE and ADDR_BOUND must be consistent with the corresponding constant in
# starkware/starknet/common/storage.cairo.
MAX_STORAGE_ITEM_SIZE = 256
ADDR_BOUND = 2**251 - MAX_STORAGE_ITEM_SIZE
FIELD_PRIME = 0x800000000000011000000000000000000000000000000000000000000000001
EC_ORDER = 0x800000000000010FFFFFFFFFFFFFFFFB781126DCAE7B2321E66A241ADC64D2F
# From cairo-lang
# int_from_bytes(b"STARKNET_CONTRACT_ADDRESS")
CONTRACT_ADDRESS_PREFIX = 523065374597054866729014270389667305596563390979550329787219
L2_ADDRESS_UPPER_BOUND = 2**251 - 256
QUERY_VERSION_BASE = 2**128
ROOT_PATH = Path(__file__).parent

View File

@@ -0,0 +1,79 @@
from typing import Sequence
from ..constants import CONTRACT_ADDRESS_PREFIX, L2_ADDRESS_UPPER_BOUND
from .utils import (
HEX_PREFIX,
_starknet_keccak,
compute_hash_on_elements,
encode_uint,
get_bytes_length,
)
def compute_address(
*,
class_hash: int,
constructor_calldata: Sequence[int],
salt: int,
deployer_address: int = 0,
) -> int:
"""
Computes the contract address in the Starknet network - a unique identifier of the contract.
:param class_hash: class hash of the contract
:param constructor_calldata: calldata for the contract constructor
:param salt: salt used to calculate contract address
:param deployer_address: address of the deployer (if not provided default 0 is used)
:return: Contract's address
"""
constructor_calldata_hash = compute_hash_on_elements(data=constructor_calldata)
raw_address = compute_hash_on_elements(
data=[
CONTRACT_ADDRESS_PREFIX,
deployer_address,
salt,
class_hash,
constructor_calldata_hash,
],
)
return raw_address % L2_ADDRESS_UPPER_BOUND
def get_checksum_address(address: str) -> str:
"""
Outputs formatted checksum address.
Follows implementation of starknet.js. It is not compatible with EIP55 as it treats hex string as encoded number,
instead of encoding it as ASCII string.
:param address: Address to encode
:return: Checksum address
"""
if not address.lower().startswith(HEX_PREFIX):
raise ValueError(f"{address} is not a valid hexadecimal address.")
int_address = int(address, 16)
string_address = address[2:].zfill(64)
address_in_bytes = encode_uint(int_address, get_bytes_length(int_address))
address_hash = _starknet_keccak(address_in_bytes)
result = "".join(
(
char.upper()
if char.isalpha() and (address_hash >> 256 - 4 * i - 1) & 1
else char
)
for i, char in enumerate(string_address)
)
return f"{HEX_PREFIX}{result}"
def is_checksum_address(address: str) -> bool:
"""
Checks if provided string is in a checksum address format.
"""
return get_checksum_address(address) == address

View File

@@ -0,0 +1,111 @@
# File is copied from
# https://github.com/starkware-libs/cairo-lang/blob/v0.13.1/src/starkware/starknet/core/os/contract_class/compiled_class_hash_objects.py
import dataclasses
import itertools
from abc import ABC, abstractmethod
from typing import Any, List, Union
from poseidon_py.poseidon_hash import poseidon_hash_many
class BytecodeSegmentStructure(ABC):
"""
Represents the structure of the bytecode to allow loading it partially into the OS memory.
See the documentation of the OS function `bytecode_hash_node` in `compiled_class.cairo`
for more details.
"""
@abstractmethod
def hash(self) -> int:
"""
Computes the hash of the node.
"""
def bytecode_with_skipped_segments(self):
"""
Returns the bytecode of the node.
Skipped segments are replaced with [-1, -2, -2, -2, ...].
"""
res: List[int] = []
self.add_bytecode_with_skipped_segments(res)
return res
@abstractmethod
def add_bytecode_with_skipped_segments(self, data: List[int]):
"""
Same as bytecode_with_skipped_segments, but appends the result to the given list.
"""
@dataclasses.dataclass
class BytecodeLeaf(BytecodeSegmentStructure):
"""
Represents a leaf in the bytecode segment tree.
"""
data: List[int]
def hash(self) -> int:
return poseidon_hash_many(self.data)
def add_bytecode_with_skipped_segments(self, data: List[int]):
data.extend(self.data)
@dataclasses.dataclass
class BytecodeSegmentedNode(BytecodeSegmentStructure):
"""
Represents an internal node in the bytecode segment tree.
Each child can be loaded into memory or skipped.
"""
segments: List["BytecodeSegment"]
def hash(self) -> int:
return (
poseidon_hash_many(
itertools.chain( # pyright: ignore
*[
(node.segment_length, node.inner_structure.hash())
for node in self.segments
]
)
)
+ 1
)
def add_bytecode_with_skipped_segments(self, data: List[int]):
for segment in self.segments:
if segment.is_used:
segment.inner_structure.add_bytecode_with_skipped_segments(data)
else:
data.append(-1)
data.extend(-2 for _ in range(segment.segment_length - 1))
@dataclasses.dataclass
class BytecodeSegment:
"""
Represents a child of BytecodeSegmentedNode.
"""
# The length of the segment.
segment_length: int
# Should the segment (or part of it) be loaded to memory.
# In other words, is the segment used during the execution.
# Note that if is_used is False, the entire segment is not loaded to memory.
# If is_used is True, it is possible that part of the segment will be skipped (according
# to the "is_used" field of the child segments).
is_used: bool
# The inner structure of the segment.
inner_structure: BytecodeSegmentStructure
def __post_init__(self):
assert (
self.segment_length > 0
), f"Invalid segment length: {self.segment_length}."
# Represents a nested list of integers. E.g., [1, [2, [3], 4], 5, 6].
NestedIntList = Union[int, List[Any]]

View File

@@ -0,0 +1,16 @@
from ..constants import (
DEFAULT_ENTRY_POINT_NAME,
DEFAULT_ENTRY_POINT_SELECTOR,
DEFAULT_L1_ENTRY_POINT_NAME,
)
from ..hash.utils import _starknet_keccak
def get_selector_from_name(func_name: str) -> int:
"""
Returns the selector of a contract's function name.
"""
if func_name in [DEFAULT_ENTRY_POINT_NAME, DEFAULT_L1_ENTRY_POINT_NAME]:
return DEFAULT_ENTRY_POINT_SELECTOR
return _starknet_keccak(data=func_name.encode("ascii"))

View File

@@ -0,0 +1,12 @@
from functools import reduce
from constants import ADDR_BOUND
from hash.utils import _starknet_keccak, pedersen_hash
def get_storage_var_address(var_name: str, *args: int) -> int:
"""
Returns the storage address of a Starknet storage variable given its name and arguments.
"""
res = _starknet_keccak(var_name.encode("ascii"))
return reduce(pedersen_hash, args, res) % ADDR_BOUND

View File

@@ -0,0 +1,78 @@
import functools
from typing import List, Optional, Sequence
from ... import keccak
from ..common import int_from_bytes
from ..constants import EC_ORDER
from ...starkware.crypto.signature import (
ECSignature,
private_to_stark_key,
sign
# verify
)
from ...starkware.crypto.fast_pedersen_hash import (
pedersen_hash
)
MASK_250 = 2**250 - 1
HEX_PREFIX = "0x"
def _starknet_keccak(data: bytes) -> int:
"""
A variant of eth-keccak that computes a value that fits in a Starknet field element.
"""
return int_from_bytes(keccak.SHA3(data)) & MASK_250
# def pedersen_hash(left: int, right: int) -> int:
# """
# One of two hash functions (along with _starknet_keccak) used throughout Starknet.
# """
# return cpp_hash(left, right)
def compute_hash_on_elements(data: Sequence) -> int:
"""
Computes a hash chain over the data, in the following order:
h(h(h(h(0, data[0]), data[1]), ...), data[n-1]), n).
The hash is initialized with 0 and ends with the data length appended.
The length is appended in order to avoid collisions of the following kind:
H([x,y,z]) = h(h(x,y),z) = H([w, z]) where w = h(x,y).
"""
return functools.reduce(pedersen_hash, [*data, len(data)], 0)
def message_signature(
msg_hash: int, priv_key: int, seed: Optional[int] = 32
) -> ECSignature:
"""
Signs the message with private key.
"""
return sign(msg_hash, priv_key, seed)
# def verify_message_signature(
# msg_hash: int, signature: List[int], public_key: int
# ) -> bool:
# """
# Verifies ECDSA signature of a given message hash with a given public key.
# Returns true if public_key signs the message.
# """
# sig_r, sig_s = signature
# # sig_w = pow(sig_s, -1, EC_ORDER)
# return verify(msg_hash=msg_hash, r=sig_r, s=sig_s, public_key=public_key)
def encode_uint(value: int, bytes_length: int = 32) -> bytes:
return value.to_bytes(bytes_length, byteorder="big")
def encode_uint_list(data: List[int]) -> bytes:
return b"".join(encode_uint(x) for x in data)
def get_bytes_length(value: int) -> int:
return (value.bit_length() + 7) // 8

View File

@@ -0,0 +1,45 @@
"""
TypedDict structures for TypedData
"""
from enum import Enum
from typing import Any, Dict, List, Optional, TypedDict
class Revision(Enum):
"""
Enum representing the revision of the specification to be used.
"""
V0 = 0
V1 = 1
class ParameterDict(TypedDict):
"""
TypedDict representing a Parameter object
"""
name: str
type: str
class StarkNetDomainDict(TypedDict):
"""
TypedDict representing a domain object (both StarkNetDomain, StarknetDomain).
"""
name: str
version: str
chainId: str
revision: Optional[Revision]
class TypedDataDict(TypedDict):
"""
TypedDict representing a TypedData object
"""
types: Dict[str, List[ParameterDict]]
primaryType: str
domain: StarkNetDomainDict
message: Dict[str, Any]

View File

@@ -0,0 +1,24 @@
# PayloadSerializer and FunctionSerializationAdapter would mostly be used by users
from .data_serializers import (
ArraySerializer,
CairoDataSerializer,
FeltSerializer,
NamedTupleSerializer,
PayloadSerializer,
StructSerializer,
TupleSerializer,
Uint256Serializer,
)
from .errors import (
CairoSerializerException,
InvalidTypeException,
InvalidValueException,
)
from .factory import (
serializer_for_event,
serializer_for_function,
serializer_for_payload,
serializer_for_type,
)
from .function_serialization_adapter import FunctionSerializationAdapter
from .tuple_dataclass import TupleDataclass

View File

@@ -0,0 +1,40 @@
from typing import List
from ..cairo.felt import CairoData
class OutOfBoundsError(Exception):
def __init__(self, position: int, requested_size: int, remaining_size: int):
super().__init__(
f"Requested {requested_size} elements, {remaining_size} available."
)
self.position = position
self.requested_size = requested_size
self.remaining_len = remaining_size
class CalldataReader:
_data: List[int]
_position: int
def __init__(self, data: List[int]):
self._data = data
self._position = 0
@property
def remaining_len(self) -> int:
return len(self._data) - self._position
def read(self, size: int) -> CairoData:
if size < 1:
raise ValueError("size must be greater than 0")
if size > self.remaining_len:
raise OutOfBoundsError(
position=self._position,
requested_size=size,
remaining_size=self.remaining_len,
)
data = self._data[self._position : self._position + size]
self._position += size
return data

View File

@@ -0,0 +1,142 @@
from __future__ import annotations
from abc import ABC
from contextlib import contextmanager
from typing import Any, Generator, Iterator, List
from ._calldata_reader import (
CairoData,
CalldataReader,
OutOfBoundsError,
)
from .errors import InvalidTypeException, InvalidValueException
class Context(ABC):
"""
Holds information about context when (de)serializing data. This is needed to inform what and where went
wrong during processing. Every separate (de)serialization should have its own context.
"""
_namespace_stack: List[str]
def __init__(self):
self._namespace_stack = []
@property
def current_entity(self):
"""
Name of currently processed entity.
:return: transformed path.
"""
return ".".join(self._namespace_stack)
@contextmanager
def push_entity(self, name: str) -> Generator:
"""
Manager used for maintaining information about names of (de)serialized types. Wraps some errors with
custom errors, adding information about the context.
:param name: name of (de)serialized entity.
"""
# This ensures the name will be popped if everything is ok. In case an exception is raised we want the stack to
# be filled to wrap the error at the end.
self._namespace_stack.append(name)
yield
self._namespace_stack.pop()
def ensure_valid_value(self, valid: bool, text: str):
if not valid:
raise InvalidValueException(f"{self._error_prefix}: {text}.")
def ensure_valid_type(self, value: Any, valid: bool, expected_type: str):
if not valid:
raise InvalidTypeException(
f"{self._error_prefix}: expected {expected_type}, "
f"received '{value}' of type '{type(value)}'."
)
@contextmanager
def _wrap_errors(self):
try:
yield
except OutOfBoundsError as err:
action_name = (
f"deserialize '{self.current_entity}'"
if self._namespace_stack
else "deserialize"
)
# This way we can precisely inform user what's wrong when reading calldata.
raise InvalidValueException(
f"Not enough data to {action_name}. "
f"Can't read {err.requested_size} values at position {err.position}, {err.remaining_len} available."
) from err
# Those two are based on ValueError and TypeError, we have to catch them early
except (InvalidValueException, InvalidTypeException) as err:
raise err
except ValueError as err:
raise InvalidValueException(f"{self._error_prefix}: {err}") from err
except TypeError as err:
raise InvalidTypeException(f"{self._error_prefix}: {err}") from err
@property
def _error_prefix(self):
if not self._namespace_stack:
return "Error"
return f"Error at path '{self.current_entity}'"
class SerializationContext(Context):
"""
Context used during serialization.
"""
# Type is iterator, because ContextManager doesn't work with pyright :|
# https://github.com/microsoft/pyright/issues/476
@classmethod
@contextmanager
def create(cls) -> Iterator[SerializationContext]:
context = cls()
with context._wrap_errors():
yield context
class DeserializationContext(Context):
"""
Context used during deserialization.
"""
reader: CalldataReader
def __init__(self, calldata: CairoData):
"""
Don't use default constructor. Use DeserializationContext.create context manager.
"""
super().__init__()
self._namespace_stack = []
self.reader = CalldataReader(calldata)
@classmethod
@contextmanager
def create(cls, data: CairoData) -> Iterator[DeserializationContext]:
context = cls(data)
with context._wrap_errors():
yield context
context._ensure_all_values_read(len(data))
def _ensure_all_values_read(self, total_len: int):
values_not_used = self.reader.remaining_len
if values_not_used != 0:
# We want to output up to 3 values. It there is more they will be truncated like "0x1,0x1,0x1..."
max_values_to_show = 3
values_to_show = min(values_not_used, max_values_to_show)
example = ",".join(hex(v) for v in self.reader.read(values_to_show))
suffix = "..." if values_not_used > max_values_to_show else ""
raise InvalidValueException(
f"Last {values_not_used} values '{example}{suffix}' out of total {total_len} "
"values were not used during deserialization."
)

View File

@@ -0,0 +1,10 @@
from .array_serializer import ArraySerializer
from .bool_serializer import BoolSerializer
from .byte_array_serializer import ByteArraySerializer
from .cairo_data_serializer import CairoDataSerializer
from .felt_serializer import FeltSerializer
from .named_tuple_serializer import NamedTupleSerializer
from .payload_serializer import PayloadSerializer
from .struct_serializer import StructSerializer
from .tuple_serializer import TupleSerializer
from .uint256_serializer import Uint256Serializer

View File

@@ -0,0 +1,82 @@
# We have to use parametrised type from typing
from collections import OrderedDict as _OrderedDict
from typing import Dict, Generator, List, OrderedDict
from .._context import (
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
# The actual serialization logic is very similar among all serializers: they either serialize data based on
# position or their name. Having this logic reused adds indirection, but makes sure proper logic is used everywhere.
def deserialize_to_list(
deserializers: List[CairoDataSerializer], context: DeserializationContext
) -> List:
"""
Deserializes data from context to list. This logic is used in every sequential type (arrays and tuples).
"""
result = []
for index, serializer in enumerate(deserializers):
with context.push_entity(f"[{index}]"):
result.append(serializer.deserialize_with_context(context))
return result
def deserialize_to_dict(
deserializers: OrderedDict[str, CairoDataSerializer],
context: DeserializationContext,
) -> OrderedDict:
"""
Deserializes data from context to dictionary. This logic is used in every type with named fields (structs,
named tuples and payloads).
"""
result = _OrderedDict()
for key, serializer in deserializers.items():
with context.push_entity(key):
result[key] = serializer.deserialize_with_context(context)
return result
def serialize_from_list(
serializers: List[CairoDataSerializer], context: SerializationContext, values: List
) -> Generator[int, None, None]:
"""
Serializes data from list. This logic is used in every sequential type (arrays and tuples).
"""
context.ensure_valid_value(
len(serializers) == len(values),
f"expected {len(serializers)} elements, {len(values)} provided",
)
for index, (serializer, value) in enumerate(zip(serializers, values)):
with context.push_entity(f"[{index}]"):
yield from serializer.serialize_with_context(context, value)
def serialize_from_dict(
serializers: OrderedDict[str, CairoDataSerializer],
context: SerializationContext,
values: Dict,
) -> Generator[int, None, None]:
"""
Serializes data from dict. This logic is used in every type with named fields (structs, named tuples and payloads).
"""
excessive_keys = set(values.keys()).difference(serializers.keys())
context.ensure_valid_value(
not excessive_keys,
f"unexpected keys '{','.join(excessive_keys)}' were provided",
)
for name, serializer in serializers.items():
with context.push_entity(name):
context.ensure_valid_value(name in values, f"key '{name}' is missing")
yield from serializer.serialize_with_context(context, values[name])

View File

@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import Generator, Iterable, List
from .._context import (
DeserializationContext,
SerializationContext,
)
from ..data_serializers._common import (
deserialize_to_list,
serialize_from_list,
)
from ..data_serializers.cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class ArraySerializer(CairoDataSerializer[Iterable, List]):
"""
Serializer for arrays. In abi they are represented as a pointer to a type.
Can serialize any iterable and prepends its length to resulting list.
Deserializes data to a list.
Examples:
[1,2,3] => [3,1,2,3]
[] => [0]
"""
inner_serializer: CairoDataSerializer
def deserialize_with_context(self, context: DeserializationContext) -> List:
with context.push_entity("len"):
[size] = context.reader.read(1)
return deserialize_to_list([self.inner_serializer] * size, context)
def serialize_with_context(
self, context: SerializationContext, value: List
) -> Generator[int, None, None]:
yield len(value)
yield from serialize_from_list(
[self.inner_serializer] * len(value), context, value
)

View File

@@ -0,0 +1,37 @@
from dataclasses import dataclass
from typing import Generator
from .._context import (
Context,
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class BoolSerializer(CairoDataSerializer[bool, int]):
"""
Serializer for boolean.
"""
def deserialize_with_context(self, context: DeserializationContext) -> bool:
[val] = context.reader.read(1)
self._ensure_bool(context, val)
return bool(val)
def serialize_with_context(
self, context: SerializationContext, value: bool
) -> Generator[int, None, None]:
context.ensure_valid_type(value, isinstance(value, bool), "bool")
self._ensure_bool(context, value)
yield int(value)
@staticmethod
def _ensure_bool(context: Context, value: int):
context.ensure_valid_value(
value in [0, 1],
f"invalid value '{value}' - must be in [0, 2) range",
)

View File

@@ -0,0 +1,66 @@
from dataclasses import dataclass
from typing import Generator
from ...cairo.felt import decode_shortstring, encode_shortstring
from .._context import (
DeserializationContext,
SerializationContext,
)
from ._common import (
deserialize_to_list,
serialize_from_list,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
from .felt_serializer import FeltSerializer
BYTES_31_SIZE = 31
@dataclass
class ByteArraySerializer(CairoDataSerializer[str, str]):
"""
Serializer for ByteArrays. Serializes to and deserializes from str values.
Examples:
"" => [0,0,0]
"hello" => [0,448378203247,5]
"""
def deserialize_with_context(self, context: DeserializationContext) -> str:
with context.push_entity("data_array_len"):
[size] = context.reader.read(1)
data = deserialize_to_list([FeltSerializer()] * size, context)
with context.push_entity("pending_word"):
[pending_word] = context.reader.read(1)
with context.push_entity("pending_word_len"):
[pending_word_len] = context.reader.read(1)
pending_word = decode_shortstring(pending_word)
context.ensure_valid_value(
len(pending_word) == pending_word_len,
f"Invalid length {pending_word_len} for pending word {pending_word}",
)
data_joined = "".join(map(decode_shortstring, data))
return data_joined + pending_word
def serialize_with_context(
self, context: SerializationContext, value: str
) -> Generator[int, None, None]:
context.ensure_valid_type(value, isinstance(value, str), "str")
data = [
value[i : i + BYTES_31_SIZE] for i in range(0, len(value), BYTES_31_SIZE)
]
pending_word = (
"" if len(data) == 0 or len(data[-1]) == BYTES_31_SIZE else data.pop(-1)
)
yield len(data)
yield from serialize_from_list([FeltSerializer()] * len(data), context, data)
yield encode_shortstring(pending_word)
yield len(pending_word)

View File

@@ -0,0 +1,71 @@
from abc import ABC, abstractmethod
from typing import Generator, Generic, List, TypeVar
from .._calldata_reader import CairoData
from .._context import (
DeserializationContext,
SerializationContext,
)
# Python type that is accepted by a serializer
# pylint: disable=invalid-name
SerializationType = TypeVar("SerializationType")
# Python type that will be returned from a serializer. Often same as SerializationType.
# pylint: disable=invalid-name
DeserializationType = TypeVar("DeserializationType")
class CairoDataSerializer(ABC, Generic[SerializationType, DeserializationType]):
"""
Base class for serializing/deserializing data to/from calldata.
"""
def deserialize(self, data: List[int]) -> DeserializationType:
"""
Transform calldata into python value.
:param data: calldata to deserialize.
:return: defined DeserializationType.
"""
with DeserializationContext.create(data) as context:
return self.deserialize_with_context(context)
def serialize(self, data: SerializationType) -> CairoData:
"""
Transform python data into calldata.
:param data: data to serialize.
:return: calldata.
"""
with SerializationContext.create() as context:
serialized_data = list(self.serialize_with_context(context, data))
return self.remove_units_from_serialized_data(serialized_data)
@abstractmethod
def deserialize_with_context(
self, context: DeserializationContext
) -> DeserializationType:
"""
Transform calldata into python value.
:param context: context of this deserialization.
:return: defined DeserializationType.
"""
@abstractmethod
def serialize_with_context(
self, context: SerializationContext, value: SerializationType
) -> Generator[int, None, None]:
"""
Transform python value into calldata.
:param context: context of this serialization.
:param value: python value to serialize.
:return: defined SerializationType.
"""
@staticmethod
def remove_units_from_serialized_data(serialized_data: List) -> List:
return [x for x in serialized_data if x is not None]

View File

@@ -0,0 +1,71 @@
from dataclasses import dataclass
from typing import Dict, Generator, OrderedDict, Tuple, Union
from .._context import (
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
from ..tuple_dataclass import TupleDataclass
@dataclass
class EnumSerializer(CairoDataSerializer[Union[Dict, TupleDataclass], TupleDataclass]):
"""
Serializer of enums.
Can serialize a dictionary and TupleDataclass.
Deserializes data to a TupleDataclass.
Example:
enum MyEnum {
a: u128,
b: u128
}
{"a": 1} => [0, 1]
{"b": 100} => [1, 100]
TupleDataclass(variant='a', value=100) => [0, 100]
"""
serializers: OrderedDict[str, CairoDataSerializer]
def deserialize_with_context(
self, context: DeserializationContext
) -> TupleDataclass:
[variant_index] = context.reader.read(1)
variant_name, serializer = self._get_variant(variant_index)
with context.push_entity("enum.variant: " + variant_name):
result_dict = {
"variant": variant_name,
"value": serializer.deserialize_with_context(context),
}
return TupleDataclass.from_dict(result_dict)
def serialize_with_context(
self, context: SerializationContext, value: Union[Dict, TupleDataclass]
) -> Generator[int, None, None]:
if isinstance(value, Dict):
items = list(value.items())
if len(items) != 1:
raise ValueError(
"Can serialize only one enum variant, got: " + str(len(items))
)
variant_name, variant_value = items[0]
else:
variant_name, variant_value = value
yield self._get_variant_index(variant_name)
yield from self.serializers[variant_name].serialize_with_context(
context, variant_value
)
def _get_variant(self, variant_index: int) -> Tuple[str, CairoDataSerializer]:
return list(self.serializers.items())[variant_index]
def _get_variant_index(self, variant_name: str) -> int:
return list(self.serializers.keys()).index(variant_name)

View File

@@ -0,0 +1,50 @@
import warnings
from dataclasses import dataclass
from typing import Generator
from ...cairo.felt import encode_shortstring, is_in_felt_range
from ...constants import FIELD_PRIME
from .._context import (
Context,
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class FeltSerializer(CairoDataSerializer[int, int]):
"""
Serializer for field element. At the time of writing it is the only existing numeric type.
"""
def deserialize_with_context(self, context: DeserializationContext) -> int:
[val] = context.reader.read(1)
self._ensure_felt(context, val)
return val
def serialize_with_context(
self, context: SerializationContext, value: int
) -> Generator[int, None, None]:
if isinstance(value, str):
warnings.warn(
"Serializing shortstrings in FeltSerializer is deprecated. "
"Use starknet_py.cairo.felt.encode_shortstring instead.",
category=DeprecationWarning,
)
value = encode_shortstring(value)
yield value
return
context.ensure_valid_type(value, isinstance(value, int), "int")
self._ensure_felt(context, value)
yield value
@staticmethod
def _ensure_felt(context: Context, value: int):
context.ensure_valid_value(
is_in_felt_range(value),
f"invalid value '{value}' - must be in [0, {FIELD_PRIME}) range",
)

View File

@@ -0,0 +1,58 @@
from dataclasses import dataclass
from typing import Dict, Generator, NamedTuple, OrderedDict, Union
from .._context import (
DeserializationContext,
SerializationContext,
)
from ._common import (
deserialize_to_dict,
serialize_from_dict,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
from ..tuple_dataclass import TupleDataclass
@dataclass
class NamedTupleSerializer(
CairoDataSerializer[Union[Dict, NamedTuple, TupleDataclass], TupleDataclass]
):
"""
Serializer for tuples with named fields.
Can serialize a dictionary, a named tuple and TupleDataclass.
Deserializes data to a TupleDataclass.
Example:
{"a": 1, "b": 2} => [1,2]
"""
serializers: OrderedDict[str, CairoDataSerializer]
def deserialize_with_context(
self, context: DeserializationContext
) -> TupleDataclass:
as_dictionary = deserialize_to_dict(self.serializers, context)
return TupleDataclass.from_dict(as_dictionary)
def serialize_with_context(
self,
context: SerializationContext,
value: Union[Dict, NamedTuple, TupleDataclass],
) -> Generator[int, None, None]:
# We can't use isinstance(value, NamedTuple), because there is no NamedTuple type.
context.ensure_valid_type(
value,
isinstance(value, (dict, TupleDataclass)) or self._is_namedtuple(value),
"dict, NamedTuple or TupleDataclass",
)
# noinspection PyUnresolvedReferences, PyProtectedMember
values: Dict = value if isinstance(value, dict) else value._asdict()
yield from serialize_from_dict(self.serializers, context, values)
@staticmethod
def _is_namedtuple(value) -> bool:
return isinstance(value, tuple) and hasattr(value, "_fields")

View File

@@ -0,0 +1,43 @@
from dataclasses import dataclass
from typing import Any, Generator, Optional
from .._context import (
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class OptionSerializer(CairoDataSerializer[Optional[Any], Optional[Any]]):
"""
Serializer for Option type.
Can serialize None and common CairoTypes.
Deserializes data to None or CairoType.
Example:
None => [1]
{"option1": 123, "option2": None} => [0, 123, 1]
"""
serializer: CairoDataSerializer
def deserialize_with_context(
self, context: DeserializationContext
) -> Optional[Any]:
(is_none,) = context.reader.read(1)
if is_none == 1:
return None
return self.serializer.deserialize_with_context(context)
def serialize_with_context(
self, context: SerializationContext, value: Optional[Any]
) -> Generator[int, None, None]:
if value is None:
yield 1
else:
yield 0
yield from self.serializer.serialize_with_context(context, value)

View File

@@ -0,0 +1,40 @@
from dataclasses import dataclass, field
from typing import Dict, Generator, List, Tuple
from .._context import (
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class OutputSerializer(CairoDataSerializer[List, Tuple]):
"""
Serializer for function output.
Can't serialize anything.
Deserializes data to a Tuple.
Example:
[1, 1, 1] => (340282366920938463463374607431768211457)
"""
serializers: List[CairoDataSerializer] = field(init=True)
def deserialize_with_context(self, context: DeserializationContext) -> Tuple:
result = []
for index, serializer in enumerate(self.serializers):
with context.push_entity("output[" + str(index) + "]"):
result.append(serializer.deserialize_with_context(context))
return tuple(result)
def serialize_with_context(
self, context: SerializationContext, value: Dict
) -> Generator[int, None, None]:
raise ValueError(
"Output serializer can't be used to transform python data into calldata."
)

View File

@@ -0,0 +1,72 @@
from collections import OrderedDict as _OrderedDict
from dataclasses import InitVar, dataclass, field
from typing import Dict, Generator, OrderedDict
from .._context import (
DeserializationContext,
SerializationContext,
)
from ._common import (
deserialize_to_dict,
serialize_from_dict,
)
from .array_serializer import ArraySerializer
from .cairo_data_serializer import (
CairoDataSerializer,
)
from .felt_serializer import FeltSerializer
from ..tuple_dataclass import TupleDataclass
SIZE_SUFFIX = "_len"
SIZE_SUFFIX_LEN = len(SIZE_SUFFIX)
@dataclass
class PayloadSerializer(CairoDataSerializer[Dict, TupleDataclass]):
"""
Serializer for payloads like function arguments/function outputs/events.
Can serialize a dictionary.
Deserializes data to a TupleDataclass.
Example:
{"a": 1, "b": 2} => [1,2]
"""
# Value present only in constructor.
# We don't want to mutate the serializers received in constructor.
input_serializers: InitVar[OrderedDict[str, CairoDataSerializer]]
serializers: OrderedDict[str, CairoDataSerializer] = field(init=False)
def __post_init__(self, input_serializers):
"""
ABI adds ARG_len for every argument ARG that is an array. We parse length as a part of ArraySerializer, so we
need to remove those lengths from args.
"""
self.serializers = _OrderedDict(
(key, serializer)
for key, serializer in input_serializers.items()
if not self._is_len_arg(key, input_serializers)
)
def deserialize_with_context(
self, context: DeserializationContext
) -> TupleDataclass:
as_dictionary = deserialize_to_dict(self.serializers, context)
return TupleDataclass.from_dict(as_dictionary)
def serialize_with_context(
self, context: SerializationContext, value: Dict
) -> Generator[int, None, None]:
yield from serialize_from_dict(self.serializers, context, value)
@staticmethod
def _is_len_arg(arg_name: str, serializers: Dict[str, CairoDataSerializer]) -> bool:
return (
arg_name.endswith(SIZE_SUFFIX)
and isinstance(serializers[arg_name], FeltSerializer)
# There is an ArraySerializer under key that is arg_name without the size suffix
and isinstance(
serializers.get(arg_name[:-SIZE_SUFFIX_LEN]), ArraySerializer
)
)

View File

@@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Dict, Generator, OrderedDict
from .._context import (
DeserializationContext,
SerializationContext,
)
from ._common import (
deserialize_to_dict,
serialize_from_dict,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class StructSerializer(CairoDataSerializer[Dict, Dict]):
"""
Serializer of custom structures.
Can serialize a dictionary.
Deserializes data to a dictionary.
Example:
{"a": 1, "b": 2} => [1,2]
"""
serializers: OrderedDict[str, CairoDataSerializer]
def deserialize_with_context(self, context: DeserializationContext) -> Dict:
return deserialize_to_dict(self.serializers, context)
def serialize_with_context(
self, context: SerializationContext, value: Dict
) -> Generator[int, None, None]:
yield from serialize_from_dict(self.serializers, context, value)

View File

@@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Generator, Iterable, List, Tuple
from .._context import (
DeserializationContext,
SerializationContext,
)
from ._common import (
deserialize_to_list,
serialize_from_list,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class TupleSerializer(CairoDataSerializer[Iterable, Tuple]):
"""
Serializer for tuples without named fields.
Can serialize any iterable.
Deserializes data to a python tuple.
Example:
(1,2,(3,4)) => [1,2,3,4]
"""
serializers: List[CairoDataSerializer]
def deserialize_with_context(self, context: DeserializationContext) -> Tuple:
return tuple(deserialize_to_list(self.serializers, context))
def serialize_with_context(
self, context: SerializationContext, value: Iterable
) -> Generator[int, None, None]:
yield from serialize_from_list(self.serializers, context, [*value])

View File

@@ -0,0 +1,76 @@
from dataclasses import dataclass
from typing import Generator, TypedDict, Union
from ...cairo.felt import uint256_range_check
from .._context import (
Context,
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
U128_UPPER_BOUND = 2**128
class Uint256Dict(TypedDict):
low: int
high: int
@dataclass
class Uint256Serializer(CairoDataSerializer[Union[int, Uint256Dict], int]):
"""
Serializer of Uint256. In Cairo it is represented by structure {low: Uint128, high: Uint128}.
Can serialize an int.
Deserializes data to an int.
Examples:
0 => [0,0]
1 => [1,0]
2**128 => [0,1]
3 + 2**128 => [3,1]
"""
def deserialize_with_context(self, context: DeserializationContext) -> int:
[low, high] = context.reader.read(2)
# Checking if resulting value is in [0, 2**256) range is not enough. Uint256 should be made of two uint128.
with context.push_entity("low"):
self._ensure_valid_uint128(low, context)
with context.push_entity("high"):
self._ensure_valid_uint128(high, context)
return (high << 128) + low
def serialize_with_context(
self, context: SerializationContext, value: Union[int, Uint256Dict]
) -> Generator[int, None, None]:
context.ensure_valid_type(value, isinstance(value, (int, dict)), "int or dict")
if isinstance(value, int):
yield from self._serialize_from_int(value)
else:
yield from self._serialize_from_dict(context, value)
@staticmethod
def _serialize_from_int(value: int) -> Generator[int, None, None]:
uint256_range_check(value)
result = (value % 2**128, value // 2**128)
yield from result
def _serialize_from_dict(
self, context: SerializationContext, value: Uint256Dict
) -> Generator[int, None, None]:
with context.push_entity("low"):
self._ensure_valid_uint128(value["low"], context)
yield value["low"]
with context.push_entity("high"):
self._ensure_valid_uint128(value["high"], context)
yield value["high"]
@staticmethod
def _ensure_valid_uint128(value: int, context: Context):
context.ensure_valid_value(
0 <= value < U128_UPPER_BOUND, "expected value in range [0;2**128)"
)

View File

@@ -0,0 +1,100 @@
from dataclasses import dataclass
from typing import Generator, TypedDict, Union
from ...cairo.felt import uint256_range_check
from .._context import (
Context,
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
class Uint256Dict(TypedDict):
low: int
high: int
@dataclass
class UintSerializer(CairoDataSerializer[Union[int, Uint256Dict], int]):
"""
Serializer of uint. In Cairo there are few uints (u8, ..., u128 and u256).
u256 is represented by structure {low: u128, high: u128}.
Can serialize an int and dict.
Deserializes data to an int.
Examples:
if bits < 256:
0 => [0]
1 => [1]
2**128-1 => [2**128-1]
else:
0 => [0,0]
1 => [1,0]
2**128 => [0,1]
3 + 2**128 => [3,1]
"""
bits: int
def deserialize_with_context(self, context: DeserializationContext) -> int:
if self.bits < 256:
(uint,) = context.reader.read(1)
with context.push_entity("uint" + str(self.bits)):
self._ensure_valid_uint(uint, context, self.bits)
return uint
[low, high] = context.reader.read(2)
# Checking if resulting value is in [0, 2**256) range is not enough. Uint256 should be made of two uint128.
with context.push_entity("low"):
self._ensure_valid_uint(low, context, bits=128)
with context.push_entity("high"):
self._ensure_valid_uint(high, context, bits=128)
return (high << 128) + low
def serialize_with_context(
self, context: SerializationContext, value: Union[int, Uint256Dict]
) -> Generator[int, None, None]:
context.ensure_valid_type(value, isinstance(value, (int, dict)), "int or dict")
if isinstance(value, int):
yield from self._serialize_from_int(value, context, self.bits)
else:
yield from self._serialize_from_dict(context, value)
@staticmethod
def _serialize_from_int(
value: int, context: SerializationContext, bits: int
) -> Generator[int, None, None]:
if bits < 256:
UintSerializer._ensure_valid_uint(value, context, bits)
yield value
else:
uint256_range_check(value)
result = (value % 2**128, value >> 128)
yield from result
def _serialize_from_dict(
self, context: SerializationContext, value: Uint256Dict
) -> Generator[int, None, None]:
with context.push_entity("low"):
self._ensure_valid_uint(value["low"], context, bits=128)
yield value["low"]
with context.push_entity("high"):
self._ensure_valid_uint(value["high"], context, bits=128)
yield value["high"]
@staticmethod
def _ensure_valid_uint(value: int, context: Context, bits: int):
"""
Ensures that value is a valid uint on `bits` bits.
"""
context.ensure_valid_value(
0 <= value < 2**bits, "expected value in range [0;2**" + str(bits) + ")"
)

View File

@@ -0,0 +1,32 @@
from dataclasses import dataclass
from typing import Any, Generator, Optional
from .._context import (
DeserializationContext,
SerializationContext,
)
from .cairo_data_serializer import (
CairoDataSerializer,
)
@dataclass
class UnitSerializer(CairoDataSerializer[None, None]):
"""
Serializer for unit type.
Can only serialize None.
Deserializes data to None.
Example:
[] => None
"""
def deserialize_with_context(self, context: DeserializationContext) -> None:
return None
def serialize_with_context(
self, context: SerializationContext, value: Optional[Any]
) -> Generator[None, None, None]:
if value is not None:
raise ValueError("Can only serialize `None`.")
yield None

View File

@@ -0,0 +1,10 @@
class CairoSerializerException(Exception):
"""Exception thrown by CairoSerializer."""
class InvalidTypeException(CairoSerializerException, TypeError):
"""Exception thrown when invalid type was provided."""
class InvalidValueException(CairoSerializerException, ValueError):
"""Exception thrown when invalid value was provided."""

View File

@@ -0,0 +1,229 @@
from __future__ import annotations
from collections import OrderedDict
from typing import Dict, List, Union
from ..abi.v0 import Abi as AbiV0
from ..abi.v1 import Abi as AbiV1
from ..abi.v2 import Abi as AbiV2
from ..cairo.data_types import (
ArrayType,
BoolType,
CairoType,
EnumType,
EventType,
FeltType,
NamedTupleType,
OptionType,
StructType,
TupleType,
UintType,
UnitType,
)
from .data_serializers import (
BoolSerializer,
ByteArraySerializer,
)
from .data_serializers.array_serializer import ArraySerializer
from .data_serializers.cairo_data_serializer import (
CairoDataSerializer,
)
from .data_serializers.enum_serializer import EnumSerializer
from .data_serializers.felt_serializer import FeltSerializer
from .data_serializers.named_tuple_serializer import (
NamedTupleSerializer,
)
from .data_serializers.option_serializer import (
OptionSerializer,
)
from .data_serializers.output_serializer import (
OutputSerializer,
)
from .data_serializers.payload_serializer import (
PayloadSerializer,
)
from .data_serializers.struct_serializer import (
StructSerializer,
)
from .data_serializers.tuple_serializer import TupleSerializer
from .data_serializers.uint256_serializer import (
Uint256Serializer,
)
from .data_serializers.uint_serializer import UintSerializer
from .data_serializers.unit_serializer import UnitSerializer
from .errors import InvalidTypeException
from .function_serialization_adapter import (
FunctionSerializationAdapter,
FunctionSerializationAdapterV1,
)
_uint256_type = StructType("Uint256", OrderedDict(low=FeltType(), high=FeltType()))
_byte_array_type = StructType(
"core::byte_array::ByteArray",
OrderedDict(
data=ArrayType(FeltType()),
pending_word=FeltType(),
pending_word_len=UintType(bits=32),
),
)
def serializer_for_type(cairo_type: CairoType) -> CairoDataSerializer:
"""
Create a serializer for cairo type.
:param cairo_type: CairoType.
:return: CairoDataSerializer.
"""
# pylint: disable=too-many-return-statements, too-many-branches
if isinstance(cairo_type, FeltType):
return FeltSerializer()
if isinstance(cairo_type, BoolType):
return BoolSerializer()
if isinstance(cairo_type, StructType):
# Special case: Uint256 is represented as struct
if cairo_type == _uint256_type:
return Uint256Serializer()
if cairo_type == _byte_array_type:
return ByteArraySerializer()
return StructSerializer(
OrderedDict(
(name, serializer_for_type(member_type))
for name, member_type in cairo_type.types.items()
)
)
if isinstance(cairo_type, ArrayType):
return ArraySerializer(serializer_for_type(cairo_type.inner_type))
if isinstance(cairo_type, TupleType):
return TupleSerializer(
[serializer_for_type(member) for member in cairo_type.types]
)
if isinstance(cairo_type, NamedTupleType):
return NamedTupleSerializer(
OrderedDict(
(name, serializer_for_type(member_type))
for name, member_type in cairo_type.types.items()
)
)
if isinstance(cairo_type, UintType):
return UintSerializer(bits=cairo_type.bits)
if isinstance(cairo_type, OptionType):
return OptionSerializer(serializer_for_type(cairo_type.type))
if isinstance(cairo_type, UnitType):
return UnitSerializer()
if isinstance(cairo_type, EnumType):
return EnumSerializer(
OrderedDict(
(name, serializer_for_type(variant_type))
for name, variant_type in cairo_type.variants.items()
)
)
if isinstance(cairo_type, EventType):
return serializer_for_payload(cairo_type.types)
raise InvalidTypeException(f"Received unknown Cairo type '{cairo_type}'.")
# We don't want to require users to use OrderedDict. Regular python requires order since python 3.7.
def serializer_for_payload(payload: Dict[str, CairoType]) -> PayloadSerializer:
"""
Create PayloadSerializer for types listed in a dictionary. Please note that the order of fields in the dict is
very important. Make sure the keys are provided in the right order.
:param payload: dictionary with cairo types.
:return: PayloadSerializer that can be used to (de)serialize events/function calls.
"""
return PayloadSerializer(
OrderedDict(
(name, serializer_for_type(cairo_type))
for name, cairo_type in payload.items()
)
)
def serializer_for_outputs(payload: List[CairoType]) -> OutputSerializer:
"""
Create OutputSerializer for types in list. Please note that the order of fields in the list is
very important. Make sure the types are provided in the right order.
:param payload: list with cairo types.
:return: OutputSerializer that can be used to deserialize function outputs.
"""
return OutputSerializer(
serializers=[serializer_for_type(cairo_type) for cairo_type in payload]
)
EventV0 = AbiV0.Event
EventV1 = AbiV1.Event
EventV2 = EventType
def serializer_for_event(event: EventV0 | EventV1 | EventV2) -> PayloadSerializer:
"""
Create serializer for an event.
:param event: parsed event.
:return: PayloadSerializer that can be used to (de)serialize events.
"""
if isinstance(event, EventV0):
return serializer_for_payload(event.data)
if isinstance(event, EventV1):
return serializer_for_payload(event.inputs)
return serializer_for_payload(event.types)
def serializer_for_function(
abi_function: AbiV0.Function,
) -> FunctionSerializationAdapter:
"""
Create FunctionSerializationAdapter for serializing function inputs and deserializing function outputs.
:param abi_function: parsed function's abi.
:return: FunctionSerializationAdapter.
"""
return FunctionSerializationAdapter(
inputs_serializer=serializer_for_payload(abi_function.inputs),
outputs_deserializer=serializer_for_payload(abi_function.outputs),
)
def serializer_for_function_v1(
abi_function: Union[AbiV1.Function, AbiV2.Function],
) -> FunctionSerializationAdapter:
"""
Create FunctionSerializationAdapter for serializing function inputs and deserializing function outputs.
:param abi_function: parsed function's abi.
:return: FunctionSerializationAdapter.
"""
return FunctionSerializationAdapterV1(
inputs_serializer=serializer_for_payload(abi_function.inputs),
outputs_deserializer=serializer_for_outputs(abi_function.outputs),
)
def serializer_for_constructor_v2(
abi_function: AbiV2.Constructor,
) -> FunctionSerializationAdapter:
"""
Create FunctionSerializationAdapter for serializing constructor inputs.
:param abi_function: parsed constructor's abi.
:return: FunctionSerializationAdapter.
"""
return FunctionSerializationAdapterV1(
inputs_serializer=serializer_for_payload(abi_function.inputs),
outputs_deserializer=serializer_for_outputs([]),
)

View File

@@ -0,0 +1,110 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Dict, List, Set, Tuple
from ..cairo.felt import CairoData
from .data_serializers.output_serializer import (
OutputSerializer,
)
from .data_serializers.payload_serializer import (
PayloadSerializer,
)
from .errors import InvalidTypeException
from .tuple_dataclass import TupleDataclass
@dataclass
class FunctionSerializationAdapter:
"""
Class serializing ``*args`` and ``**kwargs`` by adapting them to function inputs.
"""
inputs_serializer: PayloadSerializer
outputs_deserializer: PayloadSerializer
expected_args: Tuple[str] = field(init=False)
def __post_init__(self):
self.expected_args = tuple(
self.inputs_serializer.serializers.keys()
) # pyright: ignore
def serialize(self, *args, **kwargs) -> CairoData:
"""
Method using args and kwargs to match members and serialize them separately.
:return: Members serialized separately in SerializedPayload.
"""
named_arguments = self._merge_arguments(args, kwargs)
return self.inputs_serializer.serialize(named_arguments)
def deserialize(self, data: List[int]) -> TupleDataclass:
"""
Deserializes data into TupleDataclass containing python representations.
:return: cairo data.
"""
return self.outputs_deserializer.deserialize(data)
def _merge_arguments(self, args: Tuple, kwargs: Dict) -> Dict:
"""
Merges positional and keyed arguments.
"""
# After this line we know that len(args) <= len(self.expected_args)
self._ensure_no_unnecessary_positional_args(args)
named_arguments = dict(kwargs)
for arg, input_name in zip(args, self.expected_args):
if input_name in kwargs:
raise InvalidTypeException(
f"Both positional and named argument provided for '{input_name}'."
)
named_arguments[input_name] = arg
expected_args = set(self.expected_args)
provided_args = set(named_arguments.keys())
# named_arguments might have unnecessary arguments coming from kwargs (we ensure that
# len(args) <= len(self.expected_args) above)
self._ensure_no_unnecessary_args(expected_args, provided_args)
# there might be some argument missing (not provided)
self._ensure_no_missing_args(expected_args, provided_args)
return named_arguments
def _ensure_no_unnecessary_positional_args(self, args: Tuple):
if len(args) > len(self.expected_args):
raise InvalidTypeException(
f"Provided {len(args)} positional arguments, {len(self.expected_args)} possible."
)
@staticmethod
def _ensure_no_unnecessary_args(expected_args: Set[str], provided_args: Set[str]):
excessive_arguments = provided_args - expected_args
if excessive_arguments:
raise InvalidTypeException(
f"Unnecessary named arguments provided: '{', '.join(excessive_arguments)}'."
)
@staticmethod
def _ensure_no_missing_args(expected_args: Set[str], provided_args: Set[str]):
missing_arguments = expected_args - provided_args
if missing_arguments:
raise InvalidTypeException(
f"Missing arguments: '{', '.join(missing_arguments)}'."
)
@dataclass
class FunctionSerializationAdapterV1(FunctionSerializationAdapter):
outputs_deserializer: OutputSerializer
def deserialize(self, data: List[int]) -> Tuple:
"""
Deserializes data into TupleDataclass containing python representations.
:return: cairo data.
"""
return self.outputs_deserializer.deserialize(data)

View File

@@ -0,0 +1,59 @@
from __future__ import annotations
from dataclasses import dataclass, fields, make_dataclass
from typing import Dict, Optional, Tuple
@dataclass(frozen=True, eq=False)
class TupleDataclass:
"""
Dataclass that behaves like a tuple at the same time. Used when data has defined order and names.
For instance in case of named tuples or function responses.
"""
# getattr is called when attribute is not found in object. For instance when using object.unknown_attribute.
# This way pyright will know that there might be some arguments it doesn't know about and will stop complaining
# about some fields that don't exist statically.
def __getattr__(self, item):
# This should always fail - only attributes that don't exist end up in here.
# We use __getattribute__ to get the native error.
return super().__getattribute__(item)
def __getitem__(self, item: int):
field = fields(self)[item]
return getattr(self, field.name)
def __iter__(self):
return (getattr(self, field.name) for field in fields(self))
def as_tuple(self) -> Tuple:
"""
Creates a regular tuple from TupleDataclass.
"""
return tuple(self)
def as_dict(self) -> Dict:
"""
Creates a regular dict from TupleDataclass.
"""
return {field.name: getattr(self, field.name) for field in fields(self)}
# Added for backward compatibility with previous implementation based on NamedTuple
def _asdict(self):
return self.as_dict()
def __eq__(self, other):
if isinstance(other, TupleDataclass):
return self.as_tuple() == other.as_tuple()
return self.as_tuple() == other
@staticmethod
def from_dict(data: Dict, *, name: Optional[str] = None) -> TupleDataclass:
result_class = make_dataclass(
name or "TupleDataclass",
fields=[(key, type(value)) for key, value in data.items()],
bases=(TupleDataclass,),
frozen=True,
eq=False,
)
return result_class(**data)

View File

@@ -0,0 +1,86 @@
from typing import List, Optional, Union
from ..abi.v2 import shape as ShapeV2
from ..abi.v0 import AbiParser as AbiParserV0
from ..abi.v1 import AbiParser as AbiParserV1
from ..abi.v2 import AbiParser as AbiParserV2
from ..serialization import (
FunctionSerializationAdapter,
serializer_for_function,
)
from ..serialization.factory import (
serializer_for_constructor_v2,
serializer_for_function_v1,
)
def translate_constructor_args(
abi: List, constructor_args: Optional[Union[List, dict]], *, cairo_version: int = 1
) -> List[int]:
serializer = (
_get_constructor_serializer_v1(abi)
if cairo_version == 1
else _get_constructor_serializer_v0(abi)
)
if serializer is None or len(serializer.inputs_serializer.serializers) == 0:
return []
if not constructor_args:
raise ValueError(
"Provided contract has a constructor and no arguments were provided."
)
args, kwargs = (
([], constructor_args)
if isinstance(constructor_args, dict)
else (constructor_args, {})
)
return serializer.serialize(*args, **kwargs)
def _get_constructor_serializer_v1(abi: List) -> Optional[FunctionSerializationAdapter]:
if _is_abi_v2(abi):
parsed = AbiParserV2(abi).parse()
constructor = parsed.constructor
if constructor is None or not constructor.inputs:
return None
return serializer_for_constructor_v2(constructor)
parsed = AbiParserV1(abi).parse()
constructor = parsed.functions.get("constructor", None)
# Constructor might not accept any arguments
if constructor is None or not constructor.inputs:
return None
return serializer_for_function_v1(constructor)
def _is_abi_v2(abi: List) -> bool:
for entry in abi:
if entry["type"] in [
ShapeV2.CONSTRUCTOR_ENTRY,
ShapeV2.L1_HANDLER_ENTRY,
ShapeV2.INTERFACE_ENTRY,
ShapeV2.IMPL_ENTRY,
]:
return True
if entry["type"] == ShapeV2.EVENT_ENTRY:
if "inputs" in entry:
return False
if "kind" in entry:
return True
return False
def _get_constructor_serializer_v0(abi: List) -> Optional[FunctionSerializationAdapter]:
parsed = AbiParserV0(abi).parse()
# Constructor might not accept any arguments
if not parsed.constructor or not parsed.constructor.inputs:
return None
return serializer_for_function(parsed.constructor)

View File

@@ -0,0 +1,13 @@
from typing import Iterable, TypeVar, Union
T = TypeVar("T")
# pyright: reportGeneralTypeIssues=false
def ensure_iterable(value: Union[T, Iterable[T]]) -> Iterable[T]:
try:
iter(value)
# Now we now it is iterable
return value
except TypeError:
return [value]

View File

@@ -0,0 +1,13 @@
import os
from ..marshmallow import EXCLUDE, RAISE
from ..marshmallow import Schema as MarshmallowSchema
MARSHMALLOW_UKNOWN_EXCLUDE = os.environ.get("STARKNET_PY_MARSHMALLOW_UKNOWN_EXCLUDE")
class Schema(MarshmallowSchema):
class Meta:
unknown = (
EXCLUDE if (MARSHMALLOW_UKNOWN_EXCLUDE or "").lower() == "true" else RAISE
)

View File

@@ -0,0 +1,182 @@
from dataclasses import dataclass
from typing import Dict, List, Union, cast
from ...marshmallow import Schema, fields, post_load
from ..cairo.felt import encode_shortstring
from ..hash.selector import get_selector_from_name
from ..hash.utils import compute_hash_on_elements
from ..models.typed_data import StarkNetDomainDict, TypedDataDict
@dataclass(frozen=True)
class Parameter:
"""
Dataclass representing a Parameter object
"""
name: str
type: str
@dataclass(frozen=True)
class TypedData:
"""
Dataclass representing a TypedData object
"""
types: Dict[str, List[Parameter]]
primary_type: str
domain: StarkNetDomainDict
message: dict
@staticmethod
def from_dict(data: TypedDataDict) -> "TypedData":
"""
Create TypedData dataclass from dictionary.
:param data: TypedData dictionary.
:return: TypedData dataclass instance.
"""
return cast(TypedData, TypedDataSchema().load(data))
def _is_struct(self, type_name: str) -> bool:
return type_name in self.types
def _encode_value(self, type_name: str, value: Union[int, str, dict, list]) -> int:
if is_pointer(type_name) and isinstance(value, list):
type_name = strip_pointer(type_name)
if self._is_struct(type_name):
return compute_hash_on_elements(
[self.struct_hash(type_name, data) for data in value]
)
return compute_hash_on_elements([int(get_hex(val), 16) for val in value])
if self._is_struct(type_name) and isinstance(value, dict):
return self.struct_hash(type_name, value)
value = cast(Union[int, str], value)
return int(get_hex(value), 16)
def _encode_data(self, type_name: str, data: dict) -> List[int]:
values = []
for param in self.types[type_name]:
encoded_value = self._encode_value(param.type, data[param.name])
values.append(encoded_value)
return values
def _get_dependencies(self, type_name: str) -> List[str]:
if type_name not in self.types:
# type_name is a primitive type, has no dependencies
return []
dependencies = set()
def collect_deps(type_name: str) -> None:
for param in self.types[type_name]:
fixed_type = strip_pointer(param.type)
if fixed_type in self.types and fixed_type not in dependencies:
dependencies.add(fixed_type)
# recursive call
collect_deps(fixed_type)
# collect dependencies into a set
collect_deps(type_name)
return [type_name, *list(dependencies)]
def _encode_type(self, type_name: str) -> str:
primary, *dependencies = self._get_dependencies(type_name)
types = [primary, *sorted(dependencies)]
def make_dependency_str(dependency):
lst = [f"{t.name}:{t.type}" for t in self.types[dependency]]
return f"{dependency}({','.join(lst)})"
return "".join([make_dependency_str(x) for x in types])
def type_hash(self, type_name: str) -> int:
"""
Calculate the hash of a type name.
:param type_name: Name of the type.
:return: Hash of the type name.
"""
return get_selector_from_name(self._encode_type(type_name))
def struct_hash(self, type_name: str, data: dict) -> int:
"""
Calculate the hash of a struct.
:param type_name: Name of the type.
:param data: Data defining the struct.
:return: Hash of the struct.
"""
return compute_hash_on_elements(
[self.type_hash(type_name), *self._encode_data(type_name, data)]
)
def message_hash(self, account_address: int) -> int:
"""
Calculate the hash of the message.
:param account_address: Address of an account.
:return: Hash of the message.
"""
message = [
encode_shortstring("StarkNet Message"),
self.struct_hash("StarkNetDomain", cast(dict, self.domain)),
account_address,
self.struct_hash(self.primary_type, self.message),
]
return compute_hash_on_elements(message)
def get_hex(value: Union[int, str]) -> str:
if isinstance(value, int):
return hex(value)
if value[:2] == "0x":
return value
if value.isnumeric():
return hex(int(value))
return hex(encode_shortstring(value))
def is_pointer(value: str) -> bool:
return len(value) > 0 and value[-1] == "*"
def strip_pointer(value: str) -> str:
if is_pointer(value):
return value[:-1]
return value
# pylint: disable=unused-argument
# pylint: disable=no-self-use
class ParameterSchema(Schema):
name = fields.String(data_key="name", required=True)
type = fields.String(data_key="type", required=True)
@post_load
def make_dataclass(self, data, **kwargs) -> Parameter:
return Parameter(**data)
class TypedDataSchema(Schema):
types = fields.Dict(
data_key="types",
keys=fields.Str(),
values=fields.List(fields.Nested(ParameterSchema())),
)
primary_type = fields.String(data_key="primaryType", required=True)
domain = fields.Dict(data_key="domain", required=True)
message = fields.Dict(data_key="message", required=True)
@post_load
def make_dataclass(self, data, **kwargs) -> TypedData:
return TypedData(**data)