Add digest sum handling to the tool.

Android might want to replace the system *sum (i.e. md5sum, sha256sum
etc) binaries with a symlink to the BoringSSL tool binary.

This change also allows the tool to figure out what to do based on
argv[0] if it matches one of the known commands.

Change-Id: Ia4fc3cff45ce2ae623dae6786eea5d7ad127d44b
Reviewed-on: https://boringssl-review.googlesource.com/2940
Reviewed-by: Adam Langley <agl@google.com>
diff --git a/tool/CMakeLists.txt b/tool/CMakeLists.txt
index 03c1c21..c9c093b 100644
--- a/tool/CMakeLists.txt
+++ b/tool/CMakeLists.txt
@@ -5,9 +5,10 @@
 
 	args.cc
 	client.cc
-	server.cc
 	const.cc
+	digest.cc
 	pkcs12.cc
+	server.cc
 	speed.cc
 	tool.cc
 	transport_common.cc
diff --git a/tool/digest.cc b/tool/digest.cc
new file mode 100644
index 0000000..93058ed
--- /dev/null
+++ b/tool/digest.cc
@@ -0,0 +1,459 @@
+/* Copyright (c) 2014, Google Inc.
+ *
+ * Permission to use, copy, modify, and/or distribute this software for any
+ * purpose with or without fee is hereby granted, provided that the above
+ * copyright notice and this permission notice appear in all copies.
+ *
+ * THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
+ * WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
+ * MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR ANY
+ * SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
+ * WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN ACTION
+ * OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF OR IN
+ * CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. */
+
+#include <openssl/base.h>
+
+#if !defined(OPENSSL_WINDOWS)
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include <errno.h>
+#include <fcntl.h>
+#include <limits.h>
+#include <stdio.h>
+#include <sys/stat.h>
+#include <sys/types.h>
+#include <unistd.h>
+
+#include <openssl/digest.h>
+
+
+struct close_delete {
+  void operator()(int *fd) {
+    close(*fd);
+  }
+};
+
+template<typename T, typename R, R (*func) (T*)>
+struct func_delete {
+  void operator()(T* obj) {
+    func(obj);
+  }
+};
+
+// Source is an awkward expression of a union type in C++: Stdin | File filename.
+struct Source {
+  enum Type {
+    STDIN,
+  };
+
+  Source() : is_stdin_(false) {}
+  Source(Type) : is_stdin_(true) {}
+  Source(const std::string &name) : is_stdin_(false), filename_(name) {}
+
+  bool is_stdin() const { return is_stdin_; }
+  const std::string &filename() const { return filename_; }
+
+ private:
+  bool is_stdin_;
+  std::string filename_;
+};
+
+static const char kStdinName[] = "standard input";
+
+// OpenFile opens the regular file named |filename| and sets |*out_fd| to be a
+// file descriptor to it. Returns true on sucess or prints an error to stderr
+// and returns false on error.
+static bool OpenFile(int *out_fd, const std::string &filename) {
+  *out_fd = -1;
+
+  int fd = open(filename.c_str(), O_RDONLY);
+  if (fd < 0) {
+    fprintf(stderr, "Failed to open input file '%s': %s\n", filename.c_str(),
+            strerror(errno));
+    return false;
+  }
+
+  struct stat st;
+  if (fstat(fd, &st)) {
+    fprintf(stderr, "Failed to stat input file '%s': %s\n", filename.c_str(),
+            strerror(errno));
+    goto err;
+  }
+
+  if (!S_ISREG(st.st_mode)) {
+    fprintf(stderr, "%s: not a regular file\n", filename.c_str());
+    goto err;
+  }
+
+  *out_fd = fd;
+  return true;
+
+err:
+  close(fd);
+  return false;
+}
+
+// SumFile hashes the contents of |source| with |md| and sets |*out_hex| to the
+// hex-encoded result.
+//
+// It returns true on success or prints an error to stderr and returns false on
+// error.
+static bool SumFile(std::string *out_hex, const EVP_MD *md,
+                    const Source &source) {
+  std::unique_ptr<int, close_delete> scoped_fd;
+  int fd;
+
+  if (source.is_stdin()) {
+    fd = 0;
+  } else {
+    if (!OpenFile(&fd, source.filename())) {
+      return false;
+    }
+    scoped_fd.reset(&fd);
+  }
+
+  static const size_t kBufSize = 8192;
+  std::unique_ptr<uint8_t[]> buf(new uint8_t[kBufSize]);
+
+  EVP_MD_CTX ctx;
+  EVP_MD_CTX_init(&ctx);
+  std::unique_ptr<EVP_MD_CTX, func_delete<EVP_MD_CTX, int, EVP_MD_CTX_cleanup>>
+      scoped_ctx(&ctx);
+
+  if (!EVP_DigestInit_ex(&ctx, md, NULL)) {
+    fprintf(stderr, "Failed to initialize EVP_MD_CTX.\n");
+    return false;
+  }
+
+  for (;;) {
+    ssize_t n;
+
+    do {
+      n = read(fd, buf.get(), kBufSize);
+    } while (n == -1 && errno == EINTR);
+
+    if (n == 0) {
+      break;
+    } else if (n < 0) {
+      fprintf(stderr, "Failed to read from %s: %s\n",
+              source.is_stdin() ? kStdinName : source.filename().c_str(),
+              strerror(errno));
+      return false;
+    }
+
+    if (!EVP_DigestUpdate(&ctx, buf.get(), n)) {
+      fprintf(stderr, "Failed to update hash.\n");
+      return false;
+    }
+  }
+
+  uint8_t digest[EVP_MAX_MD_SIZE];
+  unsigned digest_len;
+  if (!EVP_DigestFinal_ex(&ctx, digest, &digest_len)) {
+    fprintf(stderr, "Failed to finish hash.\n");
+    return false;
+  }
+
+  char hex_digest[EVP_MAX_MD_SIZE * 2];
+  static const char kHextable[] = "0123456789abcdef";
+  for (unsigned i = 0; i < digest_len; i++) {
+    const uint8_t b = digest[i];
+    hex_digest[i * 2] = kHextable[b >> 4];
+    hex_digest[i * 2 + 1] = kHextable[b & 0xf];
+  }
+  *out_hex = std::string(hex_digest, digest_len * 2);
+
+  return true;
+}
+
+// PrintFileSum hashes |source| with |md| and prints a line to stdout in the
+// format of the coreutils *sum utilities. It returns true on success or prints
+// an error to stderr and returns false on error.
+static bool PrintFileSum(const EVP_MD *md, const Source &source) {
+  std::string hex_digest;
+  if (!SumFile(&hex_digest, md, source)) {
+    return false;
+  }
+
+  printf("%s  %s\n", hex_digest.c_str(),
+         source.is_stdin() ? "-" : source.filename().c_str());
+  return true;
+}
+
+// CheckModeArguments contains arguments for the check mode. See the
+// sha256sum(1) man page for details.
+struct CheckModeArguments {
+  bool quiet = false;
+  bool status = false;
+  bool warn = false;
+  bool strict = false;
+};
+
+// Check reads lines from |source| where each line is in the format of the
+// coreutils *sum utilities. It attempts to verify each hash by reading the
+// file named in the line.
+//
+// It returns true if all files were verified and, if |args.strict|, no input
+// lines had formatting errors. Otherwise it prints errors to stderr and
+// returns false.
+static bool Check(const CheckModeArguments &args, const EVP_MD *md,
+                  const Source &source) {
+  std::unique_ptr<FILE, func_delete<FILE, int, fclose>> scoped_file;
+  FILE *file;
+
+  if (source.is_stdin()) {
+    file = stdin;
+  } else {
+    int fd;
+    if (!OpenFile(&fd, source.filename())) {
+      return false;
+    }
+
+    file = fdopen(fd, "r");
+    if (!file) {
+      perror("fdopen");
+      close(fd);
+      return false;
+    }
+
+    scoped_file = std::unique_ptr<FILE, func_delete<FILE, int, fclose>>(file);
+  }
+
+  const size_t hex_size = EVP_MD_size(md) * 2;
+  char line[EVP_MAX_MD_SIZE * 2 + 2 /* spaces */ + PATH_MAX + 1 /* newline */ +
+            1 /* NUL */];
+  unsigned bad_lines = 0;
+  unsigned parsed_lines = 0;
+  unsigned error_lines = 0;
+  unsigned bad_hash_lines = 0;
+  unsigned line_no = 0;
+  bool ok = true;
+  bool draining_overlong_line = false;
+
+  for (;;) {
+    line_no++;
+
+    if (fgets(line, sizeof(line), file) == nullptr) {
+      if (feof(file)) {
+        break;
+      }
+      fprintf(stderr, "Error reading from input.\n");
+      return false;
+    }
+
+    size_t len = strlen(line);
+
+    if (draining_overlong_line) {
+      if (line[len - 1] == '\n') {
+        draining_overlong_line = false;
+      }
+      continue;
+    }
+
+    const bool overlong = line[len - 1] != '\n' && !feof(file);
+
+    if (len < hex_size + 2 /* spaces */ + 1 /* filename */ ||
+        line[hex_size] != ' ' ||
+        line[hex_size + 1] != ' ' ||
+        overlong) {
+      bad_lines++;
+      if (args.warn) {
+        fprintf(stderr, "%s: %u: improperly formatted line\n",
+                source.is_stdin() ? kStdinName : source.filename().c_str(), line_no);
+      }
+      if (args.strict) {
+        ok = false;
+      }
+      if (overlong) {
+        draining_overlong_line = true;
+      }
+      continue;
+    }
+
+    if (line[len - 1] == '\n') {
+      line[len - 1] = 0;
+      len--;
+    }
+
+    parsed_lines++;
+
+    // coreutils does not attempt to restrict relative or absolute paths in the
+    // input so nor does this code.
+    std::string calculated_hex_digest;
+    const std::string target_filename(&line[hex_size + 2]);
+    Source target_source;
+    if (target_filename == "-") {
+      // coreutils reads from stdin if the filename is "-".
+      target_source = Source(Source::STDIN);
+    } else {
+      target_source = Source(target_filename);
+    }
+
+    if (!SumFile(&calculated_hex_digest, md, target_source)) {
+      error_lines++;
+      ok = false;
+      continue;
+    }
+
+    if (calculated_hex_digest != std::string(line, hex_size)) {
+      bad_hash_lines++;
+      if (!args.status) {
+        printf("%s: FAILED\n", target_filename.c_str());
+      }
+      ok = false;
+      continue;
+    }
+
+    if (!args.quiet) {
+      printf("%s: OK\n", target_filename.c_str());
+    }
+  }
+
+  if (!args.status) {
+    if (bad_lines > 0 && parsed_lines > 0) {
+      fprintf(stderr, "WARNING: %u line%s improperly formatted\n", bad_lines,
+              bad_lines == 1 ? " is" : "s are");
+    }
+    if (error_lines > 0) {
+      fprintf(stderr, "WARNING: %u computed checksum(s) did NOT match\n",
+              error_lines);
+    }
+  }
+
+  if (parsed_lines == 0) {
+    fprintf(stderr, "%s: no properly formatted checksum lines found.\n",
+            source.is_stdin() ? kStdinName : source.filename().c_str());
+    ok = false;
+  }
+
+  return ok;
+}
+
+// DigestSum acts like the coreutils *sum utilites, with the given hash
+// function.
+static bool DigestSum(const EVP_MD *md,
+                      const std::vector<std::string> &args) {
+  bool check_mode = false;
+  CheckModeArguments check_args;
+  bool check_mode_args_given = false;
+  std::vector<Source> sources;
+
+  auto it = args.begin();
+  while (it != args.end()) {
+    const std::string &arg = *it;
+    if (!arg.empty() && arg[0] != '-') {
+      break;
+    }
+
+    it++;
+
+    if (arg == "--") {
+      break;
+    }
+
+    if (arg[0] == "-") {
+      // "-" ends the argument list and indicates that stdin should be used.
+      sources.push_back(Source(Source::STDIN));
+      break;
+    }
+
+    if (arg.size() >= 2 && arg[0] == '-' && arg[1] != '-') {
+      for (size_t i = 1; i < arg.size(); i++) {
+        switch (arg[i]) {
+          case 'b':
+          case 't':
+            // Binary/text mode – irrelevent.
+            break;
+          case 'c':
+            check_mode = true;
+            break;
+          case 'w':
+            check_mode_args_given = true;
+            check_args.warn = true;
+            break;
+          default:
+            fprintf(stderr, "Unknown option '%c'.\n", arg[i]);
+            return false;
+        }
+      }
+    } else if (arg == "--binary" || arg == "--text") {
+      // Binary/text mode – irrelevent.
+    } else if (arg == "--check") {
+      check_mode = true;
+    } else if (arg == "--quiet") {
+      check_mode_args_given = true;
+      check_args.quiet = true;
+    } else if (arg == "--status") {
+      check_mode_args_given = true;
+      check_args.status = true;
+    } else if (arg == "--warn") {
+      check_mode_args_given = true;
+      check_args.warn = true;
+    } else if (arg == "--strict") {
+      check_mode_args_given = true;
+      check_args.strict = true;
+    } else {
+      fprintf(stderr, "Unknown option '%s'.\n", arg.c_str());
+      return false;
+    }
+  }
+
+  if (check_mode_args_given && !check_mode) {
+    fprintf(
+        stderr,
+        "Check mode arguments are only meaningful when verifying checksums.\n");
+    return false;
+  }
+
+  for (; it != args.end(); it++) {
+    sources.push_back(Source(*it));
+  }
+
+  if (sources.empty()) {
+    sources.push_back(Source(Source::STDIN));
+  }
+
+  bool ok = true;
+
+  if (check_mode) {
+    for (auto &source : sources) {
+      ok &= Check(check_args, md, source);
+    }
+  } else {
+    for (auto &source : sources) {
+      ok &= PrintFileSum(md, source);
+    }
+  }
+
+  return ok;
+}
+
+bool MD5Sum(const std::vector<std::string> &args) {
+  return DigestSum(EVP_md5(), args);
+}
+
+bool SHA1Sum(const std::vector<std::string> &args) {
+  return DigestSum(EVP_sha1(), args);
+}
+
+bool SHA224Sum(const std::vector<std::string> &args) {
+  return DigestSum(EVP_sha224(), args);
+}
+
+bool SHA256Sum(const std::vector<std::string> &args) {
+  return DigestSum(EVP_sha256(), args);
+}
+
+bool SHA384Sum(const std::vector<std::string> &args) {
+  return DigestSum(EVP_sha384(), args);
+}
+
+bool SHA512Sum(const std::vector<std::string> &args) {
+  return DigestSum(EVP_sha512(), args);
+}
+
+#endif  /* !OPENSSL_WINDOWS */
diff --git a/tool/tool.cc b/tool/tool.cc
index a57cd16..88b6f24 100644
--- a/tool/tool.cc
+++ b/tool/tool.cc
@@ -18,43 +18,97 @@
 #include <openssl/err.h>
 #include <openssl/ssl.h>
 
+#if !defined(OPENSSL_WINDOWS)
+#include <libgen.h>
+#endif
+
 
 #if !defined(OPENSSL_WINDOWS)
 bool Client(const std::vector<std::string> &args);
 bool Server(const std::vector<std::string> &args);
+bool MD5Sum(const std::vector<std::string> &args);
+bool SHA1Sum(const std::vector<std::string> &args);
+bool SHA224Sum(const std::vector<std::string> &args);
+bool SHA256Sum(const std::vector<std::string> &args);
+bool SHA384Sum(const std::vector<std::string> &args);
+bool SHA512Sum(const std::vector<std::string> &args);
 #endif
 bool DoPKCS12(const std::vector<std::string> &args);
 bool Speed(const std::vector<std::string> &args);
 
+typedef bool (*tool_func_t)(const std::vector<std::string> &args);
+
+struct Tool {
+  char name[16];
+  tool_func_t func;
+};
+
+static const Tool kTools[] = {
+  { "speed", Speed },
+  { "pkcs12", DoPKCS12 },
+#if !defined(OPENSSL_WINDOWS)
+  { "client", Client },
+  { "s_client", Client },
+  { "server", Server },
+  { "s_server", Server },
+  { "md5sum", MD5Sum },
+  { "sha1sum", SHA1Sum },
+  { "sha224sum", SHA224Sum },
+  { "sha256sum", SHA256Sum },
+  { "sha384sum", SHA384Sum },
+  { "sha512sum", SHA512Sum },
+#endif
+  { "", nullptr },
+};
+
 static void usage(const char *name) {
-  printf("Usage: %s [speed|client|server|pkcs12]\n", name);
+  printf("Usage: %s [", name);
+
+  for (size_t i = 0;; i++) {
+    const Tool &tool = kTools[i];
+    if (tool.func == nullptr) {
+      break;
+    }
+    if (i > 0) {
+      printf("|");
+    }
+    printf("%s", tool.name);
+  }
+  printf("]\n");
+}
+
+tool_func_t FindTool(const std::string &name) {
+  for (size_t i = 0;; i++) {
+    const Tool &tool = kTools[i];
+    if (tool.func == nullptr || name == tool.name) {
+      return tool.func;
+    }
+  }
 }
 
 int main(int argc, char **argv) {
-  std::string tool;
-  if (argc >= 2) {
-    tool = argv[1];
-  }
-
   SSL_library_init();
 
-  std::vector<std::string> args;
-  for (int i = 2; i < argc; i++) {
-    args.push_back(argv[i]);
-  }
-
-  if (tool == "speed") {
-    return !Speed(args);
+  int starting_arg = 1;
+  tool_func_t tool = nullptr;
 #if !defined(OPENSSL_WINDOWS)
-  } else if (tool == "s_client" || tool == "client") {
-    return !Client(args);
-  } else if (tool == "s_server" || tool == "server") {
-    return !Server(args);
+  tool = FindTool(basename(argv[0]));
 #endif
-  } else if (tool == "pkcs12") {
-    return !DoPKCS12(args);
-  } else {
+  if (tool == nullptr) {
+    starting_arg++;
+    if (argc > 1) {
+      tool = FindTool(argv[1]);
+    }
+  }
+  if (tool == nullptr) {
     usage(argv[0]);
     return 1;
   }
+
+  std::vector<std::string> args;
+  for (int i = starting_arg; i < argc; i++) {
+    args.push_back(argv[i]);
+  }
+
+  return !tool(args);
 }