mirror of
				https://github.com/mihonapp/mihon.git
				synced 2025-10-31 14:27:57 +01:00 
			
		
		
		
	Rewrite RateLimitInterceptor (#7889)
This commit is contained in:
		| @@ -5,6 +5,8 @@ import okhttp3.Interceptor | ||||
| import okhttp3.OkHttpClient | ||||
| import okhttp3.Response | ||||
| import java.io.IOException | ||||
| import java.util.ArrayDeque | ||||
| import java.util.concurrent.Semaphore | ||||
| import java.util.concurrent.TimeUnit | ||||
|  | ||||
| /** | ||||
| @@ -25,54 +27,77 @@ fun OkHttpClient.Builder.rateLimit( | ||||
|     permits: Int, | ||||
|     period: Long = 1, | ||||
|     unit: TimeUnit = TimeUnit.SECONDS, | ||||
| ) = addInterceptor(RateLimitInterceptor(permits, period, unit)) | ||||
| ) = addInterceptor(RateLimitInterceptor(null, permits, period, unit)) | ||||
|  | ||||
| private class RateLimitInterceptor( | ||||
| /** We can probably accept domains or wildcards by comparing with [endsWith], etc. */ | ||||
| @Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN") | ||||
| internal class RateLimitInterceptor( | ||||
|     private val host: String?, | ||||
|     private val permits: Int, | ||||
|     period: Long, | ||||
|     unit: TimeUnit, | ||||
| ) : Interceptor { | ||||
|  | ||||
|     private val requestQueue = ArrayList<Long>(permits) | ||||
|     private val requestQueue = ArrayDeque<Long>(permits) | ||||
|     private val rateLimitMillis = unit.toMillis(period) | ||||
|     private val fairLock = Semaphore(1, true) | ||||
|  | ||||
|     override fun intercept(chain: Interceptor.Chain): Response { | ||||
|         // Ignore canceled calls, otherwise they would jam the queue | ||||
|         if (chain.call().isCanceled()) { | ||||
|             throw IOException() | ||||
|         val call = chain.call() | ||||
|         if (call.isCanceled()) throw IOException("Canceled") | ||||
|  | ||||
|         val request = chain.request() | ||||
|         when (host) { | ||||
|             null, request.url.host -> {} // need rate limit | ||||
|             else -> return chain.proceed(request) | ||||
|         } | ||||
|  | ||||
|         synchronized(requestQueue) { | ||||
|             val now = SystemClock.elapsedRealtime() | ||||
|             val waitTime = if (requestQueue.size < permits) { | ||||
|                 0 | ||||
|             } else { | ||||
|                 val oldestReq = requestQueue[0] | ||||
|                 val newestReq = requestQueue[permits - 1] | ||||
|         try { | ||||
|             fairLock.acquire() | ||||
|         } catch (e: InterruptedException) { | ||||
|             throw IOException(e) | ||||
|         } | ||||
|  | ||||
|                 if (newestReq - oldestReq > rateLimitMillis) { | ||||
|                     0 | ||||
|                 } else { | ||||
|                     oldestReq + rateLimitMillis - now // Remaining time | ||||
|         val requestQueue = this.requestQueue | ||||
|         val timestamp: Long | ||||
|  | ||||
|         try { | ||||
|             synchronized(requestQueue) { | ||||
|                 while (requestQueue.size >= permits) { // queue is full, remove expired entries | ||||
|                     val periodStart = SystemClock.elapsedRealtime() - rateLimitMillis | ||||
|                     var hasRemovedExpired = false | ||||
|                     while (requestQueue.isEmpty().not() && requestQueue.first <= periodStart) { | ||||
|                         requestQueue.removeFirst() | ||||
|                         hasRemovedExpired = true | ||||
|                     } | ||||
|                     if (call.isCanceled()) { | ||||
|                         throw IOException("Canceled") | ||||
|                     } else if (hasRemovedExpired) { | ||||
|                         break | ||||
|                     } else try { // wait for the first entry to expire, or notified by cached response | ||||
|                         (requestQueue as Object).wait(requestQueue.first - periodStart) | ||||
|                     } catch (_: InterruptedException) { | ||||
|                         continue | ||||
|                     } | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Final check | ||||
|             if (chain.call().isCanceled()) { | ||||
|                 throw IOException() | ||||
|                 // add request to queue | ||||
|                 timestamp = SystemClock.elapsedRealtime() | ||||
|                 requestQueue.addLast(timestamp) | ||||
|             } | ||||
|         } finally { | ||||
|             fairLock.release() | ||||
|         } | ||||
|  | ||||
|             if (requestQueue.size == permits) { | ||||
|                 requestQueue.removeAt(0) | ||||
|             } | ||||
|             if (waitTime > 0) { | ||||
|                 requestQueue.add(now + waitTime) | ||||
|                 Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests | ||||
|             } else { | ||||
|                 requestQueue.add(now) | ||||
|         val response = chain.proceed(request) | ||||
|         if (response.networkResponse == null) { // response is cached, remove it from queue | ||||
|             synchronized(requestQueue) { | ||||
|                 if (requestQueue.isEmpty() || timestamp < requestQueue.first) return@synchronized | ||||
|                 requestQueue.removeFirstOccurrence(timestamp) | ||||
|                 (requestQueue as Object).notifyAll() | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         return chain.proceed(chain.request()) | ||||
|         return response | ||||
|     } | ||||
| } | ||||
|   | ||||
| @@ -1,11 +1,7 @@ | ||||
| package eu.kanade.tachiyomi.network.interceptor | ||||
|  | ||||
| import android.os.SystemClock | ||||
| import okhttp3.HttpUrl | ||||
| import okhttp3.Interceptor | ||||
| import okhttp3.OkHttpClient | ||||
| import okhttp3.Response | ||||
| import java.io.IOException | ||||
| import java.util.concurrent.TimeUnit | ||||
|  | ||||
| /** | ||||
| @@ -28,58 +24,4 @@ fun OkHttpClient.Builder.rateLimitHost( | ||||
|     permits: Int, | ||||
|     period: Long = 1, | ||||
|     unit: TimeUnit = TimeUnit.SECONDS, | ||||
| ) = addInterceptor(SpecificHostRateLimitInterceptor(httpUrl, permits, period, unit)) | ||||
|  | ||||
| class SpecificHostRateLimitInterceptor( | ||||
|     httpUrl: HttpUrl, | ||||
|     private val permits: Int, | ||||
|     period: Long, | ||||
|     unit: TimeUnit, | ||||
| ) : Interceptor { | ||||
|  | ||||
|     private val requestQueue = ArrayList<Long>(permits) | ||||
|     private val rateLimitMillis = unit.toMillis(period) | ||||
|     private val host = httpUrl.host | ||||
|  | ||||
|     override fun intercept(chain: Interceptor.Chain): Response { | ||||
|         // Ignore canceled calls, otherwise they would jam the queue | ||||
|         if (chain.call().isCanceled()) { | ||||
|             throw IOException() | ||||
|         } else if (chain.request().url.host != host) { | ||||
|             return chain.proceed(chain.request()) | ||||
|         } | ||||
|  | ||||
|         synchronized(requestQueue) { | ||||
|             val now = SystemClock.elapsedRealtime() | ||||
|             val waitTime = if (requestQueue.size < permits) { | ||||
|                 0 | ||||
|             } else { | ||||
|                 val oldestReq = requestQueue[0] | ||||
|                 val newestReq = requestQueue[permits - 1] | ||||
|  | ||||
|                 if (newestReq - oldestReq > rateLimitMillis) { | ||||
|                     0 | ||||
|                 } else { | ||||
|                     oldestReq + rateLimitMillis - now // Remaining time | ||||
|                 } | ||||
|             } | ||||
|  | ||||
|             // Final check | ||||
|             if (chain.call().isCanceled()) { | ||||
|                 throw IOException() | ||||
|             } | ||||
|  | ||||
|             if (requestQueue.size == permits) { | ||||
|                 requestQueue.removeAt(0) | ||||
|             } | ||||
|             if (waitTime > 0) { | ||||
|                 requestQueue.add(now + waitTime) | ||||
|                 Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests | ||||
|             } else { | ||||
|                 requestQueue.add(now) | ||||
|             } | ||||
|         } | ||||
|  | ||||
|         return chain.proceed(chain.request()) | ||||
|     } | ||||
| } | ||||
| ) = addInterceptor(RateLimitInterceptor(httpUrl.host, permits, period, unit)) | ||||
|   | ||||
		Reference in New Issue
	
	Block a user