Support Token Expiry properly (#878)

* Support token expiry properly

* Small fix

* Lint

* Use newer fixes for expiry

* Lint
This commit is contained in:
Mitchell Syer
2024-02-19 11:06:00 -05:00
committed by GitHub
parent 6803ac0611
commit 07e011092a
7 changed files with 93 additions and 48 deletions

View File

@@ -67,6 +67,19 @@ class TrackerScoresDataLoader : KotlinDataLoader<Int, List<String>> {
} }
} }
class TrackerTokenExpiredDataLoader : KotlinDataLoader<Int, Boolean> {
override val dataLoaderName = "TrackerTokenExpiredDataLoader"
override fun getDataLoader(): DataLoader<Int, Boolean> =
DataLoaderFactory.newDataLoader { ids ->
future {
ids.map { id ->
TrackerManager.getTracker(id)?.getIfAuthExpired()
}
}
}
}
class TrackRecordsForMangaIdDataLoader : KotlinDataLoader<Int, TrackRecordNodeList> { class TrackRecordsForMangaIdDataLoader : KotlinDataLoader<Int, TrackRecordNodeList> {
override val dataLoaderName = "TrackRecordsForMangaIdDataLoader" override val dataLoaderName = "TrackRecordsForMangaIdDataLoader"

View File

@@ -38,6 +38,7 @@ import suwayomi.tachidesk.graphql.dataLoaders.TrackRecordsForTrackerIdDataLoader
import suwayomi.tachidesk.graphql.dataLoaders.TrackerDataLoader import suwayomi.tachidesk.graphql.dataLoaders.TrackerDataLoader
import suwayomi.tachidesk.graphql.dataLoaders.TrackerScoresDataLoader import suwayomi.tachidesk.graphql.dataLoaders.TrackerScoresDataLoader
import suwayomi.tachidesk.graphql.dataLoaders.TrackerStatusesDataLoader import suwayomi.tachidesk.graphql.dataLoaders.TrackerStatusesDataLoader
import suwayomi.tachidesk.graphql.dataLoaders.TrackerTokenExpiredDataLoader
import suwayomi.tachidesk.graphql.dataLoaders.UnreadChapterCountForMangaDataLoader import suwayomi.tachidesk.graphql.dataLoaders.UnreadChapterCountForMangaDataLoader
class TachideskDataLoaderRegistryFactory { class TachideskDataLoaderRegistryFactory {
@@ -71,6 +72,7 @@ class TachideskDataLoaderRegistryFactory {
TrackerDataLoader(), TrackerDataLoader(),
TrackerStatusesDataLoader(), TrackerStatusesDataLoader(),
TrackerScoresDataLoader(), TrackerScoresDataLoader(),
TrackerTokenExpiredDataLoader(),
TrackRecordsForMangaIdDataLoader(), TrackRecordsForMangaIdDataLoader(),
DisplayScoreForTrackRecordDataLoader(), DisplayScoreForTrackRecordDataLoader(),
TrackRecordsForTrackerIdDataLoader(), TrackRecordsForTrackerIdDataLoader(),

View File

@@ -49,6 +49,10 @@ class TrackerType(
fun trackRecords(dataFetchingEnvironment: DataFetchingEnvironment): CompletableFuture<TrackRecordNodeList> { fun trackRecords(dataFetchingEnvironment: DataFetchingEnvironment): CompletableFuture<TrackRecordNodeList> {
return dataFetchingEnvironment.getValueFromDataLoader<Int, TrackRecordNodeList>("TrackRecordsForTrackerIdDataLoader", id) return dataFetchingEnvironment.getValueFromDataLoader<Int, TrackRecordNodeList>("TrackRecordsForTrackerIdDataLoader", id)
} }
fun isTokenExpired(dataFetchingEnvironment: DataFetchingEnvironment): CompletableFuture<Boolean> {
return dataFetchingEnvironment.getValueFromDataLoader<Int, Boolean>("TrackerTokenExpiredDataLoader", id)
}
} }
class TrackStatusType( class TrackStatusType(

View File

@@ -5,6 +5,7 @@ import okhttp3.OkHttpClient
import suwayomi.tachidesk.manga.impl.track.tracker.model.Track import suwayomi.tachidesk.manga.impl.track.tracker.model.Track
import suwayomi.tachidesk.manga.impl.track.tracker.model.TrackSearch import suwayomi.tachidesk.manga.impl.track.tracker.model.TrackSearch
import uy.kohesive.injekt.injectLazy import uy.kohesive.injekt.injectLazy
import java.io.IOException
abstract class Tracker(val id: Int, val name: String) { abstract class Tracker(val id: Int, val name: String) {
val trackPreferences = TrackerPreferences val trackPreferences = TrackerPreferences
@@ -81,6 +82,14 @@ abstract class Tracker(val id: Int, val name: String) {
) { ) {
trackPreferences.setTrackCredentials(this, username, password) trackPreferences.setTrackCredentials(this, username, password)
} }
fun getIfAuthExpired(): Boolean {
return trackPreferences.trackAuthExpired(this)
}
fun setAuthExpired() {
trackPreferences.setTrackTokenExpired(this)
}
} }
fun String.extractToken(key: String): String? { fun String.extractToken(key: String): String? {
@@ -93,3 +102,7 @@ fun String.extractToken(key: String): String? {
} }
return null return null
} }
class TokenExpired : IOException("Token is expired, re-logging required")
class TokenRefreshFailed : IOException("Token refresh failed")

View File

@@ -16,6 +16,12 @@ object TrackerPreferences {
fun getTrackPassword(sync: Tracker) = preferenceStore.getString(trackPassword(sync.id), "") fun getTrackPassword(sync: Tracker) = preferenceStore.getString(trackPassword(sync.id), "")
fun trackAuthExpired(tracker: Tracker) =
preferenceStore.getBoolean(
trackTokenExpired(tracker.id),
false,
)
fun setTrackCredentials( fun setTrackCredentials(
sync: Tracker, sync: Tracker,
username: String, username: String,
@@ -25,6 +31,7 @@ object TrackerPreferences {
preferenceStore.edit() preferenceStore.edit()
.putString(trackUsername(sync.id), username) .putString(trackUsername(sync.id), username)
.putString(trackPassword(sync.id), password) .putString(trackPassword(sync.id), password)
.putBoolean(trackTokenExpired(sync.id), false)
.apply() .apply()
} }
@@ -38,14 +45,22 @@ object TrackerPreferences {
if (token == null) { if (token == null) {
preferenceStore.edit() preferenceStore.edit()
.remove(trackToken(sync.id)) .remove(trackToken(sync.id))
.putBoolean(trackTokenExpired(sync.id), false)
.apply() .apply()
} else { } else {
preferenceStore.edit() preferenceStore.edit()
.putString(trackToken(sync.id), token) .putString(trackToken(sync.id), token)
.putBoolean(trackTokenExpired(sync.id), false)
.apply() .apply()
} }
} }
fun setTrackTokenExpired(sync: Tracker) {
preferenceStore.edit()
.putBoolean(trackTokenExpired(sync.id), true)
.apply()
}
fun getScoreType(sync: Tracker) = preferenceStore.getString(scoreType(sync.id), Anilist.POINT_10) fun getScoreType(sync: Tracker) = preferenceStore.getString(scoreType(sync.id), Anilist.POINT_10)
fun setScoreType( fun setScoreType(
@@ -63,5 +78,7 @@ object TrackerPreferences {
private fun trackToken(trackerId: Int) = "track_token_$trackerId" private fun trackToken(trackerId: Int) = "track_token_$trackerId"
private fun trackTokenExpired(trackerId: Int) = "track_token_expired_$trackerId"
private fun scoreType(trackerId: Int) = "score_type_$trackerId" private fun scoreType(trackerId: Int) = "score_type_$trackerId"
} }

View File

@@ -2,6 +2,7 @@ package suwayomi.tachidesk.manga.impl.track.tracker.anilist
import okhttp3.Interceptor import okhttp3.Interceptor
import okhttp3.Response import okhttp3.Response
import suwayomi.tachidesk.manga.impl.track.tracker.TokenExpired
import java.io.IOException import java.io.IOException
class AnilistInterceptor(val anilist: Anilist, private var token: String?) : Interceptor { class AnilistInterceptor(val anilist: Anilist, private var token: String?) : Interceptor {
@@ -17,6 +18,9 @@ class AnilistInterceptor(val anilist: Anilist, private var token: String?) : Int
} }
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response {
if (anilist.getIfAuthExpired()) {
throw TokenExpired()
}
val originalRequest = chain.request() val originalRequest = chain.request()
if (token.isNullOrEmpty()) { if (token.isNullOrEmpty()) {
@@ -26,9 +30,9 @@ class AnilistInterceptor(val anilist: Anilist, private var token: String?) : Int
oauth = anilist.loadOAuth() oauth = anilist.loadOAuth()
} }
// Refresh access token if null or expired. // Refresh access token if null or expired.
if (oauth!!.isExpired()) { if (oauth?.isExpired() == true) {
anilist.logout() anilist.setAuthExpired()
throw IOException("Token expired") throw TokenExpired()
} }
// Throw on null auth. // Throw on null auth.

View File

@@ -1,9 +1,12 @@
package suwayomi.tachidesk.manga.impl.track.tracker.myanimelist package suwayomi.tachidesk.manga.impl.track.tracker.myanimelist
import eu.kanade.tachiyomi.AppInfo
import eu.kanade.tachiyomi.network.parseAs import eu.kanade.tachiyomi.network.parseAs
import kotlinx.serialization.json.Json import kotlinx.serialization.json.Json
import okhttp3.Interceptor import okhttp3.Interceptor
import okhttp3.Response import okhttp3.Response
import suwayomi.tachidesk.manga.impl.track.tracker.TokenExpired
import suwayomi.tachidesk.manga.impl.track.tracker.TokenRefreshFailed
import uy.kohesive.injekt.injectLazy import uy.kohesive.injekt.injectLazy
import java.io.IOException import java.io.IOException
@@ -13,51 +16,27 @@ class MyAnimeListInterceptor(private val myanimelist: MyAnimeList, private var t
private var oauth: OAuth? = null private var oauth: OAuth? = null
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response {
if (myanimelist.getIfAuthExpired()) {
throw TokenExpired()
}
val originalRequest = chain.request() val originalRequest = chain.request()
if (token.isNullOrEmpty()) { if (oauth?.isExpired() == true) {
throw IOException("Not authenticated with MyAnimeList") refreshToken(chain)
}
if (oauth == null) {
oauth = myanimelist.loadOAuth()
}
// Refresh access token if expired
if (oauth != null && oauth!!.isExpired()) {
setAuth(refreshToken(chain))
} }
if (oauth == null) { if (oauth == null) {
throw IOException("No authentication token") throw IOException("MAL: User is not authenticated")
} }
// Add the authorization header to the original request // Add the authorization header to the original request
val authRequest = val authRequest =
originalRequest.newBuilder() originalRequest.newBuilder()
.addHeader("Authorization", "Bearer ${oauth!!.access_token}") .addHeader("Authorization", "Bearer ${oauth!!.access_token}")
.header("User-Agent", "Suwayomi v${AppInfo.getVersionName()}")
.build() .build()
val response = chain.proceed(authRequest) return chain.proceed(authRequest)
val tokenIsExpired =
response.headers["www-authenticate"]
?.contains("The access token expired") ?: false
// Retry the request once with a new token in case it was not already refreshed
// by the is expired check before.
if (response.code == 401 && tokenIsExpired) {
response.close()
val newToken = refreshToken(chain)
setAuth(newToken)
val newRequest =
originalRequest.newBuilder()
.addHeader("Authorization", "Bearer ${newToken.access_token}")
.build()
return chain.proceed(newRequest)
}
return response
} }
/** /**
@@ -70,23 +49,36 @@ class MyAnimeListInterceptor(private val myanimelist: MyAnimeList, private var t
myanimelist.saveOAuth(oauth) myanimelist.saveOAuth(oauth)
} }
private fun refreshToken(chain: Interceptor.Chain): OAuth { private fun refreshToken(chain: Interceptor.Chain): OAuth =
val newOauth = synchronized(this) {
runCatching { if (myanimelist.getIfAuthExpired()) throw TokenExpired()
val oauthResponse = chain.proceed(MyAnimeListApi.refreshTokenRequest(oauth!!)) oauth?.takeUnless { it.isExpired() }?.let { return@synchronized it }
if (oauthResponse.isSuccessful) { val response =
with(json) { oauthResponse.parseAs<OAuth>() } try {
chain.proceed(MyAnimeListApi.refreshTokenRequest(oauth!!))
} catch (_: Throwable) {
throw TokenRefreshFailed()
}
if (response.code == 401) {
myanimelist.setAuthExpired()
throw TokenExpired()
}
return runCatching {
if (response.isSuccessful) {
with(json) { response.parseAs<OAuth>() }
} else { } else {
oauthResponse.close() response.close()
null null
} }
} }
.getOrNull()
if (newOauth.getOrNull() == null) { ?.also {
throw IOException("Failed to refresh the access token") this.oauth = it
myanimelist.saveOAuth(it)
}
?: throw TokenRefreshFailed()
} }
return newOauth.getOrNull()!!
}
} }