diff options
Diffstat (limited to 'third_party')
-rw-r--r-- | third_party/nix/src/libutil/CMakeLists.txt | 1 | ||||
-rw-r--r-- | third_party/nix/src/libutil/hash.cc | 120 | ||||
-rw-r--r-- | third_party/nix/src/libutil/hash.hh | 10 | ||||
-rw-r--r-- | third_party/nix/src/libutil/types.hh | 2 | ||||
-rw-r--r-- | third_party/nix/src/tests/hash_test.cc | 42 |
5 files changed, 123 insertions, 52 deletions
diff --git a/third_party/nix/src/libutil/CMakeLists.txt b/third_party/nix/src/libutil/CMakeLists.txt index 8713d7e9b86d..db504940a07e 100644 --- a/third_party/nix/src/libutil/CMakeLists.txt +++ b/third_party/nix/src/libutil/CMakeLists.txt @@ -48,6 +48,7 @@ target_sources(nixutil target_link_libraries(nixutil absl::strings + absl::statusor glog BZip2::BZip2 LibLZMA::LibLZMA diff --git a/third_party/nix/src/libutil/hash.cc b/third_party/nix/src/libutil/hash.cc index 5596ef01784f..50169b0f19c7 100644 --- a/third_party/nix/src/libutil/hash.cc +++ b/third_party/nix/src/libutil/hash.cc @@ -4,6 +4,7 @@ #include <iostream> #include <absl/strings/escaping.h> +#include <absl/strings/str_format.h> #include <fcntl.h> #include <openssl/md5.h> #include <openssl/sha.h> @@ -75,8 +76,18 @@ static std::string printHash16(const Hash& hash) { return std::string(buf, hash.hashSize * 2); } +bool Hash::IsValidBase16(absl::string_view s) { + for (char c : s) { + if ('0' <= c && c <= '9') continue; + if ('a' <= c && c <= 'f') continue; + if ('A' <= c && c <= 'F') continue; + return false; + } + return true; +} + // omitted: E O U T -const std::string base32Chars = "0123456789abcdfghijklmnpqrsvwxyz"; +constexpr char base32Chars[] = "0123456789abcdfghijklmnpqrsvwxyz"; constexpr signed char kUnBase32[] = { -1, -1, -1, -1, -1, -1, -1, -1, /* unprintables */ @@ -167,6 +178,15 @@ std::string Hash::to_string(Base base, bool includeType) const { } Hash::Hash(const std::string& s, HashType type) : type(type) { + absl::StatusOr<Hash> result = deserialize(s, type); + if (result.ok()) { + *this = *result; + } else { + throw BadHash(result.status().message()); + } +} + +absl::StatusOr<Hash> Hash::deserialize(const std::string& s, HashType type) { size_t pos = 0; bool isSRI = false; @@ -176,90 +196,88 @@ Hash::Hash(const std::string& s, HashType type) : type(type) { if (sep != std::string::npos) { isSRI = true; } else if (type == htUnknown) { - throw BadHash("hash '%s' does not include a type", s); + return absl::InvalidArgumentError( + absl::StrCat("hash string '", s, " does not include a type")); } } + HashType parsedType = type; if (sep != std::string::npos) { std::string hts = std::string(s, 0, sep); - this->type = parseHashType(hts); - if (this->type == htUnknown) { - throw BadHash("unknown hash type '%s'", hts); - } - if (type != htUnknown && type != this->type) { - throw BadHash("hash '%s' should have type '%s'", s, printHashType(type)); + parsedType = parseHashType(hts); + if (parsedType != type) { + return absl::InvalidArgumentError( + absl::StrCat("hash '", s, "' should have type '", printHashType(type), + "', found '", printHashType(parsedType), "'")); } pos = sep + 1; } - init(); + Hash dest(parsedType); size_t size = s.size() - pos; + absl::string_view sv(s.data() + pos, size); - if (!isSRI && size == base16Len()) { - auto parseHexDigit = [&](char c) { - if (c >= '0' && c <= '9') { - return c - '0'; - } - if (c >= 'A' && c <= 'F') { - return c - 'A' + 10; - } - if (c >= 'a' && c <= 'f') { - return c - 'a' + 10; - } - throw BadHash("invalid base-16 hash '%s'", s); - }; - - for (unsigned int i = 0; i < hashSize; i++) { - hash[i] = parseHexDigit(s[pos + i * 2]) << 4 | - parseHexDigit(s[pos + i * 2 + 1]); + if (!isSRI && size == dest.base16Len()) { + std::string bytes; + if (!IsValidBase16(sv)) { + return absl::InvalidArgumentError( + absl::StrCat("invalid base-16 hash: bad character in '", s, "'")); } + bytes = absl::HexStringToBytes(sv); + if (bytes.size() != dest.hashSize) { + return absl::InvalidArgumentError( + absl::StrCat("hash '", s, "' has wrong length for base16 ", + printHashType(dest.type))); + } + memcpy(dest.hash, bytes.data(), dest.hashSize); } - else if (!isSRI && size == base32Len()) { + else if (!isSRI && size == dest.base32Len()) { for (unsigned int n = 0; n < size; ++n) { - char c = s[pos + size - n - 1]; - unsigned char digit = 0; - for (digit = 0; digit < base32Chars.size(); ++digit) { /* !!! slow */ - if (base32Chars[digit] == c) { - break; - } - } - if (digit >= 32) { - throw BadHash("invalid base-32 hash '%s'", s); + char c = sv[size - n - 1]; + // range: -1, 0..31 + signed char digit = kUnBase32[static_cast<unsigned char>(c)]; + if (digit < 0) { + return absl::InvalidArgumentError( + absl::StrCat("invalid base-32 hash: bad character ", + absl::CEscape(absl::string_view(&c, 1)))); } unsigned int b = n * 5; unsigned int i = b / 8; unsigned int j = b % 8; - hash[i] |= digit << j; + dest.hash[i] |= digit << j; - if (i < hashSize - 1) { - hash[i + 1] |= digit >> (8 - j); + if (i < dest.hashSize - 1) { + dest.hash[i + 1] |= digit >> (8 - j); } else { if ((digit >> (8 - j)) != 0) { - throw BadHash("invalid base-32 hash '%s'", s); + return absl::InvalidArgumentError( + absl::StrCat("invalid base-32 hash '", s, "'")); } } } } - else if (isSRI || size == base64Len()) { - std::string d; - if (!absl::Base64Unescape(std::string(s, pos), &d)) { - // TODO(grfn): replace this with StatusOr - throw Error("Invalid Base64"); + else if (isSRI || size == dest.base64Len()) { + std::string decoded; + if (!absl::Base64Unescape(sv, &decoded)) { + return absl::InvalidArgumentError("invalid base-64 hash"); } - if (d.size() != hashSize) { - throw BadHash("invalid %s hash '%s'", isSRI ? "SRI" : "base-64", s); + if (decoded.size() != dest.hashSize) { + return absl::InvalidArgumentError( + absl::StrCat("hash '", s, "' has wrong length for base64 ", + printHashType(dest.type))); } - assert(hashSize); - memcpy(hash, d.data(), hashSize); + memcpy(dest.hash, decoded.data(), dest.hashSize); } else { - throw BadHash("hash '%s' has wrong length for hash type '%s'", s, - printHashType(type)); + return absl::InvalidArgumentError(absl::StrCat( + "hash '", s, "' has wrong length for ", printHashType(dest.type))); } + + return dest; } union Ctx { diff --git a/third_party/nix/src/libutil/hash.hh b/third_party/nix/src/libutil/hash.hh index 58f808896fe2..0b7b11edd072 100644 --- a/third_party/nix/src/libutil/hash.hh +++ b/third_party/nix/src/libutil/hash.hh @@ -1,5 +1,7 @@ #pragma once +#include <absl/status/statusor.h> + #include "libutil/serialise.hh" #include "libutil/types.hh" @@ -36,6 +38,10 @@ struct Hash { string. */ Hash(const std::string& s, HashType type = htUnknown); + /* Status-returning version of above constructor */ + static absl::StatusOr<Hash> deserialize(const std::string& s, + HashType type = htUnknown); + void init(); /* Check whether a hash is set. */ @@ -64,6 +70,10 @@ struct Hash { (e.g. "sha256:"). */ std::string to_string(Base base = Base32, bool includeType = true) const; + /* Returns whether the passed string contains entirely valid base16 + characters. */ + static bool IsValidBase16(absl::string_view s); + /* Returns whether the passed string contains entirely valid base32 characters. */ static bool IsValidBase32(absl::string_view s); diff --git a/third_party/nix/src/libutil/types.hh b/third_party/nix/src/libutil/types.hh index e2ea86fdcf5f..3d37e4efee37 100644 --- a/third_party/nix/src/libutil/types.hh +++ b/third_party/nix/src/libutil/types.hh @@ -44,6 +44,8 @@ struct FormatOrString { inline std::string fmt(const std::string& s) { return s; } +inline std::string fmt(std::string_view s) { return std::string(s); } + inline std::string fmt(const char* s) { return s; } inline std::string fmt(const FormatOrString& fs) { return fs.s; } diff --git a/third_party/nix/src/tests/hash_test.cc b/third_party/nix/src/tests/hash_test.cc index 2ed4dca3bd44..ea10e7b700b1 100644 --- a/third_party/nix/src/tests/hash_test.cc +++ b/third_party/nix/src/tests/hash_test.cc @@ -1,12 +1,16 @@ #include "libutil/hash.hh" +#include <gmock/gmock.h> #include <gtest/gtest.h> class HashTest : public ::testing::Test {}; +using testing::EndsWith; +using testing::HasSubstr; + namespace nix { -TEST(HASH_TEST, SHA256) { +TEST(HashTest, SHA256) { auto hash = hashString(HashType::htSHA256, "foo"); ASSERT_EQ(hash.base64Len(), 44); ASSERT_EQ(hash.base32Len(), 52); @@ -40,4 +44,40 @@ TEST(HashTest, SHA256Decode) { ASSERT_EQ(hash, *base64); } +TEST(HashTest, SHA256DecodeFail) { + EXPECT_THAT( + Hash::deserialize("sha256:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm56==", + HashType::htSHA256) + .status() + .message(), + HasSubstr("wrong length")); + EXPECT_THAT( + Hash::deserialize("sha256:LCa0a2j/xo/5m0U8HTBBNBNCLXBkg7+g+YpeiGJm56,=", + HashType::htSHA256) + .status() + .message(), + HasSubstr("invalid base-64")); + + EXPECT_THAT(Hash::deserialize( + "sha256:1bp7cri8hplaz6hbz0v4f0nl44rl84q1sg25kgwqzipzd1mv89i", + HashType::htSHA256) + .status() + .message(), + HasSubstr("wrong length")); + absl::StatusOr<Hash> badB32Char = Hash::deserialize( + "sha256:1bp7cri8hplaz6hbz0v4f0nl44rl84q1sg25kgwqzipzd1mv89i,", + HashType::htSHA256); + EXPECT_THAT(badB32Char.status().message(), HasSubstr("invalid base-32")); + EXPECT_THAT(badB32Char.status().message(), EndsWith(",")); + + EXPECT_THAT( + Hash::deserialize( + "sha256:" + "2c26b46b68ffc68ff99b453c1d30413413422d706483bfa0f98a5e886266e7 ", + HashType::htSHA256) + .status() + .message(), + HasSubstr("invalid base-16")); +} + } // namespace nix |