99import bisect
1010
1111
12- class Spline :
12+ class CubicSpline1D :
1313 """
14- Cubic Spline class
14+ 1D Cubic Spline class
15+
16+ Parameters
17+ ----------
18+ x : list
19+ x coordinates for data points. This x coordinates must be
20+ sorted
21+ in ascending order.
22+ y : list
23+ y coordinates for data points
24+
25+ Examples
26+ --------
27+ You can interpolate 1D data points.
28+
29+ >>> import numpy as np
30+ >>> import matplotlib.pyplot as plt
31+ >>> x = np.arange(5)
32+ >>> y = [1.7, -6, 5, 6.5, 0.0]
33+ >>> sp = CubicSpline1D(x, y)
34+ >>> xi = np.linspace(0.0, 5.0)
35+ >>> yi = [sp.calc_position(x) for x in xi]
36+ >>> plt.plot(x, y, "xb", label="Data points")
37+ >>> plt.plot(xi, yi , "r", label="Cubic spline interpolation")
38+ >>> plt.grid(True)
39+ >>> plt.legend()
40+ >>> plt.show()
41+
42+ .. image:: cubic_spline_1d.png
43+
1544 """
1645
1746 def __init__ (self , x , y ):
18- self .b , self .c , self .d , self .w = [], [], [], []
1947
48+ h = np .diff (x )
49+ if np .any (h < 0 ):
50+ raise ValueError ("x coordinates must be sorted in ascending order" )
51+
52+ self .a , self .b , self .c , self .d = [], [], [], []
2053 self .x = x
2154 self .y = y
22-
2355 self .nx = len (x ) # dimension of x
24- h = np .diff (x )
2556
26- # calc coefficient c
57+ # calc coefficient a
2758 self .a = [iy for iy in y ]
2859
2960 # calc coefficient c
3061 A = self .__calc_A (h )
31- B = self .__calc_B (h )
62+ B = self .__calc_B (h , self . a )
3263 self .c = np .linalg .solve (A , B )
33- # print(self.c1)
3464
3565 # calc spline coefficient b and d
3666 for i in range (self .nx - 1 ):
37- self .d .append ((self .c [i + 1 ] - self .c [i ]) / (3.0 * h [i ]))
38- tb = (self .a [i + 1 ] - self .a [i ]) / h [i ] - h [i ] * \
39- (self .c [i + 1 ] + 2.0 * self .c [i ]) / 3.0
40- self .b .append (tb )
67+ d = (self .c [i + 1 ] - self .c [i ]) / (3.0 * h [i ])
68+ b = 1.0 / h [i ] * (self .a [i + 1 ] - self .a [i ]) \
69+ - h [i ] / 3.0 * (2.0 * self .c [i ] + self .c [i + 1 ])
70+ self .d .append (d )
71+ self .b .append (b )
4172
42- def calc (self , t ):
73+ def calc_position (self , x ):
4374 """
44- Calc position
75+ Calc `y` position for given `x`.
4576
46- if t is outside of the input x , return None
77+ if `x` is outside the data point's `x` range , return None.
4778
79+ Returns
80+ -------
81+ y : float
82+ y position for given x.
4883 """
49-
50- if t < self .x [0 ]:
84+ if x < self .x [0 ]:
5185 return None
52- elif t > self .x [- 1 ]:
86+ elif x > self .x [- 1 ]:
5387 return None
5488
55- i = self .__search_index (t )
56- dx = t - self .x [i ]
57- result = self .a [i ] + self .b [i ] * dx + \
89+ i = self .__search_index (x )
90+ dx = x - self .x [i ]
91+ position = self .a [i ] + self .b [i ] * dx + \
5892 self .c [i ] * dx ** 2.0 + self .d [i ] * dx ** 3.0
5993
60- return result
94+ return position
6195
62- def calcd (self , t ):
96+ def calc_first_derivative (self , x ):
6397 """
64- Calc first derivative
98+ Calc first derivative at given x.
99+
100+ if x is outside the input x, return None
65101
66- if t is outside of the input x, return None
102+ Returns
103+ -------
104+ dy : float
105+ first derivative for given x.
67106 """
68107
69- if t < self .x [0 ]:
108+ if x < self .x [0 ]:
70109 return None
71- elif t > self .x [- 1 ]:
110+ elif x > self .x [- 1 ]:
72111 return None
73112
74- i = self .__search_index (t )
75- dx = t - self .x [i ]
76- result = self .b [i ] + 2.0 * self .c [i ] * dx + 3.0 * self .d [i ] * dx ** 2.0
77- return result
113+ i = self .__search_index (x )
114+ dx = x - self .x [i ]
115+ dy = self .b [i ] + 2.0 * self .c [i ] * dx + 3.0 * self .d [i ] * dx ** 2.0
116+ return dy
78117
79- def calcdd (self , t ):
118+ def calc_second_derivative (self , x ):
80119 """
81- Calc second derivative
120+ Calc second derivative at given x.
121+
122+ if x is outside the input x, return None
123+
124+ Returns
125+ -------
126+ ddy : float
127+ second derivative for given x.
82128 """
83129
84- if t < self .x [0 ]:
130+ if x < self .x [0 ]:
85131 return None
86- elif t > self .x [- 1 ]:
132+ elif x > self .x [- 1 ]:
87133 return None
88134
89- i = self .__search_index (t )
90- dx = t - self .x [i ]
91- result = 2.0 * self .c [i ] + 6.0 * self .d [i ] * dx
92- return result
135+ i = self .__search_index (x )
136+ dx = x - self .x [i ]
137+ ddy = 2.0 * self .c [i ] + 6.0 * self .d [i ] * dx
138+ return ddy
93139
94140 def __search_index (self , x ):
95141 """
@@ -112,30 +158,82 @@ def __calc_A(self, h):
112158 A [0 , 1 ] = 0.0
113159 A [self .nx - 1 , self .nx - 2 ] = 0.0
114160 A [self .nx - 1 , self .nx - 1 ] = 1.0
115- # print(A)
116161 return A
117162
118- def __calc_B (self , h ):
163+ def __calc_B (self , h , a ):
119164 """
120165 calc matrix B for spline coefficient c
121166 """
122167 B = np .zeros (self .nx )
123168 for i in range (self .nx - 2 ):
124- B [i + 1 ] = 3.0 * (self . a [i + 2 ] - self . a [i + 1 ]) / \
125- h [ i + 1 ] - 3.0 * (self . a [i + 1 ] - self . a [i ]) / h [i ]
169+ B [i + 1 ] = 3.0 * (a [i + 2 ] - a [i + 1 ]) / h [ i + 1 ] \
170+ - 3.0 * (a [i + 1 ] - a [i ]) / h [i ]
126171 return B
127172
128173
129- class Spline2D :
174+ class CubicSpline2D :
130175 """
131- 2D Cubic Spline class
132-
176+ Cubic CubicSpline2D class
177+
178+ Parameters
179+ ----------
180+ x : list
181+ x coordinates for data points.
182+ y : list
183+ y coordinates for data points.
184+
185+ Examples
186+ --------
187+ You can interpolate a 2D data points.
188+
189+ >>> import matplotlib.pyplot as plt
190+ >>> x = [-2.5, 0.0, 2.5, 5.0, 7.5, 3.0, -1.0]
191+ >>> y = [0.7, -6, 5, 6.5, 0.0, 5.0, -2.0]
192+ >>> ds = 0.1 # [m] distance of each interpolated points
193+ >>> sp = CubicSpline2D(x, y)
194+ >>> s = np.arange(0, sp.s[-1], ds)
195+ >>> rx, ry, ryaw, rk = [], [], [], []
196+ >>> for i_s in s:
197+ ... ix, iy = sp.calc_position(i_s)
198+ ... rx.append(ix)
199+ ... ry.append(iy)
200+ ... ryaw.append(sp.calc_yaw(i_s))
201+ ... rk.append(sp.calc_curvature(i_s))
202+ >>> plt.subplots(1)
203+ >>> plt.plot(x, y, "xb", label="Data points")
204+ >>> plt.plot(rx, ry, "-r", label="Cubic spline path")
205+ >>> plt.grid(True)
206+ >>> plt.axis("equal")
207+ >>> plt.xlabel("x[m]")
208+ >>> plt.ylabel("y[m]")
209+ >>> plt.legend()
210+ >>> plt.show()
211+
212+ .. image:: cubic_spline_2d_path.png
213+
214+ >>> plt.subplots(1)
215+ >>> plt.plot(s, [np.rad2deg(iyaw) for iyaw in ryaw], "-r", label="yaw")
216+ >>> plt.grid(True)
217+ >>> plt.legend()
218+ >>> plt.xlabel("line length[m]")
219+ >>> plt.ylabel("yaw angle[deg]")
220+
221+ .. image:: cubic_spline_2d_yaw.png
222+
223+ >>> plt.subplots(1)
224+ >>> plt.plot(s, rk, "-r", label="curvature")
225+ >>> plt.grid(True)
226+ >>> plt.legend()
227+ >>> plt.xlabel("line length[m]")
228+ >>> plt.ylabel("curvature [1/m]")
229+
230+ .. image:: cubic_spline_2d_curvature.png
133231 """
134232
135233 def __init__ (self , x , y ):
136234 self .s = self .__calc_s (x , y )
137- self .sx = Spline (self .s , x )
138- self .sy = Spline (self .s , y )
235+ self .sx = CubicSpline1D (self .s , x )
236+ self .sy = CubicSpline1D (self .s , y )
139237
140238 def __calc_s (self , x , y ):
141239 dx = np .diff (x )
@@ -148,35 +246,70 @@ def __calc_s(self, x, y):
148246 def calc_position (self , s ):
149247 """
150248 calc position
249+
250+ Parameters
251+ ----------
252+ s : float
253+ distance from the start point. if `s` is outside the data point's
254+ range, return None.
255+
256+ Returns
257+ -------
258+ x : float
259+ x position for given s.
260+ y : float
261+ y position for given s.
151262 """
152- x = self .sx .calc (s )
153- y = self .sy .calc (s )
263+ x = self .sx .calc_position (s )
264+ y = self .sy .calc_position (s )
154265
155266 return x , y
156267
157268 def calc_curvature (self , s ):
158269 """
159270 calc curvature
271+
272+ Parameters
273+ ----------
274+ s : float
275+ distance from the start point. if `s` is outside the data point's
276+ range, return None.
277+
278+ Returns
279+ -------
280+ k : float
281+ curvature for given s.
160282 """
161- dx = self .sx .calcd (s )
162- ddx = self .sx .calcdd (s )
163- dy = self .sy .calcd (s )
164- ddy = self .sy .calcdd (s )
283+ dx = self .sx .calc_first_derivative (s )
284+ ddx = self .sx .calc_second_derivative (s )
285+ dy = self .sy .calc_first_derivative (s )
286+ ddy = self .sy .calc_second_derivative (s )
165287 k = (ddy * dx - ddx * dy ) / ((dx ** 2 + dy ** 2 )** (3 / 2 ))
166288 return k
167289
168290 def calc_yaw (self , s ):
169291 """
170292 calc yaw
293+
294+ Parameters
295+ ----------
296+ s : float
297+ distance from the start point. if `s` is outside the data point's
298+ range, return None.
299+
300+ Returns
301+ -------
302+ yaw : float
303+ yaw angle (tangent vector) for given s.
171304 """
172- dx = self .sx .calcd (s )
173- dy = self .sy .calcd (s )
305+ dx = self .sx .calc_first_derivative (s )
306+ dy = self .sy .calc_first_derivative (s )
174307 yaw = math .atan2 (dy , dx )
175308 return yaw
176309
177310
178311def calc_spline_course (x , y , ds = 0.1 ):
179- sp = Spline2D (x , y )
312+ sp = CubicSpline2D (x , y )
180313 s = list (np .arange (0 , sp .s [- 1 ], ds ))
181314
182315 rx , ry , ryaw , rk = [], [], [], []
@@ -190,14 +323,30 @@ def calc_spline_course(x, y, ds=0.1):
190323 return rx , ry , ryaw , rk , s
191324
192325
193- def main (): # pragma: no cover
194- print ("Spline 2D test" )
326+ def main_1d ():
327+ print ("CubicSpline1D test" )
328+ import matplotlib .pyplot as plt
329+ x = np .arange (5 )
330+ y = [1.7 , - 6 , 5 , 6.5 , 0.0 ]
331+ sp = CubicSpline1D (x , y )
332+ xi = np .linspace (0.0 , 5.0 )
333+
334+ plt .plot (x , y , "xb" , label = "Data points" )
335+ plt .plot (xi , [sp .calc_position (x ) for x in xi ], "r" ,
336+ label = "Cubic spline interpolation" )
337+ plt .grid (True )
338+ plt .legend ()
339+ plt .show ()
340+
341+
342+ def main_2d (): # pragma: no cover
343+ print ("CubicSpline1D 2D test" )
195344 import matplotlib .pyplot as plt
196345 x = [- 2.5 , 0.0 , 2.5 , 5.0 , 7.5 , 3.0 , - 1.0 ]
197346 y = [0.7 , - 6 , 5 , 6.5 , 0.0 , 5.0 , - 2.0 ]
198347 ds = 0.1 # [m] distance of each interpolated points
199348
200- sp = Spline2D (x , y )
349+ sp = CubicSpline2D (x , y )
201350 s = np .arange (0 , sp .s [- 1 ], ds )
202351
203352 rx , ry , ryaw , rk = [], [], [], []
@@ -209,8 +358,8 @@ def main(): # pragma: no cover
209358 rk .append (sp .calc_curvature (i_s ))
210359
211360 plt .subplots (1 )
212- plt .plot (x , y , "xb" , label = "input " )
213- plt .plot (rx , ry , "-r" , label = "spline" )
361+ plt .plot (x , y , "xb" , label = "Data points " )
362+ plt .plot (rx , ry , "-r" , label = "Cubic spline path " )
214363 plt .grid (True )
215364 plt .axis ("equal" )
216365 plt .xlabel ("x[m]" )
@@ -235,4 +384,5 @@ def main(): # pragma: no cover
235384
236385
237386if __name__ == '__main__' :
238- main ()
387+ # main_1d()
388+ main_2d ()
0 commit comments