Skip to content

Commit

Permalink
binding/c: add stream synchronize in pt2p2 and coll
Browse files Browse the repository at this point in the history
If a stream communicator is backed with a gpu stream and we call regular
pt2pt and collective functions instead of the enqueue functions, we
should run streamsynchronize to ensure the buffer is cleared for the
pt2pt and coll operations.

There is no need to stream synchronize for completion functions, i.e.
Test and Wait, since the buffer safety is asserted by the nonblocking
semantics and offloading calls issued after the completion function are
safe to use the buffer.

Amend: add rma operatons too. Following the same reason, we don't need
stream synchronize for Win_fence, Win_lock, etc. The RMA synchronize
calls are essentially the host-side counterpart to stream synchronize.
  • Loading branch information
hzhou committed Jun 5, 2024
1 parent 8018793 commit 61cbd22
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 21 deletions.
54 changes: 37 additions & 17 deletions src/binding/c/coll_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

MPI_Allgather:
.desc: Gathers data from all tasks and distribute the combined data to all tasks
.extra: threadcomm
.extra: threadcomm, streamsync
/*
Notes:
The MPI standard (1.0 and 1.1) says that
Expand Down Expand Up @@ -35,7 +35,7 @@ MPI_Allgather_init:

MPI_Allgatherv:
.desc: Gathers data from all tasks and deliver the combined data to all tasks
.extra: threadcomm
.extra: threadcomm, streamsync
/*
Notes:
The MPI standard (1.0 and 1.1) says that
Expand Down Expand Up @@ -75,7 +75,7 @@ MPI_Allgatherv_init:
MPI_Allreduce:
.desc: Combines values from all processes and distributes the result back to all processes
.docnotes: collops
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (MPIR_is_self_comm(comm_ptr)) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -90,7 +90,7 @@ MPI_Allreduce_init:

MPI_Alltoall:
.desc: Sends data from all to all processes
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && recvcount == 0) {
goto fn_exit;
Expand All @@ -102,21 +102,21 @@ MPI_Alltoall_init:

MPI_Alltoallv:
.desc: Sends data from all to all processes; each process may send a different amount of data and provide displacements for the input and output data.
.extra: threadcomm
.extra: threadcomm, streamsync

MPI_Alltoallv_init:
.desc: Create a persistent request for alltoallv.

MPI_Alltoallw:
.desc: Generalized all-to-all communication allowing different datatypes, counts, and displacements for each partner
.extra: threadcomm
.extra: threadcomm, streamsync

MPI_Alltoallw_init:
.desc: Create a persistent request for alltoallw.

MPI_Barrier:
.desc: Blocks until all processes in the communicator have reached this routine.
.extra: threadcomm
.extra: threadcomm, streamsync
/*
Notes:
Blocks the caller until all processes in the communicator have called it;
Expand All @@ -134,7 +134,7 @@ MPI_Barrier_init:

MPI_Bcast:
.desc: Broadcasts a message from the process with rank "root" to all other processes of the communicator
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (count == 0 || MPIR_is_self_comm(comm_ptr)) {
goto fn_exit;
Expand Down Expand Up @@ -165,7 +165,7 @@ MPI_Exscan_init:

MPI_Gather:
.desc: Gathers together values from a group of processes
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
if ((MPIR_Comm_rank(comm_ptr) == root && recvcount == 0) || (MPIR_Comm_rank(comm_ptr) != root && sendcount == 0)) {
Expand All @@ -179,13 +179,14 @@ MPI_Gather_init:

MPI_Gatherv:
.desc: Gathers into specified locations from all processes in a group
.extra: threadcomm
.extra: threadcomm, streamsync

MPI_Gatherv_init:
.desc: Create a persistent request for gatherv.

MPI_Iallgather:
.desc: Gathers data from all tasks and distribute the combined data to all tasks in a nonblocking way
.extra: streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
if ((sendcount == 0 && sendbuf != MPI_IN_PLACE) || recvcount == 0) {
Expand All @@ -198,6 +199,7 @@ MPI_Iallgather:

MPI_Iallgatherv:
.desc: Gathers data from all tasks and deliver the combined data to all tasks in a nonblocking way
.extra: streamsync
{ -- early_return --
if (MPIR_is_self_comm(comm_ptr)) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -214,6 +216,7 @@ MPI_Iallgatherv:

MPI_Iallreduce:
.desc: Combines values from all processes and distributes the result back to all processes in a nonblocking way
.extra: streamsync
{ -- early_return --
if (MPIR_is_self_comm(comm_ptr)) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -227,6 +230,7 @@ MPI_Iallreduce:

MPI_Ialltoall:
.desc: Sends data from all to all processes in a nonblocking way
.extra: streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && recvcount == 0) {
MPIR_Request *request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);
Expand All @@ -237,12 +241,15 @@ MPI_Ialltoall:

MPI_Ialltoallv:
.desc: Sends data from all to all processes in a nonblocking way; each process may send a different amount of data and provide displacements for the input and output data.
.extra: streamsync

MPI_Ialltoallw:
.desc: Nonblocking generalized all-to-all communication allowing different datatypes, counts, and displacements for each partner
.extra: streamsync

MPI_Ibarrier:
.desc: Notifies the process that it has reached the barrier and returns immediately
.extra: streamsync
/*
Notes:
MPI_Ibarrier is a nonblocking version of MPI_barrier. By calling MPI_Ibarrier,
Expand All @@ -264,6 +271,7 @@ MPI_Ibarrier:

MPI_Ibcast:
.desc: Broadcasts a message from the process with rank "root" to all other processes of the communicator in a nonblocking way
.extra: streamsync
{ -- early_return --
if (count == 0 || MPIR_is_self_comm(comm_ptr)) {
MPIR_Request *request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);
Expand All @@ -275,7 +283,7 @@ MPI_Ibcast:
MPI_Iexscan:
.desc: Computes the exclusive scan (partial reductions) of data on a collection of processes in a nonblocking way
.docnotes: collops
.extra: errtest_comm_intra
.extra: errtest_comm_intra, streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && count == 0) {
MPIR_Request *request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);
Expand All @@ -286,6 +294,7 @@ MPI_Iexscan:

MPI_Igather:
.desc: Gathers together values from a group of processes in a nonblocking way
.extra: streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
if ((MPIR_Comm_rank(comm_ptr) == root && recvcount == 0) || (MPIR_Comm_rank(comm_ptr) != root && sendcount == 0)) {
Expand All @@ -298,24 +307,31 @@ MPI_Igather:

MPI_Igatherv:
.desc: Gathers into specified locations from all processes in a group in a nonblocking way
.extra: streamsync

MPI_Ineighbor_allgather:
.desc: Nonblocking version of MPI_Neighbor_allgather.
.extra: streamsync

MPI_Ineighbor_allgatherv:
.desc: Nonblocking version of MPI_Neighbor_allgatherv.
.extra: streamsync

MPI_Ineighbor_alltoall:
.desc: Nonblocking version of MPI_Neighbor_alltoall.
.extra: streamsync

MPI_Ineighbor_alltoallv:
.desc: Nonblocking version of MPI_Neighbor_alltoallv.
.extra: streamsync

MPI_Ineighbor_alltoallw:
.desc: Nonblocking version of MPI_Neighbor_alltoallw.
.extra: streamsync

MPI_Ireduce:
.desc: Reduces values on all processes to a single value in a nonblocking way
.extra: streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && (count == 0 || MPIR_is_self_comm(comm_ptr))) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -329,6 +345,7 @@ MPI_Ireduce:

MPI_Ireduce_scatter:
.desc: Combines values and scatters the results in a nonblocking way
.extra: streamsync
{ -- early_return --
if (MPIR_is_self_comm(comm_ptr)) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -342,6 +359,7 @@ MPI_Ireduce_scatter:

MPI_Ireduce_scatter_block:
.desc: Combines values and scatters the results in a nonblocking way
.extra: streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && (recvcount == 0 || MPIR_is_self_comm(comm_ptr))) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -356,7 +374,7 @@ MPI_Ireduce_scatter_block:
MPI_Iscan:
.desc: Computes the scan (partial reductions) of data on a collection of processes in a nonblocking way
.docnotes: collops
.extra: errtest_comm_intra
.extra: errtest_comm_intra, streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && count == 0) {
MPIR_Request *request_ptr = MPIR_Request_create_complete(MPIR_REQUEST_KIND__COLL);
Expand All @@ -367,6 +385,7 @@ MPI_Iscan:

MPI_Iscatter:
.desc: Sends data from one process to all other processes in a communicator in a nonblocking way
.extra: streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
if ((MPIR_Comm_rank(comm_ptr) == root && sendcount == 0) || (MPIR_Comm_rank(comm_ptr) != root && recvcount == 0)) {
Expand All @@ -379,6 +398,7 @@ MPI_Iscatter:

MPI_Iscatterv:
.desc: Scatters a buffer in parts to all processes in a communicator in a nonblocking way
.extra: streamsync

MPI_Neighbor_allgather:
.desc: Gathers data from all neighboring processes and distribute the combined data to all neighboring processes
Expand Down Expand Up @@ -422,7 +442,7 @@ MPI_Neighbor_alltoallw_init:
MPI_Reduce:
.desc: Reduces values on all processes to a single value
.docnotes: collops
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (count == 0 && comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
goto fn_exit;
Expand Down Expand Up @@ -456,7 +476,7 @@ MPI_Reduce_local:
MPI_Reduce_scatter:
.desc: Combines values and scatters the results
.docnotes: collops
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (MPIR_is_self_comm(comm_ptr)) {
if (sendbuf != MPI_IN_PLACE) {
Expand All @@ -472,7 +492,7 @@ MPI_Reduce_scatter_init:
MPI_Reduce_scatter_block:
.desc: Combines values and scatters the results
.docnotes: collops
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM && (recvcount == 0 || MPIR_is_self_comm(comm_ptr))) {
if (sendbuf != MPI_IN_PLACE) {
Expand Down Expand Up @@ -500,7 +520,7 @@ MPI_Scan_init:

MPI_Scatter:
.desc: Sends data from one process to all other processes in a communicator
.extra: threadcomm
.extra: threadcomm, streamsync
{ -- early_return --
if (comm_ptr->comm_kind == MPIR_COMM_KIND__INTRACOMM) {
if ((MPIR_Comm_rank(comm_ptr) == root && sendcount == 0) || (MPIR_Comm_rank(comm_ptr) != root && recvcount == 0)) {
Expand All @@ -514,7 +534,7 @@ MPI_Scatter_init:

MPI_Scatterv:
.desc: Scatters a buffer in parts to all processes in a communicator
.extra: threadcomm
.extra: threadcomm, streamsync

MPI_Scatterv_init:
.desc: Create a persistent request for scatterv.
20 changes: 16 additions & 4 deletions src/binding/c/pt2pt_api.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ MPI_Bsend:
.desc: Basic send with user-provided buffering
.seealso: MPI_Buffer_attach, MPI_Ibsend, MPI_Bsend_init
.earlyreturn: pt2pt_proc_null
.extra: streamsync
/*
Notes:
This send is provided as a convenience function; it allows the user to
Expand Down Expand Up @@ -174,6 +175,7 @@ MPI_Comm_iflush_buffer:
MPI_Ibsend:
.desc: Starts a nonblocking buffered send
.earlyreturn: pt2pt_proc_null
.extra: streamsync
{
mpi_errno = MPIR_Bsend_isend(buf, count, datatype, dest, tag, comm_ptr, NULL);
if (mpi_errno)
Expand Down Expand Up @@ -202,6 +204,7 @@ MPI_Improbe:

MPI_Imrecv:
.desc: Nonblocking receive of message matched by MPI_Mprobe or MPI_Improbe.
.extra: streamsync
{ -- early_return --
if (message_ptr == NULL || message_ptr->handle == MPIR_REQUEST_NULL_RECV) {
MPIR_Request *rreq;
Expand Down Expand Up @@ -235,7 +238,7 @@ MPI_Iprobe:
MPI_Irecv:
.desc: Begins a nonblocking receive
.earlyreturn: pt2pt_proc_null
.extra: threadcomm
.extra: threadcomm, streamsync
{
MPIR_Request *request_ptr = NULL;

Expand All @@ -258,6 +261,7 @@ MPI_Irecv:
MPI_Irsend:
.desc: Starts a nonblocking ready send
.earlyreturn: pt2pt_proc_null
.extra: streamsync
{
MPIR_Request *request_ptr = NULL;

Expand All @@ -275,7 +279,7 @@ MPI_Irsend:
MPI_Isend:
.desc: Begins a nonblocking send
.earlyreturn: pt2pt_proc_null
.extra: threadcomm
.extra: threadcomm, streamsync
{
MPIR_Request *request_ptr = NULL;

Expand All @@ -293,6 +297,7 @@ MPI_Isend:
MPI_Issend:
.desc: Starts a nonblocking synchronous send
.earlyreturn: pt2pt_proc_null
.extra: streamsync
{
MPIR_Request *request_ptr = NULL;

Expand Down Expand Up @@ -323,6 +328,7 @@ MPI_Mprobe:

MPI_Mrecv:
.desc: Blocking receive of message matched by MPI_Mprobe or MPI_Improbe.
.extra: streamsync
{ -- early_return --
if (message_ptr == NULL || message_ptr->handle == MPIR_REQUEST_NULL_RECV) {
/* treat as though MPI_MESSAGE_NO_PROC was passed */
Expand Down Expand Up @@ -358,7 +364,7 @@ MPI_Probe:
MPI_Recv:
.desc: Blocking receive for a message
.earlyreturn: pt2pt_proc_null
.extra: threadcomm
.extra: threadcomm, streamsync
/*
Notes:
The 'count' argument indicates the maximum length of a message; the actual
Expand Down Expand Up @@ -401,6 +407,7 @@ MPI_Recv_init:
MPI_Rsend:
.desc: Blocking ready send
.earlyreturn: pt2pt_proc_null
.extra: streamsync
{
MPIR_Request *request_ptr = NULL;

Expand Down Expand Up @@ -437,7 +444,7 @@ MPI_Send:
.desc: Performs a blocking send
.seealso: MPI_Isend, MPI_Bsend
.earlyreturn: pt2pt_proc_null
.extra: threadcomm
.extra: threadcomm, streamsync
/*
Notes:
This routine may block until the message is received by the destination
Expand Down Expand Up @@ -475,10 +482,12 @@ MPI_Send_init:

MPI_Sendrecv:
.desc: Sends and receives a message
.extra: streamsync

MPI_Sendrecv_replace:
.desc: Sends and receives using a single buffer
.decl: MPIR_Sendrecv_replace_impl
.extra: streamsync
{
#if defined(MPID_Sendrecv_replace)
mpi_errno = MPID_Sendrecv_replace(buf, count, datatype, dest,
Expand Down Expand Up @@ -513,6 +522,7 @@ MPI_Session_iflush_buffer:
MPI_Ssend:
.desc: Blocking synchronous send
.earlyreturn: pt2pt_proc_null
.extra: streamsync
{
MPIR_Request *request_ptr = NULL;

Expand Down Expand Up @@ -546,6 +556,8 @@ MPI_Ssend_init:

MPI_Isendrecv:
.desc: Starts a nonblocking send and receive
.extra: streamsync

MPI_Isendrecv_replace:
.desc: Starts a nonblocking send and receive with a single buffer
.extra: streamsync
Loading

0 comments on commit 61cbd22

Please sign in to comment.