Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions _codeql_detected_source_root
78 changes: 45 additions & 33 deletions include/wil/result.h
Original file line number Diff line number Diff line change
Expand Up @@ -383,65 +383,77 @@ namespace details_abi

~ThreadLocalStorage() WI_NOEXCEPT
{
for (auto& entry : m_hashArray)
{
Node* pNode = entry;
while (pNode != nullptr)
{
auto pCurrent = pNode;
#pragma warning(push)
#pragma warning(disable : 6001) // https://github.com/microsoft/wil/issues/164
pNode = pNode->pNext;
#pragma warning(pop)
pCurrent->~Node();
::HeapFree(::GetProcessHeap(), 0, pCurrent);
}
entry = nullptr;
}
// Cleanup is automatic via unique_process_heap_ptr destructors
}

// Note: Can return nullptr even when (shouldAllocate == true) upon allocation failure
T* GetLocal(bool shouldAllocate = false) WI_NOEXCEPT
{
// Get the current thread ID
DWORD const threadId = ::GetCurrentThreadId();

// Determine the appropriate bucket for this thread
size_t const index = ((threadId >> 2) % ARRAYSIZE(m_hashArray)); // Reduce hash collisions; thread IDs are even.
for (auto pNode = m_hashArray[index]; pNode != nullptr; pNode = pNode->pNext)
Bucket& bucket = m_hashArray[index];

// Lock the bucket and search for an existing entry
{
if (pNode->threadId == threadId)
auto lock = bucket.lock.lock_shared();
for (auto pNode = bucket.head.get(); pNode != nullptr; pNode = pNode->pNext.get())
{
return &pNode->value;
if (pNode->threadId == threadId)
{
return &pNode->value;
}
}
}

if (shouldAllocate)
if (!shouldAllocate)
{
if (auto pNewRaw = details::ProcessHeapAlloc(0, sizeof(Node)))
{
auto pNew = new (pNewRaw) Node{threadId};
return nullptr;
}

Node* pFirst;
do
{
pFirst = m_hashArray[index];
pNew->pNext = pFirst;
} while (::InterlockedCompareExchangePointer(reinterpret_cast<PVOID volatile*>(m_hashArray + index), pNew, pFirst) !=
pFirst);
// No entry for us, make a new one and insert it at the head
wil::unique_process_heap_ptr<Node> newNode;
void* newMemory = details::ProcessHeapAlloc(0, sizeof(Node));
if (!newMemory)
{
return nullptr;
}
newNode.reset(new (newMemory) Node{threadId});

return &pNew->value;
auto lock = bucket.lock.lock_exclusive();

// Check again if an entry was created by another thread while we were waiting for the exclusive lock
for (auto pNode = bucket.head.get(); pNode != nullptr; pNode = pNode->pNext.get())
{
if (pNode->threadId == threadId)
{
return &pNode->value;
}
}
return nullptr;

// No duplicate entry, safe to insert
newNode->pNext = wistd::move(bucket.head);
bucket.head = wistd::move(newNode);
return &bucket.head->value;
}

private:
struct Node
{
DWORD threadId = 0xffffffffU;
Node* pNext = nullptr;
wil::unique_process_heap_ptr<Node> pNext;
T value{};
};

Node* volatile m_hashArray[10]{};
struct Bucket
{
wil::srwlock lock;
wil::unique_process_heap_ptr<Node> head;
};

Bucket m_hashArray[10]{};
};

struct ThreadLocalFailureInfo
Expand Down
62 changes: 62 additions & 0 deletions tests/ResultTests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@

#include "common.h"

#include <thread>
#include <vector>
#include <atomic>

static volatile long objectCount = 0;
struct SharedObject
{
Expand Down Expand Up @@ -773,3 +777,61 @@ TEST_CASE("ResultTests::ReportDoesNotChangeLastError", "[result]")
LOG_IF_WIN32_BOOL_FALSE(FALSE);
REQUIRE(::GetLastError() == ERROR_ABIOS_ERROR);
}

TEST_CASE("ResultTests::ThreadLocalStorage", "[result]")
{
constexpr int NUM_THREADS = 10;
constexpr int ITERATIONS = 1000;

wil::details_abi::ThreadLocalStorage<int> storage;
std::atomic<int> errors{0};
std::vector<std::thread> threads;

// Create multiple threads that will access the thread local storage concurrently
for (int i = 0; i < NUM_THREADS; ++i)
{
threads.emplace_back([&storage, &errors, i]() {
for (int j = 0; j < ITERATIONS; ++j)
{
// Get or create thread local value
int* pValue = storage.GetLocal(true);
if (!pValue)
{
errors.fetch_add(1);
continue;
}

// First time should be zero-initialized
if (j == 0 && *pValue != 0)
{
errors.fetch_add(1);
}

// Set a thread-specific value
*pValue = i * 1000 + j;

// Verify we get the same pointer on subsequent calls
int* pValue2 = storage.GetLocal(false);
if (pValue != pValue2)
{
errors.fetch_add(1);
}

// Verify the value is correct
if (*pValue2 != i * 1000 + j)
{
errors.fetch_add(1);
}
}
});
}

// Wait for all threads to complete
for (auto& thread : threads)
{
thread.join();
}

// Verify no errors occurred
REQUIRE(errors.load() == 0);
}
Loading