Skip to content

Commit 13e72ec

Browse files
authored
sse: fix regression in URL joining (#265)
* sse: fix regression in URL joining This was broken by #197. The URL join behaves like this: ``` $ let baseUrl = "https://example.com/sse"; $ new URL("?sessionId=x", baseUrl).href 'https://example.com/sse?sessionId=x' $ new URL("/?sessionId=x", baseUrl).href 'https://example.com/?sessionId=x' ``` The PR #197 did not take into account the relative URL Fixes #252 * Address review comments
1 parent 8a57765 commit 13e72ec

File tree

1 file changed

+67
-9
lines changed

1 file changed

+67
-9
lines changed

crates/rmcp/src/transport/sse_client.rs

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,10 @@ impl<C: SseClient> SseClientTransport<C> {
121121

122122
let mut sse_stream = client.get_stream(sse_endpoint.clone(), None, None).await?;
123123
let message_endpoint = if let Some(endpoint) = config.use_message_endpoint.clone() {
124-
endpoint.parse::<http::Uri>()?
124+
let ep = endpoint.parse::<http::Uri>()?;
125+
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
126+
sse_endpoint_parts.path_and_query = ep.into_parts().path_and_query;
127+
Uri::from_parts(sse_endpoint_parts)?
125128
} else {
126129
// wait the endpoint event
127130
loop {
@@ -132,17 +135,12 @@ impl<C: SseClient> SseClientTransport<C> {
132135
let Some("endpoint") = sse.event.as_deref() else {
133136
continue;
134137
};
135-
let sse_endpoint = sse.data.unwrap_or_default();
136-
break sse_endpoint.parse::<http::Uri>()?;
138+
let ep = sse.data.unwrap_or_default();
139+
140+
break message_endpoint(sse_endpoint.clone(), ep)?;
137141
}
138142
};
139143

140-
// sse: <authority><sse_pq> -> <authority><message_pq>
141-
let message_endpoint = {
142-
let mut sse_endpoint_parts = sse_endpoint.clone().into_parts();
143-
sse_endpoint_parts.path_and_query = message_endpoint.into_parts().path_and_query;
144-
Uri::from_parts(sse_endpoint_parts)?
145-
};
146144
let stream = Box::pin(SseAutoReconnectStream::new(
147145
sse_stream,
148146
SseClientReconnect {
@@ -160,6 +158,36 @@ impl<C: SseClient> SseClientTransport<C> {
160158
}
161159
}
162160

161+
fn message_endpoint(base: http::Uri, endpoint: String) -> Result<http::Uri, http::uri::InvalidUri> {
162+
// If endpoint is a full URL, parse and return it directly
163+
if endpoint.starts_with("http://") || endpoint.starts_with("https://") {
164+
return endpoint.parse::<http::Uri>();
165+
}
166+
167+
let mut base_parts = base.into_parts();
168+
let endpoint_clone = endpoint.clone();
169+
170+
if endpoint.starts_with("?") {
171+
// Query only - keep base path and append query
172+
if let Some(base_path_and_query) = &base_parts.path_and_query {
173+
let base_path = base_path_and_query.path();
174+
base_parts.path_and_query = Some(format!("{}{}", base_path, endpoint).parse()?);
175+
} else {
176+
base_parts.path_and_query = Some(format!("/{}", endpoint).parse()?);
177+
}
178+
} else {
179+
// Path (with optional query) - replace entire path_and_query
180+
let path_to_use = if endpoint.starts_with("/") {
181+
endpoint // Use absolute path as-is
182+
} else {
183+
format!("/{}", endpoint) // Make relative path absolute
184+
};
185+
base_parts.path_and_query = Some(path_to_use.parse()?);
186+
}
187+
188+
http::Uri::from_parts(base_parts).map_err(|_| endpoint_clone.parse::<http::Uri>().unwrap_err())
189+
}
190+
163191
#[derive(Debug, Clone)]
164192
pub struct SseClientConfig {
165193
/// client sse endpoint
@@ -188,3 +216,33 @@ impl Default for SseClientConfig {
188216
}
189217
}
190218
}
219+
220+
#[cfg(test)]
221+
mod tests {
222+
use super::*;
223+
224+
#[test]
225+
fn test_message_endpoint() {
226+
let base_url = "https://localhost/sse".parse::<http::Uri>().unwrap();
227+
228+
// Query only
229+
let result = message_endpoint(base_url.clone(), "?sessionId=x".to_string()).unwrap();
230+
assert_eq!(result.to_string(), "https://localhost/sse?sessionId=x");
231+
232+
// Relative path with query
233+
let result = message_endpoint(base_url.clone(), "mypath?sessionId=x".to_string()).unwrap();
234+
assert_eq!(result.to_string(), "https://localhost/mypath?sessionId=x");
235+
236+
// Absolute path with query
237+
let result = message_endpoint(base_url.clone(), "/xxx?sessionId=x".to_string()).unwrap();
238+
assert_eq!(result.to_string(), "https://localhost/xxx?sessionId=x");
239+
240+
// Full URL
241+
let result = message_endpoint(
242+
base_url.clone(),
243+
"http://example.com/xxx?sessionId=x".to_string(),
244+
)
245+
.unwrap();
246+
assert_eq!(result.to_string(), "http://example.com/xxx?sessionId=x");
247+
}
248+
}

0 commit comments

Comments
 (0)