Skip to content

Commit 7036e91

Browse files
Mike Ruberryfacebook-github-bot
Mike Ruberry
authored andcommitted
Revert D23323486: DPP Async Tracing
Test Plan: revert-hammer Differential Revision: D23323486 (71673b3) Original commit changeset: 4b6ca6c0e320 fbshipit-source-id: c6bd6d277aca070bef2de3522c2a60e23b4395ad
1 parent 2435d94 commit 7036e91

File tree

5 files changed

+18
-133
lines changed

5 files changed

+18
-133
lines changed

aten/src/ATen/record_function.cpp

+12-25
Original file line numberDiff line numberDiff line change
@@ -92,14 +92,11 @@ class CallbackManager {
9292
bool found_needs_ids = false;
9393
auto init_handles = [
9494
scope, &found_active_cb, &found_needs_inputs, &found_needs_ids](
95-
CallbackHandles& handles, RecordFunctionCallbacks& cbs, ObserverContextList& ctx_list) {
95+
CallbackHandles& handles, RecordFunctionCallbacks& cbs) {
9696
handles.clear();
97-
98-
size_t num_callbacks = 0;
9997
for (const auto& cb : cbs) {
10098
if (cb.first.shouldRun(scope)) {
10199
handles.push_back(cb.second);
102-
++num_callbacks;
103100
found_active_cb = true;
104101
if (cb.first.needsInputs()) {
105102
found_needs_inputs = true;
@@ -109,12 +106,10 @@ class CallbackManager {
109106
}
110107
}
111108
}
112-
// Pre-allocate observer context list with nullptr.
113-
ctx_list.resize(num_callbacks);
114109
};
115110

116-
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_, rec_fn.tls_ctx_);
117-
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_, rec_fn.global_ctx_);
111+
init_handles(rec_fn.sorted_active_tls_handles_, sorted_tls_callbacks_);
112+
init_handles(rec_fn.sorted_active_global_handles_, sorted_global_callbacks_);
118113
rec_fn.active = found_active_cb;
119114
rec_fn.needs_inputs = found_needs_inputs;
120115
if (found_needs_ids && found_active_cb) {
@@ -126,13 +121,11 @@ class CallbackManager {
126121
mergeRunCallbacks(
127122
sorted_global_callbacks_,
128123
rf.sorted_active_global_handles_,
129-
rf.global_ctx_,
130124
/* is_start */ true,
131125
rf);
132126
mergeRunCallbacks(
133127
sorted_tls_callbacks_,
134128
rf.sorted_active_tls_handles_,
135-
rf.tls_ctx_,
136129
/* is_start */ true,
137130
rf);
138131
rf.called_start_callbacks_ = true;
@@ -142,30 +135,21 @@ class CallbackManager {
142135
mergeRunCallbacks(
143136
sorted_global_callbacks_,
144137
rf.sorted_active_global_handles_,
145-
rf.global_ctx_,
146138
/* is_start */ false,
147139
rf);
148140
mergeRunCallbacks(
149141
sorted_tls_callbacks_,
150142
rf.sorted_active_tls_handles_,
151-
rf.tls_ctx_,
152143
/* is_start */ false,
153144
rf);
154145
}
155146

156147
private:
157148
bool tryRunCallback(
158-
const RecordFunctionCallback& rfcb,
159-
RecordFunction& rf,
160-
std::unique_ptr<ObserverContext>& ctx,
161-
bool is_start) {
149+
const std::function<void(const RecordFunction&)>& fn,
150+
RecordFunction& rf) {
162151
try {
163-
if (is_start) {
164-
ctx = rfcb.start()(rf);
165-
}
166-
else {
167-
rfcb.end()(rf, ctx.get());
168-
}
152+
fn(rf);
169153
return true;
170154
} catch (const std::exception &e) {
171155
LOG(WARNING) << "Exception in RecordFunction callback: "
@@ -181,12 +165,11 @@ class CallbackManager {
181165
void mergeRunCallbacks(
182166
const RecordFunctionCallbacks& sorted_callbacks,
183167
const CallbackHandles& sorted_handles,
184-
ObserverContextList& ctx_list,
185168
bool is_start,
186169
RecordFunction& rf) {
187170
size_t num_executed = 0;
188171
size_t idx_c = 0;
189-
for (size_t idx_h = 0; idx_h < sorted_handles.size() && idx_h < ctx_list.size(); ++idx_h) {
172+
for (size_t idx_h = 0; idx_h < sorted_handles.size(); ++idx_h) {
190173
while (idx_c < sorted_callbacks.size() &&
191174
sorted_callbacks[idx_c].second < sorted_handles[idx_h]) {
192175
++idx_c;
@@ -195,7 +178,11 @@ class CallbackManager {
195178
break;
196179
}
197180
if (sorted_callbacks[idx_c].second == sorted_handles[idx_h]) {
198-
tryRunCallback(sorted_callbacks[idx_c].first, rf, ctx_list[idx_h], is_start);
181+
if (is_start) {
182+
tryRunCallback(sorted_callbacks[idx_c].first.start(), rf);
183+
} else {
184+
tryRunCallback(sorted_callbacks[idx_c].first.end(), rf);
185+
}
199186
++num_executed;
200187
}
201188
}

aten/src/ATen/record_function.h

+6-39
Original file line numberDiff line numberDiff line change
@@ -67,16 +67,7 @@ struct TORCH_API StringView {
6767
// Soft limit on the number of callbacks to use;
6868
constexpr std::size_t kSoftLimitCallbacks = 4;
6969

70-
// An abstract base class for various observer contexts that can be attached to
71-
// the RecordFunction.
72-
struct ObserverContext {
73-
virtual ~ObserverContext() {}
74-
protected:
75-
ObserverContext() {}
76-
};
77-
7870
typedef c10::SmallVector<uint64_t, kSoftLimitCallbacks> CallbackHandles;
79-
typedef std::vector<std::unique_ptr<ObserverContext>> ObserverContextList;
8071
typedef uint64_t RecordFunctionHandle;
8172

8273
struct TORCH_API RecordFunction {
@@ -173,15 +164,6 @@ struct TORCH_API RecordFunction {
173164
// public because of anonymous "friend" class
174165
CallbackHandles sorted_active_tls_handles_;
175166
CallbackHandles sorted_active_global_handles_;
176-
177-
// Stores various ObserverContext objects with event metadata for thread local
178-
// callbacks.
179-
ObserverContextList tls_ctx_;
180-
181-
// Stores various ObserverContext objects with event metadata for global
182-
// callbacks.
183-
ObserverContextList global_ctx_;
184-
185167
// Whether this RecordFunction runs any callbacks
186168
bool active = false;
187169
/// Whether any of the picked callbacks require inputs
@@ -216,8 +198,6 @@ struct TORCH_API RecordFunction {
216198
* RecordFunctionCallback represents a pair of callbacks to be used with
217199
* RecordFunction, members:
218200
* start, end - the callbacks to run when entering and exiting the scope;
219-
* optionally, the start callback may return an ObserverContext which will
220-
* be passed to the end callback, use appropriate constructor accordingly.
221201
* needs_inputs - whether the callbacks need the inputs passed from the observed
222202
* function/range; NOTE: passing the inputs incurs an additional overhead;
223203
* sampling_probability - if not 1.0, then the callback is probabilistically sampled
@@ -231,25 +211,12 @@ struct TORCH_API RecordFunction {
231211
*/
232212
class TORCH_API RecordFunctionCallback {
233213
public:
234-
// This interface supports observers that require passing an ObserverContext
235-
// between start and end callbacks.
236-
explicit RecordFunctionCallback(
237-
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start,
238-
std::function<void(const RecordFunction&, ObserverContext*)> end =
239-
[](const RecordFunction&, ObserverContext*) {}):
240-
start_(std::move(start)),
241-
end_(std::move(end)) {
242-
scopes_.fill(true);
243-
}
244-
245-
// This interface is for observers that do not pass an ObserverContext object
246-
// between start and end callbacks.
247214
explicit RecordFunctionCallback(
248215
std::function<void(const RecordFunction&)> start,
249216
std::function<void(const RecordFunction&)> end =
250217
[](const RecordFunction&) {}):
251-
start_{[start](const RecordFunction& rf) { start(rf); return nullptr; }},
252-
end_{[end](const RecordFunction& rf, ObserverContext*) { end(rf); }} {
218+
start_(std::move(start)),
219+
end_(std::move(end)) {
253220
scopes_.fill(true);
254221
}
255222

@@ -305,20 +272,20 @@ class TORCH_API RecordFunctionCallback {
305272
return scopes_[(size_t)sc];
306273
}
307274

308-
inline const std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)>& start() const {
275+
inline const std::function<void(const RecordFunction&)>& start() const {
309276
return start_;
310277
}
311278

312-
inline const std::function<void(const RecordFunction&, ObserverContext*)>& end() const {
279+
inline const std::function<void(const RecordFunction&)>& end() const {
313280
return end_;
314281
}
315282

316283
// whether the callbacks should run in the given scope
317284
bool shouldRun(RecordScope scope) const;
318285

319286
private:
320-
std::function<std::unique_ptr<ObserverContext>(const RecordFunction&)> start_;
321-
std::function<void(const RecordFunction&, ObserverContext*)> end_;
287+
std::function<void(const RecordFunction&)> start_;
288+
std::function<void(const RecordFunction&)> end_;
322289
std::function<bool(const RecordFunctionCallback&)> should_run_;
323290
bool needs_inputs_ = false;
324291
bool needs_ids_ = false;

c10/macros/Macros.h

-3
Original file line numberDiff line numberDiff line change
@@ -304,9 +304,6 @@ __host__ __device__
304304
#endif // ANDROID / IOS
305305

306306
// Portably determine if a type T is trivially copyable or not.
307-
// Warning: __has_trivial_copy for GCC may not always detect the non-POD
308-
// correctly. For example, T = std::unique_ptr may evaluate to true and be
309-
// treated as POD. This can cause unexpected behavior.
310307
#if defined(__GNUG__) && __GNUC__ < 5
311308
#define C10_IS_TRIVIALLY_COPYABLE(T) __has_trivial_copy(T)
312309
#else

c10/util/SmallVector.h

-3
Original file line numberDiff line numberDiff line change
@@ -378,9 +378,6 @@ class SmallVectorTemplateBase<T, true> : public SmallVectorTemplateCommon<T> {
378378

379379
/// This class consists of common code factored out of the SmallVector class to
380380
/// reduce code duplication based on the SmallVector 'N' template parameter.
381-
/// Warning: C10_IS_TRIVIALLY_COPYABLE may not always detect non-POD
382-
/// type correctly. For example, std::unique_ptr may be treated as POD and cause
383-
/// memory leaks.
384381
template <typename T>
385382
class SmallVectorImpl
386383
: public SmallVectorTemplateBase<T, C10_IS_TRIVIALLY_COPYABLE(T)> {

test/cpp/jit/test_misc.cpp

-63
Original file line numberDiff line numberDiff line change
@@ -1036,69 +1036,6 @@ void testRecordFunction() {
10361036

10371037
clearCallbacks();
10381038

1039-
// START: thread local / global context check callbacks
1040-
struct TestContext : public ObserverContext {
1041-
int a{0};
1042-
std::string b;
1043-
};
1044-
ids.clear();
1045-
{ // START: global test
1046-
const int test_val = 123;
1047-
const std::string test_str = "test str";
1048-
addGlobalCallback(RecordFunctionCallback(
1049-
[test_str, &ids](const RecordFunction& /* unused */) {
1050-
auto ctx = std::make_unique<TestContext>();
1051-
ctx->a = test_val;
1052-
ctx->b = test_str;
1053-
ids.push_back(1);
1054-
return ctx;
1055-
},
1056-
[test_str](
1057-
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1058-
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1059-
TORCH_CHECK(ctx_ptr != nullptr);
1060-
TORCH_CHECK(ctx->a == test_val);
1061-
TORCH_CHECK(ctx->b == test_str);
1062-
}));
1063-
1064-
{ RECORD_USER_SCOPE("test"); }
1065-
1066-
TORCH_CHECK(ids.size() == 1);
1067-
TORCH_CHECK(ids[0] == 1);
1068-
ids.clear();
1069-
} // END: global test
1070-
{ // START: thread local test
1071-
auto ctx_th = std::thread([&ids]() {
1072-
const int test_val = 234;
1073-
const std::string test_str = "test thread str";
1074-
addThreadLocalCallback(RecordFunctionCallback(
1075-
[test_str, &ids](const RecordFunction& /* unused */) {
1076-
auto ctx = std::make_unique<TestContext>();
1077-
ctx->a = test_val;
1078-
ctx->b = test_str;
1079-
ids.push_back(2);
1080-
return ctx;
1081-
},
1082-
[test_str](
1083-
const RecordFunction& /* unused */, ObserverContext* ctx_ptr) {
1084-
auto ctx = dynamic_cast<TestContext*>(ctx_ptr);
1085-
TORCH_CHECK(ctx_ptr != nullptr);
1086-
TORCH_CHECK(ctx->a == test_val);
1087-
TORCH_CHECK(ctx->b == test_str);
1088-
}));
1089-
1090-
// Will call both global and thread local callbacks.
1091-
{ RECORD_USER_SCOPE("test_thread"); }
1092-
});
1093-
ctx_th.join();
1094-
TORCH_CHECK(ids.size() == 2);
1095-
TORCH_CHECK(std::find(ids.begin(), ids.end(), 1) != ids.end());
1096-
TORCH_CHECK(std::find(ids.begin(), ids.end(), 2) != ids.end());
1097-
ids.clear();
1098-
} // END: thread local test
1099-
1100-
clearCallbacks();
1101-
11021039
// test should_run
11031040

11041041
bool ran = false;

0 commit comments

Comments
 (0)