Implement custom extensions. This change mirrors upstream's custom extension API because we have some internal users that depend on it. Change-Id: I408e442de0a55df7b05c872c953ff048cd406513 Reviewed-on: https://boringssl-review.googlesource.com/5471 Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/err/ssl.errordata b/crypto/err/ssl.errordata index 4fc89d4..7825cce 100644 --- a/crypto/err/ssl.errordata +++ b/crypto/err/ssl.errordata
@@ -37,6 +37,8 @@ SSL,135,CONNECTION_REJECTED SSL,136,CONNECTION_TYPE_NOT_SET SSL,137,COOKIE_MISMATCH +SSL,284,CUSTOM_EXTENSION_CONTENTS_TOO_LARGE +SSL,285,CUSTOM_EXTENSION_ERROR SSL,138,D2I_ECDSA_SIG SSL,139,DATA_BETWEEN_CCS_AND_FINISHED SSL,140,DATA_LENGTH_TOO_LONG
diff --git a/include/openssl/base.h b/include/openssl/base.h index 84400ae..4618ed2 100644 --- a/include/openssl/base.h +++ b/include/openssl/base.h
@@ -214,6 +214,7 @@ typedef struct sha512_state_st SHA512_CTX; typedef struct sha_state_st SHA_CTX; typedef struct ssl_ctx_st SSL_CTX; +typedef struct ssl_custom_extension SSL_CUSTOM_EXTENSION; typedef struct ssl_st SSL; typedef struct st_ERR_FNS ERR_FNS; typedef struct v3_ext_ctx X509V3_CTX;
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h index 656a901..02445bf 100644 --- a/include/openssl/ssl.h +++ b/include/openssl/ssl.h
@@ -768,6 +768,92 @@ size_t *out_len, size_t max_out); +/* Custom extensions. + * + * The custom extension functions allow TLS extensions to be added to + * ClientHello and ServerHello messages. */ + +/* SSL_custom_ext_add_cb is a callback function that is called when the + * ClientHello (for clients) or ServerHello (for servers) is constructed. In + * the case of a server, this callback will only be called for a given + * extension if the ClientHello contained that extension – it's not possible to + * inject extensions into a ServerHello that the client didn't request. + * + * When called, |extension_value| will contain the extension number that is + * being considered for addition (so that a single callback can handle multiple + * extensions). If the callback wishes to include the extension, it must set + * |*out| to point to |*out_len| bytes of extension contents and return one. In + * this case, the corresponding |SSL_custom_ext_free_cb| callback will later be + * called with the value of |*out| once that data has been copied. + * + * If the callback does not wish to add an extension it must return zero. + * + * Alternatively, the callback can abort the connection by setting + * |*out_alert_value| to a TLS alert number and returning -1. */ +typedef int (*SSL_custom_ext_add_cb)(SSL *ssl, unsigned extension_value, + const uint8_t **out, size_t *out_len, + int *out_alert_value, void *add_arg); + +/* SSL_custom_ext_free_cb is a callback function that is called by OpenSSL iff + * an |SSL_custom_ext_add_cb| callback previously returned one. In that case, + * this callback is called and passed the |out| pointer that was returned by + * the add callback. This is to free any dynamically allocated data created by + * the add callback. */ +typedef void (*SSL_custom_ext_free_cb)(SSL *ssl, unsigned extension_value, + const uint8_t *out, void *add_arg); + +/* SSL_custom_ext_parse_cb is a callback function that is called by OpenSSL to + * parse an extension from the peer: that is from the ServerHello for a client + * and from the ClientHello for a server. + * + * When called, |extension_value| will contain the extension number and the + * contents of the extension are |contents_len| bytes at |contents|. + * + * The callback must return one to continue the handshake. Otherwise, if it + * returns zero, a fatal alert with value |*out_alert_value| is sent and the + * handshake is aborted. */ +typedef int (*SSL_custom_ext_parse_cb)(SSL *ssl, unsigned extension_value, + const uint8_t *contents, + size_t contents_len, + int *out_alert_value, void *parse_arg); + +/* SSL_extension_supported returns one iff OpenSSL internally handles + * extensions of type |extension_value|. This can be used to avoid registering + * custom extension handlers for extensions that a future version of OpenSSL + * may handle internally. */ +OPENSSL_EXPORT int SSL_extension_supported(unsigned extension_value); + +/* SSL_CTX_add_client_custom_ext registers callback functions for handling + * custom TLS extensions for client connections. + * + * If |add_cb| is NULL then an empty extension will be added in each + * ClientHello. Otherwise, see the comment for |SSL_custom_ext_add_cb| about + * this callback. + * + * The |free_cb| may be NULL if |add_cb| doesn't dynamically allocate data that + * needs to be freed. + * + * It returns one on success or zero on error. It's always an error to register + * callbacks for the same extension twice, or to register callbacks for an + * extension that OpenSSL handles internally. See |SSL_extension_supported| to + * discover, at runtime, which extensions OpenSSL handles internally. */ +OPENSSL_EXPORT int SSL_CTX_add_client_custom_ext( + SSL_CTX *ctx, unsigned extension_value, SSL_custom_ext_add_cb add_cb, + SSL_custom_ext_free_cb free_cb, void *add_arg, + SSL_custom_ext_parse_cb parse_cb, void *parse_arg); + +/* SSL_CTX_add_server_custom_ext is the same as + * |SSL_CTX_add_client_custom_ext|, but for server connections. + * + * Unlike on the client side, if |add_cb| is NULL no extension will be added. + * The |add_cb|, if any, will only be called if the ClientHello contained a + * matching extension. */ +OPENSSL_EXPORT int SSL_CTX_add_server_custom_ext( + SSL_CTX *ctx, unsigned extension_value, SSL_custom_ext_add_cb add_cb, + SSL_custom_ext_free_cb free_cb, void *add_arg, + SSL_custom_ext_parse_cb parse_cb, void *parse_arg); + + /* Session tickets. */ /* SSL_CTX_get_tlsext_ticket_keys writes |ctx|'s session ticket key material to @@ -1221,6 +1307,10 @@ CRYPTO_EX_DATA ex_data; + /* custom_*_extensions stores any callback sets for custom extensions. Note + * that these pointers will be NULL if the stack would otherwise be empty. */ + STACK_OF(SSL_CUSTOM_EXTENSION) *client_custom_extensions; + STACK_OF(SSL_CUSTOM_EXTENSION) *server_custom_extensions; /* Default values used when no per-SSL value is defined follow */ @@ -2945,6 +3035,8 @@ #define SSL_R_ERROR_ADDING_EXTENSION 281 #define SSL_R_ERROR_PARSING_EXTENSION 282 #define SSL_R_MISSING_EXTENSION 283 +#define SSL_R_CUSTOM_EXTENSION_CONTENTS_TOO_LARGE 284 +#define SSL_R_CUSTOM_EXTENSION_ERROR 285 #define SSL_R_SSLV3_ALERT_CLOSE_NOTIFY 1000 #define SSL_R_SSLV3_ALERT_UNEXPECTED_MESSAGE 1010 #define SSL_R_SSLV3_ALERT_BAD_RECORD_MAC 1020
diff --git a/include/openssl/ssl3.h b/include/openssl/ssl3.h index 541b039..faf69ab 100644 --- a/include/openssl/ssl3.h +++ b/include/openssl/ssl3.h
@@ -453,6 +453,16 @@ uint32_t received; } extensions; + union { + /* sent is a bitset where the bits correspond to elements of + * |client_custom_extensions| in the |SSL_CTX|. Each bit is set if that + * extension was sent in a ClientHello. It's not used by servers. */ + uint16_t sent; + /* received is a bitset, like |sent|, but is used by servers to record + * which custom extensions were received from a client. The bits here + * correspond to |server_custom_extensions|. */ + uint16_t received; + } custom_extensions; /* SNI extension */
diff --git a/include/openssl/stack.h b/include/openssl/stack.h index 5f9b683..36a0397 100644 --- a/include/openssl/stack.h +++ b/include/openssl/stack.h
@@ -147,6 +147,7 @@ * STACK_OF:POLICY_MAPPING * STACK_OF:RSA_additional_prime * STACK_OF:SSL_COMP + * STACK_OF:SSL_CUSTOM_EXTENSION * STACK_OF:STACK_OF_X509_NAME_ENTRY * STACK_OF:SXNETID * STACK_OF:X509
diff --git a/include/openssl/stack_macros.h b/include/openssl/stack_macros.h index a4dc926..08097af 100644 --- a/include/openssl/stack_macros.h +++ b/include/openssl/stack_macros.h
@@ -2106,6 +2106,99 @@ CHECKED_CAST(void *(*)(void *), SSL_COMP *(*)(SSL_COMP *), copy_func), \ CHECKED_CAST(void (*)(void *), void (*)(SSL_COMP *), free_func))) +/* SSL_CUSTOM_EXTENSION */ +#define sk_SSL_CUSTOM_EXTENSION_new(comp) \ + ((STACK_OF(SSL_CUSTOM_EXTENSION) *)sk_new(CHECKED_CAST( \ + stack_cmp_func, \ + int (*)(const SSL_CUSTOM_EXTENSION **a, const SSL_CUSTOM_EXTENSION **b), \ + comp))) + +#define sk_SSL_CUSTOM_EXTENSION_new_null() \ + ((STACK_OF(SSL_CUSTOM_EXTENSION) *)sk_new_null()) + +#define sk_SSL_CUSTOM_EXTENSION_num(sk) \ + sk_num(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk)) + +#define sk_SSL_CUSTOM_EXTENSION_zero(sk) \ + sk_zero(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk)); + +#define sk_SSL_CUSTOM_EXTENSION_value(sk, i) \ + ((SSL_CUSTOM_EXTENSION *)sk_value( \ + CHECKED_CAST(_STACK *, const STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + (i))) + +#define sk_SSL_CUSTOM_EXTENSION_set(sk, i, p) \ + ((SSL_CUSTOM_EXTENSION *)sk_set( \ + CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), (i), \ + CHECKED_CAST(void *, SSL_CUSTOM_EXTENSION *, p))) + +#define sk_SSL_CUSTOM_EXTENSION_free(sk) \ + sk_free(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk)) + +#define sk_SSL_CUSTOM_EXTENSION_pop_free(sk, free_func) \ + sk_pop_free(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + CHECKED_CAST(void (*)(void *), void (*)(SSL_CUSTOM_EXTENSION *), \ + free_func)) + +#define sk_SSL_CUSTOM_EXTENSION_insert(sk, p, where) \ + sk_insert(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + CHECKED_CAST(void *, SSL_CUSTOM_EXTENSION *, p), (where)) + +#define sk_SSL_CUSTOM_EXTENSION_delete(sk, where) \ + ((SSL_CUSTOM_EXTENSION *)sk_delete( \ + CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), (where))) + +#define sk_SSL_CUSTOM_EXTENSION_delete_ptr(sk, p) \ + ((SSL_CUSTOM_EXTENSION *)sk_delete_ptr( \ + CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + CHECKED_CAST(void *, SSL_CUSTOM_EXTENSION *, p))) + +#define sk_SSL_CUSTOM_EXTENSION_find(sk, out_index, p) \ + sk_find(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + (out_index), CHECKED_CAST(void *, SSL_CUSTOM_EXTENSION *, p)) + +#define sk_SSL_CUSTOM_EXTENSION_shift(sk) \ + ((SSL_CUSTOM_EXTENSION *)sk_shift( \ + CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk))) + +#define sk_SSL_CUSTOM_EXTENSION_push(sk, p) \ + sk_push(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + CHECKED_CAST(void *, SSL_CUSTOM_EXTENSION *, p)) + +#define sk_SSL_CUSTOM_EXTENSION_pop(sk) \ + ((SSL_CUSTOM_EXTENSION *)sk_pop( \ + CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk))) + +#define sk_SSL_CUSTOM_EXTENSION_dup(sk) \ + ((STACK_OF(SSL_CUSTOM_EXTENSION) *)sk_dup( \ + CHECKED_CAST(_STACK *, const STACK_OF(SSL_CUSTOM_EXTENSION) *, sk))) + +#define sk_SSL_CUSTOM_EXTENSION_sort(sk) \ + sk_sort(CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk)) + +#define sk_SSL_CUSTOM_EXTENSION_is_sorted(sk) \ + sk_is_sorted( \ + CHECKED_CAST(_STACK *, const STACK_OF(SSL_CUSTOM_EXTENSION) *, sk)) + +#define sk_SSL_CUSTOM_EXTENSION_set_cmp_func(sk, comp) \ + ((int (*)(const SSL_CUSTOM_EXTENSION **a, const SSL_CUSTOM_EXTENSION **b)) \ + sk_set_cmp_func( \ + CHECKED_CAST(_STACK *, STACK_OF(SSL_CUSTOM_EXTENSION) *, sk), \ + CHECKED_CAST(stack_cmp_func, \ + int (*)(const SSL_CUSTOM_EXTENSION **a, \ + const SSL_CUSTOM_EXTENSION **b), \ + comp))) + +#define sk_SSL_CUSTOM_EXTENSION_deep_copy(sk, copy_func, free_func) \ + ((STACK_OF(SSL_CUSTOM_EXTENSION) *)sk_deep_copy( \ + CHECKED_CAST(const _STACK *, const STACK_OF(SSL_CUSTOM_EXTENSION) *, \ + sk), \ + CHECKED_CAST(void *(*)(void *), \ + SSL_CUSTOM_EXTENSION *(*)(SSL_CUSTOM_EXTENSION *), \ + copy_func), \ + CHECKED_CAST(void (*)(void *), void (*)(SSL_CUSTOM_EXTENSION *), \ + free_func))) + /* STACK_OF_X509_NAME_ENTRY */ #define sk_STACK_OF_X509_NAME_ENTRY_new(comp) \ ((STACK_OF(STACK_OF_X509_NAME_ENTRY) *)sk_new(CHECKED_CAST( \
diff --git a/ssl/CMakeLists.txt b/ssl/CMakeLists.txt index 4379060..ae241ae 100644 --- a/ssl/CMakeLists.txt +++ b/ssl/CMakeLists.txt
@@ -5,6 +5,7 @@ add_library( ssl + custom_extensions.c d1_both.c d1_clnt.c d1_lib.c
diff --git a/ssl/custom_extensions.c b/ssl/custom_extensions.c new file mode 100644 index 0000000..d0bc257 --- /dev/null +++ b/ssl/custom_extensions.c
@@ -0,0 +1,252 @@ +/* Copyright (c) 2014, Google Inc. + * + * Permission to use, copy, modify, and/or distribute this software for any + * purpose with or without fee is hereby granted, provided that the above + * copyright notice and this permission notice appear in all copies. + * + * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY + * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION + * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN + * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */ + +#include <assert.h> +#include <string.h> + +#include <openssl/ssl.h> + +#include "internal.h" + + +void SSL_CUSTOM_EXTENSION_free(SSL_CUSTOM_EXTENSION *custom_extension) { + OPENSSL_free(custom_extension); +} + +static const SSL_CUSTOM_EXTENSION *custom_ext_find( + STACK_OF(SSL_CUSTOM_EXTENSION) *stack, + unsigned *out_index, uint16_t value) { + size_t i; + for (i = 0; i < sk_SSL_CUSTOM_EXTENSION_num(stack); i++) { + const SSL_CUSTOM_EXTENSION *ext = sk_SSL_CUSTOM_EXTENSION_value(stack, i); + if (ext->value == value) { + if (out_index != NULL) { + *out_index = i; + } + return ext; + } + } + + return NULL; +} + +/* default_add_callback is used as the |add_callback| when the user doesn't + * provide one. For servers, it does nothing while, for clients, it causes an + * empty extension to be included. */ +static int default_add_callback(SSL *ssl, unsigned extension_value, + const uint8_t **out, size_t *out_len, + int *out_alert_value, void *add_arg) { + if (ssl->server) { + return 0; + } + *out_len = 0; + return 1; +} + +static int custom_ext_add_hello(SSL *ssl, CBB *extensions) { + STACK_OF(SSL_CUSTOM_EXTENSION) *stack = ssl->ctx->client_custom_extensions; + if (ssl->server) { + stack = ssl->ctx->server_custom_extensions; + } + + if (stack == NULL) { + return 1; + } + + size_t i; + for (i = 0; i < sk_SSL_CUSTOM_EXTENSION_num(stack); i++) { + const SSL_CUSTOM_EXTENSION *ext = sk_SSL_CUSTOM_EXTENSION_value(stack, i); + + if (ssl->server && + !(ssl->s3->tmp.custom_extensions.received & (1u << i))) { + /* Servers cannot echo extensions that the client didn't send. */ + continue; + } + + const uint8_t *contents; + size_t contents_len; + int alert = SSL_AD_DECODE_ERROR; + CBB contents_cbb; + + switch (ext->add_callback(ssl, ext->value, &contents, &contents_len, &alert, + ext->add_arg)) { + case 1: + if (!CBB_add_u16(extensions, ext->value) || + !CBB_add_u16_length_prefixed(extensions, &contents_cbb) || + !CBB_add_bytes(&contents_cbb, contents, contents_len) || + !CBB_flush(extensions)) { + OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR); + ERR_add_error_dataf("extension: %u", (unsigned) ext->value); + if (ext->free_callback && 0 < contents_len) { + ext->free_callback(ssl, ext->value, contents, ext->add_arg); + } + return 0; + } + + if (ext->free_callback && 0 < contents_len) { + ext->free_callback(ssl, ext->value, contents, ext->add_arg); + } + + if (!ssl->server) { + assert((ssl->s3->tmp.custom_extensions.sent & (1u << i)) == 0); + ssl->s3->tmp.custom_extensions.sent |= (1u << i); + } + break; + + case 0: + break; + + default: + ssl3_send_alert(ssl, SSL3_AL_FATAL, alert); + OPENSSL_PUT_ERROR(SSL, SSL_R_CUSTOM_EXTENSION_ERROR); + ERR_add_error_dataf("extension: %u", (unsigned) ext->value); + return 0; + } + } + + return 1; +} + +int custom_ext_add_clienthello(SSL *ssl, CBB *extensions) { + return custom_ext_add_hello(ssl, extensions); +} + +int custom_ext_parse_serverhello(SSL *ssl, int *out_alert, uint16_t value, + const CBS *extension) { + unsigned index; + const SSL_CUSTOM_EXTENSION *ext = + custom_ext_find(ssl->ctx->client_custom_extensions, &index, value); + + if (/* Unknown extensions are not allowed in a ServerHello. */ + ext == NULL || + /* Also, if we didn't send the extension, that's also unacceptable. */ + !(ssl->s3->tmp.custom_extensions.sent & (1u << index))) { + OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_EXTENSION); + ERR_add_error_dataf("extension: %u", (unsigned)ext->value); + *out_alert = SSL_AD_DECODE_ERROR; + return 0; + } + + if (ext->parse_callback != NULL && + !ext->parse_callback(ssl, value, CBS_data(extension), CBS_len(extension), + out_alert, ext->parse_arg)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_CUSTOM_EXTENSION_ERROR); + ERR_add_error_dataf("extension: %u", (unsigned)ext->value); + return 0; + } + + return 1; +} + +int custom_ext_parse_clienthello(SSL *ssl, int *out_alert, uint16_t value, + const CBS *extension) { + unsigned index; + const SSL_CUSTOM_EXTENSION *ext = + custom_ext_find(ssl->ctx->server_custom_extensions, &index, value); + + if (ext == NULL) { + return 1; + } + + assert((ssl->s3->tmp.custom_extensions.received & (1u << index)) == 0); + ssl->s3->tmp.custom_extensions.received |= (1u << index); + + if (ext->parse_callback && + !ext->parse_callback(ssl, value, CBS_data(extension), CBS_len(extension), + out_alert, ext->parse_arg)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_CUSTOM_EXTENSION_ERROR); + ERR_add_error_dataf("extension: %u", (unsigned)ext->value); + return 0; + } + + return 1; +} + +int custom_ext_add_serverhello(SSL *ssl, CBB *extensions) { + return custom_ext_add_hello(ssl, extensions); +} + +/* MAX_NUM_CUSTOM_EXTENSIONS is the maximum number of custom extensions that + * can be set on an |SSL_CTX|. It's determined by the size of the bitset used + * to track when an extension has been sent. */ +#define MAX_NUM_CUSTOM_EXTENSIONS \ + (sizeof(((struct ssl3_state_st *)NULL)->tmp.custom_extensions.sent) * 8) + +static int custom_ext_append(STACK_OF(SSL_CUSTOM_EXTENSION) **stack, + unsigned extension_value, + SSL_custom_ext_add_cb add_cb, + SSL_custom_ext_free_cb free_cb, void *add_arg, + SSL_custom_ext_parse_cb parse_cb, + void *parse_arg) { + if (add_cb == NULL || + 0xffff < extension_value || + SSL_extension_supported(extension_value) || + /* Specifying a free callback without an add callback is nonsensical + * and an error. */ + (*stack != NULL && + (MAX_NUM_CUSTOM_EXTENSIONS <= sk_SSL_CUSTOM_EXTENSION_num(*stack) || + custom_ext_find(*stack, NULL, extension_value) != NULL))) { + return 0; + } + + SSL_CUSTOM_EXTENSION *ext = OPENSSL_malloc(sizeof(SSL_CUSTOM_EXTENSION)); + if (ext == NULL) { + return 0; + } + ext->add_callback = add_cb; + ext->add_arg = add_arg; + ext->free_callback = free_cb; + ext->parse_callback = parse_cb; + ext->parse_arg = parse_arg; + ext->value = extension_value; + + if (*stack == NULL) { + *stack = sk_SSL_CUSTOM_EXTENSION_new_null(); + if (*stack == NULL) { + SSL_CUSTOM_EXTENSION_free(ext); + return 0; + } + } + + if (!sk_SSL_CUSTOM_EXTENSION_push(*stack, ext)) { + SSL_CUSTOM_EXTENSION_free(ext); + if (sk_SSL_CUSTOM_EXTENSION_num(*stack) == 0) { + sk_SSL_CUSTOM_EXTENSION_free(*stack); + *stack = NULL; + } + return 0; + } + + return 1; +} + +int SSL_CTX_add_client_custom_ext(SSL_CTX *ctx, unsigned extension_value, + SSL_custom_ext_add_cb add_cb, + SSL_custom_ext_free_cb free_cb, void *add_arg, + SSL_custom_ext_parse_cb parse_cb, + void *parse_arg) { + return custom_ext_append(&ctx->client_custom_extensions, extension_value, + add_cb ? add_cb : default_add_callback, free_cb, + add_arg, parse_cb, parse_arg); +} + +int SSL_CTX_add_server_custom_ext(SSL_CTX *ctx, unsigned extension_value, + SSL_custom_ext_add_cb add_cb, + SSL_custom_ext_free_cb free_cb, void *add_arg, + SSL_custom_ext_parse_cb parse_cb, + void *parse_arg) { + return custom_ext_append(&ctx->server_custom_extensions, extension_value, + add_cb ? add_cb : default_add_callback, free_cb, + add_arg, parse_cb, parse_arg); +}
diff --git a/ssl/internal.h b/ssl/internal.h index a63c0cd..f19a265 100644 --- a/ssl/internal.h +++ b/ssl/internal.h
@@ -364,6 +364,29 @@ SSL *ssl, uint8_t *out, size_t *out_len, size_t max_out); +/* Custom extensions */ + +/* ssl_custom_extension (a.k.a. SSL_CUSTOM_EXTENSION) is a structure that + * contains information about custom-extension callbacks. */ +struct ssl_custom_extension { + SSL_custom_ext_add_cb add_callback; + void *add_arg; + SSL_custom_ext_free_cb free_callback; + SSL_custom_ext_parse_cb parse_callback; + void *parse_arg; + uint16_t value; +}; + +void SSL_CUSTOM_EXTENSION_free(SSL_CUSTOM_EXTENSION *custom_extension); + +int custom_ext_add_clienthello(SSL *ssl, CBB *extensions); +int custom_ext_parse_serverhello(SSL *ssl, int *out_alert, uint16_t value, + const CBS *extension); +int custom_ext_parse_clienthello(SSL *ssl, int *out_alert, uint16_t value, + const CBS *extension); +int custom_ext_add_serverhello(SSL *ssl, CBB *extensions); + + /* Underdocumented functions. * * Functions below here haven't been touched up and may be underdocumented. */
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c index 61f0626..a7d5d5b 100644 --- a/ssl/ssl_lib.c +++ b/ssl/ssl_lib.c
@@ -1767,6 +1767,10 @@ sk_SSL_CIPHER_free(ctx->cipher_list_by_id); ssl_cipher_preference_list_free(ctx->cipher_list_tls11); ssl_cert_free(ctx->cert); + sk_SSL_CUSTOM_EXTENSION_pop_free(ctx->client_custom_extensions, + SSL_CUSTOM_EXTENSION_free); + sk_SSL_CUSTOM_EXTENSION_pop_free(ctx->server_custom_extensions, + SSL_CUSTOM_EXTENSION_free); sk_X509_NAME_pop_free(ctx->client_CA, X509_NAME_free); sk_SRTP_PROTECTION_PROFILE_free(ctx->srtp_profiles); OPENSSL_free(ctx->psk_identity_hint);
diff --git a/ssl/t1_lib.c b/ssl/t1_lib.c index bc3e556..296ab80 100644 --- a/ssl/t1_lib.c +++ b/ssl/t1_lib.c
@@ -2235,6 +2235,12 @@ return NULL; } +int SSL_extension_supported(unsigned extension_value) { + uint32_t index; + return extension_value == TLSEXT_TYPE_padding || + tls_extension_find(&index, extension_value) != NULL; +} + /* header_len is the length of the ClientHello header written so far, used to * compute padding. It does not include the record header. Pass 0 if no padding * is to be done. */ @@ -2253,6 +2259,7 @@ } s->s3->tmp.extensions.sent = 0; + s->s3->tmp.custom_extensions.sent = 0; size_t i; for (i = 0; i < kNumExtensions; i++) { @@ -2274,6 +2281,10 @@ } } + if (!custom_ext_add_clienthello(s, &extensions)) { + goto err; + } + if (header_len > 0) { header_len += CBB_len(&extensions); if (header_len > 0xff && header_len < 0x200) { @@ -2353,6 +2364,10 @@ } } + if (!custom_ext_add_serverhello(s, &extensions)) { + goto err; + } + if (!CBB_flush(&cbb)) { goto err; } @@ -2384,6 +2399,7 @@ } s->s3->tmp.extensions.received = 0; + s->s3->tmp.custom_extensions.received = 0; /* The renegotiation extension must always be at index zero because the * |received| and |sent| bitsets need to be tweaked when the "extension" is * sent as an SCSV. */ @@ -2415,6 +2431,10 @@ tls_extension_find(&ext_index, type); if (ext == NULL) { + if (!custom_ext_parse_clienthello(s, out_alert, type, &extension)) { + OPENSSL_PUT_ERROR(SSL, SSL_R_ERROR_PARSING_EXTENSION); + return 0; + } continue; } @@ -2490,11 +2510,15 @@ const struct tls_extension *const ext = tls_extension_find(&ext_index, type); - if (/* If ext == NULL then an unknown extension was received. Since we - * cannot have sent an unknown extension, this is illegal. */ - ext == NULL || - /* If the extension was never sent then it is also illegal. */ - !(s->s3->tmp.extensions.sent & (1u << ext_index))) { + if (ext == NULL) { + if (!custom_ext_parse_serverhello(s, out_alert, type, &extension)) { + return 0; + } + continue; + } + + if (!(s->s3->tmp.extensions.sent & (1u << ext_index))) { + /* If the extension was never sent then it is illegal. */ OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_EXTENSION); ERR_add_error_dataf("extension :%u", (unsigned)type); *out_alert = SSL_AD_DECODE_ERROR;
diff --git a/ssl/test/bssl_shim.cc b/ssl/test/bssl_shim.cc index add45ae..571680b 100644 --- a/ssl/test/bssl_shim.cc +++ b/ssl/test/bssl_shim.cc
@@ -103,7 +103,7 @@ }; static void TestStateExFree(void *parent, void *ptr, CRYPTO_EX_DATA *ad, - int index, long argl, void *argp) { + int index, long argl, void *argp) { delete ((TestState *)ptr); } @@ -470,6 +470,63 @@ return 1; } +// kCustomExtensionValue is the extension value that the custom extension +// callbacks will add. +constexpr uint16_t kCustomExtensionValue = 1234; +static void *const kCustomExtensionAddArg = + reinterpret_cast<void *>(kCustomExtensionValue); +static void *const kCustomExtensionParseArg = + reinterpret_cast<void *>(kCustomExtensionValue + 1); +static const char kCustomExtensionContents[] = "custom extension"; + +static int CustomExtensionAddCallback(SSL *ssl, unsigned extension_value, + const uint8_t **out, size_t *out_len, + int *out_alert_value, void *add_arg) { + if (extension_value != kCustomExtensionValue || + add_arg != kCustomExtensionAddArg) { + abort(); + } + + if (GetConfigPtr(ssl)->custom_extension_skip) { + return 0; + } + if (GetConfigPtr(ssl)->custom_extension_fail_add) { + return -1; + } + + *out = reinterpret_cast<const uint8_t*>(kCustomExtensionContents); + *out_len = sizeof(kCustomExtensionContents) - 1; + + return 1; +} + +static void CustomExtensionFreeCallback(SSL *ssl, unsigned extension_value, + const uint8_t *out, void *add_arg) { + if (extension_value != kCustomExtensionValue || + add_arg != kCustomExtensionAddArg || + out != reinterpret_cast<const uint8_t *>(kCustomExtensionContents)) { + abort(); + } +} + +static int CustomExtensionParseCallback(SSL *ssl, unsigned extension_value, + const uint8_t *contents, + size_t contents_len, + int *out_alert_value, void *parse_arg) { + if (extension_value != kCustomExtensionValue || + parse_arg != kCustomExtensionParseArg) { + abort(); + } + + if (contents_len != sizeof(kCustomExtensionContents) - 1 || + memcmp(contents, kCustomExtensionContents, contents_len) != 0) { + *out_alert_value = SSL_AD_DECODE_ERROR; + return 0; + } + + return 1; +} + // Connect returns a new socket connected to localhost on |port| or -1 on // error. static int Connect(uint16_t port) { @@ -579,6 +636,22 @@ SSL_CTX_set_tlsext_ticket_key_cb(ssl_ctx.get(), TicketKeyCallback); } + if (config->enable_client_custom_extension && + !SSL_CTX_add_client_custom_ext( + ssl_ctx.get(), kCustomExtensionValue, CustomExtensionAddCallback, + CustomExtensionFreeCallback, kCustomExtensionAddArg, + CustomExtensionParseCallback, kCustomExtensionParseArg)) { + return nullptr; + } + + if (config->enable_server_custom_extension && + !SSL_CTX_add_server_custom_ext( + ssl_ctx.get(), kCustomExtensionValue, CustomExtensionAddCallback, + CustomExtensionFreeCallback, kCustomExtensionAddArg, + CustomExtensionParseCallback, kCustomExtensionParseArg)) { + return nullptr; + } + return ssl_ctx; }
diff --git a/ssl/test/runner/common.go b/ssl/test/runner/common.go index fb78ef1..7b7a35b 100644 --- a/ssl/test/runner/common.go +++ b/ssl/test/runner/common.go
@@ -82,6 +82,7 @@ extensionSignedCertificateTimestamp uint16 = 18 extensionExtendedMasterSecret uint16 = 23 extensionSessionTicket uint16 = 35 + extensionCustom uint16 = 1234 // not IANA assigned extensionNextProtoNeg uint16 = 13172 // not IANA assigned extensionRenegotiationInfo uint16 = 0xff01 extensionChannelID uint16 = 30032 // not IANA assigned @@ -736,6 +737,14 @@ // RequireClientHelloSize, if not zero, is the required length in bytes // of the ClientHello /record/. This is checked by the server. RequireClientHelloSize int + + // CustomExtension, if not empty, contains the contents of an extension + // that will be added to client/server hellos. + CustomExtension string + + // ExpectedCustomExtension, if not nil, contains the expected contents + // of a custom extension. + ExpectedCustomExtension *string } func (c *Config) serverInit() {
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go index bc10fe7..a96cd9c 100644 --- a/ssl/test/runner/handshake_client.go +++ b/ssl/test/runner/handshake_client.go
@@ -73,6 +73,7 @@ extendedMasterSecret: c.config.maxVersion() >= VersionTLS10, srtpProtectionProfiles: c.config.SRTPProtectionProfiles, srtpMasterKeyIdentifier: c.config.Bugs.SRTPMasterKeyIdentifer, + customExtension: c.config.Bugs.CustomExtension, } if c.config.Bugs.SendClientVersion != 0 { @@ -290,6 +291,12 @@ } } + if expected := c.config.Bugs.ExpectedCustomExtension; expected != nil { + if serverHello.customExtension != *expected { + return fmt.Errorf("tls: bad custom extension contents %q", serverHello.customExtension) + } + } + hs := &clientHandshakeState{ c: c, serverHello: serverHello,
diff --git a/ssl/test/runner/handshake_messages.go b/ssl/test/runner/handshake_messages.go index 46ff2fd..92f603a 100644 --- a/ssl/test/runner/handshake_messages.go +++ b/ssl/test/runner/handshake_messages.go
@@ -32,6 +32,7 @@ srtpProtectionProfiles []uint16 srtpMasterKeyIdentifier string sctListSupported bool + customExtension string } func (m *clientHelloMsg) equal(i interface{}) bool { @@ -65,7 +66,8 @@ m.extendedMasterSecret == m1.extendedMasterSecret && eqUint16s(m.srtpProtectionProfiles, m1.srtpProtectionProfiles) && m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier && - m.sctListSupported == m1.sctListSupported + m.sctListSupported == m1.sctListSupported && + m.customExtension == m1.customExtension } func (m *clientHelloMsg) marshal() []byte { @@ -138,6 +140,10 @@ if m.sctListSupported { numExtensions++ } + if l := len(m.customExtension); l > 0 { + extensionsLength += l + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions length += 2 + extensionsLength @@ -376,6 +382,14 @@ z[1] = byte(extensionSignedCertificateTimestamp & 0xff) z = z[4:] } + if l := len(m.customExtension); l > 0 { + z[0] = byte(extensionCustom >> 8) + z[1] = byte(extensionCustom & 0xff) + z[2] = byte(l >> 8) + z[3] = byte(l & 0xff) + copy(z[4:], []byte(m.customExtension)) + z = z[4 + l:] + } m.raw = x @@ -443,6 +457,7 @@ m.signatureAndHashes = nil m.alpnProtocols = nil m.extendedMasterSecret = false + m.customExtension = "" if len(data) == 0 { // ClientHello is optionally followed by extension data @@ -604,6 +619,8 @@ return false } m.sctListSupported = true + case extensionCustom: + m.customExtension = string(data[:length]) } data = data[length:] } @@ -632,6 +649,7 @@ srtpProtectionProfile uint16 srtpMasterKeyIdentifier string sctList []byte + customExtension string } func (m *serverHelloMsg) marshal() []byte { @@ -686,6 +704,10 @@ extensionsLength += len(m.sctList) numExtensions++ } + if l := len(m.customExtension); l > 0 { + extensionsLength += l + numExtensions++ + } if numExtensions > 0 { extensionsLength += 4 * numExtensions @@ -811,6 +833,14 @@ copy(z[4:], m.sctList) z = z[4+l:] } + if l := len(m.customExtension); l > 0 { + z[0] = byte(extensionCustom >> 8) + z[1] = byte(extensionCustom & 0xff) + z[2] = byte(l >> 8) + z[3] = byte(l & 0xff) + copy(z[4:], []byte(m.customExtension)) + z = z[4 + l:] + } m.raw = x @@ -844,6 +874,7 @@ m.alpnProtocol = "" m.alpnProtocolEmpty = false m.extendedMasterSecret = false + m.customExtension = "" if len(data) == 0 { // ServerHello is optionally followed by extension data @@ -948,6 +979,8 @@ return false } m.sctList = data[2:length] + case extensionCustom: + m.customExtension = string(data[:length]) } data = data[length:] }
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go index 5d37674..7686402 100644 --- a/ssl/test/runner/handshake_server.go +++ b/ssl/test/runner/handshake_server.go
@@ -210,8 +210,10 @@ } c.haveVers = true - hs.hello = new(serverHelloMsg) - hs.hello.isDTLS = c.isDTLS + hs.hello = &serverHelloMsg { + isDTLS: c.isDTLS, + customExtension: config.Bugs.CustomExtension, + } supportedCurve := false preferredCurves := config.curvePreferences() @@ -340,6 +342,12 @@ hs.hello.srtpProtectionProfile = c.config.Bugs.SendSRTPProtectionProfile } + if expected := c.config.Bugs.ExpectedCustomExtension; expected != nil { + if hs.clientHello.customExtension != *expected { + return false, fmt.Errorf("tls: bad custom extension contents %q", hs.clientHello.customExtension) + } + } + _, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey) // For test purposes, check that the peer never offers a session when
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go index ff10c05..d66dc74 100644 --- a/ssl/test/runner/runner.go +++ b/ssl/test/runner/runner.go
@@ -3627,6 +3627,97 @@ } } +func addCustomExtensionTests() { + expectedContents := "custom extension" + emptyString := "" + + for _, isClient := range []bool{false, true} { + suffix := "Server" + flag := "-enable-server-custom-extension" + testType := serverTest + if isClient { + suffix = "Client" + flag = "-enable-client-custom-extension" + testType = clientTest + } + + testCases = append(testCases, testCase{ + testType: testType, + name: "CustomExtensions-" + suffix, + config: Config{ + Bugs: ProtocolBugs { + CustomExtension: expectedContents, + ExpectedCustomExtension: &expectedContents, + }, + }, + flags: []string{flag}, + }) + + // If the parse callback fails, the handshake should also fail. + testCases = append(testCases, testCase{ + testType: testType, + name: "CustomExtensions-ParseError-" + suffix, + config: Config{ + Bugs: ProtocolBugs { + CustomExtension: expectedContents + "foo", + ExpectedCustomExtension: &expectedContents, + }, + }, + flags: []string{flag}, + shouldFail: true, + expectedError: ":CUSTOM_EXTENSION_ERROR:", + }) + + // If the add callback fails, the handshake should also fail. + testCases = append(testCases, testCase{ + testType: testType, + name: "CustomExtensions-FailAdd-" + suffix, + config: Config{ + Bugs: ProtocolBugs { + CustomExtension: expectedContents, + ExpectedCustomExtension: &expectedContents, + }, + }, + flags: []string{flag, "-custom-extension-fail-add"}, + shouldFail: true, + expectedError: ":CUSTOM_EXTENSION_ERROR:", + }) + + // If the add callback returns zero, no extension should be + // added. + skipCustomExtension := expectedContents + if isClient { + // For the case where the client skips sending the + // custom extension, the server must not “echo” it. + skipCustomExtension = "" + } + testCases = append(testCases, testCase{ + testType: testType, + name: "CustomExtensions-Skip-" + suffix, + config: Config{ + Bugs: ProtocolBugs { + CustomExtension: skipCustomExtension, + ExpectedCustomExtension: &emptyString, + }, + }, + flags: []string{flag, "-custom-extension-skip"}, + }) + } + + // The custom extension add callback should not be called if the client + // doesn't send the extension. + testCases = append(testCases, testCase{ + testType: serverTest, + name: "CustomExtensions-NotCalled-Server", + config: Config{ + Bugs: ProtocolBugs { + ExpectedCustomExtension: &emptyString, + }, + }, + flags: []string{"-enable-server-custom-extension", "-custom-extension-fail-add"}, + }) +} + func worker(statusChan chan statusMsg, c chan *testCase, shimPath string, wg *sync.WaitGroup) { defer wg.Done() @@ -3723,6 +3814,7 @@ addDTLSRetransmitTests() addExportKeyingMaterialTests() addTLSUniqueTests() + addCustomExtensionTests() for _, async := range []bool{false, true} { for _, splitHandshake := range []bool{false, true} { for _, protocol := range []protocol{tls, dtls} {
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc index fef51ed..d873d8e 100644 --- a/ssl/test/test_config.cc +++ b/ssl/test/test_config.cc
@@ -84,6 +84,12 @@ { "-expect-no-session", &TestConfig::expect_no_session }, { "-use-ticket-callback", &TestConfig::use_ticket_callback }, { "-renew-ticket", &TestConfig::renew_ticket }, + { "-enable-client-custom-extension", + &TestConfig::enable_client_custom_extension }, + { "-enable-server-custom-extension", + &TestConfig::enable_server_custom_extension }, + { "-custom-extension-skip", &TestConfig::custom_extension_skip }, + { "-custom-extension-fail-add", &TestConfig::custom_extension_fail_add }, }; const Flag<std::string> kStringFlags[] = {
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h index 67655f4..29a1c77 100644 --- a/ssl/test/test_config.h +++ b/ssl/test/test_config.h
@@ -82,6 +82,10 @@ bool expect_no_session = false; bool use_ticket_callback = false; bool renew_ticket = false; + bool enable_client_custom_extension = false; + bool enable_server_custom_extension = false; + bool custom_extension_skip = false; + bool custom_extension_fail_add = false; }; bool ParseConfig(int argc, char **argv, TestConfig *out_config);