Skip to content

Commit a0d3f71

Browse files
andylytensorflower-gardener
authored andcommitted
Migrate constant sinking pass to use declarative pass registration instead of manually defined pass registration (NFC).
Pass documentation is also improved and migrated to the declarative pass spec. Pass is also renamed to better reflect what ops it operates on (tf_device.cluster). The pass does not operate on the tf_executor dialect. PiperOrigin-RevId: 348120524 Change-Id: Ic828b720a8a087548c162a1299aef1655b479cd6
1 parent 2235bb1 commit a0d3f71

File tree

4 files changed

+51
-18
lines changed

4 files changed

+51
-18
lines changed

tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
110110
pm.addPass(CreateTPUOutsideCompilationClusterPass());
111111
pm.addPass(CreateTPUExtractOutsideCompilationPass());
112112

113-
pm.addNestedPass<FuncOp>(tf_executor::CreateTFExecutorConstantSinkingPass());
113+
pm.addNestedPass<FuncOp>(TFDevice::CreateClusterConstantSinkingPass());
114114
pm.addPass(TF::CreateResourceDeviceInferencePass());
115115
pm.addPass(TFDevice::CreateClusterOutliningPass());
116116
pm.addPass(CreateTPUDynamicPaddingMapperPass());

tensorflow/compiler/mlir/tensorflow/transforms/passes.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -237,18 +237,18 @@ CreateTFExecutorTPUV1IslandInliningPass();
237237

238238
// Creates a pass to prune tf_executor.graph from dead nodes.
239239
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorGraphPruningPass();
240-
241-
// Sink `tf.Const` operations in the LaunchOp region using them. This is
242-
// performed in order to limit the number of values implicitly captured in this
243-
// region before outlining.
244-
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass();
245240
} // namespace tf_executor
246241

247242
namespace TFDevice {
248243
// Creates a pass that forms clusters from instructions that are assigned to
249244
// same device.
250245
std::unique_ptr<OperationPass<FuncOp>> CreateClusterFormationPass();
251246

247+
// Sinks `tf.Const` operations in the ClusterOp region using them. This is
248+
// performed in order to limit the number of values implicitly captured in this
249+
// region before outlining.
250+
std::unique_ptr<OperationPass<FuncOp>> CreateClusterConstantSinkingPass();
251+
252252
// Creates a pass that outlines regions of tf_device.launch operations.
253253
std::unique_ptr<OperationPass<ModuleOp>> CreateClusterOutliningPass();
254254

tensorflow/compiler/mlir/tensorflow/transforms/sink_constant.cc

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -25,21 +25,21 @@ limitations under the License.
2525
#include "mlir/Transforms/Passes.h" // from @llvm-project
2626
#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
2727
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
28-
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
2928
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
3029
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
30+
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes_detail.h"
3131
#include "tensorflow/compiler/mlir/tensorflow/utils/error_util.h"
3232

3333
#define DEBUG_TYPE "tf-executor-sink-constant"
3434

3535
namespace mlir {
36-
namespace tf_executor {
36+
namespace TFDevice {
3737

3838
namespace {
3939
using ::mlir::TF::ConstOp;
4040

41-
class ExecutorConstantSinking
42-
: public mlir::PassWrapper<ExecutorConstantSinking, FunctionPass> {
41+
class ClusterConstantSinkingPass
42+
: public TF::ClusterConstantSinkingPassBase<ClusterConstantSinkingPass> {
4343
void runOnFunction() override {
4444
getFunction().walk([](tf_device::ClusterOp cluster) {
4545
LLVM_DEBUG(llvm::dbgs() << "Visit " << *cluster.getOperation() << "\n");
@@ -82,16 +82,11 @@ class ExecutorConstantSinking
8282
}
8383
};
8484

85-
static mlir::PassRegistration<ExecutorConstantSinking> pass(
86-
"tf-device-constant-sinking",
87-
"Sink constants implicitly captured in a tf_device.cluster region. This "
88-
"reduces the number of arguments when outlining later.");
89-
9085
} // anonymous namespace
9186

92-
std::unique_ptr<OperationPass<FuncOp>> CreateTFExecutorConstantSinkingPass() {
93-
return std::make_unique<ExecutorConstantSinking>();
87+
std::unique_ptr<OperationPass<FuncOp>> CreateClusterConstantSinkingPass() {
88+
return std::make_unique<ClusterConstantSinkingPass>();
9489
}
9590

96-
} // namespace tf_executor
91+
} // namespace TFDevice
9792
} // namespace mlir

tensorflow/compiler/mlir/tensorflow/transforms/tf_passes.td

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,44 @@ func @tpu_computation(%arg0: tensor<i32>, %arg1: tensor<i32>) -> (tensor<i32>, t
202202
let constructor = "TFTPU::CreateTPUClusterFormationPass()";
203203
}
204204

205+
def ClusterConstantSinkingPass : FunctionPass<"tf-device-constant-sinking"> {
206+
let summary = "Sinks constants implicitly captured in a tf_device.cluster region.";
207+
208+
let description = [{
209+
This pass sinks implicitly captured constants (`tf.Const` ops) used by and into
210+
a `tf_device.cluster` region. Performing this prior to outlining will reduce the
211+
number of arguments of the outlined function.
212+
213+
For example, the following:
214+
215+
```mlir
216+
func @cluster() -> tensor<i32> {
217+
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
218+
%cluster = "tf_device.cluster"() ( {
219+
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
220+
tf_device.return %identity : tensor<i32>
221+
}) : () -> (tensor<i32>)
222+
return %cluster : tensor<i32>
223+
}
224+
```
225+
226+
will be transformed into:
227+
228+
```mlir
229+
func @cluster() -> tensor<i32> {
230+
%cluster = "tf_device.cluster"() ( {
231+
%const = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
232+
%identity = "tf.Identity"(%const) : (tensor<i32>) -> tensor<i32>
233+
tf_device.return %identity : tensor<i32>
234+
}) : () -> (tensor<i32>)
235+
return %cluster : tensor<i32>
236+
}
237+
```
238+
}];
239+
240+
let constructor = "TFDevice::CreateClusterConstantSinkingPass()";
241+
}
242+
205243
def TPUExtractOutsideCompilationPass : Pass<"tf-tpu-extract-outside-compilation", "ModuleOp"> {
206244
let summary = "Extracts TPU outside compilation computation to a separate tf_device.parallel_execute region.";
207245

0 commit comments

Comments
 (0)