Skip to content

Commit 43d26ff

Browse files
authored
Off-grid sensors. Band-limited interpolation and tests. (#192)
* Add BLI tests * Add example for off grid sensors. * Add off-grid sensors.
1 parent fc778ec commit 43d26ff

File tree

6 files changed

+564
-54
lines changed

6 files changed

+564
-54
lines changed

.gitignore

Lines changed: 37 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -1,36 +1,37 @@
1-
# Project
2-
*.ipynb_checkpoints
3-
*.pytest_cache
4-
*.vscode
5-
*.egg-info
6-
*__pycache__
7-
.venv
8-
site/*
9-
build/*
10-
11-
# poetry
12-
poetry.lock
13-
14-
# Testing
15-
.coverage
16-
coverage.xml
17-
18-
# Pages to be generated
19-
docs/test_reports/test_accuracy.png
20-
docs/test_reports/test_data.txt
21-
docs/test_reports/test_report.md
22-
23-
# Data
24-
*.mat
25-
docs/notebooks/others/ct.jpg
26-
27-
#k-wave outputs
28-
docs/benchmarks/kwave/*.png
29-
30-
31-
# Misc
32-
*.DS_Store
33-
*.env.local
34-
*.env.development.local
35-
*.env.test.local
36-
*.env.production.local
1+
.idea
2+
# Project
3+
*.ipynb_checkpoints
4+
*.pytest_cache
5+
*.vscode
6+
*.egg-info
7+
*__pycache__
8+
.venv
9+
site/*
10+
build/*
11+
12+
# poetry
13+
poetry.lock
14+
15+
# Testing
16+
.coverage
17+
coverage.xml
18+
19+
# Pages to be generated
20+
docs/test_reports/test_accuracy.png
21+
docs/test_reports/test_data.txt
22+
docs/test_reports/test_report.md
23+
24+
# Data
25+
*.mat
26+
docs/notebooks/others/ct.jpg
27+
28+
#k-wave outputs
29+
docs/benchmarks/kwave/*.png
30+
31+
32+
# Misc
33+
*.DS_Store
34+
*.env.local
35+
*.env.development.local
36+
*.env.test.local
37+
*.env.production.local

CHANGELOG.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ All notable changes to this project will be documented in this file.
44
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
55

66
## [Unreleased]
7+
### Added
8+
- Added off grid sensors [@tomelse]
79

810
## [0.1.2] - 2023-06-22
911
### Changed
@@ -64,4 +66,3 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
6466
[0.0.3]: https://github.com/ucl-bug/jwave/compare/0.0.2...0.0.3
6567
[0.0.2]: https://github.com/ucl-bug/jwave/compare/0.0.1...0.0.2
6668
[0.0.1]: https://github.com/ucl-bug/jwave/releases/tag/0.0.1
67-

docs/notebooks/ivp/off_grid_sensors.ipynb

Lines changed: 319 additions & 0 deletions
Large diffs are not rendered by default.

jwave/geometry.py

Lines changed: 123 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import math
1717
from dataclasses import dataclass
1818
from typing import List, Tuple, Union
19+
from numpy.typing import ArrayLike
1920

2021
import numpy as np
2122
from jax import numpy as jnp
@@ -35,7 +36,7 @@ class Medium:
3536
3637
Attributes:
3738
domain (Domain): domain of the medium
38-
sound_speed (jnp.darray): speed of sound map, can be a scalar
39+
sound_speed (jnp.ndarray): speed of sound map, can be a scalar
3940
density (jnp.ndarray): density map, can be a scalar
4041
attenuation (jnp.ndarray): attenuation map, can be a scalar
4142
pml_size (int): size of the PML layer in grid-points
@@ -148,8 +149,8 @@ class MediumType(Medium):
148149
@type_of.dispatch
149150
def type_of(m: Medium):
150151
return MediumType[type(m.sound_speed),
151-
type(m.density),
152-
type(m.attenuation)]
152+
type(m.density),
153+
type(m.attenuation)]
153154

154155

155156
MediumAllScalars = MediumType[object, object, object]
@@ -187,11 +188,11 @@ def _unit_fibonacci_sphere(
187188
coordinates of the points on the sphere.
188189
"""
189190
points = []
190-
phi = math.pi * (3.0 - math.sqrt(5.0)) # golden angle in radians
191+
phi = math.pi * (3.0 - math.sqrt(5.0)) # golden angle in radians
191192
for i in range(samples):
192-
y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1
193-
radius = math.sqrt(1 - y * y) # radius at y
194-
theta = phi * i # golden angle increment
193+
y = 1 - (i / float(samples - 1)) * 2 # y goes from 1 to -1
194+
radius = math.sqrt(1 - y * y) # radius at y
195+
theta = phi * i # golden angle increment
195196
x = math.cos(theta) * radius
196197
z = math.sin(theta) * radius
197198
points.append((x, y, z))
@@ -228,15 +229,15 @@ def _fibonacci_sphere(
228229

229230
def _circ_mask(N, radius, centre):
230231
x, y = np.mgrid[0:N[0], 0:N[1]]
231-
dist_from_centre = np.sqrt((x - centre[0])**2 + (y - centre[1])**2)
232+
dist_from_centre = np.sqrt((x - centre[0]) ** 2 + (y - centre[1]) ** 2)
232233
mask = (dist_from_centre < radius).astype(int)
233234
return mask
234235

235236

236237
def _sphere_mask(N, radius, centre):
237238
x, y, z = np.mgrid[0:N[0], 0:N[1], 0:N[2]]
238-
dist_from_centre = np.sqrt((x - centre[0])**2 + (y - centre[1])**2 +
239-
(z - centre[2])**2)
239+
dist_from_centre = np.sqrt((x - centre[0]) ** 2 + (y - centre[1]) ** 2 +
240+
(z - centre[2]) ** 2)
240241
mask = (dist_from_centre < radius).astype(int)
241242
return mask
242243

@@ -327,7 +328,7 @@ def __init__(self, mask, signal, dt, domain):
327328

328329
def tree_flatten(self):
329330
children = (self.mask, self.signal, self.dt)
330-
aux = (self.domain, )
331+
aux = (self.domain,)
331332
return (children, aux)
332333

333334
@classmethod
@@ -430,7 +431,7 @@ def __init__(self, positions):
430431

431432
def tree_flatten(self):
432433
children = None
433-
aux = (self.positions, )
434+
aux = (self.positions,)
434435
return (children, aux)
435436

436437
@classmethod
@@ -461,10 +462,115 @@ def __call__(self, p: Field, u: Field, rho: Field):
461462
return p.on_grid[self.positions[0]]
462463
elif len(self.positions) == 2:
463464
return p.on_grid[self.positions[0],
464-
self.positions[1]] # type: ignore
465+
self.positions[1]] # type: ignore
465466
elif len(self.positions) == 3:
466467
return p.on_grid[self.positions[0], self.positions[1],
467-
self.positions[2]] # type: ignore
468+
self.positions[2]] # type: ignore
469+
else:
470+
raise ValueError(
471+
"Sensors positions must be 1, 2 or 3 dimensional. Not {}".
472+
format(len(self.positions)))
473+
474+
475+
def _bli_function(x0: jnp.ndarray, x: jnp.ndarray, n: int, include_imag: bool = False) -> jnp.ndarray:
476+
"""
477+
The function used to compute the band limited interpolation function.
478+
479+
Args:
480+
x0 (jnp.ndarray): Position of the sensors along the axis.
481+
x (jnp.ndarray): Grid positions.
482+
n (int): Size of the grid
483+
include_imag (bool): Include the imaginary component?
484+
485+
Returns:
486+
jnp.ndarray: The values of the function at the grid positions.
487+
"""
488+
dx = jnp.where((x - x0[:, None]) == 0, 1, x - x0[:, None]) # https://github.com/google/jax/issues/1052
489+
dx_nonzero = (x - x0[:, None]) != 0
490+
491+
if n % 2 == 0:
492+
y = jnp.sin(jnp.pi * dx) / \
493+
jnp.tan(jnp.pi * dx / n) / n
494+
y -= jnp.sin(jnp.pi * x0[:, None]) * jnp.sin(jnp.pi * x) / n
495+
if include_imag:
496+
y += 1j * jnp.cos(jnp.pi * x0[:, None]) * jnp.sin(jnp.pi * x) / n
497+
else:
498+
y = jnp.sin(jnp.pi * dx) / \
499+
jnp.sin(jnp.pi * dx / n) / n
500+
501+
# Deal with case of precisely on grid.
502+
y = y * jnp.all(dx_nonzero, axis=1)[:, None] + (1 - dx_nonzero) * (~jnp.all(dx_nonzero, axis=1)[:, None])
503+
return y
504+
505+
506+
@register_pytree_node_class
507+
class BLISensors:
508+
""" Band-limited interpolant (off-grid) sensors.
509+
510+
Args:
511+
positions (Tuple of List of float): Sensor positions.
512+
n (Tuple of int): Grid size.
513+
514+
Attributes:
515+
positions (Tuple[jnp.ndarray]): Sensor positions
516+
n (Tuple[int]): Grid size.
517+
"""
518+
519+
positions: Tuple[jnp.ndarray]
520+
n: Tuple[int]
521+
522+
def __init__(self, positions: Tuple[jnp.ndarray], n: Tuple[int]):
523+
self.positions = positions
524+
self.n = n
525+
526+
# Calculate the band-limited interpolant weights if not provided.
527+
x = jnp.arange(n[0])[None]
528+
self.bx = jnp.expand_dims(_bli_function(positions[0], x, n[0]),
529+
axis=range(2, 2 + len(n)))
530+
531+
if len(n) > 1:
532+
y = jnp.arange(n[1])[None]
533+
self.by = jnp.expand_dims(_bli_function(positions[1], y, n[1]),
534+
axis=range(2, 2 + len(n) - 1))
535+
else:
536+
self.by = None
537+
538+
if len(n) > 2:
539+
z = jnp.arange(n[2])[None]
540+
self.bz = jnp.expand_dims(_bli_function(positions[2], z, n[2]),
541+
axis=range(2, 2 + len(n) - 2))
542+
else:
543+
self.bz = None
544+
545+
def tree_flatten(self):
546+
children = self.positions,
547+
aux = self.n,
548+
return children, aux
549+
550+
@classmethod
551+
def tree_unflatten(cls, aux, children):
552+
return cls(*children, *aux)
553+
554+
def __call__(self, p: Field, u, v):
555+
r"""Returns the values of the field p at the sensors positions.
556+
Args:
557+
p (Field): The field to be sampled.
558+
"""
559+
if len(self.positions) == 1:
560+
# 1D
561+
pw = jnp.sum(p.on_grid[None] * self.bx, axis=1)
562+
return pw
563+
elif len(self.positions) == 2:
564+
# 2D
565+
pw = jnp.sum(p.on_grid[None] * self.bx, axis=1)
566+
pw = jnp.sum(pw * self.by, axis=1)
567+
return pw
568+
elif len(self.positions) == 3:
569+
# 3D
570+
pw = jnp.sum(p.on_grid[None] * self.bx, axis=1)
571+
pw = jnp.sum(pw * self.by, axis=1)
572+
pw = jnp.sum(pw * self.bz, axis=1)
573+
return pw
468574
else:
469575
raise ValueError(
470576
"Sensors positions must be 1, 2 or 3 dimensional. Not {}".
@@ -488,7 +594,7 @@ def __init__(self, dt, t_end):
488594
self.t_end = t_end
489595

490596
def tree_flatten(self):
491-
children = (None, )
597+
children = (None,)
492598
aux = (self.dt, self.t_end)
493599
return (children, aux)
494600

@@ -522,7 +628,7 @@ def from_medium(medium: Medium, cfl: float = 0.3, t_end=None):
522628
np.max)
523629
if t_end is None:
524630
t_end = np.sqrt(
525-
sum((x[-1] - x[0])**2
631+
sum((x[-1] - x[0]) ** 2
526632
for x in medium.domain.spatial_axis)) / functional(
527-
medium.sound_speed)(np.min)
633+
medium.sound_speed)(np.min)
528634
return TimeAxis(dt=float(dt), t_end=float(t_end))

mkdocs.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,7 @@ nav:
7474
- Homogeneous wave propagation: notebooks/ivp/homogeneous_medium.ipynb
7575
- 3D simulations: notebooks/ivp/3d.ipynb
7676
- Sensors: notebooks/ivp/homogeneous_medium_sensors.ipynb
77+
- Off-Grid Sensors: notebooks/ivp/off_grid_sensors.ipynb
7778
- Custom sensors: notebooks/ivp/custom_sensors.ipynb
7879
- Automatic differentiation: notebooks/ivp/homogeneous_medium_backprop.ipynb
7980
- Heterogeneous medium: notebooks/ivp/heterogeneous_medium.ipynb

tests/test_off_grid_sensors.py

Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
from jwave.geometry import _bli_function, BLISensors, Domain, FourierSeries
2+
import numpy as np
3+
import pytest
4+
5+
6+
@pytest.mark.parametrize("n_grid", [100, 101])
7+
def test_bli_function(n_grid):
8+
# Make a load of sensors on the grid. Check the bli function.
9+
y = _bli_function(np.arange(n_grid), np.arange(n_grid), n_grid)
10+
# Assert that the bli function is 1 at one place for each detector.
11+
assert (np.all(np.sum(y != 0, axis=1) == 1))
12+
# Assert that the bli function is 0 everywhere else
13+
assert (np.all(np.sum(y == 0, axis=1) == n_grid - 1))
14+
# Assert that the 1 is in the correct place.
15+
assert (np.all(y[np.arange(n_grid), np.arange(n_grid)] == 1))
16+
17+
# Check off-grid points:
18+
y = _bli_function(np.arange(0, n_grid - 1) + 0.25, np.arange(n_grid), n_grid)
19+
# Check that the sensor is non-zero at more than one place.
20+
assert (np.all(np.sum(y != 0, axis=1) > 1))
21+
assert (np.all(np.isclose(np.sum(y, axis=1), 1)))
22+
23+
24+
@pytest.mark.parametrize("nx,ny,nz", [(100, 100, 100), (100, 101, 102)])
25+
def test_sensor(nx, ny, nz):
26+
# Check that it is identical to on-grid detectors in that case.
27+
n_detectors = min(nx, ny, nz)
28+
29+
xi = np.arange(n_detectors)
30+
yi = np.arange(n_detectors)
31+
zi = np.arange(n_detectors)
32+
np.random.shuffle(xi)
33+
np.random.shuffle(yi)
34+
np.random.shuffle(zi)
35+
36+
x = xi.astype(float)
37+
y = yi.astype(float)
38+
z = zi.astype(float)
39+
40+
p = np.random.random((nx, ny, nz))
41+
42+
s1d = BLISensors((x,), (nx,))
43+
s2d = BLISensors((x, y), (nx, ny))
44+
s3d = BLISensors((x, y, z), (nx, ny, nz))
45+
46+
domain1d = Domain((nx,), (1,))
47+
p1d = FourierSeries(p[:, 0, 0], domain1d)
48+
49+
domain2d = Domain((nx, ny), (1, 1))
50+
p2d = FourierSeries(p[:, :, 0], domain2d)
51+
52+
domain3d = Domain((nx, ny, nz), (1, 1, 1))
53+
p3d = FourierSeries(p, domain3d)
54+
55+
result = s1d(p1d, None, None)
56+
assert (np.all(result[..., 0] == p[xi, 0, 0]))
57+
58+
result = s2d(p2d, None, None)
59+
assert (np.all(result[..., 0] == p[xi, yi, 0]))
60+
61+
result = s3d(p3d, None, None)
62+
assert (np.all(result[..., 0] == p[xi, yi, zi]))
63+
64+
# Check off-grid (perturb a bit):
65+
s3d = BLISensors((x + 0.25, y + 0.3, z + 0.1), (nx, ny, nz))
66+
domain3d = Domain((nx, ny, nz), (1, 1, 1))
67+
# Check ones in ones out.
68+
p3d = FourierSeries(np.ones((nx, ny, nz)), domain3d)
69+
y = s3d(p3d, None, None)
70+
assert (np.all(np.isclose(y, 1)))
71+
72+
# Check zeros in zeros out
73+
p3d = FourierSeries(np.zeros((nx, ny, nz)), domain3d)
74+
y = s3d(p3d, None, None)
75+
assert (np.all(y == 0))
76+
77+
78+
if __name__ == "__main__":
79+
test_bli_function(100)
80+
test_bli_function(101)
81+
test_sensor()
82+
test_sensor(100, 101, 102)

0 commit comments

Comments
 (0)