@@ -51,15 +51,15 @@ static inline void nvtx_end(nvtxRangeId_t id) {
5151}
5252
5353#define NCCL_OFI_TRACE_SEND_NVTX (dev , size , comm , msg_seq_num , request , nccl_req ) do { \
54- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
54+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
5555 nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)comm) \
5656 ->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
5757 get_send_data(request)->trace_id = nvtx_start_domain(true, handle, "Send", 0xeb9234); \
5858 } \
5959} while (0)
6060
6161#define NCCL_OFI_TRACE_SEND_END_NVTX (request ) do { \
62- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
62+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
6363 nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_send_comm_t*)(request->comm)) \
6464 ->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
6565 nvtx_end_domain(handle, get_send_data(request)->trace_id); \
@@ -68,98 +68,98 @@ static inline void nvtx_end(nvtxRangeId_t id) {
6868
6969#define NCCL_OFI_TRACE_EAGER_SEND_START_NVTX (dev , rail_id , size , comm , msg_seq_num , request ) do { \
7070 nvtxDomainHandle_t handle; \
71- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
71+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
7272 handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
7373 get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_eager", 0x0000FF); \
7474 } \
75- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
75+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
7676 handle = (static_cast<nccl_net_ofi_rdma_ep_t *>(comm->ep)->rdma_endpoint_get_device())->nvtx_domain[rail_id]; \
7777 get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_eager", 0x0000FF); \
7878 } \
7979} while (0)
8080
8181#define NCCL_OFI_TRACE_EAGER_SEND_COMPLETE_NVTX (dev , rail_id , comm , msg_seq_num , request ) do { \
8282 nvtxDomainHandle_t handle; \
83- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
83+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
8484 handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
8585 nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \
8686 } \
87- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
87+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
8888 handle = (static_cast<nccl_net_ofi_rdma_ep_t *>(comm->ep)->rdma_endpoint_get_device())->nvtx_domain[rail_id]; \
8989 nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \
9090 } \
9191} while(0)
9292
9393#define NCCL_OFI_TRACE_SEND_CTRL_RECV_NVTX (dev , rail_id , comm , msg_seq_num ) do { \
9494 nvtxDomainHandle_t handle; \
95- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
95+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
9696 handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
9797 nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \
9898 } \
99- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
99+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
100100 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(s_comm->base.base.ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
101101 nvtx_mark_domain(handle, "Send_ctrl_recv", 0x00ffff); \
102102 } \
103103} while (0)
104104
105105#define NCCL_OFI_TRACE_WRITE_CTRL_START_NVTX (dev , rail_id , comm , req , msg_seq_num ) do { \
106106 nvtxDomainHandle_t handle; \
107- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
107+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
108108 handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
109109 get_recv_data(req)->write_ctrl_trace_id = nvtx_start_domain(true, handle, "Write_ctrl_start", 0x00ffff); \
110110 } \
111- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
111+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
112112 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(comm->ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
113113 get_recv_data(req)->write_ctrl_trace_id = nvtx_start_domain(true, handle, "Write_ctrl_start", 0x00ffff); \
114114 } \
115115} while (0)
116116
117117#define NCCL_OFI_TRACE_WRITE_CTRL_END_NVTX (dev , rail_id , comm , req , msg_seq_num ) do { \
118118 nvtxDomainHandle_t handle; \
119- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
119+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
120120 handle = ((nccl_net_ofi_rdma_recv_comm_t *)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
121121 nvtx_end_domain(handle, get_recv_data(req)->write_ctrl_trace_id); \
122122 } \
123- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
123+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
124124 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(comm->ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
125125 nvtx_end_domain(handle, get_recv_data(req)->write_ctrl_trace_id);\
126126 } \
127127} while (0)
128128
129129#define NCCL_OFI_TRACE_SEND_WRITE_SEG_START_NVTX (dev , rail_id , size , comm , msg_seq_num , request ) do { \
130130 nvtxDomainHandle_t handle; \
131- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
131+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
132132 handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
133133 get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \
134134 } \
135- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
135+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
136136 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(comm->ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
137137 get_send_data(request)->seg_trace_id[rail_id] = nvtx_start_domain(true, handle, "Send_write_seg", 0xff0000); \
138138 } \
139139} while(0)
140140
141141#define NCCL_OFI_TRACE_SEND_WRITE_SEG_COMPLETE_NVTX (dev , rail_id , comm , msg_seq_num , request ) do { \
142142 nvtxDomainHandle_t handle; \
143- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
143+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
144144 handle = ((nccl_net_ofi_rdma_send_comm_t*)comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
145145 nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \
146146 } \
147- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
147+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
148148 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(comm->ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
149149 nvtx_end_domain(handle, get_send_data(request)->seg_trace_id[rail_id]); \
150150 } \
151151} while(0)
152152
153153#define NCCL_OFI_TRACE_RECV_NVTX (dev , r_comm , size , request , nccl_req ) do { \
154- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
154+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
155155 nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \
156156 ->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
157157 get_recv_data(request)->trace_id = nvtx_start_domain(true, handle, "Recv", 0x34EB37); \
158158 } \
159159} while(0)
160160
161161#define NCCL_OFI_TRACE_RECV_END_NVTX (request ) do { \
162- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
162+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
163163 nvtxDomainHandle_t handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm) \
164164 ->nvtx_domain[request->msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
165165 nvtx_end_domain(handle, get_recv_data(request)->trace_id); \
@@ -168,23 +168,23 @@ static inline void nvtx_end(nvtxRangeId_t id) {
168168
169169#define NCCL_OFI_TRACE_RECV_SEGMENT_COMPLETE_NVTX (dev , rail_id , size , request , msg_seq_num ) do { \
170170 nvtxDomainHandle_t handle; \
171- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
171+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
172172 handle = ((nccl_net_ofi_rdma_recv_comm_t *)request->comm)->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
173173 nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \
174174 } \
175- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
175+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
176176 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(request->comm->ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
177177 nvtx_mark_domain(handle, "Recv_segment_complete", 0xff0000); \
178178 } \
179179} while(0)
180180
181181#define NCCL_OFI_TRACE_EAGER_RECV_NVTX (dev , rail_id , comm , msg_seq_num ) do { \
182182 nvtxDomainHandle_t handle; \
183- if (NCCL_OFI_NVTX_TRACE_PER_COMM ) { \
183+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_COMM ) { \
184184 handle = comm->nvtx_domain[msg_seq_num % NCCL_OFI_N_NVTX_DOMAIN_PER_COMM]; \
185185 nvtx_mark_domain(handle, "Eager_recv", 0x0000FF); \
186186 } \
187- if (NCCL_OFI_NVTX_TRACE_PER_DEV ) { \
187+ if (ofi_nccl_nvtx_trace_dimension() == NVTX_TRACE_DIMENSION::PER_DEV ) { \
188188 handle = static_cast<nccl_net_ofi_rdma_ep_t *>(r_comm->base.base.ep)->rdma_endpoint_get_device()->nvtx_domain[rail_id]; \
189189 nvtx_mark_domain(handle, "Eager_recv", 0x0000FF); \
190190 } \
0 commit comments