@@ -18,6 +18,7 @@ import {
18
18
type ExternalMap ,
19
19
replaceExternalsInWgsl ,
20
20
} from '../resolve/externals.ts' ;
21
+ import { extractArgs } from './extractArgs.ts' ;
21
22
import type { Implementation } from './fnTypes.ts' ;
22
23
23
24
export interface TgpuFnShellBase < Args extends unknown [ ] , Return > {
@@ -81,31 +82,70 @@ export function createFnCore(
81
82
const id = ctx . names . makeUnique ( getName ( this ) ) ;
82
83
83
84
if ( typeof implementation === 'string' ) {
85
+ const replacedImpl = replaceExternalsInWgsl (
86
+ ctx ,
87
+ externalMap ,
88
+ implementation ,
89
+ ) ;
90
+
84
91
let header = '' ;
92
+ let body = '' ;
85
93
86
94
if ( shell . isEntry ) {
87
- const input = isWgslStruct ( shell . argTypes [ 0 ] ) ? '(in: In)' : '()' ;
95
+ const input = isWgslStruct ( shell . argTypes [ 0 ] )
96
+ ? `(in: ${ ctx . resolve ( shell . argTypes [ 0 ] ) } )`
97
+ : '()' ;
88
98
89
99
const attributes = isWgslData ( shell . returnType )
90
100
? getAttributesString ( shell . returnType )
91
101
: '' ;
92
102
const output = shell . returnType !== Void
93
103
? isWgslStruct ( shell . returnType )
94
- ? ' -> Out'
104
+ ? ` -> ${ ctx . resolve ( shell . returnType ) } `
95
105
: `-> ${ attributes !== '' ? attributes : '@location(0)' } ${
96
106
ctx . resolve ( shell . returnType )
97
107
} `
98
108
: '' ;
109
+
99
110
header = `${ input } ${ output } ` ;
111
+ body = replacedImpl ;
112
+ } else {
113
+ const providedArgs = extractArgs ( replacedImpl ) ;
114
+
115
+ if ( providedArgs . args . length !== shell . argTypes . length ) {
116
+ throw new Error (
117
+ `WGSL implementation has ${ providedArgs . args . length } arguments, while the shell has ${ shell . argTypes . length } arguments.` ,
118
+ ) ;
119
+ }
120
+
121
+ const input = providedArgs . args . map ( ( argInfo , i ) =>
122
+ `${ argInfo . identifier } : ${
123
+ checkAndReturnType (
124
+ ctx ,
125
+ `parameter ${ argInfo . identifier } ` ,
126
+ argInfo . type ,
127
+ shell . argTypes [ i ] ,
128
+ )
129
+ } `
130
+ ) . join ( ', ' ) ;
131
+
132
+ const output = shell . returnType === Void
133
+ ? ''
134
+ : `-> ${
135
+ checkAndReturnType (
136
+ ctx ,
137
+ 'return type' ,
138
+ providedArgs . ret ?. type ,
139
+ shell . returnType ,
140
+ )
141
+ } `;
142
+
143
+ header = `(${ input } ) ${ output } ` ;
144
+
145
+ body = replacedImpl . slice ( providedArgs . range . end ) ;
100
146
}
101
147
102
- const replacedImpl = replaceExternalsInWgsl (
103
- ctx ,
104
- externalMap ,
105
- `${ header } ${ implementation . trim ( ) } ` ,
106
- ) ;
107
-
108
- ctx . addDeclaration ( `${ fnAttribute } fn ${ id } ${ replacedImpl } ` ) ;
148
+ ctx . addDeclaration ( `${ fnAttribute } fn ${ id } ${ header } ${ body } ` ) ;
109
149
} else {
110
150
// get data generated by the plugin
111
151
const pluginData = getMetaData ( implementation ) ;
@@ -176,3 +216,26 @@ export function createFnCore(
176
216
177
217
return core ;
178
218
}
219
+
220
+ function checkAndReturnType (
221
+ ctx : ResolutionCtx ,
222
+ name : string ,
223
+ wgslType : string | undefined ,
224
+ jsType : unknown ,
225
+ ) {
226
+ const resolvedJsType = ctx . resolve ( jsType ) . replace ( / \s / g, '' ) ;
227
+
228
+ if ( ! wgslType ) {
229
+ return resolvedJsType ;
230
+ }
231
+
232
+ const resolvedWgslType = wgslType . replace ( / \s / g, '' ) ;
233
+
234
+ if ( resolvedWgslType !== resolvedJsType ) {
235
+ throw new Error (
236
+ `Type mismatch between TGPU shell and WGSL code string: ${ name } , JS type "${ resolvedJsType } ", WGSL type "${ resolvedWgslType } ".` ,
237
+ ) ;
238
+ }
239
+
240
+ return wgslType ;
241
+ }
0 commit comments