16
16
import math
17
17
from dataclasses import dataclass
18
18
from typing import List , Tuple , Union
19
+ from numpy .typing import ArrayLike
19
20
20
21
import numpy as np
21
22
from jax import numpy as jnp
@@ -35,7 +36,7 @@ class Medium:
35
36
36
37
Attributes:
37
38
domain (Domain): domain of the medium
38
- sound_speed (jnp.darray ): speed of sound map, can be a scalar
39
+ sound_speed (jnp.ndarray ): speed of sound map, can be a scalar
39
40
density (jnp.ndarray): density map, can be a scalar
40
41
attenuation (jnp.ndarray): attenuation map, can be a scalar
41
42
pml_size (int): size of the PML layer in grid-points
@@ -148,8 +149,8 @@ class MediumType(Medium):
148
149
@type_of .dispatch
149
150
def type_of (m : Medium ):
150
151
return MediumType [type (m .sound_speed ),
151
- type (m .density ),
152
- type (m .attenuation )]
152
+ type (m .density ),
153
+ type (m .attenuation )]
153
154
154
155
155
156
MediumAllScalars = MediumType [object , object , object ]
@@ -187,11 +188,11 @@ def _unit_fibonacci_sphere(
187
188
coordinates of the points on the sphere.
188
189
"""
189
190
points = []
190
- phi = math .pi * (3.0 - math .sqrt (5.0 )) # golden angle in radians
191
+ phi = math .pi * (3.0 - math .sqrt (5.0 )) # golden angle in radians
191
192
for i in range (samples ):
192
- y = 1 - (i / float (samples - 1 )) * 2 # y goes from 1 to -1
193
- radius = math .sqrt (1 - y * y ) # radius at y
194
- theta = phi * i # golden angle increment
193
+ y = 1 - (i / float (samples - 1 )) * 2 # y goes from 1 to -1
194
+ radius = math .sqrt (1 - y * y ) # radius at y
195
+ theta = phi * i # golden angle increment
195
196
x = math .cos (theta ) * radius
196
197
z = math .sin (theta ) * radius
197
198
points .append ((x , y , z ))
@@ -228,15 +229,15 @@ def _fibonacci_sphere(
228
229
229
230
def _circ_mask (N , radius , centre ):
230
231
x , y = np .mgrid [0 :N [0 ], 0 :N [1 ]]
231
- dist_from_centre = np .sqrt ((x - centre [0 ])** 2 + (y - centre [1 ])** 2 )
232
+ dist_from_centre = np .sqrt ((x - centre [0 ]) ** 2 + (y - centre [1 ]) ** 2 )
232
233
mask = (dist_from_centre < radius ).astype (int )
233
234
return mask
234
235
235
236
236
237
def _sphere_mask (N , radius , centre ):
237
238
x , y , z = np .mgrid [0 :N [0 ], 0 :N [1 ], 0 :N [2 ]]
238
- dist_from_centre = np .sqrt ((x - centre [0 ])** 2 + (y - centre [1 ])** 2 +
239
- (z - centre [2 ])** 2 )
239
+ dist_from_centre = np .sqrt ((x - centre [0 ]) ** 2 + (y - centre [1 ]) ** 2 +
240
+ (z - centre [2 ]) ** 2 )
240
241
mask = (dist_from_centre < radius ).astype (int )
241
242
return mask
242
243
@@ -327,7 +328,7 @@ def __init__(self, mask, signal, dt, domain):
327
328
328
329
def tree_flatten (self ):
329
330
children = (self .mask , self .signal , self .dt )
330
- aux = (self .domain , )
331
+ aux = (self .domain ,)
331
332
return (children , aux )
332
333
333
334
@classmethod
@@ -430,7 +431,7 @@ def __init__(self, positions):
430
431
431
432
def tree_flatten (self ):
432
433
children = None
433
- aux = (self .positions , )
434
+ aux = (self .positions ,)
434
435
return (children , aux )
435
436
436
437
@classmethod
@@ -461,10 +462,115 @@ def __call__(self, p: Field, u: Field, rho: Field):
461
462
return p .on_grid [self .positions [0 ]]
462
463
elif len (self .positions ) == 2 :
463
464
return p .on_grid [self .positions [0 ],
464
- self .positions [1 ]] # type: ignore
465
+ self .positions [1 ]] # type: ignore
465
466
elif len (self .positions ) == 3 :
466
467
return p .on_grid [self .positions [0 ], self .positions [1 ],
467
- self .positions [2 ]] # type: ignore
468
+ self .positions [2 ]] # type: ignore
469
+ else :
470
+ raise ValueError (
471
+ "Sensors positions must be 1, 2 or 3 dimensional. Not {}" .
472
+ format (len (self .positions )))
473
+
474
+
475
+ def _bli_function (x0 : jnp .ndarray , x : jnp .ndarray , n : int , include_imag : bool = False ) -> jnp .ndarray :
476
+ """
477
+ The function used to compute the band limited interpolation function.
478
+
479
+ Args:
480
+ x0 (jnp.ndarray): Position of the sensors along the axis.
481
+ x (jnp.ndarray): Grid positions.
482
+ n (int): Size of the grid
483
+ include_imag (bool): Include the imaginary component?
484
+
485
+ Returns:
486
+ jnp.ndarray: The values of the function at the grid positions.
487
+ """
488
+ dx = jnp .where ((x - x0 [:, None ]) == 0 , 1 , x - x0 [:, None ]) # https://github.com/google/jax/issues/1052
489
+ dx_nonzero = (x - x0 [:, None ]) != 0
490
+
491
+ if n % 2 == 0 :
492
+ y = jnp .sin (jnp .pi * dx ) / \
493
+ jnp .tan (jnp .pi * dx / n ) / n
494
+ y -= jnp .sin (jnp .pi * x0 [:, None ]) * jnp .sin (jnp .pi * x ) / n
495
+ if include_imag :
496
+ y += 1j * jnp .cos (jnp .pi * x0 [:, None ]) * jnp .sin (jnp .pi * x ) / n
497
+ else :
498
+ y = jnp .sin (jnp .pi * dx ) / \
499
+ jnp .sin (jnp .pi * dx / n ) / n
500
+
501
+ # Deal with case of precisely on grid.
502
+ y = y * jnp .all (dx_nonzero , axis = 1 )[:, None ] + (1 - dx_nonzero ) * (~ jnp .all (dx_nonzero , axis = 1 )[:, None ])
503
+ return y
504
+
505
+
506
+ @register_pytree_node_class
507
+ class BLISensors :
508
+ """ Band-limited interpolant (off-grid) sensors.
509
+
510
+ Args:
511
+ positions (Tuple of List of float): Sensor positions.
512
+ n (Tuple of int): Grid size.
513
+
514
+ Attributes:
515
+ positions (Tuple[jnp.ndarray]): Sensor positions
516
+ n (Tuple[int]): Grid size.
517
+ """
518
+
519
+ positions : Tuple [jnp .ndarray ]
520
+ n : Tuple [int ]
521
+
522
+ def __init__ (self , positions : Tuple [jnp .ndarray ], n : Tuple [int ]):
523
+ self .positions = positions
524
+ self .n = n
525
+
526
+ # Calculate the band-limited interpolant weights if not provided.
527
+ x = jnp .arange (n [0 ])[None ]
528
+ self .bx = jnp .expand_dims (_bli_function (positions [0 ], x , n [0 ]),
529
+ axis = range (2 , 2 + len (n )))
530
+
531
+ if len (n ) > 1 :
532
+ y = jnp .arange (n [1 ])[None ]
533
+ self .by = jnp .expand_dims (_bli_function (positions [1 ], y , n [1 ]),
534
+ axis = range (2 , 2 + len (n ) - 1 ))
535
+ else :
536
+ self .by = None
537
+
538
+ if len (n ) > 2 :
539
+ z = jnp .arange (n [2 ])[None ]
540
+ self .bz = jnp .expand_dims (_bli_function (positions [2 ], z , n [2 ]),
541
+ axis = range (2 , 2 + len (n ) - 2 ))
542
+ else :
543
+ self .bz = None
544
+
545
+ def tree_flatten (self ):
546
+ children = self .positions ,
547
+ aux = self .n ,
548
+ return children , aux
549
+
550
+ @classmethod
551
+ def tree_unflatten (cls , aux , children ):
552
+ return cls (* children , * aux )
553
+
554
+ def __call__ (self , p : Field , u , v ):
555
+ r"""Returns the values of the field p at the sensors positions.
556
+ Args:
557
+ p (Field): The field to be sampled.
558
+ """
559
+ if len (self .positions ) == 1 :
560
+ # 1D
561
+ pw = jnp .sum (p .on_grid [None ] * self .bx , axis = 1 )
562
+ return pw
563
+ elif len (self .positions ) == 2 :
564
+ # 2D
565
+ pw = jnp .sum (p .on_grid [None ] * self .bx , axis = 1 )
566
+ pw = jnp .sum (pw * self .by , axis = 1 )
567
+ return pw
568
+ elif len (self .positions ) == 3 :
569
+ # 3D
570
+ pw = jnp .sum (p .on_grid [None ] * self .bx , axis = 1 )
571
+ pw = jnp .sum (pw * self .by , axis = 1 )
572
+ pw = jnp .sum (pw * self .bz , axis = 1 )
573
+ return pw
468
574
else :
469
575
raise ValueError (
470
576
"Sensors positions must be 1, 2 or 3 dimensional. Not {}" .
@@ -488,7 +594,7 @@ def __init__(self, dt, t_end):
488
594
self .t_end = t_end
489
595
490
596
def tree_flatten (self ):
491
- children = (None , )
597
+ children = (None ,)
492
598
aux = (self .dt , self .t_end )
493
599
return (children , aux )
494
600
@@ -522,7 +628,7 @@ def from_medium(medium: Medium, cfl: float = 0.3, t_end=None):
522
628
np .max )
523
629
if t_end is None :
524
630
t_end = np .sqrt (
525
- sum ((x [- 1 ] - x [0 ])** 2
631
+ sum ((x [- 1 ] - x [0 ]) ** 2
526
632
for x in medium .domain .spatial_axis )) / functional (
527
- medium .sound_speed )(np .min )
633
+ medium .sound_speed )(np .min )
528
634
return TimeAxis (dt = float (dt ), t_end = float (t_end ))
0 commit comments