Skip to content

Commit a4edaec

Browse files
committed
Merge commit 'aeb7a72620be47c0e6a8928a9cb6df49c06902a0'
2 parents 92481b5 + aeb7a72 commit a4edaec

File tree

5 files changed

+119
-68
lines changed

5 files changed

+119
-68
lines changed

torch/lib/THC/CMakeLists.txt

Lines changed: 34 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ CMAKE_POLICY(VERSION 2.8)
33

44
SET(CMAKE_MODULE_PATH ${CMAKE_CURRENT_SOURCE_DIR}/cmake ${CMAKE_MODULE_PATH})
55

6+
SET(CUDA_ATTACH_VS_BUILD_RULE_TO_CUDA_FILE OFF)
67
OPTION(NDEBUG "disable asserts (WARNING: this may result in invalid memory accesses)")
78
IF(NOT NDEBUG)
89
MESSAGE(STATUS "Removing -DNDEBUG from compile flags")
@@ -59,6 +60,10 @@ ENDIF()
5960
INCLUDE_DIRECTORIES(${CUDA_INCLUDE_DIRS})
6061
INCLUDE_DIRECTORIES("${CUDA_SDK_ROOT_DIR}/common/inc")
6162

63+
IF ("$ENV{STATIC_TH}" STREQUAL "YES")
64+
LIST(APPEND CUDA_NVCC_FLAGS "-Xcompiler -fPIC")
65+
ENDIF()
66+
6267
IF(MAGMA_FOUND)
6368
INCLUDE_DIRECTORIES(${MAGMA_INCLUDE_DIR})
6469
SET(CMAKE_REQUIRED_INCLUDES "${MAGMA_INCLUDE_DIR};${CUDA_INCLUDE_DIRS}")
@@ -130,9 +135,9 @@ IF(NOT THC_INSTALL_BIN_SUBDIR
130135
SET(THC_INSTALL_CMAKE_SUBDIR ${Torch_INSTALL_CMAKE_SUBDIR})
131136
ELSE(Torch_INSTALL_BIN_SUBDIR)
132137
# not installing in a Torch context, so Torch_INSTALL_BIN_SUBDIR is not available
133-
SET(THC_INSTALL_BIN_SUBDIR "bin" CACHE PATH "THC install binary subdirectory")
134-
SET(THC_INSTALL_LIB_SUBDIR "lib" CACHE PATH "THC install library subdirectory")
135-
SET(THC_INSTALL_INCLUDE_SUBDIR "include" CACHE PATH "THC install include subdirectory")
138+
SET(THC_INSTALL_BIN_SUBDIR "bin" CACHE PATH "THC install binary subdirectory")
139+
SET(THC_INSTALL_LIB_SUBDIR "lib" CACHE PATH "THC install library subdirectory")
140+
SET(THC_INSTALL_INCLUDE_SUBDIR "include" CACHE PATH "THC install include subdirectory")
136141
SET(THC_INSTALL_CMAKE_SUBDIR "share/cmake/THC" CACHE PATH "THC install cmake subdirectory")
137142
ENDIF(Torch_INSTALL_BIN_SUBDIR)
138143

@@ -208,28 +213,33 @@ ELSE(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
208213
ENDIF(CUDA_HAS_FP16 OR NOT ${CUDA_VERSION} LESS 7.5)
209214

210215
MESSAGE(STATUS "CUDA_NVCC_FLAGS: ${CUDA_NVCC_FLAGS}")
216+
IF ("$ENV{STATIC_TH}" STREQUAL "YES")
217+
CUDA_ADD_LIBRARY(THC STATIC ${src} ${src-cuda})
218+
SET_TARGET_PROPERTIES(THC PROPERTIES COMPILE_FLAGS "-fPIC")
219+
ELSE()
220+
CUDA_ADD_LIBRARY(THC SHARED ${src} ${src-cuda})
221+
CUDA_ADD_CUBLAS_TO_TARGET(THC)
222+
TARGET_LINK_LIBRARIES(THC ${TH_LIBRARIES} ${CUDA_curand_LIBRARY})
223+
224+
IF(USE_MAGMA)
225+
TARGET_LINK_LIBRARIES(THC ${MAGMA_LIBRARIES} ${CUDA_cusparse_LIBRARY})
226+
ENDIF(USE_MAGMA)
227+
228+
IF(NOT THC_SO_VERSION)
229+
SET(THC_SO_VERSION 0)
230+
ENDIF(NOT THC_SO_VERSION)
231+
MESSAGE(STATUS "THC_SO_VERSION: ${THC_SO_VERSION}")
232+
SET_TARGET_PROPERTIES(THC PROPERTIES
233+
VERSION ${THC_SO_VERSION}
234+
SOVERSION ${THC_SO_VERSION})
235+
236+
237+
INSTALL(TARGETS THC
238+
RUNTIME DESTINATION "${THC_INSTALL_BIN_SUBDIR}"
239+
LIBRARY DESTINATION "${THC_INSTALL_LIB_SUBDIR}"
240+
ARCHIVE DESTINATION "${THC_INSTALL_LIB_SUBDIR}")
241+
ENDIF()
211242

212-
CUDA_ADD_LIBRARY(THC SHARED ${src} ${src-cuda})
213-
CUDA_ADD_CUBLAS_TO_TARGET(THC)
214-
TARGET_LINK_LIBRARIES(THC ${TH_LIBRARIES} ${CUDA_curand_LIBRARY})
215-
216-
IF(USE_MAGMA)
217-
TARGET_LINK_LIBRARIES(THC ${MAGMA_LIBRARIES} ${CUDA_cusparse_LIBRARY})
218-
ENDIF(USE_MAGMA)
219-
220-
IF(NOT THC_SO_VERSION)
221-
SET(THC_SO_VERSION 0)
222-
ENDIF(NOT THC_SO_VERSION)
223-
MESSAGE(STATUS "THC_SO_VERSION: ${THC_SO_VERSION}")
224-
SET_TARGET_PROPERTIES(THC PROPERTIES
225-
VERSION ${THC_SO_VERSION}
226-
SOVERSION ${THC_SO_VERSION})
227-
228-
229-
INSTALL(TARGETS THC
230-
RUNTIME DESTINATION "${THC_INSTALL_BIN_SUBDIR}"
231-
LIBRARY DESTINATION "${THC_INSTALL_LIB_SUBDIR}"
232-
ARCHIVE DESTINATION "${THC_INSTALL_LIB_SUBDIR}")
233243

234244
INSTALL(FILES
235245
THC.h

torch/lib/THC/generic/THCTensor.c

Lines changed: 76 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@ void THCTensor_(clearFlag)(THCState *state, THCTensor *self, const char flag)
6565
/**** creation methods ****/
6666

6767
static void THCTensor_(rawInit)(THCState *state, THCTensor *self);
68-
static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride);
6968

7069

7170
/* Empty init */
@@ -81,13 +80,13 @@ THCTensor *THCTensor_(newWithTensor)(THCState *state, THCTensor *tensor)
8180
{
8281
THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor));
8382
THCTensor_(rawInit)(state, self);
84-
THCTensor_(rawSet)(state,
85-
self,
86-
tensor->storage,
87-
tensor->storageOffset,
88-
tensor->nDimension,
89-
tensor->size,
90-
tensor->stride);
83+
THCTensor_(setStorageNd)(state,
84+
self,
85+
tensor->storage,
86+
tensor->storageOffset,
87+
tensor->nDimension,
88+
tensor->size,
89+
tensor->stride);
9190
return self;
9291
}
9392

@@ -99,13 +98,13 @@ THCTensor *THCTensor_(newWithStorage)(THCState *state, THCStorage *storage, ptrd
9998
THArgCheck(size->size == stride->size, 4, "inconsistent size");
10099

101100
THCTensor_(rawInit)(state, self);
102-
THCTensor_(rawSet)(state,
103-
self,
104-
storage,
105-
storageOffset,
106-
(size ? size->size : (stride ? stride->size : 0)),
107-
(size ? size->data : NULL),
108-
(stride ? stride->data : NULL));
101+
THCTensor_(setStorageNd)(state,
102+
self,
103+
storage,
104+
storageOffset,
105+
(size ? size->size : (stride ? stride->size : 0)),
106+
(size ? size->data : NULL),
107+
(stride ? stride->data : NULL));
109108

110109
return self;
111110
}
@@ -141,7 +140,7 @@ THCTensor *THCTensor_(newWithStorage4d)(THCState *state, THCStorage *storage, pt
141140

142141
THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor));
143142
THCTensor_(rawInit)(state, self);
144-
THCTensor_(rawSet)(state, self, storage, storageOffset, 4, size, stride);
143+
THCTensor_(setStorageNd)(state, self, storage, storageOffset, 4, size, stride);
145144

146145
return self;
147146
}
@@ -172,7 +171,7 @@ THCTensor *THCTensor_(newWithSize4d)(THCState *state, long size0, long size1, lo
172171

173172
THCTensor *self = (THCTensor*)THAlloc(sizeof(THCTensor));
174173
THCTensor_(rawInit)(state, self);
175-
THCTensor_(rawResize)(state, self, 4, size, NULL);
174+
THCTensor_(resizeNd)(state, self, 4, size, NULL);
176175

177176
return self;
178177
}
@@ -224,14 +223,25 @@ THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimensi
224223
return self;
225224
}
226225

226+
THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size)
227+
{
228+
THArgCheck(THCTensor_(isContiguous)(state, tensor), 2, "input is not contiguous");
229+
ptrdiff_t numel = THCTensor_(nElement)(state, tensor);
230+
THCTensor *self = THCTensor_(new)(state);
231+
THLongStorage *inferred_size = THLongStorage_newInferSize(size, numel);
232+
THCTensor_(setStorage)(state, self, tensor->storage, tensor->storageOffset, inferred_size, NULL);
233+
THLongStorage_free(inferred_size);
234+
return self;
235+
}
236+
227237
/* Resize */
228238
void THCTensor_(resize)(THCState *state, THCTensor *self, THLongStorage *size, THLongStorage *stride)
229239
{
230240
THArgCheck(size != NULL, 2, "invalid size");
231241
if(stride)
232242
THArgCheck(stride->size == size->size, 3, "invalid stride");
233243

234-
THCTensor_(rawResize)(state, self, size->size, size->data, (stride ? stride->data : NULL));
244+
THCTensor_(resizeNd)(state, self, size->size, size->data, (stride ? stride->data : NULL));
235245
}
236246

237247
void THCTensor_(resizeAs)(THCState *state, THCTensor *self, THCTensor *src)
@@ -252,7 +262,7 @@ void THCTensor_(resizeAs)(THCState *state, THCTensor *self, THCTensor *src)
252262
}
253263

254264
if(!isSame)
255-
THCTensor_(rawResize)(state, self, src->nDimension, src->size, NULL);
265+
THCTensor_(resizeNd)(state, self, src->nDimension, src->size, NULL);
256266
}
257267

258268
void THCTensor_(resize1d)(THCState *state, THCTensor *tensor, long size0)
@@ -274,40 +284,40 @@ void THCTensor_(resize4d)(THCState *state, THCTensor *self, long size0, long siz
274284
{
275285
long size[4] = {size0, size1, size2, size3};
276286

277-
THCTensor_(rawResize)(state, self, 4, size, NULL);
287+
THCTensor_(resizeNd)(state, self, 4, size, NULL);
278288
}
279289

280290
void THCTensor_(resize5d)(THCState *state, THCTensor *self, long size0, long size1, long size2, long size3, long size4)
281291
{
282292
long size[5] = {size0, size1, size2, size3, size4};
283293

284-
THCTensor_(rawResize)(state, self, 5, size, NULL);
294+
THCTensor_(resizeNd)(state, self, 5, size, NULL);
285295
}
286296

287297
void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src)
288298
{
289299
if(self != src)
290-
THCTensor_(rawSet)(state,
291-
self,
292-
src->storage,
293-
src->storageOffset,
294-
src->nDimension,
295-
src->size,
296-
src->stride);
300+
THCTensor_(setStorageNd)(state,
301+
self,
302+
src->storage,
303+
src->storageOffset,
304+
src->nDimension,
305+
src->size,
306+
src->stride);
297307
}
298308

299309
void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_)
300310
{
301311
if(size_ && stride_)
302312
THArgCheck(size_->size == stride_->size, 5, "inconsistent size/stride sizes");
303313

304-
THCTensor_(rawSet)(state,
305-
self,
306-
storage_,
307-
storageOffset_,
308-
(size_ ? size_->size : (stride_ ? stride_->size : 0)),
309-
(size_ ? size_->data : NULL),
310-
(stride_ ? stride_->data : NULL));
314+
THCTensor_(setStorageNd)(state,
315+
self,
316+
storage_,
317+
storageOffset_,
318+
(size_ ? size_->size : (stride_ ? stride_->size : 0)),
319+
(size_ ? size_->data : NULL),
320+
(stride_ ? stride_->data : NULL));
311321
}
312322

313323
void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
@@ -353,7 +363,7 @@ void THCTensor_(setStorage4d)(THCState *state, THCTensor *self, THCStorage *stor
353363
long size[4] = {size0_, size1_, size2_, size3_};
354364
long stride[4] = {stride0_, stride1_, stride2_, stride3_};
355365

356-
THCTensor_(rawSet)(state, self, storage_, storageOffset_, 4, size, stride);
366+
THCTensor_(setStorageNd)(state, self, storage_, storageOffset_, 4, size, stride);
357367
}
358368

359369

@@ -517,6 +527,33 @@ void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int
517527
}
518528
}
519529

530+
void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension)
531+
{
532+
int d;
533+
534+
if(!src)
535+
src = self;
536+
537+
THArgCheck((dimension >= 0) && (dimension <= src->nDimension), 3, "dimension out of range");
538+
THArgCheck(src->nDimension > 0, 3, "cannot unsqueeze empty tensor");
539+
540+
THCTensor_(set)(state, self, src);
541+
542+
self->size = (long*)THRealloc(self->size, sizeof(long)*(self->nDimension+1));
543+
self->stride = (long*)THRealloc(self->stride, sizeof(long)*(self->nDimension+1));
544+
self->nDimension++;
545+
for (d = self->nDimension-1; d > dimension; d--) {
546+
self->size[d] = self->size[d-1];
547+
self->stride[d] = self->stride[d-1];
548+
}
549+
if (dimension+1 < self->nDimension) {
550+
self->stride[dimension] = self->size[dimension+1] * self->stride[dimension+1];
551+
} else {
552+
self->stride[dimension] = 1;
553+
}
554+
self->size[dimension] = 1;
555+
}
556+
520557
int THCTensor_(isContiguous)(THCState *state, const THCTensor *self)
521558
{
522559
long z = 1;
@@ -637,7 +674,7 @@ static void THCTensor_(rawInit)(THCState *state, THCTensor *self)
637674
self->flag = TH_TENSOR_REFCOUNTED;
638675
}
639676

640-
static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride)
677+
void THCTensor_(setStorageNd)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride)
641678
{
642679
/* storage */
643680
if(self->storage != storage)
@@ -660,10 +697,10 @@ static void THCTensor_(rawSet)(THCState *state, THCTensor *self, THCStorage *sto
660697
self->storageOffset = storageOffset;
661698

662699
/* size and stride */
663-
THCTensor_(rawResize)(state, self, nDimension, size, stride);
700+
THCTensor_(resizeNd)(state, self, nDimension, size, stride);
664701
}
665702

666-
void THCTensor_(rawResize)(THCState *state, THCTensor *self, int nDimension, long *size, long *stride)
703+
void THCTensor_(resizeNd)(THCState *state, THCTensor *self, int nDimension, long *size, long *stride)
667704
{
668705
int d;
669706
int nDimension_;

torch/lib/THC/generic/THCTensor.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,8 @@ THC_API THCTensor *THCTensor_(newSelect)(THCState *state, THCTensor *tensor, int
6666
THC_API THCTensor *THCTensor_(newNarrow)(THCState *state, THCTensor *tensor, int dimension_, long firstIndex_, long size_);
6767
THC_API THCTensor *THCTensor_(newTranspose)(THCState *state, THCTensor *tensor, int dimension1_, int dimension2_);
6868
THC_API THCTensor *THCTensor_(newUnfold)(THCState *state, THCTensor *tensor, int dimension_, long size_, long step_);
69+
THC_API THCTensor *THCTensor_(newView)(THCState *state, THCTensor *tensor, THLongStorage *size);
70+
6971

7072
THC_API void THCTensor_(resize)(THCState *state, THCTensor *tensor, THLongStorage *size, THLongStorage *stride);
7173
THC_API void THCTensor_(resizeAs)(THCState *state, THCTensor *tensor, THCTensor *src);
@@ -74,10 +76,11 @@ THC_API void THCTensor_(resize2d)(THCState *state, THCTensor *tensor, long size0
7476
THC_API void THCTensor_(resize3d)(THCState *state, THCTensor *tensor, long size0_, long size1_, long size2_);
7577
THC_API void THCTensor_(resize4d)(THCState *state, THCTensor *tensor, long size0_, long size1_, long size2_, long size3_);
7678
THC_API void THCTensor_(resize5d)(THCState *state, THCTensor *tensor, long size0_, long size1_, long size2_, long size3_, long size4_);
77-
THC_API void THCTensor_(rawResize)(THCState *state, THCTensor *self, int nDimension, long *size, long *stride);
79+
THC_API void THCTensor_(resizeNd)(THCState *state, THCTensor *tensor, int nDimension, long *size, long *stride);
7880

7981
THC_API void THCTensor_(set)(THCState *state, THCTensor *self, THCTensor *src);
8082
THC_API void THCTensor_(setStorage)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_, THLongStorage *size_, THLongStorage *stride_);
83+
THC_API void THCTensor_(setStorageNd)(THCState *state, THCTensor *self, THCStorage *storage, ptrdiff_t storageOffset, int nDimension, long *size, long *stride);
8184
THC_API void THCTensor_(setStorage1d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
8285
long size0_, long stride0_);
8386
THC_API void THCTensor_(setStorage2d)(THCState *state, THCTensor *self, THCStorage *storage_, ptrdiff_t storageOffset_,
@@ -100,6 +103,7 @@ THC_API void THCTensor_(unfold)(THCState *state, THCTensor *self, THCTensor *src
100103

101104
THC_API void THCTensor_(squeeze)(THCState *state, THCTensor *self, THCTensor *src);
102105
THC_API void THCTensor_(squeeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_);
106+
THC_API void THCTensor_(unsqueeze1d)(THCState *state, THCTensor *self, THCTensor *src, int dimension_);
103107

104108
THC_API int THCTensor_(isContiguous)(THCState *state, const THCTensor *self);
105109
THC_API int THCTensor_(isSameSizeAs)(THCState *state, const THCTensor *self, const THCTensor *src);

torch/lib/THC/generic/THCTensorMathBlas.cu

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ __global__ void createBatchGemmBuffer(const real** buffer, real* data,
424424
const long idx = blockIdx.x * blockDim.x + threadIdx.x;
425425
if (idx < num_batches) {
426426
buffer[idx] = data + idx * stride;
427-
}
427+
}
428428
}
429429

430430
THC_API void

torch/lib/THC/generic/THCTensorMathMagma.cu

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ static void THCTensor_(copyArray1d)(THCState *state, THCTensor *self, real *src,
1010
{
1111
long size[1] = { k };
1212
long stride[1] = { 1 };
13-
THCTensor_(rawResize)(state, self, 1, size, stride);
13+
THCTensor_(resizeNd)(state, self, 1, size, stride);
1414
size_t len = k * sizeof(real);
1515
THCudaCheck(cudaMemcpy(self->storage->data + self->storageOffset, src, len, cudaMemcpyHostToDevice));
1616
}
@@ -19,7 +19,7 @@ static void THCTensor_(copyArray2d)(THCState *state, THCTensor *self, real *src,
1919
{
2020
long size[2] = { m, n };
2121
long stride[2] = { 1, m };
22-
THCTensor_(rawResize)(state, self, 2, size, stride);
22+
THCTensor_(resizeNd)(state, self, 2, size, stride);
2323
size_t len = m * n * sizeof(real);
2424
THCudaCheck(cudaMemcpy(self->storage->data + self->storageOffset, src, len, cudaMemcpyHostToDevice));
2525
}
@@ -54,7 +54,7 @@ static THCTensor* THCTensor_(newColumnMajor)(THCState *state, THCTensor *self, T
5454
long size[2] = { src->size[0], src->size[1] };
5555
long stride[2] = { 1, src->size[0] };
5656

57-
THCTensor_(rawResize)(state, self, 2, size, stride);
57+
THCTensor_(resizeNd)(state, self, 2, size, stride);
5858
THCTensor_(copy)(state, self, src);
5959
return self;
6060
}

0 commit comments

Comments
 (0)