Fix SSL_get_{read,write}_sequence.

I switched up the endianness. Add some tests to make sure those work right.

Also tweak the DTLS semantics. SSL_get_read_sequence should return the highest
sequence number received so far. Include the epoch number in both so we don't
need a second API for it.

Change-Id: I9901a1665b41224c46fadb7ce0b0881dcb466bcc
Reviewed-on: https://boringssl-review.googlesource.com/7141
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 54c5a22..41813b2 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -2600,12 +2600,15 @@
 OPENSSL_EXPORT int SSL_generate_key_block(const SSL *ssl, uint8_t *out,
                                           size_t out_len);
 
-/* SSL_get_read_sequence returns the expected sequence number of the next
- * incoming record in the current epoch. It always returns zero in DTLS. */
+/* SSL_get_read_sequence returns, in TLS, the expected sequence number of the
+ * next incoming record in the current epoch. In DTLS, it returns the maximum
+ * sequence number received in the current epoch and includes the epoch number
+ * in the two most significant bytes. */
 OPENSSL_EXPORT uint64_t SSL_get_read_sequence(const SSL *ssl);
 
 /* SSL_get_write_sequence returns the sequence number of the next outgoing
- * record in the current epoch. */
+ * record in the current epoch. In DTLS, it includes the epoch number in the
+ * two most significant bytes. */
 OPENSSL_EXPORT uint64_t SSL_get_write_sequence(const SSL *ssl);
 
 
diff --git a/ssl/ssl_lib.c b/ssl/ssl_lib.c
index 542dc17..0a3c8f7 100644
--- a/ssl/ssl_lib.c
+++ b/ssl/ssl_lib.c
@@ -2545,19 +2545,29 @@
 }
 
 static uint64_t be_to_u64(const uint8_t in[8]) {
-  return (((uint64_t)in[7]) << 56) | (((uint64_t)in[6]) << 48) |
-         (((uint64_t)in[5]) << 40) | (((uint64_t)in[4]) << 32) |
-         (((uint64_t)in[3]) << 24) | (((uint64_t)in[2]) << 16) |
-         (((uint64_t)in[1]) << 8) | ((uint64_t)in[0]);
+  return (((uint64_t)in[0]) << 56) | (((uint64_t)in[1]) << 48) |
+         (((uint64_t)in[2]) << 40) | (((uint64_t)in[3]) << 32) |
+         (((uint64_t)in[4]) << 24) | (((uint64_t)in[5]) << 16) |
+         (((uint64_t)in[6]) << 8) | ((uint64_t)in[7]);
 }
 
 uint64_t SSL_get_read_sequence(const SSL *ssl) {
   /* TODO(davidben): Internally represent sequence numbers as uint64_t. */
+  if (SSL_IS_DTLS(ssl)) {
+    /* max_seq_num already includes the epoch. */
+    assert(ssl->d1->r_epoch == (ssl->d1->bitmap.max_seq_num >> 48));
+    return ssl->d1->bitmap.max_seq_num;
+  }
   return be_to_u64(ssl->s3->read_sequence);
 }
 
 uint64_t SSL_get_write_sequence(const SSL *ssl) {
-  return be_to_u64(ssl->s3->write_sequence);
+  uint64_t ret = be_to_u64(ssl->s3->write_sequence);
+  if (SSL_IS_DTLS(ssl)) {
+    assert((ret >> 48) == 0);
+    ret |= ((uint64_t)ssl->d1->w_epoch) << 48;
+  }
+  return ret;
 }
 
 uint8_t SSL_get_server_key_exchange_hash(const SSL *ssl) {
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index f91542d..ed1d3c7 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -25,7 +25,9 @@
 #include <openssl/bio.h>
 #include <openssl/crypto.h>
 #include <openssl/err.h>
+#include <openssl/pem.h>
 #include <openssl/ssl.h>
+#include <openssl/x509.h>
 
 #include "test/scoped_types.h"
 #include "../crypto/test/test_util.h"
@@ -978,6 +980,158 @@
   return true;
 }
 
+static uint16_t EpochFromSequence(uint64_t seq) {
+  return static_cast<uint16_t>(seq >> 48);
+}
+
+static ScopedX509 GetTestCertificate() {
+  static const char kCertPEM[] =
+      "-----BEGIN CERTIFICATE-----\n"
+      "MIICWDCCAcGgAwIBAgIJAPuwTC6rEJsMMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV\n"
+      "BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX\n"
+      "aWRnaXRzIFB0eSBMdGQwHhcNMTQwNDIzMjA1MDQwWhcNMTcwNDIyMjA1MDQwWjBF\n"
+      "MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50\n"
+      "ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB\n"
+      "gQDYK8imMuRi/03z0K1Zi0WnvfFHvwlYeyK9Na6XJYaUoIDAtB92kWdGMdAQhLci\n"
+      "HnAjkXLI6W15OoV3gA/ElRZ1xUpxTMhjP6PyY5wqT5r6y8FxbiiFKKAnHmUcrgfV\n"
+      "W28tQ+0rkLGMryRtrukXOgXBv7gcrmU7G1jC2a7WqmeI8QIDAQABo1AwTjAdBgNV\n"
+      "HQ4EFgQUi3XVrMsIvg4fZbf6Vr5sp3Xaha8wHwYDVR0jBBgwFoAUi3XVrMsIvg4f\n"
+      "Zbf6Vr5sp3Xaha8wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOBgQA76Hht\n"
+      "ldY9avcTGSwbwoiuIqv0jTL1fHFnzy3RHMLDh+Lpvolc5DSrSJHCP5WuK0eeJXhr\n"
+      "T5oQpHL9z/cCDLAKCKRa4uV0fhEdOWBqyR9p8y5jJtye72t6CuFUV5iqcpF4BH4f\n"
+      "j2VNHwsSrJwkD4QUGlUtH7vwnQmyCFxZMmWAJg==\n"
+      "-----END CERTIFICATE-----\n";
+  ScopedBIO bio(
+      BIO_new_mem_buf(const_cast<char *>(kCertPEM), strlen(kCertPEM)));
+  return ScopedX509(PEM_read_bio_X509(bio.get(), nullptr, nullptr, nullptr));
+}
+
+static ScopedEVP_PKEY GetTestKey() {
+  static const char kKeyPEM[] =
+      "-----BEGIN RSA PRIVATE KEY-----\n"
+      "MIICXgIBAAKBgQDYK8imMuRi/03z0K1Zi0WnvfFHvwlYeyK9Na6XJYaUoIDAtB92\n"
+      "kWdGMdAQhLciHnAjkXLI6W15OoV3gA/ElRZ1xUpxTMhjP6PyY5wqT5r6y8FxbiiF\n"
+      "KKAnHmUcrgfVW28tQ+0rkLGMryRtrukXOgXBv7gcrmU7G1jC2a7WqmeI8QIDAQAB\n"
+      "AoGBAIBy09Fd4DOq/Ijp8HeKuCMKTHqTW1xGHshLQ6jwVV2vWZIn9aIgmDsvkjCe\n"
+      "i6ssZvnbjVcwzSoByhjN8ZCf/i15HECWDFFh6gt0P5z0MnChwzZmvatV/FXCT0j+\n"
+      "WmGNB/gkehKjGXLLcjTb6dRYVJSCZhVuOLLcbWIV10gggJQBAkEA8S8sGe4ezyyZ\n"
+      "m4e9r95g6s43kPqtj5rewTsUxt+2n4eVodD+ZUlCULWVNAFLkYRTBCASlSrm9Xhj\n"
+      "QpmWAHJUkQJBAOVzQdFUaewLtdOJoPCtpYoY1zd22eae8TQEmpGOR11L6kbxLQsk\n"
+      "aMly/DOnOaa82tqAGTdqDEZgSNmCeKKknmECQAvpnY8GUOVAubGR6c+W90iBuQLj\n"
+      "LtFp/9ihd2w/PoDwrHZaoUYVcT4VSfJQog/k7kjE4MYXYWL8eEKg3WTWQNECQQDk\n"
+      "104Wi91Umd1PzF0ijd2jXOERJU1wEKe6XLkYYNHWQAe5l4J4MWj9OdxFXAxIuuR/\n"
+      "tfDwbqkta4xcux67//khAkEAvvRXLHTaa6VFzTaiiO8SaFsHV3lQyXOtMrBpB5jd\n"
+      "moZWgjHvB2W9Ckn7sDqsPB+U2tyX0joDdQEyuiMECDY8oQ==\n"
+      "-----END RSA PRIVATE KEY-----\n";
+  ScopedBIO bio(BIO_new_mem_buf(const_cast<char *>(kKeyPEM), strlen(kKeyPEM)));
+  return ScopedEVP_PKEY(
+      PEM_read_bio_PrivateKey(bio.get(), nullptr, nullptr, nullptr));
+}
+
+static bool TestSequenceNumber(bool dtls) {
+  ScopedSSL_CTX client_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method()));
+  ScopedSSL_CTX server_ctx(SSL_CTX_new(dtls ? DTLS_method() : TLS_method()));
+  if (!client_ctx || !server_ctx) {
+    return false;
+  }
+
+  ScopedX509 cert = GetTestCertificate();
+  ScopedEVP_PKEY key = GetTestKey();
+  if (!cert || !key ||
+      !SSL_CTX_use_certificate(server_ctx.get(), cert.get()) ||
+      !SSL_CTX_use_PrivateKey(server_ctx.get(), key.get())) {
+    return false;
+  }
+
+  // Create a client and server connected to each other.
+  ScopedSSL client(SSL_new(client_ctx.get())), server(SSL_new(server_ctx.get()));
+  if (!client || !server) {
+    return false;
+  }
+  SSL_set_connect_state(client.get());
+  SSL_set_accept_state(server.get());
+
+  BIO *bio1, *bio2;
+  if (!BIO_new_bio_pair(&bio1, 0, &bio2, 0)) {
+    return false;
+  }
+  // SSL_set_bio takes ownership.
+  SSL_set_bio(client.get(), bio1, bio1);
+  SSL_set_bio(server.get(), bio2, bio2);
+
+  // Drive both their handshakes to completion.
+  for (;;) {
+    int client_ret = SSL_do_handshake(client.get());
+    int client_err = SSL_get_error(client.get(), client_ret);
+    if (client_err != SSL_ERROR_NONE &&
+        client_err != SSL_ERROR_WANT_READ &&
+        client_err != SSL_ERROR_WANT_WRITE) {
+      fprintf(stderr, "Client error: %d\n", client_err);
+      return false;
+    }
+
+    int server_ret = SSL_do_handshake(server.get());
+    int server_err = SSL_get_error(server.get(), server_ret);
+    if (server_err != SSL_ERROR_NONE &&
+        server_err != SSL_ERROR_WANT_READ &&
+        server_err != SSL_ERROR_WANT_WRITE) {
+      fprintf(stderr, "Server error: %d\n", server_err);
+      return false;
+    }
+
+    if (client_ret == 1 && server_ret == 1) {
+      break;
+    }
+  }
+
+  uint64_t client_read_seq = SSL_get_read_sequence(client.get());
+  uint64_t client_write_seq = SSL_get_write_sequence(client.get());
+  uint64_t server_read_seq = SSL_get_read_sequence(server.get());
+  uint64_t server_write_seq = SSL_get_write_sequence(server.get());
+
+  if (dtls) {
+    // Both client and server must be at epoch 1.
+    if (EpochFromSequence(client_read_seq) != 1 ||
+        EpochFromSequence(client_write_seq) != 1 ||
+        EpochFromSequence(server_read_seq) != 1 ||
+        EpochFromSequence(server_write_seq) != 1) {
+      fprintf(stderr, "Bad epochs.\n");
+      return false;
+    }
+
+    // The next record to be written should exceed the largest received.
+    if (client_write_seq <= server_read_seq ||
+        server_write_seq <= client_read_seq) {
+      fprintf(stderr, "Inconsistent sequence numbers.\n");
+      return false;
+    }
+  } else {
+    // The next record to be written should equal the next to be received.
+    if (client_write_seq != server_read_seq ||
+        server_write_seq != client_write_seq) {
+      fprintf(stderr, "Inconsistent sequence numbers.\n");
+      return false;
+    }
+  }
+
+  // Send a record from client to server.
+  uint8_t byte = 0;
+  if (SSL_write(client.get(), &byte, 1) != 1 ||
+      SSL_read(server.get(), &byte, 1) != 1) {
+    fprintf(stderr, "Could not send byte.\n");
+    return false;
+  }
+
+  // The client write and server read sequence numbers should have incremented.
+  if (client_write_seq + 1 != SSL_get_write_sequence(client.get()) ||
+      server_read_seq + 1 != SSL_get_read_sequence(server.get())) {
+    fprintf(stderr, "Sequence numbers did not increment.\n");\
+    return false;
+  }
+
+  return true;
+}
+
 int main() {
   CRYPTO_library_init();
 
@@ -999,7 +1153,9 @@
       !TestCipherGetRFCName() ||
       !TestPaddingExtension() ||
       !TestClientCAList() ||
-      !TestInternalSessionCache()) {
+      !TestInternalSessionCache() ||
+      !TestSequenceNumber(false /* TLS */) ||
+      !TestSequenceNumber(true /* DTLS */)) {
     ERR_print_errors_fp(stderr);
     return 1;
   }