diff --git a/CMakeLists.txt b/CMakeLists.txt index 49ec7006..f6c50e6b 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -212,6 +212,12 @@ include(AISSanitizers) #------------------ option(AIS_INSTALL_EXAMPLES "Install example programs" ON) +#------------------ +# Tool programs +#------------------ +option(AIS_INSTALL_TOOLS "Install tool programs" ON) + + #------ # docs #------ @@ -282,6 +288,12 @@ if(AIS_INSTALL_EXAMPLES) add_subdirectory(examples/aiscp) endif() +# Add tools directory +if(AIS_INSTALL_TOOLS AND AIS_BUILD_AMD_DETAIL) + add_subdirectory(tools/ais-stats) +endif() + + #----------------------------------------------------------------------------- # Set up installs (must come AFTER target creation) #----------------------------------------------------------------------------- diff --git a/src/amd_detail/CMakeLists.txt b/src/amd_detail/CMakeLists.txt index 75b28d2e..c264f630 100644 --- a/src/amd_detail/CMakeLists.txt +++ b/src/amd_detail/CMakeLists.txt @@ -23,6 +23,7 @@ set(HIPFILE_SOURCES io.cpp mountinfo.cpp state.cpp + stats.cpp stream.cpp sys.cpp ) diff --git a/src/amd_detail/backend/fallback.cpp b/src/amd_detail/backend/fallback.cpp index aa386f31..1d2caa73 100644 --- a/src/amd_detail/backend/fallback.cpp +++ b/src/amd_detail/backend/fallback.cpp @@ -14,6 +14,7 @@ #include "hip.h" #include "hipfile.h" #include "io.h" +#include "stats.h" #include "sys.h" #include "stream.h" #include "util.h" @@ -114,6 +115,16 @@ Fallback::io(IoType io_type, shared_ptr file, shared_ptr buffer, } } while (static_cast(total_io_bytes) < size); + switch (io_type) { + case IoType::Read: + statsAddFallbackPathRead(static_cast(total_io_bytes)); + break; + case IoType::Write: + statsAddFallbackPathWrite(static_cast(total_io_bytes)); + break; + default: + break; + } return total_io_bytes; } diff --git a/src/amd_detail/backend/fastpath.cpp b/src/amd_detail/backend/fastpath.cpp index 192ce3de..aba587c8 100644 --- a/src/amd_detail/backend/fastpath.cpp +++ b/src/amd_detail/backend/fastpath.cpp @@ -10,6 +10,7 @@ #include "hip.h" #include "hipfile.h" #include "io.h" +#include "stats.h" #include #include @@ -183,6 +184,15 @@ Fastpath::io(IoType type, shared_ptr file, shared_ptr buffer, si default: throw std::runtime_error("Invalid IoType"); } - + switch (type) { + case IoType::Read: + statsAddFastPathRead(nbytes); + break; + case IoType::Write: + statsAddFastPathWrite(nbytes); + break; + default: + break; + } return static_cast(nbytes); } diff --git a/src/amd_detail/context.cpp b/src/amd_detail/context.cpp index 4e909938..5adc1d87 100644 --- a/src/amd_detail/context.cpp +++ b/src/amd_detail/context.cpp @@ -7,6 +7,7 @@ #include "hip.h" #include "hipfile-warnings.h" #include "state.h" +#include "stats.h" #include "sys.h" namespace hipFile { @@ -15,6 +16,7 @@ HipFileInit::HipFileInit() { Context::get(); Context::get(); + Context::get(); Context::get(); } diff --git a/src/amd_detail/stats.cpp b/src/amd_detail/stats.cpp new file mode 100644 index 00000000..619c45c6 --- /dev/null +++ b/src/amd_detail/stats.cpp @@ -0,0 +1,273 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ + +#include "context.h" +#include "stats.h" +#include "sys.h" + +#include +#include +#include +#include +#include +#include +#include + +static int +sendFd(int sock, int fd) noexcept +{ + int data{1}; + iovec iov{&data, sizeof(data)}; + msghdr msgh; + cmsghdr *cmsgp; + + union { + char buff[CMSG_SPACE(sizeof(int))]; + cmsghdr align; + } controlMsg; + + msgh.msg_name = nullptr; + msgh.msg_namelen = 0; + msgh.msg_iov = &iov; + msgh.msg_iovlen = 1; + msgh.msg_control = controlMsg.buff; + msgh.msg_controllen = sizeof(controlMsg.buff); + + cmsgp = CMSG_FIRSTHDR(&msgh); + cmsgp->cmsg_level = SOL_SOCKET; + cmsgp->cmsg_type = SCM_RIGHTS; + cmsgp->cmsg_len = CMSG_LEN(sizeof(int)); + memcpy(CMSG_DATA(cmsgp), &fd, sizeof(int)); + + if (sendmsg(sock, &msgh, 0) == -1) + return -1; + + return 0; +} + +static int +recvFd(int sockfd) noexcept +{ + int data, fd; + iovec iov{&data, sizeof(data)}; + msghdr msgh; + cmsghdr *cmsgp; + + union { + char buff[CMSG_SPACE(sizeof(int))]; + cmsghdr align; + } controlMsg; + + msgh.msg_name = nullptr; + msgh.msg_namelen = 0; + msgh.msg_iov = &iov; + msgh.msg_iovlen = 1; + msgh.msg_control = controlMsg.buff; + msgh.msg_controllen = sizeof(controlMsg.buff); + + if (recvmsg(sockfd, &msgh, 0) == -1) + return -1; + + cmsgp = CMSG_FIRSTHDR(&msgh); + if (cmsgp == NULL || cmsgp->cmsg_len != CMSG_LEN(sizeof(int)) || cmsgp->cmsg_level != SOL_SOCKET || + cmsgp->cmsg_type != SCM_RIGHTS) { + errno = EINVAL; + return -1; + } + memcpy(&fd, CMSG_DATA(cmsgp), sizeof(int)); + return fd; +} + +static void +populateSocketAddr(sockaddr_un &addr, pid_t pid) noexcept +{ + memset(&addr, 0, sizeof(addr)); + addr.sun_family = AF_UNIX; + addr.sun_path[0] = '\0'; // abstract namespace + snprintf(&(addr.sun_path[1]), sizeof(addr.sun_path) - 1, "AISSTATS%08x", static_cast(pid)); +} + +namespace hipFile { +StatsServer::StatsServer() + : m_fd{FileDescriptor::make_managed(Context::get()->memfd_create("AISSTATS", MFD_ALLOW_SEALING))}, + m_efd{FileDescriptor::make_managed(Context::get()->eventfd(0, 0))}, m_stats{nullptr, &statsDeleter} +{ + int fd{m_fd.get()}; + Context::get()->ftruncate(fd, sizeof(Stats)); + void *shm = Context::get()->mmap(nullptr, sizeof(Stats), PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + Context::get()->fcntl(fd, F_ADD_SEALS, F_SEAL_SHRINK | F_SEAL_FUTURE_WRITE); + m_stats = UniqueStats{new (shm) Stats{}, &statsDeleter}; + m_thread = std::thread(&StatsServer::threadFn, this); +} + +StatsServer::~StatsServer() +{ + if (m_thread.joinable()) { + uint64_t i{1}; + write(m_efd.get(), &i, sizeof(i)); + m_thread.join(); + } +} + +void +StatsServer::statsDeleter(Stats *s) +{ + if (s == nullptr) { + return; + } + s->~Stats(); + Context::get()->munmap(s, sizeof(Stats)); +} + +void +StatsServer::threadFn() +{ + int sock{socket(AF_UNIX, SOCK_STREAM, 0)}; + pid_t pid{getpid()}; + if (sock == -1) { + return; + } + sockaddr_un addr; + populateSocketAddr(addr, pid); + if (bind(sock, reinterpret_cast(&addr), sizeof(addr)) == -1) { + return; + } + if (listen(sock, 64) == -1) { + return; + } + while (true) { + pollfd pfd[2]{{sock, POLLIN, 0}, {m_efd.get(), POLLIN, 0}}; + poll(&pfd[0], 2, -1); + if (pfd[0].revents & POLLIN) { + socklen_t addrlen{sizeof(addr)}; + int conn{accept(sock, reinterpret_cast(&addr), &addrlen)}; + if (conn == -1) { + continue; + } + sendFd(conn, m_fd.get()); + close(conn); + } + if (pfd[1].revents & POLLIN) { + break; + } + } + close(sock); +} + +StatsClient::StatsClient(pid_t p) + : m_pfd{FileDescriptor::make_managed(Context::get()->pidfd_open(p, 0))}, m_pid{p} +{ +} + +bool +StatsClient::pollProcess(int timeout) +{ + if (m_pfd.get() == -1) { + return true; + } + pollfd pfd{m_pfd.get(), POLLIN, 0}; + return poll(&pfd, 1, timeout) > 0; +} + +bool +StatsClient::connectServer() +{ + FileDescriptor sock{FileDescriptor::make_managed(socket(AF_UNIX, SOCK_STREAM, 0))}; + if (sock.get() == -1) { + return false; + } + int success{-1}; + + sockaddr_un addr; + populateSocketAddr(addr, m_pid); + for (int timeout{0}; !pollProcess(timeout); ++timeout) { // backoff on connect attempts + success = connect(sock.get(), reinterpret_cast(&addr), sizeof(struct sockaddr_un)); + if (success == 0) { + break; + } + } + if (success == -1) { + return false; + } + m_sfd = FileDescriptor::make_managed(recvFd(sock.get())); + return true; +} + +bool +StatsClient::generateReport(std::ostream &stream) +{ + if (m_sfd.get() == -1) { + return false; + } + void *shm = mmap(nullptr, sizeof(Stats), PROT_READ, MAP_SHARED, m_sfd.get(), 0); + if (shm == reinterpret_cast(-1)) { + return false; + } + Stats *stats = reinterpret_cast(shm); + if (stats == nullptr) { + return false; + } + switch (stats->version) { + case 1: + generateReportV1(stream, stats); + break; + default: + break; + } + munmap(stats, sizeof(Stats)); + return true; +} + +void +StatsClient::generateReportV1(std::ostream &stream, const Stats *stats) +{ + if (stats == nullptr) { + return; + } + stream << "Total fast path reads (B): " << stats->getCounter(StatsCounters::TotalFastPathReadBytes).load() + << "\nTotal fast path writes (B): " + << stats->getCounter(StatsCounters::TotalFastPathWriteBytes).load() + << "\nTotal fallback path reads (B): " + << stats->getCounter(StatsCounters::TotalFallbackPathReadBytes).load() + << "\nTotal fallback path writes (B): " + << stats->getCounter(StatsCounters::TotalFallbackPathWriteBytes).load() << '\n'; +} + +void +statsAddFastPathRead(uint64_t bytes) +{ + Stats *stats{Context::get()->getStats()}; + if (stats) { + stats->getCounter(StatsCounters::TotalFastPathReadBytes) += bytes; + } +} + +void +statsAddFastPathWrite(uint64_t bytes) +{ + Stats *stats{Context::get()->getStats()}; + if (stats) { + stats->getCounter(StatsCounters::TotalFastPathWriteBytes) += bytes; + } +} + +void +statsAddFallbackPathRead(uint64_t bytes) +{ + Stats *stats{Context::get()->getStats()}; + if (stats) { + stats->getCounter(StatsCounters::TotalFallbackPathReadBytes) += bytes; + } +} + +void +statsAddFallbackPathWrite(uint64_t bytes) +{ + Stats *stats{Context::get()->getStats()}; + if (stats) { + stats->getCounter(StatsCounters::TotalFallbackPathWriteBytes) += bytes; + } +} +} diff --git a/src/amd_detail/stats.h b/src/amd_detail/stats.h new file mode 100644 index 00000000..43ccf70b --- /dev/null +++ b/src/amd_detail/stats.h @@ -0,0 +1,83 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ +#pragma once + +#include "file-descriptor.h" + +#include +#include +#include +#include +#include + +namespace hipFile { + +/// When adding new fields, remember to increment Stats::version +enum class StatsCounters { + TotalFastPathReadBytes, + TotalFastPathWriteBytes, + TotalFallbackPathReadBytes, + TotalFallbackPathWriteBytes, + + Max, + Capacity = 64, +}; + +static_assert(StatsCounters::Max <= StatsCounters::Capacity, "Increase StatsCounters::Capacity"); + +struct Stats { + using Array = std::array(StatsCounters::Capacity)>; + const uint64_t version{1}; + Array counters; + + std::atomic_uint64_t &getCounter(StatsCounters index) + { + return counters[static_cast(index)]; + } + const std::atomic_uint64_t &getCounter(StatsCounters index) const + { + return counters[static_cast(index)]; + } +}; + +void statsAddFastPathRead(uint64_t bytes); +void statsAddFastPathWrite(uint64_t bytes); +void statsAddFallbackPathRead(uint64_t bytes); +void statsAddFallbackPathWrite(uint64_t bytes); + +class StatsServer { +public: + StatsServer(); + virtual ~StatsServer(); + virtual Stats *getStats() + { + return m_stats.get(); + } + + static void statsDeleter(Stats *s); + using UniqueStats = std::unique_ptr; + +private: + void threadFn(); + FileDescriptor m_fd; + FileDescriptor m_efd; + UniqueStats m_stats; + std::thread m_thread; +}; + +class StatsClient { +public: + explicit StatsClient(pid_t p); + bool pollProcess(int timeout); + bool connectServer(); + bool generateReport(std::ostream &stream); + + static void generateReportV1(std::ostream &stream, const Stats *stats); + +private: + FileDescriptor m_pfd, m_sfd; + pid_t m_pid{0}; +}; +} diff --git a/src/amd_detail/sys.cpp b/src/amd_detail/sys.cpp index 23e7a7d7..912a2733 100644 --- a/src/amd_detail/sys.cpp +++ b/src/amd_detail/sys.cpp @@ -8,8 +8,10 @@ #include #include #include +#include #include #include // IWYU pragma: keep +#include #include #include #include @@ -102,6 +104,12 @@ Sys::fcntl(int fd, int op, uintptr_t arg) const return throwOn(-1, ::fcntl(fd, op, arg)); } +void +Sys::ftruncate(int fd, off_t offset) const +{ + throwOn(-1, ::ftruncate(fd, offset)); +} + struct statx Sys::statx(int dirfd, const char *pathname, int flags, unsigned int mask) const { @@ -116,4 +124,22 @@ Sys::getenv(const char *name) const noexcept return ::getenv(name); } +int +Sys::memfd_create(const char *name, unsigned int flags) const +{ + return throwOn(-1, ::memfd_create(name, flags)); +} + +int +Sys::eventfd(unsigned int initval, int flags) const +{ + return throwOn(-1, ::eventfd(initval, flags)); +} + +int +Sys::pidfd_open(pid_t pid, unsigned int flags) const +{ + return throwOn(-1, static_cast(::syscall(SYS_pidfd_open, pid, flags))); +} + } diff --git a/src/amd_detail/sys.h b/src/amd_detail/sys.h index 2b3b8512..81fbc660 100644 --- a/src/amd_detail/sys.h +++ b/src/amd_detail/sys.h @@ -40,9 +40,14 @@ struct Sys { virtual struct stat fstat(int fd) const; virtual struct statx statx(int dirfd, const char *pathname, int flags, unsigned int mask) const; - virtual int fcntl(int fd, int op, uintptr_t arg) const; + virtual int fcntl(int fd, int op, uintptr_t arg) const; + virtual void ftruncate(int fd, off_t offset) const; virtual char *getenv(const char *name) const noexcept; + + virtual int memfd_create(const char *name, unsigned int flags) const; + virtual int eventfd(unsigned int initval, int flags) const; + virtual int pidfd_open(pid_t pid, unsigned int flags) const; }; } diff --git a/test/amd_detail/CMakeLists.txt b/test/amd_detail/CMakeLists.txt index 07f61c23..c44514f5 100644 --- a/test/amd_detail/CMakeLists.txt +++ b/test/amd_detail/CMakeLists.txt @@ -26,6 +26,7 @@ set(TEST_SOURCE_FILES fastpath.cpp main.cpp mountinfo.cpp + stats.cpp stream.cpp ) diff --git a/test/amd_detail/mstats.h b/test/amd_detail/mstats.h new file mode 100644 index 00000000..503f5e16 --- /dev/null +++ b/test/amd_detail/mstats.h @@ -0,0 +1,23 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ + +#pragma once + +#include "context.h" +#include "stats.h" + +#include + +namespace hipFile { +class MStatsServer : public StatsServer { + ContextOverride co; + +public: + MStatsServer() : co{this} + { + } + MOCK_METHOD(Stats *, getStats, (), (override)); +}; +} diff --git a/test/amd_detail/msys.h b/test/amd_detail/msys.h index 27277d7a..e921e34b 100644 --- a/test/amd_detail/msys.h +++ b/test/amd_detail/msys.h @@ -37,7 +37,12 @@ struct MSys : Sys { (const, override)); MOCK_METHOD(int, fcntl, (int fd, int op, uintptr_t arg), (const, override)); + MOCK_METHOD(void, ftruncate, (int fd, off_t offset), (const, override)); MOCK_METHOD(char *, getenv, (const char *name), (const, noexcept, override)); + + MOCK_METHOD(int, memfd_create, (const char *name, unsigned int flags), (const, override)); + MOCK_METHOD(int, eventfd, (unsigned int, int), (const, override)); + MOCK_METHOD(int, pidfd_open, (pid_t, unsigned int), (const, override)); }; } diff --git a/test/amd_detail/stats.cpp b/test/amd_detail/stats.cpp new file mode 100644 index 00000000..d00a8690 --- /dev/null +++ b/test/amd_detail/stats.cpp @@ -0,0 +1,72 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ + +#include "hipfile-test.h" +#include "mstats.h" +#include "stats.h" +#include "msys.h" + +#include +#include +#include + +using namespace hipFile; + +using ::testing::StrictMock; + +// Put tests inside the macros to suppress the global constructor +// warnings +HIPFILE_WARN_NO_GLOBAL_CTOR_OFF + +struct HipFileStats : public HipFileUnopened {}; + +#define STAT_TEST(name) \ + TEST_F(HipFileStats, statsAdd##name) \ + { \ + Stats stats{}; \ + StrictMock mstats{}; \ + EXPECT_CALL(mstats, getStats).WillRepeatedly(testing::Return(&stats)); \ + statsAdd##name(0x10); \ + ASSERT_EQ(0x10, stats.getCounter(StatsCounters::Total##name##Bytes).load()); \ + statsAdd##name(0x10); \ + ASSERT_EQ(0x20, stats.getCounter(StatsCounters::Total##name##Bytes).load()); \ + } + +STAT_TEST(FastPathRead) +STAT_TEST(FastPathWrite) +STAT_TEST(FallbackPathRead) +STAT_TEST(FallbackPathWrite) + +TEST_F(HipFileStats, StatsServerLifetime) +{ + StrictMock msys{}; + char buff[sizeof(Stats)]; + EXPECT_CALL(msys, memfd_create).WillOnce(testing::Return(10)); + EXPECT_CALL(msys, eventfd).WillOnce(testing::Return(11)); + EXPECT_CALL(msys, fcntl).WillOnce(testing::Return(0)); + EXPECT_CALL(msys, ftruncate); + EXPECT_CALL(msys, mmap).WillOnce(testing::Return(&buff)); + EXPECT_CALL(msys, munmap); + EXPECT_CALL(msys, close).Times(2); + StatsServer srvr{}; +} + +TEST_F(HipFileStats, GenerateReportV1) +{ + Stats stats{}; + std::ostringstream os{}; + stats.getCounter(StatsCounters::TotalFastPathReadBytes) = 2; + stats.getCounter(StatsCounters::TotalFastPathWriteBytes) = 4; + stats.getCounter(StatsCounters::TotalFallbackPathReadBytes) = 6; + stats.getCounter(StatsCounters::TotalFallbackPathWriteBytes) = 8; + StatsClient::generateReportV1(os, &stats); + std::string str{os.str()}; + ASSERT_GT(std::string::npos, str.find('2')); + ASSERT_GT(std::string::npos, str.find('4')); + ASSERT_GT(std::string::npos, str.find('6')); + ASSERT_GT(std::string::npos, str.find('8')); +} + +HIPFILE_WARN_NO_GLOBAL_CTOR_ON diff --git a/tools/ais-stats/CMakeLists.txt b/tools/ais-stats/CMakeLists.txt new file mode 100644 index 00000000..4d268b51 --- /dev/null +++ b/tools/ais-stats/CMakeLists.txt @@ -0,0 +1,12 @@ +# Copyright (c) Advanced Micro Devices, Inc. All rights reserved. +# +# SPDX-License-Identifier: MIT + +include(AISAddExecutable) + +ais_add_executable( + NAME ais-stats + DEPS hipfile + SRCS "ais-stats.cpp" + SYSINCLS ${HIPFILE_AMD_SOURCE_PATH} ${HIPFILE_INCLUDE_PATH} +) diff --git a/tools/ais-stats/ais-stats.cpp b/tools/ais-stats/ais-stats.cpp new file mode 100644 index 00000000..46543394 --- /dev/null +++ b/tools/ais-stats/ais-stats.cpp @@ -0,0 +1,63 @@ +/* Copyright (c) Advanced Micro Devices, Inc. All rights reserved. + * + * SPDX-License-Identifier: MIT + */ + +#include "stats.h" +#include +#include +#include + +static void +printUsage() +{ + std::cerr << "Usage:\n" + << " ais-stats [args...]\n" + << " ais-stats -p [-i]\n" + << " -p Process to collect data from\n" + << " -i Generate report immediately instead of waiting for process to exit\n"; +} + +int +main(int argc, char *argv[]) +{ + pid_t pid{-1}; + bool imm{false}; + if (argc < 2) { + printUsage(); + return 1; + } + if (std::strcmp(argv[1], "-p") == 0) { + if (argc < 3) { + printUsage(); + return 1; + } + pid = std::atoi(argv[2]); + imm = argc == 4 && std::strcmp(argv[3], "-i") == 0; + } + else { + pid = fork(); + if (pid < 0) { + std::cerr << "Failed to launch " << argv[1] << '\n'; + return 1; + } + if (pid == 0) { + execvp(argv[1], &argv[1]); + std::cerr << "Failed to launch " << argv[1] << '\n'; + return 1; + } + } + hipFile::StatsClient client{pid}; + if (!client.connectServer()) { + std::cerr << "Failed to collect info from target process.\n"; + return 1; + } + if (!imm) { + client.pollProcess(-1); + } + if (!client.generateReport(std::cout)) { + std::cerr << "No stats could be collected from target process.\n"; + return 1; + } + return 0; +} diff --git a/util/format-source.sh b/util/format-source.sh index 405a401d..afff659d 100755 --- a/util/format-source.sh +++ b/util/format-source.sh @@ -11,6 +11,7 @@ DIRS=( "examples" "src" "test" + "tools" ) if [ $# -eq 1 ]; then