Fix ClipQuantFusion crash when Clip has multiple input edges (#27016) · microsoft/onnxruntime@cc2b01b

@@ -3221,6 +3221,37 @@ TEST(QDQTransformerTests, ReluQuantFusion_Level2Only) {

32213221

test_case(TransformerLevel::Level3, 0); // Will not fuse Relu into QuantizeLinear due to zero-point != -128

32223222

}

322332233224+

// Test skip removing node when min/max come from DequantizeLinear nodes instead of initializers.

3225+

TEST(QDQTransformerTests, ClipQuantFusion_MultipleInputEdges) {

3226+

auto build_test_case = [&](ModelTestBuilder& builder) {

3227+

// Clip's min coming from another DQ node (creating 2 input edges to Clip)

3228+

auto* input_arg = builder.MakeInput<uint8_t>({1, 2, 2, 2}, std::numeric_limits<uint8_t>::min(),

3229+

std::numeric_limits<uint8_t>::max());

3230+

auto* data_dq = builder.MakeIntermediate();

3231+

builder.AddDequantizeLinearNode<uint8_t>(input_arg, 0.04f, static_cast<uint8_t>(0), data_dq);

3232+

auto* min_q = builder.MakeScalarInitializer<uint8_t>(0);

3233+

auto* min_dq = builder.MakeIntermediate();

3234+

builder.AddDequantizeLinearNode<uint8_t>(min_q, 0.04f, static_cast<uint8_t>(0), min_dq);

3235+

auto* clip_output = builder.MakeIntermediate();

3236+

builder.AddNode("Clip", {data_dq, min_dq}, {clip_output});

3237+

auto* output_q = builder.MakeIntermediate();

3238+

builder.AddQuantizeLinearNode<uint8_t>(clip_output, 0.04f, static_cast<uint8_t>(0), output_q);

3239+

auto* output_arg = builder.MakeOutput();

3240+

builder.AddDequantizeLinearNode<uint8_t>(output_q, 0.04f, static_cast<uint8_t>(0), output_arg);

3241+

};

3242+3243+

auto check_graph = [&](InferenceSessionWrapper& session) {

3244+

auto op_to_count = CountOpsInGraph(session.GetGraph());

3245+

// ClipQuantFusion should skip it due to CanRemoveNode check

3246+

EXPECT_EQ(op_to_count["Clip"], 1);

3247+

};

3248+3249+

TransformerTester(build_test_case, check_graph,

3250+

TransformerLevel::Default,

3251+

TransformerLevel::Level2,

3252+

18); // opset

3253+

}

3254+32243255

template <typename ScaleType, typename ZpType>

32253256

void TestWhereWithDqInput(bool is_dq_1,

32263257

bool is_dq_2,