Skip to content

Commit 655e8f7

Browse files
committed
better transitions
rtx broken for now
1 parent fb635cd commit 655e8f7

File tree

21 files changed

+852
-401
lines changed

21 files changed

+852
-401
lines changed

sources/DirectXFramework/DX12/CommandList.cpp

Lines changed: 93 additions & 86 deletions
Large diffs are not rendered by default.

sources/DirectXFramework/DX12/CommandList.h

Lines changed: 215 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -336,32 +336,132 @@ namespace DX12
336336

337337
class TransitionCommandList;
338338

339+
340+
struct TransitionPoint
341+
{
342+
std::list<Transition> transitions;
343+
std::list<Resource*> uav_transitions;
344+
std::list<Resource*> aliasing;
345+
346+
};
347+
348+
enum class TransitionType:int
349+
{
350+
ZERO,
351+
FIRST,
352+
LAST
353+
};
339354
class Transitions : public virtual CommandListBase
340355
{
341-
std::vector<D3D12_RESOURCE_BARRIER> transitions;
342-
//unsigned int transition_count = 0;
356+
343357

344358
std::list<Resource*> used_resources;
345359
std::list<TrackedResource::ptr> tracked_resources;
346360

347361
std::shared_ptr<TransitionCommandList> transition_list;
348362

349-
363+
friend class SignatureDataSetter;
364+
friend class Sendable;
365+
friend class Eventer;
366+
friend class ResourceStateManager;
350367
protected:
351-
void reset();
368+
void begin();
369+
void on_execute();
352370
std::list<ComPtr<ID3D12PipelineState>> tracked_psos;
353371

372+
std::list<TransitionPoint> transition_points;
373+
TransitionPoint zero_tranzition;
374+
375+
void create_transition_point()
376+
{
377+
auto point = &transition_points.emplace_back();
378+
bool first = transition_points.size() == 1;
379+
// auto &point = transition_points.back();
380+
compiler.func([point, first](ID3D12GraphicsCommandList4* list)
381+
{
382+
383+
std::vector<D3D12_RESOURCE_BARRIER> transitions;
384+
for(auto uav:point->uav_transitions)
385+
{
386+
transitions.emplace_back(CD3DX12_RESOURCE_BARRIER::UAV(uav->get_native().Get()));
387+
}
388+
389+
for (auto& uav : point->aliasing)
390+
{
391+
transitions.emplace_back(CD3DX12_RESOURCE_BARRIER::Aliasing(nullptr, uav->get_native().Get()));
392+
}
393+
394+
for (auto& transition : point->transitions)
395+
{
396+
auto prev_transition = transition.prev_transition;
397+
398+
if (!prev_transition)
399+
continue;
400+
401+
if(prev_transition->wanted_state == transition.wanted_state)
402+
continue;
403+
404+
405+
transitions.emplace_back(CD3DX12_RESOURCE_BARRIER::Transition(transition.resource->get_native().Get(),
406+
static_cast<D3D12_RESOURCE_STATES>(prev_transition->wanted_state),
407+
static_cast<D3D12_RESOURCE_STATES>(transition.wanted_state),
408+
transition.subres));
409+
}
410+
411+
if (!transitions.empty())
412+
{
413+
list->ResourceBarrier((UINT)transitions.size(), transitions.data());
414+
transitions.clear();
415+
}
416+
});
417+
}
418+
419+
public:
420+
void transition(const Resource* resource, ResourceState state, UINT subres = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES);
421+
void transition(const Resource::ptr& resource, ResourceState state, UINT subres = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES);
422+
354423
public:
355424
void free_resources();
356425
std::list<ComPtr<ID3D12Heap>> tracked_heaps;
357-
void flush_transitions();
426+
UINT transition_count = 0;
427+
428+
Transition* create_transition(const Resource* resource, UINT subres = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES, TransitionType type = TransitionType::LAST)
429+
{
430+
TransitionPoint* point = nullptr;
431+
432+
if (type == TransitionType::FIRST) point = &transition_points.front();
433+
if (type == TransitionType::LAST) point = &transition_points.back();
434+
if (type == TransitionType::ZERO) point = &zero_tranzition;
435+
436+
437+
438+
Transition& transition = point->transitions.emplace_back();
439+
440+
// transition.index = 0æ..first?0:transition_points.size() - 1;
441+
transition.resource = const_cast<Resource*>(resource);
442+
transition.subres = subres;
443+
444+
return &transition;
445+
}
446+
447+
void create_uav_transition(const Resource* resource)
448+
{
449+
auto& point = transition_points.back();
450+
point.uav_transitions.emplace_back(const_cast<Resource*>(resource));
451+
}
452+
453+
void create_aliasing_transition(const Resource* resource)
454+
{
455+
auto& point = transition_points.back();
456+
point.aliasing.emplace_back(const_cast<Resource*>(resource));
457+
}
358458

359-
void transition(const Resource* resource, ResourceState state, UINT subres = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES);
360-
void transition(const Resource::ptr& resource, ResourceState state, UINT subres = D3D12_RESOURCE_BARRIER_ALL_SUBRESOURCES);
361459
void use_resource(const Resource* resource);
362460
public:
363461
void prepare_transitions(Transitions* to, bool all);
364462

463+
464+
void merge_transition(Transitions* to, Resource* res);
365465
void transition_uav(Resource* resource);
366466
void transition(Resource* from, Resource* to);
367467
std::shared_ptr<TransitionCommandList> fix_pretransitions();
@@ -654,7 +754,7 @@ namespace DX12
654754
};
655755

656756

657-
class Eventer : public virtual CommandListBase, public TimedRoot
757+
class Eventer : public virtual CommandListBase, public TimedRoot
658758
{
659759

660760
std::list<std::wstring> names;
@@ -862,18 +962,18 @@ namespace DX12
862962

863963
void clear_uav(const Handle& h, vec4 ClearColor = vec4(0, 0, 0, 0))
864964
{
965+
create_transition_point();
865966
transition_uav(h.resource_info);
866-
867-
flush_transitions();
967+
868968
auto handle = get_cpu_heap(DescriptorHeapType::CBV_SRV_UAV).place(h);
869969
get_native_list()->ClearUnorderedAccessViewFloat(handle.gpu, h.cpu, h.resource_info->resource_ptr->get_native().Get(), reinterpret_cast<FLOAT*>(ClearColor.data()), 0, nullptr);
870970
}
871971

872972

873973
void clear_rtv(const Handle& h, vec4 ClearColor = vec4(0, 0, 0, 0))
874974
{
975+
create_transition_point();
875976
transition_rtv(h.resource_info);
876-
flush_transitions();
877977
get_native_list()->ClearRenderTargetView(h.cpu, ClearColor.data(), 0, nullptr);
878978
}
879979

@@ -886,15 +986,17 @@ namespace DX12
886986

887987
void clear_stencil(Handle dsv, UINT8 stencil = 0)
888988
{
989+
create_transition_point();
889990
transition_dsv(dsv.resource_info);
890-
flush_transitions();
991+
891992
get_native_list()->ClearDepthStencilView(dsv.cpu, D3D12_CLEAR_FLAG_STENCIL, 0, stencil, 0, nullptr);
892993
}
893994

894995
void clear_depth(Handle dsv, float depth = 0)
895996
{
997+
create_transition_point();
896998
transition_dsv(dsv.resource_info);
897-
flush_transitions();
999+
8981000
get_native_list()->ClearDepthStencilView(dsv.cpu, D3D12_CLEAR_FLAG_DEPTH, depth, 0, 0, nullptr);
8991001
}
9001002

@@ -920,7 +1022,6 @@ namespace DX12
9201022

9211023

9221024
//TODO: remove
923-
void update_resource(Resource::ptr resource, UINT first_subresource, UINT sub_count, D3D12_SUBRESOURCE_DATA* data);
9241025
void update_buffer(Resource::ptr resource, UINT offset, const char* data, UINT size);
9251026
void update_texture(Resource::ptr resource, ivec3 offset, ivec3 box, UINT sub_resource, const char* data, UINT row_stride, UINT slice_stride = 0);
9261027
void update_buffer(Resource* resource, UINT offset, const char* data, UINT size);
@@ -938,17 +1039,106 @@ namespace DX12
9381039

9391040
class SignatureDataSetter
9401041
{
1042+
struct RowInfo
1043+
{
1044+
HandleType type;
1045+
bool dirty = false;
1046+
HandleTableLight table;
1047+
};
1048+
std::vector<RowInfo> tables;
1049+
1050+
friend class CommandList;
1051+
1052+
9411053
protected:
9421054
CommandList& base;
943-
SignatureDataSetter(CommandList& base) :base(base) { }
1055+
SignatureDataSetter(CommandList& base) :base(base) {
1056+
tables.resize(32); // !!!!!!!!!!!
1057+
}
1058+
1059+
virtual void set(UINT, const HandleTableLight&) = 0;
1060+
virtual void set_const_buffer(UINT i, const D3D12_GPU_VIRTUAL_ADDRESS&) = 0;
1061+
1062+
1063+
void reset_tables()
1064+
{
1065+
for (auto& row : tables)
1066+
{
1067+
row.dirty = false;
1068+
}
1069+
1070+
}
1071+
1072+
1073+
void commit_tables()
1074+
{
1075+
for(auto &row:tables)
1076+
{
1077+
if (!row.dirty) continue;
1078+
1079+
auto &table = row.table;
1080+
auto type = row.type;
1081+
for (UINT i = 0; i < (UINT)table.get_count(); ++i)
1082+
{
1083+
const auto& h = table[i];
1084+
if (h.resource_info && h.resource_info->resource_ptr)
1085+
{
1086+
if (h.resource_info->resource_ptr->get_heap_type() == HeapType::DEFAULT || h.resource_info->resource_ptr->get_heap_type() == HeapType::RESERVED)
1087+
{
1088+
if (type == HandleType::SRV) get_base().transition_srv(h.resource_info);
1089+
else if (type == HandleType::UAV) get_base().transition_uav(h.resource_info);
1090+
else assert(false);
1091+
}
1092+
else
1093+
{
1094+
get_base().use_resource(h.resource_info->resource_ptr);
1095+
}
1096+
}
1097+
}
1098+
1099+
row.dirty = false;
1100+
}
1101+
1102+
}
9441103
public:
9451104

9461105
CommandList& get_base() {
9471106
return base;
9481107
}
9491108
virtual void set_signature(const RootSignature::ptr&) = 0;
950-
virtual void set(UINT, const HandleTableLight&) = 0;
951-
virtual void set_const_buffer(UINT i, const D3D12_GPU_VIRTUAL_ADDRESS&) = 0;
1109+
1110+
template<HandleType type>
1111+
void set_table(UINT index, const HandleTableLight& table)
1112+
{
1113+
1114+
1115+
auto &row = tables[index];
1116+
1117+
row.type = type;
1118+
row.table = table;
1119+
row.dirty = true;
1120+
set(index, table);
1121+
}
1122+
1123+
1124+
void set_cb(UINT index, const ResourceAddress& address)
1125+
{
1126+
1127+
if (address.resource)
1128+
{
1129+
if (address.resource->get_heap_type() == HeapType::DEFAULT || address.resource->get_heap_type() == HeapType::RESERVED)
1130+
{
1131+
get_base().transition(address.resource, ResourceState::VERTEX_AND_CONSTANT_BUFFER);
1132+
}
1133+
else
1134+
{
1135+
get_base().use_resource(address.resource);
1136+
}
1137+
1138+
}
1139+
set_const_buffer(index, address.address);
1140+
}
1141+
9521142

9531143
template<class T>
9541144
std::unique_ptr<T> wrap()
@@ -996,13 +1186,10 @@ namespace DX12
9961186
void begin();
9971187
void end();
9981188
void on_execute();
999-
public:
10001189

1001-
void set_const_buffer(UINT, const D3D12_GPU_VIRTUAL_ADDRESS&)override;
1002-
1190+
void set_const_buffer(UINT, const D3D12_GPU_VIRTUAL_ADDRESS&)override;
10031191
void set(UINT, const HandleTableLight&)override;
10041192

1005-
10061193
public:
10071194

10081195
CommandList& get_base()
@@ -1031,23 +1218,18 @@ namespace DX12
10311218

10321219

10331220
void set_layout(Layouts layout);
1034-
10351221
void set_heaps(DescriptorHeap::ptr& a, DescriptorHeap::ptr& b);
10361222

10371223
void set_scissor(sizer_long rect);
10381224
void set_viewport(Viewport viewport);
1039-
10401225
void set_viewport(vec4 viewport);
1041-
10421226
void set_scissors(sizer_long rect);
10431227
void set_viewports(std::vector<Viewport> viewports);
10441228

10451229

10461230
void set_rtv(std::initializer_list<Handle> rt, Handle h);
1047-
10481231
void set_rtv(const HandleTable&, Handle);
1049-
void set_rtv(int c, Handle rt, Handle h);
1050-
1232+
void set_rtv(int c, Handle rt, Handle h);
10511233
void set_rtv(const HandleTableLight&, Handle);
10521234

10531235
void draw(D3D12_DRAW_INDEXED_ARGUMENTS args)
@@ -1103,10 +1285,7 @@ namespace DX12
11031285
if (h.is_valid())
11041286
get_base().transition_dsv(h.resource_info);
11051287

1106-
get_base().flush_transitions();
1107-
1108-
1109-
1288+
11101289
CD3DX12_CPU_DESCRIPTOR_HANDLE ar[] = { (rtvlist.cpu)... };
11111290
list->OMSetRenderTargets(size(ar), ar, false, h.is_valid() ? &h.cpu : nullptr);
11121291
}
@@ -1123,11 +1302,10 @@ namespace DX12
11231302
list->IASetIndexBuffer(&view.view);
11241303
}
11251304

1305+
1306+
11261307
void draw(UINT vertex_count, UINT vertex_offset = 0, UINT instance_count = 1, UINT instance_offset = 0);
11271308
void draw_indexed(UINT index_count, UINT index_offset, UINT vertex_offset, UINT instance_count = 1, UINT instance_offset = 0);
1128-
1129-
1130-
11311309
void execute_indirect(IndirectCommand& command_types, UINT max_commands, Resource* command_buffer, UINT64 command_offset = 0, Resource* counter_buffer = nullptr, UINT64 counter_offset = 0);
11321310

11331311
};
@@ -1191,12 +1369,14 @@ namespace DX12
11911369

11921370
void build_ras(const D3D12_BUILD_RAYTRACING_ACCELERATION_STRUCTURE_DESC &desc)
11931371
{
1194-
get_base().flush_transitions();
1372+
11951373
list->BuildRaytracingAccelerationStructure(&desc, 0, nullptr);
11961374
}
11971375

11981376
void dispatch_rays(const D3D12_DISPATCH_RAYS_DESC &desc)
11991377
{
1378+
base.create_transition_point();
1379+
commit_tables();
12001380
list->DispatchRays(&desc);
12011381
}
12021382

0 commit comments

Comments
 (0)