about summary refs log tree commit diff
path: root/third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc
diff options
context:
space:
mode:
authorVincent Ambo <mail@tazj.in>2020-11-21T13·43+0100
committerVincent Ambo <mail@tazj.in>2020-11-21T14·48+0100
commit082c006c04343a78d87b6c6ab3608c25d6213c3f (patch)
tree16e6f04f8d1d1d2d67e8e917d5e7bb48c1b60375 /third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc
parentcc27324d0226953943f408ce3c69ad7d648e005e (diff)
merge(3p/absl): subtree merge of Abseil up to e19260f r/1889
... notably, this includes Abseil's own StatusOr type, which
conflicted with our implementation (that was taken from TensorFlow).

Change-Id: Ie7d6764b64055caaeb8dc7b6b9d066291e6b538f
Diffstat (limited to 'third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc')
-rw-r--r--third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc303
1 files changed, 289 insertions, 14 deletions
diff --git a/third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc b/third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc
index 10e4695411..0ded0a66af 100644
--- a/third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc
+++ b/third_party/abseil_cpp/absl/strings/internal/str_format/float_conversion.cc
@@ -1,3 +1,17 @@
+// Copyright 2020 The Abseil Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//      https://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
 #include "absl/strings/internal/str_format/float_conversion.h"
 
 #include <string.h>
@@ -15,6 +29,7 @@
 #include "absl/functional/function_ref.h"
 #include "absl/meta/type_traits.h"
 #include "absl/numeric/int128.h"
+#include "absl/strings/numbers.h"
 #include "absl/types/optional.h"
 #include "absl/types/span.h"
 
@@ -119,7 +134,7 @@ class BinaryToDecimal {
     assert(exp > 0);
     assert(exp <= std::numeric_limits<long double>::max_exponent);
     static_assert(
-        StackArray::kMaxCapacity >=
+        static_cast<int>(StackArray::kMaxCapacity) >=
             ChunksNeeded(std::numeric_limits<long double>::max_exponent),
         "");
 
@@ -204,7 +219,7 @@ class BinaryToDecimal {
   }
 
  private:
-  static constexpr size_t kDigitsPerChunk = 9;
+  static constexpr int kDigitsPerChunk = 9;
 
   int decimal_start_;
   int decimal_end_;
@@ -441,8 +456,10 @@ struct Padding {
 };
 
 Padding ExtraWidthToPadding(size_t total_size, const FormatState &state) {
-  if (state.conv.width() < 0 || state.conv.width() <= total_size)
+  if (state.conv.width() < 0 ||
+      static_cast<size_t>(state.conv.width()) <= total_size) {
     return {0, 0, 0};
+  }
   int missing_chars = state.conv.width() - total_size;
   if (state.conv.has_left_flag()) {
     return {0, 0, missing_chars};
@@ -453,26 +470,31 @@ Padding ExtraWidthToPadding(size_t total_size, const FormatState &state) {
   }
 }
 
-void FinalPrint(absl::string_view data, int trailing_zeros,
-                const FormatState &state) {
+void FinalPrint(const FormatState &state, absl::string_view data,
+                int padding_offset, int trailing_zeros,
+                absl::string_view data_postfix) {
   if (state.conv.width() < 0) {
     // No width specified. Fast-path.
     if (state.sign_char != '\0') state.sink->Append(1, state.sign_char);
     state.sink->Append(data);
     state.sink->Append(trailing_zeros, '0');
+    state.sink->Append(data_postfix);
     return;
   }
 
-  auto padding =
-      ExtraWidthToPadding((state.sign_char != '\0' ? 1 : 0) + data.size() +
-                              static_cast<size_t>(trailing_zeros),
-                          state);
+  auto padding = ExtraWidthToPadding((state.sign_char != '\0' ? 1 : 0) +
+                                         data.size() + data_postfix.size() +
+                                         static_cast<size_t>(trailing_zeros),
+                                     state);
 
   state.sink->Append(padding.left_spaces, ' ');
   if (state.sign_char != '\0') state.sink->Append(1, state.sign_char);
+  // Padding in general needs to be inserted somewhere in the middle of `data`.
+  state.sink->Append(data.substr(0, padding_offset));
   state.sink->Append(padding.zeros, '0');
-  state.sink->Append(data);
+  state.sink->Append(data.substr(padding_offset));
   state.sink->Append(trailing_zeros, '0');
+  state.sink->Append(data_postfix);
   state.sink->Append(padding.right_spaces, ' ');
 }
 
@@ -525,10 +547,11 @@ void FormatFFast(Int v, int exp, const FormatState &state) {
   // In `alt` mode (flag #) we keep the `.` even if there are no fractional
   // digits. In non-alt mode, we strip it.
   if (!state.ShouldPrintDot()) --size;
-  FinalPrint(absl::string_view(integral_digits_start, size),
+  FinalPrint(state, absl::string_view(integral_digits_start, size),
+             /*padding_offset=*/0,
              static_cast<int>(state.precision - (fractional_digits_end -
                                                  fractional_digits_start)),
-             state);
+             /*data_postfix=*/"");
 }
 
 // Slow %f formatter for when the shifted value does not fit in a uint128, and
@@ -655,6 +678,255 @@ void FormatF(Int mantissa, int exp, const FormatState &state) {
   return FormatFFast(mantissa, exp, state);
 }
 
+// Grab the group of four bits (nibble) from `n`. E.g., nibble 1 corresponds to
+// bits 4-7.
+template <typename Int>
+uint8_t GetNibble(Int n, int nibble_index) {
+  constexpr Int mask_low_nibble = Int{0xf};
+  int shift = nibble_index * 4;
+  n &= mask_low_nibble << shift;
+  return static_cast<uint8_t>((n >> shift) & 0xf);
+}
+
+// Add one to the given nibble, applying carry to higher nibbles. Returns true
+// if overflow, false otherwise.
+template <typename Int>
+bool IncrementNibble(int nibble_index, Int *n) {
+  constexpr int kShift = sizeof(Int) * 8 - 1;
+  constexpr int kNumNibbles = sizeof(Int) * 8 / 4;
+  Int before = *n >> kShift;
+  // Here we essentially want to take the number 1 and move it into the requsted
+  // nibble, then add it to *n to effectively increment the nibble. However,
+  // ASan will complain if we try to shift the 1 beyond the limits of the Int,
+  // i.e., if the nibble_index is out of range. So therefore we check for this
+  // and if we are out of range we just add 0 which leaves *n unchanged, which
+  // seems like the reasonable thing to do in that case.
+  *n += ((nibble_index >= kNumNibbles) ? 0 : (Int{1} << (nibble_index * 4)));
+  Int after = *n >> kShift;
+  return (before && !after) || (nibble_index >= kNumNibbles);
+}
+
+// Return a mask with 1's in the given nibble and all lower nibbles.
+template <typename Int>
+Int MaskUpToNibbleInclusive(int nibble_index) {
+  constexpr int kNumNibbles = sizeof(Int) * 8 / 4;
+  static const Int ones = ~Int{0};
+  return ones >> std::max(0, 4 * (kNumNibbles - nibble_index - 1));
+}
+
+// Return a mask with 1's below the given nibble.
+template <typename Int>
+Int MaskUpToNibbleExclusive(int nibble_index) {
+  return nibble_index <= 0 ? 0 : MaskUpToNibbleInclusive<Int>(nibble_index - 1);
+}
+
+template <typename Int>
+Int MoveToNibble(uint8_t nibble, int nibble_index) {
+  return Int{nibble} << (4 * nibble_index);
+}
+
+// Given mantissa size, find optimal # of mantissa bits to put in initial digit.
+//
+// In the hex representation we keep a single hex digit to the left of the dot.
+// However, the question as to how many bits of the mantissa should be put into
+// that hex digit in theory is arbitrary, but in practice it is optimal to
+// choose based on the size of the mantissa. E.g., for a `double`, there are 53
+// mantissa bits, so that means that we should put 1 bit to the left of the dot,
+// thereby leaving 52 bits to the right, which is evenly divisible by four and
+// thus all fractional digits represent actual precision. For a `long double`,
+// on the other hand, there are 64 bits of mantissa, thus we can use all four
+// bits for the initial hex digit and still have a number left over (60) that is
+// a multiple of four. Once again, the goal is to have all fractional digits
+// represent real precision.
+template <typename Float>
+constexpr int HexFloatLeadingDigitSizeInBits() {
+  return std::numeric_limits<Float>::digits % 4 > 0
+             ? std::numeric_limits<Float>::digits % 4
+             : 4;
+}
+
+// This function captures the rounding behavior of glibc for hex float
+// representations. E.g. when rounding 0x1.ab800000 to a precision of .2
+// ("%.2a") glibc will round up because it rounds toward the even number (since
+// 0xb is an odd number, it will round up to 0xc). However, when rounding at a
+// point that is not followed by 800000..., it disregards the parity and rounds
+// up if > 8 and rounds down if < 8.
+template <typename Int>
+bool HexFloatNeedsRoundUp(Int mantissa, int final_nibble_displayed,
+                          uint8_t leading) {
+  // If the last nibble (hex digit) to be displayed is the lowest on in the
+  // mantissa then that means that we don't have any further nibbles to inform
+  // rounding, so don't round.
+  if (final_nibble_displayed <= 0) {
+    return false;
+  }
+  int rounding_nibble_idx = final_nibble_displayed - 1;
+  constexpr int kTotalNibbles = sizeof(Int) * 8 / 4;
+  assert(final_nibble_displayed <= kTotalNibbles);
+  Int mantissa_up_to_rounding_nibble_inclusive =
+      mantissa & MaskUpToNibbleInclusive<Int>(rounding_nibble_idx);
+  Int eight = MoveToNibble<Int>(8, rounding_nibble_idx);
+  if (mantissa_up_to_rounding_nibble_inclusive != eight) {
+    return mantissa_up_to_rounding_nibble_inclusive > eight;
+  }
+  // Nibble in question == 8.
+  uint8_t round_if_odd = (final_nibble_displayed == kTotalNibbles)
+                             ? leading
+                             : GetNibble(mantissa, final_nibble_displayed);
+  return round_if_odd % 2 == 1;
+}
+
+// Stores values associated with a Float type needed by the FormatA
+// implementation in order to avoid templatizing that function by the Float
+// type.
+struct HexFloatTypeParams {
+  template <typename Float>
+  explicit HexFloatTypeParams(Float)
+      : min_exponent(std::numeric_limits<Float>::min_exponent - 1),
+        leading_digit_size_bits(HexFloatLeadingDigitSizeInBits<Float>()) {
+    assert(leading_digit_size_bits >= 1 && leading_digit_size_bits <= 4);
+  }
+
+  int min_exponent;
+  int leading_digit_size_bits;
+};
+
+// Hex Float Rounding. First check if we need to round; if so, then we do that
+// by manipulating (incrementing) the mantissa, that way we can later print the
+// mantissa digits by iterating through them in the same way regardless of
+// whether a rounding happened.
+template <typename Int>
+void FormatARound(bool precision_specified, const FormatState &state,
+                  uint8_t *leading, Int *mantissa, int *exp) {
+  constexpr int kTotalNibbles = sizeof(Int) * 8 / 4;
+  // Index of the last nibble that we could display given precision.
+  int final_nibble_displayed =
+      precision_specified ? std::max(0, (kTotalNibbles - state.precision)) : 0;
+  if (HexFloatNeedsRoundUp(*mantissa, final_nibble_displayed, *leading)) {
+    // Need to round up.
+    bool overflow = IncrementNibble(final_nibble_displayed, mantissa);
+    *leading += (overflow ? 1 : 0);
+    if (ABSL_PREDICT_FALSE(*leading > 15)) {
+      // We have overflowed the leading digit. This would mean that we would
+      // need two hex digits to the left of the dot, which is not allowed. So
+      // adjust the mantissa and exponent so that the result is always 1.0eXXX.
+      *leading = 1;
+      *mantissa = 0;
+      *exp += 4;
+    }
+  }
+  // Now that we have handled a possible round-up we can go ahead and zero out
+  // all the nibbles of the mantissa that we won't need.
+  if (precision_specified) {
+    *mantissa &= ~MaskUpToNibbleExclusive<Int>(final_nibble_displayed);
+  }
+}
+
+template <typename Int>
+void FormatANormalize(const HexFloatTypeParams float_traits, uint8_t *leading,
+                      Int *mantissa, int *exp) {
+  constexpr int kIntBits = sizeof(Int) * 8;
+  static const Int kHighIntBit = Int{1} << (kIntBits - 1);
+  const int kLeadDigitBitsCount = float_traits.leading_digit_size_bits;
+  // Normalize mantissa so that highest bit set is in MSB position, unless we
+  // get interrupted by the exponent threshold.
+  while (*mantissa && !(*mantissa & kHighIntBit)) {
+    if (ABSL_PREDICT_FALSE(*exp - 1 < float_traits.min_exponent)) {
+      *mantissa >>= (float_traits.min_exponent - *exp);
+      *exp = float_traits.min_exponent;
+      return;
+    }
+    *mantissa <<= 1;
+    --*exp;
+  }
+  // Extract bits for leading digit then shift them away leaving the
+  // fractional part.
+  *leading =
+      static_cast<uint8_t>(*mantissa >> (kIntBits - kLeadDigitBitsCount));
+  *exp -= (*mantissa != 0) ? kLeadDigitBitsCount : *exp;
+  *mantissa <<= kLeadDigitBitsCount;
+}
+
+template <typename Int>
+void FormatA(const HexFloatTypeParams float_traits, Int mantissa, int exp,
+             bool uppercase, const FormatState &state) {
+  // Int properties.
+  constexpr int kIntBits = sizeof(Int) * 8;
+  constexpr int kTotalNibbles = sizeof(Int) * 8 / 4;
+  // Did the user specify a precision explicitly?
+  const bool precision_specified = state.conv.precision() >= 0;
+
+  // ========== Normalize/Denormalize ==========
+  exp += kIntBits;  // make all digits fractional digits.
+  // This holds the (up to four) bits of leading digit, i.e., the '1' in the
+  // number 0x1.e6fp+2. It's always > 0 unless number is zero or denormal.
+  uint8_t leading = 0;
+  FormatANormalize(float_traits, &leading, &mantissa, &exp);
+
+  // =============== Rounding ==================
+  // Check if we need to round; if so, then we do that by manipulating
+  // (incrementing) the mantissa before beginning to print characters.
+  FormatARound(precision_specified, state, &leading, &mantissa, &exp);
+
+  // ============= Format Result ===============
+  // This buffer holds the "0x1.ab1de3" portion of "0x1.ab1de3pe+2". Compute the
+  // size with long double which is the largest of the floats.
+  constexpr size_t kBufSizeForHexFloatRepr =
+      2                                               // 0x
+      + std::numeric_limits<long double>::digits / 4  // number of hex digits
+      + 1                                             // round up
+      + 1;                                            // "." (dot)
+  char digits_buffer[kBufSizeForHexFloatRepr];
+  char *digits_iter = digits_buffer;
+  const char *const digits =
+      static_cast<const char *>("0123456789ABCDEF0123456789abcdef") +
+      (uppercase ? 0 : 16);
+
+  // =============== Hex Prefix ================
+  *digits_iter++ = '0';
+  *digits_iter++ = uppercase ? 'X' : 'x';
+
+  // ========== Non-Fractional Digit ===========
+  *digits_iter++ = digits[leading];
+
+  // ================== Dot ====================
+  // There are three reasons we might need a dot. Keep in mind that, at this
+  // point, the mantissa holds only the fractional part.
+  if ((precision_specified && state.precision > 0) ||
+      (!precision_specified && mantissa > 0) || state.conv.has_alt_flag()) {
+    *digits_iter++ = '.';
+  }
+
+  // ============ Fractional Digits ============
+  int digits_emitted = 0;
+  while (mantissa > 0) {
+    *digits_iter++ = digits[GetNibble(mantissa, kTotalNibbles - 1)];
+    mantissa <<= 4;
+    ++digits_emitted;
+  }
+  int trailing_zeros =
+      precision_specified ? state.precision - digits_emitted : 0;
+  assert(trailing_zeros >= 0);
+  auto digits_result = string_view(digits_buffer, digits_iter - digits_buffer);
+
+  // =============== Exponent ==================
+  constexpr size_t kBufSizeForExpDecRepr =
+      numbers_internal::kFastToBufferSize  // requred for FastIntToBuffer
+      + 1                                  // 'p' or 'P'
+      + 1;                                 // '+' or '-'
+  char exp_buffer[kBufSizeForExpDecRepr];
+  exp_buffer[0] = uppercase ? 'P' : 'p';
+  exp_buffer[1] = exp >= 0 ? '+' : '-';
+  numbers_internal::FastIntToBuffer(exp < 0 ? -exp : exp, exp_buffer + 2);
+
+  // ============ Assemble Result ==============
+  FinalPrint(state,           //
+             digits_result,   // 0xN.NNN...
+             2,               // offset in `data` to start padding if needed.
+             trailing_zeros,  // num remaining mantissa padding zeros
+             exp_buffer);     // exponent
+}
+
 char *CopyStringTo(absl::string_view v, char *out) {
   std::memcpy(out, v.data(), v.size());
   return out + v.size();
@@ -1103,7 +1375,10 @@ bool FloatToSink(const Float v, const FormatConversionSpecImpl &conv,
     }
   } else if (c == FormatConversionCharInternal::a ||
              c == FormatConversionCharInternal::A) {
-    return FallbackToSnprintf(v, conv, sink);
+    bool uppercase = (c == FormatConversionCharInternal::A);
+    FormatA(HexFloatTypeParams(Float{}), decomposed.mantissa,
+            decomposed.exponent, uppercase, {sign_char, precision, conv, sink});
+    return true;
   } else {
     return false;
   }
@@ -1131,7 +1406,7 @@ bool ConvertFloatImpl(long double v, const FormatConversionSpecImpl &conv,
 
 bool ConvertFloatImpl(float v, const FormatConversionSpecImpl &conv,
                       FormatSinkImpl *sink) {
-  return FloatToSink(v, conv, sink);
+  return FloatToSink(static_cast<double>(v), conv, sink);
 }
 
 bool ConvertFloatImpl(double v, const FormatConversionSpecImpl &conv,