Skip to content

Commit 867eeae

Browse files
authored
Merge branch 'main' into feat/array-dot-length
2 parents c8a6024 + 634d1b1 commit 867eeae

File tree

2 files changed

+107
-1
lines changed

2 files changed

+107
-1
lines changed

packages/typegpu/src/core/function/tgpuFn.ts

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ import type { AnyWgslData } from '../../data/wgslTypes';
33
import type { TgpuNamable } from '../../namable';
44
import { createDualImpl } from '../../shared/generators';
55
import { $internal } from '../../shared/symbols';
6+
import type { GenerationCtx } from '../../smol/wgslGenerator';
67
import {
78
type Labelled,
89
type ResolutionCtx,
@@ -155,7 +156,23 @@ function createFn<
155156
},
156157

157158
'~resolve'(ctx: ResolutionCtx): string {
158-
return core.resolve(ctx);
159+
if (typeof implementation === 'string') {
160+
return core.resolve(ctx);
161+
}
162+
163+
const generationCtx = ctx as GenerationCtx;
164+
if (generationCtx.callStack === undefined) {
165+
throw new Error(
166+
'Cannot resolve a TGSL function outside of a generation context',
167+
);
168+
}
169+
170+
try {
171+
generationCtx.callStack.push(shell.returnType);
172+
return core.resolve(ctx);
173+
} finally {
174+
generationCtx.callStack.pop();
175+
}
159176
},
160177
};
161178

packages/typegpu/tests/tgslFn.test.ts

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,6 +344,95 @@ describe('TGSL tgpu.fn function', () => {
344344
expect(actual).toEqual(expected);
345345
});
346346

347+
it('allows for an object based on return type struct to be returned', () => {
348+
const TestStruct = struct({
349+
a: f32,
350+
b: f32,
351+
c: vec2f,
352+
}).$name('TestStruct');
353+
354+
const fn = tgpu['~unstable']
355+
.fn([], TestStruct)
356+
.does(() => {
357+
return {
358+
a: 1,
359+
b: 2,
360+
c: vec2f(3, 4),
361+
};
362+
})
363+
.$name('test_struct');
364+
365+
const actual = parseResolved({ fn });
366+
367+
const expected = parse(`
368+
struct TestStruct {
369+
a: f32,
370+
b: f32,
371+
c: vec2f,
372+
}
373+
374+
fn test_struct() -> TestStruct {
375+
return TestStruct(1, 2, vec2f(3, 4));
376+
}
377+
`);
378+
379+
expect(actual).toEqual(expected);
380+
});
381+
382+
it('correctly handles object based on return type struct with a function call inside another function', () => {
383+
const TestStruct = struct({
384+
a: f32,
385+
b: f32,
386+
c: vec2f,
387+
}).$name('TestStruct');
388+
389+
const fn = tgpu['~unstable']
390+
.fn([], TestStruct)
391+
.does(() => {
392+
return {
393+
a: 1,
394+
b: 2,
395+
c: vec2f(3, 4),
396+
};
397+
})
398+
.$name('test_struct');
399+
400+
const fn2 = tgpu['~unstable']
401+
.computeFn({
402+
in: { gid: builtin.globalInvocationId },
403+
workgroupSize: [24],
404+
})
405+
.does((input) => {
406+
const testStruct = fn();
407+
})
408+
.$name('compute_fn');
409+
410+
const actual = parseResolved({ fn2 });
411+
412+
const expected = parse(`
413+
struct TestStruct {
414+
a: f32,
415+
b: f32,
416+
c: vec2f,
417+
}
418+
419+
fn test_struct() -> TestStruct {
420+
return TestStruct(1, 2, vec2f(3, 4));
421+
}
422+
423+
struct compute_fn_Input {
424+
@builtin(global_invocation_id) gid: vec3u,
425+
}
426+
427+
@compute @workgroup_size(24)
428+
fn compute_fn(input: compute_fn_Input) {
429+
var testStruct = test_struct();
430+
}
431+
`);
432+
433+
expect(actual).toEqual(expected);
434+
});
435+
347436
// TODO: Add this back when we can properly infer ast types (and implement appropriate behavior for pointers)
348437
// it('resolves a function with a pointer parameter', () => {
349438
// const addOnes = tgpu['~unstable']

0 commit comments

Comments
 (0)