Skip to content

Support for recursive messages #130

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Aug 30, 2020
60 changes: 51 additions & 9 deletions src/betterproto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -428,6 +428,7 @@ class ProtoClassMetadata:
"cls_by_field",
"field_name_by_number",
"meta_by_field_name",
"sorted_field_names",
)

def __init__(self, cls: Type["Message"]):
Expand All @@ -453,6 +454,9 @@ def __init__(self, cls: Type["Message"]):
self.oneof_field_by_group = by_group
self.field_name_by_number = by_field_number
self.meta_by_field_name = by_field_name
self.sorted_field_names = tuple(
by_field_number[number] for number in sorted(by_field_number.keys())
)

self.default_gen = self._get_default_gen(cls, fields)
self.cls_by_field = self._get_cls_by_field(cls, fields)
Expand Down Expand Up @@ -513,23 +517,63 @@ def __post_init__(self) -> None:
if meta.group:
group_current.setdefault(meta.group)

if getattr(self, field_name) != PLACEHOLDER:
# Skip anything not set to the sentinel value
if self.__raw_get(field_name) != PLACEHOLDER:
# Found a non-sentinel value
all_sentinel = False

if meta.group:
# This was set, so make it the selected value of the one-of.
group_current[meta.group] = field_name

continue

setattr(self, field_name, self._get_field_default(field_name))

# Now that all the defaults are set, reset it!
self.__dict__["_serialized_on_wire"] = not all_sentinel
self.__dict__["_unknown_fields"] = b""
self.__dict__["_group_current"] = group_current

def __raw_get(self, name: str) -> Any:
return super().__getattribute__(name)

def __eq__(self, other) -> bool:
if type(self) is not type(other):
return False

for field_name in self._betterproto.meta_by_field_name:
self_val = self.__raw_get(field_name)
other_val = other.__raw_get(field_name)
if self_val is PLACEHOLDER:
if other_val is PLACEHOLDER:
continue
self_val = self._get_field_default(field_name)
elif other_val is PLACEHOLDER:
other_val = other._get_field_default(field_name)

if self_val != other_val:
return False

return True

def __repr__(self) -> str:
parts = [
f"{field_name}={value!r}"
for field_name in self._betterproto.sorted_field_names
for value in (self.__raw_get(field_name),)
if value is not PLACEHOLDER
]
return f"{self.__class__.__name__}({', '.join(parts)})"

def __getattribute__(self, name: str) -> Any:
"""
Lazily initialize default values to avoid infinite recursion for recursive
message types
"""
value = super().__getattribute__(name)
if value is not PLACEHOLDER:
return value

value = self._get_field_default(name)
super().__setattr__(name, value)
return value

def __setattr__(self, attr: str, value: Any) -> None:
if attr != "_serialized_on_wire":
# Track when a field has been set.
Expand All @@ -542,9 +586,7 @@ def __setattr__(self, attr: str, value: Any) -> None:
if field.name == attr:
self._group_current[group] = field.name
else:
super().__setattr__(
field.name, self._get_field_default(field.name)
)
super().__setattr__(field.name, PLACEHOLDER)

super().__setattr__(attr, value)

Expand Down
2 changes: 1 addition & 1 deletion src/betterproto/templates/template.py.j2
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ class {{ enum.py_name }}(betterproto.Enum):
{% endfor %}
{% endif %}
{% for message in output_file.messages %}
@dataclass
@dataclass(eq=False, repr=False)
class {{ message.py_name }}(betterproto.Message):
{% if message.comment %}
{{ message.comment }}
Expand Down
12 changes: 12 additions & 0 deletions tests/inputs/recursivemessage/recursivemessage.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
{
"name": "Zues",
"child": {
"name": "Hercules"
},
"intermediate": {
"child": {
"name": "Douglas Adams"
},
"number": 42
}
}
13 changes: 13 additions & 0 deletions tests/inputs/recursivemessage/recursivemessage.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
syntax = "proto3";

message Test {
string name = 1;
Test child = 2;
Intermediate intermediate = 3;
}


message Intermediate {
int32 number = 1;
Test child = 2;
}
48 changes: 48 additions & 0 deletions tests/test_features.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,3 +317,51 @@ def _round_trip_serialization(foo: Foo) -> Foo:
== betterproto.which_one_of(_round_trip_serialization(foo3), "group1")
== ("", None)
)


def test_recursive_message():
from tests.output_betterproto.recursivemessage import Test as RecursiveMessage

msg = RecursiveMessage()

assert msg.child == RecursiveMessage()

# Lazily-created zero-value children must not affect equality.
assert msg == RecursiveMessage()

# Lazily-created zero-value children must not affect serialization.
assert bytes(msg) == b""


def test_recursive_message_defaults():
from tests.output_betterproto.recursivemessage import (
Test as RecursiveMessage,
Intermediate,
)

msg = RecursiveMessage(name="bob", intermediate=Intermediate(42))

# set values are as expected
assert msg == RecursiveMessage(name="bob", intermediate=Intermediate(42))

# lazy initialized works modifies the message
assert msg != RecursiveMessage(
name="bob", intermediate=Intermediate(42), child=RecursiveMessage(name="jude")
)
msg.child.child.name = "jude"
assert msg == RecursiveMessage(
name="bob",
intermediate=Intermediate(42),
child=RecursiveMessage(child=RecursiveMessage(name="jude")),
)

# lazily initialization recurses as needed
assert msg.child.child.child.child.child.child.child == RecursiveMessage()
assert msg.intermediate.child.intermediate == Intermediate()


def test_message_repr():
from tests.output_betterproto.recursivemessage import Test

assert repr(Test(name="Loki")) == "Test(name='Loki')"
assert repr(Test(child=Test(), name="Loki")) == "Test(name='Loki', child=Test())"