Skip to content

[FEAT] Implemented covariance calculation Closes #48 #66

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
Open
43 changes: 43 additions & 0 deletions numpower.c
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,48 @@ PHP_METHOD(NDArray, variance) {
RETURN_NDARRAY(rtn, return_value);
}

/**
* NDArray::cov
*
* @param execute_data
* @param return_value
*/
ZEND_BEGIN_ARG_INFO_EX(arginfo_ndarray_cov, 0, 0, 1)
ZEND_ARG_INFO(0, array)
ZEND_ARG_INFO(0, rowvar)
ZEND_END_ARG_INFO()
PHP_METHOD(NDArray, cov) {
NDArray *rtn = NULL;
zval *array;
bool rowvar = true;
ZEND_PARSE_PARAMETERS_START(1, 2)
Z_PARAM_ZVAL(array)
Z_PARAM_OPTIONAL
Z_PARAM_BOOL(rowvar)
ZEND_PARSE_PARAMETERS_END();
NDArray *nda = ZVAL_TO_NDARRAY(array);
if (nda == NULL) {
return;
}

if (NDArray_DEVICE(nda) == NDARRAY_DEVICE_CPU) {
rtn = NDArray_cov(nda, rowvar);
} else {
#ifdef HAVE_CUBLAS
rtn = NDArray_cov(nda, rowvar);
#else
zend_throw_error(NULL, "GPU operations unavailable. CUBLAS not detected.");
#endif
}
if (rtn == NULL) {
return;
}
if (Z_TYPE_P(array) == IS_ARRAY) {
NDArray_FREE(nda);
}
RETURN_NDARRAY(rtn, return_value);
}

/**
* NDArray::ceil
*
Expand Down Expand Up @@ -5180,6 +5222,7 @@ static const zend_function_entry class_NDArray_methods[] = {
ZEND_ME(NDArray, average, arginfo_ndarray_average, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC)
ZEND_ME(NDArray, std, arginfo_ndarray_std, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC)
ZEND_ME(NDArray, quantile, arginfo_ndarray_quantile, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC)
ZEND_ME(NDArray, cov, arginfo_ndarray_cov, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC)

// ARITHMETICS
ZEND_ME(NDArray, add, arginfo_ndarray_add, ZEND_ACC_PUBLIC | ZEND_ACC_STATIC)
Expand Down
66 changes: 66 additions & 0 deletions src/ndmath/statistics.c
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@
#include "string.h"
#include "../initializers.h"
#include "arithmetics.h"
#include "../manipulation.h"
#include "linalg.h"

// Comparison function for sorting
int compare_quantile(const void* a, const void* b) {
Expand Down Expand Up @@ -151,4 +153,68 @@ NDArray_Average(NDArray *a, NDArray *weights) {
NDArray_FREE(m_weights);
}
return rtn;
}

/**
* NDArray::cov
*
* @param a
* @return
*/
NDArray *NDArray_cov(NDArray *a, bool rowvar)
{
if (!rowvar) {
a = NDArray_Transpose(a, NULL);
}

if (a == NULL || NDArray_NUMELEMENTS(a) == 0)
{
zend_throw_error(NULL, "Input cannot be null or empty.");
return NULL;
}
if (NDArray_NDIM(a) != 2 || NDArray_SHAPE(a)[1] == 1)
{
zend_throw_error(NULL, "Input must be a 2D NDArray.");
return NULL;
}

int cols = NDArray_SHAPE(a)[0];
int rows = NDArray_SHAPE(a)[1];

int *indices_shape = emalloc(sizeof(int) * 2);
indices_shape[0] = 2;
indices_shape[1] = 1;

NDArray** indices_axis = emalloc(sizeof(NDArray*) * 2);
indices_axis[0] = NDArray_Zeros(indices_shape, 1, NDArray_TYPE(a), NDArray_DEVICE(a));
indices_axis[1] = NDArray_Zeros(indices_shape, 1, NDArray_TYPE(a), NDArray_DEVICE(a));

NDArray_FDATA(indices_axis[1])[0] = 0;
NDArray_FDATA(indices_axis[1])[1] = rows;

NDArray **centered_vectors = emalloc(sizeof(NDArray *) * cols);
for (int i = 0; i < cols; i++)
{
NDArray_FDATA(indices_axis[0])[0] = i;
NDArray_FDATA(indices_axis[0])[1] = i + 1;
NDArray *col_vector = NDArray_Slice(a, indices_axis, 2);
NDArray *centered = NDArray_Subtract_Float(col_vector, NDArray_CreateFromFloatScalar(NDArray_Sum_Float(col_vector) / NDArray_NUMELEMENTS(col_vector)));
NDArray_FREE(col_vector);
centered_vectors[i] = centered;
}
efree(indices_shape);
efree(indices_axis[0]);
efree(indices_axis[1]);
efree(indices_axis);
NDArray *centered_a = NDArray_Reshape(NDArray_ConcatenateFlat(centered_vectors, cols), NDArray_SHAPE(a), NDArray_NDIM(a));
for (int i = 0; i < cols; i++)
{
NDArray_FREE(centered_vectors[i]);
}
efree(centered_vectors);
NDArray *multiplied = NDArray_Dot(centered_a, NDArray_Transpose(centered_a, NULL));
NDArray_FREE(centered_a);
NDArray *rtn = NDArray_Divide_Float(multiplied, NDArray_CreateFromFloatScalar((float)rows - 1));
NDArray_FREE(multiplied);
return rtn;
}
1 change: 1 addition & 0 deletions src/ndmath/statistics.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ NDArray* NDArray_Quantile(NDArray *target, NDArray *q);
NDArray* NDArray_Std(NDArray *a);
NDArray* NDArray_Variance(NDArray *a);
NDArray* NDArray_Average(NDArray *a, NDArray *weights);
NDArray* NDArray_cov(NDArray *a, bool rowvar);

#endif //NUMPOWER_STATISTICS_H
127 changes: 127 additions & 0 deletions tests/math/048-ndarray-cov.phpt
Original file line number Diff line number Diff line change
@@ -0,0 +1,127 @@
--TEST--
NDArray::cov
--FILE--
<?php
$a = \NDArray::array([[3, 7, 8], [2, 4, 3]]);
print_r(\NDArray::cov($a)->toArray());
$b = \NDArray::array([[1, 2, 3, 4], [5, 4, 3, 2]]);
print_r(\NDArray::cov($b)->toArray());
$c = \NDArray::array([[1, 2, 3, 4], [5, 6, 7, 8]]);
print_r(\NDArray::cov($c)->toArray());
$d = \NDArray::array([[1, 2, 3, 4], [1, 2, 3, 4]]);
print_r(\NDArray::cov($d)->toArray());
$e = \NDArray::array([[1, 2, 3, 4]]);
print_r(\NDArray::cov($e)->toArray());
$f = \NDArray::array([[0, 0, 0, 0], [0, 0, 0, 0]]);
print_r(\NDArray::cov($f)->toArray());
$g = \NDArray::array([[3, 7, 8], [2, 4, 3]]);
print_r(\NDArray::cov($g, False)->toArray());
?>
--EXPECT--
Array
(
[0] => Array
(
[0] => 7
[1] => 2
)

[1] => Array
(
[0] => 2
[1] => 1
)

)
Array
(
[0] => Array
(
[0] => 1.6666666269302
[1] => -1.6666666269302
)

[1] => Array
(
[0] => -1.6666666269302
[1] => 1.6666666269302
)

)
Array
(
[0] => Array
(
[0] => 1.6666666269302
[1] => 1.6666666269302
)

[1] => Array
(
[0] => 1.6666666269302
[1] => 1.6666666269302
)

)
Array
(
[0] => Array
(
[0] => 1.6666666269302
[1] => 1.6666666269302
)

[1] => Array
(
[0] => 1.6666666269302
[1] => 1.6666666269302
)

)
Array
(
[0] => Array
(
[0] => 1.6666666269302
)

)
Array
(
[0] => Array
(
[0] => 0
[1] => 0
)

[1] => Array
(
[0] => 0
[1] => 0
)

)
Array
(
[0] => Array
(
[0] => 0.5
[1] => 1.5
[2] => 2.5
)

[1] => Array
(
[0] => 1.5
[1] => 4.5
[2] => 7.5
)

[2] => Array
(
[0] => 2.5
[1] => 7.5
[2] => 12.5
)

)
Loading