From dc132b2035169cb2cf6e84be4bf6533b5520a7c5 Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Wed, 13 Aug 2025 11:47:26 +0900 Subject: [PATCH 01/14] GH-47317: [C++][C++23][Gandiva] Use pointer for Cache test `gandiva::Cache` requires pointer type for value type. --- cpp/src/gandiva/cache_test.cc | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/cpp/src/gandiva/cache_test.cc b/cpp/src/gandiva/cache_test.cc index 96cf4a12e58..d371db59dfc 100644 --- a/cpp/src/gandiva/cache_test.cc +++ b/cpp/src/gandiva/cache_test.cc @@ -35,11 +35,13 @@ class TestCacheKey { }; TEST(TestCache, TestGetPut) { - Cache cache(2); - cache.PutObjectCode(TestCacheKey(1), "hello"); - cache.PutObjectCode(TestCacheKey(2), "world"); - ASSERT_EQ(cache.GetObjectCode(TestCacheKey(1)), "hello"); - ASSERT_EQ(cache.GetObjectCode(TestCacheKey(2)), "world"); + Cache> cache(2); + auto hello = std::make_shared("hello"); + cache.PutObjectCode(TestCacheKey(1), hello); + auto world = std::make_shared("world"); + cache.PutObjectCode(TestCacheKey(2), world); + ASSERT_EQ(cache.GetObjectCode(TestCacheKey(1)), hello); + ASSERT_EQ(cache.GetObjectCode(TestCacheKey(2)), world); } namespace { From a39c27c341cf67bb51077988584191914351232b Mon Sep 17 00:00:00 2001 From: Sutou Kouhei Date: Wed, 13 Aug 2025 18:38:12 +0900 Subject: [PATCH 02/14] Use Debian 13 to use more recent GCC --- .github/workflows/cpp_extra.yml | 8 ++ ci/docker/debian-13-cpp.dockerfile | 144 +++++++++++++++++++++++++++++ 2 files changed, 152 insertions(+) create mode 100644 ci/docker/debian-13-cpp.dockerfile diff --git a/.github/workflows/cpp_extra.yml b/.github/workflows/cpp_extra.yml index 891aab51ccb..a391d497dde 100644 --- a/.github/workflows/cpp_extra.yml +++ b/.github/workflows/cpp_extra.yml @@ -126,6 +126,8 @@ jobs: title: AMD64 Ubuntu Meson # TODO: We should remove this "continue-on-error: true" once GH-47207 is resolved - continue-on-error: true + envs: + - DEBIAN=13 image: debian-cpp run-options: >- -e CMAKE_CXX_STANDARD=23 @@ -158,10 +160,16 @@ jobs: env: ARCHERY_DOCKER_USER: ${{ secrets.DOCKERHUB_USER }} ARCHERY_DOCKER_PASSWORD: ${{ secrets.DOCKERHUB_TOKEN }} + ENVS: ${{ toJSON(matrix.envs) }} run: | # GH-40558: reduce ASLR to avoid ASAN/LSAN crashes sudo sysctl -w vm.mmap_rnd_bits=28 source ci/scripts/util_enable_core_dumps.sh + if [ "${ENVS}" != "null" ]; then + echo "${ENVS}" | jq -r '.[]' | while read env; do + echo "${env}" >> .env + done + fi archery docker run ${{ matrix.run-options || '' }} ${{ matrix.image }} - name: Docker Push if: >- diff --git a/ci/docker/debian-13-cpp.dockerfile b/ci/docker/debian-13-cpp.dockerfile new file mode 100644 index 00000000000..3e5c645c81a --- /dev/null +++ b/ci/docker/debian-13-cpp.dockerfile @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG arch=amd64 +FROM ${arch}/debian:13 +ARG arch + +ENV DEBIAN_FRONTEND noninteractive + +ARG llvm +RUN apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + apt-transport-https \ + ca-certificates \ + gnupg \ + lsb-release \ + wget && \ + if [ ${llvm} -ge 20 ]; then \ + wget -O /usr/share/keyrings/llvm-snapshot.asc \ + https://apt.llvm.org/llvm-snapshot.gpg.key && \ + (echo "Types: deb"; \ + echo "URIs: https://apt.llvm.org/$(lsb_release --codename --short)/"; \ + echo "Suites: llvm-toolchain-$(lsb_release --codename --short)-${llvm}"; \ + echo "Components: main"; \ + echo "Signed-By: /usr/share/keyrings/llvm-snapshot.asc") | \ + tee /etc/apt/sources.list.d/llvm.sources; \ + fi && \ + apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + autoconf \ + ccache \ + clang-${llvm} \ + cmake \ + curl \ + g++ \ + gcc \ + gdb \ + git \ + libbenchmark-dev \ + libboost-filesystem-dev \ + libboost-system-dev \ + libbrotli-dev \ + libbz2-dev \ + libc-ares-dev \ + libcurl4-openssl-dev \ + libgflags-dev \ + libgmock-dev \ + libgoogle-glog-dev \ + libgrpc++-dev \ + libidn2-dev \ + libkrb5-dev \ + libldap-dev \ + liblz4-dev \ + libnghttp2-dev \ + libprotobuf-dev \ + libprotoc-dev \ + libpsl-dev \ + libre2-dev \ + librtmp-dev \ + libsnappy-dev \ + libsqlite3-dev \ + libssh-dev \ + libssh2-1-dev \ + libssl-dev \ + libthrift-dev \ + libutf8proc-dev \ + libxml2-dev \ + libxsimd-dev \ + libzstd-dev \ + llvm-${llvm}-dev \ + make \ + ninja-build \ + nlohmann-json3-dev \ + npm \ + opentelemetry-cpp-dev \ + pkg-config \ + protobuf-compiler-grpc \ + python3-dev \ + python3-pip \ + python3-venv \ + rapidjson-dev \ + rsync \ + tzdata \ + zlib1g-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +COPY ci/scripts/install_minio.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_minio.sh latest /usr/local + +COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_gcs_testbench.sh default + +COPY ci/scripts/install_azurite.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_azurite.sh + +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + +# Prioritize system packages and local installation. +ENV ARROW_ACERO=ON \ + ARROW_AZURE=ON \ + ARROW_BUILD_TESTS=ON \ + ARROW_DATASET=ON \ + ARROW_DEPENDENCY_SOURCE=SYSTEM \ + ARROW_DATASET=ON \ + ARROW_FLIGHT=ON \ + ARROW_FLIGHT_SQL=ON \ + ARROW_GANDIVA=ON \ + ARROW_GCS=ON \ + ARROW_HOME=/usr/local \ + ARROW_JEMALLOC=ON \ + ARROW_ORC=ON \ + ARROW_PARQUET=ON \ + ARROW_S3=ON \ + ARROW_SUBSTRAIT=ON \ + ARROW_USE_CCACHE=ON \ + ARROW_WITH_BROTLI=ON \ + ARROW_WITH_BZ2=ON \ + ARROW_WITH_LZ4=ON \ + ARROW_WITH_OPENTELEMETRY=ON \ + ARROW_WITH_SNAPPY=ON \ + ARROW_WITH_ZLIB=ON \ + ARROW_WITH_ZSTD=ON \ + AWSSDK_SOURCE=BUNDLED \ + Azure_SOURCE=BUNDLED \ + google_cloud_cpp_storage_SOURCE=BUNDLED \ + ORC_SOURCE=BUNDLED \ + PATH=/usr/lib/ccache/:$PATH \ + PYTHON=python3 From 71c0f90e0d5c4d7c7fadc4064de92b17d8ae800b Mon Sep 17 00:00:00 2001 From: Benjamin Leff Date: Fri, 22 Aug 2025 00:16:56 +0000 Subject: [PATCH 03/14] chore: merge conflicts --- cpp/src/arrow/flight/types.h | 118 +++++++++++++++++------------------ 1 file changed, 59 insertions(+), 59 deletions(-) diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 656cc00e676..74d5093601a 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -468,38 +468,6 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor } }; -/// \brief Data structure providing an opaque identifier or credential to use -/// when requesting a data stream with the DoGet RPC -struct ARROW_FLIGHT_EXPORT Ticket : public internal::BaseType { - std::string ticket; - - Ticket() = default; - Ticket(std::string ticket) // NOLINT runtime/explicit - : ticket(std::move(ticket)) {} - - std::string ToString() const; - bool Equals(const Ticket& other) const; - - using SuperT::Deserialize; - using SuperT::SerializeToString; - - /// \brief Get the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - /// - /// Use `SerializeToString()` if you want a Result-returning version. - arrow::Status SerializeToString(std::string* out) const; - - /// \brief Parse the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - /// - /// Use `Deserialize(serialized)` if you want a Result-returning version. - static arrow::Status Deserialize(std::string_view serialized, Ticket* out); -}; - /// \brief A host location (a URI) struct ARROW_FLIGHT_EXPORT Location : public internal::BaseType { public: @@ -569,6 +537,38 @@ struct ARROW_FLIGHT_EXPORT Location : public internal::BaseType { std::shared_ptr uri_; }; +/// \brief Data structure providing an opaque identifier or credential to use +/// when requesting a data stream with the DoGet RPC +struct ARROW_FLIGHT_EXPORT Ticket : public internal::BaseType { + std::string ticket; + + Ticket() = default; + Ticket(std::string ticket) // NOLINT runtime/explicit + : ticket(std::move(ticket)) {} + + std::string ToString() const; + bool Equals(const Ticket& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Get the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Parse the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Ticket* out); +}; + /// \brief A flight ticket and list of locations where the ticket can be /// redeemed struct ARROW_FLIGHT_EXPORT FlightEndpoint : public internal::BaseType { @@ -613,6 +613,33 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint : public internal::BaseType { + FlightEndpoint endpoint; + + RenewFlightEndpointRequest() = default; + explicit RenewFlightEndpointRequest(FlightEndpoint endpoint) + : endpoint(std::move(endpoint)) {} + + std::string ToString() const; + bool Equals(const RenewFlightEndpointRequest& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Serialize this message to its wire-format representation. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Deserialize this message from its wire-format representation. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + RenewFlightEndpointRequest* out); +}; + /// \brief The access coordinates for retrieval of a dataset, returned by /// GetFlightInfo class ARROW_FLIGHT_EXPORT FlightInfo @@ -857,33 +884,6 @@ struct ARROW_FLIGHT_EXPORT CancelFlightInfoResult ARROW_FLIGHT_EXPORT std::ostream& operator<<(std::ostream& os, CancelStatus status); -/// \brief The request of the RenewFlightEndpoint action. -struct ARROW_FLIGHT_EXPORT RenewFlightEndpointRequest - : public internal::BaseType { - FlightEndpoint endpoint; - - RenewFlightEndpointRequest() = default; - explicit RenewFlightEndpointRequest(FlightEndpoint endpoint) - : endpoint(std::move(endpoint)) {} - - std::string ToString() const; - bool Equals(const RenewFlightEndpointRequest& other) const; - - using SuperT::Deserialize; - using SuperT::SerializeToString; - - /// \brief Serialize this message to its wire-format representation. - /// - /// Use `SerializeToString()` if you want a Result-returning version. - arrow::Status SerializeToString(std::string* out) const; - - /// \brief Deserialize this message from its wire-format representation. - /// - /// Use `Deserialize(serialized)` if you want a Result-returning version. - static arrow::Status Deserialize(std::string_view serialized, - RenewFlightEndpointRequest* out); -}; - // FlightData in Flight.proto maps to FlightPayload here. /// \brief Staging data structure for messages about to be put on the wire From bce867d58abbde5ec31975e1edd7a66a0ce78860 Mon Sep 17 00:00:00 2001 From: Benjamin Leff Date: Fri, 22 Aug 2025 00:26:39 +0000 Subject: [PATCH 04/14] chore: merge conflicts in types.h --- cpp/src/arrow/flight/types.h | 64 +++++++++++++++++++----------------- 1 file changed, 33 insertions(+), 31 deletions(-) diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 74d5093601a..aac23754b81 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -468,6 +468,38 @@ struct ARROW_FLIGHT_EXPORT FlightDescriptor } }; +/// \brief Data structure providing an opaque identifier or credential to use +/// when requesting a data stream with the DoGet RPC +struct ARROW_FLIGHT_EXPORT Ticket : public internal::BaseType { + std::string ticket; + + Ticket() = default; + Ticket(std::string ticket) // NOLINT runtime/explicit + : ticket(std::move(ticket)) {} + + std::string ToString() const; + bool Equals(const Ticket& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Get the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Parse the wire-format representation of this type. + /// + /// Useful when interoperating with non-Flight systems (e.g. REST + /// services) that may want to return Flight types. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, Ticket* out); +}; + /// \brief A host location (a URI) struct ARROW_FLIGHT_EXPORT Location : public internal::BaseType { public: @@ -537,37 +569,6 @@ struct ARROW_FLIGHT_EXPORT Location : public internal::BaseType { std::shared_ptr uri_; }; -/// \brief Data structure providing an opaque identifier or credential to use -/// when requesting a data stream with the DoGet RPC -struct ARROW_FLIGHT_EXPORT Ticket : public internal::BaseType { - std::string ticket; - - Ticket() = default; - Ticket(std::string ticket) // NOLINT runtime/explicit - : ticket(std::move(ticket)) {} - - std::string ToString() const; - bool Equals(const Ticket& other) const; - - using SuperT::Deserialize; - using SuperT::SerializeToString; - - /// \brief Get the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - /// - /// Use `SerializeToString()` if you want a Result-returning version. - arrow::Status SerializeToString(std::string* out) const; - - /// \brief Parse the wire-format representation of this type. - /// - /// Useful when interoperating with non-Flight systems (e.g. REST - /// services) that may want to return Flight types. - /// - /// Use `Deserialize(serialized)` if you want a Result-returning version. - static arrow::Status Deserialize(std::string_view serialized, Ticket* out); -}; /// \brief A flight ticket and list of locations where the ticket can be /// redeemed @@ -884,6 +885,7 @@ struct ARROW_FLIGHT_EXPORT CancelFlightInfoResult ARROW_FLIGHT_EXPORT std::ostream& operator<<(std::ostream& os, CancelStatus status); + // FlightData in Flight.proto maps to FlightPayload here. /// \brief Staging data structure for messages about to be put on the wire From dff6c79aafef90570c212fce35b4568210150e5a Mon Sep 17 00:00:00 2001 From: Benjamin Leff Date: Fri, 22 Aug 2025 00:29:14 +0000 Subject: [PATCH 05/14] chore: merge conflicts in mockfs.cc --- cpp/src/arrow/filesystem/mockfs.cc | 66 ++++++++++++++++++------------ 1 file changed, 39 insertions(+), 27 deletions(-) diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index 15bc3f9b212..c8989c8401b 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -54,6 +54,7 @@ Status ValidatePath(std::string_view s) { //////////////////////////////////////////////////////////////////////////// // Filesystem structure +struct Directory; class Entry; struct File { @@ -80,37 +81,18 @@ struct Directory { TimePoint mtime; std::map> entries; - Directory(std::string name, TimePoint mtime) : name(std::move(name)), mtime(mtime) {} - Directory(Directory&& other) noexcept - : name(std::move(other.name)), - mtime(other.mtime), - entries(std::move(other.entries)) {} + Directory(std::string name, TimePoint mtime); + Directory(Directory&& other) noexcept; - Directory& operator=(Directory&& other) noexcept { - name = std::move(other.name); - mtime = other.mtime; - entries = std::move(other.entries); - return *this; - } + Directory& operator=(Directory&& other) noexcept; - Entry* Find(const std::string& s) { - auto it = entries.find(s); - if (it != entries.end()) { - return it->second.get(); - } else { - return nullptr; - } - } + Entry* Find(const std::string& s); - bool CreateEntry(const std::string& s, std::unique_ptr entry) { - DCHECK(!s.empty()); - auto p = entries.emplace(s, std::move(entry)); - return p.second; - } + bool CreateEntry(const std::string& s, std::unique_ptr entry); void AssignEntry(const std::string& s, std::unique_ptr entry); - bool DeleteEntry(const std::string& s) { return entries.erase(s) > 0; } + bool DeleteEntry(const std::string& s); private: ARROW_DISALLOW_COPY_AND_ASSIGN(Directory); @@ -120,7 +102,7 @@ struct Directory { using EntryBase = std::variant; class Entry : public EntryBase { - public: +public: Entry(Entry&&) = default; Entry& operator=(Entry&&) = default; explicit Entry(Directory&& v) : EntryBase(std::move(v)) {} @@ -180,15 +162,45 @@ class Entry : public EntryBase { } } - private: +private: ARROW_DISALLOW_COPY_AND_ASSIGN(Entry); }; +Directory::Directory(std::string name, TimePoint mtime) : name(std::move(name)), mtime(mtime) {} +Directory::Directory(Directory&& other) noexcept + : name(std::move(other.name)), + mtime(other.mtime), + entries(std::move(other.entries)) {} + +Directory& Directory::operator=(Directory&& other) noexcept { + name = std::move(other.name); + mtime = other.mtime; + entries = std::move(other.entries); + return *this; +} + +Entry* Directory::Find(const std::string& s) { + auto it = entries.find(s); + if (it != entries.end()) { + return it->second.get(); + } else { + return nullptr; + } +} + +bool Directory::CreateEntry(const std::string& s, std::unique_ptr entry) { + DCHECK(!s.empty()); + auto p = entries.emplace(s, std::move(entry)); + return p.second; +} + void Directory::AssignEntry(const std::string& s, std::unique_ptr entry) { DCHECK(!s.empty()); entries[s] = std::move(entry); } +bool Directory::DeleteEntry(const std::string& s) { return entries.erase(s) > 0; } + //////////////////////////////////////////////////////////////////////////// // Streams From 0b19e8553dc093fafe732366af04454222c5b72a Mon Sep 17 00:00:00 2001 From: gorloffslava Date: Mon, 4 Nov 2024 19:46:52 +0500 Subject: [PATCH 06/14] + Fix building in C++ 20 and 23 language modes --- cpp/src/arrow/filesystem/mockfs.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index c8989c8401b..e476f7dcbda 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -101,6 +101,8 @@ struct Directory { // A filesystem entry using EntryBase = std::variant; +struct Directory; + class Entry : public EntryBase { public: Entry(Entry&&) = default; From bf1c40dac0b4e420a47e73fe37ff6de27d8fadfa Mon Sep 17 00:00:00 2001 From: gorloffslava Date: Mon, 4 Nov 2024 19:50:13 +0500 Subject: [PATCH 07/14] + Fix building in C++ 20 and 23 language modes --- cpp/src/arrow/filesystem/mockfs.cc | 2 -- 1 file changed, 2 deletions(-) diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index e476f7dcbda..c8989c8401b 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -101,8 +101,6 @@ struct Directory { // A filesystem entry using EntryBase = std::variant; -struct Directory; - class Entry : public EntryBase { public: Entry(Entry&&) = default; From 8945f1abbfc1773682c1e6a6477ddee226b4b71f Mon Sep 17 00:00:00 2001 From: Benjamin Leff Date: Fri, 22 Aug 2025 00:30:22 +0000 Subject: [PATCH 08/14] chore: merge conflicts in asof_join_node.cc --- cpp/src/arrow/acero/asof_join_node.cc | 1981 +++++++++++++------------ 1 file changed, 1069 insertions(+), 912 deletions(-) diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 55fa45543e4..4edcf1d7c10 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -135,8 +135,6 @@ static inline uint64_t key_value(T t) { return static_cast(t); } -class AsofJoinNode; - #ifndef NDEBUG // Get the debug-stream associated with the as-of-join node std::ostream* GetDebugStream(AsofJoinNode* node); @@ -384,80 +382,8 @@ struct MemoStore { } }; -// a specialized higher-performance variation of Hashing64 logic from hash_join_node -// the code here avoids recreating objects that are independent of each batch processed -class KeyHasher { - friend class AsofJoinNode; - - static constexpr int kMiniBatchLength = arrow::util::MiniBatch::kMiniBatchLength; - - public: - // the key hasher is not thread-safe and is only used in sequential batch processing - // of the input it is associated with - KeyHasher(size_t index, const std::vector& indices) - : index_(index), - indices_(indices), - metadata_(indices.size()), - batch_(NULLPTR), - hashes_(), - ctx_(), - column_arrays_(), - stack_() { - ctx_.stack = &stack_; - column_arrays_.resize(indices.size()); - } - - Status Init(ExecContext* exec_context, const std::shared_ptr& schema) { - ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); - const auto& fields = schema->fields(); - for (size_t k = 0; k < metadata_.size(); k++) { - ARROW_ASSIGN_OR_RAISE(metadata_[k], - ColumnMetadataFromDataType(fields[indices_[k]]->type())); - } - return stack_.Init(exec_context->memory_pool(), - 4 * kMiniBatchLength * sizeof(uint32_t)); - } - - // invalidate cached hashes for batch - required when it changes - // only this method can be called concurrently with HashesFor - void Invalidate() { batch_ = NULLPTR; } - - // compute and cache a hash for each row of the given batch - const std::vector& HashesFor(const RecordBatch* batch) { - if (batch_ == batch) { - return hashes_; // cache hit - return cached hashes - } - Invalidate(); - size_t batch_length = batch->num_rows(); - hashes_.resize(batch_length); - for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { - int64_t length = std::min(static_cast(batch_length - i), - static_cast(kMiniBatchLength)); - for (size_t k = 0; k < indices_.size(); k++) { - auto array_data = batch->column_data(indices_[k]); - column_arrays_[k] = - ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); - } - // write directly to the cache - Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); - } - DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ", - compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl)); - batch_ = batch; // associate cache with current batch - return hashes_; - } - - private: - AsofJoinNode* node_ = nullptr; // avoids circular dependency during initialization - size_t index_; - std::vector indices_; - std::vector metadata_; - std::atomic batch_; - std::vector hashes_; - LightContext ctx_; - std::vector column_arrays_; - arrow::util::TempVectorStack stack_; -}; +class KeyHasher; +class AsofJoinNode; class BackpressureController : public BackpressureControl { public: @@ -484,140 +410,50 @@ class InputState : public util::SerialSequencingQueue::Processor { KeyHasher* key_hasher, AsofJoinNode* node, BackpressureHandler handler, const std::shared_ptr& schema, const col_index_t time_col_index, - const std::vector& key_col_index) - : sequencer_(util::SerialSequencingQueue::Make(this)), - queue_(std::move(handler)), - schema_(schema), - time_col_index_(time_col_index), - key_col_index_(key_col_index), - time_type_id_(schema_->fields()[time_col_index_]->type()->id()), - key_type_id_(key_col_index.size()), - key_hasher_(key_hasher), - node_(node), - index_(index), - must_hash_(must_hash), - may_rehash_(may_rehash), - tolerance_(tolerance), - memo_(DEBUG_ADD(/*no_future=*/index == 0 || !tolerance.positive, node, index)) { - for (size_t k = 0; k < key_col_index_.size(); k++) { - key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); - } - } + const std::vector& key_col_index); static Result> Make( size_t index, TolType tolerance, bool must_hash, bool may_rehash, KeyHasher* key_hasher, ExecNode* asof_input, AsofJoinNode* asof_node, std::atomic& backpressure_counter, const std::shared_ptr& schema, const col_index_t time_col_index, - const std::vector& key_col_index) { - constexpr size_t low_threshold = 4, high_threshold = 8; - std::unique_ptr backpressure_control = - std::make_unique( - /*node=*/asof_input, /*output=*/asof_node, backpressure_counter); - ARROW_ASSIGN_OR_RAISE( - auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold, - std::move(backpressure_control))); - return std::make_unique(index, tolerance, must_hash, may_rehash, - key_hasher, asof_node, std::move(handler), schema, - time_col_index, key_col_index); - } + const std::vector& key_col_index); - col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { - src_to_dst_.resize(schema_->num_fields()); - for (int i = 0; i < schema_->num_fields(); ++i) - if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) - src_to_dst_[i] = dst_offset++; - return dst_offset; - } + col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields); - const std::optional& MapSrcToDst(col_index_t src) const { - return src_to_dst_[src]; - } + const std::optional& MapSrcToDst(col_index_t src) const; - bool IsTimeOrKeyColumn(col_index_t i) const { - DCHECK_LT(i, schema_->num_fields()); - return (i == time_col_index_) || std_has(key_col_index_, i); - } + bool IsTimeOrKeyColumn(col_index_t i) const; // Gets the latest row index, assuming the queue isn't empty - row_index_t GetLatestRow() const { return latest_ref_row_; } + row_index_t GetLatestRow() const; - bool Empty() const { - // cannot be empty if ref row is >0 -- can avoid slow queue lock - // below - if (latest_ref_row_ > 0) return false; - return queue_.Empty(); - } + bool Empty() const; // true when the queue is empty and, when memo may have future entries (the case of a // positive tolerance), when the memo is empty. // used when checking whether RHS is up to date with LHS. // NOTE: The emptiness must be decided by a single call to Empty() in caller, due to the // potential race with Push(), see GH-41614. - bool CurrentEmpty(bool empty) const { - return memo_.no_future_ ? empty : (memo_.times_.empty() && empty); - } + bool CurrentEmpty(bool empty) const; // in case memo may not have future entries (the case of a non-positive tolerance), // returns the latest time (which is current); otherwise, returns the current time. // used when checking whether RHS is up to date with LHS. - OnType GetCurrentTime() const { - return memo_.no_future_ ? GetLatestTime() : static_cast(memo_.current_time_); - } + OnType GetCurrentTime() const; - int total_batches() const { return total_batches_; } + int total_batches() const; // Gets latest batch (precondition: must not be empty) const std::shared_ptr& GetLatestBatch() const { return queue_.Front(); } -#define LATEST_VAL_CASE(id, val) \ - case Type::id: { \ - using T = typename TypeIdTraits::Type; \ - using CType = typename TypeTraits::CType; \ - return val(data->GetValues(1)[row]); \ - } - - inline ByType GetLatestKey() const { - return GetKey(GetLatestBatch().get(), latest_ref_row_); - } + inline ByType GetLatestKey() const; - inline ByType GetKey(const RecordBatch* batch, row_index_t row) const { - if (must_hash_) { - // Query the key hasher. This may hit cache, which must be valid for the batch. - // Therefore, the key hasher is invalidated when a new batch is pushed - see - // `InputState::Push`. - return key_hasher_->HashesFor(batch)[row]; - } - if (key_col_index_.size() == 0) { - return 0; - } - auto data = batch->column_data(key_col_index_[0]); - switch (key_type_id_[0]) { - LATEST_VAL_CASE(INT8, key_value) - LATEST_VAL_CASE(INT16, key_value) - LATEST_VAL_CASE(INT32, key_value) - LATEST_VAL_CASE(INT64, key_value) - LATEST_VAL_CASE(UINT8, key_value) - LATEST_VAL_CASE(UINT16, key_value) - LATEST_VAL_CASE(UINT32, key_value) - LATEST_VAL_CASE(UINT64, key_value) - LATEST_VAL_CASE(DATE32, key_value) - LATEST_VAL_CASE(DATE64, key_value) - LATEST_VAL_CASE(TIME32, key_value) - LATEST_VAL_CASE(TIME64, key_value) - LATEST_VAL_CASE(TIMESTAMP, key_value) - default: - DCHECK(false); - return 0; // cannot happen - } - } + inline ByType GetKey(const RecordBatch* batch, row_index_t row) const; - inline OnType GetLatestTime() const { - return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_, - latest_ref_row_); - } + inline OnType GetLatestTime() const; #undef LATEST_VAL_CASE @@ -655,119 +491,27 @@ class InputState : public util::SerialSequencingQueue::Processor { // Returns true if updates were made, false if not. // NOTE: The emptiness must be decided by a single call to Empty() in caller, due to the // potential race with Push(), see GH-41614. - Result AdvanceAndMemoize(OnType ts, bool empty) { - // Advance the right side row index until we reach the latest right row (for each key) - // for the given left timestamp. - DEBUG_SYNC(node_, "Advancing input ", index_, DEBUG_MANIP(std::endl)); - - // Check if already updated for TS (or if there is no latest) - if (empty) { // can't advance if empty and no future entries - return memo_.no_future_ ? false : memo_.RemoveEntriesWithLesserTime(ts); - } + Result AdvanceAndMemoize(OnType ts, bool empty); - // Not updated. Try to update and possibly advance. - bool advanced, updated = false; - OnType latest_time; - do { - latest_time = GetLatestTime(); - // if Advance() returns true, then the latest_ts must also be valid - // Keep advancing right table until we hit the latest row that has - // timestamp <= ts. This is because we only need the latest row for the - // match given a left ts. - if (latest_time > tolerance_.Horizon(ts)) { // hit a distant timestamp - DEBUG_SYNC(node_, "Advancing input ", index_, " hit distant time=", latest_time, - " at=", ts, DEBUG_MANIP(std::endl)); - // if no future entries, which would have been earlier than the distant time, no - // need to queue it - if (memo_.future_entries_.empty()) break; - } - auto rb = GetLatestBatch(); - if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { - must_hash_ = true; - may_rehash_ = false; - Rehash(); - } - memo_.Store(rb, latest_ref_row_, latest_time, DEBUG_ADD(GetLatestKey(), ts)); - // negative tolerance means a last-known entry was stored - set `updated` to `true` - updated = memo_.no_future_; - ARROW_ASSIGN_OR_RAISE(advanced, Advance()); - } while (advanced); - if (!memo_.no_future_ && latest_time >= ts) { - // `updated` was not modified in the loop from the initial `false` value; set it now - updated = memo_.RemoveEntriesWithLesserTime(ts); - } - DEBUG_SYNC(node_, "Advancing input ", index_, " updated=", updated, - DEBUG_MANIP(std::endl)); - return updated; - } - Status InsertBatch(ExecBatch batch) { - return sequencer_->InsertBatch(std::move(batch)); - } - - Status Process(ExecBatch batch) override { - auto rb = *batch.ToRecordBatch(schema_); - DEBUG_SYNC(node_, "received batch from input ", index_, ":", DEBUG_MANIP(std::endl), - rb->ToString(), DEBUG_MANIP(std::endl)); - return Push(rb); - } - void Rehash() { - DEBUG_SYNC(node_, "rehashing for input ", index_, ":", DEBUG_MANIP(std::endl)); - MemoStore new_memo(DEBUG_ADD(memo_.no_future_, node_, index_)); - new_memo.current_time_ = (OnType)memo_.current_time_; - for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) { - auto& entry = e->second; - auto new_key = GetKey(entry.batch.get(), entry.row); - DEBUG_SYNC(node_, " ", e->first, " to ", new_key, DEBUG_MANIP(std::endl)); - new_memo.entries_[new_key].swap(entry); - auto fe = memo_.future_entries_.find(e->first); - if (fe != memo_.future_entries_.end()) { - new_memo.future_entries_[new_key].swap(fe->second); - } - } - memo_.times_.swap(new_memo.times_); - memo_.swap(new_memo); - } + Status InsertBatch(ExecBatch batch); - Status Push(const std::shared_ptr& rb) { - if (rb->num_rows() > 0) { - key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache - queue_.Push(rb); // only now push batch for processing - } else { - ++batches_processed_; // don't enqueue empty batches, just record as processed - } - return Status::OK(); - } + Status Process(ExecBatch batch) override; - std::optional GetMemoEntryForKey(ByType key) { - return memo_.GetEntryForKey(key); - } + void Rehash(); - std::optional GetMemoTimeForKey(ByType key) { - auto r = GetMemoEntryForKey(key); - if (r.has_value()) { - return (*r)->time; - } else { - return std::nullopt; - } - } + Status Push(const std::shared_ptr& rb); - void RemoveMemoEntriesWithLesserTime(OnType ts) { - memo_.RemoveEntriesWithLesserTime(ts); - } + std::optional GetMemoEntryForKey(ByType key); - const std::shared_ptr& get_schema() const { return schema_; } + std::optional GetMemoTimeForKey(ByType key); - void set_total_batches(int n) { - DCHECK_GE(n, 0); - DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; - total_batches_ = n; - } + void RemoveMemoEntriesWithLesserTime(OnType ts); - Status ForceShutdown() { - // Force the upstream input node to unpause. Necessary to avoid deadlock when we - // terminate the process thread - return queue_.ForceShutdown(); - } + const std::shared_ptr& get_schema() const; + + void set_total_batches(int n); + + Status ForceShutdown(); private: // ExecBatch Sequencer @@ -811,121 +555,37 @@ class InputState : public util::SerialSequencingQueue::Processor { std::vector> src_to_dst_; }; -/// Wrapper around UnmaterializedCompositeTable that knows how to emplace -/// the join row-by-row -template -class CompositeTableBuilder { - using SliceBuilder = UnmaterializedSliceBuilder; - using CompositeTable = UnmaterializedCompositeTable; - - public: - NDEBUG_EXPLICIT CompositeTableBuilder( - const std::vector>& inputs, - const std::shared_ptr& schema, arrow::MemoryPool* pool, - DEBUG_ADD(size_t n_tables, AsofJoinNode* node)) - : unmaterialized_table(InitUnmaterializedTable(schema, inputs, pool)), - DEBUG_ADD(n_tables_(n_tables), node_(node)) { - DCHECK_GE(n_tables_, 1); - DCHECK_LE(n_tables_, MAX_TABLES); - } - - size_t n_rows() const { return unmaterialized_table.Size(); } - - // Adds the latest row from the input state as a new composite reference row - // - LHS must have a valid key,timestep,and latest rows - // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, TolType tolerance) { - DCHECK_EQ(in.size(), n_tables_); - - // Get the LHS key - ByType key = in[0]->GetLatestKey(); - - // Add row and setup LHS - // (the LHS state comes just from the latest row of the LHS table) - DCHECK(!in[0]->Empty()); - const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); - row_index_t lhs_latest_row = in[0]->GetLatestRow(); - OnType lhs_latest_time = in[0]->GetLatestTime(); - if (0 == lhs_latest_row) { - // On the first row of the batch, we resize the destination. - // The destination size is dictated by the size of the LHS batch. - row_index_t new_batch_size = lhs_latest_batch->num_rows(); - row_index_t new_capacity = unmaterialized_table.Size() + new_batch_size; - if (unmaterialized_table.capacity() < new_capacity) { - unmaterialized_table.reserve(new_capacity); - } - } - - SliceBuilder new_row{&unmaterialized_table}; +// a specialized higher-performance variation of Hashing64 logic from hash_join_node +// the code here avoids recreating objects that are independent of each batch processed +class KeyHasher { + friend class AsofJoinNode; - // Each item represents a portion of the columns of the output table - new_row.AddEntry(lhs_latest_batch, lhs_latest_row, lhs_latest_row + 1); + static constexpr int kMiniBatchLength = arrow::util::MiniBatch::kMiniBatchLength; - DEBUG_SYNC(node_, "Emplace: key=", key, " lhs_latest_row=", lhs_latest_row, - " lhs_latest_time=", lhs_latest_time, DEBUG_MANIP(std::endl)); + public: + // the key hasher is not thread-safe and is only used in sequential batch processing + // of the input it is associated with + KeyHasher(size_t index, const std::vector& indices); - // Get the state for that key from all on the RHS -- assumes it's up to date - // (the RHS state comes from the memoized row references) - for (size_t i = 1; i < in.size(); ++i) { - std::optional opt_entry = in[i]->GetMemoEntryForKey(key); -#ifndef NDEBUG - { - bool has_entry = opt_entry.has_value(); - OnType entry_time = has_entry ? (*opt_entry)->time : TolType::kMinValue; - row_index_t entry_row = has_entry ? (*opt_entry)->row : 0; - bool accepted = has_entry && tolerance.Accepts(lhs_latest_time, entry_time); - DEBUG_SYNC(node_, " i=", i, " has_entry=", has_entry, " time=", entry_time, - " row=", entry_row, " accepted=", accepted, DEBUG_MANIP(std::endl)); - } -#endif - if (opt_entry.has_value()) { - DCHECK(*opt_entry); - if (tolerance.Accepts(lhs_latest_time, (*opt_entry)->time)) { - // Have a valid entry - const MemoStore::Entry* entry = *opt_entry; - new_row.AddEntry(entry->batch, entry->row, entry->row + 1); - continue; - } - } - new_row.AddEntry(nullptr, 0, 1); - } - new_row.Finalize(); - } + Status Init(ExecContext* exec_context, const std::shared_ptr& schema); - // Materializes the current reference table into a target record batch - Result>> Materialize() { - return unmaterialized_table.Materialize(); - } + // invalidate cached hashes for batch - required when it changes + // only this method can be called concurrently with HashesFor + void Invalidate(); - // Returns true if there are no rows - bool empty() const { return unmaterialized_table.Empty(); } + // compute and cache a hash for each row of the given batch + const std::vector& HashesFor(const RecordBatch* batch); private: - CompositeTable unmaterialized_table; - - // Total number of tables in the composite table - size_t n_tables_; - -#ifndef NDEBUG - // Owning node - AsofJoinNode* node_; -#endif - - static CompositeTable InitUnmaterializedTable( - const std::shared_ptr& schema, - const std::vector>& inputs, arrow::MemoryPool* pool) { - std::unordered_map> dst_to_src; - for (size_t i = 0; i < inputs.size(); i++) { - auto& input = inputs[i]; - for (int src = 0; src < input->get_schema()->num_fields(); src++) { - auto dst = input->MapSrcToDst(src); - if (dst.has_value()) { - dst_to_src[dst.value()] = std::make_pair(static_cast(i), src); - } - } - } - return CompositeTable{schema, inputs.size(), dst_to_src, pool}; - } + AsofJoinNode* node_ = nullptr; // avoids circular dependency during initialization + size_t index_; + std::vector indices_; + std::vector metadata_; + std::atomic batch_; + std::vector hashes_; + LightContext ctx_; + std::vector column_arrays_; + arrow::util::TempVectorStack stack_; }; // TODO: Currently, AsofJoinNode uses 64-bit hashing which leads to a non-negligible @@ -941,93 +601,16 @@ class AsofJoinNode : public ExecNode { bool any_advanced; bool all_up_to_date_with_lhs; }; + // Advances the RHS as far as possible to be up to date for the current LHS timestamp, // and checks if all RHS are up to date with LHS. The reason they have to be performed // together is that they both depend on the emptiness of the RHS, which can be changed // by Push() executing in another thread. - Result UpdateRhs() { - auto& lhs = *state_.at(0); - auto lhs_latest_time = lhs.GetLatestTime(); - RhsUpdateState update_state{/*any_advanced=*/false, /*all_up_to_date_with_lhs=*/true}; - for (size_t i = 1; i < state_.size(); ++i) { - auto& rhs = *state_[i]; - - // Obtain RHS emptiness once for subsequent AdvanceAndMemoize() and CurrentEmpty(). - bool rhs_empty = rhs.Empty(); - // Obtain RHS current time here because AdvanceAndMemoize() can change the - // emptiness. - OnType rhs_current_time = rhs_empty ? OnType{} : rhs.GetLatestTime(); - - ARROW_ASSIGN_OR_RAISE(bool advanced, - rhs.AdvanceAndMemoize(lhs_latest_time, rhs_empty)); - update_state.any_advanced |= advanced; - - if (update_state.all_up_to_date_with_lhs && !rhs.Finished()) { - // If RHS is finished, then we know it's up to date - if (rhs.CurrentEmpty(rhs_empty)) { - // RHS isn't finished, but is empty --> not up to date - update_state.all_up_to_date_with_lhs = false; - } else if (lhs_latest_time > rhs_current_time) { - // RHS isn't up to date (and not finished) - update_state.all_up_to_date_with_lhs = false; - } - } - } - return update_state; - } - - Result> ProcessInner() { - DCHECK(!state_.empty()); - auto& lhs = *state_.at(0); - - // Construct new target table if needed - CompositeTableBuilder dst(state_, output_schema_, - plan()->query_context()->memory_pool(), - DEBUG_ADD(state_.size(), this)); - - // Generate rows into the dst table until we either run out of data or hit the row - // limit, or run out of input - for (;;) { - // If LHS is finished or empty then there's nothing we can do here - if (lhs.Finished() || lhs.Empty()) break; - - ARROW_ASSIGN_OR_RAISE(auto rhs_update_state, UpdateRhs()); - - // If we have received enough inputs to produce the next output batch - // (decided by IsUpToDateWithLhsRow), we will perform the join and - // materialize the output batch. The join is done by advancing through - // the LHS and adding joined row to rows_ (done by Emplace). Finally, - // input batches that are no longer needed are removed to free up memory. - if (rhs_update_state.all_up_to_date_with_lhs) { - dst.Emplace(state_, tolerance_); - ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); - if (!advanced) break; // if we can't advance LHS, we're done for this batch - } else { - if (!rhs_update_state.any_advanced) break; // need to wait for new data - } - } + Result UpdateRhs(); - // Prune memo entries that have expired (to bound memory consumption) - if (!lhs.Empty()) { - for (size_t i = 1; i < state_.size(); ++i) { - OnType ts = tolerance_.Expiry(lhs.GetLatestTime()); - if (ts != TolType::kMinValue) { - state_[i]->RemoveMemoEntriesWithLesserTime(ts); - } - } - } - - // Emit the batch - if (dst.empty()) { - return NULLPTR; - } else { - ARROW_ASSIGN_OR_RAISE(auto out, dst.Materialize()); - return out.has_value() ? out.value() : NULLPTR; - } - } + Result> ProcessInner(); #ifdef ARROW_ENABLE_THREADING - template struct Defer { Callable callable; @@ -1035,85 +618,15 @@ class AsofJoinNode : public ExecNode { ~Defer() noexcept { callable(); } }; - void EndFromProcessThread(Status st = Status::OK()) { - // We must spawn a new task to transfer off the process thread when - // marking this finished. Otherwise there is a chance that doing so could - // mark the plan finished which may destroy the plan which will destroy this - // node which will cause us to join on ourselves. - ARROW_UNUSED( - plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable { - Defer cleanup([this, &st]() { process_task_.MarkFinished(st); }); - if (st.ok()) { - st = output_->InputFinished(this, batches_produced_); - } - for (const auto& s : state_) { - st &= s->ForceShutdown(); - } - })); - } - - bool CheckEnded() { - if (state_.at(0)->Finished()) { - EndFromProcessThread(); - return false; - } - return true; - } - - bool Process() { - std::lock_guard guard(gate_); - if (!CheckEnded()) { - return false; - } + void EndFromProcessThread(Status st = Status::OK()); - // Process batches while we have data - for (;;) { - Result> result = ProcessInner(); - - if (result.ok()) { - auto out_rb = *result; - if (!out_rb) break; - ExecBatch out_b(*out_rb); - out_b.index = batches_produced_++; - DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), - out_rb->ToString(), DEBUG_MANIP(std::endl)); - Status st = output_->InputReceived(this, std::move(out_b)); - if (!st.ok()) { - EndFromProcessThread(std::move(st)); - } - } else { - EndFromProcessThread(result.status()); - return false; - } - } - - // Report to the output the total batch count, if we've already finished everything - // (there are two places where this can happen: here and InputFinished) - // - // It may happen here in cases where InputFinished was called before we were finished - // producing results (so we didn't know the output size at that time) - if (!CheckEnded()) { - return false; - } + bool CheckEnded(); - // There is no more we can do now but there is still work remaining for later when - // more data arrives. - return true; - } + bool Process(); - void ProcessThread() { - for (;;) { - if (!process_.WaitAndPop()) { - EndFromProcessThread(); - return; - } - if (!Process()) { - return; - } - } - } + void ProcessThread(); - static void ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); } + static void ProcessThreadWrapper(AsofJoinNode* node); #endif public: @@ -1124,117 +637,16 @@ class AsofJoinNode : public ExecNode { std::vector> key_hashers, bool must_hash, bool may_rehash); - Status Init() override { - auto inputs = this->inputs(); - for (size_t i = 0; i < inputs.size(); i++) { - RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(), - inputs[i]->output_schema())); - ARROW_ASSIGN_OR_RAISE( - auto input_state, - InputState::Make(i, tolerance_, must_hash_, may_rehash_, key_hashers_[i].get(), - inputs[i], this, backpressure_counter_, - inputs[i]->output_schema(), indices_of_on_key_[i], - indices_of_by_key_[i])); - state_.push_back(std::move(input_state)); - } - - col_index_t dst_offset = 0; - for (auto& state : state_) - dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + Status Init() override; - return Status::OK(); - } + virtual ~AsofJoinNode(); - virtual ~AsofJoinNode() { -#ifdef ARROW_ENABLE_THREADING - PushProcess(false); - if (process_thread_.joinable()) { - process_thread_.join(); - } -#endif - } + const std::vector& indices_of_on_key(); + const std::vector>& indices_of_by_key(); - const std::vector& indices_of_on_key() { return indices_of_on_key_; } - const std::vector>& indices_of_by_key() { - return indices_of_by_key_; - } - - static Status is_valid_on_field(const std::shared_ptr& field) { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - return Status::OK(); - default: - return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", - field->type()->ToString()); - } - } - - static Status is_valid_by_field(const std::shared_ptr& field) { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::STRING: - case Type::LARGE_STRING: - case Type::BINARY: - case Type::LARGE_BINARY: - return Status::OK(); - default: - return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", - field->type()->ToString()); - } - } - - static Status is_valid_data_field(const std::shared_ptr& field) { - switch (field->type()->id()) { - case Type::BOOL: - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::FLOAT: - case Type::DOUBLE: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::STRING: - case Type::LARGE_STRING: - case Type::BINARY: - case Type::LARGE_BINARY: - return Status::OK(); - default: - return Status::Invalid("Unsupported type for data field ", field->name(), " : ", - field->type()->ToString()); - } - } + static Status is_valid_on_field(const std::shared_ptr& field); + static Status is_valid_by_field(const std::shared_ptr& field); + static Status is_valid_data_field(const std::shared_ptr& field); /// \brief Make the output schema of an as-of-join node /// @@ -1244,331 +656,1076 @@ class AsofJoinNode : public ExecNode { static arrow::Result> MakeOutputSchema( const std::vector> input_schema, const std::vector& indices_of_on_key, - const std::vector>& indices_of_by_key) { - std::vector> fields; - - size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); - const DataType* on_key_type = NULLPTR; - std::vector by_key_type(n_by, NULLPTR); - // Take all non-key, non-time RHS fields - for (size_t j = 0; j < input_schema.size(); ++j) { - const auto& on_field_ix = indices_of_on_key[j]; - const auto& by_field_ix = indices_of_by_key[j]; - - if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { - return Status::Invalid("Missing join key on table ", j); - } + const std::vector>& indices_of_by_key); - const auto& on_field = input_schema[j]->fields()[on_field_ix]; - std::vector by_field(n_by); - for (size_t k = 0; k < n_by; k++) { - by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); - } + static inline Result FindColIndex(const Schema& schema, + const FieldRef& field_ref, + std::string_view key_kind); - if (on_key_type == NULLPTR) { - on_key_type = on_field->type().get(); - } else if (*on_key_type != *on_field->type()) { - return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", - *on_field->type(), " for field ", on_field->name(), - " in input ", j); - } - for (size_t k = 0; k < n_by; k++) { - if (by_key_type[k] == NULLPTR) { - by_key_type[k] = by_field[k]->type().get(); - } else if (*by_key_type[k] != *by_field[k]->type()) { - return Status::Invalid("Expected by-key type ", *by_key_type[k], " but got ", - *by_field[k]->type(), " for field ", by_field[k]->name(), - " in input ", j); - } + static Result GetByKeySize( + const std::vector& input_keys); + + static Result> GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys); + + static Result>> GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys); + + static arrow::Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options); + + const char* kind_name() const override; + const Ordering& ordering() const override; + + Status InputReceived(ExecNode* input, ExecBatch batch) override; + + Status InputFinished(ExecNode* input, int total_batches) override; + + void PushProcess(bool value); + +#ifdef ARROW_ENABLE_THREADING + bool ProcessNonThreaded(); + void EndFromSingleThread(Status st = Status::OK()); +#endif + + Status StartProducing() override; + + void PauseProducing(ExecNode* output, int32_t counter) override; + void ResumeProducing(ExecNode* output, int32_t counter) override; + + Status StopProducingImpl() override; + +#ifndef NDEBUG + std::ostream* GetDebugStream(); + + std::mutex* GetDebugMutex(); +#endif + + private: + // Outputs from this node are always in ascending order according to the on key + const Ordering ordering_; + std::vector indices_of_on_key_; + std::vector> indices_of_by_key_; + std::vector> key_hashers_; + bool must_hash_; + bool may_rehash_; + // InputStates + // Each input state corresponds to an input table + std::vector> state_; + std::mutex gate_; + TolType tolerance_; +#ifndef NDEBUG + std::ostream* debug_os_; + std::mutex* debug_mutex_; +#endif + + // Backpressure counter common to all inputs + std::atomic backpressure_counter_; +#ifdef ARROW_ENABLE_THREADING + // Queue for triggering processing of a given input + // (a false value is a poison pill) + ConcurrentQueue process_; + // Worker thread + std::thread process_thread_; +#endif + Future<> process_task_; + + // In-progress batches produced + int batches_produced_ = 0; +}; + +InputState::InputState(size_t index, TolType tolerance, bool must_hash, bool may_rehash, + KeyHasher* key_hasher, AsofJoinNode* node, + BackpressureHandler handler, + const std::shared_ptr& schema, + const col_index_t time_col_index, + const std::vector& key_col_index) + : sequencer_(util::SerialSequencingQueue::Make(this)), + queue_(std::move(handler)), + schema_(schema), + time_col_index_(time_col_index), + key_col_index_(key_col_index), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()), + key_type_id_(key_col_index.size()), + key_hasher_(key_hasher), + node_(node), + index_(index), + must_hash_(must_hash), + may_rehash_(may_rehash), + tolerance_(tolerance), + memo_(DEBUG_ADD(/*no_future=*/index == 0 || !tolerance.positive, node, index)) { + for (size_t k = 0; k < key_col_index_.size(); k++) { + key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); + } +} + +Result> InputState::Make( + size_t index, TolType tolerance, bool must_hash, bool may_rehash, + KeyHasher* key_hasher, ExecNode* asof_input, AsofJoinNode* asof_node, + std::atomic& backpressure_counter, + const std::shared_ptr& schema, const col_index_t time_col_index, + const std::vector& key_col_index) { + constexpr size_t low_threshold = 4, high_threshold = 8; + std::unique_ptr backpressure_control = + std::make_unique( + /*node=*/asof_input, /*output=*/asof_node, backpressure_counter); + ARROW_ASSIGN_OR_RAISE( + auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold, + std::move(backpressure_control))); + return std::make_unique(index, tolerance, must_hash, may_rehash, key_hasher, + asof_node, std::move(handler), schema, + time_col_index, key_col_index); +} + +col_index_t InputState::InitSrcToDstMapping(col_index_t dst_offset, + bool skip_time_and_key_fields) { + src_to_dst_.resize(schema_->num_fields()); + for (int i = 0; i < schema_->num_fields(); ++i) + if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) + src_to_dst_[i] = dst_offset++; + return dst_offset; +} + +const std::optional& InputState::MapSrcToDst(col_index_t src) const { + return src_to_dst_[src]; +} + +bool InputState::IsTimeOrKeyColumn(col_index_t i) const { + DCHECK_LT(i, schema_->num_fields()); + return (i == time_col_index_) || std_has(key_col_index_, i); +} + +row_index_t InputState::GetLatestRow() const { return latest_ref_row_; } + +bool InputState::Empty() const { + // cannot be empty if ref row is >0 -- can avoid slow queue lock + // below + if (latest_ref_row_ > 0) return false; + return queue_.Empty(); +} + +bool InputState::CurrentEmpty(bool empty) const { + return memo_.no_future_ ? empty : (memo_.times_.empty() && empty); +} + +OnType InputState::GetCurrentTime() const { + return memo_.no_future_ ? GetLatestTime() : static_cast(memo_.current_time_); +} + +inline ByType InputState::GetLatestKey() const { + return GetKey(GetLatestBatch().get(), latest_ref_row_); +} + +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[row]); \ + } + +inline ByType InputState::GetKey(const RecordBatch* batch, row_index_t row) const { + if (must_hash_) { + // Query the key hasher. This may hit cache, which must be valid for the batch. + // Therefore, the key hasher is invalidated when a new batch is pushed - see + // `InputState::Push`. + return key_hasher_->HashesFor(batch)[row]; + } + if (key_col_index_.size() == 0) { + return 0; + } + auto data = batch->column_data(key_col_index_[0]); + switch (key_type_id_[0]) { + LATEST_VAL_CASE(INT8, key_value) + LATEST_VAL_CASE(INT16, key_value) + LATEST_VAL_CASE(INT32, key_value) + LATEST_VAL_CASE(INT64, key_value) + LATEST_VAL_CASE(UINT8, key_value) + LATEST_VAL_CASE(UINT16, key_value) + LATEST_VAL_CASE(UINT32, key_value) + LATEST_VAL_CASE(UINT64, key_value) + LATEST_VAL_CASE(DATE32, key_value) + LATEST_VAL_CASE(DATE64, key_value) + LATEST_VAL_CASE(TIME32, key_value) + LATEST_VAL_CASE(TIME64, key_value) + LATEST_VAL_CASE(TIMESTAMP, key_value) + default: + DCHECK(false); + return 0; // cannot happen + } +} + +#undef LATEST_VAL_CASE + +inline OnType InputState::GetLatestTime() const { + return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_, latest_ref_row_); +} + +int InputState::total_batches() const { return total_batches_; } + +Result InputState::AdvanceAndMemoize(OnType ts, bool empty) { + // Advance the right side row index until we reach the latest right row (for each key) + // for the given left timestamp. + DEBUG_SYNC(node_, "Advancing input ", index_, DEBUG_MANIP(std::endl)); + + // Check if already updated for TS (or if there is no latest) + if (empty) { // can't advance if empty and no future entries + return memo_.no_future_ ? false : memo_.RemoveEntriesWithLesserTime(ts); + } + + // Not updated. Try to update and possibly advance. + bool advanced, updated = false; + OnType latest_time; + do { + latest_time = GetLatestTime(); + // if Advance() returns true, then the latest_ts must also be valid + // Keep advancing right table until we hit the latest row that has + // timestamp <= ts. This is because we only need the latest row for the + // match given a left ts. + if (latest_time > tolerance_.Horizon(ts)) { // hit a distant timestamp + DEBUG_SYNC(node_, "Advancing input ", index_, " hit distant time=", latest_time, + " at=", ts, DEBUG_MANIP(std::endl)); + // if no future entries, which would have been earlier than the distant time, no + // need to queue it + if (memo_.future_entries_.empty()) break; + } + auto rb = GetLatestBatch(); + if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + must_hash_ = true; + may_rehash_ = false; + Rehash(); + } + memo_.Store(rb, latest_ref_row_, latest_time, DEBUG_ADD(GetLatestKey(), ts)); + // negative tolerance means a last-known entry was stored - set `updated` to `true` + updated = memo_.no_future_; + ARROW_ASSIGN_OR_RAISE(advanced, Advance()); + } while (advanced); + if (!memo_.no_future_ && latest_time >= ts) { + // `updated` was not modified in the loop from the initial `false` value; set it now + updated = memo_.RemoveEntriesWithLesserTime(ts); + } + DEBUG_SYNC(node_, "Advancing input ", index_, " updated=", updated, + DEBUG_MANIP(std::endl)); + return updated; +} + +Status InputState::InsertBatch(ExecBatch batch) { + return sequencer_->InsertBatch(std::move(batch)); +} + +Status InputState::Process(ExecBatch batch) { + auto rb = *batch.ToRecordBatch(schema_); + DEBUG_SYNC(node_, "received batch from input ", index_, ":", DEBUG_MANIP(std::endl), + rb->ToString(), DEBUG_MANIP(std::endl)); + return Push(rb); +} + +void InputState::Rehash() { + DEBUG_SYNC(node_, "rehashing for input ", index_, ":", DEBUG_MANIP(std::endl)); + MemoStore new_memo(DEBUG_ADD(memo_.no_future_, node_, index_)); + new_memo.current_time_ = (OnType)memo_.current_time_; + for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) { + auto& entry = e->second; + auto new_key = GetKey(entry.batch.get(), entry.row); + DEBUG_SYNC(node_, " ", e->first, " to ", new_key, DEBUG_MANIP(std::endl)); + new_memo.entries_[new_key].swap(entry); + auto fe = memo_.future_entries_.find(e->first); + if (fe != memo_.future_entries_.end()) { + new_memo.future_entries_[new_key].swap(fe->second); + } + } + memo_.times_.swap(new_memo.times_); + memo_.swap(new_memo); +} + +Status InputState::Push(const std::shared_ptr& rb) { + if (rb->num_rows() > 0) { + key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache + queue_.Push(rb); // only now push batch for processing + } else { + ++batches_processed_; // don't enqueue empty batches, just record as processed + } + return Status::OK(); +} + +std::optional InputState::GetMemoEntryForKey(ByType key) { + return memo_.GetEntryForKey(key); +} + +std::optional InputState::GetMemoTimeForKey(ByType key) { + auto r = GetMemoEntryForKey(key); + if (r.has_value()) { + return (*r)->time; + } else { + return std::nullopt; + } +} + +void InputState::RemoveMemoEntriesWithLesserTime(OnType ts) { + memo_.RemoveEntriesWithLesserTime(ts); +} + +const std::shared_ptr& InputState::get_schema() const { return schema_; } + +void InputState::set_total_batches(int n) { + DCHECK_GE(n, 0); + DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; + total_batches_ = n; +} + +Status InputState::ForceShutdown() { + // Force the upstream input node to unpause. Necessary to avoid deadlock when we + // terminate the process thread + return queue_.ForceShutdown(); +} + +KeyHasher::KeyHasher(size_t index, const std::vector& indices) + : index_(index), + indices_(indices), + metadata_(indices.size()), + batch_(NULLPTR), + hashes_(), + ctx_(), + column_arrays_(), + stack_() { + ctx_.stack = &stack_; + column_arrays_.resize(indices.size()); +} + +Status KeyHasher::Init(ExecContext* exec_context, + const std::shared_ptr& schema) { + ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); + const auto& fields = schema->fields(); + for (size_t k = 0; k < metadata_.size(); k++) { + ARROW_ASSIGN_OR_RAISE(metadata_[k], + ColumnMetadataFromDataType(fields[indices_[k]]->type())); + } + return stack_.Init(exec_context->memory_pool(), + 4 * kMiniBatchLength * sizeof(uint32_t)); +} + +void KeyHasher::Invalidate() { batch_ = NULLPTR; } + +// compute and cache a hash for each row of the given batch +const std::vector& KeyHasher::HashesFor(const RecordBatch* batch) { + if (batch_ == batch) { + return hashes_; // cache hit - return cached hashes + } + Invalidate(); + size_t batch_length = batch->num_rows(); + hashes_.resize(batch_length); + for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { + int64_t length = std::min(static_cast(batch_length - i), + static_cast(kMiniBatchLength)); + for (size_t k = 0; k < indices_.size(); k++) { + auto array_data = batch->column_data(indices_[k]); + column_arrays_[k] = + ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); + } + // write directly to the cache + Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + } + DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ", + compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl)); + batch_ = batch; // associate cache with current batch + return hashes_; +} + +Result AsofJoinNode::UpdateRhs() { + auto& lhs = *state_.at(0); + auto lhs_latest_time = lhs.GetLatestTime(); + RhsUpdateState update_state{/*any_advanced=*/false, /*all_up_to_date_with_lhs=*/true}; + for (size_t i = 1; i < state_.size(); ++i) { + auto& rhs = *state_[i]; + + // Obtain RHS emptiness once for subsequent AdvanceAndMemoize() and CurrentEmpty(). + bool rhs_empty = rhs.Empty(); + // Obtain RHS current time here because AdvanceAndMemoize() can change the + // emptiness. + OnType rhs_current_time = rhs_empty ? OnType{} : rhs.GetLatestTime(); + + ARROW_ASSIGN_OR_RAISE(bool advanced, + rhs.AdvanceAndMemoize(lhs_latest_time, rhs_empty)); + update_state.any_advanced |= advanced; + + if (update_state.all_up_to_date_with_lhs && !rhs.Finished()) { + // If RHS is finished, then we know it's up to date + if (rhs.CurrentEmpty(rhs_empty)) { + // RHS isn't finished, but is empty --> not up to date + update_state.all_up_to_date_with_lhs = false; + } else if (lhs_latest_time > rhs_current_time) { + // RHS isn't up to date (and not finished) + update_state.all_up_to_date_with_lhs = false; } + } + } + return update_state; +} - for (int i = 0; i < input_schema[j]->num_fields(); ++i) { - const auto field = input_schema[j]->field(i); - bool as_output; // true if the field appears as an output - if (i == on_field_ix) { - ARROW_RETURN_NOT_OK(is_valid_on_field(field)); - // Only add on field from the left table - as_output = (j == 0); - } else if (std_has(by_field_ix, i)) { - ARROW_RETURN_NOT_OK(is_valid_by_field(field)); - // Only add by field from the left table - as_output = (j == 0); - } else { - ARROW_RETURN_NOT_OK(is_valid_data_field(field)); - as_output = true; +#ifdef ARROW_ENABLE_THREADING +void AsofJoinNode::EndFromProcessThread(Status st) { + // We must spawn a new task to transfer off the process thread when + // marking this finished. Otherwise there is a chance that doing so could + // mark the plan finished which may destroy the plan which will destroy this + // node which will cause us to join on ourselves. + ARROW_UNUSED( + plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable { + Defer cleanup([this, &st]() { process_task_.MarkFinished(st); }); + if (st.ok()) { + st = output_->InputFinished(this, batches_produced_); } - if (as_output) { - fields.push_back(field); + for (const auto& s : state_) { + st &= s->ForceShutdown(); } + })); +} + +bool AsofJoinNode::CheckEnded() { + if (state_.at(0)->Finished()) { + EndFromProcessThread(); + return false; + } + return true; +} + +bool AsofJoinNode::Process() { + std::lock_guard guard(gate_); + if (!CheckEnded()) { + return false; + } + + // Process batches while we have data + for (;;) { + Result> result = ProcessInner(); + + if (result.ok()) { + auto out_rb = *result; + if (!out_rb) break; + ExecBatch out_b(*out_rb); + out_b.index = batches_produced_++; + DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), + out_rb->ToString(), DEBUG_MANIP(std::endl)); + Status st = output_->InputReceived(this, std::move(out_b)); + if (!st.ok()) { + EndFromProcessThread(std::move(st)); } + } else { + EndFromProcessThread(result.status()); + return false; } - return std::make_shared(fields); } - static inline Result FindColIndex(const Schema& schema, - const FieldRef& field_ref, - std::string_view key_kind) { - auto match_res = field_ref.FindOne(schema); - if (!match_res.ok()) { - return Status::Invalid("Bad join key on table : ", match_res.status().message()); + // Report to the output the total batch count, if we've already finished everything + // (there are two places where this can happen: here and InputFinished) + // + // It may happen here in cases where InputFinished was called before we were finished + // producing results (so we didn't know the output size at that time) + if (!CheckEnded()) { + return false; + } + + // There is no more we can do now but there is still work remaining for later when + // more data arrives. + return true; +} + +void AsofJoinNode::ProcessThread() { + for (;;) { + if (!process_.WaitAndPop()) { + EndFromProcessThread(); + return; } - ARROW_ASSIGN_OR_RAISE(auto match, match_res); - if (match.indices().size() != 1) { - return Status::Invalid("AsOfJoinNode does not support a nested ", key_kind, "-key ", - field_ref.ToString()); + if (!Process()) { + return; } - return match.indices()[0]; } +} - static Result GetByKeySize( - const std::vector& input_keys) { - size_t n_by = 0; - for (size_t i = 0; i < input_keys.size(); ++i) { - const auto& by_key = input_keys[i].by_key; - if (i == 0) { - n_by = by_key.size(); - } else if (n_by != by_key.size()) { - return Status::Invalid("inconsistent size of by-key across inputs"); - } - } - return n_by; +void AsofJoinNode::ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); } +#endif + +Status AsofJoinNode::Init() { + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(), + inputs[i]->output_schema())); + ARROW_ASSIGN_OR_RAISE( + auto input_state, + InputState::Make(i, tolerance_, must_hash_, may_rehash_, key_hashers_[i].get(), + inputs[i], this, backpressure_counter_, + inputs[i]->output_schema(), indices_of_on_key_[i], + indices_of_by_key_[i])); + state_.push_back(std::move(input_state)); } - static Result> GetIndicesOfOnKey( - const std::vector>& input_schema, - const std::vector& input_keys) { - if (input_schema.size() != input_keys.size()) { - return Status::Invalid("mismatching number of input schema and keys"); + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); +} + +AsofJoinNode::~AsofJoinNode() { +#ifdef ARROW_ENABLE_THREADING + PushProcess(false); + if (process_thread_.joinable()) { + process_thread_.join(); + } +#endif +} + +const std::vector& AsofJoinNode::indices_of_on_key() { + return indices_of_on_key_; +} +const std::vector>& AsofJoinNode::indices_of_by_key() { + return indices_of_by_key_; +} + +Status AsofJoinNode::is_valid_on_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); + } +} + +Status AsofJoinNode::is_valid_by_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); + } +} + +Status AsofJoinNode::is_valid_data_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::BOOL: + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for data field ", field->name(), " : ", + field->type()->ToString()); + } +} + +/// \brief Make the output schema of an as-of-join node +/// +/// \param[in] input_schema the schema of each input to the node +/// \param[in] indices_of_on_key the on-key index of each input to the node +/// \param[in] indices_of_by_key the by-key indices of each input to the node +arrow::Result> AsofJoinNode::MakeOutputSchema( + const std::vector> input_schema, + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key) { + std::vector> fields; + + size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); + const DataType* on_key_type = NULLPTR; + std::vector by_key_type(n_by, NULLPTR); + // Take all non-key, non-time RHS fields + for (size_t j = 0; j < input_schema.size(); ++j) { + const auto& on_field_ix = indices_of_on_key[j]; + const auto& by_field_ix = indices_of_by_key[j]; + + if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { + return Status::Invalid("Missing join key on table ", j); } - size_t n_input = input_schema.size(); - std::vector indices_of_on_key(n_input); - for (size_t i = 0; i < n_input; ++i) { - const auto& on_key = input_keys[i].on_key; - ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], - FindColIndex(*input_schema[i], on_key, "on")); + + const auto& on_field = input_schema[j]->fields()[on_field_ix]; + std::vector by_field(n_by); + for (size_t k = 0; k < n_by; k++) { + by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); } - return indices_of_on_key; - } - static Result>> GetIndicesOfByKey( - const std::vector>& input_schema, - const std::vector& input_keys) { - if (input_schema.size() != input_keys.size()) { - return Status::Invalid("mismatching number of input schema and keys"); + if (on_key_type == NULLPTR) { + on_key_type = on_field->type().get(); + } else if (*on_key_type != *on_field->type()) { + return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", + *on_field->type(), " for field ", on_field->name(), + " in input ", j); } - ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); - size_t n_input = input_schema.size(); - std::vector> indices_of_by_key( - n_input, std::vector(n_by)); - for (size_t i = 0; i < n_input; ++i) { - for (size_t k = 0; k < n_by; k++) { - const auto& by_key = input_keys[i].by_key; - ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], - FindColIndex(*input_schema[i], by_key[k], "by")); + for (size_t k = 0; k < n_by; k++) { + if (by_key_type[k] == NULLPTR) { + by_key_type[k] = by_field[k]->type().get(); + } else if (*by_key_type[k] != *by_field[k]->type()) { + return Status::Invalid("Expected by-key type ", *by_key_type[k], " but got ", + *by_field[k]->type(), " for field ", by_field[k]->name(), + " in input ", j); } } - return indices_of_by_key; - } - static arrow::Result Make(ExecPlan* plan, std::vector inputs, - const ExecNodeOptions& options) { - DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; - const auto& join_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); - size_t n_input = inputs.size(); - std::vector input_labels(n_input); - std::vector> input_schema(n_input); - for (size_t i = 0; i < n_input; ++i) { - input_labels[i] = i == 0 ? "left" : "right_" + ToChars(i); - input_schema[i] = inputs[i]->output_schema(); + for (int i = 0; i < input_schema[j]->num_fields(); ++i) { + const auto field = input_schema[j]->field(i); + bool as_output; // true if the field appears as an output + if (i == on_field_ix) { + ARROW_RETURN_NOT_OK(is_valid_on_field(field)); + // Only add on field from the left table + as_output = (j == 0); + } else if (std_has(by_field_ix, i)) { + ARROW_RETURN_NOT_OK(is_valid_by_field(field)); + // Only add by field from the left table + as_output = (j == 0); + } else { + ARROW_RETURN_NOT_OK(is_valid_data_field(field)); + as_output = true; + } + if (as_output) { + fields.push_back(field); + } } - ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, - GetIndicesOfOnKey(input_schema, join_options.input_keys)); - ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, - GetIndicesOfByKey(input_schema, join_options.input_keys)); - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr output_schema, - MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); + } + return std::make_shared(fields); +} - std::vector> key_hashers; - for (size_t i = 0; i < n_input; i++) { - key_hashers.push_back(std::make_unique(i, indices_of_by_key[i])); +inline Result AsofJoinNode::FindColIndex(const Schema& schema, + const FieldRef& field_ref, + std::string_view key_kind) { + auto match_res = field_ref.FindOne(schema); + if (!match_res.ok()) { + return Status::Invalid("Bad join key on table : ", match_res.status().message()); + } + ARROW_ASSIGN_OR_RAISE(auto match, match_res); + if (match.indices().size() != 1) { + return Status::Invalid("AsOfJoinNode does not support a nested ", key_kind, "-key ", + field_ref.ToString()); + } + return match.indices()[0]; +} + +Result AsofJoinNode::GetByKeySize( + const std::vector& input_keys) { + size_t n_by = 0; + for (size_t i = 0; i < input_keys.size(); ++i) { + const auto& by_key = input_keys[i].by_key; + if (i == 0) { + n_by = by_key.size(); + } else if (n_by != by_key.size()) { + return Status::Invalid("inconsistent size of by-key across inputs"); } - bool must_hash = - n_by > 1 || - (n_by == 1 && - !is_primitive( - inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); - bool may_rehash = n_by == 1 && !must_hash; - return plan->EmplaceNode( - plan, inputs, std::move(input_labels), std::move(indices_of_on_key), - std::move(indices_of_by_key), std::move(join_options), std::move(output_schema), - std::move(key_hashers), must_hash, may_rehash); - } - - const char* kind_name() const override { return "AsofJoinNode"; } - const Ordering& ordering() const override { return ordering_; } - - Status InputReceived(ExecNode* input, ExecBatch batch) override { - // InputReceived may be called after execution was finished. Pushing it to the - // InputState is unnecessary since we're done (and anyway may cause the - // BackPressureController to pause the input, causing a deadlock), so drop it. - if (::arrow::compute::kUnsequencedIndex == batch.index) - return Status::Invalid("AsofJoin requires sequenced input"); - - if (process_task_.is_finished()) { - DEBUG_SYNC(this, "Input received while done. Short circuiting.", - DEBUG_MANIP(std::endl)); - return Status::OK(); + } + return n_by; +} + +Result> AsofJoinNode::GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + size_t n_input = input_schema.size(); + std::vector indices_of_on_key(n_input); + for (size_t i = 0; i < n_input; ++i) { + const auto& on_key = input_keys[i].on_key; + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(*input_schema[i], on_key, "on")); + } + return indices_of_on_key; +} + +Result>> AsofJoinNode::GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); + size_t n_input = input_schema.size(); + std::vector> indices_of_by_key(n_input, + std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + for (size_t k = 0; k < n_by; k++) { + const auto& by_key = input_keys[i].by_key; + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(*input_schema[i], by_key[k], "by")); } + } + return indices_of_by_key; +} - // Get the input - ARROW_DCHECK(std_has(inputs_, input)); - size_t k = std_find(inputs_, input) - inputs_.begin(); +arrow::Result AsofJoinNode::Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; + const auto& join_options = checked_cast(options); + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); + size_t n_input = inputs.size(); + std::vector input_labels(n_input); + std::vector> input_schema(n_input); + for (size_t i = 0; i < n_input; ++i) { + input_labels[i] = i == 0 ? "left" : "right_" + ToChars(i); + input_schema[i] = inputs[i]->output_schema(); + } + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + GetIndicesOfOnKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + GetIndicesOfByKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr output_schema, + MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); + + std::vector> key_hashers; + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back(std::make_unique(i, indices_of_by_key[i])); + } + bool must_hash = + n_by > 1 || + (n_by == 1 && + !is_primitive( + inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); + bool may_rehash = n_by == 1 && !must_hash; + return plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), std::move(join_options), std::move(output_schema), + std::move(key_hashers), must_hash, may_rehash); +} - // Put into the sequencing queue - ARROW_RETURN_NOT_OK(state_.at(k)->InsertBatch(std::move(batch))); +const char* AsofJoinNode::kind_name() const { return "AsofJoinNode"; } +const Ordering& AsofJoinNode::ordering() const { return ordering_; } - PushProcess(true); +Status AsofJoinNode::InputReceived(ExecNode* input, ExecBatch batch) { + // InputReceived may be called after execution was finished. Pushing it to the + // InputState is unnecessary since we're done (and anyway may cause the + // BackPressureController to pause the input, causing a deadlock), so drop it. + if (::arrow::compute::kUnsequencedIndex == batch.index) + return Status::Invalid("AsofJoin requires sequenced input"); + if (process_task_.is_finished()) { + DEBUG_SYNC(this, "Input received while done. Short circuiting.", + DEBUG_MANIP(std::endl)); return Status::OK(); } - Status InputFinished(ExecNode* input, int total_batches) override { - { - std::lock_guard guard(gate_); - ARROW_DCHECK(std_has(inputs_, input)); - size_t k = std_find(inputs_, input) - inputs_.begin(); - state_.at(k)->set_total_batches(total_batches); - } - // Trigger a process call - // The reason for this is that there are cases at the end of a table where we don't - // know whether the RHS of the join is up-to-date until we know that the table is - // finished. - PushProcess(true); + // Get the input + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); - return Status::OK(); + // Put into the sequencing queue + ARROW_RETURN_NOT_OK(state_.at(k)->InsertBatch(std::move(batch))); + + PushProcess(true); + + return Status::OK(); +} + +Status AsofJoinNode::InputFinished(ExecNode* input, int total_batches) { + { + std::lock_guard guard(gate_); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); + state_.at(k)->set_total_batches(total_batches); } - void PushProcess(bool value) { + // Trigger a process call + // The reason for this is that there are cases at the end of a table where we don't + // know whether the RHS of the join is up-to-date until we know that the table is + // finished. + PushProcess(true); + + return Status::OK(); +} + +void AsofJoinNode::PushProcess(bool value) { #ifdef ARROW_ENABLE_THREADING - process_.Push(value); + process_.Push(value); #else - if (value) { - ProcessNonThreaded(); - } else if (!process_task_.is_finished()) { - EndFromSingleThread(); - } -#endif + if (value) { + ProcessNonThreaded(); + } else if (!process_task_.is_finished()) { + EndFromSingleThread(); } +#endif +} #ifndef ARROW_ENABLE_THREADING - bool ProcessNonThreaded() { - while (!process_task_.is_finished()) { - Result> result = ProcessInner(); - - if (result.ok()) { - auto out_rb = *result; - if (!out_rb) break; - ExecBatch out_b(*out_rb); - out_b.index = batches_produced_++; - DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), - out_rb->ToString(), DEBUG_MANIP(std::endl)); - Status st = output_->InputReceived(this, std::move(out_b)); - if (!st.ok()) { - // this isn't really from a thread, - // but we call through to this for consistency - EndFromSingleThread(std::move(st)); - return false; - } - } else { +bool AsofJoinNode::ProcessNonThreaded() { + while (!process_task_.is_finished()) { + Result> result = ProcessInner(); + + if (result.ok()) { + auto out_rb = *result; + if (!out_rb) break; + ExecBatch out_b(*out_rb); + out_b.index = batches_produced_++; + DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), + out_rb->ToString(), DEBUG_MANIP(std::endl)); + Status st = output_->InputReceived(this, std::move(out_b)); + if (!st.ok()) { // this isn't really from a thread, // but we call through to this for consistency - EndFromSingleThread(result.status()); + EndFromSingleThread(std::move(st)); return false; } + } else { + // this isn't really from a thread, + // but we call through to this for consistency + EndFromSingleThread(result.status()); + return false; } - auto& lhs = *state_.at(0); - if (lhs.Finished() && !process_task_.is_finished()) { - EndFromSingleThread(Status::OK()); - } - return true; } - - void EndFromSingleThread(Status st = Status::OK()) { - process_task_.MarkFinished(st); - if (st.ok()) { - st = output_->InputFinished(this, batches_produced_); - } - for (const auto& s : state_) { - st &= s->ForceShutdown(); - } + auto& lhs = *state_.at(0); + if (lhs.Finished() && !process_task_.is_finished()) { + EndFromSingleThread(Status::OK()); } + return true; +} +void AsofJoinNode::EndFromSingleThread(Status st = Status::OK()) { + process_task_.MarkFinished(st); + if (st.ok()) { + st = output_->InputFinished(this, batches_produced_); + } + for (const auto& s : state_) { + st &= s->ForceShutdown(); + } +} #endif - Status StartProducing() override { - ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask( - "AsofJoinNode::ProcessThread")); - if (!process_task_.is_valid()) { - // Plan has already aborted. Do not start process thread - return Status::OK(); - } -#ifdef ARROW_ENABLE_THREADING - process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this); -#endif +Status AsofJoinNode::StartProducing() { + ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask( + "AsofJoinNode::ProcessThread")); + if (!process_task_.is_valid()) { + // Plan has already aborted. Do not start process thread return Status::OK(); } +#ifdef ARROW_ENABLE_THREADING + process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this); +#endif + return Status::OK(); +} - void PauseProducing(ExecNode* output, int32_t counter) override {} - void ResumeProducing(ExecNode* output, int32_t counter) override {} +void AsofJoinNode::PauseProducing(ExecNode* output, int32_t counter) {} +void AsofJoinNode::ResumeProducing(ExecNode* output, int32_t counter) {} - Status StopProducingImpl() override { +Status AsofJoinNode::StopProducingImpl() { #ifdef ARROW_ENABLE_THREADING - process_.Clear(); + process_.Clear(); #endif - PushProcess(false); - return Status::OK(); - } + PushProcess(false); + return Status::OK(); +} #ifndef NDEBUG - std::ostream* GetDebugStream() { return debug_os_; } +std::ostream* AsofJoinNode::GetDebugStream() { return debug_os_; } - std::mutex* GetDebugMutex() { return debug_mutex_; } +std::mutex* AsofJoinNode::GetDebugMutex() { return debug_mutex_; } #endif - private: - // Outputs from this node are always in ascending order according to the on key - const Ordering ordering_; - std::vector indices_of_on_key_; - std::vector> indices_of_by_key_; - std::vector> key_hashers_; - bool must_hash_; - bool may_rehash_; - // InputStates - // Each input state corresponds to an input table - std::vector> state_; - std::mutex gate_; - TolType tolerance_; +/// Wrapper around UnmaterializedCompositeTable that knows how to emplace +/// the join row-by-row +template +class CompositeTableBuilder { + using SliceBuilder = UnmaterializedSliceBuilder; + using CompositeTable = UnmaterializedCompositeTable; + + public: + NDEBUG_EXPLICIT CompositeTableBuilder( + const std::vector>& inputs, + const std::shared_ptr& schema, arrow::MemoryPool* pool, + DEBUG_ADD(size_t n_tables, AsofJoinNode* node)) + : unmaterialized_table(InitUnmaterializedTable(schema, inputs, pool)), + DEBUG_ADD(n_tables_(n_tables), node_(node)) { + DCHECK_GE(n_tables_, 1); + DCHECK_LE(n_tables_, MAX_TABLES); + } + + size_t n_rows() const { return unmaterialized_table.Size(); } + + // Adds the latest row from the input state as a new composite reference row + // - LHS must have a valid key,timestep,and latest rows + // - RHS must have valid data memo'ed for the key + void Emplace(std::vector>& in, TolType tolerance) { + DCHECK_EQ(in.size(), n_tables_); + + // Get the LHS key + ByType key = in[0]->GetLatestKey(); + + // Add row and setup LHS + // (the LHS state comes just from the latest row of the LHS table) + DCHECK(!in[0]->Empty()); + const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); + row_index_t lhs_latest_row = in[0]->GetLatestRow(); + OnType lhs_latest_time = in[0]->GetLatestTime(); + if (0 == lhs_latest_row) { + // On the first row of the batch, we resize the destination. + // The destination size is dictated by the size of the LHS batch. + row_index_t new_batch_size = lhs_latest_batch->num_rows(); + row_index_t new_capacity = unmaterialized_table.Size() + new_batch_size; + if (unmaterialized_table.capacity() < new_capacity) { + unmaterialized_table.reserve(new_capacity); + } + } + + SliceBuilder new_row{&unmaterialized_table}; + + // Each item represents a portion of the columns of the output table + new_row.AddEntry(lhs_latest_batch, lhs_latest_row, lhs_latest_row + 1); + + DEBUG_SYNC(node_, "Emplace: key=", key, " lhs_latest_row=", lhs_latest_row, + " lhs_latest_time=", lhs_latest_time, DEBUG_MANIP(std::endl)); + + // Get the state for that key from all on the RHS -- assumes it's up to date + // (the RHS state comes from the memoized row references) + for (size_t i = 1; i < in.size(); ++i) { + std::optional opt_entry = in[i]->GetMemoEntryForKey(key); #ifndef NDEBUG - std::ostream* debug_os_; - std::mutex* debug_mutex_; + { + bool has_entry = opt_entry.has_value(); + OnType entry_time = has_entry ? (*opt_entry)->time : TolType::kMinValue; + row_index_t entry_row = has_entry ? (*opt_entry)->row : 0; + bool accepted = has_entry && tolerance.Accepts(lhs_latest_time, entry_time); + DEBUG_SYNC(node_, " i=", i, " has_entry=", has_entry, " time=", entry_time, + " row=", entry_row, " accepted=", accepted, DEBUG_MANIP(std::endl)); + } #endif + if (opt_entry.has_value()) { + DCHECK(*opt_entry); + if (tolerance.Accepts(lhs_latest_time, (*opt_entry)->time)) { + // Have a valid entry + const MemoStore::Entry* entry = *opt_entry; + new_row.AddEntry(entry->batch, entry->row, entry->row + 1); + continue; + } + } + new_row.AddEntry(nullptr, 0, 1); + } + new_row.Finalize(); + } - // Backpressure counter common to all inputs - std::atomic backpressure_counter_; -#ifdef ARROW_ENABLE_THREADING - // Queue for triggering processing of a given input - // (a false value is a poison pill) - ConcurrentQueue process_; - // Worker thread - std::thread process_thread_; + // Materializes the current reference table into a target record batch + Result>> Materialize() { + return unmaterialized_table.Materialize(); + } + + // Returns true if there are no rows + bool empty() const { return unmaterialized_table.Empty(); } + + private: + CompositeTable unmaterialized_table; + + // Total number of tables in the composite table + size_t n_tables_; + +#ifndef NDEBUG + // Owning node + AsofJoinNode* node_; #endif - Future<> process_task_; - // In-progress batches produced - int batches_produced_ = 0; + static CompositeTable InitUnmaterializedTable( + const std::shared_ptr& schema, + const std::vector>& inputs, arrow::MemoryPool* pool) { + std::unordered_map> dst_to_src; + for (size_t i = 0; i < inputs.size(); i++) { + auto& input = inputs[i]; + for (int src = 0; src < input->get_schema()->num_fields(); src++) { + auto dst = input->MapSrcToDst(src); + if (dst.has_value()) { + dst_to_src[dst.value()] = std::make_pair(static_cast(i), src); + } + } + } + return CompositeTable{schema, inputs.size(), dst_to_src, pool}; + } }; +Result> AsofJoinNode::ProcessInner() { + DCHECK(!state_.empty()); + auto& lhs = *state_.at(0); + + // Construct new target table if needed + CompositeTableBuilder dst(state_, output_schema_, + plan()->query_context()->memory_pool(), + DEBUG_ADD(state_.size(), this)); + + // Generate rows into the dst table until we either run out of data or hit the row + // limit, or run out of input + for (;;) { + // If LHS is finished or empty then there's nothing we can do here + if (lhs.Finished() || lhs.Empty()) break; + + ARROW_ASSIGN_OR_RAISE(auto rhs_update_state, UpdateRhs()); + + // If we have received enough inputs to produce the next output batch + // (decided by IsUpToDateWithLhsRow), we will perform the join and + // materialize the output batch. The join is done by advancing through + // the LHS and adding joined row to rows_ (done by Emplace). Finally, + // input batches that are no longer needed are removed to free up memory. + if (rhs_update_state.all_up_to_date_with_lhs) { + dst.Emplace(state_, tolerance_); + ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); + if (!advanced) break; // if we can't advance LHS, we're done for this batch + } else { + if (!rhs_update_state.any_advanced) break; // need to wait for new data + } + } + + // Prune memo entries that have expired (to bound memory consumption) + if (!lhs.Empty()) { + for (size_t i = 1; i < state_.size(); ++i) { + OnType ts = tolerance_.Expiry(lhs.GetLatestTime()); + if (ts != TolType::kMinValue) { + state_[i]->RemoveMemoEntriesWithLesserTime(ts); + } + } + } + + // Emit the batch + if (dst.empty()) { + return NULLPTR; + } else { + ARROW_ASSIGN_OR_RAISE(auto out, dst.Materialize()); + return out.has_value() ? out.value() : NULLPTR; + } +} + AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const std::vector& indices_of_on_key, @@ -1635,4 +1792,4 @@ std::mutex* GetDebugMutex(AsofJoinNode* node) { return node->GetDebugMutex(); } #undef DEBUG_ADD } // namespace acero -} // namespace arrow +} // namespace arrow \ No newline at end of file From b20e2e98c7c19c9058b6c97b9796d7f7ed53296d Mon Sep 17 00:00:00 2001 From: Benjamin Leff Date: Fri, 22 Aug 2025 13:15:43 +0000 Subject: [PATCH 09/14] chore: serialize.cc --- python/pyarrow/src/arrow/python/serialize.cc | 798 +++++++++++++++++++ 1 file changed, 798 insertions(+) create mode 100644 python/pyarrow/src/arrow/python/serialize.cc diff --git a/python/pyarrow/src/arrow/python/serialize.cc b/python/pyarrow/src/arrow/python/serialize.cc new file mode 100644 index 00000000000..2295421165d --- /dev/null +++ b/python/pyarrow/src/arrow/python/serialize.cc @@ -0,0 +1,798 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "arrow/python/serialize.h" +#include "arrow/python/numpy_interop.h" + +#include +#include +#include +#include +#include +#include + +#include +#include + +#include "arrow/array.h" +#include "arrow/array/builder_binary.h" +#include "arrow/array/builder_nested.h" +#include "arrow/array/builder_primitive.h" +#include "arrow/array/builder_union.h" +#include "arrow/io/interfaces.h" +#include "arrow/io/memory.h" +#include "arrow/ipc/util.h" +#include "arrow/ipc/writer.h" +#include "arrow/record_batch.h" +#include "arrow/result.h" +#include "arrow/tensor.h" +#include "arrow/util/logging.h" + +#include "arrow/python/common.h" +#include "arrow/python/datetime.h" +#include "arrow/python/helpers.h" +#include "arrow/python/iterators.h" +#include "arrow/python/numpy_convert.h" +#include "arrow/python/platform.h" +#include "arrow/python/pyarrow.h" + +constexpr int32_t kMaxRecursionDepth = 100; + +namespace arrow { + +using internal::checked_cast; + +namespace py { + +class SequenceBuilder; +class DictBuilder; + +Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder, + int32_t recursion_depth, SerializedPyObject* blobs_out); + +// Constructing dictionaries of key/value pairs. Sequences of +// keys and values are built separately using a pair of +// SequenceBuilders. The resulting Arrow representation +// can be obtained via the Finish method. +class DictBuilder { +public: + explicit DictBuilder(MemoryPool* pool = nullptr) : keys_(pool), vals_(pool) { + builder_.reset(new StructBuilder(struct_({field("keys", dense_union(FieldVector{})), + field("vals", dense_union(FieldVector{}))}), + pool, {keys_.builder(), vals_.builder()})); + } + + // Builder for the keys of the dictionary + SequenceBuilder& keys() { return keys_; } + // Builder for the values of the dictionary + SequenceBuilder& vals() { return vals_; } + + // Construct an Arrow StructArray representing the dictionary. + // Contains a field "keys" for the keys and "vals" for the values. + Status Finish(std::shared_ptr* out) { return builder_->Finish(out); } + + std::shared_ptr builder() { return builder_; } + +private: + SequenceBuilder keys_; + SequenceBuilder vals_; + std::shared_ptr builder_; +}; + +// A Sequence is a heterogeneous collections of elements. It can contain +// scalar Python types, lists, tuples, dictionaries, tensors and sparse tensors. +class SequenceBuilder { + public: + explicit SequenceBuilder(MemoryPool* pool = default_memory_pool()) + : pool_(pool), + types_(::arrow::int8(), pool), + offsets_(::arrow::int32(), pool), + type_map_(PythonType::NUM_PYTHON_TYPES, -1) { + auto null_builder = std::make_shared(pool); + auto initial_ty = dense_union({field("0", null())}); + builder_.reset(new DenseUnionBuilder(pool, {null_builder}, initial_ty)); + } + + // Appending a none to the sequence + Status AppendNone() { return builder_->AppendNull(); } + + template + Status CreateAndUpdate(std::shared_ptr* child_builder, int8_t tag, + MakeBuilderFn make_builder) { + if (!*child_builder) { + child_builder->reset(make_builder()); + std::ostringstream convert; + convert.imbue(std::locale::classic()); + convert << static_cast(tag); + type_map_[tag] = builder_->AppendChild(*child_builder, convert.str()); + } + return builder_->Append(type_map_[tag]); + } + + template + Status AppendPrimitive(std::shared_ptr* child_builder, const T val, + int8_t tag) { + RETURN_NOT_OK( + CreateAndUpdate(child_builder, tag, [this]() { return new BuilderType(pool_); })); + return (*child_builder)->Append(val); + } + + // Appending a boolean to the sequence + Status AppendBool(const bool data) { + return AppendPrimitive(&bools_, data, PythonType::BOOL); + } + + // Appending an int64_t to the sequence + Status AppendInt64(const int64_t data) { + return AppendPrimitive(&ints_, data, PythonType::INT); + } + + // Append a list of bytes to the sequence + Status AppendBytes(const uint8_t* data, int32_t length) { + RETURN_NOT_OK(CreateAndUpdate(&bytes_, PythonType::BYTES, + [this]() { return new BinaryBuilder(pool_); })); + return bytes_->Append(data, length); + } + + // Appending a string to the sequence + Status AppendString(const char* data, int32_t length) { + RETURN_NOT_OK(CreateAndUpdate(&strings_, PythonType::STRING, + [this]() { return new StringBuilder(pool_); })); + return strings_->Append(data, length); + } + + // Appending a half_float to the sequence + Status AppendHalfFloat(const npy_half data) { + return AppendPrimitive(&half_floats_, data, PythonType::HALF_FLOAT); + } + + // Appending a float to the sequence + Status AppendFloat(const float data) { + return AppendPrimitive(&floats_, data, PythonType::FLOAT); + } + + // Appending a double to the sequence + Status AppendDouble(const double data) { + return AppendPrimitive(&doubles_, data, PythonType::DOUBLE); + } + + // Appending a Date64 timestamp to the sequence + Status AppendDate64(const int64_t timestamp) { + return AppendPrimitive(&date64s_, timestamp, PythonType::DATE64); + } + + // Appending a tensor to the sequence + // + // \param tensor_index Index of the tensor in the object. + Status AppendTensor(const int32_t tensor_index) { + RETURN_NOT_OK(CreateAndUpdate(&tensor_indices_, PythonType::TENSOR, + [this]() { return new Int32Builder(pool_); })); + return tensor_indices_->Append(tensor_index); + } + + // Appending a sparse coo tensor to the sequence + // + // \param sparse_coo_tensor_index Index of the sparse coo tensor in the object. + Status AppendSparseCOOTensor(const int32_t sparse_coo_tensor_index) { + RETURN_NOT_OK(CreateAndUpdate(&sparse_coo_tensor_indices_, + PythonType::SPARSECOOTENSOR, + [this]() { return new Int32Builder(pool_); })); + return sparse_coo_tensor_indices_->Append(sparse_coo_tensor_index); + } + + // Appending a sparse csr matrix to the sequence + // + // \param sparse_csr_matrix_index Index of the sparse csr matrix in the object. + Status AppendSparseCSRMatrix(const int32_t sparse_csr_matrix_index) { + RETURN_NOT_OK(CreateAndUpdate(&sparse_csr_matrix_indices_, + PythonType::SPARSECSRMATRIX, + [this]() { return new Int32Builder(pool_); })); + return sparse_csr_matrix_indices_->Append(sparse_csr_matrix_index); + } + + // Appending a sparse csc matrix to the sequence + // + // \param sparse_csc_matrix_index Index of the sparse csc matrix in the object. + Status AppendSparseCSCMatrix(const int32_t sparse_csc_matrix_index) { + RETURN_NOT_OK(CreateAndUpdate(&sparse_csc_matrix_indices_, + PythonType::SPARSECSCMATRIX, + [this]() { return new Int32Builder(pool_); })); + return sparse_csc_matrix_indices_->Append(sparse_csc_matrix_index); + } + + // Appending a sparse csf tensor to the sequence + // + // \param sparse_csf_tensor_index Index of the sparse csf tensor in the object. + Status AppendSparseCSFTensor(const int32_t sparse_csf_tensor_index) { + RETURN_NOT_OK(CreateAndUpdate(&sparse_csf_tensor_indices_, + PythonType::SPARSECSFTENSOR, + [this]() { return new Int32Builder(pool_); })); + return sparse_csf_tensor_indices_->Append(sparse_csf_tensor_index); + } + + // Appending a numpy ndarray to the sequence + // + // \param tensor_index Index of the tensor in the object. + Status AppendNdarray(const int32_t ndarray_index) { + RETURN_NOT_OK(CreateAndUpdate(&ndarray_indices_, PythonType::NDARRAY, + [this]() { return new Int32Builder(pool_); })); + return ndarray_indices_->Append(ndarray_index); + } + + // Appending a buffer to the sequence + // + // \param buffer_index Index of the buffer in the object. + Status AppendBuffer(const int32_t buffer_index) { + RETURN_NOT_OK(CreateAndUpdate(&buffer_indices_, PythonType::BUFFER, + [this]() { return new Int32Builder(pool_); })); + return buffer_indices_->Append(buffer_index); + } + + Status AppendSequence(PyObject* context, PyObject* sequence, int8_t tag, + std::shared_ptr& target_sequence, + std::unique_ptr& values, int32_t recursion_depth, + SerializedPyObject* blobs_out) { + if (recursion_depth >= kMaxRecursionDepth) { + return Status::NotImplemented( + "This object exceeds the maximum recursion depth. It may contain itself " + "recursively."); + } + RETURN_NOT_OK(CreateAndUpdate(&target_sequence, tag, [this, &values]() { + values.reset(new SequenceBuilder(pool_)); + return new ListBuilder(pool_, values->builder()); + })); + RETURN_NOT_OK(target_sequence->Append()); + return internal::VisitIterable( + sequence, [&](PyObject* obj, bool* keep_going /* unused */) { + return Append(context, obj, values.get(), recursion_depth, blobs_out); + }); + } + + Status AppendList(PyObject* context, PyObject* list, int32_t recursion_depth, + SerializedPyObject* blobs_out) { + return AppendSequence(context, list, PythonType::LIST, lists_, list_values_, + recursion_depth + 1, blobs_out); + } + + Status AppendTuple(PyObject* context, PyObject* tuple, int32_t recursion_depth, + SerializedPyObject* blobs_out) { + return AppendSequence(context, tuple, PythonType::TUPLE, tuples_, tuple_values_, + recursion_depth + 1, blobs_out); + } + + Status AppendSet(PyObject* context, PyObject* set, int32_t recursion_depth, + SerializedPyObject* blobs_out) { + return AppendSequence(context, set, PythonType::SET, sets_, set_values_, + recursion_depth + 1, blobs_out); + } + + Status AppendDict(PyObject* context, PyObject* dict, int32_t recursion_depth, + SerializedPyObject* blobs_out); + + // Finish building the sequence and return the result. + // Input arrays may be nullptr + Status Finish(std::shared_ptr* out) { return builder_->Finish(out); } + + std::shared_ptr builder() { return builder_; } + + private: + MemoryPool* pool_; + + Int8Builder types_; + Int32Builder offsets_; + + /// Mapping from PythonType to child index + std::vector type_map_; + + std::shared_ptr bools_; + std::shared_ptr ints_; + std::shared_ptr bytes_; + std::shared_ptr strings_; + std::shared_ptr half_floats_; + std::shared_ptr floats_; + std::shared_ptr doubles_; + std::shared_ptr date64s_; + + std::unique_ptr list_values_; + std::shared_ptr lists_; + std::unique_ptr dict_values_; + std::shared_ptr dicts_; + std::unique_ptr tuple_values_; + std::shared_ptr tuples_; + std::unique_ptr set_values_; + std::shared_ptr sets_; + + std::shared_ptr tensor_indices_; + std::shared_ptr sparse_coo_tensor_indices_; + std::shared_ptr sparse_csr_matrix_indices_; + std::shared_ptr sparse_csc_matrix_indices_; + std::shared_ptr sparse_csf_tensor_indices_; + std::shared_ptr ndarray_indices_; + std::shared_ptr buffer_indices_; + + std::shared_ptr builder_; +}; + +Status SequenceBuilder::AppendDict(PyObject* context, PyObject* dict, + int32_t recursion_depth, + SerializedPyObject* blobs_out) { + if (recursion_depth >= kMaxRecursionDepth) { + return Status::NotImplemented( + "This object exceeds the maximum recursion depth. It may contain itself " + "recursively."); + } + RETURN_NOT_OK(CreateAndUpdate(&dicts_, PythonType::DICT, [this]() { + dict_values_.reset(new DictBuilder(pool_)); + return new ListBuilder(pool_, dict_values_->builder()); + })); + RETURN_NOT_OK(dicts_->Append()); + PyObject* key; + PyObject* value; + Py_ssize_t pos = 0; + while (PyDict_Next(dict, &pos, &key, &value)) { + RETURN_NOT_OK(dict_values_->builder()->Append()); + RETURN_NOT_OK( + Append(context, key, &dict_values_->keys(), recursion_depth + 1, blobs_out)); + RETURN_NOT_OK( + Append(context, value, &dict_values_->vals(), recursion_depth + 1, blobs_out)); + } + + // This block is used to decrement the reference counts of the results + // returned by the serialization callback, which is called in AppendArray, + // in DeserializeDict and in Append + static PyObject* py_type = PyUnicode_FromString("_pytype_"); + if (PyDict_Contains(dict, py_type)) { + // If the dictionary contains the key "_pytype_", then the user has to + // have registered a callback. + if (context == Py_None) { + return Status::Invalid("No serialization callback set"); + } + Py_XDECREF(dict); + } + return Status::OK(); +} + +Status CallCustomCallback(PyObject* context, PyObject* method_name, PyObject* elem, + PyObject** result) { + if (context == Py_None) { + *result = NULL; + return Status::SerializationError("error while calling callback on ", + internal::PyObject_StdStringRepr(elem), + ": handler not registered"); + } else { + *result = PyObject_CallMethodObjArgs(context, method_name, elem, NULL); + return CheckPyError(); + } +} + +Status CallSerializeCallback(PyObject* context, PyObject* value, + PyObject** serialized_object) { + OwnedRef method_name(PyUnicode_FromString("_serialize_callback")); + RETURN_NOT_OK(CallCustomCallback(context, method_name.obj(), value, serialized_object)); + if (!PyDict_Check(*serialized_object)) { + return Status::TypeError("serialization callback must return a valid dictionary"); + } + return Status::OK(); +} + +Status CallDeserializeCallback(PyObject* context, PyObject* value, + PyObject** deserialized_object) { + OwnedRef method_name(PyUnicode_FromString("_deserialize_callback")); + return CallCustomCallback(context, method_name.obj(), value, deserialized_object); +} + +Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder, + int32_t recursion_depth, SerializedPyObject* blobs_out); + +template +Status AppendIntegerScalar(PyObject* obj, SequenceBuilder* builder) { + int64_t value = reinterpret_cast(obj)->obval; + return builder->AppendInt64(value); +} + +// Append a potentially 64-bit wide unsigned Numpy scalar. +// Must check for overflow as we reinterpret it as signed int64. +template +Status AppendLargeUnsignedScalar(PyObject* obj, SequenceBuilder* builder) { + constexpr uint64_t max_value = std::numeric_limits::max(); + + uint64_t value = reinterpret_cast(obj)->obval; + if (value > max_value) { + return Status::Invalid("cannot serialize Numpy uint64 scalar >= 2**63"); + } + return builder->AppendInt64(static_cast(value)); +} + +Status AppendScalar(PyObject* obj, SequenceBuilder* builder) { + if (PyArray_IsScalar(obj, Bool)) { + return builder->AppendBool(reinterpret_cast(obj)->obval != 0); + } else if (PyArray_IsScalar(obj, Half)) { + return builder->AppendHalfFloat(reinterpret_cast(obj)->obval); + } else if (PyArray_IsScalar(obj, Float)) { + return builder->AppendFloat(reinterpret_cast(obj)->obval); + } else if (PyArray_IsScalar(obj, Double)) { + return builder->AppendDouble(reinterpret_cast(obj)->obval); + } + if (PyArray_IsScalar(obj, Byte)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, Short)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, Int)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, Long)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, LongLong)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, Int64)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, UByte)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, UShort)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, UInt)) { + return AppendIntegerScalar(obj, builder); + } else if (PyArray_IsScalar(obj, ULong)) { + return AppendLargeUnsignedScalar(obj, builder); + } else if (PyArray_IsScalar(obj, ULongLong)) { + return AppendLargeUnsignedScalar(obj, builder); + } else if (PyArray_IsScalar(obj, UInt64)) { + return AppendLargeUnsignedScalar(obj, builder); + } + return Status::NotImplemented("Numpy scalar type not recognized"); +} + +Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder, + int32_t recursion_depth, SerializedPyObject* blobs_out) { + // The bool case must precede the int case (PyInt_Check passes for bools) + if (PyBool_Check(elem)) { + RETURN_NOT_OK(builder->AppendBool(elem == Py_True)); + } else if (PyArray_DescrFromScalar(elem)->type_num == NPY_HALF) { + npy_half halffloat = reinterpret_cast(elem)->obval; + RETURN_NOT_OK(builder->AppendHalfFloat(halffloat)); + } else if (PyFloat_Check(elem)) { + RETURN_NOT_OK(builder->AppendDouble(PyFloat_AS_DOUBLE(elem))); + } else if (PyLong_Check(elem)) { + int overflow = 0; + int64_t data = PyLong_AsLongLongAndOverflow(elem, &overflow); + if (!overflow) { + RETURN_NOT_OK(builder->AppendInt64(data)); + } else { + // Attempt to serialize the object using the custom callback. + PyObject* serialized_object; + // The reference count of serialized_object will be decremented in SerializeDict + RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object)); + RETURN_NOT_OK( + builder->AppendDict(context, serialized_object, recursion_depth, blobs_out)); + } + } else if (PyBytes_Check(elem)) { + auto data = reinterpret_cast(PyBytes_AS_STRING(elem)); + int32_t size = -1; + RETURN_NOT_OK(internal::CastSize(PyBytes_GET_SIZE(elem), &size)); + RETURN_NOT_OK(builder->AppendBytes(data, size)); + } else if (PyUnicode_Check(elem)) { + ARROW_ASSIGN_OR_RAISE(auto view, PyBytesView::FromUnicode(elem)); + int32_t size = -1; + RETURN_NOT_OK(internal::CastSize(view.size, &size)); + RETURN_NOT_OK(builder->AppendString(view.bytes, size)); + } else if (PyList_CheckExact(elem)) { + RETURN_NOT_OK(builder->AppendList(context, elem, recursion_depth, blobs_out)); + } else if (PyDict_CheckExact(elem)) { + RETURN_NOT_OK(builder->AppendDict(context, elem, recursion_depth, blobs_out)); + } else if (PyTuple_CheckExact(elem)) { + RETURN_NOT_OK(builder->AppendTuple(context, elem, recursion_depth, blobs_out)); + } else if (PySet_Check(elem)) { + RETURN_NOT_OK(builder->AppendSet(context, elem, recursion_depth, blobs_out)); + } else if (PyArray_IsScalar(elem, Generic)) { + RETURN_NOT_OK(AppendScalar(elem, builder)); + } else if (PyArray_CheckExact(elem)) { + RETURN_NOT_OK(AppendArray(context, reinterpret_cast(elem), builder, + recursion_depth, blobs_out)); + } else if (elem == Py_None) { + RETURN_NOT_OK(builder->AppendNone()); + } else if (PyDateTime_Check(elem)) { + PyDateTime_DateTime* datetime = reinterpret_cast(elem); + RETURN_NOT_OK(builder->AppendDate64(internal::PyDateTime_to_us(datetime))); + } else if (is_buffer(elem)) { + RETURN_NOT_OK(builder->AppendBuffer(static_cast(blobs_out->buffers.size()))); + ARROW_ASSIGN_OR_RAISE(auto buffer, unwrap_buffer(elem)); + blobs_out->buffers.push_back(buffer); + } else if (is_tensor(elem)) { + RETURN_NOT_OK(builder->AppendTensor(static_cast(blobs_out->tensors.size()))); + ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_tensor(elem)); + blobs_out->tensors.push_back(tensor); + } else if (is_sparse_coo_tensor(elem)) { + RETURN_NOT_OK(builder->AppendSparseCOOTensor( + static_cast(blobs_out->sparse_tensors.size()))); + ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_sparse_coo_tensor(elem)); + blobs_out->sparse_tensors.push_back(tensor); + } else if (is_sparse_csr_matrix(elem)) { + RETURN_NOT_OK(builder->AppendSparseCSRMatrix( + static_cast(blobs_out->sparse_tensors.size()))); + ARROW_ASSIGN_OR_RAISE(auto matrix, unwrap_sparse_csr_matrix(elem)); + blobs_out->sparse_tensors.push_back(matrix); + } else if (is_sparse_csc_matrix(elem)) { + RETURN_NOT_OK(builder->AppendSparseCSCMatrix( + static_cast(blobs_out->sparse_tensors.size()))); + ARROW_ASSIGN_OR_RAISE(auto matrix, unwrap_sparse_csc_matrix(elem)); + blobs_out->sparse_tensors.push_back(matrix); + } else if (is_sparse_csf_tensor(elem)) { + RETURN_NOT_OK(builder->AppendSparseCSFTensor( + static_cast(blobs_out->sparse_tensors.size()))); + ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_sparse_csf_tensor(elem)); + blobs_out->sparse_tensors.push_back(tensor); + } else { + // Attempt to serialize the object using the custom callback. + PyObject* serialized_object; + // The reference count of serialized_object will be decremented in SerializeDict + RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object)); + RETURN_NOT_OK( + builder->AppendDict(context, serialized_object, recursion_depth, blobs_out)); + } + return Status::OK(); +} + +Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder, + int32_t recursion_depth, SerializedPyObject* blobs_out) { + int dtype = PyArray_TYPE(array); + switch (dtype) { + case NPY_UINT8: + case NPY_INT8: + case NPY_UINT16: + case NPY_INT16: + case NPY_UINT32: + case NPY_INT32: + case NPY_UINT64: + case NPY_INT64: + case NPY_HALF: + case NPY_FLOAT: + case NPY_DOUBLE: { + RETURN_NOT_OK( + builder->AppendNdarray(static_cast(blobs_out->ndarrays.size()))); + std::shared_ptr tensor; + RETURN_NOT_OK(NdarrayToTensor(default_memory_pool(), + reinterpret_cast(array), {}, &tensor)); + blobs_out->ndarrays.push_back(tensor); + } break; + default: { + PyObject* serialized_object; + // The reference count of serialized_object will be decremented in SerializeDict + RETURN_NOT_OK(CallSerializeCallback(context, reinterpret_cast(array), + &serialized_object)); + RETURN_NOT_OK(builder->AppendDict(context, serialized_object, recursion_depth + 1, + blobs_out)); + } + } + return Status::OK(); +} + +std::shared_ptr MakeBatch(std::shared_ptr data) { + auto field = std::make_shared("list", data->type()); + auto schema = ::arrow::schema({field}); + return RecordBatch::Make(schema, data->length(), {data}); +} + +Status SerializeObject(PyObject* context, PyObject* sequence, SerializedPyObject* out) { + PyAcquireGIL lock; + SequenceBuilder builder; + RETURN_NOT_OK(internal::VisitIterable( + sequence, [&](PyObject* obj, bool* keep_going /* unused */) { + return Append(context, obj, &builder, 0, out); + })); + std::shared_ptr array; + RETURN_NOT_OK(builder.Finish(&array)); + out->batch = MakeBatch(array); + return Status::OK(); +} + +Status SerializeNdarray(std::shared_ptr tensor, SerializedPyObject* out) { + std::shared_ptr array; + SequenceBuilder builder; + RETURN_NOT_OK(builder.AppendNdarray(static_cast(out->ndarrays.size()))); + out->ndarrays.push_back(tensor); + RETURN_NOT_OK(builder.Finish(&array)); + out->batch = MakeBatch(array); + return Status::OK(); +} + +Status WriteNdarrayHeader(std::shared_ptr dtype, + const std::vector& shape, int64_t tensor_num_bytes, + io::OutputStream* dst) { + auto empty_tensor = std::make_shared( + dtype, std::make_shared(nullptr, tensor_num_bytes), shape); + SerializedPyObject serialized_tensor; + RETURN_NOT_OK(SerializeNdarray(empty_tensor, &serialized_tensor)); + return serialized_tensor.WriteTo(dst); +} + +SerializedPyObject::SerializedPyObject() + : ipc_options(ipc::IpcWriteOptions::Defaults()) {} + +Status SerializedPyObject::WriteTo(io::OutputStream* dst) { + int32_t num_tensors = static_cast(this->tensors.size()); + int32_t num_sparse_tensors = static_cast(this->sparse_tensors.size()); + int32_t num_ndarrays = static_cast(this->ndarrays.size()); + int32_t num_buffers = static_cast(this->buffers.size()); + RETURN_NOT_OK( + dst->Write(reinterpret_cast(&num_tensors), sizeof(int32_t))); + RETURN_NOT_OK( + dst->Write(reinterpret_cast(&num_sparse_tensors), sizeof(int32_t))); + RETURN_NOT_OK( + dst->Write(reinterpret_cast(&num_ndarrays), sizeof(int32_t))); + RETURN_NOT_OK( + dst->Write(reinterpret_cast(&num_buffers), sizeof(int32_t))); + + // Align stream to 8-byte offset + RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kArrowIpcAlignment)); + RETURN_NOT_OK(ipc::WriteRecordBatchStream({this->batch}, this->ipc_options, dst)); + + // Align stream to 64-byte offset so tensor bodies are 64-byte aligned + RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); + + int32_t metadata_length; + int64_t body_length; + for (const auto& tensor : this->tensors) { + RETURN_NOT_OK(ipc::WriteTensor(*tensor, dst, &metadata_length, &body_length)); + RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); + } + + for (const auto& sparse_tensor : this->sparse_tensors) { + RETURN_NOT_OK( + ipc::WriteSparseTensor(*sparse_tensor, dst, &metadata_length, &body_length)); + RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); + } + + for (const auto& tensor : this->ndarrays) { + RETURN_NOT_OK(ipc::WriteTensor(*tensor, dst, &metadata_length, &body_length)); + RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); + } + + for (const auto& buffer : this->buffers) { + int64_t size = buffer->size(); + RETURN_NOT_OK(dst->Write(reinterpret_cast(&size), sizeof(int64_t))); + RETURN_NOT_OK(dst->Write(buffer->data(), size)); + } + + return Status::OK(); +} + +namespace { + +Status CountSparseTensors( + const std::vector>& sparse_tensors, PyObject** out) { + OwnedRef num_sparse_tensors(PyDict_New()); + size_t num_coo = 0; + size_t num_csr = 0; + size_t num_csc = 0; + size_t num_csf = 0; + size_t ndim_csf = 0; + + for (const auto& sparse_tensor : sparse_tensors) { + switch (sparse_tensor->format_id()) { + case SparseTensorFormat::COO: + ++num_coo; + break; + case SparseTensorFormat::CSR: + ++num_csr; + break; + case SparseTensorFormat::CSC: + ++num_csc; + break; + case SparseTensorFormat::CSF: + ++num_csf; + ndim_csf += sparse_tensor->ndim(); + break; + } + } + + PyDict_SetItemString(num_sparse_tensors.obj(), "coo", PyLong_FromSize_t(num_coo)); + PyDict_SetItemString(num_sparse_tensors.obj(), "csr", PyLong_FromSize_t(num_csr)); + PyDict_SetItemString(num_sparse_tensors.obj(), "csc", PyLong_FromSize_t(num_csc)); + PyDict_SetItemString(num_sparse_tensors.obj(), "csf", PyLong_FromSize_t(num_csf)); + PyDict_SetItemString(num_sparse_tensors.obj(), "ndim_csf", PyLong_FromSize_t(ndim_csf)); + RETURN_IF_PYERROR(); + + *out = num_sparse_tensors.detach(); + return Status::OK(); +} + +} // namespace + +Status SerializedPyObject::GetComponents(MemoryPool* memory_pool, PyObject** out) { + PyAcquireGIL py_gil; + + OwnedRef result(PyDict_New()); + PyObject* buffers = PyList_New(0); + PyObject* num_sparse_tensors = nullptr; + + // TODO(wesm): Not sure how pedantic we need to be about checking the return + // values of these functions. There are other places where we do not check + // PyDict_SetItem/SetItemString return value, but these failures would be + // quite esoteric + PyDict_SetItemString(result.obj(), "num_tensors", + PyLong_FromSize_t(this->tensors.size())); + RETURN_NOT_OK(CountSparseTensors(this->sparse_tensors, &num_sparse_tensors)); + PyDict_SetItemString(result.obj(), "num_sparse_tensors", num_sparse_tensors); + PyDict_SetItemString(result.obj(), "ndim_csf", num_sparse_tensors); + PyDict_SetItemString(result.obj(), "num_ndarrays", + PyLong_FromSize_t(this->ndarrays.size())); + PyDict_SetItemString(result.obj(), "num_buffers", + PyLong_FromSize_t(this->buffers.size())); + PyDict_SetItemString(result.obj(), "data", buffers); + RETURN_IF_PYERROR(); + + Py_DECREF(buffers); + + auto PushBuffer = [&buffers](const std::shared_ptr& buffer) { + PyObject* wrapped_buffer = wrap_buffer(buffer); + RETURN_IF_PYERROR(); + if (PyList_Append(buffers, wrapped_buffer) < 0) { + Py_DECREF(wrapped_buffer); + RETURN_IF_PYERROR(); + } + Py_DECREF(wrapped_buffer); + return Status::OK(); + }; + + constexpr int64_t kInitialCapacity = 1024; + + // Write the record batch describing the object structure + py_gil.release(); + ARROW_ASSIGN_OR_RAISE(auto stream, + io::BufferOutputStream::Create(kInitialCapacity, memory_pool)); + RETURN_NOT_OK( + ipc::WriteRecordBatchStream({this->batch}, this->ipc_options, stream.get())); + ARROW_ASSIGN_OR_RAISE(auto buffer, stream->Finish()); + py_gil.acquire(); + + RETURN_NOT_OK(PushBuffer(buffer)); + + // For each tensor, get a metadata buffer and a buffer for the body + for (const auto& tensor : this->tensors) { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, + ipc::GetTensorMessage(*tensor, memory_pool)); + RETURN_NOT_OK(PushBuffer(message->metadata())); + RETURN_NOT_OK(PushBuffer(message->body())); + } + + // For each sparse tensor, get a metadata buffer and buffers containing index and data + for (const auto& sparse_tensor : this->sparse_tensors) { + ipc::IpcPayload payload; + RETURN_NOT_OK(ipc::GetSparseTensorPayload(*sparse_tensor, memory_pool, &payload)); + RETURN_NOT_OK(PushBuffer(payload.metadata)); + for (const auto& body : payload.body_buffers) { + RETURN_NOT_OK(PushBuffer(body)); + } + } + + // For each ndarray, get a metadata buffer and a buffer for the body + for (const auto& ndarray : this->ndarrays) { + ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, + ipc::GetTensorMessage(*ndarray, memory_pool)); + RETURN_NOT_OK(PushBuffer(message->metadata())); + RETURN_NOT_OK(PushBuffer(message->body())); + } + + for (const auto& buf : this->buffers) { + RETURN_NOT_OK(PushBuffer(buf)); + } + + *out = result.detach(); + return Status::OK(); +} + +} // namespace py +} // namespace arrow From a10d3cc19f730d39a2be2955cbd65ef1f8bc26d2 Mon Sep 17 00:00:00 2001 From: gorloffslava Date: Tue, 5 Nov 2024 20:54:19 +0500 Subject: [PATCH 10/14] + Fix building in C++ 20 and 23 language modes --- python/pyarrow/src/arrow/python/serialize.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pyarrow/src/arrow/python/serialize.cc b/python/pyarrow/src/arrow/python/serialize.cc index 2295421165d..f1ce7e10c4e 100644 --- a/python/pyarrow/src/arrow/python/serialize.cc +++ b/python/pyarrow/src/arrow/python/serialize.cc @@ -87,7 +87,7 @@ class DictBuilder { std::shared_ptr builder() { return builder_; } -private: + private: SequenceBuilder keys_; SequenceBuilder vals_; std::shared_ptr builder_; From a202c8fb5d15b809fc647c99d730796affdbc3c2 Mon Sep 17 00:00:00 2001 From: gorloffslava Date: Tue, 5 Nov 2024 21:20:09 +0500 Subject: [PATCH 11/14] + Fix building in C++ 20 and 23 language modes --- python/pyarrow/src/arrow/python/serialize.cc | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/python/pyarrow/src/arrow/python/serialize.cc b/python/pyarrow/src/arrow/python/serialize.cc index f1ce7e10c4e..26ebae3c4d2 100644 --- a/python/pyarrow/src/arrow/python/serialize.cc +++ b/python/pyarrow/src/arrow/python/serialize.cc @@ -69,7 +69,7 @@ Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder, // SequenceBuilders. The resulting Arrow representation // can be obtained via the Finish method. class DictBuilder { -public: + public: explicit DictBuilder(MemoryPool* pool = nullptr) : keys_(pool), vals_(pool) { builder_.reset(new StructBuilder(struct_({field("keys", dense_union(FieldVector{})), field("vals", dense_union(FieldVector{}))}), @@ -77,9 +77,9 @@ class DictBuilder { } // Builder for the keys of the dictionary - SequenceBuilder& keys() { return keys_; } + SequenceBuilder& keys() { return (SequenceBuilder)keys_; } // Builder for the values of the dictionary - SequenceBuilder& vals() { return vals_; } + SequenceBuilder& vals() { return (SequenceBuilder)vals_; } // Construct an Arrow StructArray representing the dictionary. // Contains a field "keys" for the keys and "vals" for the values. @@ -88,8 +88,8 @@ class DictBuilder { std::shared_ptr builder() { return builder_; } private: - SequenceBuilder keys_; - SequenceBuilder vals_; + std::any keys_; + std::any vals_; std::shared_ptr builder_; }; From 16f9b91518a9b706afcecfc36fc58d5a5b2c8db5 Mon Sep 17 00:00:00 2001 From: gorloffslava Date: Tue, 5 Nov 2024 21:53:47 +0500 Subject: [PATCH 12/14] + Fix building in C++ 20 and 23 language modes --- python/pyarrow/src/arrow/python/serialize.cc | 21 +++++++++++--------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/python/pyarrow/src/arrow/python/serialize.cc b/python/pyarrow/src/arrow/python/serialize.cc index 26ebae3c4d2..1329728b48f 100644 --- a/python/pyarrow/src/arrow/python/serialize.cc +++ b/python/pyarrow/src/arrow/python/serialize.cc @@ -24,6 +24,7 @@ #include #include #include +#include #include #include @@ -70,16 +71,12 @@ Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder, // can be obtained via the Finish method. class DictBuilder { public: - explicit DictBuilder(MemoryPool* pool = nullptr) : keys_(pool), vals_(pool) { - builder_.reset(new StructBuilder(struct_({field("keys", dense_union(FieldVector{})), - field("vals", dense_union(FieldVector{}))}), - pool, {keys_.builder(), vals_.builder()})); - } + explicit DictBuilder(MemoryPool* pool = nullptr); // Builder for the keys of the dictionary - SequenceBuilder& keys() { return (SequenceBuilder)keys_; } + SequenceBuilder& keys() { return *keys_; } // Builder for the values of the dictionary - SequenceBuilder& vals() { return (SequenceBuilder)vals_; } + SequenceBuilder& vals() { return *vals_; } // Construct an Arrow StructArray representing the dictionary. // Contains a field "keys" for the keys and "vals" for the values. @@ -88,8 +85,8 @@ class DictBuilder { std::shared_ptr builder() { return builder_; } private: - std::any keys_; - std::any vals_; + std::unique_ptr keys_; + std::unique_ptr vals_; std::shared_ptr builder_; }; @@ -327,6 +324,12 @@ class SequenceBuilder { std::shared_ptr builder_; }; +DictBuilder::DictBuilder(MemoryPool* pool) : keys_(std::make_unique(SequenceBuilder(pool))), vals_(std::make_unique(SequenceBuilder(pool))) { + builder_.reset(new StructBuilder(struct_({field("keys", dense_union(FieldVector{})), + field("vals", dense_union(FieldVector{}))}), + pool, {keys_->builder(), vals_->builder()})); +} + Status SequenceBuilder::AppendDict(PyObject* context, PyObject* dict, int32_t recursion_depth, SerializedPyObject* blobs_out) { From 1aa6a621b538391c907debd7ffe994237d11dc44 Mon Sep 17 00:00:00 2001 From: gorloffslava Date: Tue, 5 Nov 2024 22:07:16 +0500 Subject: [PATCH 13/14] + Fix building in C++ 20 and 23 language modes --- python/pyarrow/src/arrow/python/serialize.cc | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/python/pyarrow/src/arrow/python/serialize.cc b/python/pyarrow/src/arrow/python/serialize.cc index 1329728b48f..7c30bdf39c9 100644 --- a/python/pyarrow/src/arrow/python/serialize.cc +++ b/python/pyarrow/src/arrow/python/serialize.cc @@ -74,9 +74,9 @@ class DictBuilder { explicit DictBuilder(MemoryPool* pool = nullptr); // Builder for the keys of the dictionary - SequenceBuilder& keys() { return *keys_; } + SequenceBuilder* keys() { return keys_.get(); } // Builder for the values of the dictionary - SequenceBuilder& vals() { return *vals_; } + SequenceBuilder* vals() { return vals_.get(); } // Construct an Arrow StructArray representing the dictionary. // Contains a field "keys" for the keys and "vals" for the values. @@ -349,9 +349,9 @@ Status SequenceBuilder::AppendDict(PyObject* context, PyObject* dict, while (PyDict_Next(dict, &pos, &key, &value)) { RETURN_NOT_OK(dict_values_->builder()->Append()); RETURN_NOT_OK( - Append(context, key, &dict_values_->keys(), recursion_depth + 1, blobs_out)); + Append(context, key, dict_values_->keys(), recursion_depth + 1, blobs_out)); RETURN_NOT_OK( - Append(context, value, &dict_values_->vals(), recursion_depth + 1, blobs_out)); + Append(context, value, dict_values_->vals(), recursion_depth + 1, blobs_out)); } // This block is used to decrement the reference counts of the results From f1d315bd7ce2e651a4aaeee475ff49acfc029af7 Mon Sep 17 00:00:00 2001 From: Benjamin Leff Date: Wed, 13 Aug 2025 11:08:19 -0700 Subject: [PATCH 14/14] chore: deleting python/pyarrow/src/arrow/python/serialize.cc as it no longer exists in upstream --- python/pyarrow/src/arrow/python/serialize.cc | 801 ------------------- 1 file changed, 801 deletions(-) delete mode 100644 python/pyarrow/src/arrow/python/serialize.cc diff --git a/python/pyarrow/src/arrow/python/serialize.cc b/python/pyarrow/src/arrow/python/serialize.cc deleted file mode 100644 index 7c30bdf39c9..00000000000 --- a/python/pyarrow/src/arrow/python/serialize.cc +++ /dev/null @@ -1,801 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -#include "arrow/python/serialize.h" -#include "arrow/python/numpy_interop.h" - -#include -#include -#include -#include -#include -#include -#include - -#include -#include - -#include "arrow/array.h" -#include "arrow/array/builder_binary.h" -#include "arrow/array/builder_nested.h" -#include "arrow/array/builder_primitive.h" -#include "arrow/array/builder_union.h" -#include "arrow/io/interfaces.h" -#include "arrow/io/memory.h" -#include "arrow/ipc/util.h" -#include "arrow/ipc/writer.h" -#include "arrow/record_batch.h" -#include "arrow/result.h" -#include "arrow/tensor.h" -#include "arrow/util/logging.h" - -#include "arrow/python/common.h" -#include "arrow/python/datetime.h" -#include "arrow/python/helpers.h" -#include "arrow/python/iterators.h" -#include "arrow/python/numpy_convert.h" -#include "arrow/python/platform.h" -#include "arrow/python/pyarrow.h" - -constexpr int32_t kMaxRecursionDepth = 100; - -namespace arrow { - -using internal::checked_cast; - -namespace py { - -class SequenceBuilder; -class DictBuilder; - -Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder, - int32_t recursion_depth, SerializedPyObject* blobs_out); - -// Constructing dictionaries of key/value pairs. Sequences of -// keys and values are built separately using a pair of -// SequenceBuilders. The resulting Arrow representation -// can be obtained via the Finish method. -class DictBuilder { - public: - explicit DictBuilder(MemoryPool* pool = nullptr); - - // Builder for the keys of the dictionary - SequenceBuilder* keys() { return keys_.get(); } - // Builder for the values of the dictionary - SequenceBuilder* vals() { return vals_.get(); } - - // Construct an Arrow StructArray representing the dictionary. - // Contains a field "keys" for the keys and "vals" for the values. - Status Finish(std::shared_ptr* out) { return builder_->Finish(out); } - - std::shared_ptr builder() { return builder_; } - - private: - std::unique_ptr keys_; - std::unique_ptr vals_; - std::shared_ptr builder_; -}; - -// A Sequence is a heterogeneous collections of elements. It can contain -// scalar Python types, lists, tuples, dictionaries, tensors and sparse tensors. -class SequenceBuilder { - public: - explicit SequenceBuilder(MemoryPool* pool = default_memory_pool()) - : pool_(pool), - types_(::arrow::int8(), pool), - offsets_(::arrow::int32(), pool), - type_map_(PythonType::NUM_PYTHON_TYPES, -1) { - auto null_builder = std::make_shared(pool); - auto initial_ty = dense_union({field("0", null())}); - builder_.reset(new DenseUnionBuilder(pool, {null_builder}, initial_ty)); - } - - // Appending a none to the sequence - Status AppendNone() { return builder_->AppendNull(); } - - template - Status CreateAndUpdate(std::shared_ptr* child_builder, int8_t tag, - MakeBuilderFn make_builder) { - if (!*child_builder) { - child_builder->reset(make_builder()); - std::ostringstream convert; - convert.imbue(std::locale::classic()); - convert << static_cast(tag); - type_map_[tag] = builder_->AppendChild(*child_builder, convert.str()); - } - return builder_->Append(type_map_[tag]); - } - - template - Status AppendPrimitive(std::shared_ptr* child_builder, const T val, - int8_t tag) { - RETURN_NOT_OK( - CreateAndUpdate(child_builder, tag, [this]() { return new BuilderType(pool_); })); - return (*child_builder)->Append(val); - } - - // Appending a boolean to the sequence - Status AppendBool(const bool data) { - return AppendPrimitive(&bools_, data, PythonType::BOOL); - } - - // Appending an int64_t to the sequence - Status AppendInt64(const int64_t data) { - return AppendPrimitive(&ints_, data, PythonType::INT); - } - - // Append a list of bytes to the sequence - Status AppendBytes(const uint8_t* data, int32_t length) { - RETURN_NOT_OK(CreateAndUpdate(&bytes_, PythonType::BYTES, - [this]() { return new BinaryBuilder(pool_); })); - return bytes_->Append(data, length); - } - - // Appending a string to the sequence - Status AppendString(const char* data, int32_t length) { - RETURN_NOT_OK(CreateAndUpdate(&strings_, PythonType::STRING, - [this]() { return new StringBuilder(pool_); })); - return strings_->Append(data, length); - } - - // Appending a half_float to the sequence - Status AppendHalfFloat(const npy_half data) { - return AppendPrimitive(&half_floats_, data, PythonType::HALF_FLOAT); - } - - // Appending a float to the sequence - Status AppendFloat(const float data) { - return AppendPrimitive(&floats_, data, PythonType::FLOAT); - } - - // Appending a double to the sequence - Status AppendDouble(const double data) { - return AppendPrimitive(&doubles_, data, PythonType::DOUBLE); - } - - // Appending a Date64 timestamp to the sequence - Status AppendDate64(const int64_t timestamp) { - return AppendPrimitive(&date64s_, timestamp, PythonType::DATE64); - } - - // Appending a tensor to the sequence - // - // \param tensor_index Index of the tensor in the object. - Status AppendTensor(const int32_t tensor_index) { - RETURN_NOT_OK(CreateAndUpdate(&tensor_indices_, PythonType::TENSOR, - [this]() { return new Int32Builder(pool_); })); - return tensor_indices_->Append(tensor_index); - } - - // Appending a sparse coo tensor to the sequence - // - // \param sparse_coo_tensor_index Index of the sparse coo tensor in the object. - Status AppendSparseCOOTensor(const int32_t sparse_coo_tensor_index) { - RETURN_NOT_OK(CreateAndUpdate(&sparse_coo_tensor_indices_, - PythonType::SPARSECOOTENSOR, - [this]() { return new Int32Builder(pool_); })); - return sparse_coo_tensor_indices_->Append(sparse_coo_tensor_index); - } - - // Appending a sparse csr matrix to the sequence - // - // \param sparse_csr_matrix_index Index of the sparse csr matrix in the object. - Status AppendSparseCSRMatrix(const int32_t sparse_csr_matrix_index) { - RETURN_NOT_OK(CreateAndUpdate(&sparse_csr_matrix_indices_, - PythonType::SPARSECSRMATRIX, - [this]() { return new Int32Builder(pool_); })); - return sparse_csr_matrix_indices_->Append(sparse_csr_matrix_index); - } - - // Appending a sparse csc matrix to the sequence - // - // \param sparse_csc_matrix_index Index of the sparse csc matrix in the object. - Status AppendSparseCSCMatrix(const int32_t sparse_csc_matrix_index) { - RETURN_NOT_OK(CreateAndUpdate(&sparse_csc_matrix_indices_, - PythonType::SPARSECSCMATRIX, - [this]() { return new Int32Builder(pool_); })); - return sparse_csc_matrix_indices_->Append(sparse_csc_matrix_index); - } - - // Appending a sparse csf tensor to the sequence - // - // \param sparse_csf_tensor_index Index of the sparse csf tensor in the object. - Status AppendSparseCSFTensor(const int32_t sparse_csf_tensor_index) { - RETURN_NOT_OK(CreateAndUpdate(&sparse_csf_tensor_indices_, - PythonType::SPARSECSFTENSOR, - [this]() { return new Int32Builder(pool_); })); - return sparse_csf_tensor_indices_->Append(sparse_csf_tensor_index); - } - - // Appending a numpy ndarray to the sequence - // - // \param tensor_index Index of the tensor in the object. - Status AppendNdarray(const int32_t ndarray_index) { - RETURN_NOT_OK(CreateAndUpdate(&ndarray_indices_, PythonType::NDARRAY, - [this]() { return new Int32Builder(pool_); })); - return ndarray_indices_->Append(ndarray_index); - } - - // Appending a buffer to the sequence - // - // \param buffer_index Index of the buffer in the object. - Status AppendBuffer(const int32_t buffer_index) { - RETURN_NOT_OK(CreateAndUpdate(&buffer_indices_, PythonType::BUFFER, - [this]() { return new Int32Builder(pool_); })); - return buffer_indices_->Append(buffer_index); - } - - Status AppendSequence(PyObject* context, PyObject* sequence, int8_t tag, - std::shared_ptr& target_sequence, - std::unique_ptr& values, int32_t recursion_depth, - SerializedPyObject* blobs_out) { - if (recursion_depth >= kMaxRecursionDepth) { - return Status::NotImplemented( - "This object exceeds the maximum recursion depth. It may contain itself " - "recursively."); - } - RETURN_NOT_OK(CreateAndUpdate(&target_sequence, tag, [this, &values]() { - values.reset(new SequenceBuilder(pool_)); - return new ListBuilder(pool_, values->builder()); - })); - RETURN_NOT_OK(target_sequence->Append()); - return internal::VisitIterable( - sequence, [&](PyObject* obj, bool* keep_going /* unused */) { - return Append(context, obj, values.get(), recursion_depth, blobs_out); - }); - } - - Status AppendList(PyObject* context, PyObject* list, int32_t recursion_depth, - SerializedPyObject* blobs_out) { - return AppendSequence(context, list, PythonType::LIST, lists_, list_values_, - recursion_depth + 1, blobs_out); - } - - Status AppendTuple(PyObject* context, PyObject* tuple, int32_t recursion_depth, - SerializedPyObject* blobs_out) { - return AppendSequence(context, tuple, PythonType::TUPLE, tuples_, tuple_values_, - recursion_depth + 1, blobs_out); - } - - Status AppendSet(PyObject* context, PyObject* set, int32_t recursion_depth, - SerializedPyObject* blobs_out) { - return AppendSequence(context, set, PythonType::SET, sets_, set_values_, - recursion_depth + 1, blobs_out); - } - - Status AppendDict(PyObject* context, PyObject* dict, int32_t recursion_depth, - SerializedPyObject* blobs_out); - - // Finish building the sequence and return the result. - // Input arrays may be nullptr - Status Finish(std::shared_ptr* out) { return builder_->Finish(out); } - - std::shared_ptr builder() { return builder_; } - - private: - MemoryPool* pool_; - - Int8Builder types_; - Int32Builder offsets_; - - /// Mapping from PythonType to child index - std::vector type_map_; - - std::shared_ptr bools_; - std::shared_ptr ints_; - std::shared_ptr bytes_; - std::shared_ptr strings_; - std::shared_ptr half_floats_; - std::shared_ptr floats_; - std::shared_ptr doubles_; - std::shared_ptr date64s_; - - std::unique_ptr list_values_; - std::shared_ptr lists_; - std::unique_ptr dict_values_; - std::shared_ptr dicts_; - std::unique_ptr tuple_values_; - std::shared_ptr tuples_; - std::unique_ptr set_values_; - std::shared_ptr sets_; - - std::shared_ptr tensor_indices_; - std::shared_ptr sparse_coo_tensor_indices_; - std::shared_ptr sparse_csr_matrix_indices_; - std::shared_ptr sparse_csc_matrix_indices_; - std::shared_ptr sparse_csf_tensor_indices_; - std::shared_ptr ndarray_indices_; - std::shared_ptr buffer_indices_; - - std::shared_ptr builder_; -}; - -DictBuilder::DictBuilder(MemoryPool* pool) : keys_(std::make_unique(SequenceBuilder(pool))), vals_(std::make_unique(SequenceBuilder(pool))) { - builder_.reset(new StructBuilder(struct_({field("keys", dense_union(FieldVector{})), - field("vals", dense_union(FieldVector{}))}), - pool, {keys_->builder(), vals_->builder()})); -} - -Status SequenceBuilder::AppendDict(PyObject* context, PyObject* dict, - int32_t recursion_depth, - SerializedPyObject* blobs_out) { - if (recursion_depth >= kMaxRecursionDepth) { - return Status::NotImplemented( - "This object exceeds the maximum recursion depth. It may contain itself " - "recursively."); - } - RETURN_NOT_OK(CreateAndUpdate(&dicts_, PythonType::DICT, [this]() { - dict_values_.reset(new DictBuilder(pool_)); - return new ListBuilder(pool_, dict_values_->builder()); - })); - RETURN_NOT_OK(dicts_->Append()); - PyObject* key; - PyObject* value; - Py_ssize_t pos = 0; - while (PyDict_Next(dict, &pos, &key, &value)) { - RETURN_NOT_OK(dict_values_->builder()->Append()); - RETURN_NOT_OK( - Append(context, key, dict_values_->keys(), recursion_depth + 1, blobs_out)); - RETURN_NOT_OK( - Append(context, value, dict_values_->vals(), recursion_depth + 1, blobs_out)); - } - - // This block is used to decrement the reference counts of the results - // returned by the serialization callback, which is called in AppendArray, - // in DeserializeDict and in Append - static PyObject* py_type = PyUnicode_FromString("_pytype_"); - if (PyDict_Contains(dict, py_type)) { - // If the dictionary contains the key "_pytype_", then the user has to - // have registered a callback. - if (context == Py_None) { - return Status::Invalid("No serialization callback set"); - } - Py_XDECREF(dict); - } - return Status::OK(); -} - -Status CallCustomCallback(PyObject* context, PyObject* method_name, PyObject* elem, - PyObject** result) { - if (context == Py_None) { - *result = NULL; - return Status::SerializationError("error while calling callback on ", - internal::PyObject_StdStringRepr(elem), - ": handler not registered"); - } else { - *result = PyObject_CallMethodObjArgs(context, method_name, elem, NULL); - return CheckPyError(); - } -} - -Status CallSerializeCallback(PyObject* context, PyObject* value, - PyObject** serialized_object) { - OwnedRef method_name(PyUnicode_FromString("_serialize_callback")); - RETURN_NOT_OK(CallCustomCallback(context, method_name.obj(), value, serialized_object)); - if (!PyDict_Check(*serialized_object)) { - return Status::TypeError("serialization callback must return a valid dictionary"); - } - return Status::OK(); -} - -Status CallDeserializeCallback(PyObject* context, PyObject* value, - PyObject** deserialized_object) { - OwnedRef method_name(PyUnicode_FromString("_deserialize_callback")); - return CallCustomCallback(context, method_name.obj(), value, deserialized_object); -} - -Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder, - int32_t recursion_depth, SerializedPyObject* blobs_out); - -template -Status AppendIntegerScalar(PyObject* obj, SequenceBuilder* builder) { - int64_t value = reinterpret_cast(obj)->obval; - return builder->AppendInt64(value); -} - -// Append a potentially 64-bit wide unsigned Numpy scalar. -// Must check for overflow as we reinterpret it as signed int64. -template -Status AppendLargeUnsignedScalar(PyObject* obj, SequenceBuilder* builder) { - constexpr uint64_t max_value = std::numeric_limits::max(); - - uint64_t value = reinterpret_cast(obj)->obval; - if (value > max_value) { - return Status::Invalid("cannot serialize Numpy uint64 scalar >= 2**63"); - } - return builder->AppendInt64(static_cast(value)); -} - -Status AppendScalar(PyObject* obj, SequenceBuilder* builder) { - if (PyArray_IsScalar(obj, Bool)) { - return builder->AppendBool(reinterpret_cast(obj)->obval != 0); - } else if (PyArray_IsScalar(obj, Half)) { - return builder->AppendHalfFloat(reinterpret_cast(obj)->obval); - } else if (PyArray_IsScalar(obj, Float)) { - return builder->AppendFloat(reinterpret_cast(obj)->obval); - } else if (PyArray_IsScalar(obj, Double)) { - return builder->AppendDouble(reinterpret_cast(obj)->obval); - } - if (PyArray_IsScalar(obj, Byte)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, Short)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, Int)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, Long)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, LongLong)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, Int64)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, UByte)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, UShort)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, UInt)) { - return AppendIntegerScalar(obj, builder); - } else if (PyArray_IsScalar(obj, ULong)) { - return AppendLargeUnsignedScalar(obj, builder); - } else if (PyArray_IsScalar(obj, ULongLong)) { - return AppendLargeUnsignedScalar(obj, builder); - } else if (PyArray_IsScalar(obj, UInt64)) { - return AppendLargeUnsignedScalar(obj, builder); - } - return Status::NotImplemented("Numpy scalar type not recognized"); -} - -Status Append(PyObject* context, PyObject* elem, SequenceBuilder* builder, - int32_t recursion_depth, SerializedPyObject* blobs_out) { - // The bool case must precede the int case (PyInt_Check passes for bools) - if (PyBool_Check(elem)) { - RETURN_NOT_OK(builder->AppendBool(elem == Py_True)); - } else if (PyArray_DescrFromScalar(elem)->type_num == NPY_HALF) { - npy_half halffloat = reinterpret_cast(elem)->obval; - RETURN_NOT_OK(builder->AppendHalfFloat(halffloat)); - } else if (PyFloat_Check(elem)) { - RETURN_NOT_OK(builder->AppendDouble(PyFloat_AS_DOUBLE(elem))); - } else if (PyLong_Check(elem)) { - int overflow = 0; - int64_t data = PyLong_AsLongLongAndOverflow(elem, &overflow); - if (!overflow) { - RETURN_NOT_OK(builder->AppendInt64(data)); - } else { - // Attempt to serialize the object using the custom callback. - PyObject* serialized_object; - // The reference count of serialized_object will be decremented in SerializeDict - RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object)); - RETURN_NOT_OK( - builder->AppendDict(context, serialized_object, recursion_depth, blobs_out)); - } - } else if (PyBytes_Check(elem)) { - auto data = reinterpret_cast(PyBytes_AS_STRING(elem)); - int32_t size = -1; - RETURN_NOT_OK(internal::CastSize(PyBytes_GET_SIZE(elem), &size)); - RETURN_NOT_OK(builder->AppendBytes(data, size)); - } else if (PyUnicode_Check(elem)) { - ARROW_ASSIGN_OR_RAISE(auto view, PyBytesView::FromUnicode(elem)); - int32_t size = -1; - RETURN_NOT_OK(internal::CastSize(view.size, &size)); - RETURN_NOT_OK(builder->AppendString(view.bytes, size)); - } else if (PyList_CheckExact(elem)) { - RETURN_NOT_OK(builder->AppendList(context, elem, recursion_depth, blobs_out)); - } else if (PyDict_CheckExact(elem)) { - RETURN_NOT_OK(builder->AppendDict(context, elem, recursion_depth, blobs_out)); - } else if (PyTuple_CheckExact(elem)) { - RETURN_NOT_OK(builder->AppendTuple(context, elem, recursion_depth, blobs_out)); - } else if (PySet_Check(elem)) { - RETURN_NOT_OK(builder->AppendSet(context, elem, recursion_depth, blobs_out)); - } else if (PyArray_IsScalar(elem, Generic)) { - RETURN_NOT_OK(AppendScalar(elem, builder)); - } else if (PyArray_CheckExact(elem)) { - RETURN_NOT_OK(AppendArray(context, reinterpret_cast(elem), builder, - recursion_depth, blobs_out)); - } else if (elem == Py_None) { - RETURN_NOT_OK(builder->AppendNone()); - } else if (PyDateTime_Check(elem)) { - PyDateTime_DateTime* datetime = reinterpret_cast(elem); - RETURN_NOT_OK(builder->AppendDate64(internal::PyDateTime_to_us(datetime))); - } else if (is_buffer(elem)) { - RETURN_NOT_OK(builder->AppendBuffer(static_cast(blobs_out->buffers.size()))); - ARROW_ASSIGN_OR_RAISE(auto buffer, unwrap_buffer(elem)); - blobs_out->buffers.push_back(buffer); - } else if (is_tensor(elem)) { - RETURN_NOT_OK(builder->AppendTensor(static_cast(blobs_out->tensors.size()))); - ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_tensor(elem)); - blobs_out->tensors.push_back(tensor); - } else if (is_sparse_coo_tensor(elem)) { - RETURN_NOT_OK(builder->AppendSparseCOOTensor( - static_cast(blobs_out->sparse_tensors.size()))); - ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_sparse_coo_tensor(elem)); - blobs_out->sparse_tensors.push_back(tensor); - } else if (is_sparse_csr_matrix(elem)) { - RETURN_NOT_OK(builder->AppendSparseCSRMatrix( - static_cast(blobs_out->sparse_tensors.size()))); - ARROW_ASSIGN_OR_RAISE(auto matrix, unwrap_sparse_csr_matrix(elem)); - blobs_out->sparse_tensors.push_back(matrix); - } else if (is_sparse_csc_matrix(elem)) { - RETURN_NOT_OK(builder->AppendSparseCSCMatrix( - static_cast(blobs_out->sparse_tensors.size()))); - ARROW_ASSIGN_OR_RAISE(auto matrix, unwrap_sparse_csc_matrix(elem)); - blobs_out->sparse_tensors.push_back(matrix); - } else if (is_sparse_csf_tensor(elem)) { - RETURN_NOT_OK(builder->AppendSparseCSFTensor( - static_cast(blobs_out->sparse_tensors.size()))); - ARROW_ASSIGN_OR_RAISE(auto tensor, unwrap_sparse_csf_tensor(elem)); - blobs_out->sparse_tensors.push_back(tensor); - } else { - // Attempt to serialize the object using the custom callback. - PyObject* serialized_object; - // The reference count of serialized_object will be decremented in SerializeDict - RETURN_NOT_OK(CallSerializeCallback(context, elem, &serialized_object)); - RETURN_NOT_OK( - builder->AppendDict(context, serialized_object, recursion_depth, blobs_out)); - } - return Status::OK(); -} - -Status AppendArray(PyObject* context, PyArrayObject* array, SequenceBuilder* builder, - int32_t recursion_depth, SerializedPyObject* blobs_out) { - int dtype = PyArray_TYPE(array); - switch (dtype) { - case NPY_UINT8: - case NPY_INT8: - case NPY_UINT16: - case NPY_INT16: - case NPY_UINT32: - case NPY_INT32: - case NPY_UINT64: - case NPY_INT64: - case NPY_HALF: - case NPY_FLOAT: - case NPY_DOUBLE: { - RETURN_NOT_OK( - builder->AppendNdarray(static_cast(blobs_out->ndarrays.size()))); - std::shared_ptr tensor; - RETURN_NOT_OK(NdarrayToTensor(default_memory_pool(), - reinterpret_cast(array), {}, &tensor)); - blobs_out->ndarrays.push_back(tensor); - } break; - default: { - PyObject* serialized_object; - // The reference count of serialized_object will be decremented in SerializeDict - RETURN_NOT_OK(CallSerializeCallback(context, reinterpret_cast(array), - &serialized_object)); - RETURN_NOT_OK(builder->AppendDict(context, serialized_object, recursion_depth + 1, - blobs_out)); - } - } - return Status::OK(); -} - -std::shared_ptr MakeBatch(std::shared_ptr data) { - auto field = std::make_shared("list", data->type()); - auto schema = ::arrow::schema({field}); - return RecordBatch::Make(schema, data->length(), {data}); -} - -Status SerializeObject(PyObject* context, PyObject* sequence, SerializedPyObject* out) { - PyAcquireGIL lock; - SequenceBuilder builder; - RETURN_NOT_OK(internal::VisitIterable( - sequence, [&](PyObject* obj, bool* keep_going /* unused */) { - return Append(context, obj, &builder, 0, out); - })); - std::shared_ptr array; - RETURN_NOT_OK(builder.Finish(&array)); - out->batch = MakeBatch(array); - return Status::OK(); -} - -Status SerializeNdarray(std::shared_ptr tensor, SerializedPyObject* out) { - std::shared_ptr array; - SequenceBuilder builder; - RETURN_NOT_OK(builder.AppendNdarray(static_cast(out->ndarrays.size()))); - out->ndarrays.push_back(tensor); - RETURN_NOT_OK(builder.Finish(&array)); - out->batch = MakeBatch(array); - return Status::OK(); -} - -Status WriteNdarrayHeader(std::shared_ptr dtype, - const std::vector& shape, int64_t tensor_num_bytes, - io::OutputStream* dst) { - auto empty_tensor = std::make_shared( - dtype, std::make_shared(nullptr, tensor_num_bytes), shape); - SerializedPyObject serialized_tensor; - RETURN_NOT_OK(SerializeNdarray(empty_tensor, &serialized_tensor)); - return serialized_tensor.WriteTo(dst); -} - -SerializedPyObject::SerializedPyObject() - : ipc_options(ipc::IpcWriteOptions::Defaults()) {} - -Status SerializedPyObject::WriteTo(io::OutputStream* dst) { - int32_t num_tensors = static_cast(this->tensors.size()); - int32_t num_sparse_tensors = static_cast(this->sparse_tensors.size()); - int32_t num_ndarrays = static_cast(this->ndarrays.size()); - int32_t num_buffers = static_cast(this->buffers.size()); - RETURN_NOT_OK( - dst->Write(reinterpret_cast(&num_tensors), sizeof(int32_t))); - RETURN_NOT_OK( - dst->Write(reinterpret_cast(&num_sparse_tensors), sizeof(int32_t))); - RETURN_NOT_OK( - dst->Write(reinterpret_cast(&num_ndarrays), sizeof(int32_t))); - RETURN_NOT_OK( - dst->Write(reinterpret_cast(&num_buffers), sizeof(int32_t))); - - // Align stream to 8-byte offset - RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kArrowIpcAlignment)); - RETURN_NOT_OK(ipc::WriteRecordBatchStream({this->batch}, this->ipc_options, dst)); - - // Align stream to 64-byte offset so tensor bodies are 64-byte aligned - RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); - - int32_t metadata_length; - int64_t body_length; - for (const auto& tensor : this->tensors) { - RETURN_NOT_OK(ipc::WriteTensor(*tensor, dst, &metadata_length, &body_length)); - RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); - } - - for (const auto& sparse_tensor : this->sparse_tensors) { - RETURN_NOT_OK( - ipc::WriteSparseTensor(*sparse_tensor, dst, &metadata_length, &body_length)); - RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); - } - - for (const auto& tensor : this->ndarrays) { - RETURN_NOT_OK(ipc::WriteTensor(*tensor, dst, &metadata_length, &body_length)); - RETURN_NOT_OK(ipc::AlignStream(dst, ipc::kTensorAlignment)); - } - - for (const auto& buffer : this->buffers) { - int64_t size = buffer->size(); - RETURN_NOT_OK(dst->Write(reinterpret_cast(&size), sizeof(int64_t))); - RETURN_NOT_OK(dst->Write(buffer->data(), size)); - } - - return Status::OK(); -} - -namespace { - -Status CountSparseTensors( - const std::vector>& sparse_tensors, PyObject** out) { - OwnedRef num_sparse_tensors(PyDict_New()); - size_t num_coo = 0; - size_t num_csr = 0; - size_t num_csc = 0; - size_t num_csf = 0; - size_t ndim_csf = 0; - - for (const auto& sparse_tensor : sparse_tensors) { - switch (sparse_tensor->format_id()) { - case SparseTensorFormat::COO: - ++num_coo; - break; - case SparseTensorFormat::CSR: - ++num_csr; - break; - case SparseTensorFormat::CSC: - ++num_csc; - break; - case SparseTensorFormat::CSF: - ++num_csf; - ndim_csf += sparse_tensor->ndim(); - break; - } - } - - PyDict_SetItemString(num_sparse_tensors.obj(), "coo", PyLong_FromSize_t(num_coo)); - PyDict_SetItemString(num_sparse_tensors.obj(), "csr", PyLong_FromSize_t(num_csr)); - PyDict_SetItemString(num_sparse_tensors.obj(), "csc", PyLong_FromSize_t(num_csc)); - PyDict_SetItemString(num_sparse_tensors.obj(), "csf", PyLong_FromSize_t(num_csf)); - PyDict_SetItemString(num_sparse_tensors.obj(), "ndim_csf", PyLong_FromSize_t(ndim_csf)); - RETURN_IF_PYERROR(); - - *out = num_sparse_tensors.detach(); - return Status::OK(); -} - -} // namespace - -Status SerializedPyObject::GetComponents(MemoryPool* memory_pool, PyObject** out) { - PyAcquireGIL py_gil; - - OwnedRef result(PyDict_New()); - PyObject* buffers = PyList_New(0); - PyObject* num_sparse_tensors = nullptr; - - // TODO(wesm): Not sure how pedantic we need to be about checking the return - // values of these functions. There are other places where we do not check - // PyDict_SetItem/SetItemString return value, but these failures would be - // quite esoteric - PyDict_SetItemString(result.obj(), "num_tensors", - PyLong_FromSize_t(this->tensors.size())); - RETURN_NOT_OK(CountSparseTensors(this->sparse_tensors, &num_sparse_tensors)); - PyDict_SetItemString(result.obj(), "num_sparse_tensors", num_sparse_tensors); - PyDict_SetItemString(result.obj(), "ndim_csf", num_sparse_tensors); - PyDict_SetItemString(result.obj(), "num_ndarrays", - PyLong_FromSize_t(this->ndarrays.size())); - PyDict_SetItemString(result.obj(), "num_buffers", - PyLong_FromSize_t(this->buffers.size())); - PyDict_SetItemString(result.obj(), "data", buffers); - RETURN_IF_PYERROR(); - - Py_DECREF(buffers); - - auto PushBuffer = [&buffers](const std::shared_ptr& buffer) { - PyObject* wrapped_buffer = wrap_buffer(buffer); - RETURN_IF_PYERROR(); - if (PyList_Append(buffers, wrapped_buffer) < 0) { - Py_DECREF(wrapped_buffer); - RETURN_IF_PYERROR(); - } - Py_DECREF(wrapped_buffer); - return Status::OK(); - }; - - constexpr int64_t kInitialCapacity = 1024; - - // Write the record batch describing the object structure - py_gil.release(); - ARROW_ASSIGN_OR_RAISE(auto stream, - io::BufferOutputStream::Create(kInitialCapacity, memory_pool)); - RETURN_NOT_OK( - ipc::WriteRecordBatchStream({this->batch}, this->ipc_options, stream.get())); - ARROW_ASSIGN_OR_RAISE(auto buffer, stream->Finish()); - py_gil.acquire(); - - RETURN_NOT_OK(PushBuffer(buffer)); - - // For each tensor, get a metadata buffer and a buffer for the body - for (const auto& tensor : this->tensors) { - ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, - ipc::GetTensorMessage(*tensor, memory_pool)); - RETURN_NOT_OK(PushBuffer(message->metadata())); - RETURN_NOT_OK(PushBuffer(message->body())); - } - - // For each sparse tensor, get a metadata buffer and buffers containing index and data - for (const auto& sparse_tensor : this->sparse_tensors) { - ipc::IpcPayload payload; - RETURN_NOT_OK(ipc::GetSparseTensorPayload(*sparse_tensor, memory_pool, &payload)); - RETURN_NOT_OK(PushBuffer(payload.metadata)); - for (const auto& body : payload.body_buffers) { - RETURN_NOT_OK(PushBuffer(body)); - } - } - - // For each ndarray, get a metadata buffer and a buffer for the body - for (const auto& ndarray : this->ndarrays) { - ARROW_ASSIGN_OR_RAISE(std::unique_ptr message, - ipc::GetTensorMessage(*ndarray, memory_pool)); - RETURN_NOT_OK(PushBuffer(message->metadata())); - RETURN_NOT_OK(PushBuffer(message->body())); - } - - for (const auto& buf : this->buffers) { - RETURN_NOT_OK(PushBuffer(buf)); - } - - *out = result.detach(); - return Status::OK(); -} - -} // namespace py -} // namespace arrow