about summary refs log tree commit diff
path: root/third_party/nix/src/libutil/hash.cc
diff options
context:
space:
mode:
Diffstat (limited to 'third_party/nix/src/libutil/hash.cc')
-rw-r--r--third_party/nix/src/libutil/hash.cc120
1 files changed, 69 insertions, 51 deletions
diff --git a/third_party/nix/src/libutil/hash.cc b/third_party/nix/src/libutil/hash.cc
index 5596ef0178..50169b0f19 100644
--- a/third_party/nix/src/libutil/hash.cc
+++ b/third_party/nix/src/libutil/hash.cc
@@ -4,6 +4,7 @@
 #include <iostream>
 
 #include <absl/strings/escaping.h>
+#include <absl/strings/str_format.h>
 #include <fcntl.h>
 #include <openssl/md5.h>
 #include <openssl/sha.h>
@@ -75,8 +76,18 @@ static std::string printHash16(const Hash& hash) {
   return std::string(buf, hash.hashSize * 2);
 }
 
+bool Hash::IsValidBase16(absl::string_view s) {
+  for (char c : s) {
+    if ('0' <= c && c <= '9') continue;
+    if ('a' <= c && c <= 'f') continue;
+    if ('A' <= c && c <= 'F') continue;
+    return false;
+  }
+  return true;
+}
+
 // omitted: E O U T
-const std::string base32Chars = "0123456789abcdfghijklmnpqrsvwxyz";
+constexpr char base32Chars[] = "0123456789abcdfghijklmnpqrsvwxyz";
 
 constexpr signed char kUnBase32[] = {
     -1, -1, -1, -1, -1, -1, -1, -1, /* unprintables */
@@ -167,6 +178,15 @@ std::string Hash::to_string(Base base, bool includeType) const {
 }
 
 Hash::Hash(const std::string& s, HashType type) : type(type) {
+  absl::StatusOr<Hash> result = deserialize(s, type);
+  if (result.ok()) {
+    *this = *result;
+  } else {
+    throw BadHash(result.status().message());
+  }
+}
+
+absl::StatusOr<Hash> Hash::deserialize(const std::string& s, HashType type) {
   size_t pos = 0;
   bool isSRI = false;
 
@@ -176,90 +196,88 @@ Hash::Hash(const std::string& s, HashType type) : type(type) {
     if (sep != std::string::npos) {
       isSRI = true;
     } else if (type == htUnknown) {
-      throw BadHash("hash '%s' does not include a type", s);
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash string '", s, " does not include a type"));
     }
   }
 
+  HashType parsedType = type;
   if (sep != std::string::npos) {
     std::string hts = std::string(s, 0, sep);
-    this->type = parseHashType(hts);
-    if (this->type == htUnknown) {
-      throw BadHash("unknown hash type '%s'", hts);
-    }
-    if (type != htUnknown && type != this->type) {
-      throw BadHash("hash '%s' should have type '%s'", s, printHashType(type));
+    parsedType = parseHashType(hts);
+    if (parsedType != type) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash '", s, "' should have type '", printHashType(type),
+                       "', found '", printHashType(parsedType), "'"));
     }
     pos = sep + 1;
   }
 
-  init();
+  Hash dest(parsedType);
 
   size_t size = s.size() - pos;
+  absl::string_view sv(s.data() + pos, size);
 
-  if (!isSRI && size == base16Len()) {
-    auto parseHexDigit = [&](char c) {
-      if (c >= '0' && c <= '9') {
-        return c - '0';
-      }
-      if (c >= 'A' && c <= 'F') {
-        return c - 'A' + 10;
-      }
-      if (c >= 'a' && c <= 'f') {
-        return c - 'a' + 10;
-      }
-      throw BadHash("invalid base-16 hash '%s'", s);
-    };
-
-    for (unsigned int i = 0; i < hashSize; i++) {
-      hash[i] = parseHexDigit(s[pos + i * 2]) << 4 |
-                parseHexDigit(s[pos + i * 2 + 1]);
+  if (!isSRI && size == dest.base16Len()) {
+    std::string bytes;
+    if (!IsValidBase16(sv)) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("invalid base-16 hash: bad character in '", s, "'"));
     }
+    bytes = absl::HexStringToBytes(sv);
+    if (bytes.size() != dest.hashSize) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash '", s, "' has wrong length for base16 ",
+                       printHashType(dest.type)));
+    }
+    memcpy(dest.hash, bytes.data(), dest.hashSize);
   }
 
-  else if (!isSRI && size == base32Len()) {
+  else if (!isSRI && size == dest.base32Len()) {
     for (unsigned int n = 0; n < size; ++n) {
-      char c = s[pos + size - n - 1];
-      unsigned char digit = 0;
-      for (digit = 0; digit < base32Chars.size(); ++digit) { /* !!! slow */
-        if (base32Chars[digit] == c) {
-          break;
-        }
-      }
-      if (digit >= 32) {
-        throw BadHash("invalid base-32 hash '%s'", s);
+      char c = sv[size - n - 1];
+      // range: -1, 0..31
+      signed char digit = kUnBase32[static_cast<unsigned char>(c)];
+      if (digit < 0) {
+        return absl::InvalidArgumentError(
+            absl::StrCat("invalid base-32 hash: bad character ",
+                         absl::CEscape(absl::string_view(&c, 1))));
       }
       unsigned int b = n * 5;
       unsigned int i = b / 8;
       unsigned int j = b % 8;
-      hash[i] |= digit << j;
+      dest.hash[i] |= digit << j;
 
-      if (i < hashSize - 1) {
-        hash[i + 1] |= digit >> (8 - j);
+      if (i < dest.hashSize - 1) {
+        dest.hash[i + 1] |= digit >> (8 - j);
       } else {
         if ((digit >> (8 - j)) != 0) {
-          throw BadHash("invalid base-32 hash '%s'", s);
+          return absl::InvalidArgumentError(
+              absl::StrCat("invalid base-32 hash '", s, "'"));
         }
       }
     }
   }
 
-  else if (isSRI || size == base64Len()) {
-    std::string d;
-    if (!absl::Base64Unescape(std::string(s, pos), &d)) {
-      // TODO(grfn): replace this with StatusOr
-      throw Error("Invalid Base64");
+  else if (isSRI || size == dest.base64Len()) {
+    std::string decoded;
+    if (!absl::Base64Unescape(sv, &decoded)) {
+      return absl::InvalidArgumentError("invalid base-64 hash");
     }
-    if (d.size() != hashSize) {
-      throw BadHash("invalid %s hash '%s'", isSRI ? "SRI" : "base-64", s);
+    if (decoded.size() != dest.hashSize) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash '", s, "' has wrong length for base64 ",
+                       printHashType(dest.type)));
     }
-    assert(hashSize);
-    memcpy(hash, d.data(), hashSize);
+    memcpy(dest.hash, decoded.data(), dest.hashSize);
   }
 
   else {
-    throw BadHash("hash '%s' has wrong length for hash type '%s'", s,
-                  printHashType(type));
+    return absl::InvalidArgumentError(absl::StrCat(
+        "hash '", s, "' has wrong length for ", printHashType(dest.type)));
   }
+
+  return dest;
 }
 
 union Ctx {