Loading...
Searching...
No Matches
GaussianSplatNode.hpp
1#pragma once
2
3#include <Gfx/Graph/Node.hpp>
4#include <Gfx/Graph/NodeRenderer.hpp>
5#include <Gfx/Graph/RenderList.hpp>
6#include <Gfx/Graph/Utils.hpp>
7
8#include <ossia/detail/pod_vector.hpp>
9
10// clang-format off
11#if defined(near)
12#undef near
13#undef far
14#endif
15// clang-format on
16
17namespace score::gfx
18{
19
41{
42public:
44 virtual ~GaussianSplatNode();
45
46 score::gfx::NodeRenderer* createRenderer(RenderList&) const noexcept override;
47 void process(Message&& msg) override;
48
49 int splatCount{};
50 float scaleFactor{1.0f};
51 bool enableSorting{true};
52 uint32_t shDegree{3}; // 0, 1, 2, or 3
53
54 // Model transform
55 ossia::vec3f modelPosition{0.f, 0.f, 0.f};
56 ossia::vec3f modelRotation{0.f, 0.f, 0.f}; // Euler angles in degrees (pitch, yaw, roll)
57 ossia::vec3f modelScale{1.f, 1.f, 1.f};
58
59 // Camera parameters
60 ossia::vec3f position{-1.f, -1.f, -1.f};
61 ossia::vec3f center{0.f, 0.f, 0.f};
62 float fov{90.f};
63 float near{0.001f};
64 float far{10000.f};
65};
66
75{
76public:
77 explicit GaussianSplatRenderer(const GaussianSplatNode& node);
79
80 TextureRenderTarget renderTargetForInput(const Port& p) override;
81 void init(RenderList& renderer, QRhiResourceUpdateBatch& res) override;
82 void update(RenderList& renderer, QRhiResourceUpdateBatch& res, Edge* edge) override;
83 void runInitialPasses(
84 RenderList&, QRhiCommandBuffer& commands, QRhiResourceUpdateBatch*& res,
85 Edge& edge) override;
86 void runRenderPass(RenderList&, QRhiCommandBuffer& cb, Edge& edge) override;
87 void release(RenderList&) override;
88
89private:
90 void createPreprocessPipeline(RenderList& renderer);
91 void createRenderPipeline(RenderList& renderer);
92 void createSortPipelines(RenderList& renderer);
93
94 const GaussianSplatNode& m_node;
95
96 // Input render target
97 TextureRenderTarget m_inputTarget;
98
99 // Render pipeline resources
100 QRhiBuffer* m_uniformBuffer{};
101 QRhiBuffer* m_dummyStorageBuffer{}; // Small buffer for unused bindings
102 QRhiGraphicsPipeline* m_pipeline{};
103 QRhiShaderResourceBindings* m_bindings{};
104
105 // SH preprocessing compute resources
106 // Converts raw 256-byte splats → compact 64-byte rendering splats
107 QRhiBuffer* m_rawSplatBuffer{}; // Input: raw PLY data (256 bytes/splat)
108 QRhiBuffer* m_renderSplatBuffer{}; // Output: compact (64 bytes/splat)
109 QRhiBuffer* m_preprocessUniformBuffer{};
110 QRhiComputePipeline* m_preprocessPipeline{};
111 QRhiShaderResourceBindings* m_preprocessSrb{};
112
113 // Sorting compute resources
114 QRhiBuffer* m_sortKeysBuffer{}; // Depth keys (float -> uint for sorting)
115 QRhiBuffer* m_sortKeysAltBuffer{}; // Double buffer for key ping-pong
116 QRhiBuffer* m_sortIndicesBuffer{}; // Sorted indices
117 QRhiBuffer* m_sortIndicesAltBuffer{}; // Double buffer for index ping-pong
118 QRhiBuffer* m_histogramBuffer{}; // Histogram for radix sort
119 QRhiBuffer* m_sortUniformBuffer{}; // Depth key pass uniforms
120 QRhiBuffer* m_sortPassUniformBuffer{}; // Histogram/scatter/prefix uniforms
121 QRhiBuffer* m_prefixSumUniformBuffer{}; // Prefix sum uniforms
122
123 QRhiComputePipeline* m_depthKeyPipeline{};
124 QRhiComputePipeline* m_histogramPipeline{};
125 QRhiComputePipeline* m_prefixSumPipeline{};
126 QRhiComputePipeline* m_sortPipeline{};
127
128 QRhiShaderResourceBindings* m_depthKeySrb{};
129 QRhiShaderResourceBindings* m_histogramSrb{};
130 QRhiShaderResourceBindings* m_histogramSrbAlt{}; // For odd passes
131 QRhiShaderResourceBindings* m_prefixSumSrb{};
132 QRhiShaderResourceBindings* m_sortSrb{};
133 QRhiShaderResourceBindings* m_sortSrbAlt{}; // For ping-pong
134
135 ossia::small_vector<Sampler, 8> m_samplers;
136
137 int64_t m_lastSplatCount{0};
138 bool m_preprocessResourcesCreated{false};
139 bool m_sortResourcesCreated{false};
140
141 static constexpr int64_t MAX_SPLATS = 50000000;
142 static constexpr int SORT_WORKGROUP_SIZE = 256;
143 static constexpr int RADIX_BITS = 8;
144 static constexpr int NUM_BUCKETS = 256; // 2^RADIX_BITS
145};
146
147// Shader sources
148namespace GaussianSplatShaders
149{
150
151//=============================================================================
152// COMPUTE SHADER: SH PREPROCESSING (raw 256B → compact 64B per splat)
153//=============================================================================
154
165static constexpr auto preprocess_shader = R"_(#version 450
166layout(local_size_x = 256) in;
167
168// Raw splat: 64 floats = 256 bytes (matches PLY loader output)
169// [0..2] position (x,y,z)
170// [3..5] normal (nx,ny,nz) — unused
171// [6..8] SH DC (f_dc_0, f_dc_1, f_dc_2)
172// [9..53] SH rest (f_rest_0 .. f_rest_44)
173// [54] opacity (pre-sigmoid)
174// [55..57] scale (log-space)
175// [58..61] rotation (w,x,y,z)
176// [62..63] padding
177
178layout(std430, binding = 0) readonly buffer RawSplatBuffer {
179 float rawData[]; // 64 floats per splat
180};
181
182// Compact rendering splat: 16 floats = 64 bytes
183// vec4 position (xyz, 0)
184// vec4 scale (xyz, 0) — already exp'd
185// vec4 rotation (x,y,z,w) — normalized
186// vec4 color (r,g,b,a) — SH evaluated, alpha = sigmoid(opacity)
187
188struct RenderSplat {
189 vec4 position;
190 vec4 scale;
191 vec4 rotation;
192 vec4 color;
193};
194
195layout(std430, binding = 1) writeonly buffer RenderSplatBuffer {
196 RenderSplat renderSplats[];
197};
198
199layout(std140, binding = 2) uniform Params {
200 mat4 view;
201 vec3 camPos; // Camera position in world space
202 uint splatCount;
203 uint shDegree; // 0, 1, 2, or 3
204 float scaleMod;
205 uint _pad0;
206 uint _pad1;
207};
208
209// Spherical harmonics constants
210const float SH_C0 = 0.28209479177387814;
211
212const float SH_C1 = 0.4886025119029199;
213
214const float SH_C2[5] = float[5](
215 1.0925484305920792,
216 -1.0925484305920792,
217 0.31539156525252005,
218 -1.0925484305920792,
219 0.5462742152960396
220);
221
222const float SH_C3[7] = float[7](
223 -0.5900435899266435,
224 2.890611442640554,
225 -0.4570457994644658,
226 0.3731763325901154,
227 -0.4570457994644658,
228 1.445305721320277,
229 -0.5900435899266435
230);
231
232vec3 evaluateSH(uint base, vec3 dir) {
233 // Degree 0
234 vec3 result = SH_C0 * vec3(
235 rawData[base + 6],
236 rawData[base + 7],
237 rawData[base + 8]
238 );
239
240 if (shDegree < 1) {
241 return result + 0.5;
242 }
243
244 // Degree 1
245 float x = dir.x, y = dir.y, z = dir.z;
246
247 // f_rest layout: [0..14] = R channel rest, [15..29] = G, [30..44] = B
248 // But the INRIA convention interleaves: [0..2] = degree1 for R,G,B etc.
249 // Actually the standard layout is:
250 // f_rest[0..14]: coeffs 1..15 for channel 0 (R)
251 // f_rest[15..29]: coeffs 1..15 for channel 1 (G)
252 // f_rest[30..44]: coeffs 1..15 for channel 2 (B)
253
254 uint r = base + 9; // f_rest_0 start
255 // Degree 1: 3 coefficients per channel, interleaved as RGB triplets
256 // Coeff indices in f_rest: R=[0,1,2], G=[15,16,17], B=[30,31,32]
257 result += SH_C1 * (
258 - y * vec3(rawData[r+0], rawData[r+15], rawData[r+30])
259 + z * vec3(rawData[r+1], rawData[r+16], rawData[r+31])
260 - x * vec3(rawData[r+2], rawData[r+17], rawData[r+32])
261 );
262
263 if (shDegree < 2) {
264 return result + 0.5;
265 }
266
267 // Degree 2: 5 coefficients per channel
268 // R=[3..7], G=[18..22], B=[33..37]
269 float xx = x*x, yy = y*y, zz = z*z, xy = x*y, yz = y*z, xz = x*z;
270
271 result += SH_C2[0] * xy * vec3(rawData[r+3], rawData[r+18], rawData[r+33]);
272 result += SH_C2[1] * yz * vec3(rawData[r+4], rawData[r+19], rawData[r+34]);
273 result += SH_C2[2] * (2.*zz - xx - yy)
274 * vec3(rawData[r+5], rawData[r+20], rawData[r+35]);
275 result += SH_C2[3] * xz * vec3(rawData[r+6], rawData[r+21], rawData[r+36]);
276 result += SH_C2[4] * (xx - yy)* vec3(rawData[r+7], rawData[r+22], rawData[r+37]);
277
278 if (shDegree < 3) {
279 return result + 0.5;
280 }
281
282 // Degree 3: 7 coefficients per channel
283 // R=[8..14], G=[23..29], B=[38..44]
284 result += SH_C3[0] * y*(3.*xx - yy)
285 * vec3(rawData[r+8], rawData[r+23], rawData[r+38]);
286 result += SH_C3[1] * xy*z * vec3(rawData[r+9], rawData[r+24], rawData[r+39]);
287 result += SH_C3[2] * y*(4.*zz - xx - yy)
288 * vec3(rawData[r+10], rawData[r+25], rawData[r+40]);
289 result += SH_C3[3] * z*(2.*zz - 3.*xx - 3.*yy)
290 * vec3(rawData[r+11], rawData[r+26], rawData[r+41]);
291 result += SH_C3[4] * x*(4.*zz - xx - yy)
292 * vec3(rawData[r+12], rawData[r+27], rawData[r+42]);
293 result += SH_C3[5] * z*(xx - yy)
294 * vec3(rawData[r+13], rawData[r+28], rawData[r+43]);
295 result += SH_C3[6] * x*(xx - 3.*yy)
296 * vec3(rawData[r+14], rawData[r+29], rawData[r+44]);
297
298 return result + 0.5;
299}
300
301void main() {
302 uint idx = gl_GlobalInvocationID.x;
303 if (idx >= splatCount) return;
304
305 uint base = idx * 64; // 64 floats per raw splat
306
307 // Position
308 vec3 pos = vec3(rawData[base], rawData[base+1], rawData[base+2]);
309
310 // View direction for SH evaluation (world space, from camera towards splat)
311 // Must match the INRIA training convention: dir = pos - campos
312 vec3 dir = normalize(pos - camPos);
313
314 // Evaluate SH for view-dependent color
315 vec3 color = evaluateSH(base, dir);
316 color = clamp(color, 0.0, 1.0);
317
318 // Opacity: sigmoid(raw_opacity)
319 float rawOpacity = rawData[base + 54];
320 float alpha = 1.0 / (1.0 + exp(-rawOpacity));
321
322 // Scale: exp(log_scale) * scaleMod
323 vec3 scale = vec3(
324 exp(rawData[base + 55]),
325 exp(rawData[base + 56]),
326 exp(rawData[base + 57])
327 ) * scaleMod;
328
329 // Rotation: PLY stores (w,x,y,z), shader expects (x,y,z,w)
330 // Normalize quaternion
331 vec4 rawRot = vec4(
332 rawData[base + 58], // w
333 rawData[base + 59], // x
334 rawData[base + 60], // y
335 rawData[base + 61] // z
336 );
337 rawRot = normalize(rawRot);
338 vec4 rot = vec4(rawRot.y, rawRot.z, rawRot.w, rawRot.x); // xyzw
339
340 // Write compact rendering splat
341 renderSplats[idx].position = vec4(pos, 0.0);
342 renderSplats[idx].scale = vec4(scale, 0.0);
343 renderSplats[idx].rotation = rot;
344 renderSplats[idx].color = vec4(color, alpha);
345}
346)_";
347
348//=============================================================================
349// COMPUTE SHADERS FOR DEPTH SORTING
350//=============================================================================
351
356static constexpr auto depth_key_shader = R"_(#version 450
357layout(local_size_x = 256) in;
358
359struct RenderSplat {
360 vec4 position;
361 vec4 scale;
362 vec4 rotation;
363 vec4 color;
364};
365
366layout(std430, binding = 0) readonly buffer SplatBuffer {
367 RenderSplat splats[];
368};
369
370layout(std430, binding = 1) writeonly buffer KeyBuffer {
371 uint keys[];
372};
373
374layout(std430, binding = 2) writeonly buffer IndexBuffer {
375 uint indices[];
376};
377
378layout(std140, binding = 3) uniform Params {
379 mat4 view;
380 uint splatCount;
381 float nearPlane;
382 float farPlane;
383 uint _pad;
384};
385
386void main() {
387 uint idx = gl_GlobalInvocationID.x;
388 if (idx >= splatCount) return;
389
390 // Transform to view space
391 vec4 viewPos = view * vec4(splats[idx].position.xyz, 1.0);
392 float depth = -viewPos.z; // Negate because view space Z is negative
393
394 // Front-to-back sort key: top 16 bits = depth, bottom 16 bits = splat index.
395 // The depth gives correct rendering order; the index provides stable
396 // tie-breaking for splats at similar depths (same buffer order every frame).
397 // This eliminates the "wave" artifact from coherent sort-order swaps.
398 // Combined with "under" blending for correct front-to-back compositing.
399 const uint keyMax = 0xFFFFFFFFu;
400 uint key;
401 if (depth <= nearPlane) {
402 // Behind camera: draw last, but keep stable index-based sub-order
403 key = (0xFFFFu << 16u) | (idx & 0xFFFFu);
404 } else {
405 float t = log2(depth / nearPlane) / log2(farPlane / nearPlane);
406 t = clamp(t, 0.0, 1.0);
407 uint depthKey = uint(t * 65535.0);
408 key = (depthKey << 16u) | (idx & 0xFFFFu);
409 }
410
411 keys[idx] = key;
412 indices[idx] = idx;
413}
414)_";
415
420static constexpr auto histogram_shader = R"_(#version 450
421layout(local_size_x = 256) in;
422
423layout(std430, binding = 0) readonly buffer KeyBuffer {
424 uint keys[];
425};
426
427layout(std430, binding = 1) buffer HistogramBuffer {
428 uint histogram[]; // 256 buckets * num_workgroups
429};
430
431layout(std140, binding = 2) uniform Params {
432 uint splatCount;
433 uint bitOffset; // Which 8 bits to sort (0, 8, 16, 24)
434 uint numWorkgroups;
435 uint _pad;
436};
437
438shared uint localHistogram[256];
439
440void main() {
441 uint localId = gl_LocalInvocationID.x;
442 uint globalId = gl_GlobalInvocationID.x;
443 uint workgroupId = gl_WorkGroupID.x;
444
445 // Clear local histogram
446 localHistogram[localId] = 0;
447 barrier();
448
449 // Count digits in this workgroup
450 if (globalId < splatCount) {
451 uint key = keys[globalId];
452 uint digit = (key >> bitOffset) & 0xFFu;
453 atomicAdd(localHistogram[digit], 1);
454 }
455 barrier();
456
457 // Write local histogram to global memory
458 histogram[workgroupId * 256 + localId] = localHistogram[localId];
459}
460)_";
461
466static constexpr auto sort_scatter_shader = R"_(#version 450
467layout(local_size_x = 256) in;
468
469layout(std430, binding = 0) readonly buffer KeyBufferIn {
470 uint keysIn[];
471};
472
473layout(std430, binding = 1) readonly buffer IndexBufferIn {
474 uint indicesIn[];
475};
476
477layout(std430, binding = 2) writeonly buffer KeyBufferOut {
478 uint keysOut[];
479};
480
481layout(std430, binding = 3) writeonly buffer IndexBufferOut {
482 uint indicesOut[];
483};
484
485layout(std430, binding = 4) buffer HistogramBuffer {
486 uint histogram[]; // Global prefix sums
487};
488
489layout(std140, binding = 5) uniform Params {
490 uint splatCount;
491 uint bitOffset;
492 uint numWorkgroups;
493 uint _pad;
494};
495
496shared uint localDigits[256];
497shared uint localOffset[256];
498
499void main() {
500 uint localId = gl_LocalInvocationID.x;
501 uint globalId = gl_GlobalInvocationID.x;
502 uint workgroupId = gl_WorkGroupID.x;
503
504 // Load global prefix sum for this workgroup's digit
505 localOffset[localId] = histogram[workgroupId * 256 + localId];
506
507 // Load this thread's element
508 uint key = 0u;
509 uint idx = 0u;
510 uint digit = 256u; // invalid sentinel (> any real digit)
511 bool valid = globalId < splatCount;
512 if (valid) {
513 key = keysIn[globalId];
514 idx = indicesIn[globalId];
515 digit = (key >> bitOffset) & 0xFFu;
516 }
517 localDigits[localId] = digit;
518 barrier();
519
520 if (valid) {
521 // Stable rank: count threads with LOWER ID that share the same digit.
522 // This is deterministic (no atomicAdd race), so the sort is stable
523 // and identical across frames — eliminates flickering.
524 uint rank = 0u;
525 for (uint i = 0u; i < localId; i++) {
526 if (localDigits[i] == digit)
527 rank++;
528 }
529
530 uint globalPos = localOffset[digit] + rank;
531 if (globalPos < splatCount) {
532 keysOut[globalPos] = key;
533 indicesOut[globalPos] = idx;
534 }
535 }
536}
537)_";
538
553static constexpr auto prefix_sum_shader = R"_(#version 450
554layout(local_size_x = 256) in;
555
556layout(std430, binding = 0) buffer HistogramBuffer {
557 uint histogram[]; // Layout: histogram[workgroup * 256 + digit]
558};
559
560layout(std140, binding = 1) uniform Params {
561 uint numWorkgroups;
562 uint _pad0;
563 uint _pad1;
564 uint _pad2;
565};
566
567shared uint digitTotal[256];
568shared uint digitPrefix[256];
569
570void main() {
571 uint digit = gl_LocalInvocationID.x; // 0-255, one thread per digit
572
573 // Step 1: Sum all workgroup counts for this digit
574 uint total = 0;
575 for (uint wg = 0; wg < numWorkgroups; wg++) {
576 total += histogram[wg * 256 + digit];
577 }
578 digitTotal[digit] = total;
579 barrier();
580
581 // Step 2: Thread 0 computes exclusive prefix sum across all digits
582 // This determines the global starting offset for each digit bucket
583 if (digit == 0) {
584 digitPrefix[0] = 0;
585 for (uint d = 1; d < 256; d++) {
586 digitPrefix[d] = digitPrefix[d-1] + digitTotal[d-1];
587 }
588 }
589 barrier();
590
591 // Step 3: Convert per-workgroup counts to global offsets
592 // For each workgroup: offset = digitPrefix[digit] + sum of same-digit counts in earlier workgroups
593 uint running = digitPrefix[digit];
594 for (uint wg = 0; wg < numWorkgroups; wg++) {
595 uint idx = wg * 256 + digit;
596 uint val = histogram[idx];
597 histogram[idx] = running;
598 running += val;
599 }
600}
601)_";
602
603//=============================================================================
604// RENDER SHADERS
605//=============================================================================
606
607static constexpr auto vertex_shader = R"_(#version 450
608
609// Quad vertex positions
610const vec2 positions[6] = vec2[6](
611 vec2(-1.0, -1.0),
612 vec2( 1.0, -1.0),
613 vec2( 1.0, 1.0),
614 vec2(-1.0, -1.0),
615 vec2( 1.0, 1.0),
616 vec2(-1.0, 1.0)
617);
618
619// Compact rendering splat (output of preprocess compute shader)
620struct RenderSplat {
621 vec4 position; // xyz = position
622 vec4 scale; // xyz = scale (already exp'd)
623 vec4 rotation; // quaternion xyzw (already normalized)
624 vec4 color; // RGBA (SH evaluated, sigmoid applied)
625};
626
627layout(std430, binding = 0) readonly buffer SplatBuffer {
628 RenderSplat splats[];
629};
630
631// Sorted indices from depth sort pass
632layout(std430, binding = 1) readonly buffer SortedIndices {
633 uint sortedIndices[];
634};
635
636layout(std140, binding = 2) uniform Uniforms {
637 mat4 view;
638 mat4 projection;
639 mat4 clipSpaceCorr;
640 vec2 viewport;
641 float _pad0;
642 uint useSorting; // 0 = no sorting, 1 = use sorted indices
643};
644
645layout(location = 0) out vec2 f_center; // screen-space splat center (pixels)
646layout(location = 1) out vec4 f_color;
647layout(location = 2) out vec3 f_conic;
648
649mat3 quatToMat(vec4 q) {
650 float x = q.x, y = q.y, z = q.z, w = q.w;
651 // GLSL mat3 is column-major: mat3(col0, col1, col2)
652 return mat3(
653 1.0 - 2.0*(y*y + z*z), 2.0*(x*y + w*z), 2.0*(x*z - w*y), // col 0
654 2.0*(x*y - w*z), 1.0 - 2.0*(x*x + z*z), 2.0*(y*z + w*x), // col 1
655 2.0*(x*z + w*y), 2.0*(y*z - w*x), 1.0 - 2.0*(x*x + y*y) // col 2
656 );
657}
658
659void main() {
660 // Get splat index (sorted or unsorted)
661 uint splatIdx = useSorting != 0 ? sortedIndices[gl_InstanceIndex] : gl_InstanceIndex;
662 RenderSplat splat = splats[splatIdx];
663 vec2 quadPos = positions[gl_VertexIndex];
664
665 // Early opacity cull: skip splats that are nearly invisible
666 if (splat.color.a < 1.0 / 255.0) {
667 gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
668 return;
669 }
670
671 // View space position
672 vec4 viewPos = view * vec4(splat.position.xyz, 1.0);
673
674 // Focal lengths in pixels
675 float focal = projection[0][0] * viewport.x * 0.5;
676 float focal_y = projection[1][1] * viewport.y * 0.5;
677 float tanFovX = 0.5 * viewport.x / focal;
678 float tanFovY = 0.5 * viewport.y / focal_y;
679
680 // Frustum culling: project to clip space and check NDC bounds
681 // (matches INRIA reference: cull behind camera + outside 1.3x viewport)
682 vec4 clipPos = projection * viewPos;
683 if (clipPos.w <= 0.2) {
684 gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
685 return;
686 }
687 vec3 ndc = clipPos.xyz / clipPos.w;
688 if (abs(ndc.x) > 1.3 || abs(ndc.y) > 1.3) {
689 gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
690 return;
691 }
692
693 // Clamp view-space position to prevent numerical issues at screen edges
694 // (matches INRIA CUDA reference: 1.3x FOV tangent)
695 float limX = 1.3 * tanFovX;
696 float limY = 1.3 * tanFovY;
697 float txtz = viewPos.x / viewPos.z;
698 float tytz = viewPos.y / viewPos.z;
699 viewPos.x = clamp(txtz, -limX, limX) * viewPos.z;
700 viewPos.y = clamp(tytz, -limY, limY) * viewPos.z;
701
702 // Build 3D covariance from scale and rotation (already preprocessed)
703 // INRIA convention: Sigma = R * S * S^T * R^T = R * S² * R^T
704 // The principal axes are the COLUMNS of R.
705 vec3 scale = splat.scale.xyz;
706 mat3 R = quatToMat(splat.rotation);
707 mat3 S = mat3(scale.x, 0, 0, 0, scale.y, 0, 0, 0, scale.z);
708 mat3 M = R * S;
709 mat3 Sigma = M * transpose(M);
710
711 // 2D covariance via EWA projection
712 mat3 W = mat3(view);
713 float z2 = viewPos.z * viewPos.z;
714
715 // Jacobian of projection (column-major: mat3(col0, col1, col2))
716 mat3 J = mat3(
717 focal / viewPos.z, 0.0, 0.0, // col 0
718 0.0, focal_y / viewPos.z, 0.0, // col 1
719 -focal * viewPos.x / z2, -focal_y * viewPos.y / z2, 0.0 // col 2
720 );
721
722 mat3 T = J * W;
723 mat3 cov = T * Sigma * transpose(T);
724
725 float cov_xx = cov[0][0], cov_xy = cov[0][1], cov_yy = cov[1][1];
726
727 // Mip-Splatting 2D filter (Yu et al. 2024): approximate the pixel box filter
728 // as a Gaussian and convolve with the projected 2D covariance.
729 // Opacity is compensated to preserve each splat's total contribution:
730 // alpha' = alpha * sqrt(det(Sigma) / det(Sigma + kernel_size * I))
731 float kernel_size = 0.3;
732 float det_0 = max(1e-6, cov_xx * cov_yy - cov_xy * cov_xy);
733 cov_xx += kernel_size;
734 cov_yy += kernel_size;
735 float det_1 = max(1e-6, cov_xx * cov_yy - cov_xy * cov_xy);
736 float mipCoef = sqrt(det_0 / det_1);
737
738 float det = cov_xx * cov_yy - cov_xy * cov_xy;
739 float mid = 0.5 * (cov_xx + cov_yy);
740 float disc = max(0.0, mid * mid - det);
741 float lambda1 = mid + sqrt(disc);
742 float lambda2 = mid - sqrt(disc);
743
744 // Eigenvectors of 2D covariance for ellipse-aligned quad
745 vec2 eigVec1;
746 if (abs(cov_xy) > 1e-6) {
747 eigVec1 = normalize(vec2(cov_xy, lambda1 - cov_xx));
748 } else {
749 eigVec1 = (cov_xx >= cov_yy) ? vec2(1.0, 0.0) : vec2(0.0, 1.0);
750 }
751 vec2 eigVec2 = vec2(-eigVec1.y, eigVec1.x);
752
753 float maxExtent = 2048.0;
754 float r1 = min(ceil(3.0 * sqrt(max(lambda1, 0.0))), maxExtent);
755 float r2 = min(ceil(3.0 * sqrt(max(lambda2, 0.0))), maxExtent);
756
757 // Cull degenerate or invisible splats
758 if (det < 1e-3 || max(r1, r2) < 0.1) {
759 gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
760 return;
761 }
762
763 // Inverse covariance (conic) for fragment Gaussian evaluation.
764 // The cross-term sign must match the screen-space convention of gl_FragCoord:
765 // Vulkan/Metal/D3D (clipSpaceCorr[1][1] < 0): both screen axes flip
766 // relative to J-space, preserving the cross-product sign.
767 // OpenGL (clipSpaceCorr[1][1] > 0): only X flips, requiring correction.
768 float inv_det = 1.0 / det;
769 float crossSign = sign(clipSpaceCorr[1][1]);
770 f_conic = vec3(cov_yy * inv_det, crossSign * cov_xy * inv_det, cov_xx * inv_det);
771
772 // Oriented quad: major axis along eigVec1, minor along eigVec2
773 vec2 pixelOffset = quadPos.x * r1 * eigVec1 + quadPos.y * r2 * eigVec2;
774 vec2 center = ndc.xy;
775 vec2 ndcOffset = pixelOffset * 2.0 / viewport;
776
777 gl_Position = clipSpaceCorr * vec4(center + ndcOffset, ndc.z, 1.0);
778
779 // Score's texture compositing pipeline flips Y when sampling for Vulkan/HLSL/Metal.
780 // To match this convention (same as ISF shaders), we undo clipSpaceCorr's Y-flip here
781 // so the compositing re-flip produces a correctly oriented final image.
782 gl_Position.y = -gl_Position.y;
783
784 // Screen-space center in pixels (matches gl_FragCoord coordinate system)
785 vec4 centerClip = clipSpaceCorr * vec4(ndc.xy, ndc.z, 1.0);
786 centerClip.y = -centerClip.y;
787 f_center = (centerClip.xy / centerClip.w * 0.5 + 0.5) * viewport;
788
789 // Fade out excessively large projected splats.
790 float alpha = splat.color.a * mipCoef;
791 float maxR = max(r1, r2);
792 float fadeRadius = 512.0;
793 if (maxR > fadeRadius) {
794 float fade = fadeRadius / maxR;
795 alpha *= fade;
796 if (alpha < 1.0 / 255.0) {
797 gl_Position = vec4(0.0, 0.0, 2.0, 1.0);
798 return;
799 }
800 }
801 f_color = vec4(splat.color.rgb, alpha);
802}
803)_";
804
805static constexpr auto fragment_shader = R"_(#version 450
806
807layout(location = 0) in vec2 f_center; // screen-space splat center (pixels)
808layout(location = 1) in vec4 f_color;
809layout(location = 2) in vec3 f_conic;
810
811layout(location = 0) out vec4 fragColor;
812
813void main() {
814 // Pixel offset from splat center, computed per-fragment for precision.
815 // Unlike interpolated UVs, this is exact regardless of quad orientation.
816 vec2 d = gl_FragCoord.xy - f_center;
817
818 float power = -0.5 * (f_conic.x * d.x * d.x +
819 2.0 * f_conic.y * d.x * d.y +
820 f_conic.z * d.y * d.y);
821
822 if (power > 0.0) discard;
823
824 float gaussian = exp(power);
825 float alpha = min(0.99, gaussian * f_color.a);
826 if (alpha < 1.0/255.0) discard;
827
828 fragColor = vec4(f_color.rgb * alpha, alpha);
829}
830)_";
831
832} // namespace GaussianSplatShaders
833
834} // namespace score::gfx
Renderer for GaussianSplatNode.
Definition GaussianSplatNode.hpp:75
Generic renderer.
Definition NodeRenderer.hpp:96
Common base class for most single-pass, simple nodes.
Definition score-plugin-gfx/Gfx/Graph/Node.hpp:203
Renderer for a given node.
Definition NodeRenderer.hpp:11
List of nodes to be rendered to an output.
Definition RenderList.hpp:19
Graphics rendering pipeline for ossia score.
Definition Filter/PreviewWidget.hpp:12
Connection between two score::gfx::Port.
Definition score-plugin-gfx/Gfx/Graph/Utils.hpp:71
Gaussian Splat rendering node.
Definition GaussianSplatNode.hpp:41
score::gfx::NodeRenderer * createRenderer(RenderList &) const noexcept override
Create a renderer in a given context for this node.
Definition GaussianSplatNode.cpp:78
void process(Message &&msg) override
Process a message from the execution engine.
Definition GaussianSplatNode.cpp:35
Definition score-plugin-gfx/Gfx/Graph/Node.hpp:49
Port of a score::gfx::Node.
Definition score-plugin-gfx/Gfx/Graph/Utils.hpp:53
Useful abstraction for storing all the data related to a render target.
Definition score-plugin-gfx/Gfx/Graph/Utils.hpp:116