Skip to content

feat: Expected type stack in WgslGenerator #1532

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 68 commits into from
Jul 31, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
447ece2
Add struct default constructor
Jul 11, 2025
85e26bd
Update docs
Jul 11, 2025
4a0ee79
Merge branch 'main' into feat/struct-default-constructor
aleksanderkatan Jul 11, 2025
a51d6b6
Update packages/typegpu/tests/struct.test.ts
aleksanderkatan Jul 14, 2025
bd03e7d
Merge branch 'main' into feat/struct-default-constructor
aleksanderkatan Jul 15, 2025
fb59992
Update packages/typegpu/src/tgsl/wgslGenerator.ts
aleksanderkatan Jul 15, 2025
4aed513
Change [1] to [0]
Jul 15, 2025
9e13516
Merge branch 'feat/struct-default-constructor' into feat/make-array-s…
Jul 16, 2025
2b6ce6a
Make array schemas callable (but they do nothing)
Jul 16, 2025
89653aa
Move the struct resolve test to wgslGenerator.test.ts
Jul 16, 2025
8d49f55
Move tests back to struct.test.ts
Jul 16, 2025
f9f397b
Add array tests
Jul 16, 2025
32df513
Implement JS side of array calls
Jul 16, 2025
9e121dc
Implement WGSL side
Jul 16, 2025
7a18c01
Change u32 to i32
Jul 17, 2025
d3dbeff
Merge remote-tracking branch 'origin/main' into feat/make-array-schem…
Jul 17, 2025
154743c
Add tgsl parsing tests
Jul 17, 2025
1817dee
Remove unnecessary (I hope so) parentheses
Jul 17, 2025
5595381
Docs
Jul 17, 2025
4a03948
Lint
Jul 17, 2025
3db0533
Merge remote-tracking branch 'origin/main' into feat/make-array-schem…
Jul 17, 2025
dc7c91c
Fix import
Jul 17, 2025
ec64007
Lint
Jul 17, 2025
48a2f6a
Nits
Jul 17, 2025
29e952c
Merge remote-tracking branch 'origin/main' into feat/make-array-schem…
Jul 18, 2025
eb207c1
Remove generic from struct constructor, update docs
Jul 22, 2025
dea2917
Merge remote-tracking branch 'origin/main' into feat/make-array-schem…
Jul 22, 2025
dbe2f9b
Apply suggestions from code review
aleksanderkatan Jul 22, 2025
5cb0913
Update array toString, override dualImpl toString
Jul 22, 2025
34c0d31
Add some tests
Jul 22, 2025
eaa0c9b
Add expectedTypeStack
Jul 22, 2025
e9cbdbd
Merge branch 'feat/make-array-schemas-callable' into feat/passing-the…
Jul 22, 2025
c60e12d
Use expectedTypeStack in struct call
Jul 23, 2025
4fd127d
Use expectedTypeStack for returns and nested structs
Jul 23, 2025
be63fc2
Remove callStack from return value generation, comment out some weird…
Jul 23, 2025
6857afe
Remove the arbitrary object cast to output type
Jul 23, 2025
c4c4f87
CallStack is kil
Jul 23, 2025
ce897cb
Remove legacy calls to stack from tgpuFn
Jul 23, 2025
8a86db4
Remove legacy calls from fragmentFn and tgpuFn
Jul 24, 2025
e1d615f
Add array support
Jul 24, 2025
47b01fb
Add type coercion, apply type coercion to return value
Jul 24, 2025
8d5c717
Merge remote-tracking branch 'origin/main' into feat/passing-the-type…
Jul 24, 2025
152350f
Merge fixes
Jul 24, 2025
a17451b
Move struct/array handling
Jul 24, 2025
311c438
Cleanup struct/array call
Jul 24, 2025
67372f7
Refactor
Jul 25, 2025
5a43b33
Add function support
Jul 25, 2025
13d7e82
Add bool to the stack for if/for/while conditions
Jul 25, 2025
b7ec863
Merge remote-tracking branch 'origin/main' into feat/passing-the-type…
Jul 25, 2025
0f25a28
Bring back commented test 1
Jul 25, 2025
594913f
Bring back commented tests 2 & 3
Jul 25, 2025
7ed4db9
Refactor array generation, add better comments
Jul 25, 2025
65bcb6c
Rename `argTypes` to `argConversionHint`
Jul 25, 2025
1f09942
Remove UnknownData from stack type
Jul 25, 2025
3f21f2f
Docs for expectedTypeStack
Jul 25, 2025
215008e
Lint
Jul 25, 2025
59f2713
Update error messages
Jul 25, 2025
c13d847
Nits
Jul 25, 2025
112465c
Better convert to null message
Jul 25, 2025
fe02111
Merge branch 'main' into feat/passing-the-type-down-the-expression-chain
iwoplaza Jul 30, 2025
42d3c62
Update tests to still create Output structs before return statement
iwoplaza Jul 30, 2025
4e3edf6
Unify type errors
iwoplaza Jul 30, 2025
dae51bf
Simpler tests, review fixes
iwoplaza Jul 31, 2025
038a186
Holding the expected type stack... on the stack
iwoplaza Jul 31, 2025
5c2048e
Tweaks
iwoplaza Jul 31, 2025
f2ac24b
Review fixes
iwoplaza Jul 31, 2025
8a6e947
Shuffle things around
iwoplaza Jul 31, 2025
fd7e0a4
Merge branch 'main' into feat/passing-the-type-down-the-expression-chain
iwoplaza Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions packages/typegpu/src/core/function/dualImpl.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
import type { TgpuDualFn } from '../../data/dataTypes.ts';
import type { MapValueToSnippet, Snippet } from '../../data/snippet.ts';
import { inCodegenMode } from '../../execMode.ts';
import type { FnArgsConversionHint } from '../../types.ts';
import { setName } from '../../shared/meta.ts';
import { $internal } from '../../shared/symbols.ts';

export function createDualImpl<T extends (...args: never[]) => unknown>(
jsImpl: T,
gpuImpl: (...args: MapValueToSnippet<Parameters<T>>) => Snippet,
name: string,
argConversionHint: FnArgsConversionHint = 'keep',
): TgpuDualFn<T> {
const impl = ((...args: Parameters<T>) => {
if (inCodegenMode()) {
return gpuImpl(...(args as MapValueToSnippet<Parameters<T>>)) as Snippet;
}
return jsImpl(...args);
}) as T;

setName(impl, name);
impl.toString = () => name;
(impl as TgpuDualFn<T>)[$internal] = { jsImpl, gpuImpl, argConversionHint };

return impl as TgpuDualFn<T>;
}
25 changes: 5 additions & 20 deletions packages/typegpu/src/core/function/tgpuFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ import { snip } from '../../data/snippet.ts';
import { Void } from '../../data/wgslTypes.ts';
import { ExecutionError } from '../../errors.ts';
import { provideInsideTgpuFn } from '../../execMode.ts';
import { createDualImpl } from '../../shared/generators.ts';
import type { TgpuNamable } from '../../shared/meta.ts';
import { getName, setName } from '../../shared/meta.ts';
import type { Infer } from '../../shared/repr.ts';
Expand All @@ -14,7 +13,6 @@ import {
$providing,
} from '../../shared/symbols.ts';
import type { Prettify } from '../../shared/utilityTypes.ts';
import type { GenerationCtx } from '../../tgsl/generationHelpers.ts';
import type {
FnArgsConversionHint,
ResolutionCtx,
Expand Down Expand Up @@ -43,6 +41,7 @@ import type {
InheritArgNames,
} from './fnTypes.ts';
import { stripTemplate } from './templateUtils.ts';
import { createDualImpl } from './dualImpl.ts';

// ----------
// Public API
Expand Down Expand Up @@ -85,7 +84,7 @@ export type TgpuFnShell<
interface TgpuFnBase<ImplSchema extends AnyFn> extends TgpuNamable {
readonly [$internal]: {
implementation: Implementation<ImplSchema>;
argTypes: FnArgsConversionHint;
argConversionHint: FnArgsConversionHint;
};
readonly resourceType: 'function';
readonly shell: TgpuFnShellHeader<
Expand Down Expand Up @@ -170,7 +169,7 @@ function createFn<ImplSchema extends AnyFn>(
const fnBase: This = {
[$internal]: {
implementation,
argTypes: shell.argTypes,
argConversionHint: shell.argTypes,
},
shell,
resourceType: 'function' as const,
Expand Down Expand Up @@ -207,23 +206,9 @@ function createFn<ImplSchema extends AnyFn>(
shell.returnType,
core.applyExternals,
);

return core.resolve(ctx, shell.argTypes, shell.returnType);
}

const generationCtx = ctx as GenerationCtx;
if (generationCtx.callStack === undefined) {
throw new Error(
'Cannot resolve a TGSL function outside of a generation context',
);
}

try {
generationCtx.callStack.push(shell.returnType);
return core.resolve(ctx, shell.argTypes, shell.returnType);
} finally {
generationCtx.callStack.pop();
}
return core.resolve(ctx, shell.argTypes, shell.returnType);
},
};

Expand Down Expand Up @@ -284,7 +269,7 @@ function createBoundFunction<ImplSchema extends AnyFn>(
const fnBase: This = {
[$internal]: {
implementation: innerFn[$internal].implementation,
argTypes: innerFn[$internal].argTypes,
argConversionHint: innerFn[$internal].argConversionHint,
},
resourceType: 'function',
shell: innerFn.shell,
Expand Down
31 changes: 5 additions & 26 deletions packages/typegpu/src/core/function/tgpuFragmentFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ import {
type TgpuNamable,
} from '../../shared/meta.ts';
import { $getNameForward, $internal } from '../../shared/symbols.ts';
import type { GenerationCtx } from '../../tgsl/generationHelpers.ts';
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
import { addReturnTypeToExternals } from '../resolve/externals.ts';
import { createFnCore, type FnCore } from './fnCore.ts';
Expand Down Expand Up @@ -227,31 +226,11 @@ function createFragmentFn(
}
core.applyExternals({ Out: outputType });

if (typeof implementation === 'string') {
return core.resolve(
ctx,
inputWithLocation ? [inputWithLocation] : [],
shell.returnType,
);
}

const generationCtx = ctx as GenerationCtx;
if (generationCtx.callStack === undefined) {
throw new Error(
'Cannot resolve a TGSL function outside of a generation context',
);
}

try {
generationCtx.callStack.push(outputType);
return core.resolve(
ctx,
inputWithLocation ? [inputWithLocation] : [],
shell.returnType,
);
} finally {
generationCtx.callStack.pop();
}
return core.resolve(
ctx,
inputWithLocation ? [inputWithLocation] : [],
shell.returnType,
);
},

toString() {
Expand Down
29 changes: 5 additions & 24 deletions packages/typegpu/src/core/function/tgpuVertexFn.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ import {
type TgpuNamable,
} from '../../shared/meta.ts';
import { $getNameForward, $internal } from '../../shared/symbols.ts';
import type { GenerationCtx } from '../../tgsl/generationHelpers.ts';
import type { ResolutionCtx, SelfResolvable } from '../../types.ts';
import { createFnCore, type FnCore } from './fnCore.ts';
import type {
Expand Down Expand Up @@ -207,31 +206,13 @@ function createVertexFn(
core.applyExternals({ In: inputType });
}
core.applyExternals({ Out: outputWithLocation });

return core.resolve(
ctx,
shell.argTypes,
outputWithLocation,
);
}

const generationCtx = ctx as GenerationCtx;
if (generationCtx.callStack === undefined) {
throw new Error(
'Cannot resolve a TGSL function outside of a generation context',
);
}

try {
generationCtx.callStack.push(outputWithLocation);
return core.resolve(
ctx,
shell.argTypes,
outputWithLocation,
);
} finally {
generationCtx.callStack.pop();
}
return core.resolve(
ctx,
shell.argTypes,
outputWithLocation,
);
},

toString() {
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/data/dataTypes.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ export type TgpuDualFn<TImpl extends (...args: never[]) => unknown> =
[$internal]: {
jsImpl: TImpl | string;
gpuImpl: (...args: MapValueToSnippet<Parameters<TImpl>>) => Snippet;
argTypes: FnArgsConversionHint;
argConversionHint: FnArgsConversionHint;
};
};

Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/data/matrix.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
import type { $repr } from '../shared/symbols.ts';
import { $internal } from '../shared/symbols.ts';
import type { SelfResolvable } from '../types.ts';
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/data/numeric.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
import { $internal } from '../shared/symbols.ts';
import { snip } from './snippet.ts';
import type {
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/data/vector.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
import { $repr } from '../shared/symbols.ts';
import { snip } from './snippet.ts';
import { bool, f16, f32, i32, u32 } from './numeric.ts';
Expand Down
9 changes: 9 additions & 0 deletions packages/typegpu/src/errors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -197,3 +197,12 @@ export class IllegalBufferAccessError extends Error {
Object.setPrototypeOf(this, IllegalBufferAccessError.prototype);
}
}

export class WgslTypeError extends Error {
constructor(msg: string) {
super(msg);

// Set the prototype explicitly.
Object.setPrototypeOf(this, WgslTypeError.prototype);
}
}
14 changes: 13 additions & 1 deletion packages/typegpu/src/resolutionCtx.ts
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,14 @@ class ItemStateStackImpl implements ItemStateStack {
return state;
}

get topFunctionReturnType(): AnyData {
const scope = this._stack.findLast((e) => e.type === 'functionScope');
if (!scope) {
throw new Error('Internal error, expected function scope to be present.');
}
return scope.returnType;
}

pushItem() {
this._itemDepth++;
this._stack.push({
Expand Down Expand Up @@ -343,8 +351,8 @@ export class ResolutionCtxImpl implements ResolutionCtx {
public readonly fixedBindings: FixedBindingConfig[] = [];
// --

public readonly callStack: unknown[] = [];
public readonly names: NameRegistry;
public expectedType: AnyData | undefined;

constructor(opts: ResolutionCtxImplOptions) {
this.names = opts.names;
Expand All @@ -354,6 +362,10 @@ export class ResolutionCtxImpl implements ResolutionCtx {
return this._indentController.pre;
}

get topFunctionReturnType() {
return this._itemStateStack.topFunctionReturnType;
}

indent(): string {
return this._indentController.indent();
}
Expand Down
27 changes: 0 additions & 27 deletions packages/typegpu/src/shared/generators.ts
Original file line number Diff line number Diff line change
@@ -1,10 +1,3 @@
import type { TgpuDualFn } from '../data/dataTypes.ts';
import type { MapValueToSnippet, Snippet } from '../data/snippet.ts';
import { inCodegenMode } from '../execMode.ts';
import type { FnArgsConversionHint } from '../types.ts';
import { setName } from './meta.ts';
import { $internal } from './symbols.ts';

/**
* Yields values in the sequence 0,1,2..∞ except for the ones in the `excluded` set.
*/
Expand All @@ -21,23 +14,3 @@ export function* naturalsExcept(
next++;
}
}

export function createDualImpl<T extends (...args: never[]) => unknown>(
jsImpl: T,
gpuImpl: (...args: MapValueToSnippet<Parameters<T>>) => Snippet,
name: string,
argTypes?: FnArgsConversionHint,
): TgpuDualFn<T> {
const impl = ((...args: Parameters<T>) => {
if (inCodegenMode()) {
return gpuImpl(...(args as MapValueToSnippet<Parameters<T>>)) as Snippet;
}
return jsImpl(...args);
}) as T;

setName(impl, name);
impl.toString = () => name;
(impl as TgpuDualFn<T>)[$internal] = { jsImpl, gpuImpl, argTypes };

return impl as TgpuDualFn<T>;
}
2 changes: 1 addition & 1 deletion packages/typegpu/src/std/array.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ import { abstractInt, u32 } from '../data/numeric.ts';
import { ptrFn } from '../data/ptr.ts';
import type { AnyWgslData } from '../data/wgslTypes.ts';
import { isPtr, isWgslArray } from '../data/wgslTypes.ts';
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';

export const arrayLength = createDualImpl(
// CPU implementation
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/std/atomic.ts
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ import {
isWgslData,
Void,
} from '../data/wgslTypes.ts';
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
type AnyAtomic = atomicI32 | atomicU32;

export const workgroupBarrier = createDualImpl(
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/std/boolean.ts
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ import {
type v3b,
type v4b,
} from '../data/wgslTypes.ts';
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
import { isSnippetNumeric, sub } from './numeric.ts';

function correspondingBooleanVectorSchema(value: Snippet) {
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/std/discard.ts
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import { snip } from '../data/snippet.ts';
import { Void } from '../data/wgslTypes.ts';
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';

export const discard = createDualImpl(
// CPU
Expand Down
2 changes: 1 addition & 1 deletion packages/typegpu/src/std/matrix.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import type { Snippet } from '../data/snippet.ts';
import { mat4x4f } from '../data/matrix.ts';
import type { m4x4f, v3f } from '../data/wgslTypes.ts';
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
import { mul } from './numeric.ts';

/**
Expand Down
10 changes: 5 additions & 5 deletions packages/typegpu/src/std/numeric.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ import {
type v4i,
type vBaseForMat,
} from '../data/wgslTypes.ts';
import { createDualImpl } from '../shared/generators.ts';
import { createDualImpl } from '../core/function/dualImpl.ts';
import { $internal } from '../shared/symbols.ts';

type NumVec = AnyNumericVecInstance;
Expand Down Expand Up @@ -96,7 +96,7 @@ export const add = createDualImpl(
`(${lhs.value} + ${rhs.value})`,
isSnippetNumeric(lhs) ? rhs.dataType : lhs.dataType,
),
'coerce',
'unify',
);

function cpuSub(lhs: number, rhs: number): number; // default subtraction
Expand Down Expand Up @@ -126,7 +126,7 @@ export const sub = createDualImpl(
isSnippetNumeric(lhs) ? rhs.dataType : lhs.dataType,
),
'sub',
'coerce',
'unify',
);

function cpuMul(lhs: number, rhs: number): number; // default multiplication
Expand Down Expand Up @@ -510,7 +510,7 @@ export const max = createDualImpl(
// GPU implementation
(a, b) => snip(`max(${a.value}, ${b.value})`, a.dataType),
'max',
'coerce',
'unify',
);

/**
Expand All @@ -528,7 +528,7 @@ export const min = createDualImpl(
// GPU implementation
(a, b) => snip(`min(${a.value}, ${b.value})`, a.dataType),
'min',
'coerce',
'unify',
);

export const sign = createDualImpl(
Expand Down
Loading