Add some tests for SSL_CTX_set_keylog_callback

We actually don't test this at all right now.

Change-Id: Iaac8850da3c012cbd21d0f38b026e7ff14db3650
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/67828
Reviewed-by: Bob Beck <bbe@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/include/openssl/ssl.h b/include/openssl/ssl.h
index 49fb47e..97f1c89 100644
--- a/include/openssl/ssl.h
+++ b/include/openssl/ssl.h
@@ -4375,7 +4375,7 @@
 // access to the log.
 //
 // The format is described in
-// https://developer.mozilla.org/en-US/docs/Mozilla/Projects/NSS/Key_Log_Format.
+// https://www.ietf.org/archive/id/draft-ietf-tls-keylogfile-01.html
 OPENSSL_EXPORT void SSL_CTX_set_keylog_callback(
     SSL_CTX *ctx, void (*cb)(const SSL *ssl, const char *line));
 
diff --git a/ssl/ssl_test.cc b/ssl/ssl_test.cc
index 503ad5f..f39d042 100644
--- a/ssl/ssl_test.cc
+++ b/ssl/ssl_test.cc
@@ -23,6 +23,7 @@
 #include <utility>
 #include <vector>
 
+#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
 #include <openssl/aead.h>
@@ -60,6 +61,9 @@
 #endif
 
 
+using testing::ElementsAre;
+using testing::Key;
+
 BSSL_NAMESPACE_BEGIN
 
 namespace {
@@ -910,6 +914,35 @@
   return true;
 }
 
+static bool DecodeLowerHex(std::vector<uint8_t> *out,
+                           bssl::Span<const char> in) {
+  if (in.size() % 2 != 0) {
+    return false;
+  }
+  out->resize(in.size() / 2);
+  for (size_t i = 0; i < out->size(); i++) {
+    char hi = in[2 * i], lo = in[2 * i + 1];
+    uint8_t b = 0;
+    if ('0' <= hi && hi <= '9') {
+      b |= hi - '0';
+    } else if ('a' <= hi && hi <= 'f') {
+      b |= hi - 'a' + 10;
+    } else {
+      return false;
+    }
+    b <<= 4;
+    if ('0' <= lo && lo <= '9') {
+      b |= lo - '0';
+    } else if ('a' <= lo && lo <= 'f') {
+      b |= lo - 'a' + 10;
+    } else {
+      return false;
+    }
+    (*out)[i] = b;
+  }
+  return true;
+}
+
 TEST(SSLTest, SessionEncoding) {
   for (const char *input_b64 : {
            kOpenSSLSession,
@@ -9332,5 +9365,80 @@
   }
 }
 
+TEST_P(SSLVersionTest, KeyLog) {
+  using KeyLog = std::map<std::string, std::vector<uint8_t>>;
+  KeyLog client_log, server_log;
+
+  SSL_CTX_set_app_data(client_ctx_.get(), &client_log);
+  SSL_CTX_set_app_data(server_ctx_.get(), &server_log);
+
+  auto keylog_callback = [](const SSL *ssl, const char *line) {
+    SSL_CTX *ctx = SSL_get_SSL_CTX(ssl);
+    KeyLog *log = static_cast<KeyLog *>(SSL_CTX_get_app_data(ctx));
+    ASSERT_TRUE(log);
+
+    const char *space1 = strchr(line, ' ');
+    ASSERT_TRUE(space1);
+    std::string name(line, space1 - line);
+    space1++;
+    const char *space2 = strchr(space1, ' ');
+    ASSERT_TRUE(space2);
+    bssl::Span<const char> client_random_hex(space1, space2 - space1);
+    space2++;
+    bssl::Span<const char> secret_hex(space2, strlen(space2));
+
+    std::vector<uint8_t> client_random, secret;
+    ASSERT_TRUE(DecodeLowerHex(&client_random, client_random_hex));
+    ASSERT_TRUE(DecodeLowerHex(&secret, secret_hex));
+
+    // The client_random field identifies the connection. Check it matches
+    // the connection.
+    uint8_t expected_random[SSL3_RANDOM_SIZE];
+    ASSERT_EQ(
+        sizeof(expected_random),
+        SSL_get_client_random(ssl, expected_random, sizeof(expected_random)));
+    ASSERT_EQ(Bytes(expected_random), Bytes(client_random));
+
+    ASSERT_EQ(log->count(name), 0u) << "duplicate name " << name;
+    log->emplace(std::move(name), std::move(secret));
+  };
+  SSL_CTX_set_keylog_callback(client_ctx_.get(), keylog_callback);
+  SSL_CTX_set_keylog_callback(server_ctx_.get(), keylog_callback);
+
+  // Connect and capture the various secrets.
+  ASSERT_TRUE(Connect());
+
+  // Check that we logged the secrets we expected to log.
+  if (version() == TLS1_3_VERSION) {
+    EXPECT_THAT(client_log, ElementsAre(Key("CLIENT_HANDSHAKE_TRAFFIC_SECRET"),
+                                        Key("CLIENT_TRAFFIC_SECRET_0"),
+                                        Key("EXPORTER_SECRET"),
+                                        Key("SERVER_HANDSHAKE_TRAFFIC_SECRET"),
+                                        Key("SERVER_TRAFFIC_SECRET_0")));
+
+    // Ideally we'd check the other values, but those are harder to check
+    // without actually decrypting the records.
+    Span<const uint8_t> read_secret, write_secret;
+    ASSERT_TRUE(bssl::SSL_get_traffic_secrets(client_.get(), &read_secret,
+                                              &write_secret));
+    EXPECT_EQ(Bytes(read_secret), Bytes(client_log["SERVER_TRAFFIC_SECRET_0"]));
+    EXPECT_EQ(Bytes(write_secret),
+              Bytes(client_log["CLIENT_TRAFFIC_SECRET_0"]));
+  } else {
+    EXPECT_THAT(client_log, ElementsAre(Key("CLIENT_RANDOM")));
+
+    size_t len =
+        SSL_SESSION_get_master_key(SSL_get0_session(client_.get()), nullptr, 0);
+    std::vector<uint8_t> expected(len);
+    ASSERT_EQ(SSL_SESSION_get_master_key(SSL_get0_session(client_.get()),
+                                         expected.data(), expected.size()),
+              expected.size());
+    EXPECT_EQ(Bytes(expected), Bytes(client_log["CLIENT_RANDOM"]));
+  }
+
+  // The server should have logged the same secrets as the client.
+  EXPECT_EQ(client_log, server_log);
+}
+
 }  // namespace
 BSSL_NAMESPACE_END