@@ -67,16 +67,7 @@ struct TORCH_API StringView {
67
67
// Soft limit on the number of callbacks to use;
68
68
constexpr std::size_t kSoftLimitCallbacks = 4 ;
69
69
70
- // An abstract base class for various observer contexts that can be attached to
71
- // the RecordFunction.
72
- struct ObserverContext {
73
- virtual ~ObserverContext () {}
74
- protected:
75
- ObserverContext () {}
76
- };
77
-
78
70
typedef c10::SmallVector<uint64_t , kSoftLimitCallbacks > CallbackHandles;
79
- typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
80
71
typedef uint64_t RecordFunctionHandle;
81
72
82
73
struct TORCH_API RecordFunction {
@@ -173,15 +164,6 @@ struct TORCH_API RecordFunction {
173
164
// public because of anonymous "friend" class
174
165
CallbackHandles sorted_active_tls_handles_;
175
166
CallbackHandles sorted_active_global_handles_;
176
-
177
- // Stores various ObserverContext objects with event metadata for thread local
178
- // callbacks.
179
- ObserverContextList tls_ctx_;
180
-
181
- // Stores various ObserverContext objects with event metadata for global
182
- // callbacks.
183
- ObserverContextList global_ctx_;
184
-
185
167
// Whether this RecordFunction runs any callbacks
186
168
bool active = false ;
187
169
// / Whether any of the picked callbacks require inputs
@@ -216,8 +198,6 @@ struct TORCH_API RecordFunction {
216
198
* RecordFunctionCallback represents a pair of callbacks to be used with
217
199
* RecordFunction, members:
218
200
* start, end - the callbacks to run when entering and exiting the scope;
219
- * optionally, the start callback may return an ObserverContext which will
220
- * be passed to the end callback, use appropriate constructor accordingly.
221
201
* needs_inputs - whether the callbacks need the inputs passed from the observed
222
202
* function/range; NOTE: passing the inputs incurs an additional overhead;
223
203
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
@@ -231,25 +211,12 @@ struct TORCH_API RecordFunction {
231
211
*/
232
212
class TORCH_API RecordFunctionCallback {
233
213
public:
234
- // This interface supports observers that require passing an ObserverContext
235
- // between start and end callbacks.
236
- explicit RecordFunctionCallback (
237
- std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
238
- std::function<void(const RecordFunction&, ObserverContext*)> end =
239
- [](const RecordFunction&, ObserverContext*) {}):
240
- start_ (std::move(start)),
241
- end_ (std::move(end)) {
242
- scopes_.fill (true );
243
- }
244
-
245
- // This interface is for observers that do not pass an ObserverContext object
246
- // between start and end callbacks.
247
214
explicit RecordFunctionCallback (
248
215
std::function<void (const RecordFunction&)> start,
249
216
std::function<void(const RecordFunction&)> end =
250
217
[](const RecordFunction&) {}):
251
- start_{[start]( const RecordFunction& rf) { start (rf); return nullptr ; }} ,
252
- end_{[end]( const RecordFunction& rf, ObserverContext*) { end (rf); }} {
218
+ start_ (std::move(start)) ,
219
+ end_ (std::move(end)) {
253
220
scopes_.fill (true );
254
221
}
255
222
@@ -305,20 +272,20 @@ class TORCH_API RecordFunctionCallback {
305
272
return scopes_[(size_t )sc];
306
273
}
307
274
308
- inline const std::function<std::unique_ptr<ObserverContext> (const RecordFunction&)>& start () const {
275
+ inline const std::function<void (const RecordFunction&)>& start () const {
309
276
return start_;
310
277
}
311
278
312
- inline const std::function<void (const RecordFunction&, ObserverContext* )>& end () const {
279
+ inline const std::function<void (const RecordFunction&)>& end () const {
313
280
return end_;
314
281
}
315
282
316
283
// whether the callbacks should run in the given scope
317
284
bool shouldRun (RecordScope scope) const ;
318
285
319
286
private:
320
- std::function<std::unique_ptr<ObserverContext> (const RecordFunction&)> start_;
321
- std::function<void (const RecordFunction&, ObserverContext* )> end_;
287
+ std::function<void (const RecordFunction&)> start_;
288
+ std::function<void (const RecordFunction&)> end_;
322
289
std::function<bool (const RecordFunctionCallback&)> should_run_;
323
290
bool needs_inputs_ = false ;
324
291
bool needs_ids_ = false ;
0 commit comments