Add tests for |BN_mod_inverse| with modulus 1. Zero is only a valid input to or output of |BN_mod_inverse| when the modulus is one. |BN_MONT_CTX_set| actually depends on this, so test that this works. Change-Id: Ic18f1fe786f668394951d4309020c6ead95e5e28 Reviewed-on: https://boringssl-review.googlesource.com/8922 Reviewed-by: Adam Langley <agl@google.com> Commit-Queue: Adam Langley <agl@google.com> CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/bn/bn_test.cc b/crypto/bn/bn_test.cc index 3405cbd..20f64d6 100644 --- a/crypto/bn/bn_test.cc +++ b/crypto/bn/bn_test.cc
@@ -584,6 +584,33 @@ return true; } +static bool TestModInv(FileTest *t, BN_CTX *ctx) { + ScopedBIGNUM a = GetBIGNUM(t, "A"); + ScopedBIGNUM m = GetBIGNUM(t, "M"); + ScopedBIGNUM mod_inv = GetBIGNUM(t, "ModInv"); + if (!a || !m || !mod_inv) { + return false; + } + + ScopedBIGNUM ret(BN_new()); + if (!ret || + !BN_mod_inverse(ret.get(), a.get(), m.get(), ctx) || + !ExpectBIGNUMsEqual(t, "inv(A) (mod M)", mod_inv.get(), ret.get())) { + return false; + } + + BN_set_flags(a.get(), BN_FLG_CONSTTIME); + + if (!ret || + !BN_mod_inverse(ret.get(), a.get(), m.get(), ctx) || + !ExpectBIGNUMsEqual(t, "inv(A) (mod M) (constant-time)", mod_inv.get(), + ret.get())) { + return false; + } + + return true; +} + struct Test { const char *name; bool (*func)(FileTest *t, BN_CTX *ctx); @@ -601,6 +628,7 @@ {"ModExp", TestModExp}, {"Exp", TestExp}, {"ModSqrt", TestModSqrt}, + {"ModInv", TestModInv}, }; static bool RunTest(FileTest *t, void *arg) {
diff --git a/crypto/bn/bn_tests.txt b/crypto/bn/bn_tests.txt index 6bb688b..2c09520 100644 --- a/crypto/bn/bn_tests.txt +++ b/crypto/bn/bn_tests.txt
@@ -10689,3 +10689,19 @@ ModSqrt = a1d52989f12f204d3d2167d9b1e6c8a6174c0c786a979a5952383b7b8bd186 A = 2eee37cf06228a387788188e650bc6d8a2ff402931443f69156a29155eca07dcb45f3aac238d92943c0c25c896098716baa433f25bd696a142f5a69d5d937e81 P = 9df9d6cc20b8540411af4e5357ef2b0353cb1f2ab5ffc3e246b41c32f71e951f + +ModInv = 00 +A = 00 +M = 01 + +ModInv = 00 +A = 01 +M = 01 + +ModInv = 00 +A = 02 +M = 01 + +ModInv = 00 +A = 03 +M = 01
diff --git a/crypto/bn/check_bn_tests.go b/crypto/bn/check_bn_tests.go index 9a1b65e..0d2042e 100644 --- a/crypto/bn/check_bn_tests.go +++ b/crypto/bn/check_bn_tests.go
@@ -247,6 +247,11 @@ } } } + case "ModInv": + if checkKeys(test, "A", "M", "ModInv") { + r := new(big.Int).ModInverse(test.Values["A"], test.Values["M"]) + checkResult(test, "A ^ -1 (mod M)", "ModInv", r) + } default: fmt.Fprintf(os.Stderr, "Line %d: unknown test type %q.\n", test.LineNumber, test.Type) }