@@ -62,53 +62,126 @@ bool IsCanonicalPrimExpr(const PrimExpr& expr) {
6262 * \brief Mutator to canonicalize ShapeExpr in struct info
6363 *
6464 * This pass handles ShapeExpr canonicalization by:
65- * 1. Detecting compound PrimExpr in ShapeExpr dimensions
66- * 2. Lifting them into separate ShapeExpr bindings
65+ * 1. Detecting compound PrimExpr in variable struct_info
66+ * 2. Emitting ShapeExpr bindings to compute expressions
6767 * 3. Using MatchCast to extract values into fresh symbolic tir::Var
68- * 4. Replacing compound expressions with these canonical vars
68+ * 4. Replacing compound expressions with these canonical vars in struct_info
6969 */
7070class ShapeExprCanonicalizer : public ExprMutator {
7171 public:
7272 using ExprMutator::VisitExpr_;
7373
7474 Expr VisitExpr_ (const FunctionNode* func) override {
7575 // Reset state for each function
76- auto cached_compound_to_var = compound_expr_to_var_;
77- auto cached_counter = symbolic_var_counter_;
76+ symbolic_var_counter_ = 0 ;
77+ compound_expr_to_var_.clear ();
78+ emitted_bindings_.clear ();
7879
79- auto result = ExprMutator::VisitExpr_ (func);
80+ // Process the function body
81+ Expr new_body = VisitExpr (func->body );
8082
81- compound_expr_to_var_ = cached_compound_to_var;
82- symbolic_var_counter_ = cached_counter;
83+ if (new_body.same_as (func->body )) {
84+ return ffi::GetRef<Function>(func);
85+ }
8386
84- return result;
87+ return Function (func->params , new_body, func->ret_struct_info , func->is_pure , func->attrs ,
88+ func->span );
89+ }
90+
91+ void VisitBinding (const Binding& binding) override {
92+ // Emit canonicalization bindings before processing the binding itself
93+ auto sinfo = GetStructInfo (binding->var );
94+ if (NeedsCanonicalization (sinfo)) {
95+ CanonicalizeStructInfoAndEmit (sinfo);
96+ }
97+
98+ // Now process the binding normally - VisitVarDef will handle struct_info canonicalization
99+ ExprMutator::VisitBinding (binding);
85100 }
86101
87- /* !
88- * \brief Override VisitVarDef to canonicalize struct_info
89- *
90- * This is where we intercept variable definitions and canonicalize any
91- * compound PrimExpr in their TensorStructInfo shapes.
92- */
93102 Var VisitVarDef (const Var& var) override {
94103 auto sinfo = GetStructInfo (var);
95-
96- // Check if we need to canonicalize the struct_info
97104 auto canonical_sinfo = CanonicalizeStructInfo (sinfo);
98105
99106 if (canonical_sinfo.same_as (sinfo)) {
100- // No changes needed
101107 return ExprMutator::VisitVarDef (var);
102108 }
103109
104- // Create a new var with canonicalized strcut_info
110+ // Create a new var with canonicalized struct_info
111+ Var canonical_var;
105112 if (var->IsInstance <DataflowVarNode>()) {
106- return DataflowVar (var->vid , canonical_sinfo, var->span );
113+ canonical_var = DataflowVar (var->vid , canonical_sinfo, var->span );
114+ } else {
115+ canonical_var = Var (var->vid , canonical_sinfo, var->span );
107116 }
108- return Var (var->vid , canonical_sinfo, var->span );
117+
118+ return ExprMutator::VisitVarDef (canonical_var);
109119 }
110120
111121 private:
122+ /* !
123+ * \brief Check if struct_info needs canonicalization
124+ */
125+ bool NeedsCanonicalization (const StructInfo& sinfo) {
126+ if (auto tensor_sinfo = sinfo.as <TensorStructInfoNode>()) {
127+ if (!tensor_sinfo->shape .defined ()) {
128+ return false ;
129+ }
130+ auto shape_expr = tensor_sinfo->shape .as <ShapeExprNode>();
131+ if (!shape_expr) {
132+ return false ;
133+ }
134+ for (const PrimExpr& dim : shape_expr->values ) {
135+ if (!IsCanonicalPrimExpr (dim)) {
136+ return true ;
137+ }
138+ }
139+ return false ;
140+ } else if (auto tuple_sinfo = sinfo.as <TupleStructInfoNode>()) {
141+ for (const StructInfo& field : tuple_sinfo->fields ) {
142+ if (NeedsCanonicalization (field)) {
143+ return true ;
144+ }
145+ }
146+ return false ;
147+ }
148+ return false ;
149+ }
150+
151+ /* !
152+ * Canonicalize struct info and emit necessary bindings
153+ */
154+ void CanonicalizeStructInfoAndEmit (const StructInfo& sinfo) {
155+ if (auto tensor_sinfo = sinfo.as <TensorStructInfoNode>()) {
156+ CanonicalizeTensorStructInfoAndEmit (ffi::GetRef<TensorStructInfo>(tensor_sinfo));
157+ } else if (auto tuple_sinfo = sinfo.as <TupleStructInfoNode>()) {
158+ for (const StructInfo& field : tuple_sinfo->fields ) {
159+ CanonicalizeStructInfoAndEmit (field);
160+ }
161+ }
162+ }
163+
164+ /* !
165+ * Canonicalize tensor struct info and emit necessary bindings
166+ */
167+ void CanonicalizeTensorStructInfoAndEmit (const TensorStructInfo& sinfo) {
168+ if (!sinfo->shape .defined ()) {
169+ return ;
170+ }
171+
172+ auto shape_expr = sinfo->shape .as <ShapeExprNode>();
173+ if (!shape_expr) {
174+ return ;
175+ }
176+
177+ // Emit bindings for each compound dimension
178+ for (const PrimExpr& dim : shape_expr->values ) {
179+ if (!IsCanonicalPrimExpr (dim)) {
180+ CanonicalizeDimension (dim);
181+ }
182+ }
183+ }
184+
112185 /* !
113186 * \brief Canonicalize struct info by lifting compound shape expressions
114187 */
@@ -140,7 +213,7 @@ class ShapeExprCanonicalizer : public ExprMutator {
140213 bool changed = false ;
141214
142215 for (const PrimExpr& dim : shape_expr->values ) {
143- PrimExpr canonical_dim = CanonicalizeDimension (dim);
216+ PrimExpr canonical_dim = GetCanonicalDimension (dim);
144217 canonical_dims.push_back (canonical_dim);
145218 changed |= !canonical_dim.same_as (dim);
146219 }
@@ -174,15 +247,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
174247 }
175248
176249 /* !
177- * \brief Canonicalize a single shape dimension
178- *
179- * If the dimension is a compound PrimExpr:
180- * 1. Emit a ShapeExpr binding containing the compound expression
181- * 2. Create a fresh symbolic tir::Var
182- * 3. Emit a MatchCast to bind the computed value to the symbolic var
183- * 4. Return the symbolic var
250+ * \brief Get the canonical form of a dimension (returns the symbolic var if already emitted)
184251 */
185- PrimExpr CanonicalizeDimension (const PrimExpr& dim) {
252+ PrimExpr GetCanonicalDimension (const PrimExpr& dim) {
186253 // If already canonical, return as is
187254 if (IsCanonicalPrimExpr (dim)) {
188255 return dim;
@@ -193,9 +260,42 @@ class ShapeExprCanonicalizer : public ExprMutator {
193260 return it->second ;
194261 }
195262
196- // Create a fresh symbolic variable
263+ // Create a fresh symbolic variable, but don't emit yet
197264 tir::Var symbolic_var = CreateFreshSymbolicVar (dim->dtype );
198265
266+ compound_expr_to_var_[dim] = symbolic_var;
267+
268+ return symbolic_var;
269+ }
270+
271+ /* !
272+ * \brief Emit bindings for a single compound dimension
273+ *
274+ * If the dimension is a compound PrimExpr:
275+ * 1. Emit a ShapeExpr binding containing the compound expression
276+ * 2. Create a fresh symbolic tir::Var
277+ * 3. Emit a MatchCast to bind the computed value to the symbolic var
278+ */
279+ void CanonicalizeDimension (const PrimExpr& dim) {
280+ // If already canonical, nothing to emit
281+ if (IsCanonicalPrimExpr (dim)) {
282+ return ;
283+ }
284+
285+ // Check If we've already emitted bindings for this expression
286+ auto it = compound_expr_to_var_.find (dim);
287+ if (it == compound_expr_to_var_.end ()) {
288+ // This should not happen if GetCanonicalDimension was called first
289+ return ;
290+ }
291+
292+ // Check if we've already emitted the bindings
293+ if (emitted_bindings_.count (dim)) {
294+ return ;
295+ }
296+
297+ tir::Var symbolic_var = it->second ;
298+
199299 // Emit shape binding: shape_var = R.shape([compound_expr])
200300 ShapeExpr shape_value ({dim});
201301 Var shape_var = builder_->Emit (shape_value);
@@ -206,10 +306,8 @@ class ShapeExprCanonicalizer : public ExprMutator {
206306 Var match_cast_var (" _" , match_sinfo);
207307 builder_->EmitNormalized (MatchCast (match_cast_var, shape_var, match_sinfo));
208308
209- // Cache the mapping to avoid duplicate bindings
210- compound_expr_to_var_[dim] = symbolic_var;
211-
212- return symbolic_var;
309+ // Mark as emitted
310+ emitted_bindings_.insert (dim);
213311 }
214312
215313 /* !
@@ -223,6 +321,9 @@ class ShapeExprCanonicalizer : public ExprMutator {
223321 // Cache to avoid creating duplicate bindings for the same compound expression
224322 std::unordered_map<PrimExpr, tir::Var, StructuralHash, StructuralEqual> compound_expr_to_var_;
225323
324+ // Track which compound expressions have had their bindings emitted
325+ std::unordered_set<PrimExpr, StructuralHash, StructuralEqual> emitted_bindings_;
326+
226327 // Counter for generating unique symbolic variable names
227328 int symbolic_var_counter_ = 0 ;
228329};
0 commit comments