Skip to content

Commit 7b88899

Browse files
tdejagerkonstin
authored andcommitted
Encapsulate middleware switching in MiddlewareStack
1 parent 6c23ff2 commit 7b88899

File tree

8 files changed

+76
-105
lines changed

8 files changed

+76
-105
lines changed

crates/uv-client/src/base_client.rs

+49-63
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ use std::env;
88
use std::fmt::Debug;
99
use std::ops::Deref;
1010
use std::path::Path;
11+
use std::slice::Iter;
1112
use std::sync::Arc;
1213
use tracing::debug;
1314
use uv_auth::AuthMiddleware;
@@ -20,69 +21,76 @@ use crate::linehaul::LineHaul;
2021
use crate::middleware::OfflineMiddleware;
2122
use crate::Connectivity;
2223

23-
#[derive(Clone, Default)]
24-
pub struct CustomMiddleware {
25-
pub middleware: Vec<Arc<dyn Middleware>>,
24+
/// Newtype to implement [`Debug`] on [`Middleware`].
25+
#[derive(Clone)]
26+
pub struct MiddlewareStack(Vec<Arc<dyn Middleware>>);
27+
28+
impl MiddlewareStack {
29+
/// Use a custom middleware stack, skipping the default retry and auth layers.
30+
// This function exists for rustlib users.
31+
pub fn custom(middleware_stack: Vec<Arc<dyn Middleware>>) -> Self {
32+
Self(middleware_stack)
33+
}
34+
35+
/// Initialize the middleware stack, using an [`ExponentialBackoff`] with the given number of
36+
/// retries and an [`AuthMiddleware`] layer with the given keyring provider.
37+
pub fn new(retries: u32, keyring: KeyringProviderType) -> Self {
38+
let retry_policy = ExponentialBackoff::builder().build_with_max_retries(retries);
39+
let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy);
40+
41+
let auth_middleware = AuthMiddleware::new().with_keyring(keyring.to_provider());
42+
43+
Self(vec![Arc::new(retry_strategy), Arc::new(auth_middleware)])
44+
}
2645
}
27-
impl Debug for CustomMiddleware {
46+
47+
impl Default for MiddlewareStack {
48+
/// Initialize the default middleware stack, consisting of an [`ExponentialBackoff`] retry
49+
/// strategy and an [`AuthMiddleware`] layer not using the keyring.
50+
fn default() -> Self {
51+
Self::new(3, KeyringProviderType::default())
52+
}
53+
}
54+
55+
impl Debug for MiddlewareStack {
2856
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
2957
f.debug_struct("CustomMiddleware").finish()
3058
}
3159
}
3260

61+
impl<'a> IntoIterator for &'a MiddlewareStack {
62+
type Item = &'a Arc<dyn Middleware>;
63+
type IntoIter = Iter<'a, Arc<dyn Middleware>>;
64+
65+
fn into_iter(self) -> Self::IntoIter {
66+
self.0.iter()
67+
}
68+
}
69+
3370
/// A builder for an [`BaseClient`].
34-
#[derive(Debug, Clone)]
71+
#[derive(Debug, Clone, Default)]
3572
pub struct BaseClientBuilder<'a> {
36-
keyring: KeyringProviderType,
3773
native_tls: bool,
38-
retries: u32,
3974
connectivity: Connectivity,
4075
client: Option<Client>,
4176
markers: Option<&'a MarkerEnvironment>,
4277
platform: Option<&'a Platform>,
43-
custom_middleware: CustomMiddleware,
44-
}
45-
46-
impl Default for BaseClientBuilder<'_> {
47-
fn default() -> Self {
48-
Self::new()
49-
}
78+
middleware_stack: MiddlewareStack,
5079
}
5180

5281
impl BaseClientBuilder<'_> {
5382
pub fn new() -> Self {
54-
Self {
55-
keyring: KeyringProviderType::default(),
56-
native_tls: false,
57-
connectivity: Connectivity::Online,
58-
retries: 3,
59-
client: None,
60-
markers: None,
61-
platform: None,
62-
custom_middleware: CustomMiddleware::default(),
63-
}
83+
Self::default()
6484
}
6585
}
6686

6787
impl<'a> BaseClientBuilder<'a> {
68-
#[must_use]
69-
pub fn keyring(mut self, keyring_type: KeyringProviderType) -> Self {
70-
self.keyring = keyring_type;
71-
self
72-
}
73-
7488
#[must_use]
7589
pub fn connectivity(mut self, connectivity: Connectivity) -> Self {
7690
self.connectivity = connectivity;
7791
self
7892
}
7993

80-
#[must_use]
81-
pub fn retries(mut self, retries: u32) -> Self {
82-
self.retries = retries;
83-
self
84-
}
85-
8694
#[must_use]
8795
pub fn native_tls(mut self, native_tls: bool) -> Self {
8896
self.native_tls = native_tls;
@@ -108,14 +116,8 @@ impl<'a> BaseClientBuilder<'a> {
108116
}
109117

110118
#[must_use]
111-
pub fn add_middleware<M: Middleware>(mut self, middleware: M) -> Self {
112-
self.custom_middleware.middleware.push(Arc::new(middleware));
113-
self
114-
}
115-
116-
#[must_use]
117-
pub fn custom_middleware(mut self, custom_middleware: CustomMiddleware) -> Self {
118-
self.custom_middleware = custom_middleware;
119+
pub fn middleware_stack(mut self, middleware_stack: MiddlewareStack) -> Self {
120+
self.middleware_stack = middleware_stack;
119121
self
120122
}
121123

@@ -187,26 +189,10 @@ impl<'a> BaseClientBuilder<'a> {
187189
let client = match self.connectivity {
188190
Connectivity::Online => {
189191
let mut client = reqwest_middleware::ClientBuilder::new(client.clone());
190-
191-
// Use custom middleware instead if provided
192-
if !self.custom_middleware.middleware.is_empty() {
193-
for middleware in &self.custom_middleware.middleware {
194-
client = client.with_arc(middleware.clone());
195-
}
196-
client.build()
197-
} else {
198-
// Initialize the retry strategy.
199-
let retry_policy =
200-
ExponentialBackoff::builder().build_with_max_retries(self.retries);
201-
let retry_strategy = RetryTransientMiddleware::new_with_policy(retry_policy);
202-
let client = client.with(retry_strategy);
203-
204-
// Initialize the authentication middleware to set headers.
205-
let client =
206-
client.with(AuthMiddleware::new().with_keyring(self.keyring.to_provider()));
207-
208-
client.build()
192+
for middleware in &self.middleware_stack {
193+
client = client.with_arc(middleware.clone());
209194
}
195+
client.build()
210196
}
211197
Connectivity::Offline => reqwest_middleware::ClientBuilder::new(client.clone())
212198
.with(OfflineMiddleware)

crates/uv-client/src/lib.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
pub use base_client::{BaseClient, BaseClientBuilder};
1+
pub use base_client::{BaseClient, BaseClientBuilder, MiddlewareStack};
22
pub use cached_client::{CacheControl, CachedClient, CachedClientError, DataWithCachePolicy};
33
pub use error::{BetterReqwestError, Error, ErrorKind};
44
pub use flat_index::{FlatIndexClient, FlatIndexEntries, FlatIndexError};

crates/uv-client/src/registry_client.rs

+8-28
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@ use std::collections::BTreeMap;
22
use std::fmt::Debug;
33
use std::path::Path;
44
use std::str::FromStr;
5-
use std::sync::Arc;
65

76
use async_http_range_reader::AsyncHttpRangeReader;
87
use futures::{FutureExt, TryStreamExt};
98
use http::HeaderMap;
109
use reqwest::{Client, Response, StatusCode};
11-
use reqwest_middleware::Middleware;
1210
use serde::{Deserialize, Serialize};
1311
use tokio::io::AsyncReadExt;
1412
use tokio_util::compat::{FuturesAsyncReadCompatExt, TokioAsyncReadCompatExt};
@@ -24,10 +22,9 @@ use platform_tags::Platform;
2422
use pypi_types::{Metadata23, SimpleJson};
2523
use uv_cache::{Cache, CacheBucket, WheelCache};
2624
use uv_configuration::IndexStrategy;
27-
use uv_configuration::KeyringProviderType;
2825
use uv_normalize::PackageName;
2926

30-
use crate::base_client::CustomMiddleware;
27+
use crate::base_client::MiddlewareStack;
3128
use crate::base_client::{BaseClient, BaseClientBuilder};
3229
use crate::cached_client::CacheControl;
3330
use crate::html::SimpleHtml;
@@ -40,31 +37,27 @@ use crate::{CachedClient, CachedClientError, Error, ErrorKind};
4037
pub struct RegistryClientBuilder<'a> {
4138
index_urls: IndexUrls,
4239
index_strategy: IndexStrategy,
43-
keyring: KeyringProviderType,
4440
native_tls: bool,
45-
retries: u32,
4641
connectivity: Connectivity,
4742
cache: Cache,
4843
client: Option<Client>,
4944
markers: Option<&'a MarkerEnvironment>,
5045
platform: Option<&'a Platform>,
51-
custom_middleware: CustomMiddleware,
46+
middleware_stack: MiddlewareStack,
5247
}
5348

5449
impl RegistryClientBuilder<'_> {
5550
pub fn new(cache: Cache) -> Self {
5651
Self {
5752
index_urls: IndexUrls::default(),
5853
index_strategy: IndexStrategy::default(),
59-
keyring: KeyringProviderType::default(),
6054
native_tls: false,
6155
cache,
6256
connectivity: Connectivity::Online,
63-
retries: 3,
6457
client: None,
6558
markers: None,
6659
platform: None,
67-
custom_middleware: CustomMiddleware::default(),
60+
middleware_stack: MiddlewareStack::default(),
6861
}
6962
}
7063
}
@@ -82,24 +75,12 @@ impl<'a> RegistryClientBuilder<'a> {
8275
self
8376
}
8477

85-
#[must_use]
86-
pub fn keyring(mut self, keyring_type: KeyringProviderType) -> Self {
87-
self.keyring = keyring_type;
88-
self
89-
}
90-
9178
#[must_use]
9279
pub fn connectivity(mut self, connectivity: Connectivity) -> Self {
9380
self.connectivity = connectivity;
9481
self
9582
}
9683

97-
#[must_use]
98-
pub fn retries(mut self, retries: u32) -> Self {
99-
self.retries = retries;
100-
self
101-
}
102-
10384
#[must_use]
10485
pub fn native_tls(mut self, native_tls: bool) -> Self {
10586
self.native_tls = native_tls;
@@ -131,8 +112,8 @@ impl<'a> RegistryClientBuilder<'a> {
131112
}
132113

133114
#[must_use]
134-
pub fn add_middleware<M: Middleware>(mut self, middleware: M) -> Self {
135-
self.custom_middleware.middleware.push(Arc::new(middleware));
115+
pub fn middleware_stack(mut self, middleware_stack: MiddlewareStack) -> Self {
116+
self.middleware_stack = middleware_stack;
136117
self
137118
}
138119

@@ -153,11 +134,9 @@ impl<'a> RegistryClientBuilder<'a> {
153134
}
154135

155136
let client = builder
156-
.retries(self.retries)
157137
.connectivity(self.connectivity)
158138
.native_tls(self.native_tls)
159-
.custom_middleware(self.custom_middleware)
160-
.keyring(self.keyring)
139+
.middleware_stack(self.middleware_stack)
161140
.build();
162141

163142
let timeout = client.timeout();
@@ -859,9 +838,10 @@ impl MediaType {
859838
}
860839
}
861840

862-
#[derive(Debug, Copy, Clone, Eq, PartialEq)]
841+
#[derive(Debug, Copy, Clone, Eq, PartialEq, Default)]
863842
pub enum Connectivity {
864843
/// Allow access to the network.
844+
#[default]
865845
Online,
866846

867847
/// Do not allow access to the network.

crates/uv/src/commands/pip/compile.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ use pypi_types::Metadata23;
2626
use requirements_txt::EditableRequirement;
2727
use uv_auth::store_credentials_from_url;
2828
use uv_cache::Cache;
29-
use uv_client::{BaseClientBuilder, Connectivity, FlatIndexClient, RegistryClientBuilder};
29+
use uv_client::{
30+
BaseClientBuilder, Connectivity, FlatIndexClient, MiddlewareStack, RegistryClientBuilder,
31+
};
3032
use uv_configuration::{
3133
Concurrency, ConfigSettings, Constraints, IndexStrategy, NoBinary, NoBuild, Overrides,
3234
PreviewMode, SetupPyStrategy, Upgrade,
@@ -113,7 +115,7 @@ pub(crate) async fn pip_compile(
113115
let client_builder = BaseClientBuilder::new()
114116
.connectivity(connectivity)
115117
.native_tls(native_tls)
116-
.keyring(keyring_provider);
118+
.middleware_stack(MiddlewareStack::new(3, keyring_provider));
117119

118120
// Read all requirements from the provided sources.
119121
let RequirementsSpecification {
@@ -273,7 +275,7 @@ pub(crate) async fn pip_compile(
273275
.connectivity(connectivity)
274276
.index_urls(index_locations.index_urls())
275277
.index_strategy(index_strategy)
276-
.keyring(keyring_provider)
278+
.middleware_stack(MiddlewareStack::new(3, keyring_provider))
277279
.markers(&markers)
278280
.platform(interpreter.platform())
279281
.build();

crates/uv/src/commands/pip/install.rs

+4-3
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,8 @@ use pypi_types::Yanked;
2222
use uv_auth::store_credentials_from_url;
2323
use uv_cache::Cache;
2424
use uv_client::{
25-
BaseClientBuilder, Connectivity, FlatIndexClient, RegistryClient, RegistryClientBuilder,
25+
BaseClientBuilder, Connectivity, FlatIndexClient, MiddlewareStack, RegistryClient,
26+
RegistryClientBuilder,
2627
};
2728
use uv_configuration::{
2829
Concurrency, ConfigSettings, Constraints, IndexStrategy, NoBinary, NoBuild, Overrides,
@@ -98,7 +99,7 @@ pub(crate) async fn pip_install(
9899
let client_builder = BaseClientBuilder::new()
99100
.connectivity(connectivity)
100101
.native_tls(native_tls)
101-
.keyring(keyring_provider);
102+
.middleware_stack(MiddlewareStack::new(3, keyring_provider));
102103

103104
// Read all requirements from the provided sources.
104105
let RequirementsSpecification {
@@ -287,7 +288,7 @@ pub(crate) async fn pip_install(
287288
.connectivity(connectivity)
288289
.index_urls(index_locations.index_urls())
289290
.index_strategy(index_strategy)
290-
.keyring(keyring_provider)
291+
.middleware_stack(MiddlewareStack::new(3, keyring_provider))
291292
.markers(&markers)
292293
.platform(interpreter.platform())
293294
.build();

crates/uv/src/commands/pip/sync.rs

+5-3
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,9 @@ use platform_tags::Tags;
1313
use pypi_types::Yanked;
1414
use uv_auth::store_credentials_from_url;
1515
use uv_cache::Cache;
16-
use uv_client::{BaseClientBuilder, Connectivity, FlatIndexClient, RegistryClientBuilder};
16+
use uv_client::{
17+
BaseClientBuilder, Connectivity, FlatIndexClient, MiddlewareStack, RegistryClientBuilder,
18+
};
1719
use uv_configuration::{
1820
Concurrency, ConfigSettings, IndexStrategy, NoBinary, NoBuild, PreviewMode, Reinstall,
1921
SetupPyStrategy,
@@ -74,7 +76,7 @@ pub(crate) async fn pip_sync(
7476
let client_builder = BaseClientBuilder::new()
7577
.connectivity(connectivity)
7678
.native_tls(native_tls)
77-
.keyring(keyring_provider);
79+
.middleware_stack(MiddlewareStack::new(3, keyring_provider));
7880

7981
// Read all requirements from the provided sources.
8082
let RequirementsSpecification {
@@ -213,7 +215,7 @@ pub(crate) async fn pip_sync(
213215
.connectivity(connectivity)
214216
.index_urls(index_locations.index_urls())
215217
.index_strategy(index_strategy)
216-
.keyring(keyring_provider)
218+
.middleware_stack(MiddlewareStack::new(3, keyring_provider))
217219
.markers(venv.interpreter().markers())
218220
.platform(venv.interpreter().platform())
219221
.build();

crates/uv/src/commands/pip/uninstall.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ use tracing::debug;
88
use distribution_types::{InstalledMetadata, Name, Requirement, UnresolvedRequirement};
99
use pep508_rs::UnnamedRequirement;
1010
use uv_cache::Cache;
11-
use uv_client::{BaseClientBuilder, Connectivity};
11+
use uv_client::{BaseClientBuilder, Connectivity, MiddlewareStack};
1212
use uv_configuration::{KeyringProviderType, PreviewMode};
1313
use uv_fs::Simplified;
1414
use uv_interpreter::{PythonEnvironment, Target};
@@ -36,7 +36,7 @@ pub(crate) async fn pip_uninstall(
3636
let client_builder = BaseClientBuilder::new()
3737
.connectivity(connectivity)
3838
.native_tls(native_tls)
39-
.keyring(keyring_provider);
39+
.middleware_stack(MiddlewareStack::new(3, keyring_provider));
4040

4141
// Read all requirements from the provided sources.
4242
let spec =

0 commit comments

Comments
 (0)