Simplify constant-time RSA padding check.

(Imported form upstream's 455b65dfab0de51c9f67b3c909311770f2b3f801 and
0d6a11a91f4de238ce533c40bd9507fe5d95f288)

Change-Id: Ia195c7fe753cfa3a7f8c91d2d7b2cd40a547be43
diff --git a/crypto/internal.h b/crypto/internal.h
index f32d369..e32f460 100644
--- a/crypto/internal.h
+++ b/crypto/internal.h
@@ -225,11 +225,22 @@
   return constant_time_is_zero(a ^ b);
 }
 
-/* constant_time_eq_8 acts like constant_time_eq but returns an 8-bit mask. */
+/* constant_time_eq_8 acts like |constant_time_eq| but returns an 8-bit mask. */
 static inline uint8_t constant_time_eq_8(unsigned int a, unsigned int b) {
   return (uint8_t)(constant_time_eq(a, b));
 }
 
+/* constant_time_eq_int acts like |constant_time_eq| but works on int values. */
+static inline unsigned int constant_time_eq_int(int a, int b) {
+  return constant_time_eq((unsigned)(a), (unsigned)(b));
+}
+
+/* constant_time_eq_int_8 acts like |constant_time_eq_int| but returns an 8-bit
+ * mask. */
+static inline uint8_t constant_time_eq_int_8(int a, int b) {
+  return constant_time_eq_8((unsigned)(a), (unsigned)(b));
+}
+
 /* constant_time_select returns (mask & a) | (~mask & b). When |mask| is all 1s
  * or all 0s (as returned by the methods above), the select methods return
  * either |a| (if |mask| is nonzero) or |b| (if |mask| is zero). */
diff --git a/ssl/s3_srvr.c b/ssl/s3_srvr.c
index 3fc40d7..da1dc9f 100644
--- a/ssl/s3_srvr.c
+++ b/ssl/s3_srvr.c
@@ -169,6 +169,7 @@
 #include <openssl/x509.h>
 
 #include "ssl_locl.h"
+#include "../crypto/internal.h"
 #include "../crypto/dh/internal.h"
 
 static const SSL_METHOD *ssl3_get_server_method(int ver)
@@ -1846,8 +1847,7 @@
 		{
 		CBS encrypted_premaster_secret;
 		uint8_t rand_premaster_secret[SSL_MAX_MASTER_KEY_LENGTH];
-		int decrypt_good_mask;
-		uint8_t version_good;
+		uint8_t good;
 		size_t rsa_size, decrypt_len, premaster_index, j;
 
 		pkey=s->cert->pkeys[SSL_PKEY_RSA_ENC].privatekey;
@@ -1930,16 +1930,15 @@
 			goto err;
 			}
 
-		/* Remove the PKCS#1 padding and adjust decrypt_len as
-		 * appropriate. decrypt_good_mask will be zero if the premaster
-		 * if good and non-zero otherwise. */
-		decrypt_good_mask = RSA_message_index_PKCS1_type_2(
-			decrypt_buf, decrypt_len, &premaster_index);
-		decrypt_good_mask--;
+		/* Remove the PKCS#1 padding and adjust |decrypt_len| as
+		 * appropriate. |good| will be 0xff if the premaster is
+		 * acceptable and zero otherwise. */
+		good = constant_time_eq_int_8(
+		    RSA_message_index_PKCS1_type_2(decrypt_buf, decrypt_len, &premaster_index), 1);
 		decrypt_len = decrypt_len - premaster_index;
 
 		/* decrypt_len should be SSL_MAX_MASTER_KEY_LENGTH. */
-		decrypt_good_mask |= decrypt_len ^ SSL_MAX_MASTER_KEY_LENGTH;
+		good &= constant_time_eq_8(decrypt_len, SSL_MAX_MASTER_KEY_LENGTH);
 
 		/* Copy over the unpadded premaster. Whatever the value of
 		 * |decrypt_good_mask|, copy as if the premaster were the right
@@ -1957,43 +1956,20 @@
 		decrypt_buf = NULL;
 
 		/* If the version in the decrypted pre-master secret is correct
-		 * then version_good will be zero. The Klima-Pokorny-Rosa
-		 * extension of Bleichenbacher's attack
+		 * then version_good will be 0xff, otherwise it'll be zero. The
+		 * Klima-Pokorny-Rosa extension of Bleichenbacher's attack
 		 * (http://eprint.iacr.org/2003/052/) exploits the version
 		 * number check as a "bad version oracle". Thus version checks
 		 * are done in constant time and are treated like any other
 		 * decryption error. */
-		version_good = premaster_secret[0] ^ (s->client_version>>8);
-		version_good |= premaster_secret[1] ^ (s->client_version&0xff);
-
-		/* If any bits in version_good are set then they'll poision
-		 * decrypt_good_mask and cause rand_premaster_secret to be
-		 * used. */
-		decrypt_good_mask |= version_good;
-
-		/* decrypt_good_mask will be zero iff decrypt_len ==
-		 * SSL_MAX_MASTER_KEY_LENGTH and the version check passed. We
-		 * fold the bottom 32 bits of it with an OR so that the LSB
-		 * will be zero iff everything is good. This assumes that we'll
-		 * never decrypt a value > 2**31 bytes, which seems safe. */
-		decrypt_good_mask |= decrypt_good_mask >> 16;
-		decrypt_good_mask |= decrypt_good_mask >> 8;
-		decrypt_good_mask |= decrypt_good_mask >> 4;
-		decrypt_good_mask |= decrypt_good_mask >> 2;
-		decrypt_good_mask |= decrypt_good_mask >> 1;
-		/* Now select only the LSB and subtract one. If decrypt_len ==
-		 * SSL_MAX_MASTER_KEY_LENGTH and the version check passed then
-		 * decrypt_good_mask will be all ones. Otherwise it'll be all
-		 * zeros. */
-		decrypt_good_mask &= 1;
-		decrypt_good_mask--;
+		good &= constant_time_eq_8(premaster_secret[0], (unsigned)(s->client_version>>8));
+		good &= constant_time_eq_8(premaster_secret[1], (unsigned)(s->client_version&0xff));
 
 		/* Now copy rand_premaster_secret over premaster_secret using
 		 * decrypt_good_mask. */
 		for (j = 0; j < sizeof(rand_premaster_secret); j++)
 			{
-			premaster_secret[j] = (premaster_secret[j] & decrypt_good_mask) |
-			       (rand_premaster_secret[j] & ~decrypt_good_mask);
+			premaster_secret[j] = constant_time_select_8(good, premaster_secret[j], rand_premaster_secret[j]);
 			}
 
 		premaster_secret_len = sizeof(rand_premaster_secret);