Skip to content

Commit 51f0e00

Browse files
authored
Merge pull request #778 from mkstoyanov/fix_wavelet_changegpu
fix bug when wavelet grids change the gpu context
2 parents 513b72c + a6009ff commit 51f0e00

File tree

2 files changed

+4
-1
lines changed

2 files changed

+4
-1
lines changed

InterfacePython/TasmanianSG.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2140,7 +2140,7 @@ def enableAcceleration(self, sAccelerationType, iGPUID = None):
21402140
else:
21412141
if ((iGPUID < 0) or (iGPUID >= self.getNumGPUs())):
21422142
raise TasmanianInputError("iGPUID", "ERROR: invalid GPU ID number")
2143-
pLibTSG.tsgEnableAcceleration(self.pGrid, bytes(sAccelerationType, encoding='utf8'), iGPUID)
2143+
pLibTSG.tsgEnableAccelerationGPU(self.pGrid, bytes(sAccelerationType, encoding='utf8'), iGPUID)
21442144

21452145
def getAccelerationType(self):
21462146
'''

SparseGrids/tsgGridWavelet.cpp

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -395,6 +395,7 @@ void GridWavelet::buildInterpolationMatrix() const{
395395

396396
if (order == 1 and TasSparse::WaveletBasisMatrix::useDense(acceleration, num_points)
397397
and acceleration->useKernels()){ // using the GPU algorithm
398+
acceleration->setDevice();
398399
std::vector<double> pnts(Utils::size_mult(num_dimensions, num_points));
399400
getPoints(pnts.data());
400401
GpuVector<double> gpu_pnts(acceleration, pnts);
@@ -1040,6 +1041,8 @@ void GridWavelet::updateAccelerationData(AccelerationContext::ChangeType change)
10401041
case AccelerationContext::change_gpu_device:
10411042
gpu_cache.reset();
10421043
gpu_cachef.reset();
1044+
if (inter_matrix.getNumRows() > 0)
1045+
inter_matrix = TasSparse::WaveletBasisMatrix();
10431046
break;
10441047
case AccelerationContext::change_sparse_dense:
10451048
if ((acceleration->algorithm_select == AccelerationContext::algorithm_dense and inter_matrix.isSparse())

0 commit comments

Comments
 (0)