Created
August 31, 2023 05:54
-
-
Save unitycoder/0ba1bfa2cb82e434a0614c42b9ed778a to your computer and use it in GitHub Desktop.
Parallel prefix sum ComputeShader
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| // https://forum.unity.com/threads/parallel-prefix-sum-computeshader.518397/#post-7887517 | |
| #define THREADS_PER_GROUP 512 // Ensure that this equals the "threadsPerGroup" variables in the host scripts using this. | |
| int N; | |
| StructuredBuffer<uint> InputBufR; | |
| RWStructuredBuffer<uint> OutputBufW; | |
| groupshared uint bucket[THREADS_PER_GROUP]; | |
| void Scan(uint id, uint gi, uint x) | |
| { | |
| bucket[gi] = x; | |
| [unroll] | |
| for (uint t = 1; t < THREADS_PER_GROUP; t <<= 1) { | |
| GroupMemoryBarrierWithGroupSync(); | |
| uint temp = bucket[gi]; | |
| if (gi >= t) temp += bucket[gi - t]; | |
| GroupMemoryBarrierWithGroupSync(); | |
| bucket[gi] = temp; | |
| } | |
| OutputBufW[id] = bucket[gi]; | |
| } | |
| // Perform isolated scans within each group. | |
| #pragma kernel ScanInGroupsInclusive | |
| [numthreads(THREADS_PER_GROUP, 1, 1)] | |
| void ScanInGroupsInclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex) | |
| { | |
| uint x = 0; | |
| if ((int)id < N) | |
| x = InputBufR[id]; | |
| Scan(id, gi, x); | |
| } | |
| // Perform isolated scans within each group. Shift the input so as to make the final | |
| // result (obtained after the ScanSums and AddScannedSums calls) exclusive. | |
| #pragma kernel ScanInGroupsExclusive | |
| [numthreads(THREADS_PER_GROUP, 1, 1)] | |
| void ScanInGroupsExclusive(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex) | |
| { | |
| //uint x = (id == 0) ? 0 : InputBufR[id - 1]; | |
| uint idx = (id - 1); | |
| uint x = 0; | |
| if ((int)idx >= 0 && (int)idx < N) | |
| x = InputBufR[idx]; | |
| Scan(id, gi, x); | |
| } | |
| // Scan the sums of each of the groups (partial sums) from the preceding ScanInGroupsInclusive/Exclusive call. | |
| #pragma kernel ScanSums | |
| [numthreads(THREADS_PER_GROUP, 1, 1)] | |
| void ScanSums(uint id : SV_DispatchThreadID, uint gi : SV_GroupIndex) | |
| { | |
| //uint x = (id == 0) ? 0 : InputBufR[id * THREADS_PER_GROUP - 1]; | |
| uint idx = (id * THREADS_PER_GROUP - 1); | |
| uint x = 0; | |
| if ((int)idx >= 0 && (int)idx < N) | |
| x = InputBufR[idx]; | |
| Scan(id, gi, x); | |
| } | |
| // Add the scanned sums to the output of the first kernel call, to get the final, complete prefix sum. | |
| #pragma kernel AddScannedSums | |
| [numthreads(THREADS_PER_GROUP, 1, 1)] | |
| void AddScannedSums(uint id : SV_DispatchThreadID, uint gid : SV_GroupID) | |
| { | |
| if ((int)id < N) | |
| OutputBufW[id] += InputBufR[gid]; | |
| } |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| struct ScanHelper | |
| { | |
| const int threadsPerGroup = 512; // THREADS_PER_GROUP in ScanOperations.compute | |
| public int size; | |
| public List<ComputeBuffer> group_buffer; | |
| public List<int> work_size; | |
| public void InclusiveScan(int num, ComputeShader scanOperations, | |
| ComputeBuffer inputs, ComputeBuffer outputs) | |
| { | |
| this.RequireBuffer(num); | |
| // 1. Per group scan | |
| int kernelScan = scanOperations.FindKernel("ScanInGroupsInclusive"); | |
| scanOperations.SetInt("N", num); | |
| scanOperations.SetBuffer(kernelScan, "InputBufR", inputs); | |
| scanOperations.SetBuffer(kernelScan, "OutputBufW", outputs); | |
| scanOperations.Dispatch(kernelScan, NUM_GROUPS(num, threadsPerGroup), 1, 1); | |
| if (num < threadsPerGroup) | |
| return; | |
| int kernelScanSums = scanOperations.FindKernel("ScanSums"); | |
| int kernelAdd = scanOperations.FindKernel("AddScannedSums"); | |
| // 2. Scan per group sum | |
| scanOperations.SetInt("N", num); | |
| scanOperations.SetBuffer(kernelScanSums, "InputBufR", outputs); | |
| scanOperations.SetBuffer(kernelScanSums, "OutputBufW", this.group_buffer[0]); | |
| scanOperations.Dispatch(kernelScanSums, NUM_GROUPS(this.work_size[0], threadsPerGroup), 1, 1); | |
| // Continue down the pyramid | |
| for (int l = 0; l < this.group_buffer.Count - 1; ++l) | |
| { | |
| int work_sz = this.work_size[l]; | |
| // 2. Scan per group sum | |
| scanOperations.SetInt("N", work_sz); | |
| scanOperations.SetBuffer(kernelScanSums, "InputBufR", this.group_buffer[l]); | |
| scanOperations.SetBuffer(kernelScanSums, "OutputBufW", this.group_buffer[l+1]); | |
| scanOperations.Dispatch(kernelScanSums, NUM_GROUPS(this.work_size[l+1], threadsPerGroup), 1, 1); | |
| } | |
| for (int l = this.group_buffer.Count - 1; l > 0; --l) | |
| { | |
| int work_sz = this.work_size[l - 1]; | |
| // 3. Add scanned group sum | |
| scanOperations.SetInt("N", work_sz); | |
| scanOperations.SetBuffer(kernelAdd, "InputBufR", this.group_buffer[l]); | |
| scanOperations.SetBuffer(kernelAdd, "OutputBufW", this.group_buffer[l - 1]); | |
| scanOperations.Dispatch(kernelAdd, NUM_GROUPS(work_sz, threadsPerGroup), 1, 1); | |
| } | |
| // 3. Add scanned group sum | |
| scanOperations.SetInt("N", num); | |
| scanOperations.SetBuffer(kernelAdd, "InputBufR", this.group_buffer[0]); | |
| scanOperations.SetBuffer(kernelAdd, "OutputBufW", outputs); | |
| scanOperations.Dispatch(kernelAdd, this.work_size[0], 1, 1); | |
| } | |
| public void RequireBuffer(int alloc_sz) | |
| { | |
| if (this.size < alloc_sz) | |
| { | |
| this.Release(); | |
| this.size = (int)(alloc_sz * 1.5); | |
| this.group_buffer = new List<ComputeBuffer>(); | |
| this.work_size = new List<int>(); | |
| int work_sz = this.size; | |
| while (work_sz > threadsPerGroup) | |
| { | |
| work_sz = NUM_GROUPS(work_sz, threadsPerGroup); | |
| this.group_buffer.Add(new ComputeBuffer(work_sz, sizeof(uint))); | |
| this.work_size.Add(work_sz); | |
| } | |
| } | |
| } | |
| public void Release() | |
| { | |
| if (group_buffer != null) | |
| { | |
| foreach (ComputeBuffer buffer in group_buffer) | |
| if (buffer != null) | |
| buffer.Dispose(); | |
| group_buffer = null; | |
| } | |
| } | |
| } | |
| [SerializeField] ComputeShader scanOperations; | |
| ScanHelper mScanHelper; | |
| mScanHelper.InclusiveScan(N, scanOperations, inputs, outputs); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment