Skip to content

Commit 6e089ef

Browse files
committed
extmod/ussl: Improve support for polling
1 parent f1d2b81 commit 6e089ef

File tree

1 file changed

+47
-1
lines changed

1 file changed

+47
-1
lines changed

extmod/modussl_mbedtls.c

Lines changed: 47 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@
4949
#include "mbedtls/debug.h"
5050
#include "mbedtls/error.h"
5151

52+
#define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
53+
5254
typedef struct _mp_obj_ssl_socket_t {
5355
mp_obj_base_t base;
5456
mp_obj_t sock;
@@ -59,6 +61,9 @@ typedef struct _mp_obj_ssl_socket_t {
5961
mbedtls_x509_crt cacert;
6062
mbedtls_x509_crt cert;
6163
mbedtls_pk_context pkey;
64+
65+
uintptr_t poll_mask; // Indicates which read or write operations the protocol needs next
66+
6267
#ifdef MBEDTLS_SSL_PROTO_DTLS
6368
mp_uint_t timer_start_ms;
6469
mp_int_t timer_fin_ms;
@@ -210,6 +215,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
210215
#endif
211216
o->base.type = &ussl_socket_type;
212217
o->sock = sock;
218+
o->poll_mask = 0;
213219

214220
int ret;
215221
mbedtls_ssl_init(&o->ssl);
@@ -373,6 +379,7 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
373379

374380
STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errcode) {
375381
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
382+
o->poll_mask = 0;
376383

377384
int ret = mbedtls_ssl_read(&o->ssl, buf, size);
378385
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) {
@@ -389,13 +396,15 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
389396
// wanting to write next handshake message. The same may happen with
390397
// renegotation.
391398
ret = MP_EWOULDBLOCK;
399+
o->poll_mask = MP_STREAM_POLL_WR;
392400
}
393401
*errcode = ret;
394402
return MP_STREAM_ERROR;
395403
}
396404

397405
STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, int *errcode) {
398406
mp_obj_ssl_socket_t *o = MP_OBJ_TO_PTR(o_in);
407+
o->poll_mask = 0;
399408

400409
int ret = mbedtls_ssl_write(&o->ssl, buf, size);
401410
if (ret >= 0) {
@@ -408,6 +417,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
408417
// wanting to read next handshake message. The same may happen with
409418
// renegotation.
410419
ret = MP_EWOULDBLOCK;
420+
o->poll_mask = MP_STREAM_POLL_RD;
411421
}
412422
*errcode = ret;
413423
return MP_STREAM_ERROR;
@@ -433,8 +443,44 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
433443
mbedtls_ssl_config_free(&self->conf);
434444
mbedtls_ctr_drbg_free(&self->ctr_drbg);
435445
mbedtls_entropy_free(&self->entropy);
446+
} else if (request == MP_STREAM_POLL) {
447+
if ((arg & MP_STREAM_POLL_RDWR) && (_mbedtls_timing_get_delay(self) >= 1)) {
448+
return arg & MP_STREAM_POLL_RDWR;
449+
}
450+
451+
// If the library signaled us that it needs reading or writing, only check that direction,
452+
// but save what the caller asked because we need to restore it later
453+
int saved_arg = 0;
454+
if (self->poll_mask && (arg & MP_STREAM_POLL_RDWR)) {
455+
saved_arg = arg & MP_STREAM_POLL_RDWR;
456+
arg = ~saved_arg | self->poll_mask;
457+
}
458+
459+
// Take into account that the library might have buffered data already
460+
int has_pending = 0;
461+
if (arg & MP_STREAM_POLL_RD) {
462+
has_pending = mbedtls_ssl_check_pending(&self->ssl);
463+
if (has_pending && arg == MP_STREAM_POLL_RD) {
464+
// Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
465+
return MP_STREAM_POLL_RD;
466+
}
467+
}
468+
469+
// ...otherwise fall through to pass request to underlying socket
470+
mp_uint_t ret = mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
471+
472+
if (has_pending) {
473+
ret |= MP_STREAM_POLL_RD;
474+
}
475+
476+
// The direction the library needed is available, return a fake result to the caller so that
477+
// it reenters a read or a write to allow the handshake to progress
478+
if (ret & self->poll_mask) {
479+
ret |= saved_arg;
480+
}
481+
return ret;
436482
}
437-
// Pass all requests down to the underlying socket
483+
// Pass all other requests down to the underlying socket
438484
return mp_get_stream(self->sock)->ioctl(self->sock, request, arg, errcode);
439485
}
440486

0 commit comments

Comments
 (0)