Skip to content

Commit fbd9a6e

Browse files
committed
Format ml_tools/data.cpp and ml_tools/grid.cpp
1 parent 3e25aaa commit fbd9a6e

File tree

2 files changed

+120
-82
lines changed

2 files changed

+120
-82
lines changed

source/module_hamilt_pw/hamilt_ofdft/ml_tools/data.cpp

+114-79
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,7 @@ Data::~Data()
1717

1818
void Data::load_data(Input &input, const int ndata, std::string *dir, const torch::Device device)
1919
{
20-
if (ndata <= 0) { return;
21-
}
20+
if (ndata <= 0) return;
2221
this->init_label(input);
2322
this->init_data(input.nkernel, ndata, input.fftdim, device);
2423
this->load_data_(input, ndata, input.fftdim, dir);
@@ -30,36 +29,51 @@ void Data::load_data(Input &input, const int ndata, std::string *dir, const torc
3029
}
3130

3231
torch::Tensor Data::get_data(std::string parameter, const int ikernel){
33-
if (parameter == "gamma") { return this->gamma.reshape({this->nx_tot});
34-
}
35-
if (parameter == "p") { return this->p.reshape({this->nx_tot});
36-
}
37-
if (parameter == "q") { return this->q.reshape({this->nx_tot});
38-
}
39-
if (parameter == "tanhp") { return this->tanhp.reshape({this->nx_tot});
40-
}
41-
if (parameter == "tanhq") { return this->tanhq.reshape({this->nx_tot});
42-
}
43-
if (parameter == "gammanl") { return this->gammanl[ikernel].reshape({this->nx_tot});
44-
}
45-
if (parameter == "pnl") { return this->pnl[ikernel].reshape({this->nx_tot});
46-
}
47-
if (parameter == "qnl") { return this->qnl[ikernel].reshape({this->nx_tot});
48-
}
49-
if (parameter == "xi") { return this->xi[ikernel].reshape({this->nx_tot});
50-
}
51-
if (parameter == "tanhxi") { return this->tanhxi[ikernel].reshape({this->nx_tot});
52-
}
53-
if (parameter == "tanhxi_nl") { return this->tanhxi_nl[ikernel].reshape({this->nx_tot});
54-
}
55-
if (parameter == "tanh_pnl") { return this->tanh_pnl[ikernel].reshape({this->nx_tot});
56-
}
57-
if (parameter == "tanh_qnl") { return this->tanh_qnl[ikernel].reshape({this->nx_tot});
58-
}
59-
if (parameter == "tanhp_nl") { return this->tanhp_nl[ikernel].reshape({this->nx_tot});
60-
}
61-
if (parameter == "tanhq_nl") { return this->tanhq_nl[ikernel].reshape({this->nx_tot});
62-
}
32+
if (parameter == "gamma"){
33+
return this->gamma.reshape({this->nx_tot});
34+
}
35+
if (parameter == "p"){
36+
return this->p.reshape({this->nx_tot});
37+
}
38+
if (parameter == "q"){
39+
return this->q.reshape({this->nx_tot});
40+
}
41+
if (parameter == "tanhp"){
42+
return this->tanhp.reshape({this->nx_tot});
43+
}
44+
if (parameter == "tanhq"){
45+
return this->tanhq.reshape({this->nx_tot});
46+
}
47+
if (parameter == "gammanl"){
48+
return this->gammanl[ikernel].reshape({this->nx_tot});
49+
}
50+
if (parameter == "pnl"){
51+
return this->pnl[ikernel].reshape({this->nx_tot});
52+
}
53+
if (parameter == "qnl"){
54+
return this->qnl[ikernel].reshape({this->nx_tot});
55+
}
56+
if (parameter == "xi"){
57+
return this->xi[ikernel].reshape({this->nx_tot});
58+
}
59+
if (parameter == "tanhxi"){
60+
return this->tanhxi[ikernel].reshape({this->nx_tot});
61+
}
62+
if (parameter == "tanhxi_nl"){
63+
return this->tanhxi_nl[ikernel].reshape({this->nx_tot});
64+
}
65+
if (parameter == "tanh_pnl"){
66+
return this->tanh_pnl[ikernel].reshape({this->nx_tot});
67+
}
68+
if (parameter == "tanh_qnl"){
69+
return this->tanh_qnl[ikernel].reshape({this->nx_tot});
70+
}
71+
if (parameter == "tanhp_nl"){
72+
return this->tanhp_nl[ikernel].reshape({this->nx_tot});
73+
}
74+
if (parameter == "tanhq_nl"){
75+
return this->tanhq_nl[ikernel].reshape({this->nx_tot});
76+
}
6377
return torch::zeros({});
6478
}
6579

@@ -139,25 +153,31 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
139153
this->nx_tot = this->nx * ndata;
140154

141155
this->rho = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
142-
if (this->load_p) { this->nablaRho = torch::zeros({ndata, 3, fftdim, fftdim, fftdim}).to(device);
143-
}
156+
if (this->load_p){
157+
this->nablaRho = torch::zeros({ndata, 3, fftdim, fftdim, fftdim}).to(device);
158+
}
144159

145160
this->enhancement = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
146161
this->enhancement_mean = torch::zeros(ndata).to(device);
147162
this->tau_mean = torch::zeros(ndata).to(device);
148163
this->pauli = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
149164
this->pauli_mean = torch::zeros(ndata).to(device);
150165

151-
if (this->load_gamma) { this->gamma = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
152-
}
153-
if (this->load_p) { this->p = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
154-
}
155-
if (this->load_q) { this->q = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
156-
}
157-
if (this->load_tanhp) { this->tanhp = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
158-
}
159-
if (this->load_tanhq) { this->tanhq = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
160-
}
166+
if (this->load_gamma){
167+
this->gamma = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
168+
}
169+
if (this->load_p){
170+
this->p = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
171+
}
172+
if (this->load_q){
173+
this->q = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
174+
}
175+
if (this->load_tanhp){
176+
this->tanhp = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
177+
}
178+
if (this->load_tanhq){
179+
this->tanhq = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
180+
}
161181

162182
for (int ik = 0; ik < nkernel; ++ik)
163183
{
@@ -172,26 +192,36 @@ void Data::init_data(const int nkernel, const int ndata, const int fftdim, const
172192
this->tanhp_nl.push_back(torch::zeros({}).to(device));
173193
this->tanhq_nl.push_back(torch::zeros({}).to(device));
174194

175-
if (this->load_gammanl[ik]) { this->gammanl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
176-
}
177-
if (this->load_pnl[ik]) { this->pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
178-
}
179-
if (this->load_qnl[ik]) { this->qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
180-
}
181-
if (this->load_xi[ik]) { this->xi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
182-
}
183-
if (this->load_tanhxi[ik]) { this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
184-
}
185-
if (this->load_tanhxi_nl[ik]) { this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
186-
}
187-
if (this->load_tanh_pnl[ik]) { this->tanh_pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
188-
}
189-
if (this->load_tanh_qnl[ik]) { this->tanh_qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
190-
}
191-
if (this->load_tanhp_nl[ik]) { this->tanhp_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
192-
}
193-
if (this->load_tanhq_nl[ik]) { this->tanhq_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
194-
}
195+
if (this->load_gammanl[ik]){
196+
this->gammanl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
197+
}
198+
if (this->load_pnl[ik]){
199+
this->pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
200+
}
201+
if (this->load_qnl[ik]){
202+
this->qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
203+
}
204+
if (this->load_xi[ik]){
205+
this->xi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
206+
}
207+
if (this->load_tanhxi[ik]){
208+
this->tanhxi[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
209+
}
210+
if (this->load_tanhxi_nl[ik{
211+
this->tanhxi_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
212+
}
213+
if (this->load_tanh_pnl[ik]){
214+
this->tanh_pnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
215+
}
216+
if (this->load_tanh_qnl[ik]){
217+
this->tanh_qnl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
218+
}
219+
if (this->load_tanhp_nl[ik]){
220+
this->tanhp_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
221+
}
222+
if (this->load_tanhq_nl[ik]){
223+
this->tanhq_nl[ik] = torch::zeros({ndata, fftdim, fftdim, fftdim}).to(device);
224+
}
195225
}
196226

197227
// Input::print("init_data done");
@@ -205,20 +235,21 @@ void Data::load_data_(
205235
)
206236
{
207237
// Input::print("load_data");
208-
if (ndata <= 0) { return;
209-
}
210-
238+
if (ndata <= 0){
239+
return;
240+
}
241+
211242
std::vector<long unsigned int> cshape = {(long unsigned) nx};
212243
std::vector<double> container(nx);
213244
bool fortran_order = false;
214245

215246
for (int idata = 0; idata < ndata; ++idata)
216247
{
217248
this->loadTensor(dir[idata] + "/rho.npy", cshape, fortran_order, container, idata, fftdim, rho);
218-
if (this->load_gamma) { this->loadTensor(dir[idata] + "/gamma.npy", cshape, fortran_order, container, idata, fftdim, gamma);
219-
}
220-
if (this->load_p)
221-
{
249+
if (this->load_gamma){
250+
this->loadTensor(dir[idata] + "/gamma.npy", cshape, fortran_order, container, idata, fftdim, gamma);
251+
}
252+
if (this->load_p){
222253
this->loadTensor(dir[idata] + "/p.npy", cshape, fortran_order, container, idata, fftdim, p);
223254
npy::LoadArrayFromNumpy(dir[idata] + "/nablaRhox.npy", cshape, fortran_order, container);
224255
nablaRho[idata][0] = torch::tensor(container).reshape({fftdim, fftdim, fftdim});
@@ -227,12 +258,15 @@ void Data::load_data_(
227258
npy::LoadArrayFromNumpy(dir[idata] + "/nablaRhoz.npy", cshape, fortran_order, container);
228259
nablaRho[idata][2] = torch::tensor(container).reshape({fftdim, fftdim, fftdim});
229260
}
230-
if (this->load_q) { this->loadTensor(dir[idata] + "/q.npy", cshape, fortran_order, container, idata, fftdim, q);
231-
}
232-
if (this->load_tanhp) { this->loadTensor(dir[idata] + "/tanhp.npy", cshape, fortran_order, container, idata, fftdim, tanhp);
233-
}
234-
if (this->load_tanhq) { this->loadTensor(dir[idata] + "/tanhq.npy", cshape, fortran_order, container, idata, fftdim, tanhq);
235-
}
261+
if (this->load_q){
262+
this->loadTensor(dir[idata] + "/q.npy", cshape, fortran_order, container, idata, fftdim, q);
263+
}
264+
if (this->load_tanhp){
265+
this->loadTensor(dir[idata] + "/tanhp.npy", cshape, fortran_order, container, idata, fftdim, tanhp);
266+
}
267+
if (this->load_tanhq){
268+
this->loadTensor(dir[idata] + "/tanhq.npy", cshape, fortran_order, container, idata, fftdim, tanhq);
269+
}
236270

237271
for (int ik = 0; ik < input.nkernel; ++ik)
238272
{
@@ -305,8 +339,9 @@ void Data::loadTensor(
305339
void Data::dumpTensor(const torch::Tensor &data, std::string filename, int nx)
306340
{
307341
std::vector<double> v(nx);
308-
for (int ir = 0; ir < nx; ++ir) { v[ir] = data[ir].item<double>();
309-
}
342+
for (int ir = 0; ir < nx; ++ir){
343+
v[ir] = data[ir].item<double>();
344+
}
310345
// std::vector<double> v(data.data_ptr<float>(), data.data_ptr<float>() + data.numel()); // this works, but only supports float tensor
311346
const long unsigned cshape[] = {(long unsigned) nx}; // shape
312347
npy::SaveArrayAsNumpy(filename, false, 1, cshape, v);

source/module_hamilt_pw/hamilt_ofdft/ml_tools/grid.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -29,12 +29,15 @@ void Grid::initGrid_(const int fftdim,
2929

3030
for (int it = 0; it < ndata; ++it)
3131
{
32-
if (cell[it] == "sc")
32+
if (cell[it] == "sc"){
3333
this->initScRecipGrid(fftdim, a[it], it, device, volume, grid, gg);
34-
else if (cell[it] == "fcc")
34+
}
35+
else if (cell[it] == "fcc"){
3536
this->initFccRecipGrid(fftdim, a[it], it, device, volume, grid, gg);
36-
else if (cell[it] == "bcc")
37+
}
38+
else if (cell[it] == "bcc"){
3739
this->initBccRecipGrid(fftdim, a[it], it, device, volume, grid, gg);
40+
}
3841
}
3942
}
4043

0 commit comments

Comments
 (0)