diff --git a/src/bench/checkblock.cpp b/src/bench/checkblock.cpp index 765b8b0dadcd..3f0dd2b1f51b 100644 --- a/src/bench/checkblock.cpp +++ b/src/bench/checkblock.cpp @@ -21,15 +21,38 @@ #include #include +static void SizeComputerBlock(benchmark::Bench& bench) { + CBlock block; + DataStream(benchmark::data::block413567) >> TX_WITH_WITNESS(block); + + bench.unit("block").run([&] { + SizeComputer size_computer; + size_computer << TX_WITH_WITNESS(block); + assert(size_computer.size() == benchmark::data::block413567.size()); + }); +} + +static void SerializeBlock(benchmark::Bench& bench) { + CBlock block; + DataStream(benchmark::data::block413567) >> TX_WITH_WITNESS(block); + + // Create output stream and verify first serialization matches input + bench.unit("block").run([&] { + DataStream output_stream(benchmark::data::block413567.size()); + output_stream << TX_WITH_WITNESS(block); + assert(output_stream.size() == benchmark::data::block413567.size()); + }); +} + // These are the two major time-sinks which happen after we have fully received // a block off the wire, but before we can relay the block on to peers using // compact block relay. -static void DeserializeBlockTest(benchmark::Bench& bench) +static void DeserializeBlock(benchmark::Bench& bench) { DataStream stream(benchmark::data::block413567); std::byte a{0}; - stream.write({&a, 1}); // Prevent compaction + stream.write(std::span{&a, 1}); // Prevent compaction bench.unit("block").run([&] { CBlock block; @@ -39,11 +62,11 @@ static void DeserializeBlockTest(benchmark::Bench& bench) }); } -static void DeserializeAndCheckBlockTest(benchmark::Bench& bench) +static void DeserializeAndCheckBlock(benchmark::Bench& bench) { DataStream stream(benchmark::data::block413567); std::byte a{0}; - stream.write({&a, 1}); // Prevent compaction + stream.write(std::span{&a, 1}); // Prevent compaction ArgsManager bench_args; const auto chainParams = CreateChainParams(bench_args, ChainType::MAIN); @@ -60,5 +83,7 @@ static void DeserializeAndCheckBlockTest(benchmark::Bench& bench) }); } -BENCHMARK(DeserializeBlockTest); -BENCHMARK(DeserializeAndCheckBlockTest); +BENCHMARK(SizeComputerBlock); +BENCHMARK(SerializeBlock); +BENCHMARK(DeserializeBlock); +BENCHMARK(DeserializeAndCheckBlock); diff --git a/src/bench/rpc_blockchain.cpp b/src/bench/rpc_blockchain.cpp index 0e89ac78a136..3f8ad5351980 100644 --- a/src/bench/rpc_blockchain.cpp +++ b/src/bench/rpc_blockchain.cpp @@ -33,7 +33,7 @@ struct TestBlockAndIndex { { DataStream stream{benchmark::data::block413567}; std::byte a{0}; - stream.write({&a, 1}); // Prevent compaction + stream.write(std::span{&a, 1}); // Prevent compaction stream >> TX_WITH_WITNESS(block); diff --git a/src/crypto/sha256.cpp b/src/crypto/sha256.cpp index 54bd9a59f9ef..7d4dfa81941f 100644 --- a/src/crypto/sha256.cpp +++ b/src/crypto/sha256.cpp @@ -721,6 +721,21 @@ CSHA256& CSHA256::Write(const unsigned char* data, size_t len) } return *this; } +CSHA256& CSHA256::Write(unsigned char data) +{ + size_t bufsize = bytes % 64; + + // Add the single byte to the buffer + buf[bufsize] = data; + bytes += 1; + + if (bufsize == 63) { + // Process the buffer if full + Transform(s, buf, 1); + } + + return *this; +} void CSHA256::Finalize(unsigned char hash[OUTPUT_SIZE]) { diff --git a/src/crypto/sha256.h b/src/crypto/sha256.h index 3ac771c5d0db..ba4b5eb9c5e1 100644 --- a/src/crypto/sha256.h +++ b/src/crypto/sha256.h @@ -22,6 +22,7 @@ class CSHA256 CSHA256(); CSHA256& Write(const unsigned char* data, size_t len); + CSHA256& Write(unsigned char data); void Finalize(unsigned char hash[OUTPUT_SIZE]); CSHA256& Reset(); }; diff --git a/src/hash.h b/src/hash.h index 34486af64a1d..da3b1ab1145f 100644 --- a/src/hash.h +++ b/src/hash.h @@ -38,6 +38,10 @@ class CHash256 { sha.Write(input.data(), input.size()); return *this; } + CHash256& Write(std::span input) { + sha.Write(input[0]); + return *this; + } CHash256& Reset() { sha.Reset(); @@ -63,6 +67,10 @@ class CHash160 { sha.Write(input.data(), input.size()); return *this; } + CHash160& Write(std::span input) { + sha.Write(input[0]); + return *this; + } CHash160& Reset() { sha.Reset(); @@ -107,6 +115,10 @@ class HashWriter { ctx.Write(UCharCast(src.data()), src.size()); } + void write(std::span src) + { + ctx.Write(*UCharCast(&src[0])); + } /** Compute the double-SHA256 hash of all data written to this object. * @@ -160,13 +172,18 @@ class HashVerifier : public HashWriter m_source.read(dst); this->write(dst); } + void read(std::span dst) + { + m_source.read(dst); + this->write(std::span{dst}); + } void ignore(size_t num_bytes) { std::byte data[1024]; while (num_bytes > 0) { size_t now = std::min(num_bytes, 1024); - read({data, now}); + read(std::span{data, now}); num_bytes -= now; } } @@ -194,6 +211,11 @@ class HashedSourceWriter : public HashWriter m_source.write(src); HashWriter::write(src); } + void write(std::span src) + { + m_source.write(src); + HashWriter::write(src); + } template HashedSourceWriter& operator<<(const T& obj) diff --git a/src/serialize.h b/src/serialize.h index 21b3325f7a7a..c946f4e90782 100644 --- a/src/serialize.h +++ b/src/serialize.h @@ -48,67 +48,75 @@ static const unsigned int MAX_VECTOR_ALLOCATE = 5000000; struct deserialize_type {}; constexpr deserialize_type deserialize {}; +class SizeComputer; + +//! Check if type contains a stream by seeing if it has a GetStream() method. +template +concept ContainsStream = requires(T t) { t.GetStream(); }; + +template +concept ContainsSizeComputer = ContainsStream && + std::is_same_v().GetStream())>, SizeComputer>; + /* * Lowest-level serialization and conversion. */ template inline void ser_writedata8(Stream &s, uint8_t obj) { - s.write(std::as_bytes(std::span{&obj, 1})); + s.write(std::as_bytes(std::span{&obj, 1})); } template inline void ser_writedata16(Stream &s, uint16_t obj) { obj = htole16_internal(obj); - s.write(std::as_bytes(std::span{&obj, 1})); + s.write(std::as_bytes(std::span{&obj, 1})); } template inline void ser_writedata32(Stream &s, uint32_t obj) { obj = htole32_internal(obj); - s.write(std::as_bytes(std::span{&obj, 1})); + s.write(std::as_bytes(std::span{&obj, 1})); } template inline void ser_writedata32be(Stream &s, uint32_t obj) { obj = htobe32_internal(obj); - s.write(std::as_bytes(std::span{&obj, 1})); + s.write(std::as_bytes(std::span{&obj, 1})); } template inline void ser_writedata64(Stream &s, uint64_t obj) { obj = htole64_internal(obj); - s.write(std::as_bytes(std::span{&obj, 1})); + s.write(std::as_bytes(std::span{&obj, 1})); } template inline uint8_t ser_readdata8(Stream &s) { uint8_t obj; - s.read(std::as_writable_bytes(std::span{&obj, 1})); + s.read(std::as_writable_bytes(std::span{&obj, 1})); return obj; } template inline uint16_t ser_readdata16(Stream &s) { uint16_t obj; - s.read(std::as_writable_bytes(std::span{&obj, 1})); + s.read(std::as_writable_bytes(std::span{&obj, 1})); return le16toh_internal(obj); } template inline uint32_t ser_readdata32(Stream &s) { uint32_t obj; - s.read(std::as_writable_bytes(std::span{&obj, 1})); + s.read(std::as_writable_bytes(std::span{&obj, 1})); return le32toh_internal(obj); } template inline uint32_t ser_readdata32be(Stream &s) { uint32_t obj; - s.read(std::as_writable_bytes(std::span{&obj, 1})); + s.read(std::as_writable_bytes(std::span{&obj, 1})); return be32toh_internal(obj); } template inline uint64_t ser_readdata64(Stream &s) { uint64_t obj; - s.read(std::as_writable_bytes(std::span{&obj, 1})); + s.read(std::as_writable_bytes(std::span{&obj, 1})); return le64toh_internal(obj); } -class SizeComputer; - /** * Convert any argument to a reference to X, maintaining constness. * @@ -240,41 +248,76 @@ const Out& AsBase(const In& x) template concept CharNotInt8 = std::same_as && !std::same_as; -// clang-format off +template +concept ByteOrIntegral = std::is_same_v || + (std::is_integral_v && !std::is_same_v); + template void Serialize(Stream&, V) = delete; // char serialization forbidden. Use uint8_t or int8_t -template void Serialize(Stream& s, std::byte a) { ser_writedata8(s, uint8_t(a)); } -template void Serialize(Stream& s, int8_t a) { ser_writedata8(s, uint8_t(a)); } -template void Serialize(Stream& s, uint8_t a) { ser_writedata8(s, a); } -template void Serialize(Stream& s, int16_t a) { ser_writedata16(s, uint16_t(a)); } -template void Serialize(Stream& s, uint16_t a) { ser_writedata16(s, a); } -template void Serialize(Stream& s, int32_t a) { ser_writedata32(s, uint32_t(a)); } -template void Serialize(Stream& s, uint32_t a) { ser_writedata32(s, a); } -template void Serialize(Stream& s, int64_t a) { ser_writedata64(s, uint64_t(a)); } -template void Serialize(Stream& s, uint64_t a) { ser_writedata64(s, a); } - -template void Serialize(Stream& s, const B (&a)[N]) { s.write(MakeByteSpan(a)); } -template void Serialize(Stream& s, const std::array& a) { s.write(MakeByteSpan(a)); } -template void Serialize(Stream& s, std::span span) { s.write(std::as_bytes(span)); } -template void Serialize(Stream& s, std::span span) { s.write(std::as_bytes(span)); } +template void Serialize(Stream& s, T a) +{ + if constexpr (ContainsSizeComputer) { + s.GetStream().seek(sizeof(T)); + } else if constexpr (sizeof(T) == 1) { + ser_writedata8(s, static_cast(a)); // (u)int8_t or std::byte or bool + } else if constexpr (sizeof(T) == 2) { + ser_writedata16(s, static_cast(a)); // (u)int16_t + } else if constexpr (sizeof(T) == 4) { + ser_writedata32(s, static_cast(a)); // (u)int32_t + } else { + static_assert(sizeof(T) == 8); + ser_writedata64(s, static_cast(a)); // (u)int64_t + } +} +template void Serialize(Stream& s, const B (&a)[N]) +{ + if constexpr (ContainsSizeComputer) { + s.GetStream().seek(N); + } else { + s.write(MakeByteSpan(a)); + } +} +template void Serialize(Stream& s, const std::array& a) +{ + if constexpr (ContainsSizeComputer) { + s.GetStream().seek(N); + } else { + s.write(MakeByteSpan(a)); + } +} +template void Serialize(Stream& s, std::span span) +{ + if constexpr (ContainsSizeComputer) { + s.GetStream().seek(N); + } else { + s.write(std::as_bytes(span)); + } +} +template void Serialize(Stream& s, std::span span) +{ + if constexpr (ContainsSizeComputer) { + s.GetStream().seek(span.size()); + } else { + s.write(std::as_bytes(span)); + } +} template void Unserialize(Stream&, V) = delete; // char serialization forbidden. Use uint8_t or int8_t -template void Unserialize(Stream& s, std::byte& a) { a = std::byte(ser_readdata8(s)); } -template void Unserialize(Stream& s, int8_t& a) { a = int8_t(ser_readdata8(s)); } -template void Unserialize(Stream& s, uint8_t& a) { a = ser_readdata8(s); } -template void Unserialize(Stream& s, int16_t& a) { a = int16_t(ser_readdata16(s)); } -template void Unserialize(Stream& s, uint16_t& a) { a = ser_readdata16(s); } -template void Unserialize(Stream& s, int32_t& a) { a = int32_t(ser_readdata32(s)); } -template void Unserialize(Stream& s, uint32_t& a) { a = ser_readdata32(s); } -template void Unserialize(Stream& s, int64_t& a) { a = int64_t(ser_readdata64(s)); } -template void Unserialize(Stream& s, uint64_t& a) { a = ser_readdata64(s); } - -template void Unserialize(Stream& s, B (&a)[N]) { s.read(MakeWritableByteSpan(a)); } -template void Unserialize(Stream& s, std::array& a) { s.read(MakeWritableByteSpan(a)); } -template void Unserialize(Stream& s, std::span span) { s.read(std::as_writable_bytes(span)); } -template void Unserialize(Stream& s, std::span span) { s.read(std::as_writable_bytes(span)); } - -template void Serialize(Stream& s, bool a) { uint8_t f = a; ser_writedata8(s, f); } -template void Unserialize(Stream& s, bool& a) { uint8_t f = ser_readdata8(s); a = f; } +template void Unserialize(Stream& s, T& a) +{ + if constexpr (sizeof(T) == 1) { + a = static_cast(ser_readdata8(s)); // (u)int8_t or std::byte or bool + } else if constexpr (sizeof(T) == 2) { + a = static_cast(ser_readdata16(s)); // (u)int16_t + } else if constexpr (sizeof(T) == 4) { + a = static_cast(ser_readdata32(s)); // (u)int32_t + } else { + static_assert(sizeof(T) == 8); + a = static_cast(ser_readdata64(s)); // (u)int64_t + } +} +template void Unserialize(Stream& s, B (&a)[N]) { s.read(MakeWritableByteSpan(a)); } +template void Unserialize(Stream& s, std::array& a) { s.read(MakeWritableByteSpan(a)); } +template void Unserialize(Stream& s, std::span span) { s.read(std::as_writable_bytes(span)); } // clang-format on @@ -293,12 +336,14 @@ constexpr inline unsigned int GetSizeOfCompactSize(uint64_t nSize) else return sizeof(unsigned char) + sizeof(uint64_t); } -inline void WriteCompactSize(SizeComputer& os, uint64_t nSize); - template void WriteCompactSize(Stream& os, uint64_t nSize) { - if (nSize < 253) + if constexpr (ContainsSizeComputer) + { + os.GetStream().seek(GetSizeOfCompactSize(nSize)); + } + else if (nSize < 253) { ser_writedata8(os, nSize); } @@ -405,7 +450,7 @@ struct CheckVarIntMode { }; template -inline unsigned int GetSizeOfVarInt(I n) +constexpr unsigned int GetSizeOfVarInt(I n) { CheckVarIntMode(); int nRet = 0; @@ -418,25 +463,92 @@ inline unsigned int GetSizeOfVarInt(I n) return nRet; } -template -inline void WriteVarInt(SizeComputer& os, I n); +template +ALWAYS_INLINE void WriteVarIntFixed(Stream& os, I n) +{ + unsigned char out[N]; + if constexpr (N == 2) { + out[0] = static_cast(((n >> 7) - 1) | 0x80); + out[1] = static_cast(n & 0x7F); + } else { + I x = n; + out[N - 1] = static_cast(x & 0x7F); + if constexpr (N > 1) { + x = (x >> 7) - 1; + out[N - 2] = static_cast((x & 0x7F) | 0x80); + } + if constexpr (N > 2) { + x = (x >> 7) - 1; + out[N - 3] = static_cast((x & 0x7F) | 0x80); + } + if constexpr (N > 3) { + x = (x >> 7) - 1; + out[N - 4] = static_cast((x & 0x7F) | 0x80); + } + if constexpr (N > 4) { + x = (x >> 7) - 1; + out[N - 5] = static_cast((x & 0x7F) | 0x80); + } + if constexpr (N > 5) { + x = (x >> 7) - 1; + out[N - 6] = static_cast((x & 0x7F) | 0x80); + } + if constexpr (N > 6) { + x = (x >> 7) - 1; + out[N - 7] = static_cast((x & 0x7F) | 0x80); + } + if constexpr (N > 7) { + x = (x >> 7) - 1; + out[N - 8] = static_cast((x & 0x7F) | 0x80); + } + } + os.write(std::as_bytes(std::span{out})); +} template -void WriteVarInt(Stream& os, I n) +ALWAYS_INLINE void WriteVarInt(Stream& os, I n) { - CheckVarIntMode(); - unsigned char tmp[(sizeof(n)*8+6)/7]; - int len=0; - while(true) { - tmp[len] = (n & 0x7F) | (len ? 0x80 : 0x00); - if (n <= 0x7F) - break; - n = (n >> 7) - 1; - len++; + if constexpr (ContainsSizeComputer) { + os.GetStream().seek(GetSizeOfVarInt(n)); + } else { + CheckVarIntMode(); + if (n <= 0x7F) { + ser_writedata8(os, n); + return; + } + if (n <= 0x1020407F) { + if (n <= 0x407F) { + WriteVarIntFixed<2>(os, n); + return; + } + if (n <= 0x20407F) { + WriteVarIntFixed<3>(os, n); + return; + } + WriteVarIntFixed<4>(os, n); + return; + } + if (n <= 0x4081020407FULL) { + if (n <= 0x81020407FULL) { + WriteVarIntFixed<5>(os, n); + return; + } + WriteVarIntFixed<6>(os, n); + return; + } + if (n <= 0x204081020407FULL) { + WriteVarIntFixed<7>(os, n); + return; + } + unsigned char tmp[(sizeof(n) * 8 + 6) / 7]; + size_t pos = std::size(tmp); + tmp[--pos] = n & 0x7F; + while (n > 0x7F) { + n = (n >> 7) - 1; + tmp[--pos] = (n & 0x7F) | 0x80; + } + os.write(std::as_bytes(std::span{tmp}.subspan(pos))); } - do { - ser_writedata8(os, tmp[len]); - } while(len--); } template @@ -480,7 +592,7 @@ class Wrapper * serialization, and Unser(stream, object&) for deserialization. Serialization routines (inside * READWRITE, or directly with << and >> operators), can then use Using(object). * - * This works by constructing a Wrapper-wrapped version of object, where T is + * This works by constructing a Wrapper-wrapped version of object, where T is * const during serialization, and non-const during deserialization, which maintains const * correctness. */ @@ -525,12 +637,14 @@ struct CustomUintFormatter template void Ser(Stream& s, I v) { if (v < 0 || v > MAX) throw std::ios_base::failure("CustomUintFormatter value out of range"); - if (BigEndian) { + if constexpr (ContainsSizeComputer) { + s.GetStream().seek(Bytes); + } else if (BigEndian) { uint64_t raw = htobe64_internal(v); - s.write(std::as_bytes(std::span{&raw, 1}).last(Bytes)); + s.write(std::as_bytes(std::span{&raw, 1}).template last()); } else { uint64_t raw = htole64_internal(v); - s.write(std::as_bytes(std::span{&raw, 1}).first(Bytes)); + s.write(std::as_bytes(std::span{&raw, 1}).template first()); } } @@ -540,10 +654,10 @@ struct CustomUintFormatter static_assert(std::numeric_limits::max() >= MAX && std::numeric_limits::min() <= 0, "Assigned type too small"); uint64_t raw = 0; if (BigEndian) { - s.read(std::as_writable_bytes(std::span{&raw, 1}).last(Bytes)); + s.read(std::as_writable_bytes(std::span{&raw, 1}).last()); v = static_cast(be64toh_internal(raw)); } else { - s.read(std::as_writable_bytes(std::span{&raw, 1}).first(Bytes)); + s.read(std::as_writable_bytes(std::span{&raw, 1}).first()); v = static_cast(le64toh_internal(raw)); } } @@ -815,10 +929,18 @@ void Unserialize(Stream& is, prevector& v) if constexpr (BasicByte) { // Use optimized version for unformatted basic bytes // Limit size per read so bogus size value won't cause out of memory v.clear(); - unsigned int nSize = ReadCompactSize(is); - unsigned int i = 0; + size_t nSize = ReadCompactSize(is); + constexpr size_t max_chunk{static_cast(1 + 4999999 / sizeof(T))}; + if (nSize <= max_chunk) { + v.resize_uninitialized(nSize); + if (nSize != 0) { + is.read(std::as_writable_bytes(std::span{v.data(), nSize})); + } + return; + } + size_t i = 0; while (i < nSize) { - unsigned int blk = std::min(nSize - i, (unsigned int)(1 + 4999999 / sizeof(T))); + size_t blk = std::min(nSize - i, max_chunk); v.resize_uninitialized(i + blk); is.read(std::as_writable_bytes(std::span{&v[i], blk})); i += blk; @@ -858,10 +980,18 @@ void Unserialize(Stream& is, std::vector& v) if constexpr (BasicByte) { // Use optimized version for unformatted basic bytes // Limit size per read so bogus size value won't cause out of memory v.clear(); - unsigned int nSize = ReadCompactSize(is); - unsigned int i = 0; + size_t nSize = ReadCompactSize(is); + constexpr size_t max_chunk{static_cast(1 + 4999999 / sizeof(T))}; + if (nSize <= max_chunk) { + v.resize(nSize); + if (nSize != 0) { + is.read(std::as_writable_bytes(std::span{v.data(), nSize})); + } + return; + } + size_t i = 0; while (i < nSize) { - unsigned int blk = std::min(nSize - i, (unsigned int)(1 + 4999999 / sizeof(T))); + size_t blk = std::min(nSize - i, max_chunk); v.resize(i + blk); is.read(std::as_writable_bytes(std::span{&v[i], blk})); i += blk; @@ -912,7 +1042,7 @@ void Unserialize(Stream& is, std::map& m) { std::pair item; Unserialize(is, item); - mi = m.insert(mi, item); + mi = m.insert(mi, std::move(item)); } } @@ -939,7 +1069,7 @@ void Unserialize(Stream& is, std::set& m) { K key; Unserialize(is, key); - it = m.insert(it, key); + it = m.insert(it, std::move(key)); } } @@ -1056,10 +1186,17 @@ class SizeComputer public: SizeComputer() = default; + SizeComputer& GetStream() { return *this; } + const SizeComputer& GetStream() const { return *this; }; + void write(std::span src) { m_size += src.size(); } + void write(std::span) + { + this->m_size += 1; + } /** Pretend this many bytes are written, without specifying them. */ void seek(uint64_t num) @@ -1080,27 +1217,12 @@ class SizeComputer } }; -template -inline void WriteVarInt(SizeComputer &s, I n) -{ - s.seek(GetSizeOfVarInt(n)); -} - -inline void WriteCompactSize(SizeComputer &s, uint64_t nSize) -{ - s.seek(GetSizeOfCompactSize(nSize)); -} - template uint64_t GetSerializeSize(const T& t) { return (SizeComputer() << t).size(); } -//! Check if type contains a stream by seeing if has a GetStream() method. -template -concept ContainsStream = requires(T t) { t.GetStream(); }; - /** Wrapper that overrides the GetParams() function of a stream. */ template class ParamsStream @@ -1125,7 +1247,9 @@ class ParamsStream template ParamsStream& operator<<(const U& obj) { ::Serialize(*this, obj); return *this; } template ParamsStream& operator>>(U&& obj) { ::Unserialize(*this, obj); return *this; } void write(std::span src) { GetStream().write(src); } + void write(std::span src) { GetStream().write(src); } void read(std::span dst) { GetStream().read(dst); } + void read(std::span dst) { GetStream().read(dst); } void ignore(size_t num) { GetStream().ignore(num); } bool empty() const { return GetStream().empty(); } size_t size() const { return GetStream().size(); } diff --git a/src/streams.cpp b/src/streams.cpp index e38b9592942e..65df3c4b6f55 100644 --- a/src/streams.cpp +++ b/src/streams.cpp @@ -78,6 +78,13 @@ void AutoFile::read(std::span dst) } } +void AutoFile::read(std::span dst) +{ + if (detail_fread(dst) != 1) { + throw std::ios_base::failure(feof() ? "AutoFile::read: end of file" : "AutoFile::read: fread failed"); + } +} + void AutoFile::ignore(size_t nSize) { if (!m_file) throw std::ios_base::failure("AutoFile::ignore: file handle is nullptr"); @@ -112,6 +119,12 @@ void AutoFile::write(std::span src) } } +void AutoFile::write(std::span src) +{ + std::byte temp_byte = src[0]; + write_buffer(std::span(&temp_byte, 1)); +} + void AutoFile::write_buffer(std::span src) { if (!m_file) throw std::ios_base::failure("AutoFile::write_buffer: file handle is nullptr"); diff --git a/src/streams.h b/src/streams.h index f70adcf74a71..0567f7a2bf26 100644 --- a/src/streams.h +++ b/src/streams.h @@ -12,7 +12,6 @@ #include #include #include -#include #include #include @@ -53,15 +52,32 @@ class VectorWriter { ::SerializeMany(*this, std::forward(args)...); } - void write(std::span src) + template + void write(std::span src) { assert(nPos <= vchData.size()); + const auto src_ptr{UCharCast(src.data())}; + if constexpr (Extent == 1) { + const auto byte{src_ptr[0]}; + if (nPos < vchData.size()) { + vchData[nPos] = byte; + } else { + vchData.push_back(byte); + } + nPos += 1; + return; + } + if (nPos == vchData.size()) { + vchData.insert(vchData.end(), src_ptr, src_ptr + src.size()); + nPos += src.size(); + return; + } size_t nOverwrite = std::min(src.size(), vchData.size() - nPos); if (nOverwrite) { - memcpy(vchData.data() + nPos, src.data(), nOverwrite); + memcpy(vchData.data() + nPos, src_ptr, nOverwrite); } if (nOverwrite < src.size()) { - vchData.insert(vchData.end(), UCharCast(src.data()) + nOverwrite, UCharCast(src.data() + src.size())); + vchData.insert(vchData.end(), src_ptr + nOverwrite, src_ptr + src.size()); } nPos += src.size(); } @@ -101,18 +117,24 @@ class SpanReader size_t size() const { return m_data.size(); } bool empty() const { return m_data.empty(); } - void read(std::span dst) + template + void read(std::span dst) { - if (dst.size() == 0) { - return; - } - - // Read from the beginning of the buffer - if (dst.size() > m_data.size()) { - throw std::ios_base::failure("SpanReader::read(): end of data"); + if constexpr (Extent == 1) { + if (m_data.empty()) { + throw std::ios_base::failure("SpanReader::read(): end of data"); + } + dst[0] = m_data[0]; + m_data = m_data.subspan(1); + } else { + const auto n{dst.size()}; + // Read from the beginning of the buffer + if (n > m_data.size()) { + throw std::ios_base::failure("SpanReader::read(): end of data"); + } + memcpy(dst.data(), m_data.data(), n); + m_data = m_data.subspan(n); } - memcpy(dst.data(), m_data.data(), dst.size()); - m_data = m_data.subspan(dst.size()); } void ignore(size_t n) @@ -148,6 +170,7 @@ class DataStream typedef vector_type::reverse_iterator reverse_iterator; explicit DataStream() = default; + explicit DataStream(size_type n) { reserve(n); } explicit DataStream(std::span sp) : DataStream{std::as_bytes(sp)} {} explicit DataStream(std::span sp) : vch(sp.data(), sp.data() + sp.size()) {} @@ -200,43 +223,65 @@ class DataStream // int in_avail() const { return size(); } - void read(std::span dst) + template + void read(std::span dst) { - if (dst.size() == 0) return; - - // Read from the beginning of the buffer - auto next_read_pos{CheckedAdd(m_read_pos, dst.size())}; - if (!next_read_pos.has_value() || next_read_pos.value() > vch.size()) { - throw std::ios_base::failure("DataStream::read(): end of data"); - } - memcpy(dst.data(), &vch[m_read_pos], dst.size()); - if (next_read_pos.value() == vch.size()) { - m_read_pos = 0; - vch.clear(); - return; + if constexpr (Extent == 1) { + if (m_read_pos == vch.size()) { + throw std::ios_base::failure("DataStream::read(): end of data"); + } + dst[0] = vch[m_read_pos]; + ++m_read_pos; + if (m_read_pos == vch.size()) { + m_read_pos = 0; + vch.clear(); + } + } else { + const auto n{dst.size()}; + const auto avail{vch.size() - m_read_pos}; + if (n > avail) { + throw std::ios_base::failure("DataStream::read(): end of data"); + } + memcpy(dst.data(), &vch[m_read_pos], n); + if (n == avail) { + m_read_pos = 0; + vch.clear(); + return; + } + m_read_pos += n; } - m_read_pos = next_read_pos.value(); } void ignore(size_t num_ignore) { - // Ignore from the beginning of the buffer - auto next_read_pos{CheckedAdd(m_read_pos, num_ignore)}; - if (!next_read_pos.has_value() || next_read_pos.value() > vch.size()) { + const auto avail{vch.size() - m_read_pos}; + if (num_ignore > avail) { throw std::ios_base::failure("DataStream::ignore(): end of data"); } - if (next_read_pos.value() == vch.size()) { + if (num_ignore == avail) { m_read_pos = 0; vch.clear(); return; } - m_read_pos = next_read_pos.value(); + m_read_pos += num_ignore; } - void write(std::span src) + template + void write(std::span src) { // Write to the end of the buffer - vch.insert(vch.end(), src.begin(), src.end()); + if constexpr (Extent == 1) { + vch.push_back(src[0]); + } else if constexpr (Extent == 2) { + vch.push_back(src[0]); + vch.push_back(src[1]); + } else if constexpr (Extent != std::dynamic_extent) { + // Keep Extent a compile-time constant so small fixed-size writes can be optimized better + // than the dynamic-size path. + vch.insert(vch.end(), src.data(), src.data() + Extent); + } else { + vch.insert(vch.end(), src.data(), src.data() + src.size()); + } } template @@ -453,8 +498,10 @@ class AutoFile // Stream subset // void read(std::span dst); + void read(std::span dst); void ignore(size_t nSize); void write(std::span src); + void write(std::span src); template AutoFile& operator<<(const T& obj) diff --git a/src/test/crypto_tests.cpp b/src/test/crypto_tests.cpp index 5588d4cdbc66..0aab9ef0e77d 100644 --- a/src/test/crypto_tests.cpp +++ b/src/test/crypto_tests.cpp @@ -1079,7 +1079,7 @@ BOOST_AUTO_TEST_CASE(sha256d64) in[j] = m_rng.randbits(8); } for (int j = 0; j < i; ++j) { - CHash256().Write({in + 64 * j, 64}).Finalize({out1 + 32 * j, 32}); + CHash256().Write(std::span{in + 64 * j, 64}).Finalize({out1 + 32 * j, 32}); } SHA256D64(out2, in, i); BOOST_CHECK(memcmp(out1, out2, 32 * i) == 0); diff --git a/src/test/fuzz/autofile.cpp b/src/test/fuzz/autofile.cpp index 5aa5d8c13322..b1be07c11239 100644 --- a/src/test/fuzz/autofile.cpp +++ b/src/test/fuzz/autofile.cpp @@ -31,14 +31,14 @@ FUZZ_TARGET(autofile) [&] { std::array arr{}; try { - auto_file.read({arr.data(), fuzzed_data_provider.ConsumeIntegralInRange(0, 4096)}); + auto_file.read(std::span{arr.data(), fuzzed_data_provider.ConsumeIntegralInRange(0, 4096)}); } catch (const std::ios_base::failure&) { } }, [&] { const std::array arr{}; try { - auto_file.write({arr.data(), fuzzed_data_provider.ConsumeIntegralInRange(0, 4096)}); + auto_file.write(std::span{arr.data(), fuzzed_data_provider.ConsumeIntegralInRange(0, 4096)}); } catch (const std::ios_base::failure&) { } }, diff --git a/src/test/streams_tests.cpp b/src/test/streams_tests.cpp index af75ee987ad3..1c76497d3c35 100644 --- a/src/test/streams_tests.cpp +++ b/src/test/streams_tests.cpp @@ -98,9 +98,9 @@ BOOST_AUTO_TEST_CASE(xor_file) { // Check errors for missing file AutoFile xor_file{raw_file("rb"), obfuscation}; - BOOST_CHECK_EXCEPTION(xor_file << std::byte{}, std::ios_base::failure, HasReason{"AutoFile::write: file handle is nullptr"}); - BOOST_CHECK_EXCEPTION(xor_file >> std::byte{}, std::ios_base::failure, HasReason{"AutoFile::read: file handle is nullptr"}); - BOOST_CHECK_EXCEPTION(xor_file.ignore(1), std::ios_base::failure, HasReason{"AutoFile::ignore: file handle is nullptr"}); + BOOST_CHECK_EXCEPTION(xor_file << std::byte{}, std::ios_base::failure, HasReason{"file handle is nullptr"}); + BOOST_CHECK_EXCEPTION(xor_file >> std::byte{}, std::ios_base::failure, HasReason{"file handle is nullptr"}); + BOOST_CHECK_EXCEPTION(xor_file.ignore(1), std::ios_base::failure, HasReason{"file handle is nullptr"}); BOOST_CHECK_EXCEPTION(xor_file.size(), std::ios_base::failure, HasReason{"AutoFile::size: file handle is nullptr"}); } {