Skip to content

Commit 628e275

Browse files
committed
snapshot
1 parent aa3c10e commit 628e275

File tree

6 files changed

+144
-103
lines changed

6 files changed

+144
-103
lines changed

src/binning.jl

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,33 @@ function hitBinning(hits, bbs, blockSizeX, blockSizeY, gridSizeX, gridSizeY)
2929
return
3030
end
3131

32+
function packTileId(x::UInt64)
33+
34+
end
35+
36+
function packZValue(x::UInt64)
37+
38+
end
39+
40+
function unpackTileId(x::UInt64)
3241

33-
function bbTileHit(bbs, )
42+
end
3443

44+
function unpackZValue(x::UInt64)
45+
46+
end
47+
48+
function binPacking(bbs, packedIds, blockSizeX, blockSizeY, gridSizeX, gridSizeY)
49+
idx = (blockIdx().x - 1i32)*blockDim().x + threadIdx().x
50+
xbbmin = (floor(bbs[1, 1, idx]))
51+
xbbmax = (ceil(bbs[1, 2, idx]))
52+
ybbmin = (floor(bbs[2, 1, idx]))
53+
ybbmax = (ceil(bbs[2, 2, idx]))
54+
# sync_threads()
55+
bminxIdx = Int32(div(xbbmin, float32(blockSizeX))) + 1i32
56+
bminyIdx = Int32(div(ybbmin, float32(blockSizeY))) + 1i32
57+
bmaxxIdx = Int32(div(xbbmax, float32(blockSizeX))) + 1i32
58+
bmaxyIdx = Int32(div(ybbmax, float32(blockSizeY))) + 1i32
59+
# # BB Cover
60+
sync_threads()
3561
end

src/camera.jl

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,7 @@ function computeTransform(camera::Camera)
8787
eye = camera.eye
8888
lookat = camera.lookat
8989
up = camera.up
90-
w = -(lookat .- eye) |> normalize
90+
w = (lookat .- eye) |> normalize
9191
u = cross(up, w) |> normalize
9292
v = cross(w, u)
9393
m = MMatrix{4, 4, Float32}(I)
@@ -96,6 +96,13 @@ function computeTransform(camera::Camera)
9696
return LinearMap(m) translateCamera(camera)
9797
end
9898

99+
function computeTransform(camera::GroundTruthCamera)
100+
m = MMatrix{4, 4, Float32}(I)
101+
m[1:3, 1:3] .= camera.rotation
102+
m[1:3, 4] .= -camera.position
103+
return LinearMap(m)
104+
end
105+
99106
function computeProjection(camera::Camera, w, h)
100107
p = MArray{Tuple{4, 4}, Float32}(undef)
101108
p .= 0.0f0
@@ -107,6 +114,17 @@ function computeProjection(camera::Camera, w, h)
107114
return LinearMap(p)
108115
end
109116

117+
function computeProjection(camera::GroundTruthCamera, near, far)
118+
p = MArray{Tuple{4, 4}, Float32}(undef)
119+
p .= 0.0f0
120+
p[1, 1] = 2.0f0*(camera.fx)/camera.width
121+
p[2, 2] = 2.0f0*(camera.fy)/camera.height
122+
p[3, 3] = (far + near)/(far - near)
123+
p[3, 4] = -2.0f0*(far*near)/(far - near)
124+
p[4, 3] = 1
125+
return LinearMap(p)
126+
end
127+
110128
function loadCameras(path)
111129
cameras = JSON.parsefile(path)
112130
return cameras

src/forward.jl

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -22,40 +22,69 @@ function preprocess(renderer::GaussianRenderer2D)
2222
end
2323

2424
function preprocess(renderer::GaussianRenderer3D)
25+
# Worldspace, clip space initializations
26+
# TODO avoid dynamic memory allocations
2527
ts = CUDA.zeros(4, renderer.nGaussians);
2628
tps = CUDA.zeros(4, renderer.nGaussians);
2729
μ′ = CUDA.zeros(2, renderer.nGaussians);
28-
camera = defaultCamera();
30+
31+
# Camera related params
32+
camerasPath = joinpath(pkgdir(WGPUgfx), "assets", "bonsai", "cameras.json")
33+
camIdx = 1
34+
near = 0.1f0
35+
far = 10.0f0
36+
camera = getCamera(camerasPath, camIdx)
2937
T = computeTransform(camera).linear |> MArray |> gpu;
3038
(w, h) = size(renderer.imageData)[1:2];
31-
P = computeProjection(camera, w, h).linear |> gpu;
39+
P = computeProjection(camera, near, far).linear |> gpu;
40+
w = camera.width
41+
h = camera.height
3242
cx = div(w, 2)
3343
cy = div(h, 2)
3444
n = renderer.nGaussians
35-
fx = 3200.7f0
36-
fy = 3200.7f0
45+
fx = camera.fx
46+
fy = camera.fy
3747
means = renderer.splatData.means |> gpu
38-
cov2ds = renderer.cov2ds
39-
cov3ds = renderer.cov3ds
40-
bbs = renderer.bbs
48+
cov2ds = renderer.cov2ds;
49+
cov3ds = renderer.cov3ds;
50+
bbs = renderer.bbs;
4151
invCov2ds = renderer.invCov2ds;
42-
quaternions = renderer.splatData.quaternions |> gpu
43-
scales = renderer.splatData.scales |> gpu
44-
n = renderer.nGaussians
45-
bbs = renderer.bbs
46-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) tValues(
52+
quaternions = renderer.splatData.quaternions |> gpu;
53+
scales = renderer.splatData.scales |> gpu;
54+
n = renderer.nGaussians;
55+
56+
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) frustumCulling(
4757
ts, tps, cov3ds, means, μ′, fx, fy,
4858
quaternions, scales, T, P, w, h, cx, cy,
49-
cov2ds,
59+
cov2ds, far, near
5060
)
5161
end
62+
63+
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) tValues(
64+
ts, cov3ds, fx, fy,
65+
quaternions, scales, cov2ds
66+
)
67+
end
68+
69+
renderer.positions = μ′
70+
sortIdxs = CUDA.sortperm(tps[3, :])
5271
CUDA.unsafe_free!(ts)
5372
CUDA.unsafe_free!(tps)
54-
renderer.positions = μ′
73+
renderer.cov2ds = cov2ds[:, :, sortIdxs]
74+
renderer.positions = μ′[:, sortIdxs]
5575
# TODO this is temporary hack
5676
#CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales) end
5777
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeInvCov2d(cov2ds, invCov2ds) end
58-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeBB(cov2ds, bbs, μ′, size(renderer.imageData)[1:end-1]) end
78+
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeBB(cov2ds, bbs, renderer.positions, (w, h)) end
79+
end
80+
81+
82+
function packedTileIds(renderer)
83+
bbs = renderer.bbs
84+
packedIds = CUDA.zeros(UInt64, nGaussians)
85+
CUDA.@sync begin
86+
@cuda threads=32 blocks=div(nGaussians, 32) binPacking(packedIds, threads..., blocks...)
87+
end
5988
end
6089

6190
function compactIdxs(renderer)
@@ -65,11 +94,21 @@ function compactIdxs(renderer)
6594
CUDA.@sync begin
6695
@cuda threads=32 blocks=div(n, 32) hitBinning(hits, bbs, threads..., blocks...)
6796
end
68-
hitScans = CUDA.zeros(UInt16, size(hits));
97+
98+
# This is not memory efficient but works for small list of gaussians in tile ...
99+
# hitScans = CUDA.zeros(UInt16, size(hits));
69100
CUDA.@sync CUDA.scan!(+, hitScans, hits; dims=3);
70101
CUDA.@sync maxHits = CUDA.maximum(hitScans) |> Int
71-
maxBinSize = min(typemax(UInt16) |> Int, nextpow(2, maxHits))# TODO limiting maxBinSize hardcoded to 4096
72-
renderer.hitIdxs = CUDA.zeros(UInt32, blocks..., maxBinSize);
102+
103+
# TODO hardcoding UInt16 will cause issues if number of gaussians in a Tile
104+
if maxHits < typemax(UInt16)
105+
maxBinSize = min((typemax(UInt16) |> Int), nextpow(2, maxHits))# TODO limiting maxBinSize hardcoded to 4096
106+
renderer.hitIdxs = CUDA.zeros(UInt32, blocks..., maxBinSize);
107+
else
108+
maxBinSize = 2*nextpow(2, maxHits)
109+
renderer.hitIdxs = CUDA.zeros(UInt32, blocks..., maxBinSize);
110+
end
111+
73112
CUDA.@sync begin
74113
@cuda threads=blocks blocks=(32, div(n, 32)) shmem=reduce(*, blocks)*sizeof(UInt32) compactHits(
75114
hits,

src/main.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ include("cov2d.jl")
33
include("boundingbox.jl")
44
include("binning.jl")
55
include("compact.jl")
6+
67
include("camera.jl")
78
include("renderer.jl")
89
include("projection.jl")
@@ -38,9 +39,6 @@ yimg = colorview(RGB{N0f8},
3839
)
3940
yimg = Images.imrotate(yimg, -pi/2)
4041

41-
42-
43-
4442
include("train.jl")
4543

4644
windowSize = 11

src/projection.jl

Lines changed: 30 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,9 @@ function computeCov3dProjection_kernel(cov2ds, cov3ds, rotation, affineTransform
99
S[i, j] = 0.0f0
1010
end
1111
end
12-
S[1, 1] = scales[1, idx]
13-
S[2, 2] = scales[2, idx]
14-
S[3, 3] = scales[3, idx]
12+
S[1, 1] = exp(scales[1, idx])
13+
S[2, 2] = exp(scales[2, idx])
14+
S[3, 3] = exp(scales[3, idx])
1515
W = R*S
1616
J = W*adjoint(W)
1717
for i in 1:3
@@ -40,76 +40,10 @@ end
4040
return R
4141
end
4242

43-
function tValues(ts, tps, meansList, μ′, T, P, w, h, cx, cy)
44-
idx = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
45-
46-
meanVec = MVector{4, Float32}(undef)
47-
meanVec[1] = meansList[1, idx]
48-
meanVec[2] = meansList[2, idx]
49-
meanVec[3] = meansList[3, idx]
50-
meanVec[4] = 1
51-
52-
Tcw = MArray{Tuple{4, 4}, Float32}(undef)
53-
for ii in 1:4
54-
for jj in 1:4
55-
Tcw[ii, jj] = T[ii, jj]
56-
end
57-
end
58-
59-
tstmp = Tcw*meanVec
60-
ts[1, idx] = tstmp[1]
61-
ts[2, idx] = tstmp[2]
62-
ts[3, idx] = tstmp[3]
63-
ts[4, idx] = tstmp[4]
64-
65-
Ptmp = MArray{Tuple{4, 4}, Float32}(undef)
66-
for ii in 1:4
67-
for jj in 1:4
68-
Ptmp[ii, jj] = P[ii, jj]
69-
end
70-
end
71-
72-
tpstmp = Ptmp*tstmp
73-
tps[1, idx] = tpstmp[1]
74-
tps[2, idx] = tpstmp[2]
75-
tps[3, idx] = tpstmp[3]
76-
tps[4, idx] = tpstmp[4]
77-
78-
tx = tpstmp[1]
79-
ty = tpstmp[2]
80-
tz = tpstmp[3]
81-
tw = tpstmp[4]
82-
83-
μ′[1, idx] = (w*tx/tw) + cx
84-
μ′[2, idx] = (w*ty/tw) + cy
85-
86-
quat = quaternions[1, idx]
87-
@inline R = quatToRot(quat)
88-
S = MArray{Tuple{3, 3}, Float32}(undef)
89-
for i in 1:3
90-
for j in 1:3
91-
S[i, j] = 0.0f0
92-
end
93-
end
94-
S[1, 1] = scales[1, idx]
95-
S[2, 2] = scales[2, idx]
96-
S[3, 3] = scales[3, idx]
97-
W = R*S
98-
J = W*adjoint(W)
99-
for i in 1:3
100-
for j in 1:3
101-
cov3ds[i, j, idx] = J[i, j]
102-
end
103-
end
104-
105-
return nothing
106-
end
107-
108-
109-
function tValues(
43+
function frustumCulling(
11044
ts, tps, cov3ds, meansList, μ′, fx, fy,
11145
quaternions, scales, T, P, w, h, cx, cy,
112-
cov2ds
46+
cov2ds, far, near
11347
)
11448
idx = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
11549

@@ -155,9 +89,31 @@ function tValues(
15589
tz′ = tpstmp[3]
15690
tw′ = tpstmp[4]
15791

158-
μ′[1, idx] = ((w*tx′/tw′) + 1)/2 + cx
159-
μ′[2, idx] = ((h*ty′/tw′) + 1)/2 + cy
92+
x = ((w*tx′/tw′) + 1)/2 + cx
93+
y = ((h*ty′/tw′) + 1)/2 + cy
16094

95+
if (-w < x < w) && (-h < y < h)# && (near < tz′ < far)
96+
μ′[1, idx] = ((w*tx′/tw′) + 1)/2 + cx
97+
μ′[2, idx] = ((h*ty′/tw′) + 1)/2 + cy
98+
else
99+
# TODO zero values are used for culling checks for now
100+
μ′[1, idx] = 0.0f0
101+
μ′[1, idx] = 0.0f0
102+
end
103+
return nothing
104+
end
105+
106+
107+
function tValues(
108+
ts, cov3ds, fx, fy, quaternions, scales, cov2ds
109+
)
110+
idx = (blockIdx().x - 1i32) * blockDim().x + threadIdx().x
111+
112+
tx = ts[1, idx]
113+
ty = ts[2, idx]
114+
tz = ts[3, idx]
115+
tw = ts[4, idx]
116+
161117
quat = MVector{4, Float32}(undef)
162118
quat[1] = quaternions[1, idx]
163119
quat[2] = quaternions[2, idx]
@@ -195,7 +151,7 @@ function tValues(
195151

196152
for ii in 1:2
197153
for jj in 1:2
198-
cov2ds[ii, jj, idx] = cov2d[ii, jj]
154+
cov2ds[ii, jj, idx] = cov2d[ii, jj] + 0.1
199155
end
200156
end
201157

src/splat.jl

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -220,13 +220,17 @@ function splatDraw(cimage, transGlobal, means, bbs, invCov2ds, hitIdxs, opacitie
220220
deltaY = float(j) - means[2, bIdx]
221221
delta[2] = deltaY
222222
disttmp = invCov2d*delta
223-
dist = disttmp[1]*delta[1] + disttmp[2]*delta[2]
223+
dist = 0.50f0*(disttmp[1]*delta[1] + disttmp[2]*delta[2])
224224
alpha = cusigmoid(opacity*exp(-dist))
225225
transmittance = splatData[txIdx, tyIdx, transIdx]
226-
CUDA.@atomic splatData[txIdx, tyIdx, 1] += (colors[1, bIdx]*alpha*transmittance)
227-
CUDA.@atomic splatData[txIdx, tyIdx, 2] += (colors[2, bIdx]*alpha*transmittance)
228-
CUDA.@atomic splatData[txIdx, tyIdx, 3] += (colors[3, bIdx]*alpha*transmittance)
229-
CUDA.@atomic splatData[txIdx, tyIdx, transIdx] *= (1.0f0 - alpha)
226+
color1 = SH_C0*colors[1, bIdx]
227+
color2 = SH_C0*colors[2, bIdx]
228+
color3 = SH_C0*colors[3, bIdx]
229+
230+
splatData[txIdx, tyIdx, 1] += (color1*alpha*transmittance)
231+
splatData[txIdx, tyIdx, 2] += (color2*alpha*transmittance)
232+
splatData[txIdx, tyIdx, 3] += (color3*alpha*transmittance)
233+
splatData[txIdx, tyIdx, transIdx] *= (1.0f0 - alpha)
230234
end
231235
end
232236
sync_threads()
@@ -312,7 +316,7 @@ function splatGrads(
312316
dist = disttmp[1]*delta[1] + disttmp[2]*delta[2]
313317
ΔMean = invCov2d*delta
314318
ΔΣ = ΔMean*adjoint(ΔMean)
315-
Δo = exp(-dist)
319+
Δo = exp(-0.5f0*dist)
316320
Δσ = -opacity*Δo
317321
transmittance = transData[txIdx, tyIdx]
318322
alpha = opacity*exp(-dist)

0 commit comments

Comments
 (0)