Skip to content

Commit e1ad1ea

Browse files
Merge branch 'main' into fix/tgpu-functions-declarations-order
2 parents 2acfcd4 + 83d1428 commit e1ad1ea

File tree

5 files changed

+259
-29
lines changed

5 files changed

+259
-29
lines changed
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
export const smoothstepScalar = (
2+
edge0: number,
3+
edge1: number,
4+
x: number,
5+
): number => {
6+
if (edge0 === edge1) {
7+
return 0; // WGSL spec says this is an indeterminate value
8+
}
9+
const t = clamp((x - edge0) / (edge1 - edge0), 0.0, 1.0);
10+
return t * t * (3 - 2 * t);
11+
};
12+
13+
export const clamp = (value: number, low: number, high: number) =>
14+
Math.min(Math.max(low, value), high);
15+
16+
export const divInteger = (lhs: number, rhs: number) => {
17+
if (rhs === 0) {
18+
return lhs;
19+
}
20+
return Math.trunc(lhs / rhs);
21+
};

packages/typegpu/src/data/vectorOps.ts

Lines changed: 67 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import { mat2x2f, mat3x3f, mat4x4f } from './matrix.ts';
2+
import { clamp, divInteger, smoothstepScalar } from './numberOps.ts';
23
import {
34
vec2b,
45
vec2f,
@@ -38,9 +39,6 @@ const dotVec3 = (lhs: v3, rhs: v3) =>
3839
const dotVec4 = (lhs: v4, rhs: v4) =>
3940
lhs.x * rhs.x + lhs.y * rhs.y + lhs.z * rhs.z + lhs.w * rhs.w;
4041

41-
const clamp = (value: number, low: number, high: number) =>
42-
Math.min(Math.max(low, value), high);
43-
4442
type UnaryOp = (a: number) => number;
4543
type BinaryOp = (a: number, b: number) => number;
4644

@@ -162,14 +160,41 @@ const binaryComponentWise4x4f =
162160
);
163161
};
164162

165-
export const NumberOps = {
166-
divInteger: (lhs: number, rhs: number) => {
167-
if (rhs === 0) {
168-
return lhs;
169-
}
170-
return Math.trunc(lhs / rhs);
171-
},
172-
};
163+
type TernaryOp = (a: number, b: number, c: number) => number;
164+
165+
const ternaryComponentWise2f =
166+
(op: TernaryOp) => (a: wgsl.v2f, b: wgsl.v2f, c: wgsl.v2f) =>
167+
vec2f(op(a.x, b.x, c.x), op(a.y, b.y, c.y));
168+
169+
const ternaryComponentWise2h =
170+
(op: TernaryOp) => (a: wgsl.v2h, b: wgsl.v2h, c: wgsl.v2h) =>
171+
vec2h(op(a.x, b.x, c.x), op(a.y, b.y, c.y));
172+
173+
const ternaryComponentWise3f =
174+
(op: TernaryOp) => (a: wgsl.v3f, b: wgsl.v3f, c: wgsl.v3f) =>
175+
vec3f(op(a.x, b.x, c.x), op(a.y, b.y, c.y), op(a.z, b.z, c.z));
176+
177+
const ternaryComponentWise3h =
178+
(op: TernaryOp) => (a: wgsl.v3h, b: wgsl.v3h, c: wgsl.v3h) =>
179+
vec3h(op(a.x, b.x, c.x), op(a.y, b.y, c.y), op(a.z, b.z, c.z));
180+
181+
const ternaryComponentWise4f =
182+
(op: TernaryOp) => (a: wgsl.v4f, b: wgsl.v4f, c: wgsl.v4f) =>
183+
vec4f(
184+
op(a.x, b.x, c.x),
185+
op(a.y, b.y, c.y),
186+
op(a.z, b.z, c.z),
187+
op(a.w, b.w, c.w),
188+
);
189+
190+
const ternaryComponentWise4h =
191+
(op: TernaryOp) => (a: wgsl.v4h, b: wgsl.v4h, c: wgsl.v4h) =>
192+
vec4h(
193+
op(a.x, b.x, c.x),
194+
op(a.y, b.y, c.y),
195+
op(a.z, b.z, c.z),
196+
op(a.w, b.w, c.w),
197+
);
173198

174199
export const VectorOps = {
175200
eq: {
@@ -446,6 +471,25 @@ export const VectorOps = {
446471
<T extends vBase | mBase>(lhs: T, rhs: T) => T
447472
>,
448473

474+
smoothstep: {
475+
vec2f: ternaryComponentWise2f(smoothstepScalar),
476+
vec2h: ternaryComponentWise2h(smoothstepScalar),
477+
vec3f: ternaryComponentWise3f(smoothstepScalar),
478+
vec3h: ternaryComponentWise3h(smoothstepScalar),
479+
vec4f: ternaryComponentWise4f(smoothstepScalar),
480+
vec4h: ternaryComponentWise4h(smoothstepScalar),
481+
} as Record<
482+
VecKind,
483+
<T extends vBase>(
484+
edge0: T,
485+
edge1: T,
486+
x: T,
487+
) => T extends wgsl.AnyVec2Instance ? wgsl.v2f
488+
: T extends wgsl.AnyVec3Instance ? wgsl.v3f
489+
: T extends wgsl.AnyVec4Instance ? wgsl.v4f
490+
: wgsl.AnyVecInstance
491+
>,
492+
449493
addMixed: {
450494
vec2f: (a: wgsl.v2f, b: number) => unary2f((e) => e + b)(a),
451495
vec2h: (a: wgsl.v2h, b: number) => unary2h((e) => e + b)(a),
@@ -688,41 +732,35 @@ export const VectorOps = {
688732
div: {
689733
vec2f: binaryComponentWise2f((a, b) => a / b),
690734
vec2h: binaryComponentWise2h((a, b) => a / b),
691-
vec2i: binaryComponentWise2i(NumberOps.divInteger),
692-
vec2u: binaryComponentWise2u(NumberOps.divInteger),
735+
vec2i: binaryComponentWise2i(divInteger),
736+
vec2u: binaryComponentWise2u(divInteger),
693737

694738
vec3f: binaryComponentWise3f((a, b) => a / b),
695739
vec3h: binaryComponentWise3h((a, b) => a / b),
696-
vec3i: binaryComponentWise3i(NumberOps.divInteger),
697-
vec3u: binaryComponentWise3u(NumberOps.divInteger),
740+
vec3i: binaryComponentWise3i(divInteger),
741+
vec3u: binaryComponentWise3u(divInteger),
698742

699743
vec4f: binaryComponentWise4f((a, b) => a / b),
700744
vec4h: binaryComponentWise4h((a, b) => a / b),
701-
vec4i: binaryComponentWise4i(NumberOps.divInteger),
702-
vec4u: binaryComponentWise4u(NumberOps.divInteger),
745+
vec4i: binaryComponentWise4i(divInteger),
746+
vec4u: binaryComponentWise4u(divInteger),
703747
} as Record<VecKind, <T extends vBase>(a: T, b: T) => T>,
704748

705749
divMixed: {
706750
vec2f: (a: wgsl.v2f, b: number) => unary2f((e) => e / b)(a),
707751
vec2h: (a: wgsl.v2h, b: number) => unary2h((e) => e / b)(a),
708-
vec2i: (a: wgsl.v2i, b: number) =>
709-
unary2i((e) => NumberOps.divInteger(e, b))(a),
710-
vec2u: (a: wgsl.v2u, b: number) =>
711-
unary2u((e) => NumberOps.divInteger(e, b))(a),
752+
vec2i: (a: wgsl.v2i, b: number) => unary2i((e) => divInteger(e, b))(a),
753+
vec2u: (a: wgsl.v2u, b: number) => unary2u((e) => divInteger(e, b))(a),
712754

713755
vec3f: (a: wgsl.v3f, b: number) => unary3f((e) => e / b)(a),
714756
vec3h: (a: wgsl.v3h, b: number) => unary3h((e) => e / b)(a),
715-
vec3i: (a: wgsl.v3i, b: number) =>
716-
unary3i((e) => NumberOps.divInteger(e, b))(a),
717-
vec3u: (a: wgsl.v3u, b: number) =>
718-
unary3u((e) => NumberOps.divInteger(e, b))(a),
757+
vec3i: (a: wgsl.v3i, b: number) => unary3i((e) => divInteger(e, b))(a),
758+
vec3u: (a: wgsl.v3u, b: number) => unary3u((e) => divInteger(e, b))(a),
719759

720760
vec4f: (a: wgsl.v4f, b: number) => unary4f((e) => e / b)(a),
721761
vec4h: (a: wgsl.v4h, b: number) => unary4h((e) => e / b)(a),
722-
vec4i: (a: wgsl.v4i, b: number) =>
723-
unary4i((e) => NumberOps.divInteger(e, b))(a),
724-
vec4u: (a: wgsl.v4u, b: number) =>
725-
unary4u((e) => NumberOps.divInteger(e, b))(a),
762+
vec4i: (a: wgsl.v4i, b: number) => unary4i((e) => divInteger(e, b))(a),
763+
vec4u: (a: wgsl.v4u, b: number) => unary4u((e) => divInteger(e, b))(a),
726764
} as Record<VecKind, <T extends vBase>(lhs: T, rhs: number) => T>,
727765

728766
dot: {

packages/typegpu/src/std/index.ts

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ export {
3737
reflect,
3838
sign,
3939
sin,
40+
smoothstep,
4041
sqrt,
4142
tanh
4243
} from './numeric.ts';

packages/typegpu/src/std/numeric.ts

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import { vecTypeToConstructor } from '../data/vector.ts';
22
import { type AnyData, snip, type Snippet } from '../data/dataTypes.ts';
3+
import { smoothstepScalar } from '../data/numberOps.ts';
34
import { f32 } from '../data/numeric.ts';
45
import { VectorOps } from '../data/vectorOps.ts';
56
import {
@@ -764,3 +765,25 @@ export const tanh = createDualImpl(
764765
(value) => snip(`tanh(${value.value})`, value.dataType),
765766
'tanh',
766767
);
768+
769+
export const smoothstep = createDualImpl(
770+
// CPU implementation
771+
<T extends AnyFloatVecInstance | number>(edge0: T, edge1: T, x: T): T => {
772+
if (typeof x === 'number') {
773+
return smoothstepScalar(
774+
edge0 as number,
775+
edge1 as number,
776+
x as number,
777+
) as T;
778+
}
779+
return VectorOps.smoothstep[x.kind](
780+
edge0 as AnyFloatVecInstance,
781+
edge1 as AnyFloatVecInstance,
782+
x as AnyFloatVecInstance,
783+
) as T;
784+
},
785+
// GPU implementation
786+
(edge0, edge1, x) =>
787+
snip(`smoothstep(${edge0.value}, ${edge1.value}, ${x.value})`, x.dataType),
788+
'smoothstep',
789+
);
Lines changed: 147 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,147 @@
1+
import { describe, expect, it } from 'vitest';
2+
import { vec2f, vec3f, vec4f } from '../../../src/data/index.ts';
3+
import { smoothstep } from '../../../src/std/index.ts';
4+
import { isCloseTo } from '../../../src/std/boolean.ts';
5+
6+
describe('smoothstep', () => {
7+
it('returns 0 when x is less than or equal to edge0', () => {
8+
expect(smoothstep(0, 1, -0.5)).toBe(0);
9+
expect(smoothstep(0, 1, 0)).toBe(0);
10+
});
11+
12+
it('returns 1 when x is greater than or equal to edge1', () => {
13+
expect(smoothstep(0, 1, 1)).toBe(1);
14+
expect(smoothstep(0, 1, 1.5)).toBe(1);
15+
});
16+
17+
it('returns smoothly interpolated value between 0 and 1 when x is between edge0 and edge1', () => {
18+
expect(smoothstep(0, 1, 0.25)).toBeCloseTo(0.15625); // t = 0.25, t² * (3 - 2t) = 0.15625
19+
expect(smoothstep(0, 1, 0.5)).toBeCloseTo(0.5); // t = 0.5, t² * (3 - 2t) = 0.5
20+
expect(smoothstep(0, 1, 0.75)).toBeCloseTo(0.84375); // t = 0.75, t² * (3 - 2t) = 0.84375
21+
});
22+
23+
it('works with vec2f', () => {
24+
// all components less than edge0
25+
const result1 = smoothstep(
26+
vec2f(0.3, 0.5),
27+
vec2f(0.8, 0.9),
28+
vec2f(0.1, 0.2),
29+
);
30+
expect(isCloseTo(result1, vec2f(0, 0))).toBe(true);
31+
32+
// all components greater than edge1
33+
const result2 = smoothstep(
34+
vec2f(0.3, 0.5),
35+
vec2f(0.8, 0.9),
36+
vec2f(0.9, 1.0),
37+
);
38+
expect(isCloseTo(result2, vec2f(1, 1))).toBe(true);
39+
40+
// components between edge0 and edge1
41+
const result3 = smoothstep(
42+
vec2f(0.3, 0.5),
43+
vec2f(0.8, 0.9),
44+
vec2f(0.55, 0.7),
45+
);
46+
expect(isCloseTo(result3, vec2f(0.5, 0.5))).toBe(true);
47+
48+
// mixed results
49+
const result4 = smoothstep(
50+
vec2f(0.3, 0.5),
51+
vec2f(0.8, 0.9),
52+
vec2f(0.2, 0.95),
53+
);
54+
expect(isCloseTo(result4, vec2f(0, 1))).toBe(true);
55+
});
56+
57+
it('works with vec3f', () => {
58+
const result = smoothstep(
59+
vec3f(0.1, 0.2, 0.3),
60+
vec3f(0.6, 0.7, 0.8),
61+
vec3f(0.35, 0.7, 0.55),
62+
);
63+
expect(isCloseTo(result, vec3f(0.5, 1, 0.5))).toBe(true);
64+
});
65+
66+
it('works with vec4f', () => {
67+
const result = smoothstep(
68+
vec4f(0.0, 0.1, 0.2, 0.3),
69+
vec4f(1.0, 0.9, 0.8, 0.7),
70+
vec4f(0.0, 0.5, 0.8, 0.7),
71+
);
72+
expect(isCloseTo(result, vec4f(0, 0.5, 1, 1))).toBe(true);
73+
});
74+
75+
it('works with vector edges with same values for all components', () => {
76+
const result = smoothstep(
77+
vec3f(0.3, 0.3, 0.3),
78+
vec3f(0.7, 0.7, 0.7),
79+
vec3f(0.2, 0.5, 0.8),
80+
);
81+
// For first component: x < edge0, result = 0
82+
// For second component: t = (0.5-0.3)/(0.7-0.3) = 0.5, result = 0.5
83+
// For third component: x > edge1, result = 1
84+
expect(isCloseTo(result, vec3f(0, 0.5, 1))).toBe(true);
85+
});
86+
87+
it('handles edge case with equal edge values', () => {
88+
expect(smoothstep(0.5, 0.5, 0.4)).toBe(0);
89+
expect(smoothstep(0.5, 0.5, 0.5)).toBe(0);
90+
expect(smoothstep(0.5, 0.5, 0.6)).toBe(0);
91+
92+
const result = smoothstep(
93+
vec2f(0.5, 0.3),
94+
vec2f(0.5, 0.3),
95+
vec2f(0.4, 0.3),
96+
);
97+
expect(isCloseTo(result, vec2f(0, 1))).toBe(false);
98+
});
99+
100+
it('handles reversed edge values (edge0 > edge1)', () => {
101+
expect(smoothstep(0.8, 0.2, 0.5)).toBeCloseTo(0.5);
102+
103+
const result = smoothstep(
104+
vec2f(0.8, 0.7),
105+
vec2f(0.2, 0.3),
106+
vec2f(0.5, 0.5),
107+
);
108+
expect(isCloseTo(result, vec2f(0.5, 0.5))).toBe(true);
109+
});
110+
111+
it('handles negative values correctly', () => {
112+
expect(smoothstep(-2, -1, -1.5)).toBeCloseTo(0.5);
113+
expect(smoothstep(-10, -5, -20)).toBe(0);
114+
expect(smoothstep(-10, -5, 0)).toBe(1);
115+
116+
const result = smoothstep(
117+
vec3f(-1, 0, -10),
118+
vec3f(1, 1, -5),
119+
vec3f(0, 0.5, -7.5),
120+
);
121+
expect(isCloseTo(result, vec3f(0.5, 0.5, 0.5))).toBe(true);
122+
});
123+
124+
it('handles very small differences between edges', () => {
125+
expect(smoothstep(0.5000, 0.5001, 0.50005)).toBeCloseTo(0.5);
126+
127+
const result = smoothstep(
128+
vec2f(0.1, 0.2),
129+
vec2f(0.1001, 0.2001),
130+
vec2f(0.1, 0.20005),
131+
);
132+
expect(isCloseTo(result, vec2f(0, 0.5))).toBe(true);
133+
});
134+
135+
it('handles extreme values', () => {
136+
expect(smoothstep(1000, 2000, 1500)).toBeCloseTo(0.5);
137+
138+
expect(smoothstep(0.00001, 0.00002, 0.000015)).toBeCloseTo(0.5);
139+
140+
const result = smoothstep(
141+
vec3f(0.00001, 1000, 0),
142+
vec3f(0.00002, 2000, 1),
143+
vec3f(0.000015, 1500, 0.5),
144+
);
145+
expect(isCloseTo(result, vec3f(0.5, 0.5, 0.5))).toBe(true);
146+
});
147+
});

0 commit comments

Comments
 (0)