Always use |BN_mod_exp_mont|/|BN_mod_exp_mont_consttime| in RSA.

This removes a hard dependency on |BN_mod_exp|, which will allow the
linker to drop it in programs that don't use other features that
require it.

Also, remove the |mont| member of |bn_blinding_st| in favor of having
callers pass it when necssaary. The |mont| member was a weak reference,
and weak references tend to be error-prone.

Finally, reduce the scope of some parts of the blinding code to
|static|.

Change-Id: I16d8ccc2d6d950c1bb40377988daf1a377a21fe6
Reviewed-on: https://boringssl-review.googlesource.com/7111
Reviewed-by: David Benjamin <davidben@google.com>
diff --git a/crypto/rsa/blinding.c b/crypto/rsa/blinding.c
index e6b987d..9bd0263 100644
--- a/crypto/rsa/blinding.c
+++ b/crypto/rsa/blinding.c
@@ -125,13 +125,12 @@
   BIGNUM *e;
   BIGNUM *mod;
   int counter;
-  /* mont is the Montgomery context used for this |BN_BLINDING|. It is not
-   * owned and must outlive this structure. */
-  const BN_MONT_CTX *mont;
-  int (*bn_mod_exp)(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
-                    const BIGNUM *m, BN_CTX *ctx, const BN_MONT_CTX *mont);
 };
 
+static BN_BLINDING *bn_blinding_create_param(BN_BLINDING *b, const BIGNUM *e,
+                                             BIGNUM *m, BN_CTX *ctx,
+                                             const BN_MONT_CTX *mont_ctx);
+
 BN_BLINDING *BN_BLINDING_new(const BIGNUM *A, const BIGNUM *Ai, BIGNUM *mod) {
   BN_BLINDING *ret = NULL;
 
@@ -159,9 +158,7 @@
   if (ret->mod == NULL) {
     goto err;
   }
-  if (BN_get_flags(mod, BN_FLG_CONSTTIME) != 0) {
-    BN_set_flags(ret->mod, BN_FLG_CONSTTIME);
-  }
+  BN_set_flags(ret->mod, BN_FLG_CONSTTIME);
 
   /* Set the counter to the special value -1
    * to indicate that this is never-used fresh blinding
@@ -186,7 +183,8 @@
   OPENSSL_free(r);
 }
 
-int BN_BLINDING_update(BN_BLINDING *b, BN_CTX *ctx) {
+static int bn_blinding_update(BN_BLINDING *b, BN_CTX *ctx,
+                              const BN_MONT_CTX *mont_ctx) {
   int ret = 0;
 
   if (b->A == NULL || b->Ai == NULL) {
@@ -200,7 +198,7 @@
 
   if (++b->counter == BN_BLINDING_COUNTER && b->e != NULL) {
     /* re-create blinding parameters */
-    if (!BN_BLINDING_create_param(b, NULL, NULL, ctx, NULL, NULL)) {
+    if (!bn_blinding_create_param(b, NULL, NULL, ctx, mont_ctx)) {
       goto err;
     }
   } else {
@@ -221,7 +219,8 @@
   return ret;
 }
 
-int BN_BLINDING_convert(BIGNUM *n, BN_BLINDING *b, BN_CTX *ctx) {
+int BN_BLINDING_convert(BIGNUM *n, BN_BLINDING *b, BN_CTX *ctx,
+                        const BN_MONT_CTX *mont_ctx) {
   int ret = 1;
 
   if (b->A == NULL || b->Ai == NULL) {
@@ -232,7 +231,7 @@
   if (b->counter == -1) {
     /* Fresh blinding, doesn't need updating. */
     b->counter = 0;
-  } else if (!BN_BLINDING_update(b, ctx)) {
+  } else if (!bn_blinding_update(b, ctx, mont_ctx)) {
     return 0;
   }
 
@@ -251,11 +250,9 @@
   return BN_mod_mul(n, n, b->Ai, b->mod, ctx);
 }
 
-BN_BLINDING *BN_BLINDING_create_param(
+static BN_BLINDING *bn_blinding_create_param(
     BN_BLINDING *b, const BIGNUM *e, BIGNUM *m, BN_CTX *ctx,
-    int (*bn_mod_exp)(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
-                      const BIGNUM *m, BN_CTX *ctx, const BN_MONT_CTX *mont),
-    const BN_MONT_CTX *mont) {
+    const BN_MONT_CTX *mont_ctx) {
   int retry_counter = 32;
   BN_BLINDING *ret = NULL;
 
@@ -284,13 +281,6 @@
     goto err;
   }
 
-  if (bn_mod_exp != NULL) {
-    ret->bn_mod_exp = bn_mod_exp;
-  }
-  if (mont != NULL) {
-    ret->mont = mont;
-  }
-
   do {
     if (!BN_rand_range(ret->A, ret->mod)) {
       goto err;
@@ -313,14 +303,8 @@
     }
   } while (1);
 
-  if (ret->bn_mod_exp != NULL && ret->mont != NULL) {
-    if (!ret->bn_mod_exp(ret->A, ret->A, ret->e, ret->mod, ctx, ret->mont)) {
-      goto err;
-    }
-  } else {
-    if (!BN_mod_exp(ret->A, ret->A, ret->e, ret->mod, ctx)) {
-      goto err;
-    }
+  if (!BN_mod_exp_mont(ret->A, ret->A, ret->e, ret->mod, ctx, mont_ctx)) {
+    goto err;
   }
 
   return ret;
@@ -403,8 +387,7 @@
     }
   }
 
-  ret = BN_BLINDING_create_param(NULL, e, n, ctx, rsa->meth->bn_mod_exp,
-                                 mont_ctx);
+  ret = bn_blinding_create_param(NULL, e, n, ctx, mont_ctx);
   if (ret == NULL) {
     OPENSSL_PUT_ERROR(RSA, ERR_R_BN_LIB);
     goto err;
diff --git a/crypto/rsa/internal.h b/crypto/rsa/internal.h
index 4d27344..f8e0fa2 100644
--- a/crypto/rsa/internal.h
+++ b/crypto/rsa/internal.h
@@ -92,14 +92,9 @@
 
 BN_BLINDING *BN_BLINDING_new(const BIGNUM *A, const BIGNUM *Ai, BIGNUM *mod);
 void BN_BLINDING_free(BN_BLINDING *b);
-int BN_BLINDING_update(BN_BLINDING *b, BN_CTX *ctx);
-int BN_BLINDING_convert(BIGNUM *n, BN_BLINDING *b, BN_CTX *ctx);
+int BN_BLINDING_convert(BIGNUM *n, BN_BLINDING *b, BN_CTX *ctx,
+                        const BN_MONT_CTX *mont_ctx);
 int BN_BLINDING_invert(BIGNUM *n, const BN_BLINDING *b, BN_CTX *ctx);
-BN_BLINDING *BN_BLINDING_create_param(
-    BN_BLINDING *b, const BIGNUM *e, BIGNUM *m, BN_CTX *ctx,
-    int (*bn_mod_exp)(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
-                      const BIGNUM *m, BN_CTX *ctx, const BN_MONT_CTX *mont),
-    const BN_MONT_CTX *mont);
 BN_BLINDING *rsa_setup_blinding(RSA *rsa, BN_CTX *in_ctx);
 
 
diff --git a/crypto/rsa/rsa_impl.c b/crypto/rsa/rsa_impl.c
index 7eedf6f..8dd59dc 100644
--- a/crypto/rsa/rsa_impl.c
+++ b/crypto/rsa/rsa_impl.c
@@ -177,7 +177,7 @@
     }
   }
 
-  if (!rsa->meth->bn_mod_exp(result, f, rsa->e, rsa->n, ctx, rsa->mont_n)) {
+  if (!BN_mod_exp_mont(result, f, rsa->e, rsa->n, ctx, rsa->mont_n)) {
     goto err;
   }
 
@@ -488,7 +488,7 @@
     }
   }
 
-  if (!rsa->meth->bn_mod_exp(result, f, rsa->e, rsa->n, ctx, rsa->mont_n)) {
+  if (!BN_mod_exp_mont(result, f, rsa->e, rsa->n, ctx, rsa->mont_n)) {
     goto err;
   }
 
@@ -565,7 +565,7 @@
       OPENSSL_PUT_ERROR(RSA, ERR_R_INTERNAL_ERROR);
       goto err;
     }
-    if (!BN_BLINDING_convert(f, blinding, ctx)) {
+    if (!BN_BLINDING_convert(f, blinding, ctx, rsa->mont_n)) {
       goto err;
     }
   }
@@ -591,7 +591,7 @@
       }
     }
 
-    if (!rsa->meth->bn_mod_exp(result, f, d, rsa->n, ctx, rsa->mont_n)) {
+    if (!BN_mod_exp_mont_consttime(result, f, d, rsa->n, ctx, rsa->mont_n)) {
       goto err;
     }
   }
@@ -677,7 +677,7 @@
   /* compute r1^dmq1 mod q */
   dmq1 = &local_dmq1;
   BN_with_flags(dmq1, rsa->dmq1, BN_FLG_CONSTTIME);
-  if (!rsa->meth->bn_mod_exp(m1, r1, dmq1, rsa->q, ctx, rsa->mont_q)) {
+  if (!BN_mod_exp_mont_consttime(m1, r1, dmq1, rsa->q, ctx, rsa->mont_q)) {
     goto err;
   }
 
@@ -691,7 +691,7 @@
   /* compute r1^dmp1 mod p */
   dmp1 = &local_dmp1;
   BN_with_flags(dmp1, rsa->dmp1, BN_FLG_CONSTTIME);
-  if (!rsa->meth->bn_mod_exp(r0, r1, dmp1, rsa->p, ctx, rsa->mont_p)) {
+  if (!BN_mod_exp_mont_consttime(r0, r1, dmp1, rsa->p, ctx, rsa->mont_p)) {
     goto err;
   }
 
@@ -756,7 +756,7 @@
       goto err;
     }
 
-    if (!rsa->meth->bn_mod_exp(m1, r1, exp, prime, ctx, ap->mont)) {
+    if (!BN_mod_exp_mont_consttime(m1, r1, exp, prime, ctx, ap->mont)) {
       goto err;
     }
 
@@ -773,7 +773,7 @@
   }
 
   if (rsa->e && rsa->n) {
-    if (!rsa->meth->bn_mod_exp(vrfy, r0, rsa->e, rsa->n, ctx, rsa->mont_n)) {
+    if (!BN_mod_exp_mont(vrfy, r0, rsa->e, rsa->n, ctx, rsa->mont_n)) {
       goto err;
     }
     /* If 'I' was greater than (or equal to) rsa->n, the operation
@@ -801,7 +801,7 @@
 
       d = &local_d;
       BN_with_flags(d, rsa->d, BN_FLG_CONSTTIME);
-      if (!rsa->meth->bn_mod_exp(r0, I, d, rsa->n, ctx, rsa->mont_n)) {
+      if (!BN_mod_exp_mont_consttime(r0, I, d, rsa->n, ctx, rsa->mont_n)) {
         goto err;
       }
     }
@@ -1133,7 +1133,7 @@
   NULL /* private_transform (defaults to rsa_default_private_transform) */,
 
   mod_exp,
-  BN_mod_exp_mont /* bn_mod_exp */,
+  NULL /* bn_mod_exp */,
 
   RSA_FLAG_CACHE_PUBLIC | RSA_FLAG_CACHE_PRIVATE,
 
diff --git a/include/openssl/rsa.h b/include/openssl/rsa.h
index 3798f48..df75af0 100644
--- a/include/openssl/rsa.h
+++ b/include/openssl/rsa.h
@@ -525,6 +525,8 @@
 
   int (*mod_exp)(BIGNUM *r0, const BIGNUM *I, RSA *rsa,
                  BN_CTX *ctx); /* Can be null */
+
+  /* bn_mod_exp is deprecated and ignored. Set it to NULL. */
   int (*bn_mod_exp)(BIGNUM *r, const BIGNUM *a, const BIGNUM *p,
                     const BIGNUM *m, BN_CTX *ctx,
                     const BN_MONT_CTX *mont);