Skip to content

Commit 2b22fef

Browse files
authored
Remove the simplifier's ability to preserve dead letstmts (#8899)
Remove the simplifiers ability to preserve dead letstmts We shouldn't need it anymore thanks to .loop_max symbols now being regular lets instead of magic things in need of preservation.
1 parent bed0d3a commit 2b22fef

File tree

11 files changed

+43
-51
lines changed

11 files changed

+43
-51
lines changed

src/BoundConstantExtentLoops.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ class BoundLoops : public IRMutator {
6767
extent = remove_likelies(extent);
6868
extent = substitute_in_all_lets(extent);
6969
extent = simplify(extent,
70-
true,
7170
Scope<Interval>::empty_scope(),
7271
Scope<ModulusRemainder>::empty_scope(),
7372
facts);

src/LowerWarpShuffles.cpp

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -76,8 +76,8 @@ Expr reduce_expr_helper(Expr e, const Expr &modulus) {
7676
}
7777

7878
Expr reduce_expr(Expr e, const Expr &modulus, const Scope<Interval> &bounds) {
79-
e = reduce_expr_helper(simplify(e, true, bounds), modulus);
80-
if (is_const_one(simplify(e >= 0 && e < modulus, true, bounds))) {
79+
e = reduce_expr_helper(simplify(e, bounds), modulus);
80+
if (is_const_one(simplify(e >= 0 && e < modulus, bounds))) {
8181
return e;
8282
} else {
8383
return e % modulus;
@@ -285,7 +285,7 @@ class DetermineAllocStride : public IRVisitor {
285285

286286
// A version of can_prove which exploits the constant bounds we've been tracking
287287
bool can_prove(const Expr &e) {
288-
return is_const_one(simplify(e, true, bounds));
288+
return is_const_one(simplify(e, bounds));
289289
}
290290

291291
Expr get_stride() {
@@ -411,7 +411,7 @@ class LowerWarpShuffles : public IRMutator {
411411
// the number of lanes (rounded up).
412412
Expr extent = op->extent();
413413
Expr new_size = (alloc->extents[0] + extent - 1) / extent;
414-
new_size = simplify(new_size, true, bounds);
414+
new_size = simplify(new_size, bounds);
415415
new_size = find_constant_bound(new_size, Direction::Upper, bounds);
416416
auto sz = as_const_int(new_size);
417417
user_assert(sz) << "Warp-level allocation with non-constant size: "
@@ -511,7 +511,7 @@ class LowerWarpShuffles : public IRMutator {
511511
// of the index and shifting the high bits down to cover
512512
// them. Reassembling the result into a flat address gives
513513
// the expression below.
514-
Expr in_warp_idx = simplify((idx / (warp_size * stride)) * stride + reduce_expr(idx, stride, bounds), true, bounds);
514+
Expr in_warp_idx = simplify((idx / (warp_size * stride)) * stride + reduce_expr(idx, stride, bounds), bounds);
515515
return Store::make(op->name, value, in_warp_idx, op->param, op->predicate, ModulusRemainder());
516516
} else {
517517
return IRMutator::visit(op);
@@ -536,7 +536,7 @@ class LowerWarpShuffles : public IRMutator {
536536
// Load the right lanes from stripe number i
537537
equiv = select(idx >= i, make_warp_load(type, name, make_const(idx.type(), i), lane), equiv);
538538
}
539-
return simplify(equiv, true, bounds);
539+
return simplify(equiv, bounds);
540540
}
541541

542542
// Load the value to be shuffled
@@ -606,7 +606,7 @@ class LowerWarpShuffles : public IRMutator {
606606
} else if (expr_match((this_lane + wild) % wild, lane, result) &&
607607
(bits = is_const_power_of_two_integer(result[1])) &&
608608
*bits <= 5) {
609-
result[0] = simplify(result[0] % result[1], true, bounds);
609+
result[0] = simplify(result[0] % result[1], bounds);
610610
// Rotate. Mux a shuffle up and a shuffle down. Uses fewer
611611
// intermediate registers than using a general gather for
612612
// this.
@@ -617,7 +617,7 @@ class LowerWarpShuffles : public IRMutator {
617617
shfl_args({membermask, base_val, (1 << *bits) - result[0], 0}), Call::PureExtern);
618618
Expr cond = (this_lane >= (1 << *bits) - result[0]);
619619
Expr equiv = select(cond, up, down);
620-
shuffled = simplify(equiv, true, bounds);
620+
shuffled = simplify(equiv, bounds);
621621
} else {
622622
// The format of the mask is a pain. The high bits tell
623623
// you how large the a warp is for this instruction
@@ -647,10 +647,10 @@ class LowerWarpShuffles : public IRMutator {
647647
Expr stride = alloc->stride;
648648

649649
// Break the index into lane and stripe components
650-
Expr lane = simplify(reduce_expr(idx / stride, warp_size, bounds), true, bounds);
651-
idx = simplify((idx / (warp_size * stride)) * stride + reduce_expr(idx, stride, bounds), true, bounds);
650+
Expr lane = simplify(reduce_expr(idx / stride, warp_size, bounds), bounds);
651+
idx = simplify((idx / (warp_size * stride)) * stride + reduce_expr(idx, stride, bounds), bounds);
652652
// We don't want the idx to depend on the lane var, so try to eliminate it
653-
idx = simplify(solve_expression(idx, this_lane_name).result, true, bounds);
653+
idx = simplify(solve_expression(idx, this_lane_name).result, bounds);
654654
return make_warp_load(op->type, op->name, idx, lane);
655655
} else {
656656
return IRMutator::visit(op);

src/ParallelRVar.cpp

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -157,14 +157,9 @@ bool can_parallelize_rvar(const string &v,
157157
// Pull out common non-boolean terms
158158
hazard = common_subexpression_elimination(hazard);
159159
hazard = SubstituteInBooleanLets().mutate(hazard);
160-
hazard = simplify(hazard, false, bounds);
160+
hazard = simplify(hazard, bounds);
161161
debug(3) << "Simplified to: " << hazard << "\n";
162162

163-
// strip lets
164-
while (const Let *l = hazard.as<Let>()) {
165-
hazard = l->body;
166-
}
167-
168163
return is_const_zero(hazard);
169164
}
170165

src/PrintLoopNest.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ string print_loop_nest(const vector<Function> &output_funcs) {
216216
s = allocation_bounds_inference(s, env, func_bounds);
217217
s = remove_undef(s);
218218
s = uniquify_variable_names(s);
219-
s = simplify(s, false);
219+
s = simplify(s);
220220

221221
// Now convert that to pseudocode
222222
std::ostringstream sstr;

src/Simplify.cpp

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,7 @@ using std::pair;
1515
using std::string;
1616
using std::vector;
1717

18-
Simplify::Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai)
19-
: remove_dead_code(r) {
18+
Simplify::Simplify(const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai) {
2019

2120
// Only respect the constant bounds from the containing scope.
2221
for (auto iter = bi->cbegin(); iter != bi->cend(); ++iter) {
@@ -361,11 +360,11 @@ Simplify::ScopedFact::~ScopedFact() {
361360
}
362361
}
363362

364-
Expr simplify(const Expr &e, bool remove_dead_let_stmts,
363+
Expr simplify(const Expr &e,
365364
const Scope<Interval> &bounds,
366365
const Scope<ModulusRemainder> &alignment,
367366
const std::vector<Expr> &assumptions) {
368-
Simplify m(remove_dead_let_stmts, &bounds, &alignment);
367+
Simplify m(&bounds, &alignment);
369368
std::vector<Simplify::ScopedFact> facts;
370369
facts.reserve(assumptions.size());
371370
for (const Expr &a : assumptions) {
@@ -378,11 +377,11 @@ Expr simplify(const Expr &e, bool remove_dead_let_stmts,
378377
return result;
379378
}
380379

381-
Stmt simplify(const Stmt &s, bool remove_dead_let_stmts,
380+
Stmt simplify(const Stmt &s,
382381
const Scope<Interval> &bounds,
383382
const Scope<ModulusRemainder> &alignment,
384383
const std::vector<Expr> &assumptions) {
385-
Simplify m(remove_dead_let_stmts, &bounds, &alignment);
384+
Simplify m(&bounds, &alignment);
386385
std::vector<Simplify::ScopedFact> facts;
387386
facts.reserve(assumptions.size());
388387
for (const Expr &a : assumptions) {
@@ -416,7 +415,7 @@ bool can_prove(Expr e, const Scope<Interval> &bounds) {
416415

417416
Expr orig = e;
418417

419-
e = simplify(e, true, bounds);
418+
e = simplify(e, bounds);
420419

421420
// Take a closer look at all failed proof attempts to hunt for
422421
// simplifier weaknesses

src/Simplify.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@ namespace Internal {
2222
*/
2323
// @{
2424
Stmt simplify(const Stmt &,
25-
bool remove_dead_code = true,
2625
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
2726
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
2827
const std::vector<Expr> &assumptions = std::vector<Expr>());
2928
Expr simplify(const Expr &,
30-
bool remove_dead_code = true,
3129
const Scope<Interval> &bounds = Scope<Interval>::empty_scope(),
3230
const Scope<ModulusRemainder> &alignment = Scope<ModulusRemainder>::empty_scope(),
3331
const std::vector<Expr> &assumptions = std::vector<Expr>());

src/Simplify_Internal.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
3434
using Super = VariadicVisitor<Simplify, Expr, Stmt>;
3535

3636
public:
37-
Simplify(bool r, const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai);
37+
Simplify(const Scope<Interval> *bi, const Scope<ModulusRemainder> *ai);
3838

3939
struct ExprInfo {
4040
// We track constant integer bounds when they exist
@@ -353,8 +353,6 @@ class Simplify : public VariadicVisitor<Simplify, Expr, Stmt> {
353353
}
354354
#endif
355355

356-
bool remove_dead_code;
357-
358356
// Returns true iff t is an integral type where overflow is undefined
359357
HALIDE_ALWAYS_INLINE
360358
bool no_overflow_int(Type t) {

src/Simplify_Let.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -293,8 +293,7 @@ Body Simplify::simplify_let(const LetOrLetStmt *op, ExprInfo *info) {
293293
find_var_uses(frame.new_value, unused_vars);
294294
}
295295

296-
if ((!remove_dead_code && std::is_same_v<LetOrLetStmt, LetStmt>) ||
297-
(frame.info.old_uses > 0 && !unused_vars.count(frame.op->name))) {
296+
if (frame.info.old_uses > 0 && !unused_vars.count(frame.op->name)) {
298297
// The old name is still in use. We'd better keep it as well.
299298
result = LetOrLetStmt::make(frame.op->name, frame.value, result);
300299
find_var_uses(frame.value, unused_vars);

src/Simplify_Stmts.cpp

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -79,9 +79,13 @@ Stmt Simplify::visit(const IfThenElse *op) {
7979
else_case = Stmt();
8080
}
8181

82-
// Pull out common nodes, but only when the "late in lowering" flag is set. This
83-
// avoids simplifying specializations before they have a chance to specialize.
84-
if (remove_dead_code && equal(then_case, else_case)) {
82+
// This code used to use the remove_dead_lets flag to not merge equal
83+
// clauses on the grounds that they might be specializations that will
84+
// simplify later. However, specializations should be simplified
85+
// aggressively quite early in lowering. If in future there is a bug with
86+
// specializations seemingly disappearing halfway through lowering, try
87+
// disabling this.
88+
if (equal(then_case, else_case)) {
8589
return then_case;
8690
}
8791
const Acquire *then_acquire = then_case.as<Acquire>();

src/StorageFolding.cpp

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -588,14 +588,14 @@ class AttemptStorageFoldingOfFunction : public IRMutator {
588588
Expr loop_var = Variable::make(Int(32), op->name);
589589
Expr steady_state = (op->min < loop_var);
590590

591-
Expr min_steady = simplify(substitute(steady_state, const_true(), min), true, steady_bounds);
592-
Expr max_steady = simplify(substitute(steady_state, const_true(), max), true, steady_bounds);
593-
Expr min_initial = simplify(substitute(steady_state, const_false(), min), true, bounds);
594-
Expr max_initial = simplify(substitute(steady_state, const_false(), max), true, bounds);
595-
Expr extent_initial = simplify(substitute(loop_var, op->min, max_initial - min_initial + 1), true, bounds);
596-
Expr extent_steady = simplify(max_steady - min_steady + 1, true, steady_bounds);
591+
Expr min_steady = simplify(substitute(steady_state, const_true(), min), steady_bounds);
592+
Expr max_steady = simplify(substitute(steady_state, const_true(), max), steady_bounds);
593+
Expr min_initial = simplify(substitute(steady_state, const_false(), min), bounds);
594+
Expr max_initial = simplify(substitute(steady_state, const_false(), max), bounds);
595+
Expr extent_initial = simplify(substitute(loop_var, op->min, max_initial - min_initial + 1), bounds);
596+
Expr extent_steady = simplify(max_steady - min_steady + 1, steady_bounds);
597597
Expr extent = Max::make(extent_initial, extent_steady);
598-
extent = simplify(common_subexpression_elimination(extent), true, bounds);
598+
extent = simplify(common_subexpression_elimination(extent), bounds);
599599

600600
// Find the StorageDim corresponding to dim.
601601
const std::vector<StorageDim> &storage_dims = func.schedule().storage_dims();

0 commit comments

Comments
 (0)