Skip to content

feat: Better add, sub, mul, div interface #1283

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 31 commits into from
May 26, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
ba3f5c4
Update mul signature
Apr 25, 2025
7a5ee37
Add missing mul tests
Apr 25, 2025
5ce563d
Mul overload shenanigans
Apr 25, 2025
ba4b816
Update m2x2f to have fluent mul
Apr 25, 2025
9fee6f2
Add vectorOps helpers for matrices, update `add` signature
Apr 25, 2025
70a8a71
Update sub signature
Apr 28, 2025
d65e11c
Add missing `add` tests
Apr 28, 2025
7a8eba8
Update v2f signature with mul
Apr 28, 2025
4a7fba3
Merge remote-tracking branch 'origin/main' into feat/implement-fluent…
Apr 28, 2025
7dde273
Update Vec2fImpl mul overload
Apr 28, 2025
14dfa09
Add a temporary workaround to wgslGenerator
Apr 28, 2025
78c0933
Add tests for new `sub` signature
Apr 28, 2025
bf2c0de
Update remaining interfaces
Apr 28, 2025
71a522a
Update wgslGenerator to handle all
Apr 28, 2025
30f23d8
Add tests for resolve
Apr 29, 2025
d4127b6
Add example tests for fluent operators
Apr 29, 2025
cd49475
Clean up vector ops a little, add missing mul tests
Apr 29, 2025
020eb44
Remove fluent operator changes
May 22, 2025
6a3f636
Merge remote-tracking branch 'origin/main' into feat/better-mul-inter…
May 22, 2025
3097def
Fix types
May 22, 2025
b407148
Reformat numeric.ts
May 22, 2025
509e0df
Add union overload types and tests
May 22, 2025
a89bae7
Lint
May 22, 2025
d06c115
Rewrite `div`
May 22, 2025
54ee606
Update div overload and tests
May 23, 2025
ad474a0
Nits
May 23, 2025
5ddcfba
Optimize imports
May 23, 2025
19d988b
Refactor and update instance checkers
May 23, 2025
9f23dc7
Fix interfaces to comply with the docs
May 23, 2025
f311b4a
Review fixes
May 26, 2025
ea132e2
Merge remote-tracking branch 'origin/main' into feat/better-mul-inter…
May 26, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
314 changes: 181 additions & 133 deletions packages/typegpu/src/data/vectorOps.ts
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import type * as wgsl from './wgslTypes.ts';
import type { VecKind } from './wgslTypes.ts';

type vBase = { kind: VecKind };
type mBase = { kind: MatKind };
type v2 = wgsl.v2f | wgsl.v2h | wgsl.v2i | wgsl.v2u;
type v3 = wgsl.v3f | wgsl.v3h | wgsl.v3i | wgsl.v3u;
type v4 = wgsl.v4f | wgsl.v4h | wgsl.v4i | wgsl.v4u;
Expand Down Expand Up @@ -72,6 +73,26 @@ const unary4i = (op: UnaryOp) => (a: wgsl.v4i) =>
const unary4u = (op: UnaryOp) => (a: wgsl.v4u) =>
vec4u(op(a.x), op(a.y), op(a.z), op(a.w));

const unary2x2f = (op: UnaryOp) => (a: wgsl.m2x2f) => {
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
return mat2x2f(unary2f(op)(a_[0]), unary2f(op)(a_[1]));
};

const unary3x3f = (op: UnaryOp) => (a: wgsl.m3x3f) => {
const a_ = a.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
return mat3x3f(unary3f(op)(a_[0]), unary3f(op)(a_[1]), unary3f(op)(a_[2]));
};

const unary4x4f = (op: UnaryOp) => (a: wgsl.m4x4f) => {
const a_ = a.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
return mat4x4f(
unary4f(op)(a_[0]),
unary4f(op)(a_[1]),
unary4f(op)(a_[2]),
unary4f(op)(a_[3]),
);
};

const binaryComponentWise2f = (op: BinaryOp) => (a: wgsl.v2f, b: wgsl.v2f) =>
vec2f(op(a.x, b.x), op(a.y, b.y));

Expand Down Expand Up @@ -108,6 +129,48 @@ const binaryComponentWise4i = (op: BinaryOp) => (a: wgsl.v4i, b: wgsl.v4i) =>
const binaryComponentWise4u = (op: BinaryOp) => (a: wgsl.v4u, b: wgsl.v4u) =>
vec4u(op(a.x, b.x), op(a.y, b.y), op(a.z, b.z), op(a.w, b.w));

const binaryComponentWise2x2f =
(op: BinaryOp) => (a: wgsl.m2x2f, b: wgsl.m2x2f) => {
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
const b_ = b.columns as [wgsl.v2f, wgsl.v2f];
return mat2x2f(
binaryComponentWise2f(op)(a_[0], b_[0]),
binaryComponentWise2f(op)(a_[1], b_[1]),
);
};

const binaryComponentWise3x3f =
(op: BinaryOp) => (a: wgsl.m3x3f, b: wgsl.m3x3f) => {
const a_ = a.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
const b_ = b.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
return mat3x3f(
binaryComponentWise3f(op)(a_[0], b_[0]),
binaryComponentWise3f(op)(a_[1], b_[1]),
binaryComponentWise3f(op)(a_[2], b_[2]),
);
};

const binaryComponentWise4x4f =
(op: BinaryOp) => (a: wgsl.m4x4f, b: wgsl.m4x4f) => {
const a_ = a.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
const b_ = b.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
return mat4x4f(
binaryComponentWise4f(op)(a_[0], b_[0]),
binaryComponentWise4f(op)(a_[1], b_[1]),
binaryComponentWise4f(op)(a_[2], b_[2]),
binaryComponentWise4f(op)(a_[3], b_[3]),
);
};

export const NumberOps = {
divInteger: (lhs: number, rhs: number) => {
if (rhs === 0) {
return lhs;
}
return Math.trunc(lhs / rhs);
},
};

export const VectorOps = {
eq: {
vec2f: (e1: wgsl.v2f, e2: wgsl.v2f) => vec2b(e1.x === e2.x, e1.y === e2.y),
Expand Down Expand Up @@ -349,132 +412,92 @@ export const VectorOps = {
} as Record<VecKind, (v: vBase) => number>,

add: {
vec2f: (a: wgsl.v2f, b: wgsl.v2f) => vec2f(a.x + b.x, a.y + b.y),
vec2h: (a: wgsl.v2h, b: wgsl.v2h) => vec2h(a.x + b.x, a.y + b.y),
vec2i: (a: wgsl.v2i, b: wgsl.v2i) => vec2i(a.x + b.x, a.y + b.y),
vec2u: (a: wgsl.v2u, b: wgsl.v2u) => vec2u(a.x + b.x, a.y + b.y),

vec3f: (a: wgsl.v3f, b: wgsl.v3f) => vec3f(a.x + b.x, a.y + b.y, a.z + b.z),
vec3h: (a: wgsl.v3h, b: wgsl.v3h) => vec3h(a.x + b.x, a.y + b.y, a.z + b.z),
vec3i: (a: wgsl.v3i, b: wgsl.v3i) => vec3i(a.x + b.x, a.y + b.y, a.z + b.z),
vec3u: (a: wgsl.v3u, b: wgsl.v3u) => vec3u(a.x + b.x, a.y + b.y, a.z + b.z),

vec4f: (a: wgsl.v4f, b: wgsl.v4f) =>
vec4f(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
vec4h: (a: wgsl.v4h, b: wgsl.v4h) =>
vec4h(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
vec4i: (a: wgsl.v4i, b: wgsl.v4i) =>
vec4i(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
vec4u: (a: wgsl.v4u, b: wgsl.v4u) =>
vec4u(a.x + b.x, a.y + b.y, a.z + b.z, a.w + b.w),
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: T) => T>,

sub: {
vec2f: (a: wgsl.v2f, b: wgsl.v2f) => vec2f(a.x - b.x, a.y - b.y),
vec2h: (a: wgsl.v2h, b: wgsl.v2h) => vec2h(a.x - b.x, a.y - b.y),
vec2i: (a: wgsl.v2i, b: wgsl.v2i) => vec2i(a.x - b.x, a.y - b.y),
vec2u: (a: wgsl.v2u, b: wgsl.v2u) => vec2u(a.x - b.x, a.y - b.y),

vec3f: (a: wgsl.v3f, b: wgsl.v3f) => vec3f(a.x - b.x, a.y - b.y, a.z - b.z),
vec3h: (a: wgsl.v3h, b: wgsl.v3h) => vec3h(a.x - b.x, a.y - b.y, a.z - b.z),
vec3i: (a: wgsl.v3i, b: wgsl.v3i) => vec3i(a.x - b.x, a.y - b.y, a.z - b.z),
vec3u: (a: wgsl.v3u, b: wgsl.v3u) => vec3u(a.x - b.x, a.y - b.y, a.z - b.z),

vec4f: (a: wgsl.v4f, b: wgsl.v4f) =>
vec4f(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
vec4h: (a: wgsl.v4h, b: wgsl.v4h) =>
vec4h(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
vec4i: (a: wgsl.v4i, b: wgsl.v4i) =>
vec4i(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
vec4u: (a: wgsl.v4u, b: wgsl.v4u) =>
vec4u(a.x - b.x, a.y - b.y, a.z - b.z, a.w - b.w),
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: T) => T>,

mulSxV: {
vec2f: (s: number, v: wgsl.v2f) => vec2f(s * v.x, s * v.y),
vec2h: (s: number, v: wgsl.v2h) => vec2h(s * v.x, s * v.y),
vec2i: (s: number, v: wgsl.v2i) => vec2i(s * v.x, s * v.y),
vec2u: (s: number, v: wgsl.v2u) => vec2u(s * v.x, s * v.y),

vec3f: (s: number, v: wgsl.v3f) => vec3f(s * v.x, s * v.y, s * v.z),
vec3h: (s: number, v: wgsl.v3h) => vec3h(s * v.x, s * v.y, s * v.z),
vec3i: (s: number, v: wgsl.v3i) => vec3i(s * v.x, s * v.y, s * v.z),
vec3u: (s: number, v: wgsl.v3u) => vec3u(s * v.x, s * v.y, s * v.z),

vec4f: (s: number, v: wgsl.v4f) =>
vec4f(s * v.x, s * v.y, s * v.z, s * v.w),
vec4h: (s: number, v: wgsl.v4h) =>
vec4h(s * v.x, s * v.y, s * v.z, s * v.w),
vec4i: (s: number, v: wgsl.v4i) =>
vec4i(s * v.x, s * v.y, s * v.z, s * v.w),
vec4u: (s: number, v: wgsl.v4u) =>
vec4u(s * v.x, s * v.y, s * v.z, s * v.w),

mat2x2f: (s: number, m: wgsl.m2x2f) => {
const m_ = m.columns as [wgsl.v2f, wgsl.v2f];
return mat2x2f(s * m_[0].x, s * m_[0].y, s * m_[1].x, s * m_[1].y);
},
vec2f: binaryComponentWise2f((a, b) => a + b),
vec2h: binaryComponentWise2h((a, b) => a + b),
vec2i: binaryComponentWise2i((a, b) => a + b),
vec2u: binaryComponentWise2u((a, b) => a + b),

vec3f: binaryComponentWise3f((a, b) => a + b),
vec3h: binaryComponentWise3h((a, b) => a + b),
vec3i: binaryComponentWise3i((a, b) => a + b),
vec3u: binaryComponentWise3u((a, b) => a + b),

vec4f: binaryComponentWise4f((a, b) => a + b),
vec4h: binaryComponentWise4h((a, b) => a + b),
vec4i: binaryComponentWise4i((a, b) => a + b),
vec4u: binaryComponentWise4u((a, b) => a + b),

mat2x2f: binaryComponentWise2x2f((a, b) => a + b),
mat3x3f: binaryComponentWise3x3f((a, b) => a + b),
mat4x4f: binaryComponentWise4x4f((a, b) => a + b),
} as Record<
VecKind | MatKind,
<T extends vBase | mBase>(lhs: T, rhs: T) => T
>,

mat3x3f: (s: number, m: wgsl.m3x3f) => {
const m_ = m.columns as [wgsl.v3f, wgsl.v3f, wgsl.v3f];
return mat3x3f(
s * m_[0].x,
s * m_[0].y,
s * m_[0].z,
s * m_[1].x,
s * m_[1].y,
s * m_[1].z,
s * m_[2].x,
s * m_[2].y,
s * m_[2].z,
);
},
addMixed: {
vec2f: (a: wgsl.v2f, b: number) => unary2f((e) => e + b)(a),
vec2h: (a: wgsl.v2h, b: number) => unary2h((e) => e + b)(a),
vec2i: (a: wgsl.v2i, b: number) => unary2i((e) => e + b)(a),
vec2u: (a: wgsl.v2u, b: number) => unary2u((e) => e + b)(a),

vec3f: (a: wgsl.v3f, b: number) => unary3f((e) => e + b)(a),
vec3h: (a: wgsl.v3h, b: number) => unary3h((e) => e + b)(a),
vec3i: (a: wgsl.v3i, b: number) => unary3i((e) => e + b)(a),
vec3u: (a: wgsl.v3u, b: number) => unary3u((e) => e + b)(a),

vec4f: (a: wgsl.v4f, b: number) => unary4f((e) => e + b)(a),
vec4h: (a: wgsl.v4h, b: number) => unary4h((e) => e + b)(a),
vec4i: (a: wgsl.v4i, b: number) => unary4i((e) => e + b)(a),
vec4u: (a: wgsl.v4u, b: number) => unary4u((e) => e + b)(a),

mat2x2f: (a: wgsl.m2x2f, b: number) => unary2x2f((e) => e + b)(a),
mat3x3f: (a: wgsl.m3x3f, b: number) => unary3x3f((e) => e + b)(a),
mat4x4f: (a: wgsl.m4x4f, b: number) => unary4x4f((e) => e + b)(a),
} as Record<
VecKind | MatKind,
<T extends vBase | mBase>(lhs: T, rhs: number) => T
>,

mat4x4f: (s: number, m: wgsl.m4x4f) => {
const m_ = m.columns as [wgsl.v4f, wgsl.v4f, wgsl.v4f, wgsl.v4f];
return mat4x4f(
s * m_[0].x,
s * m_[0].y,
s * m_[0].z,
s * m_[0].w,
s * m_[1].x,
s * m_[1].y,
s * m_[1].z,
s * m_[1].w,
s * m_[2].x,
s * m_[2].y,
s * m_[2].z,
s * m_[2].w,
s * m_[3].x,
s * m_[3].y,
s * m_[3].z,
s * m_[3].w,
);
},
mulSxV: {
vec2f: (s: number, v: wgsl.v2f) => unary2f((e) => s * e)(v),
vec2h: (s: number, v: wgsl.v2h) => unary2h((e) => s * e)(v),
vec2i: (s: number, v: wgsl.v2i) => unary2i((e) => s * e)(v),
vec2u: (s: number, v: wgsl.v2u) => unary2u((e) => s * e)(v),

vec3f: (s: number, v: wgsl.v3f) => unary3f((e) => s * e)(v),
vec3h: (s: number, v: wgsl.v3h) => unary3h((e) => s * e)(v),
vec3i: (s: number, v: wgsl.v3i) => unary3i((e) => s * e)(v),
vec3u: (s: number, v: wgsl.v3u) => unary3u((e) => s * e)(v),

vec4f: (s: number, v: wgsl.v4f) => unary4f((e) => s * e)(v),
vec4h: (s: number, v: wgsl.v4h) => unary4h((e) => s * e)(v),
vec4i: (s: number, v: wgsl.v4i) => unary4i((e) => s * e)(v),
vec4u: (s: number, v: wgsl.v4u) => unary4u((e) => s * e)(v),

mat2x2f: (s: number, m: wgsl.m2x2f) => unary2x2f((e) => s * e)(m),
mat3x3f: (s: number, m: wgsl.m3x3f) => unary3x3f((e) => s * e)(m),
mat4x4f: (s: number, m: wgsl.m4x4f) => unary4x4f((e) => s * e)(m),
} as Record<
VecKind | MatKind,
<T extends vBase | wgsl.AnyMatInstance>(s: number, v: T) => T
>,

mulVxV: {
vec2f: (a: wgsl.v2f, b: wgsl.v2f) => vec2f(a.x * b.x, a.y * b.y),
vec2h: (a: wgsl.v2h, b: wgsl.v2h) => vec2h(a.x * b.x, a.y * b.y),
vec2i: (a: wgsl.v2i, b: wgsl.v2i) => vec2i(a.x * b.x, a.y * b.y),
vec2u: (a: wgsl.v2u, b: wgsl.v2u) => vec2u(a.x * b.x, a.y * b.y),

vec3f: (a: wgsl.v3f, b: wgsl.v3f) => vec3f(a.x * b.x, a.y * b.y, a.z * b.z),
vec3h: (a: wgsl.v3h, b: wgsl.v3h) => vec3h(a.x * b.x, a.y * b.y, a.z * b.z),
vec3i: (a: wgsl.v3i, b: wgsl.v3i) => vec3i(a.x * b.x, a.y * b.y, a.z * b.z),
vec3u: (a: wgsl.v3u, b: wgsl.v3u) => vec3u(a.x * b.x, a.y * b.y, a.z * b.z),

vec4f: (a: wgsl.v4f, b: wgsl.v4f) =>
vec4f(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
vec4h: (a: wgsl.v4h, b: wgsl.v4h) =>
vec4h(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
vec4i: (a: wgsl.v4i, b: wgsl.v4i) =>
vec4i(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
vec4u: (a: wgsl.v4u, b: wgsl.v4u) =>
vec4u(a.x * b.x, a.y * b.y, a.z * b.z, a.w * b.w),
vec2f: binaryComponentWise2f((a, b) => a * b),
vec2h: binaryComponentWise2h((a, b) => a * b),
vec2i: binaryComponentWise2i((a, b) => a * b),
vec2u: binaryComponentWise2u((a, b) => a * b),

vec3f: binaryComponentWise3f((a, b) => a * b),
vec3h: binaryComponentWise3h((a, b) => a * b),
vec3i: binaryComponentWise3i((a, b) => a * b),
vec3u: binaryComponentWise3u((a, b) => a * b),

vec4f: binaryComponentWise4f((a, b) => a * b),
vec4h: binaryComponentWise4h((a, b) => a * b),
vec4i: binaryComponentWise4i((a, b) => a * b),
vec4u: binaryComponentWise4u((a, b) => a * b),

mat2x2f: (a: wgsl.m2x2f, b: wgsl.m2x2f) => {
const a_ = a.columns as [wgsl.v2f, wgsl.v2f];
Expand Down Expand Up @@ -651,6 +674,46 @@ export const VectorOps = {
) => wgsl.vBaseForMat<T>
>,

div: {
vec2f: binaryComponentWise2f((a, b) => a / b),
vec2h: binaryComponentWise2h((a, b) => a / b),
vec2i: binaryComponentWise2i(NumberOps.divInteger),
vec2u: binaryComponentWise2u(NumberOps.divInteger),

vec3f: binaryComponentWise3f((a, b) => a / b),
vec3h: binaryComponentWise3h((a, b) => a / b),
vec3i: binaryComponentWise3i(NumberOps.divInteger),
vec3u: binaryComponentWise3u(NumberOps.divInteger),

vec4f: binaryComponentWise4f((a, b) => a / b),
vec4h: binaryComponentWise4h((a, b) => a / b),
vec4i: binaryComponentWise4i(NumberOps.divInteger),
vec4u: binaryComponentWise4u(NumberOps.divInteger),
} as Record<VecKind, <T extends vBase>(a: T, b: T) => T>,

divMixed: {
vec2f: (a: wgsl.v2f, b: number) => unary2f((e) => e / b)(a),
vec2h: (a: wgsl.v2h, b: number) => unary2h((e) => e / b)(a),
vec2i: (a: wgsl.v2i, b: number) =>
unary2i((e) => NumberOps.divInteger(e, b))(a),
vec2u: (a: wgsl.v2u, b: number) =>
unary2u((e) => NumberOps.divInteger(e, b))(a),

vec3f: (a: wgsl.v3f, b: number) => unary3f((e) => e / b)(a),
vec3h: (a: wgsl.v3h, b: number) => unary3h((e) => e / b)(a),
vec3i: (a: wgsl.v3i, b: number) =>
unary3i((e) => NumberOps.divInteger(e, b))(a),
vec3u: (a: wgsl.v3u, b: number) =>
unary3u((e) => NumberOps.divInteger(e, b))(a),

vec4f: (a: wgsl.v4f, b: number) => unary4f((e) => e / b)(a),
vec4h: (a: wgsl.v4h, b: number) => unary4h((e) => e / b)(a),
vec4i: (a: wgsl.v4i, b: number) =>
unary4i((e) => NumberOps.divInteger(e, b))(a),
vec4u: (a: wgsl.v4u, b: number) =>
unary4u((e) => NumberOps.divInteger(e, b))(a),
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: number) => T>,

dot: {
vec2f: dotVec2,
vec2h: dotVec2,
Expand Down Expand Up @@ -839,21 +902,6 @@ export const VectorOps = {
vec4h: unary4h(Math.sqrt),
} as Record<VecKind, <T extends vBase>(v: T) => T>,

div: {
vec2f: binaryComponentWise2f((a, b) => a / b),
vec2h: binaryComponentWise2h((a, b) => a / b),
vec2i: binaryComponentWise2i((a, b) => a / b),
vec2u: binaryComponentWise2u((a, b) => a / b),
vec3f: binaryComponentWise3f((a, b) => a / b),
vec3h: binaryComponentWise3h((a, b) => a / b),
vec3i: binaryComponentWise3i((a, b) => a / b),
vec3u: binaryComponentWise3u((a, b) => a / b),
vec4f: binaryComponentWise4f((a, b) => a / b),
vec4h: binaryComponentWise4h((a, b) => a / b),
vec4i: binaryComponentWise4i((a, b) => a / b),
vec4u: binaryComponentWise4u((a, b) => a / b),
} as Record<VecKind, <T extends vBase>(a: T, b: T) => T>,

mix: {
vec2f: (e1: wgsl.v2f, e2: wgsl.v2f, e3: wgsl.v2f | number) => {
if (typeof e3 === 'number') {
Expand Down
Loading