@@ -15,9 +15,20 @@ function preprocess(renderer::GaussianRenderer2D)
1515 scales = renderer. splatData. scales;
1616 n = renderer. nGaussians
1717 bbs = renderer. bbs
18- CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) computeCov2d_kernel (cov2ds, rots, scales) end
19- CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) computeInvCov2d (cov2ds, invCov2ds) end
20- CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) computeBB (cov2ds, bbs, means, size (renderer. imageData)[1 : end - 1 ]) end
18+ CUDA. @sync begin
19+ @cuda threads= 32 blocks= div (n, 32 ) computeCov2d_kernel (cov2ds, rots, scales)
20+ end
21+ CUDA. @sync begin
22+ @cuda threads= 32 blocks= div (n, 32 ) computeInvCov2d (cov2ds, invCov2ds)
23+ end
24+ CUDA. @sync begin
25+ @cuda threads= 32 blocks= div (n, 32 ) computeBB (
26+ cov2ds,
27+ bbs,
28+ means,
29+ size (renderer. imageData)[1 : end - 1 ]
30+ )
31+ end
2132 return nothing
2233end
2334
@@ -30,17 +41,22 @@ function preprocess(renderer::GaussianRenderer3D)
3041 (w, h) = size (renderer. imageData)[1 : 2 ];
3142 # Camera related params
3243 camerasPath = joinpath (
33- ENV [" HOMEPATH" ], " Downloads" , " GaussianSplatting" , " GaussianSplatting" , " bonsai" , " cameras.json"
44+ ENV [" HOMEPATH" ],
45+ " Downloads" ,
46+ " GaussianSplatting" ,
47+ " GaussianSplatting" ,
48+ " bonsai" ,
49+ " cameras.json"
3450 ) # TODO this is hardcoded
35- camIdx = 2
36- # camera = getCamera(camerasPath, camIdx) # defaultCamera();
51+ camIdx = 1
52+ # camera = getCamera(camerasPath, camIdx) # defaultCamera();
3753 camera = defaultCamera ();
3854 near = camera. near
3955 far = camera. far
4056 T = computeTransform (camera). linear |> gpu;
4157 P = computeProjection (camera, w, h). linear |> gpu;
4258 cx = w/ 2.0
43- cy = w / 2.0
59+ cy = h / 2.0
4460 n = renderer. nGaussians
4561 fx = camera. fx |> Float32
4662 fy = camera. fy |> Float32
@@ -52,84 +68,110 @@ function preprocess(renderer::GaussianRenderer3D)
5268 quaternions = renderer. splatData. quaternions |> gpu;
5369 scales = renderer. splatData. scales |> gpu;
5470 n = renderer. nGaussians;
55-
71+
5672 CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) frustumCulling (
57- ts, tps, cov3ds, means, μ′, fx, fy,
58- quaternions, scales, T, P, w, h, cx, cy,
59- cov2ds, far, near
60- )
73+ ts, tps, μ′, # outs
74+ means, T, P, # ins
75+ w, h, cx, cy, # Numbers
76+ )
6177 end
62-
78+
6379 CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) tValues (
6480 ts, cov3ds, fx, fy,
6581 quaternions, scales, cov2ds
82+ )
83+ end
84+
85+ CUDA. unsafe_free! (ts)
86+
87+ # renderer.positions = μ′
88+ # TODO this is temporary hack
89+ # CUDA.@sync begin
90+ # @cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales)
91+ # end
92+
93+ CUDA. @sync begin
94+ @cuda threads= 32 blocks= div (n, 32 ) computeInvCov2d (
95+ cov2ds,
96+ invCov2ds
6697 )
6798 end
68-
69- # renderer.positions = μ′
99+
100+ CUDA. @sync begin
101+ @cuda threads= 32 blocks= div (n, 32 ) computeBB (
102+ cov2ds,
103+ bbs,
104+ μ′,
105+ (w, h)
106+ )
107+ end
108+
109+ sortIdxs = CUDA. sortperm (tps[3 , :], lt= ! isless)
70110 renderer. camera = camera
71- sortIdxs = CUDA. sortperm (tps[3 , :], lt= isless) # chck
72111 renderer. sortIdxs = sortIdxs
73- CUDA. unsafe_free! (ts)
74- # CUDA.unsafe_free!(tps)
75112 renderer. cov2ds = cov2ds[:, :, sortIdxs]
76113 renderer. positions = μ′[:, sortIdxs]
77- # TODO this is temporary hack
78- # CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales) end
79- CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) computeInvCov2d (renderer. cov2ds, invCov2ds) end
80- CUDA. @sync begin @cuda threads= 32 blocks= div (n, 32 ) computeBB (renderer. cov2ds, bbs, renderer. positions, (w, h)) end
114+ renderer. invCov2ds = invCov2ds[:, :, sortIdxs]
115+ renderer. bbs = bbs[:, :, sortIdxs]
81116 return tps
82117end
83118
119+ """
120+ compactIndex(renderer::Renderer)
84121
85- function packedTileIds (renderer)
86- bbs = renderer. bbs
87- packedIds = CUDA. zeros (UInt64, nGaussians)
88- CUDA. @sync begin
89- @cuda threads= 32 blocks= div (nGaussians, 32 ) binPacking (bbs, packedIds, threads... , blocks... )
90- end
91- end
92-
93-
122+ This function compute compact indexes.
123+ """
94124function compactIdxs (renderer)
95125 bbs = renderer. bbs
96126 hits = CUDA. zeros (UInt8, blocks... , renderer. nGaussians);
97127 n = renderer. nGaussians
128+
98129 CUDA. @sync begin
99130 @cuda threads= 32 blocks= div (n, 32 ) hitBinning (hits, bbs, threads... , blocks... )
100131 end
101132
102133 # This is not memory efficient but works for small list of gaussians in tile ...
134+
103135 hitScans = CUDA. zeros (UInt16, size (hits));
104136 CUDA. @sync CUDA. scan! (+ , hitScans, hits; dims= 3 );
105137 CUDA. @sync maxHits = CUDA. maximum (hitScans) |> Int
106138
107139 # TODO hardcoding UInt16 will cause issues if number of gaussians in a Tile
108140 # if maxHits < typemax(UInt32)
109- maxBinSize = min ((typemax (UInt16) |> Int), nextpow (2 , maxHits))# TODO limiting maxBinSize hardcoded to 4096
141+ # TODO limiting maxBinSize hardcoded to 4096
142+
143+ maxBinSize = min ((typemax (UInt16) |> Int), nextpow (2 , maxHits))
110144 renderer. hitIdxs = CUDA. zeros (UInt32, blocks... , maxBinSize);
145+
111146 # else
112147 # maxBinSize = 2*nextpow(2, maxHits)
113148 # renderer.hitIdxs = CUDA.zeros(UInt32, blocks..., maxBinSize);
114149 # end
115150
116151 CUDA. @sync begin
117- @cuda threads= blocks blocks= (32 , div (n, 32 )) shmem= reduce (* , blocks)* sizeof (UInt32) compactHits (
118- hits,
119- bbs,
120- hitScans,
121- renderer. hitIdxs
152+ @cuda (
153+ threads= blocks,
154+ blocks= (32 , div (n, 32 )),
155+ shmem= reduce (* , blocks)* sizeof (UInt32),
156+ compactHits (
157+ hits,
158+ bbs,
159+ hitScans,
160+ renderer. hitIdxs
161+ )
122162 )
123163 end
164+
124165 CUDA. unsafe_free! (hits)
125166 CUDA. unsafe_free! (hitScans)
126167 return nothing
127168end
128169
129170function forward (renderer, tps)
171+ sortIdxs = renderer. sortIdxs
172+ tps = tps[:, sortIdxs]
130173 cimage = renderer. imageData
131174 invCov2ds = renderer. invCov2ds
132- sortIdxs = renderer. sortIdxs
133175 transmittance = renderer. transmittance
134176 positions = renderer. positions
135177 bbs = renderer. bbs
@@ -141,18 +183,23 @@ function forward(renderer, tps)
141183 eye = renderer. camera. eye .| > Float32 |> gpu
142184 lookAt = renderer. camera. lookAt .| > Float32 |> gpu
143185 CUDA. @sync begin
144- @cuda threads= threads blocks= blocks shmem= (4 * (reduce (* , threads))* sizeof (Float32)) splatDraw (
145- cimage,
146- transmittance,
147- positions,
148- tps,
149- bbs,
150- invCov2ds,
151- hitIdxs,
152- opacities,
153- shs,
154- eye,
155- lookAt
186+ @cuda (
187+ threads= threads,
188+ blocks= blocks,
189+ shmem= (4 * (reduce (* , threads))* sizeof (Float32)),
190+ splatDraw (
191+ cimage,
192+ transmittance,
193+ positions,
194+ tps,
195+ bbs,
196+ invCov2ds,
197+ hitIdxs,
198+ opacities,
199+ shs,
200+ eye,
201+ lookAt
202+ )
156203 )
157204 end
158205 return nothing
0 commit comments