@@ -48,16 +48,6 @@ impl ConstantValue {
48
48
}
49
49
}
50
50
51
- pub fn tensor_ty_tokens ( & self ) -> TokenStream {
52
- match self {
53
- ConstantValue :: Tensor ( tensor_type, _) => {
54
- let ty = tensor_type. ty ( ) ;
55
- quote ! { #ty }
56
- }
57
- _ => panic ! ( "Not a tensor constant" ) ,
58
- }
59
- }
60
-
61
51
pub fn val_tokens ( & self ) -> TokenStream {
62
52
match self {
63
53
ConstantValue :: Float32 ( val) => quote ! { #val } ,
@@ -137,23 +127,23 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
137
127
crate :: burn:: TensorKind :: Int => Some ( quote ! {
138
128
let #name: burn:: module:: Param <#ty> = burn:: module:: Param :: uninitialized(
139
129
burn:: module:: ParamId :: new( ) ,
140
- move |device, _require_grad| Tensor :: <B , #dim, burn :: tensor :: Int >:: zeros( #shape, & device) ,
130
+ move |device, _require_grad| Tensor :: <B , #dim, Int >:: zeros( #shape, & device) ,
141
131
device. clone( ) ,
142
132
false
143
133
) ;
144
134
} ) ,
145
135
crate :: burn:: TensorKind :: Float => Some ( quote ! {
146
136
let #name: burn:: module:: Param <#ty> = burn:: module:: Param :: uninitialized(
147
137
burn:: module:: ParamId :: new( ) ,
148
- move |device, _require_grad| Tensor :: <B , #dim, burn :: tensor :: Float >:: zeros( #shape, & device) ,
138
+ move |device, _require_grad| Tensor :: <B , #dim>:: zeros( #shape, & device) ,
149
139
device. clone( ) ,
150
140
false ,
151
141
) ;
152
142
} ) ,
153
143
crate :: burn:: TensorKind :: Bool => Some ( quote ! {
154
144
let #name: burn:: module:: Param <#ty> = burn:: module:: Param :: uninitialized(
155
145
burn:: module:: ParamId :: new( ) ,
156
- move |device, _require_grad| Tensor :: <B , #dim, burn :: tensor :: Bool >:: empty( #shape, & device) ,
146
+ move |device, _require_grad| Tensor :: <B , #dim, Bool >:: empty( #shape, & device) ,
157
147
device. clone( ) ,
158
148
false ,
159
149
) ;
@@ -204,7 +194,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
204
194
mod tests {
205
195
use super :: * ;
206
196
use crate :: burn:: {
207
- graph:: BurnGraph , node:: test:: assert_tokens, ScalarKind , ScalarType , TensorType ,
197
+ ScalarKind , ScalarType , TensorType , graph:: BurnGraph , node:: test:: assert_tokens,
208
198
} ;
209
199
use burn:: record:: FullPrecisionSettings ;
210
200
use burn:: tensor:: TensorData ;
@@ -292,15 +282,6 @@ mod tests {
292
282
assert_codegen_constant_scalar ( ConstantValue :: Bool ( false ) , ScalarKind :: Bool ) ;
293
283
}
294
284
295
- /// Transforms e.g. `&[1usize, 2usize, 3usize]` into literal tokens [1, 2, 3].
296
- fn shape_to_tokens ( shape : & [ usize ] ) -> TokenStream {
297
- let dims = shape. iter ( ) . map ( |d| {
298
- let lit = proc_macro2:: Literal :: usize_unsuffixed ( * d) ;
299
- quote ! { #lit }
300
- } ) ;
301
- quote ! { [ #( #dims) , * ] }
302
- }
303
-
304
285
#[ test]
305
286
fn test_codegen_constant_tensor_float ( ) {
306
287
let mut graph = BurnGraph :: < FullPrecisionSettings > :: default ( ) ;
@@ -326,11 +307,6 @@ mod tests {
326
307
) ) ,
327
308
) ) ;
328
309
329
- let con = const_tensor. to_token_stream ( ) ;
330
- let ty = constant. ty_tokens ( ) ;
331
- let tensor_ty = constant. tensor_ty_tokens ( ) ;
332
- let shp = shape_to_tokens ( & shape) ;
333
-
334
310
graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
335
311
336
312
let expected = quote ! {
@@ -341,26 +317,31 @@ mod tests {
341
317
342
318
#[ derive( Module , Debug ) ]
343
319
pub struct Model <B : Backend > {
344
- #con : #ty ,
320
+ const_tensor : burn :: module :: Param < Tensor < B , 1 >> ,
345
321
phantom: core:: marker:: PhantomData <B >,
346
322
device: burn:: module:: Ignored <B :: Device >,
347
323
}
348
324
349
325
impl <B : Backend > Model <B > {
350
326
#[ allow( unused_variables) ]
351
327
pub fn new( device: & B :: Device ) -> Self {
352
- let #con: #ty = burn:: nn:: Initializer :: Zeros . init( #shp, device) . set_require_grad( false ) ;
328
+ let const_tensor: burn:: module:: Param <Tensor <B , 1 >> = burn:: module:: Param :: uninitialized(
329
+ burn:: module:: ParamId :: new( ) ,
330
+ move |device, _require_grad| Tensor :: <B , 1 >:: zeros( [ 4 ] , & device) ,
331
+ device. clone( ) ,
332
+ false
333
+ ) ;
353
334
354
335
Self {
355
- #con ,
336
+ const_tensor ,
356
337
phantom: core:: marker:: PhantomData ,
357
338
device: burn:: module:: Ignored ( device. clone( ) ) ,
358
339
}
359
340
}
360
341
361
342
#[ allow( clippy:: let_and_return, clippy:: approx_constant) ]
362
- pub fn forward( & self ) -> #tensor_ty {
363
- let output = self . #con . val( ) ;
343
+ pub fn forward( & self ) -> Tensor < B , 1 > {
344
+ let output = self . const_tensor . val( ) ;
364
345
output
365
346
}
366
347
}
@@ -394,11 +375,6 @@ mod tests {
394
375
) ) ,
395
376
) ) ;
396
377
397
- let con = const_tensor. to_token_stream ( ) ;
398
- let ty = constant. ty_tokens ( ) ;
399
- let tensor_ty = constant. tensor_ty_tokens ( ) ;
400
- let shp = shape_to_tokens ( & shape) ;
401
-
402
378
graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
403
379
404
380
let expected = quote ! {
@@ -410,26 +386,31 @@ mod tests {
410
386
411
387
#[ derive( Module , Debug ) ]
412
388
pub struct Model <B : Backend > {
413
- #con : #ty ,
389
+ const_tensor_int : burn :: module :: Param < Tensor < B , 1 , Int >> ,
414
390
phantom: core:: marker:: PhantomData <B >,
415
391
device: burn:: module:: Ignored <B :: Device >,
416
392
}
417
393
418
394
impl <B : Backend > Model <B > {
419
395
#[ allow( unused_variables) ]
420
396
pub fn new( device: & B :: Device ) -> Self {
421
- let #con: #ty = burn:: nn:: Initializer :: Zeros . init( #shp, device) . set_require_grad( false ) ;
397
+ let const_tensor_int: burn:: module:: Param <Tensor <B , 1 , Int >> = burn:: module:: Param :: uninitialized(
398
+ burn:: module:: ParamId :: new( ) ,
399
+ move |device, _require_grad| Tensor :: <B , 1 , Int >:: zeros( [ 3 ] , & device) ,
400
+ device. clone( ) ,
401
+ false
402
+ ) ;
422
403
423
404
Self {
424
- #con ,
405
+ const_tensor_int ,
425
406
phantom: core:: marker:: PhantomData ,
426
407
device: burn:: module:: Ignored ( device. clone( ) ) ,
427
408
}
428
409
}
429
410
430
411
#[ allow( clippy:: let_and_return, clippy:: approx_constant) ]
431
- pub fn forward( & self ) -> #tensor_ty {
432
- let output = self . #con . val( ) ;
412
+ pub fn forward( & self ) -> Tensor < B , 1 , Int > {
413
+ let output = self . const_tensor_int . val( ) ;
433
414
output
434
415
}
435
416
}
@@ -442,75 +423,6 @@ mod tests {
442
423
fn test_codegen_constant_tensor_bool ( ) {
443
424
let mut graph = BurnGraph :: < FullPrecisionSettings > :: default ( ) ;
444
425
445
- let const_tensor = Ident :: new ( "const_tensor_bool" , Span :: call_site ( ) ) ;
446
- let dimensions = 1 ;
447
- let shape = vec ! [ 2 ] ;
448
- let data = TensorData :: from ( [ true , false ] ) ;
449
- let tensor_type = TensorType :: new_bool_with_shape (
450
- const_tensor. to_string ( ) ,
451
- dimensions,
452
- Some ( shape. clone ( ) ) ,
453
- ) ;
454
- let constant = ConstantValue :: Tensor ( tensor_type. clone ( ) , data) ;
455
-
456
- graph. register ( ConstantNode :: new (
457
- const_tensor. to_string ( ) ,
458
- constant. clone ( ) ,
459
- Type :: Tensor ( TensorType :: new_bool_with_shape (
460
- "output" ,
461
- dimensions,
462
- Some ( shape. clone ( ) ) ,
463
- ) ) ,
464
- ) ) ;
465
-
466
- let con = const_tensor. to_token_stream ( ) ;
467
- let ty = constant. ty_tokens ( ) ;
468
- let tensor_ty = constant. tensor_ty_tokens ( ) ;
469
- let shp = shape_to_tokens ( & shape) ;
470
-
471
- graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
472
-
473
- let expected = quote ! {
474
- use burn:: {
475
- module:: Module ,
476
- tensor:: { backend:: Backend , Tensor } ,
477
- } ;
478
- use burn:: tensor:: Bool ;
479
-
480
- #[ derive( Module , Debug ) ]
481
- pub struct Model <B : Backend > {
482
- #con: #ty,
483
- phantom: core:: marker:: PhantomData <B >,
484
- device: burn:: module:: Ignored <B :: Device >,
485
- }
486
-
487
- impl <B : Backend > Model <B > {
488
- #[ allow( unused_variables) ]
489
- pub fn new( device: & B :: Device ) -> Self {
490
- let #con: #ty = burn:: nn:: Initializer :: Zeros . init( #shp, device) . set_require_grad( false ) ;
491
-
492
- Self {
493
- #con,
494
- phantom: core:: marker:: PhantomData ,
495
- device: burn:: module:: Ignored ( device. clone( ) ) ,
496
- }
497
- }
498
-
499
- #[ allow( clippy:: let_and_return, clippy:: approx_constant) ]
500
- pub fn forward( & self ) -> #tensor_ty {
501
- let output = self . #con. val( ) ;
502
- output
503
- }
504
- }
505
- } ;
506
-
507
- assert_tokens ( graph. codegen ( ) , expected) ;
508
- }
509
-
510
- #[ test]
511
- fn test_codegen_constant_tensor_3d ( ) {
512
- let mut graph = BurnGraph :: < FullPrecisionSettings > :: default ( ) ;
513
-
514
426
let const_tensor = Ident :: new ( "const_tensor_3d" , Span :: call_site ( ) ) ;
515
427
let dimensions = 3 ;
516
428
let shape = vec ! [ 1 , 3 , 2 ] ;
@@ -532,11 +444,6 @@ mod tests {
532
444
) ) ,
533
445
) ) ;
534
446
535
- let con = const_tensor. to_token_stream ( ) ;
536
- let ty = constant. ty_tokens ( ) ;
537
- let tensor_ty = constant. tensor_ty_tokens ( ) ;
538
- let shp = shape_to_tokens ( & shape) ;
539
-
540
447
graph. register_input_output ( vec ! [ ] , vec ! [ "output" . to_string( ) ] ) ;
541
448
542
449
let expected = quote ! {
@@ -548,26 +455,31 @@ mod tests {
548
455
549
456
#[ derive( Module , Debug ) ]
550
457
pub struct Model <B : Backend > {
551
- #con : #ty ,
458
+ const_tensor_3d : burn :: module :: Param < Tensor < B , 3 , Bool >> ,
552
459
phantom: core:: marker:: PhantomData <B >,
553
460
device: burn:: module:: Ignored <B :: Device >,
554
461
}
555
462
556
463
impl <B : Backend > Model <B > {
557
464
#[ allow( unused_variables) ]
558
465
pub fn new( device: & B :: Device ) -> Self {
559
- let #con: #ty = burn:: nn:: Initializer :: Zeros . init( #shp, device) . set_require_grad( false ) ;
466
+ let const_tensor_3d: burn:: module:: Param <Tensor <B , 3 , Bool >> = burn:: module:: Param :: uninitialized(
467
+ burn:: module:: ParamId :: new( ) ,
468
+ move |device, _require_grad| Tensor :: <B , 3 , Bool >:: empty( [ 1 , 3 , 2 ] , & device) ,
469
+ device. clone( ) ,
470
+ false
471
+ ) ;
560
472
561
473
Self {
562
- #con ,
474
+ const_tensor_3d ,
563
475
phantom: core:: marker:: PhantomData ,
564
476
device: burn:: module:: Ignored ( device. clone( ) ) ,
565
477
}
566
478
}
567
479
568
480
#[ allow( clippy:: let_and_return, clippy:: approx_constant) ]
569
- pub fn forward( & self ) -> #tensor_ty {
570
- let output = self . #con . val( ) ;
481
+ pub fn forward( & self ) -> Tensor < B , 3 , Bool > {
482
+ let output = self . const_tensor_3d . val( ) ;
571
483
output
572
484
}
573
485
}
0 commit comments