Add bssl::SealRecord and bssl::OpenRecord.

This is a C++ interface for encrypting and decrypting TLS application
data records in-place, wrapping the existing C API in tls_record.cc.

Also add bssl::Span, a non-owning reference to a contiguous array of
elements which can be used as a common interface over contiguous
container types (like std::vector), pointer-length-pairs, arrays, etc.

Change-Id: Iaa2ca4957cde511cb734b997db38f54e103b0d92
Reviewed-on: https://boringssl-review.googlesource.com/18104
Commit-Queue: Martin Kreichgauer <martinkr@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/CMakeLists.txt b/ssl/CMakeLists.txt
index c228f4a..d9d2eb5 100644
--- a/ssl/CMakeLists.txt
+++ b/ssl/CMakeLists.txt
@@ -45,6 +45,7 @@
 add_executable(
   ssl_test
 
+  span_test.cc
   ssl_test.cc
 
   $<TARGET_OBJECTS:gtest_main>
diff --git a/ssl/span_test.cc b/ssl/span_test.cc
new file mode 100644
index 0000000..0aa7f3d
--- /dev/null
+++ b/ssl/span_test.cc
@@ -0,0 +1,90 @@
+/* Copyright (c) 2017, Google Inc.
+ *
+ * Permission to use, copy, modify, and/or distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+ * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+ * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+ * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
+
+#include <stdio.h>
+#include <vector>
+
+#include <gtest/gtest.h>
+
+#include <openssl/ssl.h>
+
+namespace bssl {
+namespace {
+
+static void TestCtor(Span<int> s, const int *ptr, size_t size) {
+  EXPECT_EQ(s.data(), ptr);
+  EXPECT_EQ(s.size(), size);
+}
+
+static void TestConstCtor(Span<const int> s, const int *ptr, size_t size) {
+  EXPECT_EQ(s.data(), ptr);
+  EXPECT_EQ(s.size(), size);
+}
+
+TEST(SpanTest, CtorEmpty) {
+  Span<int> s;
+  TestCtor(s, nullptr, 0);
+}
+
+TEST(SpanTest, CtorFromPtrAndSize) {
+  std::vector<int> v = {7, 8, 9, 10};
+  Span<int> s(v.data(), v.size());
+  TestCtor(s, v.data(), v.size());
+}
+
+TEST(SpanTest, CtorFromVector) {
+  std::vector<int> v = {1, 2};
+  // Const ctor is implicit.
+  TestConstCtor(v, v.data(), v.size());
+  // Mutable is explicit.
+  Span<int> s(v);
+  TestCtor(s, v.data(), v.size());
+}
+
+TEST(SpanTest, CtorConstFromArray) {
+  int v[] = {10, 11};
+  // Array ctor is implicit for const and mutable T.
+  TestConstCtor(v, v, 2);
+  TestCtor(v, v, 2);
+}
+
+TEST(SpanTest, MakeSpan) {
+  std::vector<int> v = {100, 200, 300};
+  TestCtor(MakeSpan(v), v.data(), v.size());
+  TestCtor(MakeSpan(v.data(), v.size()), v.data(), v.size());
+  TestConstCtor(MakeSpan(v.data(), v.size()), v.data(), v.size());
+  TestConstCtor(MakeSpan(v), v.data(), v.size());
+}
+
+TEST(SpanTest, MakeConstSpan) {
+  std::vector<int> v = {100, 200, 300};
+  TestConstCtor(MakeConstSpan(v), v.data(), v.size());
+  TestConstCtor(MakeConstSpan(v.data(), v.size()), v.data(), v.size());
+  // But not:
+  // TestConstCtor(MakeSpan(v), v.data(), v.size());
+}
+
+TEST(SpanTest, Accessor) {
+  std::vector<int> v({42, 23, 5, 101, 80});
+  Span<int> s(v);
+  for (size_t i = 0; i < s.size(); ++i) {
+    EXPECT_EQ(s[i], v[i]);
+    EXPECT_EQ(s.at(i), v.at(i));
+  }
+  EXPECT_EQ(s.begin(), v.data());
+  EXPECT_EQ(s.end(), v.data() + v.size());
+}
+
+}  // namespace
+}  // namespace bssl
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index a57298f..b410175 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -3693,6 +3693,194 @@
   EXPECT_EQ(Bytes("x"), Bytes(result, result_len));
 }
 
+TEST(SSLTest, SealRecord) {
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method())),
+      server_ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get(),
+                                     nullptr /* no session */));
+
+  const std::vector<uint8_t> record = {1, 2, 3, 4, 5};
+  std::vector<uint8_t> prefix(
+      bssl::SealRecordPrefixLen(client.get(), record.size())),
+      body(record.size()), suffix(bssl::SealRecordMaxSuffixLen(client.get()));
+  size_t suffix_size;
+  ASSERT_TRUE(bssl::SealRecord(client.get(), bssl::MakeSpan(prefix),
+                               bssl::MakeSpan(body), bssl::MakeSpan(suffix),
+                               &suffix_size, record));
+  suffix.resize(suffix_size);
+
+  std::vector<uint8_t> sealed;
+  sealed.insert(sealed.end(), prefix.begin(), prefix.end());
+  sealed.insert(sealed.end(), body.begin(), body.end());
+  sealed.insert(sealed.end(), suffix.begin(), suffix.end());
+  std::vector<uint8_t> sealed_copy = sealed;
+
+  bssl::Span<uint8_t> plaintext;
+  size_t record_len;
+  uint8_t alert = 255;
+  EXPECT_EQ(bssl::OpenRecord(server.get(), &plaintext, &record_len, &alert,
+                             bssl::MakeSpan(sealed)),
+            bssl::OpenRecordResult::kOK);
+  EXPECT_EQ(record_len, sealed.size());
+  EXPECT_EQ(plaintext, record);
+  EXPECT_EQ(255, alert);
+}
+
+TEST(SSLTest, SealRecordInPlace) {
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method())),
+      server_ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get(),
+                                     nullptr /* no session */));
+
+  const std::vector<uint8_t> plaintext = {1, 2, 3, 4, 5};
+  std::vector<uint8_t> record = plaintext;
+  std::vector<uint8_t> prefix(
+      bssl::SealRecordPrefixLen(client.get(), record.size())),
+      suffix(bssl::SealRecordMaxSuffixLen(client.get()));
+  size_t suffix_size;
+  ASSERT_TRUE(bssl::SealRecord(client.get(), bssl::MakeSpan(prefix),
+                               bssl::MakeSpan(record), bssl::MakeSpan(suffix),
+                               &suffix_size, record));
+  suffix.resize(suffix_size);
+  record.insert(record.begin(), prefix.begin(), prefix.end());
+  record.insert(record.end(), suffix.begin(), suffix.end());
+
+  bssl::Span<uint8_t> result;
+  size_t record_len;
+  uint8_t alert;
+  EXPECT_EQ(bssl::OpenRecord(server.get(), &result, &record_len, &alert,
+                             bssl::MakeSpan(record)),
+            bssl::OpenRecordResult::kOK);
+  EXPECT_EQ(record_len, record.size());
+  EXPECT_EQ(plaintext, result);
+}
+
+TEST(SSLTest, SealRecordTrailingData) {
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method())),
+      server_ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get(),
+                                     nullptr /* no session */));
+
+  const std::vector<uint8_t> plaintext = {1, 2, 3, 4, 5};
+  std::vector<uint8_t> record = plaintext;
+  std::vector<uint8_t> prefix(
+      bssl::SealRecordPrefixLen(client.get(), record.size())),
+      suffix(bssl::SealRecordMaxSuffixLen(client.get()));
+  size_t suffix_size;
+  ASSERT_TRUE(bssl::SealRecord(client.get(), bssl::MakeSpan(prefix),
+                               bssl::MakeSpan(record), bssl::MakeSpan(suffix),
+                               &suffix_size, record));
+  suffix.resize(suffix_size);
+  record.insert(record.begin(), prefix.begin(), prefix.end());
+  record.insert(record.end(), suffix.begin(), suffix.end());
+  record.insert(record.end(), {5, 4, 3, 2, 1});
+
+  bssl::Span<uint8_t> result;
+  size_t record_len;
+  uint8_t alert;
+  EXPECT_EQ(bssl::OpenRecord(server.get(), &result, &record_len, &alert,
+                             bssl::MakeSpan(record)),
+            bssl::OpenRecordResult::kOK);
+  EXPECT_EQ(record_len, record.size() - 5);
+  EXPECT_EQ(plaintext, result);
+}
+
+TEST(SSLTest, SealRecordInvalidSpanSize) {
+  bssl::UniquePtr<SSL_CTX> client_ctx(SSL_CTX_new(TLS_method())),
+      server_ctx(SSL_CTX_new(TLS_method()));
+  ASSERT_TRUE(client_ctx);
+  ASSERT_TRUE(server_ctx);
+
+  bssl::UniquePtr<X509> cert = GetTestCertificate();
+  bssl::UniquePtr<EVP_PKEY> key = GetTestKey();
+  ASSERT_TRUE(cert);
+  ASSERT_TRUE(key);
+  ASSERT_TRUE(SSL_CTX_use_certificate(server_ctx.get(), cert.get()));
+  ASSERT_TRUE(SSL_CTX_use_PrivateKey(server_ctx.get(), key.get()));
+
+  bssl::UniquePtr<SSL> client, server;
+  ASSERT_TRUE(ConnectClientAndServer(&client, &server, client_ctx.get(),
+                                     server_ctx.get(),
+                                     nullptr /* no session */));
+
+  std::vector<uint8_t> record = {1, 2, 3, 4, 5};
+  std::vector<uint8_t> prefix(
+      bssl::SealRecordPrefixLen(client.get(), record.size())),
+      suffix(bssl::SealRecordMaxSuffixLen(client.get()));
+  size_t suffix_size;
+
+  auto expect_err = []() {
+    int err = ERR_get_error();
+    EXPECT_EQ(ERR_GET_LIB(err), ERR_LIB_SSL);
+    EXPECT_EQ(ERR_GET_REASON(err), SSL_R_BUFFER_TOO_SMALL);
+    ERR_clear_error();
+  };
+  EXPECT_FALSE(bssl::SealRecord(
+      client.get(), bssl::MakeSpan(prefix.data(), prefix.size() - 1),
+      bssl::MakeSpan(record), bssl::MakeSpan(suffix), &suffix_size, record));
+  expect_err();
+  EXPECT_FALSE(bssl::SealRecord(
+      client.get(), bssl::MakeSpan(prefix.data(), prefix.size() + 1),
+      bssl::MakeSpan(record), bssl::MakeSpan(suffix), &suffix_size, record));
+  expect_err();
+
+  EXPECT_FALSE(
+      bssl::SealRecord(client.get(), bssl::MakeSpan(prefix),
+                       bssl::MakeSpan(record.data(), record.size() - 1),
+                       bssl::MakeSpan(suffix), &suffix_size, record));
+  expect_err();
+  EXPECT_FALSE(
+      bssl::SealRecord(client.get(), bssl::MakeSpan(prefix),
+                       bssl::MakeSpan(record.data(), record.size() + 1),
+                       bssl::MakeSpan(suffix), &suffix_size, record));
+  expect_err();
+
+  EXPECT_FALSE(bssl::SealRecord(
+      client.get(), bssl::MakeSpan(prefix), bssl::MakeSpan(record),
+      bssl::MakeSpan(suffix.data(), suffix.size() - 1), &suffix_size, record));
+  expect_err();
+  EXPECT_FALSE(bssl::SealRecord(
+      client.get(), bssl::MakeSpan(prefix), bssl::MakeSpan(record),
+      bssl::MakeSpan(suffix.data(), suffix.size() + 1), &suffix_size, record));
+  expect_err();
+}
+
 // TODO(davidben): Convert this file to GTest properly.
 TEST(SSLTest, AllTests) {
   if (!TestSSL_SESSIONEncoding(kOpenSSLSession) ||
diff --git a/ssl/tls_record.cc b/ssl/tls_record.cc
index 437d02f..46132e1 100644
--- a/ssl/tls_record.cc
+++ b/ssl/tls_record.cc
@@ -403,7 +403,7 @@
 }
 
 static size_t tls_seal_scatter_prefix_len(const SSL *ssl, uint8_t type,
-                                          size_t in_len) {
+                                   size_t in_len) {
   size_t ret = SSL3_RT_HEADER_LENGTH;
   if (type == SSL3_RT_APPLICATION_DATA && in_len > 1 &&
       ssl_needs_record_splitting(ssl)) {
@@ -419,10 +419,20 @@
   return ret;
 }
 
+static size_t tls_seal_scatter_max_suffix_len(const SSL *ssl) {
+  size_t ret = ssl->s3->aead_write_ctx->MaxOverhead();
+  /* TLS 1.3 needs an extra byte for the encrypted record type. */
+  if (ssl->s3->aead_write_ctx->is_null_cipher() &&
+      ssl->s3->aead_write_ctx->version() >= TLS1_3_VERSION) {
+    ret += 1;
+  }
+  return ret;
+}
+
 /* tls_seal_scatter_record seals a new record of type |type| and body |in| and
  * splits it between |out_prefix|, |out|, and |out_suffix|. Exactly
  * |tls_seal_scatter_prefix_len| bytes are written to |out_prefix|, |in_len|
- * bytes to |out|, and up to 1 + |SSLAEADContext::MaxOverhead| bytes to
+ * bytes to |out|, and up to |tls_seal_scatter_max_suffix_len| bytes to
  * |out_suffix|. |*out_suffix_len| is set to the actual number of bytes written
  * to |out_suffix|. It returns one on success and zero on error. If enabled,
  * |tls_seal_scatter_record| implements TLS 1.0 CBC 1/n-1 record splitting and
@@ -567,6 +577,91 @@
   return ssl_open_record_error;
 }
 
+OpenRecordResult OpenRecord(SSL *ssl, Span<uint8_t> *out,
+                            size_t *out_record_len, uint8_t *out_alert,
+                            const Span<uint8_t> in) {
+  // This API is a work in progress and currently only works for TLS 1.2 servers
+  // and below.
+  if (SSL_in_init(ssl) ||
+      SSL_is_dtls(ssl) ||
+      ssl3_protocol_version(ssl) > TLS1_2_VERSION) {
+    assert(false);
+    *out_alert = SSL_AD_INTERNAL_ERROR;
+    return OpenRecordResult::kError;
+  }
+
+  *out = Span<uint8_t>();
+  *out_record_len = 0;
+
+  CBS plaintext;
+  uint8_t type;
+  size_t record_len;
+  const ssl_open_record_t result = tls_open_record(
+      ssl, &type, &plaintext, &record_len, out_alert, in.data(), in.size());
+  if (type != SSL3_RT_APPLICATION_DATA && type != SSL3_RT_ALERT) {
+    *out_alert = SSL_AD_UNEXPECTED_MESSAGE;
+    return OpenRecordResult::kError;
+  }
+
+  OpenRecordResult ret = OpenRecordResult::kError;
+  switch (result) {
+    case ssl_open_record_success:
+      ret = OpenRecordResult::kOK;
+      break;
+    case ssl_open_record_discard:
+      ret = OpenRecordResult::kDiscard;
+      break;
+    case ssl_open_record_partial:
+      ret = OpenRecordResult::kIncompleteRecord;
+      break;
+    case ssl_open_record_close_notify:
+      ret = OpenRecordResult::kAlertCloseNotify;
+      break;
+    case ssl_open_record_fatal_alert:
+      ret = OpenRecordResult::kAlertFatal;
+      break;
+    case ssl_open_record_error:
+      ret = OpenRecordResult::kError;
+      break;
+  }
+  *out =
+      MakeSpan(const_cast<uint8_t*>(CBS_data(&plaintext)), CBS_len(&plaintext));
+  *out_record_len = record_len;
+  return ret;
+}
+
+size_t SealRecordPrefixLen(SSL *ssl, size_t record_len) {
+  return tls_seal_scatter_prefix_len(ssl, SSL3_RT_APPLICATION_DATA, record_len);
+}
+
+size_t SealRecordMaxSuffixLen(SSL *ssl) {
+  return tls_seal_scatter_max_suffix_len(ssl);
+}
+
+bool SealRecord(SSL *ssl, const Span<uint8_t> out_prefix,
+                const Span<uint8_t> out, Span<uint8_t> out_suffix,
+                size_t *out_suffix_len, const Span<const uint8_t> in) {
+  // This API is a work in progress and currently only works for TLS 1.2 servers
+  // and below.
+  if (SSL_in_init(ssl) ||
+      SSL_is_dtls(ssl) ||
+      ssl3_protocol_version(ssl) > TLS1_2_VERSION) {
+    assert(false);
+    OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+    return false;
+  }
+
+  if (out_prefix.size() != SealRecordPrefixLen(ssl, in.size()) ||
+      out.size() != in.size() ||
+      out_suffix.size() != SealRecordMaxSuffixLen(ssl)) {
+    OPENSSL_PUT_ERROR(SSL, SSL_R_BUFFER_TOO_SMALL);
+    return false;
+  }
+  return tls_seal_scatter_record(
+      ssl, out_prefix.data(), out.data(), out_suffix.data(), out_suffix_len,
+      out_suffix.size(), SSL3_RT_APPLICATION_DATA, in.data(), in.size());
+}
+
 }  // namespace bssl
 
 using namespace bssl;