Make QUIC tests work with early data.

This changes the format of the mock QUIC transport to include an
explicit encryption level, matching real QUIC a bit better. In
particular, we need that extra data to properly skip rejected early data
on the shim side. (On the runner, we manage it by synchronizing with the
TLS stack. Still, the levels make it a bit more accurate.)

Testing sending and receiving of actual early data is not very relevant
in QUIC since application I/O is external, but this allows us to more
easily run the same tests in TLS and QUIC.

Along the way, improve error-reporting in mock_quick_transport.cc so
it's easier to diagnose record-level mismatches.

Change-Id: I96175a4023134b03d61dac089f8e7ff4eb627933
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/44988
Commit-Queue: David Benjamin <davidben@google.com>
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/ssl/test/mock_quic_transport.cc b/ssl/test/mock_quic_transport.cc
index 6a3f0e8..45d664a 100644
--- a/ssl/test/mock_quic_transport.cc
+++ b/ssl/test/mock_quic_transport.cc
@@ -73,47 +73,97 @@
   return true;
 }
 
+const char *LevelToString(ssl_encryption_level_t level) {
+  switch (level) {
+    case ssl_encryption_initial:
+      return "initial";
+    case ssl_encryption_early_data:
+      return "early_data";
+    case ssl_encryption_handshake:
+      return "handshake";
+    case ssl_encryption_application:
+      return "application";
+  }
+  return "";
+}
+
 }  // namespace
 
-bool MockQuicTransport::ReadHeader(uint8_t *out_tag, size_t *out_len) {
-  uint8_t header[7];
-  if (!ReadAll(bio_.get(), header)) {
-    return false;
-  }
-  *out_tag = header[0];
-  uint16_t cipher_suite = header[1] << 8 | header[2];
-  size_t remaining_bytes =
-      header[3] << 24 | header[4] << 16 | header[5] << 8 | header[6];
-
-  enum ssl_encryption_level_t level = SSL_quic_read_level(ssl_);
-  if (*out_tag == kTagApplication) {
-    if (SSL_in_early_data(ssl_)) {
-      level = ssl_encryption_early_data;
-    } else {
-      level = ssl_encryption_application;
+bool MockQuicTransport::ReadHeader(uint8_t *out_tag,
+                                   enum ssl_encryption_level_t *out_level,
+                                   size_t *out_len) {
+  for (;;) {
+    uint8_t header[8];
+    if (!ReadAll(bio_.get(), header)) {
+      // TODO(davidben): Distinguish between errors and EOF. See
+      // ReadApplicationData.
+      return false;
     }
+
+    CBS cbs;
+    uint8_t level_id;
+    uint16_t cipher_suite;
+    uint32_t remaining_bytes;
+    CBS_init(&cbs, header, sizeof(header));
+    if (!CBS_get_u8(&cbs, out_tag) ||
+        !CBS_get_u8(&cbs, &level_id) ||
+        !CBS_get_u16(&cbs, &cipher_suite) ||
+        !CBS_get_u32(&cbs, &remaining_bytes) ||
+        level_id >= read_levels_.size()) {
+      fprintf(stderr, "Error parsing record header.\n");
+      return false;
+    }
+
+    auto level = static_cast<ssl_encryption_level_t>(level_id);
+    // Non-initial levels must be configured before use.
+    uint16_t expect_cipher = read_levels_[level].cipher;
+    if (expect_cipher == 0 && level != ssl_encryption_initial) {
+      if (level == ssl_encryption_early_data) {
+        // If we receive early data records without any early data keys, skip
+        // the record. This means early data was rejected.
+        std::vector<uint8_t> discard(remaining_bytes);
+        if (!ReadAll(bio_.get(), bssl::MakeSpan(discard))) {
+          return false;
+        }
+        continue;
+      }
+      fprintf(stderr,
+              "Got record at level %s, but keys were not configured.\n",
+              LevelToString(level));
+      return false;
+    }
+    if (cipher_suite != expect_cipher) {
+      fprintf(stderr, "Got cipher suite 0x%04x at level %s, wanted 0x%04x.\n",
+              cipher_suite, LevelToString(level), expect_cipher);
+      return false;
+    }
+    const std::vector<uint8_t> &secret = read_levels_[level].secret;
+    std::vector<uint8_t> read_secret(secret.size());
+    if (remaining_bytes < secret.size()) {
+      fprintf(stderr, "Record at level %s too small.\n", LevelToString(level));
+      return false;
+    }
+    remaining_bytes -= secret.size();
+    if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret))) {
+      fprintf(stderr, "Error reading record secret.\n");
+      return false;
+    }
+    if (read_secret != secret) {
+      fprintf(stderr, "Encryption secret at level %s did not match.\n",
+              LevelToString(level));
+      return false;
+    }
+    *out_level = level;
+    *out_len = remaining_bytes;
+    return true;
   }
-  if (cipher_suite != read_levels_[level].cipher) {
-    return false;
-  }
-  const std::vector<uint8_t> &secret = read_levels_[level].secret;
-  std::vector<uint8_t> read_secret(secret.size());
-  if (remaining_bytes < secret.size()) {
-    return false;
-  }
-  remaining_bytes -= secret.size();
-  if (!ReadAll(bio_.get(), bssl::MakeSpan(read_secret)) ||
-      read_secret != secret) {
-    return false;
-  }
-  *out_len = remaining_bytes;
-  return true;
 }
 
 bool MockQuicTransport::ReadHandshake() {
   uint8_t tag;
+  ssl_encryption_level_t level;
   size_t len;
-  if (!ReadHeader(&tag, &len)) {
+  if (!ReadHeader(&tag, &level, &len)) {
     return false;
   }
   if (tag != kTagHandshake) {
@@ -124,8 +174,7 @@
   if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) {
     return false;
   }
-  return SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(),
-                               buf.size());
+  return SSL_provide_quic_data(ssl_, level, buf.data(), buf.size());
 }
 
 int MockQuicTransport::ReadApplicationData(uint8_t *out, size_t max_out) {
@@ -144,9 +193,10 @@
   }
 
   uint8_t tag = 0;
+  ssl_encryption_level_t level;
   size_t len;
   while (true) {
-    if (!ReadHeader(&tag, &len)) {
+    if (!ReadHeader(&tag, &level, &len)) {
       // Assume that a failure to read the header means there's no more to read,
       // not an error reading.
       return 0;
@@ -162,8 +212,7 @@
     if (!ReadAll(bio_.get(), bssl::MakeSpan(buf))) {
       return -1;
     }
-    if (SSL_provide_quic_data(ssl_, SSL_quic_read_level(ssl_), buf.data(),
-                              buf.size()) != 1) {
+    if (SSL_provide_quic_data(ssl_, level, buf.data(), buf.size()) != 1) {
       return -1;
     }
     if (SSL_in_init(ssl_)) {
@@ -203,14 +252,15 @@
   uint16_t cipher_suite = write_levels_[level].cipher;
   const std::vector<uint8_t> &secret = write_levels_[level].secret;
   size_t tlv_len = secret.size() + len;
-  uint8_t header[7];
+  uint8_t header[8];
   header[0] = tag;
-  header[1] = (cipher_suite >> 8) & 0xff;
-  header[2] = cipher_suite & 0xff;
-  header[3] = (tlv_len >> 24) & 0xff;
-  header[4] = (tlv_len >> 16) & 0xff;
-  header[5] = (tlv_len >> 8) & 0xff;
-  header[6] = tlv_len & 0xff;
+  header[1] = level;
+  header[2] = (cipher_suite >> 8) & 0xff;
+  header[3] = cipher_suite & 0xff;
+  header[4] = (tlv_len >> 24) & 0xff;
+  header[5] = (tlv_len >> 16) & 0xff;
+  header[6] = (tlv_len >> 8) & 0xff;
+  header[7] = tlv_len & 0xff;
   return BIO_write_all(bio_.get(), header, sizeof(header)) &&
          BIO_write_all(bio_.get(), secret.data(), secret.size()) &&
          BIO_write_all(bio_.get(), data, len);
diff --git a/ssl/test/mock_quic_transport.h b/ssl/test/mock_quic_transport.h
index a56652d..114f059 100644
--- a/ssl/test/mock_quic_transport.h
+++ b/ssl/test/mock_quic_transport.h
@@ -45,10 +45,12 @@
   // Reads a record header from |bio_| and returns whether the record was read
   // successfully. As part of reading the header, this function checks that the
   // cipher suite and secret in the header are correct. On success, the tag
-  // indicating the TLS record type is put in  |*out_tag|, the length of the TLS
-  // record is put in |*out_len|, and the next thing to be read from |bio_| is
-  // |*out_len| bytes of the TLS record.
-  bool ReadHeader(uint8_t *out_tag, size_t *out_len);
+  // indicating the TLS record type is put in |*out_tag|, the encryption level
+  // is put in |*out_level|, the length of the TLS record is put in |*out_len|,
+  // and the next thing to be read from |bio_| is |*out_len| bytes of the TLS
+  // record.
+  bool ReadHeader(uint8_t *out_tag, enum ssl_encryption_level_t *out_level,
+                  size_t *out_len);
 
   // Writes a MockQuicTransport record to |bio_| at encryption level |level|
   // with record type |tag| and a TLS record payload of length |len| from
diff --git a/ssl/test/runner/conn.go b/ssl/test/runner/conn.go
index c0c91d2..9fa5c05 100644
--- a/ssl/test/runner/conn.go
+++ b/ssl/test/runner/conn.go
@@ -754,7 +754,7 @@
 	return b, bb
 }
 
-func (c *Conn) useInTrafficSecret(version uint16, suite *cipherSuite, secret []byte) error {
+func (c *Conn) useInTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) error {
 	if c.hand.Len() != 0 {
 		return c.in.setErrorLocked(errors.New("tls: buffered handshake messages on cipher change"))
 	}
@@ -763,6 +763,7 @@
 		side = clientWrite
 	}
 	if c.config.Bugs.MockQUICTransport != nil {
+		c.config.Bugs.MockQUICTransport.readLevel = level
 		c.config.Bugs.MockQUICTransport.readSecret = secret
 		c.config.Bugs.MockQUICTransport.readCipherSuite = suite.id
 	}
@@ -771,12 +772,13 @@
 	return nil
 }
 
-func (c *Conn) useOutTrafficSecret(version uint16, suite *cipherSuite, secret []byte) {
+func (c *Conn) useOutTrafficSecret(level encryptionLevel, version uint16, suite *cipherSuite, secret []byte) {
 	side := serverWrite
 	if c.isClient {
 		side = clientWrite
 	}
 	if c.config.Bugs.MockQUICTransport != nil {
+		c.config.Bugs.MockQUICTransport.writeLevel = level
 		c.config.Bugs.MockQUICTransport.writeSecret = secret
 		c.config.Bugs.MockQUICTransport.writeCipherSuite = suite.id
 	}
@@ -1677,7 +1679,7 @@
 		if c.config.Bugs.RejectUnsolicitedKeyUpdate {
 			return errors.New("tls: unexpected KeyUpdate message")
 		}
-		if err := c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil {
+		if err := c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret)); err != nil {
 			return err
 		}
 		if keyUpdate.keyUpdateRequest == keyUpdateRequested {
@@ -1711,7 +1713,7 @@
 		return errors.New("tls: received invalid KeyUpdate message")
 	}
 
-	return c.useInTrafficSecret(c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret))
+	return c.useInTrafficSecret(encryptionApplication, c.in.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.in.trafficSecret))
 }
 
 func (c *Conn) Renegotiate() error {
@@ -2065,7 +2067,7 @@
 	if err := c.flushHandshake(); err != nil {
 		return err
 	}
-	c.useOutTrafficSecret(c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret))
+	c.useOutTrafficSecret(encryptionApplication, c.out.wireVersion, c.cipherSuite, updateTrafficSecret(c.cipherSuite.hash(), c.wireVersion, c.out.trafficSecret))
 	return nil
 }
 
diff --git a/ssl/test/runner/handshake_client.go b/ssl/test/runner/handshake_client.go
index bf89c01..ad01f1e 100644
--- a/ssl/test/runner/handshake_client.go
+++ b/ssl/test/runner/handshake_client.go
@@ -518,7 +518,7 @@
 		earlyTrafficSecret := finishedHash.deriveSecret(earlyTrafficLabel)
 		c.earlyExporterSecret = finishedHash.deriveSecret(earlyExporterLabel)
 
-		c.useOutTrafficSecret(session.wireVersion, pskCipherSuite, earlyTrafficSecret)
+		c.useOutTrafficSecret(encryptionEarlyData, session.wireVersion, pskCipherSuite, earlyTrafficSecret)
 		for _, earlyData := range c.config.Bugs.SendEarlyData {
 			if _, err := c.writeRecord(recordTypeApplicationData, earlyData); err != nil {
 				return err
@@ -923,7 +923,7 @@
 	// traffic key.
 	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
 	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
-	if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil {
+	if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret); err != nil {
 		return err
 	}
 
@@ -1098,7 +1098,7 @@
 
 	// Switch to application data keys on read. In particular, any alerts
 	// from the client certificate are read over these keys.
-	if err := c.useInTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret); err != nil {
+	if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret); err != nil {
 		return err
 	}
 
@@ -1133,7 +1133,7 @@
 
 	// Send EndOfEarlyData and then switch write key to handshake
 	// traffic key.
-	if encryptedExtensions.extensions.hasEarlyData && c.out.cipher != nil && !c.config.Bugs.SkipEndOfEarlyData {
+	if encryptedExtensions.extensions.hasEarlyData && !c.config.Bugs.SkipEndOfEarlyData && c.config.Bugs.MockQUICTransport == nil {
 		if c.config.Bugs.SendStrayEarlyHandshake {
 			helloRequest := new(helloRequestMsg)
 			c.writeRecord(recordTypeHandshake, helloRequest.marshal())
@@ -1157,7 +1157,7 @@
 		c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
 	}
 
-	c.useOutTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
+	c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret)
 
 	// The client EncryptedExtensions message is sent if some extension uses it.
 	// (Currently only ALPS does.)
@@ -1263,7 +1263,7 @@
 	c.flushHandshake()
 
 	// Switch to application data keys.
-	c.useOutTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret)
+	c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret)
 	c.resumptionSecret = hs.finishedHash.deriveSecret(resumptionLabel)
 	for _, ticket := range deferredTickets {
 		if err := c.processTLS13NewSessionTicket(ticket, hs.suite); err != nil {
diff --git a/ssl/test/runner/handshake_server.go b/ssl/test/runner/handshake_server.go
index df74ccd..3cdebef 100644
--- a/ssl/test/runner/handshake_server.go
+++ b/ssl/test/runner/handshake_server.go
@@ -747,7 +747,7 @@
 			}
 
 			sessionCipher := cipherSuiteFromID(hs.sessionState.cipherSuite)
-			if err := c.useInTrafficSecret(c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil {
+			if err := c.useInTrafficSecret(encryptionEarlyData, c.wireVersion, sessionCipher, earlyTrafficSecret); err != nil {
 				return err
 			}
 
@@ -854,7 +854,7 @@
 
 	// Switch to handshake traffic keys.
 	serverHandshakeTrafficSecret := hs.finishedHash.deriveSecret(serverHandshakeTrafficLabel)
-	c.useOutTrafficSecret(c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
+	c.useOutTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, serverHandshakeTrafficSecret)
 	// Derive handshake traffic read key, but don't switch yet.
 	clientHandshakeTrafficSecret := hs.finishedHash.deriveSecret(clientHandshakeTrafficLabel)
 
@@ -1038,7 +1038,7 @@
 
 	// Switch to application data keys on write. In particular, any alerts
 	// from the client certificate are sent over these keys.
-	c.useOutTrafficSecret(c.wireVersion, hs.suite, serverTrafficSecret)
+	c.useOutTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, serverTrafficSecret)
 
 	// Send 0.5-RTT messages.
 	for _, halfRTTMsg := range config.Bugs.SendHalfRTTData {
@@ -1063,7 +1063,7 @@
 	}
 
 	// Switch input stream to handshake traffic keys.
-	if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil {
+	if err := c.useInTrafficSecret(encryptionHandshake, c.wireVersion, hs.suite, clientHandshakeTrafficSecret); err != nil {
 		return err
 	}
 
@@ -1192,7 +1192,7 @@
 	hs.writeClientHash(clientFinished.marshal())
 
 	// Switch to application data keys on read.
-	if err := c.useInTrafficSecret(c.wireVersion, hs.suite, clientTrafficSecret); err != nil {
+	if err := c.useInTrafficSecret(encryptionApplication, c.wireVersion, hs.suite, clientTrafficSecret); err != nil {
 		return err
 	}
 
diff --git a/ssl/test/runner/mock_quic_transport.go b/ssl/test/runner/mock_quic_transport.go
index 27a8bc3..99ce0f7 100644
--- a/ssl/test/runner/mock_quic_transport.go
+++ b/ssl/test/runner/mock_quic_transport.go
@@ -26,6 +26,15 @@
 const tagApplication = byte('A')
 const tagAlert = byte('L')
 
+type encryptionLevel byte
+
+const (
+	encryptionInitial     encryptionLevel = 0
+	encryptionEarlyData   encryptionLevel = 1
+	encryptionHandshake   encryptionLevel = 2
+	encryptionApplication encryptionLevel = 3
+)
+
 // mockQUICTransport provides a record layer for sending/receiving messages
 // when testing TLS over QUIC. It is only intended for testing, as it runs over
 // an in-order reliable transport, looks nothing like the QUIC wire image, and
@@ -43,6 +52,7 @@
 // cipher suite ID or tag.
 type mockQUICTransport struct {
 	net.Conn
+	readLevel, writeLevel             encryptionLevel
 	readSecret, writeSecret           []byte
 	readCipherSuite, writeCipherSuite uint16
 	skipEarlyData                     bool
@@ -54,37 +64,39 @@
 
 func (m *mockQUICTransport) read() (byte, []byte, error) {
 	for {
-		header := make([]byte, 7)
+		header := make([]byte, 8)
 		if _, err := io.ReadFull(m.Conn, header); err != nil {
 			return 0, nil, err
 		}
-		cipherSuite := binary.BigEndian.Uint16(header[1:3])
-		length := binary.BigEndian.Uint32(header[3:])
+		tag := header[0]
+		level := header[1]
+		cipherSuite := binary.BigEndian.Uint16(header[2:4])
+		length := binary.BigEndian.Uint32(header[4:])
 		value := make([]byte, length)
 		if _, err := io.ReadFull(m.Conn, value); err != nil {
-			return 0, nil, fmt.Errorf("Error reading record")
+			return 0, nil, fmt.Errorf("error reading record")
 		}
-		if cipherSuite != m.readCipherSuite {
-			if m.skipEarlyData {
+		if level != byte(m.readLevel) {
+			if m.skipEarlyData && level == byte(encryptionEarlyData) {
 				continue
 			}
-			return 0, nil, fmt.Errorf("Received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite)
+			return 0, nil, fmt.Errorf("received level %d does not match expected %d", level, m.readLevel)
+		}
+		if cipherSuite != m.readCipherSuite {
+			return 0, nil, fmt.Errorf("received cipher suite %d does not match expected %d", cipherSuite, m.readCipherSuite)
 		}
 		if len(m.readSecret) > len(value) {
-			return 0, nil, fmt.Errorf("Input length too short")
+			return 0, nil, fmt.Errorf("input length too short")
 		}
 		secret := value[:len(m.readSecret)]
 		out := value[len(m.readSecret):]
 		if !bytes.Equal(secret, m.readSecret) {
-			if m.skipEarlyData {
-				continue
-			}
 			return 0, nil, fmt.Errorf("secrets don't match: got %x but expected %x", secret, m.readSecret)
 		}
-		if m.skipEarlyData && header[0] == tagHandshake {
-			m.skipEarlyData = false
-		}
-		return header[0], out, nil
+		// Although not true for QUIC in general, our transport is ordered, so
+		// we expect to stop skipping early data after a valid record.
+		m.skipEarlyData = false
+		return tag, out, nil
 	}
 }
 
@@ -114,12 +126,13 @@
 		return 0, fmt.Errorf("unsupported record type %d\n", typ)
 	}
 	length := len(m.writeSecret) + len(data)
-	payload := make([]byte, 1+2+4+length)
+	payload := make([]byte, 1+1+2+4+length)
 	payload[0] = tag
-	binary.BigEndian.PutUint16(payload[1:3], m.writeCipherSuite)
-	binary.BigEndian.PutUint32(payload[3:7], uint32(length))
-	copy(payload[7:], m.writeSecret)
-	copy(payload[7+len(m.writeSecret):], data)
+	payload[1] = byte(m.writeLevel)
+	binary.BigEndian.PutUint16(payload[2:4], m.writeCipherSuite)
+	binary.BigEndian.PutUint32(payload[4:8], uint32(length))
+	copy(payload[8:], m.writeSecret)
+	copy(payload[8+len(m.writeSecret):], data)
 	if _, err := m.Conn.Write(payload); err != nil {
 		return 0, err
 	}
diff --git a/ssl/test/runner/runner.go b/ssl/test/runner/runner.go
index f1ec122..f3847d6 100644
--- a/ssl/test/runner/runner.go
+++ b/ssl/test/runner/runner.go
@@ -1350,6 +1350,11 @@
 			flags = append(flags, "-on-resume-expect-accept-early-data")
 		}
 
+		if test.protocol == quic {
+			// QUIC requires an early data context string.
+			flags = append(flags, "-quic-early-data-context", "context")
+		}
+
 		flags = append(flags, "-enable-early-data")
 		if test.testType == clientTest {
 			// Configure the runner with default maximum early data.
diff --git a/ssl/test/test_config.cc b/ssl/test/test_config.cc
index eb863eb..4f84867 100644
--- a/ssl/test/test_config.cc
+++ b/ssl/test/test_config.cc
@@ -183,6 +183,7 @@
     {"-handshaker-path", &TestConfig::handshaker_path},
     {"-delegated-credential", &TestConfig::delegated_credential},
     {"-expect-early-data-reason", &TestConfig::expect_early_data_reason},
+    {"-quic-early-data-context", &TestConfig::quic_early_data_context},
 };
 
 // TODO(davidben): When we can depend on C++17 or Abseil, switch this to
@@ -1797,5 +1798,13 @@
     }
   }
 
+  if (!quic_early_data_context.empty() &&
+      !SSL_set_quic_early_data_context(
+          ssl.get(),
+          reinterpret_cast<const uint8_t *>(quic_early_data_context.data()),
+          quic_early_data_context.size())) {
+    return nullptr;
+  }
+
   return ssl;
 }
diff --git a/ssl/test/test_config.h b/ssl/test/test_config.h
index 67cab95..9279fca 100644
--- a/ssl/test/test_config.h
+++ b/ssl/test/test_config.h
@@ -182,6 +182,7 @@
   bool expect_hrr = false;
   bool expect_no_hrr = false;
   bool wait_for_debugger = false;
+  std::string quic_early_data_context;
 
   int argc;
   char **argv;