Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 87 additions & 3 deletions hipfile/src/amd_detail/batch/batch.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,72 @@

namespace hipFile {

// /*
// Factory was created with the help of an LLM
// - typename... Args - "parameter pack" of zero or more types
// - Args... - Expand the "parameter pack"
// - typename = std::enable_if_t<...> - dummy template param used as a gate
// used to check for function resolution
// Allows for trying all overloads.
// If no matching signature found at all: compiler error.
// - forward<Args>(args)... - "perfectly forward" all passed args into T's matched ctor.
// aka. expand the arg type pack and arg value pack into
// separate std::forward calls.
// */
// template <typename T>
// struct GenericFactory {

// // Return an instance of T
// // Usage: Factory<T>::make(...)
// template <
// typename... Args,
// typename = std::enable_if_t<std::is_constructible<T, Args...>::value>
// >
// static T make(Args&&... args)
// {
// return T(std::forward<Args>(args)...);
// }

// // Return a std::shared_ptr of T
// // Usage:: Factory<T>::make_shared(...)
// template <
// typename... Args,
// typename = std::enable_if_t<std::is_constructible<T, Args...>::value>
// >
// static std::shared_ptr<T> make_shared(Args&&... args)
// {
// return std::make_shared<T>(std::forward<Args>(args)...);
// }

// // Return a std::shared_ptr of a base class of T
// // Usage: Factory<T>::make_shared_as<BaseT>(...)
// template <
// typename BaseT,
// typename... Args,
// typename = std::enable_if_t<
// std::is_constructible<T, Args...>::value &&
// std::is_base_of<BaseT, T>::value
// >
// >
// static std::shared_ptr<BaseT> make_shared_as(Args&&... args)
// {
// return std::static_pointer_cast<BaseT>(std::make_shared<T>(std::forward<Args>(args)...));
// //return std::shared_ptr<BaseT>{new T{std::forward<Args(args)...}}
// //return std::make_shared<T>(std::forward<Args>(args)...);
// }

// // Return a std::unique_ptr of T
// template <
// typename... Args,
// typename = std::enable_if_t<std::is_constructible<T, Args...>::value>
// >
// static std::unique_ptr<T> make_unique(Args&&... args)
// {
// return std::make_unique<T>(std::forward<Args>(args)...);
// }
// };


BatchOperation::BatchOperation(std::unique_ptr<const hipFileIOParams_t> params,
std::shared_ptr<IBuffer> _buffer, std::shared_ptr<IFile> _file)
: io_params{std::move(params)}, buffer{_buffer}, file{_file}
Expand Down Expand Up @@ -99,9 +165,26 @@ BatchContext::get_capacity() const noexcept
return capacity;
}

//using BatchOpFactory = std::function<Factory>

// template <typename Factory>
// void
// test_usage(Factory& fac)
// {
// (void)fac;
// }

// void
// test_caller()
// {
// Factory<BatchOperation> fac;
// test_usage(fac);
// }

void
BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_params)
BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_params, IFactory<IBatchOperation> factory)
{
(void)factory;
std::unique_lock<std::shared_mutex> _ulock{context_mutex};

// Check num_params first before doing anything else
Expand All @@ -113,7 +196,7 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa
throw std::invalid_argument(msg.str());
}

std::vector<std::shared_ptr<BatchOperation>> pending_ops{};
std::vector<std::shared_ptr<IBatchOperation>> pending_ops{};

// It would be more performant to be able to perform multiple lookups
// rather than waiting to lock the DriverState lock for each lookup.
Expand All @@ -124,7 +207,8 @@ BatchContext::submit_operations(const hipFileIOParams_t *params, unsigned num_pa
// file flags.
auto [_file, _buffer] = Context<DriverState>::get()->getFileAndBuffer(
param_copy->fh, param_copy->u.batch.devPtr_base, param_copy->u.batch.size, 0);
auto op = std::make_shared<BatchOperation>(std::move(param_copy), _buffer, _file);
auto op = std::shared_ptr<IBatchOperation>{new BatchOperation{std::move(param_copy), _buffer, _file}};


pending_ops.push_back(op);
}
Expand Down
98 changes: 94 additions & 4 deletions hipfile/src/amd_detail/batch/batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,94 @@ class IFile;

namespace hipFile {

/*
* Intentionally empty interface
*/
template <typename Interface>
class IFactory {
};

/*
Factory was created with the help of an LLM
- typename... Args - "parameter pack" of zero or more types
- Args... - Expand the "parameter pack"
- typename = std::enable_if_t<...> - dummy template param used as a gate
used to check for function resolution
Allows for trying all overloads.
If no matching signature found at all: compiler error.
- forward<Args>(args)... - "perfectly forward" all passed args into T's matched ctor.
aka. expand the arg type pack and arg value pack into
separate std::forward calls.
*/
template <
typename Interface,
typename Implementation
>
class GenericFactory {
public:
// Return an instance of Implementation
// Usage: Factory<Implementation>::make(...)
template <
typename... Args,
typename = std::enable_if_t<std::is_constructible<Implementation, Args...>::value>
>
static Implementation make(Args&&... args)
{
return Implementation(std::forward<Args>(args)...);
}

// Return a std::shared_ptr<Interface> to an Implementation instance
// Usage:: Factory<Interface>::make_shared(...)
template <
typename... Args,
typename = std::enable_if_t<std::is_constructible<Implementation, Args...>::value>
>
static std::shared_ptr<Interface> make_shared(Args&&... args)
{
return std::static_pointer_cast<Interface>(std::make_shared<Implementation>(std::forward<Args>(args)...));
}

// // Return a std::shared_ptr<Interface> to an arbitrary
// // Usage: Factory<T>::make_shared_as<BaseT>(...)
// template <
// typename BaseT,
// typename... Args,
// typename = std::enable_if_t<
// std::is_constructible<T, Args...>::value &&
// std::is_base_of<BaseT, T>::value
// >
// >
// static std::shared_ptr<BaseT> make_shared_as(Args&&... args)
// {
// return std::static_pointer_cast<BaseT>(std::make_shared<T>(std::forward<Args>(args)...));
// //return std::shared_ptr<BaseT>{new T{std::forward<Args(args)...}}
// //return std::make_shared<T>(std::forward<Args>(args)...);
// }

// // Return a std::unique_ptr of T
// template <
// typename... Args,
// typename = std::enable_if_t<std::is_constructible<T, Args...>::value>
// >
// static std::unique_ptr<T> make_unique(Args&&... args)
// {
// return std::make_unique<T>(std::forward<Args>(args)...);
// }
};

struct InvalidBatchHandle : public std::invalid_argument {
InvalidBatchHandle() : std::invalid_argument{"Invalid batch handle"}
{
}
};

class IBatchOperation {
public:
virtual ~IBatchOperation() = default;
};

/// @brief Represents a single IO Request
class BatchOperation {
class BatchOperation : public IBatchOperation {
public:
/// @brief Create an operation to handle and track an IO request.
/// @param [in] params IO parameters
Expand All @@ -50,13 +130,21 @@ class BatchOperation {
const std::shared_ptr<const IFile> file;
};

template <
typename Implementation,
typename Interface = IBatchOperation
>
class BatchOperationFactory : public IFactory<Interface>, public GenericFactory<Interface, Implementation>
{
};

class IBatchContext {
public:
static constexpr unsigned MAX_SIZE = 128;

virtual ~IBatchContext() = default;
virtual unsigned get_capacity() const noexcept = 0;
virtual void submit_operations(const hipFileIOParams_t *params, unsigned num_params) = 0;
virtual void submit_operations(const hipFileIOParams_t *params, unsigned num_params, IFactory<IBatchOperation> factory = BatchOperationFactory<BatchOperation>{}) = 0;
};

class BatchContext : public IBatchContext {
Expand All @@ -76,7 +164,7 @@ class BatchContext : public IBatchContext {
/// @note This is an All or None operation. If one submitted operation is not valid, no operations
/// will be submitted.
///
void submit_operations(const hipFileIOParams_t *params, const unsigned num_params) override;
void submit_operations(const hipFileIOParams_t *params, const unsigned num_params, IFactory<IBatchOperation> factory) override;

private:
const unsigned capacity;
Expand All @@ -89,7 +177,7 @@ class BatchContext : public IBatchContext {
/// but is not yet complete or completed but not yet retrieved by the
/// application.
/// shared_ptr as it may need to be passed to a backend.
std::unordered_set<std::shared_ptr<BatchOperation>> outstanding_ops;
std::unordered_set<std::shared_ptr<IBatchOperation>> outstanding_ops;

BatchContext(unsigned capacity);

Expand Down Expand Up @@ -131,4 +219,6 @@ class BatchContextMap {
mutable std::shared_mutex batch_mutex;
};

static_assert(!std::is_abstract<BatchOperationFactory<BatchOperation>>::value, "Not concrete");

}
5 changes: 4 additions & 1 deletion hipfile/test/amd_detail/mbatch.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@

namespace hipFile {

class MBatchOperation : public IBatchOperation {
};

class MBatchContext : public IBatchContext {
public:
MOCK_METHOD(unsigned, get_capacity, (), (const, noexcept, override));
MOCK_METHOD(void, submit_operations, (const hipFileIOParams_t *params, const unsigned num_params),
MOCK_METHOD(void, submit_operations, (const hipFileIOParams_t *params, const unsigned num_params, IFactory<IBatchOperation> factory),
(override));
};

Expand Down
Loading