49
49
#include "mbedtls/debug.h"
50
50
#include "mbedtls/error.h"
51
51
52
+ #define MP_STREAM_POLL_RDWR (MP_STREAM_POLL_RD | MP_STREAM_POLL_WR)
53
+
52
54
typedef struct _mp_obj_ssl_socket_t {
53
55
mp_obj_base_t base ;
54
56
mp_obj_t sock ;
@@ -59,6 +61,9 @@ typedef struct _mp_obj_ssl_socket_t {
59
61
mbedtls_x509_crt cacert ;
60
62
mbedtls_x509_crt cert ;
61
63
mbedtls_pk_context pkey ;
64
+
65
+ uintptr_t poll_mask ; // Indicates which read or write operations the protocol needs next
66
+
62
67
#ifdef MBEDTLS_SSL_PROTO_DTLS
63
68
mp_uint_t timer_start_ms ;
64
69
mp_int_t timer_fin_ms ;
@@ -210,6 +215,7 @@ STATIC mp_obj_ssl_socket_t *socket_new(mp_obj_t sock, struct ssl_args *args) {
210
215
#endif
211
216
o -> base .type = & ussl_socket_type ;
212
217
o -> sock = sock ;
218
+ o -> poll_mask = 0 ;
213
219
214
220
int ret ;
215
221
mbedtls_ssl_init (& o -> ssl );
@@ -373,6 +379,7 @@ STATIC void socket_print(const mp_print_t *print, mp_obj_t self_in, mp_print_kin
373
379
374
380
STATIC mp_uint_t socket_read (mp_obj_t o_in , void * buf , mp_uint_t size , int * errcode ) {
375
381
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
382
+ o -> poll_mask = 0 ;
376
383
377
384
int ret = mbedtls_ssl_read (& o -> ssl , buf , size );
378
385
if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY ) {
@@ -389,13 +396,15 @@ STATIC mp_uint_t socket_read(mp_obj_t o_in, void *buf, mp_uint_t size, int *errc
389
396
// wanting to write next handshake message. The same may happen with
390
397
// renegotation.
391
398
ret = MP_EWOULDBLOCK ;
399
+ o -> poll_mask = MP_STREAM_POLL_WR ;
392
400
}
393
401
* errcode = ret ;
394
402
return MP_STREAM_ERROR ;
395
403
}
396
404
397
405
STATIC mp_uint_t socket_write (mp_obj_t o_in , const void * buf , mp_uint_t size , int * errcode ) {
398
406
mp_obj_ssl_socket_t * o = MP_OBJ_TO_PTR (o_in );
407
+ o -> poll_mask = 0 ;
399
408
400
409
int ret = mbedtls_ssl_write (& o -> ssl , buf , size );
401
410
if (ret >= 0 ) {
@@ -408,6 +417,7 @@ STATIC mp_uint_t socket_write(mp_obj_t o_in, const void *buf, mp_uint_t size, in
408
417
// wanting to read next handshake message. The same may happen with
409
418
// renegotation.
410
419
ret = MP_EWOULDBLOCK ;
420
+ o -> poll_mask = MP_STREAM_POLL_RD ;
411
421
}
412
422
* errcode = ret ;
413
423
return MP_STREAM_ERROR ;
@@ -433,8 +443,44 @@ STATIC mp_uint_t socket_ioctl(mp_obj_t o_in, mp_uint_t request, uintptr_t arg, i
433
443
mbedtls_ssl_config_free (& self -> conf );
434
444
mbedtls_ctr_drbg_free (& self -> ctr_drbg );
435
445
mbedtls_entropy_free (& self -> entropy );
446
+ } else if (request == MP_STREAM_POLL ) {
447
+ if ((arg & MP_STREAM_POLL_RDWR ) && (_mbedtls_timing_get_delay (self ) >= 1 )) {
448
+ return arg & MP_STREAM_POLL_RDWR ;
449
+ }
450
+
451
+ // If the library signaled us that it needs reading or writing, only check that direction,
452
+ // but save what the caller asked because we need to restore it later
453
+ int saved_arg = 0 ;
454
+ if (self -> poll_mask && (arg & MP_STREAM_POLL_RDWR )) {
455
+ saved_arg = arg & MP_STREAM_POLL_RDWR ;
456
+ arg = ~saved_arg | self -> poll_mask ;
457
+ }
458
+
459
+ // Take into account that the library might have buffered data already
460
+ int has_pending = 0 ;
461
+ if (arg & MP_STREAM_POLL_RD ) {
462
+ has_pending = mbedtls_ssl_check_pending (& self -> ssl );
463
+ if (has_pending && arg == MP_STREAM_POLL_RD ) {
464
+ // Shortcut if we only need to read and we have buffered data, no need to go to the underlying socket
465
+ return MP_STREAM_POLL_RD ;
466
+ }
467
+ }
468
+
469
+ // ...otherwise fall through to pass request to underlying socket
470
+ mp_uint_t ret = mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
471
+
472
+ if (has_pending ) {
473
+ ret |= MP_STREAM_POLL_RD ;
474
+ }
475
+
476
+ // The direction the library needed is available, return a fake result to the caller so that
477
+ // it reenters a read or a write to allow the handshake to progress
478
+ if (ret & self -> poll_mask ) {
479
+ ret |= saved_arg ;
480
+ }
481
+ return ret ;
436
482
}
437
- // Pass all requests down to the underlying socket
483
+ // Pass all other requests down to the underlying socket
438
484
return mp_get_stream (self -> sock )-> ioctl (self -> sock , request , arg , errcode );
439
485
}
440
486
0 commit comments