Skip to content

Commit 5b51897

Browse files
Fix desync between app_inference_spec and validator.
1 parent d0f0f52 commit 5b51897

File tree

2 files changed

+64
-58
lines changed

2 files changed

+64
-58
lines changed

app_inference_spec.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
# github.com/guardrails-ai/models-host/tree/main/ray#adding-new-inference-endpoints
44
import os
55
from logging import getLogger
6+
from typing import List
67

78
from pydantic import BaseModel
89
from models_host.base_inference_spec import BaseInferenceSpec
@@ -20,11 +21,11 @@
2021

2122

2223
class InputRequest(BaseModel):
23-
message: str
24+
prompts: List[str]
2425

2526

2627
class OutputResponse(BaseModel):
27-
score: float
28+
scores: List[float]
2829

2930

3031
# Using same nomenclature as in Sagemaker classes
@@ -59,14 +60,14 @@ def load(self):
5960
self.model = DetectJailbreak(**kwargs)
6061

6162
def process_request(self, input_request: InputRequest):
62-
message = input_request.message
63+
prompts = input_request.prompts
6364
# If needed, sanity check.
6465
# raise HTTPException(status_code=400, detail="Invalid input format")
65-
args = (message,)
66+
args = (prompts,)
6667
kwargs = {}
6768
return args, kwargs
6869

69-
def infer(self, message: str) -> OutputResponse:
70+
def infer(self, prompts: List[str]) -> OutputResponse:
7071
return OutputResponse(
71-
score=self.model.predict_jailbreak([message,])[0],
72+
scores=self.model.predict_jailbreak(prompts),
7273
)

validator/main.py

Lines changed: 57 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -71,53 +71,59 @@ def __init__(
7171
super().__init__(on_fail=on_fail, **kwargs)
7272
self.device = device
7373
self.threshold = threshold
74-
75-
if not model_path_override:
76-
self.saturation_attack_detector = PromptSaturationDetectorV3(
77-
device=torch.device(device),
78-
)
79-
self.text_classifier = pipeline(
80-
"text-classification",
81-
DetectJailbreak.TEXT_CLASSIFIER_NAME,
82-
max_length=512, # HACK: Fix classifier size.
83-
truncation=True,
84-
device=device,
85-
)
86-
# There are a large number of fairly low-effort prompts people will use.
87-
# The embedding detectors do checks to roughly match those.
88-
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
89-
DetectJailbreak.EMBEDDING_MODEL_NAME
90-
)
91-
self.embedding_model = AutoModel.from_pretrained(
92-
DetectJailbreak.EMBEDDING_MODEL_NAME
93-
).to(device)
94-
else:
95-
# Saturation:
96-
self.saturation_attack_detector = PromptSaturationDetectorV3(
97-
device=torch.device(device),
98-
model_path_override=model_path_override
99-
)
100-
# Known attacks:
101-
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
102-
model_path_override,
103-
"embedding",
104-
AutoTokenizer,
105-
AutoModel
106-
)
107-
self.embedding_tokenizer = embedding_tokenizer
108-
self.embedding_model = embedding_model.to(device)
109-
# Other text attacks:
110-
self.text_classifier = get_pipeline_by_path(
111-
model_path_override,
112-
"text-classifier",
113-
"text-classification",
114-
max_length=512,
115-
truncation=True,
116-
device=device
117-
)
118-
119-
# Quick compute on startup:
120-
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)
74+
self.saturation_attack_detector = None
75+
self.text_classifier = None
76+
self.embedding_tokenizer = None
77+
self.embedding_model = None
78+
self.known_malicious_embeddings = []
79+
80+
if self.use_local:
81+
if not model_path_override:
82+
self.saturation_attack_detector = PromptSaturationDetectorV3(
83+
device=torch.device(device),
84+
)
85+
self.text_classifier = pipeline(
86+
"text-classification",
87+
DetectJailbreak.TEXT_CLASSIFIER_NAME,
88+
max_length=512, # HACK: Fix classifier size.
89+
truncation=True,
90+
device=device,
91+
)
92+
# There are a large number of fairly low-effort prompts people will use.
93+
# The embedding detectors do checks to roughly match those.
94+
self.embedding_tokenizer = AutoTokenizer.from_pretrained(
95+
DetectJailbreak.EMBEDDING_MODEL_NAME
96+
)
97+
self.embedding_model = AutoModel.from_pretrained(
98+
DetectJailbreak.EMBEDDING_MODEL_NAME
99+
).to(device)
100+
else:
101+
# Saturation:
102+
self.saturation_attack_detector = PromptSaturationDetectorV3(
103+
device=torch.device(device),
104+
model_path_override=model_path_override
105+
)
106+
# Known attacks:
107+
embedding_tokenizer, embedding_model = get_tokenizer_and_model_by_path(
108+
model_path_override,
109+
"embedding",
110+
AutoTokenizer,
111+
AutoModel
112+
)
113+
self.embedding_tokenizer = embedding_tokenizer
114+
self.embedding_model = embedding_model.to(device)
115+
# Other text attacks:
116+
self.text_classifier = get_pipeline_by_path(
117+
model_path_override,
118+
"text-classifier",
119+
"text-classification",
120+
max_length=512,
121+
truncation=True,
122+
device=device
123+
)
124+
125+
# Quick compute on startup:
126+
self.known_malicious_embeddings = self._embed(KNOWN_ATTACKS)
121127

122128
# These _are_ modifyable, but not explicitly advertised.
123129
self.known_attack_scales = DetectJailbreak.DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
@@ -303,9 +309,9 @@ def _inference_local(self, model_input: List[str]) -> Any:
303309
def _inference_remote(self, model_input: List[str]) -> Any:
304310
# This needs to be kept in-sync with app_inference_spec.
305311
request_body = {
306-
"inputs": [
312+
"inputs": [ # Required for legacy reasons.
307313
{
308-
"name": "message",
314+
"name": "prompts",
309315
"shape": [len(model_input)],
310316
"data": model_input,
311317
"datatype": "BYTES"
@@ -316,8 +322,7 @@ def _inference_remote(self, model_input: List[str]) -> Any:
316322
json.dumps(request_body),
317323
self.validation_endpoint
318324
)
319-
if not response or "outputs" not in response:
325+
if not response or "scores" not in response:
320326
raise ValueError("Invalid response from remote inference", response)
321327

322-
data = [output["score"] for output in response["outputs"]]
323-
return data
328+
return response["scores"]

0 commit comments

Comments
 (0)