Skip to content

Commit 5399411

Browse files
committed
small fixes and formatting
1 parent e8b79aa commit 5399411

File tree

5 files changed

+139
-95
lines changed

5 files changed

+139
-95
lines changed

src/boundingbox.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
1+
# Compute Bounding Boxes
22
function computeBB(cov2ds, bbs, means, sz)
33
idx = (blockIdx().x - 1i32)*blockDim().x + threadIdx().x
44
BB = MArray{Tuple{2, 2}, Float32}(undef)

src/camera.jl

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,15 @@ mutable struct Camera
2222
end
2323

2424
function defaultCamera(;id=0)
25-
eye = [0.0, 0.0, 30.0] .|> Float32
25+
eye = [0.0, 0.0, 35.0] .|> Float32
2626
lookat = [0, 0, 0] .|> Float32
2727
up = [0, 1, 0] .|> Float32
2828
scale = [1, 1, 1] .|> Float32
2929
fx = 3200.0f0
3030
fy = 3200.0f0
3131
aspectRatio = 1.0 |> Float32
32-
nearPlane = 0.1 |> Float32
33-
farPlane = 100.0 |> Float32
32+
nearPlane = -10.0 |> Float32
33+
farPlane = -100.0 |> Float32
3434
return Camera(
3535
fx,
3636
fy,
@@ -94,6 +94,7 @@ function computeTransform(camera::Camera)
9494
v = cross(w, u)
9595
m = MMatrix{4, 4, Float32}(I)
9696
m[1:3, 1:3] .= (cat([u, v, w]..., dims=2) |> adjoint .|> Float32 |> collect)
97+
m[4, 4] = 0.0
9798
m = SMatrix(m)
9899
return LinearMap(m) translateCamera(camera)
99100
end
@@ -128,9 +129,9 @@ function getCamera(path, idx)
128129
id = camera["id"]
129130
up = [0, 1, 0] .|> Float32
130131
eye = -(rotation |> adjoint)*position
131-
lookAt = (rotation |> adjoint)*[0.0f0, 0.0f0, -1.0f0]
132-
near = 0.1f0 # TODO hardcoded
133-
far = 100.0f0 # TODO hardcoded
132+
lookAt = -(rotation |> adjoint)*[0.0f0, 0.0f0, 1.0f0]
133+
near = -1.0f0 # TODO hardcoded
134+
far = -100.0f0 # TODO hardcoded
134135
scale = [1, 1, 1] .|> Float32
135136
aspectRatio=1.0f0
136137
data = imgName

src/forward.jl

Lines changed: 97 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -15,9 +15,20 @@ function preprocess(renderer::GaussianRenderer2D)
1515
scales = renderer.splatData.scales;
1616
n = renderer.nGaussians
1717
bbs = renderer.bbs
18-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales) end
19-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeInvCov2d(cov2ds, invCov2ds) end
20-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeBB(cov2ds, bbs, means, size(renderer.imageData)[1:end-1]) end
18+
CUDA.@sync begin
19+
@cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales)
20+
end
21+
CUDA.@sync begin
22+
@cuda threads=32 blocks=div(n, 32) computeInvCov2d(cov2ds, invCov2ds)
23+
end
24+
CUDA.@sync begin
25+
@cuda threads=32 blocks=div(n, 32) computeBB(
26+
cov2ds,
27+
bbs,
28+
means,
29+
size(renderer.imageData)[1:end-1]
30+
)
31+
end
2132
return nothing
2233
end
2334

@@ -30,17 +41,22 @@ function preprocess(renderer::GaussianRenderer3D)
3041
(w, h) = size(renderer.imageData)[1:2];
3142
# Camera related params
3243
camerasPath = joinpath(
33-
ENV["HOMEPATH"], "Downloads", "GaussianSplatting", "GaussianSplatting", "bonsai", "cameras.json"
44+
ENV["HOMEPATH"],
45+
"Downloads",
46+
"GaussianSplatting",
47+
"GaussianSplatting",
48+
"bonsai",
49+
"cameras.json"
3450
) # TODO this is hardcoded
35-
camIdx = 2
36-
#camera = getCamera(camerasPath, camIdx) # defaultCamera();
51+
camIdx = 1
52+
# camera = getCamera(camerasPath, camIdx) # defaultCamera();
3753
camera = defaultCamera();
3854
near = camera.near
3955
far = camera.far
4056
T = computeTransform(camera).linear |> gpu;
4157
P = computeProjection(camera, w, h).linear |> gpu;
4258
cx = w/2.0
43-
cy = w/2.0
59+
cy = h/2.0
4460
n = renderer.nGaussians
4561
fx = camera.fx |> Float32
4662
fy = camera.fy |> Float32
@@ -52,84 +68,110 @@ function preprocess(renderer::GaussianRenderer3D)
5268
quaternions = renderer.splatData.quaternions |> gpu;
5369
scales = renderer.splatData.scales |> gpu;
5470
n = renderer.nGaussians;
55-
71+
5672
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) frustumCulling(
57-
ts, tps, cov3ds, means, μ′, fx, fy,
58-
quaternions, scales, T, P, w, h, cx, cy,
59-
cov2ds, far, near
60-
)
73+
ts, tps, μ′, # outs
74+
means, T, P, # ins
75+
w, h, cx, cy, # Numbers
76+
)
6177
end
62-
78+
6379
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) tValues(
6480
ts, cov3ds, fx, fy,
6581
quaternions, scales, cov2ds
82+
)
83+
end
84+
85+
CUDA.unsafe_free!(ts)
86+
87+
# renderer.positions = μ′
88+
# TODO this is temporary hack
89+
# CUDA.@sync begin
90+
# @cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales)
91+
# end
92+
93+
CUDA.@sync begin
94+
@cuda threads=32 blocks=div(n, 32) computeInvCov2d(
95+
cov2ds,
96+
invCov2ds
6697
)
6798
end
68-
69-
#renderer.positions = μ′
99+
100+
CUDA.@sync begin
101+
@cuda threads=32 blocks=div(n, 32) computeBB(
102+
cov2ds,
103+
bbs,
104+
μ′,
105+
(w, h)
106+
)
107+
end
108+
109+
sortIdxs = CUDA.sortperm(tps[3, :], lt=!isless)
70110
renderer.camera = camera
71-
sortIdxs = CUDA.sortperm(tps[3, :], lt=isless) # chck
72111
renderer.sortIdxs = sortIdxs
73-
CUDA.unsafe_free!(ts)
74-
#CUDA.unsafe_free!(tps)
75112
renderer.cov2ds = cov2ds[:, :, sortIdxs]
76113
renderer.positions = μ′[:, sortIdxs]
77-
# TODO this is temporary hack
78-
#CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeCov2d_kernel(cov2ds, rots, scales) end
79-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeInvCov2d(renderer.cov2ds, invCov2ds) end
80-
CUDA.@sync begin @cuda threads=32 blocks=div(n, 32) computeBB(renderer.cov2ds, bbs, renderer.positions, (w, h)) end
114+
renderer.invCov2ds = invCov2ds[:, :, sortIdxs]
115+
renderer.bbs = bbs[:, :, sortIdxs]
81116
return tps
82117
end
83118

119+
"""
120+
compactIndex(renderer::Renderer)
84121
85-
function packedTileIds(renderer)
86-
bbs = renderer.bbs
87-
packedIds = CUDA.zeros(UInt64, nGaussians)
88-
CUDA.@sync begin
89-
@cuda threads=32 blocks=div(nGaussians, 32) binPacking(bbs, packedIds, threads..., blocks...)
90-
end
91-
end
92-
93-
122+
This function compute compact indexes.
123+
"""
94124
function compactIdxs(renderer)
95125
bbs = renderer.bbs
96126
hits = CUDA.zeros(UInt8, blocks..., renderer.nGaussians);
97127
n = renderer.nGaussians
128+
98129
CUDA.@sync begin
99130
@cuda threads=32 blocks=div(n, 32) hitBinning(hits, bbs, threads..., blocks...)
100131
end
101132

102133
# This is not memory efficient but works for small list of gaussians in tile ...
134+
103135
hitScans = CUDA.zeros(UInt16, size(hits));
104136
CUDA.@sync CUDA.scan!(+, hitScans, hits; dims=3);
105137
CUDA.@sync maxHits = CUDA.maximum(hitScans) |> Int
106138

107139
# TODO hardcoding UInt16 will cause issues if number of gaussians in a Tile
108140
# if maxHits < typemax(UInt32)
109-
maxBinSize = min((typemax(UInt16) |> Int), nextpow(2, maxHits))# TODO limiting maxBinSize hardcoded to 4096
141+
# TODO limiting maxBinSize hardcoded to 4096
142+
143+
maxBinSize = min((typemax(UInt16) |> Int), nextpow(2, maxHits))
110144
renderer.hitIdxs = CUDA.zeros(UInt32, blocks..., maxBinSize);
145+
111146
# else
112147
# maxBinSize = 2*nextpow(2, maxHits)
113148
# renderer.hitIdxs = CUDA.zeros(UInt32, blocks..., maxBinSize);
114149
# end
115150

116151
CUDA.@sync begin
117-
@cuda threads=blocks blocks=(32, div(n, 32)) shmem=reduce(*, blocks)*sizeof(UInt32) compactHits(
118-
hits,
119-
bbs,
120-
hitScans,
121-
renderer.hitIdxs
152+
@cuda(
153+
threads=blocks,
154+
blocks=(32, div(n, 32)),
155+
shmem=reduce(*, blocks)*sizeof(UInt32),
156+
compactHits(
157+
hits,
158+
bbs,
159+
hitScans,
160+
renderer.hitIdxs
161+
)
122162
)
123163
end
164+
124165
CUDA.unsafe_free!(hits)
125166
CUDA.unsafe_free!(hitScans)
126167
return nothing
127168
end
128169

129170
function forward(renderer, tps)
171+
sortIdxs = renderer.sortIdxs
172+
tps = tps[:, sortIdxs]
130173
cimage = renderer.imageData
131174
invCov2ds = renderer.invCov2ds
132-
sortIdxs = renderer.sortIdxs
133175
transmittance = renderer.transmittance
134176
positions = renderer.positions
135177
bbs = renderer.bbs
@@ -141,18 +183,23 @@ function forward(renderer, tps)
141183
eye = renderer.camera.eye .|> Float32 |>gpu
142184
lookAt = renderer.camera.lookAt .|> Float32 |> gpu
143185
CUDA.@sync begin
144-
@cuda threads=threads blocks=blocks shmem=(4*(reduce(*, threads))*sizeof(Float32)) splatDraw(
145-
cimage,
146-
transmittance,
147-
positions,
148-
tps,
149-
bbs,
150-
invCov2ds,
151-
hitIdxs,
152-
opacities,
153-
shs,
154-
eye,
155-
lookAt
186+
@cuda(
187+
threads=threads,
188+
blocks=blocks,
189+
shmem=(4*(reduce(*, threads))*sizeof(Float32)),
190+
splatDraw(
191+
cimage,
192+
transmittance,
193+
positions,
194+
tps,
195+
bbs,
196+
invCov2ds,
197+
hitIdxs,
198+
opacities,
199+
shs,
200+
eye,
201+
lookAt
202+
)
156203
)
157204
end
158205
return nothing

src/main.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ imSize = (512, 512, 3)
1717
# renderer = getRenderer(GAUSSIAN_2D, imSize, nGaussians, threads, blocks)
1818
renderer = getRenderer(
1919
GAUSSIAN_3D,
20-
joinpath(ENV["HOMEPATH"], "Downloads", "GaussianSplatting", "GaussianSplatting", "bonsai", "bonsai_30000.ply"),
20+
joinpath(ENV["HOMEPATH"], "Downloads", "GaussianSplatting", "GaussianSplatting", "train", "train_30000.ply"),
2121
imSize,
2222
threads,
2323
blocks;
@@ -26,9 +26,9 @@ renderer = getRenderer(
2626
GC.gc()
2727
CUDA.reclaim()
2828

29-
ts = preprocess(renderer)
29+
tps = preprocess(renderer)
3030
compactIdxs(renderer)
31-
forward(renderer, ts[:, renderer.sortIdxs])
31+
forward(renderer, tps)
3232
renderer.imageData[findall((x) -> isequal(x, NaN), renderer.imageData)] .= 0.0f0
3333
img = renderer.imageData |> cpu;
3434
tmpimageview = reshape(renderer.imageData, size(renderer.imageData)..., 1)
@@ -41,10 +41,10 @@ yimg = colorview(RGB{N0f8},
4141
yimg = Images.imrotate(yimg, -pi/2)
4242
imshow(yimg)
4343

44-
include("train.jl")
44+
# include("train.jl")
4545

46-
windowSize = 11
47-
nChannels = 3
48-
lossFunc = getLossFunction(imSize, windowSize, nChannels)
46+
# windowSize = 11
47+
# nChannels = 3
48+
# lossFunc = getLossFunction(imSize, windowSize, nChannels)
4949

50-
#train(renderer, gtimg, 1e-5, lossFunc)
50+
# train(renderer, gtimg, 1e-5, lossFunc)

0 commit comments

Comments
 (0)