Skip to content

Commit ce2fe36

Browse files
committed
Merge branch 'main' into xsn/llama_snippet
2 parents d8dafa2 + 7004980 commit ce2fe36

File tree

8 files changed

+134
-52
lines changed

8 files changed

+134
-52
lines changed

packages/gguf/src/gguf.spec.ts

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
import { describe, expect, it } from "vitest";
1+
import { beforeAll, describe, expect, it } from "vitest";
2+
import type { GGUFParseOutput } from "./gguf";
23
import { GGMLQuantizationType, gguf, ggufAllShards, parseGgufShardFilename } from "./gguf";
34
import fs from "node:fs";
45

@@ -12,8 +13,19 @@ const URL_V1 =
1213
"https://huggingface.co/tmadge/testing/resolve/66c078028d1ff92d7a9264a1590bc61ba6437933/tinyllamas-stories-260k-f32.gguf";
1314
const URL_SHARDED_GROK =
1415
"https://huggingface.co/Arki05/Grok-1-GGUF/resolve/ecafa8d8eca9b8cd75d11a0d08d3a6199dc5a068/grok-1-IQ3_XS-split-00001-of-00009.gguf";
16+
const URL_BIG_METADATA = "https://huggingface.co/ngxson/test_gguf_models/resolve/main/gguf_test_big_metadata.gguf";
1517

1618
describe("gguf", () => {
19+
beforeAll(async () => {
20+
// download the gguf for "load file" test, save to .cache directory
21+
if (!fs.existsSync(".cache")) {
22+
fs.mkdirSync(".cache");
23+
}
24+
const res = await fetch(URL_BIG_METADATA);
25+
const arrayBuf = await res.arrayBuffer();
26+
fs.writeFileSync(".cache/model.gguf", Buffer.from(arrayBuf));
27+
});
28+
1729
it("should parse a llama2 7b", async () => {
1830
const { metadata, tensorInfos } = await gguf(URL_LLAMA);
1931

@@ -228,16 +240,10 @@ describe("gguf", () => {
228240
});
229241

230242
it("should parse a local file", async () => {
231-
// download the file and save to .cache folder
232-
if (!fs.existsSync(".cache")) {
233-
fs.mkdirSync(".cache");
234-
}
235-
const res = await fetch(URL_V1);
236-
const arrayBuf = await res.arrayBuffer();
237-
fs.writeFileSync(".cache/model.gguf", Buffer.from(arrayBuf));
238-
239-
const { metadata } = await gguf(".cache/model.gguf", { allowLocalFile: true });
240-
expect(metadata).toMatchObject({ "general.name": "tinyllamas-stories-260k" });
243+
const parsedGguf = await gguf(".cache/model.gguf", { allowLocalFile: true });
244+
const { metadata } = parsedGguf as GGUFParseOutput<{ strict: false }>; // custom metadata arch, no need for typing
245+
expect(metadata["dummy.1"]).toBeDefined(); // first metadata in the list
246+
expect(metadata["dummy.32767"]).toBeDefined(); // last metadata in the list
241247
});
242248

243249
it("should detect sharded gguf filename", async () => {

packages/tasks/package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"name": "@huggingface/tasks",
33
"packageManager": "[email protected]",
4-
"version": "0.10.19",
4+
"version": "0.10.20",
55
"description": "List of ML tasks for huggingface.co/tasks",
66
"repository": "https://github.com/huggingface/huggingface.js.git",
77
"publishConfig": {

packages/tasks/src/local-apps.ts

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,14 +38,14 @@ export type LocalApp = {
3838
/**
3939
* If the app supports deeplink, URL to open.
4040
*/
41-
deeplink: (model: ModelData) => URL;
41+
deeplink: (model: ModelData, filepath?: string) => URL;
4242
}
4343
| {
4444
/**
4545
* And if not (mostly llama.cpp), snippet to copy/paste in your terminal
4646
* Support the placeholder {{GGUF_FILE}} that will be replaced by the gguf file path or the list of available files.
4747
*/
48-
snippet: (model: ModelData) => Snippet | Snippet[];
48+
snippet: (model: ModelData, filepath?: string) => string | string[] | Snippet | Snippet[];
4949
}
5050
);
5151

@@ -118,7 +118,8 @@ export const LOCAL_APPS = {
118118
docsUrl: "https://lmstudio.ai",
119119
mainTask: "text-generation",
120120
displayOnModelPage: isGgufModel,
121-
deeplink: (model) => new URL(`lmstudio://open_from_hf?model=${model.id}`),
121+
deeplink: (model, filepath) =>
122+
new URL(`lmstudio://open_from_hf?model=${model.id}` + filepath ? `&file=${filepath}` : ""),
122123
},
123124
jan: {
124125
prettyLabel: "Jan",

packages/tasks/src/snippets/curl.ts

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): stri
1010
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}"
1111
`;
1212

13+
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
14+
if (model.config?.tokenizer_config?.chat_template) {
15+
// Conversational model detected, so we display a code snippet that features the Messages API
16+
return `curl 'https://api-inference.huggingface.co/models/${model.id}/v1/chat/completions' \\
17+
-H "Authorization: Bearer ${accessToken || `{API_TOKEN}`}" \\
18+
-H 'Content-Type: application/json' \\
19+
-d '{
20+
"model": "${model.id}",
21+
"messages": [{"role": "user", "content": "What is the capital of France?"}],
22+
"max_tokens": 500,
23+
"stream": false
24+
}'
25+
`;
26+
} else {
27+
return snippetBasic(model, accessToken);
28+
}
29+
};
30+
1331
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
1432
`curl https://api-inference.huggingface.co/models/${model.id} \\
1533
-X POST \\
@@ -35,7 +53,7 @@ export const curlSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal
3553
translation: snippetBasic,
3654
summarization: snippetBasic,
3755
"feature-extraction": snippetBasic,
38-
"text-generation": snippetBasic,
56+
"text-generation": snippetTextGeneration,
3957
"text2text-generation": snippetBasic,
4058
"fill-mask": snippetBasic,
4159
"sentence-similarity": snippetBasic,

packages/tasks/src/snippets/inputs.ts

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,30 @@ const inputsSummarization = () =>
1111

1212
const inputsTableQuestionAnswering = () =>
1313
`{
14-
"query": "How many stars does the transformers repository have?",
15-
"table": {
16-
"Repository": ["Transformers", "Datasets", "Tokenizers"],
17-
"Stars": ["36542", "4512", "3934"],
18-
"Contributors": ["651", "77", "34"],
19-
"Programming language": [
20-
"Python",
21-
"Python",
22-
"Rust, Python and NodeJS"
23-
]
24-
}
25-
}`;
14+
"query": "How many stars does the transformers repository have?",
15+
"table": {
16+
"Repository": ["Transformers", "Datasets", "Tokenizers"],
17+
"Stars": ["36542", "4512", "3934"],
18+
"Contributors": ["651", "77", "34"],
19+
"Programming language": [
20+
"Python",
21+
"Python",
22+
"Rust, Python and NodeJS"
23+
]
24+
}
25+
}`;
2626

2727
const inputsVisualQuestionAnswering = () =>
2828
`{
29-
"image": "cat.png",
30-
"question": "What is in this image?"
31-
}`;
29+
"image": "cat.png",
30+
"question": "What is in this image?"
31+
}`;
3232

3333
const inputsQuestionAnswering = () =>
3434
`{
35-
"question": "What is my name?",
36-
"context": "My name is Clara and I live in Berkeley."
37-
}`;
35+
"question": "What is my name?",
36+
"context": "My name is Clara and I live in Berkeley."
37+
}`;
3838

3939
const inputsTextClassification = () => `"I like you. I love you"`;
4040

@@ -48,13 +48,13 @@ const inputsFillMask = (model: ModelDataMinimal) => `"The answer to the universe
4848

4949
const inputsSentenceSimilarity = () =>
5050
`{
51-
"source_sentence": "That is a happy person",
52-
"sentences": [
53-
"That is a happy dog",
54-
"That is a very happy person",
55-
"Today is a sunny day"
56-
]
57-
}`;
51+
"source_sentence": "That is a happy person",
52+
"sentences": [
53+
"That is a happy dog",
54+
"That is a very happy person",
55+
"Today is a sunny day"
56+
]
57+
}`;
5858

5959
const inputsFeatureExtraction = () => `"Today is a sunny day and I will get some ice cream."`;
6060

packages/tasks/src/snippets/js.ts

Lines changed: 40 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@ export const snippetBasic = (model: ModelDataMinimal, accessToken: string): stri
77
const response = await fetch(
88
"https://api-inference.huggingface.co/models/${model.id}",
99
{
10-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
10+
headers: {
11+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
12+
"Content-Type": "application/json",
13+
},
1114
method: "POST",
1215
body: JSON.stringify(data),
1316
}
@@ -20,12 +23,34 @@ query({"inputs": ${getModelInputSnippet(model)}}).then((response) => {
2023
console.log(JSON.stringify(response));
2124
});`;
2225

26+
export const snippetTextGeneration = (model: ModelDataMinimal, accessToken: string): string => {
27+
if (model.config?.tokenizer_config?.chat_template) {
28+
// Conversational model detected, so we display a code snippet that features the Messages API
29+
return `import { HfInference } from "@huggingface/inference";
30+
31+
const inference = new HfInference("${accessToken || `{API_TOKEN}`}");
32+
33+
for await (const chunk of inference.chatCompletionStream({
34+
model: "${model.id}",
35+
messages: [{ role: "user", content: "What is the capital of France?" }],
36+
max_tokens: 500,
37+
})) {
38+
process.stdout.write(chunk.choices[0]?.delta?.content || "");
39+
}
40+
`;
41+
} else {
42+
return snippetBasic(model, accessToken);
43+
}
44+
};
2345
export const snippetZeroShotClassification = (model: ModelDataMinimal, accessToken: string): string =>
2446
`async function query(data) {
2547
const response = await fetch(
2648
"https://api-inference.huggingface.co/models/${model.id}",
2749
{
28-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
50+
headers: {
51+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
52+
"Content-Type": "application/json",
53+
},
2954
method: "POST",
3055
body: JSON.stringify(data),
3156
}
@@ -45,7 +70,10 @@ export const snippetTextToImage = (model: ModelDataMinimal, accessToken: string)
4570
const response = await fetch(
4671
"https://api-inference.huggingface.co/models/${model.id}",
4772
{
48-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
73+
headers: {
74+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
75+
"Content-Type": "application/json",
76+
},
4977
method: "POST",
5078
body: JSON.stringify(data),
5179
}
@@ -62,7 +90,10 @@ export const snippetTextToAudio = (model: ModelDataMinimal, accessToken: string)
6290
const response = await fetch(
6391
"https://api-inference.huggingface.co/models/${model.id}",
6492
{
65-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
93+
headers: {
94+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
95+
"Content-Type": "application/json",
96+
},
6697
method: "POST",
6798
body: JSON.stringify(data),
6899
}
@@ -99,7 +130,10 @@ export const snippetFile = (model: ModelDataMinimal, accessToken: string): strin
99130
const response = await fetch(
100131
"https://api-inference.huggingface.co/models/${model.id}",
101132
{
102-
headers: { Authorization: "Bearer ${accessToken || `{API_TOKEN}`}" },
133+
headers: {
134+
Authorization: "Bearer ${accessToken || `{API_TOKEN}`}"
135+
"Content-Type": "application/json",
136+
},
103137
method: "POST",
104138
body: data,
105139
}
@@ -122,7 +156,7 @@ export const jsSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal,
122156
translation: snippetBasic,
123157
summarization: snippetBasic,
124158
"feature-extraction": snippetBasic,
125-
"text-generation": snippetBasic,
159+
"text-generation": snippetTextGeneration,
126160
"text2text-generation": snippetBasic,
127161
"fill-mask": snippetBasic,
128162
"sentence-similarity": snippetBasic,

packages/tasks/src/snippets/python.ts

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,22 @@ import type { PipelineType } from "../pipelines.js";
22
import { getModelInputSnippet } from "./inputs.js";
33
import type { ModelDataMinimal } from "./types.js";
44

5+
export const snippetConversational = (model: ModelDataMinimal, accessToken: string): string =>
6+
`from huggingface_hub import InferenceClient
7+
8+
client = InferenceClient(
9+
"${model.id}",
10+
token="${accessToken || "{API_TOKEN}"}",
11+
)
12+
13+
for message in client.chat_completion(
14+
messages=[{"role": "user", "content": "What is the capital of France?"}],
15+
max_tokens=500,
16+
stream=True,
17+
):
18+
print(message.choices[0].delta.content, end="")
19+
`;
20+
521
export const snippetZeroShotClassification = (model: ModelDataMinimal): string =>
622
`def query(payload):
723
response = requests.post(API_URL, headers=headers, json=payload)
@@ -107,7 +123,7 @@ output = query({
107123
"inputs": ${getModelInputSnippet(model)},
108124
})`;
109125

110-
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal) => string>> = {
126+
export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinimal, accessToken: string) => string>> = {
111127
// Same order as in tasks/src/pipelines.ts
112128
"text-classification": snippetBasic,
113129
"token-classification": snippetBasic,
@@ -138,15 +154,22 @@ export const pythonSnippets: Partial<Record<PipelineType, (model: ModelDataMinim
138154
};
139155

140156
export function getPythonInferenceSnippet(model: ModelDataMinimal, accessToken: string): string {
141-
const body =
142-
model.pipeline_tag && model.pipeline_tag in pythonSnippets ? pythonSnippets[model.pipeline_tag]?.(model) ?? "" : "";
157+
if (model.pipeline_tag === "text-generation" && model.config?.tokenizer_config?.chat_template) {
158+
// Conversational model detected, so we display a code snippet that features the Messages API
159+
return snippetConversational(model, accessToken);
160+
} else {
161+
const body =
162+
model.pipeline_tag && model.pipeline_tag in pythonSnippets
163+
? pythonSnippets[model.pipeline_tag]?.(model, accessToken) ?? ""
164+
: "";
143165

144-
return `import requests
166+
return `import requests
145167
146168
API_URL = "https://api-inference.huggingface.co/models/${model.id}"
147169
headers = {"Authorization": ${accessToken ? `"Bearer ${accessToken}"` : `f"Bearer {API_TOKEN}"`}}
148170
149171
${body}`;
172+
}
150173
}
151174

152175
export function hasPythonInferenceSnippet(model: ModelDataMinimal): boolean {

packages/tasks/src/snippets/types.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,4 +5,4 @@ import type { ModelData } from "../model-data";
55
*
66
* Add more fields as needed.
77
*/
8-
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name">;
8+
export type ModelDataMinimal = Pick<ModelData, "id" | "pipeline_tag" | "mask_token" | "library_name" | "config">;

0 commit comments

Comments
 (0)