Skip to content

Commit b309609

Browse files
authored
Merge branch 'main' into fix/cleanup-infer-return
2 parents c26cd10 + 3ac9bd8 commit b309609

File tree

17 files changed

+1115
-268
lines changed

17 files changed

+1115
-268
lines changed

apps/typegpu-docs/src/content/examples/rendering/3d-fish/render.ts

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -19,8 +19,7 @@ export const vertexShader = tgpu['~unstable']
1919
// https://simple.wikipedia.org/wiki/Pitch,_yaw,_and_roll
2020
const currentModelData = layout.$.modelData[input.instanceIndex];
2121

22-
// apply sin wave
23-
22+
// apply sin wave to imitate swimming motion
2423
let wavedVertex = PosAndNormal({
2524
position: input.modelPosition,
2625
normal: input.modelNormal,
@@ -37,43 +36,37 @@ export const vertexShader = tgpu['~unstable']
3736
}
3837

3938
// rotate model
40-
4139
const direction = std.normalize(currentModelData.direction);
40+
const yaw = -std.atan2(direction.z, direction.x) + Math.PI;
41+
const pitch = std.asin(-direction.y);
4242

43-
const yaw = std.atan2(direction.z, direction.x) + Math.PI;
44-
// deno-fmt-ignore
45-
const yawMatrix = d.mat3x3f(
46-
std.cos(yaw), 0, std.sin(yaw),
47-
0, 1, 0,
48-
-std.sin(yaw), 0, std.cos(yaw),
49-
);
50-
51-
const pitch = -std.asin(-direction.y);
52-
// deno-fmt-ignore
53-
const pitchMatrix = d.mat3x3f(
54-
std.cos(pitch), -std.sin(pitch), 0,
55-
std.sin(pitch), std.cos(pitch), 0,
56-
0, 0, 1,
57-
);
43+
const scaleMatrix = d.mat4x4f.scaling(d.vec3f(currentModelData.scale));
44+
const pitchMatrix = d.mat4x4f.rotationZ(pitch);
45+
const yawMatrix = d.mat4x4f.rotationY(yaw);
46+
const translationMatrix = d.mat4x4f.translation(currentModelData.position);
5847

59-
const worldPosition = std.add(
48+
const worldPosition = std.mul(
49+
translationMatrix,
6050
std.mul(
6151
yawMatrix,
6252
std.mul(
6353
pitchMatrix,
64-
std.mul(currentModelData.scale, wavedVertex.position),
54+
std.mul(
55+
scaleMatrix,
56+
d.vec4f(wavedVertex.position, 1),
57+
),
6558
),
6659
),
67-
currentModelData.position,
6860
);
6961

7062
// calculate where the normal vector points to
7163
const worldNormal = std.normalize(
72-
std.mul(pitchMatrix, std.mul(yawMatrix, wavedVertex.normal)),
64+
std.mul(yawMatrix, std.mul(pitchMatrix, d.vec4f(wavedVertex.normal, 1)))
65+
.xyz,
7366
);
7467

7568
// project the world position into the camera
76-
const worldPositionUniform = d.vec4f(worldPosition.xyz, 1);
69+
const worldPositionUniform = worldPosition;
7770
const canvasPosition = std.mul(
7871
layout.$.camera.projection,
7972
std.mul(layout.$.camera.view, worldPositionUniform),
@@ -83,7 +76,7 @@ export const vertexShader = tgpu['~unstable']
8376
canvasPosition: canvasPosition,
8477
textureUV: input.textureUV,
8578
worldNormal: worldNormal,
86-
worldPosition: worldPosition,
79+
worldPosition: worldPosition.xyz,
8780
applySeaFog: currentModelData.applySeaFog,
8881
applySeaDesaturation: currentModelData.applySeaDesaturation,
8982
variant: currentModelData.variant,

apps/typegpu-docs/src/content/examples/tests/tgsl-parsing-test/index.ts

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import tgpu from 'typegpu';
22
import * as d from 'typegpu/data';
33
import { logicalExpressionTests } from './logical-expressions.ts';
4+
import { matrixOpsTests } from './matrix-ops.ts';
45

56
const root = await tgpu.init();
67
const result = root['~unstable'].createMutable(d.i32, 0);
@@ -9,6 +10,7 @@ const computeRunTests = tgpu['~unstable']
910
.computeFn({ workgroupSize: [1] })(() => {
1011
let s = true;
1112
s = s && logicalExpressionTests();
13+
s = s && matrixOpsTests();
1214

1315
if (s) {
1416
result.value = 1;
Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,146 @@
1+
import tgpu from 'typegpu';
2+
import * as d from 'typegpu/data';
3+
import * as std from 'typegpu/std';
4+
5+
// TODO: replace `s = s &&` with `s &&=` when implemented
6+
export const matrixOpsTests = tgpu['~unstable']
7+
.fn([], d.bool)(() => {
8+
let s = true;
9+
10+
s = s &&
11+
std.isCloseTo(
12+
std.mul(d.mat4x4f.translation(d.vec3f(-1, 0, 1)), d.vec4f(1, 2, 3, 1)),
13+
d.vec4f(0, 2, 4, 1),
14+
);
15+
16+
s = s &&
17+
std.isCloseTo(
18+
std.mul(d.mat4x4f.scaling(d.vec3f(-1, 0, 1)), d.vec4f(1, 2, 3, 1)),
19+
d.vec4f(-1, 0, 3, 1),
20+
);
21+
22+
s = s &&
23+
std.isCloseTo(
24+
std.mul(d.mat4x4f.rotationX(Math.PI / 2), d.vec4f(1, 2, 3, 1)),
25+
d.vec4f(1, -3, 2, 1),
26+
);
27+
28+
s = s &&
29+
std.isCloseTo(
30+
std.mul(d.mat4x4f.rotationY(Math.PI / 2), d.vec4f(1, 2, 3, 1)),
31+
d.vec4f(3, 2, -1, 1),
32+
);
33+
34+
s = s &&
35+
std.isCloseTo(
36+
std.mul(d.mat4x4f.rotationZ(Math.PI / 2), d.vec4f(1, 2, 3, 1)),
37+
d.vec4f(-2, 1, 3, 1),
38+
);
39+
40+
s = s &&
41+
std.isCloseTo(
42+
std.mul(
43+
std.translate4(d.mat4x4f.identity(), d.vec3f(-1, 0, 1)),
44+
d.vec4f(1, 2, 3, 1),
45+
),
46+
d.vec4f(0, 2, 4, 1),
47+
);
48+
49+
s = s &&
50+
std.isCloseTo(
51+
std.mul(
52+
std.scale4(d.mat4x4f.identity(), d.vec3f(-1, 0, 1)),
53+
d.vec4f(1, 2, 3, 1),
54+
),
55+
d.vec4f(-1, 0, 3, 1),
56+
);
57+
58+
s = s &&
59+
std.isCloseTo(
60+
std.mul(
61+
std.rotateX4(d.mat4x4f.identity(), Math.PI / 2),
62+
d.vec4f(1, 2, 3, 1),
63+
),
64+
d.vec4f(1, -3, 2, 1),
65+
);
66+
67+
s = s &&
68+
std.isCloseTo(
69+
std.mul(
70+
std.rotateY4(d.mat4x4f.identity(), Math.PI / 2),
71+
d.vec4f(1, 2, 3, 1),
72+
),
73+
d.vec4f(3, 2, -1, 1),
74+
);
75+
76+
s = s &&
77+
std.isCloseTo(
78+
std.mul(
79+
std.rotateZ4(d.mat4x4f.identity(), Math.PI / 2),
80+
d.vec4f(1, 2, 3, 1),
81+
),
82+
d.vec4f(-2, 1, 3, 1),
83+
);
84+
85+
s = s &&
86+
std.isCloseTo(
87+
std.mul(
88+
std.rotateZ4(
89+
std.rotateX4(d.mat4x4f.identity(), Math.PI / 2),
90+
Math.PI / 2,
91+
),
92+
d.vec4f(1, 0, 0, 1),
93+
),
94+
d.vec4f(0, 1, 0, 1),
95+
);
96+
97+
s = s &&
98+
std.isCloseTo(
99+
std.mul(
100+
std.rotateX4(
101+
std.rotateZ4(d.mat4x4f.identity(), Math.PI / 2),
102+
Math.PI / 2,
103+
),
104+
d.vec4f(1, 0, 0, 1),
105+
),
106+
d.vec4f(0, 0, 1, 1),
107+
);
108+
109+
s = s &&
110+
std.isCloseTo(
111+
std.mul(
112+
std.translate4(
113+
std.scale4(d.mat4x4f.identity(), d.vec3f(2, 3, 4)),
114+
d.vec3f(0, 1, 0),
115+
),
116+
d.vec4f(1, 0, 0, 1),
117+
),
118+
d.vec4f(2, 1, 0, 1),
119+
);
120+
121+
s = s &&
122+
std.isCloseTo(
123+
std.mul(
124+
std.scale4(
125+
std.translate4(d.mat4x4f.identity(), d.vec3f(0, 1, 0)),
126+
d.vec3f(2, 3, 4),
127+
),
128+
d.vec4f(0, 0, 0, 1),
129+
),
130+
d.vec4f(0, 3, 0, 1),
131+
);
132+
133+
s = s &&
134+
std.isCloseTo(
135+
std.mul(
136+
std.rotateZ4(
137+
std.rotateY4(d.mat4x4f.identity(), Math.PI / 2),
138+
Math.PI / 2,
139+
),
140+
d.vec4f(0, 1, 0, 1),
141+
),
142+
d.vec4f(-1, 0, 0, 1),
143+
);
144+
145+
return s;
146+
});

packages/typegpu/src/core/function/extractArgs.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ function strip(
145145
};
146146
}
147147

148-
if (code.isAt('(') && !argsStart) {
148+
if (code.isAt('(') && argsStart === undefined) {
149149
argsStart = code.pos;
150150
}
151151

152-
if (argsStart) {
152+
if (argsStart !== undefined) {
153153
strippedCode += code.str[code.pos];
154154
}
155155
code.advanceBy(1); // parsed character

packages/typegpu/src/core/function/fnCore.ts

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
type ExternalMap,
1919
replaceExternalsInWgsl,
2020
} from '../resolve/externals.ts';
21+
import { extractArgs } from './extractArgs.ts';
2122
import type { Implementation } from './fnTypes.ts';
2223

2324
export interface TgpuFnShellBase<Args extends unknown[], Return> {
@@ -81,31 +82,70 @@ export function createFnCore(
8182
const id = ctx.names.makeUnique(getName(this));
8283

8384
if (typeof implementation === 'string') {
85+
const replacedImpl = replaceExternalsInWgsl(
86+
ctx,
87+
externalMap,
88+
implementation,
89+
);
90+
8491
let header = '';
92+
let body = '';
8593

8694
if (shell.isEntry) {
87-
const input = isWgslStruct(shell.argTypes[0]) ? '(in: In)' : '()';
95+
const input = isWgslStruct(shell.argTypes[0])
96+
? `(in: ${ctx.resolve(shell.argTypes[0])})`
97+
: '()';
8898

8999
const attributes = isWgslData(shell.returnType)
90100
? getAttributesString(shell.returnType)
91101
: '';
92102
const output = shell.returnType !== Void
93103
? isWgslStruct(shell.returnType)
94-
? '-> Out'
104+
? `-> ${ctx.resolve(shell.returnType)}`
95105
: `-> ${attributes !== '' ? attributes : '@location(0)'} ${
96106
ctx.resolve(shell.returnType)
97107
}`
98108
: '';
109+
99110
header = `${input} ${output} `;
111+
body = replacedImpl;
112+
} else {
113+
const providedArgs = extractArgs(replacedImpl);
114+
115+
if (providedArgs.args.length !== shell.argTypes.length) {
116+
throw new Error(
117+
`WGSL implementation has ${providedArgs.args.length} arguments, while the shell has ${shell.argTypes.length} arguments.`,
118+
);
119+
}
120+
121+
const input = providedArgs.args.map((argInfo, i) =>
122+
`${argInfo.identifier}: ${
123+
checkAndReturnType(
124+
ctx,
125+
`parameter ${argInfo.identifier}`,
126+
argInfo.type,
127+
shell.argTypes[i],
128+
)
129+
}`
130+
).join(', ');
131+
132+
const output = shell.returnType === Void
133+
? ''
134+
: `-> ${
135+
checkAndReturnType(
136+
ctx,
137+
'return type',
138+
providedArgs.ret?.type,
139+
shell.returnType,
140+
)
141+
}`;
142+
143+
header = `(${input}) ${output}`;
144+
145+
body = replacedImpl.slice(providedArgs.range.end);
100146
}
101147

102-
const replacedImpl = replaceExternalsInWgsl(
103-
ctx,
104-
externalMap,
105-
`${header}${implementation.trim()}`,
106-
);
107-
108-
ctx.addDeclaration(`${fnAttribute}fn ${id}${replacedImpl}`);
148+
ctx.addDeclaration(`${fnAttribute}fn ${id}${header}${body}`);
109149
} else {
110150
// get data generated by the plugin
111151
const pluginData = getMetaData(implementation);
@@ -176,3 +216,26 @@ export function createFnCore(
176216

177217
return core;
178218
}
219+
220+
function checkAndReturnType(
221+
ctx: ResolutionCtx,
222+
name: string,
223+
wgslType: string | undefined,
224+
jsType: unknown,
225+
) {
226+
const resolvedJsType = ctx.resolve(jsType).replace(/\s/g, '');
227+
228+
if (!wgslType) {
229+
return resolvedJsType;
230+
}
231+
232+
const resolvedWgslType = wgslType.replace(/\s/g, '');
233+
234+
if (resolvedWgslType !== resolvedJsType) {
235+
throw new Error(
236+
`Type mismatch between TGPU shell and WGSL code string: ${name}, JS type "${resolvedJsType}", WGSL type "${resolvedWgslType}".`,
237+
);
238+
}
239+
240+
return wgslType;
241+
}

0 commit comments

Comments
 (0)