Fix undefined behavior in shifts.
Td4 is an array of u8. A u8 << int promotes the u8 to an int first then shifts.
If the mathematical result of a shift (as modelled by lhs * 2^{rhs}) is not
representable in an integer, behaviour is undefined. In other words, you can't
shift into the sign bit of a signed integer. Fix this by casting to u32
whenever we're shifting left by 24.
(For consistency, cast other shifts, too.)
Caught by -fsanitize=shift
Submitted by Nick Lewycky (Google)
(Imported from upstream's 8b37e5c14f0eddb10c7f91ef91004622d90ef361.)
Change-Id: Id0f98d1d65738533c6ddcc3c21bc38b569d74793
Reviewed-on: https://boringssl-review.googlesource.com/4040
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/crypto/aes/aes.c b/crypto/aes/aes.c
index 97b4fbd..933aa07 100644
--- a/crypto/aes/aes.c
+++ b/crypto/aes/aes.c
@@ -1033,17 +1033,25 @@
#endif /* ?FULL_UNROLL */
/* apply last round and
* map cipher state to byte array block: */
- s0 = (Td4[(t0 >> 24)] << 24) ^ (Td4[(t3 >> 16) & 0xff] << 16) ^
- (Td4[(t2 >> 8) & 0xff] << 8) ^ (Td4[(t1) & 0xff]) ^ rk[0];
+ s0 = ((uint32_t)Td4[(t0 >> 24)] << 24) ^
+ ((uint32_t)Td4[(t3 >> 16) & 0xff] << 16) ^
+ ((uint32_t)Td4[(t2 >> 8) & 0xff] << 8) ^
+ ((uint32_t)Td4[(t1) & 0xff]) ^ rk[0];
PUTU32(out, s0);
- s1 = (Td4[(t1 >> 24)] << 24) ^ (Td4[(t0 >> 16) & 0xff] << 16) ^
- (Td4[(t3 >> 8) & 0xff] << 8) ^ (Td4[(t2) & 0xff]) ^ rk[1];
+ s1 = ((uint32_t)Td4[(t1 >> 24)] << 24) ^
+ ((uint32_t)Td4[(t0 >> 16) & 0xff] << 16) ^
+ ((uint32_t)Td4[(t3 >> 8) & 0xff] << 8) ^
+ ((uint32_t)Td4[(t2) & 0xff]) ^ rk[1];
PUTU32(out + 4, s1);
- s2 = (Td4[(t2 >> 24)] << 24) ^ (Td4[(t1 >> 16) & 0xff] << 16) ^
- (Td4[(t0 >> 8) & 0xff] << 8) ^ (Td4[(t3) & 0xff]) ^ rk[2];
+ s2 = ((uint32_t)Td4[(t2 >> 24)] << 24) ^
+ ((uint32_t)Td4[(t1 >> 16) & 0xff] << 16) ^
+ ((uint32_t)Td4[(t0 >> 8) & 0xff] << 8) ^
+ ((uint32_t)Td4[(t3) & 0xff]) ^ rk[2];
PUTU32(out + 8, s2);
- s3 = (Td4[(t3 >> 24)] << 24) ^ (Td4[(t2 >> 16) & 0xff] << 16) ^
- (Td4[(t1 >> 8) & 0xff] << 8) ^ (Td4[(t0) & 0xff]) ^ rk[3];
+ s3 = ((uint32_t)Td4[(t3 >> 24)] << 24) ^
+ ((uint32_t)Td4[(t2 >> 16) & 0xff] << 16) ^
+ ((uint32_t)Td4[(t1 >> 8) & 0xff] << 8) ^
+ ((uint32_t)Td4[(t0) & 0xff]) ^ rk[3];
PUTU32(out + 12, s3);
}