Skip the field inversion when just measuring output size.

https://boringssl-review.googlesource.com/c/boringssl/+/41084
inadvertently added a somewhat expensive operation (field inversion) in
the path of EC_POINT_point2oct when passed with buf == NULL. The result
is a caller that calls the function twice, first to measure and then to
serialize, actually ends up doing the field inversion twice.

Fix this by removing the dual-use calling convention from the internal
function and just have a separate function to measure the output size
separately. It's slightly subtle because EC_POINT_point2oct would check
for the point at infinity by way of converting to affine coordinates, so
we do need to repeat that check.

As part of this, add a unit test for
https://boringssl-review.googlesource.com/6488, which rejected the point
at infinity way back.

Change-Id: I3b6c0f95cced9c00489386f064a2c3f0bb1776f8
Reviewed-on: https://boringssl-review.googlesource.com/c/boringssl/+/55065
Reviewed-by: Adam Langley <agl@google.com>
Commit-Queue: David Benjamin <davidben@google.com>
diff --git a/crypto/fipsmodule/ec/ec_test.cc b/crypto/fipsmodule/ec/ec_test.cc
index edcfeaa..c54adb5 100644
--- a/crypto/fipsmodule/ec/ec_test.cc
+++ b/crypto/fipsmodule/ec/ec_test.cc
@@ -891,6 +891,29 @@
   EXPECT_TRUE(EC_POINT_is_at_infinity(group(), sum.get()));
 }
 
+// Test that we refuse to encode or decode the point at infinity.
+TEST_P(ECCurveTest, EncodeInfinity) {
+  // The point at infinity is encoded as a single zero byte, but we do not
+  // support it.
+  static const uint8_t kInfinity[] = {0};
+  bssl::UniquePtr<EC_POINT> inf(EC_POINT_new(group()));
+  ASSERT_TRUE(inf);
+  EXPECT_FALSE(EC_POINT_oct2point(group(), inf.get(), kInfinity,
+                                  sizeof(kInfinity), nullptr));
+
+  // Encoding it also fails.
+  ASSERT_TRUE(EC_POINT_set_to_infinity(group(), inf.get()));
+  uint8_t buf[128];
+  EXPECT_EQ(
+      0u, EC_POINT_point2oct(group(), inf.get(), POINT_CONVERSION_UNCOMPRESSED,
+                             buf, sizeof(buf), nullptr));
+
+  // Measuring the length of the encoding also fails.
+  EXPECT_EQ(
+      0u, EC_POINT_point2oct(group(), inf.get(), POINT_CONVERSION_UNCOMPRESSED,
+                             nullptr, 0, nullptr));
+}
+
 static std::vector<EC_builtin_curve> AllCurves() {
   const size_t num_curves = EC_get_builtin_curves(nullptr, 0);
   std::vector<EC_builtin_curve> curves(num_curves);
diff --git a/crypto/fipsmodule/ec/internal.h b/crypto/fipsmodule/ec/internal.h
index f6c8e8a..0d53546 100644
--- a/crypto/fipsmodule/ec/internal.h
+++ b/crypto/fipsmodule/ec/internal.h
@@ -439,11 +439,18 @@
                                  size_t *out_len, size_t max_out,
                                  const EC_RAW_POINT *p);
 
-// ec_point_to_bytes behaves like |EC_POINT_point2oct| but takes an
-// |EC_AFFINE|.
+// ec_point_byte_len returns the number of bytes in the byte representation of
+// a non-infinity point in |group|, encoded according to |form|, or zero if
+// |form| is invalid.
+size_t ec_point_byte_len(const EC_GROUP *group, point_conversion_form_t form);
+
+// ec_point_to_bytes encodes |point| according to |form| and writes the result
+// |buf|. It returns the size of the output on success or zero on error. At most
+// |max_out| bytes will be written. The buffer should be at least
+// |ec_point_byte_len| long to guarantee success.
 size_t ec_point_to_bytes(const EC_GROUP *group, const EC_AFFINE *point,
                          point_conversion_form_t form, uint8_t *buf,
-                         size_t len);
+                         size_t max_out);
 
 // ec_point_from_uncompressed parses |in| as a point in uncompressed form and
 // sets the result to |out|. It returns one on success and zero if the input was
diff --git a/crypto/fipsmodule/ec/oct.c b/crypto/fipsmodule/ec/oct.c
index ddd0f37..7032635 100644
--- a/crypto/fipsmodule/ec/oct.c
+++ b/crypto/fipsmodule/ec/oct.c
@@ -73,9 +73,7 @@
 #include "internal.h"
 
 
-size_t ec_point_to_bytes(const EC_GROUP *group, const EC_AFFINE *point,
-                         point_conversion_form_t form, uint8_t *buf,
-                         size_t len) {
+size_t ec_point_byte_len(const EC_GROUP *group, point_conversion_form_t form) {
   if (form != POINT_CONVERSION_COMPRESSED &&
       form != POINT_CONVERSION_UNCOMPRESSED) {
     OPENSSL_PUT_ERROR(EC, EC_R_INVALID_FORM);
@@ -88,27 +86,30 @@
     // Uncompressed points have a second coordinate.
     output_len += field_len;
   }
+  return output_len;
+}
 
-  // if 'buf' is NULL, just return required length
-  if (buf != NULL) {
-    if (len < output_len) {
-      OPENSSL_PUT_ERROR(EC, EC_R_BUFFER_TOO_SMALL);
-      return 0;
-    }
+size_t ec_point_to_bytes(const EC_GROUP *group, const EC_AFFINE *point,
+                         point_conversion_form_t form, uint8_t *buf,
+                         size_t len) {
+  size_t output_len = ec_point_byte_len(group, form);
+  if (len < output_len) {
+    OPENSSL_PUT_ERROR(EC, EC_R_BUFFER_TOO_SMALL);
+    return 0;
+  }
 
-    size_t field_len_out;
-    ec_felem_to_bytes(group, buf + 1, &field_len_out, &point->X);
-    assert(field_len_out == field_len);
+  size_t field_len;
+  ec_felem_to_bytes(group, buf + 1, &field_len, &point->X);
+  assert(field_len == BN_num_bytes(&group->field));
 
-    if (form == POINT_CONVERSION_UNCOMPRESSED) {
-      ec_felem_to_bytes(group, buf + 1 + field_len, &field_len_out, &point->Y);
-      assert(field_len_out == field_len);
-      buf[0] = form;
-    } else {
-      uint8_t y_buf[EC_MAX_BYTES];
-      ec_felem_to_bytes(group, y_buf, &field_len_out, &point->Y);
-      buf[0] = form + (y_buf[field_len_out - 1] & 1);
-    }
+  if (form == POINT_CONVERSION_UNCOMPRESSED) {
+    ec_felem_to_bytes(group, buf + 1 + field_len, &field_len, &point->Y);
+    assert(field_len == BN_num_bytes(&group->field));
+    buf[0] = form;
+  } else {
+    uint8_t y_buf[EC_MAX_BYTES];
+    ec_felem_to_bytes(group, y_buf, &field_len, &point->Y);
+    buf[0] = form + (y_buf[field_len - 1] & 1);
   }
 
   return output_len;
@@ -214,6 +215,15 @@
     OPENSSL_PUT_ERROR(EC, EC_R_INCOMPATIBLE_OBJECTS);
     return 0;
   }
+  if (buf == NULL) {
+    // When |buf| is NULL, just return the number of bytes that would be
+    // written, without doing an expensive Jacobian-to-affine conversion.
+    if (ec_GFp_simple_is_at_infinity(group, &point->raw)) {
+      OPENSSL_PUT_ERROR(EC, EC_R_POINT_AT_INFINITY);
+      return 0;
+    }
+    return ec_point_byte_len(group, form);
+  }
   EC_AFFINE affine;
   if (!ec_jacobian_to_affine(group, &affine, &point->raw)) {
     return 0;
diff --git a/crypto/trust_token/pmbtoken.c b/crypto/trust_token/pmbtoken.c
index 68d8909..ab09f01 100644
--- a/crypto/trust_token/pmbtoken.c
+++ b/crypto/trust_token/pmbtoken.c
@@ -123,8 +123,7 @@
 
 static int point_to_cbb(CBB *out, const EC_GROUP *group,
                         const EC_AFFINE *point) {
-  size_t len =
-      ec_point_to_bytes(group, point, POINT_CONVERSION_UNCOMPRESSED, NULL, 0);
+  size_t len = ec_point_byte_len(group, POINT_CONVERSION_UNCOMPRESSED);
   if (len == 0) {
     return 0;
   }
diff --git a/crypto/trust_token/voprf.c b/crypto/trust_token/voprf.c
index cedee1e..f8e1c4c 100644
--- a/crypto/trust_token/voprf.c
+++ b/crypto/trust_token/voprf.c
@@ -62,8 +62,7 @@
 
 static int cbb_add_point(CBB *out, const EC_GROUP *group,
                          const EC_AFFINE *point) {
-  size_t len =
-      ec_point_to_bytes(group, point, POINT_CONVERSION_UNCOMPRESSED, NULL, 0);
+  size_t len = ec_point_byte_len(group,  POINT_CONVERSION_UNCOMPRESSED);
   if (len == 0) {
     return 0;
   }