@@ -23,11 +23,25 @@ public enum ConvolutionDomain {
2323 // case full // M+K-1
2424}
2525
26+ // MARK: - Generic Dimensional Type Ops (apply to Vector, Matrix, Tensor)
27+ // Typically element-wise operations that can be implemented in terms to linearized access (any dimensions).
28+ extension Numerics where Element: AccelerateFloatingPoint {
29+ public static func subtract< DT: NStorageAccessible > ( _ a: DT , _ b: DT , _ result: DT ) where DT. Element == Element {
30+ precondition ( a. shape == b. shape && a. shape == result. shape)
31+
32+ withLinearizedAccesses ( a, b, result) { aacc, bacc, racc in
33+ Element . mx_vsub ( aacc. base, aacc. stride, bacc. base, bacc. stride, racc. base, racc. stride, numericCast ( racc. count) )
34+ }
35+ }
36+ public static func subtract< DT: NStorageAccessible > ( _ a: DT , _ b: DT ) -> DT where DT. Element == Element { return a. _deriving { subtract ( a, b, $0) } }
37+ }
38+
39+
2640// MARK: - Vector Ops
2741extension Numerics where Element: AccelerateFloatingPoint {
2842 /// Creation of vectors
29- public static func zeros( count: Int ) -> Vector { return Vector ( repeating: 0.0 , count : count) }
30- public static func ones( count: Int ) -> Vector { return NVector ( repeating: 1.0 , count : count) }
43+ public static func zeros( count: Int ) -> Vector { return Vector ( repeating: 0.0 , size : count) }
44+ public static func ones( count: Int ) -> Vector { return NVector ( repeating: 1.0 , size : count) }
3145 public static func linspace( start: Element , stop: Element , count: Int , output: Vector ) {
3246 precondition ( count == output. size)
3347 precondition ( count >= 2 )
@@ -102,8 +116,8 @@ extension Numerics where Element: AccelerateFloatingPoint {
102116 let afterstart = before+ input. size
103117
104118 output [ before ..< afterstart] = input
105- output [ 0 ..< before] = NVector ( repeating: input. first!, count : before)
106- output [ afterstart ..< output. size] = NVector ( repeating: input. last!, count : after)
119+ output [ 0 ..< before] = NVector ( repeating: input. first!, size : before)
120+ output [ afterstart ..< output. size] = NVector ( repeating: input. last!, size : after)
107121 }
108122
109123 // Arithmetic
@@ -141,13 +155,13 @@ extension Numerics where Element: AccelerateFloatingPoint {
141155 }
142156 public static func multiply( _ a: Vector , _ b: Vector ) -> Vector { return a. _deriving { multiply ( a, b, $0) } }
143157
144- public static func subtract( _ a: Vector , _ b: Vector , _ result: Vector ) {
145- precondition ( a. shape == b. shape && a. shape == result. shape)
146- withStorageAccess ( a, b, result) { aacc, bacc, racc in
147- Element . mx_vsub ( aacc. base, aacc. stride, bacc. base, bacc. stride, racc. base, racc. stride, numericCast ( racc. count) )
148- }
149- }
150- public static func subtract( _ a: Vector , _ b: Vector ) -> Vector { return a. _deriving { subtract ( a, b, $0) } }
158+ // public static func subtract(_ a: Vector, _ b: Vector, _ result: Vector) {
159+ // precondition(a.shape == b.shape && a.shape == result.shape)
160+ // withStorageAccess(a, b, result) { aacc, bacc, racc in
161+ // Element.mx_vsub(aacc.base, aacc.stride, bacc.base, bacc.stride, racc.base, racc.stride, numericCast(racc.count))
162+ // }
163+ // }
164+ // public static func subtract(_ a: Vector, _ b: Vector) -> Vector { return a._deriving { subtract(a, b, $0) } }
151165 public static func add( _ a: Vector , _ b: Vector , _ result: Vector ) {
152166 precondition ( a. shape == b. shape && a. shape == result. shape)
153167 withStorageAccess ( a, b, result) { aacc, bacc, racc in
@@ -270,6 +284,13 @@ extension NVector where Element: AccelerateFloatingPoint {
270284extension Numerics where Element: AccelerateFloatingPoint {
271285 public static func zeros( rows: Int , columns: Int ) -> Matrix { return Matrix ( repeating: 0.0 , rows: rows, columns: columns) }
272286 public static func ones( rows: Int , columns: Int ) -> Matrix { return Matrix ( repeating: 1.0 , rows: rows, columns: columns) }
287+ // public static func add(_ a: Vector, _ b: Vector, _ result: Vector) {
288+ // precondition(a.shape == b.shape && a.shape == result.shape)
289+ // withStorageAccess(a, b, result) { aacc, bacc, racc in
290+ // Element.mx_vadd(aacc.base, aacc.stride, bacc.base, bacc.stride, racc.base, racc.stride, numericCast(racc.count))
291+ // }
292+ // }
293+ // public static func add(_ a: Vector, _ b: Vector) -> Vector { return a._deriving { add(a, b, $0) } }
273294
274295 public static func multiply( _ a: Matrix , _ b: Element , _ result: Matrix ) {
275296 precondition ( a. shape == result. shape)
@@ -526,6 +547,10 @@ extension NMatrix where Element: AccelerateFloatingPoint {
526547 return result
527548 }
528549
550+ // Matrix/Matric
551+ public static func - ( lhs: Matrix , rhs: Matrix ) -> Matrix { return Numerics . subtract ( lhs, rhs) }
552+ // public static func +(lhs: Matrix, rhs: Matrix) -> Matrix { return Numerics.add(lhs, rhs) }
553+
529554 // Matrix/Vector
530555 public static func * ( lhs: Matrix , rhs: Matrix ) -> Matrix { return Numerics . multiply ( lhs, rhs) }
531556 public static func * ( lhs: Matrix , rhs: Vector ) -> Vector { return Numerics . multiply ( lhs, rhs) }
0 commit comments