Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 17 additions & 35 deletions cpp/basix/polyset.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,9 @@ void tabulate_polyset_line_derivs(xt::xtensor<double, 3>& P, std::size_t n,
const xt::xtensor<double, 2>& x)
{
assert(x.shape(0) > 0);
const std::size_t m = (n + 1);
assert(P.shape(0) == nderiv + 1);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);
assert(P.shape(2) == n + 1);

std::fill(P.begin(), P.end(), 0.0);
xt::view(P, 0, xt::all(), 0) = 1.0;
Expand Down Expand Up @@ -108,11 +107,9 @@ void tabulate_polyset_triangle_derivs(xt::xtensor<double, 3>& P, std::size_t n,
auto x0 = xt::col(x, 0);
auto x1 = xt::col(x, 1);

const std::size_t m = (n + 1) * (n + 2) / 2;
const std::size_t md = (nderiv + 1) * (nderiv + 2) / 2;
assert(P.shape(0) == md);
assert(P.shape(0) == (nderiv + 1) * (nderiv + 2) / 2);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);
assert(P.shape(2) == (n + 1) * (n + 2) / 2);

// f3 = ((1 - y) / 2)^2
const auto f3 = xt::square(1.0 - (x1 * 2.0 - 1.0)) * 0.25;
Expand Down Expand Up @@ -211,8 +208,9 @@ void tabulate_polyset_tetrahedron_derivs(xt::xtensor<double, 3>& P,
const xt::xtensor<double, 2>& x)
{
assert(x.shape(1) == 3);
const std::size_t m = (n + 1) * (n + 2) * (n + 3) / 6;
const std::size_t md = (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6;
assert(P.shape(0) == (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == (n + 1) * (n + 2) * (n + 3) / 6);

const auto x0 = xt::col(x, 0);
const auto x1 = xt::col(x, 1);
Expand All @@ -224,10 +222,6 @@ void tabulate_polyset_tetrahedron_derivs(xt::xtensor<double, 3>& P,
const auto f5 = f4 * f4;

// Traverse derivatives in increasing order
assert(P.shape(0) == md);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);

std::fill(P.begin(), P.end(), 0.0);
xt::view(P, idx(0, 0, 0), xt::all(), 0) = 1.0;

Expand Down Expand Up @@ -437,8 +431,9 @@ void tabulate_polyset_pyramid_derivs(xt::xtensor<double, 3>& P, std::size_t n,
const xt::xtensor<double, 2>& x)
{
assert(x.shape(1) == 3);
const std::size_t m = (n + 1) * (n + 2) * (2 * n + 3) / 6;
const std::size_t md = (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6;
assert(P.shape(0) == (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == (n + 1) * (n + 2) * (2 * n + 3) / 6);

// Indexing for pyramidal basis functions
auto pyr_idx = [n](std::size_t p, std::size_t q, std::size_t r) -> std::size_t
Expand All @@ -456,10 +451,6 @@ void tabulate_polyset_pyramid_derivs(xt::xtensor<double, 3>& P, std::size_t n,
const auto f2 = 0.25 * xt::square(1.0 - (x2 * 2.0 - 1.0));

// Traverse derivatives in increasing order
assert(P.shape(0) == md);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);

std::fill(P.begin(), P.end(), 0.0);
xt::view(P, idx(0, 0, 0), xt::all(), pyr_idx(0, 0, 0)) = 1.0;

Expand Down Expand Up @@ -650,8 +641,9 @@ void tabulate_polyset_quad_derivs(xt::xtensor<double, 3>& P, std::size_t n,
const xt::xtensor<double, 2>& x)
{
assert(x.shape(1) == 2);
const std::size_t m = (n + 1) * (n + 1);
const std::size_t md = (nderiv + 1) * (nderiv + 2) / 2;
assert(P.shape(0) == (nderiv + 1) * (nderiv + 2) / 2);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == (n + 1) * (n + 1));

// Indexing for quadrilateral basis functions
auto quad_idx = [n](std::size_t px, std::size_t py) -> std::size_t
Expand All @@ -664,10 +656,6 @@ void tabulate_polyset_quad_derivs(xt::xtensor<double, 3>& P, std::size_t n,
assert(x0.shape(0) > 0);
assert(x1.shape(0) > 0);

assert(P.shape(0) == md);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);

// Compute tabulation of interval for px = 0
std::fill(P.begin(), P.end(), 0.0);
xt::view(P, idx(0, 0), xt::all(), quad_idx(0, 0)) = 1.0;
Expand Down Expand Up @@ -778,8 +766,9 @@ void tabulate_polyset_hex_derivs(xt::xtensor<double, 3>& P, std::size_t n,
const xt::xtensor<double, 2>& x)
{
assert(x.shape(1) == 3);
const std::size_t m = (n + 1) * (n + 1) * (n + 1);
const std::size_t md = (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6;
assert(P.shape(0) == (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == (n + 1) * (n + 1) * (n + 1));

// Indexing for hexahedral basis functions
auto hex_idx
Expand All @@ -795,10 +784,6 @@ void tabulate_polyset_hex_derivs(xt::xtensor<double, 3>& P, std::size_t n,
assert(x1.shape(0) > 0);
assert(x2.shape(0) > 0);

assert(P.shape(0) == md);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);

std::fill(P.begin(), P.end(), 0.0);
xt::view(P, idx(0, 0, 0), xt::all(), hex_idx(0, 0, 0)) = 1.0;

Expand Down Expand Up @@ -1010,12 +995,9 @@ void tabulate_polyset_prism_derivs(xt::xtensor<double, 3>& P, std::size_t n,
const xt::xtensor<double, 2>& x)
{
assert(x.shape(1) == 3);
const std::size_t m = (n + 1) * (n + 1) * (n + 2) / 2;
const std::size_t md = (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6;

assert(P.shape(0) == md);
assert(P.shape(0) == (nderiv + 1) * (nderiv + 2) * (nderiv + 3) / 6);
assert(P.shape(1) == x.shape(0));
assert(P.shape(2) == m);
assert(P.shape(2) == (n + 1) * (n + 1) * (n + 2) / 2);

const auto x0 = xt::col(x, 0);
const auto x1 = xt::col(x, 1);
Expand Down