diff --git a/candle-onnx/src/eval.rs b/candle-onnx/src/eval.rs index 2c60ed2f23..6c0e8114a1 100644 --- a/candle-onnx/src/eval.rs +++ b/candle-onnx/src/eval.rs @@ -1950,6 +1950,50 @@ fn simple_eval_( let output = input.sign()?; values.insert(node.output[0].clone(), output); } + // https://onnx.ai/onnx/operators/onnx__EyeLike.html + "EyeLike" => { + let input = get(&node.input[0])?; + let shape = input.shape(); + if shape.rank() != 2 { + candle::bail!("EyeLike: input must be a 2D tensor"); + } + let (rows, cols) = (shape.dims()[0], shape.dims()[1]); + let device = input.device(); + + let k = get_attr_opt(node, "k")?.copied().unwrap_or(0); + + let dtype = match get_attr_opt(node, "dtype")?.copied() { + Some(1) => DType::F32, + Some(11) => DType::F64, + Some(7) => DType::I64, + Some(12) => DType::U32, + Some(2) => DType::U8, + None => input.dtype(), + Some(dt) => candle::bail!("EyeLike: unsupported dtype {dt}"), + }; + + let row_idx = Tensor::arange(0i64, rows as i64, device)?; + let col_idx = Tensor::arange(0i64, cols as i64, device)?; + + let row_idx = row_idx.reshape((rows, 1))?.broadcast_as((rows, cols))?; + let col_idx = col_idx.reshape((1, cols))?.broadcast_as((rows, cols))?; + let mask = match k.cmp(&0) { + std::cmp::Ordering::Equal => row_idx.eq(&col_idx)?, + std::cmp::Ordering::Greater => { + let k_tensor = Tensor::new(k as i64, device)?.broadcast_as((rows, cols))?; + let col_shifted = col_idx.sub(&k_tensor)?; + row_idx.eq(&col_shifted)? + } + std::cmp::Ordering::Less => { + let k_tensor = Tensor::new((-k) as i64, device)?.broadcast_as((rows, cols))?; + let row_shifted = row_idx.sub(&k_tensor)?; + row_shifted.eq(&col_idx)? + } + }; + + let output = mask.to_dtype(dtype)?; + values.insert(node.output[0].clone(), output); + } op_type => bail!("unsupported op_type {op_type} for op {node:?}"), } } diff --git a/candle-onnx/tests/ops.rs b/candle-onnx/tests/ops.rs index 3586bfbd68..da16e69c70 100644 --- a/candle-onnx/tests/ops.rs +++ b/candle-onnx/tests/ops.rs @@ -5910,3 +5910,146 @@ fn test_sign_operation() -> Result<()> { ); Ok(()) } + + +#[test] +fn test_eyelike_operator() -> candle::Result<()> { + // Test based on: https://github.com/onnx/onnx/blob/main/docs/Operators.md#EyeLike + // === Test: EyeLike with dtype=FLOAT and k=1 === + { + let shape = (4, 5); + let input_name = "x".to_string(); + let output_name = "y".to_string(); + + let model = create_model_proto_with_graph(Some(GraphProto { + input: vec![ValueInfoProto { + name: input_name.clone(), + ..Default::default() + }], + output: vec![ValueInfoProto { + name: output_name.clone(), + ..Default::default() + }], + node: vec![NodeProto { + op_type: "EyeLike".to_string(), + input: vec![input_name.clone()], + output: vec![output_name.clone()], + attribute: vec![ + AttributeProto { + name: "k".to_string(), + r#type: AttributeType::Int as i32, + i: 1, + ..Default::default() + }, + AttributeProto { + name: "dtype".to_string(), + r#type: AttributeType::Int as i32, + i: 1, + ..Default::default() + }, + ], + ..Default::default() + }], + ..Default::default() + })); + + let mut inputs = HashMap::new(); + inputs.insert(input_name.clone(), Tensor::zeros(shape, DType::I64, &Device::Cpu)?); + + let outputs = simple_eval(&model, inputs)?; + let actual = outputs.get(&output_name).unwrap().to_vec2::()?; + + let expected = vec![ + vec![0., 1., 0., 0., 0.], + vec![0., 0., 1., 0., 0.], + vec![0., 0., 0., 1., 0.], + vec![0., 0., 0., 0., 1.], + ]; + assert_eq!(actual, expected); + } + + // === Test: EyeLike with dtype=DOUBLE === + { + let shape = (3, 4); + let input_name = "x".to_string(); + let output_name = "y".to_string(); + + let model = create_model_proto_with_graph(Some(GraphProto { + input: vec![ValueInfoProto { + name: input_name.clone(), + ..Default::default() + }], + output: vec![ValueInfoProto { + name: output_name.clone(), + ..Default::default() + }], + node: vec![NodeProto { + op_type: "EyeLike".to_string(), + input: vec![input_name.clone()], + output: vec![output_name.clone()], + attribute: vec![AttributeProto { + name: "dtype".to_string(), + r#type: AttributeType::Int as i32, + i: 11, + ..Default::default() + }], + ..Default::default() + }], + ..Default::default() + })); + + let mut inputs = HashMap::new(); + inputs.insert(input_name.clone(), Tensor::zeros(shape, DType::I64, &Device::Cpu)?); + + let outputs = simple_eval(&model, inputs)?; + let actual = outputs.get(&output_name).unwrap().to_vec2::()?; + + let expected = vec![ + vec![1., 0., 0., 0.], + vec![0., 1., 0., 0.], + vec![0., 0., 1., 0.], + ]; + assert_eq!(actual, expected); + } + + // === Test: EyeLike without dtype (inherits from input) === + { + let shape = (4, 4); + let input_name = "x".to_string(); + let output_name = "y".to_string(); + + let model = create_model_proto_with_graph(Some(GraphProto { + input: vec![ValueInfoProto { + name: input_name.clone(), + ..Default::default() + }], + output: vec![ValueInfoProto { + name: output_name.clone(), + ..Default::default() + }], + node: vec![NodeProto { + op_type: "EyeLike".to_string(), + input: vec![input_name.clone()], + output: vec![output_name.clone()], + ..Default::default() + }], + ..Default::default() + })); + + let mut inputs = HashMap::new(); + inputs.insert(input_name.clone(), Tensor::zeros(shape, DType::I64, &Device::Cpu)?); + + let outputs = simple_eval(&model, inputs)?; + let actual = outputs.get(&output_name).unwrap().to_vec2::()?; + + let expected = vec![ + vec![1, 0, 0, 0], + vec![0, 1, 0, 0], + vec![0, 0, 1, 0], + vec![0, 0, 0, 1], + ]; + assert_eq!(actual, expected); + } + + Ok(()) +} \ No newline at end of file