Skip to content

Commit c88dd36

Browse files
authored
feat(realtime): pull access token mechanism (#615)
* feat(realtime): pull access token mechanism * add tests * report issue if custom access token is assigned * add docs * pull token from third party or fallback to auth
1 parent 2824f14 commit c88dd36

File tree

5 files changed

+99
-45
lines changed

5 files changed

+99
-45
lines changed

Sources/Realtime/V2/RealtimeChannelV2.swift

+9-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ struct Socket: Sendable {
2929
var broadcastURL: @Sendable () -> URL
3030
var status: @Sendable () -> RealtimeClientStatus
3131
var options: @Sendable () -> RealtimeClientOptions
32-
var accessToken: @Sendable () -> String?
32+
var accessToken: @Sendable () async -> String?
3333
var apiKey: @Sendable () -> String?
3434
var makeRef: @Sendable () -> Int
3535

@@ -46,7 +46,12 @@ extension Socket {
4646
broadcastURL: { [weak client] in client?.broadcastURL ?? URL(string: "http://localhost")! },
4747
status: { [weak client] in client?.status ?? .disconnected },
4848
options: { [weak client] in client?.options ?? .init() },
49-
accessToken: { [weak client] in client?.mutableState.accessToken },
49+
accessToken: { [weak client] in
50+
if let accessToken = try? await client?.options.accessToken?() {
51+
return accessToken
52+
}
53+
return client?.mutableState.accessToken
54+
},
5055
apiKey: { [weak client] in client?.apikey },
5156
makeRef: { [weak client] in client?.makeRef() ?? 0 },
5257
connect: { [weak client] in await client?.connect() },
@@ -139,7 +144,7 @@ public final class RealtimeChannelV2: Sendable {
139144

140145
let payload = RealtimeJoinPayload(
141146
config: joinConfig,
142-
accessToken: socket.accessToken()
147+
accessToken: await socket.accessToken()
143148
)
144149

145150
let joinRef = socket.makeRef().description
@@ -213,7 +218,7 @@ public final class RealtimeChannelV2: Sendable {
213218
if let apiKey = socket.apiKey() {
214219
headers[.apiKey] = apiKey
215220
}
216-
if let accessToken = socket.accessToken() {
221+
if let accessToken = await socket.accessToken() {
217222
headers[.authorization] = "Bearer \(accessToken)"
218223
}
219224

Sources/Realtime/V2/RealtimeClientV2.swift

+22-4
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,11 @@ public final class RealtimeClientV2: Sendable {
107107
apikey = options.apikey
108108

109109
mutableState.withValue {
110-
$0.accessToken = options.accessToken ?? options.apikey
110+
if let accessToken = options.headers[.authorization]?.split(separator: " ").last {
111+
$0.accessToken = String(accessToken)
112+
} else {
113+
$0.accessToken = options.apikey
114+
}
111115
}
112116
}
113117

@@ -361,8 +365,22 @@ public final class RealtimeClientV2: Sendable {
361365
}
362366

363367
/// Sets the JWT access token used for channel subscription authorization and Realtime RLS.
364-
/// - Parameter token: A JWT string.
365-
public func setAuth(_ token: String?) async {
368+
///
369+
/// If `token` is nil it will use the ``RealtimeClientOptions/accessToken`` callback function or the token set on the client.
370+
///
371+
/// On callback used, it will set the value of the token internal to the client.
372+
/// - Parameter token: A JWT string to override the token set on the client.
373+
public func setAuth(_ token: String? = nil) async {
374+
var token = token
375+
376+
if token == nil {
377+
token = try? await options.accessToken?()
378+
}
379+
380+
if token == nil {
381+
token = mutableState.accessToken
382+
}
383+
366384
if let token, let payload = JWT.decodePayload(token),
367385
let exp = payload["exp"] as? TimeInterval, exp < Date().timeIntervalSince1970
368386
{
@@ -371,7 +389,7 @@ public final class RealtimeClientV2: Sendable {
371389
return
372390
}
373391

374-
mutableState.withValue {
392+
mutableState.withValue { [token] in
375393
$0.accessToken = token
376394
}
377395

Sources/Realtime/V2/Types.swift

+4-8
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@
66
//
77

88
import Foundation
9-
import Helpers
109
import HTTPTypes
10+
import Helpers
1111

1212
#if canImport(FoundationNetworking)
1313
import FoundationNetworking
@@ -22,6 +22,7 @@ public struct RealtimeClientOptions: Sendable {
2222
var disconnectOnSessionLoss: Bool
2323
var connectOnSubscribe: Bool
2424
var fetch: (@Sendable (_ request: URLRequest) async throws -> (Data, URLResponse))?
25+
package var accessToken: (@Sendable () async throws -> String)?
2526
package var logger: (any SupabaseLogger)?
2627

2728
public static let defaultHeartbeatInterval: TimeInterval = 15
@@ -38,6 +39,7 @@ public struct RealtimeClientOptions: Sendable {
3839
disconnectOnSessionLoss: Bool = Self.defaultDisconnectOnSessionLoss,
3940
connectOnSubscribe: Bool = Self.defaultConnectOnSubscribe,
4041
fetch: (@Sendable (_ request: URLRequest) async throws -> (Data, URLResponse))? = nil,
42+
accessToken: (@Sendable () async throws -> String)? = nil,
4143
logger: (any SupabaseLogger)? = nil
4244
) {
4345
self.headers = HTTPFields(headers)
@@ -47,19 +49,13 @@ public struct RealtimeClientOptions: Sendable {
4749
self.disconnectOnSessionLoss = disconnectOnSessionLoss
4850
self.connectOnSubscribe = connectOnSubscribe
4951
self.fetch = fetch
52+
self.accessToken = accessToken
5053
self.logger = logger
5154
}
5255

5356
var apikey: String? {
5457
headers[.apiKey]
5558
}
56-
57-
var accessToken: String? {
58-
guard let accessToken = headers[.authorization]?.split(separator: " ").last else {
59-
return nil
60-
}
61-
return String(accessToken)
62-
}
6359
}
6460

6561
public typealias RealtimeSubscription = ObservationToken

Sources/Supabase/SupabaseClient.swift

+56-24
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
import ConcurrencyExtras
33
import Foundation
44
@_exported import Functions
5+
import HTTPTypes
56
import Helpers
67
import IssueReporting
78
@_exported import PostgREST
89
@_exported import Realtime
910
@_exported import Storage
10-
import HTTPTypes
1111

1212
#if canImport(FoundationNetworking)
1313
import FoundationNetworking
@@ -33,10 +33,11 @@ public final class SupabaseClient: Sendable {
3333
/// Supabase Auth allows you to create and manage user sessions for access to data that is secured by access policies.
3434
public var auth: AuthClient {
3535
if options.auth.accessToken != nil {
36-
reportIssue("""
37-
Supabase Client is configured with the auth.accessToken option,
38-
accessing supabase.auth is not possible.
39-
""")
36+
reportIssue(
37+
"""
38+
Supabase Client is configured with the auth.accessToken option,
39+
accessing supabase.auth is not possible.
40+
""")
4041
}
4142
return _auth
4243
}
@@ -80,7 +81,14 @@ public final class SupabaseClient: Sendable {
8081
let _realtime: UncheckedSendable<RealtimeClient>
8182

8283
/// Realtime client for Supabase
83-
public let realtimeV2: RealtimeClientV2
84+
public var realtimeV2: RealtimeClientV2 {
85+
mutableState.withValue {
86+
if $0.realtime == nil {
87+
$0.realtime = _initRealtimeClient()
88+
}
89+
return $0.realtime!
90+
}
91+
}
8492

8593
/// Supabase Functions allows you to deploy and invoke edge functions.
8694
public var functions: FunctionsClient {
@@ -112,6 +120,7 @@ public final class SupabaseClient: Sendable {
112120
var storage: SupabaseStorageClient?
113121
var rest: PostgrestClient?
114122
var functions: FunctionsClient?
123+
var realtime: RealtimeClientV2?
115124

116125
var changedAccessToken: String?
117126
}
@@ -189,18 +198,6 @@ public final class SupabaseClient: Sendable {
189198
)
190199
)
191200

192-
var realtimeOptions = options.realtime
193-
realtimeOptions.headers.merge(with: _headers)
194-
195-
if realtimeOptions.logger == nil {
196-
realtimeOptions.logger = options.global.logger
197-
}
198-
199-
realtimeV2 = RealtimeClientV2(
200-
url: supabaseURL.appendingPathComponent("/realtime/v1"),
201-
options: realtimeOptions
202-
)
203-
204201
if options.auth.accessToken == nil {
205202
listenForAuthEvents()
206203
}
@@ -351,11 +348,7 @@ public final class SupabaseClient: Sendable {
351348
}
352349

353350
private func adapt(request: URLRequest) async -> URLRequest {
354-
let token: String? = if let accessToken = options.auth.accessToken {
355-
try? await accessToken()
356-
} else {
357-
try? await auth.session.accessToken
358-
}
351+
let token = try? await _getAccessToken()
359352

360353
var request = request
361354
if let token {
@@ -364,6 +357,14 @@ public final class SupabaseClient: Sendable {
364357
return request
365358
}
366359

360+
private func _getAccessToken() async throws -> String {
361+
if let accessToken = options.auth.accessToken {
362+
try await accessToken()
363+
} else {
364+
try await auth.session.accessToken
365+
}
366+
}
367+
367368
private func listenForAuthEvents() {
368369
let task = Task {
369370
for await (event, session) in auth.authStateChanges {
@@ -377,7 +378,9 @@ public final class SupabaseClient: Sendable {
377378

378379
private func handleTokenChanged(event: AuthChangeEvent, session: Session?) async {
379380
let accessToken: String? = mutableState.withValue {
380-
if [.initialSession, .signedIn, .tokenRefreshed].contains(event), $0.changedAccessToken != session?.accessToken {
381+
if [.initialSession, .signedIn, .tokenRefreshed].contains(event),
382+
$0.changedAccessToken != session?.accessToken
383+
{
381384
$0.changedAccessToken = session?.accessToken
382385
return session?.accessToken ?? supabaseKey
383386
}
@@ -393,4 +396,33 @@ public final class SupabaseClient: Sendable {
393396
realtime.setAuth(accessToken)
394397
await realtimeV2.setAuth(accessToken)
395398
}
399+
400+
private func _initRealtimeClient() -> RealtimeClientV2 {
401+
var realtimeOptions = options.realtime
402+
realtimeOptions.headers.merge(with: _headers)
403+
404+
if realtimeOptions.logger == nil {
405+
realtimeOptions.logger = options.global.logger
406+
}
407+
408+
if realtimeOptions.accessToken == nil {
409+
realtimeOptions.accessToken = { [weak self] in
410+
try await self?._getAccessToken() ?? ""
411+
}
412+
} else {
413+
reportIssue(
414+
"""
415+
You assigned a custom `accessToken` closure to the RealtimeClientV2. This might not work as you expect
416+
as SupabaseClient uses Auth for pulling an access token to send on the realtime channels.
417+
418+
Please make sure you know what you're doing.
419+
"""
420+
)
421+
}
422+
423+
return RealtimeClientV2(
424+
url: supabaseURL.appendingPathComponent("/realtime/v1"),
425+
options: realtimeOptions
426+
)
427+
}
396428
}

Tests/RealtimeTests/RealtimeTests.swift

+8-5
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@ final class RealtimeTests: XCTestCase {
3636
headers: ["apikey": apiKey],
3737
heartbeatInterval: 1,
3838
reconnectDelay: 1,
39-
timeoutInterval: 2
39+
timeoutInterval: 2,
40+
accessToken: {
41+
"custom.access.token"
42+
}
4043
),
4144
ws: ws,
4245
http: http
@@ -100,7 +103,7 @@ final class RealtimeTests: XCTestCase {
100103
"event" : "phx_join",
101104
"join_ref" : "1",
102105
"payload" : {
103-
"access_token" : "anon.api.key",
106+
"access_token" : "custom.access.token",
104107
"config" : {
105108
"broadcast" : {
106109
"ack" : false,
@@ -179,7 +182,7 @@ final class RealtimeTests: XCTestCase {
179182
"event" : "phx_join",
180183
"join_ref" : "1",
181184
"payload" : {
182-
"access_token" : "anon.api.key",
185+
"access_token" : "custom.access.token",
183186
"config" : {
184187
"broadcast" : {
185188
"ack" : false,
@@ -201,7 +204,7 @@ final class RealtimeTests: XCTestCase {
201204
"event" : "phx_join",
202205
"join_ref" : "2",
203206
"payload" : {
204-
"access_token" : "anon.api.key",
207+
"access_token" : "custom.access.token",
205208
"config" : {
206209
"broadcast" : {
207210
"ack" : false,
@@ -322,7 +325,7 @@ final class RealtimeTests: XCTestCase {
322325
assertInlineSnapshot(of: request?.urlRequest, as: .raw(pretty: true)) {
323326
"""
324327
POST https://localhost:54321/realtime/v1/api/broadcast
325-
Authorization: Bearer anon.api.key
328+
Authorization: Bearer custom.access.token
326329
Content-Type: application/json
327330
apiKey: anon.api.key
328331

0 commit comments

Comments
 (0)