diff --git a/src/shmem_env_defs.h b/src/shmem_env_defs.h index c0a27a30..fe37be71 100644 --- a/src/shmem_env_defs.h +++ b/src/shmem_env_defs.h @@ -111,6 +111,8 @@ SHMEM_INTERNAL_ENV_DEF(OFI_DISABLE_MULTIRAIL, bool, false, SHMEM_INTERNAL_ENV_CA "Disable usage of multirail functionality") SHMEM_INTERNAL_ENV_DEF(OFI_DISABLE_SINGLE_EP, bool, false, SHMEM_INTERNAL_ENV_CAT_TRANSPORT, "Disable single endpoint resource optimization (enable separate Tx and Rx EPs)") +SHMEM_INTERNAL_ENV_DEF(OFI_TX_MUX_ENDPOINTS, long, 1, SHMEM_INTERNAL_ENV_CAT_TRANSPORT, + "Number of TX endpoints per context for transmit multiplexing (1 = disabled)") #endif #ifdef USE_UCX diff --git a/src/transport_ofi.c b/src/transport_ofi.c index f8f04d65..f98170ec 100644 --- a/src/transport_ofi.c +++ b/src/transport_ofi.c @@ -383,6 +383,7 @@ static stx_allocator_t shmem_transport_ofi_stx_allocator; static long shmem_transport_ofi_stx_max; static long shmem_transport_ofi_stx_threshold; +static long shmem_transport_ofi_tx_mux_max; /* Number of TX endpoints per context for muxing */ struct shmem_transport_ofi_stx_t { struct fid_stx* stx; @@ -1800,6 +1801,105 @@ static int shmem_transport_ofi_ctx_init(shmem_transport_ctx_t *ctx, int id) ret = bind_enable_ep_resources(ctx); OFI_CHECK_RETURN_MSG(ret, "context bind/enable endpoint failed (%s)\n", fi_strerror(errno)); + /* Initialize TX muxing fields */ + ctx->num_tx_eps = 1; + ctx->tx_ep_idx = 0; + ctx->tx_ep_arr = NULL; + ctx->tx_stx_idxs = NULL; + + /* Create additional TX endpoints for muxing when enabled and applicable. + * TX muxing is not applied to the default context in single-endpoint mode + * since it shares the target endpoint. */ + if (shmem_transport_ofi_tx_mux_max > 1 && + !(shmem_transport_ofi_single_ep && id == SHMEM_TRANSPORT_CTX_DEFAULT_ID)) + { + int n = (int) shmem_transport_ofi_tx_mux_max; + + ctx->tx_ep_arr = calloc(n, sizeof(struct fid_ep *)); + if (ctx->tx_ep_arr == NULL) { + RAISE_ERROR_STR("Out of memory allocating TX muxing endpoint array"); + } + ctx->tx_stx_idxs = malloc(n * sizeof(int)); + if (ctx->tx_stx_idxs == NULL) { + RAISE_ERROR_STR("Out of memory allocating TX muxing STX index array"); + } + /* Initialize to -1 (no STX) so the cleanup loop can safely check all entries */ + memset(ctx->tx_stx_idxs, -1, n * sizeof(int)); + + /* Slot 0 is the primary endpoint already created above */ + ctx->tx_ep_arr[0] = ctx->ep; + ctx->tx_stx_idxs[0] = ctx->stx_idx; + + for (int i = 1; i < n; i++) { + struct fid_ep *extra_ep; + + /* Allocate a shared STX from the pool for this extra endpoint */ + int extra_stx = -1; + if (shmem_transport_ofi_stx_max > 0) { + extra_stx = shmem_transport_ofi_stx_search_shared(shmem_transport_ofi_stx_threshold); + if (extra_stx < 0) + extra_stx = shmem_transport_ofi_stx_search_unused(); + if (extra_stx < 0) + extra_stx = shmem_transport_ofi_stx_search_shared(-1); + if (extra_stx >= 0) + shmem_transport_ofi_stx_pool[extra_stx].ref_cnt++; + } + + /* Create the additional TX endpoint */ + ret = fi_endpoint(shmem_transport_ofi_domainfd, info->p_info, &extra_ep, NULL); + if (ret) { + RAISE_WARN_MSG("Additional TX endpoint %d creation failed (%s); " + "using %d TX endpoints for muxing\n", + i, fi_strerror(ret), i); + if (extra_stx >= 0) + shmem_transport_ofi_stx_pool[extra_stx].ref_cnt--; + break; + } + + /* Record the endpoint and its STX index now (before bind/enable) so that + * ctx_destroy can clean up even if a subsequent operation fails. */ + ctx->tx_ep_arr[i] = extra_ep; + ctx->tx_stx_idxs[i] = extra_stx; + ctx->num_tx_eps = i + 1; + + /* Bind the STX (if available) */ + if (extra_stx >= 0) { + ret = fi_ep_bind(extra_ep, + &shmem_transport_ofi_stx_pool[extra_stx].stx->fid, 0); + OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind STX to extra TX endpoint %d failed (%s)\n", + i, fi_strerror(ret)); + } + + /* Bind shared put counter */ + ret = fi_ep_bind(extra_ep, &ctx->put_cntr->fid, FI_WRITE); + OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind put CNTR to extra TX endpoint %d failed (%s)\n", + i, fi_strerror(ret)); + + /* Bind shared get counter */ + ret = fi_ep_bind(extra_ep, &ctx->get_cntr->fid, FI_READ); + OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind get CNTR to extra TX endpoint %d failed (%s)\n", + i, fi_strerror(ret)); + + /* Bind shared CQ */ + ret = fi_ep_bind(extra_ep, &ctx->cq->fid, + FI_SELECTIVE_COMPLETION | FI_TRANSMIT | FI_RECV); + OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind CQ to extra TX endpoint %d failed (%s)\n", + i, fi_strerror(ret)); + + /* Bind address vector */ + ret = fi_ep_bind(extra_ep, &shmem_transport_ofi_avfd->fid, 0); + OFI_CHECK_RETURN_MSG(ret, "fi_ep_bind AV to extra TX endpoint %d failed (%s)\n", + i, fi_strerror(ret)); + + /* Enable the endpoint */ + ret = fi_enable(extra_ep); + OFI_CHECK_RETURN_MSG(ret, "fi_enable on extra TX endpoint %d failed (%s)\n", + i, fi_strerror(ret)); + } + + DEBUG_MSG("TX muxing: context %d using %d TX endpoints\n", id, ctx->num_tx_eps); + } + if (ctx->options & SHMEMX_CTX_BOUNCE_BUFFER && shmem_transport_ofi_bounce_buffer_size > 0 && shmem_transport_ofi_max_bounce_buffers > 0) @@ -1902,6 +2002,13 @@ int shmem_transport_init(void) shmem_transport_ofi_put_poll_limit = shmem_internal_params.OFI_TX_POLL_LIMIT; shmem_transport_ofi_get_poll_limit = shmem_internal_params.OFI_RX_POLL_LIMIT; + /* TX muxing: validate and store the number of TX endpoints per context */ + if (shmem_internal_params.OFI_TX_MUX_ENDPOINTS < 1) { + RAISE_ERROR_MSG("OFI_TX_MUX_ENDPOINTS must be >= 1 (got %ld)\n", + shmem_internal_params.OFI_TX_MUX_ENDPOINTS); + } + shmem_transport_ofi_tx_mux_max = shmem_internal_params.OFI_TX_MUX_ENDPOINTS; + #ifdef USE_CTX_LOCK /* In multithreaded mode, force completion polling so that threads yield * the lock during put/get completion operations. User can still override @@ -2110,6 +2217,27 @@ void shmem_transport_ctx_destroy(shmem_transport_ctx_t *ctx) OFI_CHECK_ERROR_MSG(ret, "Context endpoint close failed (%s)\n", fi_strerror(errno)); } + /* Close additional TX endpoints created for TX muxing (index 0 is ctx->ep, + * already closed above; release their STX references as well). */ + if (ctx->tx_ep_arr != NULL) { + SHMEM_MUTEX_LOCK(shmem_transport_ofi_lock); + for (int i = 1; i < ctx->num_tx_eps; i++) { + if (ctx->tx_ep_arr[i]) { + ret = fi_close(&ctx->tx_ep_arr[i]->fid); + OFI_CHECK_ERROR_MSG(ret, "Extra TX endpoint close failed (%s)\n", + fi_strerror(errno)); + } + if (ctx->tx_stx_idxs != NULL && ctx->tx_stx_idxs[i] >= 0) { + shmem_transport_ofi_stx_pool[ctx->tx_stx_idxs[i]].ref_cnt--; + } + } + SHMEM_MUTEX_UNLOCK(shmem_transport_ofi_lock); + free(ctx->tx_ep_arr); + free(ctx->tx_stx_idxs); + ctx->tx_ep_arr = NULL; + ctx->tx_stx_idxs = NULL; + } + if (ctx->bounce_buffers) { shmem_free_list_destroy(ctx->bounce_buffers); } diff --git a/src/transport_ofi.h b/src/transport_ofi.h index f7e2ebf3..3df81b61 100644 --- a/src/transport_ofi.h +++ b/src/transport_ofi.h @@ -337,6 +337,16 @@ struct shmem_transport_ctx_t { int stx_idx; struct shmem_internal_tid tid; struct shmem_internal_team_t *team; + /* TX Muxing: allows a context to use multiple TX endpoints to improve + * transmit throughput by distributing operations across endpoints. + * When num_tx_eps <= 1 (disabled), ctx->ep is used directly. + * When num_tx_eps > 1, tx_ep_arr[0..num_tx_eps-1] holds all TX endpoints, + * where tx_ep_arr[0] == ctx->ep. All endpoints share the same counters + * and CQ so existing quiet/fence/wait logic works unchanged. */ + int num_tx_eps; /* Total TX eps; 1 = muxing disabled */ + int tx_ep_idx; /* Round-robin index (0..num_tx_eps-1) */ + struct fid_ep **tx_ep_arr; /* All TX eps (NULL if num_tx_eps <= 1) */ + int *tx_stx_idxs; /* STX index per additional TX ep (NULL if num_tx_eps <= 1) */ }; typedef struct shmem_transport_ctx_t shmem_transport_ctx_t; @@ -344,6 +354,21 @@ extern shmem_transport_ctx_t shmem_transport_ctx_default; extern struct fid_ep* shmem_transport_ofi_target_ep; +/* Select the TX endpoint for the next operation using round-robin scheduling. + * When TX muxing is disabled (num_tx_eps <= 1), returns ctx->ep directly. + * Otherwise, cycles through tx_ep_arr to distribute TX load across endpoints. + * Must be called with the context lock held (or from a private/serialized ctx). */ +static inline +struct fid_ep* shmem_transport_ofi_get_tx_ep(shmem_transport_ctx_t *ctx) +{ + if (ctx->num_tx_eps <= 1) + return ctx->ep; + + int idx = ctx->tx_ep_idx; + ctx->tx_ep_idx = (idx + 1 >= ctx->num_tx_eps) ? 0 : idx + 1; + return ctx->tx_ep_arr[idx]; +} + #ifdef USE_CTX_LOCK #define SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx) \ do { \ @@ -628,16 +653,19 @@ void shmem_transport_put_scalar(shmem_transport_ctx_t* ctx, void *target, const SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_put_cntr); - do { + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { - ret = fi_inject_write(ctx->ep, - source, - len, - GET_DEST(dst), - (uint64_t) addr, - key); + ret = fi_inject_write(tx_ep, + source, + len, + GET_DEST(dst), + (uint64_t) addr, + key); - } while (try_again(ctx, ret, &polled)); + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); } @@ -667,13 +695,16 @@ void shmem_transport_ofi_put_large(shmem_transport_ctx_t* ctx, void *target, con SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_put_cntr); - do { - ret = fi_write(ctx->ep, - frag_source, frag_len, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), - GET_DEST(dst), frag_target, - key, NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_write(tx_ep, + frag_source, frag_len, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), + GET_DEST(dst), frag_target, + key, NULL); + } while (try_again(ctx, ret, &polled)); + } frag_source += frag_len; frag_target += frag_len; @@ -720,7 +751,7 @@ void shmem_transport_put_nb(shmem_transport_ctx_t* ctx, void *target, const void .data = 0 }; do { - ret = fi_writemsg(ctx->ep, &msg, FI_COMPLETION | FI_DELIVERY_COMPLETE); + ret = fi_writemsg(shmem_transport_ofi_get_tx_ep(ctx), &msg, FI_COMPLETION | FI_DELIVERY_COMPLETE); } while (try_again(ctx, ret, &polled)); SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); @@ -769,7 +800,7 @@ void shmem_transport_put_signal_nbi(shmem_transport_ctx_t* ctx, void *target, co }; do { - ret = fi_writemsg(ctx->ep, &msg, FI_DELIVERY_COMPLETE | FI_INJECT); + ret = fi_writemsg(shmem_transport_ofi_get_tx_ep(ctx), &msg, FI_DELIVERY_COMPLETE | FI_INJECT); } while (try_again(ctx, ret, &polled)); SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); @@ -816,9 +847,12 @@ void shmem_transport_put_signal_nbi(shmem_transport_ctx_t* ctx, void *target, co SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_put_cntr); - do { - ret = fi_writemsg(ctx->ep, &msg, FI_DELIVERY_COMPLETE); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_writemsg(tx_ep, &msg, FI_DELIVERY_COMPLETE); + } while (try_again(ctx, ret, &polled)); + } frag_source += frag_len; frag_target += frag_len; @@ -867,7 +901,7 @@ void shmem_transport_put_signal_nbi(shmem_transport_ctx_t* ctx, void *target, co }; do { - ret = fi_atomicmsg(ctx->ep, &msg_signal, flags_signal); + ret = fi_atomicmsg(shmem_transport_ofi_get_tx_ep(ctx), &msg_signal, flags_signal); } while (try_again(ctx, ret, &polled)); SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); @@ -915,16 +949,19 @@ void shmem_transport_get(shmem_transport_ctx_t* ctx, void *target, const void *s if (len <= shmem_transport_ofi_max_msg_size) { SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_read(ctx->ep, - target, - len, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(target)), - GET_DEST(dst), - (uint64_t) addr, - key, - NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_read(tx_ep, + target, + len, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(target)), + GET_DEST(dst), + (uint64_t) addr, + key, + NULL); + } while (try_again(ctx, ret, &polled)); + } } else { uint8_t *frag_target = (uint8_t *) target; @@ -938,13 +975,16 @@ void shmem_transport_get(shmem_transport_ctx_t* ctx, void *target, const void *s SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_read(ctx->ep, - frag_target, frag_len, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(target)), - GET_DEST(dst), frag_source, - key, NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_read(tx_ep, + frag_target, frag_len, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(target)), + GET_DEST(dst), frag_source, + key, NULL); + } while (try_again(ctx, ret, &polled)); + } frag_source += frag_len; frag_target += frag_len; @@ -1038,18 +1078,21 @@ void shmem_transport_cswap_nbi(shmem_transport_ctx_t* ctx, void *target, const SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_compare_atomicmsg(ctx->ep, - &msg, - &comparev, - NULL, - 1, - &resultv, - GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(dest)), - 1, - FI_INJECT); /* FI_DELIVERY_COMPLETE is not required as - it is implied for fetch atomicmsgs */ - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_compare_atomicmsg(tx_ep, + &msg, + &comparev, + NULL, + 1, + &resultv, + GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(dest)), + 1, + FI_INJECT); /* FI_DELIVERY_COMPLETE is not required as + it is implied for fetch atomicmsgs */ + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); } @@ -1079,22 +1122,25 @@ void shmem_transport_cswap(shmem_transport_ctx_t* ctx, void *target, const void SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_compare_atomic(ctx->ep, - source, - 1, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), - operand, - NULL, - dest, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)), - GET_DEST(dst), - (uint64_t) addr, - key, - SHMEM_TRANSPORT_DTYPE(datatype), - FI_CSWAP, - NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_compare_atomic(tx_ep, + source, + 1, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), + operand, + NULL, + dest, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)), + GET_DEST(dst), + (uint64_t) addr, + key, + SHMEM_TRANSPORT_DTYPE(datatype), + FI_CSWAP, + NULL); + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); #endif } @@ -1118,22 +1164,25 @@ void shmem_transport_mswap(shmem_transport_ctx_t* ctx, void *target, const void SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_compare_atomic(ctx->ep, - source, - 1, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), - mask, - NULL, - dest, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)), - GET_DEST(dst), - (uint64_t) addr, - key, - SHMEM_TRANSPORT_DTYPE(datatype), - FI_MSWAP, - NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_compare_atomic(tx_ep, + source, + 1, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), + mask, + NULL, + dest, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)), + GET_DEST(dst), + (uint64_t) addr, + key, + SHMEM_TRANSPORT_DTYPE(datatype), + FI_MSWAP, + NULL); + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); } @@ -1155,16 +1204,19 @@ void shmem_transport_atomic(shmem_transport_ctx_t* ctx, void *target, const void SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_put_cntr); - do { - ret = fi_inject_atomic(ctx->ep, - source, - 1, - GET_DEST(dst), - (uint64_t) addr, - key, - SHMEM_TRANSPORT_DTYPE(datatype), - op); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_inject_atomic(tx_ep, + source, + 1, + GET_DEST(dst), + (uint64_t) addr, + key, + SHMEM_TRANSPORT_DTYPE(datatype), + op); + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); } @@ -1206,16 +1258,19 @@ void shmem_transport_atomicv(shmem_transport_ctx_t* ctx, void *target, const voi SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_put_cntr); - do { - ret = fi_inject_atomic(ctx->ep, - source, - len, - GET_DEST(dst), - (uint64_t) addr, - key, - dt, - op); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_inject_atomic(tx_ep, + source, + len, + GET_DEST(dst), + (uint64_t) addr, + key, + dt, + op); + } while (try_again(ctx, ret, &polled)); + } } else if (full_len <= MIN(shmem_transport_ofi_bounce_buffer_size, max_atomic_size) && @@ -1241,9 +1296,12 @@ void shmem_transport_atomicv(shmem_transport_ctx_t* ctx, void *target, const voi .context = buff, .data = 0 }; - do { - ret = fi_atomicmsg(ctx->ep, &msg, FI_COMPLETION | FI_DELIVERY_COMPLETE); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_atomicmsg(tx_ep, &msg, FI_COMPLETION | FI_DELIVERY_COMPLETE); + } while (try_again(ctx, ret, &polled)); + } } else { size_t sent = 0; @@ -1254,20 +1312,23 @@ void shmem_transport_atomicv(shmem_transport_ctx_t* ctx, void *target, const voi (max_atomic_size/SHMEM_Dtsize[dt])); polled = 0; SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_put_cntr); - do { - ret = fi_atomic(ctx->ep, - (void *)((char *)source + - (sent*SHMEM_Dtsize[dt])), - chunksize, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), - GET_DEST(dst), - ((uint64_t) addr + - (sent*SHMEM_Dtsize[dt])), - key, - dt, - op, - NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_atomic(tx_ep, + (void *)((char *)source + + (sent*SHMEM_Dtsize[dt])), + chunksize, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), + GET_DEST(dst), + ((uint64_t) addr + + (sent*SHMEM_Dtsize[dt])), + key, + dt, + op, + NULL); + } while (try_again(ctx, ret, &polled)); + } sent += chunksize; } @@ -1313,15 +1374,18 @@ void shmem_transport_fetch_atomic_nbi(shmem_transport_ctx_t* ctx, void *target, SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_fetch_atomicmsg(ctx->ep, - &msg, - &resultv, - GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(dest)), - 1, - FI_INJECT); /* FI_DELIVERY_COMPLETE is not required as it's - implied for fetch atomicmsgs */ - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_fetch_atomicmsg(tx_ep, + &msg, + &resultv, + GET_MR_DESC_ADDR(shmem_transport_ofi_get_mr_desc_index(dest)), + 1, + FI_INJECT); /* FI_DELIVERY_COMPLETE is not required as it's + implied for fetch atomicmsgs */ + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); } @@ -1352,20 +1416,23 @@ void shmem_transport_fetch_atomic(shmem_transport_ctx_t* ctx, void *target, SHMEM_TRANSPORT_OFI_CTX_LOCK(ctx); SHMEM_TRANSPORT_OFI_CNTR_INC(&ctx->pending_get_cntr); - do { - ret = fi_fetch_atomic(ctx->ep, - source, - 1, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), - dest, - GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)), - GET_DEST(dst), - (uint64_t) addr, - key, - SHMEM_TRANSPORT_DTYPE(datatype), - op, - NULL); - } while (try_again(ctx, ret, &polled)); + { + struct fid_ep *tx_ep = shmem_transport_ofi_get_tx_ep(ctx); + do { + ret = fi_fetch_atomic(tx_ep, + source, + 1, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(source)), + dest, + GET_MR_DESC(shmem_transport_ofi_get_mr_desc_index(dest)), + GET_DEST(dst), + (uint64_t) addr, + key, + SHMEM_TRANSPORT_DTYPE(datatype), + op, + NULL); + } while (try_again(ctx, ret, &polled)); + } SHMEM_TRANSPORT_OFI_CTX_UNLOCK(ctx); #endif }