about summary refs log tree commit diff
diff options
context:
space:
mode:
-rw-r--r--third_party/nix/src/libutil/CMakeLists.txt1
-rw-r--r--third_party/nix/src/libutil/hash.cc120
-rw-r--r--third_party/nix/src/libutil/hash.hh10
-rw-r--r--third_party/nix/src/libutil/types.hh2
-rw-r--r--third_party/nix/src/tests/hash_test.cc42
5 files changed, 123 insertions, 52 deletions
diff --git a/third_party/nix/src/libutil/CMakeLists.txt b/third_party/nix/src/libutil/CMakeLists.txt
index 8713d7e9b86d..db504940a07e 100644
--- a/third_party/nix/src/libutil/CMakeLists.txt
+++ b/third_party/nix/src/libutil/CMakeLists.txt
@@ -48,6 +48,7 @@ target_sources(nixutil
 
 target_link_libraries(nixutil
   absl::strings
+  absl::statusor
   glog
   BZip2::BZip2
   LibLZMA::LibLZMA
diff --git a/third_party/nix/src/libutil/hash.cc b/third_party/nix/src/libutil/hash.cc
index 5596ef01784f..50169b0f19c7 100644
--- a/third_party/nix/src/libutil/hash.cc
+++ b/third_party/nix/src/libutil/hash.cc
@@ -4,6 +4,7 @@
 #include <iostream>
 
 #include <absl/strings/escaping.h>
+#include <absl/strings/str_format.h>
 #include <fcntl.h>
 #include <openssl/md5.h>
 #include <openssl/sha.h>
@@ -75,8 +76,18 @@ static std::string printHash16(const Hash& hash) {
   return std::string(buf, hash.hashSize * 2);
 }
 
+bool Hash::IsValidBase16(absl::string_view s) {
+  for (char c : s) {
+    if ('0' <= c && c <= '9') continue;
+    if ('a' <= c && c <= 'f') continue;
+    if ('A' <= c && c <= 'F') continue;
+    return false;
+  }
+  return true;
+}
+
 // omitted: E O U T
-const std::string base32Chars = "0123456789abcdfghijklmnpqrsvwxyz";
+constexpr char base32Chars[] = "0123456789abcdfghijklmnpqrsvwxyz";
 
 constexpr signed char kUnBase32[] = {
     -1, -1, -1, -1, -1, -1, -1, -1, /* unprintables */
@@ -167,6 +178,15 @@ std::string Hash::to_string(Base base, bool includeType) const {
 }
 
 Hash::Hash(const std::string& s, HashType type) : type(type) {
+  absl::StatusOr<Hash> result = deserialize(s, type);
+  if (result.ok()) {
+    *this = *result;
+  } else {
+    throw BadHash(result.status().message());
+  }
+}
+
+absl::StatusOr<Hash> Hash::deserialize(const std::string& s, HashType type) {
   size_t pos = 0;
   bool isSRI = false;
 
@@ -176,90 +196,88 @@ Hash::Hash(const std::string& s, HashType type) : type(type) {
     if (sep != std::string::npos) {
       isSRI = true;
     } else if (type == htUnknown) {
-      throw BadHash("hash '%s' does not include a type", s);
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash string '", s, " does not include a type"));
     }
   }
 
+  HashType parsedType = type;
   if (sep != std::string::npos) {
     std::string hts = std::string(s, 0, sep);
-    this->type = parseHashType(hts);
-    if (this->type == htUnknown) {
-      throw BadHash("unknown hash type '%s'", hts);
-    }
-    if (type != htUnknown && type != this->type) {
-      throw BadHash("hash '%s' should have type '%s'", s, printHashType(type));
+    parsedType = parseHashType(hts);
+    if (parsedType != type) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash '", s, "' should have type '", printHashType(type),
+                       "', found '", printHashType(parsedType), "'"));
     }
     pos = sep + 1;
   }
 
-  init();
+  Hash dest(parsedType);
 
   size_t size = s.size() - pos;
+  absl::string_view sv(s.data() + pos, size);
 
-  if (!isSRI && size == base16Len()) {
-    auto parseHexDigit = [&](char c) {
-      if (c >= '0' && c <= '9') {
-        return c - '0';
-      }
-      if (c >= 'A' && c <= 'F') {
-        return c - 'A' + 10;
-      }
-      if (c >= 'a' && c <= 'f') {
-        return c - 'a' + 10;
-      }
-      throw BadHash("invalid base-16 hash '%s'", s);
-    };
-
-    for (unsigned int i = 0; i < hashSize; i++) {
-      hash[i] = parseHexDigit(s[pos + i * 2]) << 4 |
-                parseHexDigit(s[pos + i * 2 + 1]);
+  if (!isSRI && size == dest.base16Len()) {
+    std::string bytes;
+    if (!IsValidBase16(sv)) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("invalid base-16 hash: bad character in '", s, "'"));
     }
+    bytes = absl::HexStringToBytes(sv);
+    if (bytes.size() != dest.hashSize) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash '", s, "' has wrong length for base16 ",
+                       printHashType(dest.type)));
+    }
+    memcpy(dest.hash, bytes.data(), dest.hashSize);
   }
 
-  else if (!isSRI && size == base32Len()) {
+  else if (!isSRI && size == dest.base32Len()) {
     for (unsigned int n = 0; n < size; ++n) {
-      char c = s[pos + size - n - 1];
-      unsigned char digit = 0;
-      for (digit = 0; digit < base32Chars.size(); ++digit) { /* !!! slow */
-        if (base32Chars[digit] == c) {
-          break;
-        }
-      }
-      if (digit >= 32) {
-        throw BadHash("invalid base-32 hash '%s'", s);
+      char c = sv[size - n - 1];
+      // range: -1, 0..31
+      signed char digit = kUnBase32[static_cast<unsigned char>(c)];
+      if (digit < 0) {
+        return absl::InvalidArgumentError(
+            absl::StrCat("invalid base-32 hash: bad character ",
+                         absl::CEscape(absl::string_view(&c, 1))));
       }
       unsigned int b = n * 5;
       unsigned int i = b / 8;
       unsigned int j = b % 8;
-      hash[i] |= digit << j;
+      dest.hash[i] |= digit << j;
 
-      if (i < hashSize - 1) {
-        hash[i + 1] |= digit >> (8 - j);
+      if (i < dest.hashSize - 1) {
+        dest.hash[i + 1] |= digit >> (8 - j);
       } else {
         if ((digit >> (8 - j)) != 0) {
-          throw BadHash("invalid base-32 hash '%s'", s);
+          return absl::InvalidArgumentError(
+              absl::StrCat("invalid base-32 hash '", s, "'"));
         }
       }
     }
   }
 
-  else if (isSRI || size == base64Len()) {
-    std::string d;
-    if (!absl::Base64Unescape(std::string(s, pos), &d)) {
-      // TODO(grfn): replace this with StatusOr
-      throw Error("Invalid Base64");
+  else if (isSRI || size == dest.base64Len()) {
+    std::string decoded;
+    if (!absl::Base64Unescape(sv, &decoded)) {
+      return absl::InvalidArgumentError("invalid base-64 hash");
     }
-    if (d.size() != hashSize) {
-      throw BadHash("invalid %s hash '%s'", isSRI ? "SRI" : "base-64", s);
+    if (decoded.size() != dest.hashSize) {
+      return absl::InvalidArgumentError(
+          absl::StrCat("hash '", s, "' has wrong length for base64 ",
+                       printHashType(dest.type)));
     }
-    assert(hashSize);
-    memcpy(hash, d.data(), hashSize);
+    memcpy(dest.hash, decoded.data(), dest.hashSize);
   }
 
   else {
-    throw BadHash("hash '%s' has wrong length for hash type '%s'", s,
-                  printHashType(type));
+    return absl::InvalidArgumentError(absl::StrCat(
+        "hash '", s, "' has wrong length for ", printHashType(dest.type)));
   }
+
+  return dest;
 }
 
 union Ctx {
diff --git a/third_party/nix/src/libutil/hash.hh b/third_party/nix/src/libutil/hash.hh
index 58f808896fe2..0b7b11edd072 100644
--- a/third_party/nix/src/libutil/hash.hh
+++ b/third_party/nix/src/libutil/hash.hh
@@ -1,5 +1,7 @@
 #pragma once
 
+#include <absl/status/statusor.h>
+
 #include "libutil/serialise.hh"
 #include "libutil/types.hh"
 
@@ -36,6 +38,10 @@ struct Hash {
      string. */
   Hash(const std::string& s, HashType type = htUnknown);
 
+  /* Status-returning version of above constructor */
+  static absl::StatusOr<Hash> deserialize(const std::string& s,
+                                          HashType type = htUnknown);
+
   void init();
 
   /* Check whether a hash is set. */
@@ -64,6 +70,10 @@ struct Hash {
      (e.g. "sha256:"). */
   std::string to_string(Base base = Base32, bool includeType = true) const;
 
+  /* Returns whether the passed string contains entirely valid base16
+     characters. */
+  static bool IsValidBase16(absl::string_view s);
+
   /* Returns whether the passed string contains entirely valid base32
      characters. */
   static bool IsValidBase32(absl::string_view s);
diff --git a/third_party/nix/src/libutil/types.hh b/third_party/nix/src/libutil/types.hh
index e2ea86fdcf5f..3d37e4efee37 100644
--- a/third_party/nix/src/libutil/types.hh
+++ b/third_party/nix/src/libutil/types.hh
@@ -44,6 +44,8 @@ struct FormatOrString {
 
 inline std::string fmt(const std::string& s) { return s; }
 
+inline std::string fmt(std::string_view s) { return std::string(s); }
+
 inline std::string fmt(const char* s) { return s; }
 
 inline std::string fmt(const FormatOrString& fs) { return fs.s; }
diff --git a/third_party/nix/src/tests/hash_test.cc b/third_party/nix/src/tests/hash_test.cc
index 2ed4dca3bd44..ea10e7b700b1 100644
--- a/third_party/nix/src/tests/hash_test.cc
+++ b/third_party/nix/src/tests/hash_test.cc
@@ -1,12 +1,16 @@
 #include "libutil/hash.hh"
 
+#include <gmock/gmock.h>
 #include <gtest/gtest.h>
 
 class HashTest : public ::testing::Test {};
 
+using testing::EndsWith;
+using testing::HasSubstr;
+
 namespace nix {
 
-TEST(HASH_TEST, SHA256) {
+TEST(HashTest, SHA256) {
   auto hash = hashString(HashType::htSHA256, "foo");
   ASSERT_EQ(hash.base64Len(), 44);
   ASSERT_EQ(hash.base32Len(), 52);
@@ -40,4 +44,40 @@ TEST(HashTest, SHA256Decode) {
   ASSERT_EQ(hash, *base64);
 }
 
+TEST(HashTest, SHA256DecodeFail) {
+  EXPECT_THAT(
+      Hash::deserialize("sha256:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm56==",
+                        HashType::htSHA256)
+          .status()
+          .message(),
+      HasSubstr("wrong length"));
+  EXPECT_THAT(
+      Hash::deserialize("sha256:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm56,=",
+                        HashType::htSHA256)
+          .status()
+          .message(),
+      HasSubstr("invalid base-64"));
+
+  EXPECT_THAT(Hash::deserialize(
+                  "sha256:1bp7cri8hplaz6hbz0v4f0nl44rl84q1sg25kgwqzipzd1mv89i",
+                  HashType::htSHA256)
+                  .status()
+                  .message(),
+              HasSubstr("wrong length"));
+  absl::StatusOr<Hash> badB32Char = Hash::deserialize(
+      "sha256:1bp7cri8hplaz6hbz0v4f0nl44rl84q1sg25kgwqzipzd1mv89i,",
+      HashType::htSHA256);
+  EXPECT_THAT(badB32Char.status().message(), HasSubstr("invalid base-32"));
+  EXPECT_THAT(badB32Char.status().message(), EndsWith(","));
+
+  EXPECT_THAT(
+      Hash::deserialize(
+          "sha256:"
+          "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7  ",
+          HashType::htSHA256)
+          .status()
+          .message(),
+      HasSubstr("invalid base-16"));
+}
+
 }  // namespace nix