Skip to content

Commit a9849c0

Browse files
committed
Implement shader identifier patching during capture
1 parent cfc5204 commit a9849c0

File tree

5 files changed

+526
-15
lines changed

5 files changed

+526
-15
lines changed

renderdoc/data/hlsl/hlsl_cbuffers.h

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,23 @@ struct InstanceDesc
286286
GPUAddress blasAddress;
287287
};
288288

289+
cbuffer RayDispatchPatchCB REG(b0)
290+
{
291+
uint raydispatch_missoffs;
292+
uint raydispatch_missstride;
293+
uint raydispatch_misscount;
294+
295+
uint raydispatch_hitoffs;
296+
uint raydispatch_hitstride;
297+
uint raydispatch_hitcount;
298+
299+
uint raydispatch_calloffs;
300+
uint raydispatch_callstride;
301+
uint raydispatch_callcount;
302+
};
303+
304+
#define MAX_LOCALSIG_HANDLES 31
305+
289306
cbuffer DebugSampleOperation REG(b0)
290307
{
291308
float4 debugSampleUV;

renderdoc/data/hlsl/raytracing.hlsl

Lines changed: 116 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,3 +56,119 @@ bool InRange(BlasAddressRange addressRange, GPUAddress address)
5656
// This might cause device hang but at least we won't access incorrect addresses
5757
instanceDescs[dispatchGroup.x].blasAddress = 0;
5858
}
59+
60+
struct StateObjectLookup
61+
{
62+
uint2 id; // ResourceId
63+
uint offset;
64+
};
65+
66+
StructuredBuffer<StateObjectLookup> stateObjects : register(t0);
67+
68+
struct RecordData
69+
{
70+
uint4 identifier[2]; // 32-byte real identifier
71+
uint rootSigIndex; // only lower 16-bits are valid
72+
};
73+
74+
StructuredBuffer<RecordData> records : register(t1);
75+
76+
struct RootSig
77+
{
78+
uint numHandles;
79+
uint handleOffsets[MAX_LOCALSIG_HANDLES];
80+
};
81+
82+
StructuredBuffer<RootSig> rootsigs : register(t2);
83+
84+
struct WrappedRecord
85+
{
86+
uint2 id; // ResourceId
87+
uint index;
88+
};
89+
90+
RWByteAddressBuffer bufferToPatch : register(u0);
91+
92+
void PatchTable(uint byteOffset)
93+
{
94+
// load our wrapped record from the start of the table
95+
WrappedRecord wrappedRecord;
96+
wrappedRecord.id = bufferToPatch.Load2(byteOffset);
97+
wrappedRecord.index = bufferToPatch.Load(byteOffset + 8);
98+
99+
// find the state object it came from
100+
int i = 0;
101+
StateObjectLookup objectLookup;
102+
do
103+
{
104+
objectLookup = stateObjects[i];
105+
106+
if(objectLookup.id.x == wrappedRecord.id.x && objectLookup.id.y == wrappedRecord.id.y)
107+
break;
108+
109+
// terminate when the lookup is empty, we're out of state objects
110+
} while(objectLookup.id.x != 0 || objectLookup.id.y != 0);
111+
112+
// if didn't find a match, set a NULL shader identifier. This will fail if it's raygen but others
113+
// will in theory not crash.
114+
if(objectLookup.id.x == 0 && objectLookup.id.y == 0)
115+
{
116+
bufferToPatch.Store4(byteOffset, uint4(0, 0, 0, 0));
117+
bufferToPatch.Store4(byteOffset + 16, uint4(0, 0, 0, 0));
118+
return;
119+
}
120+
121+
// the exports from this state object are contiguous starting from the given index, look up this
122+
// identifier's export
123+
RecordData recordData = records[objectLookup.offset + wrappedRecord.index];
124+
125+
// store the unwrapped shader identifier
126+
bufferToPatch.Store4(byteOffset, recordData.identifier[0]);
127+
bufferToPatch.Store4(byteOffset + 16, recordData.identifier[1]);
128+
129+
if(recordData.rootSigIndex & 0xffff != 0xffff)
130+
{
131+
RootSig sig = rootsigs[recordData.rootSigIndex];
132+
133+
for(int i = 0; i < sig.numHandles; i++)
134+
{
135+
// TODO: patch descriptor handle at offset sig.handleOffsets[i]
136+
}
137+
}
138+
}
139+
140+
// Each SV_GroupId corresponds to one shader record to patch
141+
[numthreads(1, 1, 1)] void RENDERDOC_PatchRayDispatchCS(uint3 dispatchGroup
142+
: SV_GroupId) {
143+
uint group = dispatchGroup.x;
144+
145+
if(group == 0)
146+
{
147+
PatchTable(0);
148+
return;
149+
}
150+
151+
group--;
152+
153+
if(group < raydispatch_misscount)
154+
{
155+
PatchTable(raydispatch_missoffs + raydispatch_missstride * group);
156+
return;
157+
}
158+
159+
group -= raydispatch_misscount;
160+
161+
if(group < raydispatch_hitcount)
162+
{
163+
PatchTable(raydispatch_hitoffs + raydispatch_hitstride * group);
164+
return;
165+
}
166+
167+
group -= raydispatch_hitcount;
168+
169+
if(group < raydispatch_callcount)
170+
{
171+
PatchTable(raydispatch_calloffs + raydispatch_callstride * group);
172+
return;
173+
}
174+
}

renderdoc/driver/d3d12/d3d12_command_list4_wrap.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -851,13 +851,12 @@ bool WrappedID3D12GraphicsCommandList::PatchAccStructBlasAddress(
851851

852852
dxrCmd->SetPipelineState(patchInfo.m_pipeline);
853853
dxrCmd->SetComputeRootSignature(patchInfo.m_rootSignature);
854-
dxrCmd->SetComputeRoot32BitConstant(
855-
(UINT)D3D12PatchAccStructRootParamIndices::RootConstantBuffer, (UINT)addressCount, 0);
856-
dxrCmd->SetComputeRootShaderResourceView(
857-
(UINT)D3D12PatchAccStructRootParamIndices::RootAddressPairSrv, addressPairResAddress);
858-
dxrCmd->SetComputeRootUnorderedAccessView(
859-
(UINT)D3D12PatchAccStructRootParamIndices::RootPatchedAddressUav,
860-
patchRaytracing->m_patchedInstanceBuffer->Address());
854+
dxrCmd->SetComputeRoot32BitConstant((UINT)D3D12PatchTLASBuildParam::RootConstantBuffer,
855+
(UINT)addressCount, 0);
856+
dxrCmd->SetComputeRootShaderResourceView((UINT)D3D12PatchTLASBuildParam::RootAddressPairSrv,
857+
addressPairResAddress);
858+
dxrCmd->SetComputeRootUnorderedAccessView((UINT)D3D12PatchTLASBuildParam::RootPatchedAddressUav,
859+
patchRaytracing->m_patchedInstanceBuffer->Address());
861860
dxrCmd->Dispatch(accStructInput->Inputs.NumDescs, 1, 1);
862861

863862
{

0 commit comments

Comments
 (0)