Skip to content

Commit a7e7975

Browse files
feat: Better add, sub, mul, div interface (#1283)
1 parent f6564ed commit a7e7975

File tree

8 files changed

+822
-317
lines changed

8 files changed

+822
-317
lines changed

packages/typegpu/src/data/vectorOps.ts

Lines changed: 181 additions & 133 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ import type * as wgsl from './wgslTypes.ts';
2020
import type { VecKind } from './wgslTypes.ts';
2121

2222
type vBase = { kind: VecKind };
23+
type mBase = { kind: MatKind };
2324
type v2 = wgsl.v2f | wgsl.v2h | wgsl.v2i | wgsl.v2u;
2425
type v3 = wgsl.v3f | wgsl.v3h | wgsl.v3i | wgsl.v3u;
2526
type v4 = wgsl.v4f | wgsl.v4h | wgsl.v4i | wgsl.v4u;
@@ -72,6 +73,26 @@ const unary4i = (op: UnaryOp) => (a: wgsl.v4i) =>
7273
const unary4u = (op: UnaryOp) => (a: wgsl.v4u) =>
7374
vec4u(op(a.x), op(a.y), op(a.z), op(a.w));
7475

76+
const unary2x2f = (op: UnaryOp) => (a: wgsl.m2x2f) => {
77+
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
78+
return mat2x2f(unary2f(op)(a_[0]), unary2f(op)(a_[1]));
79+
};
80+
81+
const unary3x3f = (op: UnaryOp) => (a: wgsl.m3x3f) => {
82+
const a_ = a.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
83+
return mat3x3f(unary3f(op)(a_[0]), unary3f(op)(a_[1]), unary3f(op)(a_[2]));
84+
};
85+
86+
const unary4x4f = (op: UnaryOp) => (a: wgsl.m4x4f) => {
87+
const a_ = a.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
88+
return mat4x4f(
89+
unary4f(op)(a_[0]),
90+
unary4f(op)(a_[1]),
91+
unary4f(op)(a_[2]),
92+
unary4f(op)(a_[3]),
93+
);
94+
};
95+
7596
const binaryComponentWise2f = (op: BinaryOp) => (a: wgsl.v2f, b: wgsl.v2f) =>
7697
vec2f(op(a.x, b.x), op(a.y, b.y));
7798

@@ -108,6 +129,48 @@ const binaryComponentWise4i = (op: BinaryOp) => (a: wgsl.v4i, b: wgsl.v4i) =>
108129
const binaryComponentWise4u = (op: BinaryOp) => (a: wgsl.v4u, b: wgsl.v4u) =>
109130
vec4u(op(a.x, b.x), op(a.y, b.y), op(a.z, b.z), op(a.w, b.w));
110131

132+
const binaryComponentWise2x2f =
133+
(op: BinaryOp) => (a: wgsl.m2x2f, b: wgsl.m2x2f) => {
134+
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
135+
const b_ = b.columns as [wgsl.v2f, wgsl.v2f];
136+
return mat2x2f(
137+
binaryComponentWise2f(op)(a_[0], b_[0]),
138+
binaryComponentWise2f(op)(a_[1], b_[1]),
139+
);
140+
};
141+
142+
const binaryComponentWise3x3f =
143+
(op: BinaryOp) => (a: wgsl.m3x3f, b: wgsl.m3x3f) => {
144+
const a_ = a.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
145+
const b_ = b.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
146+
return mat3x3f(
147+
binaryComponentWise3f(op)(a_[0], b_[0]),
148+
binaryComponentWise3f(op)(a_[1], b_[1]),
149+
binaryComponentWise3f(op)(a_[2], b_[2]),
150+
);
151+
};
152+
153+
const binaryComponentWise4x4f =
154+
(op: BinaryOp) => (a: wgsl.m4x4f, b: wgsl.m4x4f) => {
155+
const a_ = a.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
156+
const b_ = b.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
157+
return mat4x4f(
158+
binaryComponentWise4f(op)(a_[0], b_[0]),
159+
binaryComponentWise4f(op)(a_[1], b_[1]),
160+
binaryComponentWise4f(op)(a_[2], b_[2]),
161+
binaryComponentWise4f(op)(a_[3], b_[3]),
162+
);
163+
};
164+
165+
export const NumberOps = {
166+
divInteger: (lhs: number, rhs: number) => {
167+
if (rhs === 0) {
168+
return lhs;
169+
}
170+
return Math.trunc(lhs / rhs);
171+
},
172+
};
173+
111174
export const VectorOps = {
112175
eq: {
113176
vec2f: (e1: wgsl.v2f, e2: wgsl.v2f) => vec2b(e1.x === e2.x, e1.y === e2.y),
@@ -349,132 +412,92 @@ export const VectorOps = {
349412
} as Record<VecKind, (v: vBase) => number>,
350413

351414
add: {
352-
vec2f: (a: wgsl.v2f, b: wgsl.v2f) => vec2f(a.x + b.x, a.y + b.y),
353-
vec2h: (a: wgsl.v2h, b: wgsl.v2h) => vec2h(a.x + b.x, a.y + b.y),
354-
vec2i: (a: wgsl.v2i, b: wgsl.v2i) => vec2i(a.x + b.x, a.y + b.y),
355-
vec2u: (a: wgsl.v2u, b: wgsl.v2u) => vec2u(a.x + b.x, a.y + b.y),
356-
357-
vec3f: (a: wgsl.v3f, b: wgsl.v3f) => vec3f(a.x + b.x, a.y + b.y, a.z + b.z),
358-
vec3h: (a: wgsl.v3h, b: wgsl.v3h) => vec3h(a.x + b.x, a.y + b.y, a.z + b.z),
359-
vec3i: (a: wgsl.v3i, b: wgsl.v3i) => vec3i(a.x + b.x, a.y + b.y, a.z + b.z),
360-
vec3u: (a: wgsl.v3u, b: wgsl.v3u) => vec3u(a.x + b.x, a.y + b.y, a.z + b.z),
361-
362-
vec4f: (a: wgsl.v4f, b: wgsl.v4f) =>
363-
vec4f(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
364-
vec4h: (a: wgsl.v4h, b: wgsl.v4h) =>
365-
vec4h(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
366-
vec4i: (a: wgsl.v4i, b: wgsl.v4i) =>
367-
vec4i(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
368-
vec4u: (a: wgsl.v4u, b: wgsl.v4u) =>
369-
vec4u(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
370-
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: T) => T>,
371-
372-
sub: {
373-
vec2f: (a: wgsl.v2f, b: wgsl.v2f) => vec2f(a.x - b.x, a.y - b.y),
374-
vec2h: (a: wgsl.v2h, b: wgsl.v2h) => vec2h(a.x - b.x, a.y - b.y),
375-
vec2i: (a: wgsl.v2i, b: wgsl.v2i) => vec2i(a.x - b.x, a.y - b.y),
376-
vec2u: (a: wgsl.v2u, b: wgsl.v2u) => vec2u(a.x - b.x, a.y - b.y),
377-
378-
vec3f: (a: wgsl.v3f, b: wgsl.v3f) => vec3f(a.x - b.x, a.y - b.y, a.z - b.z),
379-
vec3h: (a: wgsl.v3h, b: wgsl.v3h) => vec3h(a.x - b.x, a.y - b.y, a.z - b.z),
380-
vec3i: (a: wgsl.v3i, b: wgsl.v3i) => vec3i(a.x - b.x, a.y - b.y, a.z - b.z),
381-
vec3u: (a: wgsl.v3u, b: wgsl.v3u) => vec3u(a.x - b.x, a.y - b.y, a.z - b.z),
382-
383-
vec4f: (a: wgsl.v4f, b: wgsl.v4f) =>
384-
vec4f(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
385-
vec4h: (a: wgsl.v4h, b: wgsl.v4h) =>
386-
vec4h(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
387-
vec4i: (a: wgsl.v4i, b: wgsl.v4i) =>
388-
vec4i(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
389-
vec4u: (a: wgsl.v4u, b: wgsl.v4u) =>
390-
vec4u(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
391-
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: T) => T>,
392-
393-
mulSxV: {
394-
vec2f: (s: number, v: wgsl.v2f) => vec2f(s * v.x, s * v.y),
395-
vec2h: (s: number, v: wgsl.v2h) => vec2h(s * v.x, s * v.y),
396-
vec2i: (s: number, v: wgsl.v2i) => vec2i(s * v.x, s * v.y),
397-
vec2u: (s: number, v: wgsl.v2u) => vec2u(s * v.x, s * v.y),
398-
399-
vec3f: (s: number, v: wgsl.v3f) => vec3f(s * v.x, s * v.y, s * v.z),
400-
vec3h: (s: number, v: wgsl.v3h) => vec3h(s * v.x, s * v.y, s * v.z),
401-
vec3i: (s: number, v: wgsl.v3i) => vec3i(s * v.x, s * v.y, s * v.z),
402-
vec3u: (s: number, v: wgsl.v3u) => vec3u(s * v.x, s * v.y, s * v.z),
403-
404-
vec4f: (s: number, v: wgsl.v4f) =>
405-
vec4f(s * v.x, s * v.y, s * v.z, s * v.w),
406-
vec4h: (s: number, v: wgsl.v4h) =>
407-
vec4h(s * v.x, s * v.y, s * v.z, s * v.w),
408-
vec4i: (s: number, v: wgsl.v4i) =>
409-
vec4i(s * v.x, s * v.y, s * v.z, s * v.w),
410-
vec4u: (s: number, v: wgsl.v4u) =>
411-
vec4u(s * v.x, s * v.y, s * v.z, s * v.w),
412-
413-
mat2x2f: (s: number, m: wgsl.m2x2f) => {
414-
const m_ = m.columns as [wgsl.v2f, wgsl.v2f];
415-
return mat2x2f(s * m_[0].x, s * m_[0].y, s * m_[1].x, s * m_[1].y);
416-
},
415+
vec2f: binaryComponentWise2f((a, b) => a + b),
416+
vec2h: binaryComponentWise2h((a, b) => a + b),
417+
vec2i: binaryComponentWise2i((a, b) => a + b),
418+
vec2u: binaryComponentWise2u((a, b) => a + b),
419+
420+
vec3f: binaryComponentWise3f((a, b) => a + b),
421+
vec3h: binaryComponentWise3h((a, b) => a + b),
422+
vec3i: binaryComponentWise3i((a, b) => a + b),
423+
vec3u: binaryComponentWise3u((a, b) => a + b),
424+
425+
vec4f: binaryComponentWise4f((a, b) => a + b),
426+
vec4h: binaryComponentWise4h((a, b) => a + b),
427+
vec4i: binaryComponentWise4i((a, b) => a + b),
428+
vec4u: binaryComponentWise4u((a, b) => a + b),
429+
430+
mat2x2f: binaryComponentWise2x2f((a, b) => a + b),
431+
mat3x3f: binaryComponentWise3x3f((a, b) => a + b),
432+
mat4x4f: binaryComponentWise4x4f((a, b) => a + b),
433+
} as Record<
434+
VecKind | MatKind,
435+
<T extends vBase | mBase>(lhs: T, rhs: T) => T
436+
>,
417437

418-
mat3x3f: (s: number, m: wgsl.m3x3f) => {
419-
const m_ = m.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
420-
return mat3x3f(
421-
s * m_[0].x,
422-
s * m_[0].y,
423-
s * m_[0].z,
424-
s * m_[1].x,
425-
s * m_[1].y,
426-
s * m_[1].z,
427-
s * m_[2].x,
428-
s * m_[2].y,
429-
s * m_[2].z,
430-
);
431-
},
438+
addMixed: {
439+
vec2f: (a: wgsl.v2f, b: number) => unary2f((e) => e + b)(a),
440+
vec2h: (a: wgsl.v2h, b: number) => unary2h((e) => e + b)(a),
441+
vec2i: (a: wgsl.v2i, b: number) => unary2i((e) => e + b)(a),
442+
vec2u: (a: wgsl.v2u, b: number) => unary2u((e) => e + b)(a),
443+
444+
vec3f: (a: wgsl.v3f, b: number) => unary3f((e) => e + b)(a),
445+
vec3h: (a: wgsl.v3h, b: number) => unary3h((e) => e + b)(a),
446+
vec3i: (a: wgsl.v3i, b: number) => unary3i((e) => e + b)(a),
447+
vec3u: (a: wgsl.v3u, b: number) => unary3u((e) => e + b)(a),
448+
449+
vec4f: (a: wgsl.v4f, b: number) => unary4f((e) => e + b)(a),
450+
vec4h: (a: wgsl.v4h, b: number) => unary4h((e) => e + b)(a),
451+
vec4i: (a: wgsl.v4i, b: number) => unary4i((e) => e + b)(a),
452+
vec4u: (a: wgsl.v4u, b: number) => unary4u((e) => e + b)(a),
453+
454+
mat2x2f: (a: wgsl.m2x2f, b: number) => unary2x2f((e) => e + b)(a),
455+
mat3x3f: (a: wgsl.m3x3f, b: number) => unary3x3f((e) => e + b)(a),
456+
mat4x4f: (a: wgsl.m4x4f, b: number) => unary4x4f((e) => e + b)(a),
457+
} as Record<
458+
VecKind | MatKind,
459+
<T extends vBase | mBase>(lhs: T, rhs: number) => T
460+
>,
432461

433-
mat4x4f: (s: number, m: wgsl.m4x4f) => {
434-
const m_ = m.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
435-
return mat4x4f(
436-
s * m_[0].x,
437-
s * m_[0].y,
438-
s * m_[0].z,
439-
s * m_[0].w,
440-
s * m_[1].x,
441-
s * m_[1].y,
442-
s * m_[1].z,
443-
s * m_[1].w,
444-
s * m_[2].x,
445-
s * m_[2].y,
446-
s * m_[2].z,
447-
s * m_[2].w,
448-
s * m_[3].x,
449-
s * m_[3].y,
450-
s * m_[3].z,
451-
s * m_[3].w,
452-
);
453-
},
462+
mulSxV: {
463+
vec2f: (s: number, v: wgsl.v2f) => unary2f((e) => s * e)(v),
464+
vec2h: (s: number, v: wgsl.v2h) => unary2h((e) => s * e)(v),
465+
vec2i: (s: number, v: wgsl.v2i) => unary2i((e) => s * e)(v),
466+
vec2u: (s: number, v: wgsl.v2u) => unary2u((e) => s * e)(v),
467+
468+
vec3f: (s: number, v: wgsl.v3f) => unary3f((e) => s * e)(v),
469+
vec3h: (s: number, v: wgsl.v3h) => unary3h((e) => s * e)(v),
470+
vec3i: (s: number, v: wgsl.v3i) => unary3i((e) => s * e)(v),
471+
vec3u: (s: number, v: wgsl.v3u) => unary3u((e) => s * e)(v),
472+
473+
vec4f: (s: number, v: wgsl.v4f) => unary4f((e) => s * e)(v),
474+
vec4h: (s: number, v: wgsl.v4h) => unary4h((e) => s * e)(v),
475+
vec4i: (s: number, v: wgsl.v4i) => unary4i((e) => s * e)(v),
476+
vec4u: (s: number, v: wgsl.v4u) => unary4u((e) => s * e)(v),
477+
478+
mat2x2f: (s: number, m: wgsl.m2x2f) => unary2x2f((e) => s * e)(m),
479+
mat3x3f: (s: number, m: wgsl.m3x3f) => unary3x3f((e) => s * e)(m),
480+
mat4x4f: (s: number, m: wgsl.m4x4f) => unary4x4f((e) => s * e)(m),
454481
} as Record<
455482
VecKind | MatKind,
456483
<T extends vBase | wgsl.AnyMatInstance>(s: number, v: T) => T
457484
>,
458485

459486
mulVxV: {
460-
vec2f: (a: wgsl.v2f, b: wgsl.v2f) => vec2f(a.x * b.x, a.y * b.y),
461-
vec2h: (a: wgsl.v2h, b: wgsl.v2h) => vec2h(a.x * b.x, a.y * b.y),
462-
vec2i: (a: wgsl.v2i, b: wgsl.v2i) => vec2i(a.x * b.x, a.y * b.y),
463-
vec2u: (a: wgsl.v2u, b: wgsl.v2u) => vec2u(a.x * b.x, a.y * b.y),
464-
465-
vec3f: (a: wgsl.v3f, b: wgsl.v3f) => vec3f(a.x * b.x, a.y * b.y, a.z * b.z),
466-
vec3h: (a: wgsl.v3h, b: wgsl.v3h) => vec3h(a.x * b.x, a.y * b.y, a.z * b.z),
467-
vec3i: (a: wgsl.v3i, b: wgsl.v3i) => vec3i(a.x * b.x, a.y * b.y, a.z * b.z),
468-
vec3u: (a: wgsl.v3u, b: wgsl.v3u) => vec3u(a.x * b.x, a.y * b.y, a.z * b.z),
469-
470-
vec4f: (a: wgsl.v4f, b: wgsl.v4f) =>
471-
vec4f(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
472-
vec4h: (a: wgsl.v4h, b: wgsl.v4h) =>
473-
vec4h(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
474-
vec4i: (a: wgsl.v4i, b: wgsl.v4i) =>
475-
vec4i(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
476-
vec4u: (a: wgsl.v4u, b: wgsl.v4u) =>
477-
vec4u(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
487+
vec2f: binaryComponentWise2f((a, b) => a * b),
488+
vec2h: binaryComponentWise2h((a, b) => a * b),
489+
vec2i: binaryComponentWise2i((a, b) => a * b),
490+
vec2u: binaryComponentWise2u((a, b) => a * b),
491+
492+
vec3f: binaryComponentWise3f((a, b) => a * b),
493+
vec3h: binaryComponentWise3h((a, b) => a * b),
494+
vec3i: binaryComponentWise3i((a, b) => a * b),
495+
vec3u: binaryComponentWise3u((a, b) => a * b),
496+
497+
vec4f: binaryComponentWise4f((a, b) => a * b),
498+
vec4h: binaryComponentWise4h((a, b) => a * b),
499+
vec4i: binaryComponentWise4i((a, b) => a * b),
500+
vec4u: binaryComponentWise4u((a, b) => a * b),
478501

479502
mat2x2f: (a: wgsl.m2x2f, b: wgsl.m2x2f) => {
480503
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
@@ -651,6 +674,46 @@ export const VectorOps = {
651674
) => wgsl.vBaseForMat<T>
652675
>,
653676

677+
div: {
678+
vec2f: binaryComponentWise2f((a, b) => a / b),
679+
vec2h: binaryComponentWise2h((a, b) => a / b),
680+
vec2i: binaryComponentWise2i(NumberOps.divInteger),
681+
vec2u: binaryComponentWise2u(NumberOps.divInteger),
682+
683+
vec3f: binaryComponentWise3f((a, b) => a / b),
684+
vec3h: binaryComponentWise3h((a, b) => a / b),
685+
vec3i: binaryComponentWise3i(NumberOps.divInteger),
686+
vec3u: binaryComponentWise3u(NumberOps.divInteger),
687+
688+
vec4f: binaryComponentWise4f((a, b) => a / b),
689+
vec4h: binaryComponentWise4h((a, b) => a / b),
690+
vec4i: binaryComponentWise4i(NumberOps.divInteger),
691+
vec4u: binaryComponentWise4u(NumberOps.divInteger),
692+
} as Record<VecKind, <T extends vBase>(a: T, b: T) => T>,
693+
694+
divMixed: {
695+
vec2f: (a: wgsl.v2f, b: number) => unary2f((e) => e / b)(a),
696+
vec2h: (a: wgsl.v2h, b: number) => unary2h((e) => e / b)(a),
697+
vec2i: (a: wgsl.v2i, b: number) =>
698+
unary2i((e) => NumberOps.divInteger(e, b))(a),
699+
vec2u: (a: wgsl.v2u, b: number) =>
700+
unary2u((e) => NumberOps.divInteger(e, b))(a),
701+
702+
vec3f: (a: wgsl.v3f, b: number) => unary3f((e) => e / b)(a),
703+
vec3h: (a: wgsl.v3h, b: number) => unary3h((e) => e / b)(a),
704+
vec3i: (a: wgsl.v3i, b: number) =>
705+
unary3i((e) => NumberOps.divInteger(e, b))(a),
706+
vec3u: (a: wgsl.v3u, b: number) =>
707+
unary3u((e) => NumberOps.divInteger(e, b))(a),
708+
709+
vec4f: (a: wgsl.v4f, b: number) => unary4f((e) => e / b)(a),
710+
vec4h: (a: wgsl.v4h, b: number) => unary4h((e) => e / b)(a),
711+
vec4i: (a: wgsl.v4i, b: number) =>
712+
unary4i((e) => NumberOps.divInteger(e, b))(a),
713+
vec4u: (a: wgsl.v4u, b: number) =>
714+
unary4u((e) => NumberOps.divInteger(e, b))(a),
715+
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: number) => T>,
716+
654717
dot: {
655718
vec2f: dotVec2,
656719
vec2h: dotVec2,
@@ -839,21 +902,6 @@ export const VectorOps = {
839902
vec4h: unary4h(Math.sqrt),
840903
} as Record<VecKind, <T extends vBase>(v: T) => T>,
841904

842-
div: {
843-
vec2f: binaryComponentWise2f((a, b) => a / b),
844-
vec2h: binaryComponentWise2h((a, b) => a / b),
845-
vec2i: binaryComponentWise2i((a, b) => a / b),
846-
vec2u: binaryComponentWise2u((a, b) => a / b),
847-
vec3f: binaryComponentWise3f((a, b) => a / b),
848-
vec3h: binaryComponentWise3h((a, b) => a / b),
849-
vec3i: binaryComponentWise3i((a, b) => a / b),
850-
vec3u: binaryComponentWise3u((a, b) => a / b),
851-
vec4f: binaryComponentWise4f((a, b) => a / b),
852-
vec4h: binaryComponentWise4h((a, b) => a / b),
853-
vec4i: binaryComponentWise4i((a, b) => a / b),
854-
vec4u: binaryComponentWise4u((a, b) => a / b),
855-
} as Record<VecKind, <T extends vBase>(a: T, b: T) => T>,
856-
857905
mix: {
858906
vec2f: (e1: wgsl.v2f, e2: wgsl.v2f, e3: wgsl.v2f | number) => {
859907
if (typeof e3 === 'number') {

0 commit comments

Comments
 (0)