@@ -1085,14 +1085,16 @@ static inline int handle_write_comp(struct fi_cq_data_entry *cq_entry, nccl_net_
10851085static inline int handle_flush_comp (nccl_net_ofi_rdma_req_t *req)
10861086{
10871087 int ret = 0 ;
1088- rdma_req_flush_data_t *flush_data = get_flush_data (req);
1088+
10891089
10901090#if HAVE_NEURON
1091+ rdma_req_flush_data_t *flush_data = get_flush_data (req);
10911092 ret = inc_req_completion (req, 0 , flush_data->total_num_compls );
10921093#endif
10931094
10941095#if HAVE_CUDA
10951096
1097+ rdma_req_flush_data_t *flush_data = get_flush_data (req);
10961098 int num_completions = ++(req->ncompls );
10971099 /* Check if the number of completions is equal to total completions
10981100 * and if the req has not errored.
@@ -3235,10 +3237,10 @@ int nccl_net_ofi_rdma_domain_t::dealloc_and_dereg_flush_buff()
32353237int nccl_net_ofi_rdma_domain_t::alloc_and_reg_flush_buff (int dev_id)
32363238{
32373239 int ret = 0 ;
3238- int rc;
32393240 nccl_net_ofi_rdma_mr_handle_t *mr_handle = NULL ;
32403241
32413242#if HAVE_NEURON
3243+ int rc;
32423244 NCCL_OFI_TRACE (NCCL_NET, " Registering buffer for flush operations" );
32433245
32443246 this ->flush_buff .size = NCCL_OFI_FLUSH_SIZE;
@@ -3267,6 +3269,7 @@ int nccl_net_ofi_rdma_domain_t::alloc_and_reg_flush_buff(int dev_id)
32673269#endif
32683270
32693271#if HAVE_CUDA
3272+ int rc;
32703273 NCCL_OFI_TRACE (NCCL_NET, " Registering buffer in GPU for flush operations" );
32713274
32723275 /*
@@ -5347,6 +5350,7 @@ static int post_eager_copy(nccl_net_ofi_rdma_req_t *req)
53475350 return rc;
53485351}
53495352
5353+ #ifdef HAVE_NEURON
53505354static int post_flush_req (nccl_net_ofi_rdma_req_t *req)
53515355{
53525356 nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm ;
@@ -5361,17 +5365,10 @@ static int post_flush_req(nccl_net_ofi_rdma_req_t *req)
53615365 comm_rail = rdma_recv_comm_get_rail (r_comm, rail_id);
53625366 struct fid_mr *mr_handle = NULL ;
53635367
5364- #if HAVE_NEURON
53655368 void *desc = fi_mr_desc (domain->flush_buff .mr_handle ->mr [rail_id].get ());
53665369 mr_handle = flush_data->mr_handle ->mr [rail_id].get ();
5367- #endif
53685370
5369- #if HAVE_CUDA
5370- freelist_regmr_fn_handle_t *fl_handle =
5371- (freelist_regmr_fn_handle_t *)flush_data->flush_fl_elem ->mr_handle ;
5372- void *desc = fi_mr_desc (fl_handle->mr_handle ->mr [rail_id].get ());
5373- mr_handle = domain->flush_buff .mr_handle ->mr [rail_id].get ();
5374- #endif
5371+
53755372 uint64_t cuda_key = 0ULL ;
53765373
53775374 if (mr_handle != NULL ) {
@@ -5384,23 +5381,61 @@ static int post_flush_req(nccl_net_ofi_rdma_req_t *req)
53845381 }
53855382 }
53865383
5387- #if HAVE_NEURON
53885384 nccl_net_ofi_rdma_flush_buffer_t *f_buff = &domain->flush_buff ;
53895385 uint64_t host_buff_addr = (uint64_t )f_buff->buffer + (NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE * rail_id);
53905386 rc = fi_read (comm_rail->local_ep ,
53915387 (void *)host_buff_addr,
53925388 NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE, desc, comm_rail->local_addr ,
53935389 (uint64_t )(virt_addr_mr ? flush_data->data : 0 ),
53945390 cuda_key, rdma_req_get_ofi_context (req, rail_id));
5395- #endif
5396- #if HAVE_CUDA
5391+
5392+ if ((rc != 0 ) && (rc != -FI_EAGAIN)) {
5393+ NCCL_OFI_WARN (" Error posting flush request. RC: %zd, Error: %s" ,
5394+ rc, fi_strerror (-rc));
5395+ goto exit;
5396+ }
5397+ }
5398+
5399+ exit:
5400+ return (int )rc;
5401+ }
5402+ #elif HAVE_CUDA
5403+ static int post_flush_req (nccl_net_ofi_rdma_req_t *req)
5404+ {
5405+ nccl_net_ofi_rdma_recv_comm_t *r_comm = (nccl_net_ofi_rdma_recv_comm_t *)req->comm ;
5406+ nccl_net_ofi_rdma_ep_t *ep = (nccl_net_ofi_rdma_ep_t *)r_comm->base .base .ep ;
5407+ nccl_net_ofi_rdma_domain_t *domain = ep->rdma_endpoint_get_domain ();
5408+ rdma_req_flush_data_t *flush_data = get_flush_data (req);
5409+ nccl_net_ofi_rdma_recv_comm_rail_t *comm_rail;
5410+ ssize_t rc = 0 ;
5411+
5412+ /* iterate all rails and post RDMA local read */
5413+ for (uint16_t rail_id = 0 ; rail_id < ep->num_rails ; rail_id++) {
5414+ comm_rail = rdma_recv_comm_get_rail (r_comm, rail_id);
5415+ struct fid_mr *mr_handle = NULL ;
5416+
5417+ freelist_regmr_fn_handle_t *fl_handle =
5418+ (freelist_regmr_fn_handle_t *)flush_data->flush_fl_elem ->mr_handle ;
5419+ void *desc = fi_mr_desc (fl_handle->mr_handle ->mr [rail_id].get ());
5420+ mr_handle = domain->flush_buff .mr_handle ->mr [rail_id].get ();
5421+ uint64_t cuda_key = 0ULL ;
5422+
5423+ if (mr_handle != NULL ) {
5424+ /* Extract remote key */
5425+ cuda_key = fi_mr_key (mr_handle);
5426+ if (OFI_UNLIKELY (cuda_key == FI_KEY_NOTAVAIL)) {
5427+ NCCL_OFI_WARN (" Memory registration may not have completed." );
5428+ rc = -FI_ENODATA;
5429+ goto exit;
5430+ }
5431+ }
5432+
53975433 uint64_t *host_buff_addr = get_flush_buffer_for_rail (flush_data->flush_fl_elem ->ptr , rail_id);
53985434 rc = fi_read (comm_rail->local_ep ,
53995435 (void *)host_buff_addr,
54005436 NCCL_OFI_DEFAULT_CPU_CACHE_LINE_SIZE, desc, comm_rail->local_addr ,
54015437 (uint64_t )domain->flush_buff .buffer ,
54025438 cuda_key, rdma_req_get_ofi_context (req, rail_id));
5403- #endif
54045439 if ((rc != 0 ) && (rc != -FI_EAGAIN)) {
54055440 NCCL_OFI_WARN (" Error posting flush request. RC: %zd, Error: %s" ,
54065441 rc, fi_strerror (-rc));
@@ -5411,6 +5446,11 @@ static int post_flush_req(nccl_net_ofi_rdma_req_t *req)
54115446 exit:
54125447 return (int )rc;
54135448}
5449+ #else
5450+ static int post_flush_req (nccl_net_ofi_rdma_req_t *req) {
5451+ return -FI_EOPNOTSUPP;
5452+ }
5453+ #endif
54145454
54155455static inline int check_post_rx_buff_req (nccl_net_ofi_rdma_req_t *rx_buff_req)
54165456{
0 commit comments