@@ -17,8 +17,7 @@ Data::~Data()
17
17
18
18
void Data::load_data (Input &input, const int ndata, std::string *dir, const torch::Device device)
19
19
{
20
- if (ndata <= 0 ) { return ;
21
- }
20
+ if (ndata <= 0 ) return ;
22
21
this ->init_label (input);
23
22
this ->init_data (input.nkernel , ndata, input.fftdim , device);
24
23
this ->load_data_ (input, ndata, input.fftdim , dir);
@@ -30,36 +29,51 @@ void Data::load_data(Input &input, const int ndata, std::string *dir, const torc
30
29
}
31
30
32
31
torch::Tensor Data::get_data (std::string parameter, const int ikernel){
33
- if (parameter == " gamma" ) { return this ->gamma .reshape ({this ->nx_tot });
34
- }
35
- if (parameter == " p" ) { return this ->p .reshape ({this ->nx_tot });
36
- }
37
- if (parameter == " q" ) { return this ->q .reshape ({this ->nx_tot });
38
- }
39
- if (parameter == " tanhp" ) { return this ->tanhp .reshape ({this ->nx_tot });
40
- }
41
- if (parameter == " tanhq" ) { return this ->tanhq .reshape ({this ->nx_tot });
42
- }
43
- if (parameter == " gammanl" ) { return this ->gammanl [ikernel].reshape ({this ->nx_tot });
44
- }
45
- if (parameter == " pnl" ) { return this ->pnl [ikernel].reshape ({this ->nx_tot });
46
- }
47
- if (parameter == " qnl" ) { return this ->qnl [ikernel].reshape ({this ->nx_tot });
48
- }
49
- if (parameter == " xi" ) { return this ->xi [ikernel].reshape ({this ->nx_tot });
50
- }
51
- if (parameter == " tanhxi" ) { return this ->tanhxi [ikernel].reshape ({this ->nx_tot });
52
- }
53
- if (parameter == " tanhxi_nl" ) { return this ->tanhxi_nl [ikernel].reshape ({this ->nx_tot });
54
- }
55
- if (parameter == " tanh_pnl" ) { return this ->tanh_pnl [ikernel].reshape ({this ->nx_tot });
56
- }
57
- if (parameter == " tanh_qnl" ) { return this ->tanh_qnl [ikernel].reshape ({this ->nx_tot });
58
- }
59
- if (parameter == " tanhp_nl" ) { return this ->tanhp_nl [ikernel].reshape ({this ->nx_tot });
60
- }
61
- if (parameter == " tanhq_nl" ) { return this ->tanhq_nl [ikernel].reshape ({this ->nx_tot });
62
- }
32
+ if (parameter == " gamma" ){
33
+ return this ->gamma .reshape ({this ->nx_tot });
34
+ }
35
+ if (parameter == " p" ){
36
+ return this ->p .reshape ({this ->nx_tot });
37
+ }
38
+ if (parameter == " q" ){
39
+ return this ->q .reshape ({this ->nx_tot });
40
+ }
41
+ if (parameter == " tanhp" ){
42
+ return this ->tanhp .reshape ({this ->nx_tot });
43
+ }
44
+ if (parameter == " tanhq" ){
45
+ return this ->tanhq .reshape ({this ->nx_tot });
46
+ }
47
+ if (parameter == " gammanl" ){
48
+ return this ->gammanl [ikernel].reshape ({this ->nx_tot });
49
+ }
50
+ if (parameter == " pnl" ){
51
+ return this ->pnl [ikernel].reshape ({this ->nx_tot });
52
+ }
53
+ if (parameter == " qnl" ){
54
+ return this ->qnl [ikernel].reshape ({this ->nx_tot });
55
+ }
56
+ if (parameter == " xi" ){
57
+ return this ->xi [ikernel].reshape ({this ->nx_tot });
58
+ }
59
+ if (parameter == " tanhxi" ){
60
+ return this ->tanhxi [ikernel].reshape ({this ->nx_tot });
61
+ }
62
+ if (parameter == " tanhxi_nl" ){
63
+ return this ->tanhxi_nl [ikernel].reshape ({this ->nx_tot });
64
+ }
65
+ if (parameter == " tanh_pnl" ){
66
+ return this ->tanh_pnl [ikernel].reshape ({this ->nx_tot });
67
+ }
68
+ if (parameter == " tanh_qnl" ){
69
+ return this ->tanh_qnl [ikernel].reshape ({this ->nx_tot });
70
+ }
71
+ if (parameter == " tanhp_nl" ){
72
+ return this ->tanhp_nl [ikernel].reshape ({this ->nx_tot });
73
+ }
74
+ if (parameter == " tanhq_nl" ){
75
+ return this ->tanhq_nl [ikernel].reshape ({this ->nx_tot });
76
+ }
63
77
return torch::zeros ({});
64
78
}
65
79
@@ -139,25 +153,31 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
139
153
this ->nx_tot = this ->nx * ndata;
140
154
141
155
this ->rho = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
142
- if (this ->load_p ) { this ->nablaRho = torch::zeros ({ndata, 3 , fftdim, fftdim, fftdim}).to (device);
143
- }
156
+ if (this ->load_p ){
157
+ this ->nablaRho = torch::zeros ({ndata, 3 , fftdim, fftdim, fftdim}).to (device);
158
+ }
144
159
145
160
this ->enhancement = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
146
161
this ->enhancement_mean = torch::zeros (ndata).to (device);
147
162
this ->tau_mean = torch::zeros (ndata).to (device);
148
163
this ->pauli = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
149
164
this ->pauli_mean = torch::zeros (ndata).to (device);
150
165
151
- if (this ->load_gamma ) { this ->gamma = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
152
- }
153
- if (this ->load_p ) { this ->p = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
154
- }
155
- if (this ->load_q ) { this ->q = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
156
- }
157
- if (this ->load_tanhp ) { this ->tanhp = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
158
- }
159
- if (this ->load_tanhq ) { this ->tanhq = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
160
- }
166
+ if (this ->load_gamma ){
167
+ this ->gamma = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
168
+ }
169
+ if (this ->load_p ){
170
+ this ->p = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
171
+ }
172
+ if (this ->load_q ){
173
+ this ->q = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
174
+ }
175
+ if (this ->load_tanhp ){
176
+ this ->tanhp = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
177
+ }
178
+ if (this ->load_tanhq ){
179
+ this ->tanhq = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
180
+ }
161
181
162
182
for (int ik = 0 ; ik < nkernel; ++ik)
163
183
{
@@ -172,26 +192,36 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
172
192
this ->tanhp_nl .push_back (torch::zeros ({}).to (device));
173
193
this ->tanhq_nl .push_back (torch::zeros ({}).to (device));
174
194
175
- if (this ->load_gammanl [ik]) { this ->gammanl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
176
- }
177
- if (this ->load_pnl [ik]) { this ->pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
178
- }
179
- if (this ->load_qnl [ik]) { this ->qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
180
- }
181
- if (this ->load_xi [ik]) { this ->xi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
182
- }
183
- if (this ->load_tanhxi [ik]) { this ->tanhxi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
184
- }
185
- if (this ->load_tanhxi_nl [ik]) { this ->tanhxi_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
186
- }
187
- if (this ->load_tanh_pnl [ik]) { this ->tanh_pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
188
- }
189
- if (this ->load_tanh_qnl [ik]) { this ->tanh_qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
190
- }
191
- if (this ->load_tanhp_nl [ik]) { this ->tanhp_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
192
- }
193
- if (this ->load_tanhq_nl [ik]) { this ->tanhq_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
194
- }
195
+ if (this ->load_gammanl [ik]){
196
+ this ->gammanl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
197
+ }
198
+ if (this ->load_pnl [ik]){
199
+ this ->pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
200
+ }
201
+ if (this ->load_qnl [ik]){
202
+ this ->qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
203
+ }
204
+ if (this ->load_xi [ik]){
205
+ this ->xi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
206
+ }
207
+ if (this ->load_tanhxi [ik]){
208
+ this ->tanhxi [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
209
+ }
210
+ if (this ->load_tanhxi_nl [ik{
211
+ this ->tanhxi_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
212
+ }
213
+ if (this ->load_tanh_pnl [ik]){
214
+ this ->tanh_pnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
215
+ }
216
+ if (this ->load_tanh_qnl [ik]){
217
+ this ->tanh_qnl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
218
+ }
219
+ if (this ->load_tanhp_nl [ik]){
220
+ this ->tanhp_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
221
+ }
222
+ if (this ->load_tanhq_nl [ik]){
223
+ this ->tanhq_nl [ik] = torch::zeros ({ndata, fftdim, fftdim, fftdim}).to (device);
224
+ }
195
225
}
196
226
197
227
// Input::print("init_data done");
@@ -205,20 +235,21 @@ void Data::load_data_(
205
235
)
206
236
{
207
237
// Input::print("load_data");
208
- if (ndata <= 0 ) { return ;
209
- }
210
-
238
+ if (ndata <= 0 ){
239
+ return ;
240
+ }
241
+
211
242
std::vector<long unsigned int > cshape = {(long unsigned ) nx};
212
243
std::vector<double > container (nx);
213
244
bool fortran_order = false ;
214
245
215
246
for (int idata = 0 ; idata < ndata; ++idata)
216
247
{
217
248
this ->loadTensor (dir[idata] + " /rho.npy" , cshape, fortran_order, container, idata, fftdim, rho);
218
- if (this ->load_gamma ) { this -> loadTensor (dir[idata] + " /gamma.npy " , cshape, fortran_order, container, idata, fftdim, gamma );
219
- }
220
- if ( this -> load_p )
221
- {
249
+ if (this ->load_gamma ){
250
+ this -> loadTensor (dir[idata] + " /gamma.npy " , cshape, fortran_order, container, idata, fftdim, gamma );
251
+ }
252
+ if ( this -> load_p ) {
222
253
this ->loadTensor (dir[idata] + " /p.npy" , cshape, fortran_order, container, idata, fftdim, p);
223
254
npy::LoadArrayFromNumpy (dir[idata] + " /nablaRhox.npy" , cshape, fortran_order, container);
224
255
nablaRho[idata][0 ] = torch::tensor (container).reshape ({fftdim, fftdim, fftdim});
@@ -227,12 +258,15 @@ void Data::load_data_(
227
258
npy::LoadArrayFromNumpy (dir[idata] + " /nablaRhoz.npy" , cshape, fortran_order, container);
228
259
nablaRho[idata][2 ] = torch::tensor (container).reshape ({fftdim, fftdim, fftdim});
229
260
}
230
- if (this ->load_q ) { this ->loadTensor (dir[idata] + " /q.npy" , cshape, fortran_order, container, idata, fftdim, q);
231
- }
232
- if (this ->load_tanhp ) { this ->loadTensor (dir[idata] + " /tanhp.npy" , cshape, fortran_order, container, idata, fftdim, tanhp);
233
- }
234
- if (this ->load_tanhq ) { this ->loadTensor (dir[idata] + " /tanhq.npy" , cshape, fortran_order, container, idata, fftdim, tanhq);
235
- }
261
+ if (this ->load_q ){
262
+ this ->loadTensor (dir[idata] + " /q.npy" , cshape, fortran_order, container, idata, fftdim, q);
263
+ }
264
+ if (this ->load_tanhp ){
265
+ this ->loadTensor (dir[idata] + " /tanhp.npy" , cshape, fortran_order, container, idata, fftdim, tanhp);
266
+ }
267
+ if (this ->load_tanhq ){
268
+ this ->loadTensor (dir[idata] + " /tanhq.npy" , cshape, fortran_order, container, idata, fftdim, tanhq);
269
+ }
236
270
237
271
for (int ik = 0 ; ik < input.nkernel ; ++ik)
238
272
{
@@ -305,8 +339,9 @@ void Data::loadTensor(
305
339
void Data::dumpTensor (const torch::Tensor &data, std::string filename, int nx)
306
340
{
307
341
std::vector<double > v (nx);
308
- for (int ir = 0 ; ir < nx; ++ir) { v[ir] = data[ir].item <double >();
309
- }
342
+ for (int ir = 0 ; ir < nx; ++ir){
343
+ v[ir] = data[ir].item <double >();
344
+ }
310
345
// std::vector<double> v(data.data_ptr<float>(), data.data_ptr<float>() + data.numel()); // this works, but only supports float tensor
311
346
const long unsigned cshape[] = {(long unsigned ) nx}; // shape
312
347
npy::SaveArrayAsNumpy (filename, false , 1 , cshape, v);
0 commit comments