Skip to content

Commit 83458c6

Browse files
committed
Cleanup tests
1 parent d206c53 commit 83458c6

File tree

1 file changed

+34
-122
lines changed

1 file changed

+34
-122
lines changed

crates/burn-import/src/burn/node/constant.rs

+34-122
Original file line numberDiff line numberDiff line change
@@ -48,16 +48,6 @@ impl ConstantValue {
4848
}
4949
}
5050

51-
pub fn tensor_ty_tokens(&self) -> TokenStream {
52-
match self {
53-
ConstantValue::Tensor(tensor_type, _) => {
54-
let ty = tensor_type.ty();
55-
quote! { #ty }
56-
}
57-
_ => panic!("Not a tensor constant"),
58-
}
59-
}
60-
6151
pub fn val_tokens(&self) -> TokenStream {
6252
match self {
6353
ConstantValue::Float32(val) => quote! { #val },
@@ -137,23 +127,23 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
137127
crate::burn::TensorKind::Int => Some(quote! {
138128
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
139129
burn::module::ParamId::new(),
140-
move |device, _require_grad| Tensor::<B, #dim, burn::tensor::Int>::zeros(#shape, &device),
130+
move |device, _require_grad| Tensor::<B, #dim, Int>::zeros(#shape, &device),
141131
device.clone(),
142132
false
143133
);
144134
}),
145135
crate::burn::TensorKind::Float => Some(quote! {
146136
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
147137
burn::module::ParamId::new(),
148-
move |device, _require_grad| Tensor::<B, #dim, burn::tensor::Float>::zeros(#shape, &device),
138+
move |device, _require_grad| Tensor::<B, #dim>::zeros(#shape, &device),
149139
device.clone(),
150140
false,
151141
);
152142
}),
153143
crate::burn::TensorKind::Bool => Some(quote! {
154144
let #name: burn::module::Param<#ty> = burn::module::Param::uninitialized(
155145
burn::module::ParamId::new(),
156-
move |device, _require_grad| Tensor::<B, #dim, burn::tensor::Bool>::empty(#shape, &device),
146+
move |device, _require_grad| Tensor::<B, #dim, Bool>::empty(#shape, &device),
157147
device.clone(),
158148
false,
159149
);
@@ -204,7 +194,7 @@ impl<PS: PrecisionSettings> NodeCodegen<PS> for ConstantNode {
204194
mod tests {
205195
use super::*;
206196
use crate::burn::{
207-
graph::BurnGraph, node::test::assert_tokens, ScalarKind, ScalarType, TensorType,
197+
ScalarKind, ScalarType, TensorType, graph::BurnGraph, node::test::assert_tokens,
208198
};
209199
use burn::record::FullPrecisionSettings;
210200
use burn::tensor::TensorData;
@@ -292,15 +282,6 @@ mod tests {
292282
assert_codegen_constant_scalar(ConstantValue::Bool(false), ScalarKind::Bool);
293283
}
294284

295-
/// Transforms e.g. `&[1usize, 2usize, 3usize]` into literal tokens [1, 2, 3].
296-
fn shape_to_tokens(shape: &[usize]) -> TokenStream {
297-
let dims = shape.iter().map(|d| {
298-
let lit = proc_macro2::Literal::usize_unsuffixed(*d);
299-
quote! { #lit }
300-
});
301-
quote! { [#(#dims),*] }
302-
}
303-
304285
#[test]
305286
fn test_codegen_constant_tensor_float() {
306287
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
@@ -326,11 +307,6 @@ mod tests {
326307
)),
327308
));
328309

329-
let con = const_tensor.to_token_stream();
330-
let ty = constant.ty_tokens();
331-
let tensor_ty = constant.tensor_ty_tokens();
332-
let shp = shape_to_tokens(&shape);
333-
334310
graph.register_input_output(vec![], vec!["output".to_string()]);
335311

336312
let expected = quote! {
@@ -341,26 +317,31 @@ mod tests {
341317

342318
#[derive(Module, Debug)]
343319
pub struct Model<B: Backend> {
344-
#con: #ty,
320+
const_tensor: burn::module::Param<Tensor<B, 1>>,
345321
phantom: core::marker::PhantomData<B>,
346322
device: burn::module::Ignored<B::Device>,
347323
}
348324

349325
impl<B: Backend> Model<B> {
350326
#[allow(unused_variables)]
351327
pub fn new(device: &B::Device) -> Self {
352-
let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false);
328+
let const_tensor: burn::module::Param<Tensor<B, 1>> = burn::module::Param::uninitialized(
329+
burn::module::ParamId::new(),
330+
move |device, _require_grad| Tensor::<B, 1>::zeros([4], &device),
331+
device.clone(),
332+
false
333+
);
353334

354335
Self {
355-
#con,
336+
const_tensor,
356337
phantom: core::marker::PhantomData,
357338
device: burn::module::Ignored(device.clone()),
358339
}
359340
}
360341

361342
#[allow(clippy::let_and_return, clippy::approx_constant)]
362-
pub fn forward(&self) -> #tensor_ty {
363-
let output = self.#con.val();
343+
pub fn forward(&self) -> Tensor<B, 1> {
344+
let output = self.const_tensor.val();
364345
output
365346
}
366347
}
@@ -394,11 +375,6 @@ mod tests {
394375
)),
395376
));
396377

397-
let con = const_tensor.to_token_stream();
398-
let ty = constant.ty_tokens();
399-
let tensor_ty = constant.tensor_ty_tokens();
400-
let shp = shape_to_tokens(&shape);
401-
402378
graph.register_input_output(vec![], vec!["output".to_string()]);
403379

404380
let expected = quote! {
@@ -410,26 +386,31 @@ mod tests {
410386

411387
#[derive(Module, Debug)]
412388
pub struct Model<B: Backend> {
413-
#con: #ty,
389+
const_tensor_int: burn::module::Param<Tensor<B, 1, Int>>,
414390
phantom: core::marker::PhantomData<B>,
415391
device: burn::module::Ignored<B::Device>,
416392
}
417393

418394
impl<B: Backend> Model<B> {
419395
#[allow(unused_variables)]
420396
pub fn new(device: &B::Device) -> Self {
421-
let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false);
397+
let const_tensor_int: burn::module::Param<Tensor<B, 1, Int>> = burn::module::Param::uninitialized(
398+
burn::module::ParamId::new(),
399+
move |device, _require_grad| Tensor::<B, 1, Int>::zeros([3], &device),
400+
device.clone(),
401+
false
402+
);
422403

423404
Self {
424-
#con,
405+
const_tensor_int,
425406
phantom: core::marker::PhantomData,
426407
device: burn::module::Ignored(device.clone()),
427408
}
428409
}
429410

430411
#[allow(clippy::let_and_return, clippy::approx_constant)]
431-
pub fn forward(&self) -> #tensor_ty {
432-
let output = self.#con.val();
412+
pub fn forward(&self) -> Tensor<B, 1, Int> {
413+
let output = self.const_tensor_int.val();
433414
output
434415
}
435416
}
@@ -442,75 +423,6 @@ mod tests {
442423
fn test_codegen_constant_tensor_bool() {
443424
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
444425

445-
let const_tensor = Ident::new("const_tensor_bool", Span::call_site());
446-
let dimensions = 1;
447-
let shape = vec![2];
448-
let data = TensorData::from([true, false]);
449-
let tensor_type = TensorType::new_bool_with_shape(
450-
const_tensor.to_string(),
451-
dimensions,
452-
Some(shape.clone()),
453-
);
454-
let constant = ConstantValue::Tensor(tensor_type.clone(), data);
455-
456-
graph.register(ConstantNode::new(
457-
const_tensor.to_string(),
458-
constant.clone(),
459-
Type::Tensor(TensorType::new_bool_with_shape(
460-
"output",
461-
dimensions,
462-
Some(shape.clone()),
463-
)),
464-
));
465-
466-
let con = const_tensor.to_token_stream();
467-
let ty = constant.ty_tokens();
468-
let tensor_ty = constant.tensor_ty_tokens();
469-
let shp = shape_to_tokens(&shape);
470-
471-
graph.register_input_output(vec![], vec!["output".to_string()]);
472-
473-
let expected = quote! {
474-
use burn::{
475-
module::Module,
476-
tensor::{backend::Backend, Tensor},
477-
};
478-
use burn::tensor::Bool;
479-
480-
#[derive(Module, Debug)]
481-
pub struct Model<B: Backend> {
482-
#con: #ty,
483-
phantom: core::marker::PhantomData<B>,
484-
device: burn::module::Ignored<B::Device>,
485-
}
486-
487-
impl<B: Backend> Model<B> {
488-
#[allow(unused_variables)]
489-
pub fn new(device: &B::Device) -> Self {
490-
let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false);
491-
492-
Self {
493-
#con,
494-
phantom: core::marker::PhantomData,
495-
device: burn::module::Ignored(device.clone()),
496-
}
497-
}
498-
499-
#[allow(clippy::let_and_return, clippy::approx_constant)]
500-
pub fn forward(&self) -> #tensor_ty {
501-
let output = self.#con.val();
502-
output
503-
}
504-
}
505-
};
506-
507-
assert_tokens(graph.codegen(), expected);
508-
}
509-
510-
#[test]
511-
fn test_codegen_constant_tensor_3d() {
512-
let mut graph = BurnGraph::<FullPrecisionSettings>::default();
513-
514426
let const_tensor = Ident::new("const_tensor_3d", Span::call_site());
515427
let dimensions = 3;
516428
let shape = vec![1, 3, 2];
@@ -532,11 +444,6 @@ mod tests {
532444
)),
533445
));
534446

535-
let con = const_tensor.to_token_stream();
536-
let ty = constant.ty_tokens();
537-
let tensor_ty = constant.tensor_ty_tokens();
538-
let shp = shape_to_tokens(&shape);
539-
540447
graph.register_input_output(vec![], vec!["output".to_string()]);
541448

542449
let expected = quote! {
@@ -548,26 +455,31 @@ mod tests {
548455

549456
#[derive(Module, Debug)]
550457
pub struct Model<B: Backend> {
551-
#con: #ty,
458+
const_tensor_3d: burn::module::Param<Tensor<B, 3, Bool>>,
552459
phantom: core::marker::PhantomData<B>,
553460
device: burn::module::Ignored<B::Device>,
554461
}
555462

556463
impl<B: Backend> Model<B> {
557464
#[allow(unused_variables)]
558465
pub fn new(device: &B::Device) -> Self {
559-
let #con: #ty = burn::nn::Initializer::Zeros.init(#shp, device).set_require_grad(false);
466+
let const_tensor_3d: burn::module::Param<Tensor<B, 3, Bool>> = burn::module::Param::uninitialized(
467+
burn::module::ParamId::new(),
468+
move |device, _require_grad| Tensor::<B, 3, Bool>::empty([1, 3, 2], &device),
469+
device.clone(),
470+
false
471+
);
560472

561473
Self {
562-
#con,
474+
const_tensor_3d,
563475
phantom: core::marker::PhantomData,
564476
device: burn::module::Ignored(device.clone()),
565477
}
566478
}
567479

568480
#[allow(clippy::let_and_return, clippy::approx_constant)]
569-
pub fn forward(&self) -> #tensor_ty {
570-
let output = self.#con.val();
481+
pub fn forward(&self) -> Tensor<B, 3, Bool> {
482+
let output = self.const_tensor_3d.val();
571483
output
572484
}
573485
}

0 commit comments

Comments
 (0)