@@ -2,6 +2,7 @@ package service
2
2
3
3
import (
4
4
"bytes"
5
+ "context"
5
6
"crypto/tls"
6
7
"encoding/json"
7
8
"fmt"
@@ -10,9 +11,11 @@ import (
10
11
"net/http"
11
12
"net/url"
12
13
"reflect"
14
+ "strings"
13
15
"time"
14
16
15
17
"github.com/google/go-querystring/query"
18
+ "github.com/hashicorp/go-retryablehttp"
16
19
)
17
20
18
21
// CloudServiceProvider is a custom type for different types of cloud service providers
@@ -62,25 +65,87 @@ type DBApiClientConfig struct {
62
65
DefaultHeaders map [string ]string
63
66
InsecureSkipVerify bool
64
67
TimeoutSeconds int
65
- client http.Client
68
+ client * retryablehttp.Client
69
+ }
70
+
71
+ var transientErrorStringMatches []string = []string { // TODO: Should we make these regexes to match more of the message or is this sufficient?
72
+ "com.databricks.backend.manager.util.UnknownWorkerEnvironmentException" ,
73
+ "does not have any associated worker environments" ,
66
74
}
67
75
68
76
// Setup initializes the client
69
77
func (c * DBApiClientConfig ) Setup () {
70
78
if c .TimeoutSeconds == 0 {
71
79
c .TimeoutSeconds = 60
72
80
}
73
- c .client = http.Client {
74
- Timeout : time .Duration (time .Duration (c .TimeoutSeconds ) * time .Second ),
75
- Transport : & http.Transport {
76
- TLSClientConfig : & tls.Config {
77
- InsecureSkipVerify : c .InsecureSkipVerify ,
81
+ // Set up a retryable HTTP Client to handle cases where the service returns
82
+ // a transient error on initial creation
83
+ retryDelayDuration := 10 * time .Second
84
+ retryMaximumDuration := 5 * time .Minute
85
+ c .client = & retryablehttp.Client {
86
+ HTTPClient : & http.Client {
87
+ Timeout : time .Duration (time .Duration (c .TimeoutSeconds ) * time .Second ),
88
+ Transport : & http.Transport {
89
+ TLSClientConfig : & tls.Config {
90
+ InsecureSkipVerify : c .InsecureSkipVerify ,
91
+ },
78
92
},
79
93
},
94
+ CheckRetry : checkHTTPRetry ,
95
+ // Using a linear retry rather than the default exponential retry
96
+ // as the creation condition is normally passed after 30-40 seconds
97
+ // Setting the retry interval to 10 seconds. Setting RetryWaitMin and RetryWaitMax
98
+ // to the same value removes jitter (which would be useful in a high-volume traffic scenario
99
+ // but wouldn't add much here)
100
+ Backoff : retryablehttp .LinearJitterBackoff ,
101
+ RetryWaitMin : retryDelayDuration ,
102
+ RetryWaitMax : retryDelayDuration ,
103
+ RetryMax : int (retryMaximumDuration / retryDelayDuration ),
104
+ }
105
+ }
106
+
107
+ // checkHTTPRetry inspects HTTP errors from the Databricks API for known transient errors on Workspace creation
108
+ func checkHTTPRetry (ctx context.Context , resp * http.Response , err error ) (bool , error ) {
109
+ if resp == nil {
110
+ // If response is nil we can't make retry choices.
111
+ // In this case don't retry and return the original error from httpclient
112
+ return false , err
113
+ }
114
+ if resp .StatusCode >= 400 {
115
+ log .Printf ("Failed request detected. Status Code: %v\n " , resp .StatusCode )
116
+ // reading the body means that the caller cannot read it themselves
117
+ // But that's ok because we've hit an error case
118
+ // Our job now is to
119
+ // - capture the error and return it
120
+ // - determine if the error is retryable
121
+
122
+ body , err := ioutil .ReadAll (resp .Body )
123
+ if err != nil {
124
+ return false , err
125
+ }
126
+
127
+ var errorBody DBApiErrorBody
128
+ err = json .Unmarshal (body , & errorBody )
129
+ if err != nil {
130
+ return false , fmt .Errorf ("Response from server (%d) %s: %v" , resp .StatusCode , string (body ), err )
131
+ }
132
+ dbAPIError := DBApiError {
133
+ ErrorBody : & errorBody ,
134
+ StatusCode : resp .StatusCode ,
135
+ Err : fmt .Errorf ("Response from server %s" , string (body )),
136
+ }
137
+ for _ , substring := range transientErrorStringMatches {
138
+ if strings .Contains (errorBody .Message , substring ) {
139
+ log .Println ("Failed request detected: Retryable type found. Attempting retry..." )
140
+ return true , dbAPIError
141
+ }
142
+ }
143
+ return false , dbAPIError
80
144
}
145
+ return false , nil
81
146
}
82
147
83
- func (c DBApiClientConfig ) getAuthHeader () map [string ]string {
148
+ func (c * DBApiClientConfig ) getAuthHeader () map [string ]string {
84
149
auth := make (map [string ]string )
85
150
if c .AuthType == BasicAuth {
86
151
auth ["Authorization" ] = "Basic " + c .Token
@@ -91,7 +156,7 @@ func (c DBApiClientConfig) getAuthHeader() map[string]string {
91
156
return auth
92
157
}
93
158
94
- func (c DBApiClientConfig ) getUserAgentHeader () map [string ]string {
159
+ func (c * DBApiClientConfig ) getUserAgentHeader () map [string ]string {
95
160
if reflect .ValueOf (c .UserAgent ).IsZero () {
96
161
return map [string ]string {
97
162
"User-Agent" : "databricks-go-client-sdk" ,
@@ -102,7 +167,7 @@ func (c DBApiClientConfig) getUserAgentHeader() map[string]string {
102
167
}
103
168
}
104
169
105
- func (c DBApiClientConfig ) getDefaultHeaders () map [string ]string {
170
+ func (c * DBApiClientConfig ) getDefaultHeaders () map [string ]string {
106
171
auth := c .getAuthHeader ()
107
172
userAgent := c .getUserAgentHeader ()
108
173
@@ -119,7 +184,7 @@ func (c DBApiClientConfig) getDefaultHeaders() map[string]string {
119
184
return defaultHeaders
120
185
}
121
186
122
- func (c DBApiClientConfig ) getRequestURI (path string , apiVersion string ) (string , error ) {
187
+ func (c * DBApiClientConfig ) getRequestURI (path string , apiVersion string ) (string , error ) {
123
188
var apiVersionString string
124
189
if apiVersion == "" {
125
190
apiVersionString = "2.0"
@@ -189,6 +254,9 @@ func PerformQuery(config *DBApiClientConfig, method, path string, apiVersion str
189
254
}
190
255
}
191
256
requestHeaders := config .getDefaultHeaders ()
257
+ if config .client == nil {
258
+ config .Setup ()
259
+ }
192
260
193
261
if len (headers ) > 0 {
194
262
for k , v := range headers {
@@ -221,7 +289,7 @@ func PerformQuery(config *DBApiClientConfig, method, path string, apiVersion str
221
289
auditNonGetPayload (method , requestURL , data , secretsMask )
222
290
}
223
291
224
- request , err := http .NewRequest (method , requestURL , bytes .NewBuffer (requestBody ))
292
+ request , err := retryablehttp .NewRequest (method , requestURL , bytes .NewBuffer (requestBody ))
225
293
if err != nil {
226
294
return nil , err
227
295
}
@@ -244,19 +312,8 @@ func PerformQuery(config *DBApiClientConfig, method, path string, apiVersion str
244
312
if err != nil {
245
313
return nil , err
246
314
}
247
-
248
- if resp .StatusCode >= 400 {
249
- var errorBody DBApiErrorBody
250
- err = json .Unmarshal (body , & errorBody )
251
- if err != nil {
252
- return nil , fmt .Errorf ("Response from server (%d) %s" , resp .StatusCode , string (body ))
253
- }
254
- return nil , DBApiError {
255
- ErrorBody : & errorBody ,
256
- StatusCode : resp .StatusCode ,
257
- Err : fmt .Errorf ("Response from server %s" , string (body )),
258
- }
259
- }
315
+ // Don't need to check the status code here as the RetryCheck for
316
+ // retryablehttp.Client is doing that and returning an error
260
317
261
318
return body , nil
262
319
}
0 commit comments