You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
306 lines
12 KiB
306 lines
12 KiB
# -*- coding: utf-8 -*-
|
|
"""Utility classes and values used for marshalling and unmarshalling objects to
|
|
and from primitive types.
|
|
|
|
.. warning::
|
|
|
|
This module is treated as private API.
|
|
Users should not need to use this module directly.
|
|
"""
|
|
|
|
from __future__ import unicode_literals
|
|
|
|
from marshmallow.utils import is_collection, missing, set_value
|
|
from marshmallow.compat import text_type, iteritems, Mapping
|
|
from marshmallow.exceptions import (
|
|
ValidationError,
|
|
)
|
|
|
|
__all__ = [
|
|
'Marshaller',
|
|
'Unmarshaller',
|
|
]
|
|
|
|
# Key used for field-level validation errors on nested fields
|
|
FIELD = '_field'
|
|
|
|
|
|
class ErrorStore(object):
|
|
|
|
def __init__(self):
|
|
#: Dictionary of errors stored during serialization
|
|
self.errors = {}
|
|
#: List of `Field` objects which have validation errors
|
|
self.error_fields = []
|
|
#: List of field_names which have validation errors
|
|
self.error_field_names = []
|
|
#: True while (de)serializing a collection
|
|
self._pending = False
|
|
#: Dictionary of extra kwargs from user raised exception
|
|
self.error_kwargs = {}
|
|
|
|
def get_errors(self, index=None):
|
|
if index is not None:
|
|
errors = self.errors.get(index, {})
|
|
self.errors[index] = errors
|
|
else:
|
|
errors = self.errors
|
|
return errors
|
|
|
|
def call_and_store(self, getter_func, data, field_name, field_obj, index=None):
|
|
"""Call ``getter_func`` with ``data`` as its argument, and store any `ValidationErrors`.
|
|
|
|
:param callable getter_func: Function for getting the serialized/deserialized
|
|
value from ``data``.
|
|
:param data: The data passed to ``getter_func``.
|
|
:param str field_name: Field name.
|
|
:param FieldABC field_obj: Field object that performs the
|
|
serialization/deserialization behavior.
|
|
:param int index: Index of the item being validated, if validating a collection,
|
|
otherwise `None`.
|
|
"""
|
|
try:
|
|
value = getter_func(data)
|
|
except ValidationError as err: # Store validation errors
|
|
self.error_kwargs.update(err.kwargs)
|
|
self.error_fields.append(field_obj)
|
|
self.error_field_names.append(field_name)
|
|
errors = self.get_errors(index=index)
|
|
# Warning: Mutation!
|
|
if isinstance(err.messages, dict):
|
|
errors[field_name] = err.messages
|
|
elif isinstance(errors.get(field_name), dict):
|
|
errors[field_name].setdefault(FIELD, []).extend(err.messages)
|
|
else:
|
|
errors.setdefault(field_name, []).extend(err.messages)
|
|
# When a Nested field fails validation, the marshalled data is stored
|
|
# on the ValidationError's data attribute
|
|
value = err.data or missing
|
|
return value
|
|
|
|
|
|
class Marshaller(ErrorStore):
|
|
"""Callable class responsible for serializing data and storing errors.
|
|
|
|
:param str prefix: Optional prefix that will be prepended to all the
|
|
serialized field names.
|
|
"""
|
|
def __init__(self, prefix=''):
|
|
self.prefix = prefix
|
|
ErrorStore.__init__(self)
|
|
|
|
def serialize(self, obj, fields_dict, many=False,
|
|
accessor=None, dict_class=dict, index_errors=True, index=None):
|
|
"""Takes raw data (a dict, list, or other object) and a dict of
|
|
fields to output and serializes the data based on those fields.
|
|
|
|
:param obj: The actual object(s) from which the fields are taken from
|
|
:param dict fields_dict: Mapping of field names to :class:`Field` objects.
|
|
:param bool many: Set to `True` if ``data`` should be serialized as
|
|
a collection.
|
|
:param callable accessor: Function to use for getting values from ``obj``.
|
|
:param type dict_class: Dictionary class used to construct the output.
|
|
:param bool index_errors: Whether to store the index of invalid items in
|
|
``self.errors`` when ``many=True``.
|
|
:param int index: Index of the item being serialized (for storing errors) if
|
|
serializing a collection, otherwise `None`.
|
|
:return: A dictionary of the marshalled data
|
|
|
|
.. versionchanged:: 1.0.0
|
|
Renamed from ``marshal``.
|
|
"""
|
|
if many and obj is not None:
|
|
self._pending = True
|
|
ret = [self.serialize(d, fields_dict, many=False,
|
|
dict_class=dict_class, accessor=accessor,
|
|
index=idx, index_errors=index_errors)
|
|
for idx, d in enumerate(obj)]
|
|
self._pending = False
|
|
if self.errors:
|
|
raise ValidationError(
|
|
self.errors,
|
|
field_names=self.error_field_names,
|
|
fields=self.error_fields,
|
|
data=ret,
|
|
)
|
|
return ret
|
|
items = []
|
|
for attr_name, field_obj in iteritems(fields_dict):
|
|
if getattr(field_obj, 'load_only', False):
|
|
continue
|
|
|
|
key = ''.join([self.prefix or '', field_obj.dump_to or attr_name])
|
|
|
|
getter = lambda d: field_obj.serialize(attr_name, d, accessor=accessor)
|
|
value = self.call_and_store(
|
|
getter_func=getter,
|
|
data=obj,
|
|
field_name=key,
|
|
field_obj=field_obj,
|
|
index=(index if index_errors else None)
|
|
)
|
|
if value is missing:
|
|
continue
|
|
items.append((key, value))
|
|
ret = dict_class(items)
|
|
if self.errors and not self._pending:
|
|
raise ValidationError(
|
|
self.errors,
|
|
field_names=self.error_field_names,
|
|
fields=self.error_fields,
|
|
data=ret
|
|
)
|
|
return ret
|
|
|
|
# Make an instance callable
|
|
__call__ = serialize
|
|
|
|
|
|
# Key used for schema-level validation errors
|
|
SCHEMA = '_schema'
|
|
|
|
|
|
class Unmarshaller(ErrorStore):
|
|
"""Callable class responsible for deserializing data and storing errors.
|
|
|
|
.. versionadded:: 1.0.0
|
|
"""
|
|
|
|
default_schema_validation_error = 'Invalid data.'
|
|
|
|
def run_validator(self, validator_func, output,
|
|
original_data, fields_dict, index=None,
|
|
many=False, pass_original=False):
|
|
try:
|
|
if pass_original: # Pass original, raw data (before unmarshalling)
|
|
res = validator_func(output, original_data)
|
|
else:
|
|
res = validator_func(output)
|
|
if res is False:
|
|
raise ValidationError(self.default_schema_validation_error)
|
|
except ValidationError as err:
|
|
errors = self.get_errors(index=index)
|
|
self.error_kwargs.update(err.kwargs)
|
|
# Store or reraise errors
|
|
if err.field_names:
|
|
field_names = err.field_names
|
|
field_objs = [fields_dict[each] if each in fields_dict else None
|
|
for each in field_names]
|
|
else:
|
|
field_names = [SCHEMA]
|
|
field_objs = []
|
|
self.error_field_names = field_names
|
|
self.error_fields = field_objs
|
|
for field_name in field_names:
|
|
if isinstance(err.messages, (list, tuple)):
|
|
# self.errors[field_name] may be a dict if schemas are nested
|
|
if isinstance(errors.get(field_name), dict):
|
|
errors[field_name].setdefault(
|
|
SCHEMA, []
|
|
).extend(err.messages)
|
|
else:
|
|
errors.setdefault(field_name, []).extend(err.messages)
|
|
elif isinstance(err.messages, dict):
|
|
errors.setdefault(field_name, []).append(err.messages)
|
|
else:
|
|
errors.setdefault(field_name, []).append(text_type(err))
|
|
|
|
def deserialize(self, data, fields_dict, many=False, partial=False,
|
|
dict_class=dict, index_errors=True, index=None):
|
|
"""Deserialize ``data`` based on the schema defined by ``fields_dict``.
|
|
|
|
:param dict data: The data to deserialize.
|
|
:param dict fields_dict: Mapping of field names to :class:`Field` objects.
|
|
:param bool many: Set to `True` if ``data`` should be deserialized as
|
|
a collection.
|
|
:param bool|tuple partial: Whether to ignore missing fields. If its
|
|
value is an iterable, only missing fields listed in that iterable
|
|
will be ignored.
|
|
:param type dict_class: Dictionary class used to construct the output.
|
|
:param bool index_errors: Whether to store the index of invalid items in
|
|
``self.errors`` when ``many=True``.
|
|
:param int index: Index of the item being serialized (for storing errors) if
|
|
serializing a collection, otherwise `None`.
|
|
:return: A dictionary of the deserialized data.
|
|
"""
|
|
if many and data is not None:
|
|
if not is_collection(data):
|
|
errors = self.get_errors(index=index)
|
|
self.error_field_names.append(SCHEMA)
|
|
errors[SCHEMA] = ['Invalid input type.']
|
|
ret = []
|
|
else:
|
|
self._pending = True
|
|
ret = [self.deserialize(d, fields_dict, many=False,
|
|
partial=partial, dict_class=dict_class,
|
|
index=idx, index_errors=index_errors)
|
|
for idx, d in enumerate(data)]
|
|
|
|
self._pending = False
|
|
if self.errors:
|
|
raise ValidationError(
|
|
self.errors,
|
|
field_names=self.error_field_names,
|
|
fields=self.error_fields,
|
|
data=ret,
|
|
)
|
|
return ret
|
|
|
|
ret = dict_class()
|
|
|
|
if not isinstance(data, Mapping):
|
|
errors = self.get_errors(index=index)
|
|
msg = 'Invalid input type.'
|
|
self.error_field_names = [SCHEMA]
|
|
errors = self.get_errors()
|
|
errors.setdefault(SCHEMA, []).append(msg)
|
|
return None
|
|
else:
|
|
partial_is_collection = is_collection(partial)
|
|
for attr_name, field_obj in iteritems(fields_dict):
|
|
if field_obj.dump_only:
|
|
continue
|
|
raw_value = data.get(attr_name, missing)
|
|
field_name = attr_name
|
|
if raw_value is missing and field_obj.load_from:
|
|
field_name = field_obj.load_from
|
|
raw_value = data.get(field_obj.load_from, missing)
|
|
if raw_value is missing:
|
|
# Ignore missing field if we're allowed to.
|
|
if (
|
|
partial is True or
|
|
(partial_is_collection and attr_name in partial)
|
|
):
|
|
continue
|
|
_miss = field_obj.missing
|
|
raw_value = _miss() if callable(_miss) else _miss
|
|
if raw_value is missing and not field_obj.required:
|
|
continue
|
|
|
|
getter = lambda val: field_obj.deserialize(
|
|
val,
|
|
field_obj.load_from or attr_name,
|
|
data
|
|
)
|
|
value = self.call_and_store(
|
|
getter_func=getter,
|
|
data=raw_value,
|
|
field_name=field_name,
|
|
field_obj=field_obj,
|
|
index=(index if index_errors else None)
|
|
)
|
|
if value is not missing:
|
|
key = fields_dict[attr_name].attribute or attr_name
|
|
set_value(ret, key, value)
|
|
|
|
if self.errors and not self._pending:
|
|
raise ValidationError(
|
|
self.errors,
|
|
field_names=self.error_field_names,
|
|
fields=self.error_fields,
|
|
data=ret,
|
|
)
|
|
return ret
|
|
|
|
# Make an instance callable
|
|
__call__ = deserialize
|
|
|