Skip to content

Commit 34b3ed4

Browse files
authored
Merge pull request #27 from apollographql/jeffrey/support-passing-in-a-custom-scalar-map-file
add support for passing in a custom_scalar_map file
2 parents 0fa9f4b + 7aae84b commit 34b3ed4

File tree

6 files changed

+265
-28
lines changed

6 files changed

+265
-28
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
use crate::errors::ServerError;
2+
use rmcp::{
3+
schemars::schema::{Schema, SchemaObject, SingleOrVec},
4+
serde_json,
5+
};
6+
use std::{collections::HashMap, path::PathBuf, str::FromStr};
7+
8+
impl FromStr for CustomScalarMap {
9+
type Err = ServerError;
10+
11+
fn from_str(string_custom_scalar_file: &str) -> Result<Self, Self::Err> {
12+
// Parse the string into an initial map of serde_json::Values
13+
let parsed_custom_scalar_file: serde_json::Map<String, serde_json::Value> =
14+
serde_json::from_str(string_custom_scalar_file)
15+
.map_err(ServerError::CustomScalarConfig)?;
16+
17+
// Validate each of the values in the map and coerce into schemars::schema::SchemaObject
18+
let custom_scalar_map = parsed_custom_scalar_file
19+
.into_iter()
20+
.map(|(key, value)| {
21+
let value_string = value.to_string();
22+
// The only way I could find to do this was to reparse it.
23+
let schema: SchemaObject = serde_json::from_str(value_string.as_str())
24+
.map_err(ServerError::CustomScalarConfig)?;
25+
26+
if has_invalid_schema(&Schema::Object(schema.clone())) {
27+
Err(ServerError::CustomScalarJsonSchema(value))
28+
} else {
29+
Ok((key, schema))
30+
}
31+
})
32+
.collect::<Result<_, _>>()?;
33+
34+
// panic!("hello2! {:?}", parsed_custom_scalar_file);
35+
36+
Ok::<_, ServerError>(CustomScalarMap(custom_scalar_map))
37+
}
38+
}
39+
40+
impl TryFrom<&PathBuf> for CustomScalarMap {
41+
type Error = ServerError;
42+
43+
fn try_from(file_path_buf: &PathBuf) -> Result<Self, Self::Error> {
44+
let custom_scalars_config_path = file_path_buf.as_path();
45+
tracing::info!(custom_scalars_config=?custom_scalars_config_path, "Loading custom_scalars_config");
46+
let string_custom_scalar_file = std::fs::read_to_string(custom_scalars_config_path)?;
47+
CustomScalarMap::from_str(string_custom_scalar_file.as_str())
48+
}
49+
}
50+
51+
#[derive(Debug)]
52+
pub struct CustomScalarMap(HashMap<String, SchemaObject>);
53+
54+
impl CustomScalarMap {
55+
pub fn get(&self, key: &str) -> Option<&SchemaObject> {
56+
self.0.get(key)
57+
}
58+
}
59+
60+
// Unknown keys will be put into "extensions" in the schema object, check for those and consider those invalid
61+
fn has_invalid_schema(schema: &Schema) -> bool {
62+
match schema {
63+
Schema::Object(schema_object) => {
64+
!schema_object.extensions.is_empty()
65+
|| schema_object
66+
.object
67+
.as_ref()
68+
.is_some_and(|object| object.properties.values().any(has_invalid_schema))
69+
|| schema_object.array.as_ref().is_some_and(|object| {
70+
object.items.as_ref().is_some_and(|items| match items {
71+
SingleOrVec::Single(item) => has_invalid_schema(item),
72+
SingleOrVec::Vec(items) => items.iter().any(has_invalid_schema),
73+
})
74+
})
75+
}
76+
Schema::Bool(_) => false,
77+
}
78+
}
79+
80+
#[cfg(test)]
81+
mod tests {
82+
use std::{
83+
collections::{BTreeMap, HashMap},
84+
str::FromStr,
85+
};
86+
87+
use rmcp::schemars::schema::{
88+
InstanceType, ObjectValidation, Schema, SchemaObject, SingleOrVec,
89+
};
90+
91+
use crate::custom_scalar_map::CustomScalarMap;
92+
93+
#[test]
94+
fn empty_file() {
95+
let result = CustomScalarMap::from_str("").err().unwrap();
96+
97+
insta::assert_debug_snapshot!(result, @r###"
98+
CustomScalarConfig(
99+
Error("EOF while parsing a value", line: 1, column: 0),
100+
)
101+
"###)
102+
}
103+
104+
#[test]
105+
fn only_spaces() {
106+
let result = CustomScalarMap::from_str(" ").err().unwrap();
107+
108+
insta::assert_debug_snapshot!(result, @r###"
109+
CustomScalarConfig(
110+
Error("EOF while parsing a value", line: 1, column: 4),
111+
)
112+
"###)
113+
}
114+
115+
#[test]
116+
fn invalid_json() {
117+
let result = CustomScalarMap::from_str("Hello: }").err().unwrap();
118+
119+
insta::assert_debug_snapshot!(result, @r###"
120+
CustomScalarConfig(
121+
Error("expected value", line: 1, column: 1),
122+
)
123+
"###)
124+
}
125+
126+
#[test]
127+
fn invalid_simple_schema() {
128+
let result = CustomScalarMap::from_str(
129+
r###"{
130+
"custom": {
131+
"test": true
132+
}
133+
}"###,
134+
)
135+
.err()
136+
.unwrap();
137+
138+
insta::assert_debug_snapshot!(result, @r###"
139+
CustomScalarJsonSchema(
140+
Object {
141+
"test": Bool(true),
142+
},
143+
)
144+
"###)
145+
}
146+
147+
#[test]
148+
fn invalid_complex_schema() {
149+
let result = CustomScalarMap::from_str(
150+
r###"{
151+
"custom": {
152+
"type": "object",
153+
"properties": {
154+
"test": {
155+
"test": true
156+
}
157+
}
158+
}
159+
}"###,
160+
)
161+
.err()
162+
.unwrap();
163+
164+
insta::assert_debug_snapshot!(result, @r###"
165+
CustomScalarJsonSchema(
166+
Object {
167+
"properties": Object {
168+
"test": Object {
169+
"test": Bool(true),
170+
},
171+
},
172+
"type": String("object"),
173+
},
174+
)
175+
"###)
176+
}
177+
178+
#[test]
179+
fn valid_schema() {
180+
let result = CustomScalarMap::from_str(
181+
r###"
182+
{
183+
"simple": {
184+
"type": "string"
185+
},
186+
"complex": {
187+
"type": "object",
188+
"properties": { "name": { "type": "string" } }
189+
}
190+
}
191+
"###,
192+
)
193+
.unwrap()
194+
.0;
195+
196+
let expected_data = HashMap::from_iter([
197+
(
198+
"simple".to_string(),
199+
SchemaObject {
200+
instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::String))),
201+
..Default::default()
202+
},
203+
),
204+
(
205+
"complex".to_string(),
206+
SchemaObject {
207+
instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::Object))),
208+
object: Some(Box::new(ObjectValidation {
209+
properties: BTreeMap::from_iter([(
210+
"name".to_string(),
211+
Schema::Object(SchemaObject {
212+
instance_type: Some(SingleOrVec::Single(Box::new(
213+
InstanceType::String,
214+
))),
215+
..Default::default()
216+
}),
217+
)]),
218+
..Default::default()
219+
})),
220+
..Default::default()
221+
},
222+
),
223+
]);
224+
assert_eq!(result, expected_data);
225+
}
226+
}

crates/mcp-apollo-server/src/errors.rs

+6
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,12 @@ pub enum ServerError {
5353

5454
#[error("invalid header: {0}")]
5555
Header(String),
56+
57+
#[error("invalid custom_scalar_config: {0}")]
58+
CustomScalarConfig(serde_json::Error),
59+
60+
#[error("invalid json schema: {0}")]
61+
CustomScalarJsonSchema(serde_json::Value),
5662
}
5763

5864
/// An MCP tool error

crates/mcp-apollo-server/src/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
pub mod custom_scalar_map;
12
pub mod errors;
23
mod graphql;
34
mod introspection;

crates/mcp-apollo-server/src/main.rs

+10
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ use apollo_compiler::Schema;
33
use clap::builder::Styles;
44
use clap::builder::styling::{AnsiColor, Effects};
55
use clap::{Parser, ValueEnum};
6+
use mcp_apollo_server::custom_scalar_map::CustomScalarMap;
67
use mcp_apollo_server::errors::ServerError;
78
use mcp_apollo_server::server::Server;
89
use rmcp::ServiceExt;
@@ -35,6 +36,10 @@ struct Args {
3536
#[clap(long, short = 's')]
3637
schema: PathBuf,
3738

39+
/// The path to the GraphQL custom_scalars_config file
40+
#[clap(long, short = 'c', required = false)]
41+
custom_scalars_config: Option<PathBuf>,
42+
3843
/// The GraphQL endpoint the server will invoke
3944
#[clap(long, short = 'e', default_value = "http://127.0.0.1:4000")]
4045
endpoint: String,
@@ -102,6 +107,11 @@ async fn main() -> anyhow::Result<()> {
102107
.operations(args.operations)
103108
.headers(args.headers)
104109
.introspection(args.introspection)
110+
.and_custom_scalar_map(
111+
args.custom_scalars_config
112+
.map(|custom_scalars_config| CustomScalarMap::try_from(&custom_scalars_config))
113+
.transpose()?,
114+
)
105115
.and_persisted_query_manifest(
106116
args.pq_manifest
107117
.map(

crates/mcp-apollo-server/src/operations.rs

+14-26
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
use std::collections::HashMap;
2-
31
use apollo_compiler::ast::{FragmentDefinition, Selection};
42
use apollo_compiler::{
53
Name, Node, Schema as GraphqlSchema,
@@ -17,6 +15,7 @@ use rmcp::{
1715
use rover_copy::pq_manifest::ApolloPersistedQueryManifest;
1816
use serde::Serialize;
1917

18+
use crate::custom_scalar_map::CustomScalarMap;
2019
use crate::errors::{McpError, OperationError};
2120
use crate::graphql;
2221
use crate::tree_shake::TreeShaker;
@@ -42,7 +41,7 @@ impl Operation {
4241
pub fn from_document(
4342
source_text: &str,
4443
graphql_schema: &GraphqlSchema,
45-
custom_scalar_map: Option<&HashMap<String, SchemaObject>>,
44+
custom_scalar_map: Option<&CustomScalarMap>,
4645
) -> Result<Self, OperationError> {
4746
let document = Parser::new()
4847
.parse_ast(source_text, "operation.graphql")
@@ -110,14 +109,15 @@ impl Operation {
110109
pub fn from_manifest(
111110
schema: &GraphqlSchema,
112111
manifest: ApolloPersistedQueryManifest,
112+
custom_scalar_map: Option<&CustomScalarMap>,
113113
) -> Result<Vec<Self>, OperationError> {
114114
manifest
115115
.operations
116116
.into_iter()
117117
.map(|pq| {
118118
tracing::info!(pesisted_query = pq.name, "Loading persisted query");
119119

120-
Self::from_document(&pq.body, schema, None)
120+
Self::from_document(&pq.body, schema, custom_scalar_map)
121121
})
122122
.collect::<Result<Vec<_>, _>>()
123123
}
@@ -230,7 +230,7 @@ impl Operation {
230230
fn get_json_schema(
231231
operation: &Node<OperationDefinition>,
232232
graphql_schema: &GraphqlSchema,
233-
custom_scalar_map: Option<&HashMap<String, SchemaObject>>,
233+
custom_scalar_map: Option<&CustomScalarMap>,
234234
) -> RootSchema {
235235
let mut obj = ObjectValidation::default();
236236

@@ -321,7 +321,7 @@ fn type_to_schema(
321321
description: Option<String>,
322322
variable_type: &Type,
323323
graphql_schema: &GraphqlSchema,
324-
custom_scalar_map: Option<&HashMap<String, SchemaObject>>,
324+
custom_scalar_map: Option<&CustomScalarMap>,
325325
) -> Schema {
326326
match variable_type {
327327
Type::NonNullNamed(named) | Type::Named(named) => match named.as_str() {
@@ -459,20 +459,13 @@ impl graphql::Executable for Operation {
459459

460460
#[cfg(test)]
461461
mod tests {
462-
use std::{
463-
collections::{HashMap, HashSet},
464-
sync::LazyLock,
465-
};
462+
use std::{collections::HashSet, str::FromStr, sync::LazyLock};
466463

467464
use apollo_compiler::{Schema, parser::Parser, validation::Valid};
468-
use rmcp::{
469-
model::Tool,
470-
schemars::schema::{InstanceType, SchemaObject, SingleOrVec},
471-
serde_json,
472-
};
465+
use rmcp::{model::Tool, serde_json};
473466
use rover_copy::pq_manifest::ApolloPersistedQueryManifest;
474467

475-
use crate::operations::Operation;
468+
use crate::{custom_scalar_map::CustomScalarMap, operations::Operation};
476469

477470
// Example schema for tests
478471
static SCHEMA: LazyLock<Valid<Schema>> = LazyLock::new(|| {
@@ -1069,7 +1062,7 @@ mod tests {
10691062
let operation = Operation::from_document(
10701063
"query QueryName($id: RealCustomScalar) { id }",
10711064
&SCHEMA,
1072-
Some(&HashMap::new()),
1065+
Some(&CustomScalarMap::from_str("{}").unwrap()),
10731066
)
10741067
.unwrap();
10751068
let tool = Tool::from(operation);
@@ -1092,18 +1085,13 @@ mod tests {
10921085

10931086
#[test]
10941087
fn custom_scalar_with_map() {
1095-
let custom_scalar_map = HashMap::from([(
1096-
"RealCustomScalar".to_string(),
1097-
SchemaObject {
1098-
instance_type: Some(SingleOrVec::Single(Box::new(InstanceType::String))),
1099-
..Default::default()
1100-
},
1101-
)]);
1088+
let custom_scalar_map =
1089+
CustomScalarMap::from_str("{ \"RealCustomScalar\": { \"type\": \"string\" }}");
11021090

11031091
let operation = Operation::from_document(
11041092
"query QueryName($id: RealCustomScalar) { id }",
11051093
&SCHEMA,
1106-
Some(&custom_scalar_map),
1094+
custom_scalar_map.ok().as_ref(),
11071095
)
11081096
.unwrap();
11091097
let tool = Tool::from(operation);
@@ -1355,7 +1343,7 @@ mod tests {
13551343
}))
13561344
.expect("apollo pq should be valid");
13571345

1358-
let operations = Operation::from_manifest(&SCHEMA, apollo_pq.clone())
1346+
let operations = Operation::from_manifest(&SCHEMA, apollo_pq.clone(), None)
13591347
.expect("operations from manifest should parse");
13601348
assert_eq!(
13611349
operations

0 commit comments

Comments
 (0)