Google OR-Tools: ortools/constraint_solver/element.cc Source File

1

2

3

4

5

6

7

8

9

10

11

12

13

14#include <algorithm>

15#include <cstdint>

16#include <functional>

17#include <limits>

18#include <memory>

19#include <numeric>

20#include <string>

21#include <utility>

22#include <vector>

23

24#include "absl/strings/str_format.h"

25#include "absl/strings/str_join.h"

32

33ABSL_FLAG(bool, cp_disable_element_cache, true,

34 "If true, caching for IntElement is disabled.");

35

37

38

39void LinkVarExpr(Solver* s, IntExpr* expr, IntVar* var);

40

41namespace {

42

43template <class T>

44class VectorLess {

45 public:

46 explicit VectorLess(const std::vector<T>* values) : values_(values) {}

47 bool operator()(const T& x, const T& y) const {

48 return (*values_)[x] < (*values_)[y];

49 }

50

51 private:

52 const std::vector<T>* values_;

53};

54

55template <class T>

56class VectorGreater {

57 public:

58 explicit VectorGreater(const std::vector<T>* values) : values_(values) {}

59 bool operator()(const T& x, const T& y) const {

60 return (*values_)[x] > (*values_)[y];

61 }

62

63 private:

64 const std::vector<T>* values_;

65};

66

67

68

69class BaseIntExprElement : public BaseIntExpr {

70 public:

71 BaseIntExprElement(Solver* s, IntVar* e);

72 ~BaseIntExprElement() override {}

73 int64_t Min() const override;

74 int64_t Max() const override;

75 void Range(int64_t* mi, int64_t* ma) override;

76 void SetMin(int64_t m) override;

77 void SetMax(int64_t m) override;

78 void SetRange(int64_t mi, int64_t ma) override;

79 bool Bound() const override { return (expr_->Bound()); }

80

81 void WhenRange(Demon* d) override { expr_->WhenRange(d); }

82

83 protected:

84 virtual int64_t ElementValue(int index) const = 0;

85 virtual int64_t ExprMin() const = 0;

86 virtual int64_t ExprMax() const = 0;

87

88 IntVar* const expr_;

89

90 private:

91 void UpdateSupports() const;

92 template <typename T>

93 void UpdateElementIndexBounds(T check_value) {

94 const int64_t emin = ExprMin();

95 const int64_t emax = ExprMax();

96 int64_t nmin = emin;

97 int64_t value = ElementValue(nmin);

98 while (nmin < emax && check_value(value)) {

99 nmin++;

100 value = ElementValue(nmin);

101 }

102 if (nmin == emax && check_value(value)) {

103 solver()->Fail();

104 }

105 int64_t nmax = emax;

106 value = ElementValue(nmax);

107 while (nmax >= nmin && check_value(value)) {

108 nmax--;

109 value = ElementValue(nmax);

110 }

111 expr_->SetRange(nmin, nmax);

112 }

113

114 mutable int64_t min_;

115 mutable int min_support_;

116 mutable int64_t max_;

117 mutable int max_support_;

118 mutable bool initial_update_;

119 IntVarIterator* const expr_iterator_;

120};

121

122BaseIntExprElement::BaseIntExprElement(Solver* const s, IntVar* const e)

124 expr_(e),

125 min_(0),

126 min_support_(-1),

127 max_(0),

128 max_support_(-1),

129 initial_update_(true),

130 expr_iterator_(expr_->MakeDomainIterator(true)) {

131 CHECK(s != nullptr);

132 CHECK(e != nullptr);

133}

134

135int64_t BaseIntExprElement::Min() const {

136 UpdateSupports();

137 return min_;

138}

139

140int64_t BaseIntExprElement::Max() const {

141 UpdateSupports();

142 return max_;

143}

144

145void BaseIntExprElement::Range(int64_t* mi, int64_t* ma) {

146 UpdateSupports();

147 *mi = min_;

148 *ma = max_;

149}

150

151void BaseIntExprElement::SetMin(int64_t m) {

152 UpdateElementIndexBounds([m](int64_t value) { return value < m; });

153}

154

155void BaseIntExprElement::SetMax(int64_t m) {

156 UpdateElementIndexBounds([m](int64_t value) { return value > m; });

157}

158

159void BaseIntExprElement::SetRange(int64_t mi, int64_t ma) {

160 if (mi > ma) {

161 solver()->Fail();

162 }

163 UpdateElementIndexBounds(

164 [mi, ma](int64_t value) { return value < mi || value > ma; });

165}

166

167void BaseIntExprElement::UpdateSupports() const {

168 if (initial_update_ || !expr_->Contains(min_support_) ||

169 !expr_->Contains(max_support_)) {

170 const int64_t emin = ExprMin();

171 const int64_t emax = ExprMax();

172 int64_t min_value = ElementValue(emax);

173 int64_t max_value = min_value;

174 int min_support = emax;

175 int max_support = emax;

176 const uint64_t expr_size = expr_->Size();

177 if (expr_size > 1) {

178 if (expr_size == emax - emin + 1) {

179

180 for (int64_t index = emin; index < emax; ++index) {

181 const int64_t value = ElementValue(index);

182 if (value > max_value) {

183 max_value = value;

184 max_support = index;

185 } else if (value < min_value) {

186 min_value = value;

187 min_support = index;

188 }

189 }

190 } else {

191 for (const int64_t index : InitAndGetValues(expr_iterator_)) {

192 if (index >= emin && index <= emax) {

193 const int64_t value = ElementValue(index);

194 if (value > max_value) {

195 max_value = value;

196 max_support = index;

197 } else if (value < min_value) {

198 min_value = value;

199 min_support = index;

200 }

201 }

202 }

203 }

204 }

205 Solver* s = solver();

206 s->SaveAndSetValue(&min_, min_value);

207 s->SaveAndSetValue(&min_support_, min_support);

208 s->SaveAndSetValue(&max_, max_value);

209 s->SaveAndSetValue(&max_support_, max_support);

210 s->SaveAndSetValue(&initial_update_, false);

211 }

212}

213

214

215

216

217

218

219class IntElementConstraint : public CastConstraint {

220 public:

221 IntElementConstraint(Solver* const s, const std::vector<int64_t>& values,

222 IntVar* const index, IntVar* const elem)

223 : CastConstraint(s, elem),

224 values_(values),

225 index_(index),

226 index_iterator_(index_->MakeDomainIterator(true)) {

227 CHECK(index != nullptr);

228 }

229

230 void Post() override {

231 Demon* const d =

232 solver()->MakeDelayedConstraintInitialPropagateCallback(this);

233 index_->WhenDomain(d);

234 target_var_->WhenRange(d);

235 }

236

237 void InitialPropagate() override {

238 index_->SetRange(0, values_.size() - 1);

239 const int64_t target_var_min = target_var_->Min();

240 const int64_t target_var_max = target_var_->Max();

241 int64_t new_min = target_var_max;

242 int64_t new_max = target_var_min;

243 to_remove_.clear();

244 for (const int64_t index : InitAndGetValues(index_iterator_)) {

245 const int64_t value = values_[index];

246 if (value < target_var_min || value > target_var_max) {

247 to_remove_.push_back(index);

248 } else {

249 if (value < new_min) {

250 new_min = value;

251 }

252 if (value > new_max) {

253 new_max = value;

254 }

255 }

256 }

257 target_var_->SetRange(new_min, new_max);

258 if (!to_remove_.empty()) {

259 index_->RemoveValues(to_remove_);

260 }

261 }

262

263 std::string DebugString() const override {

264 return absl::StrFormat("IntElementConstraint(%s, %s, %s)",

265 absl::StrJoin(values_, ", "), index_->DebugString(),

266 target_var_->DebugString());

267 }

268

269 void Accept(ModelVisitor* const visitor) const override {

270 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);

271 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);

272 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

273 index_);

274 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,

275 target_var_);

276 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);

277 }

278

279 private:

280 const std::vector<int64_t> values_;

281 IntVar* const index_;

282 IntVarIterator* const index_iterator_;

283 std::vector<int64_t> to_remove_;

284};

285

286

287

288IntVar* BuildDomainIntVar(Solver* solver, std::vector<int64_t>* values);

289

290class IntExprElement : public BaseIntExprElement {

291 public:

292 IntExprElement(Solver* const s, const std::vector<int64_t>& vals,

293 IntVar* const expr)

294 : BaseIntExprElement(s, expr), values_(vals) {}

295

296 ~IntExprElement() override {}

297

298 std::string name() const override {

299 const int size = values_.size();

300 if (size > 10) {

301 return absl::StrFormat("IntElement(array of size %d, %s)", size,

302 expr_->name());

303 } else {

304 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),

305 expr_->name());

306 }

307 }

308

309 std::string DebugString() const override {

310 const int size = values_.size();

311 if (size > 10) {

312 return absl::StrFormat("IntElement(array of size %d, %s)", size,

313 expr_->DebugString());

314 } else {

315 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),

316 expr_->DebugString());

317 }

318 }

319

320 IntVar* CastToVar() override {

321 Solver* const s = solver();

322 IntVar* const var = s->MakeIntVar(values_);

323 s->AddCastConstraint(

324 s->RevAlloc(new IntElementConstraint(s, values_, expr_, var)), var,

325 this);

326 return var;

327 }

328

329 void Accept(ModelVisitor* const visitor) const override {

330 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);

331 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);

332 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

333 expr_);

334 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);

335 }

336

337 protected:

338 int64_t ElementValue(int index) const override {

339 DCHECK_LT(index, values_.size());

340 return values_[index];

341 }

342 int64_t ExprMin() const override {

343 return std::max<int64_t>(0, expr_->Min());

344 }

345 int64_t ExprMax() const override {

346 return values_.empty()

347 ? 0

348 : std::min<int64_t>(values_.size() - 1, expr_->Max());

349 }

350

351 private:

352 const std::vector<int64_t> values_;

353};

354

355

356

357class RangeMinimumQueryExprElement : public BaseIntExpr {

358 public:

359 RangeMinimumQueryExprElement(Solver* solver,

360 const std::vector<int64_t>& values,

361 IntVar* index);

362 ~RangeMinimumQueryExprElement() override {}

363 int64_t Min() const override;

364 int64_t Max() const override;

365 void Range(int64_t* mi, int64_t* ma) override;

366 void SetMin(int64_t m) override;

367 void SetMax(int64_t m) override;

368 void SetRange(int64_t mi, int64_t ma) override;

369 bool Bound() const override { return (index_->Bound()); }

370

371 void WhenRange(Demon* d) override { index_->WhenRange(d); }

372 IntVar* CastToVar() override {

373

374

375

376 IntVar* const var = solver()->MakeIntVar(min_rmq_.array());

377 solver()->AddCastConstraint(solver()->RevAlloc(new IntElementConstraint(

378 solver(), min_rmq_.array(), index_, var)),

379 var, this);

380 return var;

381 }

382 void Accept(ModelVisitor* const visitor) const override {

383 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);

384 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument,

385 min_rmq_.array());

386 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

387 index_);

388 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);

389 }

390

391 private:

392 int64_t IndexMin() const { return std::max<int64_t>(0, index_->Min()); }

393 int64_t IndexMax() const {

394 return std::min<int64_t>(min_rmq_.array().size() - 1, index_->Max());

395 }

396

397 IntVar* const index_;

398 const RangeMinimumQuery<int64_t, std::less<int64_t>> min_rmq_;

399 const RangeMinimumQuery<int64_t, std::greater<int64_t>> max_rmq_;

400};

401

402RangeMinimumQueryExprElement::RangeMinimumQueryExprElement(

403 Solver* solver, const std::vector<int64_t>& values, IntVar* index)

404 : BaseIntExpr(solver), index_(index), min_rmq_(values), max_rmq_(values) {

405 CHECK(solver != nullptr);

406 CHECK(index != nullptr);

407}

408

409int64_t RangeMinimumQueryExprElement::Min() const {

410 return min_rmq_.RangeMinimum(IndexMin(), IndexMax() + 1);

411}

412

413int64_t RangeMinimumQueryExprElement::Max() const {

414 return max_rmq_.RangeMinimum(IndexMin(), IndexMax() + 1);

415}

416

417void RangeMinimumQueryExprElement::Range(int64_t* mi, int64_t* ma) {

418 const int64_t range_min = IndexMin();

419 const int64_t range_max = IndexMax() + 1;

420 *mi = min_rmq_.RangeMinimum(range_min, range_max);

421 *ma = max_rmq_.RangeMinimum(range_min, range_max);

422}

423

424#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test) \

425 const std::vector<int64_t>& values = min_rmq_.array(); \

426 int64_t index_min = IndexMin(); \

427 int64_t index_max = IndexMax(); \

428 int64_t value = values[index_min]; \

429 while (index_min < index_max && (test)) { \

430 index_min++; \

431 value = values[index_min]; \

432 } \

433 if (index_min == index_max && (test)) { \

434 solver()->Fail(); \

435 } \

436 value = values[index_max]; \

437 while (index_max >= index_min && (test)) { \

438 index_max--; \

439 value = values[index_max]; \

440 } \

441 index_->SetRange(index_min, index_max);

442

443void RangeMinimumQueryExprElement::SetMin(int64_t m) {

445}

446

447void RangeMinimumQueryExprElement::SetMax(int64_t m) {

449}

450

451void RangeMinimumQueryExprElement::SetRange(int64_t mi, int64_t ma) {

452 if (mi > ma) {

453 solver()->Fail();

454 }

456}

457

458#undef UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS

459

460

461

462class IncreasingIntExprElement : public BaseIntExpr {

463 public:

464 IncreasingIntExprElement(Solver* s, const std::vector<int64_t>& values,

465 IntVar* index);

466 ~IncreasingIntExprElement() override {}

467

468 int64_t Min() const override;

469 void SetMin(int64_t m) override;

470 int64_t Max() const override;

471 void SetMax(int64_t m) override;

472 void SetRange(int64_t mi, int64_t ma) override;

473 bool Bound() const override { return (index_->Bound()); }

474

475 std::string name() const override {

476 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),

477 index_->name());

478 }

479 std::string DebugString() const override {

480 return absl::StrFormat("IntElement(%s, %s)", absl::StrJoin(values_, ", "),

481 index_->DebugString());

482 }

483

484 void Accept(ModelVisitor* const visitor) const override {

485 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);

486 visitor->VisitIntegerArrayArgument(ModelVisitor::kValuesArgument, values_);

487 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

488 index_);

489 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);

490 }

491

492 void WhenRange(Demon* d) override { index_->WhenRange(d); }

493

494 IntVar* CastToVar() override {

495 Solver* const s = solver();

496 IntVar* const var = s->MakeIntVar(values_);

498 return var;

499 }

500

501 private:

502 const std::vector<int64_t> values_;

503 IntVar* const index_;

504};

505

506IncreasingIntExprElement::IncreasingIntExprElement(

507 Solver* const s, const std::vector<int64_t>& values, IntVar* const index)

508 : BaseIntExpr(s), values_(values), index_(index) {

509 DCHECK(index);

510 DCHECK(s);

511}

512

513int64_t IncreasingIntExprElement::Min() const {

514 const int64_t expression_min = std::max<int64_t>(0, index_->Min());

515 return (expression_min < values_.size()

516 ? values_[expression_min]

517 : std::numeric_limits<int64_t>::max());

518}

519

520void IncreasingIntExprElement::SetMin(int64_t m) {

521 const int64_t index_min = std::max<int64_t>(0, index_->Min());

522 const int64_t index_max =

523 std::min<int64_t>(values_.size() - 1, index_->Max());

524

525 if (index_min > index_max || m > values_[index_max]) {

526 solver()->Fail();

527 }

528

529 const std::vector<int64_t>::const_iterator first =

530 std::lower_bound(values_.begin(), values_.end(), m);

531 const int64_t new_index_min = first - values_.begin();

532 index_->SetMin(new_index_min);

533}

534

535int64_t IncreasingIntExprElement::Max() const {

536 const int64_t expression_max =

537 std::min<int64_t>(values_.size() - 1, index_->Max());

538 return (expression_max >= 0 ? values_[expression_max]

539 : std::numeric_limits<int64_t>::max());

540}

541

542void IncreasingIntExprElement::SetMax(int64_t m) {

543 int64_t index_min = std::max<int64_t>(0, index_->Min());

544 if (m < values_[index_min]) {

545 solver()->Fail();

546 }

547

548 const std::vector<int64_t>::const_iterator last_after =

549 std::upper_bound(values_.begin(), values_.end(), m);

550 const int64_t new_index_max = (last_after - values_.begin()) - 1;

551 index_->SetRange(0, new_index_max);

552}

553

554void IncreasingIntExprElement::SetRange(int64_t mi, int64_t ma) {

555 if (mi > ma) {

556 solver()->Fail();

557 }

558 const int64_t index_min = std::max<int64_t>(0, index_->Min());

559 const int64_t index_max =

560 std::min<int64_t>(values_.size() - 1, index_->Max());

561

562 if (mi > ma || ma < values_[index_min] || mi > values_[index_max]) {

563 solver()->Fail();

564 }

565

566 const std::vector<int64_t>::const_iterator first =

567 std::lower_bound(values_.begin(), values_.end(), mi);

568 const int64_t new_index_min = first - values_.begin();

569

570 const std::vector<int64_t>::const_iterator last_after =

571 std::upper_bound(first, values_.end(), ma);

572 const int64_t new_index_max = (last_after - values_.begin()) - 1;

573

574

575 index_->SetRange(new_index_min, new_index_max);

576}

577

578

579IntExpr* BuildElement(Solver* const solver, const std::vector<int64_t>& values,

580 IntVar* const index) {

581

582

584 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));

585 return solver->MakeIntConst(values[0]);

586 }

587

588

590 std::vector<int64_t> ones;

591 int first_zero = -1;

592 for (int i = 0; i < values.size(); ++i) {

593 if (values[i] == 1) {

594 ones.push_back(i);

595 } else {

596 first_zero = i;

597 }

598 }

599 if (ones.size() == 1) {

600 DCHECK_EQ(int64_t{1}, values[ones.back()]);

601 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));

602 return solver->MakeIsEqualCstVar(index, ones.back());

603 } else if (ones.size() == values.size() - 1) {

604 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));

605 return solver->MakeIsDifferentCstVar(index, first_zero);

606 } else if (ones.size() == ones.back() - ones.front() + 1) {

607 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));

608 IntVar* const b = solver->MakeBoolVar("ContiguousBooleanElementVar");

609 solver->AddConstraint(

610 solver->MakeIsBetweenCt(index, ones.front(), ones.back(), b));

611 return b;

612 } else {

613 IntVar* const b = solver->MakeBoolVar("NonContiguousBooleanElementVar");

614 solver->AddConstraint(solver->MakeBetweenCt(index, 0, values.size() - 1));

615 solver->AddConstraint(solver->MakeIsMemberCt(index, ones, b));

616 return b;

617 }

618 }

619 IntExpr* cache = nullptr;

620 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {

621 cache = solver->Cache()->FindVarConstantArrayExpression(

622 index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);

623 }

624 if (cache != nullptr) {

625 return cache;

626 } else {

627 IntExpr* result = nullptr;

628 if (values.size() >= 2 && index->Min() == 0 && index->Max() == 1) {

629 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),

630 values[0]);

631 } else if (values.size() == 2 && index->Contains(0) && index->Contains(1)) {

632 solver->AddConstraint(solver->MakeBetweenCt(index, 0, 1));

633 result = solver->MakeSum(solver->MakeProd(index, values[1] - values[0]),

634 values[0]);

636 result = solver->MakeSum(index, values[0]);

638 result = solver->RegisterIntExpr(solver->RevAlloc(

639 new IncreasingIntExprElement(solver, values, index)));

640 } else {

641 if (solver->parameters().use_element_rmq()) {

642 result = solver->RegisterIntExpr(solver->RevAlloc(

643 new RangeMinimumQueryExprElement(solver, values, index)));

644 } else {

645 result = solver->RegisterIntExpr(

646 solver->RevAlloc(new IntExprElement(solver, values, index)));

647 }

648 }

649 if (!absl::GetFlag(FLAGS_cp_disable_element_cache)) {

650 solver->Cache()->InsertVarConstantArrayExpression(

651 result, index, values, ModelCache::VAR_CONSTANT_ARRAY_ELEMENT);

652 }

653 return result;

654 }

655}

656}

657

659 IntVar* const index) {

660 DCHECK(index);

661 DCHECK_EQ(this, index->solver());

662 if (index->Bound()) {

664 }

665 return BuildElement(this, values, index);

666}

667

669 IntVar* const index) {

670 DCHECK(index);

671 DCHECK_EQ(this, index->solver());

672 if (index->Bound()) {

674 }

675 return BuildElement(this, ToInt64Vector(values), index);

676}

677

678

679

680namespace {

681class IntExprFunctionElement : public BaseIntExprElement {

682 public:

684 ~IntExprFunctionElement() override;

685

686 std::string name() const override {

687 return absl::StrFormat("IntFunctionElement(%s)", expr_->name());

688 }

689

690 std::string DebugString() const override {

691 return absl::StrFormat("IntFunctionElement(%s)", expr_->DebugString());

692 }

693

694 void Accept(ModelVisitor* const visitor) const override {

695

696 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);

697 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

698 expr_);

699 visitor->VisitInt64ToInt64Extension(values_, expr_->Min(), expr_->Max());

700 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);

701 }

702

703 protected:

704 int64_t ElementValue(int index) const override { return values_(index); }

705 int64_t ExprMin() const override { return expr_->Min(); }

706 int64_t ExprMax() const override { return expr_->Max(); }

707

708 private:

709 Solver::IndexEvaluator1 values_;

710};

711

712IntExprFunctionElement::IntExprFunctionElement(Solver* const s,

713 Solver::IndexEvaluator1 values,

714 IntVar* const e)

715 : BaseIntExprElement(s, e), values_(std::move(values)) {

716 CHECK(values_ != nullptr);

717}

718

719IntExprFunctionElement::~IntExprFunctionElement() {}

720

721

722

723class IncreasingIntExprFunctionElement : public BaseIntExpr {

724 public:

725 IncreasingIntExprFunctionElement(Solver* const s,

726 Solver::IndexEvaluator1 values,

727 IntVar* const index)

728 : BaseIntExpr(s), values_(std::move(values)), index_(index) {

729 DCHECK(values_ != nullptr);

730 DCHECK(index);

731 DCHECK(s);

732 }

733

734 ~IncreasingIntExprFunctionElement() override {}

735

736 int64_t Min() const override { return values_(index_->Min()); }

737

738 void SetMin(int64_t m) override {

739 const int64_t index_min = index_->Min();

740 const int64_t index_max = index_->Max();

741 if (m > values_(index_max)) {

742 solver()->Fail();

743 }

744 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, m);

745 index_->SetMin(new_index_min);

746 }

747

748 int64_t Max() const override { return values_(index_->Max()); }

749

750 void SetMax(int64_t m) override {

751 int64_t index_min = index_->Min();

752 int64_t index_max = index_->Max();

753 if (m < values_(index_min)) {

754 solver()->Fail();

755 }

756 const int64_t new_index_max = FindNewIndexMax(index_min, index_max, m);

757 index_->SetMax(new_index_max);

758 }

759

760 void SetRange(int64_t mi, int64_t ma) override {

761 const int64_t index_min = index_->Min();

762 const int64_t index_max = index_->Max();

763 const int64_t value_min = values_(index_min);

764 const int64_t value_max = values_(index_max);

765 if (mi > ma || ma < value_min || mi > value_max) {

766 solver()->Fail();

767 }

768 if (mi <= value_min && ma >= value_max) {

769

770 return;

771 }

772

773 const int64_t new_index_min = FindNewIndexMin(index_min, index_max, mi);

774 const int64_t new_index_max = FindNewIndexMax(new_index_min, index_max, ma);

775

776 index_->SetRange(new_index_min, new_index_max);

777 }

778

779 std::string name() const override {

780 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",

781 index_->name());

782 }

783

784 std::string DebugString() const override {

785 return absl::StrFormat("IncreasingIntExprFunctionElement(values, %s)",

787 }

788

789 void WhenRange(Demon* d) override { index_->WhenRange(d); }

790

791 void Accept(ModelVisitor* const visitor) const override {

792

793 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);

794 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

795 index_);

796 if (index_->Min() == 0) {

797 visitor->VisitInt64ToInt64AsArray(values_, ModelVisitor::kValuesArgument,

798 index_->Max());

799 } else {

800 visitor->VisitInt64ToInt64Extension(values_, index_->Min(),

801 index_->Max());

802 }

803 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);

804 }

805

806 private:

807 int64_t FindNewIndexMin(int64_t index_min, int64_t index_max, int64_t m) {

808 if (m <= values_(index_min)) {

809 return index_min;

810 }

811

812 DCHECK_LT(values_(index_min), m);

813 DCHECK_GE(values_(index_max), m);

814

815 int64_t index_lower_bound = index_min;

816 int64_t index_upper_bound = index_max;

817 while (index_upper_bound - index_lower_bound > 1) {

818 DCHECK_LT(values_(index_lower_bound), m);

819 DCHECK_GE(values_(index_upper_bound), m);

820 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;

821 const int64_t pivot_value = values_(pivot);

822 if (pivot_value < m) {

823 index_lower_bound = pivot;

824 } else {

825 index_upper_bound = pivot;

826 }

827 }

828 DCHECK(values_(index_upper_bound) >= m);

829 return index_upper_bound;

830 }

831

832 int64_t FindNewIndexMax(int64_t index_min, int64_t index_max, int64_t m) {

833 if (m >= values_(index_max)) {

834 return index_max;

835 }

836

837 DCHECK_LE(values_(index_min), m);

838 DCHECK_GT(values_(index_max), m);

839

840 int64_t index_lower_bound = index_min;

841 int64_t index_upper_bound = index_max;

842 while (index_upper_bound - index_lower_bound > 1) {

843 DCHECK_LE(values_(index_lower_bound), m);

844 DCHECK_GT(values_(index_upper_bound), m);

845 const int64_t pivot = (index_lower_bound + index_upper_bound) / 2;

846 const int64_t pivot_value = values_(pivot);

847 if (pivot_value > m) {

848 index_upper_bound = pivot;

849 } else {

850 index_lower_bound = pivot;

851 }

852 }

853 DCHECK(values_(index_lower_bound) <= m);

854 return index_lower_bound;

855 }

856

857 Solver::IndexEvaluator1 values_;

858 IntVar* const index_;

859};

860}

861

863 IntVar* const index) {

864 CHECK_EQ(this, index->solver());

866 RevAlloc(new IntExprFunctionElement(this, std::move(values), index)));

867}

868

870 bool increasing, IntVar* const index) {

871 CHECK_EQ(this, index->solver());

872 if (increasing) {

874 new IncreasingIntExprFunctionElement(this, std::move(values), index)));

875 } else {

878 this,

879 [values = std::move(values)](int64_t i) { return -values(i); },

880 index))));

881 }

882}

883

884

885

886namespace {

887class IntIntExprFunctionElement : public BaseIntExpr {

888 public:

891 ~IntIntExprFunctionElement() override;

892 std::string DebugString() const override {

893 return absl::StrFormat("IntIntFunctionElement(%s,%s)",

895 }

896 int64_t Min() const override;

897 int64_t Max() const override;

898 void Range(int64_t* lower_bound, int64_t* upper_bound) override;

899 void SetMin(int64_t lower_bound) override;

900 void SetMax(int64_t upper_bound) override;

901 void SetRange(int64_t lower_bound, int64_t upper_bound) override;

902 bool Bound() const override { return expr1_->Bound() && expr2_->Bound(); }

903

904 void WhenRange(Demon* d) override {

907 }

908

909 void Accept(ModelVisitor* const visitor) const override {

910 visitor->BeginVisitIntegerExpression(ModelVisitor::kElement, this);

911 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

912 expr1_);

913 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndex2Argument,

914 expr2_);

915

916 const int64_t expr1_min = expr1_->Min();

917 const int64_t expr1_max = expr1_->Max();

918 visitor->VisitIntegerArgument(ModelVisitor::kMinArgument, expr1_min);

919 visitor->VisitIntegerArgument(ModelVisitor::kMaxArgument, expr1_max);

920 for (int i = expr1_min; i <= expr1_max; ++i) {

921 visitor->VisitInt64ToInt64Extension(

922 [this, i](int64_t j) { return values_(i, j); }, expr2_->Min(),

923 expr2_->Max());

924 }

925 visitor->EndVisitIntegerExpression(ModelVisitor::kElement, this);

926 }

927

928 private:

929 int64_t ElementValue(int index1, int index2) const {

930 return values_(index1, index2);

931 }

932 void UpdateSupports() const;

933

934 IntVar* const expr1_;

935 IntVar* const expr2_;

936 mutable int64_t min_;

937 mutable int min_support1_;

938 mutable int min_support2_;

939 mutable int64_t max_;

940 mutable int max_support1_;

941 mutable int max_support2_;

942 mutable bool initial_update_;

943 Solver::IndexEvaluator2 values_;

944 IntVarIterator* const expr1_iterator_;

945 IntVarIterator* const expr2_iterator_;

946};

947

948IntIntExprFunctionElement::IntIntExprFunctionElement(

950 IntVar* const expr2)

951 : BaseIntExpr(s),

952 expr1_(expr1),

953 expr2_(expr2),

954 min_(0),

955 min_support1_(-1),

956 min_support2_(-1),

957 max_(0),

958 max_support1_(-1),

959 max_support2_(-1),

960 initial_update_(true),

961 values_(std::move(values)),

962 expr1_iterator_(expr1_->MakeDomainIterator(true)),

963 expr2_iterator_(expr2_->MakeDomainIterator(true)) {

964 CHECK(values_ != nullptr);

965}

966

967IntIntExprFunctionElement::~IntIntExprFunctionElement() {}

968

969int64_t IntIntExprFunctionElement::Min() const {

970 UpdateSupports();

971 return min_;

972}

973

974int64_t IntIntExprFunctionElement::Max() const {

975 UpdateSupports();

976 return max_;

977}

978

979void IntIntExprFunctionElement::Range(int64_t* lower_bound,

980 int64_t* upper_bound) {

981 UpdateSupports();

982 *lower_bound = min_;

983 *upper_bound = max_;

984}

985

986#define UPDATE_ELEMENT_INDEX_BOUNDS(test) \

987 const int64_t emin1 = expr1_->Min(); \

988 const int64_t emax1 = expr1_->Max(); \

989 const int64_t emin2 = expr2_->Min(); \

990 const int64_t emax2 = expr2_->Max(); \

991 int64_t nmin1 = emin1; \

992 bool found = false; \

993 while (nmin1 <= emax1 && !found) { \

994 for (int i = emin2; i <= emax2; ++i) { \

995 int64_t value = ElementValue(nmin1, i); \

996 if (test) { \

997 found = true; \

998 break; \

999 } \

1000 } \

1001 if (!found) { \

1002 nmin1++; \

1003 } \

1004 } \

1005 if (nmin1 > emax1) { \

1006 solver()->Fail(); \

1007 } \

1008 int64_t nmin2 = emin2; \

1009 found = false; \

1010 while (nmin2 <= emax2 && !found) { \

1011 for (int i = emin1; i <= emax1; ++i) { \

1012 int64_t value = ElementValue(i, nmin2); \

1013 if (test) { \

1014 found = true; \

1015 break; \

1016 } \

1017 } \

1018 if (!found) { \

1019 nmin2++; \

1020 } \

1021 } \

1022 if (nmin2 > emax2) { \

1023 solver()->Fail(); \

1024 } \

1025 int64_t nmax1 = emax1; \

1026 found = false; \

1027 while (nmax1 >= nmin1 && !found) { \

1028 for (int i = emin2; i <= emax2; ++i) { \

1029 int64_t value = ElementValue(nmax1, i); \

1030 if (test) { \

1031 found = true; \

1032 break; \

1033 } \

1034 } \

1035 if (!found) { \

1036 nmax1--; \

1037 } \

1038 } \

1039 int64_t nmax2 = emax2; \

1040 found = false; \

1041 while (nmax2 >= nmin2 && !found) { \

1042 for (int i = emin1; i <= emax1; ++i) { \

1043 int64_t value = ElementValue(i, nmax2); \

1044 if (test) { \

1045 found = true; \

1046 break; \

1047 } \

1048 } \

1049 if (!found) { \

1050 nmax2--; \

1051 } \

1052 } \

1053 expr1_->SetRange(nmin1, nmax1); \

1054 expr2_->SetRange(nmin2, nmax2);

1055

1056void IntIntExprFunctionElement::SetMin(int64_t lower_bound) {

1058}

1059

1060void IntIntExprFunctionElement::SetMax(int64_t upper_bound) {

1062}

1063

1064void IntIntExprFunctionElement::SetRange(int64_t lower_bound,

1065 int64_t upper_bound) {

1066 if (lower_bound > upper_bound) {

1067 solver()->Fail();

1068 }

1070}

1071

1072#undef UPDATE_ELEMENT_INDEX_BOUNDS

1073

1074void IntIntExprFunctionElement::UpdateSupports() const {

1075 if (initial_update_ || !expr1_->Contains(min_support1_) ||

1076 !expr1_->Contains(max_support1_) || !expr2_->Contains(min_support2_) ||

1077 !expr2_->Contains(max_support2_)) {

1078 const int64_t emax1 = expr1_->Max();

1079 const int64_t emax2 = expr2_->Max();

1080 int64_t min_value = ElementValue(emax1, emax2);

1081 int64_t max_value = min_value;

1082 int min_support1 = emax1;

1083 int max_support1 = emax1;

1084 int min_support2 = emax2;

1085 int max_support2 = emax2;

1086 for (const int64_t index1 : InitAndGetValues(expr1_iterator_)) {

1087 for (const int64_t index2 : InitAndGetValues(expr2_iterator_)) {

1088 const int64_t value = ElementValue(index1, index2);

1089 if (value > max_value) {

1090 max_value = value;

1091 max_support1 = index1;

1092 max_support2 = index2;

1093 } else if (value < min_value) {

1094 min_value = value;

1095 min_support1 = index1;

1096 min_support2 = index2;

1097 }

1098 }

1099 }

1100 Solver* s = solver();

1101 s->SaveAndSetValue(&min_, min_value);

1102 s->SaveAndSetValue(&min_support1_, min_support1);

1103 s->SaveAndSetValue(&min_support2_, min_support2);

1104 s->SaveAndSetValue(&max_, max_value);

1105 s->SaveAndSetValue(&max_support1_, max_support1);

1106 s->SaveAndSetValue(&max_support2_, max_support2);

1107 s->SaveAndSetValue(&initial_update_, false);

1108 }

1109}

1110}

1111

1113 IntVar* const index1, IntVar* const index2) {

1114 CHECK_EQ(this, index1->solver());

1115 CHECK_EQ(this, index2->solver());

1117 new IntIntExprFunctionElement(this, std::move(values), index1, index2)));

1118}

1119

1120

1121

1122

1123

1125 public:

1129 condition_(condition),

1130 zero_(zero),

1131 one_(one) {}

1132

1134

1136 Demon* const demon = solver()->MakeConstraintInitialPropagateCallback(this);

1137 condition_->WhenBound(demon);

1138 one_->WhenRange(demon);

1139 zero_->WhenRange(demon);

1141 }

1142

1144 condition_->SetRange(0, 1);

1145 const int64_t target_var_min = target_var_->Min();

1146 const int64_t target_var_max = target_var_->Max();

1147 int64_t new_min = std::numeric_limits<int64_t>::min();

1148 int64_t new_max = std::numeric_limits<int64_t>::max();

1149 if (condition_->Max() == 0) {

1150 zero_->SetRange(target_var_min, target_var_max);

1151 zero_->Range(&new_min, &new_max);

1152 } else if (condition_->Min() == 1) {

1153 one_->SetRange(target_var_min, target_var_max);

1154 one_->Range(&new_min, &new_max);

1155 } else {

1156 if (target_var_max < zero_->Min() || target_var_min > zero_->Max()) {

1157 condition_->SetValue(1);

1158 one_->SetRange(target_var_min, target_var_max);

1159 one_->Range(&new_min, &new_max);

1160 } else if (target_var_max < one_->Min() || target_var_min > one_->Max()) {

1161 condition_->SetValue(0);

1162 zero_->SetRange(target_var_min, target_var_max);

1163 zero_->Range(&new_min, &new_max);

1164 } else {

1165 int64_t zl = 0;

1166 int64_t zu = 0;

1167 int64_t ol = 0;

1168 int64_t ou = 0;

1169 zero_->Range(&zl, &zu);

1170 one_->Range(&ol, &ou);

1171 new_min = std::min(zl, ol);

1172 new_max = std::max(zu, ou);

1173 }

1174 }

1175 target_var_->SetRange(new_min, new_max);

1176 }

1177

1179 return absl::StrFormat("(%s ? %s : %s) == %s", condition_->DebugString(),

1180 one_->DebugString(), zero_->DebugString(),

1182 }

1183

1185

1186 private:

1187 IntVar* const condition_;

1190};

1191

1192

1193

1194

1195

1196

1197

1198namespace {

1199class IntExprEvaluatorElementCt : public CastConstraint {

1200 public:

1201 IntExprEvaluatorElementCt(Solver* s, Solver::Int64ToIntVar evaluator,

1202 int64_t range_start, int64_t range_end,

1203 IntVar* index, IntVar* target_var);

1204 ~IntExprEvaluatorElementCt() override {}

1205

1206 void Post() override;

1207 void InitialPropagate() override;

1208

1209 void Propagate();

1210 void Update(int index);

1211 void UpdateExpr();

1212

1213 std::string DebugString() const override;

1214 void Accept(ModelVisitor* visitor) const override;

1215

1216 protected:

1217 IntVar* const index_;

1218

1219 private:

1220 const Solver::Int64ToIntVar evaluator_;

1221 const int64_t range_start_;

1222 const int64_t range_end_;

1223 int min_support_;

1224 int max_support_;

1225};

1226

1227IntExprEvaluatorElementCt::IntExprEvaluatorElementCt(

1229 int64_t range_end, IntVar* const index, IntVar* const target_var)

1230 : CastConstraint(s, target_var),

1231 index_(index),

1232 evaluator_(std::move(evaluator)),

1233 range_start_(range_start),

1234 range_end_(range_end),

1235 min_support_(-1),

1236 max_support_(-1) {}

1237

1238void IntExprEvaluatorElementCt::Post() {

1240 solver(), this, &IntExprEvaluatorElementCt::Propagate, "Propagate");

1241 for (int i = range_start_; i < range_end_; ++i) {

1242 IntVar* const current_var = evaluator_(i);

1243 current_var->WhenRange(delayed_propagate_demon);

1245 solver(), this, &IntExprEvaluatorElementCt::Update, "Update", i);

1246 current_var->WhenRange(update_demon);

1247 }

1248 index_->WhenRange(delayed_propagate_demon);

1250 solver(), this, &IntExprEvaluatorElementCt::UpdateExpr, "UpdateExpr");

1251 index_->WhenRange(update_expr_demon);

1253 solver(), this, &IntExprEvaluatorElementCt::Propagate, "UpdateVar");

1254

1255 target_var_->WhenRange(update_var_demon);

1256}

1257

1258void IntExprEvaluatorElementCt::InitialPropagate() { Propagate(); }

1259

1260void IntExprEvaluatorElementCt::Propagate() {

1261 const int64_t emin = std::max(range_start_, index_->Min());

1262 const int64_t emax = std::min<int64_t>(range_end_ - 1, index_->Max());

1263 const int64_t vmin = target_var_->Min();

1264 const int64_t vmax = target_var_->Max();

1265 if (emin == emax) {

1266 index_->SetValue(emin);

1267 evaluator_(emin)->SetRange(vmin, vmax);

1268 } else {

1269 int64_t nmin = emin;

1270 for (; nmin <= emax; nmin++) {

1271

1272

1273

1274 IntVar* const nmin_var = evaluator_(nmin);

1275 if (nmin_var->Min() <= vmax && nmin_var->Max() >= vmin) break;

1276 }

1277 int64_t nmax = emax;

1278 for (; nmin <= nmax; nmax--) {

1279

1280

1281

1282 IntExpr* const nmax_var = evaluator_(nmax);

1283 if (nmax_var->Min() <= vmax && nmax_var->Max() >= vmin) break;

1284 }

1285 index_->SetRange(nmin, nmax);

1286 if (nmin == nmax) {

1287 evaluator_(nmin)->SetRange(vmin, vmax);

1288 }

1289 }

1290 if (min_support_ == -1 || max_support_ == -1) {

1291 int min_support = -1;

1292 int max_support = -1;

1293 int64_t gmin = std::numeric_limits<int64_t>::max();

1294 int64_t gmax = std::numeric_limits<int64_t>::min();

1295 for (int i = index_->Min(); i <= index_->Max(); ++i) {

1296 IntExpr* const var_i = evaluator_(i);

1297 const int64_t vmin = var_i->Min();

1298 if (vmin < gmin) {

1299 gmin = vmin;

1300 }

1301 const int64_t vmax = var_i->Max();

1302 if (vmax > gmax) {

1303 gmax = vmax;

1304 }

1305 }

1306 solver()->SaveAndSetValue(&min_support_, min_support);

1307 solver()->SaveAndSetValue(&max_support_, max_support);

1308 target_var_->SetRange(gmin, gmax);

1309 }

1310}

1311

1312void IntExprEvaluatorElementCt::Update(int index) {

1313 if (index == min_support_ || index == max_support_) {

1314 solver()->SaveAndSetValue(&min_support_, -1);

1315 solver()->SaveAndSetValue(&max_support_, -1);

1316 }

1317}

1318

1319void IntExprEvaluatorElementCt::UpdateExpr() {

1320 if (!index_->Contains(min_support_) || !index_->Contains(max_support_)) {

1321 solver()->SaveAndSetValue(&min_support_, -1);

1322 solver()->SaveAndSetValue(&max_support_, -1);

1323 }

1324}

1325

1326namespace {

1328 int64_t range_start, int64_t range_end) {

1329 std::string out;

1330 for (int64_t i = range_start; i < range_end; ++i) {

1331 if (i != range_start) {

1332 out += ", ";

1333 }

1334 out += absl::StrFormat("%d -> %s", i, evaluator(i)->DebugString());

1335 }

1336 return out;

1337}

1338

1340 int64_t range_begin, int64_t range_end) {

1341 std::string out;

1342 if (range_end - range_begin > 10) {

1343 out = absl::StrFormat(

1344 "IntToIntVar(%s, ...%s)",

1345 StringifyEvaluatorBare(evaluator, range_begin, range_begin + 5),

1346 StringifyEvaluatorBare(evaluator, range_end - 5, range_end));

1347 } else {

1348 out = absl::StrFormat(

1349 "IntToIntVar(%s)",

1350 StringifyEvaluatorBare(evaluator, range_begin, range_end));

1351 }

1352 return out;

1353}

1354}

1355

1356std::string IntExprEvaluatorElementCt::DebugString() const {

1357 return StringifyInt64ToIntVar(evaluator_, range_start_, range_end_);

1358}

1359

1360void IntExprEvaluatorElementCt::Accept(ModelVisitor* const visitor) const {

1361 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);

1362 visitor->VisitIntegerVariableEvaluatorArgument(

1363 ModelVisitor::kEvaluatorArgument, evaluator_);

1364 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);

1365 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,

1366 target_var_);

1367 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);

1368}

1369

1370

1371

1372

1373

1374

1375class IntExprArrayElementCt : public IntExprEvaluatorElementCt {

1376 public:

1377 IntExprArrayElementCt(Solver* s, std::vector<IntVar*> vars, IntVar* index,

1378 IntVar* target_var);

1379

1380 std::string DebugString() const override;

1381 void Accept(ModelVisitor* visitor) const override;

1382

1383 private:

1384 const std::vector<IntVar*> vars_;

1385};

1386

1387IntExprArrayElementCt::IntExprArrayElementCt(Solver* const s,

1388 std::vector<IntVar*> vars,

1389 IntVar* const index,

1390 IntVar* const target_var)

1391 : IntExprEvaluatorElementCt(

1392 s, [this](int64_t idx) { return vars_[idx]; }, 0, vars.size(), index,

1393 target_var),

1394 vars_(std::move(vars)) {}

1395

1396std::string IntExprArrayElementCt::DebugString() const {

1397 int64_t size = vars_.size();

1398 if (size > 10) {

1399 return absl::StrFormat(

1400 "IntExprArrayElement(var array of size %d, %s) == %s", size,

1401 index_->DebugString(), target_var_->DebugString());

1402 } else {

1403 return absl::StrFormat("IntExprArrayElement([%s], %s) == %s",

1405 index_->DebugString(), target_var_->DebugString());

1406 }

1407}

1408

1409void IntExprArrayElementCt::Accept(ModelVisitor* const visitor) const {

1410 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);

1411 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,

1412 vars_);

1413 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument, index_);

1414 visitor->VisitIntegerExpressionArgument(ModelVisitor::kTargetArgument,

1415 target_var_);

1416 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);

1417}

1418

1419

1420

1421

1422

1423class IntExprArrayElementCstCt : public Constraint {

1424 public:

1425 IntExprArrayElementCstCt(Solver* const s, const std::vector<IntVar*>& vars,

1426 IntVar* const index, int64_t target)

1427 : Constraint(s),

1428 vars_(vars),

1429 index_(index),

1430 target_(target),

1431 demons_(vars.size()) {}

1432

1433 ~IntExprArrayElementCstCt() override {}

1434

1435 void Post() override {

1436 for (int i = 0; i < vars_.size(); ++i) {

1438 solver(), this, &IntExprArrayElementCstCt::Propagate, "Propagate", i);

1439 vars_[i]->WhenDomain(demons_[i]);

1440 }

1442 solver(), this, &IntExprArrayElementCstCt::PropagateIndex,

1443 "PropagateIndex");

1445 }

1446

1447 void InitialPropagate() override {

1448 for (int i = 0; i < vars_.size(); ++i) {

1449 Propagate(i);

1450 }

1451 PropagateIndex();

1452 }

1453

1454 void Propagate(int index) {

1455 if (!vars_[index]->Contains(target_)) {

1457 demons_[index]->inhibit(solver());

1458 }

1459 }

1460

1461 void PropagateIndex() {

1462 if (index_->Bound()) {

1463 vars_[index_->Min()]->SetValue(target_);

1464 }

1465 }

1466

1467 std::string DebugString() const override {

1468 return absl::StrFormat("IntExprArrayElement([%s], %s) == %d",

1471 }

1472

1473 void Accept(ModelVisitor* const visitor) const override {

1474 visitor->BeginVisitConstraint(ModelVisitor::kElementEqual, this);

1475 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,

1476 vars_);

1477 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

1478 index_);

1479 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);

1480 visitor->EndVisitConstraint(ModelVisitor::kElementEqual, this);

1481 }

1482

1483 private:

1484 const std::vector<IntVar*> vars_;

1485 IntVar* const index_;

1486 const int64_t target_;

1487 std::vector<Demon*> demons_;

1488};

1489

1490

1491

1492class IntExprIndexOfCt : public Constraint {

1493 public:

1494 IntExprIndexOfCt(Solver* const s, const std::vector<IntVar*>& vars,

1495 IntVar* const index, int64_t target)

1496 : Constraint(s),

1497 vars_(vars),

1498 index_(index),

1499 target_(target),

1500 demons_(vars_.size()),

1501 index_iterator_(index->MakeHoleIterator(true)) {}

1502

1503 ~IntExprIndexOfCt() override {}

1504

1505 void Post() override {

1506 for (int i = 0; i < vars_.size(); ++i) {

1508 solver(), this, &IntExprIndexOfCt::Propagate, "Propagate", i);

1509 vars_[i]->WhenDomain(demons_[i]);

1510 }

1512 solver(), this, &IntExprIndexOfCt::PropagateIndex, "PropagateIndex");

1514 }

1515

1516 void InitialPropagate() override {

1517 for (int i = 0; i < vars_.size(); ++i) {

1519 vars_[i]->RemoveValue(target_);

1520 } else if (!vars_[i]->Contains(target_)) {

1522 demons_[i]->inhibit(solver());

1523 } else if (vars_[i]->Bound()) {

1525 demons_[i]->inhibit(solver());

1526 }

1527 }

1528 }

1529

1530 void Propagate(int index) {

1531 if (!vars_[index]->Contains(target_)) {

1533 demons_[index]->inhibit(solver());

1534 } else if (vars_[index]->Bound()) {

1536 }

1537 }

1538

1539 void PropagateIndex() {

1540 const int64_t oldmax = index_->OldMax();

1541 const int64_t vmin = index_->Min();

1542 const int64_t vmax = index_->Max();

1543 for (int64_t value = index_->OldMin(); value < vmin; ++value) {

1544 vars_[value]->RemoveValue(target_);

1545 demons_[value]->inhibit(solver());

1546 }

1547 for (const int64_t value : InitAndGetValues(index_iterator_)) {

1548 vars_[value]->RemoveValue(target_);

1549 demons_[value]->inhibit(solver());

1550 }

1551 for (int64_t value = vmax + 1; value <= oldmax; ++value) {

1552 vars_[value]->RemoveValue(target_);

1553 demons_[value]->inhibit(solver());

1554 }

1555 if (index_->Bound()) {

1556 vars_[index_->Min()]->SetValue(target_);

1557 }

1558 }

1559

1560 std::string DebugString() const override {

1561 return absl::StrFormat("IntExprIndexOf([%s], %s) == %d",

1564 }

1565

1566 void Accept(ModelVisitor* const visitor) const override {

1567 visitor->BeginVisitConstraint(ModelVisitor::kIndexOf, this);

1568 visitor->VisitIntegerVariableArrayArgument(ModelVisitor::kVarsArgument,

1569 vars_);

1570 visitor->VisitIntegerExpressionArgument(ModelVisitor::kIndexArgument,

1571 index_);

1572 visitor->VisitIntegerArgument(ModelVisitor::kTargetArgument, target_);

1573 visitor->EndVisitConstraint(ModelVisitor::kIndexOf, this);

1574 }

1575

1576 private:

1577 const std::vector<IntVar*> vars_;

1578 IntVar* const index_;

1579 const int64_t target_;

1580 std::vector<Demon*> demons_;

1581 IntVarIterator* const index_iterator_;

1582};

1583

1584

1585

1586Constraint* MakeElementEqualityFunc(Solver* const solver,

1587 const std::vector<int64_t>& vals,

1588 IntVar* const index, IntVar* const target) {

1589 if (index->Bound()) {

1590 const int64_t val = index->Min();

1591 if (val < 0 || val >= vals.size()) {

1592 return solver->MakeFalseConstraint();

1593 } else {

1594 return solver->MakeEquality(target, vals[val]);

1595 }

1596 } else {

1598 return solver->MakeEquality(target, solver->MakeSum(index, vals[0]));

1599 } else {

1600 return solver->RevAlloc(

1601 new IntElementConstraint(solver, vals, index, target));

1602 }

1603 }

1604}

1605}

1606

1608 IntExpr* const then_expr,

1609 IntExpr* const else_expr,

1610 IntVar* const target_var) {

1612 new IfThenElseCt(this, condition, then_expr, else_expr, target_var));

1613}

1614

1616 IntVar* const index) {

1617 if (index->Bound()) {

1618 return vars[index->Min()];

1619 }

1620 const int size = vars.size();

1622 std::vector<int64_t> values(size);

1623 for (int i = 0; i < size; ++i) {

1624 values[i] = vars[i]->Value();

1625 }

1627 }

1628 if (index->Size() == 2 && index->Min() + 1 == index->Max() &&

1629 index->Min() >= 0 && index->Max() < vars.size()) {

1630

1631 IntVar* const scaled_index = MakeSum(index, -index->Min())->Var();

1632 IntVar* const zero = vars[index->Min()];

1633 IntVar* const one = vars[index->Max()];

1634 const std::string name = absl::StrFormat(

1635 "ElementVar([%s], %s)", JoinNamePtr(vars, ", "), index->name());

1637 std::max(zero->Max(), one->Max()), name);

1640 return target;

1641 }

1642 int64_t emin = std::numeric_limits<int64_t>::max();

1643 int64_t emax = std::numeric_limits<int64_t>::min();

1644 std::unique_ptr<IntVarIterator> iterator(index->MakeDomainIterator(false));

1645 for (const int64_t index_value : InitAndGetValues(iterator.get())) {

1646 if (index_value >= 0 && index_value < size) {

1647 emin = std::min(emin, vars[index_value]->Min());

1648 emax = std::max(emax, vars[index_value]->Max());

1649 }

1650 }

1651 const std::string vname =

1652 size > 10 ? absl::StrFormat("ElementVar(var array of size %d, %s)", size,

1654 : absl::StrFormat("ElementVar([%s], %s)",

1658 RevAlloc(new IntExprArrayElementCt(this, vars, index, element_var)));

1659 return element_var;

1660}

1661

1663 int64_t range_end, IntVar* argument) {

1664 const std::string index_name =

1666 const std::string vname = absl::StrFormat(

1667 "ElementVar(%s, %s)",

1668 StringifyInt64ToIntVar(vars, range_start, range_end), index_name);

1669 IntVar* const element_var =

1670 MakeIntVar(std::numeric_limits<int64_t>::min(),

1671 std::numeric_limits<int64_t>::max(), vname);

1672 IntExprEvaluatorElementCt* evaluation_ct = new IntExprEvaluatorElementCt(

1673 this, std::move(vars), range_start, range_end, argument, element_var);

1675 evaluation_ct->Propagate();

1676 return element_var;

1677}

1678

1680 IntVar* const index,

1681 IntVar* const target) {

1682 return MakeElementEqualityFunc(this, vals, index, target);

1683}

1684

1686 IntVar* const index,

1687 IntVar* const target) {

1688 return MakeElementEqualityFunc(this, ToInt64Vector(vals), index, target);

1689}

1690

1692 IntVar* const index,

1693 IntVar* const target) {

1695 std::vector<int64_t> values(vars.size());

1696 for (int i = 0; i < vars.size(); ++i) {

1697 values[i] = vars[i]->Value();

1698 }

1700 }

1701 if (index->Bound()) {

1702 const int64_t val = index->Min();

1703 if (val < 0 || val >= vars.size()) {

1705 } else {

1707 }

1708 } else {

1709 if (target->Bound()) {

1711 new IntExprArrayElementCstCt(this, vars, index, target->Min()));

1712 } else {

1713 return RevAlloc(new IntExprArrayElementCt(this, vars, index, target));

1714 }

1715 }

1716}

1717

1719 IntVar* const index, int64_t target) {

1721 std::vector<int> valid_indices;

1722 for (int i = 0; i < vars.size(); ++i) {

1723 if (vars[i]->Value() == target) {

1724 valid_indices.push_back(i);

1725 }

1726 }

1728 }

1729 if (index->Bound()) {

1730 const int64_t pos = index->Min();

1731 if (pos >= 0 && pos < vars.size()) {

1732 IntVar* const var = vars[pos];

1734 } else {

1736 }

1737 } else {

1738 return RevAlloc(new IntExprArrayElementCstCt(this, vars, index, target));

1739 }

1740}

1741

1743 IntVar* const index, int64_t target) {

1744 if (index->Bound()) {

1745 const int64_t pos = index->Min();

1746 if (pos >= 0 && pos < vars.size()) {

1747 IntVar* const var = vars[pos];

1749 } else {

1751 }

1752 } else {

1753 return RevAlloc(new IntExprIndexOfCt(this, vars, index, target));

1754 }

1755}

1756

1758 int64_t value) {

1759 IntExpr* const cache = model_cache_->FindVarArrayConstantExpression(

1761 if (cache != nullptr) {

1762 return cache->Var();

1763 } else {

1764 const std::string name =

1765 absl::StrFormat("Index(%s, %d)", JoinNamePtr(vars, ", "), value);

1768 model_cache_->InsertVarArrayConstantExpression(

1770 return index;

1771 }

1772}

1773}

CastConstraint(Solver *const solver, IntVar *const target_var)

IntVar *const target_var_

Definition element.cc:1124

void Accept(ModelVisitor *const visitor) const override

Accepts the given visitor.

Definition element.cc:1184

~IfThenElseCt() override

Definition element.cc:1133

IfThenElseCt(Solver *const solver, IntVar *const condition, IntExpr *const one, IntExpr *const zero, IntVar *const target)

Definition element.cc:1126

std::string DebugString() const override

Definition element.cc:1178

void Post() override

Definition element.cc:1135

void InitialPropagate() override

Definition element.cc:1143

virtual void SetValue(int64_t v)

This method sets the value of the expression.

virtual bool Bound() const

Returns true if the min and the max of the expression are equal.

virtual void SetMax(int64_t m)=0

virtual void SetRange(int64_t l, int64_t u)

This method sets both the min and the max of the expression.

virtual int64_t Min() const =0

virtual void SetMin(int64_t m)=0

virtual void WhenRange(Demon *d)=0

Attach a demon that will watch the min or the max of the expression.

virtual IntVar * Var()=0

Creates a variable from the expression.

virtual int64_t Max() const =0

virtual void WhenBound(Demon *d)=0

virtual void WhenDomain(Demon *d)=0

virtual IntVarIterator * MakeDomainIterator(bool reversible) const =0

virtual int64_t OldMax() const =0

Returns the previous max.

virtual bool Contains(int64_t v) const =0

virtual void RemoveValue(int64_t v)=0

This method removes the value 'v' from the domain of the variable.

virtual uint64_t Size() const =0

This method returns the number of values in the domain of the variable.

virtual int64_t OldMin() const =0

Returns the previous min.

@ VAR_ARRAY_CONSTANT_INDEX

virtual std::string name() const

Object naming.

std::string DebugString() const override

IntExpr * MakeElement(const std::vector< int64_t > &values, IntVar *index)

values[index]

Definition element.cc:658

Constraint * MakeMemberCt(IntExpr *expr, const std::vector< int64_t > &values)

Constraint * MakeIfThenElseCt(IntVar *condition, IntExpr *then_expr, IntExpr *else_expr, IntVar *target_var)

Special cases with arrays of size two.

Definition element.cc:1607

IntExpr * MakeSum(IntExpr *left, IntExpr *right)

left + right.

std::function< int64_t(int64_t)> IndexEvaluator1

Callback typedefs.

IntExpr * MakeMonotonicElement(IndexEvaluator1 values, bool increasing, IntVar *index)

Definition element.cc:869

IntExpr * MakeOpposite(IntExpr *expr)

-expr

IntExpr * RegisterIntExpr(IntExpr *expr)

Registers a new IntExpr and wraps it inside a TraceIntExpr if necessary.

Constraint * MakeElementEquality(const std::vector< int64_t > &vals, IntVar *index, IntVar *target)

Definition element.cc:1679

std::function< int64_t(int64_t, int64_t)> IndexEvaluator2

Constraint * MakeIndexOfConstraint(const std::vector< IntVar * > &vars, IntVar *index, int64_t target)

Definition element.cc:1742

IntExpr * MakeIndexExpression(const std::vector< IntVar * > &vars, int64_t value)

Definition element.cc:1757

std::function< IntVar *(int64_t)> Int64ToIntVar

Constraint * MakeEquality(IntExpr *left, IntExpr *right)

left == right

IntVar * MakeIntVar(int64_t min, int64_t max, const std::string &name)

MakeIntVar will create the best range based int var for the bounds given.

Constraint * MakeFalseConstraint()

This constraint always fails.

IntVar * MakeIntConst(int64_t val, const std::string &name)

IntConst will create a constant expression.

void AddConstraint(Constraint *c)

Adds the constraint 'c' to the model.

#define UPDATE_ELEMENT_INDEX_BOUNDS(test)

Definition element.cc:986

ABSL_FLAG(bool, cp_disable_element_cache, true, "If true, caching for IntElement is disabled.")

#define UPDATE_RMQ_BASE_ELEMENT_INDEX_BOUNDS(test)

Definition element.cc:424

std::pair< double, double > Range

bool IsArrayConstant(const std::vector< T > &values, const T &value)

std::string JoinDebugStringPtr(const std::vector< T > &v, absl::string_view separator)

bool IsIncreasing(const std::vector< T > &values)

bool IsArrayBoolean(const std::vector< T > &values)

Demon * MakeDelayedConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)

void LinkVarExpr(Solver *s, IntExpr *expr, IntVar *var)

Demon * MakeConstraintDemon0(Solver *const s, T *const ct, void(T::*method)(), const std::string &name)

std::vector< int64_t > ToInt64Vector(const std::vector< int > &input)

bool IsIncreasingContiguous(const std::vector< T > &values)

Demon * MakeConstraintDemon1(Solver *const s, T *const ct, void(T::*method)(P), const std::string &name, P param1)

bool AreAllBound(const std::vector< IntVar * > &vars)

std::string JoinNamePtr(const std::vector< T > &v, absl::string_view separator)