Skip to content

Commit 7285ac3

Browse files
Added '.run()' method to TelnetServer.
This is a better than having a separate .start() and .stop() when it comes to cancellation.
1 parent 2d397fb commit 7285ac3

File tree

3 files changed

+54
-35
lines changed

3 files changed

+54
-35
lines changed

examples/telnet/hello-world.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,7 @@ async def interact(connection):
3232

3333
async def main():
3434
server = TelnetServer(interact=interact, port=2323)
35-
server.start()
36-
37-
# Run forever.
38-
await Future()
35+
await server.run()
3936

4037

4138
if __name__ == "__main__":

examples/telnet/toolbar.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -37,10 +37,7 @@ def get_toolbar():
3737

3838
async def main():
3939
server = TelnetServer(interact=interact, port=2323)
40-
server.start()
41-
42-
# Run forever.
43-
await Future()
40+
await server.run()
4441

4542

4643
if __name__ == "__main__":

src/prompt_toolkit/contrib/telnet/server.py

Lines changed: 52 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -283,10 +283,11 @@ def __init__(
283283
self.encoding = encoding
284284
self.style = style
285285
self.enable_cpr = enable_cpr
286+
287+
self._run_task: asyncio.Task[None] | None = None
286288
self._application_tasks: list[asyncio.Task[None]] = []
287289

288290
self.connections: set[TelnetConnection] = set()
289-
self._listen_socket: socket.socket | None = None
290291

291292
@classmethod
292293
def _create_socket(cls, host: str, port: int) -> socket.socket:
@@ -298,44 +299,68 @@ def _create_socket(cls, host: str, port: int) -> socket.socket:
298299
s.listen(4)
299300
return s
300301

301-
def start(self) -> None:
302+
async def run(self) -> None:
302303
"""
303-
Start the telnet server.
304-
Don't forget to call `loop.run_forever()` after doing this.
304+
Run the telnet server, until this gets cancelled.
305305
"""
306-
self._listen_socket = self._create_socket(self.host, self.port)
306+
socket = self._create_socket(self.host, self.port)
307307
logger.info(
308308
"Listening for telnet connections on %s port %r", self.host, self.port
309309
)
310310

311-
get_running_loop().add_reader(self._listen_socket, self._accept)
311+
get_running_loop().add_reader(socket, lambda: self._accept(socket))
312+
313+
try:
314+
# Run forever, until cancelled.
315+
await asyncio.Future()
316+
finally:
317+
get_running_loop().remove_reader(socket)
318+
socket.close()
319+
320+
# Wait for all applications to finish.
321+
for t in self._application_tasks:
322+
t.cancel()
323+
324+
# (This is similar to
325+
# `Application.cancel_and_wait_for_background_tasks`. We wait for the
326+
# background tasks to complete, but don't propagate exceptions, because
327+
# we can't use `ExceptionGroup` yet.)
328+
if len(self._application_tasks) > 0:
329+
await asyncio.wait(
330+
self._application_tasks,
331+
timeout=None,
332+
return_when=asyncio.ALL_COMPLETED,
333+
)
334+
335+
def start(self) -> None:
336+
"""
337+
Start the telnet server (stop by calling and awaiting `stop()`).
338+
339+
Note: When possible, it's better to call `.run()` instead.
340+
"""
341+
if self._run_task is not None:
342+
# Already running.
343+
return
344+
345+
self._run_task = get_running_loop().create_task(self.run())
312346

313347
async def stop(self) -> None:
314-
if self._listen_socket:
315-
get_running_loop().remove_reader(self._listen_socket)
316-
self._listen_socket.close()
317-
318-
# Wait for all applications to finish.
319-
for t in self._application_tasks:
320-
t.cancel()
321-
322-
# (This is similar to
323-
# `Application.cancel_and_wait_for_background_tasks`. We wait for the
324-
# background tasks to complete, but don't propagate exceptions, because
325-
# we can't use `ExceptionGroup` yet.)
326-
if len(self._application_tasks) > 0:
327-
await asyncio.wait(
328-
self._application_tasks, timeout=None, return_when=asyncio.ALL_COMPLETED
329-
)
348+
"""
349+
Stop a telnet server that was started using `.start()` and wait for the
350+
cancellation to complete.
351+
"""
352+
if self._run_task is not None:
353+
self._run_task.cancel()
354+
try:
355+
await self._run_task
356+
except asyncio.CancelledError:
357+
pass
330358

331-
def _accept(self) -> None:
359+
def _accept(self, listen_socket: socket.socket) -> None:
332360
"""
333361
Accept new incoming connection.
334362
"""
335-
if self._listen_socket is None:
336-
return # Should not happen. `_accept` is called after `start`.
337-
338-
conn, addr = self._listen_socket.accept()
363+
conn, addr = listen_socket.accept()
339364
logger.info("New connection %r %r", *addr)
340365

341366
# Run application for this connection.

0 commit comments

Comments
 (0)