fix: preserve mobile bootstrap auth fallback (#60238) (thanks @ngutman)

This commit is contained in:
Peter Steinberger
2026-04-04 15:55:24 +09:00
parent 226ca1f324
commit 20266ff7dd
7 changed files with 318 additions and 13 deletions

View File

@@ -418,11 +418,63 @@ class GatewaySession(
}
throw GatewayConnectFailure(error)
}
handleConnectSuccess(res, identity.deviceId)
handleConnectSuccess(res, identity.deviceId, selectedAuth.authSource)
connectDeferred.complete(Unit)
}
private fun handleConnectSuccess(res: RpcResponse, deviceId: String) {
private fun shouldPersistBootstrapHandoffTokens(authSource: GatewayConnectAuthSource): Boolean {
if (authSource != GatewayConnectAuthSource.BOOTSTRAP_TOKEN) return false
if (isLoopbackGatewayHost(endpoint.host)) return true
return tls != null
}
private fun filteredBootstrapHandoffScopes(role: String, scopes: List<String>): List<String>? {
return when (role.trim()) {
"node" -> emptyList()
"operator" -> {
val allowedOperatorScopes =
setOf(
"operator.approvals",
"operator.read",
"operator.talk.secrets",
"operator.write",
)
scopes.filter { allowedOperatorScopes.contains(it) }.distinct().sorted()
}
else -> null
}
}
private fun persistBootstrapHandoffToken(
deviceId: String,
role: String,
token: String,
scopes: List<String>,
) {
if (filteredBootstrapHandoffScopes(role, scopes) == null) return
deviceAuthStore.saveToken(deviceId, role, token)
}
private fun persistIssuedDeviceToken(
authSource: GatewayConnectAuthSource,
deviceId: String,
role: String,
token: String,
scopes: List<String>,
) {
if (authSource == GatewayConnectAuthSource.BOOTSTRAP_TOKEN) {
if (!shouldPersistBootstrapHandoffTokens(authSource)) return
persistBootstrapHandoffToken(deviceId, role, token, scopes)
return
}
deviceAuthStore.saveToken(deviceId, role, token)
}
private fun handleConnectSuccess(
res: RpcResponse,
deviceId: String,
authSource: GatewayConnectAuthSource,
) {
val payloadJson = res.payloadJson ?: throw IllegalStateException("connect failed: missing payload")
val obj = json.parseToJsonElement(payloadJson).asObjectOrNull() ?: throw IllegalStateException("connect failed")
pendingDeviceTokenRetry = false
@@ -432,8 +484,27 @@ class GatewaySession(
val authObj = obj["auth"].asObjectOrNull()
val deviceToken = authObj?.get("deviceToken").asStringOrNull()
val authRole = authObj?.get("role").asStringOrNull() ?: options.role
val authScopes =
authObj?.get("scopes").asArrayOrNull()
?.mapNotNull { it.asStringOrNull() }
?: emptyList()
if (!deviceToken.isNullOrBlank()) {
deviceAuthStore.saveToken(deviceId, authRole, deviceToken)
persistIssuedDeviceToken(authSource, deviceId, authRole, deviceToken, authScopes)
}
if (shouldPersistBootstrapHandoffTokens(authSource)) {
authObj?.get("deviceTokens").asArrayOrNull()
?.mapNotNull { it.asObjectOrNull() }
?.forEach { tokenEntry ->
val handoffToken = tokenEntry["deviceToken"].asStringOrNull()
val handoffRole = tokenEntry["role"].asStringOrNull()
val handoffScopes =
tokenEntry["scopes"].asArrayOrNull()
?.mapNotNull { it.asStringOrNull() }
?: emptyList()
if (!handoffToken.isNullOrBlank() && !handoffRole.isNullOrBlank()) {
persistBootstrapHandoffToken(deviceId, handoffRole, handoffToken, handoffScopes)
}
}
}
val rawCanvas = obj["canvasHostUrl"].asStringOrNull()
canvasHostUrl = normalizeCanvasHostUrl(rawCanvas, endpoint, isTlsConnection = tls != null)
@@ -899,6 +970,8 @@ private fun formatGatewayAuthorityHost(host: String): String {
private fun JsonElement?.asObjectOrNull(): JsonObject? = this as? JsonObject
private fun JsonElement?.asArrayOrNull(): JsonArray? = this as? JsonArray
private fun JsonElement?.asStringOrNull(): String? =
when (this) {
is JsonNull -> null

View File

@@ -213,6 +213,137 @@ class GatewaySessionInvokeTest {
}
}
@Test
fun connect_storesPrimaryDeviceTokenFromSuccessfulSharedTokenConnect() = runBlocking {
val json = testJson()
val connected = CompletableDeferred<Unit>()
val lastDisconnect = AtomicReference("")
val server =
startGatewayServer(json) { webSocket, id, method, _ ->
when (method) {
"connect" -> {
webSocket.send(
connectResponseFrame(
id,
authJson = """{"deviceToken":"shared-node-token","role":"node","scopes":[]}""",
),
)
webSocket.close(1000, "done")
}
}
}
val harness =
createNodeHarness(
connected = connected,
lastDisconnect = lastDisconnect,
) { GatewaySession.InvokeResult.ok("""{"handled":true}""") }
try {
connectNodeSession(
session = harness.session,
port = server.port,
token = "shared-auth-token",
bootstrapToken = null,
)
awaitConnectedOrThrow(connected, lastDisconnect, server)
val deviceId = DeviceIdentityStore(RuntimeEnvironment.getApplication()).loadOrCreate().deviceId
assertEquals("shared-node-token", harness.deviceAuthStore.loadToken(deviceId, "node"))
assertNull(harness.deviceAuthStore.loadToken(deviceId, "operator"))
} finally {
shutdownHarness(harness, server)
}
}
@Test
fun bootstrapConnect_storesAdditionalBoundedDeviceTokensOnTrustedTransport() = runBlocking {
val json = testJson()
val connected = CompletableDeferred<Unit>()
val lastDisconnect = AtomicReference("")
val server =
startGatewayServer(json) { webSocket, id, method, _ ->
when (method) {
"connect" -> {
webSocket.send(
connectResponseFrame(
id,
authJson =
"""{"deviceToken":"bootstrap-node-token","role":"node","scopes":[],"deviceTokens":[{"deviceToken":"bootstrap-operator-token","role":"operator","scopes":["operator.admin","operator.approvals","operator.read","operator.talk.secrets","operator.write"]}]}""",
),
)
webSocket.close(1000, "done")
}
}
}
val harness =
createNodeHarness(
connected = connected,
lastDisconnect = lastDisconnect,
) { GatewaySession.InvokeResult.ok("""{"handled":true}""") }
try {
connectNodeSession(
session = harness.session,
port = server.port,
token = null,
bootstrapToken = "bootstrap-token",
)
awaitConnectedOrThrow(connected, lastDisconnect, server)
val deviceId = DeviceIdentityStore(RuntimeEnvironment.getApplication()).loadOrCreate().deviceId
assertEquals("bootstrap-node-token", harness.deviceAuthStore.loadToken(deviceId, "node"))
assertEquals("bootstrap-operator-token", harness.deviceAuthStore.loadToken(deviceId, "operator"))
} finally {
shutdownHarness(harness, server)
}
}
@Test
fun nonBootstrapConnect_ignoresAdditionalBootstrapDeviceTokens() = runBlocking {
val json = testJson()
val connected = CompletableDeferred<Unit>()
val lastDisconnect = AtomicReference("")
val server =
startGatewayServer(json) { webSocket, id, method, _ ->
when (method) {
"connect" -> {
webSocket.send(
connectResponseFrame(
id,
authJson =
"""{"deviceToken":"shared-node-token","role":"node","scopes":[],"deviceTokens":[{"deviceToken":"shared-operator-token","role":"operator","scopes":["operator.approvals","operator.read"]}]}""",
),
)
webSocket.close(1000, "done")
}
}
}
val harness =
createNodeHarness(
connected = connected,
lastDisconnect = lastDisconnect,
) { GatewaySession.InvokeResult.ok("""{"handled":true}""") }
try {
connectNodeSession(
session = harness.session,
port = server.port,
token = "shared-auth-token",
bootstrapToken = null,
)
awaitConnectedOrThrow(connected, lastDisconnect, server)
val deviceId = DeviceIdentityStore(RuntimeEnvironment.getApplication()).loadOrCreate().deviceId
assertEquals("shared-node-token", harness.deviceAuthStore.loadToken(deviceId, "node"))
assertNull(harness.deviceAuthStore.loadToken(deviceId, "operator"))
} finally {
shutdownHarness(harness, server)
}
}
@Test
fun nodeInvokeRequest_roundTripsInvokeResult() = runBlocking {
val handshakeOrigin = AtomicReference<String?>(null)
@@ -470,9 +601,14 @@ class GatewaySessionInvokeTest {
}
}
private fun connectResponseFrame(id: String, canvasHostUrl: String? = null): String {
private fun connectResponseFrame(
id: String,
canvasHostUrl: String? = null,
authJson: String? = null,
): String {
val canvas = canvasHostUrl?.let { "\"canvasHostUrl\":\"$it\"," } ?: ""
return """{"type":"res","id":"$id","ok":true,"payload":{$canvas"snapshot":{"sessionDefaults":{"mainSessionKey":"main"}}}}"""
val auth = authJson?.let { "\"auth\":$it," } ?: ""
return """{"type":"res","id":"$id","ok":true,"payload":{$canvas$auth"snapshot":{"sessionDefaults":{"mainSessionKey":"main"}}}}"""
}
private fun startGatewayServer(