diff --git a/README.rst b/README.rst index ed745305..13d74816 100644 --- a/README.rst +++ b/README.rst @@ -57,22 +57,20 @@ Alternatively you can download the code and install from the repository: First steps ########### -Firstly create your specification object. +Firstly create your OpenAPI object. .. code-block:: python - from jsonschema_path import SchemaPath + from openapi_core import OpenAPI - spec = SchemaPath.from_file_path('openapi.json') + openapi = OpenAPI.from_file_path('openapi.json') Now you can use it to validate and unmarshal against requests and/or responses. .. code-block:: python - from openapi_core import unmarshal_request - # raises error if request is invalid - result = unmarshal_request(request, spec=spec) + result = openapi.unmarshal_request(request) Retrieve validated and unmarshalled request data diff --git a/docs/customizations.rst b/docs/customizations.rst index 059cc745..a2019fbf 100644 --- a/docs/customizations.rst +++ b/docs/customizations.rst @@ -9,13 +9,15 @@ By default, the specified specification is also validated. If you know you have a valid specification already, disabling the validator can improve the performance. .. code-block:: python - :emphasize-lines: 4 + :emphasize-lines: 1,4,6 - validate_request( - request, - spec=spec, + from openapi_core import Config + + config = Config( spec_validator_cls=None, ) + openapi = OpenAPI.from_file_path('openapi.json', config=config) + openapi.validate_request(request) Media type deserializers ------------------------ @@ -25,7 +27,7 @@ OpenAPI comes with a set of built-in media type deserializers such as: ``applica You can also define your own ones. Pass custom defined media type deserializers dictionary with supported mimetypes as a key to `unmarshal_response` function: .. code-block:: python - :emphasize-lines: 13 + :emphasize-lines: 11 def protobuf_deserializer(message): feature = route_guide_pb2.Feature() @@ -36,11 +38,12 @@ You can also define your own ones. Pass custom defined media type deserializers 'application/protobuf': protobuf_deserializer, } - result = unmarshal_response( - request, response, - spec=spec, + config = Config( extra_media_type_deserializers=extra_media_type_deserializers, ) + openapi = OpenAPI.from_file_path('openapi.json', config=config) + + result = openapi.unmarshal_response(request, response) Format validators ----------------- @@ -52,7 +55,7 @@ OpenAPI comes with a set of built-in format validators, but it's also possible t Here's how you could add support for a ``usdate`` format that handles dates of the form MM/DD/YYYY: .. code-block:: python - :emphasize-lines: 13 + :emphasize-lines: 11 import re @@ -63,11 +66,12 @@ Here's how you could add support for a ``usdate`` format that handles dates of t 'usdate': validate_usdate, } - validate_response( - request, response, - spec=spec, + config = Config( extra_format_validators=extra_format_validators, ) + openapi = OpenAPI.from_file_path('openapi.json', config=config) + + openapi.validate_response(request, response) Format unmarshallers -------------------- @@ -79,7 +83,7 @@ Openapi-core comes with a set of built-in format unmarshallers, but it's also po Here's an example with the ``usdate`` format that converts a value to date object: .. code-block:: python - :emphasize-lines: 13 + :emphasize-lines: 11 from datetime import datetime @@ -90,8 +94,9 @@ Here's an example with the ``usdate`` format that converts a value to date objec 'usdate': unmarshal_usdate, } - result = unmarshal_response( - request, response, - spec=spec, + config = Config( extra_format_unmarshallers=extra_format_unmarshallers, ) + openapi = OpenAPI.from_file_path('openapi.json', config=config) + + result = openapi.unmarshal_response(request, response) diff --git a/docs/index.rst b/docs/index.rst index 9d309ad8..f2defc02 100644 --- a/docs/index.rst +++ b/docs/index.rst @@ -45,22 +45,20 @@ Installation First steps ----------- -Firstly create your specification object. +Firstly create your OpenAPI object. .. code-block:: python - from jsonschema_path import SchemaPath + from openapi_core import OpenAPI - spec = SchemaPath.from_file_path('openapi.json') + openapi = OpenAPI.from_file_path('openapi.json') Now you can use it to validate and unmarshal your requests and/or responses. .. code-block:: python - from openapi_core import unmarshal_request - # raises error if request is invalid - result = unmarshal_request(request, spec=spec) + result = openapi.unmarshal_request(request) Retrieve validated and unmarshalled request data diff --git a/docs/integrations.rst b/docs/integrations.rst index cf057c9e..c20247c1 100644 --- a/docs/integrations.rst +++ b/docs/integrations.rst @@ -48,20 +48,20 @@ The integration supports Django from version 3.0 and above. Middleware ~~~~~~~~~~ -Django can be integrated by middleware. Add ``DjangoOpenAPIMiddleware`` to your ``MIDDLEWARE`` list and define ``OPENAPI_SPEC``. +Django can be integrated by middleware. Add ``DjangoOpenAPIMiddleware`` to your ``MIDDLEWARE`` list and define ``OPENAPI``. .. code-block:: python :emphasize-lines: 6,9 # settings.py - from jsonschema_path import SchemaPath + from openapi_core import OpenAPI MIDDLEWARE = [ # ... 'openapi_core.contrib.django.middlewares.DjangoOpenAPIMiddleware', ] - OPENAPI_SPEC = SchemaPath.from_dict(spec_dict) + OPENAPI = OpenAPI.from_dict(spec_dict) You can skip response validation process: by setting ``OPENAPI_RESPONSE_CLS`` to ``None`` @@ -69,14 +69,14 @@ You can skip response validation process: by setting ``OPENAPI_RESPONSE_CLS`` to :emphasize-lines: 10 # settings.py - from jsonschema_path import SchemaPath + from openapi_core import OpenAPI MIDDLEWARE = [ # ... 'openapi_core.contrib.django.middlewares.DjangoOpenAPIMiddleware', ] - OPENAPI_SPEC = SchemaPath.from_dict(spec_dict) + OPENAPI = OpenAPI.from_dict(spec_dict) OPENAPI_RESPONSE_CLS = None After that you have access to unmarshal result object with all validated request data from Django view through request object. diff --git a/openapi_core/__init__.py b/openapi_core/__init__.py index ccb5b2d6..10c5dca3 100644 --- a/openapi_core/__init__.py +++ b/openapi_core/__init__.py @@ -1,4 +1,6 @@ """OpenAPI core module""" +from openapi_core.app import OpenAPI +from openapi_core.configurations import Config from openapi_core.shortcuts import unmarshal_apicall_request from openapi_core.shortcuts import unmarshal_apicall_response from openapi_core.shortcuts import unmarshal_request @@ -11,7 +13,7 @@ from openapi_core.shortcuts import validate_response from openapi_core.shortcuts import validate_webhook_request from openapi_core.shortcuts import validate_webhook_response -from openapi_core.spec import Spec +from openapi_core.spec.paths import Spec from openapi_core.unmarshalling.request import V3RequestUnmarshaller from openapi_core.unmarshalling.request import V3WebhookRequestUnmarshaller from openapi_core.unmarshalling.request import V30RequestUnmarshaller @@ -40,6 +42,8 @@ __license__ = "BSD 3-Clause License" __all__ = [ + "OpenAPI", + "Config", "Spec", "unmarshal_request", "unmarshal_response", diff --git a/openapi_core/app.py b/openapi_core/app.py new file mode 100644 index 00000000..bc13e9b4 --- /dev/null +++ b/openapi_core/app.py @@ -0,0 +1,363 @@ +"""OpenAPI core app module""" +import warnings +from dataclasses import dataclass +from dataclasses import field +from functools import lru_cache +from pathlib import Path +from typing import Any +from typing import Hashable +from typing import Mapping +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union + +from jsonschema._utils import Unset +from jsonschema.validators import _UNSET +from jsonschema_path import SchemaPath +from jsonschema_path.handlers.protocols import SupportsRead +from jsonschema_path.typing import Schema +from openapi_spec_validator import validate +from openapi_spec_validator.validation.exceptions import ValidatorDetectError +from openapi_spec_validator.validation.types import SpecValidatorType +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound +from openapi_spec_validator.versions.shortcuts import get_spec_version + +from openapi_core.configurations import Config +from openapi_core.exceptions import SpecError +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest +from openapi_core.types import AnyRequest +from openapi_core.unmarshalling.request import ( + UNMARSHALLERS as REQUEST_UNMARSHALLERS, +) +from openapi_core.unmarshalling.request import ( + WEBHOOK_UNMARSHALLERS as WEBHOOK_REQUEST_UNMARSHALLERS, +) +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.protocols import RequestUnmarshaller +from openapi_core.unmarshalling.request.protocols import ( + WebhookRequestUnmarshaller, +) +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.request.types import ( + WebhookRequestUnmarshallerType, +) +from openapi_core.unmarshalling.response import ( + UNMARSHALLERS as RESPONSE_UNMARSHALLERS, +) +from openapi_core.unmarshalling.response import ( + WEBHOOK_UNMARSHALLERS as WEBHOOK_RESPONSE_UNMARSHALLERS, +) +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.protocols import ResponseUnmarshaller +from openapi_core.unmarshalling.response.protocols import ( + WebhookResponseUnmarshaller, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.response.types import ( + WebhookResponseUnmarshallerType, +) +from openapi_core.validation.request import VALIDATORS as REQUEST_VALIDATORS +from openapi_core.validation.request import ( + WEBHOOK_VALIDATORS as WEBHOOK_REQUEST_VALIDATORS, +) +from openapi_core.validation.request.protocols import RequestValidator +from openapi_core.validation.request.protocols import WebhookRequestValidator +from openapi_core.validation.request.types import RequestValidatorType +from openapi_core.validation.request.types import WebhookRequestValidatorType +from openapi_core.validation.response import VALIDATORS as RESPONSE_VALIDATORS +from openapi_core.validation.response import ( + WEBHOOK_VALIDATORS as WEBHOOK_RESPONSE_VALIDATORS, +) +from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.validation.response.protocols import WebhookResponseValidator +from openapi_core.validation.response.types import ResponseValidatorType +from openapi_core.validation.response.types import WebhookResponseValidatorType + + +class OpenAPI: + """OpenAPI class.""" + + def __init__( + self, + spec: SchemaPath, + config: Optional[Config] = None, + ): + if not isinstance(spec, SchemaPath): + raise TypeError("'spec' argument is not type of SchemaPath") + + self.spec = spec + self.config = config or Config() + + self.check_spec() + + @classmethod + def from_dict( + cls, data: Schema, config: Optional[Config] = None + ) -> "OpenAPI": + sp = SchemaPath.from_dict(data) + return cls(sp, config=config) + + @classmethod + def from_path( + cls, path: Path, config: Optional[Config] = None + ) -> "OpenAPI": + sp = SchemaPath.from_path(path) + return cls(sp, config=config) + + @classmethod + def from_file_path( + cls, file_path: str, config: Optional[Config] = None + ) -> "OpenAPI": + sp = SchemaPath.from_file_path(file_path) + return cls(sp, config=config) + + @classmethod + def from_file( + cls, fileobj: SupportsRead, config: Optional[Config] = None + ) -> "OpenAPI": + sp = SchemaPath.from_file(fileobj) + return cls(sp, config=config) + + def _get_version(self) -> SpecVersion: + try: + return get_spec_version(self.spec.contents()) + # backward compatibility + except OpenAPIVersionNotFound: + raise SpecError("Spec schema version not detected") + + def check_spec(self) -> None: + if self.config.spec_validator_cls is None: + return + + cls = None + if self.config.spec_validator_cls is not _UNSET: + cls = self.config.spec_validator_cls + + try: + validate( + self.spec.contents(), + base_uri=self.config.spec_base_uri, + cls=cls, + ) + except ValidatorDetectError: + raise SpecError("spec not detected") + + @property + def version(self) -> SpecVersion: + return self._get_version() + + @property + def request_validator_cls(self) -> Optional[RequestValidatorType]: + if not isinstance(self.config.request_validator_cls, Unset): + return self.config.request_validator_cls + return REQUEST_VALIDATORS.get(self.version) + + @property + def response_validator_cls(self) -> Optional[ResponseValidatorType]: + if not isinstance(self.config.response_validator_cls, Unset): + return self.config.response_validator_cls + return RESPONSE_VALIDATORS.get(self.version) + + @property + def webhook_request_validator_cls( + self, + ) -> Optional[WebhookRequestValidatorType]: + if not isinstance(self.config.webhook_request_validator_cls, Unset): + return self.config.webhook_request_validator_cls + return WEBHOOK_REQUEST_VALIDATORS.get(self.version) + + @property + def webhook_response_validator_cls( + self, + ) -> Optional[WebhookResponseValidatorType]: + if not isinstance(self.config.webhook_response_validator_cls, Unset): + return self.config.webhook_response_validator_cls + return WEBHOOK_RESPONSE_VALIDATORS.get(self.version) + + @property + def request_unmarshaller_cls(self) -> Optional[RequestUnmarshallerType]: + if not isinstance(self.config.request_unmarshaller_cls, Unset): + return self.config.request_unmarshaller_cls + return REQUEST_UNMARSHALLERS.get(self.version) + + @property + def response_unmarshaller_cls(self) -> Optional[ResponseUnmarshallerType]: + if not isinstance(self.config.response_unmarshaller_cls, Unset): + return self.config.response_unmarshaller_cls + return RESPONSE_UNMARSHALLERS.get(self.version) + + @property + def webhook_request_unmarshaller_cls( + self, + ) -> Optional[WebhookRequestUnmarshallerType]: + if not isinstance(self.config.webhook_request_unmarshaller_cls, Unset): + return self.config.webhook_request_unmarshaller_cls + return WEBHOOK_REQUEST_UNMARSHALLERS.get(self.version) + + @property + def webhook_response_unmarshaller_cls( + self, + ) -> Optional[WebhookResponseUnmarshallerType]: + if not isinstance( + self.config.webhook_response_unmarshaller_cls, Unset + ): + return self.config.webhook_response_unmarshaller_cls + return WEBHOOK_RESPONSE_UNMARSHALLERS.get(self.version) + + @property + def request_validator(self) -> RequestValidator: + if self.request_validator_cls is None: + raise SpecError("Validator class not found") + return self.request_validator_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def response_validator(self) -> ResponseValidator: + if self.response_validator_cls is None: + raise SpecError("Validator class not found") + return self.response_validator_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def webhook_request_validator(self) -> WebhookRequestValidator: + if self.webhook_request_validator_cls is None: + raise SpecError("Validator class not found") + return self.webhook_request_validator_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def webhook_response_validator(self) -> WebhookResponseValidator: + if self.webhook_response_validator_cls is None: + raise SpecError("Validator class not found") + return self.webhook_response_validator_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def request_unmarshaller(self) -> RequestUnmarshaller: + if self.request_unmarshaller_cls is None: + raise SpecError("Unmarshaller class not found") + return self.request_unmarshaller_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def response_unmarshaller(self) -> ResponseUnmarshaller: + if self.response_unmarshaller_cls is None: + raise SpecError("Unmarshaller class not found") + return self.response_unmarshaller_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def webhook_request_unmarshaller(self) -> WebhookRequestUnmarshaller: + if self.webhook_request_unmarshaller_cls is None: + raise SpecError("Unmarshaller class not found") + return self.webhook_request_unmarshaller_cls( + self.spec, base_url=self.config.server_base_url + ) + + @property + def webhook_response_unmarshaller(self) -> WebhookResponseUnmarshaller: + if self.webhook_response_unmarshaller_cls is None: + raise SpecError("Unmarshaller class not found") + return self.webhook_response_unmarshaller_cls( + self.spec, base_url=self.config.server_base_url + ) + + def validate_request(self, request: AnyRequest) -> None: + if isinstance(request, WebhookRequest): + self.validate_webhook_request(request) + else: + self.validate_apicall_request(request) + + def validate_response( + self, request: AnyRequest, response: Response + ) -> None: + if isinstance(request, WebhookRequest): + self.validate_webhook_response(request, response) + else: + self.validate_apicall_response(request, response) + + def validate_apicall_request(self, request: Request) -> None: + if not isinstance(request, Request): + raise TypeError("'request' argument is not type of Request") + self.request_validator.validate(request) + + def validate_apicall_response( + self, request: Request, response: Response + ) -> None: + if not isinstance(request, Request): + raise TypeError("'request' argument is not type of Request") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + self.response_validator.validate(request, response) + + def validate_webhook_request(self, request: WebhookRequest) -> None: + if not isinstance(request, WebhookRequest): + raise TypeError("'request' argument is not type of WebhookRequest") + self.webhook_request_validator.validate(request) + + def validate_webhook_response( + self, request: WebhookRequest, response: Response + ) -> None: + if not isinstance(request, WebhookRequest): + raise TypeError("'request' argument is not type of WebhookRequest") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + self.webhook_response_validator.validate(request, response) + + def unmarshal_request(self, request: AnyRequest) -> RequestUnmarshalResult: + if isinstance(request, WebhookRequest): + return self.unmarshal_webhook_request(request) + else: + return self.unmarshal_apicall_request(request) + + def unmarshal_response( + self, request: AnyRequest, response: Response + ) -> ResponseUnmarshalResult: + if isinstance(request, WebhookRequest): + return self.unmarshal_webhook_response(request, response) + else: + return self.unmarshal_apicall_response(request, response) + + def unmarshal_apicall_request( + self, request: Request + ) -> RequestUnmarshalResult: + if not isinstance(request, Request): + raise TypeError("'request' argument is not type of Request") + return self.request_unmarshaller.unmarshal(request) + + def unmarshal_apicall_response( + self, request: Request, response: Response + ) -> ResponseUnmarshalResult: + if not isinstance(request, Request): + raise TypeError("'request' argument is not type of Request") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + return self.response_unmarshaller.unmarshal(request, response) + + def unmarshal_webhook_request( + self, request: WebhookRequest + ) -> RequestUnmarshalResult: + if not isinstance(request, WebhookRequest): + raise TypeError("'request' argument is not type of WebhookRequest") + return self.webhook_request_unmarshaller.unmarshal(request) + + def unmarshal_webhook_response( + self, request: WebhookRequest, response: Response + ) -> ResponseUnmarshalResult: + if not isinstance(request, WebhookRequest): + raise TypeError("'request' argument is not type of WebhookRequest") + if not isinstance(response, Response): + raise TypeError("'response' argument is not type of Response") + return self.webhook_response_unmarshaller.unmarshal(request, response) diff --git a/openapi_core/configurations.py b/openapi_core/configurations.py new file mode 100644 index 00000000..a348de21 --- /dev/null +++ b/openapi_core/configurations.py @@ -0,0 +1,126 @@ +import warnings +from dataclasses import dataclass +from dataclasses import field +from functools import lru_cache +from pathlib import Path +from typing import Any +from typing import Hashable +from typing import Mapping +from typing import Optional +from typing import Type +from typing import TypeVar +from typing import Union + +from jsonschema._utils import Unset +from jsonschema.validators import _UNSET +from jsonschema_path import SchemaPath +from jsonschema_path.handlers.protocols import SupportsRead +from jsonschema_path.typing import Schema +from openapi_spec_validator import validate +from openapi_spec_validator.validation.types import SpecValidatorType +from openapi_spec_validator.versions.datatypes import SpecVersion +from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound +from openapi_spec_validator.versions.shortcuts import get_spec_version + +from openapi_core.exceptions import SpecError +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.protocols import WebhookRequest +from openapi_core.types import AnyRequest +from openapi_core.unmarshalling.configurations import UnmarshallerConfig +from openapi_core.unmarshalling.request import ( + UNMARSHALLERS as REQUEST_UNMARSHALLERS, +) +from openapi_core.unmarshalling.request import ( + WEBHOOK_UNMARSHALLERS as WEBHOOK_REQUEST_UNMARSHALLERS, +) +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.protocols import RequestUnmarshaller +from openapi_core.unmarshalling.request.protocols import ( + WebhookRequestUnmarshaller, +) +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.request.types import ( + WebhookRequestUnmarshallerType, +) +from openapi_core.unmarshalling.response import ( + UNMARSHALLERS as RESPONSE_UNMARSHALLERS, +) +from openapi_core.unmarshalling.response import ( + WEBHOOK_UNMARSHALLERS as WEBHOOK_RESPONSE_UNMARSHALLERS, +) +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.protocols import ResponseUnmarshaller +from openapi_core.unmarshalling.response.protocols import ( + WebhookResponseUnmarshaller, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.response.types import ( + WebhookResponseUnmarshallerType, +) +from openapi_core.validation.request import VALIDATORS as REQUEST_VALIDATORS +from openapi_core.validation.request import ( + WEBHOOK_VALIDATORS as WEBHOOK_REQUEST_VALIDATORS, +) +from openapi_core.validation.request.protocols import RequestValidator +from openapi_core.validation.request.protocols import WebhookRequestValidator +from openapi_core.validation.request.types import RequestValidatorType +from openapi_core.validation.request.types import WebhookRequestValidatorType +from openapi_core.validation.response import VALIDATORS as RESPONSE_VALIDATORS +from openapi_core.validation.response import ( + WEBHOOK_VALIDATORS as WEBHOOK_RESPONSE_VALIDATORS, +) +from openapi_core.validation.response.protocols import ResponseValidator +from openapi_core.validation.response.protocols import WebhookResponseValidator +from openapi_core.validation.response.types import ResponseValidatorType +from openapi_core.validation.response.types import WebhookResponseValidatorType + + +@dataclass +class Config(UnmarshallerConfig): + """OpenAPI configuration dataclass. + + Attributes: + spec_validator_cls + Specifincation validator class. + spec_base_uri + Specification base uri. + request_validator_cls + Request validator class. + response_validator_cls + Response validator class. + webhook_request_validator_cls + Webhook request validator class. + webhook_response_validator_cls + Webhook response validator class. + request_unmarshaller_cls + Request unmarshaller class. + response_unmarshaller_cls + Response unmarshaller class. + webhook_request_unmarshaller_cls + Webhook request unmarshaller class. + webhook_response_unmarshaller_cls + Webhook response unmarshaller class. + """ + + spec_validator_cls: Union[SpecValidatorType, Unset] = _UNSET + spec_base_uri: str = "" + + request_validator_cls: Union[RequestValidatorType, Unset] = _UNSET + response_validator_cls: Union[ResponseValidatorType, Unset] = _UNSET + webhook_request_validator_cls: Union[ + WebhookRequestValidatorType, Unset + ] = _UNSET + webhook_response_validator_cls: Union[ + WebhookResponseValidatorType, Unset + ] = _UNSET + request_unmarshaller_cls: Union[RequestUnmarshallerType, Unset] = _UNSET + response_unmarshaller_cls: Union[ResponseUnmarshallerType, Unset] = _UNSET + webhook_request_unmarshaller_cls: Union[ + WebhookRequestUnmarshallerType, Unset + ] = _UNSET + webhook_response_unmarshaller_cls: Union[ + WebhookResponseUnmarshallerType, Unset + ] = _UNSET diff --git a/openapi_core/contrib/django/integrations.py b/openapi_core/contrib/django/integrations.py new file mode 100644 index 00000000..520aa7a6 --- /dev/null +++ b/openapi_core/contrib/django/integrations.py @@ -0,0 +1,36 @@ +from django.http.request import HttpRequest +from django.http.response import HttpResponse + +from openapi_core.contrib.django.requests import DjangoOpenAPIRequest +from openapi_core.contrib.django.responses import DjangoOpenAPIResponse +from openapi_core.unmarshalling.processors import UnmarshallingProcessor +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable + + +class DjangoIntegration(UnmarshallingProcessor[HttpRequest, HttpResponse]): + request_cls = DjangoOpenAPIRequest + response_cls = DjangoOpenAPIResponse + + def get_openapi_request( + self, request: HttpRequest + ) -> DjangoOpenAPIRequest: + return self.request_cls(request) + + def get_openapi_response( + self, response: HttpResponse + ) -> DjangoOpenAPIResponse: + assert self.response_cls is not None + return self.response_cls(response) + + def should_validate_response(self) -> bool: + return self.response_cls is not None + + def handle_response( + self, + request: HttpRequest, + response: HttpResponse, + errors_handler: ErrorsHandlerCallable[HttpResponse], + ) -> HttpResponse: + if not self.should_validate_response(): + return response + return super().handle_response(request, response, errors_handler) diff --git a/openapi_core/contrib/django/middlewares.py b/openapi_core/contrib/django/middlewares.py index db87751f..aa410c57 100644 --- a/openapi_core/contrib/django/middlewares.py +++ b/openapi_core/contrib/django/middlewares.py @@ -1,4 +1,5 @@ """OpenAPI core contrib django middlewares module""" +import warnings from typing import Callable from django.conf import settings @@ -6,33 +7,42 @@ from django.http.request import HttpRequest from django.http.response import HttpResponse +from openapi_core import OpenAPI from openapi_core.contrib.django.handlers import DjangoOpenAPIErrorsHandler from openapi_core.contrib.django.handlers import ( DjangoOpenAPIValidRequestHandler, ) +from openapi_core.contrib.django.integrations import DjangoIntegration from openapi_core.contrib.django.requests import DjangoOpenAPIRequest from openapi_core.contrib.django.responses import DjangoOpenAPIResponse from openapi_core.unmarshalling.processors import UnmarshallingProcessor -class DjangoOpenAPIMiddleware( - UnmarshallingProcessor[HttpRequest, HttpResponse] -): - request_cls = DjangoOpenAPIRequest - response_cls = DjangoOpenAPIResponse +class DjangoOpenAPIMiddleware(DjangoIntegration): valid_request_handler_cls = DjangoOpenAPIValidRequestHandler errors_handler = DjangoOpenAPIErrorsHandler() def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): self.get_response = get_response - if not hasattr(settings, "OPENAPI_SPEC"): - raise ImproperlyConfigured("OPENAPI_SPEC not defined in settings") - if hasattr(settings, "OPENAPI_RESPONSE_CLS"): self.response_cls = settings.OPENAPI_RESPONSE_CLS - super().__init__(settings.OPENAPI_SPEC) + if not hasattr(settings, "OPENAPI"): + if not hasattr(settings, "OPENAPI_SPEC"): + raise ImproperlyConfigured( + "OPENAPI_SPEC not defined in settings" + ) + else: + warnings.warn( + "OPENAPI_SPEC is deprecated. Use OPENAPI instead.", + DeprecationWarning, + ) + openapi = OpenAPI(settings.OPENAPI_SPEC) + else: + openapi = settings.OPENAPI + + super().__init__(openapi) def __call__(self, request: HttpRequest) -> HttpResponse: valid_request_handler = self.valid_request_handler_cls( @@ -43,17 +53,3 @@ def __call__(self, request: HttpRequest) -> HttpResponse: ) return self.handle_response(request, response, self.errors_handler) - - def _get_openapi_request( - self, request: HttpRequest - ) -> DjangoOpenAPIRequest: - return self.request_cls(request) - - def _get_openapi_response( - self, response: HttpResponse - ) -> DjangoOpenAPIResponse: - assert self.response_cls is not None - return self.response_cls(response) - - def _validate_response(self) -> bool: - return self.response_cls is not None diff --git a/openapi_core/contrib/falcon/integrations.py b/openapi_core/contrib/falcon/integrations.py new file mode 100644 index 00000000..8c3fa544 --- /dev/null +++ b/openapi_core/contrib/falcon/integrations.py @@ -0,0 +1,34 @@ +from falcon.request import Request +from falcon.response import Response + +from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest +from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse +from openapi_core.unmarshalling.processors import UnmarshallingProcessor +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable + + +class FalconIntegration(UnmarshallingProcessor[Request, Response]): + request_cls = FalconOpenAPIRequest + response_cls = FalconOpenAPIResponse + + def get_openapi_request(self, request: Request) -> FalconOpenAPIRequest: + return self.request_cls(request) + + def get_openapi_response( + self, response: Response + ) -> FalconOpenAPIResponse: + assert self.response_cls is not None + return self.response_cls(response) + + def should_validate_response(self) -> bool: + return self.response_cls is not None + + def handle_response( + self, + request: Request, + response: Response, + errors_handler: ErrorsHandlerCallable[Response], + ) -> Response: + if not self.should_validate_response(): + return response + return super().handle_response(request, response, errors_handler) diff --git a/openapi_core/contrib/falcon/middlewares.py b/openapi_core/contrib/falcon/middlewares.py index 4fc71661..29b8bfba 100644 --- a/openapi_core/contrib/falcon/middlewares.py +++ b/openapi_core/contrib/falcon/middlewares.py @@ -2,15 +2,21 @@ from typing import Any from typing import Optional from typing import Type +from typing import Union from falcon.request import Request from falcon.response import Response +from jsonschema._utils import Unset +from jsonschema.validators import _UNSET from jsonschema_path import SchemaPath +from openapi_core import Config +from openapi_core import OpenAPI from openapi_core.contrib.falcon.handlers import FalconOpenAPIErrorsHandler from openapi_core.contrib.falcon.handlers import ( FalconOpenAPIValidRequestHandler, ) +from openapi_core.contrib.falcon.integrations import FalconIntegration from openapi_core.contrib.falcon.requests import FalconOpenAPIRequest from openapi_core.contrib.falcon.responses import FalconOpenAPIResponse from openapi_core.unmarshalling.processors import UnmarshallingProcessor @@ -18,9 +24,7 @@ from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType -class FalconOpenAPIMiddleware(UnmarshallingProcessor[Request, Response]): - request_cls = FalconOpenAPIRequest - response_cls = FalconOpenAPIResponse +class FalconOpenAPIMiddleware(FalconIntegration): valid_request_handler_cls = FalconOpenAPIValidRequestHandler errors_handler_cls: Type[ FalconOpenAPIErrorsHandler @@ -28,9 +32,7 @@ class FalconOpenAPIMiddleware(UnmarshallingProcessor[Request, Response]): def __init__( self, - spec: SchemaPath, - request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, - response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + openapi: OpenAPI, request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler_cls: Type[ @@ -38,12 +40,7 @@ def __init__( ] = FalconOpenAPIErrorsHandler, **unmarshaller_kwargs: Any, ): - super().__init__( - spec, - request_unmarshaller_cls=request_unmarshaller_cls, - response_unmarshaller_cls=response_unmarshaller_cls, - **unmarshaller_kwargs, - ) + super().__init__(openapi) self.request_cls = request_cls or self.request_cls self.response_cls = response_cls or self.response_cls self.errors_handler_cls = errors_handler_cls or self.errors_handler_cls @@ -52,8 +49,12 @@ def __init__( def from_spec( cls, spec: SchemaPath, - request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, - response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + request_unmarshaller_cls: Union[ + RequestUnmarshallerType, Unset + ] = _UNSET, + response_unmarshaller_cls: Union[ + ResponseUnmarshallerType, Unset + ] = _UNSET, request_cls: Type[FalconOpenAPIRequest] = FalconOpenAPIRequest, response_cls: Type[FalconOpenAPIResponse] = FalconOpenAPIResponse, errors_handler_cls: Type[ @@ -61,8 +62,13 @@ def from_spec( ] = FalconOpenAPIErrorsHandler, **unmarshaller_kwargs: Any, ) -> "FalconOpenAPIMiddleware": + config = Config( + request_unmarshaller_cls=request_unmarshaller_cls, + response_unmarshaller_cls=response_unmarshaller_cls, + ) + openapi = OpenAPI(spec, config=config) return cls( - spec, + openapi, request_unmarshaller_cls=request_unmarshaller_cls, response_unmarshaller_cls=response_unmarshaller_cls, request_cls=request_cls, @@ -81,15 +87,3 @@ def process_response( ) -> None: errors_handler = self.errors_handler_cls(req, resp) self.handle_response(req, resp, errors_handler) - - def _get_openapi_request(self, request: Request) -> FalconOpenAPIRequest: - return self.request_cls(request) - - def _get_openapi_response( - self, response: Response - ) -> FalconOpenAPIResponse: - assert self.response_cls is not None - return self.response_cls(response) - - def _validate_response(self) -> bool: - return self.response_cls is not None diff --git a/openapi_core/contrib/flask/decorators.py b/openapi_core/contrib/flask/decorators.py index a379d136..497b60d8 100644 --- a/openapi_core/contrib/flask/decorators.py +++ b/openapi_core/contrib/flask/decorators.py @@ -10,8 +10,10 @@ from flask.wrappers import Response from jsonschema_path import SchemaPath +from openapi_core import OpenAPI from openapi_core.contrib.flask.handlers import FlaskOpenAPIErrorsHandler from openapi_core.contrib.flask.handlers import FlaskOpenAPIValidRequestHandler +from openapi_core.contrib.flask.integrations import FlaskIntegration from openapi_core.contrib.flask.providers import FlaskRequestProvider from openapi_core.contrib.flask.requests import FlaskOpenAPIRequest from openapi_core.contrib.flask.responses import FlaskOpenAPIResponse @@ -20,7 +22,7 @@ from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType -class FlaskOpenAPIViewDecorator(UnmarshallingProcessor[Request, Response]): +class FlaskOpenAPIViewDecorator(FlaskIntegration): valid_request_handler_cls = FlaskOpenAPIValidRequestHandler errors_handler_cls: Type[ FlaskOpenAPIErrorsHandler @@ -28,25 +30,15 @@ class FlaskOpenAPIViewDecorator(UnmarshallingProcessor[Request, Response]): def __init__( self, - spec: SchemaPath, - request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, - response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, + openapi: OpenAPI, request_cls: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, - response_cls: Optional[ - Type[FlaskOpenAPIResponse] - ] = FlaskOpenAPIResponse, + response_cls: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, errors_handler_cls: Type[ FlaskOpenAPIErrorsHandler ] = FlaskOpenAPIErrorsHandler, - **unmarshaller_kwargs: Any, ): - super().__init__( - spec, - request_unmarshaller_cls=request_unmarshaller_cls, - response_unmarshaller_cls=response_unmarshaller_cls, - **unmarshaller_kwargs, - ) + super().__init__(openapi) self.request_cls = request_cls self.response_cls = response_cls self.request_provider = request_provider @@ -55,7 +47,7 @@ def __init__( def __call__(self, view: Callable[..., Any]) -> Callable[..., Any]: @wraps(view) def decorated(*args: Any, **kwargs: Any) -> Response: - request = self._get_request() + request = self.get_request() valid_request_handler = self.valid_request_handler_cls( request, view, *args, **kwargs ) @@ -67,42 +59,25 @@ def decorated(*args: Any, **kwargs: Any) -> Response: return decorated - def _get_request(self) -> Request: + def get_request(self) -> Request: return request - def _get_openapi_request(self, request: Request) -> FlaskOpenAPIRequest: - return self.request_cls(request) - - def _get_openapi_response( - self, response: Response - ) -> FlaskOpenAPIResponse: - assert self.response_cls is not None - return self.response_cls(response) - - def _validate_response(self) -> bool: - return self.response_cls is not None - @classmethod def from_spec( cls, spec: SchemaPath, - request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, - response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, request_cls: Type[FlaskOpenAPIRequest] = FlaskOpenAPIRequest, response_cls: Type[FlaskOpenAPIResponse] = FlaskOpenAPIResponse, request_provider: Type[FlaskRequestProvider] = FlaskRequestProvider, errors_handler_cls: Type[ FlaskOpenAPIErrorsHandler ] = FlaskOpenAPIErrorsHandler, - **unmarshaller_kwargs: Any, ) -> "FlaskOpenAPIViewDecorator": + openapi = OpenAPI(spec) return cls( - spec, - request_unmarshaller_cls=request_unmarshaller_cls, - response_unmarshaller_cls=response_unmarshaller_cls, + openapi, request_cls=request_cls, response_cls=response_cls, request_provider=request_provider, errors_handler_cls=errors_handler_cls, - **unmarshaller_kwargs, ) diff --git a/openapi_core/contrib/flask/integrations.py b/openapi_core/contrib/flask/integrations.py new file mode 100644 index 00000000..49f7009e --- /dev/null +++ b/openapi_core/contrib/flask/integrations.py @@ -0,0 +1,32 @@ +from flask.wrappers import Request +from flask.wrappers import Response + +from openapi_core.contrib.flask.requests import FlaskOpenAPIRequest +from openapi_core.contrib.flask.responses import FlaskOpenAPIResponse +from openapi_core.unmarshalling.processors import UnmarshallingProcessor +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable + + +class FlaskIntegration(UnmarshallingProcessor[Request, Response]): + request_cls = FlaskOpenAPIRequest + response_cls = FlaskOpenAPIResponse + + def get_openapi_request(self, request: Request) -> FlaskOpenAPIRequest: + return self.request_cls(request) + + def get_openapi_response(self, response: Response) -> FlaskOpenAPIResponse: + assert self.response_cls is not None + return self.response_cls(response) + + def should_validate_response(self) -> bool: + return self.response_cls is not None + + def handle_response( + self, + request: Request, + response: Response, + errors_handler: ErrorsHandlerCallable[Response], + ) -> Response: + if not self.should_validate_response(): + return response + return super().handle_response(request, response, errors_handler) diff --git a/openapi_core/contrib/flask/views.py b/openapi_core/contrib/flask/views.py index 5fc233b4..0f72a018 100644 --- a/openapi_core/contrib/flask/views.py +++ b/openapi_core/contrib/flask/views.py @@ -2,8 +2,8 @@ from typing import Any from flask.views import MethodView -from jsonschema_path import SchemaPath +from openapi_core import OpenAPI from openapi_core.contrib.flask.decorators import FlaskOpenAPIViewDecorator from openapi_core.contrib.flask.handlers import FlaskOpenAPIErrorsHandler @@ -13,13 +13,12 @@ class FlaskOpenAPIView(MethodView): openapi_errors_handler = FlaskOpenAPIErrorsHandler - def __init__(self, spec: SchemaPath, **unmarshaller_kwargs: Any): + def __init__(self, openapi: OpenAPI): super().__init__() self.decorator = FlaskOpenAPIViewDecorator( - spec, + openapi, errors_handler_cls=self.openapi_errors_handler, - **unmarshaller_kwargs, ) def dispatch_request(self, *args: Any, **kwargs: Any) -> Any: diff --git a/openapi_core/contrib/starlette/integrations.py b/openapi_core/contrib/starlette/integrations.py new file mode 100644 index 00000000..3f30c969 --- /dev/null +++ b/openapi_core/contrib/starlette/integrations.py @@ -0,0 +1,54 @@ +from typing import Callable + +from aioitertools.builtins import list as alist +from aioitertools.itertools import tee as atee +from starlette.requests import Request +from starlette.responses import Response +from starlette.responses import StreamingResponse + +from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest +from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse +from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable + + +class StarletteIntegration(AsyncUnmarshallingProcessor[Request, Response]): + request_cls = StarletteOpenAPIRequest + response_cls = StarletteOpenAPIResponse + + async def get_openapi_request( + self, request: Request + ) -> StarletteOpenAPIRequest: + body = await request.body() + return self.request_cls(request, body) + + async def get_openapi_response( + self, response: Response + ) -> StarletteOpenAPIResponse: + assert self.response_cls is not None + data = None + if isinstance(response, StreamingResponse): + body_iter1, body_iter2 = atee(response.body_iterator) + response.body_iterator = body_iter2 + data = b"".join( + [ + chunk.encode(response.charset) + if not isinstance(chunk, bytes) + else chunk + async for chunk in body_iter1 + ] + ) + return self.response_cls(response, data=data) + + def should_validate_response(self) -> bool: + return self.response_cls is not None + + async def handle_response( + self, + request: Request, + response: Response, + errors_handler: ErrorsHandlerCallable[Response], + ) -> Response: + if not self.should_validate_response(): + return response + return await super().handle_response(request, response, errors_handler) diff --git a/openapi_core/contrib/starlette/middlewares.py b/openapi_core/contrib/starlette/middlewares.py index f9bfb779..9bea9066 100644 --- a/openapi_core/contrib/starlette/middlewares.py +++ b/openapi_core/contrib/starlette/middlewares.py @@ -1,9 +1,4 @@ """OpenAPI core contrib starlette middlewares module""" -from typing import Callable - -from aioitertools.builtins import list as alist -from aioitertools.itertools import tee as atee -from jsonschema_path import SchemaPath from starlette.middleware.base import BaseHTTPMiddleware from starlette.middleware.base import RequestResponseEndpoint from starlette.requests import Request @@ -11,28 +6,26 @@ from starlette.responses import StreamingResponse from starlette.types import ASGIApp +from openapi_core import OpenAPI from openapi_core.contrib.starlette.handlers import ( StarletteOpenAPIErrorsHandler, ) from openapi_core.contrib.starlette.handlers import ( StarletteOpenAPIValidRequestHandler, ) +from openapi_core.contrib.starlette.integrations import StarletteIntegration from openapi_core.contrib.starlette.requests import StarletteOpenAPIRequest from openapi_core.contrib.starlette.responses import StarletteOpenAPIResponse from openapi_core.unmarshalling.processors import AsyncUnmarshallingProcessor -class StarletteOpenAPIMiddleware( - BaseHTTPMiddleware, AsyncUnmarshallingProcessor[Request, Response] -): - request_cls = StarletteOpenAPIRequest - response_cls = StarletteOpenAPIResponse +class StarletteOpenAPIMiddleware(StarletteIntegration, BaseHTTPMiddleware): valid_request_handler_cls = StarletteOpenAPIValidRequestHandler errors_handler = StarletteOpenAPIErrorsHandler() - def __init__(self, app: ASGIApp, spec: SchemaPath): + def __init__(self, app: ASGIApp, openapi: OpenAPI): + super().__init__(openapi) BaseHTTPMiddleware.__init__(self, app) - AsyncUnmarshallingProcessor.__init__(self, spec) async def dispatch( self, request: Request, call_next: RequestResponseEndpoint @@ -46,30 +39,3 @@ async def dispatch( return await self.handle_response( request, response, self.errors_handler ) - - async def _get_openapi_request( - self, request: Request - ) -> StarletteOpenAPIRequest: - body = await request.body() - return self.request_cls(request, body) - - async def _get_openapi_response( - self, response: Response - ) -> StarletteOpenAPIResponse: - assert self.response_cls is not None - data = None - if isinstance(response, StreamingResponse): - body_iter1, body_iter2 = atee(response.body_iterator) - response.body_iterator = body_iter2 - data = b"".join( - [ - chunk.encode(response.charset) - if not isinstance(chunk, bytes) - else chunk - async for chunk in body_iter1 - ] - ) - return self.response_cls(response, data=data) - - def _validate_response(self) -> bool: - return self.response_cls is not None diff --git a/openapi_core/protocols.py b/openapi_core/protocols.py index 82bf1532..338225c9 100644 --- a/openapi_core/protocols.py +++ b/openapi_core/protocols.py @@ -6,6 +6,8 @@ from typing import runtime_checkable from openapi_core.datatypes import RequestParameters +from openapi_core.typing import RequestType +from openapi_core.typing import ResponseType @runtime_checkable diff --git a/openapi_core/shortcuts.py b/openapi_core/shortcuts.py index 00717ffa..34a149e8 100644 --- a/openapi_core/shortcuts.py +++ b/openapi_core/shortcuts.py @@ -4,18 +4,20 @@ from typing import Optional from typing import Union +from jsonschema.validators import _UNSET from jsonschema_path import SchemaPath from openapi_spec_validator.versions import consts as versions from openapi_spec_validator.versions.datatypes import SpecVersion from openapi_spec_validator.versions.exceptions import OpenAPIVersionNotFound from openapi_spec_validator.versions.shortcuts import get_spec_version +from openapi_core.app import OpenAPI +from openapi_core.configurations import Config from openapi_core.exceptions import SpecError from openapi_core.protocols import Request from openapi_core.protocols import Response from openapi_core.protocols import WebhookRequest -from openapi_core.spec import Spec -from openapi_core.types import SpecClasses +from openapi_core.types import AnyRequest from openapi_core.unmarshalling.request import V30RequestUnmarshaller from openapi_core.unmarshalling.request import V31RequestUnmarshaller from openapi_core.unmarshalling.request import V31WebhookRequestUnmarshaller @@ -63,43 +65,6 @@ from openapi_core.validation.response.types import ResponseValidatorType from openapi_core.validation.response.types import WebhookResponseValidatorType -AnyRequest = Union[Request, WebhookRequest] - -SPEC2CLASSES: Dict[SpecVersion, SpecClasses] = { - versions.OPENAPIV30: SpecClasses( - V30RequestValidator, - V30ResponseValidator, - None, - None, - V30RequestUnmarshaller, - V30ResponseUnmarshaller, - None, - None, - ), - versions.OPENAPIV31: SpecClasses( - V31RequestValidator, - V31ResponseValidator, - V31WebhookRequestValidator, - V31WebhookResponseValidator, - V31RequestUnmarshaller, - V31ResponseUnmarshaller, - V31WebhookRequestUnmarshaller, - V31WebhookResponseUnmarshaller, - ), -} - - -def get_classes(spec: SchemaPath) -> SpecClasses: - try: - spec_version = get_spec_version(spec.contents()) - # backward compatibility - except OpenAPIVersionNotFound: - raise SpecError("Spec schema version not detected") - try: - return SPEC2CLASSES[spec_version] - except KeyError: - raise SpecError("Spec schema version not supported") - def unmarshal_apicall_request( request: Request, @@ -108,18 +73,12 @@ def unmarshal_apicall_request( cls: Optional[RequestUnmarshallerType] = None, **unmarshaller_kwargs: Any, ) -> RequestUnmarshalResult: - if not isinstance(request, Request): - raise TypeError("'request' argument is not type of Request") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.request_unmarshaller_cls - if not issubclass(cls, RequestUnmarshaller): - raise TypeError("'cls' argument is not type of RequestUnmarshaller") - v = cls(spec, base_url=base_url, **unmarshaller_kwargs) - v.check_spec(spec) - result = v.unmarshal(request) + config = Config( + server_base_url=base_url, + request_unmarshaller_cls=cls or _UNSET, + **unmarshaller_kwargs, + ) + result = OpenAPI(spec, config=config).unmarshal_apicall_request(request) result.raise_for_errors() return result @@ -131,22 +90,12 @@ def unmarshal_webhook_request( cls: Optional[WebhookRequestUnmarshallerType] = None, **unmarshaller_kwargs: Any, ) -> RequestUnmarshalResult: - if not isinstance(request, WebhookRequest): - raise TypeError("'request' argument is not type of WebhookRequest") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.webhook_request_unmarshaller_cls - if cls is None: - raise SpecError("Unmarshaller class not found") - if not issubclass(cls, WebhookRequestUnmarshaller): - raise TypeError( - "'cls' argument is not type of WebhookRequestUnmarshaller" - ) - v = cls(spec, base_url=base_url, **unmarshaller_kwargs) - v.check_spec(spec) - result = v.unmarshal(request) + config = Config( + server_base_url=base_url, + webhook_request_unmarshaller_cls=cls or _UNSET, + **unmarshaller_kwargs, + ) + result = OpenAPI(spec, config=config).unmarshal_webhook_request(request) result.raise_for_errors() return result @@ -158,36 +107,15 @@ def unmarshal_request( cls: Optional[AnyRequestUnmarshallerType] = None, **unmarshaller_kwargs: Any, ) -> RequestUnmarshalResult: - if not isinstance(request, (Request, WebhookRequest)): - raise TypeError("'request' argument is not type of (Webhook)Request") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if isinstance(request, WebhookRequest): - if cls is None or issubclass(cls, WebhookRequestUnmarshaller): - return unmarshal_webhook_request( - request, - spec, - base_url=base_url, - cls=cls, - **unmarshaller_kwargs, - ) - else: - raise TypeError( - "'cls' argument is not type of WebhookRequestUnmarshaller" - ) - else: - if cls is None or issubclass(cls, RequestUnmarshaller): - return unmarshal_apicall_request( - request, - spec, - base_url=base_url, - cls=cls, - **unmarshaller_kwargs, - ) - else: - raise TypeError( - "'cls' argument is not type of RequestUnmarshaller" - ) + config = Config( + server_base_url=base_url, + request_unmarshaller_cls=cls or _UNSET, + webhook_request_unmarshaller_cls=cls or _UNSET, + **unmarshaller_kwargs, + ) + result = OpenAPI(spec, config=config).unmarshal_request(request) + result.raise_for_errors() + return result def unmarshal_apicall_response( @@ -198,20 +126,14 @@ def unmarshal_apicall_response( cls: Optional[ResponseUnmarshallerType] = None, **unmarshaller_kwargs: Any, ) -> ResponseUnmarshalResult: - if not isinstance(request, Request): - raise TypeError("'request' argument is not type of Request") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.response_unmarshaller_cls - if not issubclass(cls, ResponseUnmarshaller): - raise TypeError("'cls' argument is not type of ResponseUnmarshaller") - v = cls(spec, base_url=base_url, **unmarshaller_kwargs) - v.check_spec(spec) - result = v.unmarshal(request, response) + config = Config( + server_base_url=base_url, + response_unmarshaller_cls=cls or _UNSET, + **unmarshaller_kwargs, + ) + result = OpenAPI(spec, config=config).unmarshal_apicall_response( + request, response + ) result.raise_for_errors() return result @@ -224,24 +146,14 @@ def unmarshal_webhook_response( cls: Optional[WebhookResponseUnmarshallerType] = None, **unmarshaller_kwargs: Any, ) -> ResponseUnmarshalResult: - if not isinstance(request, WebhookRequest): - raise TypeError("'request' argument is not type of WebhookRequest") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.webhook_response_unmarshaller_cls - if cls is None: - raise SpecError("Unmarshaller class not found") - if not issubclass(cls, WebhookResponseUnmarshaller): - raise TypeError( - "'cls' argument is not type of WebhookResponseUnmarshaller" - ) - v = cls(spec, base_url=base_url, **unmarshaller_kwargs) - v.check_spec(spec) - result = v.unmarshal(request, response) + config = Config( + server_base_url=base_url, + webhook_response_unmarshaller_cls=cls or _UNSET, + **unmarshaller_kwargs, + ) + result = OpenAPI(spec, config=config).unmarshal_webhook_response( + request, response + ) result.raise_for_errors() return result @@ -254,40 +166,15 @@ def unmarshal_response( cls: Optional[AnyResponseUnmarshallerType] = None, **unmarshaller_kwargs: Any, ) -> ResponseUnmarshalResult: - if not isinstance(request, (Request, WebhookRequest)): - raise TypeError("'request' argument is not type of (Webhook)Request") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if isinstance(request, WebhookRequest): - if cls is None or issubclass(cls, WebhookResponseUnmarshaller): - return unmarshal_webhook_response( - request, - response, - spec, - base_url=base_url, - cls=cls, - **unmarshaller_kwargs, - ) - else: - raise TypeError( - "'cls' argument is not type of WebhookResponseUnmarshaller" - ) - else: - if cls is None or issubclass(cls, ResponseUnmarshaller): - return unmarshal_apicall_response( - request, - response, - spec, - base_url=base_url, - cls=cls, - **unmarshaller_kwargs, - ) - else: - raise TypeError( - "'cls' argument is not type of ResponseUnmarshaller" - ) + config = Config( + server_base_url=base_url, + response_unmarshaller_cls=cls or _UNSET, + webhook_response_unmarshaller_cls=cls or _UNSET, + **unmarshaller_kwargs, + ) + result = OpenAPI(spec, config=config).unmarshal_response(request, response) + result.raise_for_errors() + return result def validate_request( @@ -296,83 +183,31 @@ def validate_request( base_url: Optional[str] = None, cls: Optional[AnyRequestValidatorType] = None, **validator_kwargs: Any, -) -> Optional[RequestUnmarshalResult]: - if not isinstance(request, (Request, WebhookRequest)): - raise TypeError("'request' argument is not type of (Webhook)Request") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - - if isinstance(request, WebhookRequest): - if cls is None or issubclass(cls, WebhookRequestValidator): - validate_webhook_request( - request, - spec, - base_url=base_url, - cls=cls, - **validator_kwargs, - ) - return None - else: - raise TypeError( - "'cls' argument is not type of WebhookRequestValidator" - ) - else: - if cls is None or issubclass(cls, RequestValidator): - validate_apicall_request( - request, - spec, - base_url=base_url, - cls=cls, - **validator_kwargs, - ) - return None - else: - raise TypeError("'cls' argument is not type of RequestValidator") +) -> None: + config = Config( + server_base_url=base_url, + request_validator_cls=cls or _UNSET, + webhook_request_validator_cls=cls or _UNSET, + **validator_kwargs, + ) + return OpenAPI(spec, config=config).validate_request(request) def validate_response( - request: Union[Request, WebhookRequest, Spec], - response: Union[Response, Request, WebhookRequest], - spec: Union[SchemaPath, Response], + request: Union[Request, WebhookRequest], + response: Response, + spec: SchemaPath, base_url: Optional[str] = None, cls: Optional[AnyResponseValidatorType] = None, **validator_kwargs: Any, -) -> Optional[ResponseUnmarshalResult]: - if not isinstance(request, (Request, WebhookRequest)): - raise TypeError("'request' argument is not type of (Webhook)Request") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - - if isinstance(request, WebhookRequest): - if cls is None or issubclass(cls, WebhookResponseValidator): - validate_webhook_response( - request, - response, - spec, - base_url=base_url, - cls=cls, - **validator_kwargs, - ) - return None - else: - raise TypeError( - "'cls' argument is not type of WebhookResponseValidator" - ) - else: - if cls is None or issubclass(cls, ResponseValidator): - validate_apicall_response( - request, - response, - spec, - base_url=base_url, - cls=cls, - **validator_kwargs, - ) - return None - else: - raise TypeError("'cls' argument is not type of ResponseValidator") +) -> None: + config = Config( + server_base_url=base_url, + response_validator_cls=cls or _UNSET, + webhook_response_validator_cls=cls or _UNSET, + **validator_kwargs, + ) + return OpenAPI(spec, config=config).validate_response(request, response) def validate_apicall_request( @@ -382,18 +217,12 @@ def validate_apicall_request( cls: Optional[RequestValidatorType] = None, **validator_kwargs: Any, ) -> None: - if not isinstance(request, Request): - raise TypeError("'request' argument is not type of Request") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.request_validator_cls - if not issubclass(cls, RequestValidator): - raise TypeError("'cls' argument is not type of RequestValidator") - v = cls(spec, base_url=base_url, **validator_kwargs) - v.check_spec(spec) - return v.validate(request) + config = Config( + server_base_url=base_url, + request_validator_cls=cls or _UNSET, + **validator_kwargs, + ) + return OpenAPI(spec, config=config).validate_apicall_request(request) def validate_webhook_request( @@ -403,22 +232,12 @@ def validate_webhook_request( cls: Optional[WebhookRequestValidatorType] = None, **validator_kwargs: Any, ) -> None: - if not isinstance(request, WebhookRequest): - raise TypeError("'request' argument is not type of WebhookRequest") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.webhook_request_validator_cls - if cls is None: - raise SpecError("Validator class not found") - if not issubclass(cls, WebhookRequestValidator): - raise TypeError( - "'cls' argument is not type of WebhookRequestValidator" - ) - v = cls(spec, base_url=base_url, **validator_kwargs) - v.check_spec(spec) - return v.validate(request) + config = Config( + server_base_url=base_url, + webhook_request_validator_cls=cls or _UNSET, + **validator_kwargs, + ) + return OpenAPI(spec, config=config).validate_webhook_request(request) def validate_apicall_response( @@ -429,20 +248,14 @@ def validate_apicall_response( cls: Optional[ResponseValidatorType] = None, **validator_kwargs: Any, ) -> None: - if not isinstance(request, Request): - raise TypeError("'request' argument is not type of Request") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.response_validator_cls - if not issubclass(cls, ResponseValidator): - raise TypeError("'cls' argument is not type of ResponseValidator") - v = cls(spec, base_url=base_url, **validator_kwargs) - v.check_spec(spec) - return v.validate(request, response) + config = Config( + server_base_url=base_url, + response_validator_cls=cls or _UNSET, + **validator_kwargs, + ) + return OpenAPI(spec, config=config).validate_apicall_response( + request, response + ) def validate_webhook_response( @@ -453,21 +266,11 @@ def validate_webhook_response( cls: Optional[WebhookResponseValidatorType] = None, **validator_kwargs: Any, ) -> None: - if not isinstance(request, WebhookRequest): - raise TypeError("'request' argument is not type of WebhookRequest") - if not isinstance(response, Response): - raise TypeError("'response' argument is not type of Response") - if not isinstance(spec, SchemaPath): - raise TypeError("'spec' argument is not type of SchemaPath") - if cls is None: - classes = get_classes(spec) - cls = classes.webhook_response_validator_cls - if cls is None: - raise SpecError("Validator class not found") - if not issubclass(cls, WebhookResponseValidator): - raise TypeError( - "'cls' argument is not type of WebhookResponseValidator" - ) - v = cls(spec, base_url=base_url, **validator_kwargs) - v.check_spec(spec) - return v.validate(request, response) + config = Config( + server_base_url=base_url, + webhook_response_validator_cls=cls or _UNSET, + **validator_kwargs, + ) + return OpenAPI(spec, config=config).validate_webhook_response( + request, response + ) diff --git a/openapi_core/spec/__init__.py b/openapi_core/spec/__init__.py index 6ab17b89..e69de29b 100644 --- a/openapi_core/spec/__init__.py +++ b/openapi_core/spec/__init__.py @@ -1,3 +0,0 @@ -from openapi_core.spec.paths import Spec - -__all__ = ["Spec"] diff --git a/openapi_core/spec/paths.py b/openapi_core/spec/paths.py index f4e940e3..a1846ee0 100644 --- a/openapi_core/spec/paths.py +++ b/openapi_core/spec/paths.py @@ -1,46 +1,13 @@ import warnings from typing import Any -from typing import Hashable -from typing import Mapping -from typing import Type -from typing import TypeVar -from jsonschema.validators import _UNSET from jsonschema_path import SchemaPath -from openapi_spec_validator import validate - -TSpec = TypeVar("TSpec", bound="Spec") - -SPEC_SEPARATOR = "#" class Spec(SchemaPath): - @classmethod - def from_dict( - cls: Type[TSpec], - data: Mapping[Hashable, Any], - *args: Any, - **kwargs: Any, - ) -> TSpec: + def __init__(self, *args: Any, **kwargs: Any): warnings.warn( "Spec is deprecated. Use SchemaPath from jsonschema-path package.", DeprecationWarning, ) - if "validator" in kwargs: - warnings.warn( - "validator parameter is deprecated. Use spec_validator_cls instead.", - DeprecationWarning, - ) - validator = kwargs.pop("validator", _UNSET) - spec_validator_cls = kwargs.pop("spec_validator_cls", _UNSET) - base_uri = kwargs.get("base_uri", "") - spec_url = kwargs.get("spec_url") - if spec_validator_cls is not None: - if spec_validator_cls is not _UNSET: - validate(data, base_uri=base_uri, cls=spec_validator_cls) - elif validator is _UNSET: - validate(data, base_uri=base_uri) - elif validator is not None: - validator.validate(data, base_uri=base_uri, spec_url=spec_url) - - return super().from_dict(data, *args, **kwargs) + super().__init__(*args, **kwargs) diff --git a/openapi_core/types.py b/openapi_core/types.py index 9d9b1bc8..2a1934ad 100644 --- a/openapi_core/types.py +++ b/openapi_core/types.py @@ -1,36 +1,6 @@ -from dataclasses import dataclass -from typing import Mapping -from typing import NamedTuple -from typing import Optional -from typing import Type +from typing import Union -from jsonschema_path import SchemaPath +from openapi_core.protocols import Request +from openapi_core.protocols import WebhookRequest -from openapi_core.exceptions import SpecError -from openapi_core.unmarshalling.request.types import RequestUnmarshallerType -from openapi_core.unmarshalling.request.types import ( - WebhookRequestUnmarshallerType, -) -from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType -from openapi_core.unmarshalling.response.types import ( - WebhookResponseUnmarshallerType, -) -from openapi_core.validation.request.types import RequestValidatorType -from openapi_core.validation.request.types import WebhookRequestValidatorType -from openapi_core.validation.response.types import ResponseValidatorType -from openapi_core.validation.response.types import WebhookResponseValidatorType -from openapi_core.validation.validators import BaseValidator - - -@dataclass -class SpecClasses: - request_validator_cls: RequestValidatorType - response_validator_cls: ResponseValidatorType - webhook_request_validator_cls: Optional[WebhookRequestValidatorType] - webhook_response_validator_cls: Optional[WebhookResponseValidatorType] - request_unmarshaller_cls: RequestUnmarshallerType - response_unmarshaller_cls: ResponseUnmarshallerType - webhook_request_unmarshaller_cls: Optional[WebhookRequestUnmarshallerType] - webhook_response_unmarshaller_cls: Optional[ - WebhookResponseUnmarshallerType - ] +AnyRequest = Union[Request, WebhookRequest] diff --git a/openapi_core/typing.py b/openapi_core/typing.py index ed682913..7cb12f9d 100644 --- a/openapi_core/typing.py +++ b/openapi_core/typing.py @@ -1,17 +1,6 @@ -from typing import Awaitable -from typing import Callable -from typing import Iterable from typing import TypeVar -from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult - #: The type of request within an integration. RequestType = TypeVar("RequestType") #: The type of response within an integration. ResponseType = TypeVar("ResponseType") - -ErrorsHandlerCallable = Callable[[Iterable[Exception]], ResponseType] -ValidRequestHandlerCallable = Callable[[RequestUnmarshalResult], ResponseType] -AsyncValidRequestHandlerCallable = Callable[ - [RequestUnmarshalResult], Awaitable[ResponseType] -] diff --git a/openapi_core/unmarshalling/configurations.py b/openapi_core/unmarshalling/configurations.py new file mode 100644 index 00000000..27cdccd7 --- /dev/null +++ b/openapi_core/unmarshalling/configurations.py @@ -0,0 +1,25 @@ +from dataclasses import dataclass +from typing import Optional + +from openapi_core.unmarshalling.schemas.datatypes import ( + FormatUnmarshallersDict, +) +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) +from openapi_core.validation.configurations import ValidatorConfig + + +@dataclass +class UnmarshallerConfig(ValidatorConfig): + """Unmarshaller configuration dataclass. + + Attributes: + schema_unmarshallers_factory + Schema unmarshallers factory. + extra_format_unmarshallers + Extra format unmarshallers. + """ + + schema_unmarshallers_factory: Optional[SchemaUnmarshallersFactory] = None + extra_format_unmarshallers: Optional[FormatUnmarshallersDict] = None diff --git a/openapi_core/unmarshalling/integrations.py b/openapi_core/unmarshalling/integrations.py new file mode 100644 index 00000000..d3f4b708 --- /dev/null +++ b/openapi_core/unmarshalling/integrations.py @@ -0,0 +1,83 @@ +"""OpenAPI core unmarshalling processors module""" +from typing import Any +from typing import Generic +from typing import Optional + +from jsonschema_path import SchemaPath + +from openapi_core.app import OpenAPI +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.typing import RequestType +from openapi_core.typing import ResponseType +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.processors import ( + RequestUnmarshallingProcessor, +) +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.processors import ( + ResponseUnmarshallingProcessor, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.typing import AsyncValidRequestHandlerCallable +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable +from openapi_core.unmarshalling.typing import ValidRequestHandlerCallable +from openapi_core.validation.integrations import ValidationIntegration + + +class UnmarshallingIntegration( + ValidationIntegration[RequestType, ResponseType] +): + def unmarshal_request( + self, request: RequestType + ) -> RequestUnmarshalResult: + openapi_request = self.get_openapi_request(request) + return self.openapi.unmarshal_request( + openapi_request, + ) + + def unmarshal_response( + self, + request: RequestType, + response: ResponseType, + ) -> ResponseUnmarshalResult: + openapi_request = self.get_openapi_request(request) + openapi_response = self.get_openapi_response(response) + return self.openapi.unmarshal_response( + openapi_request, openapi_response + ) + + +class AsyncUnmarshallingIntegration(Generic[RequestType, ResponseType]): + def __init__( + self, + openapi: OpenAPI, + ): + self.openapi = openapi + + async def get_openapi_request(self, request: RequestType) -> Request: + raise NotImplementedError + + async def get_openapi_response(self, response: ResponseType) -> Response: + raise NotImplementedError + + async def unmarshal_request( + self, + request: RequestType, + ) -> RequestUnmarshalResult: + openapi_request = await self.get_openapi_request(request) + return self.openapi.unmarshal_request(openapi_request) + + async def unmarshal_response( + self, + request: RequestType, + response: ResponseType, + ) -> ResponseUnmarshalResult: + openapi_request = await self.get_openapi_request(request) + openapi_response = await self.get_openapi_response(response) + return self.openapi.unmarshal_response( + openapi_request, openapi_response + ) diff --git a/openapi_core/unmarshalling/processors.py b/openapi_core/unmarshalling/processors.py index 6a1945f9..7470ee2b 100644 --- a/openapi_core/unmarshalling/processors.py +++ b/openapi_core/unmarshalling/processors.py @@ -7,12 +7,12 @@ from openapi_core.protocols import Request from openapi_core.protocols import Response -from openapi_core.shortcuts import get_classes -from openapi_core.typing import AsyncValidRequestHandlerCallable -from openapi_core.typing import ErrorsHandlerCallable from openapi_core.typing import RequestType from openapi_core.typing import ResponseType -from openapi_core.typing import ValidRequestHandlerCallable +from openapi_core.unmarshalling.integrations import ( + AsyncUnmarshallingIntegration, +) +from openapi_core.unmarshalling.integrations import UnmarshallingIntegration from openapi_core.unmarshalling.request.processors import ( RequestUnmarshallingProcessor, ) @@ -21,55 +21,22 @@ ResponseUnmarshallingProcessor, ) from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.typing import AsyncValidRequestHandlerCallable +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable +from openapi_core.unmarshalling.typing import ValidRequestHandlerCallable -class UnmarshallingProcessor(Generic[RequestType, ResponseType]): - def __init__( - self, - spec: SchemaPath, - request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, - response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, - **unmarshaller_kwargs: Any, - ): - if ( - request_unmarshaller_cls is None - or response_unmarshaller_cls is None - ): - classes = get_classes(spec) - if request_unmarshaller_cls is None: - request_unmarshaller_cls = classes.request_unmarshaller_cls - if response_unmarshaller_cls is None: - response_unmarshaller_cls = classes.response_unmarshaller_cls - - self.request_processor = RequestUnmarshallingProcessor( - spec, - request_unmarshaller_cls, - **unmarshaller_kwargs, - ) - self.response_processor = ResponseUnmarshallingProcessor( - spec, - response_unmarshaller_cls, - **unmarshaller_kwargs, - ) - - def _get_openapi_request(self, request: RequestType) -> Request: - raise NotImplementedError - - def _get_openapi_response(self, response: ResponseType) -> Response: - raise NotImplementedError - - def _validate_response(self) -> bool: - raise NotImplementedError - +class UnmarshallingProcessor( + UnmarshallingIntegration[RequestType, ResponseType] +): def handle_request( self, request: RequestType, valid_handler: ValidRequestHandlerCallable[ResponseType], errors_handler: ErrorsHandlerCallable[ResponseType], ) -> ResponseType: - openapi_request = self._get_openapi_request(request) - request_unmarshal_result = self.request_processor.process( - openapi_request + request_unmarshal_result = self.unmarshal_request( + request, ) if request_unmarshal_result.errors: return errors_handler(request_unmarshal_result.errors) @@ -81,66 +48,22 @@ def handle_response( response: ResponseType, errors_handler: ErrorsHandlerCallable[ResponseType], ) -> ResponseType: - if not self._validate_response(): - return response - openapi_request = self._get_openapi_request(request) - openapi_response = self._get_openapi_response(response) - response_unmarshal_result = self.response_processor.process( - openapi_request, openapi_response - ) + response_unmarshal_result = self.unmarshal_response(request, response) if response_unmarshal_result.errors: return errors_handler(response_unmarshal_result.errors) return response -class AsyncUnmarshallingProcessor(Generic[RequestType, ResponseType]): - def __init__( - self, - spec: SchemaPath, - request_unmarshaller_cls: Optional[RequestUnmarshallerType] = None, - response_unmarshaller_cls: Optional[ResponseUnmarshallerType] = None, - **unmarshaller_kwargs: Any, - ): - if ( - request_unmarshaller_cls is None - or response_unmarshaller_cls is None - ): - classes = get_classes(spec) - if request_unmarshaller_cls is None: - request_unmarshaller_cls = classes.request_unmarshaller_cls - if response_unmarshaller_cls is None: - response_unmarshaller_cls = classes.response_unmarshaller_cls - - self.request_processor = RequestUnmarshallingProcessor( - spec, - request_unmarshaller_cls, - **unmarshaller_kwargs, - ) - self.response_processor = ResponseUnmarshallingProcessor( - spec, - response_unmarshaller_cls, - **unmarshaller_kwargs, - ) - - async def _get_openapi_request(self, request: RequestType) -> Request: - raise NotImplementedError - - async def _get_openapi_response(self, response: ResponseType) -> Response: - raise NotImplementedError - - def _validate_response(self) -> bool: - raise NotImplementedError - +class AsyncUnmarshallingProcessor( + AsyncUnmarshallingIntegration[RequestType, ResponseType] +): async def handle_request( self, request: RequestType, valid_handler: AsyncValidRequestHandlerCallable[ResponseType], errors_handler: ErrorsHandlerCallable[ResponseType], ) -> ResponseType: - openapi_request = await self._get_openapi_request(request) - request_unmarshal_result = self.request_processor.process( - openapi_request - ) + request_unmarshal_result = await self.unmarshal_request(request) if request_unmarshal_result.errors: return errors_handler(request_unmarshal_result.errors) result = await valid_handler(request_unmarshal_result) @@ -152,12 +75,8 @@ async def handle_response( response: ResponseType, errors_handler: ErrorsHandlerCallable[ResponseType], ) -> ResponseType: - if not self._validate_response(): - return response - openapi_request = await self._get_openapi_request(request) - openapi_response = await self._get_openapi_response(response) - response_unmarshal_result = self.response_processor.process( - openapi_request, openapi_response + response_unmarshal_result = await self.unmarshal_response( + request, response ) if response_unmarshal_result.errors: return errors_handler(response_unmarshal_result.errors) diff --git a/openapi_core/unmarshalling/request/__init__.py b/openapi_core/unmarshalling/request/__init__.py index ddf7207a..fc2a08a4 100644 --- a/openapi_core/unmarshalling/request/__init__.py +++ b/openapi_core/unmarshalling/request/__init__.py @@ -1,4 +1,13 @@ """OpenAPI core unmarshalling request module""" +from typing import Mapping + +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.datatypes import SpecVersion + +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.request.types import ( + WebhookRequestUnmarshallerType, +) from openapi_core.unmarshalling.request.unmarshallers import ( V30RequestUnmarshaller, ) @@ -10,6 +19,8 @@ ) __all__ = [ + "UNMARSHALLERS", + "WEBHOOK_UNMARSHALLERS", "V3RequestUnmarshaller", "V3WebhookRequestUnmarshaller", "V30RequestUnmarshaller", @@ -17,6 +28,15 @@ "V31WebhookRequestUnmarshaller", ] +# versions mapping +UNMARSHALLERS: Mapping[SpecVersion, RequestUnmarshallerType] = { + versions.OPENAPIV30: V30RequestUnmarshaller, + versions.OPENAPIV31: V31RequestUnmarshaller, +} +WEBHOOK_UNMARSHALLERS: Mapping[SpecVersion, WebhookRequestUnmarshallerType] = { + versions.OPENAPIV31: V31WebhookRequestUnmarshaller, +} + # alias to the latest v3 version V3RequestUnmarshaller = V31RequestUnmarshaller V3WebhookRequestUnmarshaller = V31WebhookRequestUnmarshaller diff --git a/openapi_core/unmarshalling/request/protocols.py b/openapi_core/unmarshalling/request/protocols.py index c6d0b057..388f13c8 100644 --- a/openapi_core/unmarshalling/request/protocols.py +++ b/openapi_core/unmarshalling/request/protocols.py @@ -15,9 +15,6 @@ class RequestUnmarshaller(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def unmarshal( self, request: Request, @@ -30,9 +27,6 @@ class WebhookRequestUnmarshaller(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def unmarshal( self, request: WebhookRequest, diff --git a/openapi_core/unmarshalling/response/__init__.py b/openapi_core/unmarshalling/response/__init__.py index 998b202c..2c7094f1 100644 --- a/openapi_core/unmarshalling/response/__init__.py +++ b/openapi_core/unmarshalling/response/__init__.py @@ -1,4 +1,13 @@ """OpenAPI core unmarshalling response module""" +from typing import Mapping + +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.datatypes import SpecVersion + +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.response.types import ( + WebhookResponseUnmarshallerType, +) from openapi_core.unmarshalling.response.unmarshallers import ( V30ResponseUnmarshaller, ) @@ -10,6 +19,8 @@ ) __all__ = [ + "UNMARSHALLERS", + "WEBHOOK_UNMARSHALLERS", "V3ResponseUnmarshaller", "V3WebhookResponseUnmarshaller", "V30ResponseUnmarshaller", @@ -17,6 +28,17 @@ "V31WebhookResponseUnmarshaller", ] +# versions mapping +UNMARSHALLERS: Mapping[SpecVersion, ResponseUnmarshallerType] = { + versions.OPENAPIV30: V30ResponseUnmarshaller, + versions.OPENAPIV31: V31ResponseUnmarshaller, +} +WEBHOOK_UNMARSHALLERS: Mapping[ + SpecVersion, WebhookResponseUnmarshallerType +] = { + versions.OPENAPIV31: V31WebhookResponseUnmarshaller, +} + # alias to the latest v3 version V3ResponseUnmarshaller = V31ResponseUnmarshaller V3WebhookResponseUnmarshaller = V31WebhookResponseUnmarshaller diff --git a/openapi_core/unmarshalling/response/protocols.py b/openapi_core/unmarshalling/response/protocols.py index 08c79e9d..8666e84d 100644 --- a/openapi_core/unmarshalling/response/protocols.py +++ b/openapi_core/unmarshalling/response/protocols.py @@ -20,9 +20,6 @@ class ResponseUnmarshaller(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def unmarshal( self, request: Request, @@ -36,9 +33,6 @@ class WebhookResponseUnmarshaller(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def unmarshal( self, request: WebhookRequest, diff --git a/openapi_core/unmarshalling/typing.py b/openapi_core/unmarshalling/typing.py new file mode 100644 index 00000000..587b977c --- /dev/null +++ b/openapi_core/unmarshalling/typing.py @@ -0,0 +1,12 @@ +from typing import Awaitable +from typing import Callable +from typing import Iterable + +from openapi_core.typing import ResponseType +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult + +ErrorsHandlerCallable = Callable[[Iterable[Exception]], ResponseType] +ValidRequestHandlerCallable = Callable[[RequestUnmarshalResult], ResponseType] +AsyncValidRequestHandlerCallable = Callable[ + [RequestUnmarshalResult], Awaitable[ResponseType] +] diff --git a/openapi_core/validation/configurations.py b/openapi_core/validation/configurations.py new file mode 100644 index 00000000..60eb1fb4 --- /dev/null +++ b/openapi_core/validation/configurations.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from typing import Optional + +from openapi_core.casting.schemas.factories import SchemaCastersFactory +from openapi_core.deserializing.media_types import ( + media_type_deserializers_factory, +) +from openapi_core.deserializing.media_types.datatypes import ( + MediaTypeDeserializersDict, +) +from openapi_core.deserializing.media_types.factories import ( + MediaTypeDeserializersFactory, +) +from openapi_core.deserializing.styles import style_deserializers_factory +from openapi_core.deserializing.styles.factories import ( + StyleDeserializersFactory, +) +from openapi_core.unmarshalling.schemas.datatypes import ( + FormatUnmarshallersDict, +) +from openapi_core.unmarshalling.schemas.factories import ( + SchemaUnmarshallersFactory, +) +from openapi_core.validation.schemas.datatypes import FormatValidatorsDict +from openapi_core.validation.schemas.factories import SchemaValidatorsFactory + + +@dataclass +class ValidatorConfig: + """Validator configuration dataclass. + + Attributes: + server_base_url + Server base URI. + style_deserializers_factory + Style deserializers factory. + media_type_deserializers_factory + Media type deserializers factory. + schema_casters_factory + Schema casters factory. + schema_validators_factory + Schema validators factory. + extra_format_validators + Extra format validators. + extra_media_type_deserializers + Extra media type deserializers. + """ + + server_base_url: Optional[str] = None + + style_deserializers_factory: StyleDeserializersFactory = ( + style_deserializers_factory + ) + media_type_deserializers_factory: MediaTypeDeserializersFactory = ( + media_type_deserializers_factory + ) + schema_casters_factory: Optional[SchemaCastersFactory] = None + schema_validators_factory: Optional[SchemaValidatorsFactory] = None + + extra_format_validators: Optional[FormatValidatorsDict] = None + extra_media_type_deserializers: Optional[MediaTypeDeserializersDict] = None diff --git a/openapi_core/validation/integrations.py b/openapi_core/validation/integrations.py new file mode 100644 index 00000000..d16ecdb6 --- /dev/null +++ b/openapi_core/validation/integrations.py @@ -0,0 +1,56 @@ +"""OpenAPI core unmarshalling processors module""" +from typing import Any +from typing import Generic +from typing import Optional + +from jsonschema_path import SchemaPath + +from openapi_core.app import OpenAPI +from openapi_core.protocols import Request +from openapi_core.protocols import Response +from openapi_core.typing import RequestType +from openapi_core.typing import ResponseType +from openapi_core.unmarshalling.request.datatypes import RequestUnmarshalResult +from openapi_core.unmarshalling.request.processors import ( + RequestUnmarshallingProcessor, +) +from openapi_core.unmarshalling.request.types import RequestUnmarshallerType +from openapi_core.unmarshalling.response.datatypes import ( + ResponseUnmarshalResult, +) +from openapi_core.unmarshalling.response.processors import ( + ResponseUnmarshallingProcessor, +) +from openapi_core.unmarshalling.response.types import ResponseUnmarshallerType +from openapi_core.unmarshalling.typing import AsyncValidRequestHandlerCallable +from openapi_core.unmarshalling.typing import ErrorsHandlerCallable +from openapi_core.unmarshalling.typing import ValidRequestHandlerCallable + + +class ValidationIntegration(Generic[RequestType, ResponseType]): + def __init__( + self, + openapi: OpenAPI, + ): + self.openapi = openapi + + def get_openapi_request(self, request: RequestType) -> Request: + raise NotImplementedError + + def get_openapi_response(self, response: ResponseType) -> Response: + raise NotImplementedError + + def validate_request(self, request: RequestType) -> None: + openapi_request = self.get_openapi_request(request) + self.openapi.validate_request( + openapi_request, + ) + + def validate_response( + self, + request: RequestType, + response: ResponseType, + ) -> None: + openapi_request = self.get_openapi_request(request) + openapi_response = self.get_openapi_response(response) + self.openapi.validate_response(openapi_request, openapi_response) diff --git a/openapi_core/validation/processors.py b/openapi_core/validation/processors.py index 711b5225..08f1f41a 100644 --- a/openapi_core/validation/processors.py +++ b/openapi_core/validation/processors.py @@ -6,35 +6,18 @@ from openapi_core.protocols import Request from openapi_core.protocols import Response -from openapi_core.shortcuts import get_classes +from openapi_core.typing import RequestType +from openapi_core.typing import ResponseType +from openapi_core.validation.integrations import ValidationIntegration from openapi_core.validation.request.types import RequestValidatorType from openapi_core.validation.response.types import ResponseValidatorType -class ValidationProcessor: - def __init__( - self, - spec: SchemaPath, - request_validator_cls: Optional[RequestValidatorType] = None, - response_validator_cls: Optional[ResponseValidatorType] = None, - **unmarshaller_kwargs: Any, - ): - self.spec = spec - if request_validator_cls is None or response_validator_cls is None: - classes = get_classes(self.spec) - if request_validator_cls is None: - request_validator_cls = classes.request_validator_cls - if response_validator_cls is None: - response_validator_cls = classes.response_validator_cls - self.request_validator = request_validator_cls( - self.spec, **unmarshaller_kwargs - ) - self.response_validator = response_validator_cls( - self.spec, **unmarshaller_kwargs - ) +class ValidationProcessor(ValidationIntegration[RequestType, ResponseType]): + def handle_request(self, request: RequestType) -> None: + self.validate_request(request) - def process_request(self, request: Request) -> None: - self.request_validator.validate(request) - - def process_response(self, request: Request, response: Response) -> None: - self.response_validator.validate(request, response) + def handle_response( + self, request: RequestType, response: ResponseType + ) -> None: + self.validate_response(request, response) diff --git a/openapi_core/validation/request/__init__.py b/openapi_core/validation/request/__init__.py index d79102cc..e94adeda 100644 --- a/openapi_core/validation/request/__init__.py +++ b/openapi_core/validation/request/__init__.py @@ -1,4 +1,11 @@ """OpenAPI core validation request module""" +from typing import Mapping + +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.datatypes import SpecVersion + +from openapi_core.validation.request.types import RequestValidatorType +from openapi_core.validation.request.types import WebhookRequestValidatorType from openapi_core.validation.request.validators import V30RequestBodyValidator from openapi_core.validation.request.validators import ( V30RequestParametersValidator, @@ -29,6 +36,8 @@ ) __all__ = [ + "VALIDATORS", + "WEBHOOK_VALIDATORS", "V30RequestBodyValidator", "V30RequestParametersValidator", "V30RequestSecurityValidator", @@ -45,6 +54,15 @@ "V3WebhookRequestValidator", ] +# versions mapping +VALIDATORS: Mapping[SpecVersion, RequestValidatorType] = { + versions.OPENAPIV30: V30RequestValidator, + versions.OPENAPIV31: V31RequestValidator, +} +WEBHOOK_VALIDATORS: Mapping[SpecVersion, WebhookRequestValidatorType] = { + versions.OPENAPIV31: V31WebhookRequestValidator, +} + # alias to the latest v3 version V3RequestValidator = V31RequestValidator V3WebhookRequestValidator = V31WebhookRequestValidator diff --git a/openapi_core/validation/request/protocols.py b/openapi_core/validation/request/protocols.py index 8009c50a..e27f5863 100644 --- a/openapi_core/validation/request/protocols.py +++ b/openapi_core/validation/request/protocols.py @@ -15,9 +15,6 @@ class RequestValidator(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def iter_errors( self, request: Request, @@ -36,9 +33,6 @@ class WebhookRequestValidator(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def iter_errors( self, request: WebhookRequest, diff --git a/openapi_core/validation/response/__init__.py b/openapi_core/validation/response/__init__.py index 5c62af3f..2210d613 100644 --- a/openapi_core/validation/response/__init__.py +++ b/openapi_core/validation/response/__init__.py @@ -1,4 +1,11 @@ """OpenAPI core validation response module""" +from typing import Mapping + +from openapi_spec_validator.versions import consts as versions +from openapi_spec_validator.versions.datatypes import SpecVersion + +from openapi_core.validation.response.types import ResponseValidatorType +from openapi_core.validation.response.types import WebhookResponseValidatorType from openapi_core.validation.response.validators import ( V30ResponseDataValidator, ) @@ -24,6 +31,8 @@ ) __all__ = [ + "VALIDATORS", + "WEBHOOK_VALIDATORS", "V30ResponseDataValidator", "V30ResponseHeadersValidator", "V30ResponseValidator", @@ -37,6 +46,15 @@ "V3WebhookResponseValidator", ] +# versions mapping +VALIDATORS: Mapping[SpecVersion, ResponseValidatorType] = { + versions.OPENAPIV30: V30ResponseValidator, + versions.OPENAPIV31: V31ResponseValidator, +} +WEBHOOK_VALIDATORS: Mapping[SpecVersion, WebhookResponseValidatorType] = { + versions.OPENAPIV31: V31WebhookResponseValidator, +} + # alias to the latest v3 version V3ResponseValidator = V31ResponseValidator V3WebhookResponseValidator = V31WebhookResponseValidator diff --git a/openapi_core/validation/response/protocols.py b/openapi_core/validation/response/protocols.py index 95c4a83d..7a403d3e 100644 --- a/openapi_core/validation/response/protocols.py +++ b/openapi_core/validation/response/protocols.py @@ -16,9 +16,6 @@ class ResponseValidator(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def iter_errors( self, request: Request, @@ -39,9 +36,6 @@ class WebhookResponseValidator(Protocol): def __init__(self, spec: SchemaPath, base_url: Optional[str] = None): ... - def check_spec(self, spec: SchemaPath) -> None: - ... - def iter_errors( self, request: WebhookRequest, diff --git a/pyproject.toml b/pyproject.toml index e473f45c..25c922f8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -128,6 +128,7 @@ filterwarnings = [ "error", # falcon.media.handlers uses cgi to parse data "ignore:'cgi' is deprecated and slated for removal in Python 3.13:DeprecationWarning", + "ignore:co_lnotab is deprecated, use co_lines instead:DeprecationWarning", ] [tool.black] diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 00dc26b6..cea4a154 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -7,6 +7,8 @@ from openapi_spec_validator.readers import read_from_filename from yaml import safe_load +from openapi_core import Spec + def content_from_file(spec_file): directory = path.abspath(path.dirname(__file__)) @@ -14,17 +16,27 @@ def content_from_file(spec_file): return read_from_filename(path_full) -def spec_from_file(spec_file): +def schema_path_from_file(spec_file): spec_dict, base_uri = content_from_file(spec_file) return SchemaPath.from_dict(spec_dict, base_uri=base_uri) -def spec_from_url(base_uri): +def schema_path_from_url(base_uri): content = request.urlopen(base_uri) spec_dict = safe_load(content) return SchemaPath.from_dict(spec_dict, base_uri=base_uri) +def spec_from_file(spec_file): + schema_path = schema_path_from_file(spec_file) + return Spec(schema_path) + + +def spec_from_url(base_uri): + schema_path = schema_path_from_url(base_uri) + return Spec(schema_path) + + @pytest.fixture(scope="session") def data_gif(): return b64decode( @@ -44,17 +56,31 @@ class Factory(dict): @pytest.fixture(scope="session") -def factory(): +def content_factory(): + return Factory( + from_file=content_from_file, + ) + + +@pytest.fixture(scope="session") +def schema_path_factory(): + return Factory( + from_file=schema_path_from_file, + from_url=schema_path_from_url, + ) + + +@pytest.fixture(scope="session") +def spec_factory(schema_path_factory): return Factory( - content_from_file=content_from_file, - spec_from_file=spec_from_file, - spec_from_url=spec_from_url, + from_file=spec_from_file, + from_url=spec_from_url, ) @pytest.fixture(scope="session") -def v30_petstore_content(factory): - content, _ = factory.content_from_file("data/v3.0/petstore.yaml") +def v30_petstore_content(content_factory): + content, _ = content_factory.from_file("data/v3.0/petstore.yaml") return content diff --git a/tests/integration/contrib/aiohttp/conftest.py b/tests/integration/contrib/aiohttp/conftest.py index a76607a3..ce299473 100644 --- a/tests/integration/contrib/aiohttp/conftest.py +++ b/tests/integration/contrib/aiohttp/conftest.py @@ -14,10 +14,10 @@ @pytest.fixture -def spec(factory): +def schema_path(schema_path_factory): directory = pathlib.Path(__file__).parent specfile = directory / "data" / "v3.0" / "aiohttp_factory.yaml" - return factory.spec_from_file(str(specfile)) + return schema_path_factory.from_file(str(specfile)) @pytest.fixture @@ -41,11 +41,11 @@ async def test_route(request: web.Request) -> web.Response: @pytest.fixture -def request_validation(spec, response_getter): +def request_validation(schema_path, response_getter): async def test_route(request: web.Request) -> web.Response: request_body = await request.text() openapi_request = AIOHTTPOpenAPIWebRequest(request, body=request_body) - unmarshaller = V30RequestUnmarshaller(spec) + unmarshaller = V30RequestUnmarshaller(schema_path) result = unmarshaller.unmarshal(openapi_request) response: dict[str, Any] = response_getter() status = 200 @@ -62,7 +62,7 @@ async def test_route(request: web.Request) -> web.Response: @pytest.fixture -def response_validation(spec, response_getter): +def response_validation(schema_path, response_getter): async def test_route(request: web.Request) -> web.Response: request_body = await request.text() openapi_request = AIOHTTPOpenAPIWebRequest(request, body=request_body) @@ -73,7 +73,7 @@ async def test_route(request: web.Request) -> web.Response: status=200, ) openapi_response = AIOHTTPOpenAPIWebResponse(response) - unmarshaller = V30ResponseUnmarshaller(spec) + unmarshaller = V30ResponseUnmarshaller(schema_path) result = unmarshaller.unmarshal(openapi_request, openapi_response) if result.errors: response = web.json_response( diff --git a/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/openapi.py b/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/openapi.py index ac65a703..4ca6d9fa 100644 --- a/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/openapi.py +++ b/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/openapi.py @@ -1,8 +1,9 @@ from pathlib import Path import yaml -from jsonschema_path import SchemaPath + +from openapi_core import OpenAPI openapi_spec_path = Path("tests/integration/data/v3.0/petstore.yaml") spec_dict = yaml.load(openapi_spec_path.read_text(), yaml.Loader) -spec = SchemaPath.from_dict(spec_dict) +openapi = OpenAPI.from_dict(spec_dict) diff --git a/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/pets/views.py b/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/pets/views.py index fea3545e..c9130b58 100644 --- a/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/pets/views.py +++ b/tests/integration/contrib/aiohttp/data/v3.0/aiohttpproject/pets/views.py @@ -2,7 +2,7 @@ from io import BytesIO from aiohttp import web -from aiohttpproject.openapi import spec +from aiohttpproject.openapi import openapi from multidict import MultiDict from openapi_core import unmarshal_request @@ -27,14 +27,14 @@ async def get(self): openapi_request = AIOHTTPOpenAPIWebRequest( self.request, body=request_body ) - request_unmarshalled = unmarshal_request(openapi_request, spec=spec) + request_unmarshalled = openapi.unmarshal_request(openapi_request) response = web.Response( body=self.OPENID_LOGO, content_type="image/gif", ) openapi_response = AIOHTTPOpenAPIWebResponse(response) - response_unmarshalled = unmarshal_response( - openapi_request, openapi_response, spec=spec + response_unmarshalled = openapi.unmarshal_response( + openapi_request, openapi_response ) return response @@ -43,10 +43,10 @@ async def post(self): openapi_request = AIOHTTPOpenAPIWebRequest( self.request, body=request_body ) - request_unmarshalled = unmarshal_request(openapi_request, spec=spec) + request_unmarshalled = openapi.unmarshal_request(openapi_request) response = web.Response(status=201) openapi_response = AIOHTTPOpenAPIWebResponse(response) - response_unmarshalled = unmarshal_response( - openapi_request, openapi_response, spec=spec + response_unmarshalled = openapi.unmarshal_response( + openapi_request, openapi_response ) return response diff --git a/tests/integration/contrib/django/data/v3.0/djangoproject/settings.py b/tests/integration/contrib/django/data/v3.0/djangoproject/settings.py index b5ccdaa3..b50d4884 100644 --- a/tests/integration/contrib/django/data/v3.0/djangoproject/settings.py +++ b/tests/integration/contrib/django/data/v3.0/djangoproject/settings.py @@ -16,6 +16,8 @@ import yaml from jsonschema_path import SchemaPath +from openapi_core import OpenAPI + # Build paths inside the project like this: os.path.join(BASE_DIR, ...) BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) @@ -121,3 +123,5 @@ OPENAPI_SPEC_DICT = yaml.load(OPENAPI_SPEC_PATH.read_text(), yaml.Loader) OPENAPI_SPEC = SchemaPath.from_dict(OPENAPI_SPEC_DICT) + +OPENAPI = OpenAPI(OPENAPI_SPEC) diff --git a/tests/integration/contrib/flask/conftest.py b/tests/integration/contrib/flask/conftest.py index 400b1cf7..80e8579c 100644 --- a/tests/integration/contrib/flask/conftest.py +++ b/tests/integration/contrib/flask/conftest.py @@ -1,11 +1,12 @@ import pytest from flask import Flask +from jsonschema_path import SchemaPath @pytest.fixture(scope="session") -def spec(factory): +def schema_path(schema_path_factory): specfile = "contrib/flask/data/v3.0/flask_factory.yaml" - return factory.spec_from_file(specfile) + return schema_path_factory.from_file(specfile) @pytest.fixture diff --git a/tests/integration/contrib/flask/data/v3.0/flaskproject/__main__.py b/tests/integration/contrib/flask/data/v3.0/flaskproject/__main__.py index 530264fc..dc95cdc8 100644 --- a/tests/integration/contrib/flask/data/v3.0/flaskproject/__main__.py +++ b/tests/integration/contrib/flask/data/v3.0/flaskproject/__main__.py @@ -1,11 +1,11 @@ from flask import Flask -from flaskproject.openapi import spec +from flaskproject.openapi import openapi from flaskproject.pets.views import PetPhotoView app = Flask(__name__) app.add_url_rule( "/v1/pets//photo", - view_func=PetPhotoView.as_view("pet_photo", spec), + view_func=PetPhotoView.as_view("pet_photo", openapi), methods=["GET", "POST"], ) diff --git a/tests/integration/contrib/flask/data/v3.0/flaskproject/openapi.py b/tests/integration/contrib/flask/data/v3.0/flaskproject/openapi.py index ac65a703..4ca6d9fa 100644 --- a/tests/integration/contrib/flask/data/v3.0/flaskproject/openapi.py +++ b/tests/integration/contrib/flask/data/v3.0/flaskproject/openapi.py @@ -1,8 +1,9 @@ from pathlib import Path import yaml -from jsonschema_path import SchemaPath + +from openapi_core import OpenAPI openapi_spec_path = Path("tests/integration/data/v3.0/petstore.yaml") spec_dict = yaml.load(openapi_spec_path.read_text(), yaml.Loader) -spec = SchemaPath.from_dict(spec_dict) +openapi = OpenAPI.from_dict(spec_dict) diff --git a/tests/integration/contrib/flask/test_flask_decorator.py b/tests/integration/contrib/flask/test_flask_decorator.py index 9dcf8093..cda6cd09 100644 --- a/tests/integration/contrib/flask/test_flask_decorator.py +++ b/tests/integration/contrib/flask/test_flask_decorator.py @@ -8,9 +8,9 @@ @pytest.fixture(scope="session") -def decorator_factory(spec): +def decorator_factory(schema_path): def create(**kwargs): - return FlaskOpenAPIViewDecorator.from_spec(spec, **kwargs) + return FlaskOpenAPIViewDecorator.from_spec(schema_path, **kwargs) return create diff --git a/tests/integration/contrib/flask/test_flask_validator.py b/tests/integration/contrib/flask/test_flask_validator.py index a2fd4332..45773c39 100644 --- a/tests/integration/contrib/flask/test_flask_validator.py +++ b/tests/integration/contrib/flask/test_flask_validator.py @@ -10,12 +10,12 @@ class TestFlaskOpenAPIValidation: - def test_request_validator_root_path(self, spec, app_factory): + def test_request_validator_root_path(self, schema_path, app_factory): def details_view_func(id): from flask import request openapi_request = FlaskOpenAPIRequest(request) - unmarshaller = V30RequestUnmarshaller(spec) + unmarshaller = V30RequestUnmarshaller(schema_path) result = unmarshaller.unmarshal(openapi_request) assert not result.errors diff --git a/tests/integration/contrib/flask/test_flask_views.py b/tests/integration/contrib/flask/test_flask_views.py index 2d786e88..a1caa2c7 100644 --- a/tests/integration/contrib/flask/test_flask_views.py +++ b/tests/integration/contrib/flask/test_flask_views.py @@ -3,13 +3,14 @@ from flask import jsonify from flask import make_response +from openapi_core import Config +from openapi_core import OpenAPI from openapi_core.contrib.flask.views import FlaskOpenAPIView @pytest.fixture(scope="session") -def view_factory(): +def view_factory(schema_path): def create( - spec, methods=None, extra_media_type_deserializers=None, extra_format_validators=None, @@ -25,12 +26,15 @@ def get(view, id): MyView = type("MyView", (FlaskOpenAPIView,), methods) extra_media_type_deserializers = extra_media_type_deserializers or {} extra_format_validators = extra_format_validators or {} - return MyView.as_view( - "myview", - spec, + config = Config( extra_media_type_deserializers=extra_media_type_deserializers, extra_format_validators=extra_format_validators, ) + openapi = OpenAPI(schema_path, config=config) + return MyView.as_view( + "myview", + openapi, + ) return create @@ -42,13 +46,13 @@ def client(self, client_factory, app): with app.app_context(): yield client - def test_invalid_content_type(self, client, app, spec, view_factory): + def test_invalid_content_type(self, client, app, view_factory): def get(view, id): view_response = make_response("success", 200) view_response.headers["X-Rate-Limit"] = "12" return view_response - view_func = view_factory(spec, {"get": get}) + view_func = view_factory({"get": get}) app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/") @@ -70,8 +74,8 @@ def get(view, id): ] } - def test_server_error(self, client, app, spec, view_factory): - view_func = view_factory(spec) + def test_server_error(self, client, app, view_factory): + view_func = view_factory() app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/", base_url="https://localhost") @@ -94,11 +98,11 @@ def test_server_error(self, client, app, spec, view_factory): assert result.status_code == 400 assert result.json == expected_data - def test_operation_error(self, client, app, spec, view_factory): + def test_operation_error(self, client, app, view_factory): def put(view, id): return make_response("success", 200) - view_func = view_factory(spec, {"put": put}) + view_func = view_factory({"put": put}) app.add_url_rule("/browse//", view_func=view_func) result = client.put("/browse/12/") @@ -121,8 +125,8 @@ def put(view, id): assert result.status_code == 405 assert result.json == expected_data - def test_path_error(self, client, app, spec, view_factory): - view_func = view_factory(spec) + def test_path_error(self, client, app, view_factory): + view_func = view_factory() app.add_url_rule("/browse/", view_func=view_func) result = client.get("/browse/") @@ -144,8 +148,8 @@ def test_path_error(self, client, app, spec, view_factory): assert result.status_code == 404 assert result.json == expected_data - def test_endpoint_error(self, client, app, spec, view_factory): - view_func = view_factory(spec) + def test_endpoint_error(self, client, app, view_factory): + view_func = view_factory() app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/invalidparameter/") @@ -168,11 +172,11 @@ def test_endpoint_error(self, client, app, spec, view_factory): assert result.status_code == 400 assert result.json == expected_data - def test_missing_required_header(self, client, app, spec, view_factory): + def test_missing_required_header(self, client, app, view_factory): def get(view, id): return jsonify(data="data") - view_func = view_factory(spec, {"get": get}) + view_func = view_factory({"get": get}) app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/") @@ -192,13 +196,13 @@ def get(view, id): assert result.status_code == 400 assert result.json == expected_data - def test_valid(self, client, app, spec, view_factory): + def test_valid(self, client, app, view_factory): def get(view, id): resp = jsonify(data="data") resp.headers["X-Rate-Limit"] = "12" return resp - view_func = view_factory(spec, {"get": get}) + view_func = view_factory({"get": get}) app.add_url_rule("/browse//", view_func=view_func) result = client.get("/browse/12/") diff --git a/tests/integration/contrib/requests/test_requests_validation.py b/tests/integration/contrib/requests/test_requests_validation.py index 9647edc5..69aa1c34 100644 --- a/tests/integration/contrib/requests/test_requests_validation.py +++ b/tests/integration/contrib/requests/test_requests_validation.py @@ -17,25 +17,25 @@ class TestV31RequestsFactory: @pytest.fixture - def spec(self, factory): + def schema_path(self, schema_path_factory): specfile = "contrib/requests/data/v3.1/requests_factory.yaml" - return factory.spec_from_file(specfile) + return schema_path_factory.from_file(specfile) @pytest.fixture - def request_unmarshaller(self, spec): - return V31RequestUnmarshaller(spec) + def request_unmarshaller(self, schema_path): + return V31RequestUnmarshaller(schema_path) @pytest.fixture - def response_unmarshaller(self, spec): - return V31ResponseUnmarshaller(spec) + def response_unmarshaller(self, schema_path): + return V31ResponseUnmarshaller(schema_path) @pytest.fixture - def webhook_request_unmarshaller(self, spec): - return V31WebhookRequestUnmarshaller(spec) + def webhook_request_unmarshaller(self, schema_path): + return V31WebhookRequestUnmarshaller(schema_path) @pytest.fixture - def webhook_response_unmarshaller(self, spec): - return V31WebhookResponseUnmarshaller(spec) + def webhook_response_unmarshaller(self, schema_path): + return V31WebhookResponseUnmarshaller(schema_path) @responses.activate def test_response_validator_path_pattern(self, response_unmarshaller): @@ -152,17 +152,17 @@ def api_key_encoded(self): class TestPetstore(BaseTestPetstore): @pytest.fixture - def spec(self, factory): + def schema_path(self, schema_path_factory): specfile = "data/v3.0/petstore.yaml" - return factory.spec_from_file(specfile) + return schema_path_factory.from_file(specfile) @pytest.fixture - def request_unmarshaller(self, spec): - return V30RequestUnmarshaller(spec) + def request_unmarshaller(self, schema_path): + return V30RequestUnmarshaller(schema_path) @pytest.fixture - def response_unmarshaller(self, spec): - return V30ResponseUnmarshaller(spec) + def response_unmarshaller(self, schema_path): + return V30ResponseUnmarshaller(schema_path) @responses.activate def test_response_binary_valid(self, response_unmarshaller, data_gif): diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py index ee16e9c7..79b47802 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/__main__.py @@ -1,7 +1,7 @@ from starlette.applications import Starlette from starlette.middleware import Middleware from starlette.routing import Route -from starletteproject.openapi import spec +from starletteproject.openapi import openapi from starletteproject.pets.endpoints import pet_detail_endpoint from starletteproject.pets.endpoints import pet_list_endpoint from starletteproject.pets.endpoints import pet_photo_endpoint @@ -13,7 +13,7 @@ middleware = [ Middleware( StarletteOpenAPIMiddleware, - spec=spec, + openapi=openapi, ), ] diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/openapi.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/openapi.py index ac65a703..4ca6d9fa 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/openapi.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/openapi.py @@ -1,8 +1,9 @@ from pathlib import Path import yaml -from jsonschema_path import SchemaPath + +from openapi_core import OpenAPI openapi_spec_path = Path("tests/integration/data/v3.0/petstore.yaml") spec_dict = yaml.load(openapi_spec_path.read_text(), yaml.Loader) -spec = SchemaPath.from_dict(spec_dict) +openapi = OpenAPI.from_dict(spec_dict) diff --git a/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py b/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py index c569cad2..1ec8e17b 100644 --- a/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py +++ b/tests/integration/contrib/starlette/data/v3.0/starletteproject/pets/endpoints.py @@ -3,7 +3,6 @@ from starlette.responses import JSONResponse from starlette.responses import Response from starlette.responses import StreamingResponse -from starletteproject.openapi import spec from openapi_core import unmarshal_request from openapi_core import unmarshal_response diff --git a/tests/integration/contrib/starlette/test_starlette_validation.py b/tests/integration/contrib/starlette/test_starlette_validation.py index 09f4a96b..6bebcfbb 100644 --- a/tests/integration/contrib/starlette/test_starlette_validation.py +++ b/tests/integration/contrib/starlette/test_starlette_validation.py @@ -16,9 +16,9 @@ class TestV30StarletteFactory: @pytest.fixture - def spec(self, factory): + def schema_path(self, schema_path_factory): specfile = "contrib/starlette/data/v3.0/starlette_factory.yaml" - return factory.spec_from_file(specfile) + return schema_path_factory.from_file(specfile) @pytest.fixture def app(self): @@ -45,13 +45,13 @@ async def test_route(scope, receive, send): def client(self, app): return TestClient(app, base_url="http://localhost") - def test_request_validator_path_pattern(self, client, spec): + def test_request_validator_path_pattern(self, client, schema_path): response_data = {"data": "data"} async def test_route(request): body = await request.body() openapi_request = StarletteOpenAPIRequest(request, body) - result = unmarshal_request(openapi_request, spec) + result = unmarshal_request(openapi_request, schema_path) assert not result.errors return JSONResponse( response_data, @@ -81,7 +81,7 @@ async def test_route(request): assert response.status_code == 200 assert response.json() == response_data - def test_response_validator_path_pattern(self, client, spec): + def test_response_validator_path_pattern(self, client, schema_path): response_data = {"data": "data"} def test_route(request): @@ -94,7 +94,7 @@ def test_route(request): openapi_request = StarletteOpenAPIRequest(request) openapi_response = StarletteOpenAPIResponse(response) result = unmarshal_response( - openapi_request, openapi_response, spec + openapi_request, openapi_response, schema_path ) assert not result.errors return response diff --git a/tests/integration/contrib/werkzeug/test_werkzeug_validation.py b/tests/integration/contrib/werkzeug/test_werkzeug_validation.py index f2b36ec6..a2641ca8 100644 --- a/tests/integration/contrib/werkzeug/test_werkzeug_validation.py +++ b/tests/integration/contrib/werkzeug/test_werkzeug_validation.py @@ -14,9 +14,9 @@ class TestWerkzeugOpenAPIValidation: @pytest.fixture - def spec(self, factory): + def schema_path(self, schema_path_factory): specfile = "contrib/requests/data/v3.1/requests_factory.yaml" - return factory.spec_from_file(specfile) + return schema_path_factory.from_file(specfile) @pytest.fixture def app(self): @@ -39,7 +39,7 @@ def test_app(environ, start_response): def client(self, app): return Client(app) - def test_request_validator_root_path(self, client, spec): + def test_request_validator_root_path(self, client, schema_path): query_string = { "q": "string", } @@ -53,11 +53,11 @@ def test_request_validator_root_path(self, client, spec): headers=headers, ) openapi_request = WerkzeugOpenAPIRequest(response.request) - unmarshaller = V30RequestUnmarshaller(spec) + unmarshaller = V30RequestUnmarshaller(schema_path) result = unmarshaller.unmarshal(openapi_request) assert not result.errors - def test_request_validator_path_pattern(self, client, spec): + def test_request_validator_path_pattern(self, client, schema_path): query_string = { "q": "string", } @@ -71,12 +71,12 @@ def test_request_validator_path_pattern(self, client, spec): headers=headers, ) openapi_request = WerkzeugOpenAPIRequest(response.request) - unmarshaller = V30RequestUnmarshaller(spec) + unmarshaller = V30RequestUnmarshaller(schema_path) result = unmarshaller.unmarshal(openapi_request) assert not result.errors @responses.activate - def test_response_validator_path_pattern(self, client, spec): + def test_response_validator_path_pattern(self, client, schema_path): query_string = { "q": "string", } @@ -91,6 +91,6 @@ def test_response_validator_path_pattern(self, client, spec): ) openapi_request = WerkzeugOpenAPIRequest(response.request) openapi_response = WerkzeugOpenAPIResponse(response) - unmarshaller = V30ResponseUnmarshaller(spec) + unmarshaller = V30ResponseUnmarshaller(schema_path) result = unmarshaller.unmarshal(openapi_request, openapi_response) assert not result.errors diff --git a/tests/integration/schema/test_empty.py b/tests/integration/schema/test_empty.py deleted file mode 100644 index bf2c3132..00000000 --- a/tests/integration/schema/test_empty.py +++ /dev/null @@ -1,11 +0,0 @@ -import pytest -from openapi_spec_validator.validation.exceptions import ValidatorDetectError - -from openapi_core import Spec - - -class TestEmpty: - def test_raises_on_invalid(self): - with pytest.warns(DeprecationWarning): - with pytest.raises(ValidatorDetectError): - Spec.from_dict("") diff --git a/tests/integration/schema/test_link_spec.py b/tests/integration/schema/test_link_spec.py index 7e519f9b..2abb5b75 100644 --- a/tests/integration/schema/test_link_spec.py +++ b/tests/integration/schema/test_link_spec.py @@ -9,9 +9,9 @@ class TestLinkSpec: "data/v3.1/links.yaml", ], ) - def test_no_param(self, spec_file, factory): - spec = factory.spec_from_file(spec_file) - resp = spec / "paths#/status#get#responses#default" + def test_no_param(self, spec_file, schema_path_factory): + schema_path = schema_path_factory.from_file(spec_file) + resp = schema_path / "paths#/status#get#responses#default" links = resp / "links" assert len(links) == 1 @@ -29,9 +29,9 @@ def test_no_param(self, spec_file, factory): "data/v3.1/links.yaml", ], ) - def test_param(self, spec_file, factory): - spec = factory.spec_from_file(spec_file) - resp = spec / "paths#/status/{resourceId}#get#responses#default" + def test_param(self, spec_file, schema_path_factory): + schema_path = schema_path_factory.from_file(spec_file) + resp = schema_path / "paths#/status/{resourceId}#get#responses#default" links = resp / "links" assert len(links) == 1 diff --git a/tests/integration/schema/test_path_params.py b/tests/integration/schema/test_path_params.py index 34ed7d05..20d3e6d9 100644 --- a/tests/integration/schema/test_path_params.py +++ b/tests/integration/schema/test_path_params.py @@ -9,10 +9,10 @@ class TestMinimal: "data/v3.1/path_param.yaml", ], ) - def test_param_present(self, spec_file, factory): - spec = factory.spec_from_file(spec_file) + def test_param_present(self, spec_file, schema_path_factory): + schema_path = schema_path_factory.from_file(spec_file) - path = spec / "paths#/resource/{resId}" + path = schema_path / "paths#/resource/{resId}" parameters = path / "parameters" assert len(parameters) == 1 diff --git a/tests/integration/schema/test_spec.py b/tests/integration/schema/test_spec.py index 60eff027..56f14c29 100644 --- a/tests/integration/schema/test_spec.py +++ b/tests/integration/schema/test_spec.py @@ -23,26 +23,26 @@ def base_uri(self): return "file://tests/integration/data/v3.0/petstore.yaml" @pytest.fixture - def spec_dict(self, factory): - content, _ = factory.content_from_file("data/v3.0/petstore.yaml") + def spec_dict(self, content_factory): + content, _ = content_factory.from_file("data/v3.0/petstore.yaml") return content @pytest.fixture - def spec(self, spec_dict, base_uri): + def schema_path(self, spec_dict, base_uri): return SchemaPath.from_dict(spec_dict, base_uri=base_uri) @pytest.fixture - def request_validator(self, spec): - return V30RequestValidator(spec) + def request_validator(self, schema_path): + return V30RequestValidator(schema_path) @pytest.fixture - def response_validator(self, spec): - return V30ResponseValidator(spec) + def response_validator(self, schema_path): + return V30ResponseValidator(schema_path) - def test_spec(self, spec, spec_dict): + def test_spec(self, schema_path, spec_dict): url = "http://petstore.swagger.io/v1" - info = spec / "info" + info = schema_path / "info" info_spec = spec_dict["info"] assert info["title"] == info_spec["title"] assert info["description"] == info_spec["description"] @@ -60,16 +60,16 @@ def test_spec(self, spec, spec_dict): assert license["name"] == license_spec["name"] assert license["url"] == license_spec["url"] - security = spec / "security" + security = schema_path / "security" security_spec = spec_dict.get("security", []) for idx, security_reqs in enumerate(security): security_reqs_spec = security_spec[idx] for scheme_name, security_req in security_reqs.items(): security_req == security_reqs_spec[scheme_name] - assert get_spec_url(spec) == url + assert get_spec_url(schema_path) == url - servers = spec / "servers" + servers = schema_path / "servers" for idx, server in enumerate(servers): server_spec = spec_dict["servers"][idx] assert server["url"] == server_spec["url"] @@ -81,7 +81,7 @@ def test_spec(self, spec, spec_dict): assert variable["default"] == variable_spec["default"] assert variable["enum"] == variable_spec.get("enum") - paths = spec / "paths" + paths = schema_path / "paths" for path_name, path in paths.items(): path_spec = spec_dict["paths"][path_name] assert path.getkey("summary") == path_spec.get("summary") @@ -287,7 +287,7 @@ def test_spec(self, spec, spec_dict): "required" ) - components = spec.get("components") + components = schema_path.get("components") if not components: return @@ -312,14 +312,14 @@ def base_uri(self): return "file://tests/integration/data/v3.1/webhook-example.yaml" @pytest.fixture - def spec_dict(self, factory): - content, _ = factory.content_from_file( + def spec_dict(self, content_factory): + content, _ = content_factory.from_file( "data/v3.1/webhook-example.yaml" ) return content @pytest.fixture - def spec(self, spec_dict, base_uri): + def schema_path(self, spec_dict, base_uri): return SchemaPath.from_dict( spec_dict, base_uri=base_uri, @@ -333,17 +333,17 @@ def request_validator(self, spec): def response_validator(self, spec): return ResponseValidator(spec) - def test_spec(self, spec, spec_dict): - info = spec / "info" + def test_spec(self, schema_path, spec_dict): + info = schema_path / "info" info_spec = spec_dict["info"] assert info["title"] == info_spec["title"] assert info["version"] == info_spec["version"] - webhooks = spec / "webhooks" + webhooks = schema_path / "webhooks" webhooks_spec = spec_dict["webhooks"] assert webhooks["newPet"] == webhooks_spec["newPet"] - components = spec.get("components") + components = schema_path.get("components") if not components: return diff --git a/tests/integration/test_minimal.py b/tests/integration/test_minimal.py index 6575e06a..8d80c3d2 100644 --- a/tests/integration/test_minimal.py +++ b/tests/integration/test_minimal.py @@ -25,8 +25,8 @@ class TestMinimal: @pytest.mark.parametrize("server", servers) @pytest.mark.parametrize("spec_path", spec_paths) - def test_hosts(self, factory, server, spec_path): - spec = factory.spec_from_file(spec_path) + def test_hosts(self, schema_path_factory, server, spec_path): + spec = schema_path_factory.from_file(spec_path) request = MockRequest(server, "get", "/status") result = unmarshal_request(request, spec=spec) @@ -35,8 +35,8 @@ def test_hosts(self, factory, server, spec_path): @pytest.mark.parametrize("server", servers) @pytest.mark.parametrize("spec_path", spec_paths) - def test_invalid_operation(self, factory, server, spec_path): - spec = factory.spec_from_file(spec_path) + def test_invalid_operation(self, schema_path_factory, server, spec_path): + spec = schema_path_factory.from_file(spec_path) request = MockRequest(server, "post", "/status") with pytest.raises(OperationNotFound): @@ -44,8 +44,8 @@ def test_invalid_operation(self, factory, server, spec_path): @pytest.mark.parametrize("server", servers) @pytest.mark.parametrize("spec_path", spec_paths) - def test_invalid_path(self, factory, server, spec_path): - spec = factory.spec_from_file(spec_path) + def test_invalid_path(self, schema_path_factory, server, spec_path): + spec = schema_path_factory.from_file(spec_path) request = MockRequest(server, "get", "/nonexistent") with pytest.raises(PathNotFound): diff --git a/tests/integration/unmarshalling/test_read_only_write_only.py b/tests/integration/unmarshalling/test_read_only_write_only.py index d8727cac..6297654e 100644 --- a/tests/integration/unmarshalling/test_read_only_write_only.py +++ b/tests/integration/unmarshalling/test_read_only_write_only.py @@ -16,18 +16,18 @@ @pytest.fixture(scope="class") -def spec(factory): - return factory.spec_from_file("data/v3.0/read_only_write_only.yaml") +def schema_path(schema_path_factory): + return schema_path_factory.from_file("data/v3.0/read_only_write_only.yaml") @pytest.fixture(scope="class") -def request_unmarshaller(spec): - return V30RequestUnmarshaller(spec) +def request_unmarshaller(schema_path): + return V30RequestUnmarshaller(schema_path) @pytest.fixture(scope="class") -def response_unmarshaller(spec): - return V30ResponseUnmarshaller(spec) +def response_unmarshaller(schema_path): + return V30ResponseUnmarshaller(schema_path) class TestReadOnly: diff --git a/tests/integration/unmarshalling/test_security_override.py b/tests/integration/unmarshalling/test_security_override.py index 40efa6d1..8e549d6a 100644 --- a/tests/integration/unmarshalling/test_security_override.py +++ b/tests/integration/unmarshalling/test_security_override.py @@ -11,13 +11,13 @@ @pytest.fixture(scope="class") -def spec(factory): - return factory.spec_from_file("data/v3.0/security_override.yaml") +def schema_path(schema_path_factory): + return schema_path_factory.from_file("data/v3.0/security_override.yaml") @pytest.fixture(scope="class") -def request_unmarshaller(spec): - return V30RequestUnmarshaller(spec) +def request_unmarshaller(schema_path): + return V30RequestUnmarshaller(schema_path) class TestSecurityOverride: diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 63fad9df..cb19dafb 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,15 +1,37 @@ +from json import dumps +from os import unlink +from tempfile import NamedTemporaryFile + import pytest from jsonschema_path import SchemaPath @pytest.fixture def spec_v20(): - return SchemaPath.from_dict({"swagger": "2.0"}) + return SchemaPath.from_dict( + { + "swagger": "2.0", + "info": { + "title": "Spec", + "version": "0.0.1", + }, + "paths": {}, + } + ) @pytest.fixture def spec_v30(): - return SchemaPath.from_dict({"openapi": "3.0.0"}) + return SchemaPath.from_dict( + { + "openapi": "3.0.0", + "info": { + "title": "Spec", + "version": "0.0.1", + }, + "paths": {}, + } + ) @pytest.fixture @@ -29,3 +51,19 @@ def spec_v31(): @pytest.fixture def spec_invalid(): return SchemaPath.from_dict({}) + + +@pytest.fixture +def create_file(): + files = [] + + def create(schema): + contents = dumps(schema).encode("utf-8") + with NamedTemporaryFile(delete=False) as tf: + files.append(tf) + tf.write(contents) + return tf.name + + yield create + for tf in files: + unlink(tf.name) diff --git a/tests/unit/test_app.py b/tests/unit/test_app.py new file mode 100644 index 00000000..a98f7a8b --- /dev/null +++ b/tests/unit/test_app.py @@ -0,0 +1,77 @@ +from pathlib import Path + +import pytest + +from openapi_core import Config +from openapi_core import OpenAPI +from openapi_core.exceptions import SpecError + + +class TestOpenAPIFromPath: + def test_valid(self, create_file): + spec_dict = { + "openapi": "3.1.0", + "info": { + "title": "Spec", + "version": "0.0.1", + }, + "paths": {}, + } + file_path = create_file(spec_dict) + path = Path(file_path) + result = OpenAPI.from_path(path) + + assert type(result) == OpenAPI + assert result.spec.contents() == spec_dict + + +class TestOpenAPIFromFilePath: + def test_valid(self, create_file): + spec_dict = { + "openapi": "3.1.0", + "info": { + "title": "Spec", + "version": "0.0.1", + }, + "paths": {}, + } + file_path = create_file(spec_dict) + result = OpenAPI.from_file_path(file_path) + + assert type(result) == OpenAPI + assert result.spec.contents() == spec_dict + + +class TestOpenAPIFromFile: + def test_valid(self, create_file): + spec_dict = { + "openapi": "3.1.0", + "info": { + "title": "Spec", + "version": "0.0.1", + }, + "paths": {}, + } + file_path = create_file(spec_dict) + with open(file_path) as f: + result = OpenAPI.from_file(f) + + assert type(result) == OpenAPI + assert result.spec.contents() == spec_dict + + +class TestOpenAPIFromDict: + def test_spec_error(self): + spec_dict = {} + + with pytest.raises(SpecError): + OpenAPI.from_dict(spec_dict) + + def test_check_skipped(self): + spec_dict = {} + config = Config(spec_validator_cls=None) + + result = OpenAPI.from_dict(spec_dict, config=config) + + assert type(result) == OpenAPI + assert result.spec.contents() == spec_dict diff --git a/tests/unit/test_paths_spec.py b/tests/unit/test_paths_spec.py index f93dae47..8167abf3 100644 --- a/tests/unit/test_paths_spec.py +++ b/tests/unit/test_paths_spec.py @@ -1,26 +1,11 @@ import pytest -from openapi_spec_validator import openapi_v31_spec_validator -from openapi_spec_validator.validation.exceptions import OpenAPIValidationError from openapi_core import Spec class TestSpecFromDict: - def test_validator(self): + def test_deprecated(self): schema = {} with pytest.warns(DeprecationWarning): - with pytest.raises(OpenAPIValidationError): - Spec.from_dict(schema, validator=openapi_v31_spec_validator) - - def test_validator_none(self): - schema = {} - - with pytest.warns(DeprecationWarning): - Spec.from_dict(schema, validator=None) - - def test_spec_validator_cls_none(self): - schema = {} - - with pytest.warns(DeprecationWarning): - Spec.from_dict(schema, spec_validator_cls=None) + Spec.from_dict(schema) diff --git a/tests/unit/test_shortcuts.py b/tests/unit/test_shortcuts.py index 1d69c69e..963c4658 100644 --- a/tests/unit/test_shortcuts.py +++ b/tests/unit/test_shortcuts.py @@ -537,7 +537,7 @@ def test_request(self, mock_validate, spec_v31): result = validate_apicall_request(request, spec=spec_v31) - assert result == mock_validate.return_value + assert result is None mock_validate.assert_called_once_with(request) @@ -588,7 +588,7 @@ def test_request(self, mock_validate, spec_v31): result = validate_webhook_request(request, spec=spec_v31) - assert result == mock_validate.return_value + assert result is None mock_validate.assert_called_once_with(request) @@ -812,7 +812,7 @@ def test_request_response(self, mock_validate, spec_v31): result = validate_apicall_response(request, response, spec=spec_v31) - assert result == mock_validate.return_value + assert result is None mock_validate.assert_called_once_with(request, response) @@ -879,7 +879,7 @@ def test_request_response(self, mock_validate, spec_v31): result = validate_webhook_response(request, response, spec=spec_v31) - assert result == mock_validate.return_value + assert result is None mock_validate.assert_called_once_with(request, response)