Skip to content

Commit

Permalink
Merge pull request lammps#4092 from stanmoore1/comm_tiled
Browse files Browse the repository at this point in the history
Add better Kokkos support for comm_style tiled
  • Loading branch information
akohlmey authored Mar 1, 2024
2 parents b0ca503 + 6f03b22 commit 554f53d
Show file tree
Hide file tree
Showing 59 changed files with 1,328 additions and 1,402 deletions.
310 changes: 10 additions & 300 deletions src/KOKKOS/atom_vec_angle_kokkos.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -186,302 +186,13 @@ void AtomVecAngleKokkos::sort_kokkos(Kokkos::BinSort<KeyViewType, BinOp> &Sorter

/* ---------------------------------------------------------------------- */

template<class DeviceType,int PBC_FLAG,int TRICLINIC>
struct AtomVecAngleKokkos_PackComm {
typedef DeviceType device_type;

typename ArrayTypes<DeviceType>::t_x_array_randomread _x;
typename ArrayTypes<DeviceType>::t_xfloat_2d_um _buf;
typename ArrayTypes<DeviceType>::t_int_2d_const _list;
const int _iswap;
X_FLOAT _xprd,_yprd,_zprd,_xy,_xz,_yz;
X_FLOAT _pbc[6];

AtomVecAngleKokkos_PackComm(
const typename DAT::tdual_x_array &x,
const typename DAT::tdual_xfloat_2d &buf,
const typename DAT::tdual_int_2d &list,
const int & iswap,
const X_FLOAT &xprd, const X_FLOAT &yprd, const X_FLOAT &zprd,
const X_FLOAT &xy, const X_FLOAT &xz, const X_FLOAT &yz, const int* const pbc):
_x(x.view<DeviceType>()),_list(list.view<DeviceType>()),_iswap(iswap),
_xprd(xprd),_yprd(yprd),_zprd(zprd),
_xy(xy),_xz(xz),_yz(yz) {
const size_t maxsend = (buf.view<DeviceType>().extent(0)
*buf.view<DeviceType>().extent(1))/3;
const size_t elements = 3;
buffer_view<DeviceType>(_buf,buf,maxsend,elements);
_pbc[0] = pbc[0]; _pbc[1] = pbc[1]; _pbc[2] = pbc[2];
_pbc[3] = pbc[3]; _pbc[4] = pbc[4]; _pbc[5] = pbc[5];
};

KOKKOS_INLINE_FUNCTION
void operator() (const int& i) const {
const int j = _list(_iswap,i);
if (PBC_FLAG == 0) {
_buf(i,0) = _x(j,0);
_buf(i,1) = _x(j,1);
_buf(i,2) = _x(j,2);
} else {
if (TRICLINIC == 0) {
_buf(i,0) = _x(j,0) + _pbc[0]*_xprd;
_buf(i,1) = _x(j,1) + _pbc[1]*_yprd;
_buf(i,2) = _x(j,2) + _pbc[2]*_zprd;
} else {
_buf(i,0) = _x(j,0) + _pbc[0]*_xprd + _pbc[5]*_xy + _pbc[4]*_xz;
_buf(i,1) = _x(j,1) + _pbc[1]*_yprd + _pbc[3]*_yz;
_buf(i,2) = _x(j,2) + _pbc[2]*_zprd;
}
}
}
};

/* ---------------------------------------------------------------------- */

int AtomVecAngleKokkos::pack_comm_kokkos(const int &n,
const DAT::tdual_int_2d &list,
const int & iswap,
const DAT::tdual_xfloat_2d &buf,
const int &pbc_flag,
const int* const pbc)
{
// Check whether to always run forward communication on the host
// Choose correct forward PackComm kernel

if (commKK->forward_comm_on_host) {
atomKK->sync(Host,X_MASK);
if (pbc_flag) {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackComm<LMPHostType,1,1> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackComm<LMPHostType,1,0> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
} else {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackComm<LMPHostType,0,1> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackComm<LMPHostType,0,0> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
}
} else {
atomKK->sync(Device,X_MASK);
if (pbc_flag) {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackComm<LMPDeviceType,1,1> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackComm<LMPDeviceType,1,0> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
} else {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackComm<LMPDeviceType,0,1> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackComm<LMPDeviceType,0,0> f(atomKK->k_x,buf,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
}
}

return n*size_forward;
}

/* ---------------------------------------------------------------------- */

template<class DeviceType,int PBC_FLAG,int TRICLINIC>
struct AtomVecAngleKokkos_PackCommSelf {
typedef DeviceType device_type;

typename ArrayTypes<DeviceType>::t_x_array_randomread _x;
typename ArrayTypes<DeviceType>::t_x_array _xw;
int _nfirst;
typename ArrayTypes<DeviceType>::t_int_2d_const _list;
const int _iswap;
X_FLOAT _xprd,_yprd,_zprd,_xy,_xz,_yz;
X_FLOAT _pbc[6];

AtomVecAngleKokkos_PackCommSelf(
const typename DAT::tdual_x_array &x,
const int &nfirst,
const typename DAT::tdual_int_2d &list,
const int & iswap,
const X_FLOAT &xprd, const X_FLOAT &yprd, const X_FLOAT &zprd,
const X_FLOAT &xy, const X_FLOAT &xz, const X_FLOAT &yz, const int* const pbc):
_x(x.view<DeviceType>()),_xw(x.view<DeviceType>()),_nfirst(nfirst),_list(list.view<DeviceType>()),_iswap(iswap),
_xprd(xprd),_yprd(yprd),_zprd(zprd),
_xy(xy),_xz(xz),_yz(yz) {
_pbc[0] = pbc[0]; _pbc[1] = pbc[1]; _pbc[2] = pbc[2];
_pbc[3] = pbc[3]; _pbc[4] = pbc[4]; _pbc[5] = pbc[5];
};

KOKKOS_INLINE_FUNCTION
void operator() (const int& i) const {
const int j = _list(_iswap,i);
if (PBC_FLAG == 0) {
_xw(i+_nfirst,0) = _x(j,0);
_xw(i+_nfirst,1) = _x(j,1);
_xw(i+_nfirst,2) = _x(j,2);
} else {
if (TRICLINIC == 0) {
_xw(i+_nfirst,0) = _x(j,0) + _pbc[0]*_xprd;
_xw(i+_nfirst,1) = _x(j,1) + _pbc[1]*_yprd;
_xw(i+_nfirst,2) = _x(j,2) + _pbc[2]*_zprd;
} else {
_xw(i+_nfirst,0) = _x(j,0) + _pbc[0]*_xprd + _pbc[5]*_xy + _pbc[4]*_xz;
_xw(i+_nfirst,1) = _x(j,1) + _pbc[1]*_yprd + _pbc[3]*_yz;
_xw(i+_nfirst,2) = _x(j,2) + _pbc[2]*_zprd;
}
}

}
};

/* ---------------------------------------------------------------------- */

int AtomVecAngleKokkos::pack_comm_self(const int &n, const DAT::tdual_int_2d &list,
const int & iswap,
const int nfirst, const int &pbc_flag,
const int* const pbc) {
if (commKK->forward_comm_on_host) {
atomKK->sync(Host,X_MASK);
atomKK->modified(Host,X_MASK);
if (pbc_flag) {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackCommSelf<LMPHostType,1,1>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackCommSelf<LMPHostType,1,0>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
} else {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackCommSelf<LMPHostType,0,1>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackCommSelf<LMPHostType,0,0>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
}
} else {
atomKK->sync(Device,X_MASK);
atomKK->modified(Device,X_MASK);
if (pbc_flag) {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackCommSelf<LMPDeviceType,1,1>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackCommSelf<LMPDeviceType,1,0>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
} else {
if (domain->triclinic) {
struct AtomVecAngleKokkos_PackCommSelf<LMPDeviceType,0,1>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
} else {
struct AtomVecAngleKokkos_PackCommSelf<LMPDeviceType,0,0>
f(atomKK->k_x,nfirst,list,iswap,
domain->xprd,domain->yprd,domain->zprd,
domain->xy,domain->xz,domain->yz,pbc);
Kokkos::parallel_for(n,f);
}
}
}
return n*3;
}

/* ---------------------------------------------------------------------- */

template<class DeviceType>
struct AtomVecAngleKokkos_UnpackComm {
typedef DeviceType device_type;

typename ArrayTypes<DeviceType>::t_x_array _x;
typename ArrayTypes<DeviceType>::t_xfloat_2d_const _buf;
int _first;

AtomVecAngleKokkos_UnpackComm(
const typename DAT::tdual_x_array &x,
const typename DAT::tdual_xfloat_2d &buf,
const int& first):_x(x.view<DeviceType>()),_buf(buf.view<DeviceType>()),
_first(first) {};

KOKKOS_INLINE_FUNCTION
void operator() (const int& i) const {
_x(i+_first,0) = _buf(i,0);
_x(i+_first,1) = _buf(i,1);
_x(i+_first,2) = _buf(i,2);
}
};

/* ---------------------------------------------------------------------- */

void AtomVecAngleKokkos::unpack_comm_kokkos(const int &n, const int &first,
const DAT::tdual_xfloat_2d &buf) {
if (commKK->forward_comm_on_host) {
atomKK->sync(Host,X_MASK);
atomKK->modified(Host,X_MASK);
struct AtomVecAngleKokkos_UnpackComm<LMPHostType> f(atomKK->k_x,buf,first);
Kokkos::parallel_for(n,f);
} else {
atomKK->sync(Device,X_MASK);
atomKK->modified(Device,X_MASK);
struct AtomVecAngleKokkos_UnpackComm<LMPDeviceType> f(atomKK->k_x,buf,first);
Kokkos::parallel_for(n,f);
}
}

/* ---------------------------------------------------------------------- */

template<class DeviceType,int PBC_FLAG>
struct AtomVecAngleKokkos_PackBorder {
typedef DeviceType device_type;
typedef ArrayTypes<DeviceType> AT;

typename AT::t_xfloat_2d _buf;
const typename AT::t_int_2d_const _list;
const int _iswap;
const typename AT::t_int_1d_const _list;
const typename AT::t_x_array_randomread _x;
const typename AT::t_tagint_1d _tag;
const typename AT::t_int_1d _type;
Expand All @@ -491,21 +202,20 @@ struct AtomVecAngleKokkos_PackBorder {

AtomVecAngleKokkos_PackBorder(
const typename AT::t_xfloat_2d &buf,
const typename AT::t_int_2d_const &list,
const int & iswap,
const typename AT::t_int_1d_const &list,
const typename AT::t_x_array &x,
const typename AT::t_tagint_1d &tag,
const typename AT::t_int_1d &type,
const typename AT::t_int_1d &mask,
const typename AT::t_tagint_1d &molecule,
const X_FLOAT &dx, const X_FLOAT &dy, const X_FLOAT &dz):
_buf(buf),_list(list),_iswap(iswap),
_buf(buf),_list(list),
_x(x),_tag(tag),_type(type),_mask(mask),_molecule(molecule),
_dx(dx),_dy(dy),_dz(dz) {}

KOKKOS_INLINE_FUNCTION
void operator() (const int& i) const {
const int j = _list(_iswap,i);
const int j = _list(i);
if (PBC_FLAG == 0) {
_buf(i,0) = _x(j,0);
_buf(i,1) = _x(j,1);
Expand All @@ -528,8 +238,8 @@ struct AtomVecAngleKokkos_PackBorder {

/* ---------------------------------------------------------------------- */

int AtomVecAngleKokkos::pack_border_kokkos(int n, DAT::tdual_int_2d k_sendlist,
DAT::tdual_xfloat_2d buf,int iswap,
int AtomVecAngleKokkos::pack_border_kokkos(int n, DAT::tdual_int_1d k_sendlist,
DAT::tdual_xfloat_2d buf,
int pbc_flag, int *pbc, ExecutionSpace space)
{
X_FLOAT dx,dy,dz;
Expand All @@ -547,12 +257,12 @@ int AtomVecAngleKokkos::pack_border_kokkos(int n, DAT::tdual_int_2d k_sendlist,
if (space==Host) {
AtomVecAngleKokkos_PackBorder<LMPHostType,1> f(
buf.view<LMPHostType>(), k_sendlist.view<LMPHostType>(),
iswap,h_x,h_tag,h_type,h_mask,h_molecule,dx,dy,dz);
h_x,h_tag,h_type,h_mask,h_molecule,dx,dy,dz);
Kokkos::parallel_for(n,f);
} else {
AtomVecAngleKokkos_PackBorder<LMPDeviceType,1> f(
buf.view<LMPDeviceType>(), k_sendlist.view<LMPDeviceType>(),
iswap,d_x,d_tag,d_type,d_mask,d_molecule,dx,dy,dz);
d_x,d_tag,d_type,d_mask,d_molecule,dx,dy,dz);
Kokkos::parallel_for(n,f);
}

Expand All @@ -561,12 +271,12 @@ int AtomVecAngleKokkos::pack_border_kokkos(int n, DAT::tdual_int_2d k_sendlist,
if (space==Host) {
AtomVecAngleKokkos_PackBorder<LMPHostType,0> f(
buf.view<LMPHostType>(), k_sendlist.view<LMPHostType>(),
iswap,h_x,h_tag,h_type,h_mask,h_molecule,dx,dy,dz);
h_x,h_tag,h_type,h_mask,h_molecule,dx,dy,dz);
Kokkos::parallel_for(n,f);
} else {
AtomVecAngleKokkos_PackBorder<LMPDeviceType,0> f(
buf.view<LMPDeviceType>(), k_sendlist.view<LMPDeviceType>(),
iswap,d_x,d_tag,d_type,d_mask,d_molecule,dx,dy,dz);
d_x,d_tag,d_type,d_mask,d_molecule,dx,dy,dz);
Kokkos::parallel_for(n,f);
}
}
Expand Down
13 changes: 2 additions & 11 deletions src/KOKKOS/atom_vec_angle_kokkos.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,17 +35,8 @@ class AtomVecAngleKokkos : public AtomVecKokkos, public AtomVecAngle {
void grow(int) override;
void grow_pointers() override;
void sort_kokkos(Kokkos::BinSort<KeyViewType, BinOp> &Sorter) override;
int pack_comm_kokkos(const int &n, const DAT::tdual_int_2d &k_sendlist,
const int & iswap,
const DAT::tdual_xfloat_2d &buf,
const int &pbc_flag, const int pbc[]) override;
void unpack_comm_kokkos(const int &n, const int &nfirst,
const DAT::tdual_xfloat_2d &buf) override;
int pack_comm_self(const int &n, const DAT::tdual_int_2d &list,
const int & iswap, const int nfirst,
const int &pbc_flag, const int pbc[]) override;
int pack_border_kokkos(int n, DAT::tdual_int_2d k_sendlist,
DAT::tdual_xfloat_2d buf,int iswap,
int pack_border_kokkos(int n, DAT::tdual_int_1d k_sendlist,
DAT::tdual_xfloat_2d buf,
int pbc_flag, int *pbc, ExecutionSpace space) override;
void unpack_border_kokkos(const int &n, const int &nfirst,
const DAT::tdual_xfloat_2d &buf,
Expand Down
Loading

0 comments on commit 554f53d

Please sign in to comment.