Skip to content

Commit b492829

Browse files
committed
Refactor VisitExpr_
1 parent b6e091b commit b492829

File tree

1 file changed

+136
-35
lines changed

1 file changed

+136
-35
lines changed

src/relax/transform/canonicalize_shape_expr.cc

Lines changed: 136 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -62,53 +62,126 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) {
6262
* \brief Mutator to canonicalize ShapeExpr in struct info
6363
*
6464
* This pass handles ShapeExpr canonicalization by:
65-
* 1. Detecting compound PrimExpr in ShapeExpr dimensions
66-
* 2. Lifting them into separate ShapeExpr bindings
65+
* 1. Detecting compound PrimExpr in variable struct_info
66+
* 2. Emitting ShapeExpr bindings to compute expressions
6767
* 3. Using MatchCast to extract values into fresh symbolic tir::Var
68-
* 4. Replacing compound expressions with these canonical vars
68+
* 4. Replacing compound expressions with these canonical vars in struct_info
6969
*/
7070
class ShapeExprCanonicalizer : public ExprMutator {
7171
public:
7272
using ExprMutator::VisitExpr_;
7373

7474
Expr VisitExpr_(const FunctionNode* func) override {
7575
// Reset state for each function
76-
auto cached_compound_to_var = compound_expr_to_var_;
77-
auto cached_counter = symbolic_var_counter_;
76+
symbolic_var_counter_ = 0;
77+
compound_expr_to_var_.clear();
78+
emitted_bindings_.clear();
7879

79-
auto result = ExprMutator::VisitExpr_(func);
80+
// Process the function body
81+
Expr new_body = VisitExpr(func->body);
8082

81-
compound_expr_to_var_ = cached_compound_to_var;
82-
symbolic_var_counter_ = cached_counter;
83+
if (new_body.same_as(func->body)) {
84+
return ffi::GetRef<Function>(func);
85+
}
8386

84-
return result;
87+
return Function(func->params, new_body, func->ret_struct_info, func->is_pure, func->attrs,
88+
func->span);
89+
}
90+
91+
void VisitBinding(const Binding& binding) override {
92+
// Emit canonicalization bindings before processing the binding itself
93+
auto sinfo = GetStructInfo(binding->var);
94+
if (NeedsCanonicalization(sinfo)) {
95+
CanonicalizeStructInfoAndEmit(sinfo);
96+
}
97+
98+
// Now process the binding normally - VisitVarDef will handle struct_info canonicalization
99+
ExprMutator::VisitBinding(binding);
85100
}
86101

87-
/*!
88-
* \brief Override VisitVarDef to canonicalize struct_info
89-
*
90-
* This is where we intercept variable definitions and canonicalize any
91-
* compound PrimExpr in their TensorStructInfo shapes.
92-
*/
93102
Var VisitVarDef(const Var& var) override {
94103
auto sinfo = GetStructInfo(var);
95-
96-
// Check if we need to canonicalize the struct_info
97104
auto canonical_sinfo = CanonicalizeStructInfo(sinfo);
98105

99106
if (canonical_sinfo.same_as(sinfo)) {
100-
// No changes needed
101107
return ExprMutator::VisitVarDef(var);
102108
}
103109

104-
// Create a new var with canonicalized strcut_info
110+
// Create a new var with canonicalized struct_info
111+
Var canonical_var;
105112
if (var->IsInstance<DataflowVarNode>()) {
106-
return DataflowVar(var->vid, canonical_sinfo, var->span);
113+
canonical_var = DataflowVar(var->vid, canonical_sinfo, var->span);
114+
} else {
115+
canonical_var = Var(var->vid, canonical_sinfo, var->span);
107116
}
108-
return Var(var->vid, canonical_sinfo, var->span);
117+
118+
return ExprMutator::VisitVarDef(canonical_var);
109119
}
110120

111121
private:
122+
/*!
123+
* \brief Check if struct_info needs canonicalization
124+
*/
125+
bool NeedsCanonicalization(const StructInfo& sinfo) {
126+
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
127+
if (!tensor_sinfo->shape.defined()) {
128+
return false;
129+
}
130+
auto shape_expr = tensor_sinfo->shape.as<ShapeExprNode>();
131+
if (!shape_expr) {
132+
return false;
133+
}
134+
for (const PrimExpr& dim : shape_expr->values) {
135+
if (!IsCanonicalPrimExpr(dim)) {
136+
return true;
137+
}
138+
}
139+
return false;
140+
} else if (auto tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
141+
for (const StructInfo& field : tuple_sinfo->fields) {
142+
if (NeedsCanonicalization(field)) {
143+
return true;
144+
}
145+
}
146+
return false;
147+
}
148+
return false;
149+
}
150+
151+
/*!
152+
* Canonicalize struct info and emit necessary bindings
153+
*/
154+
void CanonicalizeStructInfoAndEmit(const StructInfo& sinfo) {
155+
if (auto tensor_sinfo = sinfo.as<TensorStructInfoNode>()) {
156+
CanonicalizeTensorStructInfoAndEmit(ffi::GetRef<TensorStructInfo>(tensor_sinfo));
157+
} else if (auto tuple_sinfo = sinfo.as<TupleStructInfoNode>()) {
158+
for (const StructInfo& field : tuple_sinfo->fields) {
159+
CanonicalizeStructInfoAndEmit(field);
160+
}
161+
}
162+
}
163+
164+
/*!
165+
* Canonicalize tensor struct info and emit necessary bindings
166+
*/
167+
void CanonicalizeTensorStructInfoAndEmit(const TensorStructInfo& sinfo) {
168+
if (!sinfo->shape.defined()) {
169+
return;
170+
}
171+
172+
auto shape_expr = sinfo->shape.as<ShapeExprNode>();
173+
if (!shape_expr) {
174+
return;
175+
}
176+
177+
// Emit bindings for each compound dimension
178+
for (const PrimExpr& dim : shape_expr->values) {
179+
if (!IsCanonicalPrimExpr(dim)) {
180+
CanonicalizeDimension(dim);
181+
}
182+
}
183+
}
184+
112185
/*!
113186
* \brief Canonicalize struct info by lifting compound shape expressions
114187
*/
@@ -140,7 +213,7 @@ class ShapeExprCanonicalizer : public ExprMutator {
140213
bool changed = false;
141214

142215
for (const PrimExpr& dim : shape_expr->values) {
143-
PrimExpr canonical_dim = CanonicalizeDimension(dim);
216+
PrimExpr canonical_dim = GetCanonicalDimension(dim);
144217
canonical_dims.push_back(canonical_dim);
145218
changed |= !canonical_dim.same_as(dim);
146219
}
@@ -174,15 +247,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
174247
}
175248

176249
/*!
177-
* \brief Canonicalize a single shape dimension
178-
*
179-
* If the dimension is a compound PrimExpr:
180-
* 1. Emit a ShapeExpr binding containing the compound expression
181-
* 2. Create a fresh symbolic tir::Var
182-
* 3. Emit a MatchCast to bind the computed value to the symbolic var
183-
* 4. Return the symbolic var
250+
* \brief Get the canonical form of a dimension (returns the symbolic var if already emitted)
184251
*/
185-
PrimExpr CanonicalizeDimension(const PrimExpr& dim) {
252+
PrimExpr GetCanonicalDimension(const PrimExpr& dim) {
186253
// If already canonical, return as is
187254
if (IsCanonicalPrimExpr(dim)) {
188255
return dim;
@@ -193,9 +260,42 @@ class ShapeExprCanonicalizer : public ExprMutator {
193260
return it->second;
194261
}
195262

196-
// Create a fresh symbolic variable
263+
// Create a fresh symbolic variable, but don't emit yet
197264
tir::Var symbolic_var = CreateFreshSymbolicVar(dim->dtype);
198265

266+
compound_expr_to_var_[dim] = symbolic_var;
267+
268+
return symbolic_var;
269+
}
270+
271+
/*!
272+
* \brief Emit bindings for a single compound dimension
273+
*
274+
* If the dimension is a compound PrimExpr:
275+
* 1. Emit a ShapeExpr binding containing the compound expression
276+
* 2. Create a fresh symbolic tir::Var
277+
* 3. Emit a MatchCast to bind the computed value to the symbolic var
278+
*/
279+
void CanonicalizeDimension(const PrimExpr& dim) {
280+
// If already canonical, nothing to emit
281+
if (IsCanonicalPrimExpr(dim)) {
282+
return;
283+
}
284+
285+
// Check If we've already emitted bindings for this expression
286+
auto it = compound_expr_to_var_.find(dim);
287+
if (it == compound_expr_to_var_.end()) {
288+
// This should not happen if GetCanonicalDimension was called first
289+
return;
290+
}
291+
292+
// Check if we've already emitted the bindings
293+
if (emitted_bindings_.count(dim)) {
294+
return;
295+
}
296+
297+
tir::Var symbolic_var = it->second;
298+
199299
// Emit shape binding: shape_var = R.shape([compound_expr])
200300
ShapeExpr shape_value({dim});
201301
Var shape_var = builder_->Emit(shape_value);
@@ -206,10 +306,8 @@ class ShapeExprCanonicalizer : public ExprMutator {
206306
Var match_cast_var("_", match_sinfo);
207307
builder_->EmitNormalized(MatchCast(match_cast_var, shape_var, match_sinfo));
208308

209-
// Cache the mapping to avoid duplicate bindings
210-
compound_expr_to_var_[dim] = symbolic_var;
211-
212-
return symbolic_var;
309+
// Mark as emitted
310+
emitted_bindings_.insert(dim);
213311
}
214312

215313
/*!
@@ -223,6 +321,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
223321
// Cache to avoid creating duplicate bindings for the same compound expression
224322
std::unordered_map<PrimExpr, tir::Var, StructuralHash, StructuralEqual> compound_expr_to_var_;
225323

324+
// Track which compound expressions have had their bindings emitted
325+
std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> emitted_bindings_;
326+
226327
// Counter for generating unique symbolic variable names
227328
int symbolic_var_counter_ = 0;
228329
};

0 commit comments

Comments
 (0)