Skip to content
This repository was archived by the owner on Apr 1, 2025. It is now read-only.

Commit 4f91eb8

Browse files
authored
Sync from internal repo (2024-03-18) (#15)
* feat(sdk/go): add support for temporary tokens (#4127) GitOrigin-RevId: 3f1b1e5f61d0a4193424b7b54c00b87ad3d1ea81 * fix(sdk/go): add lemur model enums (#4147) GitOrigin-RevId: bf5eb9e70ce2ff5bdb7faadf97af01cf0a18b576 * fix(sdk): add conformer-2 enum to go and node sdks (#4146) GitOrigin-RevId: f5c7342796ba568e6ab8572042bb01e861c0b589 * fix(sdk/go): tidy up go.sum
1 parent 5b819ff commit 4f91eb8

File tree

7 files changed

+281
-152
lines changed

7 files changed

+281
-152
lines changed

assemblyai.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ import (
1111
)
1212

1313
const (
14-
version = "1.3.0"
14+
version = "1.4.0"
1515
defaultBaseURLScheme = "https"
1616
defaultBaseURLHost = "api.assemblyai.com"
1717
defaultUserAgent = "assemblyai-go/" + version
@@ -27,6 +27,7 @@ type Client struct {
2727

2828
Transcripts *TranscriptService
2929
LeMUR *LeMURService
30+
RealTime *RealTimeService
3031
}
3132

3233
// NewClientWithOptions returns a new configurable AssemblyAI client. If you provide client
@@ -51,6 +52,7 @@ func NewClientWithOptions(opts ...ClientOption) *Client {
5152

5253
c.Transcripts = &TranscriptService{client: c}
5354
c.LeMUR = &LeMURService{client: c}
55+
c.RealTime = &RealTimeService{client: c}
5456

5557
return c
5658
}

go.mod

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,13 @@ require (
66
github.com/cenkalti/backoff v2.2.1+incompatible
77
github.com/google/go-cmp v0.5.9
88
github.com/google/go-querystring v1.1.0
9+
github.com/stretchr/testify v1.9.0
910
nhooyr.io/websocket v1.8.7
1011
)
1112

12-
require github.com/klauspost/compress v1.10.3 // indirect
13+
require (
14+
github.com/davecgh/go-spew v1.1.1 // indirect
15+
github.com/klauspost/compress v1.10.3 // indirect
16+
github.com/pmezard/go-difflib v1.0.0 // indirect
17+
gopkg.in/yaml.v3 v3.0.1 // indirect
18+
)

go.sum

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
github.com/cenkalti/backoff v2.2.1+incompatible h1:tNowT99t7UNflLxfYYSlKYsBpXdEet03Pg2g16Swow4=
22
github.com/cenkalti/backoff v2.2.1+incompatible/go.mod h1:90ReRw6GdpyfrHakVjL/QHaoyV4aDUVVkXQJJJ3NXXM=
33
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
4+
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
45
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
56
github.com/gin-contrib/sse v0.1.0 h1:Y/yl/+YNO8GZSjAhjMsSuLt29uWRFHdHYUb5lYOV9qE=
67
github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI=
@@ -29,8 +30,6 @@ github.com/google/go-cmp v0.5.9/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeN
2930
github.com/google/go-querystring v1.1.0 h1:AnCroh3fv4ZBgVIf1Iwtovgjaw/GiKJo8M8yD/fhyJ8=
3031
github.com/google/go-querystring v1.1.0/go.mod h1:Kcdr2DB4koayq7X8pmAG4sNG59So17icRSOU623lUBU=
3132
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
32-
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5 h1:5AlozfqaVjGYGhms2OsdUyfdJME76E6rx5MdGpjzZpc=
33-
github.com/gordonklaus/portaudio v0.0.0-20230709114228-aafa478834f5/go.mod h1:WY8R6YKlI2ZI3UyzFk7P6yGSuS+hFwNtEzrexRyD7Es=
3433
github.com/gorilla/websocket v1.4.1 h1:q7AeDBpnBk8AogcD4DSag/Ukw/KV+YhzLj2bP5HvKCM=
3534
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
3635
github.com/json-iterator/go v1.1.9 h1:9yzud/Ht36ygwatGx56VwCZtlI/2AD15T1X2sjSuGns=
@@ -45,10 +44,13 @@ github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OH
4544
github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q=
4645
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742 h1:Esafd1046DLDQ0W1YjYsBW+p8U2u7vzgW2SQVmlNazg=
4746
github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0=
47+
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
4848
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
4949
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
5050
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
5151
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
52+
github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
53+
github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
5254
github.com/ugorji/go v1.1.7 h1:/68gy2h+1mWMrwZFeD1kQialdSzAb432dtpeJ42ovdo=
5355
github.com/ugorji/go v1.1.7/go.mod h1:kZn38zHttfInRq0xu/PH0az30d+z6vm202qpg1oXVMw=
5456
github.com/ugorji/go/codec v1.1.7 h1:2SvQaVZ1ouYrrKKwoSk2pzd4A9evlKJb9oTL+OaLUSs=
@@ -59,9 +61,12 @@ golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
5961
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
6062
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
6163
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
64+
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
6265
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
6366
gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
6467
gopkg.in/yaml.v2 v2.2.8 h1:obN1ZagJSUGI0Ek/LBmuj4SNLPfIny3KsKFopxRdj10=
6568
gopkg.in/yaml.v2 v2.2.8/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI=
69+
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
70+
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
6671
nhooyr.io/websocket v1.8.7 h1:usjR2uOr/zjjkVMy0lW+PPohFok7PCow5sDjLgX4P4g=
6772
nhooyr.io/websocket v1.8.7/go.mod h1:B70DZP8IakI65RVQ51MsWP/8jndNma26DVA/nFSCgW0=

lemur.go

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,26 @@ import (
44
"context"
55
)
66

7+
const (
8+
// LeMUR Default is best at complex reasoning. It offers more nuanced
9+
// responses and improved contextual comprehension.
10+
LeMURModelDefault LeMURModel = "default"
11+
12+
// LeMUR Basic is a simplified model optimized for speed and cost. LeMUR
13+
// Basic can complete requests up to 20% faster than Default.
14+
LeMURModelBasic LeMURModel = "basic"
15+
16+
// Claude 2.1 is similar to Default, with key improvements: it minimizes
17+
// model hallucination and system prompts, has a larger context window, and
18+
// performs better in citations.
19+
LeMURModelAssemblyAIMistral7B LeMURModel = "assemblyai/mistral-7b"
20+
21+
// LeMUR Mistral 7B is an LLM self-hosted by AssemblyAI. It's the fastest
22+
// and cheapest of the LLM options. We recommend it for use cases like basic
23+
// summaries and factual Q&A.
24+
LeMURModelAnthropicClaude2_1 LeMURModel = "anthropic/claude-2-1"
25+
)
26+
727
// LeMURService groups the operations related to LeMUR.
828
type LeMURService struct {
929
client *Client

realtime.go

Lines changed: 90 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"context"
55
"encoding/json"
66
"errors"
7-
"fmt"
87
"net/http"
98
"net/url"
109
"strconv"
@@ -15,10 +14,12 @@ import (
1514
)
1615

1716
var (
18-
// ErrSessionClosed is returned when attempting to write to a closed session.
17+
// ErrSessionClosed is returned when attempting to write to a closed
18+
// session.
1919
ErrSessionClosed = errors.New("session closed")
2020

21-
// ErrDisconnected is returned when attempting to write to a disconnected client.
21+
// ErrDisconnected is returned when attempting to write to a disconnected
22+
// client.
2223
ErrDisconnected = errors.New("client is disconnected")
2324
)
2425

@@ -72,8 +73,10 @@ type RealTimeBaseTranscript struct {
7273
// The partial transcript for your audio
7374
Text string `json:"text"`
7475

75-
// An array of objects, with the information for each word in the transcription text.
76-
// Includes the start and end time of the word in milliseconds, the confidence score of the word, and the text, which is the word itself.
76+
// An array of objects, with the information for each word in the
77+
// transcription text. Includes the start and end time of the word in
78+
// milliseconds, the confidence score of the word, and the text, which is
79+
// the word itself.
7780
Words []Word `json:"words"`
7881
}
7982

@@ -116,8 +119,10 @@ var DefaultSampleRate = 16_000
116119
type RealTimeClient struct {
117120
baseURL *url.URL
118121
apiKey string
122+
token string
119123

120-
conn *websocket.Conn
124+
conn *websocket.Conn
125+
httpClient *http.Client
121126

122127
mtx sync.RWMutex
123128
sessionOpen bool
@@ -126,6 +131,10 @@ type RealTimeClient struct {
126131
done chan bool
127132

128133
handler RealTimeHandler
134+
135+
sampleRate int
136+
encoding RealTimeEncoding
137+
wordBoost []string
129138
}
130139

131140
func (c *RealTimeClient) isSessionOpen() bool {
@@ -148,7 +157,8 @@ type RealTimeError struct {
148157

149158
type RealTimeClientOption func(*RealTimeClient)
150159

151-
// WithRealTimeBaseURL sets the API endpoint used by the client. Mainly used for testing.
160+
// WithRealTimeBaseURL sets the API endpoint used by the client. Mainly used for
161+
// testing.
152162
func WithRealTimeBaseURL(rawurl string) RealTimeClientOption {
153163
return func(c *RealTimeClient) {
154164
if u, err := url.Parse(rawurl); err == nil {
@@ -157,12 +167,22 @@ func WithRealTimeBaseURL(rawurl string) RealTimeClientOption {
157167
}
158168
}
159169

170+
// WithRealTimeAuthToken configures the client to authenticate using an
171+
// AssemblyAI API key.
160172
func WithRealTimeAPIKey(apiKey string) RealTimeClientOption {
161173
return func(rtc *RealTimeClient) {
162174
rtc.apiKey = apiKey
163175
}
164176
}
165177

178+
// WithRealTimeAuthToken configures the client to authenticate using a temporary
179+
// token generated using [CreateTemporaryToken].
180+
func WithRealTimeAuthToken(token string) RealTimeClientOption {
181+
return func(rtc *RealTimeClient) {
182+
rtc.token = token
183+
}
184+
}
185+
166186
func WithHandler(handler RealTimeHandler) RealTimeClientOption {
167187
return func(rtc *RealTimeClient) {
168188
rtc.handler = handler
@@ -171,24 +191,13 @@ func WithHandler(handler RealTimeHandler) RealTimeClientOption {
171191

172192
func WithRealTimeSampleRate(sampleRate int) RealTimeClientOption {
173193
return func(rtc *RealTimeClient) {
174-
if sampleRate > 0 {
175-
vs := rtc.baseURL.Query()
176-
vs.Set("sample_rate", strconv.Itoa(sampleRate))
177-
rtc.baseURL.RawQuery = vs.Encode()
178-
}
194+
rtc.sampleRate = sampleRate
179195
}
180196
}
181197

182198
func WithRealTimeWordBoost(wordBoost []string) RealTimeClientOption {
183199
return func(rtc *RealTimeClient) {
184-
vs := rtc.baseURL.Query()
185-
186-
if len(wordBoost) > 0 {
187-
b, _ := json.Marshal(wordBoost)
188-
vs.Set("word_boost", string(b))
189-
}
190-
191-
rtc.baseURL.RawQuery = vs.Encode()
200+
rtc.wordBoost = wordBoost
192201
}
193202
}
194203

@@ -205,26 +214,26 @@ const (
205214

206215
func WithRealTimeEncoding(encoding RealTimeEncoding) RealTimeClientOption {
207216
return func(rtc *RealTimeClient) {
208-
vs := rtc.baseURL.Query()
209-
vs.Set("encoding", string(encoding))
210-
rtc.baseURL.RawQuery = vs.Encode()
217+
rtc.encoding = encoding
211218
}
212219
}
213220

214221
func NewRealTimeClientWithOptions(options ...RealTimeClientOption) *RealTimeClient {
215222
client := &RealTimeClient{
216223
baseURL: &url.URL{
217-
Scheme: "wss",
218-
Host: "api.assemblyai.com",
219-
Path: "/v2/realtime/ws",
220-
RawQuery: fmt.Sprintf("sample_rate=%v", DefaultSampleRate),
224+
Scheme: "wss",
225+
Host: "api.assemblyai.com",
226+
Path: "/v2/realtime/ws",
221227
},
228+
httpClient: &http.Client{},
222229
}
223230

224231
for _, option := range options {
225232
option(client)
226233
}
227234

235+
client.baseURL.RawQuery = client.queryFromOptions()
236+
228237
return client
229238
}
230239

@@ -261,7 +270,6 @@ func NewRealTimeClient(apiKey string, handler RealTimeHandler) *RealTimeClient {
261270
// Closes the any open WebSocket connection in case of errors.
262271
func (c *RealTimeClient) Connect(ctx context.Context) error {
263272
header := make(http.Header)
264-
header.Set("Authorization", c.apiKey)
265273

266274
opts := &websocket.DialOptions{
267275
HTTPHeader: header,
@@ -360,6 +368,33 @@ func (c *RealTimeClient) Connect(ctx context.Context) error {
360368
return nil
361369
}
362370

371+
func (c *RealTimeClient) queryFromOptions() string {
372+
values := url.Values{}
373+
374+
// Temporary token
375+
if c.token != "" {
376+
values.Set("token", c.token)
377+
}
378+
379+
// Sample rate
380+
if c.sampleRate > 0 {
381+
values.Set("sample_rate", strconv.Itoa(c.sampleRate))
382+
}
383+
384+
// Encoding
385+
if c.encoding != "" {
386+
values.Set("encoding", string(c.encoding))
387+
}
388+
389+
// Word boost
390+
if len(c.wordBoost) > 0 {
391+
b, _ := json.Marshal(c.wordBoost)
392+
values.Set("word_boost", string(b))
393+
}
394+
395+
return values.Encode()
396+
}
397+
363398
// Disconnect sends the terminate_session message and waits for the server to
364399
// send a SessionTerminated message before closing the connection.
365400
func (c *RealTimeClient) Disconnect(ctx context.Context, waitForSessionTermination bool) error {
@@ -405,3 +440,30 @@ func (c *RealTimeClient) SetEndUtteranceSilenceThreshold(ctx context.Context, th
405440
EndUtteranceSilenceThreshold: threshold,
406441
})
407442
}
443+
444+
// RealTimeService groups operations related to the real-time transcription.
445+
type RealTimeService struct {
446+
client *Client
447+
}
448+
449+
// CreateTemporaryToken creates a temporary token that can be used to
450+
// authenticate a real-time client.
451+
func (svc *RealTimeService) CreateTemporaryToken(ctx context.Context, expiresIn int64) (*RealtimeTemporaryTokenResponse, error) {
452+
params := &CreateRealtimeTemporaryTokenParams{
453+
ExpiresIn: Int64(expiresIn),
454+
}
455+
456+
req, err := svc.client.newJSONRequest("POST", "/v2/realtime/token", params)
457+
if err != nil {
458+
return nil, err
459+
}
460+
461+
var tokenResponse RealtimeTemporaryTokenResponse
462+
resp, err := svc.client.do(ctx, req, &tokenResponse)
463+
if err != nil {
464+
return nil, err
465+
}
466+
defer resp.Body.Close()
467+
468+
return &tokenResponse, nil
469+
}

0 commit comments

Comments
 (0)