Skip to content

Commit a2e436e

Browse files
tomasonjoeaswee
andauthored
community[patch]: Fix strict mode comparison and formatting for llm graph transformer (#4988)
* Fix strict mode comparison and formatting for llm graph transformer * Map only once, extend serializable. * formatting * lc namespace * formattting --------- Co-authored-by: Anej Gorkič <[email protected]>
1 parent be4f8b7 commit a2e436e

File tree

3 files changed

+77
-14
lines changed

3 files changed

+77
-14
lines changed

libs/langchain-community/src/experimental/graph_transformers/llm.int.test.ts

+43-4
Original file line numberDiff line numberDiff line change
@@ -45,13 +45,52 @@ test("convertToGraphDocuments with allowed", async () => {
4545
expect(result).toEqual([
4646
new GraphDocument({
4747
nodes: [
48-
new Node({ id: "Elon Musk", type: "PERSON" }),
49-
new Node({ id: "OpenAI", type: "ORGANIZATION" }),
48+
new Node({ id: "Elon Musk", type: "Person" }),
49+
new Node({ id: "OpenAI", type: "Organization" }),
5050
],
5151
relationships: [
5252
new Relationship({
53-
source: new Node({ id: "Elon Musk", type: "PERSON" }),
54-
target: new Node({ id: "OpenAI", type: "ORGANIZATION" }),
53+
source: new Node({ id: "Elon Musk", type: "Person" }),
54+
target: new Node({ id: "OpenAI", type: "Organization" }),
55+
type: "SUES",
56+
}),
57+
],
58+
source: new Document({
59+
pageContent: "Elon Musk is suing OpenAI",
60+
metadata: {},
61+
}),
62+
}),
63+
]);
64+
});
65+
66+
test("convertToGraphDocuments with allowed lowercased", async () => {
67+
const model = new ChatOpenAI({
68+
temperature: 0,
69+
modelName: "gpt-4-turbo-preview",
70+
});
71+
72+
const llmGraphTransformer = new LLMGraphTransformer({
73+
llm: model,
74+
allowedNodes: ["Person", "Organization"],
75+
allowedRelationships: ["SUES"],
76+
});
77+
78+
const result = await llmGraphTransformer.convertToGraphDocuments([
79+
new Document({ pageContent: "Elon Musk is suing OpenAI" }),
80+
]);
81+
82+
console.log(JSON.stringify(result));
83+
84+
expect(result).toEqual([
85+
new GraphDocument({
86+
nodes: [
87+
new Node({ id: "Elon Musk", type: "Person" }),
88+
new Node({ id: "OpenAI", type: "Organization" }),
89+
],
90+
relationships: [
91+
new Relationship({
92+
source: new Node({ id: "Elon Musk", type: "Person" }),
93+
target: new Node({ id: "OpenAI", type: "Organization" }),
5594
type: "SUES",
5695
}),
5796
],

libs/langchain-community/src/experimental/graph_transformers/llm.ts

+27-7
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,13 @@ interface OptionalEnumFieldProps {
4747
fieldKwargs?: object;
4848
}
4949

50+
function toTitleCase(str: string): string {
51+
return str
52+
.split(" ")
53+
.map((w) => w[0].toUpperCase() + w.substring(1).toLowerCase())
54+
.join("");
55+
}
56+
5057
function createOptionalEnumType({
5158
enumValues = undefined,
5259
description = "",
@@ -122,7 +129,7 @@ function createSchema(allowedNodes: string[], allowedRelationships: string[]) {
122129
function mapToBaseNode(node: any): Node {
123130
return new Node({
124131
id: node.id,
125-
type: node.type.replace(" ", "_").toUpperCase(),
132+
type: toTitleCase(node.type),
126133
});
127134
}
128135

@@ -131,11 +138,11 @@ function mapToBaseRelationship(relationship: any): Relationship {
131138
return new Relationship({
132139
source: new Node({
133140
id: relationship.sourceNodeId,
134-
type: relationship.sourceNodeType.replace(" ", "_").toUpperCase(),
141+
type: toTitleCase(relationship.sourceNodeType),
135142
}),
136143
target: new Node({
137144
id: relationship.targetNodeId,
138-
type: relationship.targetNodeType.replace(" ", "_").toUpperCase(),
145+
type: toTitleCase(relationship.targetNodeType),
139146
}),
140147
type: relationship.relationshipType.replace(" ", "_").toUpperCase(),
141148
});
@@ -208,16 +215,29 @@ export class LLMGraphTransformer {
208215
(this.allowedNodes.length > 0 || this.allowedRelationships.length > 0)
209216
) {
210217
if (this.allowedNodes.length > 0) {
211-
nodes = nodes.filter((node) => this.allowedNodes.includes(node.type));
218+
const allowedNodesLowerCase = this.allowedNodes.map((node) =>
219+
node.toLowerCase()
220+
);
221+
222+
// For nodes, compare lowercased types
223+
nodes = nodes.filter((node) =>
224+
allowedNodesLowerCase.includes(node.type.toLowerCase())
225+
);
226+
227+
// For relationships, compare lowercased types for both source and target nodes
212228
relationships = relationships.filter(
213229
(rel) =>
214-
this.allowedNodes.includes(rel.source.type) &&
215-
this.allowedNodes.includes(rel.target.type)
230+
allowedNodesLowerCase.includes(rel.source.type.toLowerCase()) &&
231+
allowedNodesLowerCase.includes(rel.target.type.toLowerCase())
216232
);
217233
}
234+
218235
if (this.allowedRelationships.length > 0) {
236+
// For relationships, compare lowercased types
219237
relationships = relationships.filter((rel) =>
220-
this.allowedRelationships.includes(rel.type)
238+
this.allowedRelationships
239+
.map((rel) => rel.toLowerCase())
240+
.includes(rel.type.toLowerCase())
221241
);
222242
}
223243
}

libs/langchain-community/src/graphs/graph_document.ts

+7-3
Original file line numberDiff line numberDiff line change
@@ -60,14 +60,14 @@ export class Relationship extends Serializable {
6060
}
6161
}
6262

63-
export class GraphDocument extends Document {
63+
export class GraphDocument extends Serializable {
6464
nodes: Node[];
6565

6666
relationships: Relationship[];
6767

6868
source: Document;
6969

70-
lc_namespace = ["langchain", "graph", "document_node"];
70+
lc_namespace = ["langchain", "graph", "graph_document"];
7171

7272
constructor({
7373
nodes,
@@ -78,7 +78,11 @@ export class GraphDocument extends Document {
7878
relationships: Relationship[];
7979
source: Document;
8080
}) {
81-
super(source);
81+
super({
82+
nodes,
83+
relationships,
84+
source,
85+
});
8286
this.nodes = nodes;
8387
this.relationships = relationships;
8488
this.source = source;

0 commit comments

Comments
 (0)