@@ -341,7 +341,27 @@ INSTANTIATE_TEST_SUITE_P(
341341 return name;
342342 });
343343
344- class WhileTest : public subgraph_test_util ::ControlFlowOpTest {};
344+ class WhileTest : public subgraph_test_util ::ControlFlowOpTest {
345+ protected:
346+ TfLiteCustomAllocation NewCustomAlloc (size_t num_bytes,
347+ int required_alignment) {
348+ // Extra memory to ensure alignment.
349+ char * new_alloc = new char [num_bytes + required_alignment];
350+ char * new_underlying_buffer_aligned_ptr = reinterpret_cast <char *>(
351+ AlignTo (required_alignment, reinterpret_cast <intptr_t >(new_alloc)));
352+ custom_alloc_buffers_.emplace_back (new_alloc);
353+
354+ return TfLiteCustomAllocation (
355+ {new_underlying_buffer_aligned_ptr, num_bytes});
356+ }
357+
358+ intptr_t AlignTo (size_t alignment, intptr_t offset) {
359+ return offset % alignment == 0 ? offset
360+ : offset + (alignment - offset % alignment);
361+ }
362+
363+ std::vector<std::unique_ptr<char []>> custom_alloc_buffers_;
364+ };
345365
346366// The test builds a model that produces the i-th number of
347367// triangular number sequence: 1, 3, 6, 10, 15, 21, 28.
@@ -359,7 +379,15 @@ TEST_F(WhileTest, TestTriangularNumberSequence) {
359379 interpreter_->ResizeInputTensor (interpreter_->inputs ()[1 ], {1 });
360380 ASSERT_EQ (interpreter_->AllocateTensors (), kTfLiteOk );
361381 FillIntTensor (interpreter_->tensor (interpreter_->inputs ()[0 ]), {1 });
362- FillIntTensor (interpreter_->tensor (interpreter_->inputs ()[1 ]), {1 });
382+
383+ // Use custom allocation for second input, to ensure things work well for
384+ // non-traditional allocation types.
385+ auto alloc =
386+ NewCustomAlloc (interpreter_->tensor (interpreter_->inputs ()[1 ])->bytes ,
387+ kDefaultTensorAlignment );
388+ auto * input_data = reinterpret_cast <int *>(alloc.data );
389+ input_data[0 ] = 1 ;
390+ interpreter_->SetCustomAllocationForTensor (interpreter_->inputs ()[1 ], alloc);
363391
364392 ASSERT_EQ (interpreter_->Invoke (), kTfLiteOk );
365393 TfLiteTensor* output1 = interpreter_->tensor (interpreter_->outputs ()[0 ]);
0 commit comments