Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
90 changes: 59 additions & 31 deletions cli/fossilize_prune.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ struct PruneReplayer : StateCreatorInterface
unordered_map<Hash, const VkPipelineLayoutCreateInfo *> pipeline_layouts;
unordered_map<Hash, const VkRayTracingPipelineCreateInfoKHR *> raytracing_pipelines;
unordered_map<Hash, const VkGraphicsPipelineCreateInfo *> graphics_pipelines;
unordered_map<Hash, const VkComputePipelineCreateInfo *> compute_pipelines;
unordered_map<Hash, const VkRayTracingPipelineCreateInfoKHR *> library_raytracing_pipelines;
unordered_map<Hash, const VkGraphicsPipelineCreateInfo *> library_graphics_pipelines;

Expand Down Expand Up @@ -310,12 +311,9 @@ struct PruneReplayer : StateCreatorInterface
bool allow_pipeline = filter_shader_module((Hash)create_info->stage.module);

if (allow_pipeline)
{
access_pipeline_layout((Hash) create_info->layout);
accessed_shader_modules.insert((Hash) create_info->stage.module);
accessed_compute_pipelines.insert(hash);
}
compute_pipelines[hash] = create_info;
}

return true;
}

Expand Down Expand Up @@ -384,20 +382,27 @@ struct PruneReplayer : StateCreatorInterface

// Need to defer this since we need to access pipeline libraries.
if (allow_pipeline)
graphics_pipelines[hash] = create_info;
else if ((create_info->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) != 0)
library_graphics_pipelines[hash] = create_info;
{
if ((create_info->flags & VK_PIPELINE_CREATE_LIBRARY_BIT_KHR) != 0)
library_graphics_pipelines[hash] = create_info;
else
graphics_pipelines[hash] = create_info;
}

return true;
}

void access_graphics_pipeline(Hash hash, const VkGraphicsPipelineCreateInfo *create_info)
void access_graphics_pipeline(DatabaseInterface &iface, Hash hash, const VkGraphicsPipelineCreateInfo *create_info)
{
if (accessed_graphics_pipelines.count(hash))
return;
accessed_graphics_pipelines.insert(hash);

if (!iface.has_entry(RESOURCE_GRAPHICS_PIPELINE, hash))
return;

accessed_graphics_pipelines.insert(hash);
access_pipeline_layout((Hash) create_info->layout);

if (create_info->renderPass != VK_NULL_HANDLE)
accessed_render_passes.insert((Hash) create_info->renderPass);
for (uint32_t stage = 0; stage < create_info->stageCount; stage++)
Expand All @@ -414,18 +419,25 @@ struct PruneReplayer : StateCreatorInterface
// Only need to recurse if a pipeline was only allowed due to having CREATE_LIBRARY_KHR flag.
auto lib_itr = library_graphics_pipelines.find((Hash) library_info->pLibraries[i]);
if (lib_itr != library_graphics_pipelines.end())
access_graphics_pipeline(lib_itr->first, lib_itr->second);
{
iface.add_to_implicit_whitelist(RESOURCE_GRAPHICS_PIPELINE, lib_itr->first);
access_graphics_pipeline(iface, lib_itr->first, lib_itr->second);
}
}
}
}

void access_raytracing_pipeline(Hash hash, const VkRayTracingPipelineCreateInfoKHR *create_info)
void access_raytracing_pipeline(DatabaseInterface &iface, Hash hash, const VkRayTracingPipelineCreateInfoKHR *create_info)
{
if (accessed_raytracing_pipelines.count(hash))
return;
accessed_raytracing_pipelines.insert(hash);

if (!iface.has_entry(RESOURCE_RAYTRACING_PIPELINE, hash))
return;

accessed_raytracing_pipelines.insert(hash);
access_pipeline_layout((Hash) create_info->layout);

for (uint32_t stage = 0; stage < create_info->stageCount; stage++)
accessed_shader_modules.insert((Hash) create_info->pStages[stage].module);

Expand All @@ -436,21 +448,40 @@ struct PruneReplayer : StateCreatorInterface
// Only need to recurse if a pipeline was only allowed due to having CREATE_LIBRARY_KHR flag.
auto lib_itr = library_raytracing_pipelines.find((Hash) create_info->pLibraryInfo->pLibraries[i]);
if (lib_itr != library_raytracing_pipelines.end())
access_raytracing_pipeline(lib_itr->first, lib_itr->second);
{
iface.add_to_implicit_whitelist(RESOURCE_RAYTRACING_PIPELINE, lib_itr->first);
access_raytracing_pipeline(iface, lib_itr->first, lib_itr->second);
}
}
}
}

void access_graphics_pipelines()
void access_compute_pipeline(DatabaseInterface &iface, Hash hash, const VkComputePipelineCreateInfo *create_info)
{
if (iface.has_entry(RESOURCE_COMPUTE_PIPELINE, hash))
{
access_pipeline_layout((Hash) create_info->layout);
accessed_shader_modules.insert((Hash) create_info->stage.module);
accessed_compute_pipelines.insert(hash);
}
}

void access_graphics_pipelines(DatabaseInterface &iface)
{
for (auto &pipe : graphics_pipelines)
access_graphics_pipeline(pipe.first, pipe.second);
access_graphics_pipeline(iface, pipe.first, pipe.second);
}

void access_raytracing_pipelines()
void access_raytracing_pipelines(DatabaseInterface &iface)
{
for (auto &pipe : raytracing_pipelines)
access_raytracing_pipeline(pipe.first, pipe.second);
access_raytracing_pipeline(iface, pipe.first, pipe.second);
}

void access_compute_pipelines(DatabaseInterface &iface)
{
for (auto &pipe : compute_pipelines)
access_compute_pipeline(iface, pipe.first, pipe.second);
}

bool enqueue_create_raytracing_pipeline(Hash hash, const VkRayTracingPipelineCreateInfoKHR *create_info, VkPipeline *pipeline) override
Expand Down Expand Up @@ -639,15 +670,6 @@ int main(int argc, char *argv[])
auto input_db = std::unique_ptr<DatabaseInterface>(create_database(input_db_path.c_str(), DatabaseMode::ReadOnly));
auto output_db = std::unique_ptr<DatabaseInterface>(create_database(output_db_path.c_str(), DatabaseMode::OverWrite));

if (input_db && !whitelist.empty())
{
if (!input_db->load_whitelist_database(whitelist.c_str()))
{
LOGE("Failed to install whitelist database %s.\n", whitelist.c_str());
return EXIT_FAILURE;
}
}

if (input_db && !blacklist.empty())
{
if (!input_db->load_blacklist_database(blacklist.c_str()))
Expand Down Expand Up @@ -822,13 +844,19 @@ int main(int argc, char *argv[])
}
}
}
}

if (tag == RESOURCE_GRAPHICS_PIPELINE)
prune_replayer.access_graphics_pipelines();
else if (tag == RESOURCE_RAYTRACING_PIPELINE)
prune_replayer.access_raytracing_pipelines();
// Load whitelist in order to collect accessed pipelines
if (!whitelist.empty() && !input_db->load_whitelist_database(whitelist.c_str()))
{
LOGE("Failed to install whitelist database %s.\n", whitelist.c_str());
return EXIT_FAILURE;
}

prune_replayer.access_compute_pipelines(*input_db);
prune_replayer.access_graphics_pipelines(*input_db);
prune_replayer.access_raytracing_pipelines(*input_db);

if (invert_module_pruning)
{
// In this mode we're only interesting in emitting the shader modules we did not emit for whatever reason.
Expand Down
5 changes: 5 additions & 0 deletions fossilize_db.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -265,6 +265,11 @@ bool DatabaseInterface::add_to_implicit_whitelist(DatabaseInterface &iface)
return true;
}

void DatabaseInterface::add_to_implicit_whitelist(ResourceTag tag, Hash hash)
{
impl->implicit_whitelisted[tag].insert(hash);
}

DatabaseInterface::~DatabaseInterface()
{
#ifdef _WIN32
Expand Down
2 changes: 2 additions & 0 deletions fossilize_db.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,8 @@ class DatabaseInterface
void set_overwrite_db_clear(bool value);
bool get_overwrite_db_clear() const;

void add_to_implicit_whitelist(ResourceTag tag, Hash hash);

protected:
bool test_resource_filter(ResourceTag tag, Hash hash) const;
bool add_to_implicit_whitelist(DatabaseInterface &iface);
Expand Down