Skip to content

Commit c0c65bf

Browse files
authored
Merge pull request pytorch#696 from colesbury/unsqueeze
Add unsqueeze to THC
2 parents ed8e92f + 3884d36 commit c0c65bf

File tree

2 files changed

+28
-0
lines changed

2 files changed

+28
-0
lines changed

generic/THCTensor.c

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -516,6 +516,33 @@ void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int
516516
}
517517
}
518518

519+
void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension)
520+
{
521+
int d;
522+
523+
if(!src)
524+
src = self;
525+
526+
THArgCheck((dimension >= 0) && (dimension <= src->nDimension), 3, "dimension out of range");
527+
THArgCheck(src->nDimension > 0, 3, "cannot unsqueeze empty tensor");
528+
529+
THCTensor_(set)(state, self, src);
530+
531+
self->size = (long*)THRealloc(self->size, sizeof(long)*(self->nDimension+1));
532+
self->stride = (long*)THRealloc(self->stride, sizeof(long)*(self->nDimension+1));
533+
self->nDimension++;
534+
for (d = self->nDimension-1; d > dimension; d--) {
535+
self->size[d] = self->size[d-1];
536+
self->stride[d] = self->stride[d-1];
537+
}
538+
if (dimension+1 < self->nDimension) {
539+
self->stride[dimension] = self->size[dimension+1] * self->stride[dimension+1];
540+
} else {
541+
self->stride[dimension] = 1;
542+
}
543+
self->size[dimension] = 1;
544+
}
545+
519546
int THCTensor_(isContiguous)(THCState *state, const THCTensor *self)
520547
{
521548
long z = 1;

generic/THCTensor.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,7 @@ THC_API void THCTensor_(unfold)(THCState *state, THCTensor *self, THCTensor *src
101101

102102
THC_API void THCTensor_(squeeze)(THCState *state, THCTensor *self, THCTensor *src);
103103
THC_API void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_);
104+
THC_API void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_);
104105

105106
THC_API int THCTensor_(isContiguous)(THCState *state, const THCTensor *self);
106107
THC_API int THCTensor_(isSameSizeAs)(THCState *state, const THCTensor *self, const THCTensor *src);

0 commit comments

Comments
 (0)