83 lines
2.8 KiB
Python
83 lines
2.8 KiB
Python
import copy
|
|
import inspect
|
|
from typing import List, Tuple, Any, Optional
|
|
|
|
import typeguard
|
|
from marshmallow import fields, Schema, ValidationError
|
|
|
|
try:
|
|
from typeguard import TypeCheckError # type: ignore[attr-defined]
|
|
except ImportError:
|
|
# typeguard < 3
|
|
TypeCheckError = TypeError # type: ignore[misc, assignment]
|
|
|
|
if "argname" not in inspect.signature(typeguard.check_type).parameters:
|
|
|
|
def _check_type(value, expected_type, argname: str):
|
|
return typeguard.check_type(value=value, expected_type=expected_type)
|
|
|
|
else:
|
|
# typeguard < 3.0.0rc2
|
|
def _check_type(value, expected_type, argname: str):
|
|
return typeguard.check_type( # type: ignore[call-overload]
|
|
value=value, expected_type=expected_type, argname=argname
|
|
)
|
|
|
|
|
|
class Union(fields.Field):
|
|
"""A union field, composed other `Field` classes or instances.
|
|
This field serializes elements based on their type, with one of its child fields.
|
|
|
|
Example: ::
|
|
|
|
number_or_string = UnionField([
|
|
(float, fields.Float()),
|
|
(str, fields.Str())
|
|
])
|
|
|
|
:param union_fields: A list of types and their associated field instance.
|
|
:param kwargs: The same keyword arguments that :class:`Field` receives.
|
|
"""
|
|
|
|
def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs):
|
|
super().__init__(**kwargs)
|
|
self.union_fields = union_fields
|
|
|
|
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
|
|
super()._bind_to_schema(field_name, schema)
|
|
new_union_fields = []
|
|
for typ, field in self.union_fields:
|
|
field = copy.deepcopy(field)
|
|
field._bind_to_schema(field_name, self)
|
|
new_union_fields.append((typ, field))
|
|
|
|
self.union_fields = new_union_fields
|
|
|
|
def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any:
|
|
errors = []
|
|
if value is None:
|
|
return value
|
|
for typ, field in self.union_fields:
|
|
try:
|
|
_check_type(value=value, expected_type=typ, argname=attr or "anonymous")
|
|
return field._serialize(value, attr, obj, **kwargs)
|
|
except TypeCheckError as e:
|
|
errors.append(e)
|
|
raise TypeError(
|
|
f"Unable to serialize value with any of the fields in the union: {errors}"
|
|
)
|
|
|
|
def _deserialize(self, value: Any, attr: Optional[str], data, **kwargs) -> Any:
|
|
errors = []
|
|
for typ, field in self.union_fields:
|
|
try:
|
|
result = field.deserialize(value, **kwargs)
|
|
_check_type(
|
|
value=result, expected_type=typ, argname=attr or "anonymous"
|
|
)
|
|
return result
|
|
except (TypeCheckError, ValidationError) as e:
|
|
errors.append(e)
|
|
|
|
raise ValidationError(errors)
|