add
This commit is contained in:
@@ -0,0 +1,82 @@
|
||||
import copy
|
||||
import inspect
|
||||
from typing import List, Tuple, Any, Optional
|
||||
|
||||
import typeguard
|
||||
from marshmallow import fields, Schema, ValidationError
|
||||
|
||||
try:
|
||||
from typeguard import TypeCheckError # type: ignore[attr-defined]
|
||||
except ImportError:
|
||||
# typeguard < 3
|
||||
TypeCheckError = TypeError # type: ignore[misc, assignment]
|
||||
|
||||
if "argname" not in inspect.signature(typeguard.check_type).parameters:
|
||||
|
||||
def _check_type(value, expected_type, argname: str):
|
||||
return typeguard.check_type(value=value, expected_type=expected_type)
|
||||
|
||||
else:
|
||||
# typeguard < 3.0.0rc2
|
||||
def _check_type(value, expected_type, argname: str):
|
||||
return typeguard.check_type( # type: ignore[call-overload]
|
||||
value=value, expected_type=expected_type, argname=argname
|
||||
)
|
||||
|
||||
|
||||
class Union(fields.Field):
|
||||
"""A union field, composed other `Field` classes or instances.
|
||||
This field serializes elements based on their type, with one of its child fields.
|
||||
|
||||
Example: ::
|
||||
|
||||
number_or_string = UnionField([
|
||||
(float, fields.Float()),
|
||||
(str, fields.Str())
|
||||
])
|
||||
|
||||
:param union_fields: A list of types and their associated field instance.
|
||||
:param kwargs: The same keyword arguments that :class:`Field` receives.
|
||||
"""
|
||||
|
||||
def __init__(self, union_fields: List[Tuple[type, fields.Field]], **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
self.union_fields = union_fields
|
||||
|
||||
def _bind_to_schema(self, field_name: str, schema: Schema) -> None:
|
||||
super()._bind_to_schema(field_name, schema)
|
||||
new_union_fields = []
|
||||
for typ, field in self.union_fields:
|
||||
field = copy.deepcopy(field)
|
||||
field._bind_to_schema(field_name, self)
|
||||
new_union_fields.append((typ, field))
|
||||
|
||||
self.union_fields = new_union_fields
|
||||
|
||||
def _serialize(self, value: Any, attr: Optional[str], obj, **kwargs) -> Any:
|
||||
errors = []
|
||||
if value is None:
|
||||
return value
|
||||
for typ, field in self.union_fields:
|
||||
try:
|
||||
_check_type(value=value, expected_type=typ, argname=attr or "anonymous")
|
||||
return field._serialize(value, attr, obj, **kwargs)
|
||||
except TypeCheckError as e:
|
||||
errors.append(e)
|
||||
raise TypeError(
|
||||
f"Unable to serialize value with any of the fields in the union: {errors}"
|
||||
)
|
||||
|
||||
def _deserialize(self, value: Any, attr: Optional[str], data, **kwargs) -> Any:
|
||||
errors = []
|
||||
for typ, field in self.union_fields:
|
||||
try:
|
||||
result = field.deserialize(value, **kwargs)
|
||||
_check_type(
|
||||
value=result, expected_type=typ, argname=attr or "anonymous"
|
||||
)
|
||||
return result
|
||||
except (TypeCheckError, ValidationError) as e:
|
||||
errors.append(e)
|
||||
|
||||
raise ValidationError(errors)
|
||||
Reference in New Issue
Block a user