Skip to content

Commit 8ab8b13

Browse files
feat: Allow omitting types in wgsl header (#1306)
1 parent 84a4d9d commit 8ab8b13

File tree

4 files changed

+328
-11
lines changed

4 files changed

+328
-11
lines changed

packages/typegpu/src/core/function/extractArgs.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -145,11 +145,11 @@ function strip(
145145
};
146146
}
147147

148-
if (code.isAt('(') && !argsStart) {
148+
if (code.isAt('(') && argsStart === undefined) {
149149
argsStart = code.pos;
150150
}
151151

152-
if (argsStart) {
152+
if (argsStart !== undefined) {
153153
strippedCode += code.str[code.pos];
154154
}
155155
code.advanceBy(1); // parsed character

packages/typegpu/src/core/function/fnCore.ts

Lines changed: 72 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ import {
1818
type ExternalMap,
1919
replaceExternalsInWgsl,
2020
} from '../resolve/externals.ts';
21+
import { extractArgs } from './extractArgs.ts';
2122
import type { Implementation } from './fnTypes.ts';
2223

2324
export interface TgpuFnShellBase<Args extends unknown[], Return> {
@@ -81,31 +82,70 @@ export function createFnCore(
8182
const id = ctx.names.makeUnique(getName(this));
8283

8384
if (typeof implementation === 'string') {
85+
const replacedImpl = replaceExternalsInWgsl(
86+
ctx,
87+
externalMap,
88+
implementation,
89+
);
90+
8491
let header = '';
92+
let body = '';
8593

8694
if (shell.isEntry) {
87-
const input = isWgslStruct(shell.argTypes[0]) ? '(in: In)' : '()';
95+
const input = isWgslStruct(shell.argTypes[0])
96+
? `(in: ${ctx.resolve(shell.argTypes[0])})`
97+
: '()';
8898

8999
const attributes = isWgslData(shell.returnType)
90100
? getAttributesString(shell.returnType)
91101
: '';
92102
const output = shell.returnType !== Void
93103
? isWgslStruct(shell.returnType)
94-
? '-> Out'
104+
? `-> ${ctx.resolve(shell.returnType)}`
95105
: `-> ${attributes !== '' ? attributes : '@location(0)'} ${
96106
ctx.resolve(shell.returnType)
97107
}`
98108
: '';
109+
99110
header = `${input} ${output} `;
111+
body = replacedImpl;
112+
} else {
113+
const providedArgs = extractArgs(replacedImpl);
114+
115+
if (providedArgs.args.length !== shell.argTypes.length) {
116+
throw new Error(
117+
`WGSL implementation has ${providedArgs.args.length} arguments, while the shell has ${shell.argTypes.length} arguments.`,
118+
);
119+
}
120+
121+
const input = providedArgs.args.map((argInfo, i) =>
122+
`${argInfo.identifier}: ${
123+
checkAndReturnType(
124+
ctx,
125+
`parameter ${argInfo.identifier}`,
126+
argInfo.type,
127+
shell.argTypes[i],
128+
)
129+
}`
130+
).join(', ');
131+
132+
const output = shell.returnType === Void
133+
? ''
134+
: `-> ${
135+
checkAndReturnType(
136+
ctx,
137+
'return type',
138+
providedArgs.ret?.type,
139+
shell.returnType,
140+
)
141+
}`;
142+
143+
header = `(${input}) ${output}`;
144+
145+
body = replacedImpl.slice(providedArgs.range.end);
100146
}
101147

102-
const replacedImpl = replaceExternalsInWgsl(
103-
ctx,
104-
externalMap,
105-
`${header}${implementation.trim()}`,
106-
);
107-
108-
ctx.addDeclaration(`${fnAttribute}fn ${id}${replacedImpl}`);
148+
ctx.addDeclaration(`${fnAttribute}fn ${id}${header}${body}`);
109149
} else {
110150
// get data generated by the plugin
111151
const pluginData = getMetaData(implementation);
@@ -176,3 +216,26 @@ export function createFnCore(
176216

177217
return core;
178218
}
219+
220+
function checkAndReturnType(
221+
ctx: ResolutionCtx,
222+
name: string,
223+
wgslType: string | undefined,
224+
jsType: unknown,
225+
) {
226+
const resolvedJsType = ctx.resolve(jsType).replace(/\s/g, '');
227+
228+
if (!wgslType) {
229+
return resolvedJsType;
230+
}
231+
232+
const resolvedWgslType = wgslType.replace(/\s/g, '');
233+
234+
if (resolvedWgslType !== resolvedJsType) {
235+
throw new Error(
236+
`Type mismatch between TGPU shell and WGSL code string: ${name}, JS type "${resolvedJsType}", WGSL type "${resolvedWgslType}".`,
237+
);
238+
}
239+
240+
return wgslType;
241+
}

packages/typegpu/tests/extractArgs.test.ts

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -331,4 +331,14 @@ describe('extract args', () => {
331331
expect(ret).toBeUndefined();
332332
expect(range).toStrictEqual({ begin: 21, end: 57 });
333333
});
334+
335+
it('extracts when no arguments, no name, no return type', () => {
336+
const wgslFn = /* wgsl */ '() { return 42; }';
337+
338+
const { args, ret, range } = extractArgs(wgslFn);
339+
340+
expect(args).toStrictEqual([]);
341+
expect(ret).toStrictEqual(undefined);
342+
expect(range).toStrictEqual({ begin: 0, end: 3 });
343+
});
334344
});

0 commit comments

Comments
 (0)