Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/shmem_env_defs.h
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,8 @@ SHMEM_INTERNAL_ENV_DEF(OFI_DISABLE_MULTIRAIL, bool, false, SHMEM_INTERNAL_ENV_CA
"Disable usage of multirail functionality")
SHMEM_INTERNAL_ENV_DEF(OFI_DISABLE_SINGLE_EP, bool, false, SHMEM_INTERNAL_ENV_CAT_TRANSPORT,
"Disable single endpoint resource optimization (enable separate Tx and Rx EPs)")
SHMEM_INTERNAL_ENV_DEF(OFI_TX_MUX_ENDPOINTS, long, 1, SHMEM_INTERNAL_ENV_CAT_TRANSPORT,
"Number of TX endpoints per context for transmit multiplexing (1 = disabled)")
#endif

#ifdef USE_UCX
Expand Down
128 changes: 128 additions & 0 deletions src/transport_ofi.c
Original file line number Diff line number Diff line change
Expand Up @@ -383,6 +383,7 @@ static stx_allocator_t shmem_transport_ofi_stx_allocator;

static long shmem_transport_ofi_stx_max;
static long shmem_transport_ofi_stx_threshold;
static long shmem_transport_ofi_tx_mux_max; /* Number of TX endpoints per context for muxing */

struct shmem_transport_ofi_stx_t {
struct fid_stx* stx;
Expand Down Expand Up @@ -1800,6 +1801,105 @@ static int shmem_transport_ofi_ctx_init(shmem_transport_ctx_t *ctx, int id)
ret = bind_enable_ep_resources(ctx);
OFI_CHECK_RETURN_MSG(ret, "context bind/enable endpoint failed (%s)\n", fi_strerror(errno));

/* Initialize TX muxing fields */
ctx->num_tx_eps = 1;
ctx->tx_ep_idx = 0;
ctx->tx_ep_arr = NULL;
ctx->tx_stx_idxs = NULL;

/* Create additional TX endpoints for muxing when enabled and applicable.
* TX muxing is not applied to the default context in single-endpoint mode
* since it shares the target endpoint. */
if (shmem_transport_ofi_tx_mux_max > 1 &&
!(shmem_transport_ofi_single_ep && id == SHMEM_TRANSPORT_CTX_DEFAULT_ID))
{
int n = (int) shmem_transport_ofi_tx_mux_max;

ctx->tx_ep_arr = calloc(n, sizeof(struct fid_ep *));
if (ctx->tx_ep_arr == NULL) {
RAISE_ERROR_STR("Out of memory allocating TX muxing endpoint array");
}
ctx->tx_stx_idxs = malloc(n * sizeof(int));
if (ctx->tx_stx_idxs == NULL) {
RAISE_ERROR_STR("Out of memory allocating TX muxing STX index array");
}
/* Initialize to -1 (no STX) so the cleanup loop can safely check all entries */
memset(ctx->tx_stx_idxs, -1, n * sizeof(int));

/* Slot 0 is the primary endpoint already created above */
ctx->tx_ep_arr[0] = ctx->ep;
ctx->tx_stx_idxs[0] = ctx->stx_idx;

for (int i = 1; i < n; i++) {
struct fid_ep *extra_ep;

/* Allocate a shared STX from the pool for this extra endpoint */
int extra_stx = -1;
if (shmem_transport_ofi_stx_max > 0) {
extra_stx = shmem_transport_ofi_stx_search_shared(shmem_transport_ofi_stx_threshold);
if (extra_stx < 0)
extra_stx = shmem_transport_ofi_stx_search_unused();
if (extra_stx < 0)
extra_stx = shmem_transport_ofi_stx_search_shared(-1);
if (extra_stx >= 0)
shmem_transport_ofi_stx_pool[extra_stx].ref_cnt++;
}

/* Create the additional TX endpoint */
ret = fi_endpoint(shmem_transport_ofi_domainfd, info->p_info, &extra_ep, NULL);
if (ret) {
RAISE_WARN_MSG("Additional TX endpoint %d creation failed (%s); "
"using %d TX endpoints for muxing\n",
i, fi_strerror(ret), i);
if (extra_stx >= 0)
shmem_transport_ofi_stx_pool[extra_stx].ref_cnt--;
break;
}

/* Record the endpoint and its STX index now (before bind/enable) so that
* ctx_destroy can clean up even if a subsequent operation fails. */
ctx->tx_ep_arr[i] = extra_ep;
ctx->tx_stx_idxs[i] = extra_stx;
ctx->num_tx_eps = i + 1;

/* Bind the STX (if available) */
if (extra_stx >= 0) {
ret = fi_ep_bind(extra_ep,
&shmem_transport_ofi_stx_pool[extra_stx].stx->fid, 0);
OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind STX to extra TX endpoint %d failed (%s)\n",
i, fi_strerror(ret));
}

/* Bind shared put counter */
ret = fi_ep_bind(extra_ep, &ctx->put_cntr->fid, FI_WRITE);
OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind put CNTR to extra TX endpoint %d failed (%s)\n",
i, fi_strerror(ret));

/* Bind shared get counter */
ret = fi_ep_bind(extra_ep, &ctx->get_cntr->fid, FI_READ);
OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind get CNTR to extra TX endpoint %d failed (%s)\n",
i, fi_strerror(ret));

/* Bind shared CQ */
ret = fi_ep_bind(extra_ep, &ctx->cq->fid,
FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV);
OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind CQ to extra TX endpoint %d failed (%s)\n",
i, fi_strerror(ret));

/* Bind address vector */
ret = fi_ep_bind(extra_ep, &shmem_transport_ofi_avfd->fid, 0);
OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind AV to extra TX endpoint %d failed (%s)\n",
i, fi_strerror(ret));

/* Enable the endpoint */
ret = fi_enable(extra_ep);
OFI_CHECK_RETURN_MSG(ret, "fi_enable on extra TX endpoint %d failed (%s)\n",
i, fi_strerror(ret));
}

DEBUG_MSG("TX muxing: context %d using %d TX endpoints\n", id, ctx->num_tx_eps);
}

if (ctx->options & SHMEMX_CTX_BOUNCE_BUFFER &&
shmem_transport_ofi_bounce_buffer_size > 0 &&
shmem_transport_ofi_max_bounce_buffers > 0)
Expand Down Expand Up @@ -1902,6 +2002,13 @@ int shmem_transport_init(void)
shmem_transport_ofi_put_poll_limit = shmem_internal_params.OFI_TX_POLL_LIMIT;
shmem_transport_ofi_get_poll_limit = shmem_internal_params.OFI_RX_POLL_LIMIT;

/* TX muxing: validate and store the number of TX endpoints per context */
if (shmem_internal_params.OFI_TX_MUX_ENDPOINTS < 1) {
RAISE_ERROR_MSG("OFI_TX_MUX_ENDPOINTS must be >= 1 (got %ld)\n",
shmem_internal_params.OFI_TX_MUX_ENDPOINTS);
}
shmem_transport_ofi_tx_mux_max = shmem_internal_params.OFI_TX_MUX_ENDPOINTS;

#ifdef USE_CTX_LOCK
/* In multithreaded mode, force completion polling so that threads yield
* the lock during put/get completion operations. User can still override
Expand Down Expand Up @@ -2110,6 +2217,27 @@ void shmem_transport_ctx_destroy(shmem_transport_ctx_t *ctx)
OFI_CHECK_ERROR_MSG(ret, "Context endpoint close failed (%s)\n", fi_strerror(errno));
}

/* Close additional TX endpoints created for TX muxing (index 0 is ctx->ep,
* already closed above; release their STX references as well). */
if (ctx->tx_ep_arr != NULL) {
SHMEM_MUTEX_LOCK(shmem_transport_ofi_lock);
for (int i = 1; i < ctx->num_tx_eps; i++) {
if (ctx->tx_ep_arr[i]) {
ret = fi_close(&ctx->tx_ep_arr[i]->fid);
OFI_CHECK_ERROR_MSG(ret, "Extra TX endpoint close failed (%s)\n",
fi_strerror(errno));
}
if (ctx->tx_stx_idxs != NULL && ctx->tx_stx_idxs[i] >= 0) {
shmem_transport_ofi_stx_pool[ctx->tx_stx_idxs[i]].ref_cnt--;
}
}
SHMEM_MUTEX_UNLOCK(shmem_transport_ofi_lock);
free(ctx->tx_ep_arr);
free(ctx->tx_stx_idxs);
ctx->tx_ep_arr = NULL;
ctx->tx_stx_idxs = NULL;
}

if (ctx->bounce_buffers) {
shmem_free_list_destroy(ctx->bounce_buffers);
}
Expand Down
Loading