Skip to content

Commit 84350cc

Browse files
committed
Run suspending calls within Dispatchers.Default
This addresses a need for off-main-thread invocation of Call.Factory.newCall to support lazy HttpClient initialization.
1 parent 6cd6f7d commit 84350cc

File tree

2 files changed

+84
-48
lines changed

2 files changed

+84
-48
lines changed

retrofit/kotlin-test/src/test/java/retrofit2/KotlinSuspendTest.kt

+27
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ package retrofit2
1717

1818
import kotlinx.coroutines.CoroutineDispatcher
1919
import kotlinx.coroutines.GlobalScope
20+
import kotlinx.coroutines.Runnable
2021
import kotlinx.coroutines.async
2122
import kotlinx.coroutines.runBlocking
2223
import kotlinx.coroutines.withContext
@@ -353,6 +354,32 @@ class KotlinSuspendTest {
353354
}
354355
}
355356

357+
@Test fun usesCoroutineContextForCallFactory() {
358+
val okHttpClient = OkHttpClient()
359+
var callFactoryThread: Thread? = null
360+
val outerContextThread: Thread
361+
val retrofit = Retrofit.Builder()
362+
.baseUrl(server.url("/"))
363+
.callFactory {
364+
callFactoryThread = Thread.currentThread()
365+
okHttpClient.newCall(it)
366+
}
367+
.addConverterFactory(ToStringConverterFactory())
368+
.build()
369+
val example = retrofit.create(Service::class.java)
370+
371+
server.enqueue(MockResponse().setBody("Hi"))
372+
373+
runBlocking {
374+
outerContextThread = Thread.currentThread()
375+
example.body()
376+
}
377+
378+
assertThat(callFactoryThread).isNotNull
379+
assertThat(outerContextThread).isNotEqualTo(callFactoryThread)
380+
}
381+
382+
356383
@Suppress("EXPERIMENTAL_OVERRIDE")
357384
private object DirectUnconfinedDispatcher : CoroutineDispatcher() {
358385
override fun isDispatchNeeded(context: CoroutineContext): Boolean = false

retrofit/src/main/java/retrofit2/KotlinExtensions.kt

+57-48
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ package retrofit2
2020

2121
import kotlinx.coroutines.Dispatchers
2222
import kotlinx.coroutines.suspendCancellableCoroutine
23-
import java.lang.reflect.ParameterizedType
23+
import kotlinx.coroutines.withContext
2424
import kotlin.coroutines.intrinsics.COROUTINE_SUSPENDED
2525
import kotlin.coroutines.intrinsics.intercepted
2626
import kotlin.coroutines.intrinsics.suspendCoroutineUninterceptedOrReturn
@@ -30,57 +30,63 @@ import kotlin.coroutines.resumeWithException
3030
inline fun <reified T: Any> Retrofit.create(): T = create(T::class.java)
3131

3232
suspend fun <T : Any> Call<T>.await(): T {
33-
return suspendCancellableCoroutine { continuation ->
34-
continuation.invokeOnCancellation {
35-
cancel()
36-
}
37-
enqueue(object : Callback<T> {
38-
override fun onResponse(call: Call<T>, response: Response<T>) {
39-
if (response.isSuccessful) {
40-
val body = response.body()
41-
if (body == null) {
42-
val invocation = call.request().tag(Invocation::class.java)!!
43-
val method = invocation.method()
44-
val e = KotlinNullPointerException("Response from " +
33+
// TODO: a better solution for off-main-thread call factories than this.
34+
return withContext(Dispatchers.Default) {
35+
suspendCancellableCoroutine { continuation ->
36+
continuation.invokeOnCancellation {
37+
cancel()
38+
}
39+
enqueue(object : Callback<T> {
40+
override fun onResponse(call: Call<T>, response: Response<T>) {
41+
if (response.isSuccessful) {
42+
val body = response.body()
43+
if (body == null) {
44+
val invocation = call.request().tag(Invocation::class.java)!!
45+
val method = invocation.method()
46+
val e = KotlinNullPointerException("Response from " +
4547
method.declaringClass.name +
4648
'.' +
4749
method.name +
4850
" was null but response body type was declared as non-null")
49-
continuation.resumeWithException(e)
51+
continuation.resumeWithException(e)
52+
} else {
53+
continuation.resume(body)
54+
}
5055
} else {
51-
continuation.resume(body)
56+
continuation.resumeWithException(HttpException(response))
5257
}
53-
} else {
54-
continuation.resumeWithException(HttpException(response))
5558
}
56-
}
5759

58-
override fun onFailure(call: Call<T>, t: Throwable) {
59-
continuation.resumeWithException(t)
60-
}
61-
})
60+
override fun onFailure(call: Call<T>, t: Throwable) {
61+
continuation.resumeWithException(t)
62+
}
63+
})
64+
}
6265
}
6366
}
6467

6568
@JvmName("awaitNullable")
6669
suspend fun <T : Any> Call<T?>.await(): T? {
67-
return suspendCancellableCoroutine { continuation ->
68-
continuation.invokeOnCancellation {
69-
cancel()
70-
}
71-
enqueue(object : Callback<T?> {
72-
override fun onResponse(call: Call<T?>, response: Response<T?>) {
73-
if (response.isSuccessful) {
74-
continuation.resume(response.body())
75-
} else {
76-
continuation.resumeWithException(HttpException(response))
77-
}
70+
// TODO: a better solution for off-main-thread call factories than this.
71+
return withContext(Dispatchers.Default) {
72+
suspendCancellableCoroutine { continuation ->
73+
continuation.invokeOnCancellation {
74+
cancel()
7875
}
76+
enqueue(object : Callback<T?> {
77+
override fun onResponse(call: Call<T?>, response: Response<T?>) {
78+
if (response.isSuccessful) {
79+
continuation.resume(response.body())
80+
} else {
81+
continuation.resumeWithException(HttpException(response))
82+
}
83+
}
7984

80-
override fun onFailure(call: Call<T?>, t: Throwable) {
81-
continuation.resumeWithException(t)
82-
}
83-
})
85+
override fun onFailure(call: Call<T?>, t: Throwable) {
86+
continuation.resumeWithException(t)
87+
}
88+
})
89+
}
8490
}
8591
}
8692

@@ -91,19 +97,22 @@ suspend fun Call<Unit>.await() {
9197
}
9298

9399
suspend fun <T> Call<T>.awaitResponse(): Response<T> {
94-
return suspendCancellableCoroutine { continuation ->
95-
continuation.invokeOnCancellation {
96-
cancel()
97-
}
98-
enqueue(object : Callback<T> {
99-
override fun onResponse(call: Call<T>, response: Response<T>) {
100-
continuation.resume(response)
100+
// TODO: a better solution for off-main-thread call factories than this.
101+
return withContext(Dispatchers.Default) {
102+
suspendCancellableCoroutine { continuation ->
103+
continuation.invokeOnCancellation {
104+
cancel()
101105
}
106+
enqueue(object : Callback<T> {
107+
override fun onResponse(call: Call<T>, response: Response<T>) {
108+
continuation.resume(response)
109+
}
102110

103-
override fun onFailure(call: Call<T>, t: Throwable) {
104-
continuation.resumeWithException(t)
105-
}
106-
})
111+
override fun onFailure(call: Call<T>, t: Throwable) {
112+
continuation.resumeWithException(t)
113+
}
114+
})
115+
}
107116
}
108117
}
109118

0 commit comments

Comments
 (0)