mirror of
				https://github.com/mihonapp/mihon.git
				synced 2025-11-04 08:08:55 +01:00 
			
		
		
		
	Rewrite RateLimitInterceptor (#7889)
This commit is contained in:
		@@ -5,6 +5,8 @@ import okhttp3.Interceptor
 | 
			
		||||
import okhttp3.OkHttpClient
 | 
			
		||||
import okhttp3.Response
 | 
			
		||||
import java.io.IOException
 | 
			
		||||
import java.util.ArrayDeque
 | 
			
		||||
import java.util.concurrent.Semaphore
 | 
			
		||||
import java.util.concurrent.TimeUnit
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
@@ -25,54 +27,77 @@ fun OkHttpClient.Builder.rateLimit(
 | 
			
		||||
    permits: Int,
 | 
			
		||||
    period: Long = 1,
 | 
			
		||||
    unit: TimeUnit = TimeUnit.SECONDS,
 | 
			
		||||
) = addInterceptor(RateLimitInterceptor(permits, period, unit))
 | 
			
		||||
) = addInterceptor(RateLimitInterceptor(null, permits, period, unit))
 | 
			
		||||
 | 
			
		||||
private class RateLimitInterceptor(
 | 
			
		||||
/** We can probably accept domains or wildcards by comparing with [endsWith], etc. */
 | 
			
		||||
@Suppress("PLATFORM_CLASS_MAPPED_TO_KOTLIN")
 | 
			
		||||
internal class RateLimitInterceptor(
 | 
			
		||||
    private val host: String?,
 | 
			
		||||
    private val permits: Int,
 | 
			
		||||
    period: Long,
 | 
			
		||||
    unit: TimeUnit,
 | 
			
		||||
) : Interceptor {
 | 
			
		||||
 | 
			
		||||
    private val requestQueue = ArrayList<Long>(permits)
 | 
			
		||||
    private val requestQueue = ArrayDeque<Long>(permits)
 | 
			
		||||
    private val rateLimitMillis = unit.toMillis(period)
 | 
			
		||||
    private val fairLock = Semaphore(1, true)
 | 
			
		||||
 | 
			
		||||
    override fun intercept(chain: Interceptor.Chain): Response {
 | 
			
		||||
        // Ignore canceled calls, otherwise they would jam the queue
 | 
			
		||||
        if (chain.call().isCanceled()) {
 | 
			
		||||
            throw IOException()
 | 
			
		||||
        val call = chain.call()
 | 
			
		||||
        if (call.isCanceled()) throw IOException("Canceled")
 | 
			
		||||
 | 
			
		||||
        val request = chain.request()
 | 
			
		||||
        when (host) {
 | 
			
		||||
            null, request.url.host -> {} // need rate limit
 | 
			
		||||
            else -> return chain.proceed(request)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        synchronized(requestQueue) {
 | 
			
		||||
            val now = SystemClock.elapsedRealtime()
 | 
			
		||||
            val waitTime = if (requestQueue.size < permits) {
 | 
			
		||||
                0
 | 
			
		||||
            } else {
 | 
			
		||||
                val oldestReq = requestQueue[0]
 | 
			
		||||
                val newestReq = requestQueue[permits - 1]
 | 
			
		||||
        try {
 | 
			
		||||
            fairLock.acquire()
 | 
			
		||||
        } catch (e: InterruptedException) {
 | 
			
		||||
            throw IOException(e)
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
                if (newestReq - oldestReq > rateLimitMillis) {
 | 
			
		||||
                    0
 | 
			
		||||
                } else {
 | 
			
		||||
                    oldestReq + rateLimitMillis - now // Remaining time
 | 
			
		||||
        val requestQueue = this.requestQueue
 | 
			
		||||
        val timestamp: Long
 | 
			
		||||
 | 
			
		||||
        try {
 | 
			
		||||
            synchronized(requestQueue) {
 | 
			
		||||
                while (requestQueue.size >= permits) { // queue is full, remove expired entries
 | 
			
		||||
                    val periodStart = SystemClock.elapsedRealtime() - rateLimitMillis
 | 
			
		||||
                    var hasRemovedExpired = false
 | 
			
		||||
                    while (requestQueue.isEmpty().not() && requestQueue.first <= periodStart) {
 | 
			
		||||
                        requestQueue.removeFirst()
 | 
			
		||||
                        hasRemovedExpired = true
 | 
			
		||||
                    }
 | 
			
		||||
                    if (call.isCanceled()) {
 | 
			
		||||
                        throw IOException("Canceled")
 | 
			
		||||
                    } else if (hasRemovedExpired) {
 | 
			
		||||
                        break
 | 
			
		||||
                    } else try { // wait for the first entry to expire, or notified by cached response
 | 
			
		||||
                        (requestQueue as Object).wait(requestQueue.first - periodStart)
 | 
			
		||||
                    } catch (_: InterruptedException) {
 | 
			
		||||
                        continue
 | 
			
		||||
                    }
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Final check
 | 
			
		||||
            if (chain.call().isCanceled()) {
 | 
			
		||||
                throw IOException()
 | 
			
		||||
                // add request to queue
 | 
			
		||||
                timestamp = SystemClock.elapsedRealtime()
 | 
			
		||||
                requestQueue.addLast(timestamp)
 | 
			
		||||
            }
 | 
			
		||||
        } finally {
 | 
			
		||||
            fairLock.release()
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
            if (requestQueue.size == permits) {
 | 
			
		||||
                requestQueue.removeAt(0)
 | 
			
		||||
            }
 | 
			
		||||
            if (waitTime > 0) {
 | 
			
		||||
                requestQueue.add(now + waitTime)
 | 
			
		||||
                Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
 | 
			
		||||
            } else {
 | 
			
		||||
                requestQueue.add(now)
 | 
			
		||||
        val response = chain.proceed(request)
 | 
			
		||||
        if (response.networkResponse == null) { // response is cached, remove it from queue
 | 
			
		||||
            synchronized(requestQueue) {
 | 
			
		||||
                if (requestQueue.isEmpty() || timestamp < requestQueue.first) return@synchronized
 | 
			
		||||
                requestQueue.removeFirstOccurrence(timestamp)
 | 
			
		||||
                (requestQueue as Object).notifyAll()
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return chain.proceed(chain.request())
 | 
			
		||||
        return response
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
 
 | 
			
		||||
@@ -1,11 +1,7 @@
 | 
			
		||||
package eu.kanade.tachiyomi.network.interceptor
 | 
			
		||||
 | 
			
		||||
import android.os.SystemClock
 | 
			
		||||
import okhttp3.HttpUrl
 | 
			
		||||
import okhttp3.Interceptor
 | 
			
		||||
import okhttp3.OkHttpClient
 | 
			
		||||
import okhttp3.Response
 | 
			
		||||
import java.io.IOException
 | 
			
		||||
import java.util.concurrent.TimeUnit
 | 
			
		||||
 | 
			
		||||
/**
 | 
			
		||||
@@ -28,58 +24,4 @@ fun OkHttpClient.Builder.rateLimitHost(
 | 
			
		||||
    permits: Int,
 | 
			
		||||
    period: Long = 1,
 | 
			
		||||
    unit: TimeUnit = TimeUnit.SECONDS,
 | 
			
		||||
) = addInterceptor(SpecificHostRateLimitInterceptor(httpUrl, permits, period, unit))
 | 
			
		||||
 | 
			
		||||
class SpecificHostRateLimitInterceptor(
 | 
			
		||||
    httpUrl: HttpUrl,
 | 
			
		||||
    private val permits: Int,
 | 
			
		||||
    period: Long,
 | 
			
		||||
    unit: TimeUnit,
 | 
			
		||||
) : Interceptor {
 | 
			
		||||
 | 
			
		||||
    private val requestQueue = ArrayList<Long>(permits)
 | 
			
		||||
    private val rateLimitMillis = unit.toMillis(period)
 | 
			
		||||
    private val host = httpUrl.host
 | 
			
		||||
 | 
			
		||||
    override fun intercept(chain: Interceptor.Chain): Response {
 | 
			
		||||
        // Ignore canceled calls, otherwise they would jam the queue
 | 
			
		||||
        if (chain.call().isCanceled()) {
 | 
			
		||||
            throw IOException()
 | 
			
		||||
        } else if (chain.request().url.host != host) {
 | 
			
		||||
            return chain.proceed(chain.request())
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        synchronized(requestQueue) {
 | 
			
		||||
            val now = SystemClock.elapsedRealtime()
 | 
			
		||||
            val waitTime = if (requestQueue.size < permits) {
 | 
			
		||||
                0
 | 
			
		||||
            } else {
 | 
			
		||||
                val oldestReq = requestQueue[0]
 | 
			
		||||
                val newestReq = requestQueue[permits - 1]
 | 
			
		||||
 | 
			
		||||
                if (newestReq - oldestReq > rateLimitMillis) {
 | 
			
		||||
                    0
 | 
			
		||||
                } else {
 | 
			
		||||
                    oldestReq + rateLimitMillis - now // Remaining time
 | 
			
		||||
                }
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            // Final check
 | 
			
		||||
            if (chain.call().isCanceled()) {
 | 
			
		||||
                throw IOException()
 | 
			
		||||
            }
 | 
			
		||||
 | 
			
		||||
            if (requestQueue.size == permits) {
 | 
			
		||||
                requestQueue.removeAt(0)
 | 
			
		||||
            }
 | 
			
		||||
            if (waitTime > 0) {
 | 
			
		||||
                requestQueue.add(now + waitTime)
 | 
			
		||||
                Thread.sleep(waitTime) // Sleep inside synchronized to pause queued requests
 | 
			
		||||
            } else {
 | 
			
		||||
                requestQueue.add(now)
 | 
			
		||||
            }
 | 
			
		||||
        }
 | 
			
		||||
 | 
			
		||||
        return chain.proceed(chain.request())
 | 
			
		||||
    }
 | 
			
		||||
}
 | 
			
		||||
) = addInterceptor(RateLimitInterceptor(httpUrl.host, permits, period, unit))
 | 
			
		||||
 
 | 
			
		||||
		Reference in New Issue
	
	Block a user