Check key sizes in AES_set_*_key.

AES_set_*_key used to call directly into aes_nohw_set_*_key which
gracefully handles some NULL parameters and invalid bit sizes. However,
we now enable optimized assembly implementations, not all of which
perform these checks. (vpaes does not.)

This is fine for the internal assembly functions themselves. Such checks
are better written in C than assembly, and the calling C code usually
already knows the key size. (Indeed aes_ctr_set_key already assumes the
assembly functions are infallible.) AES_set_*_key are public APIs,
however. The NULL check is silly, but we should handle length-like
checks in public APIs.

Change-Id: I259ae6b9811ceaa9dc5bd7173d5754ca7079cff8
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/35564
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/fipsmodule/aes/aes.c b/crypto/fipsmodule/aes/aes.c
index 8a1ca31..48d60ee 100644
--- a/crypto/fipsmodule/aes/aes.c
+++ b/crypto/fipsmodule/aes/aes.c
@@ -834,6 +834,9 @@
 }
 
 int AES_set_encrypt_key(const uint8_t *key, unsigned bits, AES_KEY *aeskey) {
+  if (bits != 128 && bits != 192 && bits != 256) {
+    return -2;
+  }
   if (hwaes_capable()) {
     return aes_hw_set_encrypt_key(key, bits, aeskey);
   } else if (vpaes_capable()) {
@@ -844,6 +847,9 @@
 }
 
 int AES_set_decrypt_key(const uint8_t *key, unsigned bits, AES_KEY *aeskey) {
+  if (bits != 128 && bits != 192 && bits != 256) {
+    return -2;
+  }
   if (hwaes_capable()) {
     return aes_hw_set_decrypt_key(key, bits, aeskey);
   } else if (vpaes_capable()) {
diff --git a/crypto/fipsmodule/aes/aes_test.cc b/crypto/fipsmodule/aes/aes_test.cc
index 2222b63..1f9a491 100644
--- a/crypto/fipsmodule/aes/aes_test.cc
+++ b/crypto/fipsmodule/aes/aes_test.cc
@@ -189,6 +189,13 @@
   }
 }
 
+TEST(AESTest, InvalidKeySize) {
+  static const uint8_t kZero[8] = {0};
+  AES_KEY key;
+  EXPECT_LT(AES_set_encrypt_key(kZero, 42, &key), 0);
+  EXPECT_LT(AES_set_decrypt_key(kZero, 42, &key), 0);
+}
+
 #if defined(SUPPORTS_ABI_TEST)
 TEST(AESTest, ABI) {
   for (int bits : {128, 192, 256}) {
diff --git a/include/openssl/aes.h b/include/openssl/aes.h
index 1156585..3606bfc 100644
--- a/include/openssl/aes.h
+++ b/include/openssl/aes.h
@@ -76,18 +76,18 @@
 typedef struct aes_key_st AES_KEY;
 
 // AES_set_encrypt_key configures |aeskey| to encrypt with the |bits|-bit key,
-// |key|.
+// |key|. |key| must point to |bits|/8 bytes. It returns zero on success and a
+// negative number if |bits| is an invalid AES key size.
 //
-// WARNING: unlike other OpenSSL functions, this returns zero on success and a
-// negative number on error.
+// WARNING: this function breaks the usual return value convention.
 OPENSSL_EXPORT int AES_set_encrypt_key(const uint8_t *key, unsigned bits,
                                        AES_KEY *aeskey);
 
 // AES_set_decrypt_key configures |aeskey| to decrypt with the |bits|-bit key,
-// |key|.
+// |key|. |key| must point to |bits|/8 bytes. It returns zero on success and a
+// negative number if |bits| is an invalid AES key size.
 //
-// WARNING: unlike other OpenSSL functions, this returns zero on success and a
-// negative number on error.
+// WARNING: this function breaks the usual return value convention.
 OPENSSL_EXPORT int AES_set_decrypt_key(const uint8_t *key, unsigned bits,
                                        AES_KEY *aeskey);