Skip to content

Commit 369d105

Browse files
[Inference Providers] isolate image-to-image payload build for HF Inference API (#1439)
This PR is a prerequisite for #1427. It refactors the payload construction for `hf-inference` by isolating it into a separate async function. Note that adding a new async function to build the payload is necessary because `HFInferenceImageToImageTask.preparePayload` cannot be made async, yet the payload construction requires asynchronous operations. A similar pattern has already been implemented for [automaticSpeechRecognition](https://github.com/huggingface/huggingface.js/blob/361a0fad4c68943592f0bcbe41592d785eedcb81/packages/inference/src/tasks/audio/automaticSpeechRecognition.ts#L36) with fal.
1 parent 42e928a commit 369d105

File tree

5 files changed

+67
-54
lines changed

5 files changed

+67
-54
lines changed

packages/inference/src/providers/fal-ai.ts

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414
*
1515
* Thanks!
1616
*/
17+
import { base64FromBytes } from "../utils/base64FromBytes";
18+
1719
import type { AutomaticSpeechRecognitionOutput } from "@huggingface/tasks";
1820
import { InferenceOutputError } from "../lib/InferenceOutputError";
1921
import { isUrl } from "../lib/isUrl";
20-
import type { BodyParams, HeaderParams, ModelId, UrlParams } from "../types";
22+
import type { BodyParams, HeaderParams, ModelId, RequestArgs, UrlParams } from "../types";
2123
import { delay } from "../utils/delay";
2224
import { omit } from "../utils/omit";
2325
import {
@@ -27,6 +29,7 @@ import {
2729
type TextToVideoTaskHelper,
2830
} from "./providerHelper";
2931
import { HF_HUB_URL } from "../config";
32+
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition";
3033

3134
export interface FalAiQueueOutput {
3235
request_id: string;
@@ -224,6 +227,28 @@ export class FalAIAutomaticSpeechRecognitionTask extends FalAITask implements Au
224227
}
225228
return { text: res.text };
226229
}
230+
231+
async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
232+
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
233+
const contentType = blob?.type;
234+
if (!contentType) {
235+
throw new Error(
236+
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
237+
);
238+
}
239+
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
240+
throw new Error(
241+
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
242+
", "
243+
)}`
244+
);
245+
}
246+
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
247+
return {
248+
...("data" in args ? omit(args, "data") : omit(args, "inputs")),
249+
audio_url: `data:${contentType};base64,${base64audio}`,
250+
};
251+
}
227252
}
228253

229254
export class FalAITextToSpeechTask extends FalAITask {

packages/inference/src/providers/hf-inference.ts

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ import type {
3636
import { HF_ROUTER_URL } from "../config";
3737
import { InferenceOutputError } from "../lib/InferenceOutputError";
3838
import type { TabularClassificationOutput } from "../tasks/tabular/tabularClassification";
39-
import type { BodyParams, UrlParams } from "../types";
39+
import type { BodyParams, RequestArgs, UrlParams } from "../types";
4040
import { toArray } from "../utils/toArray";
4141
import type {
4242
AudioClassificationTaskHelper,
@@ -70,7 +70,10 @@ import type {
7070
} from "./providerHelper";
7171

7272
import { TaskProviderHelper } from "./providerHelper";
73-
73+
import { base64FromBytes } from "../utils/base64FromBytes";
74+
import type { ImageToImageArgs } from "../tasks/cv/imageToImage";
75+
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition";
76+
import { omit } from "../utils/omit";
7477
interface Base64ImageGeneration {
7578
data: Array<{
7679
b64_json: string;
@@ -221,6 +224,15 @@ export class HFInferenceAutomaticSpeechRecognitionTask
221224
override async getResponse(response: AutomaticSpeechRecognitionOutput): Promise<AutomaticSpeechRecognitionOutput> {
222225
return response;
223226
}
227+
228+
async preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
229+
return "data" in args
230+
? args
231+
: {
232+
...omit(args, "inputs"),
233+
data: args.inputs,
234+
};
235+
}
224236
}
225237

226238
export class HFInferenceAudioToAudioTask extends HFInferenceTask implements AudioToAudioTaskHelper {
@@ -326,6 +338,23 @@ export class HFInferenceImageToTextTask extends HFInferenceTask implements Image
326338
}
327339

328340
export class HFInferenceImageToImageTask extends HFInferenceTask implements ImageToImageTaskHelper {
341+
async preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs> {
342+
if (!args.parameters) {
343+
return {
344+
...args,
345+
model: args.model,
346+
data: args.inputs,
347+
};
348+
} else {
349+
return {
350+
...args,
351+
inputs: base64FromBytes(
352+
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await (args.inputs as Blob).arrayBuffer())
353+
),
354+
};
355+
}
356+
}
357+
329358
override async getResponse(response: Blob): Promise<Blob> {
330359
if (response instanceof Blob) {
331360
return response;

packages/inference/src/providers/providerHelper.ts

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,8 +48,10 @@ import type {
4848
import { HF_ROUTER_URL } from "../config";
4949
import { InferenceOutputError } from "../lib/InferenceOutputError";
5050
import type { AudioToAudioOutput } from "../tasks/audio/audioToAudio";
51-
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, UrlParams } from "../types";
51+
import type { BaseArgs, BodyParams, HeaderParams, InferenceProvider, RequestArgs, UrlParams } from "../types";
5252
import { toArray } from "../utils/toArray";
53+
import type { ImageToImageArgs } from "../tasks/cv/imageToImage";
54+
import type { AutomaticSpeechRecognitionArgs } from "../tasks/audio/automaticSpeechRecognition";
5355

5456
/**
5557
* Base class for task-specific provider helpers
@@ -142,6 +144,7 @@ export interface TextToVideoTaskHelper {
142144
export interface ImageToImageTaskHelper {
143145
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<Blob>;
144146
preparePayload(params: BodyParams<ImageToImageInput & BaseArgs>): Record<string, unknown>;
147+
preparePayloadAsync(args: ImageToImageArgs): Promise<RequestArgs>;
145148
}
146149

147150
export interface ImageSegmentationTaskHelper {
@@ -245,6 +248,7 @@ export interface AudioToAudioTaskHelper {
245248
export interface AutomaticSpeechRecognitionTaskHelper {
246249
getResponse(response: unknown, url?: string, headers?: HeadersInit): Promise<AutomaticSpeechRecognitionOutput>;
247250
preparePayload(params: BodyParams<AutomaticSpeechRecognitionInput & BaseArgs>): Record<string, unknown> | BodyInit;
251+
preparePayloadAsync(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs>;
248252
}
249253

250254
export interface AudioClassificationTaskHelper {

packages/inference/src/tasks/audio/automaticSpeechRecognition.ts

Lines changed: 2 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,9 @@ import type { AutomaticSpeechRecognitionInput, AutomaticSpeechRecognitionOutput
22
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
33
import { getProviderHelper } from "../../lib/getProviderHelper";
44
import { InferenceOutputError } from "../../lib/InferenceOutputError";
5-
import { FAL_AI_SUPPORTED_BLOB_TYPES } from "../../providers/fal-ai";
6-
import type { BaseArgs, Options, RequestArgs } from "../../types";
7-
import { base64FromBytes } from "../../utils/base64FromBytes";
8-
import { omit } from "../../utils/omit";
5+
import type { BaseArgs, Options } from "../../types";
96
import { innerRequest } from "../../utils/request";
107
import type { LegacyAudioInput } from "./utils";
11-
import { preparePayload } from "./utils";
128

139
export type AutomaticSpeechRecognitionArgs = BaseArgs & (AutomaticSpeechRecognitionInput | LegacyAudioInput);
1410
/**
@@ -21,7 +17,7 @@ export async function automaticSpeechRecognition(
2117
): Promise<AutomaticSpeechRecognitionOutput> {
2218
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
2319
const providerHelper = getProviderHelper(provider, "automatic-speech-recognition");
24-
const payload = await buildPayload(args);
20+
const payload = await providerHelper.preparePayloadAsync(args);
2521
const { data: res } = await innerRequest<AutomaticSpeechRecognitionOutput>(payload, providerHelper, {
2622
...options,
2723
task: "automatic-speech-recognition",
@@ -32,29 +28,3 @@ export async function automaticSpeechRecognition(
3228
}
3329
return providerHelper.getResponse(res);
3430
}
35-
36-
async function buildPayload(args: AutomaticSpeechRecognitionArgs): Promise<RequestArgs> {
37-
if (args.provider === "fal-ai") {
38-
const blob = "data" in args && args.data instanceof Blob ? args.data : "inputs" in args ? args.inputs : undefined;
39-
const contentType = blob?.type;
40-
if (!contentType) {
41-
throw new Error(
42-
`Unable to determine the input's content-type. Make sure your are passing a Blob when using provider fal-ai.`
43-
);
44-
}
45-
if (!FAL_AI_SUPPORTED_BLOB_TYPES.includes(contentType)) {
46-
throw new Error(
47-
`Provider fal-ai does not support blob type ${contentType} - supported content types are: ${FAL_AI_SUPPORTED_BLOB_TYPES.join(
48-
", "
49-
)}`
50-
);
51-
}
52-
const base64audio = base64FromBytes(new Uint8Array(await blob.arrayBuffer()));
53-
return {
54-
...("data" in args ? omit(args, "data") : omit(args, "inputs")),
55-
audio_url: `data:${contentType};base64,${base64audio}`,
56-
};
57-
} else {
58-
return preparePayload(args);
59-
}
60-
}

packages/inference/src/tasks/cv/imageToImage.ts

Lines changed: 3 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import type { ImageToImageInput } from "@huggingface/tasks";
22
import { resolveProvider } from "../../lib/getInferenceProviderMapping";
33
import { getProviderHelper } from "../../lib/getProviderHelper";
4-
import type { BaseArgs, Options, RequestArgs } from "../../types";
5-
import { base64FromBytes } from "../../utils/base64FromBytes";
4+
import type { BaseArgs, Options } from "../../types";
65
import { innerRequest } from "../../utils/request";
76

87
export type ImageToImageArgs = BaseArgs & ImageToImageInput;
@@ -14,22 +13,8 @@ export type ImageToImageArgs = BaseArgs & ImageToImageInput;
1413
export async function imageToImage(args: ImageToImageArgs, options?: Options): Promise<Blob> {
1514
const provider = await resolveProvider(args.provider, args.model, args.endpointUrl);
1615
const providerHelper = getProviderHelper(provider, "image-to-image");
17-
let reqArgs: RequestArgs;
18-
if (!args.parameters) {
19-
reqArgs = {
20-
accessToken: args.accessToken,
21-
model: args.model,
22-
data: args.inputs,
23-
};
24-
} else {
25-
reqArgs = {
26-
...args,
27-
inputs: base64FromBytes(
28-
new Uint8Array(args.inputs instanceof ArrayBuffer ? args.inputs : await args.inputs.arrayBuffer())
29-
),
30-
};
31-
}
32-
const { data: res } = await innerRequest<Blob>(reqArgs, providerHelper, {
16+
const payload = await providerHelper.preparePayloadAsync(args);
17+
const { data: res } = await innerRequest<Blob>(payload, providerHelper, {
3318
...options,
3419
task: "image-to-image",
3520
});

0 commit comments

Comments
 (0)