@@ -621,6 +621,7 @@ PETSC_INTERN void MatDiagDomRatio_kokkos(Mat *input_mat, PetscIntKokkosView &is_
621
621
// Can't use the global directly within the parallel
622
622
// regions on the device
623
623
intKokkosView cf_markers_d = cf_markers_local_d;
624
+ intKokkosView cf_markers_nonlocal_d;
624
625
625
626
// ~~~~~~~~~~~~
626
627
// Get the F point local indices from cf_markers_local_d
@@ -635,37 +636,23 @@ PETSC_INTERN void MatDiagDomRatio_kokkos(Mat *input_mat, PetscIntKokkosView &is_
635
636
// ~~~~~~~~~~~~~~~
636
637
// Can now go and compute the diagonal dominance sums
637
638
// ~~~~~~~~~~~~~~~
638
- Vec x, lvec = NULL ;
639
- PetscScalarKokkosView x_d, lvec_d;
640
- PetscScalar *x_d_ptr = NULL ;
641
- PetscScalar *lvec_d_ptr = NULL ;
639
+ int *cf_markers_d_ptr = cf_markers_d.data ();
640
+ int *cf_markers_nonlocal_d_ptr = NULL ;
642
641
PetscMemType mem_type = PETSC_MEMTYPE_KOKKOS;
643
642
PetscMemType mtype;
644
643
645
644
// The off-diagonal component requires some comms which we can start now
646
645
if (mpi)
647
646
{
648
- // Basically a copy of ISGetSeqIS_SameColDist_Private
649
- /* (1) iscol is a sub-column vector of mat, pad it with '-1.' to form a full vector x */
650
- lvec = mat_mpi->lvec ;
651
- MatCreateVecs (*input_mat, &x, NULL );
652
-
653
- // Use the vecs in the scatter provided by the input mat
654
- // We're going to overwrite everything in x and lvec
655
- VecGetKokkosViewWrite (x, &x_d);
656
- VecGetKokkosViewWrite (lvec, &lvec_d);
657
-
658
- // Copy in cf markers to x
659
- Kokkos::deep_copy (x_d, cf_markers_d);
660
-
661
- x_d_ptr = x_d.data ();
662
- lvec_d_ptr = lvec_d.data ();
647
+ cf_markers_nonlocal_d = intKokkosView (" cf_markers_nonlocal_d" , cols_ao);
648
+ cf_markers_nonlocal_d_ptr = cf_markers_nonlocal_d.data ();
663
649
664
- // Start the scatter of the x - the kokkos memtype is set as PETSC_MEMTYPE_HOST or
650
+ // Start the scatter of the cf splitting - the kokkos memtype is set as PETSC_MEMTYPE_HOST or
665
651
// one of the kokkos backends like PETSC_MEMTYPE_HIP
666
- PetscSFBcastWithMemTypeBegin (mat_mpi->Mvctx , MPIU_SCALAR,
667
- mem_type, x_d_ptr,
668
- mem_type, lvec_d_ptr,
652
+ // Be careful these aren't petscints
653
+ PetscSFBcastWithMemTypeBegin (mat_mpi->Mvctx , MPI_INT,
654
+ mem_type, cf_markers_d_ptr,
655
+ mem_type, cf_markers_nonlocal_d_ptr,
669
656
MPI_REPLACE);
670
657
}
671
658
@@ -704,7 +691,7 @@ PETSC_INTERN void MatDiagDomRatio_kokkos(Mat *input_mat, PetscIntKokkosView &is_
704
691
[&](const PetscInt j, PetscScalar& thread_sum) {
705
692
706
693
// Get this local column in the input_mat
707
- PetscInt target_col = device_local_j[device_local_i[i] + j];
694
+ const PetscInt target_col = device_local_j[device_local_i[i] + j];
708
695
// Is this column fine? F_POINT == -1
709
696
if (cf_markers_d (target_col) == -1 )
710
697
{
@@ -744,11 +731,9 @@ PETSC_INTERN void MatDiagDomRatio_kokkos(Mat *input_mat, PetscIntKokkosView &is_
744
731
// Basically a copy of MatCreateSubMatrix_MPIAIJ_SameRowColDist
745
732
if (mpi)
746
733
{
747
- // Finish the x scatter
748
- PetscSFBcastEnd (mat_mpi->Mvctx , MPIU_SCALAR, x_d_ptr, lvec_d_ptr, MPI_REPLACE);
749
- // We're done with x now
750
- VecRestoreKokkosViewWrite (x, &x_d);
751
- VecDestroy (&x);
734
+ // Finish the scatter of the cf splitting
735
+ // Be careful these aren't petscints
736
+ PetscSFBcastEnd (mat_mpi->Mvctx , MPI_INT, cf_markers_d_ptr, cf_markers_nonlocal_d_ptr, MPI_REPLACE);
752
737
753
738
// ~~~~~~~~~~~~
754
739
// Get pointers to the nonlocal i,j,vals on the device
@@ -776,9 +761,9 @@ PETSC_INTERN void MatDiagDomRatio_kokkos(Mat *input_mat, PetscIntKokkosView &is_
776
761
[&](const PetscInt j, PetscScalar& thread_sum) {
777
762
778
763
// This is the non-local column we have to check is present
779
- PetscInt target_col = device_nonlocal_j[device_nonlocal_i[i] + j];
764
+ const PetscInt target_col = device_nonlocal_j[device_nonlocal_i[i] + j];
780
765
// Is this column in the input IS? F_POINT == -1
781
- if (lvec_d_ptr[ target_col] < 0.0 )
766
+ if (cf_markers_nonlocal_d ( target_col) == - 1 )
782
767
{
783
768
// Get the abs value of the entry
784
769
thread_sum += Kokkos::abs (device_nonlocal_vals[device_nonlocal_i[i] + j]);
@@ -794,8 +779,6 @@ PETSC_INTERN void MatDiagDomRatio_kokkos(Mat *input_mat, PetscIntKokkosView &is_
794
779
});
795
780
});
796
781
}
797
-
798
- VecRestoreKokkosViewWrite (lvec, &lvec_d);
799
782
}
800
783
801
784
// ~~~~~~~~~~~~~
0 commit comments