diff options
Diffstat (limited to 'src')
44 files changed, 4317 insertions, 613 deletions
diff --git a/src/cpptoml/LICENSE b/src/cpptoml/LICENSE new file mode 100644 index 000000000000..8802c4fa5a36 --- /dev/null +++ b/src/cpptoml/LICENSE @@ -0,0 +1,18 @@ +Copyright (c) 2014 Chase Geigle + +Permission is hereby granted, free of charge, to any person obtaining a copy of +this software and associated documentation files (the "Software"), to deal in +the Software without restriction, including without limitation the rights to +use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of +the Software, and to permit persons to whom the Software is furnished to do so, +subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS +FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR +COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER +IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN +CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. diff --git a/src/cpptoml/cpptoml.h b/src/cpptoml/cpptoml.h new file mode 100644 index 000000000000..07de010b1520 --- /dev/null +++ b/src/cpptoml/cpptoml.h @@ -0,0 +1,3457 @@ +/** + * @file cpptoml.h + * @author Chase Geigle + * @date May 2013 + */ + +#ifndef _CPPTOML_H_ +#define _CPPTOML_H_ + +#include <algorithm> +#include <cassert> +#include <clocale> +#include <cstdint> +#include <cstring> +#include <fstream> +#include <iomanip> +#include <map> +#include <memory> +#include <sstream> +#include <stdexcept> +#include <string> +#include <unordered_map> +#include <vector> + +#if __cplusplus > 201103L +#define CPPTOML_DEPRECATED(reason) [[deprecated(reason)]] +#elif defined(__clang__) +#define CPPTOML_DEPRECATED(reason) __attribute__((deprecated(reason))) +#elif defined(__GNUG__) +#define CPPTOML_DEPRECATED(reason) __attribute__((deprecated)) +#elif defined(_MSC_VER) +#if _MSC_VER < 1910 +#define CPPTOML_DEPRECATED(reason) __declspec(deprecated) +#else +#define CPPTOML_DEPRECATED(reason) [[deprecated(reason)]] +#endif +#endif + +namespace cpptoml +{ +class writer; // forward declaration +class base; // forward declaration +#if defined(CPPTOML_USE_MAP) +// a std::map will ensure that entries a sorted, albeit at a slight +// performance penalty relative to the (default) unordered_map +using string_to_base_map = std::map<std::string, std::shared_ptr<base>>; +#else +// by default an unordered_map is used for best performance as the +// toml specification does not require entries to be sorted +using string_to_base_map + = std::unordered_map<std::string, std::shared_ptr<base>>; +#endif + +// if defined, `base` will retain type information in form of an enum class +// such that static_cast can be used instead of dynamic_cast +// #define CPPTOML_NO_RTTI + +template <class T> +class option +{ + public: + option() : empty_{true} + { + // nothing + } + + option(T value) : empty_{false}, value_(std::move(value)) + { + // nothing + } + + explicit operator bool() const + { + return !empty_; + } + + const T& operator*() const + { + return value_; + } + + const T* operator->() const + { + return &value_; + } + + const T& value_or(const T& alternative) const + { + if (!empty_) + return value_; + return alternative; + } + + private: + bool empty_; + T value_; +}; + +struct local_date +{ + int year = 0; + int month = 0; + int day = 0; +}; + +struct local_time +{ + int hour = 0; + int minute = 0; + int second = 0; + int microsecond = 0; +}; + +struct zone_offset +{ + int hour_offset = 0; + int minute_offset = 0; +}; + +struct local_datetime : local_date, local_time +{ +}; + +struct offset_datetime : local_datetime, zone_offset +{ + static inline struct offset_datetime from_zoned(const struct tm& t) + { + offset_datetime dt; + dt.year = t.tm_year + 1900; + dt.month = t.tm_mon + 1; + dt.day = t.tm_mday; + dt.hour = t.tm_hour; + dt.minute = t.tm_min; + dt.second = t.tm_sec; + + char buf[16]; + strftime(buf, 16, "%z", &t); + + int offset = std::stoi(buf); + dt.hour_offset = offset / 100; + dt.minute_offset = offset % 100; + return dt; + } + + CPPTOML_DEPRECATED("from_local has been renamed to from_zoned") + static inline struct offset_datetime from_local(const struct tm& t) + { + return from_zoned(t); + } + + static inline struct offset_datetime from_utc(const struct tm& t) + { + offset_datetime dt; + dt.year = t.tm_year + 1900; + dt.month = t.tm_mon + 1; + dt.day = t.tm_mday; + dt.hour = t.tm_hour; + dt.minute = t.tm_min; + dt.second = t.tm_sec; + return dt; + } +}; + +CPPTOML_DEPRECATED("datetime has been renamed to offset_datetime") +typedef offset_datetime datetime; + +class fill_guard +{ + public: + fill_guard(std::ostream& os) : os_(os), fill_{os.fill()} + { + // nothing + } + + ~fill_guard() + { + os_.fill(fill_); + } + + private: + std::ostream& os_; + std::ostream::char_type fill_; +}; + +inline std::ostream& operator<<(std::ostream& os, const local_date& dt) +{ + fill_guard g{os}; + os.fill('0'); + + using std::setw; + os << setw(4) << dt.year << "-" << setw(2) << dt.month << "-" << setw(2) + << dt.day; + + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const local_time& ltime) +{ + fill_guard g{os}; + os.fill('0'); + + using std::setw; + os << setw(2) << ltime.hour << ":" << setw(2) << ltime.minute << ":" + << setw(2) << ltime.second; + + if (ltime.microsecond > 0) + { + os << "."; + int power = 100000; + for (int curr_us = ltime.microsecond; curr_us; power /= 10) + { + auto num = curr_us / power; + os << num; + curr_us -= num * power; + } + } + + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const zone_offset& zo) +{ + fill_guard g{os}; + os.fill('0'); + + using std::setw; + + if (zo.hour_offset != 0 || zo.minute_offset != 0) + { + if (zo.hour_offset > 0) + { + os << "+"; + } + else + { + os << "-"; + } + os << setw(2) << std::abs(zo.hour_offset) << ":" << setw(2) + << std::abs(zo.minute_offset); + } + else + { + os << "Z"; + } + + return os; +} + +inline std::ostream& operator<<(std::ostream& os, const local_datetime& dt) +{ + return os << static_cast<const local_date&>(dt) << "T" + << static_cast<const local_time&>(dt); +} + +inline std::ostream& operator<<(std::ostream& os, const offset_datetime& dt) +{ + return os << static_cast<const local_datetime&>(dt) + << static_cast<const zone_offset&>(dt); +} + +template <class T, class... Ts> +struct is_one_of; + +template <class T, class V> +struct is_one_of<T, V> : std::is_same<T, V> +{ +}; + +template <class T, class V, class... Ts> +struct is_one_of<T, V, Ts...> +{ + const static bool value + = std::is_same<T, V>::value || is_one_of<T, Ts...>::value; +}; + +template <class T> +class value; + +template <class T> +struct valid_value + : is_one_of<T, std::string, int64_t, double, bool, local_date, local_time, + local_datetime, offset_datetime> +{ +}; + +template <class T, class Enable = void> +struct value_traits; + +template <class T> +struct valid_value_or_string_convertible +{ + + const static bool value = valid_value<typename std::decay<T>::type>::value + || std::is_convertible<T, std::string>::value; +}; + +template <class T> +struct value_traits<T, typename std:: + enable_if<valid_value_or_string_convertible<T>:: + value>::type> +{ + using value_type = typename std:: + conditional<valid_value<typename std::decay<T>::type>::value, + typename std::decay<T>::type, std::string>::type; + + using type = value<value_type>; + + static value_type construct(T&& val) + { + return value_type(val); + } +}; + +template <class T> +struct value_traits<T, + typename std:: + enable_if<!valid_value_or_string_convertible<T>::value + && std::is_floating_point< + typename std::decay<T>::type>::value>:: + type> +{ + using value_type = typename std::decay<T>::type; + + using type = value<double>; + + static value_type construct(T&& val) + { + return value_type(val); + } +}; + +template <class T> +struct value_traits<T, + typename std:: + enable_if<!valid_value_or_string_convertible<T>::value + && std::is_signed<typename std::decay<T>:: + type>::value>::type> +{ + using value_type = int64_t; + + using type = value<int64_t>; + + static value_type construct(T&& val) + { + if (val < (std::numeric_limits<int64_t>::min)()) + throw std::underflow_error{"constructed value cannot be " + "represented by a 64-bit signed " + "integer"}; + + if (val > (std::numeric_limits<int64_t>::max)()) + throw std::overflow_error{"constructed value cannot be represented " + "by a 64-bit signed integer"}; + + return static_cast<int64_t>(val); + } +}; + +template <class T> +struct value_traits<T, + typename std:: + enable_if<!valid_value_or_string_convertible<T>::value + && std::is_unsigned<typename std::decay<T>:: + type>::value>::type> +{ + using value_type = int64_t; + + using type = value<int64_t>; + + static value_type construct(T&& val) + { + if (val > static_cast<uint64_t>((std::numeric_limits<int64_t>::max)())) + throw std::overflow_error{"constructed value cannot be represented " + "by a 64-bit signed integer"}; + + return static_cast<int64_t>(val); + } +}; + +class array; +class table; +class table_array; + +template <class T> +struct array_of_trait +{ + using return_type = option<std::vector<T>>; +}; + +template <> +struct array_of_trait<array> +{ + using return_type = option<std::vector<std::shared_ptr<array>>>; +}; + +template <class T> +inline std::shared_ptr<typename value_traits<T>::type> make_value(T&& val); +inline std::shared_ptr<array> make_array(); +template <class T> +inline std::shared_ptr<T> make_element(); +inline std::shared_ptr<table> make_table(); +inline std::shared_ptr<table_array> make_table_array(); + +#if defined(CPPTOML_NO_RTTI) +/// Base type used to store underlying data type explicitly if RTTI is disabled +enum class base_type +{ + NONE, + STRING, + LOCAL_TIME, + LOCAL_DATE, + LOCAL_DATETIME, + OFFSET_DATETIME, + INT, + FLOAT, + BOOL, + TABLE, + ARRAY, + TABLE_ARRAY +}; + +/// Type traits class to convert C++ types to enum member +template <class T> +struct base_type_traits; + +template <> +struct base_type_traits<std::string> +{ + static const base_type type = base_type::STRING; +}; + +template <> +struct base_type_traits<local_time> +{ + static const base_type type = base_type::LOCAL_TIME; +}; + +template <> +struct base_type_traits<local_date> +{ + static const base_type type = base_type::LOCAL_DATE; +}; + +template <> +struct base_type_traits<local_datetime> +{ + static const base_type type = base_type::LOCAL_DATETIME; +}; + +template <> +struct base_type_traits<offset_datetime> +{ + static const base_type type = base_type::OFFSET_DATETIME; +}; + +template <> +struct base_type_traits<int64_t> +{ + static const base_type type = base_type::INT; +}; + +template <> +struct base_type_traits<double> +{ + static const base_type type = base_type::FLOAT; +}; + +template <> +struct base_type_traits<bool> +{ + static const base_type type = base_type::BOOL; +}; + +template <> +struct base_type_traits<table> +{ + static const base_type type = base_type::TABLE; +}; + +template <> +struct base_type_traits<array> +{ + static const base_type type = base_type::ARRAY; +}; + +template <> +struct base_type_traits<table_array> +{ + static const base_type type = base_type::TABLE_ARRAY; +}; +#endif + +/** + * A generic base TOML value used for type erasure. + */ +class base : public std::enable_shared_from_this<base> +{ + public: + virtual ~base() = default; + + virtual std::shared_ptr<base> clone() const = 0; + + /** + * Determines if the given TOML element is a value. + */ + virtual bool is_value() const + { + return false; + } + + /** + * Determines if the given TOML element is a table. + */ + virtual bool is_table() const + { + return false; + } + + /** + * Converts the TOML element into a table. + */ + std::shared_ptr<table> as_table() + { + if (is_table()) + return std::static_pointer_cast<table>(shared_from_this()); + return nullptr; + } + /** + * Determines if the TOML element is an array of "leaf" elements. + */ + virtual bool is_array() const + { + return false; + } + + /** + * Converts the TOML element to an array. + */ + std::shared_ptr<array> as_array() + { + if (is_array()) + return std::static_pointer_cast<array>(shared_from_this()); + return nullptr; + } + + /** + * Determines if the given TOML element is an array of tables. + */ + virtual bool is_table_array() const + { + return false; + } + + /** + * Converts the TOML element into a table array. + */ + std::shared_ptr<table_array> as_table_array() + { + if (is_table_array()) + return std::static_pointer_cast<table_array>(shared_from_this()); + return nullptr; + } + + /** + * Attempts to coerce the TOML element into a concrete TOML value + * of type T. + */ + template <class T> + std::shared_ptr<value<T>> as(); + + template <class T> + std::shared_ptr<const value<T>> as() const; + + template <class Visitor, class... Args> + void accept(Visitor&& visitor, Args&&... args) const; + +#if defined(CPPTOML_NO_RTTI) + base_type type() const + { + return type_; + } + + protected: + base(const base_type t) : type_(t) + { + // nothing + } + + private: + const base_type type_ = base_type::NONE; + +#else + protected: + base() + { + // nothing + } +#endif +}; + +/** + * A concrete TOML value representing the "leaves" of the "tree". + */ +template <class T> +class value : public base +{ + struct make_shared_enabler + { + // nothing; this is a private key accessible only to friends + }; + + template <class U> + friend std::shared_ptr<typename value_traits<U>::type> + cpptoml::make_value(U&& val); + + public: + static_assert(valid_value<T>::value, "invalid value type"); + + std::shared_ptr<base> clone() const override; + + value(const make_shared_enabler&, const T& val) : value(val) + { + // nothing; note that users cannot actually invoke this function + // because they lack access to the make_shared_enabler. + } + + bool is_value() const override + { + return true; + } + + /** + * Gets the data associated with this value. + */ + T& get() + { + return data_; + } + + /** + * Gets the data associated with this value. Const version. + */ + const T& get() const + { + return data_; + } + + private: + T data_; + + /** + * Constructs a value from the given data. + */ +#if defined(CPPTOML_NO_RTTI) + value(const T& val) : base(base_type_traits<T>::type), data_(val) + { + } +#else + value(const T& val) : data_(val) + { + } +#endif + + value(const value& val) = delete; + value& operator=(const value& val) = delete; +}; + +template <class T> +std::shared_ptr<typename value_traits<T>::type> make_value(T&& val) +{ + using value_type = typename value_traits<T>::type; + using enabler = typename value_type::make_shared_enabler; + return std::make_shared<value_type>( + enabler{}, value_traits<T>::construct(std::forward<T>(val))); +} + +template <class T> +inline std::shared_ptr<value<T>> base::as() +{ +#if defined(CPPTOML_NO_RTTI) + if (type() == base_type_traits<T>::type) + return std::static_pointer_cast<value<T>>(shared_from_this()); + else + return nullptr; +#else + return std::dynamic_pointer_cast<value<T>>(shared_from_this()); +#endif +} + +// special case value<double> to allow getting an integer parameter as a +// double value +template <> +inline std::shared_ptr<value<double>> base::as() +{ +#if defined(CPPTOML_NO_RTTI) + if (type() == base_type::FLOAT) + return std::static_pointer_cast<value<double>>(shared_from_this()); + + if (type() == base_type::INT) + { + auto v = std::static_pointer_cast<value<int64_t>>(shared_from_this()); + return make_value<double>(static_cast<double>(v->get()));; + } +#else + if (auto v = std::dynamic_pointer_cast<value<double>>(shared_from_this())) + return v; + + if (auto v = std::dynamic_pointer_cast<value<int64_t>>(shared_from_this())) + return make_value<double>(static_cast<double>(v->get())); +#endif + + return nullptr; +} + +template <class T> +inline std::shared_ptr<const value<T>> base::as() const +{ +#if defined(CPPTOML_NO_RTTI) + if (type() == base_type_traits<T>::type) + return std::static_pointer_cast<const value<T>>(shared_from_this()); + else + return nullptr; +#else + return std::dynamic_pointer_cast<const value<T>>(shared_from_this()); +#endif +} + +// special case value<double> to allow getting an integer parameter as a +// double value +template <> +inline std::shared_ptr<const value<double>> base::as() const +{ +#if defined(CPPTOML_NO_RTTI) + if (type() == base_type::FLOAT) + return std::static_pointer_cast<const value<double>>(shared_from_this()); + + if (type() == base_type::INT) + { + auto v = as<int64_t>(); + // the below has to be a non-const value<double> due to a bug in + // libc++: https://llvm.org/bugs/show_bug.cgi?id=18843 + return make_value<double>(static_cast<double>(v->get())); + } +#else + if (auto v + = std::dynamic_pointer_cast<const value<double>>(shared_from_this())) + return v; + + if (auto v = as<int64_t>()) + { + // the below has to be a non-const value<double> due to a bug in + // libc++: https://llvm.org/bugs/show_bug.cgi?id=18843 + return make_value<double>(static_cast<double>(v->get())); + } +#endif + + return nullptr; +} + +/** + * Exception class for array insertion errors. + */ +class array_exception : public std::runtime_error +{ + public: + array_exception(const std::string& err) : std::runtime_error{err} + { + } +}; + +class array : public base +{ + public: + friend std::shared_ptr<array> make_array(); + + std::shared_ptr<base> clone() const override; + + virtual bool is_array() const override + { + return true; + } + + using size_type = std::size_t; + + /** + * arrays can be iterated over + */ + using iterator = std::vector<std::shared_ptr<base>>::iterator; + + /** + * arrays can be iterated over. Const version. + */ + using const_iterator = std::vector<std::shared_ptr<base>>::const_iterator; + + iterator begin() + { + return values_.begin(); + } + + const_iterator begin() const + { + return values_.begin(); + } + + iterator end() + { + return values_.end(); + } + + const_iterator end() const + { + return values_.end(); + } + + /** + * Obtains the array (vector) of base values. + */ + std::vector<std::shared_ptr<base>>& get() + { + return values_; + } + + /** + * Obtains the array (vector) of base values. Const version. + */ + const std::vector<std::shared_ptr<base>>& get() const + { + return values_; + } + + std::shared_ptr<base> at(size_t idx) const + { + return values_.at(idx); + } + + /** + * Obtains an array of value<T>s. Note that elements may be + * nullptr if they cannot be converted to a value<T>. + */ + template <class T> + std::vector<std::shared_ptr<value<T>>> array_of() const + { + std::vector<std::shared_ptr<value<T>>> result(values_.size()); + + std::transform(values_.begin(), values_.end(), result.begin(), + [&](std::shared_ptr<base> v) { return v->as<T>(); }); + + return result; + } + + /** + * Obtains a option<vector<T>>. The option will be empty if the array + * contains values that are not of type T. + */ + template <class T> + inline typename array_of_trait<T>::return_type get_array_of() const + { + std::vector<T> result; + result.reserve(values_.size()); + + for (const auto& val : values_) + { + if (auto v = val->as<T>()) + result.push_back(v->get()); + else + return {}; + } + + return {std::move(result)}; + } + + /** + * Obtains an array of arrays. Note that elements may be nullptr + * if they cannot be converted to a array. + */ + std::vector<std::shared_ptr<array>> nested_array() const + { + std::vector<std::shared_ptr<array>> result(values_.size()); + + std::transform(values_.begin(), values_.end(), result.begin(), + [&](std::shared_ptr<base> v) -> std::shared_ptr<array> { + if (v->is_array()) + return std::static_pointer_cast<array>(v); + return std::shared_ptr<array>{}; + }); + + return result; + } + + /** + * Add a value to the end of the array + */ + template <class T> + void push_back(const std::shared_ptr<value<T>>& val) + { + if (values_.empty() || values_[0]->as<T>()) + { + values_.push_back(val); + } + else + { + throw array_exception{"Arrays must be homogenous."}; + } + } + + /** + * Add an array to the end of the array + */ + void push_back(const std::shared_ptr<array>& val) + { + if (values_.empty() || values_[0]->is_array()) + { + values_.push_back(val); + } + else + { + throw array_exception{"Arrays must be homogenous."}; + } + } + + /** + * Convenience function for adding a simple element to the end + * of the array. + */ + template <class T> + void push_back(T&& val, typename value_traits<T>::type* = 0) + { + push_back(make_value(std::forward<T>(val))); + } + + /** + * Insert a value into the array + */ + template <class T> + iterator insert(iterator position, const std::shared_ptr<value<T>>& value) + { + if (values_.empty() || values_[0]->as<T>()) + { + return values_.insert(position, value); + } + else + { + throw array_exception{"Arrays must be homogenous."}; + } + } + + /** + * Insert an array into the array + */ + iterator insert(iterator position, const std::shared_ptr<array>& value) + { + if (values_.empty() || values_[0]->is_array()) + { + return values_.insert(position, value); + } + else + { + throw array_exception{"Arrays must be homogenous."}; + } + } + + /** + * Convenience function for inserting a simple element in the array + */ + template <class T> + iterator insert(iterator position, T&& val, + typename value_traits<T>::type* = 0) + { + return insert(position, make_value(std::forward<T>(val))); + } + + /** + * Erase an element from the array + */ + iterator erase(iterator position) + { + return values_.erase(position); + } + + /** + * Clear the array + */ + void clear() + { + values_.clear(); + } + + /** + * Reserve space for n values. + */ + void reserve(size_type n) + { + values_.reserve(n); + } + + private: +#if defined(CPPTOML_NO_RTTI) + array() : base(base_type::ARRAY) + { + // empty + } +#else + array() = default; +#endif + + template <class InputIterator> + array(InputIterator begin, InputIterator end) : values_{begin, end} + { + // nothing + } + + array(const array& obj) = delete; + array& operator=(const array& obj) = delete; + + std::vector<std::shared_ptr<base>> values_; +}; + +inline std::shared_ptr<array> make_array() +{ + struct make_shared_enabler : public array + { + make_shared_enabler() + { + // nothing + } + }; + + return std::make_shared<make_shared_enabler>(); +} + +template <> +inline std::shared_ptr<array> make_element<array>() +{ + return make_array(); +} + +/** + * Obtains a option<vector<T>>. The option will be empty if the array + * contains values that are not of type T. + */ +template <> +inline typename array_of_trait<array>::return_type +array::get_array_of<array>() const +{ + std::vector<std::shared_ptr<array>> result; + result.reserve(values_.size()); + + for (const auto& val : values_) + { + if (auto v = val->as_array()) + result.push_back(v); + else + return {}; + } + + return {std::move(result)}; +} + +class table; + +class table_array : public base +{ + friend class table; + friend std::shared_ptr<table_array> make_table_array(); + + public: + std::shared_ptr<base> clone() const override; + + using size_type = std::size_t; + + /** + * arrays can be iterated over + */ + using iterator = std::vector<std::shared_ptr<table>>::iterator; + + /** + * arrays can be iterated over. Const version. + */ + using const_iterator = std::vector<std::shared_ptr<table>>::const_iterator; + + iterator begin() + { + return array_.begin(); + } + + const_iterator begin() const + { + return array_.begin(); + } + + iterator end() + { + return array_.end(); + } + + const_iterator end() const + { + return array_.end(); + } + + virtual bool is_table_array() const override + { + return true; + } + + std::vector<std::shared_ptr<table>>& get() + { + return array_; + } + + const std::vector<std::shared_ptr<table>>& get() const + { + return array_; + } + + /** + * Add a table to the end of the array + */ + void push_back(const std::shared_ptr<table>& val) + { + array_.push_back(val); + } + + /** + * Insert a table into the array + */ + iterator insert(iterator position, const std::shared_ptr<table>& value) + { + return array_.insert(position, value); + } + + /** + * Erase an element from the array + */ + iterator erase(iterator position) + { + return array_.erase(position); + } + + /** + * Clear the array + */ + void clear() + { + array_.clear(); + } + + /** + * Reserve space for n tables. + */ + void reserve(size_type n) + { + array_.reserve(n); + } + + private: +#if defined(CPPTOML_NO_RTTI) + table_array() : base(base_type::TABLE_ARRAY) + { + // nothing + } +#else + table_array() + { + // nothing + } +#endif + + table_array(const table_array& obj) = delete; + table_array& operator=(const table_array& rhs) = delete; + + std::vector<std::shared_ptr<table>> array_; +}; + +inline std::shared_ptr<table_array> make_table_array() +{ + struct make_shared_enabler : public table_array + { + make_shared_enabler() + { + // nothing + } + }; + + return std::make_shared<make_shared_enabler>(); +} + +template <> +inline std::shared_ptr<table_array> make_element<table_array>() +{ + return make_table_array(); +} + +// The below are overloads for fetching specific value types out of a value +// where special casting behavior (like bounds checking) is desired + +template <class T> +typename std::enable_if<!std::is_floating_point<T>::value + && std::is_signed<T>::value, + option<T>>::type +get_impl(const std::shared_ptr<base>& elem) +{ + if (auto v = elem->as<int64_t>()) + { + if (v->get() < (std::numeric_limits<T>::min)()) + throw std::underflow_error{ + "T cannot represent the value requested in get"}; + + if (v->get() > (std::numeric_limits<T>::max)()) + throw std::overflow_error{ + "T cannot represent the value requested in get"}; + + return {static_cast<T>(v->get())}; + } + else + { + return {}; + } +} + +template <class T> +typename std::enable_if<!std::is_same<T, bool>::value + && std::is_unsigned<T>::value, + option<T>>::type +get_impl(const std::shared_ptr<base>& elem) +{ + if (auto v = elem->as<int64_t>()) + { + if (v->get() < 0) + throw std::underflow_error{"T cannot store negative value in get"}; + + if (static_cast<uint64_t>(v->get()) > (std::numeric_limits<T>::max)()) + throw std::overflow_error{ + "T cannot represent the value requested in get"}; + + return {static_cast<T>(v->get())}; + } + else + { + return {}; + } +} + +template <class T> +typename std::enable_if<!std::is_integral<T>::value + || std::is_same<T, bool>::value, + option<T>>::type +get_impl(const std::shared_ptr<base>& elem) +{ + if (auto v = elem->as<T>()) + { + return {v->get()}; + } + else + { + return {}; + } +} + +/** + * Represents a TOML keytable. + */ +class table : public base +{ + public: + friend class table_array; + friend std::shared_ptr<table> make_table(); + + std::shared_ptr<base> clone() const override; + + /** + * tables can be iterated over. + */ + using iterator = string_to_base_map::iterator; + + /** + * tables can be iterated over. Const version. + */ + using const_iterator = string_to_base_map::const_iterator; + + iterator begin() + { + return map_.begin(); + } + + const_iterator begin() const + { + return map_.begin(); + } + + iterator end() + { + return map_.end(); + } + + const_iterator end() const + { + return map_.end(); + } + + bool is_table() const override + { + return true; + } + + bool empty() const + { + return map_.empty(); + } + + /** + * Determines if this key table contains the given key. + */ + bool contains(const std::string& key) const + { + return map_.find(key) != map_.end(); + } + + /** + * Determines if this key table contains the given key. Will + * resolve "qualified keys". Qualified keys are the full access + * path separated with dots like "grandparent.parent.child". + */ + bool contains_qualified(const std::string& key) const + { + return resolve_qualified(key); + } + + /** + * Obtains the base for a given key. + * @throw std::out_of_range if the key does not exist + */ + std::shared_ptr<base> get(const std::string& key) const + { + return map_.at(key); + } + + /** + * Obtains the base for a given key. Will resolve "qualified + * keys". Qualified keys are the full access path separated with + * dots like "grandparent.parent.child". + * + * @throw std::out_of_range if the key does not exist + */ + std::shared_ptr<base> get_qualified(const std::string& key) const + { + std::shared_ptr<base> p; + resolve_qualified(key, &p); + return p; + } + + /** + * Obtains a table for a given key, if possible. + */ + std::shared_ptr<table> get_table(const std::string& key) const + { + if (contains(key) && get(key)->is_table()) + return std::static_pointer_cast<table>(get(key)); + return nullptr; + } + + /** + * Obtains a table for a given key, if possible. Will resolve + * "qualified keys". + */ + std::shared_ptr<table> get_table_qualified(const std::string& key) const + { + if (contains_qualified(key) && get_qualified(key)->is_table()) + return std::static_pointer_cast<table>(get_qualified(key)); + return nullptr; + } + + /** + * Obtains an array for a given key. + */ + std::shared_ptr<array> get_array(const std::string& key) const + { + if (!contains(key)) + return nullptr; + return get(key)->as_array(); + } + + /** + * Obtains an array for a given key. Will resolve "qualified keys". + */ + std::shared_ptr<array> get_array_qualified(const std::string& key) const + { + if (!contains_qualified(key)) + return nullptr; + return get_qualified(key)->as_array(); + } + + /** + * Obtains a table_array for a given key, if possible. + */ + std::shared_ptr<table_array> get_table_array(const std::string& key) const + { + if (!contains(key)) + return nullptr; + return get(key)->as_table_array(); + } + + /** + * Obtains a table_array for a given key, if possible. Will resolve + * "qualified keys". + */ + std::shared_ptr<table_array> + get_table_array_qualified(const std::string& key) const + { + if (!contains_qualified(key)) + return nullptr; + return get_qualified(key)->as_table_array(); + } + + /** + * Helper function that attempts to get a value corresponding + * to the template parameter from a given key. + */ + template <class T> + option<T> get_as(const std::string& key) const + { + try + { + return get_impl<T>(get(key)); + } + catch (const std::out_of_range&) + { + return {}; + } + } + + /** + * Helper function that attempts to get a value corresponding + * to the template parameter from a given key. Will resolve "qualified + * keys". + */ + template <class T> + option<T> get_qualified_as(const std::string& key) const + { + try + { + return get_impl<T>(get_qualified(key)); + } + catch (const std::out_of_range&) + { + return {}; + } + } + + /** + * Helper function that attempts to get an array of values of a given + * type corresponding to the template parameter for a given key. + * + * If the key doesn't exist, doesn't exist as an array type, or one or + * more keys inside the array type are not of type T, an empty option + * is returned. Otherwise, an option containing a vector of the values + * is returned. + */ + template <class T> + inline typename array_of_trait<T>::return_type + get_array_of(const std::string& key) const + { + if (auto v = get_array(key)) + { + std::vector<T> result; + result.reserve(v->get().size()); + + for (const auto& b : v->get()) + { + if (auto val = b->as<T>()) + result.push_back(val->get()); + else + return {}; + } + return {std::move(result)}; + } + + return {}; + } + + /** + * Helper function that attempts to get an array of values of a given + * type corresponding to the template parameter for a given key. Will + * resolve "qualified keys". + * + * If the key doesn't exist, doesn't exist as an array type, or one or + * more keys inside the array type are not of type T, an empty option + * is returned. Otherwise, an option containing a vector of the values + * is returned. + */ + template <class T> + inline typename array_of_trait<T>::return_type + get_qualified_array_of(const std::string& key) const + { + if (auto v = get_array_qualified(key)) + { + std::vector<T> result; + result.reserve(v->get().size()); + + for (const auto& b : v->get()) + { + if (auto val = b->as<T>()) + result.push_back(val->get()); + else + return {}; + } + return {std::move(result)}; + } + + return {}; + } + + /** + * Adds an element to the keytable. + */ + void insert(const std::string& key, const std::shared_ptr<base>& value) + { + map_[key] = value; + } + + /** + * Convenience shorthand for adding a simple element to the + * keytable. + */ + template <class T> + void insert(const std::string& key, T&& val, + typename value_traits<T>::type* = 0) + { + insert(key, make_value(std::forward<T>(val))); + } + + /** + * Removes an element from the table. + */ + void erase(const std::string& key) + { + map_.erase(key); + } + + private: +#if defined(CPPTOML_NO_RTTI) + table() : base(base_type::TABLE) + { + // nothing + } +#else + table() + { + // nothing + } +#endif + + table(const table& obj) = delete; + table& operator=(const table& rhs) = delete; + + std::vector<std::string> split(const std::string& value, + char separator) const + { + std::vector<std::string> result; + std::string::size_type p = 0; + std::string::size_type q; + while ((q = value.find(separator, p)) != std::string::npos) + { + result.emplace_back(value, p, q - p); + p = q + 1; + } + result.emplace_back(value, p); + return result; + } + + // If output parameter p is specified, fill it with the pointer to the + // specified entry and throw std::out_of_range if it couldn't be found. + // + // Otherwise, just return true if the entry could be found or false + // otherwise and do not throw. + bool resolve_qualified(const std::string& key, + std::shared_ptr<base>* p = nullptr) const + { + auto parts = split(key, '.'); + auto last_key = parts.back(); + parts.pop_back(); + + auto cur_table = this; + for (const auto& part : parts) + { + cur_table = cur_table->get_table(part).get(); + if (!cur_table) + { + if (!p) + return false; + + throw std::out_of_range{key + " is not a valid key"}; + } + } + + if (!p) + return cur_table->map_.count(last_key) != 0; + + *p = cur_table->map_.at(last_key); + return true; + } + + string_to_base_map map_; +}; + +/** + * Helper function that attempts to get an array of arrays for a given + * key. + * + * If the key doesn't exist, doesn't exist as an array type, or one or + * more keys inside the array type are not of type T, an empty option + * is returned. Otherwise, an option containing a vector of the values + * is returned. + */ +template <> +inline typename array_of_trait<array>::return_type +table::get_array_of<array>(const std::string& key) const +{ + if (auto v = get_array(key)) + { + std::vector<std::shared_ptr<array>> result; + result.reserve(v->get().size()); + + for (const auto& b : v->get()) + { + if (auto val = b->as_array()) + result.push_back(val); + else + return {}; + } + + return {std::move(result)}; + } + + return {}; +} + +/** + * Helper function that attempts to get an array of arrays for a given + * key. Will resolve "qualified keys". + * + * If the key doesn't exist, doesn't exist as an array type, or one or + * more keys inside the array type are not of type T, an empty option + * is returned. Otherwise, an option containing a vector of the values + * is returned. + */ +template <> +inline typename array_of_trait<array>::return_type +table::get_qualified_array_of<array>(const std::string& key) const +{ + if (auto v = get_array_qualified(key)) + { + std::vector<std::shared_ptr<array>> result; + result.reserve(v->get().size()); + + for (const auto& b : v->get()) + { + if (auto val = b->as_array()) + result.push_back(val); + else + return {}; + } + + return {std::move(result)}; + } + + return {}; +} + +std::shared_ptr<table> make_table() +{ + struct make_shared_enabler : public table + { + make_shared_enabler() + { + // nothing + } + }; + + return std::make_shared<make_shared_enabler>(); +} + +template <> +inline std::shared_ptr<table> make_element<table>() +{ + return make_table(); +} + +template <class T> +std::shared_ptr<base> value<T>::clone() const +{ + return make_value(data_); +} + +inline std::shared_ptr<base> array::clone() const +{ + auto result = make_array(); + result->reserve(values_.size()); + for (const auto& ptr : values_) + result->values_.push_back(ptr->clone()); + return result; +} + +inline std::shared_ptr<base> table_array::clone() const +{ + auto result = make_table_array(); + result->reserve(array_.size()); + for (const auto& ptr : array_) + result->array_.push_back(ptr->clone()->as_table()); + return result; +} + +inline std::shared_ptr<base> table::clone() const +{ + auto result = make_table(); + for (const auto& pr : map_) + result->insert(pr.first, pr.second->clone()); + return result; +} + +/** + * Exception class for all TOML parsing errors. + */ +class parse_exception : public std::runtime_error +{ + public: + parse_exception(const std::string& err) : std::runtime_error{err} + { + } + + parse_exception(const std::string& err, std::size_t line_number) + : std::runtime_error{err + " at line " + std::to_string(line_number)} + { + } +}; + +inline bool is_number(char c) +{ + return c >= '0' && c <= '9'; +} + +/** + * Helper object for consuming expected characters. + */ +template <class OnError> +class consumer +{ + public: + consumer(std::string::iterator& it, const std::string::iterator& end, + OnError&& on_error) + : it_(it), end_(end), on_error_(std::forward<OnError>(on_error)) + { + // nothing + } + + void operator()(char c) + { + if (it_ == end_ || *it_ != c) + on_error_(); + ++it_; + } + + template <std::size_t N> + void operator()(const char (&str)[N]) + { + std::for_each(std::begin(str), std::end(str) - 1, + [&](char c) { (*this)(c); }); + } + + int eat_digits(int len) + { + int val = 0; + for (int i = 0; i < len; ++i) + { + if (!is_number(*it_) || it_ == end_) + on_error_(); + val = 10 * val + (*it_++ - '0'); + } + return val; + } + + void error() + { + on_error_(); + } + + private: + std::string::iterator& it_; + const std::string::iterator& end_; + OnError on_error_; +}; + +template <class OnError> +consumer<OnError> make_consumer(std::string::iterator& it, + const std::string::iterator& end, + OnError&& on_error) +{ + return consumer<OnError>(it, end, std::forward<OnError>(on_error)); +} + +// replacement for std::getline to handle incorrectly line-ended files +// https://stackoverflow.com/questions/6089231/getting-std-ifstream-to-handle-lf-cr-and-crlf +namespace detail +{ +inline std::istream& getline(std::istream& input, std::string& line) +{ + line.clear(); + + std::istream::sentry sentry{input, true}; + auto sb = input.rdbuf(); + + while (true) + { + auto c = sb->sbumpc(); + if (c == '\r') + { + if (sb->sgetc() == '\n') + c = sb->sbumpc(); + } + + if (c == '\n') + return input; + + if (c == std::istream::traits_type::eof()) + { + if (line.empty()) + input.setstate(std::ios::eofbit); + return input; + } + + line.push_back(static_cast<char>(c)); + } +} +} + +/** + * The parser class. + */ +class parser +{ + public: + /** + * Parsers are constructed from streams. + */ + parser(std::istream& stream) : input_(stream) + { + // nothing + } + + parser& operator=(const parser& parser) = delete; + + /** + * Parses the stream this parser was created on until EOF. + * @throw parse_exception if there are errors in parsing + */ + std::shared_ptr<table> parse() + { + std::shared_ptr<table> root = make_table(); + + table* curr_table = root.get(); + + while (detail::getline(input_, line_)) + { + line_number_++; + auto it = line_.begin(); + auto end = line_.end(); + consume_whitespace(it, end); + if (it == end || *it == '#') + continue; + if (*it == '[') + { + curr_table = root.get(); + parse_table(it, end, curr_table); + } + else + { + parse_key_value(it, end, curr_table); + consume_whitespace(it, end); + eol_or_comment(it, end); + } + } + return root; + } + + private: +#if defined _MSC_VER + __declspec(noreturn) +#elif defined __GNUC__ + __attribute__((noreturn)) +#endif + void throw_parse_exception(const std::string& err) + { + throw parse_exception{err, line_number_}; + } + + void parse_table(std::string::iterator& it, + const std::string::iterator& end, table*& curr_table) + { + // remove the beginning keytable marker + ++it; + if (it == end) + throw_parse_exception("Unexpected end of table"); + if (*it == '[') + parse_table_array(it, end, curr_table); + else + parse_single_table(it, end, curr_table); + } + + void parse_single_table(std::string::iterator& it, + const std::string::iterator& end, + table*& curr_table) + { + if (it == end || *it == ']') + throw_parse_exception("Table name cannot be empty"); + + std::string full_table_name; + bool inserted = false; + while (it != end && *it != ']') + { + auto part = parse_key(it, end, + [](char c) { return c == '.' || c == ']'; }); + + if (part.empty()) + throw_parse_exception("Empty component of table name"); + + if (!full_table_name.empty()) + full_table_name += "."; + full_table_name += part; + + if (curr_table->contains(part)) + { + auto b = curr_table->get(part); + if (b->is_table()) + curr_table = static_cast<table*>(b.get()); + else if (b->is_table_array()) + curr_table = std::static_pointer_cast<table_array>(b) + ->get() + .back() + .get(); + else + throw_parse_exception("Key " + full_table_name + + "already exists as a value"); + } + else + { + inserted = true; + curr_table->insert(part, make_table()); + curr_table = static_cast<table*>(curr_table->get(part).get()); + } + consume_whitespace(it, end); + if (it != end && *it == '.') + ++it; + consume_whitespace(it, end); + } + + if (it == end) + throw_parse_exception( + "Unterminated table declaration; did you forget a ']'?"); + + // table already existed + if (!inserted) + { + auto is_value + = [](const std::pair<const std::string&, + const std::shared_ptr<base>&>& p) { + return p.second->is_value(); + }; + + // if there are any values, we can't add values to this table + // since it has already been defined. If there aren't any + // values, then it was implicitly created by something like + // [a.b] + if (curr_table->empty() || std::any_of(curr_table->begin(), + curr_table->end(), is_value)) + { + throw_parse_exception("Redefinition of table " + + full_table_name); + } + } + + ++it; + consume_whitespace(it, end); + eol_or_comment(it, end); + } + + void parse_table_array(std::string::iterator& it, + const std::string::iterator& end, table*& curr_table) + { + ++it; + if (it == end || *it == ']') + throw_parse_exception("Table array name cannot be empty"); + + std::string full_ta_name; + while (it != end && *it != ']') + { + auto part = parse_key(it, end, + [](char c) { return c == '.' || c == ']'; }); + + if (part.empty()) + throw_parse_exception("Empty component of table array name"); + + if (!full_ta_name.empty()) + full_ta_name += "."; + full_ta_name += part; + + consume_whitespace(it, end); + if (it != end && *it == '.') + ++it; + consume_whitespace(it, end); + + if (curr_table->contains(part)) + { + auto b = curr_table->get(part); + + // if this is the end of the table array name, add an + // element to the table array that we just looked up + if (it != end && *it == ']') + { + if (!b->is_table_array()) + throw_parse_exception("Key " + full_ta_name + + " is not a table array"); + auto v = b->as_table_array(); + v->get().push_back(make_table()); + curr_table = v->get().back().get(); + } + // otherwise, just keep traversing down the key name + else + { + if (b->is_table()) + curr_table = static_cast<table*>(b.get()); + else if (b->is_table_array()) + curr_table = std::static_pointer_cast<table_array>(b) + ->get() + .back() + .get(); + else + throw_parse_exception("Key " + full_ta_name + + " already exists as a value"); + } + } + else + { + // if this is the end of the table array name, add a new + // table array and a new table inside that array for us to + // add keys to next + if (it != end && *it == ']') + { + curr_table->insert(part, make_table_array()); + auto arr = std::static_pointer_cast<table_array>( + curr_table->get(part)); + arr->get().push_back(make_table()); + curr_table = arr->get().back().get(); + } + // otherwise, create the implicitly defined table and move + // down to it + else + { + curr_table->insert(part, make_table()); + curr_table + = static_cast<table*>(curr_table->get(part).get()); + } + } + } + + // consume the last "]]" + if (it == end) + throw_parse_exception("Unterminated table array name"); + ++it; + if (it == end) + throw_parse_exception("Unterminated table array name"); + ++it; + + consume_whitespace(it, end); + eol_or_comment(it, end); + } + + void parse_key_value(std::string::iterator& it, std::string::iterator& end, + table* curr_table) + { + auto key = parse_key(it, end, [](char c) { return c == '='; }); + if (curr_table->contains(key)) + throw_parse_exception("Key " + key + " already present"); + if (it == end || *it != '=') + throw_parse_exception("Value must follow after a '='"); + ++it; + consume_whitespace(it, end); + curr_table->insert(key, parse_value(it, end)); + consume_whitespace(it, end); + } + + template <class Function> + std::string parse_key(std::string::iterator& it, + const std::string::iterator& end, Function&& fun) + { + consume_whitespace(it, end); + if (*it == '"') + { + return parse_quoted_key(it, end); + } + else + { + auto bke = std::find_if(it, end, std::forward<Function>(fun)); + return parse_bare_key(it, bke); + } + } + + std::string parse_bare_key(std::string::iterator& it, + const std::string::iterator& end) + { + if (it == end) + { + throw_parse_exception("Bare key missing name"); + } + + auto key_end = end; + --key_end; + consume_backwards_whitespace(key_end, it); + ++key_end; + std::string key{it, key_end}; + + if (std::find(it, key_end, '#') != key_end) + { + throw_parse_exception("Bare key " + key + " cannot contain #"); + } + + if (std::find_if(it, key_end, + [](char c) { return c == ' ' || c == '\t'; }) + != key_end) + { + throw_parse_exception("Bare key " + key + + " cannot contain whitespace"); + } + + if (std::find_if(it, key_end, + [](char c) { return c == '[' || c == ']'; }) + != key_end) + { + throw_parse_exception("Bare key " + key + + " cannot contain '[' or ']'"); + } + + it = end; + return key; + } + + std::string parse_quoted_key(std::string::iterator& it, + const std::string::iterator& end) + { + return string_literal(it, end, '"'); + } + + enum class parse_type + { + STRING = 1, + LOCAL_TIME, + LOCAL_DATE, + LOCAL_DATETIME, + OFFSET_DATETIME, + INT, + FLOAT, + BOOL, + ARRAY, + INLINE_TABLE + }; + + std::shared_ptr<base> parse_value(std::string::iterator& it, + std::string::iterator& end) + { + parse_type type = determine_value_type(it, end); + switch (type) + { + case parse_type::STRING: + return parse_string(it, end); + case parse_type::LOCAL_TIME: + return parse_time(it, end); + case parse_type::LOCAL_DATE: + case parse_type::LOCAL_DATETIME: + case parse_type::OFFSET_DATETIME: + return parse_date(it, end); + case parse_type::INT: + case parse_type::FLOAT: + return parse_number(it, end); + case parse_type::BOOL: + return parse_bool(it, end); + case parse_type::ARRAY: + return parse_array(it, end); + case parse_type::INLINE_TABLE: + return parse_inline_table(it, end); + default: + throw_parse_exception("Failed to parse value"); + } + } + + parse_type determine_value_type(const std::string::iterator& it, + const std::string::iterator& end) + { + if(it == end) + { + throw_parse_exception("Failed to parse value type"); + } + if (*it == '"' || *it == '\'') + { + return parse_type::STRING; + } + else if (is_time(it, end)) + { + return parse_type::LOCAL_TIME; + } + else if (auto dtype = date_type(it, end)) + { + return *dtype; + } + else if (is_number(*it) || *it == '-' || *it == '+') + { + return determine_number_type(it, end); + } + else if (*it == 't' || *it == 'f') + { + return parse_type::BOOL; + } + else if (*it == '[') + { + return parse_type::ARRAY; + } + else if (*it == '{') + { + return parse_type::INLINE_TABLE; + } + throw_parse_exception("Failed to parse value type"); + } + + parse_type determine_number_type(const std::string::iterator& it, + const std::string::iterator& end) + { + // determine if we are an integer or a float + auto check_it = it; + if (*check_it == '-' || *check_it == '+') + ++check_it; + while (check_it != end && is_number(*check_it)) + ++check_it; + if (check_it != end && *check_it == '.') + { + ++check_it; + while (check_it != end && is_number(*check_it)) + ++check_it; + return parse_type::FLOAT; + } + else + { + return parse_type::INT; + } + } + + std::shared_ptr<value<std::string>> parse_string(std::string::iterator& it, + std::string::iterator& end) + { + auto delim = *it; + assert(delim == '"' || delim == '\''); + + // end is non-const here because we have to be able to potentially + // parse multiple lines in a string, not just one + auto check_it = it; + ++check_it; + if (check_it != end && *check_it == delim) + { + ++check_it; + if (check_it != end && *check_it == delim) + { + it = ++check_it; + return parse_multiline_string(it, end, delim); + } + } + return make_value<std::string>(string_literal(it, end, delim)); + } + + std::shared_ptr<value<std::string>> + parse_multiline_string(std::string::iterator& it, + std::string::iterator& end, char delim) + { + std::stringstream ss; + + auto is_ws = [](char c) { return c == ' ' || c == '\t'; }; + + bool consuming = false; + std::shared_ptr<value<std::string>> ret; + + auto handle_line + = [&](std::string::iterator& local_it, + std::string::iterator& local_end) { + if (consuming) + { + local_it = std::find_if_not(local_it, local_end, is_ws); + + // whole line is whitespace + if (local_it == local_end) + return; + } + + consuming = false; + + while (local_it != local_end) + { + // handle escaped characters + if (delim == '"' && *local_it == '\\') + { + auto check = local_it; + // check if this is an actual escape sequence or a + // whitespace escaping backslash + ++check; + consume_whitespace(check, local_end); + if (check == local_end) + { + consuming = true; + break; + } + + ss << parse_escape_code(local_it, local_end); + continue; + } + + // if we can end the string + if (std::distance(local_it, local_end) >= 3) + { + auto check = local_it; + // check for """ + if (*check++ == delim && *check++ == delim + && *check++ == delim) + { + local_it = check; + ret = make_value<std::string>(ss.str()); + break; + } + } + + ss << *local_it++; + } + }; + + // handle the remainder of the current line + handle_line(it, end); + if (ret) + return ret; + + // start eating lines + while (detail::getline(input_, line_)) + { + ++line_number_; + + it = line_.begin(); + end = line_.end(); + + handle_line(it, end); + + if (ret) + return ret; + + if (!consuming) + ss << std::endl; + } + + throw_parse_exception("Unterminated multi-line basic string"); + } + + std::string string_literal(std::string::iterator& it, + const std::string::iterator& end, char delim) + { + ++it; + std::string val; + while (it != end) + { + // handle escaped characters + if (delim == '"' && *it == '\\') + { + val += parse_escape_code(it, end); + } + else if (*it == delim) + { + ++it; + consume_whitespace(it, end); + return val; + } + else + { + val += *it++; + } + } + throw_parse_exception("Unterminated string literal"); + } + + std::string parse_escape_code(std::string::iterator& it, + const std::string::iterator& end) + { + ++it; + if (it == end) + throw_parse_exception("Invalid escape sequence"); + char value; + if (*it == 'b') + { + value = '\b'; + } + else if (*it == 't') + { + value = '\t'; + } + else if (*it == 'n') + { + value = '\n'; + } + else if (*it == 'f') + { + value = '\f'; + } + else if (*it == 'r') + { + value = '\r'; + } + else if (*it == '"') + { + value = '"'; + } + else if (*it == '\\') + { + value = '\\'; + } + else if (*it == 'u' || *it == 'U') + { + return parse_unicode(it, end); + } + else + { + throw_parse_exception("Invalid escape sequence"); + } + ++it; + return std::string(1, value); + } + + std::string parse_unicode(std::string::iterator& it, + const std::string::iterator& end) + { + bool large = *it++ == 'U'; + auto codepoint = parse_hex(it, end, large ? 0x10000000 : 0x1000); + + if ((codepoint > 0xd7ff && codepoint < 0xe000) || codepoint > 0x10ffff) + { + throw_parse_exception( + "Unicode escape sequence is not a Unicode scalar value"); + } + + std::string result; + // See Table 3-6 of the Unicode standard + if (codepoint <= 0x7f) + { + // 1-byte codepoints: 00000000 0xxxxxxx + // repr: 0xxxxxxx + result += static_cast<char>(codepoint & 0x7f); + } + else if (codepoint <= 0x7ff) + { + // 2-byte codepoints: 00000yyy yyxxxxxx + // repr: 110yyyyy 10xxxxxx + // + // 0x1f = 00011111 + // 0xc0 = 11000000 + // + result += static_cast<char>(0xc0 | ((codepoint >> 6) & 0x1f)); + // + // 0x80 = 10000000 + // 0x3f = 00111111 + // + result += static_cast<char>(0x80 | (codepoint & 0x3f)); + } + else if (codepoint <= 0xffff) + { + // 3-byte codepoints: zzzzyyyy yyxxxxxx + // repr: 1110zzzz 10yyyyyy 10xxxxxx + // + // 0xe0 = 11100000 + // 0x0f = 00001111 + // + result += static_cast<char>(0xe0 | ((codepoint >> 12) & 0x0f)); + result += static_cast<char>(0x80 | ((codepoint >> 6) & 0x1f)); + result += static_cast<char>(0x80 | (codepoint & 0x3f)); + } + else + { + // 4-byte codepoints: 000uuuuu zzzzyyyy yyxxxxxx + // repr: 11110uuu 10uuzzzz 10yyyyyy 10xxxxxx + // + // 0xf0 = 11110000 + // 0x07 = 00000111 + // + result += static_cast<char>(0xf0 | ((codepoint >> 18) & 0x07)); + result += static_cast<char>(0x80 | ((codepoint >> 12) & 0x3f)); + result += static_cast<char>(0x80 | ((codepoint >> 6) & 0x3f)); + result += static_cast<char>(0x80 | (codepoint & 0x3f)); + } + return result; + } + + uint32_t parse_hex(std::string::iterator& it, + const std::string::iterator& end, uint32_t place) + { + uint32_t value = 0; + while (place > 0) + { + if (it == end) + throw_parse_exception("Unexpected end of unicode sequence"); + + if (!is_hex(*it)) + throw_parse_exception("Invalid unicode escape sequence"); + + value += place * hex_to_digit(*it++); + place /= 16; + } + return value; + } + + bool is_hex(char c) + { + return is_number(c) || (c >= 'a' && c <= 'f') || (c >= 'A' && c <= 'F'); + } + + uint32_t hex_to_digit(char c) + { + if (is_number(c)) + return static_cast<uint32_t>(c - '0'); + return 10 + static_cast<uint32_t>( + c - ((c >= 'a' && c <= 'f') ? 'a' : 'A')); + } + + std::shared_ptr<base> parse_number(std::string::iterator& it, + const std::string::iterator& end) + { + auto check_it = it; + auto check_end = find_end_of_number(it, end); + + auto eat_sign = [&]() { + if (check_it != end && (*check_it == '-' || *check_it == '+')) + ++check_it; + }; + + eat_sign(); + + auto eat_numbers = [&]() { + auto beg = check_it; + while (check_it != end && is_number(*check_it)) + { + ++check_it; + if (check_it != end && *check_it == '_') + { + ++check_it; + if (check_it == end || !is_number(*check_it)) + throw_parse_exception("Malformed number"); + } + } + + if (check_it == beg) + throw_parse_exception("Malformed number"); + }; + + auto check_no_leading_zero = [&]() { + if (check_it != end && *check_it == '0' && check_it + 1 != check_end + && check_it[1] != '.') + { + throw_parse_exception("Numbers may not have leading zeros"); + } + }; + + check_no_leading_zero(); + eat_numbers(); + + if (check_it != end + && (*check_it == '.' || *check_it == 'e' || *check_it == 'E')) + { + bool is_exp = *check_it == 'e' || *check_it == 'E'; + + ++check_it; + if (check_it == end) + throw_parse_exception("Floats must have trailing digits"); + + auto eat_exp = [&]() { + eat_sign(); + check_no_leading_zero(); + eat_numbers(); + }; + + if (is_exp) + eat_exp(); + else + eat_numbers(); + + if (!is_exp && check_it != end + && (*check_it == 'e' || *check_it == 'E')) + { + ++check_it; + eat_exp(); + } + + return parse_float(it, check_it); + } + else + { + return parse_int(it, check_it); + } + } + + std::shared_ptr<value<int64_t>> parse_int(std::string::iterator& it, + const std::string::iterator& end) + { + std::string v{it, end}; + v.erase(std::remove(v.begin(), v.end(), '_'), v.end()); + it = end; + try + { + return make_value<int64_t>(std::stoll(v)); + } + catch (const std::invalid_argument& ex) + { + throw_parse_exception("Malformed number (invalid argument: " + + std::string{ex.what()} + ")"); + } + catch (const std::out_of_range& ex) + { + throw_parse_exception("Malformed number (out of range: " + + std::string{ex.what()} + ")"); + } + } + + std::shared_ptr<value<double>> parse_float(std::string::iterator& it, + const std::string::iterator& end) + { + std::string v{it, end}; + v.erase(std::remove(v.begin(), v.end(), '_'), v.end()); + it = end; + char decimal_point = std::localeconv()->decimal_point[0]; + std::replace(v.begin(), v.end(), '.', decimal_point); + try + { + return make_value<double>(std::stod(v)); + } + catch (const std::invalid_argument& ex) + { + throw_parse_exception("Malformed number (invalid argument: " + + std::string{ex.what()} + ")"); + } + catch (const std::out_of_range& ex) + { + throw_parse_exception("Malformed number (out of range: " + + std::string{ex.what()} + ")"); + } + } + + std::shared_ptr<value<bool>> parse_bool(std::string::iterator& it, + const std::string::iterator& end) + { + auto eat = make_consumer(it, end, [this]() { + throw_parse_exception("Attempted to parse invalid boolean value"); + }); + + if (*it == 't') + { + eat("true"); + return make_value<bool>(true); + } + else if (*it == 'f') + { + eat("false"); + return make_value<bool>(false); + } + + eat.error(); + return nullptr; + } + + std::string::iterator find_end_of_number(std::string::iterator it, + std::string::iterator end) + { + return std::find_if(it, end, [](char c) { + return !is_number(c) && c != '_' && c != '.' && c != 'e' && c != 'E' + && c != '-' && c != '+'; + }); + } + + std::string::iterator find_end_of_date(std::string::iterator it, + std::string::iterator end) + { + return std::find_if(it, end, [](char c) { + return !is_number(c) && c != 'T' && c != 'Z' && c != ':' && c != '-' + && c != '+' && c != '.'; + }); + } + + std::string::iterator find_end_of_time(std::string::iterator it, + std::string::iterator end) + { + return std::find_if(it, end, [](char c) { + return !is_number(c) && c != ':' && c != '.'; + }); + } + + local_time read_time(std::string::iterator& it, + const std::string::iterator& end) + { + auto time_end = find_end_of_time(it, end); + + auto eat = make_consumer( + it, time_end, [&]() { throw_parse_exception("Malformed time"); }); + + local_time ltime; + + ltime.hour = eat.eat_digits(2); + eat(':'); + ltime.minute = eat.eat_digits(2); + eat(':'); + ltime.second = eat.eat_digits(2); + + int power = 100000; + if (it != time_end && *it == '.') + { + ++it; + while (it != time_end && is_number(*it)) + { + ltime.microsecond += power * (*it++ - '0'); + power /= 10; + } + } + + if (it != time_end) + throw_parse_exception("Malformed time"); + + return ltime; + } + + std::shared_ptr<value<local_time>> + parse_time(std::string::iterator& it, const std::string::iterator& end) + { + return make_value(read_time(it, end)); + } + + std::shared_ptr<base> parse_date(std::string::iterator& it, + const std::string::iterator& end) + { + auto date_end = find_end_of_date(it, end); + + auto eat = make_consumer( + it, date_end, [&]() { throw_parse_exception("Malformed date"); }); + + local_date ldate; + ldate.year = eat.eat_digits(4); + eat('-'); + ldate.month = eat.eat_digits(2); + eat('-'); + ldate.day = eat.eat_digits(2); + + if (it == date_end) + return make_value(ldate); + + eat('T'); + + local_datetime ldt; + static_cast<local_date&>(ldt) = ldate; + static_cast<local_time&>(ldt) = read_time(it, date_end); + + if (it == date_end) + return make_value(ldt); + + offset_datetime dt; + static_cast<local_datetime&>(dt) = ldt; + + int hoff = 0; + int moff = 0; + if (*it == '+' || *it == '-') + { + auto plus = *it == '+'; + ++it; + + hoff = eat.eat_digits(2); + dt.hour_offset = (plus) ? hoff : -hoff; + eat(':'); + moff = eat.eat_digits(2); + dt.minute_offset = (plus) ? moff : -moff; + } + else if (*it == 'Z') + { + ++it; + } + + if (it != date_end) + throw_parse_exception("Malformed date"); + + return make_value(dt); + } + + std::shared_ptr<base> parse_array(std::string::iterator& it, + std::string::iterator& end) + { + // this gets ugly because of the "homogeneity" restriction: + // arrays can either be of only one type, or contain arrays + // (each of those arrays could be of different types, though) + // + // because of the latter portion, we don't really have a choice + // but to represent them as arrays of base values... + ++it; + + // ugh---have to read the first value to determine array type... + skip_whitespace_and_comments(it, end); + + // edge case---empty array + if (*it == ']') + { + ++it; + return make_array(); + } + + auto val_end = std::find_if( + it, end, [](char c) { return c == ',' || c == ']' || c == '#'; }); + parse_type type = determine_value_type(it, val_end); + switch (type) + { + case parse_type::STRING: + return parse_value_array<std::string>(it, end); + case parse_type::LOCAL_TIME: + return parse_value_array<local_time>(it, end); + case parse_type::LOCAL_DATE: + return parse_value_array<local_date>(it, end); + case parse_type::LOCAL_DATETIME: + return parse_value_array<local_datetime>(it, end); + case parse_type::OFFSET_DATETIME: + return parse_value_array<offset_datetime>(it, end); + case parse_type::INT: + return parse_value_array<int64_t>(it, end); + case parse_type::FLOAT: + return parse_value_array<double>(it, end); + case parse_type::BOOL: + return parse_value_array<bool>(it, end); + case parse_type::ARRAY: + return parse_object_array<array>(&parser::parse_array, '[', it, + end); + case parse_type::INLINE_TABLE: + return parse_object_array<table_array>( + &parser::parse_inline_table, '{', it, end); + default: + throw_parse_exception("Unable to parse array"); + } + } + + template <class Value> + std::shared_ptr<array> parse_value_array(std::string::iterator& it, + std::string::iterator& end) + { + auto arr = make_array(); + while (it != end && *it != ']') + { + auto value = parse_value(it, end); + if (auto v = value->as<Value>()) + arr->get().push_back(value); + else + throw_parse_exception("Arrays must be homogeneous"); + skip_whitespace_and_comments(it, end); + if (*it != ',') + break; + ++it; + skip_whitespace_and_comments(it, end); + } + if (it != end) + ++it; + return arr; + } + + template <class Object, class Function> + std::shared_ptr<Object> parse_object_array(Function&& fun, char delim, + std::string::iterator& it, + std::string::iterator& end) + { + auto arr = make_element<Object>(); + + while (it != end && *it != ']') + { + if (*it != delim) + throw_parse_exception("Unexpected character in array"); + + arr->get().push_back(((*this).*fun)(it, end)); + skip_whitespace_and_comments(it, end); + + if (*it != ',') + break; + + ++it; + skip_whitespace_and_comments(it, end); + } + + if (it == end || *it != ']') + throw_parse_exception("Unterminated array"); + + ++it; + return arr; + } + + std::shared_ptr<table> parse_inline_table(std::string::iterator& it, + std::string::iterator& end) + { + auto tbl = make_table(); + do + { + ++it; + if (it == end) + throw_parse_exception("Unterminated inline table"); + + consume_whitespace(it, end); + parse_key_value(it, end, tbl.get()); + consume_whitespace(it, end); + } while (*it == ','); + + if (it == end || *it != '}') + throw_parse_exception("Unterminated inline table"); + + ++it; + consume_whitespace(it, end); + + return tbl; + } + + void skip_whitespace_and_comments(std::string::iterator& start, + std::string::iterator& end) + { + consume_whitespace(start, end); + while (start == end || *start == '#') + { + if (!detail::getline(input_, line_)) + throw_parse_exception("Unclosed array"); + line_number_++; + start = line_.begin(); + end = line_.end(); + consume_whitespace(start, end); + } + } + + void consume_whitespace(std::string::iterator& it, + const std::string::iterator& end) + { + while (it != end && (*it == ' ' || *it == '\t')) + ++it; + } + + void consume_backwards_whitespace(std::string::iterator& back, + const std::string::iterator& front) + { + while (back != front && (*back == ' ' || *back == '\t')) + --back; + } + + void eol_or_comment(const std::string::iterator& it, + const std::string::iterator& end) + { + if (it != end && *it != '#') + throw_parse_exception("Unidentified trailing character '" + + std::string{*it} + + "'---did you forget a '#'?"); + } + + bool is_time(const std::string::iterator& it, + const std::string::iterator& end) + { + auto time_end = find_end_of_time(it, end); + auto len = std::distance(it, time_end); + + if (len < 8) + return false; + + if (it[2] != ':' || it[5] != ':') + return false; + + if (len > 8) + return it[8] == '.' && len > 9; + + return true; + } + + option<parse_type> date_type(const std::string::iterator& it, + const std::string::iterator& end) + { + auto date_end = find_end_of_date(it, end); + auto len = std::distance(it, date_end); + + if (len < 10) + return {}; + + if (it[4] != '-' || it[7] != '-') + return {}; + + if (len >= 19 && it[10] == 'T' && is_time(it + 11, date_end)) + { + // datetime type + auto time_end = find_end_of_time(it + 11, date_end); + if (time_end == date_end) + return {parse_type::LOCAL_DATETIME}; + else + return {parse_type::OFFSET_DATETIME}; + } + else if (len == 10) + { + // just a regular date + return {parse_type::LOCAL_DATE}; + } + + return {}; + } + + std::istream& input_; + std::string line_; + std::size_t line_number_ = 0; +}; + +/** + * Utility function to parse a file as a TOML file. Returns the root table. + * Throws a parse_exception if the file cannot be opened. + */ +inline std::shared_ptr<table> parse_file(const std::string& filename) +{ +#if defined(BOOST_NOWIDE_FSTREAM_INCLUDED_HPP) + boost::nowide::ifstream file{filename.c_str()}; +#elif defined(NOWIDE_FSTREAM_INCLUDED_HPP) + nowide::ifstream file{filename.c_str()}; +#else + std::ifstream file{filename}; +#endif + if (!file.is_open()) + throw parse_exception{filename + " could not be opened for parsing"}; + parser p{file}; + return p.parse(); +} + +template <class... Ts> +struct value_accept; + +template <> +struct value_accept<> +{ + template <class Visitor, class... Args> + static void accept(const base&, Visitor&&, Args&&...) + { + // nothing + } +}; + +template <class T, class... Ts> +struct value_accept<T, Ts...> +{ + template <class Visitor, class... Args> + static void accept(const base& b, Visitor&& visitor, Args&&... args) + { + if (auto v = b.as<T>()) + { + visitor.visit(*v, std::forward<Args>(args)...); + } + else + { + value_accept<Ts...>::accept(b, std::forward<Visitor>(visitor), + std::forward<Args>(args)...); + } + } +}; + +/** + * base implementation of accept() that calls visitor.visit() on the concrete + * class. + */ +template <class Visitor, class... Args> +void base::accept(Visitor&& visitor, Args&&... args) const +{ + if (is_value()) + { + using value_acceptor + = value_accept<std::string, int64_t, double, bool, local_date, + local_time, local_datetime, offset_datetime>; + value_acceptor::accept(*this, std::forward<Visitor>(visitor), + std::forward<Args>(args)...); + } + else if (is_table()) + { + visitor.visit(static_cast<const table&>(*this), + std::forward<Args>(args)...); + } + else if (is_array()) + { + visitor.visit(static_cast<const array&>(*this), + std::forward<Args>(args)...); + } + else if (is_table_array()) + { + visitor.visit(static_cast<const table_array&>(*this), + std::forward<Args>(args)...); + } +} + +/** + * Writer that can be passed to accept() functions of cpptoml objects and + * will output valid TOML to a stream. + */ +class toml_writer +{ + public: + /** + * Construct a toml_writer that will write to the given stream + */ + toml_writer(std::ostream& s, const std::string& indent_space = "\t") + : stream_(s), indent_(indent_space), has_naked_endline_(false) + { + // nothing + } + + public: + /** + * Output a base value of the TOML tree. + */ + template <class T> + void visit(const value<T>& v, bool = false) + { + write(v); + } + + /** + * Output a table element of the TOML tree + */ + void visit(const table& t, bool in_array = false) + { + write_table_header(in_array); + std::vector<std::string> values; + std::vector<std::string> tables; + + for (const auto& i : t) + { + if (i.second->is_table() || i.second->is_table_array()) + { + tables.push_back(i.first); + } + else + { + values.push_back(i.first); + } + } + + for (unsigned int i = 0; i < values.size(); ++i) + { + path_.push_back(values[i]); + + if (i > 0) + endline(); + + write_table_item_header(*t.get(values[i])); + t.get(values[i])->accept(*this, false); + path_.pop_back(); + } + + for (unsigned int i = 0; i < tables.size(); ++i) + { + path_.push_back(tables[i]); + + if (values.size() > 0 || i > 0) + endline(); + + write_table_item_header(*t.get(tables[i])); + t.get(tables[i])->accept(*this, false); + path_.pop_back(); + } + + endline(); + } + + /** + * Output an array element of the TOML tree + */ + void visit(const array& a, bool = false) + { + write("["); + + for (unsigned int i = 0; i < a.get().size(); ++i) + { + if (i > 0) + write(", "); + + if (a.get()[i]->is_array()) + { + a.get()[i]->as_array()->accept(*this, true); + } + else + { + a.get()[i]->accept(*this, true); + } + } + + write("]"); + } + + /** + * Output a table_array element of the TOML tree + */ + void visit(const table_array& t, bool = false) + { + for (unsigned int j = 0; j < t.get().size(); ++j) + { + if (j > 0) + endline(); + + t.get()[j]->accept(*this, true); + } + + endline(); + } + + /** + * Escape a string for output. + */ + static std::string escape_string(const std::string& str) + { + std::string res; + for (auto it = str.begin(); it != str.end(); ++it) + { + if (*it == '\b') + { + res += "\\b"; + } + else if (*it == '\t') + { + res += "\\t"; + } + else if (*it == '\n') + { + res += "\\n"; + } + else if (*it == '\f') + { + res += "\\f"; + } + else if (*it == '\r') + { + res += "\\r"; + } + else if (*it == '"') + { + res += "\\\""; + } + else if (*it == '\\') + { + res += "\\\\"; + } + else if ((const uint32_t)*it <= 0x001f) + { + res += "\\u"; + std::stringstream ss; + ss << std::hex << static_cast<uint32_t>(*it); + res += ss.str(); + } + else + { + res += *it; + } + } + return res; + } + + protected: + /** + * Write out a string. + */ + void write(const value<std::string>& v) + { + write("\""); + write(escape_string(v.get())); + write("\""); + } + + /** + * Write out a double. + */ + void write(const value<double>& v) + { + std::ios::fmtflags flags{stream_.flags()}; + + stream_ << std::showpoint; + write(v.get()); + + stream_.flags(flags); + } + + /** + * Write out an integer, local_date, local_time, local_datetime, or + * offset_datetime. + */ + template <class T> + typename std::enable_if<is_one_of<T, int64_t, local_date, local_time, + local_datetime, + offset_datetime>::value>::type + write(const value<T>& v) + { + write(v.get()); + } + + /** + * Write out a boolean. + */ + void write(const value<bool>& v) + { + write((v.get() ? "true" : "false")); + } + + /** + * Write out the header of a table. + */ + void write_table_header(bool in_array = false) + { + if (!path_.empty()) + { + indent(); + + write("["); + + if (in_array) + { + write("["); + } + + for (unsigned int i = 0; i < path_.size(); ++i) + { + if (i > 0) + { + write("."); + } + + if (path_[i].find_first_not_of("ABCDEFGHIJKLMNOPQRSTUVWXYZabcde" + "fghijklmnopqrstuvwxyz0123456789" + "_-") + == std::string::npos) + { + write(path_[i]); + } + else + { + write("\""); + write(escape_string(path_[i])); + write("\""); + } + } + + if (in_array) + { + write("]"); + } + + write("]"); + endline(); + } + } + + /** + * Write out the identifier for an item in a table. + */ + void write_table_item_header(const base& b) + { + if (!b.is_table() && !b.is_table_array()) + { + indent(); + + if (path_.back().find_first_not_of("ABCDEFGHIJKLMNOPQRSTUVWXYZabcde" + "fghijklmnopqrstuvwxyz0123456789" + "_-") + == std::string::npos) + { + write(path_.back()); + } + else + { + write("\""); + write(escape_string(path_.back())); + write("\""); + } + + write(" = "); + } + } + + private: + /** + * Indent the proper number of tabs given the size of + * the path. + */ + void indent() + { + for (std::size_t i = 1; i < path_.size(); ++i) + write(indent_); + } + + /** + * Write a value out to the stream. + */ + template <class T> + void write(const T& v) + { + stream_ << v; + has_naked_endline_ = false; + } + + /** + * Write an endline out to the stream + */ + void endline() + { + if (!has_naked_endline_) + { + stream_ << "\n"; + has_naked_endline_ = true; + } + } + + private: + std::ostream& stream_; + const std::string indent_; + std::vector<std::string> path_; + bool has_naked_endline_; +}; + +inline std::ostream& operator<<(std::ostream& stream, const base& b) +{ + toml_writer writer{stream}; + b.accept(writer); + return stream; +} + +template <class T> +std::ostream& operator<<(std::ostream& stream, const value<T>& v) +{ + toml_writer writer{stream}; + v.accept(writer); + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, const table& t) +{ + toml_writer writer{stream}; + t.accept(writer); + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, const table_array& t) +{ + toml_writer writer{stream}; + t.accept(writer); + return stream; +} + +inline std::ostream& operator<<(std::ostream& stream, const array& a) +{ + toml_writer writer{stream}; + a.accept(writer); + return stream; +} +} +#endif diff --git a/src/libexpr/eval.cc b/src/libexpr/eval.cc index e09297546c95..f41905787f9e 100644 --- a/src/libexpr/eval.cc +++ b/src/libexpr/eval.cc @@ -349,19 +349,25 @@ Path EvalState::checkSourcePath(const Path & path_) bool found = false; + /* First canonicalize the path without symlinks, so we make sure an + * attacker can't append ../../... to a path that would be in allowedPaths + * and thus leak symlink targets. + */ + Path abspath = canonPath(path_); + for (auto & i : *allowedPaths) { - if (isDirOrInDir(path_, i)) { + if (isDirOrInDir(abspath, i)) { found = true; break; } } if (!found) - throw RestrictedPathError("access to path '%1%' is forbidden in restricted mode", path_); + throw RestrictedPathError("access to path '%1%' is forbidden in restricted mode", abspath); /* Resolve symlinks. */ - debug(format("checking access to '%s'") % path_); - Path path = canonPath(path_, true); + debug(format("checking access to '%s'") % abspath); + Path path = canonPath(abspath, true); for (auto & i : *allowedPaths) { if (isDirOrInDir(path, i)) { @@ -1076,6 +1082,8 @@ void EvalState::callPrimOp(Value & fun, Value & arg, Value & v, const Pos & pos) void EvalState::callFunction(Value & fun, Value & arg, Value & v, const Pos & pos) { + forceValue(fun, pos); + if (fun.type == tPrimOp || fun.type == tPrimOpApp) { callPrimOp(fun, arg, v, pos); return; @@ -1091,10 +1099,8 @@ void EvalState::callFunction(Value & fun, Value & arg, Value & v, const Pos & po auto & fun2 = *allocValue(); fun2 = fun; /* !!! Should we use the attr pos here? */ - forceValue(*found->value, pos); Value v2; callFunction(*found->value, fun2, v2, pos); - forceValue(v2, pos); return callFunction(v2, arg, v, pos); } } @@ -1181,7 +1187,6 @@ void EvalState::autoCallFunction(Bindings & args, Value & fun, Value & res) if (fun.type == tAttrs) { auto found = fun.attrs->find(sFunctor); if (found != fun.attrs->end()) { - forceValue(*found->value); Value * v = allocValue(); callFunction(*found->value, fun, *v, noPos); forceValue(*v); @@ -1565,7 +1570,6 @@ string EvalState::coerceToString(const Pos & pos, Value & v, PathSet & context, if (v.type == tAttrs) { auto i = v.attrs->find(sToString); if (i != v.attrs->end()) { - forceValue(*i->value, pos); Value v1; callFunction(*i->value, v, v1, pos); return coerceToString(pos, v1, context, coerceMore, copyToStore); @@ -1720,20 +1724,6 @@ bool EvalState::eqValues(Value & v1, Value & v2) } -void EvalState::printStats2() -{ - struct rusage ru; - getrusage(RUSAGE_SELF, &ru); - - GC_prof_stats_s gc; - GC_get_prof_stats(&gc, sizeof(gc)); - - printError("STATS %d %d %d %d %d %d", - nrValues, nrValuesFreed.load(), nrValues - nrValuesFreed, - ru.ru_maxrss, - gc.heapsize_full, gc.free_bytes_full); -} - void EvalState::printStats() { bool showStats = getEnv("NIX_SHOW_STATS", "0") != "0"; diff --git a/src/libexpr/eval.hh b/src/libexpr/eval.hh index 46bda86d084e..d0f298e168e9 100644 --- a/src/libexpr/eval.hh +++ b/src/libexpr/eval.hh @@ -276,7 +276,6 @@ public: /* Print statistics. */ void printStats(); - void printStats2(); void realiseContext(const PathSet & context); diff --git a/src/libexpr/lexer.l b/src/libexpr/lexer.l index 29ca327c1e4e..a052447d3dce 100644 --- a/src/libexpr/lexer.l +++ b/src/libexpr/lexer.l @@ -12,6 +12,8 @@ %{ +#include <boost/lexical_cast.hpp> + #include "nixexpr.hh" #include "parser-tab.hh" @@ -124,9 +126,11 @@ or { return OR_KW; } {ID} { yylval->id = strdup(yytext); return ID; } {INT} { errno = 0; - yylval->n = strtol(yytext, 0, 10); - if (errno != 0) + try { + yylval->n = boost::lexical_cast<int64_t>(yytext); + } catch (const boost::bad_lexical_cast &) { throw ParseError(format("invalid integer '%1%'") % yytext); + } return INT; } {FLOAT} { errno = 0; diff --git a/src/libexpr/parser.y b/src/libexpr/parser.y index eee48887dc22..cbd576d7d126 100644 --- a/src/libexpr/parser.y +++ b/src/libexpr/parser.y @@ -273,11 +273,11 @@ void yyerror(YYLTYPE * loc, yyscan_t scanner, ParseData * data, const char * err %token IND_STRING_OPEN IND_STRING_CLOSE %token ELLIPSIS -%nonassoc IMPL +%right IMPL %left OR %left AND %nonassoc EQ NEQ -%left '<' '>' LEQ GEQ +%nonassoc '<' '>' LEQ GEQ %right UPDATE %left NOT %left '+' '-' diff --git a/src/libexpr/primops.cc b/src/libexpr/primops.cc index 3a6c4035b8b8..6f82c6c404f2 100644 --- a/src/libexpr/primops.cc +++ b/src/libexpr/primops.cc @@ -1356,6 +1356,24 @@ static void prim_functionArgs(EvalState & state, const Pos & pos, Value * * args } +/* Apply a function to every element of an attribute set. */ +static void prim_mapAttrs(EvalState & state, const Pos & pos, Value * * args, Value & v) +{ + state.forceAttrs(*args[1], pos); + + state.mkAttrs(v, args[1]->attrs->size()); + + for (auto & i : *args[1]->attrs) { + Value * vName = state.allocValue(); + Value * vFun2 = state.allocValue(); + mkString(*vName, i.name); + mkApp(*vFun2, *args[0], *vName); + mkApp(*state.allocAttr(v, i.name), *vFun2, *i.value); + } +} + + + /************************************************************* * Lists *************************************************************/ @@ -1410,7 +1428,6 @@ static void prim_tail(EvalState & state, const Pos & pos, Value * * args, Value /* Apply a function to every element of a list. */ static void prim_map(EvalState & state, const Pos & pos, Value * * args, Value & v) { - state.forceFunction(*args[0], pos); state.forceList(*args[1], pos); state.mkList(v, args[1]->listSize()); @@ -1489,19 +1506,20 @@ static void prim_foldlStrict(EvalState & state, const Pos & pos, Value * * args, state.forceFunction(*args[0], pos); state.forceList(*args[2], pos); - Value * vCur = args[1]; + if (args[2]->listSize()) { + Value * vCur = args[1]; - if (args[2]->listSize()) for (unsigned int n = 0; n < args[2]->listSize(); ++n) { Value vTmp; state.callFunction(*args[0], *vCur, vTmp, pos); vCur = n == args[2]->listSize() - 1 ? &v : state.allocValue(); state.callFunction(vTmp, *args[2]->listElems()[n], *vCur, pos); } - else - v = *vCur; - - state.forceValue(v); + state.forceValue(v); + } else { + state.forceValue(*args[1]); + v = *args[1]; + } } @@ -1538,7 +1556,6 @@ static void prim_all(EvalState & state, const Pos & pos, Value * * args, Value & static void prim_genList(EvalState & state, const Pos & pos, Value * * args, Value & v) { - state.forceFunction(*args[0], pos); auto len = state.forceInt(*args[1], pos); if (len < 0) @@ -1627,6 +1644,35 @@ static void prim_partition(EvalState & state, const Pos & pos, Value * * args, V } +/* concatMap = f: list: concatLists (map f list); */ +/* C++-version is to avoid allocating `mkApp', call `f' eagerly */ +static void prim_concatMap(EvalState & state, const Pos & pos, Value * * args, Value & v) +{ + state.forceFunction(*args[0], pos); + state.forceList(*args[1], pos); + auto nrLists = args[1]->listSize(); + + Value lists[nrLists]; + size_t len = 0; + + for (unsigned int n = 0; n < nrLists; ++n) { + Value * vElem = args[1]->listElems()[n]; + state.callFunction(*args[0], *vElem, lists[n], pos); + state.forceList(lists[n], pos); + len += lists[n].listSize(); + } + + state.mkList(v, len); + auto out = v.listElems(); + for (unsigned int n = 0, pos = 0; n < nrLists; ++n) { + auto l = lists[n].listSize(); + if (l) + memcpy(out + pos, lists[n].listElems(), l * sizeof(Value *)); + pos += l; + } +} + + /************************************************************* * Integer arithmetic *************************************************************/ @@ -1634,6 +1680,8 @@ static void prim_partition(EvalState & state, const Pos & pos, Value * * args, V static void prim_add(EvalState & state, const Pos & pos, Value * * args, Value & v) { + state.forceValue(*args[0], pos); + state.forceValue(*args[1], pos); if (args[0]->type == tFloat || args[1]->type == tFloat) mkFloat(v, state.forceFloat(*args[0], pos) + state.forceFloat(*args[1], pos)); else @@ -1643,6 +1691,8 @@ static void prim_add(EvalState & state, const Pos & pos, Value * * args, Value & static void prim_sub(EvalState & state, const Pos & pos, Value * * args, Value & v) { + state.forceValue(*args[0], pos); + state.forceValue(*args[1], pos); if (args[0]->type == tFloat || args[1]->type == tFloat) mkFloat(v, state.forceFloat(*args[0], pos) - state.forceFloat(*args[1], pos)); else @@ -1652,6 +1702,8 @@ static void prim_sub(EvalState & state, const Pos & pos, Value * * args, Value & static void prim_mul(EvalState & state, const Pos & pos, Value * * args, Value & v) { + state.forceValue(*args[0], pos); + state.forceValue(*args[1], pos); if (args[0]->type == tFloat || args[1]->type == tFloat) mkFloat(v, state.forceFloat(*args[0], pos) * state.forceFloat(*args[1], pos)); else @@ -1661,6 +1713,9 @@ static void prim_mul(EvalState & state, const Pos & pos, Value * * args, Value & static void prim_div(EvalState & state, const Pos & pos, Value * * args, Value & v) { + state.forceValue(*args[0], pos); + state.forceValue(*args[1], pos); + NixFloat f2 = state.forceFloat(*args[1], pos); if (f2 == 0) throw EvalError(format("division by zero, at %1%") % pos); @@ -2212,6 +2267,7 @@ void EvalState::createBaseEnv() addPrimOp("__intersectAttrs", 2, prim_intersectAttrs); addPrimOp("__catAttrs", 2, prim_catAttrs); addPrimOp("__functionArgs", 1, prim_functionArgs); + addPrimOp("__mapAttrs", 2, prim_mapAttrs); // Lists addPrimOp("__isList", 1, prim_isList); @@ -2229,6 +2285,7 @@ void EvalState::createBaseEnv() addPrimOp("__genList", 2, prim_genList); addPrimOp("__sort", 2, prim_sort); addPrimOp("__partition", 2, prim_partition); + addPrimOp("__concatMap", 2, prim_concatMap); // Integer arithmetic addPrimOp("__add", 2, prim_add); diff --git a/src/libexpr/primops/fetchGit.cc b/src/libexpr/primops/fetchGit.cc index 7aa98e0bfab3..0c6539959bf6 100644 --- a/src/libexpr/primops/fetchGit.cc +++ b/src/libexpr/primops/fetchGit.cc @@ -219,8 +219,6 @@ static void prim_fetchGit(EvalState & state, const Pos & pos, Value * * args, Va } else url = state.coerceToString(pos, *args[0], context, false, false); - if (!isUri(url)) url = absPath(url); - // FIXME: git externals probably can be used to bypass the URI // whitelist. Ah well. state.checkURI(url); diff --git a/src/libexpr/primops/fetchMercurial.cc b/src/libexpr/primops/fetchMercurial.cc index 9d35f6d0d6d7..97cda2458c9b 100644 --- a/src/libexpr/primops/fetchMercurial.cc +++ b/src/libexpr/primops/fetchMercurial.cc @@ -184,8 +184,6 @@ static void prim_fetchMercurial(EvalState & state, const Pos & pos, Value * * ar } else url = state.coerceToString(pos, *args[0], context, false, false); - if (!isUri(url)) url = absPath(url); - // FIXME: git externals probably can be used to bypass the URI // whitelist. Ah well. state.checkURI(url); diff --git a/src/libexpr/primops/fromTOML.cc b/src/libexpr/primops/fromTOML.cc new file mode 100644 index 000000000000..4128de05d0cf --- /dev/null +++ b/src/libexpr/primops/fromTOML.cc @@ -0,0 +1,77 @@ +#include "primops.hh" +#include "eval-inline.hh" + +#include "cpptoml/cpptoml.h" + +namespace nix { + +static void prim_fromTOML(EvalState & state, const Pos & pos, Value * * args, Value & v) +{ + using namespace cpptoml; + + auto toml = state.forceStringNoCtx(*args[0], pos); + + std::istringstream tomlStream(toml); + + std::function<void(Value &, std::shared_ptr<base>)> visit; + + visit = [&](Value & v, std::shared_ptr<base> t) { + + if (auto t2 = t->as_table()) { + + size_t size = 0; + for (auto & i : *t2) { (void) i; size++; } + + state.mkAttrs(v, size); + + for (auto & i : *t2) { + auto & v2 = *state.allocAttr(v, state.symbols.create(i.first)); + + if (auto i2 = i.second->as_table_array()) { + size_t size2 = i2->get().size(); + state.mkList(v2, size2); + for (size_t j = 0; j < size2; ++j) + visit(*(v2.listElems()[j] = state.allocValue()), i2->get()[j]); + } + else + visit(v2, i.second); + } + + v.attrs->sort(); + } + + else if (auto t2 = t->as_array()) { + size_t size = t2->get().size(); + + state.mkList(v, size); + + for (size_t i = 0; i < size; ++i) + visit(*(v.listElems()[i] = state.allocValue()), t2->get()[i]); + } + + else if (t->is_value()) { + if (auto val = t->as<int64_t>()) + mkInt(v, val->get()); + else if (auto val = t->as<NixFloat>()) + mkFloat(v, val->get()); + else if (auto val = t->as<bool>()) + mkBool(v, val->get()); + else if (auto val = t->as<std::string>()) + mkString(v, val->get()); + else + throw EvalError("unsupported value type in TOML"); + } + + else abort(); + }; + + try { + visit(v, parser(tomlStream).parse()); + } catch (std::runtime_error & e) { + throw EvalError("while parsing a TOML string at %s: %s", pos, e.what()); + } +} + +static RegisterPrimOp r("fromTOML", 1, prim_fromTOML); + +} diff --git a/src/libexpr/value.hh b/src/libexpr/value.hh index 66b41a158400..e1ec87d3b84c 100644 --- a/src/libexpr/value.hh +++ b/src/libexpr/value.hh @@ -43,8 +43,8 @@ class XMLWriter; class JSONPlaceholder; -typedef long NixInt; -typedef float NixFloat; +typedef int64_t NixInt; +typedef double NixFloat; /* External values must descend from ExternalValueBase, so that * type-agnostic nix functions (e.g. showType) can be implemented diff --git a/src/libstore/binary-cache-store.cc b/src/libstore/binary-cache-store.cc index 76c0a1a891b8..4527ee6ba660 100644 --- a/src/libstore/binary-cache-store.cc +++ b/src/libstore/binary-cache-store.cc @@ -217,17 +217,6 @@ void BinaryCacheStore::narFromPath(const Path & storePath, Sink & sink) { auto info = queryPathInfo(storePath).cast<const NarInfo>(); - auto source = sinkToSource([this, url{info->url}](Sink & sink) { - try { - getFile(url, sink); - } catch (NoSuchBinaryCacheFile & e) { - throw SubstituteGone(e.what()); - } - }); - - stats.narRead++; - //stats.narReadCompressedBytes += nar->size(); // FIXME - uint64_t narSize = 0; LambdaSink wrapperSink([&](const unsigned char * data, size_t len) { @@ -235,8 +224,18 @@ void BinaryCacheStore::narFromPath(const Path & storePath, Sink & sink) narSize += len; }); - decompress(info->compression, *source, wrapperSink); + auto decompressor = makeDecompressionSink(info->compression, wrapperSink); + try { + getFile(info->url, *decompressor); + } catch (NoSuchBinaryCacheFile & e) { + throw SubstituteGone(e.what()); + } + + decompressor->finish(); + + stats.narRead++; + //stats.narReadCompressedBytes += nar->size(); // FIXME stats.narReadBytes += narSize; } diff --git a/src/libstore/build.cc b/src/libstore/build.cc index d75ca0be86ef..cd37f7a3fc08 100644 --- a/src/libstore/build.cc +++ b/src/libstore/build.cc @@ -2007,7 +2007,7 @@ void DerivationGoal::startBuilder() /* Create /etc/hosts with localhost entry. */ if (!fixedOutput) - writeFile(chrootRootDir + "/etc/hosts", "127.0.0.1 localhost\n"); + writeFile(chrootRootDir + "/etc/hosts", "127.0.0.1 localhost\n::1 localhost\n"); /* Make the closure of the inputs available in the chroot, rather than the whole Nix store. This prevents any access diff --git a/src/libstore/builtins/fetchurl.cc b/src/libstore/builtins/fetchurl.cc index 1f4abd374f54..b4dcb35f951a 100644 --- a/src/libstore/builtins/fetchurl.cc +++ b/src/libstore/builtins/fetchurl.cc @@ -39,21 +39,16 @@ void builtinFetchurl(const BasicDerivation & drv, const std::string & netrcData) request.verifyTLS = false; request.decompress = false; - downloader->download(std::move(request), sink); + auto decompressor = makeDecompressionSink( + hasSuffix(mainUrl, ".xz") ? "xz" : "none", sink); + downloader->download(std::move(request), *decompressor); + decompressor->finish(); }); - if (get(drv.env, "unpack", "") == "1") { - - if (hasSuffix(mainUrl, ".xz")) { - auto source2 = sinkToSource([&](Sink & sink) { - decompress("xz", *source, sink); - }); - restorePath(storePath, *source2); - } else - restorePath(storePath, *source); - - } else - writeFile(storePath, *source); + if (get(drv.env, "unpack", "") == "1") + restorePath(storePath, *source); + else + writeFile(storePath, *source); auto executable = drv.env.find("executable"); if (executable != drv.env.end() && executable->second == "1") { diff --git a/src/libstore/download.cc b/src/libstore/download.cc index 7a5deed32143..973fca0b130f 100644 --- a/src/libstore/download.cc +++ b/src/libstore/download.cc @@ -58,16 +58,6 @@ std::string resolveUri(const std::string & uri) return uri; } -ref<std::string> decodeContent(const std::string & encoding, ref<std::string> data) -{ - if (encoding == "") - return data; - else if (encoding == "br") - return decompress(encoding, *data); - else - throw Error("unsupported Content-Encoding '%s'", encoding); -} - struct CurlDownloader : public Downloader { CURLM * curlm = 0; @@ -106,6 +96,12 @@ struct CurlDownloader : public Downloader fmt(request.data ? "uploading '%s'" : "downloading '%s'", request.uri), {request.uri}, request.parentAct) , callback(callback) + , finalSink([this](const unsigned char * data, size_t len) { + if (this->request.dataCallback) + this->request.dataCallback((char *) data, len); + else + this->result.data->append((char *) data, len); + }) { if (!request.expectedETag.empty()) requestHeaders = curl_slist_append(requestHeaders, ("If-None-Match: " + request.expectedETag).c_str()); @@ -129,22 +125,40 @@ struct CurlDownloader : public Downloader } } - template<class T> - void fail(const T & e) + void failEx(std::exception_ptr ex) { assert(!done); done = true; - callback.rethrow(std::make_exception_ptr(e)); + callback.rethrow(ex); } + template<class T> + void fail(const T & e) + { + failEx(std::make_exception_ptr(e)); + } + + LambdaSink finalSink; + std::shared_ptr<CompressionSink> decompressionSink; + + std::exception_ptr writeException; + size_t writeCallback(void * contents, size_t size, size_t nmemb) { - size_t realSize = size * nmemb; - if (request.dataCallback) - request.dataCallback((char *) contents, realSize); - else - result.data->append((char *) contents, realSize); - return realSize; + try { + size_t realSize = size * nmemb; + result.bodySize += realSize; + + if (!decompressionSink) + decompressionSink = makeDecompressionSink(encoding, finalSink); + + (*decompressionSink)((unsigned char *) contents, realSize); + + return realSize; + } catch (...) { + writeException = std::current_exception(); + return 0; + } } static size_t writeCallbackWrapper(void * contents, size_t size, size_t nmemb, void * userp) @@ -162,6 +176,7 @@ struct CurlDownloader : public Downloader auto ss = tokenizeString<vector<string>>(line, " "); status = ss.size() >= 2 ? ss[1] : ""; result.data = std::make_shared<std::string>(); + result.bodySize = 0; encoding = ""; } else { auto i = line.find(':'); @@ -244,6 +259,7 @@ struct CurlDownloader : public Downloader curl_easy_setopt(req, CURLOPT_URL, request.uri.c_str()); curl_easy_setopt(req, CURLOPT_FOLLOWLOCATION, 1L); + curl_easy_setopt(req, CURLOPT_MAXREDIRS, 10); curl_easy_setopt(req, CURLOPT_NOSIGNAL, 1); curl_easy_setopt(req, CURLOPT_USERAGENT, ("curl/" LIBCURL_VERSION " Nix/" + nixVersion + @@ -295,6 +311,7 @@ struct CurlDownloader : public Downloader curl_easy_setopt(req, CURLOPT_NETRC, CURL_NETRC_OPTIONAL); result.data = std::make_shared<std::string>(); + result.bodySize = 0; } void finish(CURLcode code) @@ -308,29 +325,35 @@ struct CurlDownloader : public Downloader result.effectiveUrl = effectiveUrlCStr; debug("finished %s of '%s'; curl status = %d, HTTP status = %d, body = %d bytes", - request.verb(), request.uri, code, httpStatus, result.data ? result.data->size() : 0); + request.verb(), request.uri, code, httpStatus, result.bodySize); + + if (decompressionSink) + decompressionSink->finish(); if (code == CURLE_WRITE_ERROR && result.etag == request.expectedETag) { code = CURLE_OK; httpStatus = 304; } - if (code == CURLE_OK && + if (writeException) + failEx(writeException); + + else if (code == CURLE_OK && (httpStatus == 200 || httpStatus == 201 || httpStatus == 204 || httpStatus == 304 || httpStatus == 226 /* FTP */ || httpStatus == 0 /* other protocol */)) { result.cached = httpStatus == 304; done = true; try { - if (request.decompress) - result.data = decodeContent(encoding, ref<std::string>(result.data)); act.progress(result.data->size(), result.data->size()); callback(std::move(result)); } catch (...) { done = true; callback.rethrow(); } - } else { + } + + else { // We treat most errors as transient, but won't retry when hopeless Error err = Transient; @@ -364,6 +387,8 @@ struct CurlDownloader : public Downloader case CURLE_INTERFACE_FAILED: case CURLE_UNKNOWN_OPTION: case CURLE_SSL_CACERT_BADFILE: + case CURLE_TOO_MANY_REDIRECTS: + case CURLE_WRITE_ERROR: err = Misc; break; default: // Shut up warnings @@ -596,7 +621,7 @@ struct CurlDownloader : public Downloader // FIXME: do this on a worker thread try { #ifdef ENABLE_S3 - S3Helper s3Helper("", Aws::Region::US_EAST_1); // FIXME: make configurable + S3Helper s3Helper("", Aws::Region::US_EAST_1, ""); // FIXME: make configurable auto slash = request.uri.find('/', 5); if (slash == std::string::npos) throw nix::Error("bad S3 URI '%s'", request.uri); @@ -716,15 +741,17 @@ void Downloader::download(DownloadRequest && request, Sink & sink) while (true) { checkInterrupt(); - if (state->quit) { - if (state->exc) std::rethrow_exception(state->exc); - break; - } - /* If no data is available, then wait for the download thread to wake us up. */ - if (state->data.empty()) + if (state->data.empty()) { + + if (state->quit) { + if (state->exc) std::rethrow_exception(state->exc); + break; + } + state.wait(state->avail); + } /* If data is available, then flush it to the sink and wake up the download thread if it's blocked on a full buffer. */ diff --git a/src/libstore/download.hh b/src/libstore/download.hh index da55df7a6e71..f0228f7d053a 100644 --- a/src/libstore/download.hh +++ b/src/libstore/download.hh @@ -38,6 +38,7 @@ struct DownloadResult std::string etag; std::string effectiveUrl; std::shared_ptr<std::string> data; + uint64_t bodySize = 0; }; class Store; @@ -87,7 +88,4 @@ public: bool isUri(const string & s); -/* Decode data according to the Content-Encoding header. */ -ref<std::string> decodeContent(const std::string & encoding, ref<std::string> data); - } diff --git a/src/libstore/gc.cc b/src/libstore/gc.cc index b5020a506beb..b415d5421476 100644 --- a/src/libstore/gc.cc +++ b/src/libstore/gc.cc @@ -366,7 +366,7 @@ try_again: char buf[bufsiz]; auto res = readlink(file.c_str(), buf, bufsiz); if (res == -1) { - if (errno == ENOENT || errno == EACCES) + if (errno == ENOENT || errno == EACCES || errno == ESRCH) return; throw SysError("reading symlink"); } diff --git a/src/libstore/legacy-ssh-store.cc b/src/libstore/legacy-ssh-store.cc index 02d91ded04cd..88d2574e86ef 100644 --- a/src/libstore/legacy-ssh-store.cc +++ b/src/libstore/legacy-ssh-store.cc @@ -17,6 +17,7 @@ struct LegacySSHStore : public Store const Setting<Path> sshKey{this, "", "ssh-key", "path to an SSH private key"}; const Setting<bool> compress{this, false, "compress", "whether to compress the connection"}; const Setting<Path> remoteProgram{this, "nix-store", "remote-program", "path to the nix-store executable on the remote system"}; + const Setting<std::string> remoteStore{this, "", "remote-store", "URI of the store on the remote system"}; // Hack for getting remote build log output. const Setting<int> logFD{this, -1, "log-fd", "file descriptor to which SSH's stderr is connected"}; @@ -27,6 +28,7 @@ struct LegacySSHStore : public Store FdSink to; FdSource from; int remoteVersion; + bool good = true; }; std::string host; @@ -41,7 +43,7 @@ struct LegacySSHStore : public Store , connections(make_ref<Pool<Connection>>( std::max(1, (int) maxConnections), [this]() { return openConnection(); }, - [](const ref<Connection> & r) { return true; } + [](const ref<Connection> & r) { return r->good; } )) , master( host, @@ -56,7 +58,9 @@ struct LegacySSHStore : public Store ref<Connection> openConnection() { auto conn = make_ref<Connection>(); - conn->sshConn = master.startCommand(fmt("%s --serve --write", remoteProgram)); + conn->sshConn = master.startCommand( + fmt("%s --serve --write", remoteProgram) + + (remoteStore.get() == "" ? "" : " --store " + shellEscape(remoteStore.get()))); conn->to = FdSink(conn->sshConn->in.get()); conn->from = FdSource(conn->sshConn->out.get()); @@ -127,18 +131,48 @@ struct LegacySSHStore : public Store auto conn(connections->get()); - conn->to - << cmdImportPaths - << 1; - copyNAR(source, conn->to); - conn->to - << exportMagic - << info.path - << info.references - << info.deriver - << 0 - << 0; - conn->to.flush(); + if (GET_PROTOCOL_MINOR(conn->remoteVersion) >= 5) { + + conn->to + << cmdAddToStoreNar + << info.path + << info.deriver + << info.narHash.to_string(Base16, false) + << info.references + << info.registrationTime + << info.narSize + << info.ultimate + << info.sigs + << info.ca; + try { + copyNAR(source, conn->to); + } catch (...) { + conn->good = false; + throw; + } + conn->to.flush(); + + } else { + + conn->to + << cmdImportPaths + << 1; + try { + copyNAR(source, conn->to); + } catch (...) { + conn->good = false; + throw; + } + conn->to + << exportMagic + << info.path + << info.references + << info.deriver + << 0 + << 0; + conn->to.flush(); + + } if (readInt(conn->from) != 1) throw Error("failed to add path '%s' to remote host '%s', info.path, host"); diff --git a/src/libstore/local-store.cc b/src/libstore/local-store.cc index 3b2ba65f3b46..c91dbf241bcf 100644 --- a/src/libstore/local-store.cc +++ b/src/libstore/local-store.cc @@ -450,7 +450,7 @@ static void canonicalisePathMetaData_(const Path & path, uid_t fromUid, InodesSe ssize_t eaSize = llistxattr(path.c_str(), nullptr, 0); if (eaSize < 0) { - if (errno != ENOTSUP) + if (errno != ENOTSUP && errno != ENODATA) throw SysError("querying extended attributes of '%s'", path); } else if (eaSize > 0) { std::vector<char> eaBuf(eaSize); diff --git a/src/libstore/s3-binary-cache-store.cc b/src/libstore/s3-binary-cache-store.cc index 26144ccb40cc..7711388f05a9 100644 --- a/src/libstore/s3-binary-cache-store.cc +++ b/src/libstore/s3-binary-cache-store.cc @@ -84,8 +84,8 @@ static void initAWS() }); } -S3Helper::S3Helper(const std::string & profile, const std::string & region) - : config(makeConfig(region)) +S3Helper::S3Helper(const std::string & profile, const std::string & region, const std::string & endpoint) + : config(makeConfig(region, endpoint)) , client(make_ref<Aws::S3::S3Client>( profile == "" ? std::dynamic_pointer_cast<Aws::Auth::AWSCredentialsProvider>( @@ -99,7 +99,7 @@ S3Helper::S3Helper(const std::string & profile, const std::string & region) #else Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never, #endif - false)) + endpoint.empty())) { } @@ -116,11 +116,14 @@ class RetryStrategy : public Aws::Client::DefaultRetryStrategy } }; -ref<Aws::Client::ClientConfiguration> S3Helper::makeConfig(const string & region) +ref<Aws::Client::ClientConfiguration> S3Helper::makeConfig(const string & region, const string & endpoint) { initAWS(); auto res = make_ref<Aws::Client::ClientConfiguration>(); res->region = region; + if (!endpoint.empty()) { + res->endpointOverride = endpoint; + } res->requestTimeoutMs = 600 * 1000; res->retryStrategy = std::make_shared<RetryStrategy>(); res->caFile = settings.caFile; @@ -150,10 +153,8 @@ S3Helper::DownloadResult S3Helper::getObject( auto result = checkAws(fmt("AWS error fetching '%s'", key), client->GetObject(request)); - res.data = decodeContent( - result.GetContentEncoding(), - make_ref<std::string>( - dynamic_cast<std::stringstream &>(result.GetBody()).str())); + res.data = decompress(result.GetContentEncoding(), + dynamic_cast<std::stringstream &>(result.GetBody()).str()); } catch (S3Error & e) { if (e.err != Aws::S3::S3Errors::NO_SUCH_KEY) throw; @@ -170,6 +171,7 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore { const Setting<std::string> profile{this, "", "profile", "The name of the AWS configuration profile to use."}; const Setting<std::string> region{this, Aws::Region::US_EAST_1, "region", {"aws-region"}}; + const Setting<std::string> endpoint{this, "", "endpoint", "An optional override of the endpoint to use when talking to S3."}; const Setting<std::string> narinfoCompression{this, "", "narinfo-compression", "compression method for .narinfo files"}; const Setting<std::string> lsCompression{this, "", "ls-compression", "compression method for .ls files"}; const Setting<std::string> logCompression{this, "", "log-compression", "compression method for log/* files"}; @@ -186,7 +188,7 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore const Params & params, const std::string & bucketName) : S3BinaryCacheStore(params) , bucketName(bucketName) - , s3Helper(profile, region) + , s3Helper(profile, region, endpoint) { diskCache = getNarInfoDiskCache(); } @@ -273,6 +275,9 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore return true; } + std::shared_ptr<TransferManager> transferManager; + std::once_flag transferManagerCreated; + void uploadFile(const std::string & path, const std::string & data, const std::string & mimeType, const std::string & contentEncoding) @@ -284,61 +289,49 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore static std::shared_ptr<Aws::Utils::Threading::PooledThreadExecutor> executor = std::make_shared<Aws::Utils::Threading::PooledThreadExecutor>(maxThreads); - TransferManagerConfiguration transferConfig(executor.get()); - - transferConfig.s3Client = s3Helper.client; - transferConfig.bufferSize = bufferSize; - - if (contentEncoding != "") - transferConfig.createMultipartUploadTemplate.SetContentEncoding( - contentEncoding); - - transferConfig.uploadProgressCallback = - [&](const TransferManager *transferManager, - const std::shared_ptr<const TransferHandle> - &transferHandle) { - //FIXME: find a way to properly abort the multipart upload. - checkInterrupt(); - debug("upload progress ('%s'): '%d' of '%d' bytes", - path, - transferHandle->GetBytesTransferred(), - transferHandle->GetBytesTotalSize()); - }; + std::call_once(transferManagerCreated, [&]() { - transferConfig.transferStatusUpdatedCallback = - [&](const TransferManager *, - const std::shared_ptr<const TransferHandle> - &transferHandle) { - switch (transferHandle->GetStatus()) { - case TransferStatus::COMPLETED: - printTalkative("upload of '%s' completed", path); - stats.put++; - stats.putBytes += data.size(); - break; - case TransferStatus::IN_PROGRESS: - break; - case TransferStatus::FAILED: - throw Error("AWS error: failed to upload 's3://%s/%s'", - bucketName, path); - break; - default: - throw Error("AWS error: transfer status of 's3://%s/%s' " - "in unexpected state", - bucketName, path); - }; - }; + TransferManagerConfiguration transferConfig(executor.get()); + + transferConfig.s3Client = s3Helper.client; + transferConfig.bufferSize = bufferSize; - std::shared_ptr<TransferManager> transferManager = - TransferManager::Create(transferConfig); + transferConfig.uploadProgressCallback = + [&](const TransferManager *transferManager, + const std::shared_ptr<const TransferHandle> + &transferHandle) + { + //FIXME: find a way to properly abort the multipart upload. + //checkInterrupt(); + debug("upload progress ('%s'): '%d' of '%d' bytes", + path, + transferHandle->GetBytesTransferred(), + transferHandle->GetBytesTotalSize()); + }; + + transferManager = TransferManager::Create(transferConfig); + }); auto now1 = std::chrono::steady_clock::now(); std::shared_ptr<TransferHandle> transferHandle = - transferManager->UploadFile(stream, bucketName, path, mimeType, - Aws::Map<Aws::String, Aws::String>()); + transferManager->UploadFile( + stream, bucketName, path, mimeType, + Aws::Map<Aws::String, Aws::String>(), + nullptr, contentEncoding); transferHandle->WaitUntilFinished(); + if (transferHandle->GetStatus() == TransferStatus::FAILED) + throw Error("AWS error: failed to upload 's3://%s/%s': %s", + bucketName, path, transferHandle->GetLastError().GetMessage()); + + if (transferHandle->GetStatus() != TransferStatus::COMPLETED) + throw Error("AWS error: transfer status of 's3://%s/%s' in unexpected state", + bucketName, path); + + printTalkative("upload of '%s' completed", path); + auto now2 = std::chrono::steady_clock::now(); auto duration = @@ -349,6 +342,8 @@ struct S3BinaryCacheStoreImpl : public S3BinaryCacheStore bucketName % path % data.size() % duration); stats.putTimeMs += duration; + stats.putBytes += data.size(); + stats.put++; } void upsertFile(const std::string & path, const std::string & data, diff --git a/src/libstore/s3.hh b/src/libstore/s3.hh index 4f996400343c..95d612b66335 100644 --- a/src/libstore/s3.hh +++ b/src/libstore/s3.hh @@ -14,9 +14,9 @@ struct S3Helper ref<Aws::Client::ClientConfiguration> config; ref<Aws::S3::S3Client> client; - S3Helper(const std::string & profile, const std::string & region); + S3Helper(const std::string & profile, const std::string & region, const std::string & endpoint); - ref<Aws::Client::ClientConfiguration> makeConfig(const std::string & region); + ref<Aws::Client::ClientConfiguration> makeConfig(const std::string & region, const std::string & endpoint); struct DownloadResult { diff --git a/src/libstore/serve-protocol.hh b/src/libstore/serve-protocol.hh index f67d1e2580a5..9fae6d5349f1 100644 --- a/src/libstore/serve-protocol.hh +++ b/src/libstore/serve-protocol.hh @@ -5,7 +5,7 @@ namespace nix { #define SERVE_MAGIC_1 0x390c9deb #define SERVE_MAGIC_2 0x5452eecb -#define SERVE_PROTOCOL_VERSION 0x204 +#define SERVE_PROTOCOL_VERSION 0x205 #define GET_PROTOCOL_MAJOR(x) ((x) & 0xff00) #define GET_PROTOCOL_MINOR(x) ((x) & 0x00ff) @@ -18,6 +18,7 @@ typedef enum { cmdBuildPaths = 6, cmdQueryClosure = 7, cmdBuildDerivation = 8, + cmdAddToStoreNar = 9, } ServeCommand; } diff --git a/src/libstore/ssh.cc b/src/libstore/ssh.cc index 033c580936ad..5e0e44935cca 100644 --- a/src/libstore/ssh.cc +++ b/src/libstore/ssh.cc @@ -4,8 +4,9 @@ namespace nix { SSHMaster::SSHMaster(const std::string & host, const std::string & keyFile, bool useMaster, bool compress, int logFD) : host(host) + , fakeSSH(host == "localhost") , keyFile(keyFile) - , useMaster(useMaster) + , useMaster(useMaster && !fakeSSH) , compress(compress) , logFD(logFD) { @@ -45,12 +46,19 @@ std::unique_ptr<SSHMaster::Connection> SSHMaster::startCommand(const std::string if (logFD != -1 && dup2(logFD, STDERR_FILENO) == -1) throw SysError("duping over stderr"); - Strings args = { "ssh", host.c_str(), "-x", "-a" }; - addCommonSSHOpts(args); - if (socketPath != "") - args.insert(args.end(), {"-S", socketPath}); - if (verbosity >= lvlChatty) - args.push_back("-v"); + Strings args; + + if (fakeSSH) { + args = { "bash", "-c" }; + } else { + args = { "ssh", host.c_str(), "-x", "-a" }; + addCommonSSHOpts(args); + if (socketPath != "") + args.insert(args.end(), {"-S", socketPath}); + if (verbosity >= lvlChatty) + args.push_back("-v"); + } + args.push_back(command); execvp(args.begin()->c_str(), stringsToCharPtrs(args).data()); diff --git a/src/libstore/ssh.hh b/src/libstore/ssh.hh index 1268e6d00054..4f0f0bd29f9f 100644 --- a/src/libstore/ssh.hh +++ b/src/libstore/ssh.hh @@ -10,6 +10,7 @@ class SSHMaster private: const std::string host; + bool fakeSSH; const std::string keyFile; const bool useMaster; const bool compress; diff --git a/src/libstore/store-api.cc b/src/libstore/store-api.cc index 9b0b7d6327e0..1f42097fccfb 100644 --- a/src/libstore/store-api.cc +++ b/src/libstore/store-api.cc @@ -609,6 +609,8 @@ void copyStorePath(ref<Store> srcStore, ref<Store> dstStore, act.progress(total, info->narSize); }); srcStore->narFromPath({storePath}, wrapperSink); + }, [&]() { + throw EndOfFile("NAR for '%s' fetched from '%s' is incomplete", storePath, srcStore->getUri()); }); dstStore->addToStore(*info, *source, repair, checkSigs); @@ -629,11 +631,12 @@ void copyPaths(ref<Store> srcStore, ref<Store> dstStore, const PathSet & storePa Activity act(*logger, lvlInfo, actCopyPaths, fmt("copying %d paths", missing.size())); std::atomic<size_t> nrDone{0}; + std::atomic<size_t> nrFailed{0}; std::atomic<uint64_t> bytesExpected{0}; std::atomic<uint64_t> nrRunning{0}; auto showProgress = [&]() { - act.progress(nrDone, missing.size(), nrRunning); + act.progress(nrDone, missing.size(), nrRunning, nrFailed); }; ThreadPool pool; @@ -662,7 +665,16 @@ void copyPaths(ref<Store> srcStore, ref<Store> dstStore, const PathSet & storePa if (!dstStore->isValidPath(storePath)) { MaintainCount<decltype(nrRunning)> mc(nrRunning); showProgress(); - copyStorePath(srcStore, dstStore, storePath, repair, checkSigs); + try { + copyStorePath(srcStore, dstStore, storePath, repair, checkSigs); + } catch (Error &e) { + nrFailed++; + if (!settings.keepGoing) + throw e; + logger->log(lvlError, format("could not copy %s: %s") % storePath % e.what()); + showProgress(); + return; + } } nrDone++; @@ -834,8 +846,24 @@ ref<Store> openStore(const std::string & uri_, if (q != std::string::npos) { for (auto s : tokenizeString<Strings>(uri.substr(q + 1), "&")) { auto e = s.find('='); - if (e != std::string::npos) - params[s.substr(0, e)] = s.substr(e + 1); + if (e != std::string::npos) { + auto value = s.substr(e + 1); + std::string decoded; + for (size_t i = 0; i < value.size(); ) { + if (value[i] == '%') { + if (i + 2 >= value.size()) + throw Error("invalid URI parameter '%s'", value); + try { + decoded += std::stoul(std::string(value, i + 1, 2), 0, 16); + i += 3; + } catch (...) { + throw Error("invalid URI parameter '%s'", value); + } + } else + decoded += value[i++]; + } + params[s.substr(0, e)] = decoded; + } } uri = uri_.substr(0, q); } diff --git a/src/libutil/compression.cc b/src/libutil/compression.cc index e1782f8c4bd9..204c63cd26fc 100644 --- a/src/libutil/compression.cc +++ b/src/libutil/compression.cc @@ -8,246 +8,265 @@ #include <cstdio> #include <cstring> -#if HAVE_BROTLI #include <brotli/decode.h> #include <brotli/encode.h> -#endif // HAVE_BROTLI #include <iostream> namespace nix { -static const size_t bufSize = 32 * 1024; - -static void decompressNone(Source & source, Sink & sink) +// Don't feed brotli too much at once. +struct ChunkedCompressionSink : CompressionSink { - std::vector<unsigned char> buf(bufSize); - while (true) { - size_t n; - try { - n = source.read(buf.data(), buf.size()); - } catch (EndOfFile &) { - break; + uint8_t outbuf[32 * 1024]; + + void write(const unsigned char * data, size_t len) override + { + const size_t CHUNK_SIZE = sizeof(outbuf) << 2; + while (len) { + size_t n = std::min(CHUNK_SIZE, len); + writeInternal(data, n); + data += n; + len -= n; } - sink(buf.data(), n); } -} -static void decompressXZ(Source & source, Sink & sink) + virtual void writeInternal(const unsigned char * data, size_t len) = 0; +}; + +struct NoneSink : CompressionSink { - lzma_stream strm(LZMA_STREAM_INIT); - - lzma_ret ret = lzma_stream_decoder( - &strm, UINT64_MAX, LZMA_CONCATENATED); - if (ret != LZMA_OK) - throw CompressionError("unable to initialise lzma decoder"); - - Finally free([&]() { lzma_end(&strm); }); - - lzma_action action = LZMA_RUN; - std::vector<uint8_t> inbuf(bufSize), outbuf(bufSize); - strm.next_in = nullptr; - strm.avail_in = 0; - strm.next_out = outbuf.data(); - strm.avail_out = outbuf.size(); - bool eof = false; - - while (true) { - checkInterrupt(); - - if (strm.avail_in == 0 && !eof) { - strm.next_in = inbuf.data(); - try { - strm.avail_in = source.read((unsigned char *) strm.next_in, inbuf.size()); - } catch (EndOfFile &) { - eof = true; - } - } + Sink & nextSink; + NoneSink(Sink & nextSink) : nextSink(nextSink) { } + void finish() override { flush(); } + void write(const unsigned char * data, size_t len) override { nextSink(data, len); } +}; - if (strm.avail_in == 0) - action = LZMA_FINISH; +struct XzDecompressionSink : CompressionSink +{ + Sink & nextSink; + uint8_t outbuf[BUFSIZ]; + lzma_stream strm = LZMA_STREAM_INIT; + bool finished = false; - lzma_ret ret = lzma_code(&strm, action); + XzDecompressionSink(Sink & nextSink) : nextSink(nextSink) + { + lzma_ret ret = lzma_stream_decoder( + &strm, UINT64_MAX, LZMA_CONCATENATED); + if (ret != LZMA_OK) + throw CompressionError("unable to initialise lzma decoder"); - if (strm.avail_out < outbuf.size()) { - sink((unsigned char *) outbuf.data(), outbuf.size() - strm.avail_out); - strm.next_out = outbuf.data(); - strm.avail_out = outbuf.size(); - } + strm.next_out = outbuf; + strm.avail_out = sizeof(outbuf); + } - if (ret == LZMA_STREAM_END) return; + ~XzDecompressionSink() + { + lzma_end(&strm); + } - if (ret != LZMA_OK) - throw CompressionError("error %d while decompressing xz file", ret); + void finish() override + { + CompressionSink::flush(); + write(nullptr, 0); } -} -static void decompressBzip2(Source & source, Sink & sink) -{ - bz_stream strm; - memset(&strm, 0, sizeof(strm)); - - int ret = BZ2_bzDecompressInit(&strm, 0, 0); - if (ret != BZ_OK) - throw CompressionError("unable to initialise bzip2 decoder"); - - Finally free([&]() { BZ2_bzDecompressEnd(&strm); }); - - std::vector<char> inbuf(bufSize), outbuf(bufSize); - strm.next_in = nullptr; - strm.avail_in = 0; - strm.next_out = outbuf.data(); - strm.avail_out = outbuf.size(); - bool eof = false; - - while (true) { - checkInterrupt(); - - if (strm.avail_in == 0 && !eof) { - strm.next_in = inbuf.data(); - try { - strm.avail_in = source.read((unsigned char *) strm.next_in, inbuf.size()); - } catch (EndOfFile &) { - eof = true; - } - } + void write(const unsigned char * data, size_t len) override + { + strm.next_in = data; + strm.avail_in = len; + + while (!finished && (!data || strm.avail_in)) { + checkInterrupt(); - int ret = BZ2_bzDecompress(&strm); + lzma_ret ret = lzma_code(&strm, data ? LZMA_RUN : LZMA_FINISH); + if (ret != LZMA_OK && ret != LZMA_STREAM_END) + throw CompressionError("error %d while decompressing xz file", ret); - if (strm.avail_in == 0 && strm.avail_out == outbuf.size() && eof) - throw CompressionError("bzip2 data ends prematurely"); + finished = ret == LZMA_STREAM_END; - if (strm.avail_out < outbuf.size()) { - sink((unsigned char *) outbuf.data(), outbuf.size() - strm.avail_out); - strm.next_out = outbuf.data(); - strm.avail_out = outbuf.size(); + if (strm.avail_out < sizeof(outbuf) || strm.avail_in == 0) { + nextSink(outbuf, sizeof(outbuf) - strm.avail_out); + strm.next_out = outbuf; + strm.avail_out = sizeof(outbuf); + } } + } +}; - if (ret == BZ_STREAM_END) return; +struct BzipDecompressionSink : ChunkedCompressionSink +{ + Sink & nextSink; + bz_stream strm; + bool finished = false; + BzipDecompressionSink(Sink & nextSink) : nextSink(nextSink) + { + memset(&strm, 0, sizeof(strm)); + int ret = BZ2_bzDecompressInit(&strm, 0, 0); if (ret != BZ_OK) - throw CompressionError("error while decompressing bzip2 file"); + throw CompressionError("unable to initialise bzip2 decoder"); + + strm.next_out = (char *) outbuf; + strm.avail_out = sizeof(outbuf); } -} -static void decompressBrotli(Source & source, Sink & sink) -{ -#if !HAVE_BROTLI - RunOptions options(BROTLI, {"-d"}); - options.standardIn = &source; - options.standardOut = &sink; - runProgram2(options); -#else - auto *s = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); - if (!s) - throw CompressionError("unable to initialize brotli decoder"); - - Finally free([s]() { BrotliDecoderDestroyInstance(s); }); - - std::vector<uint8_t> inbuf(bufSize), outbuf(bufSize); - const uint8_t * next_in = nullptr; - size_t avail_in = 0; - bool eof = false; - - while (true) { - checkInterrupt(); - - if (avail_in == 0 && !eof) { - next_in = inbuf.data(); - try { - avail_in = source.read((unsigned char *) next_in, inbuf.size()); - } catch (EndOfFile &) { - eof = true; + ~BzipDecompressionSink() + { + BZ2_bzDecompressEnd(&strm); + } + + void finish() override + { + flush(); + write(nullptr, 0); + } + + void writeInternal(const unsigned char * data, size_t len) override + { + assert(len <= std::numeric_limits<decltype(strm.avail_in)>::max()); + + strm.next_in = (char *) data; + strm.avail_in = len; + + while (strm.avail_in) { + checkInterrupt(); + + int ret = BZ2_bzDecompress(&strm); + if (ret != BZ_OK && ret != BZ_STREAM_END) + throw CompressionError("error while decompressing bzip2 file"); + + finished = ret == BZ_STREAM_END; + + if (strm.avail_out < sizeof(outbuf) || strm.avail_in == 0) { + nextSink(outbuf, sizeof(outbuf) - strm.avail_out); + strm.next_out = (char *) outbuf; + strm.avail_out = sizeof(outbuf); } } + } +}; - uint8_t * next_out = outbuf.data(); - size_t avail_out = outbuf.size(); - - auto ret = BrotliDecoderDecompressStream(s, - &avail_in, &next_in, - &avail_out, &next_out, - nullptr); - - switch (ret) { - case BROTLI_DECODER_RESULT_ERROR: - throw CompressionError("error while decompressing brotli file"); - case BROTLI_DECODER_RESULT_NEEDS_MORE_INPUT: - if (eof) - throw CompressionError("incomplete or corrupt brotli file"); - break; - case BROTLI_DECODER_RESULT_SUCCESS: - if (avail_in != 0) - throw CompressionError("unexpected input after brotli decompression"); - break; - case BROTLI_DECODER_RESULT_NEEDS_MORE_OUTPUT: - // I'm not sure if this can happen, but abort if this happens with empty buffer - if (avail_out == outbuf.size()) - throw CompressionError("brotli decompression requires larger buffer"); - break; - } +struct BrotliDecompressionSink : ChunkedCompressionSink +{ + Sink & nextSink; + BrotliDecoderState * state; + bool finished = false; - // Always ensure we have full buffer for next invocation - if (avail_out < outbuf.size()) - sink((unsigned char *) outbuf.data(), outbuf.size() - avail_out); + BrotliDecompressionSink(Sink & nextSink) : nextSink(nextSink) + { + state = BrotliDecoderCreateInstance(nullptr, nullptr, nullptr); + if (!state) + throw CompressionError("unable to initialize brotli decoder"); + } - if (ret == BROTLI_DECODER_RESULT_SUCCESS) return; + ~BrotliDecompressionSink() + { + BrotliDecoderDestroyInstance(state); } -#endif // HAVE_BROTLI -} + + void finish() override + { + flush(); + writeInternal(nullptr, 0); + } + + void writeInternal(const unsigned char * data, size_t len) override + { + const uint8_t * next_in = data; + size_t avail_in = len; + uint8_t * next_out = outbuf; + size_t avail_out = sizeof(outbuf); + + while (!finished && (!data || avail_in)) { + checkInterrupt(); + + if (!BrotliDecoderDecompressStream(state, + &avail_in, &next_in, + &avail_out, &next_out, + nullptr)) + throw CompressionError("error while decompressing brotli file"); + + if (avail_out < sizeof(outbuf) || avail_in == 0) { + nextSink(outbuf, sizeof(outbuf) - avail_out); + next_out = outbuf; + avail_out = sizeof(outbuf); + } + + finished = BrotliDecoderIsFinished(state); + } + } +}; ref<std::string> decompress(const std::string & method, const std::string & in) { - StringSource source(in); - StringSink sink; - decompress(method, source, sink); - return sink.s; + StringSink ssink; + auto sink = makeDecompressionSink(method, ssink); + (*sink)(in); + sink->finish(); + return ssink.s; } -void decompress(const std::string & method, Source & source, Sink & sink) +ref<CompressionSink> makeDecompressionSink(const std::string & method, Sink & nextSink) { - if (method == "none") - return decompressNone(source, sink); + if (method == "none" || method == "") + return make_ref<NoneSink>(nextSink); else if (method == "xz") - return decompressXZ(source, sink); + return make_ref<XzDecompressionSink>(nextSink); else if (method == "bzip2") - return decompressBzip2(source, sink); + return make_ref<BzipDecompressionSink>(nextSink); else if (method == "br") - return decompressBrotli(source, sink); + return make_ref<BrotliDecompressionSink>(nextSink); else throw UnknownCompressionMethod("unknown compression method '%s'", method); } -struct NoneSink : CompressionSink -{ - Sink & nextSink; - NoneSink(Sink & nextSink) : nextSink(nextSink) { } - void finish() override { flush(); } - void write(const unsigned char * data, size_t len) override { nextSink(data, len); } -}; - -struct XzSink : CompressionSink +struct XzCompressionSink : CompressionSink { Sink & nextSink; uint8_t outbuf[BUFSIZ]; lzma_stream strm = LZMA_STREAM_INIT; bool finished = false; - template <typename F> - XzSink(Sink & nextSink, F&& initEncoder) : nextSink(nextSink) { - lzma_ret ret = initEncoder(); + XzCompressionSink(Sink & nextSink, bool parallel) : nextSink(nextSink) + { + lzma_ret ret; + bool done = false; + + if (parallel) { +#ifdef HAVE_LZMA_MT + lzma_mt mt_options = {}; + mt_options.flags = 0; + mt_options.timeout = 300; // Using the same setting as the xz cmd line + mt_options.preset = LZMA_PRESET_DEFAULT; + mt_options.filters = NULL; + mt_options.check = LZMA_CHECK_CRC64; + mt_options.threads = lzma_cputhreads(); + mt_options.block_size = 0; + if (mt_options.threads == 0) + mt_options.threads = 1; + // FIXME: maybe use lzma_stream_encoder_mt_memusage() to control the + // number of threads. + ret = lzma_stream_encoder_mt(&strm, &mt_options); + done = true; +#else + printMsg(lvlError, "warning: parallel compression requested but not supported for metho d '%1%', falling back to single-threaded compression", method); +#endif + } + + if (!done) + ret = lzma_easy_encoder(&strm, 6, LZMA_CHECK_CRC64); + if (ret != LZMA_OK) throw CompressionError("unable to initialise lzma encoder"); + // FIXME: apply the x86 BCJ filter? strm.next_out = outbuf; strm.avail_out = sizeof(outbuf); } - XzSink(Sink & nextSink) : XzSink(nextSink, [this]() { - return lzma_easy_encoder(&strm, 6, LZMA_CHECK_CRC64); - }) {} - ~XzSink() + ~XzCompressionSink() { lzma_end(&strm); } @@ -255,43 +274,25 @@ struct XzSink : CompressionSink void finish() override { CompressionSink::flush(); - - assert(!finished); - finished = true; - - while (true) { - checkInterrupt(); - - lzma_ret ret = lzma_code(&strm, LZMA_FINISH); - if (ret != LZMA_OK && ret != LZMA_STREAM_END) - throw CompressionError("error while flushing xz file"); - - if (strm.avail_out == 0 || ret == LZMA_STREAM_END) { - nextSink(outbuf, sizeof(outbuf) - strm.avail_out); - strm.next_out = outbuf; - strm.avail_out = sizeof(outbuf); - } - - if (ret == LZMA_STREAM_END) break; - } + write(nullptr, 0); } void write(const unsigned char * data, size_t len) override { - assert(!finished); - strm.next_in = data; strm.avail_in = len; - while (strm.avail_in) { + while (!finished && (!data || strm.avail_in)) { checkInterrupt(); - lzma_ret ret = lzma_code(&strm, LZMA_RUN); - if (ret != LZMA_OK) - throw CompressionError("error while compressing xz file"); + lzma_ret ret = lzma_code(&strm, data ? LZMA_RUN : LZMA_FINISH); + if (ret != LZMA_OK && ret != LZMA_STREAM_END) + throw CompressionError("error %d while compressing xz file", ret); + + finished = ret == LZMA_STREAM_END; - if (strm.avail_out == 0) { - nextSink(outbuf, sizeof(outbuf)); + if (strm.avail_out < sizeof(outbuf) || strm.avail_in == 0) { + nextSink(outbuf, sizeof(outbuf) - strm.avail_out); strm.next_out = outbuf; strm.avail_out = sizeof(outbuf); } @@ -299,46 +300,24 @@ struct XzSink : CompressionSink } }; -#ifdef HAVE_LZMA_MT -struct ParallelXzSink : public XzSink -{ - ParallelXzSink(Sink &nextSink) : XzSink(nextSink, [this]() { - lzma_mt mt_options = {}; - mt_options.flags = 0; - mt_options.timeout = 300; // Using the same setting as the xz cmd line - mt_options.preset = LZMA_PRESET_DEFAULT; - mt_options.filters = NULL; - mt_options.check = LZMA_CHECK_CRC64; - mt_options.threads = lzma_cputhreads(); - mt_options.block_size = 0; - if (mt_options.threads == 0) - mt_options.threads = 1; - // FIXME: maybe use lzma_stream_encoder_mt_memusage() to control the - // number of threads. - return lzma_stream_encoder_mt(&strm, &mt_options); - }) {} -}; -#endif - -struct BzipSink : CompressionSink +struct BzipCompressionSink : ChunkedCompressionSink { Sink & nextSink; - char outbuf[BUFSIZ]; bz_stream strm; bool finished = false; - BzipSink(Sink & nextSink) : nextSink(nextSink) + BzipCompressionSink(Sink & nextSink) : nextSink(nextSink) { memset(&strm, 0, sizeof(strm)); int ret = BZ2_bzCompressInit(&strm, 9, 0, 30); if (ret != BZ_OK) throw CompressionError("unable to initialise bzip2 encoder"); - strm.next_out = outbuf; + strm.next_out = (char *) outbuf; strm.avail_out = sizeof(outbuf); } - ~BzipSink() + ~BzipCompressionSink() { BZ2_bzCompressEnd(&strm); } @@ -346,114 +325,49 @@ struct BzipSink : CompressionSink void finish() override { flush(); - - assert(!finished); - finished = true; - - while (true) { - checkInterrupt(); - - int ret = BZ2_bzCompress(&strm, BZ_FINISH); - if (ret != BZ_FINISH_OK && ret != BZ_STREAM_END) - throw CompressionError("error while flushing bzip2 file"); - - if (strm.avail_out == 0 || ret == BZ_STREAM_END) { - nextSink((unsigned char *) outbuf, sizeof(outbuf) - strm.avail_out); - strm.next_out = outbuf; - strm.avail_out = sizeof(outbuf); - } - - if (ret == BZ_STREAM_END) break; - } - } - - void write(const unsigned char * data, size_t len) override - { - /* Bzip2's 'avail_in' parameter is an unsigned int, so we need - to split the input into chunks of at most 4 GiB. */ - while (len) { - auto n = std::min((size_t) std::numeric_limits<decltype(strm.avail_in)>::max(), len); - writeInternal(data, n); - data += n; - len -= n; - } + writeInternal(nullptr, 0); } - void writeInternal(const unsigned char * data, size_t len) + void writeInternal(const unsigned char * data, size_t len) override { - assert(!finished); assert(len <= std::numeric_limits<decltype(strm.avail_in)>::max()); strm.next_in = (char *) data; strm.avail_in = len; - while (strm.avail_in) { + while (!finished && (!data || strm.avail_in)) { checkInterrupt(); - int ret = BZ2_bzCompress(&strm, BZ_RUN); - if (ret != BZ_OK) - CompressionError("error while compressing bzip2 file"); + int ret = BZ2_bzCompress(&strm, data ? BZ_RUN : BZ_FINISH); + if (ret != BZ_RUN_OK && ret != BZ_FINISH_OK && ret != BZ_STREAM_END) + throw CompressionError("error %d while compressing bzip2 file", ret); - if (strm.avail_out == 0) { - nextSink((unsigned char *) outbuf, sizeof(outbuf)); - strm.next_out = outbuf; + finished = ret == BZ_STREAM_END; + + if (strm.avail_out < sizeof(outbuf) || strm.avail_in == 0) { + nextSink(outbuf, sizeof(outbuf) - strm.avail_out); + strm.next_out = (char *) outbuf; strm.avail_out = sizeof(outbuf); } } } }; -struct LambdaCompressionSink : CompressionSink -{ - Sink & nextSink; - std::string data; - using CompressFnTy = std::function<std::string(const std::string&)>; - CompressFnTy compressFn; - LambdaCompressionSink(Sink& nextSink, CompressFnTy compressFn) - : nextSink(nextSink) - , compressFn(std::move(compressFn)) - { - }; - - void finish() override - { - flush(); - nextSink(compressFn(data)); - } - - void write(const unsigned char * data, size_t len) override - { - checkInterrupt(); - this->data.append((const char *) data, len); - } -}; - -struct BrotliCmdSink : LambdaCompressionSink -{ - BrotliCmdSink(Sink& nextSink) - : LambdaCompressionSink(nextSink, [](const std::string& data) { - return runProgram(BROTLI, true, {}, data); - }) - { - } -}; - -#if HAVE_BROTLI -struct BrotliSink : CompressionSink +struct BrotliCompressionSink : ChunkedCompressionSink { Sink & nextSink; uint8_t outbuf[BUFSIZ]; BrotliEncoderState *state; bool finished = false; - BrotliSink(Sink & nextSink) : nextSink(nextSink) + BrotliCompressionSink(Sink & nextSink) : nextSink(nextSink) { state = BrotliEncoderCreateInstance(nullptr, nullptr, nullptr); if (!state) throw CompressionError("unable to initialise brotli encoder"); } - ~BrotliSink() + ~BrotliCompressionSink() { BrotliEncoderDestroyInstance(state); } @@ -461,94 +375,47 @@ struct BrotliSink : CompressionSink void finish() override { flush(); - assert(!finished); - - const uint8_t *next_in = nullptr; - size_t avail_in = 0; - uint8_t *next_out = outbuf; - size_t avail_out = sizeof(outbuf); - while (!finished) { - checkInterrupt(); - - if (!BrotliEncoderCompressStream(state, - BROTLI_OPERATION_FINISH, - &avail_in, &next_in, - &avail_out, &next_out, - nullptr)) - throw CompressionError("error while finishing brotli file"); - - finished = BrotliEncoderIsFinished(state); - if (avail_out == 0 || finished) { - nextSink(outbuf, sizeof(outbuf) - avail_out); - next_out = outbuf; - avail_out = sizeof(outbuf); - } - } + writeInternal(nullptr, 0); } - void write(const unsigned char * data, size_t len) override + void writeInternal(const unsigned char * data, size_t len) override { - // Don't feed brotli too much at once - const size_t CHUNK_SIZE = sizeof(outbuf) << 2; - while (len) { - size_t n = std::min(CHUNK_SIZE, len); - writeInternal(data, n); - data += n; - len -= n; - } - } - - void writeInternal(const unsigned char * data, size_t len) - { - assert(!finished); - - const uint8_t *next_in = data; + const uint8_t * next_in = data; size_t avail_in = len; - uint8_t *next_out = outbuf; + uint8_t * next_out = outbuf; size_t avail_out = sizeof(outbuf); - while (avail_in > 0) { + while (!finished && (!data || avail_in)) { checkInterrupt(); if (!BrotliEncoderCompressStream(state, - BROTLI_OPERATION_PROCESS, - &avail_in, &next_in, - &avail_out, &next_out, - nullptr)) - throw CompressionError("error while compressing brotli file"); + data ? BROTLI_OPERATION_PROCESS : BROTLI_OPERATION_FINISH, + &avail_in, &next_in, + &avail_out, &next_out, + nullptr)) + throw CompressionError("error while compressing brotli compression"); if (avail_out < sizeof(outbuf) || avail_in == 0) { nextSink(outbuf, sizeof(outbuf) - avail_out); next_out = outbuf; avail_out = sizeof(outbuf); } + + finished = BrotliEncoderIsFinished(state); } } }; -#endif // HAVE_BROTLI ref<CompressionSink> makeCompressionSink(const std::string & method, Sink & nextSink, const bool parallel) { - if (parallel) { -#ifdef HAVE_LZMA_MT - if (method == "xz") - return make_ref<ParallelXzSink>(nextSink); -#endif - printMsg(lvlError, format("Warning: parallel compression requested but not supported for method '%1%', falling back to single-threaded compression") % method); - } - if (method == "none") return make_ref<NoneSink>(nextSink); else if (method == "xz") - return make_ref<XzSink>(nextSink); + return make_ref<XzCompressionSink>(nextSink, parallel); else if (method == "bzip2") - return make_ref<BzipSink>(nextSink); + return make_ref<BzipCompressionSink>(nextSink); else if (method == "br") -#if HAVE_BROTLI - return make_ref<BrotliSink>(nextSink); -#else - return make_ref<BrotliCmdSink>(nextSink); -#endif + return make_ref<BrotliCompressionSink>(nextSink); else throw UnknownCompressionMethod(format("unknown compression method '%s'") % method); } diff --git a/src/libutil/compression.hh b/src/libutil/compression.hh index f7a3e3fbd32e..dd666a4e19fd 100644 --- a/src/libutil/compression.hh +++ b/src/libutil/compression.hh @@ -8,17 +8,17 @@ namespace nix { -ref<std::string> decompress(const std::string & method, const std::string & in); - -void decompress(const std::string & method, Source & source, Sink & sink); - -ref<std::string> compress(const std::string & method, const std::string & in, const bool parallel = false); - struct CompressionSink : BufferedSink { virtual void finish() = 0; }; +ref<std::string> decompress(const std::string & method, const std::string & in); + +ref<CompressionSink> makeDecompressionSink(const std::string & method, Sink & nextSink); + +ref<std::string> compress(const std::string & method, const std::string & in, const bool parallel = false); + ref<CompressionSink> makeCompressionSink(const std::string & method, Sink & nextSink, const bool parallel = false); MakeError(UnknownCompressionMethod, Error); diff --git a/src/libutil/json.cc b/src/libutil/json.cc index 813b257016e4..0a6fb65f0605 100644 --- a/src/libutil/json.cc +++ b/src/libutil/json.cc @@ -31,6 +31,7 @@ template<> void toJSON<unsigned long>(std::ostream & str, const unsigned long & template<> void toJSON<long long>(std::ostream & str, const long long & n) { str << n; } template<> void toJSON<unsigned long long>(std::ostream & str, const unsigned long long & n) { str << n; } template<> void toJSON<float>(std::ostream & str, const float & n) { str << n; } +template<> void toJSON<double>(std::ostream & str, const double & n) { str << n; } template<> void toJSON<std::string>(std::ostream & str, const std::string & s) { diff --git a/src/libutil/serialise.cc b/src/libutil/serialise.cc index 21803edd056a..17448f70efb6 100644 --- a/src/libutil/serialise.cc +++ b/src/libutil/serialise.cc @@ -157,16 +157,24 @@ size_t StringSource::read(unsigned char * data, size_t len) } -std::unique_ptr<Source> sinkToSource(std::function<void(Sink &)> fun) +#if BOOST_VERSION >= 106300 && BOOST_VERSION < 106600 +#error Coroutines are broken in this version of Boost! +#endif + +std::unique_ptr<Source> sinkToSource( + std::function<void(Sink &)> fun, + std::function<void()> eof) { struct SinkToSource : Source { typedef boost::coroutines2::coroutine<std::string> coro_t; + std::function<void()> eof; coro_t::pull_type coro; - SinkToSource(std::function<void(Sink &)> fun) - : coro([&](coro_t::push_type & yield) { + SinkToSource(std::function<void(Sink &)> fun, std::function<void()> eof) + : eof(eof) + , coro([&](coro_t::push_type & yield) { LambdaSink sink([&](const unsigned char * data, size_t len) { if (len) yield(std::string((const char *) data, len)); }); @@ -180,8 +188,7 @@ std::unique_ptr<Source> sinkToSource(std::function<void(Sink &)> fun) size_t read(unsigned char * data, size_t len) override { - if (!coro) - throw EndOfFile("coroutine has finished"); + if (!coro) { eof(); abort(); } if (pos == cur.size()) { if (!cur.empty()) coro(); @@ -197,7 +204,7 @@ std::unique_ptr<Source> sinkToSource(std::function<void(Sink &)> fun) } }; - return std::make_unique<SinkToSource>(fun); + return std::make_unique<SinkToSource>(fun, eof); } diff --git a/src/libutil/serialise.hh b/src/libutil/serialise.hh index 14b62fdb6774..4b6ad5da5b9c 100644 --- a/src/libutil/serialise.hh +++ b/src/libutil/serialise.hh @@ -214,7 +214,11 @@ struct LambdaSource : Source /* Convert a function that feeds data into a Sink into a Source. The Source executes the function as a coroutine. */ -std::unique_ptr<Source> sinkToSource(std::function<void(Sink &)> fun); +std::unique_ptr<Source> sinkToSource( + std::function<void(Sink &)> fun, + std::function<void()> eof = []() { + throw EndOfFile("coroutine has finished"); + }); void writePadding(size_t len, Sink & sink); diff --git a/src/libutil/sync.hh b/src/libutil/sync.hh index 3b2710f6fd30..e1d591d77a84 100644 --- a/src/libutil/sync.hh +++ b/src/libutil/sync.hh @@ -57,11 +57,11 @@ public: } template<class Rep, class Period> - void wait_for(std::condition_variable & cv, + std::cv_status wait_for(std::condition_variable & cv, const std::chrono::duration<Rep, Period> & duration) { assert(s); - cv.wait_for(lk, duration); + return cv.wait_for(lk, duration); } template<class Rep, class Period, class Predicate> diff --git a/src/nix-build/nix-build.cc b/src/nix-build/nix-build.cc index de0e9118fd21..94d3a27560fe 100755 --- a/src/nix-build/nix-build.cc +++ b/src/nix-build/nix-build.cc @@ -85,7 +85,6 @@ void mainWrapped(int argc, char * * argv) BuildMode buildMode = bmNormal; bool readStdin = false; - auto shell = getEnv("SHELL", "/bin/sh"); std::string envCommand; // interactive shell Strings envExclude; @@ -99,6 +98,9 @@ void mainWrapped(int argc, char * * argv) std::string outLink = "./result"; + // List of environment variables kept for --pure + std::set<string> keepVars{"HOME", "USER", "LOGNAME", "DISPLAY", "PATH", "TERM", "IN_NIX_SHELL", "TZ", "PAGER", "NIX_BUILD_SHELL", "SHLVL"}; + Strings args; for (int i = 1; i < argc; ++i) args.push_back(argv[i]); @@ -218,6 +220,9 @@ void mainWrapped(int argc, char * * argv) } } + else if (*arg == "--keep") + keepVars.insert(getArg(*arg, arg, end)); + else if (*arg == "-") readStdin = true; @@ -300,6 +305,8 @@ void mainWrapped(int argc, char * * argv) } } + state->printStats(); + auto buildPaths = [&](const PathSet & paths) { /* Note: we do this even when !printMissing to efficiently fetch binary cache data. */ @@ -368,7 +375,6 @@ void mainWrapped(int argc, char * * argv) auto tmp = getEnv("TMPDIR", getEnv("XDG_RUNTIME_DIR", "/tmp")); if (pure) { - std::set<string> keepVars{"HOME", "USER", "LOGNAME", "DISPLAY", "PATH", "TERM", "IN_NIX_SHELL", "TZ", "PAGER", "NIX_BUILD_SHELL", "SHLVL"}; decltype(env) newEnv; for (auto & i : env) if (keepVars.count(i.first)) @@ -415,7 +421,6 @@ void mainWrapped(int argc, char * * argv) R"s([ -n "$PS1" ] && PS1='\n\[\033[1;32m\][nix-shell:\w]\$\[\033[0m\] '; )s" "if [ \"$(type -t runHook)\" = function ]; then runHook shellHook; fi; " "unset NIX_ENFORCE_PURITY; " - "unset NIX_INDENT_MAKE; " "shopt -u nullglob; " "unset TZ; %4%" "%5%", diff --git a/src/nix-daemon/nix-daemon.cc b/src/nix-daemon/nix-daemon.cc index 423e6bb67893..644fa6681de3 100644 --- a/src/nix-daemon/nix-daemon.cc +++ b/src/nix-daemon/nix-daemon.cc @@ -233,7 +233,7 @@ struct RetrieveRegularNARSink : ParseSink }; -static void performOp(TunnelLogger * logger, ref<LocalStore> store, +static void performOp(TunnelLogger * logger, ref<Store> store, bool trusted, unsigned int clientVersion, Source & from, Sink & to, unsigned int op) { @@ -362,7 +362,11 @@ static void performOp(TunnelLogger * logger, ref<LocalStore> store, logger->startWork(); if (!savedRegular.regular) throw Error("regular file expected"); - Path path = store->addToStoreFromDump(recursive ? *savedNAR.data : savedRegular.s, baseName, recursive, hashAlgo); + + auto store2 = store.dynamic_pointer_cast<LocalStore>(); + if (!store2) throw Error("operation is only supported by LocalStore"); + + Path path = store2->addToStoreFromDump(recursive ? *savedNAR.data : savedRegular.s, baseName, recursive, hashAlgo); logger->stopWork(); to << path; @@ -703,6 +707,7 @@ static void performOp(TunnelLogger * logger, ref<LocalStore> store, logger->startWork(); + // FIXME: race if addToStore doesn't read source? store.cast<Store>()->addToStore(info, *source, (RepairFlag) repair, dontCheckSigs ? NoCheckSigs : CheckSigs, nullptr); @@ -776,7 +781,7 @@ static void processConnection(bool trusted) Store::Params params; // FIXME: get params from somewhere // Disable caching since the client already does that. params["path-info-cache-size"] = "0"; - auto store = make_ref<LocalStore>(params); + auto store = openStore(settings.storeUri, params); tunnelLogger->stopWork(); to.flush(); diff --git a/src/nix-prefetch-url/nix-prefetch-url.cc b/src/nix-prefetch-url/nix-prefetch-url.cc index 50b2c2803ec9..a3b025723cf1 100644 --- a/src/nix-prefetch-url/nix-prefetch-url.cc +++ b/src/nix-prefetch-url/nix-prefetch-url.cc @@ -9,6 +9,10 @@ #include <iostream> +#include <sys/types.h> +#include <sys/stat.h> +#include <fcntl.h> + using namespace nix; @@ -160,14 +164,20 @@ int main(int argc, char * * argv) auto actualUri = resolveMirrorUri(*state, uri); - /* Download the file. */ - DownloadRequest req(actualUri); - req.decompress = false; - auto result = getDownloader()->download(req); - AutoDelete tmpDir(createTempDir(), true); Path tmpFile = (Path) tmpDir + "/tmp"; - writeFile(tmpFile, *result.data); + + /* Download the file. */ + { + AutoCloseFD fd = open(tmpFile.c_str(), O_WRONLY | O_CREAT | O_EXCL, 0600); + if (!fd) throw SysError("creating temporary file '%s'", tmpFile); + + FdSink sink(fd.get()); + + DownloadRequest req(actualUri); + req.decompress = false; + getDownloader()->download(std::move(req), sink); + } /* Optionally unpack the file. */ if (unpack) { @@ -191,7 +201,7 @@ int main(int argc, char * * argv) /* FIXME: inefficient; addToStore() will also hash this. */ - hash = unpack ? hashPath(ht, tmpFile).first : hashString(ht, *result.data); + hash = unpack ? hashPath(ht, tmpFile).first : hashFile(ht, tmpFile); if (expectedHash != Hash(ht) && expectedHash != hash) throw Error(format("hash mismatch for '%1%'") % uri); diff --git a/src/nix-store/nix-store.cc b/src/nix-store/nix-store.cc index e1e27ceef94d..fe68f681ae28 100644 --- a/src/nix-store/nix-store.cc +++ b/src/nix-store/nix-store.cc @@ -860,7 +860,7 @@ static void opServe(Strings opFlags, Strings opArgs) } case cmdDumpStorePath: - dumpPath(readStorePath(*store, in), out); + store->narFromPath(readStorePath(*store, in), out); break; case cmdImportPaths: { @@ -924,6 +924,28 @@ static void opServe(Strings opFlags, Strings opArgs) break; } + case cmdAddToStoreNar: { + if (!writeAllowed) throw Error("importing paths is not allowed"); + + ValidPathInfo info; + info.path = readStorePath(*store, in); + in >> info.deriver; + if (!info.deriver.empty()) + store->assertStorePath(info.deriver); + info.narHash = Hash(readString(in), htSHA256); + info.references = readStorePaths<PathSet>(*store, in); + in >> info.registrationTime >> info.narSize >> info.ultimate; + info.sigs = readStrings<StringSet>(in); + in >> info.ca; + + // FIXME: race if addToStore doesn't read source? + store->addToStore(info, in, NoRepair, NoCheckSigs); + + out << 1; // indicate success + + break; + } + default: throw Error(format("unknown serve command %1%") % cmd); } diff --git a/src/nix/copy.cc b/src/nix/copy.cc index e4e6c3e303ed..91711c8b46da 100644 --- a/src/nix/copy.cc +++ b/src/nix/copy.cc @@ -72,6 +72,10 @@ struct CmdCopy : StorePathsCommand "To populate the current folder build output to a S3 binary cache:", "nix copy --to s3://my-bucket?region=eu-west-1" }, + Example{ + "To populate the current folder build output to an S3-compatible binary cache:", + "nix copy --to s3://my-bucket?region=eu-west-1&endpoint=example.com" + }, #endif }; } diff --git a/src/nix/installables.cc b/src/nix/installables.cc index 0be992b03c5a..0c1ad3ab3db0 100644 --- a/src/nix/installables.cc +++ b/src/nix/installables.cc @@ -96,7 +96,7 @@ struct InstallableStorePath : Installable Buildables toBuildables() override { - return {{"", {{"out", storePath}}}}; + return {{isDerivation(storePath) ? storePath : "", {{"out", storePath}}}}; } }; diff --git a/src/nix/main.cc b/src/nix/main.cc index 9cd5d21c84b6..69791e223c22 100644 --- a/src/nix/main.cc +++ b/src/nix/main.cc @@ -24,7 +24,6 @@ struct NixArgs : virtual MultiCommand, virtual MixCommonArgs { mkFlag() .longName("help") - .shortName('h') .description("show usage information") .handler([&]() { showHelpAndExit(); }); diff --git a/src/nix/path-info.cc b/src/nix/path-info.cc index 47caa401d3c9..dea5f0557b81 100644 --- a/src/nix/path-info.cc +++ b/src/nix/path-info.cc @@ -4,8 +4,8 @@ #include "json.hh" #include "common-args.hh" -#include <iomanip> #include <algorithm> +#include <array> using namespace nix; @@ -13,12 +13,14 @@ struct CmdPathInfo : StorePathsCommand, MixJSON { bool showSize = false; bool showClosureSize = false; + bool humanReadable = false; bool showSigs = false; CmdPathInfo() { mkFlag('s', "size", "print size of the NAR dump of each path", &showSize); mkFlag('S', "closure-size", "print sum size of the NAR dumps of the closure of each path", &showClosureSize); + mkFlag('h', "human-readable", "with -s and -S, print sizes like 1K 234M 5.67G etc.", &humanReadable); mkFlag(0, "sigs", "show signatures", &showSigs); } @@ -40,6 +42,10 @@ struct CmdPathInfo : StorePathsCommand, MixJSON "nix path-info -rS /run/current-system | sort -nk2" }, Example{ + "To show a package's closure size and all its dependencies with human readable sizes:", + "nix path-info -rsSh nixpkgs.rust" + }, + Example{ "To check the existence of a path in a binary cache:", "nix path-info -r /nix/store/7qvk5c91...-geeqie-1.1 --store https://cache.nixos.org/" }, @@ -58,6 +64,25 @@ struct CmdPathInfo : StorePathsCommand, MixJSON }; } + void printSize(unsigned long long value) + { + if (!humanReadable) { + std::cout << fmt("\t%11d", value); + return; + } + + static const std::array<char, 9> idents{{ + ' ', 'K', 'M', 'G', 'T', 'P', 'E', 'Z', 'Y' + }}; + size_t power = 0; + double res = value; + while (res > 1024 && power < idents.size()) { + ++power; + res /= 1024; + } + std::cout << fmt("\t%6.1f%c", res, idents.at(power)); + } + void run(ref<Store> store, Paths storePaths) override { size_t pathLen = 0; @@ -78,13 +103,16 @@ struct CmdPathInfo : StorePathsCommand, MixJSON auto info = store->queryPathInfo(storePath); storePath = info->path; // FIXME: screws up padding - std::cout << storePath << std::string(std::max(0, (int) pathLen - (int) storePath.size()), ' '); + std::cout << storePath; + + if (showSize || showClosureSize || showSigs) + std::cout << std::string(std::max(0, (int) pathLen - (int) storePath.size()), ' '); if (showSize) - std::cout << '\t' << std::setw(11) << info->narSize; + printSize(info->narSize); if (showClosureSize) - std::cout << '\t' << std::setw(11) << store->getClosureSize(storePath).first; + printSize(store->getClosureSize(storePath).first); if (showSigs) { std::cout << '\t'; diff --git a/src/nix/repl.cc b/src/nix/repl.cc index 4723a1974b77..b71e6f905f23 100644 --- a/src/nix/repl.cc +++ b/src/nix/repl.cc @@ -173,9 +173,14 @@ void NixRepl::mainLoop(const std::vector<std::string> & files) printMsg(lvlError, format(error + "%1%%2%") % (settings.showTrace ? e.prefix() : "") % e.msg()); } + if (input.size() > 0) { + // Remove trailing newline before adding to history + input.erase(input.size() - 1); + linenoiseHistoryAdd(input.c_str()); + } + // We handled the current input fully, so we should clear it // and read brand new input. - linenoiseHistoryAdd(input.c_str()); input.clear(); std::cout << std::endl; } @@ -385,7 +390,7 @@ bool NixRepl::processLine(string line) /* We could do the build in this process using buildPaths(), but doing it in a child makes it easier to recover from problems / SIGINT. */ - if (runProgram(settings.nixBinDir + "/nix-store", Strings{"-r", drvPath}) == 0) { + if (runProgram(settings.nixBinDir + "/nix", Strings{"build", drvPath}) == 0) { Derivation drv = readDerivation(drvPath); std::cout << std::endl << "this derivation produced the following outputs:" << std::endl; for (auto & i : drv.outputs) diff --git a/src/nix/run.cc b/src/nix/run.cc index d04e106e037b..35b763345872 100644 --- a/src/nix/run.cc +++ b/src/nix/run.cc @@ -7,11 +7,14 @@ #include "finally.hh" #include "fs-accessor.hh" #include "progress-bar.hh" +#include "affinity.hh" #if __linux__ #include <sys/mount.h> #endif +#include <queue> + using namespace nix; std::string chrootHelperName = "__run_in_chroot"; @@ -121,10 +124,27 @@ struct CmdRun : InstallablesCommand unsetenv(var.c_str()); } + std::unordered_set<Path> done; + std::queue<Path> todo; + for (auto & path : outPaths) todo.push(path); + auto unixPath = tokenizeString<Strings>(getEnv("PATH"), ":"); - for (auto & path : outPaths) - if (accessor->stat(path + "/bin").type != FSAccessor::tMissing) + + while (!todo.empty()) { + Path path = todo.front(); + todo.pop(); + if (!done.insert(path).second) continue; + + if (true) unixPath.push_front(path + "/bin"); + + auto propPath = path + "/nix-support/propagated-user-env-packages"; + if (accessor->stat(propPath).type == FSAccessor::tRegular) { + for (auto & p : tokenizeString<Paths>(readFile(propPath))) + todo.push(p); + } + } + setenv("PATH", concatStringsSep(":", unixPath).c_str(), 1); std::string cmd = *command.begin(); @@ -135,6 +155,8 @@ struct CmdRun : InstallablesCommand restoreSignals(); + restoreAffinity(); + /* If this is a diverted store (i.e. its "logical" location (typically /nix/store) differs from its "physical" location (e.g. /home/eelco/nix/store), then run the command in a diff --git a/src/nix/search.cc b/src/nix/search.cc index 539676698086..4cb1efa7955b 100644 --- a/src/nix/search.cc +++ b/src/nix/search.cc @@ -7,19 +7,25 @@ #include "common-args.hh" #include "json.hh" #include "json-to-value.hh" +#include "shared.hh" #include <regex> #include <fstream> using namespace nix; -std::string hilite(const std::string & s, const std::smatch & m) +std::string wrap(std::string prefix, std::string s) +{ + return prefix + s + ANSI_NORMAL; +} + +std::string hilite(const std::string & s, const std::smatch & m, std::string postfix) { return m.empty() ? s : std::string(m.prefix()) - + ANSI_RED + std::string(m.str()) + ANSI_NORMAL + + ANSI_RED + std::string(m.str()) + postfix + std::string(m.suffix()); } @@ -75,6 +81,10 @@ struct CmdSearch : SourceExprCommand, MixJSON "To search for git and frontend or gui:", "nix search git 'frontend|gui'" }, + Example{ + "To display the description of the found packages:", + "nix search git --verbose" + } }; } @@ -164,14 +174,10 @@ struct CmdSearch : SourceExprCommand, MixJSON } else { results[attrPath] = fmt( - "Attribute name: %s\n" - "Package name: %s\n" - "Version: %s\n" - "Description: %s\n", - hilite(attrPath, attrPathMatch), - hilite(name, nameMatch), - parsed.version, - hilite(description, descriptionMatch)); + "* %s (%s)\n %s\n", + wrap("\e[0;1m", hilite(attrPath, attrPathMatch, "\e[0;1m")), + wrap("\e[0;2m", hilite(parsed.fullName, nameMatch, "\e[0;2m")), + hilite(description, descriptionMatch, ANSI_NORMAL)); } } @@ -263,6 +269,10 @@ struct CmdSearch : SourceExprCommand, MixJSON throw SysError("cannot rename '%s' to '%s'", tmpFile, jsonCacheFileName); } + if (results.size() == 0) + throw Error("no results for the given search term(s)!"); + + RunPager pager; for (auto el : results) std::cout << el.second << "\n"; } diff --git a/src/nix/upgrade-nix.cc b/src/nix/upgrade-nix.cc index e23ae792369c..35c44a70cf52 100644 --- a/src/nix/upgrade-nix.cc +++ b/src/nix/upgrade-nix.cc @@ -1,14 +1,18 @@ #include "command.hh" +#include "common-args.hh" #include "store-api.hh" #include "download.hh" #include "eval.hh" #include "attr-path.hh" +#include "names.hh" +#include "progress-bar.hh" using namespace nix; -struct CmdUpgradeNix : StoreCommand +struct CmdUpgradeNix : MixDryRun, StoreCommand { Path profileDir; + std::string storePathsUrl = "https://github.com/NixOS/nixpkgs/raw/master/nixos/modules/installer/tools/nix-fallback-paths.nix"; CmdUpgradeNix() { @@ -18,6 +22,12 @@ struct CmdUpgradeNix : StoreCommand .labels({"profile-dir"}) .description("the Nix profile to upgrade") .dest(&profileDir); + + mkFlag() + .longName("nix-store-paths-url") + .labels({"url"}) + .description("URL of the file that contains the store paths of the latest Nix release") + .dest(&storePathsUrl); } std::string name() override @@ -59,6 +69,14 @@ struct CmdUpgradeNix : StoreCommand storePath = getLatestNix(store); } + auto version = DrvName(storePathToName(storePath)).version; + + if (dryRun) { + stopProgressBar(); + printError("would upgrade to version %s", version); + return; + } + { Activity act(*logger, lvlInfo, actUnknown, fmt("downloading '%s'...", storePath)); store->ensurePath(storePath); @@ -72,11 +90,15 @@ struct CmdUpgradeNix : StoreCommand throw Error("could not verify that '%s' works", program); } + stopProgressBar(); + { Activity act(*logger, lvlInfo, actUnknown, fmt("installing '%s' into profile '%s'...", storePath, profileDir)); runProgram(settings.nixBinDir + "/nix-env", false, {"--profile", profileDir, "-i", storePath, "--no-sandbox"}); } + + printError(ANSI_GREEN "upgrade to version %s done" ANSI_NORMAL, version); } /* Return the profile in which Nix is installed. */ @@ -98,11 +120,18 @@ struct CmdUpgradeNix : StoreCommand if (hasPrefix(where, "/run/current-system")) throw Error("Nix on NixOS must be upgraded via 'nixos-rebuild'"); - Path profileDir; - Path userEnv; + Path profileDir = dirOf(where); + + // Resolve profile to /nix/var/nix/profiles/<name> link. + while (canonPath(profileDir).find("/profiles/") == std::string::npos && isLink(profileDir)) + profileDir = readLink(profileDir); + + printInfo("found profile '%s'", profileDir); + + Path userEnv = canonPath(profileDir, true); if (baseNameOf(where) != "bin" || - !hasSuffix(userEnv = canonPath(profileDir = dirOf(where), true), "user-environment")) + !hasSuffix(userEnv, "user-environment")) throw Error("directory '%s' does not appear to be part of a Nix profile", where); if (!store->isValidPath(userEnv)) @@ -115,7 +144,7 @@ struct CmdUpgradeNix : StoreCommand Path getLatestNix(ref<Store> store) { // FIXME: use nixos.org? - auto req = DownloadRequest("https://github.com/NixOS/nixpkgs/raw/master/nixos/modules/installer/tools/nix-fallback-paths.nix"); + auto req = DownloadRequest(storePathsUrl); auto res = getDownloader()->download(req); auto state = std::make_unique<EvalState>(Strings(), store); diff --git a/src/nix/why-depends.cc b/src/nix/why-depends.cc index 17e0595ae887..325a2be0a793 100644 --- a/src/nix/why-depends.cc +++ b/src/nix/why-depends.cc @@ -2,6 +2,7 @@ #include "store-api.hh" #include "progress-bar.hh" #include "fs-accessor.hh" +#include "shared.hh" #include <queue> @@ -237,6 +238,7 @@ struct CmdWhyDepends : SourceExprCommand visitPath(node.path); + RunPager pager; for (auto & ref : refs) { auto hash = storePathToHash(ref.second->path); |