diff --git a/.github/workflows/cpp_extra.yml b/.github/workflows/cpp_extra.yml index 891aab51ccb..a391d497dde 100644 --- a/.github/workflows/cpp_extra.yml +++ b/.github/workflows/cpp_extra.yml @@ -126,6 +126,8 @@ jobs: title: AMD64 Ubuntu Meson # TODO: We should remove this "continue-on-error: true" once GH-47207 is resolved - continue-on-error: true + envs: + - DEBIAN=13 image: debian-cpp run-options: >- -e CMAKE_CXX_STANDARD=23 @@ -158,10 +160,16 @@ jobs: env: ARCHERY_DOCKER_USER: ${{ secrets.DOCKERHUB_USER }} ARCHERY_DOCKER_PASSWORD: ${{ secrets.DOCKERHUB_TOKEN }} + ENVS: ${{ toJSON(matrix.envs) }} run: | # GH-40558: reduce ASLR to avoid ASAN/LSAN crashes sudo sysctl -w vm.mmap_rnd_bits=28 source ci/scripts/util_enable_core_dumps.sh + if [ "${ENVS}" != "null" ]; then + echo "${ENVS}" | jq -r '.[]' | while read env; do + echo "${env}" >> .env + done + fi archery docker run ${{ matrix.run-options || '' }} ${{ matrix.image }} - name: Docker Push if: >- diff --git a/ci/docker/debian-13-cpp.dockerfile b/ci/docker/debian-13-cpp.dockerfile new file mode 100644 index 00000000000..3e5c645c81a --- /dev/null +++ b/ci/docker/debian-13-cpp.dockerfile @@ -0,0 +1,144 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +ARG arch=amd64 +FROM ${arch}/debian:13 +ARG arch + +ENV DEBIAN_FRONTEND noninteractive + +ARG llvm +RUN apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + apt-transport-https \ + ca-certificates \ + gnupg \ + lsb-release \ + wget && \ + if [ ${llvm} -ge 20 ]; then \ + wget -O /usr/share/keyrings/llvm-snapshot.asc \ + https://apt.llvm.org/llvm-snapshot.gpg.key && \ + (echo "Types: deb"; \ + echo "URIs: https://apt.llvm.org/$(lsb_release --codename --short)/"; \ + echo "Suites: llvm-toolchain-$(lsb_release --codename --short)-${llvm}"; \ + echo "Components: main"; \ + echo "Signed-By: /usr/share/keyrings/llvm-snapshot.asc") | \ + tee /etc/apt/sources.list.d/llvm.sources; \ + fi && \ + apt-get update -y -q && \ + apt-get install -y -q --no-install-recommends \ + autoconf \ + ccache \ + clang-${llvm} \ + cmake \ + curl \ + g++ \ + gcc \ + gdb \ + git \ + libbenchmark-dev \ + libboost-filesystem-dev \ + libboost-system-dev \ + libbrotli-dev \ + libbz2-dev \ + libc-ares-dev \ + libcurl4-openssl-dev \ + libgflags-dev \ + libgmock-dev \ + libgoogle-glog-dev \ + libgrpc++-dev \ + libidn2-dev \ + libkrb5-dev \ + libldap-dev \ + liblz4-dev \ + libnghttp2-dev \ + libprotobuf-dev \ + libprotoc-dev \ + libpsl-dev \ + libre2-dev \ + librtmp-dev \ + libsnappy-dev \ + libsqlite3-dev \ + libssh-dev \ + libssh2-1-dev \ + libssl-dev \ + libthrift-dev \ + libutf8proc-dev \ + libxml2-dev \ + libxsimd-dev \ + libzstd-dev \ + llvm-${llvm}-dev \ + make \ + ninja-build \ + nlohmann-json3-dev \ + npm \ + opentelemetry-cpp-dev \ + pkg-config \ + protobuf-compiler-grpc \ + python3-dev \ + python3-pip \ + python3-venv \ + rapidjson-dev \ + rsync \ + tzdata \ + zlib1g-dev && \ + apt-get clean && \ + rm -rf /var/lib/apt/lists/* + +COPY ci/scripts/install_minio.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_minio.sh latest /usr/local + +COPY ci/scripts/install_gcs_testbench.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_gcs_testbench.sh default + +COPY ci/scripts/install_azurite.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_azurite.sh + +COPY ci/scripts/install_sccache.sh /arrow/ci/scripts/ +RUN /arrow/ci/scripts/install_sccache.sh unknown-linux-musl /usr/local/bin + +# Prioritize system packages and local installation. +ENV ARROW_ACERO=ON \ + ARROW_AZURE=ON \ + ARROW_BUILD_TESTS=ON \ + ARROW_DATASET=ON \ + ARROW_DEPENDENCY_SOURCE=SYSTEM \ + ARROW_DATASET=ON \ + ARROW_FLIGHT=ON \ + ARROW_FLIGHT_SQL=ON \ + ARROW_GANDIVA=ON \ + ARROW_GCS=ON \ + ARROW_HOME=/usr/local \ + ARROW_JEMALLOC=ON \ + ARROW_ORC=ON \ + ARROW_PARQUET=ON \ + ARROW_S3=ON \ + ARROW_SUBSTRAIT=ON \ + ARROW_USE_CCACHE=ON \ + ARROW_WITH_BROTLI=ON \ + ARROW_WITH_BZ2=ON \ + ARROW_WITH_LZ4=ON \ + ARROW_WITH_OPENTELEMETRY=ON \ + ARROW_WITH_SNAPPY=ON \ + ARROW_WITH_ZLIB=ON \ + ARROW_WITH_ZSTD=ON \ + AWSSDK_SOURCE=BUNDLED \ + Azure_SOURCE=BUNDLED \ + google_cloud_cpp_storage_SOURCE=BUNDLED \ + ORC_SOURCE=BUNDLED \ + PATH=/usr/lib/ccache/:$PATH \ + PYTHON=python3 diff --git a/cpp/src/arrow/acero/asof_join_node.cc b/cpp/src/arrow/acero/asof_join_node.cc index 55fa45543e4..4edcf1d7c10 100644 --- a/cpp/src/arrow/acero/asof_join_node.cc +++ b/cpp/src/arrow/acero/asof_join_node.cc @@ -135,8 +135,6 @@ static inline uint64_t key_value(T t) { return static_cast(t); } -class AsofJoinNode; - #ifndef NDEBUG // Get the debug-stream associated with the as-of-join node std::ostream* GetDebugStream(AsofJoinNode* node); @@ -384,80 +382,8 @@ struct MemoStore { } }; -// a specialized higher-performance variation of Hashing64 logic from hash_join_node -// the code here avoids recreating objects that are independent of each batch processed -class KeyHasher { - friend class AsofJoinNode; - - static constexpr int kMiniBatchLength = arrow::util::MiniBatch::kMiniBatchLength; - - public: - // the key hasher is not thread-safe and is only used in sequential batch processing - // of the input it is associated with - KeyHasher(size_t index, const std::vector& indices) - : index_(index), - indices_(indices), - metadata_(indices.size()), - batch_(NULLPTR), - hashes_(), - ctx_(), - column_arrays_(), - stack_() { - ctx_.stack = &stack_; - column_arrays_.resize(indices.size()); - } - - Status Init(ExecContext* exec_context, const std::shared_ptr& schema) { - ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); - const auto& fields = schema->fields(); - for (size_t k = 0; k < metadata_.size(); k++) { - ARROW_ASSIGN_OR_RAISE(metadata_[k], - ColumnMetadataFromDataType(fields[indices_[k]]->type())); - } - return stack_.Init(exec_context->memory_pool(), - 4 * kMiniBatchLength * sizeof(uint32_t)); - } - - // invalidate cached hashes for batch - required when it changes - // only this method can be called concurrently with HashesFor - void Invalidate() { batch_ = NULLPTR; } - - // compute and cache a hash for each row of the given batch - const std::vector& HashesFor(const RecordBatch* batch) { - if (batch_ == batch) { - return hashes_; // cache hit - return cached hashes - } - Invalidate(); - size_t batch_length = batch->num_rows(); - hashes_.resize(batch_length); - for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { - int64_t length = std::min(static_cast(batch_length - i), - static_cast(kMiniBatchLength)); - for (size_t k = 0; k < indices_.size(); k++) { - auto array_data = batch->column_data(indices_[k]); - column_arrays_[k] = - ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); - } - // write directly to the cache - Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); - } - DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ", - compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl)); - batch_ = batch; // associate cache with current batch - return hashes_; - } - - private: - AsofJoinNode* node_ = nullptr; // avoids circular dependency during initialization - size_t index_; - std::vector indices_; - std::vector metadata_; - std::atomic batch_; - std::vector hashes_; - LightContext ctx_; - std::vector column_arrays_; - arrow::util::TempVectorStack stack_; -}; +class KeyHasher; +class AsofJoinNode; class BackpressureController : public BackpressureControl { public: @@ -484,140 +410,50 @@ class InputState : public util::SerialSequencingQueue::Processor { KeyHasher* key_hasher, AsofJoinNode* node, BackpressureHandler handler, const std::shared_ptr& schema, const col_index_t time_col_index, - const std::vector& key_col_index) - : sequencer_(util::SerialSequencingQueue::Make(this)), - queue_(std::move(handler)), - schema_(schema), - time_col_index_(time_col_index), - key_col_index_(key_col_index), - time_type_id_(schema_->fields()[time_col_index_]->type()->id()), - key_type_id_(key_col_index.size()), - key_hasher_(key_hasher), - node_(node), - index_(index), - must_hash_(must_hash), - may_rehash_(may_rehash), - tolerance_(tolerance), - memo_(DEBUG_ADD(/*no_future=*/index == 0 || !tolerance.positive, node, index)) { - for (size_t k = 0; k < key_col_index_.size(); k++) { - key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); - } - } + const std::vector& key_col_index); static Result> Make( size_t index, TolType tolerance, bool must_hash, bool may_rehash, KeyHasher* key_hasher, ExecNode* asof_input, AsofJoinNode* asof_node, std::atomic& backpressure_counter, const std::shared_ptr& schema, const col_index_t time_col_index, - const std::vector& key_col_index) { - constexpr size_t low_threshold = 4, high_threshold = 8; - std::unique_ptr backpressure_control = - std::make_unique( - /*node=*/asof_input, /*output=*/asof_node, backpressure_counter); - ARROW_ASSIGN_OR_RAISE( - auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold, - std::move(backpressure_control))); - return std::make_unique(index, tolerance, must_hash, may_rehash, - key_hasher, asof_node, std::move(handler), schema, - time_col_index, key_col_index); - } + const std::vector& key_col_index); - col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields) { - src_to_dst_.resize(schema_->num_fields()); - for (int i = 0; i < schema_->num_fields(); ++i) - if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) - src_to_dst_[i] = dst_offset++; - return dst_offset; - } + col_index_t InitSrcToDstMapping(col_index_t dst_offset, bool skip_time_and_key_fields); - const std::optional& MapSrcToDst(col_index_t src) const { - return src_to_dst_[src]; - } + const std::optional& MapSrcToDst(col_index_t src) const; - bool IsTimeOrKeyColumn(col_index_t i) const { - DCHECK_LT(i, schema_->num_fields()); - return (i == time_col_index_) || std_has(key_col_index_, i); - } + bool IsTimeOrKeyColumn(col_index_t i) const; // Gets the latest row index, assuming the queue isn't empty - row_index_t GetLatestRow() const { return latest_ref_row_; } + row_index_t GetLatestRow() const; - bool Empty() const { - // cannot be empty if ref row is >0 -- can avoid slow queue lock - // below - if (latest_ref_row_ > 0) return false; - return queue_.Empty(); - } + bool Empty() const; // true when the queue is empty and, when memo may have future entries (the case of a // positive tolerance), when the memo is empty. // used when checking whether RHS is up to date with LHS. // NOTE: The emptiness must be decided by a single call to Empty() in caller, due to the // potential race with Push(), see GH-41614. - bool CurrentEmpty(bool empty) const { - return memo_.no_future_ ? empty : (memo_.times_.empty() && empty); - } + bool CurrentEmpty(bool empty) const; // in case memo may not have future entries (the case of a non-positive tolerance), // returns the latest time (which is current); otherwise, returns the current time. // used when checking whether RHS is up to date with LHS. - OnType GetCurrentTime() const { - return memo_.no_future_ ? GetLatestTime() : static_cast(memo_.current_time_); - } + OnType GetCurrentTime() const; - int total_batches() const { return total_batches_; } + int total_batches() const; // Gets latest batch (precondition: must not be empty) const std::shared_ptr& GetLatestBatch() const { return queue_.Front(); } -#define LATEST_VAL_CASE(id, val) \ - case Type::id: { \ - using T = typename TypeIdTraits::Type; \ - using CType = typename TypeTraits::CType; \ - return val(data->GetValues(1)[row]); \ - } - - inline ByType GetLatestKey() const { - return GetKey(GetLatestBatch().get(), latest_ref_row_); - } + inline ByType GetLatestKey() const; - inline ByType GetKey(const RecordBatch* batch, row_index_t row) const { - if (must_hash_) { - // Query the key hasher. This may hit cache, which must be valid for the batch. - // Therefore, the key hasher is invalidated when a new batch is pushed - see - // `InputState::Push`. - return key_hasher_->HashesFor(batch)[row]; - } - if (key_col_index_.size() == 0) { - return 0; - } - auto data = batch->column_data(key_col_index_[0]); - switch (key_type_id_[0]) { - LATEST_VAL_CASE(INT8, key_value) - LATEST_VAL_CASE(INT16, key_value) - LATEST_VAL_CASE(INT32, key_value) - LATEST_VAL_CASE(INT64, key_value) - LATEST_VAL_CASE(UINT8, key_value) - LATEST_VAL_CASE(UINT16, key_value) - LATEST_VAL_CASE(UINT32, key_value) - LATEST_VAL_CASE(UINT64, key_value) - LATEST_VAL_CASE(DATE32, key_value) - LATEST_VAL_CASE(DATE64, key_value) - LATEST_VAL_CASE(TIME32, key_value) - LATEST_VAL_CASE(TIME64, key_value) - LATEST_VAL_CASE(TIMESTAMP, key_value) - default: - DCHECK(false); - return 0; // cannot happen - } - } + inline ByType GetKey(const RecordBatch* batch, row_index_t row) const; - inline OnType GetLatestTime() const { - return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_, - latest_ref_row_); - } + inline OnType GetLatestTime() const; #undef LATEST_VAL_CASE @@ -655,119 +491,27 @@ class InputState : public util::SerialSequencingQueue::Processor { // Returns true if updates were made, false if not. // NOTE: The emptiness must be decided by a single call to Empty() in caller, due to the // potential race with Push(), see GH-41614. - Result AdvanceAndMemoize(OnType ts, bool empty) { - // Advance the right side row index until we reach the latest right row (for each key) - // for the given left timestamp. - DEBUG_SYNC(node_, "Advancing input ", index_, DEBUG_MANIP(std::endl)); - - // Check if already updated for TS (or if there is no latest) - if (empty) { // can't advance if empty and no future entries - return memo_.no_future_ ? false : memo_.RemoveEntriesWithLesserTime(ts); - } + Result AdvanceAndMemoize(OnType ts, bool empty); - // Not updated. Try to update and possibly advance. - bool advanced, updated = false; - OnType latest_time; - do { - latest_time = GetLatestTime(); - // if Advance() returns true, then the latest_ts must also be valid - // Keep advancing right table until we hit the latest row that has - // timestamp <= ts. This is because we only need the latest row for the - // match given a left ts. - if (latest_time > tolerance_.Horizon(ts)) { // hit a distant timestamp - DEBUG_SYNC(node_, "Advancing input ", index_, " hit distant time=", latest_time, - " at=", ts, DEBUG_MANIP(std::endl)); - // if no future entries, which would have been earlier than the distant time, no - // need to queue it - if (memo_.future_entries_.empty()) break; - } - auto rb = GetLatestBatch(); - if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { - must_hash_ = true; - may_rehash_ = false; - Rehash(); - } - memo_.Store(rb, latest_ref_row_, latest_time, DEBUG_ADD(GetLatestKey(), ts)); - // negative tolerance means a last-known entry was stored - set `updated` to `true` - updated = memo_.no_future_; - ARROW_ASSIGN_OR_RAISE(advanced, Advance()); - } while (advanced); - if (!memo_.no_future_ && latest_time >= ts) { - // `updated` was not modified in the loop from the initial `false` value; set it now - updated = memo_.RemoveEntriesWithLesserTime(ts); - } - DEBUG_SYNC(node_, "Advancing input ", index_, " updated=", updated, - DEBUG_MANIP(std::endl)); - return updated; - } - Status InsertBatch(ExecBatch batch) { - return sequencer_->InsertBatch(std::move(batch)); - } - - Status Process(ExecBatch batch) override { - auto rb = *batch.ToRecordBatch(schema_); - DEBUG_SYNC(node_, "received batch from input ", index_, ":", DEBUG_MANIP(std::endl), - rb->ToString(), DEBUG_MANIP(std::endl)); - return Push(rb); - } - void Rehash() { - DEBUG_SYNC(node_, "rehashing for input ", index_, ":", DEBUG_MANIP(std::endl)); - MemoStore new_memo(DEBUG_ADD(memo_.no_future_, node_, index_)); - new_memo.current_time_ = (OnType)memo_.current_time_; - for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) { - auto& entry = e->second; - auto new_key = GetKey(entry.batch.get(), entry.row); - DEBUG_SYNC(node_, " ", e->first, " to ", new_key, DEBUG_MANIP(std::endl)); - new_memo.entries_[new_key].swap(entry); - auto fe = memo_.future_entries_.find(e->first); - if (fe != memo_.future_entries_.end()) { - new_memo.future_entries_[new_key].swap(fe->second); - } - } - memo_.times_.swap(new_memo.times_); - memo_.swap(new_memo); - } + Status InsertBatch(ExecBatch batch); - Status Push(const std::shared_ptr& rb) { - if (rb->num_rows() > 0) { - key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache - queue_.Push(rb); // only now push batch for processing - } else { - ++batches_processed_; // don't enqueue empty batches, just record as processed - } - return Status::OK(); - } + Status Process(ExecBatch batch) override; - std::optional GetMemoEntryForKey(ByType key) { - return memo_.GetEntryForKey(key); - } + void Rehash(); - std::optional GetMemoTimeForKey(ByType key) { - auto r = GetMemoEntryForKey(key); - if (r.has_value()) { - return (*r)->time; - } else { - return std::nullopt; - } - } + Status Push(const std::shared_ptr& rb); - void RemoveMemoEntriesWithLesserTime(OnType ts) { - memo_.RemoveEntriesWithLesserTime(ts); - } + std::optional GetMemoEntryForKey(ByType key); - const std::shared_ptr& get_schema() const { return schema_; } + std::optional GetMemoTimeForKey(ByType key); - void set_total_batches(int n) { - DCHECK_GE(n, 0); - DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; - total_batches_ = n; - } + void RemoveMemoEntriesWithLesserTime(OnType ts); - Status ForceShutdown() { - // Force the upstream input node to unpause. Necessary to avoid deadlock when we - // terminate the process thread - return queue_.ForceShutdown(); - } + const std::shared_ptr& get_schema() const; + + void set_total_batches(int n); + + Status ForceShutdown(); private: // ExecBatch Sequencer @@ -811,121 +555,37 @@ class InputState : public util::SerialSequencingQueue::Processor { std::vector> src_to_dst_; }; -/// Wrapper around UnmaterializedCompositeTable that knows how to emplace -/// the join row-by-row -template -class CompositeTableBuilder { - using SliceBuilder = UnmaterializedSliceBuilder; - using CompositeTable = UnmaterializedCompositeTable; - - public: - NDEBUG_EXPLICIT CompositeTableBuilder( - const std::vector>& inputs, - const std::shared_ptr& schema, arrow::MemoryPool* pool, - DEBUG_ADD(size_t n_tables, AsofJoinNode* node)) - : unmaterialized_table(InitUnmaterializedTable(schema, inputs, pool)), - DEBUG_ADD(n_tables_(n_tables), node_(node)) { - DCHECK_GE(n_tables_, 1); - DCHECK_LE(n_tables_, MAX_TABLES); - } - - size_t n_rows() const { return unmaterialized_table.Size(); } - - // Adds the latest row from the input state as a new composite reference row - // - LHS must have a valid key,timestep,and latest rows - // - RHS must have valid data memo'ed for the key - void Emplace(std::vector>& in, TolType tolerance) { - DCHECK_EQ(in.size(), n_tables_); - - // Get the LHS key - ByType key = in[0]->GetLatestKey(); - - // Add row and setup LHS - // (the LHS state comes just from the latest row of the LHS table) - DCHECK(!in[0]->Empty()); - const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); - row_index_t lhs_latest_row = in[0]->GetLatestRow(); - OnType lhs_latest_time = in[0]->GetLatestTime(); - if (0 == lhs_latest_row) { - // On the first row of the batch, we resize the destination. - // The destination size is dictated by the size of the LHS batch. - row_index_t new_batch_size = lhs_latest_batch->num_rows(); - row_index_t new_capacity = unmaterialized_table.Size() + new_batch_size; - if (unmaterialized_table.capacity() < new_capacity) { - unmaterialized_table.reserve(new_capacity); - } - } - - SliceBuilder new_row{&unmaterialized_table}; +// a specialized higher-performance variation of Hashing64 logic from hash_join_node +// the code here avoids recreating objects that are independent of each batch processed +class KeyHasher { + friend class AsofJoinNode; - // Each item represents a portion of the columns of the output table - new_row.AddEntry(lhs_latest_batch, lhs_latest_row, lhs_latest_row + 1); + static constexpr int kMiniBatchLength = arrow::util::MiniBatch::kMiniBatchLength; - DEBUG_SYNC(node_, "Emplace: key=", key, " lhs_latest_row=", lhs_latest_row, - " lhs_latest_time=", lhs_latest_time, DEBUG_MANIP(std::endl)); + public: + // the key hasher is not thread-safe and is only used in sequential batch processing + // of the input it is associated with + KeyHasher(size_t index, const std::vector& indices); - // Get the state for that key from all on the RHS -- assumes it's up to date - // (the RHS state comes from the memoized row references) - for (size_t i = 1; i < in.size(); ++i) { - std::optional opt_entry = in[i]->GetMemoEntryForKey(key); -#ifndef NDEBUG - { - bool has_entry = opt_entry.has_value(); - OnType entry_time = has_entry ? (*opt_entry)->time : TolType::kMinValue; - row_index_t entry_row = has_entry ? (*opt_entry)->row : 0; - bool accepted = has_entry && tolerance.Accepts(lhs_latest_time, entry_time); - DEBUG_SYNC(node_, " i=", i, " has_entry=", has_entry, " time=", entry_time, - " row=", entry_row, " accepted=", accepted, DEBUG_MANIP(std::endl)); - } -#endif - if (opt_entry.has_value()) { - DCHECK(*opt_entry); - if (tolerance.Accepts(lhs_latest_time, (*opt_entry)->time)) { - // Have a valid entry - const MemoStore::Entry* entry = *opt_entry; - new_row.AddEntry(entry->batch, entry->row, entry->row + 1); - continue; - } - } - new_row.AddEntry(nullptr, 0, 1); - } - new_row.Finalize(); - } + Status Init(ExecContext* exec_context, const std::shared_ptr& schema); - // Materializes the current reference table into a target record batch - Result>> Materialize() { - return unmaterialized_table.Materialize(); - } + // invalidate cached hashes for batch - required when it changes + // only this method can be called concurrently with HashesFor + void Invalidate(); - // Returns true if there are no rows - bool empty() const { return unmaterialized_table.Empty(); } + // compute and cache a hash for each row of the given batch + const std::vector& HashesFor(const RecordBatch* batch); private: - CompositeTable unmaterialized_table; - - // Total number of tables in the composite table - size_t n_tables_; - -#ifndef NDEBUG - // Owning node - AsofJoinNode* node_; -#endif - - static CompositeTable InitUnmaterializedTable( - const std::shared_ptr& schema, - const std::vector>& inputs, arrow::MemoryPool* pool) { - std::unordered_map> dst_to_src; - for (size_t i = 0; i < inputs.size(); i++) { - auto& input = inputs[i]; - for (int src = 0; src < input->get_schema()->num_fields(); src++) { - auto dst = input->MapSrcToDst(src); - if (dst.has_value()) { - dst_to_src[dst.value()] = std::make_pair(static_cast(i), src); - } - } - } - return CompositeTable{schema, inputs.size(), dst_to_src, pool}; - } + AsofJoinNode* node_ = nullptr; // avoids circular dependency during initialization + size_t index_; + std::vector indices_; + std::vector metadata_; + std::atomic batch_; + std::vector hashes_; + LightContext ctx_; + std::vector column_arrays_; + arrow::util::TempVectorStack stack_; }; // TODO: Currently, AsofJoinNode uses 64-bit hashing which leads to a non-negligible @@ -941,93 +601,16 @@ class AsofJoinNode : public ExecNode { bool any_advanced; bool all_up_to_date_with_lhs; }; + // Advances the RHS as far as possible to be up to date for the current LHS timestamp, // and checks if all RHS are up to date with LHS. The reason they have to be performed // together is that they both depend on the emptiness of the RHS, which can be changed // by Push() executing in another thread. - Result UpdateRhs() { - auto& lhs = *state_.at(0); - auto lhs_latest_time = lhs.GetLatestTime(); - RhsUpdateState update_state{/*any_advanced=*/false, /*all_up_to_date_with_lhs=*/true}; - for (size_t i = 1; i < state_.size(); ++i) { - auto& rhs = *state_[i]; - - // Obtain RHS emptiness once for subsequent AdvanceAndMemoize() and CurrentEmpty(). - bool rhs_empty = rhs.Empty(); - // Obtain RHS current time here because AdvanceAndMemoize() can change the - // emptiness. - OnType rhs_current_time = rhs_empty ? OnType{} : rhs.GetLatestTime(); - - ARROW_ASSIGN_OR_RAISE(bool advanced, - rhs.AdvanceAndMemoize(lhs_latest_time, rhs_empty)); - update_state.any_advanced |= advanced; - - if (update_state.all_up_to_date_with_lhs && !rhs.Finished()) { - // If RHS is finished, then we know it's up to date - if (rhs.CurrentEmpty(rhs_empty)) { - // RHS isn't finished, but is empty --> not up to date - update_state.all_up_to_date_with_lhs = false; - } else if (lhs_latest_time > rhs_current_time) { - // RHS isn't up to date (and not finished) - update_state.all_up_to_date_with_lhs = false; - } - } - } - return update_state; - } - - Result> ProcessInner() { - DCHECK(!state_.empty()); - auto& lhs = *state_.at(0); - - // Construct new target table if needed - CompositeTableBuilder dst(state_, output_schema_, - plan()->query_context()->memory_pool(), - DEBUG_ADD(state_.size(), this)); - - // Generate rows into the dst table until we either run out of data or hit the row - // limit, or run out of input - for (;;) { - // If LHS is finished or empty then there's nothing we can do here - if (lhs.Finished() || lhs.Empty()) break; - - ARROW_ASSIGN_OR_RAISE(auto rhs_update_state, UpdateRhs()); - - // If we have received enough inputs to produce the next output batch - // (decided by IsUpToDateWithLhsRow), we will perform the join and - // materialize the output batch. The join is done by advancing through - // the LHS and adding joined row to rows_ (done by Emplace). Finally, - // input batches that are no longer needed are removed to free up memory. - if (rhs_update_state.all_up_to_date_with_lhs) { - dst.Emplace(state_, tolerance_); - ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); - if (!advanced) break; // if we can't advance LHS, we're done for this batch - } else { - if (!rhs_update_state.any_advanced) break; // need to wait for new data - } - } + Result UpdateRhs(); - // Prune memo entries that have expired (to bound memory consumption) - if (!lhs.Empty()) { - for (size_t i = 1; i < state_.size(); ++i) { - OnType ts = tolerance_.Expiry(lhs.GetLatestTime()); - if (ts != TolType::kMinValue) { - state_[i]->RemoveMemoEntriesWithLesserTime(ts); - } - } - } - - // Emit the batch - if (dst.empty()) { - return NULLPTR; - } else { - ARROW_ASSIGN_OR_RAISE(auto out, dst.Materialize()); - return out.has_value() ? out.value() : NULLPTR; - } - } + Result> ProcessInner(); #ifdef ARROW_ENABLE_THREADING - template struct Defer { Callable callable; @@ -1035,85 +618,15 @@ class AsofJoinNode : public ExecNode { ~Defer() noexcept { callable(); } }; - void EndFromProcessThread(Status st = Status::OK()) { - // We must spawn a new task to transfer off the process thread when - // marking this finished. Otherwise there is a chance that doing so could - // mark the plan finished which may destroy the plan which will destroy this - // node which will cause us to join on ourselves. - ARROW_UNUSED( - plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable { - Defer cleanup([this, &st]() { process_task_.MarkFinished(st); }); - if (st.ok()) { - st = output_->InputFinished(this, batches_produced_); - } - for (const auto& s : state_) { - st &= s->ForceShutdown(); - } - })); - } - - bool CheckEnded() { - if (state_.at(0)->Finished()) { - EndFromProcessThread(); - return false; - } - return true; - } - - bool Process() { - std::lock_guard guard(gate_); - if (!CheckEnded()) { - return false; - } + void EndFromProcessThread(Status st = Status::OK()); - // Process batches while we have data - for (;;) { - Result> result = ProcessInner(); - - if (result.ok()) { - auto out_rb = *result; - if (!out_rb) break; - ExecBatch out_b(*out_rb); - out_b.index = batches_produced_++; - DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), - out_rb->ToString(), DEBUG_MANIP(std::endl)); - Status st = output_->InputReceived(this, std::move(out_b)); - if (!st.ok()) { - EndFromProcessThread(std::move(st)); - } - } else { - EndFromProcessThread(result.status()); - return false; - } - } - - // Report to the output the total batch count, if we've already finished everything - // (there are two places where this can happen: here and InputFinished) - // - // It may happen here in cases where InputFinished was called before we were finished - // producing results (so we didn't know the output size at that time) - if (!CheckEnded()) { - return false; - } + bool CheckEnded(); - // There is no more we can do now but there is still work remaining for later when - // more data arrives. - return true; - } + bool Process(); - void ProcessThread() { - for (;;) { - if (!process_.WaitAndPop()) { - EndFromProcessThread(); - return; - } - if (!Process()) { - return; - } - } - } + void ProcessThread(); - static void ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); } + static void ProcessThreadWrapper(AsofJoinNode* node); #endif public: @@ -1124,117 +637,16 @@ class AsofJoinNode : public ExecNode { std::vector> key_hashers, bool must_hash, bool may_rehash); - Status Init() override { - auto inputs = this->inputs(); - for (size_t i = 0; i < inputs.size(); i++) { - RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(), - inputs[i]->output_schema())); - ARROW_ASSIGN_OR_RAISE( - auto input_state, - InputState::Make(i, tolerance_, must_hash_, may_rehash_, key_hashers_[i].get(), - inputs[i], this, backpressure_counter_, - inputs[i]->output_schema(), indices_of_on_key_[i], - indices_of_by_key_[i])); - state_.push_back(std::move(input_state)); - } - - col_index_t dst_offset = 0; - for (auto& state : state_) - dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + Status Init() override; - return Status::OK(); - } + virtual ~AsofJoinNode(); - virtual ~AsofJoinNode() { -#ifdef ARROW_ENABLE_THREADING - PushProcess(false); - if (process_thread_.joinable()) { - process_thread_.join(); - } -#endif - } + const std::vector& indices_of_on_key(); + const std::vector>& indices_of_by_key(); - const std::vector& indices_of_on_key() { return indices_of_on_key_; } - const std::vector>& indices_of_by_key() { - return indices_of_by_key_; - } - - static Status is_valid_on_field(const std::shared_ptr& field) { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - return Status::OK(); - default: - return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", - field->type()->ToString()); - } - } - - static Status is_valid_by_field(const std::shared_ptr& field) { - switch (field->type()->id()) { - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::STRING: - case Type::LARGE_STRING: - case Type::BINARY: - case Type::LARGE_BINARY: - return Status::OK(); - default: - return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", - field->type()->ToString()); - } - } - - static Status is_valid_data_field(const std::shared_ptr& field) { - switch (field->type()->id()) { - case Type::BOOL: - case Type::INT8: - case Type::INT16: - case Type::INT32: - case Type::INT64: - case Type::UINT8: - case Type::UINT16: - case Type::UINT32: - case Type::UINT64: - case Type::FLOAT: - case Type::DOUBLE: - case Type::DATE32: - case Type::DATE64: - case Type::TIME32: - case Type::TIME64: - case Type::TIMESTAMP: - case Type::STRING: - case Type::LARGE_STRING: - case Type::BINARY: - case Type::LARGE_BINARY: - return Status::OK(); - default: - return Status::Invalid("Unsupported type for data field ", field->name(), " : ", - field->type()->ToString()); - } - } + static Status is_valid_on_field(const std::shared_ptr& field); + static Status is_valid_by_field(const std::shared_ptr& field); + static Status is_valid_data_field(const std::shared_ptr& field); /// \brief Make the output schema of an as-of-join node /// @@ -1244,331 +656,1076 @@ class AsofJoinNode : public ExecNode { static arrow::Result> MakeOutputSchema( const std::vector> input_schema, const std::vector& indices_of_on_key, - const std::vector>& indices_of_by_key) { - std::vector> fields; - - size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); - const DataType* on_key_type = NULLPTR; - std::vector by_key_type(n_by, NULLPTR); - // Take all non-key, non-time RHS fields - for (size_t j = 0; j < input_schema.size(); ++j) { - const auto& on_field_ix = indices_of_on_key[j]; - const auto& by_field_ix = indices_of_by_key[j]; - - if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { - return Status::Invalid("Missing join key on table ", j); - } + const std::vector>& indices_of_by_key); - const auto& on_field = input_schema[j]->fields()[on_field_ix]; - std::vector by_field(n_by); - for (size_t k = 0; k < n_by; k++) { - by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); - } + static inline Result FindColIndex(const Schema& schema, + const FieldRef& field_ref, + std::string_view key_kind); - if (on_key_type == NULLPTR) { - on_key_type = on_field->type().get(); - } else if (*on_key_type != *on_field->type()) { - return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", - *on_field->type(), " for field ", on_field->name(), - " in input ", j); - } - for (size_t k = 0; k < n_by; k++) { - if (by_key_type[k] == NULLPTR) { - by_key_type[k] = by_field[k]->type().get(); - } else if (*by_key_type[k] != *by_field[k]->type()) { - return Status::Invalid("Expected by-key type ", *by_key_type[k], " but got ", - *by_field[k]->type(), " for field ", by_field[k]->name(), - " in input ", j); - } + static Result GetByKeySize( + const std::vector& input_keys); + + static Result> GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys); + + static Result>> GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys); + + static arrow::Result Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options); + + const char* kind_name() const override; + const Ordering& ordering() const override; + + Status InputReceived(ExecNode* input, ExecBatch batch) override; + + Status InputFinished(ExecNode* input, int total_batches) override; + + void PushProcess(bool value); + +#ifdef ARROW_ENABLE_THREADING + bool ProcessNonThreaded(); + void EndFromSingleThread(Status st = Status::OK()); +#endif + + Status StartProducing() override; + + void PauseProducing(ExecNode* output, int32_t counter) override; + void ResumeProducing(ExecNode* output, int32_t counter) override; + + Status StopProducingImpl() override; + +#ifndef NDEBUG + std::ostream* GetDebugStream(); + + std::mutex* GetDebugMutex(); +#endif + + private: + // Outputs from this node are always in ascending order according to the on key + const Ordering ordering_; + std::vector indices_of_on_key_; + std::vector> indices_of_by_key_; + std::vector> key_hashers_; + bool must_hash_; + bool may_rehash_; + // InputStates + // Each input state corresponds to an input table + std::vector> state_; + std::mutex gate_; + TolType tolerance_; +#ifndef NDEBUG + std::ostream* debug_os_; + std::mutex* debug_mutex_; +#endif + + // Backpressure counter common to all inputs + std::atomic backpressure_counter_; +#ifdef ARROW_ENABLE_THREADING + // Queue for triggering processing of a given input + // (a false value is a poison pill) + ConcurrentQueue process_; + // Worker thread + std::thread process_thread_; +#endif + Future<> process_task_; + + // In-progress batches produced + int batches_produced_ = 0; +}; + +InputState::InputState(size_t index, TolType tolerance, bool must_hash, bool may_rehash, + KeyHasher* key_hasher, AsofJoinNode* node, + BackpressureHandler handler, + const std::shared_ptr& schema, + const col_index_t time_col_index, + const std::vector& key_col_index) + : sequencer_(util::SerialSequencingQueue::Make(this)), + queue_(std::move(handler)), + schema_(schema), + time_col_index_(time_col_index), + key_col_index_(key_col_index), + time_type_id_(schema_->fields()[time_col_index_]->type()->id()), + key_type_id_(key_col_index.size()), + key_hasher_(key_hasher), + node_(node), + index_(index), + must_hash_(must_hash), + may_rehash_(may_rehash), + tolerance_(tolerance), + memo_(DEBUG_ADD(/*no_future=*/index == 0 || !tolerance.positive, node, index)) { + for (size_t k = 0; k < key_col_index_.size(); k++) { + key_type_id_[k] = schema_->fields()[key_col_index_[k]]->type()->id(); + } +} + +Result> InputState::Make( + size_t index, TolType tolerance, bool must_hash, bool may_rehash, + KeyHasher* key_hasher, ExecNode* asof_input, AsofJoinNode* asof_node, + std::atomic& backpressure_counter, + const std::shared_ptr& schema, const col_index_t time_col_index, + const std::vector& key_col_index) { + constexpr size_t low_threshold = 4, high_threshold = 8; + std::unique_ptr backpressure_control = + std::make_unique( + /*node=*/asof_input, /*output=*/asof_node, backpressure_counter); + ARROW_ASSIGN_OR_RAISE( + auto handler, BackpressureHandler::Make(asof_input, low_threshold, high_threshold, + std::move(backpressure_control))); + return std::make_unique(index, tolerance, must_hash, may_rehash, key_hasher, + asof_node, std::move(handler), schema, + time_col_index, key_col_index); +} + +col_index_t InputState::InitSrcToDstMapping(col_index_t dst_offset, + bool skip_time_and_key_fields) { + src_to_dst_.resize(schema_->num_fields()); + for (int i = 0; i < schema_->num_fields(); ++i) + if (!(skip_time_and_key_fields && IsTimeOrKeyColumn(i))) + src_to_dst_[i] = dst_offset++; + return dst_offset; +} + +const std::optional& InputState::MapSrcToDst(col_index_t src) const { + return src_to_dst_[src]; +} + +bool InputState::IsTimeOrKeyColumn(col_index_t i) const { + DCHECK_LT(i, schema_->num_fields()); + return (i == time_col_index_) || std_has(key_col_index_, i); +} + +row_index_t InputState::GetLatestRow() const { return latest_ref_row_; } + +bool InputState::Empty() const { + // cannot be empty if ref row is >0 -- can avoid slow queue lock + // below + if (latest_ref_row_ > 0) return false; + return queue_.Empty(); +} + +bool InputState::CurrentEmpty(bool empty) const { + return memo_.no_future_ ? empty : (memo_.times_.empty() && empty); +} + +OnType InputState::GetCurrentTime() const { + return memo_.no_future_ ? GetLatestTime() : static_cast(memo_.current_time_); +} + +inline ByType InputState::GetLatestKey() const { + return GetKey(GetLatestBatch().get(), latest_ref_row_); +} + +#define LATEST_VAL_CASE(id, val) \ + case Type::id: { \ + using T = typename TypeIdTraits::Type; \ + using CType = typename TypeTraits::CType; \ + return val(data->GetValues(1)[row]); \ + } + +inline ByType InputState::GetKey(const RecordBatch* batch, row_index_t row) const { + if (must_hash_) { + // Query the key hasher. This may hit cache, which must be valid for the batch. + // Therefore, the key hasher is invalidated when a new batch is pushed - see + // `InputState::Push`. + return key_hasher_->HashesFor(batch)[row]; + } + if (key_col_index_.size() == 0) { + return 0; + } + auto data = batch->column_data(key_col_index_[0]); + switch (key_type_id_[0]) { + LATEST_VAL_CASE(INT8, key_value) + LATEST_VAL_CASE(INT16, key_value) + LATEST_VAL_CASE(INT32, key_value) + LATEST_VAL_CASE(INT64, key_value) + LATEST_VAL_CASE(UINT8, key_value) + LATEST_VAL_CASE(UINT16, key_value) + LATEST_VAL_CASE(UINT32, key_value) + LATEST_VAL_CASE(UINT64, key_value) + LATEST_VAL_CASE(DATE32, key_value) + LATEST_VAL_CASE(DATE64, key_value) + LATEST_VAL_CASE(TIME32, key_value) + LATEST_VAL_CASE(TIME64, key_value) + LATEST_VAL_CASE(TIMESTAMP, key_value) + default: + DCHECK(false); + return 0; // cannot happen + } +} + +#undef LATEST_VAL_CASE + +inline OnType InputState::GetLatestTime() const { + return GetTime(GetLatestBatch().get(), time_type_id_, time_col_index_, latest_ref_row_); +} + +int InputState::total_batches() const { return total_batches_; } + +Result InputState::AdvanceAndMemoize(OnType ts, bool empty) { + // Advance the right side row index until we reach the latest right row (for each key) + // for the given left timestamp. + DEBUG_SYNC(node_, "Advancing input ", index_, DEBUG_MANIP(std::endl)); + + // Check if already updated for TS (or if there is no latest) + if (empty) { // can't advance if empty and no future entries + return memo_.no_future_ ? false : memo_.RemoveEntriesWithLesserTime(ts); + } + + // Not updated. Try to update and possibly advance. + bool advanced, updated = false; + OnType latest_time; + do { + latest_time = GetLatestTime(); + // if Advance() returns true, then the latest_ts must also be valid + // Keep advancing right table until we hit the latest row that has + // timestamp <= ts. This is because we only need the latest row for the + // match given a left ts. + if (latest_time > tolerance_.Horizon(ts)) { // hit a distant timestamp + DEBUG_SYNC(node_, "Advancing input ", index_, " hit distant time=", latest_time, + " at=", ts, DEBUG_MANIP(std::endl)); + // if no future entries, which would have been earlier than the distant time, no + // need to queue it + if (memo_.future_entries_.empty()) break; + } + auto rb = GetLatestBatch(); + if (may_rehash_ && rb->column_data(key_col_index_[0])->GetNullCount() > 0) { + must_hash_ = true; + may_rehash_ = false; + Rehash(); + } + memo_.Store(rb, latest_ref_row_, latest_time, DEBUG_ADD(GetLatestKey(), ts)); + // negative tolerance means a last-known entry was stored - set `updated` to `true` + updated = memo_.no_future_; + ARROW_ASSIGN_OR_RAISE(advanced, Advance()); + } while (advanced); + if (!memo_.no_future_ && latest_time >= ts) { + // `updated` was not modified in the loop from the initial `false` value; set it now + updated = memo_.RemoveEntriesWithLesserTime(ts); + } + DEBUG_SYNC(node_, "Advancing input ", index_, " updated=", updated, + DEBUG_MANIP(std::endl)); + return updated; +} + +Status InputState::InsertBatch(ExecBatch batch) { + return sequencer_->InsertBatch(std::move(batch)); +} + +Status InputState::Process(ExecBatch batch) { + auto rb = *batch.ToRecordBatch(schema_); + DEBUG_SYNC(node_, "received batch from input ", index_, ":", DEBUG_MANIP(std::endl), + rb->ToString(), DEBUG_MANIP(std::endl)); + return Push(rb); +} + +void InputState::Rehash() { + DEBUG_SYNC(node_, "rehashing for input ", index_, ":", DEBUG_MANIP(std::endl)); + MemoStore new_memo(DEBUG_ADD(memo_.no_future_, node_, index_)); + new_memo.current_time_ = (OnType)memo_.current_time_; + for (auto e = memo_.entries_.begin(); e != memo_.entries_.end(); ++e) { + auto& entry = e->second; + auto new_key = GetKey(entry.batch.get(), entry.row); + DEBUG_SYNC(node_, " ", e->first, " to ", new_key, DEBUG_MANIP(std::endl)); + new_memo.entries_[new_key].swap(entry); + auto fe = memo_.future_entries_.find(e->first); + if (fe != memo_.future_entries_.end()) { + new_memo.future_entries_[new_key].swap(fe->second); + } + } + memo_.times_.swap(new_memo.times_); + memo_.swap(new_memo); +} + +Status InputState::Push(const std::shared_ptr& rb) { + if (rb->num_rows() > 0) { + key_hasher_->Invalidate(); // batch changed - invalidate key hasher's cache + queue_.Push(rb); // only now push batch for processing + } else { + ++batches_processed_; // don't enqueue empty batches, just record as processed + } + return Status::OK(); +} + +std::optional InputState::GetMemoEntryForKey(ByType key) { + return memo_.GetEntryForKey(key); +} + +std::optional InputState::GetMemoTimeForKey(ByType key) { + auto r = GetMemoEntryForKey(key); + if (r.has_value()) { + return (*r)->time; + } else { + return std::nullopt; + } +} + +void InputState::RemoveMemoEntriesWithLesserTime(OnType ts) { + memo_.RemoveEntriesWithLesserTime(ts); +} + +const std::shared_ptr& InputState::get_schema() const { return schema_; } + +void InputState::set_total_batches(int n) { + DCHECK_GE(n, 0); + DCHECK_EQ(total_batches_, -1) << "Set total batch more than once"; + total_batches_ = n; +} + +Status InputState::ForceShutdown() { + // Force the upstream input node to unpause. Necessary to avoid deadlock when we + // terminate the process thread + return queue_.ForceShutdown(); +} + +KeyHasher::KeyHasher(size_t index, const std::vector& indices) + : index_(index), + indices_(indices), + metadata_(indices.size()), + batch_(NULLPTR), + hashes_(), + ctx_(), + column_arrays_(), + stack_() { + ctx_.stack = &stack_; + column_arrays_.resize(indices.size()); +} + +Status KeyHasher::Init(ExecContext* exec_context, + const std::shared_ptr& schema) { + ctx_.hardware_flags = exec_context->cpu_info()->hardware_flags(); + const auto& fields = schema->fields(); + for (size_t k = 0; k < metadata_.size(); k++) { + ARROW_ASSIGN_OR_RAISE(metadata_[k], + ColumnMetadataFromDataType(fields[indices_[k]]->type())); + } + return stack_.Init(exec_context->memory_pool(), + 4 * kMiniBatchLength * sizeof(uint32_t)); +} + +void KeyHasher::Invalidate() { batch_ = NULLPTR; } + +// compute and cache a hash for each row of the given batch +const std::vector& KeyHasher::HashesFor(const RecordBatch* batch) { + if (batch_ == batch) { + return hashes_; // cache hit - return cached hashes + } + Invalidate(); + size_t batch_length = batch->num_rows(); + hashes_.resize(batch_length); + for (int64_t i = 0; i < static_cast(batch_length); i += kMiniBatchLength) { + int64_t length = std::min(static_cast(batch_length - i), + static_cast(kMiniBatchLength)); + for (size_t k = 0; k < indices_.size(); k++) { + auto array_data = batch->column_data(indices_[k]); + column_arrays_[k] = + ColumnArrayFromArrayDataAndMetadata(array_data, metadata_[k], i, length); + } + // write directly to the cache + Hashing64::HashMultiColumn(column_arrays_, &ctx_, hashes_.data() + i); + } + DEBUG_SYNC(node_, "key hasher ", index_, " got hashes ", + compute::internal::GenericToString(hashes_), DEBUG_MANIP(std::endl)); + batch_ = batch; // associate cache with current batch + return hashes_; +} + +Result AsofJoinNode::UpdateRhs() { + auto& lhs = *state_.at(0); + auto lhs_latest_time = lhs.GetLatestTime(); + RhsUpdateState update_state{/*any_advanced=*/false, /*all_up_to_date_with_lhs=*/true}; + for (size_t i = 1; i < state_.size(); ++i) { + auto& rhs = *state_[i]; + + // Obtain RHS emptiness once for subsequent AdvanceAndMemoize() and CurrentEmpty(). + bool rhs_empty = rhs.Empty(); + // Obtain RHS current time here because AdvanceAndMemoize() can change the + // emptiness. + OnType rhs_current_time = rhs_empty ? OnType{} : rhs.GetLatestTime(); + + ARROW_ASSIGN_OR_RAISE(bool advanced, + rhs.AdvanceAndMemoize(lhs_latest_time, rhs_empty)); + update_state.any_advanced |= advanced; + + if (update_state.all_up_to_date_with_lhs && !rhs.Finished()) { + // If RHS is finished, then we know it's up to date + if (rhs.CurrentEmpty(rhs_empty)) { + // RHS isn't finished, but is empty --> not up to date + update_state.all_up_to_date_with_lhs = false; + } else if (lhs_latest_time > rhs_current_time) { + // RHS isn't up to date (and not finished) + update_state.all_up_to_date_with_lhs = false; } + } + } + return update_state; +} - for (int i = 0; i < input_schema[j]->num_fields(); ++i) { - const auto field = input_schema[j]->field(i); - bool as_output; // true if the field appears as an output - if (i == on_field_ix) { - ARROW_RETURN_NOT_OK(is_valid_on_field(field)); - // Only add on field from the left table - as_output = (j == 0); - } else if (std_has(by_field_ix, i)) { - ARROW_RETURN_NOT_OK(is_valid_by_field(field)); - // Only add by field from the left table - as_output = (j == 0); - } else { - ARROW_RETURN_NOT_OK(is_valid_data_field(field)); - as_output = true; +#ifdef ARROW_ENABLE_THREADING +void AsofJoinNode::EndFromProcessThread(Status st) { + // We must spawn a new task to transfer off the process thread when + // marking this finished. Otherwise there is a chance that doing so could + // mark the plan finished which may destroy the plan which will destroy this + // node which will cause us to join on ourselves. + ARROW_UNUSED( + plan_->query_context()->executor()->Spawn([this, st = std::move(st)]() mutable { + Defer cleanup([this, &st]() { process_task_.MarkFinished(st); }); + if (st.ok()) { + st = output_->InputFinished(this, batches_produced_); } - if (as_output) { - fields.push_back(field); + for (const auto& s : state_) { + st &= s->ForceShutdown(); } + })); +} + +bool AsofJoinNode::CheckEnded() { + if (state_.at(0)->Finished()) { + EndFromProcessThread(); + return false; + } + return true; +} + +bool AsofJoinNode::Process() { + std::lock_guard guard(gate_); + if (!CheckEnded()) { + return false; + } + + // Process batches while we have data + for (;;) { + Result> result = ProcessInner(); + + if (result.ok()) { + auto out_rb = *result; + if (!out_rb) break; + ExecBatch out_b(*out_rb); + out_b.index = batches_produced_++; + DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), + out_rb->ToString(), DEBUG_MANIP(std::endl)); + Status st = output_->InputReceived(this, std::move(out_b)); + if (!st.ok()) { + EndFromProcessThread(std::move(st)); } + } else { + EndFromProcessThread(result.status()); + return false; } - return std::make_shared(fields); } - static inline Result FindColIndex(const Schema& schema, - const FieldRef& field_ref, - std::string_view key_kind) { - auto match_res = field_ref.FindOne(schema); - if (!match_res.ok()) { - return Status::Invalid("Bad join key on table : ", match_res.status().message()); + // Report to the output the total batch count, if we've already finished everything + // (there are two places where this can happen: here and InputFinished) + // + // It may happen here in cases where InputFinished was called before we were finished + // producing results (so we didn't know the output size at that time) + if (!CheckEnded()) { + return false; + } + + // There is no more we can do now but there is still work remaining for later when + // more data arrives. + return true; +} + +void AsofJoinNode::ProcessThread() { + for (;;) { + if (!process_.WaitAndPop()) { + EndFromProcessThread(); + return; } - ARROW_ASSIGN_OR_RAISE(auto match, match_res); - if (match.indices().size() != 1) { - return Status::Invalid("AsOfJoinNode does not support a nested ", key_kind, "-key ", - field_ref.ToString()); + if (!Process()) { + return; } - return match.indices()[0]; } +} - static Result GetByKeySize( - const std::vector& input_keys) { - size_t n_by = 0; - for (size_t i = 0; i < input_keys.size(); ++i) { - const auto& by_key = input_keys[i].by_key; - if (i == 0) { - n_by = by_key.size(); - } else if (n_by != by_key.size()) { - return Status::Invalid("inconsistent size of by-key across inputs"); - } - } - return n_by; +void AsofJoinNode::ProcessThreadWrapper(AsofJoinNode* node) { node->ProcessThread(); } +#endif + +Status AsofJoinNode::Init() { + auto inputs = this->inputs(); + for (size_t i = 0; i < inputs.size(); i++) { + RETURN_NOT_OK(key_hashers_[i]->Init(plan()->query_context()->exec_context(), + inputs[i]->output_schema())); + ARROW_ASSIGN_OR_RAISE( + auto input_state, + InputState::Make(i, tolerance_, must_hash_, may_rehash_, key_hashers_[i].get(), + inputs[i], this, backpressure_counter_, + inputs[i]->output_schema(), indices_of_on_key_[i], + indices_of_by_key_[i])); + state_.push_back(std::move(input_state)); } - static Result> GetIndicesOfOnKey( - const std::vector>& input_schema, - const std::vector& input_keys) { - if (input_schema.size() != input_keys.size()) { - return Status::Invalid("mismatching number of input schema and keys"); + col_index_t dst_offset = 0; + for (auto& state : state_) + dst_offset = state->InitSrcToDstMapping(dst_offset, !!dst_offset); + + return Status::OK(); +} + +AsofJoinNode::~AsofJoinNode() { +#ifdef ARROW_ENABLE_THREADING + PushProcess(false); + if (process_thread_.joinable()) { + process_thread_.join(); + } +#endif +} + +const std::vector& AsofJoinNode::indices_of_on_key() { + return indices_of_on_key_; +} +const std::vector>& AsofJoinNode::indices_of_by_key() { + return indices_of_by_key_; +} + +Status AsofJoinNode::is_valid_on_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for on-key ", field->name(), " : ", + field->type()->ToString()); + } +} + +Status AsofJoinNode::is_valid_by_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for by-key ", field->name(), " : ", + field->type()->ToString()); + } +} + +Status AsofJoinNode::is_valid_data_field(const std::shared_ptr& field) { + switch (field->type()->id()) { + case Type::BOOL: + case Type::INT8: + case Type::INT16: + case Type::INT32: + case Type::INT64: + case Type::UINT8: + case Type::UINT16: + case Type::UINT32: + case Type::UINT64: + case Type::FLOAT: + case Type::DOUBLE: + case Type::DATE32: + case Type::DATE64: + case Type::TIME32: + case Type::TIME64: + case Type::TIMESTAMP: + case Type::STRING: + case Type::LARGE_STRING: + case Type::BINARY: + case Type::LARGE_BINARY: + return Status::OK(); + default: + return Status::Invalid("Unsupported type for data field ", field->name(), " : ", + field->type()->ToString()); + } +} + +/// \brief Make the output schema of an as-of-join node +/// +/// \param[in] input_schema the schema of each input to the node +/// \param[in] indices_of_on_key the on-key index of each input to the node +/// \param[in] indices_of_by_key the by-key indices of each input to the node +arrow::Result> AsofJoinNode::MakeOutputSchema( + const std::vector> input_schema, + const std::vector& indices_of_on_key, + const std::vector>& indices_of_by_key) { + std::vector> fields; + + size_t n_by = indices_of_by_key.size() == 0 ? 0 : indices_of_by_key[0].size(); + const DataType* on_key_type = NULLPTR; + std::vector by_key_type(n_by, NULLPTR); + // Take all non-key, non-time RHS fields + for (size_t j = 0; j < input_schema.size(); ++j) { + const auto& on_field_ix = indices_of_on_key[j]; + const auto& by_field_ix = indices_of_by_key[j]; + + if ((on_field_ix == -1) || std_has(by_field_ix, -1)) { + return Status::Invalid("Missing join key on table ", j); } - size_t n_input = input_schema.size(); - std::vector indices_of_on_key(n_input); - for (size_t i = 0; i < n_input; ++i) { - const auto& on_key = input_keys[i].on_key; - ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], - FindColIndex(*input_schema[i], on_key, "on")); + + const auto& on_field = input_schema[j]->fields()[on_field_ix]; + std::vector by_field(n_by); + for (size_t k = 0; k < n_by; k++) { + by_field[k] = input_schema[j]->fields()[by_field_ix[k]].get(); } - return indices_of_on_key; - } - static Result>> GetIndicesOfByKey( - const std::vector>& input_schema, - const std::vector& input_keys) { - if (input_schema.size() != input_keys.size()) { - return Status::Invalid("mismatching number of input schema and keys"); + if (on_key_type == NULLPTR) { + on_key_type = on_field->type().get(); + } else if (*on_key_type != *on_field->type()) { + return Status::Invalid("Expected on-key type ", *on_key_type, " but got ", + *on_field->type(), " for field ", on_field->name(), + " in input ", j); } - ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); - size_t n_input = input_schema.size(); - std::vector> indices_of_by_key( - n_input, std::vector(n_by)); - for (size_t i = 0; i < n_input; ++i) { - for (size_t k = 0; k < n_by; k++) { - const auto& by_key = input_keys[i].by_key; - ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], - FindColIndex(*input_schema[i], by_key[k], "by")); + for (size_t k = 0; k < n_by; k++) { + if (by_key_type[k] == NULLPTR) { + by_key_type[k] = by_field[k]->type().get(); + } else if (*by_key_type[k] != *by_field[k]->type()) { + return Status::Invalid("Expected by-key type ", *by_key_type[k], " but got ", + *by_field[k]->type(), " for field ", by_field[k]->name(), + " in input ", j); } } - return indices_of_by_key; - } - static arrow::Result Make(ExecPlan* plan, std::vector inputs, - const ExecNodeOptions& options) { - DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; - const auto& join_options = checked_cast(options); - ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); - size_t n_input = inputs.size(); - std::vector input_labels(n_input); - std::vector> input_schema(n_input); - for (size_t i = 0; i < n_input; ++i) { - input_labels[i] = i == 0 ? "left" : "right_" + ToChars(i); - input_schema[i] = inputs[i]->output_schema(); + for (int i = 0; i < input_schema[j]->num_fields(); ++i) { + const auto field = input_schema[j]->field(i); + bool as_output; // true if the field appears as an output + if (i == on_field_ix) { + ARROW_RETURN_NOT_OK(is_valid_on_field(field)); + // Only add on field from the left table + as_output = (j == 0); + } else if (std_has(by_field_ix, i)) { + ARROW_RETURN_NOT_OK(is_valid_by_field(field)); + // Only add by field from the left table + as_output = (j == 0); + } else { + ARROW_RETURN_NOT_OK(is_valid_data_field(field)); + as_output = true; + } + if (as_output) { + fields.push_back(field); + } } - ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, - GetIndicesOfOnKey(input_schema, join_options.input_keys)); - ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, - GetIndicesOfByKey(input_schema, join_options.input_keys)); - ARROW_ASSIGN_OR_RAISE( - std::shared_ptr output_schema, - MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); + } + return std::make_shared(fields); +} - std::vector> key_hashers; - for (size_t i = 0; i < n_input; i++) { - key_hashers.push_back(std::make_unique(i, indices_of_by_key[i])); +inline Result AsofJoinNode::FindColIndex(const Schema& schema, + const FieldRef& field_ref, + std::string_view key_kind) { + auto match_res = field_ref.FindOne(schema); + if (!match_res.ok()) { + return Status::Invalid("Bad join key on table : ", match_res.status().message()); + } + ARROW_ASSIGN_OR_RAISE(auto match, match_res); + if (match.indices().size() != 1) { + return Status::Invalid("AsOfJoinNode does not support a nested ", key_kind, "-key ", + field_ref.ToString()); + } + return match.indices()[0]; +} + +Result AsofJoinNode::GetByKeySize( + const std::vector& input_keys) { + size_t n_by = 0; + for (size_t i = 0; i < input_keys.size(); ++i) { + const auto& by_key = input_keys[i].by_key; + if (i == 0) { + n_by = by_key.size(); + } else if (n_by != by_key.size()) { + return Status::Invalid("inconsistent size of by-key across inputs"); } - bool must_hash = - n_by > 1 || - (n_by == 1 && - !is_primitive( - inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); - bool may_rehash = n_by == 1 && !must_hash; - return plan->EmplaceNode( - plan, inputs, std::move(input_labels), std::move(indices_of_on_key), - std::move(indices_of_by_key), std::move(join_options), std::move(output_schema), - std::move(key_hashers), must_hash, may_rehash); - } - - const char* kind_name() const override { return "AsofJoinNode"; } - const Ordering& ordering() const override { return ordering_; } - - Status InputReceived(ExecNode* input, ExecBatch batch) override { - // InputReceived may be called after execution was finished. Pushing it to the - // InputState is unnecessary since we're done (and anyway may cause the - // BackPressureController to pause the input, causing a deadlock), so drop it. - if (::arrow::compute::kUnsequencedIndex == batch.index) - return Status::Invalid("AsofJoin requires sequenced input"); - - if (process_task_.is_finished()) { - DEBUG_SYNC(this, "Input received while done. Short circuiting.", - DEBUG_MANIP(std::endl)); - return Status::OK(); + } + return n_by; +} + +Result> AsofJoinNode::GetIndicesOfOnKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + size_t n_input = input_schema.size(); + std::vector indices_of_on_key(n_input); + for (size_t i = 0; i < n_input; ++i) { + const auto& on_key = input_keys[i].on_key; + ARROW_ASSIGN_OR_RAISE(indices_of_on_key[i], + FindColIndex(*input_schema[i], on_key, "on")); + } + return indices_of_on_key; +} + +Result>> AsofJoinNode::GetIndicesOfByKey( + const std::vector>& input_schema, + const std::vector& input_keys) { + if (input_schema.size() != input_keys.size()) { + return Status::Invalid("mismatching number of input schema and keys"); + } + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(input_keys)); + size_t n_input = input_schema.size(); + std::vector> indices_of_by_key(n_input, + std::vector(n_by)); + for (size_t i = 0; i < n_input; ++i) { + for (size_t k = 0; k < n_by; k++) { + const auto& by_key = input_keys[i].by_key; + ARROW_ASSIGN_OR_RAISE(indices_of_by_key[i][k], + FindColIndex(*input_schema[i], by_key[k], "by")); } + } + return indices_of_by_key; +} - // Get the input - ARROW_DCHECK(std_has(inputs_, input)); - size_t k = std_find(inputs_, input) - inputs_.begin(); +arrow::Result AsofJoinNode::Make(ExecPlan* plan, std::vector inputs, + const ExecNodeOptions& options) { + DCHECK_GE(inputs.size(), 2) << "Must have at least two inputs"; + const auto& join_options = checked_cast(options); + ARROW_ASSIGN_OR_RAISE(size_t n_by, GetByKeySize(join_options.input_keys)); + size_t n_input = inputs.size(); + std::vector input_labels(n_input); + std::vector> input_schema(n_input); + for (size_t i = 0; i < n_input; ++i) { + input_labels[i] = i == 0 ? "left" : "right_" + ToChars(i); + input_schema[i] = inputs[i]->output_schema(); + } + ARROW_ASSIGN_OR_RAISE(std::vector indices_of_on_key, + GetIndicesOfOnKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE(std::vector> indices_of_by_key, + GetIndicesOfByKey(input_schema, join_options.input_keys)); + ARROW_ASSIGN_OR_RAISE( + std::shared_ptr output_schema, + MakeOutputSchema(input_schema, indices_of_on_key, indices_of_by_key)); + + std::vector> key_hashers; + for (size_t i = 0; i < n_input; i++) { + key_hashers.push_back(std::make_unique(i, indices_of_by_key[i])); + } + bool must_hash = + n_by > 1 || + (n_by == 1 && + !is_primitive( + inputs[0]->output_schema()->field(indices_of_by_key[0][0])->type()->id())); + bool may_rehash = n_by == 1 && !must_hash; + return plan->EmplaceNode( + plan, inputs, std::move(input_labels), std::move(indices_of_on_key), + std::move(indices_of_by_key), std::move(join_options), std::move(output_schema), + std::move(key_hashers), must_hash, may_rehash); +} - // Put into the sequencing queue - ARROW_RETURN_NOT_OK(state_.at(k)->InsertBatch(std::move(batch))); +const char* AsofJoinNode::kind_name() const { return "AsofJoinNode"; } +const Ordering& AsofJoinNode::ordering() const { return ordering_; } - PushProcess(true); +Status AsofJoinNode::InputReceived(ExecNode* input, ExecBatch batch) { + // InputReceived may be called after execution was finished. Pushing it to the + // InputState is unnecessary since we're done (and anyway may cause the + // BackPressureController to pause the input, causing a deadlock), so drop it. + if (::arrow::compute::kUnsequencedIndex == batch.index) + return Status::Invalid("AsofJoin requires sequenced input"); + if (process_task_.is_finished()) { + DEBUG_SYNC(this, "Input received while done. Short circuiting.", + DEBUG_MANIP(std::endl)); return Status::OK(); } - Status InputFinished(ExecNode* input, int total_batches) override { - { - std::lock_guard guard(gate_); - ARROW_DCHECK(std_has(inputs_, input)); - size_t k = std_find(inputs_, input) - inputs_.begin(); - state_.at(k)->set_total_batches(total_batches); - } - // Trigger a process call - // The reason for this is that there are cases at the end of a table where we don't - // know whether the RHS of the join is up-to-date until we know that the table is - // finished. - PushProcess(true); + // Get the input + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); - return Status::OK(); + // Put into the sequencing queue + ARROW_RETURN_NOT_OK(state_.at(k)->InsertBatch(std::move(batch))); + + PushProcess(true); + + return Status::OK(); +} + +Status AsofJoinNode::InputFinished(ExecNode* input, int total_batches) { + { + std::lock_guard guard(gate_); + ARROW_DCHECK(std_has(inputs_, input)); + size_t k = std_find(inputs_, input) - inputs_.begin(); + state_.at(k)->set_total_batches(total_batches); } - void PushProcess(bool value) { + // Trigger a process call + // The reason for this is that there are cases at the end of a table where we don't + // know whether the RHS of the join is up-to-date until we know that the table is + // finished. + PushProcess(true); + + return Status::OK(); +} + +void AsofJoinNode::PushProcess(bool value) { #ifdef ARROW_ENABLE_THREADING - process_.Push(value); + process_.Push(value); #else - if (value) { - ProcessNonThreaded(); - } else if (!process_task_.is_finished()) { - EndFromSingleThread(); - } -#endif + if (value) { + ProcessNonThreaded(); + } else if (!process_task_.is_finished()) { + EndFromSingleThread(); } +#endif +} #ifndef ARROW_ENABLE_THREADING - bool ProcessNonThreaded() { - while (!process_task_.is_finished()) { - Result> result = ProcessInner(); - - if (result.ok()) { - auto out_rb = *result; - if (!out_rb) break; - ExecBatch out_b(*out_rb); - out_b.index = batches_produced_++; - DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), - out_rb->ToString(), DEBUG_MANIP(std::endl)); - Status st = output_->InputReceived(this, std::move(out_b)); - if (!st.ok()) { - // this isn't really from a thread, - // but we call through to this for consistency - EndFromSingleThread(std::move(st)); - return false; - } - } else { +bool AsofJoinNode::ProcessNonThreaded() { + while (!process_task_.is_finished()) { + Result> result = ProcessInner(); + + if (result.ok()) { + auto out_rb = *result; + if (!out_rb) break; + ExecBatch out_b(*out_rb); + out_b.index = batches_produced_++; + DEBUG_SYNC(this, "produce batch ", out_b.index, ":", DEBUG_MANIP(std::endl), + out_rb->ToString(), DEBUG_MANIP(std::endl)); + Status st = output_->InputReceived(this, std::move(out_b)); + if (!st.ok()) { // this isn't really from a thread, // but we call through to this for consistency - EndFromSingleThread(result.status()); + EndFromSingleThread(std::move(st)); return false; } + } else { + // this isn't really from a thread, + // but we call through to this for consistency + EndFromSingleThread(result.status()); + return false; } - auto& lhs = *state_.at(0); - if (lhs.Finished() && !process_task_.is_finished()) { - EndFromSingleThread(Status::OK()); - } - return true; } - - void EndFromSingleThread(Status st = Status::OK()) { - process_task_.MarkFinished(st); - if (st.ok()) { - st = output_->InputFinished(this, batches_produced_); - } - for (const auto& s : state_) { - st &= s->ForceShutdown(); - } + auto& lhs = *state_.at(0); + if (lhs.Finished() && !process_task_.is_finished()) { + EndFromSingleThread(Status::OK()); } + return true; +} +void AsofJoinNode::EndFromSingleThread(Status st = Status::OK()) { + process_task_.MarkFinished(st); + if (st.ok()) { + st = output_->InputFinished(this, batches_produced_); + } + for (const auto& s : state_) { + st &= s->ForceShutdown(); + } +} #endif - Status StartProducing() override { - ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask( - "AsofJoinNode::ProcessThread")); - if (!process_task_.is_valid()) { - // Plan has already aborted. Do not start process thread - return Status::OK(); - } -#ifdef ARROW_ENABLE_THREADING - process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this); -#endif +Status AsofJoinNode::StartProducing() { + ARROW_ASSIGN_OR_RAISE(process_task_, plan_->query_context()->BeginExternalTask( + "AsofJoinNode::ProcessThread")); + if (!process_task_.is_valid()) { + // Plan has already aborted. Do not start process thread return Status::OK(); } +#ifdef ARROW_ENABLE_THREADING + process_thread_ = std::thread(&AsofJoinNode::ProcessThreadWrapper, this); +#endif + return Status::OK(); +} - void PauseProducing(ExecNode* output, int32_t counter) override {} - void ResumeProducing(ExecNode* output, int32_t counter) override {} +void AsofJoinNode::PauseProducing(ExecNode* output, int32_t counter) {} +void AsofJoinNode::ResumeProducing(ExecNode* output, int32_t counter) {} - Status StopProducingImpl() override { +Status AsofJoinNode::StopProducingImpl() { #ifdef ARROW_ENABLE_THREADING - process_.Clear(); + process_.Clear(); #endif - PushProcess(false); - return Status::OK(); - } + PushProcess(false); + return Status::OK(); +} #ifndef NDEBUG - std::ostream* GetDebugStream() { return debug_os_; } +std::ostream* AsofJoinNode::GetDebugStream() { return debug_os_; } - std::mutex* GetDebugMutex() { return debug_mutex_; } +std::mutex* AsofJoinNode::GetDebugMutex() { return debug_mutex_; } #endif - private: - // Outputs from this node are always in ascending order according to the on key - const Ordering ordering_; - std::vector indices_of_on_key_; - std::vector> indices_of_by_key_; - std::vector> key_hashers_; - bool must_hash_; - bool may_rehash_; - // InputStates - // Each input state corresponds to an input table - std::vector> state_; - std::mutex gate_; - TolType tolerance_; +/// Wrapper around UnmaterializedCompositeTable that knows how to emplace +/// the join row-by-row +template +class CompositeTableBuilder { + using SliceBuilder = UnmaterializedSliceBuilder; + using CompositeTable = UnmaterializedCompositeTable; + + public: + NDEBUG_EXPLICIT CompositeTableBuilder( + const std::vector>& inputs, + const std::shared_ptr& schema, arrow::MemoryPool* pool, + DEBUG_ADD(size_t n_tables, AsofJoinNode* node)) + : unmaterialized_table(InitUnmaterializedTable(schema, inputs, pool)), + DEBUG_ADD(n_tables_(n_tables), node_(node)) { + DCHECK_GE(n_tables_, 1); + DCHECK_LE(n_tables_, MAX_TABLES); + } + + size_t n_rows() const { return unmaterialized_table.Size(); } + + // Adds the latest row from the input state as a new composite reference row + // - LHS must have a valid key,timestep,and latest rows + // - RHS must have valid data memo'ed for the key + void Emplace(std::vector>& in, TolType tolerance) { + DCHECK_EQ(in.size(), n_tables_); + + // Get the LHS key + ByType key = in[0]->GetLatestKey(); + + // Add row and setup LHS + // (the LHS state comes just from the latest row of the LHS table) + DCHECK(!in[0]->Empty()); + const std::shared_ptr& lhs_latest_batch = in[0]->GetLatestBatch(); + row_index_t lhs_latest_row = in[0]->GetLatestRow(); + OnType lhs_latest_time = in[0]->GetLatestTime(); + if (0 == lhs_latest_row) { + // On the first row of the batch, we resize the destination. + // The destination size is dictated by the size of the LHS batch. + row_index_t new_batch_size = lhs_latest_batch->num_rows(); + row_index_t new_capacity = unmaterialized_table.Size() + new_batch_size; + if (unmaterialized_table.capacity() < new_capacity) { + unmaterialized_table.reserve(new_capacity); + } + } + + SliceBuilder new_row{&unmaterialized_table}; + + // Each item represents a portion of the columns of the output table + new_row.AddEntry(lhs_latest_batch, lhs_latest_row, lhs_latest_row + 1); + + DEBUG_SYNC(node_, "Emplace: key=", key, " lhs_latest_row=", lhs_latest_row, + " lhs_latest_time=", lhs_latest_time, DEBUG_MANIP(std::endl)); + + // Get the state for that key from all on the RHS -- assumes it's up to date + // (the RHS state comes from the memoized row references) + for (size_t i = 1; i < in.size(); ++i) { + std::optional opt_entry = in[i]->GetMemoEntryForKey(key); #ifndef NDEBUG - std::ostream* debug_os_; - std::mutex* debug_mutex_; + { + bool has_entry = opt_entry.has_value(); + OnType entry_time = has_entry ? (*opt_entry)->time : TolType::kMinValue; + row_index_t entry_row = has_entry ? (*opt_entry)->row : 0; + bool accepted = has_entry && tolerance.Accepts(lhs_latest_time, entry_time); + DEBUG_SYNC(node_, " i=", i, " has_entry=", has_entry, " time=", entry_time, + " row=", entry_row, " accepted=", accepted, DEBUG_MANIP(std::endl)); + } #endif + if (opt_entry.has_value()) { + DCHECK(*opt_entry); + if (tolerance.Accepts(lhs_latest_time, (*opt_entry)->time)) { + // Have a valid entry + const MemoStore::Entry* entry = *opt_entry; + new_row.AddEntry(entry->batch, entry->row, entry->row + 1); + continue; + } + } + new_row.AddEntry(nullptr, 0, 1); + } + new_row.Finalize(); + } - // Backpressure counter common to all inputs - std::atomic backpressure_counter_; -#ifdef ARROW_ENABLE_THREADING - // Queue for triggering processing of a given input - // (a false value is a poison pill) - ConcurrentQueue process_; - // Worker thread - std::thread process_thread_; + // Materializes the current reference table into a target record batch + Result>> Materialize() { + return unmaterialized_table.Materialize(); + } + + // Returns true if there are no rows + bool empty() const { return unmaterialized_table.Empty(); } + + private: + CompositeTable unmaterialized_table; + + // Total number of tables in the composite table + size_t n_tables_; + +#ifndef NDEBUG + // Owning node + AsofJoinNode* node_; #endif - Future<> process_task_; - // In-progress batches produced - int batches_produced_ = 0; + static CompositeTable InitUnmaterializedTable( + const std::shared_ptr& schema, + const std::vector>& inputs, arrow::MemoryPool* pool) { + std::unordered_map> dst_to_src; + for (size_t i = 0; i < inputs.size(); i++) { + auto& input = inputs[i]; + for (int src = 0; src < input->get_schema()->num_fields(); src++) { + auto dst = input->MapSrcToDst(src); + if (dst.has_value()) { + dst_to_src[dst.value()] = std::make_pair(static_cast(i), src); + } + } + } + return CompositeTable{schema, inputs.size(), dst_to_src, pool}; + } }; +Result> AsofJoinNode::ProcessInner() { + DCHECK(!state_.empty()); + auto& lhs = *state_.at(0); + + // Construct new target table if needed + CompositeTableBuilder dst(state_, output_schema_, + plan()->query_context()->memory_pool(), + DEBUG_ADD(state_.size(), this)); + + // Generate rows into the dst table until we either run out of data or hit the row + // limit, or run out of input + for (;;) { + // If LHS is finished or empty then there's nothing we can do here + if (lhs.Finished() || lhs.Empty()) break; + + ARROW_ASSIGN_OR_RAISE(auto rhs_update_state, UpdateRhs()); + + // If we have received enough inputs to produce the next output batch + // (decided by IsUpToDateWithLhsRow), we will perform the join and + // materialize the output batch. The join is done by advancing through + // the LHS and adding joined row to rows_ (done by Emplace). Finally, + // input batches that are no longer needed are removed to free up memory. + if (rhs_update_state.all_up_to_date_with_lhs) { + dst.Emplace(state_, tolerance_); + ARROW_ASSIGN_OR_RAISE(bool advanced, lhs.Advance()); + if (!advanced) break; // if we can't advance LHS, we're done for this batch + } else { + if (!rhs_update_state.any_advanced) break; // need to wait for new data + } + } + + // Prune memo entries that have expired (to bound memory consumption) + if (!lhs.Empty()) { + for (size_t i = 1; i < state_.size(); ++i) { + OnType ts = tolerance_.Expiry(lhs.GetLatestTime()); + if (ts != TolType::kMinValue) { + state_[i]->RemoveMemoEntriesWithLesserTime(ts); + } + } + } + + // Emit the batch + if (dst.empty()) { + return NULLPTR; + } else { + ARROW_ASSIGN_OR_RAISE(auto out, dst.Materialize()); + return out.has_value() ? out.value() : NULLPTR; + } +} + AsofJoinNode::AsofJoinNode(ExecPlan* plan, NodeVector inputs, std::vector input_labels, const std::vector& indices_of_on_key, @@ -1635,4 +1792,4 @@ std::mutex* GetDebugMutex(AsofJoinNode* node) { return node->GetDebugMutex(); } #undef DEBUG_ADD } // namespace acero -} // namespace arrow +} // namespace arrow \ No newline at end of file diff --git a/cpp/src/arrow/filesystem/mockfs.cc b/cpp/src/arrow/filesystem/mockfs.cc index 15bc3f9b212..c8989c8401b 100644 --- a/cpp/src/arrow/filesystem/mockfs.cc +++ b/cpp/src/arrow/filesystem/mockfs.cc @@ -54,6 +54,7 @@ Status ValidatePath(std::string_view s) { //////////////////////////////////////////////////////////////////////////// // Filesystem structure +struct Directory; class Entry; struct File { @@ -80,37 +81,18 @@ struct Directory { TimePoint mtime; std::map> entries; - Directory(std::string name, TimePoint mtime) : name(std::move(name)), mtime(mtime) {} - Directory(Directory&& other) noexcept - : name(std::move(other.name)), - mtime(other.mtime), - entries(std::move(other.entries)) {} + Directory(std::string name, TimePoint mtime); + Directory(Directory&& other) noexcept; - Directory& operator=(Directory&& other) noexcept { - name = std::move(other.name); - mtime = other.mtime; - entries = std::move(other.entries); - return *this; - } + Directory& operator=(Directory&& other) noexcept; - Entry* Find(const std::string& s) { - auto it = entries.find(s); - if (it != entries.end()) { - return it->second.get(); - } else { - return nullptr; - } - } + Entry* Find(const std::string& s); - bool CreateEntry(const std::string& s, std::unique_ptr entry) { - DCHECK(!s.empty()); - auto p = entries.emplace(s, std::move(entry)); - return p.second; - } + bool CreateEntry(const std::string& s, std::unique_ptr entry); void AssignEntry(const std::string& s, std::unique_ptr entry); - bool DeleteEntry(const std::string& s) { return entries.erase(s) > 0; } + bool DeleteEntry(const std::string& s); private: ARROW_DISALLOW_COPY_AND_ASSIGN(Directory); @@ -120,7 +102,7 @@ struct Directory { using EntryBase = std::variant; class Entry : public EntryBase { - public: +public: Entry(Entry&&) = default; Entry& operator=(Entry&&) = default; explicit Entry(Directory&& v) : EntryBase(std::move(v)) {} @@ -180,15 +162,45 @@ class Entry : public EntryBase { } } - private: +private: ARROW_DISALLOW_COPY_AND_ASSIGN(Entry); }; +Directory::Directory(std::string name, TimePoint mtime) : name(std::move(name)), mtime(mtime) {} +Directory::Directory(Directory&& other) noexcept + : name(std::move(other.name)), + mtime(other.mtime), + entries(std::move(other.entries)) {} + +Directory& Directory::operator=(Directory&& other) noexcept { + name = std::move(other.name); + mtime = other.mtime; + entries = std::move(other.entries); + return *this; +} + +Entry* Directory::Find(const std::string& s) { + auto it = entries.find(s); + if (it != entries.end()) { + return it->second.get(); + } else { + return nullptr; + } +} + +bool Directory::CreateEntry(const std::string& s, std::unique_ptr entry) { + DCHECK(!s.empty()); + auto p = entries.emplace(s, std::move(entry)); + return p.second; +} + void Directory::AssignEntry(const std::string& s, std::unique_ptr entry) { DCHECK(!s.empty()); entries[s] = std::move(entry); } +bool Directory::DeleteEntry(const std::string& s) { return entries.erase(s) > 0; } + //////////////////////////////////////////////////////////////////////////// // Streams diff --git a/cpp/src/arrow/flight/types.h b/cpp/src/arrow/flight/types.h index 656cc00e676..aac23754b81 100644 --- a/cpp/src/arrow/flight/types.h +++ b/cpp/src/arrow/flight/types.h @@ -569,6 +569,7 @@ struct ARROW_FLIGHT_EXPORT Location : public internal::BaseType { std::shared_ptr uri_; }; + /// \brief A flight ticket and list of locations where the ticket can be /// redeemed struct ARROW_FLIGHT_EXPORT FlightEndpoint : public internal::BaseType { @@ -613,6 +614,33 @@ struct ARROW_FLIGHT_EXPORT FlightEndpoint : public internal::BaseType { + FlightEndpoint endpoint; + + RenewFlightEndpointRequest() = default; + explicit RenewFlightEndpointRequest(FlightEndpoint endpoint) + : endpoint(std::move(endpoint)) {} + + std::string ToString() const; + bool Equals(const RenewFlightEndpointRequest& other) const; + + using SuperT::Deserialize; + using SuperT::SerializeToString; + + /// \brief Serialize this message to its wire-format representation. + /// + /// Use `SerializeToString()` if you want a Result-returning version. + arrow::Status SerializeToString(std::string* out) const; + + /// \brief Deserialize this message from its wire-format representation. + /// + /// Use `Deserialize(serialized)` if you want a Result-returning version. + static arrow::Status Deserialize(std::string_view serialized, + RenewFlightEndpointRequest* out); +}; + /// \brief The access coordinates for retrieval of a dataset, returned by /// GetFlightInfo class ARROW_FLIGHT_EXPORT FlightInfo @@ -857,32 +885,6 @@ struct ARROW_FLIGHT_EXPORT CancelFlightInfoResult ARROW_FLIGHT_EXPORT std::ostream& operator<<(std::ostream& os, CancelStatus status); -/// \brief The request of the RenewFlightEndpoint action. -struct ARROW_FLIGHT_EXPORT RenewFlightEndpointRequest - : public internal::BaseType { - FlightEndpoint endpoint; - - RenewFlightEndpointRequest() = default; - explicit RenewFlightEndpointRequest(FlightEndpoint endpoint) - : endpoint(std::move(endpoint)) {} - - std::string ToString() const; - bool Equals(const RenewFlightEndpointRequest& other) const; - - using SuperT::Deserialize; - using SuperT::SerializeToString; - - /// \brief Serialize this message to its wire-format representation. - /// - /// Use `SerializeToString()` if you want a Result-returning version. - arrow::Status SerializeToString(std::string* out) const; - - /// \brief Deserialize this message from its wire-format representation. - /// - /// Use `Deserialize(serialized)` if you want a Result-returning version. - static arrow::Status Deserialize(std::string_view serialized, - RenewFlightEndpointRequest* out); -}; // FlightData in Flight.proto maps to FlightPayload here. diff --git a/cpp/src/gandiva/cache_test.cc b/cpp/src/gandiva/cache_test.cc index 96cf4a12e58..d371db59dfc 100644 --- a/cpp/src/gandiva/cache_test.cc +++ b/cpp/src/gandiva/cache_test.cc @@ -35,11 +35,13 @@ class TestCacheKey { }; TEST(TestCache, TestGetPut) { - Cache cache(2); - cache.PutObjectCode(TestCacheKey(1), "hello"); - cache.PutObjectCode(TestCacheKey(2), "world"); - ASSERT_EQ(cache.GetObjectCode(TestCacheKey(1)), "hello"); - ASSERT_EQ(cache.GetObjectCode(TestCacheKey(2)), "world"); + Cache> cache(2); + auto hello = std::make_shared("hello"); + cache.PutObjectCode(TestCacheKey(1), hello); + auto world = std::make_shared("world"); + cache.PutObjectCode(TestCacheKey(2), world); + ASSERT_EQ(cache.GetObjectCode(TestCacheKey(1)), hello); + ASSERT_EQ(cache.GetObjectCode(TestCacheKey(2)), world); } namespace {