@@ -3,7 +3,7 @@ use crate::{
3
3
request:: { DetokenizationRequest , NormalRequest , TokenizationRequest } ,
4
4
sequence:: SeqStepType ,
5
5
tools:: { ToolCallingMatcher , ToolChoice } ,
6
- RequestMessage , Response ,
6
+ ModelCategory , RequestMessage , Response ,
7
7
} ;
8
8
use candle_core:: Tensor ;
9
9
use either:: Either ;
@@ -88,6 +88,32 @@ impl Engine {
88
88
return ;
89
89
}
90
90
91
+ // Verify the model's category matches the messages received.
92
+ match (
93
+ get_mut_arcmutex ! ( self . pipeline) . category ( ) ,
94
+ & request. messages ,
95
+ ) {
96
+ (
97
+ ModelCategory :: Text | ModelCategory :: Vision { .. } ,
98
+ RequestMessage :: Chat { .. }
99
+ | RequestMessage :: VisionChat { .. }
100
+ | RequestMessage :: Completion { .. }
101
+ | RequestMessage :: CompletionTokens ( _) ,
102
+ ) => ( ) ,
103
+ ( ModelCategory :: Diffusion , RequestMessage :: ImageGeneration { .. } ) => ( ) ,
104
+ ( ModelCategory :: Speech , RequestMessage :: SpeechGeneration { .. } ) => ( ) ,
105
+ _ => {
106
+ request
107
+ . response
108
+ . send ( Response :: ValidationError (
109
+ "Received a request incompatible for this model's category." . into ( ) ,
110
+ ) )
111
+ . await
112
+ . expect ( "Expected receiver." ) ;
113
+ return ;
114
+ }
115
+ }
116
+
91
117
let images = match request. messages {
92
118
RequestMessage :: VisionChat {
93
119
ref images,
0 commit comments