Google OR-Tools: ortools/graph/linear_assignment.h Source File
193#ifndef ORTOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
194#define ORTOOLS_GRAPH_LINEAR_ASSIGNMENT_H_
206#include "absl/flags/declare.h"
207#include "absl/flags/flag.h"
208#include "absl/strings/str_format.h"
224template <typename GraphType, typename CostValue = int64_t>
227 typedef typename GraphType::NodeIndex NodeIndex;
228 typedef typename GraphType::ArcIndex ArcIndex;
270 inline const GraphType& Graph() const { return *graph_; }
281 DCHECK_EQ(0, scaled_arc_cost_[arc] % cost_scaling_factor_);
340 DCHECK_LT(left_node, num_left_nodes_);
342 DCHECK_NE(GraphType::kNilArc, matching_arc);
343 return Head(matching_arc);
346 std::string StatsString() const { return total_stats_.StatsString(); }
361 Stats() : pushes(0), double_pushes(0), relabelings(0), refinements(0) {}
368 void Add(const Stats& that) {
370 double_pushes += that.double_pushes;
371 relabelings += that.relabelings;
372 refinements += that.refinements;
376 "%d refinements; %d relabelings; "
377 "%d double pushes; %d pushes",
378 refinements, relabelings, double_pushes, pushes);
387 class ActiveNodeContainerInterface {
389 virtual ~ActiveNodeContainerInterface() {}
390 virtual bool Empty() const = 0;
391 virtual void Add(NodeIndex node) = 0;
395 class ActiveNodeStack : public ActiveNodeContainerInterface {
397 ~ActiveNodeStack() override {}
399 bool Empty() const override { return v_.empty(); }
401 void Add(NodeIndex node) override { v_.push_back(node); }
411 std::vector<NodeIndex> v_;
414 class ActiveNodeQueue : public ActiveNodeContainerInterface {
416 ~ActiveNodeQueue() override {}
418 bool Empty() const override { return q_.empty(); }
420 void Add(NodeIndex node) override { q_.push_front(node); }
441 typedef std::pair<ArcIndex, CostValue> ImplicitPriceSummary;
452 inline ImplicitPriceSummary BestArcAndGap(NodeIndex left_node) const;
456 void ReportAndAccumulateStats();
471 inline bool IsActive(NodeIndex left_node) const;
478 inline bool IsActiveForDebugging(NodeIndex node) const;
485 void InitializeActiveNodeContainer();
493 void SaturateNegativeArcs();
498 bool DoublePush(NodeIndex source);
502 return scaled_arc_cost_[arc] - price_[Head(arc)];
516 bool incidence_precondition_satisfied_;
526 const CostValue cost_scaling_factor_;
533 static constexpr CostValue kMinEpsilon = 1;
830 CostValue slack_relabeling_price_;
844 const CostValue n = graph_->num_nodes();
857 static_cast<double>(std::max<CostValue>(1, n / 2 - 1)) *
858 (static_cast<double>(old_epsilon) + static_cast<double>(new_epsilon));
860 static_cast<double>(std::numeric_limits<CostValue>::max());
863 if (in_range != nullptr) *in_range = false;
864 return std::numeric_limits<CostValue>::max();
868 return static_cast<CostValue>(result);
884 CostValue largest_scaled_cost_magnitude_;
897 ZVector<CostValue> price_;
902 std::vector<ArcIndex> matched_arc_;
910 ZVector<NodeIndex> matched_node_;
915 std::vector<CostValue> scaled_arc_cost_;
920 std::unique_ptr<ActiveNodeContainerInterface> active_nodes_;
934template <typename GraphType, typename CostValue>
936 const GraphType& graph, const NodeIndex num_left_nodes)
938 num_left_nodes_(num_left_nodes),
940 cost_scaling_factor_(1 + num_left_nodes),
941 alpha_(absl::GetFlag(FLAGS_assignment_alpha)),
944 slack_relabeling_price_(0),
945 largest_scaled_cost_magnitude_(0),
947 price_(num_left_nodes, 2 * num_left_nodes - 1),
948 matched_arc_(num_left_nodes, 0),
949 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
950 scaled_arc_cost_(graph.arc_capacity(), 0),
951 active_nodes_(absl::GetFlag(FLAGS_assignment_stack_order)
952 ? static_cast<ActiveNodeContainerInterface*>(
957template <typename GraphType, typename CostValue>
961 num_left_nodes_(num_left_nodes),
963 cost_scaling_factor_(1 + num_left_nodes),
964 alpha_(absl::GetFlag(FLAGS_assignment_alpha)),
967 slack_relabeling_price_(0),
968 largest_scaled_cost_magnitude_(0),
970 price_(num_left_nodes, 2 * num_left_nodes - 1),
971 matched_arc_(num_left_nodes, 0),
972 matched_node_(num_left_nodes, 2 * num_left_nodes - 1),
973 scaled_arc_cost_(num_arcs, 0),
974 active_nodes_(absl::GetFlag(FLAGS_assignment_stack_order)
975 ? static_cast<ActiveNodeContainerInterface*>(
980template <typename GraphType, typename CostValue>
985 DCHECK_LT(arc, graph_->num_arcs());
987 DCHECK_LE(num_left_nodes_, head);
989 cost *= cost_scaling_factor_;
990 const CostValue cost_magnitude = std::abs(cost);
991 largest_scaled_cost_magnitude_ =
992 std::max(largest_scaled_cost_magnitude_, cost_magnitude);
996template <typename ArcIndexType, typename CostValue>
1031template <typename GraphType>
1055template <typename GraphType, typename CostValue>
1056PermutationCycleHandler<typename GraphType::ArcIndex>*
1062template <typename GraphType, typename CostValue>
1063CostValue LinearSumAssignment<GraphType, CostValue>::NewEpsilon(
1064 const CostValue current_epsilon) const {
1065 return std::max(current_epsilon / alpha_, kMinEpsilon);
1068template <typename GraphType, typename CostValue>
1069bool LinearSumAssignment<GraphType, CostValue>::UpdateEpsilon() {
1070 CostValue new_epsilon = NewEpsilon(epsilon_);
1071 slack_relabeling_price_ = PriceChangeBound(epsilon_, new_epsilon, nullptr);
1073 VLOG(3) << "Updated: epsilon_ == " << epsilon_;
1074 VLOG(4) << "slack_relabeling_price_ == " << slack_relabeling_price_;
1075 DCHECK_GT(slack_relabeling_price_, 0);
1083template <typename GraphType, typename CostValue>
1084inline bool LinearSumAssignment<GraphType, CostValue>::IsActive(
1086 DCHECK_LT(left_node, num_left_nodes_);
1087 return matched_arc_[left_node] == GraphType::kNilArc;
1093template <typename GraphType, typename CostValue>
1094inline bool LinearSumAssignment<GraphType, CostValue>::IsActiveForDebugging(
1096 if (node < num_left_nodes_) {
1099 return matched_node_[node] == GraphType::kNilNode;
1103template <typename GraphType, typename CostValue>
1105 CostValue>::InitializeActiveNodeContainer() {
1106 DCHECK(active_nodes_->Empty());
1107 for (const NodeIndex node : BipartiteLeftNodes()) {
1109 active_nodes_->Add(node);
1124template <typename GraphType, typename CostValue>
1125void LinearSumAssignment<GraphType, CostValue>::SaturateNegativeArcs() {
1127 for (const NodeIndex node : BipartiteLeftNodes()) {
1135 const NodeIndex mate = GetMate(node);
1136 matched_arc_[node] = GraphType::kNilArc;
1137 matched_node_[mate] = GraphType::kNilNode;
1143template <typename GraphType, typename CostValue>
1144bool LinearSumAssignment<GraphType, CostValue>::DoublePush(NodeIndex source) {
1145 DCHECK_GT(num_left_nodes_, source);
1146 DCHECK(IsActive(source)) << "Node " << source
1147 << "must be active (unmatched)!";
1148 ImplicitPriceSummary summary = BestArcAndGap(source);
1149 const ArcIndex best_arc = summary.first;
1150 const CostValue gap = summary.second;
1154 if (best_arc == GraphType::kNilArc) {
1157 const NodeIndex new_mate = Head(best_arc);
1158 const NodeIndex to_unmatch = matched_node_[new_mate];
1159 if (to_unmatch != GraphType::kNilNode) {
1162 matched_arc_[to_unmatch] = GraphType::kNilArc;
1163 active_nodes_->Add(to_unmatch);
1165 iteration_stats_.double_pushes += 1;
1170 iteration_stats_.pushes += 1;
1172 matched_arc_[source] = best_arc;
1173 matched_node_[new_mate] = source;
1175 iteration_stats_.relabelings += 1;
1176 const CostValue new_price = price_[new_mate] - gap - epsilon_;
1177 price_[new_mate] = new_price;
1178 return new_price >= price_lower_bound_;
1181template <typename GraphType, typename CostValue>
1182bool LinearSumAssignment<GraphType, CostValue>::Refine() {
1184 InitializeActiveNodeContainer();
1185 while (total_excess_ > 0) {
1188 const NodeIndex node = active_nodes_->Get();
1197 LOG_IF(DFATAL, total_stats_.refinements > 0)
1198 << "Infeasibility detection triggered after first iteration found "
1199 << "a feasible assignment!";
1203 DCHECK(active_nodes_->Empty());
1204 iteration_stats_.refinements += 1;
1222template <typename GraphType, typename CostValue>
1223inline typename LinearSumAssignment<GraphType, CostValue>::ImplicitPriceSummary
1224LinearSumAssignment<GraphType, CostValue>::BestArcAndGap(
1226 DCHECK(IsActive(left_node))
1227 << "Node " << left_node << " must be active (unmatched)!";
1229 const auto arcs = graph_->OutgoingArcs(left_node);
1230 auto arc_it = arcs.begin();
1232 ArcIndex best_arc = *arc_it;
1233 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1239 const CostValue max_gap = slack_relabeling_price_ - epsilon_;
1240 CostValue second_min_partial_reduced_cost =
1241 min_partial_reduced_cost + max_gap;
1242 for (++arc_it; arc_it != arcs.end(); ++arc_it) {
1243 const ArcIndex arc = *arc_it;
1244 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1245 if (partial_reduced_cost < second_min_partial_reduced_cost) {
1246 if (partial_reduced_cost < min_partial_reduced_cost) {
1248 second_min_partial_reduced_cost = min_partial_reduced_cost;
1249 min_partial_reduced_cost = partial_reduced_cost;
1251 second_min_partial_reduced_cost = partial_reduced_cost;
1255 const CostValue gap = std::min<CostValue>(
1256 second_min_partial_reduced_cost - min_partial_reduced_cost, max_gap);
1258 return std::make_pair(best_arc, gap);
1265template <typename GraphType, typename CostValue>
1266inline CostValue LinearSumAssignment<GraphType, CostValue>::ImplicitPrice(
1268 DCHECK_GT(num_left_nodes_, left_node);
1270 const auto arcs = graph_->OutgoingArcs(left_node);
1273 auto arc_it = arcs.begin();
1274 ArcIndex best_arc = *arc_it;
1275 if (best_arc == matched_arc_[left_node]) {
1277 if (arc_it != arcs.end()) {
1281 CostValue min_partial_reduced_cost = PartialReducedCost(best_arc);
1282 if (arc_it == arcs.end()) {
1287 return -(min_partial_reduced_cost + slack_relabeling_price_);
1289 for (++arc_it; arc_it != arcs.end(); ++arc_it) {
1290 const ArcIndex arc = *arc_it;
1291 if (arc != matched_arc_[left_node]) {
1292 const CostValue partial_reduced_cost = PartialReducedCost(arc);
1293 if (partial_reduced_cost < min_partial_reduced_cost) {
1294 min_partial_reduced_cost = partial_reduced_cost;
1298 return -min_partial_reduced_cost;
1302template <typename GraphType, typename CostValue>
1303bool LinearSumAssignment<GraphType, CostValue>::AllMatched() const {
1304 for (NodeIndex node = 0; node < graph_->num_nodes(); ++node) {
1305 if (IsActiveForDebugging(node)) {
1313template <typename GraphType, typename CostValue>
1318 CostValue left_node_price = ImplicitPrice(left_node);
1319 for (const ArcIndex arc : graph_->OutgoingArcs(left_node)) {
1320 const CostValue reduced_cost = left_node_price + PartialReducedCost(arc);
1325 if (matched_arc_[left_node] == arc) {
1344template <typename GraphType, typename CostValue>
1346 incidence_precondition_satisfied_ = true;
1350 epsilon_ = std::max(largest_scaled_cost_magnitude_, kMinEpsilon + 1);
1351 VLOG(2) << "Largest given cost magnitude: "
1352 << largest_scaled_cost_magnitude_ / cost_scaling_factor_;
1355 for (NodeIndex node = 0; node < num_left_nodes_; ++node) {
1356 matched_arc_[node] = GraphType::kNilArc;
1357 if (graph_->OutgoingArcs(node).empty()) {
1358 incidence_precondition_satisfied_ = false;
1363 for (NodeIndex node = num_left_nodes_; node < graph_->num_nodes(); ++node) {
1365 matched_node_[node] = GraphType::kNilNode;
1368 double double_price_lower_bound = 0.0;
1370 CostValue old_error_parameter = epsilon_;
1372 new_error_parameter = NewEpsilon(old_error_parameter);
1373 double_price_lower_bound -=
1374 2.0 * static_cast<double>(PriceChangeBound(
1375 old_error_parameter, new_error_parameter, &in_range));
1376 old_error_parameter = new_error_parameter;
1377 } while (new_error_parameter != kMinEpsilon);
1379 -static_cast<double>(std::numeric_limits<CostValue>::max());
1380 if (double_price_lower_bound < limit) {
1382 price_lower_bound_ = -std::numeric_limits<CostValue>::max();
1384 price_lower_bound_ = static_cast<CostValue>(double_price_lower_bound);
1386 VLOG(4) << "price_lower_bound_ == " << price_lower_bound_;
1387 DCHECK_LE(price_lower_bound_, 0);
1389 LOG(WARNING) << "Price change bound exceeds range of representable "
1390 << "costs; arithmetic overflow is not ruled out and "
1396template <typename GraphType, typename CostValue>
1397void LinearSumAssignment<GraphType, CostValue>::ReportAndAccumulateStats() {
1398 total_stats_.Add(iteration_stats_);
1399 VLOG(3) << "Iteration stats: " << iteration_stats_.StatsString();
1400 iteration_stats_.Clear();
1403template <typename GraphType, typename CostValue>
1405 CHECK(graph_ != nullptr);
1406 bool ok = graph_->num_nodes() == 2 * num_left_nodes_;
1414 ok = ok && incidence_precondition_satisfied_;
1416 while (ok && epsilon_ > kMinEpsilon) {
1417 ok = ok && UpdateEpsilon();
1419 ReportAndAccumulateStats();
1421 DCHECK(!ok || AllMatched());
1424 VLOG(1) << "Overall stats: " << total_stats_.StatsString();
1428template <typename GraphType, typename CostValue>
bool operator()(typename GraphType::ArcIndex a, typename GraphType::ArcIndex b) const
Definition linear_assignment.h:1039
ArcIndexOrderingByTailNode(const GraphType &graph)
Definition linear_assignment.h:1034
CostValueCycleHandler(const CostValueCycleHandler &)=delete
CostValueCycleHandler & operator=(const CostValueCycleHandler &)=delete
void SetTempFromIndex(ArcIndexType source) override
Definition linear_assignment.h:1006
void SetIndexFromTemp(ArcIndexType destination) const override
Definition linear_assignment.h:1015
~CostValueCycleHandler() override
Definition linear_assignment.h:1019
CostValueCycleHandler(std::vector< CostValue > *cost)
Definition linear_assignment.h:999
void SetIndexFromIndex(ArcIndexType source, ArcIndexType destination) const override
Definition linear_assignment.h:1010
void SetArcCost(ArcIndex arc, CostValue cost)
Definition linear_assignment.h:981
CostValue GetAssignmentCost(NodeIndex node) const
Definition linear_assignment.h:334
operations_research::PermutationCycleHandler< typename GraphType::ArcIndex > * ArcAnnotationCycleHandler()
Definition linear_assignment.h:1057
std::string StatsString() const
Definition linear_assignment.h:346
NodeIndex NumLeftNodes() const
Definition linear_assignment.h:324
NodeIndex NumNodes() const
Definition linear_assignment.h:312
NodeIndex Head(ArcIndex arc) const
Definition linear_assignment.h:276
ArcIndex GetAssignmentArc(NodeIndex left_node) const
Definition linear_assignment.h:327
bool ComputeAssignment()
Definition linear_assignment.h:1404
~LinearSumAssignment()
Definition linear_assignment.h:246
CostValue CostValueT
Definition linear_assignment.h:229
const GraphType & Graph() const
Definition linear_assignment.h:270
NodeIndex GetMate(NodeIndex left_node) const
Definition linear_assignment.h:339
GraphType::NodeIndex NodeIndex
Definition linear_assignment.h:227
bool FinalizeSetup()
Definition linear_assignment.h:1345
bool EpsilonOptimal() const
Definition linear_assignment.h:1314
CostValue GetCost() const
Definition linear_assignment.h:1429
::util::IntegerRange< NodeIndex > BipartiteLeftNodes() const
Definition linear_assignment.h:349
LinearSumAssignment(const LinearSumAssignment &)=delete
void SetGraph(const GraphType *graph)
Definition linear_assignment.h:251
CostValue ArcCost(ArcIndex arc) const
Definition linear_assignment.h:280
LinearSumAssignment(const GraphType &graph, NodeIndex num_left_nodes)
Definition linear_assignment.h:935
LinearSumAssignment & operator=(const LinearSumAssignment &)=delete
void SetCostScalingDivisor(CostValue factor)
Definition linear_assignment.h:258
GraphType::ArcIndex ArcIndex
Definition linear_assignment.h:228
PermutationCycleHandler(const PermutationCycleHandler &)=delete
OR_DLL ABSL_DECLARE_FLAG(int64_t, assignment_alpha)
BlossomGraph::CostValue CostValue
BlossomGraph::NodeIndex NodeIndex