From e3ec0e49f4506e232e67131755b012912e35a2a3 Mon Sep 17 00:00:00 2001 From: p1c2u Date: Sat, 24 Sep 2022 22:54:05 +0100 Subject: [PATCH] Add any-of Co-authored-by: Coen van der Kamp Co-authored-by: Sigurd Spieckermann --- openapi_core/schema/schemas.py | 6 - .../unmarshalling/schemas/unmarshallers.py | 125 +++++++++++------ tests/unit/unmarshalling/test_unmarshal.py | 59 ++++++++ tests/unit/unmarshalling/test_validate.py | 126 ++++++++++++++++++ 4 files changed, 266 insertions(+), 50 deletions(-) diff --git a/openapi_core/schema/schemas.py b/openapi_core/schema/schemas.py index b7737374..9cdc2e92 100644 --- a/openapi_core/schema/schemas.py +++ b/openapi_core/schema/schemas.py @@ -1,6 +1,5 @@ from typing import Any from typing import Dict -from typing import Set from openapi_core.spec import Spec @@ -17,8 +16,3 @@ def get_all_properties(schema: Spec) -> Dict[str, Any]: properties_dict.update(subschema_props) return properties_dict - - -def get_all_properties_names(schema: Spec) -> Set[str]: - all_properties = get_all_properties(schema) - return set(all_properties.keys()) diff --git a/openapi_core/unmarshalling/schemas/unmarshallers.py b/openapi_core/unmarshalling/schemas/unmarshallers.py index 6c855cff..2fe4539c 100644 --- a/openapi_core/unmarshalling/schemas/unmarshallers.py +++ b/openapi_core/unmarshalling/schemas/unmarshallers.py @@ -19,7 +19,6 @@ from openapi_core.extensions.models.factories import ModelClassImporter from openapi_core.schema.schemas import get_all_properties -from openapi_core.schema.schemas import get_all_properties_names from openapi_core.spec import Spec from openapi_core.unmarshalling.schemas.datatypes import FormattersDict from openapi_core.unmarshalling.schemas.enums import UnmarshalContext @@ -201,6 +200,15 @@ def object_class_factory(self) -> ModelClassImporter: return ModelClassImporter() def unmarshal(self, value: Any) -> Any: + properties = self.unmarshal_raw(value) + + model = self.schema.getkey("x-model") + fields: Iterable[str] = properties and properties.keys() or [] + object_class = self.object_class_factory.create(fields, model=model) + + return object_class(**properties) + + def unmarshal_raw(self, value: Any) -> Any: try: value = self.formatter.unmarshal(value) except ValueError as exc: @@ -209,65 +217,57 @@ def unmarshal(self, value: Any) -> Any: else: return self._unmarshal_object(value) + def _clone(self, schema: Spec) -> "ObjectUnmarshaller": + return ObjectUnmarshaller( + schema, + self.validator, + self.formatter, + self.unmarshallers_factory, + self.context, + ) + def _unmarshal_object(self, value: Any) -> Any: + properties = {} + if "oneOf" in self.schema: - properties = None + one_of_properties = None for one_of_schema in self.schema / "oneOf": try: - unmarshalled = self._unmarshal_properties( - value, one_of_schema + unmarshalled = self._clone(one_of_schema).unmarshal_raw( + value ) except (UnmarshalError, ValueError): pass else: - if properties is not None: + if one_of_properties is not None: log.warning("multiple valid oneOf schemas found") continue - properties = unmarshalled + one_of_properties = unmarshalled - if properties is None: + if one_of_properties is None: log.warning("valid oneOf schema not found") + else: + properties.update(one_of_properties) - else: - properties = self._unmarshal_properties(value) - - model = self.schema.getkey("x-model") - fields: Iterable[str] = properties and properties.keys() or [] - object_class = self.object_class_factory.create(fields, model=model) - - return object_class(**properties) - - def _unmarshal_properties( - self, value: Any, one_of_schema: Optional[Spec] = None - ) -> Dict[str, Any]: - all_props = get_all_properties(self.schema) - all_props_names = get_all_properties_names(self.schema) - - if one_of_schema is not None: - all_props.update(get_all_properties(one_of_schema)) - all_props_names |= get_all_properties_names(one_of_schema) - - value_props_names = list(value.keys()) - extra_props = set(value_props_names) - set(all_props_names) + elif "anyOf" in self.schema: + any_of_properties = None + for any_of_schema in self.schema / "anyOf": + try: + unmarshalled = self._clone(any_of_schema).unmarshal_raw( + value + ) + except (UnmarshalError, ValueError): + pass + else: + any_of_properties = unmarshalled + break - properties: Dict[str, Any] = {} - additional_properties = self.schema.getkey( - "additionalProperties", True - ) - if additional_properties is not False: - # free-form object - if additional_properties is True: - additional_prop_schema = Spec.from_dict({}) - # defined schema + if any_of_properties is None: + log.warning("valid anyOf schema not found") else: - additional_prop_schema = self.schema / "additionalProperties" - for prop_name in extra_props: - prop_value = value[prop_name] - properties[prop_name] = self.unmarshallers_factory.create( - additional_prop_schema - )(prop_value) + properties.update(any_of_properties) - for prop_name, prop in list(all_props.items()): + for prop_name, prop in get_all_properties(self.schema).items(): read_only = prop.getkey("readOnly", False) if self.context == UnmarshalContext.REQUEST and read_only: continue @@ -285,6 +285,24 @@ def _unmarshal_properties( prop_value ) + additional_properties = self.schema.getkey( + "additionalProperties", True + ) + if additional_properties is not False: + # free-form object + if additional_properties is True: + additional_prop_schema = Spec.from_dict({}) + # defined schema + else: + additional_prop_schema = self.schema / "additionalProperties" + additional_prop_unmarshaler = self.unmarshallers_factory.create( + additional_prop_schema + ) + for prop_name, prop_value in value.items(): + if prop_name in properties: + continue + properties[prop_name] = additional_prop_unmarshaler(prop_value) + return properties @@ -304,6 +322,10 @@ def unmarshal(self, value: Any) -> Any: if one_of_schema: return self.unmarshallers_factory.create(one_of_schema)(value) + any_of_schema = self._get_any_of_schema(value) + if any_of_schema: + return self.unmarshallers_factory.create(any_of_schema)(value) + all_of_schema = self._get_all_of_schema(value) if all_of_schema: return self.unmarshallers_factory.create(all_of_schema)(value) @@ -338,6 +360,21 @@ def _get_one_of_schema(self, value: Any) -> Optional[Spec]: return subschema return None + def _get_any_of_schema(self, value: Any) -> Optional[Spec]: + if "anyOf" not in self.schema: + return None + + any_of_schemas = self.schema / "anyOf" + for subschema in any_of_schemas: + unmarshaller = self.unmarshallers_factory.create(subschema) + try: + unmarshaller.validate(value) + except ValidateError: + continue + else: + return subschema + return None + def _get_all_of_schema(self, value: Any) -> Optional[Spec]: if "allOf" not in self.schema: return None diff --git a/tests/unit/unmarshalling/test_unmarshal.py b/tests/unit/unmarshalling/test_unmarshal.py index 3b33e133..cc332f1b 100644 --- a/tests/unit/unmarshalling/test_unmarshal.py +++ b/tests/unit/unmarshalling/test_unmarshal.py @@ -560,6 +560,65 @@ def test_schema_any_one_of(self, unmarshaller_factory): spec = Spec.from_dict(schema) assert unmarshaller_factory(spec)(["hello"]) == ["hello"] + def test_schema_any_any_of(self, unmarshaller_factory): + schema = { + "anyOf": [ + { + "type": "string", + }, + { + "type": "array", + "items": { + "type": "string", + }, + }, + ], + } + spec = Spec.from_dict(schema) + assert unmarshaller_factory(spec)(["hello"]) == ["hello"] + + def test_schema_object_any_of(self, unmarshaller_factory): + schema = { + "type": "object", + "anyOf": [ + { + "type": "object", + "required": ["someint"], + "properties": {"someint": {"type": "integer"}}, + }, + { + "type": "object", + "required": ["somestr"], + "properties": {"somestr": {"type": "string"}}, + }, + ], + } + spec = Spec.from_dict(schema) + result = unmarshaller_factory(spec)({"someint": 1}) + + assert is_dataclass(result) + assert result.someint == 1 + + def test_schema_object_any_of_invalid(self, unmarshaller_factory): + schema = { + "type": "object", + "anyOf": [ + { + "type": "object", + "required": ["someint"], + "properties": {"someint": {"type": "integer"}}, + }, + { + "type": "object", + "required": ["somestr"], + "properties": {"somestr": {"type": "string"}}, + }, + ], + } + spec = Spec.from_dict(schema) + with pytest.raises(UnmarshalError): + unmarshaller_factory(spec)({"someint": "1"}) + def test_schema_any_all_of(self, unmarshaller_factory): schema = { "allOf": [ diff --git a/tests/unit/unmarshalling/test_validate.py b/tests/unit/unmarshalling/test_validate.py index 07547d10..9ad18fa5 100644 --- a/tests/unit/unmarshalling/test_validate.py +++ b/tests/unit/unmarshalling/test_validate.py @@ -863,6 +863,132 @@ def test_unambiguous_one_of(self, value, validator_factory): assert result is None + @pytest.mark.parametrize( + "value", + [ + {}, + ], + ) + def test_object_multiple_any_of(self, value, validator_factory): + any_of = [ + { + "type": "object", + }, + { + "type": "object", + }, + ] + schema = { + "type": "object", + "anyOf": any_of, + } + spec = Spec.from_dict(schema) + + result = validator_factory(spec).validate(value) + + assert result is None + + @pytest.mark.parametrize( + "value", + [ + {}, + ], + ) + def test_object_different_type_any_of(self, value, validator_factory): + any_of = [{"type": "integer"}, {"type": "string"}] + schema = { + "type": "object", + "anyOf": any_of, + } + spec = Spec.from_dict(schema) + + with pytest.raises(InvalidSchemaValue): + validator_factory(spec).validate(value) + + @pytest.mark.parametrize( + "value", + [ + {}, + ], + ) + def test_object_no_any_of(self, value, validator_factory): + any_of = [ + { + "type": "object", + "required": ["test1"], + "properties": { + "test1": { + "type": "string", + }, + }, + }, + { + "type": "object", + "required": ["test2"], + "properties": { + "test2": { + "type": "string", + }, + }, + }, + ] + schema = { + "type": "object", + "anyOf": any_of, + } + spec = Spec.from_dict(schema) + + with pytest.raises(InvalidSchemaValue): + validator_factory(spec).validate(value) + + @pytest.mark.parametrize( + "value", + [ + { + "foo": "FOO", + }, + { + "foo": "FOO", + "bar": "BAR", + }, + ], + ) + def test_unambiguous_any_of(self, value, validator_factory): + any_of = [ + { + "type": "object", + "required": ["foo"], + "properties": { + "foo": { + "type": "string", + }, + }, + "additionalProperties": False, + }, + { + "type": "object", + "required": ["foo", "bar"], + "properties": { + "foo": { + "type": "string", + }, + "bar": { + "type": "string", + }, + }, + "additionalProperties": False, + }, + ] + schema = { + "type": "object", + "anyOf": any_of, + } + spec = Spec.from_dict(schema) + + result = validator_factory(spec).validate(value) + + assert result is None + @pytest.mark.parametrize( "value", [