| // Copyright (c) 2020, Google Inc. |
| // |
| // Permission to use, copy, modify, and/or distribute this software for any |
| // purpose with or without fee is hereby granted, provided that the above |
| // copyright notice and this permission notice appear in all copies. |
| // |
| // THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES |
| // WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF |
| // MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY |
| // SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES |
| // WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION |
| // OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN |
| // CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. |
| |
| package hpke |
| |
| import ( |
| "bytes" |
| _ "crypto/sha256" |
| _ "crypto/sha512" |
| "encoding/hex" |
| "encoding/json" |
| "errors" |
| "flag" |
| "fmt" |
| "io/ioutil" |
| "path/filepath" |
| "testing" |
| ) |
| |
| const ( |
| exportOnlyAEAD uint16 = 0xffff |
| ) |
| |
| var ( |
| testDataDir = flag.String("testdata", "testdata", "The path to the test vector JSON file.") |
| ) |
| |
| // Simple round-trip test for fixed inputs. |
| func TestRoundTrip(t *testing.T) { |
| publicKeyR, secretKeyR, err := GenerateKeyPairX25519() |
| if err != nil { |
| t.Errorf("failed to generate key pair: %s", err) |
| return |
| } |
| |
| // Set up the sender and receiver contexts. |
| senderContext, enc, err := SetupBaseSenderX25519(HKDFSHA256, AES256GCM, publicKeyR, nil, nil) |
| if err != nil { |
| t.Errorf("failed to set up sender: %s", err) |
| return |
| } |
| receiverContext, err := SetupBaseReceiverX25519(HKDFSHA256, AES256GCM, enc, secretKeyR, nil) |
| if err != nil { |
| t.Errorf("failed to set up receiver: %s", err) |
| return |
| } |
| |
| // Seal() our plaintext with the sender context, then Open() the |
| // ciphertext with the receiver context. |
| plaintext := []byte("foobar") |
| ciphertext := senderContext.Seal(plaintext, nil) |
| decrypted, err := receiverContext.Open(ciphertext, nil) |
| if err != nil { |
| t.Errorf("encryption round trip failed: %s", err) |
| return |
| } |
| checkBytesEqual(t, "decrypted", decrypted, plaintext) |
| } |
| |
| // HpkeTestVector defines the subset of test-vectors.json that we read. |
| type HpkeTestVector struct { |
| KEM uint16 `json:"kem_id"` |
| Mode uint8 `json:"mode"` |
| KDF uint16 `json:"kdf_id"` |
| AEAD uint16 `json:"aead_id"` |
| Info HexString `json:"info"` |
| PSK HexString `json:"psk"` |
| PSKID HexString `json:"psk_id"` |
| SecretKeyR HexString `json:"skRm"` |
| SecretKeyE HexString `json:"skEm"` |
| PublicKeyR HexString `json:"pkRm"` |
| PublicKeyE HexString `json:"pkEm"` |
| Enc HexString `json:"enc"` |
| Encryptions []EncryptionTestVector `json:"encryptions"` |
| Exports []ExportTestVector `json:"exports"` |
| } |
| type EncryptionTestVector struct { |
| Plaintext HexString `json:"plaintext"` |
| AdditionalData HexString `json:"aad"` |
| Ciphertext HexString `json:"ciphertext"` |
| } |
| type ExportTestVector struct { |
| ExportContext HexString `json:"exporter_context"` |
| ExportLength int `json:"L"` |
| ExportValue HexString `json:"exported_value"` |
| } |
| |
| // TestVectors checks all relevant test vectors in test-vectors.json. |
| func TestVectors(t *testing.T) { |
| jsonStr, err := ioutil.ReadFile(filepath.Join(*testDataDir, "test-vectors.json")) |
| if err != nil { |
| t.Errorf("error reading test vectors: %s", err) |
| return |
| } |
| |
| var testVectors []HpkeTestVector |
| err = json.Unmarshal(jsonStr, &testVectors) |
| if err != nil { |
| t.Errorf("error parsing test vectors: %s", err) |
| return |
| } |
| |
| var numSkippedTests = 0 |
| |
| for testNum, testVec := range testVectors { |
| // Skip this vector if it specifies an unsupported parameter. |
| if testVec.KEM != X25519WithHKDFSHA256 || |
| (testVec.Mode != hpkeModeBase && testVec.Mode != hpkeModePSK) || |
| testVec.AEAD == exportOnlyAEAD { |
| numSkippedTests++ |
| continue |
| } |
| |
| testVec := testVec // capture the range variable |
| t.Run(fmt.Sprintf("test%d,Mode=%d,KDF=%d,AEAD=%d", testNum, testVec.Mode, testVec.KDF, testVec.AEAD), func(t *testing.T) { |
| var senderContext *Context |
| var receiverContext *Context |
| var enc []byte |
| var err error |
| |
| switch testVec.Mode { |
| case hpkeModeBase: |
| senderContext, enc, err = SetupBaseSenderX25519(testVec.KDF, testVec.AEAD, testVec.PublicKeyR, testVec.Info, |
| func() ([]byte, []byte, error) { |
| return testVec.PublicKeyE, testVec.SecretKeyE, nil |
| }) |
| if err != nil { |
| t.Errorf("failed to set up sender: %s", err) |
| return |
| } |
| checkBytesEqual(t, "sender enc", enc, testVec.Enc) |
| |
| receiverContext, err = SetupBaseReceiverX25519(testVec.KDF, testVec.AEAD, enc, testVec.SecretKeyR, testVec.Info) |
| if err != nil { |
| t.Errorf("failed to set up receiver: %s", err) |
| return |
| } |
| case hpkeModePSK: |
| senderContext, enc, err = SetupPSKSenderX25519(testVec.KDF, testVec.AEAD, testVec.PublicKeyR, testVec.Info, testVec.PSK, testVec.PSKID, |
| func() ([]byte, []byte, error) { |
| return testVec.PublicKeyE, testVec.SecretKeyE, nil |
| }) |
| if err != nil { |
| t.Errorf("failed to set up sender: %s", err) |
| return |
| } |
| checkBytesEqual(t, "sender enc", enc, testVec.Enc) |
| |
| receiverContext, err = SetupPSKReceiverX25519(testVec.KDF, testVec.AEAD, enc, testVec.SecretKeyR, testVec.Info, testVec.PSK, testVec.PSKID) |
| if err != nil { |
| t.Errorf("failed to set up receiver: %s", err) |
| return |
| } |
| default: |
| panic("unsupported mode") |
| } |
| |
| for encryptionNum, e := range testVec.Encryptions { |
| ciphertext := senderContext.Seal(e.Plaintext, e.AdditionalData) |
| checkBytesEqual(t, "ciphertext", ciphertext, e.Ciphertext) |
| |
| decrypted, err := receiverContext.Open(ciphertext, e.AdditionalData) |
| if err != nil { |
| t.Errorf("decryption %d failed: %s", encryptionNum, err) |
| return |
| } |
| checkBytesEqual(t, "decrypted plaintext", decrypted, e.Plaintext) |
| } |
| |
| for _, ex := range testVec.Exports { |
| exportValue := senderContext.Export(ex.ExportContext, ex.ExportLength) |
| checkBytesEqual(t, "exportValue", exportValue, ex.ExportValue) |
| |
| exportValue = receiverContext.Export(ex.ExportContext, ex.ExportLength) |
| checkBytesEqual(t, "exportValue", exportValue, ex.ExportValue) |
| } |
| }) |
| } |
| |
| if numSkippedTests == len(testVectors) { |
| panic("no test vectors were used") |
| } |
| } |
| |
| // HexString enables us to unmarshal JSON strings containing hex byte strings. |
| type HexString []byte |
| |
| func (h *HexString) UnmarshalJSON(data []byte) error { |
| if len(data) < 2 || data[0] != '"' || data[len(data)-1] != '"' { |
| return errors.New("missing double quotes") |
| } |
| var err error |
| *h, err = hex.DecodeString(string(data[1 : len(data)-1])) |
| return err |
| } |
| |
| func checkBytesEqual(t *testing.T, name string, actual, expected []byte) { |
| if !bytes.Equal(actual, expected) { |
| t.Errorf("%s = %x; want %x", name, actual, expected) |
| } |
| } |