diff --git a/src/betterproto/plugin/models.py b/src/betterproto/plugin/models.py index 93b2aa8cf..f0048c830 100644 --- a/src/betterproto/plugin/models.py +++ b/src/betterproto/plugin/models.py @@ -725,6 +725,9 @@ def __post_init__(self) -> None: self.parent.methods.append(self) # Check for imports + if "Optional" in self.py_input_message_type: + self.output_file.typing_imports.add("Optional") + if "Optional" in self.py_output_message_type: self.output_file.typing_imports.add("Optional") diff --git a/tests/inputs/service_complex_structure/payloads/messages.proto b/tests/inputs/service_complex_structure/payloads/messages.proto new file mode 100644 index 000000000..42c101724 --- /dev/null +++ b/tests/inputs/service_complex_structure/payloads/messages.proto @@ -0,0 +1,13 @@ +syntax = "proto3"; + +import "google/protobuf/duration.proto"; +import "google/protobuf/timestamp.proto"; + +package service_complex_structure.payloads; + +message Test { + string name = 1; + repeated string comments = 2; + google.protobuf.Timestamp when = 3; + google.protobuf.Duration duration = 4; +} diff --git a/tests/inputs/service_complex_structure/services/service.proto b/tests/inputs/service_complex_structure/services/service.proto new file mode 100644 index 000000000..d0ea30b3f --- /dev/null +++ b/tests/inputs/service_complex_structure/services/service.proto @@ -0,0 +1,11 @@ +syntax = "proto3"; + +import "google/protobuf/empty.proto"; +import "payloads/messages.proto"; + +package service_complex_structure.service; + + +service Test { + rpc DoThing (service_complex_structure.payloads.Test) returns (google.protobuf.Empty); +} diff --git a/tests/inputs/service_complex_structure/test_service_complex_structure.py b/tests/inputs/service_complex_structure/test_service_complex_structure.py new file mode 100644 index 000000000..c46428e1b --- /dev/null +++ b/tests/inputs/service_complex_structure/test_service_complex_structure.py @@ -0,0 +1,25 @@ +import pytest + + +def test_payloads_import_is_generated(): + try: + from tests.output_betterproto.service_complex_structure.service import ( + _payloads__, + ) + + assert hasattr(_payloads__, "Test") + except ImportError: + pytest.fail("Required import for type annotations is not generated!") + + +def test_payloads_import_is_generated_pydantic(): + try: + from tests.output_betterproto_pydantic.service_complex_structure.service import ( + _payloads__, + ) + + assert hasattr(_payloads__, "Test") + except ImportError: + pytest.fail( + "Required import for type annotations is not generated in pydantic output!" + ) diff --git a/tests/util.py b/tests/util.py index 22c4f9012..78a90a903 100644 --- a/tests/util.py +++ b/tests/util.py @@ -34,9 +34,8 @@ def get_files(path, suffix: str) -> Generator[str, None, None]: yield os.path.join(r, filename) -def get_directories(path): - for root, directories, files in os.walk(path): - yield from directories +def get_directories(path: Path) -> Generator[str, None, None]: + yield from (entry.name for entry in path.iterdir() if entry.is_dir()) async def protoc( @@ -79,7 +78,7 @@ async def protoc( "--custom_opt=pydantic_dataclasses", f"--proto_path={path.as_posix()}", f"--custom_out={output_dir.as_posix()}", - *[p.as_posix() for p in path.glob("*.proto")], + *[p.as_posix() for p in path.glob("**/*.proto")], ] else: command = [ @@ -88,7 +87,7 @@ async def protoc( "grpc.tools.protoc", f"--proto_path={path.as_posix()}", f"--{python_out_option}={output_dir.as_posix()}", - *[p.as_posix() for p in path.glob("*.proto")], + *[p.as_posix() for p in path.glob("**/*.proto")], ] proc = await asyncio.create_subprocess_exec( *command, stdout=asyncio.subprocess.PIPE, stderr=asyncio.subprocess.PIPE