diff options
Diffstat (limited to 'third_party/nix/src/libutil/hash.cc')
-rw-r--r-- | third_party/nix/src/libutil/hash.cc | 120 |
1 files changed, 69 insertions, 51 deletions
diff --git a/third_party/nix/src/libutil/hash.cc b/third_party/nix/src/libutil/hash.cc index 5596ef01784f..50169b0f19c7 100644 --- a/third_party/nix/src/libutil/hash.cc +++ b/third_party/nix/src/libutil/hash.cc @@ -4,6 +4,7 @@ #include <iostream> #include <absl/strings/escaping.h> +#include <absl/strings/str_format.h> #include <fcntl.h> #include <openssl/md5.h> #include <openssl/sha.h> @@ -75,8 +76,18 @@ static std::string printHash16(const Hash& hash) { return std::string(buf, hash.hashSize * 2); } +bool Hash::IsValidBase16(absl::string_view s) { + for (char c : s) { + if ('0' <= c && c <= '9') continue; + if ('a' <= c && c <= 'f') continue; + if ('A' <= c && c <= 'F') continue; + return false; + } + return true; +} + // omitted: E O U T -const std::string base32Chars = "0123456789abcdfghijklmnpqrsvwxyz"; +constexpr char base32Chars[] = "0123456789abcdfghijklmnpqrsvwxyz"; constexpr signed char kUnBase32[] = { -1, -1, -1, -1, -1, -1, -1, -1, /* unprintables */ @@ -167,6 +178,15 @@ std::string Hash::to_string(Base base, bool includeType) const { } Hash::Hash(const std::string& s, HashType type) : type(type) { + absl::StatusOr<Hash> result = deserialize(s, type); + if (result.ok()) { + *this = *result; + } else { + throw BadHash(result.status().message()); + } +} + +absl::StatusOr<Hash> Hash::deserialize(const std::string& s, HashType type) { size_t pos = 0; bool isSRI = false; @@ -176,90 +196,88 @@ Hash::Hash(const std::string& s, HashType type) : type(type) { if (sep != std::string::npos) { isSRI = true; } else if (type == htUnknown) { - throw BadHash("hash '%s' does not include a type", s); + return absl::InvalidArgumentError( + absl::StrCat("hash string '", s, " does not include a type")); } } + HashType parsedType = type; if (sep != std::string::npos) { std::string hts = std::string(s, 0, sep); - this->type = parseHashType(hts); - if (this->type == htUnknown) { - throw BadHash("unknown hash type '%s'", hts); - } - if (type != htUnknown && type != this->type) { - throw BadHash("hash '%s' should have type '%s'", s, printHashType(type)); + parsedType = parseHashType(hts); + if (parsedType != type) { + return absl::InvalidArgumentError( + absl::StrCat("hash '", s, "' should have type '", printHashType(type), + "', found '", printHashType(parsedType), "'")); } pos = sep + 1; } - init(); + Hash dest(parsedType); size_t size = s.size() - pos; + absl::string_view sv(s.data() + pos, size); - if (!isSRI && size == base16Len()) { - auto parseHexDigit = [&](char c) { - if (c >= '0' && c <= '9') { - return c - '0'; - } - if (c >= 'A' && c <= 'F') { - return c - 'A' + 10; - } - if (c >= 'a' && c <= 'f') { - return c - 'a' + 10; - } - throw BadHash("invalid base-16 hash '%s'", s); - }; - - for (unsigned int i = 0; i < hashSize; i++) { - hash[i] = parseHexDigit(s[pos + i * 2]) << 4 | - parseHexDigit(s[pos + i * 2 + 1]); + if (!isSRI && size == dest.base16Len()) { + std::string bytes; + if (!IsValidBase16(sv)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid base-16 hash: bad character in '", s, "'")); } + bytes = absl::HexStringToBytes(sv); + if (bytes.size() != dest.hashSize) { + return absl::InvalidArgumentError( + absl::StrCat("hash '", s, "' has wrong length for base16 ", + printHashType(dest.type))); + } + memcpy(dest.hash, bytes.data(), dest.hashSize); } - else if (!isSRI && size == base32Len()) { + else if (!isSRI && size == dest.base32Len()) { for (unsigned int n = 0; n < size; ++n) { - char c = s[pos + size - n - 1]; - unsigned char digit = 0; - for (digit = 0; digit < base32Chars.size(); ++digit) { /* !!! slow */ - if (base32Chars[digit] == c) { - break; - } - } - if (digit >= 32) { - throw BadHash("invalid base-32 hash '%s'", s); + char c = sv[size - n - 1]; + // range: -1, 0..31 + signed char digit = kUnBase32[static_cast<unsigned char>(c)]; + if (digit < 0) { + return absl::InvalidArgumentError( + absl::StrCat("invalid base-32 hash: bad character ", + absl::CEscape(absl::string_view(&c, 1)))); } unsigned int b = n * 5; unsigned int i = b / 8; unsigned int j = b % 8; - hash[i] |= digit << j; + dest.hash[i] |= digit << j; - if (i < hashSize - 1) { - hash[i + 1] |= digit >> (8 - j); + if (i < dest.hashSize - 1) { + dest.hash[i + 1] |= digit >> (8 - j); } else { if ((digit >> (8 - j)) != 0) { - throw BadHash("invalid base-32 hash '%s'", s); + return absl::InvalidArgumentError( + absl::StrCat("invalid base-32 hash '", s, "'")); } } } } - else if (isSRI || size == base64Len()) { - std::string d; - if (!absl::Base64Unescape(std::string(s, pos), &d)) { - // TODO(grfn): replace this with StatusOr - throw Error("Invalid Base64"); + else if (isSRI || size == dest.base64Len()) { + std::string decoded; + if (!absl::Base64Unescape(sv, &decoded)) { + return absl::InvalidArgumentError("invalid base-64 hash"); } - if (d.size() != hashSize) { - throw BadHash("invalid %s hash '%s'", isSRI ? "SRI" : "base-64", s); + if (decoded.size() != dest.hashSize) { + return absl::InvalidArgumentError( + absl::StrCat("hash '", s, "' has wrong length for base64 ", + printHashType(dest.type))); } - assert(hashSize); - memcpy(hash, d.data(), hashSize); + memcpy(dest.hash, decoded.data(), dest.hashSize); } else { - throw BadHash("hash '%s' has wrong length for hash type '%s'", s, - printHashType(type)); + return absl::InvalidArgumentError(absl::StrCat( + "hash '", s, "' has wrong length for ", printHashType(dest.type))); } + + return dest; } union Ctx { |