Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions maldoca/astgen/ast_def.cc
Original file line number Diff line number Diff line change
Expand Up @@ -360,6 +360,24 @@ absl::StatusOr<AstDef> AstDef::FromProto(const AstDefPb& pb) {
for (auto kind : union_type_pb.kinds()) {
union_type_node->kinds_.push_back(static_cast<FieldKind>(kind));
}
if (union_type_pb.has_ir_op_name()) {
union_type_node->ir_op_name_ = union_type_pb.ir_op_name();
}
if (union_type_pb.has_should_generate_dispatch()) {
union_type_node->should_generate_dispatch_ =
union_type_pb.should_generate_dispatch();
}
for (const auto& o : union_type_pb.dispatch_overrides()) {
std::optional<std::string> ir_op_name;
if (o.has_ir_op_name()) {
ir_op_name = o.ir_op_name();
}
union_type_node->dispatch_overrides_.emplace(
o.type(), NodeDef::DispatchOverride{o.visitor(), ir_op_name});
}
for (const auto& s : union_type_pb.dispatch_skip()) {
union_type_node->dispatch_skip_.insert(s);
}
if (nodes.contains(union_type_pb.name())) {
return absl::InvalidArgumentError(
absl::StrCat(union_type_pb.name(), " already exists!"));
Expand Down
42 changes: 42 additions & 0 deletions maldoca/astgen/ast_def.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <vector>

#include "absl/container/flat_hash_map.h"
#include "absl/container/flat_hash_set.h"
#include "absl/log/log.h"
#include "absl/status/statusor.h"
#include "absl/strings/string_view.h"
Expand Down Expand Up @@ -175,6 +176,44 @@ class NodeDef {
// If false, the op is expected to be manually written.
bool should_generate_ir_op() const { return should_generate_ir_op_; }

// Whether IR dispatch code should be automatically generated for unions.
bool should_generate_dispatch() const { return should_generate_dispatch_; }

// Overrides for generated dispatch code.
//
// For a union type (e.g., `Expression`), the generator automatically
// produces dispatch code (e.g., `dynamic_cast` chain in AST-to-IR, or
// `TypeSwitch` in IR-to-AST) to route to the correct visitor for each member
// (e.g., `VisitIdentifier`).
//
// `DispatchOverride` allows customizing this routing for specific members.
struct DispatchOverride {
// The name of the visitor function to call.
// E.g., "VisitIdentifierRef" instead of the default "VisitIdentifier".
std::string visitor;

// The IR op name to use for this member.
// E.g., "jsir.IdentifierRef" instead of "jsir.Identifier".
std::optional<std::string> ir_op_name;
};

// Map from union member type name (e.g., "Identifier") to its dispatch
// override.
const absl::flat_hash_map<std::string, DispatchOverride>& dispatch_overrides()
const {
return dispatch_overrides_;
}

// Set of union member type names to skip in the generated dispatch code.
//
// If a member is skipped, it will not be included in the generated
// dispatch methods, and must be handled manually if needed.
//
// E.g., {"InvalidExpression"} to skip generating dispatch for invalid nodes.
const absl::flat_hash_set<std::string>& dispatch_skip() const {
return dispatch_skip_;
}

// The allowed FieldKinds for this node. Does not include those specified in
// ancestors.
//
Expand Down Expand Up @@ -293,6 +332,9 @@ class NodeDef {
std::vector<FieldKind> aggregated_kinds_;
bool has_control_flow_;
std::optional<std::string> ir_op_name_;
bool should_generate_dispatch_ = true;
absl::flat_hash_map<std::string, DispatchOverride> dispatch_overrides_;
absl::flat_hash_set<std::string> dispatch_skip_;
bool has_fold_;
std::vector<MlirTrait> additional_mlir_traits_;
std::vector<MlirTrait> aggregated_additional_mlir_traits_;
Expand Down
34 changes: 34 additions & 0 deletions maldoca/astgen/ast_def.proto
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,40 @@ message UnionTypePb {

// Supported kinds. Each kind leads to a different IR op.
repeated FieldKind kinds = 5;

// [Optional] Custom MLIR op name.
optional string ir_op_name = 6;

// If true, automatically generate the corresponding IR dispatch code.
optional bool should_generate_dispatch = 7 [default = true];

message DispatchOverridePb {
// The name of the union member type to override.
// E.g., "Identifier".
optional string type = 1;

// The name of the visitor function to call.
// E.g., "VisitIdentifierRef" instead of the default "VisitIdentifier".
optional string visitor = 2;

// The IR op name to use for this member.
// E.g., "jsir.IdentifierRef" instead of "jsir.Identifier".
optional string ir_op_name = 3;
}

// Overrides for generated dispatch code.
//
// Allows customizing the generated visitor and IR op name for specific
// union members.
repeated DispatchOverridePb dispatch_overrides = 8;

// Set of union member type names to skip in the generated dispatch code.
//
// If a member is skipped, it will not be included in the generated
// dispatch methods, and must be handled manually if needed.
//
// E.g., "InvalidExpression" to skip generating dispatch for invalid nodes.
repeated string dispatch_skip = 9;
}

// Top-level AST definition.
Expand Down
76 changes: 59 additions & 17 deletions maldoca/astgen/ast_gen_main.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,31 @@ ABSL_FLAG(std::string, ast_path, "", "The directory for the AST code in C++.");
ABSL_FLAG(std::string, ir_path, "",
"The directory for the IR code in TableGen and C++.");

// Flags to support mapping AST nodes to a different target IR dialect
// (e.g. mapping SWC's "jsswc" AST to the standard "jsir" dialect).
ABSL_FLAG(std::string, ir_lang_name, "",
"The language name for the IR (e.g. 'js').");

// Overrides to prevent generated files from overwriting other dialect
// conversions and to use custom names/paths.
ABSL_FLAG(std::string, ast_to_ir_cc_path, "",
"Override output path for generated AST to IR C++ source.");

ABSL_FLAG(std::string, ir_to_ast_cc_path, "",
"Override output path for generated IR to AST C++ source.");

ABSL_FLAG(std::string, ast_to_ir_header_include_path, "",
"Override include path for AST to IR header in generated source.");

ABSL_FLAG(std::string, ir_to_ast_header_include_path, "",
"Override include path for IR to AST header in generated source.");

ABSL_FLAG(std::string, ast_to_ir_class_name, "",
"Override class name for AST to IR converter.");

ABSL_FLAG(std::string, ir_to_ast_class_name, "",
"Override class name for IR to AST converter.");

namespace maldoca {
namespace {

Expand All @@ -58,6 +83,15 @@ absl::Status AstGenMain() {
auto cc_namespace = absl::GetFlag(FLAGS_cc_namespace);
auto ast_path = absl::GetFlag(FLAGS_ast_path);
auto ir_path = absl::GetFlag(FLAGS_ir_path);
auto ir_lang_name = absl::GetFlag(FLAGS_ir_lang_name);
auto ast_to_ir_cc_path_flag = absl::GetFlag(FLAGS_ast_to_ir_cc_path);
auto ir_to_ast_cc_path_flag = absl::GetFlag(FLAGS_ir_to_ast_cc_path);
auto ast_to_ir_header_include_path =
absl::GetFlag(FLAGS_ast_to_ir_header_include_path);
auto ir_to_ast_header_include_path =
absl::GetFlag(FLAGS_ir_to_ast_header_include_path);
auto ast_to_ir_class_name = absl::GetFlag(FLAGS_ast_to_ir_class_name);
auto ir_to_ast_class_name = absl::GetFlag(FLAGS_ir_to_ast_class_name);

AstDefPb ast_def_pb;
MALDOCA_RETURN_IF_ERROR(ParseTextProtoFile(ast_def_path, &ast_def_pb));
Expand Down Expand Up @@ -87,27 +121,35 @@ absl::Status AstGenMain() {
SetFileContents(ast_from_json_path, ast_from_json));

if (!ir_path.empty()) {
std::string ir_tablegen = PrintIrTableGen(ast_def, ir_path);
auto ir_tablegen_path = JoinPath(
ir_path, absl::StrCat(ast_def.lang_name(), "ir_ops.generated.td"));
std::cout << "Writing ir_tablegen to " << ir_tablegen_path << "\n";
MALDOCA_RETURN_IF_ERROR(
SetFileContents(ir_tablegen_path, ir_tablegen));

std::string ast_to_ir =
PrintAstToIrSource(ast_def, cc_namespace, ast_path, ir_path);
auto ast_to_ir_path = JoinPath(
ir_path, "conversion",
absl::StrCat("ast_to_", ast_def.lang_name(), "ir.generated.cc"));
if (ir_lang_name.empty() || ir_lang_name == ast_def.lang_name()) {
std::string ir_tablegen = PrintIrTableGen(ast_def, ir_path);
auto ir_tablegen_path = JoinPath(
ir_path, absl::StrCat(ast_def.lang_name(), "ir_ops.generated.td"));
std::cout << "Writing ir_tablegen to " << ir_tablegen_path << "\n";
MALDOCA_RETURN_IF_ERROR(SetFileContents(ir_tablegen_path, ir_tablegen));
}

std::string ast_to_ir = PrintAstToIrSource(
ast_def, cc_namespace, ast_path, ir_path, ir_lang_name,
ast_to_ir_header_include_path, ast_to_ir_class_name);
auto ast_to_ir_path =
ast_to_ir_cc_path_flag.empty()
? JoinPath(ir_path, "conversion",
absl::StrCat("ast_to_", ast_def.lang_name(),
"ir.generated.cc"))
: ast_to_ir_cc_path_flag;
std::cout << "Writing ast_to_ir to " << ast_to_ir_path << "\n";
MALDOCA_RETURN_IF_ERROR(
SetFileContents(ast_to_ir_path, ast_to_ir));

std::string ir_to_ast =
PrintIrToAstSource(ast_def, cc_namespace, ast_path, ir_path);
auto ir_to_ast_path = JoinPath(
ir_path, "conversion",
absl::StrCat(ast_def.lang_name(), "ir_to_ast.generated.cc"));
std::string ir_to_ast = PrintIrToAstSource(
ast_def, cc_namespace, ast_path, ir_path, ir_lang_name,
ir_to_ast_header_include_path, ir_to_ast_class_name);
auto ir_to_ast_path = ir_to_ast_cc_path_flag.empty()
? JoinPath(ir_path, "conversion",
absl::StrCat(ast_def.lang_name(),
"ir_to_ast.generated.cc"))
: ir_to_ast_cc_path_flag;
std::cout << "Writing ir_to_ast to " << ir_to_ast_path << "\n";
MALDOCA_RETURN_IF_ERROR(
SetFileContents(ir_to_ast_path, ir_to_ast));
Expand Down
3 changes: 3 additions & 0 deletions maldoca/astgen/ast_source_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,9 @@ void AstSourcePrinter::PrintConstructor(const NodeDef& node,
}

Print(")");
if (!ancestor->aggregated_fields().empty()) {
Print(" /* NOLINT */");
}
}

for (const FieldDef& field : node.fields()) {
Expand Down
Loading
Loading