|
1 | 1 | #include "THCSTensor.h" |
2 | 2 | #include "THCApply.cuh" |
| 3 | +#include "THCTensorSort.cuh" |
3 | 4 | #include "THCTensorMathPointwise.cuh" |
4 | 5 | #include "stdio.h" |
5 | 6 |
|
| 7 | +const int WARP_SIZE = 32; |
| 8 | + |
6 | 9 | template <typename IndexType, typename Real, typename Op> |
7 | 10 | __device__ void applyOp2( |
8 | 11 | Op op, IndexType blockSize, |
@@ -235,6 +238,81 @@ __global__ void THCSTensor_indexSparseIntersectionKernel( |
235 | 238 | *resultNnz = r_i; |
236 | 239 | } |
237 | 240 |
|
| 241 | +// template <typename Dtype, typename Acctype> |
| 242 | +// __global__ void THCSTensor_coalesceValuesKernel_gridStrided( |
| 243 | +// long *segment_offsets, long *value_indices, |
| 244 | +// Dtype *values, Dtype *newValues, |
| 245 | +// long nnz, long newNnz, long stride) { |
| 246 | +// |
| 247 | +// long chunksPerSeg = THCCeilDiv(stride, (long) blockDim.x); |
| 248 | +// long numChunks = newNnz * chunksPerSeg; |
| 249 | +// long chunkOffset = blockIdx.x * blockDim.y + threadIdx.y; |
| 250 | +// long chunkStride = gridDim.x * blockDim.y; |
| 251 | +// |
| 252 | +// for (long chunk = chunkOffset; chunk < numChunks; chunk += chunkStride) { |
| 253 | +// long featureDim = (chunk % chunksPerSeg) * blockDim.x + threadIdx.x; |
| 254 | +// if (featureDim < stride) { |
| 255 | +// auto valFeat = values + featureDim; |
| 256 | +// long seg = chunk / chunksPerSeg; |
| 257 | +// auto begin = segment_offsets[seg]; |
| 258 | +// auto end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; |
| 259 | +// Acctype valSum = ScalarConvert<float, Acctype>::to(0); |
| 260 | +// for (long valIdx = begin; valIdx < end; valIdx++) { |
| 261 | +// const long valRow = value_indices[valIdx] * stride; |
| 262 | +// valSum += ScalarConvert<Dtype, Acctype>::to(valFeat[valRow]); |
| 263 | +// } |
| 264 | +// newValues[seg * stride + featureDim] = ScalarConvert<Acctype, Dtype>::to(valSum); |
| 265 | +// } |
| 266 | +// } |
| 267 | +// } |
| 268 | + |
| 269 | +template <typename Dtype, typename Acctype> |
| 270 | +__global__ void THCSTensor_coalesceValuesKernel( |
| 271 | + long *segment_offsets, long *value_indices, |
| 272 | + Dtype *values, Dtype *newValues, |
| 273 | + long nnz, long newNnz, long stride) { |
| 274 | + |
| 275 | + int seg = blockIdx.x * 4 + threadIdx.y; |
| 276 | + |
| 277 | + // Number of values proceessed by each thread (grain size) |
| 278 | + const int SZ = 4; |
| 279 | + |
| 280 | + if (seg < newNnz) { |
| 281 | + const int newValueRow = seg * stride; |
| 282 | + const int begin = segment_offsets[seg]; |
| 283 | + const int end = (seg < newNnz - 1) ? segment_offsets[seg + 1] : nnz; |
| 284 | + const int startFeature = threadIdx.x + blockIdx.y * blockDim.x * SZ; |
| 285 | + Acctype tmp[SZ]; |
| 286 | + #pragma unroll |
| 287 | + for (int ii = 0; ii < SZ; ii++) { |
| 288 | + tmp[ii] = ScalarConvert<float, Acctype>::to(0); |
| 289 | + } |
| 290 | + for (int row = begin; row < end; row++) { |
| 291 | + const int valueRow = ((int) value_indices[row]) * stride; |
| 292 | + |
| 293 | + |
| 294 | + #pragma unroll |
| 295 | + for (int ii = 0; ii < SZ; ii++) |
| 296 | + { |
| 297 | + int featureDim = startFeature + ii * WARP_SIZE; |
| 298 | + if (featureDim < stride) |
| 299 | + { |
| 300 | + tmp[ii] += ScalarConvert<Dtype, Acctype>::to(values[valueRow + featureDim]); |
| 301 | + } |
| 302 | + } |
| 303 | + } |
| 304 | + #pragma unroll |
| 305 | + for (int ii = 0; ii < SZ; ii++) |
| 306 | + { |
| 307 | + int featureDim = startFeature + ii * WARP_SIZE; |
| 308 | + if (featureDim < stride) |
| 309 | + { |
| 310 | + newValues[newValueRow + featureDim] = ScalarConvert<Acctype, Dtype>::to(tmp[ii]); |
| 311 | + } |
| 312 | + } |
| 313 | + } |
| 314 | +} |
| 315 | + |
238 | 316 | #include "generic/THCSTensor.cu" |
239 | 317 | #include "THCSGenerateAllTypes.h" |
240 | 318 |
|
|
0 commit comments