@@ -71,53 +71,59 @@ def __init__(
71
71
super ().__init__ (on_fail = on_fail , ** kwargs )
72
72
self .device = device
73
73
self .threshold = threshold
74
-
75
- if not model_path_override :
76
- self .saturation_attack_detector = PromptSaturationDetectorV3 (
77
- device = torch .device (device ),
78
- )
79
- self .text_classifier = pipeline (
80
- "text-classification" ,
81
- DetectJailbreak .TEXT_CLASSIFIER_NAME ,
82
- max_length = 512 , # HACK: Fix classifier size.
83
- truncation = True ,
84
- device = device ,
85
- )
86
- # There are a large number of fairly low-effort prompts people will use.
87
- # The embedding detectors do checks to roughly match those.
88
- self .embedding_tokenizer = AutoTokenizer .from_pretrained (
89
- DetectJailbreak .EMBEDDING_MODEL_NAME
90
- )
91
- self .embedding_model = AutoModel .from_pretrained (
92
- DetectJailbreak .EMBEDDING_MODEL_NAME
93
- ).to (device )
94
- else :
95
- # Saturation:
96
- self .saturation_attack_detector = PromptSaturationDetectorV3 (
97
- device = torch .device (device ),
98
- model_path_override = model_path_override
99
- )
100
- # Known attacks:
101
- embedding_tokenizer , embedding_model = get_tokenizer_and_model_by_path (
102
- model_path_override ,
103
- "embedding" ,
104
- AutoTokenizer ,
105
- AutoModel
106
- )
107
- self .embedding_tokenizer = embedding_tokenizer
108
- self .embedding_model = embedding_model .to (device )
109
- # Other text attacks:
110
- self .text_classifier = get_pipeline_by_path (
111
- model_path_override ,
112
- "text-classifier" ,
113
- "text-classification" ,
114
- max_length = 512 ,
115
- truncation = True ,
116
- device = device
117
- )
118
-
119
- # Quick compute on startup:
120
- self .known_malicious_embeddings = self ._embed (KNOWN_ATTACKS )
74
+ self .saturation_attack_detector = None
75
+ self .text_classifier = None
76
+ self .embedding_tokenizer = None
77
+ self .embedding_model = None
78
+ self .known_malicious_embeddings = []
79
+
80
+ if self .use_local :
81
+ if not model_path_override :
82
+ self .saturation_attack_detector = PromptSaturationDetectorV3 (
83
+ device = torch .device (device ),
84
+ )
85
+ self .text_classifier = pipeline (
86
+ "text-classification" ,
87
+ DetectJailbreak .TEXT_CLASSIFIER_NAME ,
88
+ max_length = 512 , # HACK: Fix classifier size.
89
+ truncation = True ,
90
+ device = device ,
91
+ )
92
+ # There are a large number of fairly low-effort prompts people will use.
93
+ # The embedding detectors do checks to roughly match those.
94
+ self .embedding_tokenizer = AutoTokenizer .from_pretrained (
95
+ DetectJailbreak .EMBEDDING_MODEL_NAME
96
+ )
97
+ self .embedding_model = AutoModel .from_pretrained (
98
+ DetectJailbreak .EMBEDDING_MODEL_NAME
99
+ ).to (device )
100
+ else :
101
+ # Saturation:
102
+ self .saturation_attack_detector = PromptSaturationDetectorV3 (
103
+ device = torch .device (device ),
104
+ model_path_override = model_path_override
105
+ )
106
+ # Known attacks:
107
+ embedding_tokenizer , embedding_model = get_tokenizer_and_model_by_path (
108
+ model_path_override ,
109
+ "embedding" ,
110
+ AutoTokenizer ,
111
+ AutoModel
112
+ )
113
+ self .embedding_tokenizer = embedding_tokenizer
114
+ self .embedding_model = embedding_model .to (device )
115
+ # Other text attacks:
116
+ self .text_classifier = get_pipeline_by_path (
117
+ model_path_override ,
118
+ "text-classifier" ,
119
+ "text-classification" ,
120
+ max_length = 512 ,
121
+ truncation = True ,
122
+ device = device
123
+ )
124
+
125
+ # Quick compute on startup:
126
+ self .known_malicious_embeddings = self ._embed (KNOWN_ATTACKS )
121
127
122
128
# These _are_ modifyable, but not explicitly advertised.
123
129
self .known_attack_scales = DetectJailbreak .DEFAULT_KNOWN_ATTACK_SCALE_FACTORS
@@ -303,9 +309,9 @@ def _inference_local(self, model_input: List[str]) -> Any:
303
309
def _inference_remote (self , model_input : List [str ]) -> Any :
304
310
# This needs to be kept in-sync with app_inference_spec.
305
311
request_body = {
306
- "inputs" : [
312
+ "inputs" : [ # Required for legacy reasons.
307
313
{
308
- "name" : "message " ,
314
+ "name" : "prompts " ,
309
315
"shape" : [len (model_input )],
310
316
"data" : model_input ,
311
317
"datatype" : "BYTES"
@@ -316,8 +322,7 @@ def _inference_remote(self, model_input: List[str]) -> Any:
316
322
json .dumps (request_body ),
317
323
self .validation_endpoint
318
324
)
319
- if not response or "outputs " not in response :
325
+ if not response or "scores " not in response :
320
326
raise ValueError ("Invalid response from remote inference" , response )
321
327
322
- data = [output ["score" ] for output in response ["outputs" ]]
323
- return data
328
+ return response ["scores" ]
0 commit comments