tls_mbed.c 7.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221
  1. #include "log.h"
  2. #include "tls.h"
  3. #if MG_TLS == MG_TLS_MBED
  4. #if defined(MBEDTLS_VERSION_NUMBER) && MBEDTLS_VERSION_NUMBER >= 0x03000000
  5. #define MG_MBEDTLS_RNG_GET , mg_mbed_rng, NULL
  6. #else
  7. #define MG_MBEDTLS_RNG_GET
  8. #endif
  9. static int mg_mbed_rng(void *ctx, unsigned char *buf, size_t len) {
  10. mg_random(buf, len);
  11. (void) ctx;
  12. return 0;
  13. }
  14. static bool mg_load_cert(struct mg_str str, mbedtls_x509_crt *p) {
  15. int rc;
  16. if (str.ptr == NULL || str.ptr[0] == '\0' || str.ptr[0] == '*') return true;
  17. if (str.ptr[0] == '-') str.len++; // PEM, include trailing NUL
  18. if ((rc = mbedtls_x509_crt_parse(p, (uint8_t *) str.ptr, str.len)) != 0) {
  19. MG_ERROR(("cert err %#x", -rc));
  20. return false;
  21. }
  22. return true;
  23. }
  24. static bool mg_load_key(struct mg_str str, mbedtls_pk_context *p) {
  25. int rc;
  26. if (str.ptr == NULL || str.ptr[0] == '\0' || str.ptr[0] == '*') return true;
  27. if (str.ptr[0] == '-') str.len++; // PEM, include trailing NUL
  28. if ((rc = mbedtls_pk_parse_key(p, (uint8_t *) str.ptr, str.len, NULL,
  29. 0 MG_MBEDTLS_RNG_GET)) != 0) {
  30. MG_ERROR(("key err %#x", -rc));
  31. return false;
  32. }
  33. return true;
  34. }
  35. void mg_tls_free(struct mg_connection *c) {
  36. struct mg_tls *tls = (struct mg_tls *) c->tls;
  37. if (tls != NULL) {
  38. mbedtls_ssl_free(&tls->ssl);
  39. mbedtls_pk_free(&tls->pk);
  40. mbedtls_x509_crt_free(&tls->ca);
  41. mbedtls_x509_crt_free(&tls->cert);
  42. mbedtls_ssl_config_free(&tls->conf);
  43. #ifdef MBEDTLS_SSL_SESSION_TICKETS
  44. mbedtls_ssl_ticket_free(&tls->ticket);
  45. #endif
  46. free(tls);
  47. c->tls = NULL;
  48. }
  49. }
  50. static int mg_net_send(void *ctx, const unsigned char *buf, size_t len) {
  51. long n = mg_io_send((struct mg_connection *) ctx, buf, len);
  52. MG_VERBOSE(("%lu n=%ld e=%d", ((struct mg_connection *) ctx)->id, n, errno));
  53. if (n == MG_IO_WAIT) return MBEDTLS_ERR_SSL_WANT_WRITE;
  54. if (n == MG_IO_RESET) return MBEDTLS_ERR_NET_CONN_RESET;
  55. if (n == MG_IO_ERR) return MBEDTLS_ERR_NET_SEND_FAILED;
  56. return (int) n;
  57. }
  58. static int mg_net_recv(void *ctx, unsigned char *buf, size_t len) {
  59. long n = mg_io_recv((struct mg_connection *) ctx, buf, len);
  60. MG_VERBOSE(("%lu n=%ld", ((struct mg_connection *) ctx)->id, n));
  61. if (n == MG_IO_WAIT) return MBEDTLS_ERR_SSL_WANT_WRITE;
  62. if (n == MG_IO_RESET) return MBEDTLS_ERR_NET_CONN_RESET;
  63. if (n == MG_IO_ERR) return MBEDTLS_ERR_NET_RECV_FAILED;
  64. return (int) n;
  65. }
  66. void mg_tls_handshake(struct mg_connection *c) {
  67. struct mg_tls *tls = (struct mg_tls *) c->tls;
  68. int rc = mbedtls_ssl_handshake(&tls->ssl);
  69. if (rc == 0) { // Success
  70. MG_DEBUG(("%lu success", c->id));
  71. c->is_tls_hs = 0;
  72. mg_call(c, MG_EV_TLS_HS, NULL);
  73. } else if (rc == MBEDTLS_ERR_SSL_WANT_READ ||
  74. rc == MBEDTLS_ERR_SSL_WANT_WRITE) { // Still pending
  75. MG_VERBOSE(("%lu pending, %d%d %d (-%#x)", c->id, c->is_connecting,
  76. c->is_tls_hs, rc, -rc));
  77. } else {
  78. mg_error(c, "TLS handshake: -%#x", -rc); // Error
  79. }
  80. }
  81. static void debug_cb(void *c, int lev, const char *s, int n, const char *s2) {
  82. n = (int) strlen(s2) - 1;
  83. MG_INFO(("%lu %d %.*s", ((struct mg_connection *) c)->id, lev, n, s2));
  84. (void) s;
  85. }
  86. void mg_tls_init(struct mg_connection *c, const struct mg_tls_opts *opts) {
  87. struct mg_tls *tls = (struct mg_tls *) calloc(1, sizeof(*tls));
  88. int rc = 0;
  89. c->tls = tls;
  90. if (c->tls == NULL) {
  91. mg_error(c, "TLS OOM");
  92. goto fail;
  93. }
  94. if (c->is_listening) goto fail;
  95. MG_DEBUG(("%lu Setting TLS", c->id));
  96. MG_PROF_ADD(c, "mbedtls_init_start");
  97. mbedtls_ssl_init(&tls->ssl);
  98. mbedtls_ssl_config_init(&tls->conf);
  99. mbedtls_x509_crt_init(&tls->ca);
  100. mbedtls_x509_crt_init(&tls->cert);
  101. mbedtls_pk_init(&tls->pk);
  102. mbedtls_ssl_conf_dbg(&tls->conf, debug_cb, c);
  103. #if defined(MG_MBEDTLS_DEBUG_LEVEL)
  104. mbedtls_debug_set_threshold(MG_MBEDTLS_DEBUG_LEVEL);
  105. #endif
  106. if ((rc = mbedtls_ssl_config_defaults(
  107. &tls->conf,
  108. c->is_client ? MBEDTLS_SSL_IS_CLIENT : MBEDTLS_SSL_IS_SERVER,
  109. MBEDTLS_SSL_TRANSPORT_STREAM, MBEDTLS_SSL_PRESET_DEFAULT)) != 0) {
  110. mg_error(c, "tls defaults %#x", -rc);
  111. goto fail;
  112. }
  113. mbedtls_ssl_conf_rng(&tls->conf, mg_mbed_rng, c);
  114. if (opts->ca.len == 0 || mg_vcmp(&opts->ca, "*") == 0) {
  115. mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_NONE);
  116. } else {
  117. if (mg_load_cert(opts->ca, &tls->ca) == false) goto fail;
  118. mbedtls_ssl_conf_ca_chain(&tls->conf, &tls->ca, NULL);
  119. if (c->is_client && opts->name.ptr != NULL && opts->name.ptr[0] != '\0') {
  120. char *host = mg_mprintf("%.*s", opts->name.len, opts->name.ptr);
  121. mbedtls_ssl_set_hostname(&tls->ssl, host);
  122. MG_DEBUG(("%lu hostname verification: %s", c->id, host));
  123. free(host);
  124. }
  125. mbedtls_ssl_conf_authmode(&tls->conf, MBEDTLS_SSL_VERIFY_REQUIRED);
  126. }
  127. if (!mg_load_cert(opts->cert, &tls->cert)) goto fail;
  128. if (!mg_load_key(opts->key, &tls->pk)) goto fail;
  129. if (tls->cert.version &&
  130. (rc = mbedtls_ssl_conf_own_cert(&tls->conf, &tls->cert, &tls->pk)) != 0) {
  131. mg_error(c, "own cert %#x", -rc);
  132. goto fail;
  133. }
  134. #ifdef MBEDTLS_SSL_SESSION_TICKETS
  135. mbedtls_ssl_conf_session_tickets_cb(
  136. &tls->conf, mbedtls_ssl_ticket_write, mbedtls_ssl_ticket_parse,
  137. &((struct mg_tls_ctx *) c->mgr->tls_ctx)->tickets);
  138. #endif
  139. if ((rc = mbedtls_ssl_setup(&tls->ssl, &tls->conf)) != 0) {
  140. mg_error(c, "setup err %#x", -rc);
  141. goto fail;
  142. }
  143. c->is_tls = 1;
  144. c->is_tls_hs = 1;
  145. mbedtls_ssl_set_bio(&tls->ssl, c, mg_net_send, mg_net_recv, 0);
  146. MG_PROF_ADD(c, "mbedtls_init_end");
  147. if (c->is_client && c->is_resolving == 0 && c->is_connecting == 0) {
  148. mg_tls_handshake(c);
  149. }
  150. return;
  151. fail:
  152. mg_tls_free(c);
  153. }
  154. size_t mg_tls_pending(struct mg_connection *c) {
  155. struct mg_tls *tls = (struct mg_tls *) c->tls;
  156. return tls == NULL ? 0 : mbedtls_ssl_get_bytes_avail(&tls->ssl);
  157. }
  158. long mg_tls_recv(struct mg_connection *c, void *buf, size_t len) {
  159. struct mg_tls *tls = (struct mg_tls *) c->tls;
  160. long n = mbedtls_ssl_read(&tls->ssl, (unsigned char *) buf, len);
  161. if (n == MBEDTLS_ERR_SSL_WANT_READ || n == MBEDTLS_ERR_SSL_WANT_WRITE)
  162. return MG_IO_WAIT;
  163. if (n <= 0) return MG_IO_ERR;
  164. return n;
  165. }
  166. long mg_tls_send(struct mg_connection *c, const void *buf, size_t len) {
  167. struct mg_tls *tls = (struct mg_tls *) c->tls;
  168. long n = mbedtls_ssl_write(&tls->ssl, (unsigned char *) buf, len);
  169. if (n == MBEDTLS_ERR_SSL_WANT_READ || n == MBEDTLS_ERR_SSL_WANT_WRITE)
  170. return MG_IO_WAIT;
  171. if (n <= 0) return MG_IO_ERR;
  172. return n;
  173. }
  174. void mg_tls_ctx_init(struct mg_mgr *mgr) {
  175. struct mg_tls_ctx *ctx = (struct mg_tls_ctx *) calloc(1, sizeof(*ctx));
  176. if (ctx == NULL) {
  177. MG_ERROR(("TLS context init OOM"));
  178. } else {
  179. #ifdef MBEDTLS_SSL_SESSION_TICKETS
  180. int rc;
  181. mbedtls_ssl_ticket_init(&ctx->tickets);
  182. if ((rc = mbedtls_ssl_ticket_setup(&ctx->tickets, mg_mbed_rng, NULL,
  183. MBEDTLS_CIPHER_AES_128_GCM, 86400)) !=
  184. 0) {
  185. MG_ERROR((" mbedtls_ssl_ticket_setup %#x", -rc));
  186. }
  187. #endif
  188. mgr->tls_ctx = ctx;
  189. }
  190. }
  191. void mg_tls_ctx_free(struct mg_mgr *mgr) {
  192. struct mg_tls_ctx *ctx = (struct mg_tls_ctx *) mgr->tls_ctx;
  193. if (ctx != NULL) {
  194. #ifdef MBEDTLS_SSL_SESSION_TICKETS
  195. mbedtls_ssl_ticket_free(&ctx->tickets);
  196. #endif
  197. free(ctx);
  198. mgr->tls_ctx = NULL;
  199. }
  200. }
  201. #endif