@@ -1503,48 +1503,50 @@ private static void matrixMultUltraSparse(MatrixBlock m1, MatrixBlock m2, Matrix
1503
1503
1504
1504
for ( int i =rl ; i <ru ; i ++ )
1505
1505
{
1506
- if ( !a .isEmpty (i ) )
1507
- {
1508
- int apos = a .pos (i );
1509
- int alen = a .size (i );
1510
- int [] aixs = a .indexes (i );
1511
- double [] avals = a .values (i );
1512
-
1513
- if ( alen ==1 && avals [apos ]==1 ) //ROW SELECTION (no aggregation)
1514
- {
1515
- int aix = aixs [apos ];
1516
- if ( rightSparse ) { //sparse right matrix (full row copy)
1517
- if ( !m2 .sparseBlock .isEmpty (aix ) ) {
1518
- ret .rlen =m ;
1519
- ret .allocateSparseRowsBlock (false ); //allocation on demand
1520
- boolean ldeep = (m2 .sparseBlock instanceof SparseBlockMCSR );
1521
- ret .sparseBlock .set (i , m2 .sparseBlock .get (aix ), ldeep );
1522
- ret .nonZeros += ret .sparseBlock .size (i );
1523
- }
1506
+ if ( a .isEmpty (i ) ) continue ;
1507
+ int apos = a .pos (i );
1508
+ int alen = a .size (i );
1509
+ int [] aixs = a .indexes (i );
1510
+ double [] avals = a .values (i );
1511
+
1512
+ if ( alen ==1 ) {
1513
+ //row selection (now aggregation) with potential scaling
1514
+ int aix = aixs [apos ];
1515
+ if ( rightSparse ) { //sparse right matrix (full row copy)
1516
+ if ( !m2 .sparseBlock .isEmpty (aix ) ) {
1517
+ ret .rlen =m ;
1518
+ ret .allocateSparseRowsBlock (false ); //allocation on demand
1519
+ boolean ldeep = (m2 .sparseBlock instanceof SparseBlockMCSR );
1520
+ ret .sparseBlock .set (i , m2 .sparseBlock .get (aix ), ldeep );
1521
+ ret .nonZeros += ret .sparseBlock .size (i );
1524
1522
}
1525
- else { //dense right matrix (append all values)
1526
- int lnnz = (int )m2 .recomputeNonZeros (aix , aix , 0 , n -1 );
1527
- if ( lnnz > 0 ) {
1528
- c .allocate (i , lnnz ); //allocate once
1529
- for ( int j =0 ; j <n ; j ++ )
1530
- c .append (i , j , m2 .quickGetValue (aix , j ));
1531
- ret .nonZeros += lnnz ;
1532
- }
1523
+ }
1524
+ else { //dense right matrix (append all values)
1525
+ int lnnz = (int )m2 .recomputeNonZeros (aix , aix , 0 , n -1 );
1526
+ if ( lnnz > 0 ) {
1527
+ c .allocate (i , lnnz ); //allocate once
1528
+ double [] bvals = m2 .getDenseBlock ().values (aix );
1529
+ for ( int j =0 , bix =m2 .getDenseBlock ().pos (aix ); j <n ; j ++ )
1530
+ c .append (i , j , bvals [bix +j ]);
1531
+ ret .nonZeros += lnnz ;
1533
1532
}
1534
1533
}
1535
- else //GENERAL CASE
1534
+ //optional scaling if not pure selection
1535
+ if ( avals [apos ] != 1 )
1536
+ vectMultiplyInPlace (avals [apos ], c .values (i ), c .pos (i ), c .size (i ));
1537
+ }
1538
+ else //GENERAL CASE
1539
+ {
1540
+ for ( int k =apos ; k <apos +alen ; k ++ )
1536
1541
{
1537
- for ( int k =apos ; k <apos +alen ; k ++ )
1542
+ double aval = avals [k ];
1543
+ int aix = aixs [k ];
1544
+ for ( int j =0 ; j <n ; j ++ )
1538
1545
{
1539
- double aval = avals [k ];
1540
- int aix = aixs [k ];
1541
- for ( int j =0 ; j <n ; j ++ )
1542
- {
1543
- double cval = ret .quickGetValue (i , j );
1544
- double cvald = aval *m2 .quickGetValue (aix , j );
1545
- if ( cvald != 0 )
1546
- ret .quickSetValue (i , j , cval +cvald );
1547
- }
1546
+ double cval = ret .quickGetValue (i , j );
1547
+ double cvald = aval *m2 .quickGetValue (aix , j );
1548
+ if ( cvald != 0 )
1549
+ ret .quickSetValue (i , j , cval +cvald );
1548
1550
}
1549
1551
}
1550
1552
}
@@ -3209,6 +3211,20 @@ public static void vectMultiplyWrite( final double aval, double[] b, double[] c,
3209
3211
c [ ci +7 ] = aval * b [ bi +7 ];
3210
3212
}
3211
3213
}
3214
+
3215
+ public static void vectMultiplyInPlace ( final double aval , double [] c , int ci , final int len ) {
3216
+ final int bn = len %8 ;
3217
+ //rest, not aligned to 8-blocks
3218
+ for ( int j = 0 ; j < bn ; j ++, ci ++)
3219
+ c [ ci ] *= aval ;
3220
+ //unrolled 8-block (for better instruction-level parallelism)
3221
+ for ( int j = bn ; j < len ; j +=8 , ci +=8 ) {
3222
+ c [ ci +0 ] *= aval ; c [ ci +1 ] *= aval ;
3223
+ c [ ci +2 ] *= aval ; c [ ci +3 ] *= aval ;
3224
+ c [ ci +4 ] *= aval ; c [ ci +5 ] *= aval ;
3225
+ c [ ci +6 ] *= aval ; c [ ci +7 ] *= aval ;
3226
+ }
3227
+ }
3212
3228
3213
3229
//note: public for use by codegen for consistency
3214
3230
public static void vectMultiplyWrite ( double [] a , double [] b , double [] c , int ai , int bi , int ci , final int len )
0 commit comments