Implement Graphql-WS Subscriptions (#630)

* Implement graphql-ws subscriptions

* Fix subscription payload issue

* Close session directly

* Improve id handling
This commit is contained in:
Mitchell Syer
2023-08-03 22:28:28 -04:00
committed by GitHub
parent 06d7a6d892
commit 92f494d0fe
4 changed files with 58 additions and 122 deletions

View File

@@ -46,4 +46,4 @@ class InfoQuery {
}
}
}
}
}

View File

@@ -13,30 +13,27 @@ import com.fasterxml.jackson.module.kotlin.convertValue
import com.fasterxml.jackson.module.kotlin.readValue
import io.javalin.websocket.WsContext
import io.javalin.websocket.WsMessageContext
import kotlinx.coroutines.FlowPreview
import kotlinx.coroutines.currentCoroutineContext
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.catch
import kotlinx.coroutines.flow.emitAll
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.flow.onCompletion
import kotlinx.coroutines.flow.onStart
import kotlinx.coroutines.flow.sample
import kotlinx.coroutines.job
import kotlinx.coroutines.runBlocking
import mu.KotlinLogging
import org.eclipse.jetty.websocket.api.CloseStatus
import suwayomi.tachidesk.graphql.server.TachideskGraphQLContextFactory
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ClientMessages.GQL_CONNECTION_INIT
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ClientMessages.GQL_CONNECTION_TERMINATE
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ClientMessages.GQL_START
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ClientMessages.GQL_STOP
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ClientMessages.GQL_SUBSCRIBE
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.CommonMessages.GQL_COMPLETE
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.CommonMessages.GQL_PING
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.CommonMessages.GQL_PONG
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ACK
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_ERROR
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_CONNECTION_KEEP_ALIVE
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_DATA
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_ERROR
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_NEXT
import suwayomi.tachidesk.graphql.server.toGraphQLContext
/**
@@ -51,8 +48,8 @@ class ApolloSubscriptionProtocolHandler(
) {
private val sessionState = ApolloSubscriptionSessionState()
private val logger = KotlinLogging.logger {}
private val keepAliveMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_KEEP_ALIVE.type)
private val basicConnectionErrorMessage = SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type)
private val pongMessage = SubscriptionOperationMessage(type = GQL_PONG.type)
private val basicConnectionErrorMessage = SubscriptionOperationMessage(type = GQL_ERROR.type)
private val acknowledgeMessage = SubscriptionOperationMessage(GQL_CONNECTION_ACK.type)
fun handleMessage(context: WsMessageContext): Flow<SubscriptionOperationMessage> {
@@ -62,9 +59,10 @@ class ApolloSubscriptionProtocolHandler(
return try {
when (operationMessage.type) {
GQL_CONNECTION_INIT.type -> onInit(operationMessage, context)
GQL_START.type -> startSubscription(operationMessage, context)
GQL_STOP.type -> onStop(operationMessage, context)
GQL_CONNECTION_TERMINATE.type -> onDisconnect(context)
GQL_SUBSCRIBE.type -> startSubscription(operationMessage, context)
GQL_COMPLETE.type -> onComplete(operationMessage, context)
GQL_PING.type -> onPing()
GQL_PONG.type -> emptyFlow()
else -> onUnknownOperation(operationMessage, context)
}
} catch (exception: Exception) {
@@ -85,46 +83,28 @@ class ApolloSubscriptionProtocolHandler(
}
}
/**
* If the keep alive configuration is set, send a message back to client at every interval until the session is terminated.
* Otherwise just return empty flux to append to the acknowledge message.
*/
@OptIn(FlowPreview::class)
private fun getKeepAliveFlow(context: WsContext): Flow<SubscriptionOperationMessage> {
val keepAliveInterval: Long? = 2000
if (keepAliveInterval != null) {
return flowOf(keepAliveMessage).sample(keepAliveInterval)
.onStart {
sessionState.saveKeepAliveSubscription(context, currentCoroutineContext().job)
}
}
return emptyFlow()
}
@Suppress("Detekt.TooGenericExceptionCaught")
private fun startSubscription(
operationMessage: SubscriptionOperationMessage,
context: WsContext
): Flow<SubscriptionOperationMessage> {
val graphQLContext = sessionState.getGraphQLContext(context)
if (operationMessage.id == null) {
logger.error("GraphQL subscription operation id is required")
return flowOf(basicConnectionErrorMessage)
}
if (sessionState.doesOperationExist(context, operationMessage)) {
if (sessionState.doesOperationExist(operationMessage)) {
sessionState.terminateSession(context, CloseStatus(4409, "Subscriber for ${operationMessage.id} already exists"))
logger.info("Already subscribed to operation ${operationMessage.id} for session ${context.sessionId}")
return emptyFlow()
}
val graphQLContext = sessionState.getGraphQLContext(context)
val payload = operationMessage.payload
if (payload == null) {
logger.error("GraphQL subscription payload was null instead of a GraphQLRequest object")
sessionState.stopOperation(context, operationMessage)
return flowOf(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
return flowOf(SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id))
}
try {
@@ -134,7 +114,7 @@ class ApolloSubscriptionProtocolHandler(
if (it.errors?.isNotEmpty() == true) {
SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id, payload = it)
} else {
SubscriptionOperationMessage(type = GQL_DATA.type, id = operationMessage.id, payload = it)
SubscriptionOperationMessage(type = GQL_NEXT.type, id = operationMessage.id, payload = it)
}
}
.onCompletion { if (it == null) emitAll(onComplete(operationMessage, context)) }
@@ -142,17 +122,14 @@ class ApolloSubscriptionProtocolHandler(
} catch (exception: Exception) {
logger.error("Error running graphql subscription", exception)
// Do not terminate the session, just stop the operation messages
sessionState.stopOperation(context, operationMessage)
return flowOf(SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id))
sessionState.completeOperation(operationMessage)
return flowOf(SubscriptionOperationMessage(type = GQL_ERROR.type, id = operationMessage.id))
}
}
private fun onInit(operationMessage: SubscriptionOperationMessage, context: WsContext): Flow<SubscriptionOperationMessage> {
saveContext(operationMessage, context)
val acknowledgeMessage = flowOf(acknowledgeMessage)
val keepAliveFlux = getKeepAliveFlow(context)
return acknowledgeMessage.onCompletion { if (it == null) emitAll(keepAliveFlux) }
.catch { emit(getConnectionErrorMessage(operationMessage)) }
return flowOf(acknowledgeMessage)
}
/**
@@ -172,36 +149,26 @@ class ApolloSubscriptionProtocolHandler(
operationMessage: SubscriptionOperationMessage,
context: WsContext
): Flow<SubscriptionOperationMessage> {
return sessionState.completeOperation(context, operationMessage)
return sessionState.completeOperation(operationMessage)
}
/**
* Called with the client has called stop manually, or on error, and we need to cancel the publisher
*/
private fun onStop(
operationMessage: SubscriptionOperationMessage,
context: WsContext
): Flow<SubscriptionOperationMessage> {
return sessionState.stopOperation(context, operationMessage)
private fun onPing(): Flow<SubscriptionOperationMessage> {
return flowOf(pongMessage)
}
private fun onDisconnect(context: WsContext): Flow<SubscriptionOperationMessage> {
sessionState.terminateSession(context)
sessionState.terminateSession(context, CloseStatus(1000, "Normal Closure"))
return emptyFlow()
}
private fun onUnknownOperation(operationMessage: SubscriptionOperationMessage, context: WsContext): Flow<SubscriptionOperationMessage> {
logger.error("Unknown subscription operation $operationMessage")
sessionState.stopOperation(context, operationMessage)
return flowOf(getConnectionErrorMessage(operationMessage))
sessionState.completeOperation(operationMessage)
return emptyFlow()
}
private fun onException(exception: Exception): Flow<SubscriptionOperationMessage> {
logger.error("Error parsing the subscription message", exception)
return flowOf(basicConnectionErrorMessage)
}
private fun getConnectionErrorMessage(operationMessage: SubscriptionOperationMessage): SubscriptionOperationMessage {
return SubscriptionOperationMessage(type = GQL_CONNECTION_ERROR.type, id = operationMessage.id)
}
}

View File

@@ -12,23 +12,22 @@ import io.javalin.websocket.WsContext
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.Flow
import kotlinx.coroutines.flow.emptyFlow
import kotlinx.coroutines.flow.flowOf
import kotlinx.coroutines.flow.onCompletion
import suwayomi.tachidesk.graphql.server.subscriptions.SubscriptionOperationMessage.ServerMessages.GQL_COMPLETE
import org.eclipse.jetty.websocket.api.CloseStatus
import suwayomi.tachidesk.graphql.server.toGraphQLContext
import java.util.concurrent.ConcurrentHashMap
import java.util.concurrent.CopyOnWriteArrayList
internal class ApolloSubscriptionSessionState {
// Sessions are saved by web socket session id
internal val activeKeepAliveSessions = ConcurrentHashMap<String, Job>()
// Operations are saved by web socket session id, then operation id
internal val activeOperations = ConcurrentHashMap<String, ConcurrentHashMap<String, Job>>()
internal val activeOperations = ConcurrentHashMap<String, Job>()
// The graphQL context is saved by web socket session id
private val cachedGraphQLContext = ConcurrentHashMap<String, GraphQLContext>()
private val sessionToOperationId = ConcurrentHashMap<String, CopyOnWriteArrayList<String>>()
/**
* Save the context created from the factory and possibly updated in the onConnect hook.
* This allows us to include some initial state to be used when handling all the messages.
@@ -43,15 +42,6 @@ internal class ApolloSubscriptionSessionState {
*/
fun getGraphQLContext(context: WsContext): GraphQLContext = cachedGraphQLContext[context.sessionId] ?: emptyMap<Any, Any>().toGraphQLContext()
/**
* Save the session that is sending keep alive messages.
* This will override values without cancelling the subscription, so it is the responsibility of the consumer to cancel.
* These messages will be stopped on [terminateSession].
*/
fun saveKeepAliveSubscription(context: WsContext, subscription: Job) {
activeKeepAliveSessions[context.sessionId] = subscription
}
/**
* Save the operation that is sending data to the client.
* This will override values without cancelling the subscription so it is the responsibility of the consumer to cancel.
@@ -60,8 +50,8 @@ internal class ApolloSubscriptionSessionState {
fun saveOperation(context: WsContext, operationMessage: SubscriptionOperationMessage, subscription: Job) {
val id = operationMessage.id
if (id != null) {
val operationsForSession: ConcurrentHashMap<String, Job> = activeOperations.getOrPut(context.sessionId) { ConcurrentHashMap() }
operationsForSession[id] = subscription
activeOperations[id] = subscription
sessionToOperationId.getOrPut(context.sessionId) { CopyOnWriteArrayList() } += id
}
}
@@ -69,60 +59,36 @@ internal class ApolloSubscriptionSessionState {
* Send the [GQL_COMPLETE] message.
* This can happen when the publisher finishes or if the client manually sends the stop message.
*/
fun completeOperation(context: WsContext, operationMessage: SubscriptionOperationMessage): Flow<SubscriptionOperationMessage> {
return getCompleteMessage(operationMessage)
.onCompletion { removeActiveOperation(context, operationMessage.id, cancelSubscription = false) }
fun completeOperation(operationMessage: SubscriptionOperationMessage): Flow<SubscriptionOperationMessage> {
return getCompleteMessage()
.onCompletion { removeActiveOperation(operationMessage.id ?: return@onCompletion) }
}
/**
* Stop the subscription sending data and send the [GQL_COMPLETE] message.
* Does NOT terminate the session.
*/
fun stopOperation(context: WsContext, operationMessage: SubscriptionOperationMessage): Flow<SubscriptionOperationMessage> {
return getCompleteMessage(operationMessage)
.onCompletion { removeActiveOperation(context, operationMessage.id, cancelSubscription = true) }
}
private fun getCompleteMessage(operationMessage: SubscriptionOperationMessage): Flow<SubscriptionOperationMessage> {
val id = operationMessage.id
if (id != null) {
return flowOf(SubscriptionOperationMessage(type = GQL_COMPLETE.type, id = id))
}
private fun getCompleteMessage(): Flow<SubscriptionOperationMessage> {
return emptyFlow()
}
/**
* Remove active running subscription from the cache and cancel if needed
*/
private fun removeActiveOperation(context: WsContext, id: String?, cancelSubscription: Boolean) {
val operationsForSession = activeOperations[context.sessionId]
val subscription = operationsForSession?.get(id)
if (subscription != null) {
if (cancelSubscription) {
subscription.cancel()
}
operationsForSession.remove(id)
if (operationsForSession.isEmpty()) {
activeOperations.remove(context.sessionId)
}
}
private fun removeActiveOperation(id: String) {
activeOperations.remove(id)?.cancel()
}
/**
* Terminate the session, cancelling the keep alive messages and all operations active for this session.
*/
fun terminateSession(context: WsContext) {
activeOperations[context.sessionId]?.forEach { (_, subscription) -> subscription.cancel() }
activeOperations.remove(context.sessionId)
fun terminateSession(context: WsContext, code: CloseStatus) {
sessionToOperationId.remove(context.sessionId)?.forEach {
activeOperations[it]?.cancel()
}
cachedGraphQLContext.remove(context.sessionId)
activeKeepAliveSessions[context.sessionId]?.cancel()
activeKeepAliveSessions.remove(context.sessionId)
context.closeSession()
context.closeSession(code)
}
/**
* Looks up the operation for the client, to check if it already exists
*/
fun doesOperationExist(context: WsContext, operationMessage: SubscriptionOperationMessage): Boolean =
activeOperations[context.sessionId]?.containsKey(operationMessage.id) ?: false
fun doesOperationExist(operationMessage: SubscriptionOperationMessage): Boolean =
activeOperations.containsKey(operationMessage.id)
}

View File

@@ -8,6 +8,7 @@
package suwayomi.tachidesk.graphql.server.subscriptions
import com.fasterxml.jackson.annotation.JsonIgnoreProperties
import com.fasterxml.jackson.annotation.JsonInclude
/**
* The `graphql-ws` protocol from Apollo Client has some special text messages to signal events.
@@ -16,24 +17,26 @@ import com.fasterxml.jackson.annotation.JsonIgnoreProperties
* https://github.com/apollographql/subscriptions-transport-ws/blob/master/PROTOCOL.md
*/
@JsonIgnoreProperties(ignoreUnknown = true)
@JsonInclude(JsonInclude.Include.NON_NULL)
data class SubscriptionOperationMessage(
val type: String,
val id: String? = null,
val payload: Any? = null
) {
enum class CommonMessages(val type: String) {
GQL_PING("ping"),
GQL_PONG("pong"),
GQL_COMPLETE("complete")
}
enum class ClientMessages(val type: String) {
GQL_CONNECTION_INIT("connection_init"),
GQL_START("start"),
GQL_STOP("stop"),
GQL_CONNECTION_TERMINATE("connection_terminate")
GQL_SUBSCRIBE("subscribe")
}
enum class ServerMessages(val type: String) {
GQL_CONNECTION_ACK("connection_ack"),
GQL_CONNECTION_ERROR("connection_error"),
GQL_DATA("data"),
GQL_ERROR("error"),
GQL_COMPLETE("complete"),
GQL_CONNECTION_KEEP_ALIVE("ka")
GQL_NEXT("next"),
GQL_ERROR("error")
}
}