Refactor MAL code to not spam refresh token when it fails

This commit is contained in:
AntsyLich 2024-01-28 00:12:18 +06:00
parent 05efc4ebeb
commit 32188f9f65
No known key found for this signature in database
4 changed files with 49 additions and 38 deletions

View File

@ -19,9 +19,15 @@ class TrackPreferences(
"", "",
) )
fun trackAuthExpired(tracker: Tracker) = preferenceStore.getBoolean(
Preference.privateKey("pref_tracker_auth_expired_${tracker.id}"),
false,
)
fun setCredentials(tracker: Tracker, username: String, password: String) { fun setCredentials(tracker: Tracker, username: String, password: String) {
trackUsername(tracker).set(username) trackUsername(tracker).set(username)
trackPassword(tracker).set(password) trackPassword(tracker).set(password)
trackAuthExpired(tracker).set(false)
} }
fun trackToken(tracker: Tracker) = preferenceStore.getString(Preference.privateKey("track_token_${tracker.id}"), "") fun trackToken(tracker: Tracker) = preferenceStore.getString(Preference.privateKey("track_token_${tracker.id}"), "")

View File

@ -35,7 +35,7 @@ class MyAnimeList(id: Long) : BaseTracker(id, "MyAnimeList"), DeletableTracker {
private val json: Json by injectLazy() private val json: Json by injectLazy()
private val interceptor by lazy { MyAnimeListInterceptor(this, getPassword()) } private val interceptor by lazy { MyAnimeListInterceptor(this) }
private val api by lazy { MyAnimeListApi(id, client, interceptor) } private val api by lazy { MyAnimeListApi(id, client, interceptor) }
override val supportsReadingDates: Boolean = true override val supportsReadingDates: Boolean = true
@ -155,6 +155,14 @@ class MyAnimeList(id: Long) : BaseTracker(id, "MyAnimeList"), DeletableTracker {
interceptor.setAuth(null) interceptor.setAuth(null)
} }
fun getIfAuthExpired(): Boolean {
return trackPreferences.trackAuthExpired(this).get()
}
fun setAuthExpired() {
trackPreferences.trackAuthExpired(this).set(true)
}
fun saveOAuth(oAuth: OAuth?) { fun saveOAuth(oAuth: OAuth?) {
trackPreferences.trackToken(this).set(json.encodeToString(oAuth)) trackPreferences.trackToken(this).set(json.encodeToString(oAuth))
} }

View File

@ -25,6 +25,7 @@ import kotlinx.serialization.json.jsonArray
import kotlinx.serialization.json.jsonObject import kotlinx.serialization.json.jsonObject
import kotlinx.serialization.json.jsonPrimitive import kotlinx.serialization.json.jsonPrimitive
import kotlinx.serialization.json.long import kotlinx.serialization.json.long
import logcat.logcat
import okhttp3.FormBody import okhttp3.FormBody
import okhttp3.Headers import okhttp3.Headers
import okhttp3.OkHttpClient import okhttp3.OkHttpClient
@ -49,13 +50,13 @@ class MyAnimeListApi(
suspend fun getAccessToken(authCode: String): OAuth { suspend fun getAccessToken(authCode: String): OAuth {
return withIOContext { return withIOContext {
val formBody: RequestBody = FormBody.Builder() val formBody: RequestBody = FormBody.Builder()
.add("client_id", clientId) .add("client_id", CLIENT_ID)
.add("code", authCode) .add("code", authCode)
.add("code_verifier", codeVerifier) .add("code_verifier", codeVerifier)
.add("grant_type", "authorization_code") .add("grant_type", "authorization_code")
.build() .build()
with(json) { with(json) {
client.newCall(POST("$baseOAuthUrl/token", body = formBody)) client.newCall(POST("$BASE_OAUTH_URL/token", body = formBody))
.awaitSuccess() .awaitSuccess()
.parseAs() .parseAs()
} }
@ -65,7 +66,7 @@ class MyAnimeListApi(
suspend fun getCurrentUser(): String { suspend fun getCurrentUser(): String {
return withIOContext { return withIOContext {
val request = Request.Builder() val request = Request.Builder()
.url("$baseApiUrl/users/@me") .url("$BASE_API_URL/users/@me")
.get() .get()
.build() .build()
with(json) { with(json) {
@ -79,7 +80,7 @@ class MyAnimeListApi(
suspend fun search(query: String): List<TrackSearch> { suspend fun search(query: String): List<TrackSearch> {
return withIOContext { return withIOContext {
val url = "$baseApiUrl/manga".toUri().buildUpon() val url = "$BASE_API_URL/manga".toUri().buildUpon()
// MAL API throws a 400 when the query is over 64 characters... // MAL API throws a 400 when the query is over 64 characters...
.appendQueryParameter("q", query.take(64)) .appendQueryParameter("q", query.take(64))
.appendQueryParameter("nsfw", "true") .appendQueryParameter("nsfw", "true")
@ -104,7 +105,7 @@ class MyAnimeListApi(
suspend fun getMangaDetails(id: Int): TrackSearch { suspend fun getMangaDetails(id: Int): TrackSearch {
return withIOContext { return withIOContext {
val url = "$baseApiUrl/manga".toUri().buildUpon() val url = "$BASE_API_URL/manga".toUri().buildUpon()
.appendPath(id.toString()) .appendPath(id.toString())
.appendQueryParameter( .appendQueryParameter(
"fields", "fields",
@ -180,7 +181,7 @@ class MyAnimeListApi(
suspend fun findListItem(track: Track): Track? { suspend fun findListItem(track: Track): Track? {
return withIOContext { return withIOContext {
val uri = "$baseApiUrl/manga".toUri().buildUpon() val uri = "$BASE_API_URL/manga".toUri().buildUpon()
.appendPath(track.remote_id.toString()) .appendPath(track.remote_id.toString())
.appendQueryParameter("fields", "num_chapters,my_list_status{start_date,finish_date}") .appendQueryParameter("fields", "num_chapters,my_list_status{start_date,finish_date}")
.build() .build()
@ -218,7 +219,7 @@ class MyAnimeListApi(
// Check next page if there's more // Check next page if there's more
if (!obj["paging"]!!.jsonObject["next"]?.jsonPrimitive?.contentOrNull.isNullOrBlank()) { if (!obj["paging"]!!.jsonObject["next"]?.jsonPrimitive?.contentOrNull.isNullOrBlank()) {
matches + findListItems(query, offset + listPaginationAmount) matches + findListItems(query, offset + LIST_PAGINATION_AMOUNT)
} else { } else {
matches matches
} }
@ -227,9 +228,9 @@ class MyAnimeListApi(
private suspend fun getListPage(offset: Int): JsonObject { private suspend fun getListPage(offset: Int): JsonObject {
return withIOContext { return withIOContext {
val urlBuilder = "$baseApiUrl/users/@me/mangalist".toUri().buildUpon() val urlBuilder = "$BASE_API_URL/users/@me/mangalist".toUri().buildUpon()
.appendQueryParameter("fields", "list_status{start_date,finish_date}") .appendQueryParameter("fields", "list_status{start_date,finish_date}")
.appendQueryParameter("limit", listPaginationAmount.toString()) .appendQueryParameter("limit", LIST_PAGINATION_AMOUNT.toString())
if (offset > 0) { if (offset > 0) {
urlBuilder.appendQueryParameter("offset", offset.toString()) urlBuilder.appendQueryParameter("offset", offset.toString())
} }
@ -279,30 +280,29 @@ class MyAnimeListApi(
} }
companion object { companion object {
// Registered under arkon's MAL account private const val CLIENT_ID = "c46c9e24640a64dad5be5ca7a1a53a0f"
private const val clientId = "f46004a9c16483b6d87b5bf10de56d97"
private const val baseOAuthUrl = "https://myanimelist.net/v1/oauth2" private const val BASE_OAUTH_URL = "https://myanimelist.net/v1/oauth2"
private const val baseApiUrl = "https://api.myanimelist.net/v2" private const val BASE_API_URL = "https://api.myanimelist.net/v2"
private const val listPaginationAmount = 250 private const val LIST_PAGINATION_AMOUNT = 250
private var codeVerifier: String = "" private var codeVerifier: String = ""
fun authUrl(): Uri = "$baseOAuthUrl/authorize".toUri().buildUpon() fun authUrl(): Uri = "$BASE_OAUTH_URL/authorize".toUri().buildUpon()
.appendQueryParameter("client_id", clientId) .appendQueryParameter("client_id", CLIENT_ID)
.appendQueryParameter("code_challenge", getPkceChallengeCode()) .appendQueryParameter("code_challenge", getPkceChallengeCode())
.appendQueryParameter("response_type", "code") .appendQueryParameter("response_type", "code")
.build() .build()
fun mangaUrl(id: Long): Uri = "$baseApiUrl/manga".toUri().buildUpon() fun mangaUrl(id: Long): Uri = "$BASE_API_URL/manga".toUri().buildUpon()
.appendPath(id.toString()) .appendPath(id.toString())
.appendPath("my_list_status") .appendPath("my_list_status")
.build() .build()
fun refreshTokenRequest(oauth: OAuth): Request { fun refreshTokenRequest(oauth: OAuth): Request {
val formBody: RequestBody = FormBody.Builder() val formBody: RequestBody = FormBody.Builder()
.add("client_id", clientId) .add("client_id", CLIENT_ID)
.add("refresh_token", oauth.refresh_token) .add("refresh_token", oauth.refresh_token)
.add("grant_type", "refresh_token") .add("grant_type", "refresh_token")
.build() .build()
@ -314,7 +314,7 @@ class MyAnimeListApi(
.add("Authorization", "Bearer ${oauth.access_token}") .add("Authorization", "Bearer ${oauth.access_token}")
.build() .build()
return POST("$baseOAuthUrl/token", body = formBody, headers = headers) return POST("$BASE_OAUTH_URL/token", body = formBody, headers = headers)
} }
private fun getPkceChallengeCode(): String { private fun getPkceChallengeCode(): String {

View File

@ -8,28 +8,26 @@ import okhttp3.Response
import uy.kohesive.injekt.injectLazy import uy.kohesive.injekt.injectLazy
import java.io.IOException import java.io.IOException
class MyAnimeListInterceptor(private val myanimelist: MyAnimeList, private var token: String?) : Interceptor { class MyAnimeListInterceptor(private val myanimelist: MyAnimeList) : Interceptor {
private val json: Json by injectLazy() private val json: Json by injectLazy()
private var oauth: OAuth? = null private var oauth: OAuth? = myanimelist.loadOAuth()
private val tokenExpired get() = myanimelist.getIfAuthExpired()
override fun intercept(chain: Interceptor.Chain): Response { override fun intercept(chain: Interceptor.Chain): Response {
if (tokenExpired) {
throw MALTokenExpired()
}
val originalRequest = chain.request() val originalRequest = chain.request()
if (token.isNullOrEmpty()) {
throw IOException("Not authenticated with MyAnimeList")
}
if (oauth == null) {
oauth = myanimelist.loadOAuth()
}
// Refresh access token if expired // Refresh access token if expired
if (oauth != null && oauth!!.isExpired()) { if (oauth != null && oauth!!.isExpired()) {
setAuth(refreshToken(chain)) setAuth(refreshToken(chain))
} }
if (oauth == null) { if (oauth == null) {
throw IOException("No authentication token") throw IOException("MAL: User is not authenticated")
} }
// Add the authorization header to the original request // Add the authorization header to the original request
@ -66,15 +64,16 @@ class MyAnimeListInterceptor(private val myanimelist: MyAnimeList, private var t
* and the oauth object. * and the oauth object.
*/ */
fun setAuth(oauth: OAuth?) { fun setAuth(oauth: OAuth?) {
token = oauth?.access_token
this.oauth = oauth this.oauth = oauth
myanimelist.saveOAuth(oauth) myanimelist.saveOAuth(oauth)
} }
private fun refreshToken(chain: Interceptor.Chain): OAuth { private fun refreshToken(chain: Interceptor.Chain): OAuth {
val newOauth = runCatching { return runCatching {
val oauthResponse = chain.proceed(MyAnimeListApi.refreshTokenRequest(oauth!!)) val oauthResponse = chain.proceed(MyAnimeListApi.refreshTokenRequest(oauth!!))
if (oauthResponse.code == 401) {
myanimelist.setAuthExpired()
}
if (oauthResponse.isSuccessful) { if (oauthResponse.isSuccessful) {
with(json) { oauthResponse.parseAs<OAuth>() } with(json) { oauthResponse.parseAs<OAuth>() }
} else { } else {
@ -82,11 +81,9 @@ class MyAnimeListInterceptor(private val myanimelist: MyAnimeList, private var t
null null
} }
} }
.getOrNull()
if (newOauth.getOrNull() == null) { ?: throw MALTokenExpired()
throw IOException("Failed to refresh the access token")
}
return newOauth.getOrNull()!!
} }
} }
class MALTokenExpired: IOException("MAL: Login has expired")