Skip to content

anyOf support #423

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Sep 25, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 0 additions & 6 deletions openapi_core/schema/schemas.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import Any
from typing import Dict
from typing import Set

from openapi_core.spec import Spec

Expand All @@ -17,8 +16,3 @@ def get_all_properties(schema: Spec) -> Dict[str, Any]:
properties_dict.update(subschema_props)

return properties_dict


def get_all_properties_names(schema: Spec) -> Set[str]:
all_properties = get_all_properties(schema)
return set(all_properties.keys())
125 changes: 81 additions & 44 deletions openapi_core/unmarshalling/schemas/unmarshallers.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@

from openapi_core.extensions.models.factories import ModelClassImporter
from openapi_core.schema.schemas import get_all_properties
from openapi_core.schema.schemas import get_all_properties_names
from openapi_core.spec import Spec
from openapi_core.unmarshalling.schemas.datatypes import FormattersDict
from openapi_core.unmarshalling.schemas.enums import UnmarshalContext
Expand Down Expand Up @@ -201,6 +200,15 @@ def object_class_factory(self) -> ModelClassImporter:
return ModelClassImporter()

def unmarshal(self, value: Any) -> Any:
properties = self.unmarshal_raw(value)

model = self.schema.getkey("x-model")
fields: Iterable[str] = properties and properties.keys() or []
object_class = self.object_class_factory.create(fields, model=model)

return object_class(**properties)

def unmarshal_raw(self, value: Any) -> Any:
try:
value = self.formatter.unmarshal(value)
except ValueError as exc:
Expand All @@ -209,65 +217,57 @@ def unmarshal(self, value: Any) -> Any:
else:
return self._unmarshal_object(value)

def _clone(self, schema: Spec) -> "ObjectUnmarshaller":
return ObjectUnmarshaller(
schema,
self.validator,
self.formatter,
self.unmarshallers_factory,
self.context,
)

def _unmarshal_object(self, value: Any) -> Any:
properties = {}

if "oneOf" in self.schema:
properties = None
one_of_properties = None
for one_of_schema in self.schema / "oneOf":
try:
unmarshalled = self._unmarshal_properties(
value, one_of_schema
unmarshalled = self._clone(one_of_schema).unmarshal_raw(
value
)
except (UnmarshalError, ValueError):
pass
else:
if properties is not None:
if one_of_properties is not None:
log.warning("multiple valid oneOf schemas found")
continue
properties = unmarshalled
one_of_properties = unmarshalled

if properties is None:
if one_of_properties is None:
log.warning("valid oneOf schema not found")
else:
properties.update(one_of_properties)

else:
properties = self._unmarshal_properties(value)

model = self.schema.getkey("x-model")
fields: Iterable[str] = properties and properties.keys() or []
object_class = self.object_class_factory.create(fields, model=model)

return object_class(**properties)

def _unmarshal_properties(
self, value: Any, one_of_schema: Optional[Spec] = None
) -> Dict[str, Any]:
all_props = get_all_properties(self.schema)
all_props_names = get_all_properties_names(self.schema)

if one_of_schema is not None:
all_props.update(get_all_properties(one_of_schema))
all_props_names |= get_all_properties_names(one_of_schema)

value_props_names = list(value.keys())
extra_props = set(value_props_names) - set(all_props_names)
elif "anyOf" in self.schema:
any_of_properties = None
for any_of_schema in self.schema / "anyOf":
try:
unmarshalled = self._clone(any_of_schema).unmarshal_raw(
value
)
except (UnmarshalError, ValueError):
pass
else:
any_of_properties = unmarshalled
break

properties: Dict[str, Any] = {}
additional_properties = self.schema.getkey(
"additionalProperties", True
)
if additional_properties is not False:
# free-form object
if additional_properties is True:
additional_prop_schema = Spec.from_dict({})
# defined schema
if any_of_properties is None:
log.warning("valid anyOf schema not found")
else:
additional_prop_schema = self.schema / "additionalProperties"
for prop_name in extra_props:
prop_value = value[prop_name]
properties[prop_name] = self.unmarshallers_factory.create(
additional_prop_schema
)(prop_value)
properties.update(any_of_properties)

for prop_name, prop in list(all_props.items()):
for prop_name, prop in get_all_properties(self.schema).items():
read_only = prop.getkey("readOnly", False)
if self.context == UnmarshalContext.REQUEST and read_only:
continue
Expand All @@ -285,6 +285,24 @@ def _unmarshal_properties(
prop_value
)

additional_properties = self.schema.getkey(
"additionalProperties", True
)
if additional_properties is not False:
# free-form object
if additional_properties is True:
additional_prop_schema = Spec.from_dict({})
# defined schema
else:
additional_prop_schema = self.schema / "additionalProperties"
additional_prop_unmarshaler = self.unmarshallers_factory.create(
additional_prop_schema
)
for prop_name, prop_value in value.items():
if prop_name in properties:
continue
properties[prop_name] = additional_prop_unmarshaler(prop_value)

return properties


Expand All @@ -304,6 +322,10 @@ def unmarshal(self, value: Any) -> Any:
if one_of_schema:
return self.unmarshallers_factory.create(one_of_schema)(value)

any_of_schema = self._get_any_of_schema(value)
if any_of_schema:
return self.unmarshallers_factory.create(any_of_schema)(value)

all_of_schema = self._get_all_of_schema(value)
if all_of_schema:
return self.unmarshallers_factory.create(all_of_schema)(value)
Expand Down Expand Up @@ -338,6 +360,21 @@ def _get_one_of_schema(self, value: Any) -> Optional[Spec]:
return subschema
return None

def _get_any_of_schema(self, value: Any) -> Optional[Spec]:
if "anyOf" not in self.schema:
return None

any_of_schemas = self.schema / "anyOf"
for subschema in any_of_schemas:
unmarshaller = self.unmarshallers_factory.create(subschema)
try:
unmarshaller.validate(value)
except ValidateError:
continue
else:
return subschema
return None

def _get_all_of_schema(self, value: Any) -> Optional[Spec]:
if "allOf" not in self.schema:
return None
Expand Down
59 changes: 59 additions & 0 deletions tests/unit/unmarshalling/test_unmarshal.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,65 @@ def test_schema_any_one_of(self, unmarshaller_factory):
spec = Spec.from_dict(schema)
assert unmarshaller_factory(spec)(["hello"]) == ["hello"]

def test_schema_any_any_of(self, unmarshaller_factory):
schema = {
"anyOf": [
{
"type": "string",
},
{
"type": "array",
"items": {
"type": "string",
},
},
],
}
spec = Spec.from_dict(schema)
assert unmarshaller_factory(spec)(["hello"]) == ["hello"]

def test_schema_object_any_of(self, unmarshaller_factory):
schema = {
"type": "object",
"anyOf": [
{
"type": "object",
"required": ["someint"],
"properties": {"someint": {"type": "integer"}},
},
{
"type": "object",
"required": ["somestr"],
"properties": {"somestr": {"type": "string"}},
},
],
}
spec = Spec.from_dict(schema)
result = unmarshaller_factory(spec)({"someint": 1})

assert is_dataclass(result)
assert result.someint == 1

def test_schema_object_any_of_invalid(self, unmarshaller_factory):
schema = {
"type": "object",
"anyOf": [
{
"type": "object",
"required": ["someint"],
"properties": {"someint": {"type": "integer"}},
},
{
"type": "object",
"required": ["somestr"],
"properties": {"somestr": {"type": "string"}},
},
],
}
spec = Spec.from_dict(schema)
with pytest.raises(UnmarshalError):
unmarshaller_factory(spec)({"someint": "1"})

def test_schema_any_all_of(self, unmarshaller_factory):
schema = {
"allOf": [
Expand Down
126 changes: 126 additions & 0 deletions tests/unit/unmarshalling/test_validate.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,6 +863,132 @@ def test_unambiguous_one_of(self, value, validator_factory):

assert result is None

@pytest.mark.parametrize(
"value",
[
{},
],
)
def test_object_multiple_any_of(self, value, validator_factory):
any_of = [
{
"type": "object",
},
{
"type": "object",
},
]
schema = {
"type": "object",
"anyOf": any_of,
}
spec = Spec.from_dict(schema)

result = validator_factory(spec).validate(value)

assert result is None

@pytest.mark.parametrize(
"value",
[
{},
],
)
def test_object_different_type_any_of(self, value, validator_factory):
any_of = [{"type": "integer"}, {"type": "string"}]
schema = {
"type": "object",
"anyOf": any_of,
}
spec = Spec.from_dict(schema)

with pytest.raises(InvalidSchemaValue):
validator_factory(spec).validate(value)

@pytest.mark.parametrize(
"value",
[
{},
],
)
def test_object_no_any_of(self, value, validator_factory):
any_of = [
{
"type": "object",
"required": ["test1"],
"properties": {
"test1": {
"type": "string",
},
},
},
{
"type": "object",
"required": ["test2"],
"properties": {
"test2": {
"type": "string",
},
},
},
]
schema = {
"type": "object",
"anyOf": any_of,
}
spec = Spec.from_dict(schema)

with pytest.raises(InvalidSchemaValue):
validator_factory(spec).validate(value)

@pytest.mark.parametrize(
"value",
[
{
"foo": "FOO",
},
{
"foo": "FOO",
"bar": "BAR",
},
],
)
def test_unambiguous_any_of(self, value, validator_factory):
any_of = [
{
"type": "object",
"required": ["foo"],
"properties": {
"foo": {
"type": "string",
},
},
"additionalProperties": False,
},
{
"type": "object",
"required": ["foo", "bar"],
"properties": {
"foo": {
"type": "string",
},
"bar": {
"type": "string",
},
},
"additionalProperties": False,
},
]
schema = {
"type": "object",
"anyOf": any_of,
}
spec = Spec.from_dict(schema)

result = validator_factory(spec).validate(value)

assert result is None

@pytest.mark.parametrize(
"value",
[
Expand Down