Tidy up primality code.

We BN_cmp with 1 at the top, so the absolute value code never runs.
This simplifies the BN_CTX business considerably. Also add a test for
negative prime numbers.

Change-Id: I500a56bc285c2f75576947cfb518e75c9e6861ce
Reviewed-on: https://boringssl-review.googlesource.com/15367
Commit-Queue: Steven Valdez <svaldez@google.com>
Reviewed-by: Steven Valdez <svaldez@google.com>
CQ-Verified: CQ bot account: commit-bot@chromium.org <commit-bot@chromium.org>
diff --git a/crypto/bn/bn_test.cc b/crypto/bn/bn_test.cc
index fc48339..2d92276 100644
--- a/crypto/bn/bn_test.cc
+++ b/crypto/bn/bn_test.cc
@@ -2739,6 +2739,26 @@
     }
   }
 
+  // Negative numbers are not prime.
+  if (!BN_set_word(p.get(), 7)) {
+    return false;
+  }
+  BN_set_negative(p.get(), 1);
+  if (!BN_primality_test(&is_probably_prime_1, p.get(), BN_prime_checks, ctx,
+                         false /* do_trial_division */,
+                         nullptr /* callback */) ||
+      is_probably_prime_1 != 0 ||
+      !BN_primality_test(&is_probably_prime_2, p.get(), BN_prime_checks, ctx,
+                         true /* do_trial_division */,
+                         nullptr /* callback */) ||
+      is_probably_prime_2 != 0) {
+    fprintf(stderr,
+            "TestPrimeChecking failed for -7 (is_prime: 0 vs %d without "
+            "trial division vs %d with it)\n",
+            is_probably_prime_1, is_probably_prime_2);
+    return false;
+  }
+
   return true;
 }
 
diff --git a/crypto/bn/prime.c b/crypto/bn/prime.c
index 400b4f1..5fc7e49 100644
--- a/crypto/bn/prime.c
+++ b/crypto/bn/prime.c
@@ -482,12 +482,8 @@
   return BN_is_prime_fasttest_ex(candidate, checks, ctx, 0, cb);
 }
 
-int BN_is_prime_fasttest_ex(const BIGNUM *a, int checks, BN_CTX *ctx_passed,
+int BN_is_prime_fasttest_ex(const BIGNUM *a, int checks, BN_CTX *ctx,
                             int do_trial_division, BN_GENCB *cb) {
-  int i, ret = -1;
-  BN_CTX *ctx = NULL;
-  const BIGNUM *A = NULL;
-
   if (BN_cmp(a, BN_value_one()) <= 0) {
     return 0;
   }
@@ -503,10 +499,10 @@
   }
 
   if (do_trial_division) {
-    for (i = 1; i < NUMPRIMES; i++) {
+    for (int i = 1; i < NUMPRIMES; i++) {
       BN_ULONG mod = BN_mod_word(a, primes[i]);
       if (mod == (BN_ULONG)-1) {
-        goto err;
+        return -1;
       }
       if (mod == 0) {
         return BN_is_word(a, primes[i]);
@@ -514,44 +510,29 @@
     }
 
     if (!BN_GENCB_call(cb, 1, -1)) {
-      goto err;
+      return -1;
     }
   }
 
-  if (ctx_passed != NULL) {
-    ctx = ctx_passed;
-  } else if ((ctx = BN_CTX_new()) == NULL) {
-    goto err;
-  }
-  BN_CTX_start(ctx);
-
-  /* A := abs(a) */
-  if (a->neg) {
-    BIGNUM *t = BN_CTX_get(ctx);
-    if (t == NULL || !BN_copy(t, a)) {
-      goto err;
+  int ret = -1;
+  BN_CTX *ctx_allocated = NULL;
+  if (ctx == NULL) {
+    ctx_allocated = BN_CTX_new();
+    if (ctx_allocated == NULL) {
+      return -1;
     }
-    t->neg = 0;
-    A = t;
-  } else {
-    A = a;
+    ctx = ctx_allocated;
   }
 
   enum bn_primality_result_t result;
-  if (!BN_enhanced_miller_rabin_primality_test(&result, A, checks, ctx, cb)) {
+  if (!BN_enhanced_miller_rabin_primality_test(&result, a, checks, ctx, cb)) {
     goto err;
   }
 
   ret = (result == bn_probably_prime);
 
 err:
-  if (ctx != NULL) {
-    BN_CTX_end(ctx);
-    if (ctx_passed == NULL) {
-      BN_CTX_free(ctx);
-    }
-  }
-
+  BN_CTX_free(ctx_allocated);
   return ret;
 }
 
@@ -613,10 +594,8 @@
 
   /* Montgomery setup for computations mod A */
   mont = BN_MONT_CTX_new();
-  if (mont == NULL) {
-    goto err;
-  }
-  if (!BN_MONT_CTX_set(mont, w, ctx)) {
+  if (mont == NULL ||
+      !BN_MONT_CTX_set(mont, w, ctx)) {
     goto err;
   }