Source code for invenio_records_rest.loaders.marshmallow

# -*- coding: utf-8 -*-
#
# This file is part of Invenio.
# Copyright (C) 2016-2018 CERN.
#
# Invenio is free software; you can redistribute it and/or modify it
# under the terms of the MIT License; see LICENSE file for more details.

"""Marshmallow loader for record deserialization.

Use marshmallow schema to transform a JSON sent via the REST API from an
external to an internal JSON presentation. The marshmallow schema further
allows for advanced data validation.
"""

from __future__ import absolute_import, print_function

import json

from flask import request
from invenio_rest.errors import RESTValidationError
from marshmallow import ValidationError
from marshmallow import __version_info__ as marshmallow_version


def _flatten_marshmallow_errors(errors, parents=()):
    """Flatten marshmallow errors."""
    res = []
    for field, error in errors.items():
        if isinstance(error, list):
            res.append(
                dict(
                    parents=parents,
                    field=field,
                    message=' '.join(str(x) for x in error)
                )
            )
        elif isinstance(error, dict):
            res.extend(_flatten_marshmallow_errors(
                error,
                parents=parents + (field,)
            ))
    return res


class MarshmallowErrors(RESTValidationError):
    """Marshmallow validation errors.

    Responsible for formatting a JSON response to a user when a validation
    error happens.
    """

    def __init__(self, errors):
        """Store marshmallow errors."""
        self._it = None
        self.errors = _flatten_marshmallow_errors(errors)
        super(MarshmallowErrors, self).__init__()

    def __str__(self):
        """Print exception with errors."""
        return "{base}. Encountered errors: {errors}".format(
            base=super(RESTValidationError, self).__str__(),
            errors=self.errors)

    def __iter__(self):
        """Get iterator."""
        self._it = iter(self.errors)
        return self

    def next(self):
        """Python 2.7 compatibility."""
        return self.__next__()  # pragma: no cover

    def __next__(self):
        """Get next file item."""
        return next(self._it)

    def get_body(self, environ=None):
        """Get the request body."""
        body = dict(
            status=self.code,
            message=self.get_description(environ),
        )

        if self.errors:
            body['errors'] = self.errors

        return json.dumps(body)


def marshmallow_loader(schema_class):
    """Marshmallow loader for JSON requests."""
    def json_loader():
        request_json = request.get_json()

        context = {}
        pid_data = request.view_args.get('pid_value')
        if pid_data:
            pid, record = pid_data.data
            context['pid'] = pid
            context['record'] = record
        if marshmallow_version[0] < 3:
            result = schema_class(context=context).load(request_json)
            if result.errors:
                raise MarshmallowErrors(result.errors)
        else:
            # From Marshmallow 3 the errors on .load() are being raised rather
            # than returned. To adjust this change to our flow we catch these
            # errors and reraise them instead.
            try:
                result = schema_class(context=context).load(request_json)
            except ValidationError as error:
                raise MarshmallowErrors(error.messages)

        return result.data
    return json_loader


def json_patch_loader():
    """Dummy loader for json-patch requests."""
    return request.get_json(force=True)