Skip to content

Commit b0ac915

Browse files
committed
coreml : use Core ML encoder inference
1 parent 72af0f5 commit b0ac915

9 files changed

+643
-24
lines changed

.gitignore

+2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
*.o
22
*.a
3+
*.mlmodel
4+
*.mlmodelc
35
.cache/
46
.vs/
57
.vscode/

CMakeLists.txt

+59-9
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,8 @@ if (APPLE)
5454
option(WHISPER_NO_AVX "whisper: disable AVX" OFF)
5555
option(WHISPER_NO_AVX2 "whisper: disable AVX2" OFF)
5656
option(WHISPER_NO_FMA "whisper: disable FMA" OFF)
57+
58+
option(WHISPER_COREML "whisper: enable Core ML framework" OFF)
5759
else()
5860
option(WHISPER_SUPPORT_OPENBLAS "whisper: support for OpenBLAS" OFF)
5961
endif()
@@ -86,16 +88,33 @@ endif()
8688

8789
find_package(Threads REQUIRED)
8890

89-
# on APPLE - include Accelerate framework
90-
if (APPLE AND NOT WHISPER_NO_ACCELERATE)
91-
find_library(ACCELERATE_FRAMEWORK Accelerate)
92-
if (ACCELERATE_FRAMEWORK)
93-
message(STATUS "Accelerate framework found")
91+
# on APPLE
92+
if (APPLE)
93+
# include Accelerate framework
94+
if (NOT WHISPER_NO_ACCELERATE)
95+
find_library(ACCELERATE_FRAMEWORK Accelerate)
96+
97+
if (ACCELERATE_FRAMEWORK)
98+
message(STATUS "Accelerate framework found")
9499

95-
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
96-
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
97-
else()
98-
message(WARNING "Accelerate framework not found")
100+
set(WHISPER_EXTRA_LIBS ${WHISPER_EXTRA_LIBS} ${ACCELERATE_FRAMEWORK})
101+
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_USE_ACCELERATE)
102+
else()
103+
message(WARNING "Accelerate framework not found")
104+
endif()
105+
endif()
106+
107+
if (WHISPER_COREML)
108+
find_library(FOUNDATION_FRAMEWORK Foundation)
109+
find_library(COREML_FRAMEWORK CoreML)
110+
111+
if (COREML_FRAMEWORK)
112+
message(STATUS "CoreML framework found")
113+
114+
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DWHISPER_USE_COREML)
115+
else()
116+
message(WARNING "CoreML framework not found")
117+
endif()
99118
endif()
100119
endif()
101120

@@ -181,6 +200,33 @@ if (WHISPER_PERF)
181200
set(WHISPER_EXTRA_FLAGS ${WHISPER_EXTRA_FLAGS} -DGGML_PERF)
182201
endif()
183202

203+
#
204+
# whisper.coreml - Core ML support
205+
#
206+
207+
if (WHISPER_COREML)
208+
set(TARGET whisper.coreml)
209+
210+
add_library(${TARGET}
211+
coreml/whisper-encoder.h
212+
coreml/whisper-encoder.mm
213+
coreml/whisper-encoder-impl.h
214+
coreml/whisper-encoder-impl.m
215+
)
216+
217+
include(DefaultTargetOptions)
218+
219+
target_include_directories(${TARGET} PUBLIC
220+
.
221+
)
222+
223+
target_link_libraries(${TARGET} PRIVATE ${FOUNDATION_FRAMEWORK} ${COREML_FRAMEWORK})
224+
225+
set_target_properties(${TARGET} PROPERTIES
226+
COMPILE_FLAGS "-fobjc-arc"
227+
)
228+
endif()
229+
184230
#
185231
# whisper - this is the main library of the project
186232
#
@@ -200,6 +246,10 @@ target_include_directories(${TARGET} PUBLIC
200246
.
201247
)
202248

249+
if (WHISPER_COREML)
250+
target_link_libraries(${TARGET} PRIVATE whisper.coreml)
251+
endif()
252+
203253
if (MSVC)
204254
target_link_libraries(${TARGET} PRIVATE ${WHISPER_EXTRA_LIBS} ${CMAKE_THREAD_LIBS_INIT})
205255

Makefile

+30-14
Original file line numberDiff line numberDiff line change
@@ -132,6 +132,10 @@ ifndef WHISPER_NO_ACCELERATE
132132
LDFLAGS += -framework Accelerate
133133
endif
134134
endif
135+
ifdef WHISPER_COREML
136+
CXXFLAGS += -DWHISPER_USE_COREML
137+
LDFLAGS += -framework Foundation -framework CoreML
138+
endif
135139
ifdef WHISPER_OPENBLAS
136140
CFLAGS += -DGGML_USE_OPENBLAS -I/usr/local/include/openblas
137141
LDFLAGS += -lopenblas
@@ -184,11 +188,23 @@ ggml.o: ggml.c ggml.h
184188
whisper.o: whisper.cpp whisper.h
185189
$(CXX) $(CXXFLAGS) -c whisper.cpp -o whisper.o
186190

187-
libwhisper.a: ggml.o whisper.o
188-
$(AR) rcs libwhisper.a ggml.o whisper.o
191+
ifndef WHISPER_COREML
192+
WHISPER_OBJ = whisper.o
193+
else
194+
whisper-encoder.o: coreml/whisper-encoder.mm coreml/whisper-encoder.h
195+
$(CXX) -O3 -I . -c coreml/whisper-encoder.mm -o whisper-encoder.o
196+
197+
whisper-encoder-impl.o: coreml/whisper-encoder-impl.m coreml/whisper-encoder-impl.h
198+
$(CXX) -O3 -I . -fobjc-arc -c coreml/whisper-encoder-impl.m -o whisper-encoder-impl.o
199+
200+
WHISPER_OBJ = whisper.o whisper-encoder.o whisper-encoder-impl.o
201+
endif
202+
203+
libwhisper.a: ggml.o $(WHISPER_OBJ)
204+
$(AR) rcs libwhisper.a ggml.o $(WHISPER_OBJ)
189205

190-
libwhisper.so: ggml.o whisper.o
191-
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o whisper.o $(LDFLAGS)
206+
libwhisper.so: ggml.o $(WHISPER_OBJ)
207+
$(CXX) $(CXXFLAGS) -shared -o libwhisper.so ggml.o $(WHISPER_OBJ) $(LDFLAGS)
192208

193209
clean:
194210
rm -f *.o main stream command talk bench libwhisper.a libwhisper.so
@@ -202,21 +218,21 @@ CC_SDL=`sdl2-config --cflags --libs`
202218
SRC_COMMON = examples/common.cpp
203219
SRC_COMMON_SDL = examples/common-sdl.cpp
204220

205-
main: examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o
206-
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o whisper.o -o main $(LDFLAGS)
221+
main: examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ)
222+
$(CXX) $(CXXFLAGS) examples/main/main.cpp $(SRC_COMMON) ggml.o $(WHISPER_OBJ) -o main $(LDFLAGS)
207223
./main -h
208224

209-
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
210-
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o stream $(CC_SDL) $(LDFLAGS)
225+
stream: examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
226+
$(CXX) $(CXXFLAGS) examples/stream/stream.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o stream $(CC_SDL) $(LDFLAGS)
211227

212-
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
213-
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o command $(CC_SDL) $(LDFLAGS)
228+
command: examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
229+
$(CXX) $(CXXFLAGS) examples/command/command.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o command $(CC_SDL) $(LDFLAGS)
214230

215-
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o
216-
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o whisper.o -o talk $(CC_SDL) $(LDFLAGS)
231+
talk: examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ)
232+
$(CXX) $(CXXFLAGS) examples/talk/talk.cpp examples/talk/gpt-2.cpp $(SRC_COMMON) $(SRC_COMMON_SDL) ggml.o $(WHISPER_OBJ) -o talk $(CC_SDL) $(LDFLAGS)
217233

218-
bench: examples/bench/bench.cpp ggml.o whisper.o
219-
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o whisper.o -o bench $(LDFLAGS)
234+
bench: examples/bench/bench.cpp ggml.o $(WHISPER_OBJ)
235+
$(CXX) $(CXXFLAGS) examples/bench/bench.cpp ggml.o $(WHISPER_OBJ) -o bench $(LDFLAGS)
220236

221237
#
222238
# Audio samples

coreml/whisper-encoder-impl.h

+142
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,142 @@
1+
//
2+
// CoremlEncoder.h
3+
//
4+
// This file was automatically generated and should not be edited.
5+
//
6+
7+
#import <Foundation/Foundation.h>
8+
#import <CoreML/CoreML.h>
9+
#include <stdint.h>
10+
#include <os/log.h>
11+
12+
NS_ASSUME_NONNULL_BEGIN
13+
14+
15+
/// Model Prediction Input Type
16+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
17+
@interface CoremlEncoderInput : NSObject<MLFeatureProvider>
18+
19+
/// melSegment as 1 × 80 × 3000 3-dimensional array of floats
20+
@property (readwrite, nonatomic, strong) MLMultiArray * melSegment;
21+
- (instancetype)init NS_UNAVAILABLE;
22+
- (instancetype)initWithMelSegment:(MLMultiArray *)melSegment NS_DESIGNATED_INITIALIZER;
23+
24+
@end
25+
26+
27+
/// Model Prediction Output Type
28+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
29+
@interface CoremlEncoderOutput : NSObject<MLFeatureProvider>
30+
31+
/// output as multidimensional array of floats
32+
@property (readwrite, nonatomic, strong) MLMultiArray * output;
33+
- (instancetype)init NS_UNAVAILABLE;
34+
- (instancetype)initWithOutput:(MLMultiArray *)output NS_DESIGNATED_INITIALIZER;
35+
36+
@end
37+
38+
39+
/// Class for model loading and prediction
40+
API_AVAILABLE(macos(10.15), ios(13.0), watchos(6.0), tvos(13.0)) __attribute__((visibility("hidden")))
41+
@interface CoremlEncoder : NSObject
42+
@property (readonly, nonatomic, nullable) MLModel * model;
43+
44+
/**
45+
URL of the underlying .mlmodelc directory.
46+
*/
47+
+ (nullable NSURL *)URLOfModelInThisBundle;
48+
49+
/**
50+
Initialize CoremlEncoder instance from an existing MLModel object.
51+
52+
Usually the application does not use this initializer unless it makes a subclass of CoremlEncoder.
53+
Such application may want to use `-[MLModel initWithContentsOfURL:configuration:error:]` and `+URLOfModelInThisBundle` to create a MLModel object to pass-in.
54+
*/
55+
- (instancetype)initWithMLModel:(MLModel *)model NS_DESIGNATED_INITIALIZER;
56+
57+
/**
58+
Initialize CoremlEncoder instance with the model in this bundle.
59+
*/
60+
- (nullable instancetype)init;
61+
62+
/**
63+
Initialize CoremlEncoder instance with the model in this bundle.
64+
65+
@param configuration The model configuration object
66+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
67+
*/
68+
- (nullable instancetype)initWithConfiguration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
69+
70+
/**
71+
Initialize CoremlEncoder instance from the model URL.
72+
73+
@param modelURL URL to the .mlmodelc directory for CoremlEncoder.
74+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
75+
*/
76+
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL error:(NSError * _Nullable __autoreleasing * _Nullable)error;
77+
78+
/**
79+
Initialize CoremlEncoder instance from the model URL.
80+
81+
@param modelURL URL to the .mlmodelc directory for CoremlEncoder.
82+
@param configuration The model configuration object
83+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
84+
*/
85+
- (nullable instancetype)initWithContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration error:(NSError * _Nullable __autoreleasing * _Nullable)error;
86+
87+
/**
88+
Construct CoremlEncoder instance asynchronously with configuration.
89+
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
90+
91+
@param configuration The model configuration
92+
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
93+
*/
94+
+ (void)loadWithConfiguration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
95+
96+
/**
97+
Construct CoremlEncoder instance asynchronously with URL of .mlmodelc directory and optional configuration.
98+
99+
Model loading may take time when the model content is not immediately available (e.g. encrypted model). Use this factory method especially when the caller is on the main thread.
100+
101+
@param modelURL The model URL.
102+
@param configuration The model configuration
103+
@param handler When the model load completes successfully or unsuccessfully, the completion handler is invoked with a valid CoremlEncoder instance or NSError object.
104+
*/
105+
+ (void)loadContentsOfURL:(NSURL *)modelURL configuration:(MLModelConfiguration *)configuration completionHandler:(void (^)(CoremlEncoder * _Nullable model, NSError * _Nullable error))handler API_AVAILABLE(macos(11.0), ios(14.0), watchos(7.0), tvos(14.0)) __attribute__((visibility("hidden")));
106+
107+
/**
108+
Make a prediction using the standard interface
109+
@param input an instance of CoremlEncoderInput to predict from
110+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
111+
@return the prediction as CoremlEncoderOutput
112+
*/
113+
- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input error:(NSError * _Nullable __autoreleasing * _Nullable)error;
114+
115+
/**
116+
Make a prediction using the standard interface
117+
@param input an instance of CoremlEncoderInput to predict from
118+
@param options prediction options
119+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
120+
@return the prediction as CoremlEncoderOutput
121+
*/
122+
- (nullable CoremlEncoderOutput *)predictionFromFeatures:(CoremlEncoderInput *)input options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
123+
124+
/**
125+
Make a prediction using the convenience interface
126+
@param melSegment as 1 × 80 × 3000 3-dimensional array of floats:
127+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
128+
@return the prediction as CoremlEncoderOutput
129+
*/
130+
- (nullable CoremlEncoderOutput *)predictionFromMelSegment:(MLMultiArray *)melSegment error:(NSError * _Nullable __autoreleasing * _Nullable)error;
131+
132+
/**
133+
Batch prediction
134+
@param inputArray array of CoremlEncoderInput instances to obtain predictions from
135+
@param options prediction options
136+
@param error If an error occurs, upon return contains an NSError object that describes the problem. If you are not interested in possible errors, pass in NULL.
137+
@return the predictions as NSArray<CoremlEncoderOutput *>
138+
*/
139+
- (nullable NSArray<CoremlEncoderOutput *> *)predictionsFromInputs:(NSArray<CoremlEncoderInput*> *)inputArray options:(MLPredictionOptions *)options error:(NSError * _Nullable __autoreleasing * _Nullable)error;
140+
@end
141+
142+
NS_ASSUME_NONNULL_END

0 commit comments

Comments
 (0)