#include "libutil/serialise.hh"

#include <boost/coroutine2/coroutine.hpp>
#include <cerrno>
#include <cstring>
#include <memory>
#include <utility>

#include <glog/logging.h>

#include "libutil/util.hh"

namespace nix {

void BufferedSink::operator()(const unsigned char* data, size_t len) {
  if (!buffer) {
    buffer = decltype(buffer)(new unsigned char[bufSize]);
  }

  while (len != 0u) {
    /* Optimisation: bypass the buffer if the data exceeds the
       buffer size. */
    if (bufPos + len >= bufSize) {
      flush();
      write(data, len);
      break;
    }
    /* Otherwise, copy the bytes to the buffer.  Flush the buffer
       when it's full. */
    size_t n = bufPos + len > bufSize ? bufSize - bufPos : len;
    memcpy(buffer.get() + bufPos, data, n);
    data += n;
    bufPos += n;
    len -= n;
    if (bufPos == bufSize) {
      flush();
    }
  }
}

void BufferedSink::flush() {
  if (bufPos == 0) {
    return;
  }
  size_t n = bufPos;
  bufPos = 0;  // don't trigger the assert() in ~BufferedSink()
  write(buffer.get(), n);
}

FdSink::~FdSink() {
  try {
    flush();
  } catch (...) {
    ignoreException();
  }
}

size_t threshold = 256 * 1024 * 1024;

static void warnLargeDump() {
  LOG(WARNING)
      << "dumping very large path (> 256 MiB); this may run out of memory";
}

void FdSink::write(const unsigned char* data, size_t len) {
  written += len;
  static bool warned = false;
  if (warn && !warned) {
    if (written > threshold) {
      warnLargeDump();
      warned = true;
    }
  }
  try {
    writeFull(fd, data, len);
  } catch (SysError& e) {
    _good = false;
    throw;
  }
}

bool FdSink::good() { return _good; }

void Source::operator()(unsigned char* data, size_t len) {
  while (len != 0u) {
    size_t n = read(data, len);
    data += n;
    len -= n;
  }
}

std::string Source::drain() {
  std::string s;
  std::vector<unsigned char> buf(8192);
  while (true) {
    size_t n = 0;
    try {
      n = read(buf.data(), buf.size());
      s.append(reinterpret_cast<char*>(buf.data()), n);
    } catch (EndOfFile&) {
      break;
    }
  }
  return s;
}

size_t BufferedSource::read(unsigned char* data, size_t len) {
  if (!buffer) {
    buffer = decltype(buffer)(new unsigned char[bufSize]);
  }

  if (bufPosIn == 0u) {
    bufPosIn = readUnbuffered(buffer.get(), bufSize);
  }

  /* Copy out the data in the buffer. */
  size_t n = len > bufPosIn - bufPosOut ? bufPosIn - bufPosOut : len;
  memcpy(data, buffer.get() + bufPosOut, n);
  bufPosOut += n;
  if (bufPosIn == bufPosOut) {
    bufPosIn = bufPosOut = 0;
  }
  return n;
}

bool BufferedSource::hasData() { return bufPosOut < bufPosIn; }

size_t FdSource::readUnbuffered(unsigned char* data, size_t len) {
  ssize_t n = 0;
  do {
    checkInterrupt();
    n = ::read(fd, reinterpret_cast<char*>(data), len);
  } while (n == -1 && errno == EINTR);
  if (n == -1) {
    _good = false;
    throw SysError("reading from file");
  }
  if (n == 0) {
    _good = false;
    throw EndOfFile("unexpected end-of-file");
  }
  read += n;
  return n;
}

bool FdSource::good() { return _good; }

size_t StringSource::read(unsigned char* data, size_t len) {
  if (pos == s.size()) {
    throw EndOfFile("end of string reached");
  }
  size_t n = s.copy(reinterpret_cast<char*>(data), len, pos);
  pos += n;
  return n;
}

#if BOOST_VERSION >= 106300 && BOOST_VERSION < 106600
#error Coroutines are broken in this version of Boost!
#endif

std::unique_ptr<Source> sinkToSource(const std::function<void(Sink&)>& fun,
                                     const std::function<void()>& eof) {
  struct SinkToSource : Source {
    using coro_t = boost::coroutines2::coroutine<std::string>;

    std::function<void(Sink&)> fun;
    std::function<void()> eof;
    std::optional<coro_t::pull_type> coro;
    bool started = false;

    SinkToSource(std::function<void(Sink&)> fun, std::function<void()> eof)
        : fun(std::move(fun)), eof(std::move(eof)) {}

    std::string cur;
    size_t pos = 0;

    size_t read(unsigned char* data, size_t len) override {
      if (!coro) {
        coro = coro_t::pull_type([&](coro_t::push_type& yield) {
          LambdaSink sink([&](const unsigned char* data, size_t len) {
            if (len != 0u) {
              yield(std::string(reinterpret_cast<const char*>(data), len));
            }
          });
          fun(sink);
        });
      }

      if (!*coro) {
        eof();
        abort();
      }

      if (pos == cur.size()) {
        if (!cur.empty()) {
          (*coro)();
        }
        cur = coro->get();
        pos = 0;
      }

      auto n = std::min(cur.size() - pos, len);
      memcpy(data, reinterpret_cast<unsigned char*>(cur.data()) + pos, n);
      pos += n;

      return n;
    }
  };

  return std::make_unique<SinkToSource>(fun, eof);
}

void writePadding(size_t len, Sink& sink) {
  if ((len % 8) != 0u) {
    unsigned char zero[8];
    memset(zero, 0, sizeof(zero));
    sink(zero, 8 - (len % 8));
  }
}

void writeString(const unsigned char* buf, size_t len, Sink& sink) {
  sink << len;
  sink(buf, len);
  writePadding(len, sink);
}

Sink& operator<<(Sink& sink, const std::string& s) {
  writeString(reinterpret_cast<const unsigned char*>(s.data()), s.size(), sink);
  return sink;
}

template <class T>
void writeStrings(const T& ss, Sink& sink) {
  sink << ss.size();
  for (auto& i : ss) {
    sink << i;
  }
}

Sink& operator<<(Sink& sink, const Strings& s) {
  writeStrings(s, sink);
  return sink;
}

Sink& operator<<(Sink& sink, const StringSet& s) {
  writeStrings(s, sink);
  return sink;
}

void readPadding(size_t len, Source& source) {
  if ((len % 8) != 0u) {
    unsigned char zero[8];
    size_t n = 8 - (len % 8);
    source(zero, n);
    for (unsigned int i = 0; i < n; i++) {
      if (zero[i] != 0u) {
        throw SerialisationError("non-zero padding");
      }
    }
  }
}

size_t readString(unsigned char* buf, size_t max, Source& source) {
  auto len = readNum<size_t>(source);
  if (len > max) {
    throw SerialisationError("string is too long");
  }
  source(buf, len);
  readPadding(len, source);
  return len;
}

std::string readString(Source& source, size_t max) {
  auto len = readNum<size_t>(source);
  if (len > max) {
    throw SerialisationError("string is too long");
  }
  std::string res(len, 0);
  source(reinterpret_cast<unsigned char*>(res.data()), len);
  readPadding(len, source);
  return res;
}

Source& operator>>(Source& in, std::string& s) {
  s = readString(in);
  return in;
}

template <class T>
T readStrings(Source& source) {
  auto count = readNum<size_t>(source);
  T ss;
  while (count--) {
    ss.insert(ss.end(), readString(source));
  }
  return ss;
}

template Paths readStrings(Source& source);
template PathSet readStrings(Source& source);

void StringSink::operator()(const unsigned char* data, size_t len) {
  static bool warned = false;
  if (!warned && s->size() > threshold) {
    warnLargeDump();
    warned = true;
  }
  s->append(reinterpret_cast<const char*>(data), len);
}

}  // namespace nix