Skip to content

Commit 0348f3f

Browse files
zlsh80826jeng1220
andauthored
Enhance several unit tests (#62477) (#62776)
* Manually release predictor_tuned * Add indices to no_cast_list to keep it as fp32 * Set both atol and rtol for the fp16 test_trt_convert_solve * Merge branch 'rewang/fix_test_sparse_fused_attention_seed' into 'nv-2.6.0' Fix test_sparse_fused_attention random seed See merge request dl/paddle/paddle!312 --------- Signed-off-by: rewang <[email protected]> Co-authored-by: Ryan Jeng <[email protected]>
1 parent 97ffa07 commit 0348f3f

File tree

4 files changed

+8
-1
lines changed

4 files changed

+8
-1
lines changed

test/cpp/inference/api/trt_dynamic_shape_test.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ void TestTunedDynamic() {
191191
output_t->copy_to_cpu(out_data.data());
192192
};
193193
check_func(predictor_tuned.get());
194+
predictor_tuned.reset(nullptr);
194195

195196
// check tuned_dynamic_shape
196197
AnalysisConfig config;

test/ir/inference/test_trt_convert_lookup_table.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ def generate_input2(dims, attrs: List[Dict[str, Any]]):
8080
)
8181
},
8282
outputs=["out_data"],
83+
no_cast_list=["indices"],
8384
)
8485

8586
yield program_config

test/ir/inference/test_trt_convert_solve.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ def clear_dynamic_shape():
9090
yield self.create_inference_config(), (1, 3), 1e-5
9191

9292
self.trt_param.precision = paddle_infer.PrecisionType.Half
93-
yield self.create_inference_config(), (1, 3), 1e-3
93+
yield self.create_inference_config(), (1, 3), (1e-3, 1e-3)
9494

9595
def test(self):
9696
self.run_test()

test/legacy_test/test_sparse_fused_attention_op.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ def get_cuda_version():
4242
)
4343
class TestSparseAttentionAPI1(unittest.TestCase):
4444
def setUp(self):
45+
paddle.seed(0)
4546
self.batch_size = 16
4647
self.num_heads = 16
4748
self.seq_len = 128
@@ -134,6 +135,7 @@ def test_dygraph(self):
134135

135136
class TestSparseAttentionAPI2(TestSparseAttentionAPI1):
136137
def setUp(self):
138+
super().setUp()
137139
self.batch_size = 16
138140
self.num_heads = 16
139141
self.seq_len = 128
@@ -144,6 +146,7 @@ def setUp(self):
144146

145147
class TestSparseAttentionAPI3(TestSparseAttentionAPI1):
146148
def setUp(self):
149+
super().setUp()
147150
self.batch_size = 16
148151
self.num_heads = 16
149152
self.seq_len = 512
@@ -154,6 +157,7 @@ def setUp(self):
154157

155158
class TestSparseAttentionAPI4(TestSparseAttentionAPI1):
156159
def setUp(self):
160+
super().setUp()
157161
self.batch_size = 16
158162
self.num_heads = 16
159163
self.seq_len = 512
@@ -164,6 +168,7 @@ def setUp(self):
164168

165169
class TestSparseAttentionAPI5(TestSparseAttentionAPI1):
166170
def setUp(self):
171+
super().setUp()
167172
self.batch_size = 16
168173
self.num_heads = 16
169174
self.seq_len = 512

0 commit comments

Comments
 (0)