Skip to content

Commit

Permalink
mpl: add MPL_gpu_stream_synchronize
Browse files Browse the repository at this point in the history
Useful for GPU stream based MPI extensions.
  • Loading branch information
hzhou committed Jun 5, 2024
1 parent 01cb171 commit 36ade32
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 2 deletions.
3 changes: 2 additions & 1 deletion src/mpl/include/mpl_gpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ typedef enum {
MPL_GPU_COPY_H2D,
MPL_GPU_COPY_D2D_INCOMING, /* copy from remote to local */
MPL_GPU_COPY_D2D_OUTGOING, /* copy from local to remote */
MPL_GPU_COPY_DIRECTION_NONE, /* copy in any direction and to/from any buffer type */
MPL_GPU_COPY_DIRECTION_NONE, /* copy in any direction and to/from any buffer type */
} MPL_gpu_copy_direction_t;

#define MPL_GPU_COPY_DIRECTION_TYPES 4
Expand Down Expand Up @@ -152,6 +152,7 @@ int MPL_gpu_launch_hostfn(MPL_gpu_stream_t stream, MPL_gpu_hostfn fn, void *data
bool MPL_gpu_stream_is_valid(MPL_gpu_stream_t stream);
void MPL_gpu_enqueue_trigger(MPL_gpu_event_t * var, MPL_gpu_stream_t stream);
void MPL_gpu_enqueue_wait(MPL_gpu_event_t * var, MPL_gpu_stream_t stream);
int MPL_gpu_stream_synchronize(MPL_gpu_stream_t stream);

/* the synchronization event has the similar semantics as completion counter,
* init to a count, then each completion decrement it by 1. */
Expand Down
7 changes: 7 additions & 0 deletions src/mpl/src/gpu/mpl_gpu_cuda.c
Original file line number Diff line number Diff line change
Expand Up @@ -589,3 +589,10 @@ bool MPL_gpu_stream_is_valid(MPL_gpu_stream_t stream)
result = cudaStreamQuery(stream);
return (result != cudaErrorInvalidResourceHandle);
}

int MPL_gpu_stream_synchronize(cudaStream_t stream)
{
cudaError_t result;
result = cudaStreamSynchronize(stream);
return result;
}
5 changes: 5 additions & 0 deletions src/mpl/src/gpu/mpl_gpu_fallback.c
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,11 @@ int MPL_gpu_launch_hostfn(int stream, MPL_gpu_hostfn fn, void *data)
return -1;
}

int MPL_gpu_stream_synchronize(int stream)
{
return -1;
}

bool MPL_gpu_stream_is_valid(MPL_gpu_stream_t stream)
{
return false;
Expand Down
7 changes: 7 additions & 0 deletions src/mpl/src/gpu/mpl_gpu_hip.c
Original file line number Diff line number Diff line change
Expand Up @@ -530,6 +530,13 @@ int MPL_gpu_launch_hostfn(hipStream_t stream, MPL_gpu_hostfn fn, void *data)
return result;
}

int MPL_gpu_stream_synchronize(hipStream_t stream)
{
hipError_t result;
result = hipStreamSynchronize(stream);
return result;
}

/* ---- */
bool MPL_gpu_stream_is_valid(MPL_gpu_stream_t stream)
{
Expand Down
7 changes: 6 additions & 1 deletion src/mpl/src/gpu/mpl_gpu_ze.c
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ static int gpu_initialized = 0;
static uint32_t device_count; /* Counts all local devices, does not include subdevices */
static uint32_t local_ze_device_count; /* Counts all local devices and subdevices */
static uint32_t global_ze_device_count; /* Counts all global devices and subdevices */
static int max_dev_id; /* Does not include subdevices */
static int max_dev_id; /* Does not include subdevices */
static int max_subdev_id;
static char **device_list = NULL;
static int *engine_conversion = NULL;
Expand Down Expand Up @@ -2460,6 +2460,11 @@ int MPL_gpu_launch_hostfn(int stream, MPL_gpu_hostfn fn, void *data)
return -1;
}

int MPL_gpu_stream_synchronize(int stream)
{
return -1;
}

bool MPL_gpu_stream_is_valid(MPL_gpu_stream_t stream)
{
return false;
Expand Down

0 comments on commit 36ade32

Please sign in to comment.