@@ -47,6 +47,13 @@ interface OptionalEnumFieldProps {
47
47
fieldKwargs ?: object ;
48
48
}
49
49
50
+ function toTitleCase ( str : string ) : string {
51
+ return str
52
+ . split ( " " )
53
+ . map ( ( w ) => w [ 0 ] . toUpperCase ( ) + w . substring ( 1 ) . toLowerCase ( ) )
54
+ . join ( "" ) ;
55
+ }
56
+
50
57
function createOptionalEnumType ( {
51
58
enumValues = undefined ,
52
59
description = "" ,
@@ -122,7 +129,7 @@ function createSchema(allowedNodes: string[], allowedRelationships: string[]) {
122
129
function mapToBaseNode ( node : any ) : Node {
123
130
return new Node ( {
124
131
id : node . id ,
125
- type : node . type . replace ( " " , "_" ) . toUpperCase ( ) ,
132
+ type : toTitleCase ( node . type ) ,
126
133
} ) ;
127
134
}
128
135
@@ -131,11 +138,11 @@ function mapToBaseRelationship(relationship: any): Relationship {
131
138
return new Relationship ( {
132
139
source : new Node ( {
133
140
id : relationship . sourceNodeId ,
134
- type : relationship . sourceNodeType . replace ( " " , "_" ) . toUpperCase ( ) ,
141
+ type : toTitleCase ( relationship . sourceNodeType ) ,
135
142
} ) ,
136
143
target : new Node ( {
137
144
id : relationship . targetNodeId ,
138
- type : relationship . targetNodeType . replace ( " " , "_" ) . toUpperCase ( ) ,
145
+ type : toTitleCase ( relationship . targetNodeType ) ,
139
146
} ) ,
140
147
type : relationship . relationshipType . replace ( " " , "_" ) . toUpperCase ( ) ,
141
148
} ) ;
@@ -208,16 +215,29 @@ export class LLMGraphTransformer {
208
215
( this . allowedNodes . length > 0 || this . allowedRelationships . length > 0 )
209
216
) {
210
217
if ( this . allowedNodes . length > 0 ) {
211
- nodes = nodes . filter ( ( node ) => this . allowedNodes . includes ( node . type ) ) ;
218
+ const allowedNodesLowerCase = this . allowedNodes . map ( ( node ) =>
219
+ node . toLowerCase ( )
220
+ ) ;
221
+
222
+ // For nodes, compare lowercased types
223
+ nodes = nodes . filter ( ( node ) =>
224
+ allowedNodesLowerCase . includes ( node . type . toLowerCase ( ) )
225
+ ) ;
226
+
227
+ // For relationships, compare lowercased types for both source and target nodes
212
228
relationships = relationships . filter (
213
229
( rel ) =>
214
- this . allowedNodes . includes ( rel . source . type ) &&
215
- this . allowedNodes . includes ( rel . target . type )
230
+ allowedNodesLowerCase . includes ( rel . source . type . toLowerCase ( ) ) &&
231
+ allowedNodesLowerCase . includes ( rel . target . type . toLowerCase ( ) )
216
232
) ;
217
233
}
234
+
218
235
if ( this . allowedRelationships . length > 0 ) {
236
+ // For relationships, compare lowercased types
219
237
relationships = relationships . filter ( ( rel ) =>
220
- this . allowedRelationships . includes ( rel . type )
238
+ this . allowedRelationships
239
+ . map ( ( rel ) => rel . toLowerCase ( ) )
240
+ . includes ( rel . type . toLowerCase ( ) )
221
241
) ;
222
242
}
223
243
}
0 commit comments