Standardize on Init vs InitForOverwrite for value vs default initialization
C++ is fun and has two notions of "default" initialization, `new T` and
`new T()`. These are default initialization and value initialization,
respectively.
They are identical except that POD types are uninit when
default-initialized and zero when value-initialized. InplaceVector
picked the safer option by default and called the other one
FooMaybeUninit. Array is older and uses the less safe one (it's almost
always the one we want; we usually allocate an array to immediately fill
it).
While MaybeUninit does capture what you do with it, it is slightly
ambiguous, as seen in Array's internal implementation: uninitialized
could also mean we haven't gotten around to initialize it at all. I.e.
we need to use a function like std::uninitialized_value_construct_n
instead of normal functions in <algorithm>.
C++20 has std::make_unique and std::make_unique_for_overwrite to capture
the two. This seems as fine a naming convention as any, so switch to it.
Along the way, make the internal bssl::Array default to the safer one.
This lets us remove a couple of memset(0)'s.
Change-Id: I32cede231da051a854e6251e10b87f8e4dd06ee6
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/72268
Reviewed-by: Nick Harper <nharper@chromium.org>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/ssl/d1_both.cc b/ssl/d1_both.cc
index f8ac34a..e0c2614 100644
--- a/ssl/d1_both.cc
+++ b/ssl/d1_both.cc
@@ -785,7 +785,7 @@
dtls1_update_mtu(ssl);
Array<uint8_t> packet;
- if (!packet.Init(ssl->d1->mtu)) {
+ if (!packet.InitForOverwrite(ssl->d1->mtu)) {
return -1;
}
diff --git a/ssl/encrypted_client_hello.cc b/ssl/encrypted_client_hello.cc
index 8c4a42c..f9686fe 100644
--- a/ssl/encrypted_client_hello.cc
+++ b/ssl/encrypted_client_hello.cc
@@ -303,7 +303,7 @@
return false;
}
#else
- if (!encoded.Init(payload.size())) {
+ if (!encoded.InitForOverwrite(payload.size())) {
*out_alert = SSL_AD_INTERNAL_ERROR;
return false;
}
diff --git a/ssl/extensions.cc b/ssl/extensions.cc
index c9424c9..ad4bf81 100644
--- a/ssl/extensions.cc
+++ b/ssl/extensions.cc
@@ -175,7 +175,7 @@
}
Array<uint16_t> extension_types;
- if (!extension_types.Init(num_extensions)) {
+ if (!extension_types.InitForOverwrite(num_extensions)) {
return false;
}
@@ -2526,7 +2526,7 @@
}
Array<uint16_t> ret;
- if (!ret.Init(CBS_len(©) / 2)) {
+ if (!ret.InitForOverwrite(CBS_len(©) / 2)) {
return false;
}
for (size_t i = 0; i < ret.size(); i++) {
@@ -2878,7 +2878,7 @@
const size_t num_given_alg_ids = CBS_len(&alg_ids) / 2;
Array<uint16_t> given_alg_ids;
- if (!given_alg_ids.Init(num_given_alg_ids)) {
+ if (!given_alg_ids.InitForOverwrite(num_given_alg_ids)) {
return false;
}
@@ -3352,7 +3352,7 @@
uint32_t seeds[kNumExtensions - 1];
Array<uint8_t> permutation;
if (!RAND_bytes(reinterpret_cast<uint8_t *>(seeds), sizeof(seeds)) ||
- !permutation.Init(kNumExtensions)) {
+ !permutation.InitForOverwrite(kNumExtensions)) {
return false;
}
for (size_t i = 0; i < kNumExtensions; i++) {
@@ -3918,7 +3918,7 @@
if (ciphertext.size() >= INT_MAX) {
return ssl_ticket_aead_ignore_ticket;
}
- if (!plaintext.Init(ciphertext.size())) {
+ if (!plaintext.InitForOverwrite(ciphertext.size())) {
return ssl_ticket_aead_error;
}
int len1, len2;
@@ -4006,7 +4006,7 @@
SSL_HANDSHAKE *hs, Array<uint8_t> *out, bool *out_renew_ticket,
Span<const uint8_t> ticket) {
Array<uint8_t> plaintext;
- if (!plaintext.Init(ticket.size())) {
+ if (!plaintext.InitForOverwrite(ticket.size())) {
return ssl_ticket_aead_error;
}
@@ -4115,7 +4115,7 @@
// Envoy's tests expect the session to have a session ID that matches the
// placeholder used by the client. It's unclear whether this is a good idea,
// but we maintain it for now.
- session->session_id.ResizeMaybeUninit(SHA256_DIGEST_LENGTH);
+ session->session_id.ResizeForOverwrite(SHA256_DIGEST_LENGTH);
SHA256(ticket.data(), ticket.size(), session->session_id.data());
*out_session = std::move(session);
@@ -4356,7 +4356,7 @@
}
size_t digest_len;
- hs->new_session->original_handshake_hash.ResizeMaybeUninit(
+ hs->new_session->original_handshake_hash.ResizeForOverwrite(
hs->transcript.DigestLen());
if (!hs->transcript.GetHash(hs->new_session->original_handshake_hash.data(),
&digest_len)) {
diff --git a/ssl/handoff.cc b/ssl/handoff.cc
index 6f5c7e2..d317269 100644
--- a/ssl/handoff.cc
+++ b/ssl/handoff.cc
@@ -176,7 +176,7 @@
return false;
}
Array<uint16_t> supported_groups;
- if (!supported_groups.Init(CBS_len(&groups) / 2)) {
+ if (!supported_groups.InitForOverwrite(CBS_len(&groups) / 2)) {
return false;
}
size_t idx = 0;
@@ -190,7 +190,7 @@
Span<const uint16_t> configured_groups =
tls1_get_grouplist(ssl->s3->hs.get());
Array<uint16_t> new_configured_groups;
- if (!new_configured_groups.Init(configured_groups.size())) {
+ if (!new_configured_groups.InitForOverwrite(configured_groups.size())) {
return false;
}
idx = 0;
diff --git a/ssl/handshake_client.cc b/ssl/handshake_client.cc
index f87e000..031ed81 100644
--- a/ssl/handshake_client.cc
+++ b/ssl/handshake_client.cc
@@ -542,7 +542,7 @@
if (has_id_session) {
hs->session_id = ssl->session->session_id;
} else if (ticket_session_requires_random_id || enable_compatibility_mode) {
- hs->session_id.ResizeMaybeUninit(SSL_MAX_SSL_SESSION_ID_LENGTH);
+ hs->session_id.ResizeForOverwrite(SSL_MAX_SSL_SESSION_ID_LENGTH);
if (!RAND_bytes(hs->session_id.data(), hs->session_id.size())) {
return ssl_hs_error;
}
@@ -1528,16 +1528,15 @@
// Depending on the key exchange method, compute |pms|.
if (alg_k & SSL_kRSA) {
- if (!pms.Init(SSL_MAX_MASTER_KEY_LENGTH)) {
- return ssl_hs_error;
- }
-
RSA *rsa = EVP_PKEY_get0_RSA(hs->peer_pubkey.get());
if (rsa == NULL) {
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
return ssl_hs_error;
}
+ if (!pms.InitForOverwrite(SSL_MAX_MASTER_KEY_LENGTH)) {
+ return ssl_hs_error;
+ }
pms[0] = hs->client_version >> 8;
pms[1] = hs->client_version & 0xff;
if (!RAND_bytes(&pms[2], SSL_MAX_MASTER_KEY_LENGTH - 2)) {
@@ -1581,7 +1580,6 @@
if (!pms.Init(psk_len)) {
return ssl_hs_error;
}
- OPENSSL_memset(pms.data(), 0, pms.size());
} else {
ssl_send_alert(ssl, SSL3_AL_FATAL, SSL_AD_HANDSHAKE_FAILURE);
OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
@@ -1609,7 +1607,7 @@
return ssl_hs_error;
}
- hs->new_session->secret.ResizeMaybeUninit(SSL3_MASTER_SECRET_SIZE);
+ hs->new_session->secret.ResizeForOverwrite(SSL3_MASTER_SECRET_SIZE);
if (!tls1_generate_master_secret(hs, MakeSpan(hs->new_session->secret),
pms)) {
return ssl_hs_error;
@@ -1850,7 +1848,7 @@
// Historically, OpenSSL filled in fake session IDs for ticket-based sessions.
// TODO(davidben): Are external callers relying on this? Try removing this.
- hs->new_session->session_id.ResizeMaybeUninit(SHA256_DIGEST_LENGTH);
+ hs->new_session->session_id.ResizeForOverwrite(SHA256_DIGEST_LENGTH);
SHA256(CBS_data(&ticket), CBS_len(&ticket),
hs->new_session->session_id.data());
diff --git a/ssl/handshake_server.cc b/ssl/handshake_server.cc
index 0eb037b..1b292a7 100644
--- a/ssl/handshake_server.cc
+++ b/ssl/handshake_server.cc
@@ -941,7 +941,7 @@
// Assign a session ID if not using session tickets.
if (!hs->ticket_expected &&
(ssl->ctx->session_cache_mode & SSL_SESS_CACHE_SERVER)) {
- hs->new_session->session_id.ResizeMaybeUninit(SSL3_SSL_SESSION_ID_LENGTH);
+ hs->new_session->session_id.ResizeForOverwrite(SSL3_SSL_SESSION_ID_LENGTH);
RAND_bytes(hs->new_session->session_id.data(),
hs->new_session->session_id.size());
}
@@ -1464,7 +1464,8 @@
// Allocate a buffer large enough for an RSA decryption.
Array<uint8_t> decrypt_buf;
- if (!decrypt_buf.Init(EVP_PKEY_size(hs->credential->pubkey.get()))) {
+ if (!decrypt_buf.InitForOverwrite(
+ EVP_PKEY_size(hs->credential->pubkey.get()))) {
return ssl_hs_error;
}
@@ -1492,7 +1493,7 @@
// Prepare a random premaster, to be used on invalid padding. See RFC 5246,
// section 7.4.7.1.
- if (!premaster_secret.Init(SSL_MAX_MASTER_KEY_LENGTH) ||
+ if (!premaster_secret.InitForOverwrite(SSL_MAX_MASTER_KEY_LENGTH) ||
!RAND_bytes(premaster_secret.data(), premaster_secret.size())) {
return ssl_hs_error;
}
@@ -1583,7 +1584,6 @@
if (!premaster_secret.Init(psk_len)) {
return ssl_hs_error;
}
- OPENSSL_memset(premaster_secret.data(), 0, premaster_secret.size());
}
ScopedCBB new_premaster;
@@ -1605,7 +1605,7 @@
}
// Compute the master secret.
- hs->new_session->secret.ResizeMaybeUninit(SSL3_MASTER_SECRET_SIZE);
+ hs->new_session->secret.ResizeForOverwrite(SSL3_MASTER_SECRET_SIZE);
if (!tls1_generate_master_secret(hs, MakeSpan(hs->new_session->secret),
premaster_secret)) {
return ssl_hs_error;
diff --git a/ssl/internal.h b/ssl/internal.h
index c4a36ac..4a1f84c 100644
--- a/ssl/internal.h
+++ b/ssl/internal.h
@@ -334,14 +334,24 @@
}
// Init replaces the array with a newly-allocated array of |new_size|
- // default-constructed copies of |T|. It returns true on success and false on
- // error.
- //
- // Note that if |T| is a primitive type like |uint8_t|, it is uninitialized.
+ // value-constructed copies of |T|. It returns true on success and false on
+ // error. If |T| is a primitive type like |uint8_t|, value-construction means
+ // it will be zero-initialized.
bool Init(size_t new_size) {
if (!InitUninitialized(new_size)) {
return false;
}
+ cxx17_uninitialized_value_construct_n(data_, size_);
+ return true;
+ }
+
+ // InitForOverwrite behaves like |Init| but it default-constructs each element
+ // instead. This means that, if |T| is a primitive type, the array will be
+ // uninitialized and thus must be filled in by the caller.
+ bool InitForOverwrite(size_t new_size) {
+ if (!InitUninitialized(new_size)) {
+ return false;
+ }
cxx17_uninitialized_default_construct_n(data_, size_);
return true;
}
@@ -585,10 +595,10 @@
return true;
}
- // TryResizeMaybeUninit behaves like |TryResize|, but newly-added elements are
- // default-initialized, so POD types may contain uninitialized values that the
- // caller is responsible for filling in.
- bool TryResizeMaybeUninit(size_t new_size) {
+ // TryResizeForOverwrite behaves like |TryResize|, but newly-added elements
+ // are default-initialized, so POD types may contain uninitialized values that
+ // the caller is responsible for filling in.
+ bool TryResizeForOverwrite(size_t new_size) {
if (new_size <= size_) {
Shrink(new_size);
return true;
@@ -628,8 +638,8 @@
// The following methods behave like their |Try*| counterparts, but abort the
// program on failure.
void Resize(size_t size) { BSSL_CHECK(TryResize(size)); }
- void ResizeMaybeUninit(size_t size) {
- BSSL_CHECK(TryResizeMaybeUninit(size));
+ void ResizeForOverwrite(size_t size) {
+ BSSL_CHECK(TryResizeForOverwrite(size));
}
void CopyFrom(Span<const T> in) { BSSL_CHECK(TryCopyFrom(in)); }
T &PushBack(T val) {
diff --git a/ssl/ssl_cipher.cc b/ssl/ssl_cipher.cc
index 97e69ff..48f90a1 100644
--- a/ssl/ssl_cipher.cc
+++ b/ssl/ssl_cipher.cc
@@ -919,7 +919,6 @@
if (!number_uses.Init(max_strength_bits + 1)) {
return false;
}
- OPENSSL_memset(number_uses.data(), 0, (max_strength_bits + 1) * sizeof(int));
// Now find the strength_bits values actually used.
curr = *head_p;
@@ -1231,7 +1230,7 @@
UniquePtr<STACK_OF(SSL_CIPHER)> cipherstack(sk_SSL_CIPHER_new_null());
Array<bool> in_group_flags;
if (cipherstack == nullptr ||
- !in_group_flags.Init(OPENSSL_ARRAY_SIZE(kCiphers))) {
+ !in_group_flags.InitForOverwrite(OPENSSL_ARRAY_SIZE(kCiphers))) {
return false;
}
@@ -1246,13 +1245,11 @@
in_group_flags[num_in_group_flags++] = curr->in_group;
}
}
+ in_group_flags.Shrink(num_in_group_flags);
UniquePtr<SSLCipherPreferenceList> pref_list =
MakeUnique<SSLCipherPreferenceList>();
- if (!pref_list ||
- !pref_list->Init(
- std::move(cipherstack),
- MakeConstSpan(in_group_flags).subspan(0, num_in_group_flags))) {
+ if (!pref_list || !pref_list->Init(std::move(cipherstack), in_group_flags)) {
return false;
}
diff --git a/ssl/ssl_credential.cc b/ssl/ssl_credential.cc
index d5df82e..b5282b8 100644
--- a/ssl/ssl_credential.cc
+++ b/ssl/ssl_credential.cc
@@ -48,7 +48,7 @@
num_creds++;
}
- if (!out->Init(num_creds)) {
+ if (!out->InitForOverwrite(num_creds)) {
return false;
}
diff --git a/ssl/ssl_internal_test.cc b/ssl/ssl_internal_test.cc
index 258d39d..f5982a2 100644
--- a/ssl/ssl_internal_test.cc
+++ b/ssl/ssl_internal_test.cc
@@ -23,6 +23,15 @@
BSSL_NAMESPACE_BEGIN
namespace {
+TEST(ArrayTest, InitValueConstructs) {
+ Array<uint8_t> array;
+ ASSERT_TRUE(array.Init(10));
+ EXPECT_EQ(array.size(), 10u);
+ for (size_t i = 0; i < 10u; i++) {
+ EXPECT_EQ(0u, array[i]);
+ }
+}
+
TEST(ArrayDeathTest, BoundsChecks) {
Array<int> array;
const int v[] = {1, 2, 3, 4};
@@ -345,7 +354,7 @@
EXPECT_DEATH_IF_SUPPORTED(vec[1000], "");
// The vector cannot be resized past the capacity.
EXPECT_DEATH_IF_SUPPORTED(vec.Resize(5), "");
- EXPECT_DEATH_IF_SUPPORTED(vec.ResizeMaybeUninit(5), "");
+ EXPECT_DEATH_IF_SUPPORTED(vec.ResizeForOverwrite(5), "");
int too_much_data[] = {1, 2, 3, 4, 5};
EXPECT_DEATH_IF_SUPPORTED(vec.CopyFrom(too_much_data), "");
vec.Resize(4);
diff --git a/ssl/ssl_key_share.cc b/ssl/ssl_key_share.cc
index 88144b0..0342e62 100644
--- a/ssl/ssl_key_share.cc
+++ b/ssl/ssl_key_share.cc
@@ -108,7 +108,7 @@
// Encode the x-coordinate left-padded with zeros.
Array<uint8_t> secret;
- if (!secret.Init((EC_GROUP_get_degree(group_) + 7) / 8) ||
+ if (!secret.InitForOverwrite((EC_GROUP_get_degree(group_) + 7) / 8) ||
!BN_bn2bin_padded(secret.data(), secret.size(), x.get())) {
return false;
}
@@ -162,7 +162,7 @@
*out_alert = SSL_AD_INTERNAL_ERROR;
Array<uint8_t> secret;
- if (!secret.Init(32)) {
+ if (!secret.InitForOverwrite(32)) {
return false;
}
@@ -220,7 +220,7 @@
bool Encap(CBB *out_ciphertext, Array<uint8_t> *out_secret,
uint8_t *out_alert, Span<const uint8_t> peer_key) override {
Array<uint8_t> secret;
- if (!secret.Init(32 + KYBER_SHARED_SECRET_BYTES)) {
+ if (!secret.InitForOverwrite(32 + KYBER_SHARED_SECRET_BYTES)) {
return false;
}
@@ -260,7 +260,7 @@
*out_alert = SSL_AD_INTERNAL_ERROR;
Array<uint8_t> secret;
- if (!secret.Init(32 + KYBER_SHARED_SECRET_BYTES)) {
+ if (!secret.InitForOverwrite(32 + KYBER_SHARED_SECRET_BYTES)) {
return false;
}
@@ -308,7 +308,8 @@
bool Encap(CBB *out_ciphertext, Array<uint8_t> *out_secret,
uint8_t *out_alert, Span<const uint8_t> peer_key) override {
Array<uint8_t> secret;
- if (!secret.Init(MLKEM_SHARED_SECRET_BYTES + X25519_SHARED_KEY_LEN)) {
+ if (!secret.InitForOverwrite(MLKEM_SHARED_SECRET_BYTES +
+ X25519_SHARED_KEY_LEN)) {
return false;
}
@@ -349,7 +350,8 @@
*out_alert = SSL_AD_INTERNAL_ERROR;
Array<uint8_t> secret;
- if (!secret.Init(MLKEM_SHARED_SECRET_BYTES + X25519_SHARED_KEY_LEN)) {
+ if (!secret.InitForOverwrite(MLKEM_SHARED_SECRET_BYTES +
+ X25519_SHARED_KEY_LEN)) {
return false;
}
diff --git a/ssl/ssl_lib.cc b/ssl/ssl_lib.cc
index cbb5cbf..ef4b7d5 100644
--- a/ssl/ssl_lib.cc
+++ b/ssl/ssl_lib.cc
@@ -1976,7 +1976,7 @@
static bool ssl_nids_to_group_ids(Array<uint16_t> *out_group_ids,
Span<const int> nids) {
Array<uint16_t> group_ids;
- if (!group_ids.Init(nids.size())) {
+ if (!group_ids.InitForOverwrite(nids.size())) {
return false;
}
@@ -2018,7 +2018,7 @@
} while (col);
Array<uint16_t> group_ids;
- if (!group_ids.Init(count)) {
+ if (!group_ids.InitForOverwrite(count)) {
return false;
}
diff --git a/ssl/ssl_privkey.cc b/ssl/ssl_privkey.cc
index 76ba084..14a8b67 100644
--- a/ssl/ssl_privkey.cc
+++ b/ssl/ssl_privkey.cc
@@ -603,7 +603,7 @@
// Check for invalid algorithms, and filter out |SSL_SIGN_RSA_PKCS1_MD5_SHA1|.
Array<uint16_t> filtered;
- if (!filtered.Init(prefs.size())) {
+ if (!filtered.InitForOverwrite(prefs.size())) {
return false;
}
size_t added = 0;
@@ -695,7 +695,7 @@
}
const size_t num_pairs = num_values / 2;
- if (!out->Init(num_pairs)) {
+ if (!out->InitForOverwrite(num_pairs)) {
return false;
}
@@ -771,7 +771,7 @@
}
}
- if (!out->Init(num_elements)) {
+ if (!out->InitForOverwrite(num_elements)) {
return false;
}
size_t out_i = 0;
diff --git a/ssl/t1_enc.cc b/ssl/t1_enc.cc
index a13629e..fb9ae89 100644
--- a/ssl/t1_enc.cc
+++ b/ssl/t1_enc.cc
@@ -214,7 +214,7 @@
// Ensure that |key_block_cache| is set up.
const size_t key_block_size = 2 * (mac_secret_len + key_len + iv_len);
if (key_block_cache->empty()) {
- if (!key_block_cache->Init(key_block_size) ||
+ if (!key_block_cache->InitForOverwrite(key_block_size) ||
!generate_key_block(ssl, MakeSpan(*key_block_cache), session)) {
return false;
}
@@ -362,7 +362,7 @@
seed_len += 2 + context_len;
}
Array<uint8_t> seed;
- if (!seed.Init(seed_len)) {
+ if (!seed.InitForOverwrite(seed_len)) {
return 0;
}
diff --git a/ssl/tls13_client.cc b/ssl/tls13_client.cc
index 24da90a..7082d16 100644
--- a/ssl/tls13_client.cc
+++ b/ssl/tls13_client.cc
@@ -1166,7 +1166,7 @@
// Historically, OpenSSL filled in fake session IDs for ticket-based sessions.
// Envoy's tests depend on this, although perhaps they shouldn't.
- session->session_id.ResizeMaybeUninit(SHA256_DIGEST_LENGTH);
+ session->session_id.ResizeForOverwrite(SHA256_DIGEST_LENGTH);
SHA256(CBS_data(&ticket), CBS_len(&ticket), session->session_id.data());
session->ticket_age_add_valid = true;
diff --git a/ssl/tls13_enc.cc b/ssl/tls13_enc.cc
index db59f71..29f07ea 100644
--- a/ssl/tls13_enc.cc
+++ b/ssl/tls13_enc.cc
@@ -168,7 +168,7 @@
return false;
}
- out->ResizeMaybeUninit(transcript.DigestLen());
+ out->ResizeForOverwrite(transcript.DigestLen());
return hkdf_expand_label(MakeSpan(*out), transcript.Digest(), hs->secret,
label, MakeConstSpan(context_hash, context_hash_len),
SSL_is_dtls(hs->ssl));