diff --git a/hipfile/src/amd_detail/batch/batch.cpp b/hipfile/src/amd_detail/batch/batch.cpp index 7ff7c14f..59aef6e1 100644 --- a/hipfile/src/amd_detail/batch/batch.cpp +++ b/hipfile/src/amd_detail/batch/batch.cpp @@ -100,9 +100,10 @@ BatchContext::get_capacity() const noexcept } void -BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_params) +BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_params, const BatchOpMaker& make_op) { std::unique_lock _ulock{context_mutex}; + (void)default_make_op; // Check num_params first before doing anything else if (num_params > capacity - outstanding_ops.size()) { @@ -113,7 +114,7 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa throw std::invalid_argument(msg.str()); } - std::vector> pending_ops{}; + std::vector> pending_ops{}; // It would be more performant to be able to perform multiple lookups // rather than waiting to lock the DriverState lock for each lookup. @@ -124,8 +125,8 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa // file flags. auto [_file, _buffer] = Context::get()->getFileAndBuffer( param_copy->fh, param_copy->u.batch.devPtr_base, param_copy->u.batch.size, 0); - auto op = std::make_shared(std::move(param_copy), _buffer, _file); - + //auto op = std::shared_ptr{new BatchOperation{std::move(param_copy), _buffer, _file}}; + auto op = make_op(std::move(param_copy), _buffer, _file); pending_ops.push_back(op); } diff --git a/hipfile/src/amd_detail/batch/batch.h b/hipfile/src/amd_detail/batch/batch.h index f8372a17..c7aa666d 100644 --- a/hipfile/src/amd_detail/batch/batch.h +++ b/hipfile/src/amd_detail/batch/batch.h @@ -7,6 +7,7 @@ #include "hipfile.h" +#include #include #include #include @@ -28,8 +29,13 @@ struct InvalidBatchHandle : public std::invalid_argument { } }; +class IBatchOperation { +public: + virtual ~IBatchOperation() = default; +}; + /// @brief Represents a single IO Request -class BatchOperation { +class BatchOperation : public IBatchOperation { public: /// @brief Create an operation to handle and track an IO request. /// @param [in] params IO parameters @@ -50,13 +56,27 @@ class BatchOperation { const std::shared_ptr file; }; +#include "hipfile-warnings.h" +HIPFILE_WARN_OFF("unused-function") +using BatchOpMaker = std::function(std::unique_ptr, std::shared_ptr, std::shared_ptr)>; +static const std::shared_ptr default_make_op(std::unique_ptr p, std::shared_ptr b, std::shared_ptr f) +{ + return std::dynamic_pointer_cast(std::make_shared(std::move(p), b, f)); +} +// OR +inline constexpr auto DefaultBatchOpMaker = [](std::unique_ptr p, std::shared_ptr b, std::shared_ptr f){ + return std::dynamic_pointer_cast(std::make_shared(std::move(p), b, f)); +}; +HIPFILE_WARN_ON("unused-function") + class IBatchContext { public: static constexpr unsigned MAX_SIZE = 128; virtual ~IBatchContext() = default; virtual unsigned get_capacity() const noexcept = 0; - virtual void submit_operations(const hipFileIOParams_t *params, unsigned num_params) = 0; + virtual void submit_operations(const hipFileIOParams_t *params, unsigned num_params, const BatchOpMaker& make_op = DefaultBatchOpMaker) = 0; + virtual std::unordered_set> get_ops() = 0; }; class BatchContext : public IBatchContext { @@ -76,8 +96,12 @@ class BatchContext : public IBatchContext { /// @note This is an All or None operation. If one submitted operation is not valid, no operations /// will be submitted. /// - void submit_operations(const hipFileIOParams_t *params, const unsigned num_params) override; + void submit_operations(const hipFileIOParams_t *params, const unsigned num_params, const BatchOpMaker& make_op) override; + // not a real function - PoC used to peer internally that the correct factory was used + std::unordered_set> get_ops() override { + return outstanding_ops; + } private: const unsigned capacity; @@ -89,7 +113,7 @@ class BatchContext : public IBatchContext { /// but is not yet complete or completed but not yet retrieved by the /// application. /// shared_ptr as it may need to be passed to a backend. - std::unordered_set> outstanding_ops; + std::unordered_set> outstanding_ops; BatchContext(unsigned capacity); diff --git a/hipfile/test/amd_detail/batch/batch.cpp b/hipfile/test/amd_detail/batch/batch.cpp index 1c1c0044..c7df4198 100644 --- a/hipfile/test/amd_detail/batch/batch.cpp +++ b/hipfile/test/amd_detail/batch/batch.cpp @@ -10,6 +10,7 @@ #include "hipfile-test.h" #include "hipfile-warnings.h" #include "invalid-enum.h" +#include "mbatch.h" #include "mbuffer.h" #include "mfile.h" #include "mstate.h" @@ -19,6 +20,7 @@ #include #include #include +#include #include #include @@ -348,4 +350,32 @@ TEST_F(HipFileBatchContext, SubmitSingleBadParamModeInvalid) ASSERT_THROW(_context->submit_operations(&bad_io_params, 1), std::invalid_argument); } +// Not a real test - testing proof of concept +TEST_F(HipFileBatchContext, _UseMockedFactory) +{ + _context->submit_operations(&io_params, 1, MBatchOperation::MBatchOpMaker); + auto context_ops = _context->get_ops(); + for(auto op : context_ops) { + // hack since we don't have the key to directly reference + //ASSERT_EQ(typeid(op.get()), typeid(MBatchOperation)); + ASSERT_NE(dynamic_cast(op.get()), nullptr); + } +} + +TEST_F(HipFileBatchContext, _UseMockedFactoryWithQueue) +{ + auto& mocked_ops = MBatchOperation::get_queue(); + std::shared_ptr m_op = std::make_shared(); + + mocked_ops.push(m_op); + ASSERT_FALSE(mocked_ops.empty()); + _context->submit_operations(&io_params, 1, MBatchOperation::MBatchOpMaker_queue); + + ASSERT_TRUE(mocked_ops.empty()); // Queue of mocks has been emptied. + + auto context_ops = _context->get_ops(); + ASSERT_EQ(context_ops.count(m_op), 1); // m_op is in the Context. +} + + HIPFILE_WARN_NO_GLOBAL_CTOR_ON diff --git a/hipfile/test/amd_detail/mbatch.h b/hipfile/test/amd_detail/mbatch.h index c59ba908..b62dd812 100644 --- a/hipfile/test/amd_detail/mbatch.h +++ b/hipfile/test/amd_detail/mbatch.h @@ -6,20 +6,59 @@ #pragma once #include "batch/batch.h" +#include "hipfile-warnings.h" #include +#include +#include + /* * A place to create mocks for the batch module. */ namespace hipFile { +class MBatchOperation : public IBatchOperation { +public: + MBatchOperation() = default; + + inline static std::queue>& get_queue() { + HIPFILE_WARN_NO_EXIT_DTOR_OFF + static std::queue> mocked_ops; + HIPFILE_WARN_NO_EXIT_DTOR_ON + return mocked_ops; + } + + inline static constexpr auto MBatchOpMaker = [](std::unique_ptr p, std::shared_ptr b, std::shared_ptr f) { + // Discard params + (void) p; + (void) b; + (void) f; + return std::make_shared(); + }; + // OR + inline static constexpr auto MBatchOpMaker_queue = [](std::unique_ptr p, std::shared_ptr b, std::shared_ptr f) { + // Discard params + (void) p; + (void) b; + (void) f; + auto& mocked_ops = get_queue(); + if (mocked_ops.empty()) { + throw std::runtime_error("Testing error: No mocks available to construct."); + } + auto op = mocked_ops.front(); + mocked_ops.pop(); + return op; + }; +}; + class MBatchContext : public IBatchContext { public: MOCK_METHOD(unsigned, get_capacity, (), (const, noexcept, override)); - MOCK_METHOD(void, submit_operations, (const hipFileIOParams_t *params, const unsigned num_params), + MOCK_METHOD(void, submit_operations, (const hipFileIOParams_t *params, const unsigned num_params, const BatchOpMaker& make_op), (override)); + MOCK_METHOD(std::unordered_set>, get_ops, (), (override)); }; }