@@ -24,11 +24,13 @@ import akka.http.scaladsl.marshalling._
24
24
import akka .http .scaladsl .model ._
25
25
import akka .http .scaladsl .settings .ConnectionPoolSettings
26
26
import akka .http .scaladsl .unmarshalling ._
27
- import akka .stream .{OverflowStrategy , QueueOfferResult }
28
27
import akka .stream .scaladsl .{Flow , _ }
28
+ import akka .stream .{KillSwitches , QueueOfferResult }
29
+ import org .apache .openwhisk .common .AkkaLogging
29
30
import spray .json ._
30
- import scala . concurrent .{ ExecutionContext , Future , Promise }
31
+
31
32
import scala .concurrent .duration ._
33
+ import scala .concurrent .{ExecutionContext , Future , Promise }
32
34
import scala .util .{Failure , Success , Try }
33
35
34
36
/**
@@ -45,10 +47,10 @@ class PoolingRestClient(
45
47
port : Int ,
46
48
queueSize : Int ,
47
49
httpFlow : Option [Flow [(HttpRequest , Promise [HttpResponse ]), (Try [HttpResponse ], Promise [HttpResponse ]), Any ]] = None ,
48
- timeout : Option [FiniteDuration ] = None )(implicit system : ActorSystem ) {
50
+ timeout : Option [FiniteDuration ] = None )(implicit system : ActorSystem , ec : ExecutionContext ) {
49
51
require(protocol == " http" || protocol == " https" , " Protocol must be one of { http, https }." )
50
52
51
- protected implicit val context : ExecutionContext = system.dispatcher
53
+ private val logging = new AkkaLogging ( system.log)
52
54
53
55
// if specified, override the ClientConnection idle-timeout and keepalive socket option value
54
56
private val timeoutSettings = {
@@ -72,16 +74,19 @@ class PoolingRestClient(
72
74
// Additional queue in case all connections are busy. Should hardly ever be
73
75
// filled in practice but can be useful, e.g., in tests starting many
74
76
// asynchronous requests in a very short period of time.
75
- private val requestQueue = Source
76
- .queue(queueSize, OverflowStrategy .dropNew )
77
+ private val (( requestQueue, killSwitch), sinkCompletion) = Source
78
+ .queue(queueSize)
77
79
.via(httpFlow.getOrElse(pool))
80
+ .viaMat(KillSwitches .single)(Keep .both)
78
81
.toMat(Sink .foreach({
79
82
case (Success (response), p) =>
80
83
p.success(response)
81
84
case (Failure (error), p) =>
82
85
p.failure(error)
83
- }))(Keep .left)
84
- .run
86
+ }))(Keep .both)
87
+ .run()
88
+
89
+ sinkCompletion.onComplete(_ => shutdown())
85
90
86
91
/**
87
92
* Execute an HttpRequest on the underlying connection pool.
@@ -96,10 +101,10 @@ class PoolingRestClient(
96
101
97
102
// When the future completes, we know whether the request made it
98
103
// through the queue.
99
- requestQueue.offer(request -> promise).flatMap {
104
+ requestQueue.offer(request -> promise) match {
100
105
case QueueOfferResult .Enqueued => promise.future
101
- case QueueOfferResult .Dropped => Future .failed(new Exception (" DB request queue is full." ))
102
- case QueueOfferResult .QueueClosed => Future .failed(new Exception (" DB request queue was closed." ))
106
+ case QueueOfferResult .Dropped => Future .failed(new Exception (" Request queue is full." ))
107
+ case QueueOfferResult .QueueClosed => Future .failed(new Exception (" Request queue was closed." ))
103
108
case QueueOfferResult .Failure (f) => Future .failed(f)
104
109
}
105
110
}
@@ -127,7 +132,13 @@ class PoolingRestClient(
127
132
}
128
133
}
129
134
130
- def shutdown (): Future [Unit ] = Future .unit
135
+ def shutdown (): Future [Unit ] = {
136
+ killSwitch.shutdown()
137
+ Try (requestQueue.complete()).recover {
138
+ case t : IllegalStateException => logging.warn(this , t.getMessage)
139
+ }
140
+ Future .unit
141
+ }
131
142
}
132
143
133
144
object PoolingRestClient {
0 commit comments