diff --git a/Makefile b/Makefile index 12b1fb13426a..3c278ec703cc 100644 --- a/Makefile +++ b/Makefile @@ -481,6 +481,7 @@ SOURCE_FILES = \ InjectHostDevBufferCopies.cpp \ Inline.cpp \ InlineReductions.cpp \ + InstructionSelector.cpp \ IntegerDivisionTable.cpp \ Interval.cpp \ Introspection.cpp \ @@ -579,7 +580,8 @@ SOURCE_FILES = \ Var.cpp \ VectorizeLoops.cpp \ WasmExecutor.cpp \ - WrapCalls.cpp + WrapCalls.cpp \ + X86Optimize.cpp # The externally-visible header files that go into making Halide.h. # Don't include anything here that includes llvm headers. @@ -662,6 +664,7 @@ HEADER_FILES = \ InjectHostDevBufferCopies.h \ Inline.h \ InlineReductions.h \ + InstructionSelector.h \ IntegerDivisionTable.h \ Interval.h \ Introspection.h \ @@ -745,7 +748,8 @@ HEADER_FILES = \ Util.h \ Var.h \ VectorizeLoops.h \ - WrapCalls.h + WrapCalls.h \ + X86Optimize.h OBJECTS = $(SOURCE_FILES:%.cpp=$(BUILD_DIR)/%.o) HEADERS = $(HEADER_FILES:%.h=$(SRC_DIR)/%.h) diff --git a/src/Bounds.cpp b/src/Bounds.cpp index 4d6d7f2c3a86..a603ac12db2a 100644 --- a/src/Bounds.cpp +++ b/src/Bounds.cpp @@ -1111,6 +1111,11 @@ class Bounds : public IRVisitor { op->value.accept(this); } + void visit(const VectorInstruction *op) override { + // TODO(rootjalex): we may need to implement bounds queries. + internal_error << "Unexpected VectorInstruction in bounds query: " << Expr(op) << "\n"; + } + void visit(const Call *op) override { TRACK_BOUNDS_INTERVAL; TRACK_BOUNDS_INFO("name:", op->name); diff --git a/src/CMakeLists.txt b/src/CMakeLists.txt index cc9f6805ba4a..47e10888c299 100644 --- a/src/CMakeLists.txt +++ b/src/CMakeLists.txt @@ -82,6 +82,7 @@ set(HEADER_FILES InjectHostDevBufferCopies.h Inline.h InlineReductions.h + InstructionSelector.h IntegerDivisionTable.h Interval.h Introspection.h @@ -166,6 +167,7 @@ set(HEADER_FILES VectorizeLoops.h WasmExecutor.h WrapCalls.h + X86Optimize.h ) set(SOURCE_FILES @@ -245,6 +247,7 @@ set(SOURCE_FILES InjectHostDevBufferCopies.cpp Inline.cpp InlineReductions.cpp + InstructionSelector.cpp IntegerDivisionTable.cpp Interval.cpp Introspection.cpp @@ -344,6 +347,7 @@ set(SOURCE_FILES VectorizeLoops.cpp WasmExecutor.cpp WrapCalls.cpp + X86Optimize.cpp ) ## diff --git a/src/CodeGen_C.cpp b/src/CodeGen_C.cpp index c5b64f0610f5..8bc1abda3502 100644 --- a/src/CodeGen_C.cpp +++ b/src/CodeGen_C.cpp @@ -2829,6 +2829,11 @@ Expr CodeGen_C::scalarize_vector_reduce(const VectorReduce *op) { return Shuffle::make_concat(lanes); } +void CodeGen_C::visit(const VectorInstruction *op) { + internal_error << "CodeGen_C should never receive a VectorInstruction, received:\n" + << Expr(op) << "\n"; +} + void CodeGen_C::visit(const VectorReduce *op) { stream << get_indent() << "// Vector reduce: " << op->op << "\n"; diff --git a/src/CodeGen_C.h b/src/CodeGen_C.h index 9c06d4bb5630..256b35f55efe 100644 --- a/src/CodeGen_C.h +++ b/src/CodeGen_C.h @@ -235,6 +235,7 @@ class CodeGen_C : public IRPrinter { void visit(const Fork *) override; void visit(const Acquire *) override; void visit(const Atomic *) override; + void visit(const VectorInstruction *) override; void visit(const VectorReduce *) override; void visit_binop(Type t, const Expr &a, const Expr &b, const char *op); diff --git a/src/CodeGen_LLVM.cpp b/src/CodeGen_LLVM.cpp index 9721a1c2ad80..fb9f59407245 100644 --- a/src/CodeGen_LLVM.cpp +++ b/src/CodeGen_LLVM.cpp @@ -4022,11 +4022,16 @@ void CodeGen_LLVM::visit(const Shuffle *op) { } } +void CodeGen_LLVM::visit(const VectorInstruction *op) { + internal_error << "CodeGen_LLVM received VectorInstruction node, should be handled by architecture-specific CodeGen class:\n" + << Expr(op) << "\n"; +} + void CodeGen_LLVM::visit(const VectorReduce *op) { codegen_vector_reduce(op, Expr()); } -void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { +Expr CodeGen_LLVM::split_vector_reduce(const VectorReduce *op, const Expr &init) const { Expr val = op->value; const int output_lanes = op->type.lanes(); const int native_lanes = native_vector_bits() / op->type.bits(); @@ -4066,8 +4071,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini equiv = max(equiv, init); } equiv = cast(op->type, equiv); - equiv.accept(this); - return; + return equiv; } if (op->type.is_bool() && op->op == VectorReduce::And) { @@ -4078,8 +4082,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini if (init.defined()) { equiv = min(equiv, init); } - equiv.accept(this); - return; + return equiv; } if (elt == Float(16) && upgrade_type_for_arithmetic(elt) != elt) { @@ -4089,8 +4092,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini equiv = binop(equiv, init); } equiv = cast(op->type, equiv); - equiv.accept(this); - return; + return equiv; } if (output_lanes == 1) { @@ -4189,8 +4191,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini if (initial_value.defined()) { equiv = binop(initial_value, equiv); } - equiv.accept(this); - return; + return equiv; } } @@ -4213,8 +4214,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini equiv = binop(equiv, init); } equiv = common_subexpression_elimination(equiv); - equiv.accept(this); - return; + return equiv; } if (factor > 2 && ((factor & 1) == 0)) { @@ -4246,8 +4246,7 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini equiv = binop(equiv, init); } equiv = common_subexpression_elimination(equiv); - codegen(equiv); - return; + return equiv; } // Extract each slice and combine @@ -4261,8 +4260,13 @@ void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &ini } } equiv = common_subexpression_elimination(equiv); - codegen(equiv); -} // namespace Internal + return equiv; +} + +void CodeGen_LLVM::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { + Expr equiv = split_vector_reduce(op, init); + equiv.accept(this); +} void CodeGen_LLVM::visit(const Atomic *op) { if (!op->mutex_name.empty()) { diff --git a/src/CodeGen_LLVM.h b/src/CodeGen_LLVM.h index d6ee5b26adff..47c5e714bcbb 100644 --- a/src/CodeGen_LLVM.h +++ b/src/CodeGen_LLVM.h @@ -48,6 +48,8 @@ struct ExternSignature; namespace Internal { +class InstructionSelector; + /** A code generator abstract base class. Actual code generators * (e.g. CodeGen_X86) inherit from this. This class is responsible * for taking a Halide Stmt and producing llvm bitcode, machine @@ -361,6 +363,7 @@ class CodeGen_LLVM : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorInstruction *) override; void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; @@ -514,6 +517,11 @@ class CodeGen_LLVM : public IRVisitor { * across backends. */ virtual void codegen_vector_reduce(const VectorReduce *op, const Expr &init); + /** Split up a VectorReduce node if possible, or generate LLVM + intrinsics for full reductions. This is used in + `codegen_vector_reduce`. **/ + virtual Expr split_vector_reduce(const VectorReduce *op, const Expr &init) const; + /** Are we inside an atomic node that uses mutex locks? This is used for detecting deadlocks from nested atomics & illegal vectorization. */ bool inside_atomic_mutex_node; @@ -621,6 +629,12 @@ class CodeGen_LLVM : public IRVisitor { * represents a unique struct type created by a closure or similar. */ std::map struct_type_recovery; + + /** Instruction selection uses `split_vector_reduce` and + * `upgrade_type_for_arithmetic`, so needs access to those + * methods. + */ + friend class InstructionSelector; }; } // namespace Internal diff --git a/src/CodeGen_X86.cpp b/src/CodeGen_X86.cpp index 5d599409fb61..a280227dfd11 100644 --- a/src/CodeGen_X86.cpp +++ b/src/CodeGen_X86.cpp @@ -7,6 +7,7 @@ #include "LLVM_Headers.h" #include "Simplify.h" #include "Util.h" +#include "X86Optimize.h" namespace Halide { namespace Internal { @@ -53,6 +54,9 @@ class CodeGen_X86 : public CodeGen_Posix { CodeGen_X86(Target); protected: + void compile_func(const LoweredFunc &f, + const std::string &simple_name, const std::string &extern_name) override; + string mcpu_target() const override; string mcpu_tune() const override; string mattrs() const override; @@ -69,10 +73,7 @@ class CodeGen_X86 : public CodeGen_Posix { /** Nodes for which we want to emit specific sse/avx intrinsics */ // @{ - void visit(const Add *) override; - void visit(const Sub *) override; void visit(const Cast *) override; - void visit(const Call *) override; void visit(const GT *) override; void visit(const LT *) override; void visit(const LE *) override; @@ -83,7 +84,7 @@ class CodeGen_X86 : public CodeGen_Posix { void visit(const Allocate *) override; void visit(const Load *) override; void visit(const Store *) override; - void codegen_vector_reduce(const VectorReduce *, const Expr &init) override; + void visit(const VectorInstruction *) override; // @} private: @@ -132,9 +133,9 @@ const x86Intrinsic intrinsic_defs[] = { {"llvm.ssub.sat.v8i16", Int(16, 8), "saturating_sub", {Int(16, 8), Int(16, 8)}}, // Sum of absolute differences - {"llvm.x86.sse2.psad.bw", UInt(64, 2), "sum_of_absolute_differences", {UInt(8, 16), UInt(8, 16)}}, - {"llvm.x86.avx2.psad.bw", UInt(64, 4), "sum_of_absolute_differences", {UInt(8, 32), UInt(8, 32)}, Target::AVX2}, - {"llvm.x86.avx512.psad.bw.512", UInt(64, 8), "sum_of_absolute_differences", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake}, + {"llvm.x86.sse2.psad.bw", UInt(64, 2), "sum_absd", {UInt(8, 16), UInt(8, 16)}}, + {"llvm.x86.avx2.psad.bw", UInt(64, 4), "sum_absd", {UInt(8, 32), UInt(8, 32)}, Target::AVX2}, + {"llvm.x86.avx512.psad.bw.512", UInt(64, 8), "sum_absd", {UInt(8, 64), UInt(8, 64)}, Target::AVX512_Skylake}, // Some of the instructions referred to below only appear with // AVX2, but LLVM generates better AVX code if you give it @@ -177,11 +178,9 @@ const x86Intrinsic intrinsic_defs[] = { {"llvm.x86.avx2.pmulh.w", Int(16, 16), "pmulh", {Int(16, 16), Int(16, 16)}, Target::AVX2}, {"llvm.x86.avx2.pmulhu.w", UInt(16, 16), "pmulh", {UInt(16, 16), UInt(16, 16)}, Target::AVX2}, {"llvm.x86.avx2.pmul.hr.sw", Int(16, 16), "pmulhrs", {Int(16, 16), Int(16, 16)}, Target::AVX2}, - {"saturating_pmulhrswx16", Int(16, 16), "saturating_pmulhrs", {Int(16, 16), Int(16, 16)}, Target::AVX2}, {"llvm.x86.sse2.pmulh.w", Int(16, 8), "pmulh", {Int(16, 8), Int(16, 8)}}, {"llvm.x86.sse2.pmulhu.w", UInt(16, 8), "pmulh", {UInt(16, 8), UInt(16, 8)}}, {"llvm.x86.ssse3.pmul.hr.sw.128", Int(16, 8), "pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41}, - {"saturating_pmulhrswx8", Int(16, 8), "saturating_pmulhrs", {Int(16, 8), Int(16, 8)}, Target::SSE41}, // Convert FP32 to BF16 {"vcvtne2ps2bf16x32", BFloat(16, 32), "f32_to_bf16", {Float(32, 32)}, Target::AVX512_SapphireRapids}, @@ -190,6 +189,16 @@ const x86Intrinsic intrinsic_defs[] = { // LLVM does not provide an unmasked 128bit cvtneps2bf16 intrinsic, so provide a wrapper around the masked version. {"vcvtneps2bf16x4", BFloat(16, 4), "f32_to_bf16", {Float(32, 4)}, Target::AVX512_SapphireRapids}, + // Horizontal adds that use (v)phadd(w | d). + {"phaddw_sse3", UInt(16, 8), "horizontal_add", {UInt(16, 16)}, Target::SSE41}, + {"phaddw_sse3", Int(16, 8), "horizontal_add", {Int(16, 16)}, Target::SSE41}, + {"phaddw_avx2", UInt(16, 16), "horizontal_add", {UInt(16, 32)}, Target::AVX2}, + {"phaddw_avx2", Int(16, 16), "horizontal_add", {Int(16, 32)}, Target::AVX2}, + {"phaddd_sse3", UInt(32, 4), "horizontal_add", {UInt(32, 8)}, Target::SSE41}, + {"phaddd_sse3", Int(32, 4), "horizontal_add", {Int(32, 8)}, Target::SSE41}, + {"phaddd_avx2", UInt(32, 8), "horizontal_add", {UInt(32, 16)}, Target::AVX2}, + {"phaddd_avx2", Int(32, 8), "horizontal_add", {Int(32, 16)}, Target::AVX2}, + // 2-way dot products {"llvm.x86.avx2.pmadd.ub.sw", Int(16, 16), "saturating_dot_product", {UInt(8, 32), Int(8, 32)}, Target::AVX2}, {"llvm.x86.ssse3.pmadd.ub.sw.128", Int(16, 8), "saturating_dot_product", {UInt(8, 16), Int(8, 16)}, Target::SSE41}, @@ -270,83 +279,40 @@ void CodeGen_X86::init_module() { } } -// i32(i16_a)*i32(i16_b) +/- i32(i16_c)*i32(i16_d) can be done by -// interleaving a, c, and b, d, and then using dot_product. -bool should_use_dot_product(const Expr &a, const Expr &b, vector &result) { - Type t = a.type(); - internal_assert(b.type() == t); - - if (!(t.is_int() && t.bits() == 32 && t.lanes() >= 4)) { - return false; - } - - const Call *ma = Call::as_intrinsic(a, {Call::widening_mul}); - const Call *mb = Call::as_intrinsic(b, {Call::widening_mul}); - // dot_product can't handle mixed type widening muls. - if (ma && ma->args[0].type() != ma->args[1].type()) { - return false; - } - if (mb && mb->args[0].type() != mb->args[1].type()) { - return false; - } - // If the operands are widening shifts, we might be able to treat these as - // multiplies. - const Call *sa = Call::as_intrinsic(a, {Call::widening_shift_left}); - const Call *sb = Call::as_intrinsic(b, {Call::widening_shift_left}); - if (sa && !is_const(sa->args[1])) { - sa = nullptr; - } - if (sb && !is_const(sb->args[1])) { - sb = nullptr; - } - if ((ma || sa) && (mb || sb)) { - Expr a0 = ma ? ma->args[0] : sa->args[0]; - Expr a1 = ma ? ma->args[1] : lossless_cast(sa->args[0].type(), simplify(make_const(sa->type, 1) << sa->args[1])); - Expr b0 = mb ? mb->args[0] : sb->args[0]; - Expr b1 = mb ? mb->args[1] : lossless_cast(sb->args[0].type(), simplify(make_const(sb->type, 1) << sb->args[1])); - if (a1.defined() && b1.defined()) { - std::vector args = {a0, a1, b0, b1}; - result.swap(args); - return true; +// FIXME: This is nearly identical to CodeGen_LLVM, should re-factor this somehow. +// Only difference is the call to `optimize_x86_instructions()` +void CodeGen_X86::compile_func(const LoweredFunc &f, const std::string &simple_name, + const std::string &extern_name) { + // Generate the function declaration and argument unpacking code. + begin_func(f.linkage, simple_name, extern_name, f.args); + + // If building with MSAN, ensure that calls to halide_msan_annotate_buffer_is_initialized() + // happen for every output buffer if the function succeeds. + if (f.linkage != LinkageType::Internal && + target.has_feature(Target::MSAN)) { + llvm::Function *annotate_buffer_fn = + module->getFunction("halide_msan_annotate_buffer_is_initialized_as_destructor"); + internal_assert(annotate_buffer_fn) + << "Could not find halide_msan_annotate_buffer_is_initialized_as_destructor in module\n"; + annotate_buffer_fn->addParamAttr(0, Attribute::NoAlias); + for (const auto &arg : f.args) { + if (arg.kind == Argument::OutputBuffer) { + register_destructor(annotate_buffer_fn, sym_get(arg.name + ".buffer"), OnSuccess); + } } } - return false; -} -void CodeGen_X86::visit(const Add *op) { - vector matches; - if (should_use_dot_product(op->a, op->b, matches)) { - Expr ac = Shuffle::make_interleave({matches[0], matches[2]}); - Expr bd = Shuffle::make_interleave({matches[1], matches[3]}); - value = call_overloaded_intrin(op->type, "dot_product", {ac, bd}); - if (value) { - return; - } - } - CodeGen_Posix::visit(op); -} + // Generate the function body. + debug(1) << "Generating llvm bitcode for function " << f.name << "...\n"; + debug(1) << "X86: Optimizing vector instructions...\n"; + Stmt body = optimize_x86_instructions(f.body, target, this); + debug(2) << "X86: Lowering after vector instructions:\n" + << body << "\n\n"; -void CodeGen_X86::visit(const Sub *op) { - vector matches; - if (should_use_dot_product(op->a, op->b, matches)) { - // Negate one of the factors in the second expression - Expr negative_2 = lossless_negate(matches[2]); - Expr negative_3 = lossless_negate(matches[3]); - if (negative_2.defined() || negative_3.defined()) { - if (negative_2.defined()) { - matches[2] = negative_2; - } else { - matches[3] = negative_3; - } - Expr ac = Shuffle::make_interleave({matches[0], matches[2]}); - Expr bd = Shuffle::make_interleave({matches[1], matches[3]}); - value = call_overloaded_intrin(op->type, "dot_product", {ac, bd}); - if (value) { - return; - } - } - } - CodeGen_Posix::visit(op); + body.accept(this); + + // Clean up and return. + end_func(f.args); } void CodeGen_X86::visit(const GT *op) { @@ -455,38 +421,12 @@ void CodeGen_X86::visit(const Select *op) { } void CodeGen_X86::visit(const Cast *op) { - if (!op->type.is_vector()) { // We only have peephole optimizations for vectors in here. CodeGen_Posix::visit(op); return; } - struct Pattern { - string intrin; - Expr pattern; - }; - - // clang-format off - static Pattern patterns[] = { - // This isn't rounding_multiply_quantzied(i16, i16, 15) because it doesn't - // saturate the result. - {"pmulhrs", i16(rounding_shift_right(widening_mul(wild_i16x_, wild_i16x_), 15))}, - - {"f32_to_bf16", bf16(wild_f32x_)}, - }; - // clang-format on - - vector matches; - for (const Pattern &p : patterns) { - if (expr_match(p.pattern, op, matches)) { - value = call_overloaded_intrin(op->type, p.intrin, matches); - if (value) { - return; - } - } - } - if (const Call *mul = Call::as_intrinsic(op->value, {Call::widening_mul})) { if (op->value.type().bits() < op->type.bits() && op->type.bits() <= 32) { // LLVM/x86 really doesn't like 8 -> 16 bit multiplication. If we're @@ -501,250 +441,6 @@ void CodeGen_X86::visit(const Cast *op) { CodeGen_Posix::visit(op); } -void CodeGen_X86::visit(const Call *op) { - if (!op->type.is_vector()) { - // We only have peephole optimizations for vectors in here. - CodeGen_Posix::visit(op); - return; - } - - // A 16-bit mul-shift-right of less than 16 can sometimes be rounded up to a - // full 16 to use pmulh(u)w by left-shifting one of the operands. This is - // handled here instead of in the lowering of mul_shift_right because it's - // unlikely to be a good idea on platforms other than x86, as it adds an - // extra shift in the fully-lowered case. - if ((op->type.element_of() == UInt(16) || - op->type.element_of() == Int(16)) && - op->is_intrinsic(Call::mul_shift_right)) { - internal_assert(op->args.size() == 3); - const uint64_t *shift = as_const_uint(op->args[2]); - if (shift && *shift < 16 && *shift >= 8) { - Type narrow = op->type.with_bits(8); - Expr narrow_a = lossless_cast(narrow, op->args[0]); - Expr narrow_b = narrow_a.defined() ? Expr() : lossless_cast(narrow, op->args[1]); - int shift_left = 16 - (int)(*shift); - if (narrow_a.defined()) { - codegen(mul_shift_right(op->args[0] << shift_left, op->args[1], 16)); - return; - } else if (narrow_b.defined()) { - codegen(mul_shift_right(op->args[0], op->args[1] << shift_left, 16)); - return; - } - } - } else if (op->type.is_int() && - op->type.bits() <= 16 && - op->is_intrinsic(Call::rounding_halving_add)) { - // We can redirect signed rounding halving add to unsigned rounding - // halving add by adding 128 / 32768 to the result if the sign of the - // args differs. - internal_assert(op->args.size() == 2); - Type t = op->type.with_code(halide_type_uint); - Expr a = cast(t, op->args[0]); - Expr b = cast(t, op->args[1]); - codegen(cast(op->type, rounding_halving_add(a, b) + ((a ^ b) & (1 << (t.bits() - 1))))); - return; - } else if (op->is_intrinsic(Call::absd)) { - internal_assert(op->args.size() == 2); - if (op->args[0].type().is_uint()) { - // On x86, there are many 3-instruction sequences to compute absd of - // unsigned integers. This one consists solely of instructions with - // throughput of 3 ops per cycle on Cannon Lake. - // - // Solution due to Wojciech Mula: - // http://0x80.pl/notesen/2018-03-11-sse-abs-unsigned.html - codegen(saturating_sub(op->args[0], op->args[1]) | saturating_sub(op->args[1], op->args[0])); - return; - } else if (op->args[0].type().is_int()) { - codegen(Max::make(op->args[0], op->args[1]) - Min::make(op->args[0], op->args[1])); - return; - } - } - - struct Pattern { - string intrin; - Expr pattern; - }; - - // clang-format off - static Pattern patterns[] = { - {"pmulh", mul_shift_right(wild_i16x_, wild_i16x_, 16)}, - {"pmulh", mul_shift_right(wild_u16x_, wild_u16x_, 16)}, - {"saturating_pmulhrs", rounding_mul_shift_right(wild_i16x_, wild_i16x_, 15)}, - {"saturating_narrow", i16_sat(wild_i32x_)}, - {"saturating_narrow", u16_sat(wild_i32x_)}, - {"saturating_narrow", i8_sat(wild_i16x_)}, - {"saturating_narrow", u8_sat(wild_i16x_)}, - }; - // clang-format on - - vector matches; - for (const auto &pattern : patterns) { - if (expr_match(pattern.pattern, op, matches)) { - value = call_overloaded_intrin(op->type, pattern.intrin, matches); - if (value) { - return; - } - } - } - - CodeGen_Posix::visit(op); -} - -void CodeGen_X86::codegen_vector_reduce(const VectorReduce *op, const Expr &init) { - if (op->op != VectorReduce::Add && op->op != VectorReduce::SaturatingAdd) { - CodeGen_Posix::codegen_vector_reduce(op, init); - return; - } - const int factor = op->value.type().lanes() / op->type.lanes(); - - struct Pattern { - VectorReduce::Operator reduce_op; - int factor; - Expr pattern; - const char *intrin; - Type narrow_type; - uint32_t flags = 0; - enum { - CombineInit = 1 << 0, - SwapOperands = 1 << 1, - SingleArg = 1 << 2, - }; - }; - // clang-format off - // These patterns are roughly sorted "best to worst", in case there are two - // patterns that match the expression. - static const Pattern patterns[] = { - // 4-way dot products - {VectorReduce::Add, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", {}, Pattern::CombineInit}, - {VectorReduce::Add, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, - {VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_u8x_, wild_i8x_)), "saturating_dot_product", {}, Pattern::CombineInit}, - {VectorReduce::SaturatingAdd, 4, i32(widening_mul(wild_i8x_, wild_u8x_)), "saturating_dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, - - // 2-way dot products - {VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_i8x_)), "dot_product", Int(16)}, - {VectorReduce::Add, 2, i32(widening_mul(wild_i8x_, wild_u8x_)), "dot_product", Int(16)}, - {VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_i8x_)), "dot_product", Int(16)}, - {VectorReduce::Add, 2, i32(widening_mul(wild_u8x_, wild_u8x_)), "dot_product", Int(16)}, - {VectorReduce::SaturatingAdd, 2, i32(widening_mul(wild_u8x_, wild_i8x_)), "saturating_dot_product", {}, Pattern::CombineInit}, - {VectorReduce::SaturatingAdd, 2, i32(widening_mul(wild_i8x_, wild_u8x_)), "saturating_dot_product", {}, Pattern::CombineInit | Pattern::SwapOperands}, - {VectorReduce::SaturatingAdd, 2, widening_mul(wild_u8x_, wild_i8x_), "saturating_dot_product"}, - {VectorReduce::SaturatingAdd, 2, widening_mul(wild_i8x_, wild_u8x_), "saturating_dot_product", {}, Pattern::SwapOperands}, - - {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", {}, Pattern::CombineInit}, - {VectorReduce::Add, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "dot_product", Int(16)}, - {VectorReduce::SaturatingAdd, 2, i32(widening_mul(wild_i16x_, wild_i16x_)), "saturating_dot_product", {}, Pattern::CombineInit}, - - {VectorReduce::Add, 2, wild_f32x_ * wild_f32x_, "dot_product", BFloat(16), Pattern::CombineInit}, - - // One could do a horizontal widening addition with - // other dot_products against a vector of ones. Currently disabled - // because I haven't found other cases where it's clearly better. - {VectorReduce::Add, 2, u16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, - {VectorReduce::Add, 2, i16(wild_u8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, - {VectorReduce::Add, 2, i16(wild_i8x_), "horizontal_widening_add", {}, Pattern::SingleArg}, - - // Sum of absolute differences - {VectorReduce::Add, 8, u64(absd(wild_u8x_, wild_u8x_)), "sum_of_absolute_differences", {}}, - - }; - // clang-format on - - std::vector matches; - for (const Pattern &p : patterns) { - if (op->op != p.reduce_op || p.factor != factor) { - continue; - } - if (expr_match(p.pattern, op->value, matches)) { - if (p.flags & Pattern::SingleArg) { - Expr a = matches[0]; - - if (p.narrow_type.bits() > 0) { - a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); - } - if (!a.defined()) { - continue; - } - - if (init.defined() && (p.flags & Pattern::CombineInit)) { - value = call_overloaded_intrin(op->type, p.intrin, {init, a}); - if (value) { - return; - } - } else { - value = call_overloaded_intrin(op->type, p.intrin, {a}); - if (value) { - if (init.defined()) { - Value *x = value; - Value *y = codegen(init); - value = builder->CreateAdd(x, y); - } - return; - } - } - } else { - Expr a = matches[0]; - Expr b = matches[1]; - if (p.flags & Pattern::SwapOperands) { - std::swap(a, b); - } - if (p.narrow_type.bits() > 0) { - a = lossless_cast(p.narrow_type.with_lanes(a.type().lanes()), a); - b = lossless_cast(p.narrow_type.with_lanes(b.type().lanes()), b); - } - if (!a.defined() || !b.defined()) { - continue; - } - - if (init.defined() && (p.flags & Pattern::CombineInit)) { - value = call_overloaded_intrin(op->type, p.intrin, {init, a, b}); - if (value) { - return; - } - } else { - value = call_overloaded_intrin(op->type, p.intrin, {a, b}); - if (value) { - if (init.defined()) { - Value *x = value; - Value *y = codegen(init); - value = builder->CreateAdd(x, y); - } - return; - } - } - } - } - } - - // Rewrite non-native sum-of-absolute-difference variants to the native - // op. We support reducing to various types. We could consider supporting - // multiple reduction factors too, but in general we don't handle non-native - // reduction factors for VectorReduce nodes (yet?). - if (op->op == VectorReduce::Add && - factor == 8) { - const Cast *cast = op->value.as(); - const Call *call = cast ? cast->value.as() : nullptr; - if (call && - call->is_intrinsic(Call::absd) && - cast->type.element_of().can_represent(UInt(8)) && - (cast->type.is_int() || cast->type.is_uint()) && - call->args[0].type().element_of() == UInt(8)) { - - internal_assert(cast->type.element_of() != UInt(64)) << "Should have pattern-matched above\n"; - - // Cast to uint64 instead - Expr equiv = Cast::make(UInt(64, cast->value.type().lanes()), cast->value); - // Reduce on that to hit psadbw - equiv = VectorReduce::make(VectorReduce::Add, equiv, op->type.lanes()); - // Then cast that to the desired type - equiv = Cast::make(cast->type.with_lanes(equiv.type().lanes()), equiv); - codegen(equiv); - return; - } - } - - CodeGen_Posix::codegen_vector_reduce(op, init); -} - void CodeGen_X86::visit(const Allocate *op) { ScopedBinding bind(mem_type, op->name, op->memory_type); CodeGen_Posix::visit(op); @@ -777,6 +473,12 @@ void CodeGen_X86::visit(const Store *op) { CodeGen_Posix::visit(op); } +void CodeGen_X86::visit(const VectorInstruction *op) { + const std::string name = op->get_instruction_name(); + value = call_overloaded_intrin(op->type, name, op->args); + internal_assert(value) << "CodeGen_X86 failed on " << Expr(op) << "\n"; +} + string CodeGen_X86::mcpu_target() const { // Perform an ad-hoc guess for the -mcpu given features. // WARNING: this is used to drive -mcpu, *NOT* -mtune! diff --git a/src/Deinterleave.cpp b/src/Deinterleave.cpp index f5840a0074b3..3f380be535c3 100644 --- a/src/Deinterleave.cpp +++ b/src/Deinterleave.cpp @@ -195,6 +195,11 @@ class Deinterleaver : public IRGraphMutator { using IRMutator::visit; + Expr visit(const VectorInstruction *op) override { + // We can't do anything special here. + return give_up_and_shuffle(op); + } + Expr visit(const VectorReduce *op) override { std::vector input_lanes; int factor = op->value.type().lanes() / op->type.lanes(); @@ -402,6 +407,12 @@ Expr deinterleave(Expr e, int starting_lane, int lane_stride, int new_lanes, con Deinterleaver d(starting_lane, lane_stride, new_lanes, lets); e = d.mutate(e); e = common_subexpression_elimination(e); + if (const Shuffle *shuffle = e.as()) { + if (shuffle->is_extract_element() && shuffle->vectors.size() == 1) { + // calling `simplify` here will produce an infinite recursive loop. + return e; + } + } return simplify(e); } } // namespace diff --git a/src/Derivative.cpp b/src/Derivative.cpp index 08a1c617ca00..7677e335c302 100644 --- a/src/Derivative.cpp +++ b/src/Derivative.cpp @@ -88,6 +88,9 @@ class ReverseAccumulationVisitor : public IRVisitor { void visit(const Shuffle *op) override { internal_error << "Encounter unexpected expression \"Shuffle\" when differentiating."; } + void visit(const VectorInstruction *op) override { + internal_error << "Encounter unexpected expression \"VectorInstruction\" when differentiating."; + } void visit(const VectorReduce *op) override { internal_error << "Encounter unexpected expression \"VectorReduce\" when differentiating."; } diff --git a/src/Expr.h b/src/Expr.h index ac0ec6521d68..efb7526e0eb5 100644 --- a/src/Expr.h +++ b/src/Expr.h @@ -57,6 +57,7 @@ enum class IRNodeType { Call, Let, Shuffle, + VectorInstruction, VectorReduce, // Stmts LetStmt, diff --git a/src/IR.cpp b/src/IR.cpp index e2f54318214a..a03cfd99d3fd 100644 --- a/src/IR.cpp +++ b/src/IR.cpp @@ -907,6 +907,46 @@ Stmt Atomic::make(const std::string &producer_name, return node; } +namespace { + +const char *const instruction_op_names[] = { + // Shared: + "abs", + "dot_product", + "rounding_halving_add", + "saturating_add", + "saturating_narrow", + "saturating_sub", + "widening_mul", + + // x86-specific + "f32_to_bf16", + "horizontal_add", + "pmulh", + "pmulhrs", + "saturating_dot_product", + "sum_absd", +}; + +static_assert(sizeof(instruction_op_names) / sizeof(instruction_op_names[0]) == VectorInstruction::InstructionOpCount, + "instruction_op_names needs attention"); + +} // namespace + +Expr VectorInstruction::make(Type type, InstructionOp op, const std::vector &args) { + user_assert(!args.empty()) << "VectorInrinsic without arguments\n"; + + VectorInstruction *node = new VectorInstruction; + node->type = type; + node->op = op; + node->args = args; + return node; +} + +const char *VectorInstruction::get_instruction_name() const { + return instruction_op_names[op]; +} + Expr VectorReduce::make(VectorReduce::Operator op, Expr vec, int lanes) { @@ -1087,6 +1127,10 @@ void ExprNode::accept(IRVisitor *v) const { v->visit((const Shuffle *)this); } template<> +void ExprNode::accept(IRVisitor *v) const { + v->visit((const VectorInstruction *)this); +} +template<> void ExprNode::accept(IRVisitor *v) const { v->visit((const VectorReduce *)this); } @@ -1276,6 +1320,10 @@ Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const Shuffle *)this); } template<> +Expr ExprNode::mutate_expr(IRMutator *v) const { + return v->visit((const VectorInstruction *)this); +} +template<> Expr ExprNode::mutate_expr(IRMutator *v) const { return v->visit((const VectorReduce *)this); } diff --git a/src/IR.h b/src/IR.h index 0da5ffa1aaa6..51a371cbc07d 100644 --- a/src/IR.h +++ b/src/IR.h @@ -910,6 +910,47 @@ struct Atomic : public StmtNode { static const IRNodeType _node_type = IRNodeType::Atomic; }; +/** Represent a length-agnostic and target-specific + * vector instruction. Intrinsic may not be element-wise + * operation, i.e. dot_products. Should only be generated + * and consumed during CodeGen. */ +struct VectorInstruction : public ExprNode { + // enums for vector instructions. Name is recovered via get_instruction_name() + // Specific enum values are *not* guaranteed to be stable across time. + // Please keep this list sorted via target architecture (with a shared section). + // This last will become more complete as we add Optimize passes for more backends. + // If you add an instruction here, update `instruction_op_names` in IR.cpp. + enum InstructionOp { + // Shared: + abs, + dot_product, + rounding_halving_add, + saturating_add, + saturating_narrow, + saturating_sub, + widening_mul, + + // x86-specific + f32_to_bf16, + horizontal_add, + pmulh, + pmulhrs, + saturating_dot_product, + sum_absd, + + InstructionOpCount // Sentinel: keep last. + }; + + InstructionOp op; + std::vector args; + + static Expr make(Type type, InstructionOp op, const std::vector &args); + + static const IRNodeType _node_type = IRNodeType::VectorInstruction; + + const char *get_instruction_name() const; +}; + /** Horizontally reduce a vector to a scalar or narrower vector using * the given commutative and associative binary operator. The reduction * factor is dictated by the number of lanes in the input and output diff --git a/src/IREquality.cpp b/src/IREquality.cpp index 20cb616d2c32..edcfc3d067dc 100644 --- a/src/IREquality.cpp +++ b/src/IREquality.cpp @@ -98,6 +98,7 @@ class IRComparer : public IRVisitor { void visit(const Shuffle *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; + void visit(const VectorInstruction *) override; void visit(const VectorReduce *) override; }; @@ -629,6 +630,13 @@ void IRComparer::visit(const Atomic *op) { compare_stmt(s->body, op->body); } +void IRComparer::visit(const VectorInstruction *op) { + const VectorInstruction *e = expr.as(); + + compare_scalar(e->op, op->op); + compare_expr_vector(e->args, op->args); +} + void IRComparer::visit(const VectorReduce *op) { const VectorReduce *e = expr.as(); diff --git a/src/IRMatch.cpp b/src/IRMatch.cpp index 8416d223ffc4..056713836448 100644 --- a/src/IRMatch.cpp +++ b/src/IRMatch.cpp @@ -296,6 +296,22 @@ class IRMatch : public IRVisitor { } } + void visit(const VectorInstruction *op) override { + const VectorInstruction *e = expr.as(); + if (result && e && + types_match(op->type, e->type) && + e->op == op->op && + e->args.size() == op->args.size()) { + for (size_t i = 0; result && (i < e->args.size()); i++) { + // FIXME: should we early-out? Here and in Call* + expr = e->args[i]; + op->args[i].accept(this); + } + } else { + result = false; + } + } + void visit(const VectorReduce *op) override { const VectorReduce *e = expr.as(); if (result && e && op->op == e->op && types_match(op->type, e->type)) { @@ -515,6 +531,9 @@ bool equal_helper(const BaseExprNode &a, const BaseExprNode &b) noexcept { case IRNodeType::Shuffle: return (equal_helper(((const Shuffle &)a).vectors, ((const Shuffle &)b).vectors) && equal_helper(((const Shuffle &)a).indices, ((const Shuffle &)b).indices)); + case IRNodeType::VectorInstruction: + return (((const VectorInstruction &)a).op == ((const VectorInstruction &)b).op && + equal_helper(((const VectorInstruction &)a).args, ((const VectorInstruction &)b).args)); case IRNodeType::VectorReduce: // As with Cast above, we use equal instead of equal_helper // here, because while we know a.type == b.type, we don't know diff --git a/src/IRMatch.h b/src/IRMatch.h index 06d35ba8b4b0..42c4e99480aa 100644 --- a/src/IRMatch.h +++ b/src/IRMatch.h @@ -211,6 +211,8 @@ struct SpecificExpr { constexpr static IRNodeType max_node_type = IRNodeType::Shuffle; constexpr static bool canonical = true; + // Having SpecificExpr hold an Expr instead of a BaseExprNode reference + // is catastrophic for performance and stack space usage. const BaseExprNode &expr; template @@ -585,8 +587,13 @@ IntLiteral pattern_arg(int64_t x) { } template -HALIDE_ALWAYS_INLINE void assert_is_lvalue_if_expr() { - static_assert(!std::is_same::type, Expr>::value || std::is_lvalue_reference::value, +static constexpr bool is_lvalue_if_expr() { + return !std::is_same::type, Expr>::value || std::is_lvalue_reference::value; +} + +template +HALIDE_ALWAYS_INLINE static constexpr void assert_is_lvalue_if_expr() { + static_assert(is_lvalue_if_expr(), "Exprs are captured by reference by IRMatcher objects and so must be lvalues"); } @@ -1459,6 +1466,12 @@ struct Intrin { return rounding_shift_left(arg0, arg1); } else if (intrin == Call::rounding_shift_right) { return rounding_shift_right(arg0, arg1); + } else if (intrin == Call::bitwise_xor) { + return arg0 ^ arg1; + } else if (intrin == Call::bitwise_and) { + return arg0 & arg1; + } else if (intrin == Call::bitwise_or) { + return arg0 | arg1; } Expr arg2 = std::get(args).make(state, type_hint); @@ -1535,6 +1548,17 @@ HALIDE_ALWAYS_INLINE auto intrin(Call::IntrinsicOp intrinsic_op, Args... args) n return {intrinsic_op, pattern_arg(args)...}; } +template +auto abs(A &&a) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + return {Call::abs, pattern_arg(a)}; +} +template +auto absd(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return {Call::absd, pattern_arg(a), pattern_arg(b)}; +} template auto widen_right_add(A &&a, B &&b) noexcept -> Intrin { return {Call::widen_right_add, pattern_arg(a), pattern_arg(b)}; @@ -1550,64 +1574,131 @@ auto widen_right_sub(A &&a, B &&b) noexcept -> Intrin auto widening_add(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::widening_add, pattern_arg(a), pattern_arg(b)}; } template auto widening_sub(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::widening_sub, pattern_arg(a), pattern_arg(b)}; } template auto widening_mul(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::widening_mul, pattern_arg(a), pattern_arg(b)}; } template auto saturating_add(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::saturating_add, pattern_arg(a), pattern_arg(b)}; } template auto saturating_sub(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::saturating_sub, pattern_arg(a), pattern_arg(b)}; } template auto saturating_cast(const Type &t, A &&a) noexcept -> Intrin { + assert_is_lvalue_if_expr(); Intrin p = {Call::saturating_cast, pattern_arg(a)}; p.optional_type_hint = t; return p; } template auto halving_add(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::halving_add, pattern_arg(a), pattern_arg(b)}; } template auto halving_sub(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::halving_sub, pattern_arg(a), pattern_arg(b)}; } template auto rounding_halving_add(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::rounding_halving_add, pattern_arg(a), pattern_arg(b)}; } template auto shift_left(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::shift_left, pattern_arg(a), pattern_arg(b)}; } template auto shift_right(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::shift_right, pattern_arg(a), pattern_arg(b)}; } template auto rounding_shift_left(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::rounding_shift_left, pattern_arg(a), pattern_arg(b)}; } template auto rounding_shift_right(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::rounding_shift_right, pattern_arg(a), pattern_arg(b)}; } +template +auto bitwise_xor(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return {Call::bitwise_xor, pattern_arg(a), pattern_arg(b)}; +} +template +HALIDE_ALWAYS_INLINE auto operator^(A &&a, B &&b) noexcept -> auto{ + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return bitwise_xor(a, b); +} +template +auto bitwise_and(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return {Call::bitwise_and, pattern_arg(a), pattern_arg(b)}; +} +template +HALIDE_ALWAYS_INLINE auto operator&(A &&a, B &&b) noexcept -> auto{ + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return bitwise_and(a, b); +} +template +auto bitwise_or(A &&a, B &&b) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return {Call::bitwise_or, pattern_arg(a), pattern_arg(b)}; +} +template +HALIDE_ALWAYS_INLINE auto operator|(A &&a, B &&b) noexcept -> auto{ + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + return bitwise_or(a, b); +} template auto mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)}; } template auto rounding_mul_shift_right(A &&a, B &&b, C &&c) noexcept -> Intrin { + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {Call::rounding_mul_shift_right, pattern_arg(a), pattern_arg(b), pattern_arg(c)}; } @@ -1803,6 +1894,7 @@ inline std::ostream &operator<<(std::ostream &s, const BroadcastOp &op) { template HALIDE_ALWAYS_INLINE auto broadcast(A &&a, B lanes) noexcept -> BroadcastOp { assert_is_lvalue_if_expr(); + assert_is_lvalue_if_expr(); return {pattern_arg(a), pattern_arg(lanes)}; } @@ -1872,6 +1964,110 @@ HALIDE_ALWAYS_INLINE auto ramp(A &&a, B &&b, C &&c) noexcept -> RampOp +struct VectorInstructionOp { + struct pattern_tag {}; + const VectorInstruction::InstructionOp op; + std::tuple args; + + static constexpr uint32_t binds = bitwise_or_reduce((bindings::mask)...); + + constexpr static IRNodeType min_node_type = IRNodeType::VectorInstruction; + constexpr static IRNodeType max_node_type = IRNodeType::VectorInstruction; + constexpr static bool canonical = and_reduce((Args::canonical)...); + + template::type> + HALIDE_ALWAYS_INLINE bool match_args(int, const VectorInstruction &v, MatcherState &state) const noexcept { + using T = decltype(std::get(args)); + return (std::get(args).template match(*v.args[i].get(), state) && + match_args::mask>(0, v, state)); + } + + template + HALIDE_ALWAYS_INLINE bool match_args(double, const VectorInstruction &v, MatcherState &state) const noexcept { + return true; + } + + template + HALIDE_ALWAYS_INLINE bool match(const BaseExprNode &e, MatcherState &state) const noexcept { + if (e.node_type != IRNodeType::VectorInstruction) { + return false; + } + const VectorInstruction &v = (const VectorInstruction &)e; + return (v.op == op && match_args<0, bound>(0, v, state)); + } + + template::type> + HALIDE_ALWAYS_INLINE void print_args(int, std::ostream &s) const { + s << std::get(args); + if (i + 1 < sizeof...(Args)) { + s << ", "; + } + print_args(0, s); + } + + template + HALIDE_ALWAYS_INLINE void print_args(double, std::ostream &s) const { + } + + HALIDE_ALWAYS_INLINE + void print_args(std::ostream &s) const { + print_args<0>(0, s); + } + + HALIDE_ALWAYS_INLINE + Expr make(MatcherState &state, halide_type_t type_hint) const { + std::vector r_args(sizeof...(Args)); + // TODO(rootjalex): How do we do type hints for the args? + // TODO(rootjalex): Is there a way to do basically an unrolled + // loop of the below? this is ugly. + // Supposedly C++20 will have constexpr std::transform, perhaps + // we can use that when Halide upgrades. + + r_args[0] = std::get<0>(args).make(state, {}); + if constexpr (sizeof...(Args) > 1) { + r_args[1] = std::get(args).make(state, {}); + } + if constexpr (sizeof...(Args) > 2) { + r_args[2] = std::get(args).make(state, {}); + } + + // for (int i = 0; i < sizeof...(Args); i++) { + // // TODO(rootjalex): how do we do type-hints here? + // args[i] = std::get(args).make(state, {}); + // } + return VectorInstruction::make(type_hint, op, r_args); + } + + constexpr static bool foldable = false; + + HALIDE_ALWAYS_INLINE + VectorInstructionOp(const VectorInstruction::InstructionOp _op, Args... args) noexcept + : op(_op), args(args...) { + static_assert(sizeof...(Args) > 0 && sizeof...(Args) <= 3, + "VectorInstructionOp must have non-zero arguments, and update make() if more than 3 arguments."); + } +}; + +template +std::ostream &operator<<(std::ostream &s, const VectorInstructionOp &op) { + // TODO(rootjalex): Should we print the type? + s << "vector_instr(\""; + s << op.op << "\", "; + op.print_args(s); + s << ")"; + return s; +} + +template +HALIDE_ALWAYS_INLINE auto v_instr(const VectorInstruction::InstructionOp op, Args &&...args) noexcept -> VectorInstructionOp { + static_assert(and_reduce((is_lvalue_if_expr())...), "All parameters to a VectorInstructionOp must be lvalues if Exprs"); + return {op, pattern_arg(args)...}; +} + template struct VectorReduceOp { struct pattern_tag {}; @@ -1928,6 +2124,12 @@ HALIDE_ALWAYS_INLINE auto h_add(A &&a, B lanes) noexcept -> VectorReduceOp +HALIDE_ALWAYS_INLINE auto h_satadd(A &&a, B lanes) noexcept -> VectorReduceOp { + assert_is_lvalue_if_expr(); + return {pattern_arg(a), pattern_arg(lanes)}; +} + template HALIDE_ALWAYS_INLINE auto h_min(A &&a, B lanes) noexcept -> VectorReduceOp { assert_is_lvalue_if_expr(); @@ -2080,6 +2282,39 @@ HALIDE_ALWAYS_INLINE auto cast(halide_type_t t, A &&a) noexcept -> CastOp +struct TypeHint { + struct pattern_tag {}; + Type type; + A a; + + constexpr static uint32_t binds = bindings::mask; + + constexpr static IRNodeType min_node_type = IRNodeType::Cast; + constexpr static IRNodeType max_node_type = IRNodeType::Cast; + constexpr static bool canonical = A::canonical; + + HALIDE_ALWAYS_INLINE + Expr make(MatcherState &state, halide_type_t type_hint) const { + return a.make(state, type); + } + + constexpr static bool foldable = false; +}; + +template +std::ostream &operator<<(std::ostream &s, const TypeHint &op) { + s << "typed(" << op.type << ", " << op.a << ")"; + return s; +} + +template +HALIDE_ALWAYS_INLINE auto typed(halide_type_t t, A &&a) noexcept -> TypeHint { + assert_is_lvalue_if_expr(); + return {t, pattern_arg(a)}; +} + template struct Fold { struct pattern_tag {}; @@ -2306,6 +2541,8 @@ template struct IsFloat { struct pattern_tag {}; A a; + int bits; + int lanes; constexpr static uint32_t binds = bindings::mask; @@ -2320,7 +2557,7 @@ struct IsFloat { void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. Type t = a.make(state, {}).type(); - val.u.u64 = t.is_float(); + val.u.u64 = t.is_float() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes); ty.code = halide_type_uint; ty.bits = 1; ty.lanes = t.lanes(); @@ -2328,14 +2565,67 @@ struct IsFloat { }; template -HALIDE_ALWAYS_INLINE auto is_float(A &&a) noexcept -> IsFloat { +HALIDE_ALWAYS_INLINE auto is_float(A &&a, int bits = 0, int lanes = 0) noexcept -> IsFloat { assert_is_lvalue_if_expr(); - return {pattern_arg(a)}; + return {pattern_arg(a), bits, lanes}; } template std::ostream &operator<<(std::ostream &s, const IsFloat &op) { - s << "is_float(" << op.a << ")"; + s << "is_float(" << op.a; + if (op.bits > 0) { + s << ", " << op.bits; + } + if (op.lanes > 0) { + s << ", " << op.lanes; + } + s << ")"; + return s; +} + +template +struct IsBFloat { + struct pattern_tag {}; + A a; + int bits; + int lanes; + + constexpr static uint32_t binds = bindings::mask; + + // This rule is a boolean-valued predicate. Bools have type UIntImm. + constexpr static IRNodeType min_node_type = IRNodeType::UIntImm; + constexpr static IRNodeType max_node_type = IRNodeType::UIntImm; + constexpr static bool canonical = true; + + constexpr static bool foldable = true; + + HALIDE_ALWAYS_INLINE + void make_folded_const(halide_scalar_value_t &val, halide_type_t &ty, MatcherState &state) const { + // a is almost certainly a very simple pattern (e.g. a wild), so just inline the make method. + Type t = a.make(state, {}).type(); + val.u.u64 = t.is_bfloat() && (bits == 0 || t.bits() == bits) && (lanes == 0 || t.lanes() == lanes); + ty.code = halide_type_uint; + ty.bits = 1; + ty.lanes = t.lanes(); + } +}; + +template +HALIDE_ALWAYS_INLINE auto is_bfloat(A &&a, int bits = 0, int lanes = 0) noexcept -> IsBFloat { + assert_is_lvalue_if_expr(); + return {pattern_arg(a), bits, lanes}; +} + +template +std::ostream &operator<<(std::ostream &s, const IsBFloat &op) { + s << "is_bfloat(" << op.a; + if (op.bits > 0) { + s << ", " << op.bits; + } + if (op.lanes > 0) { + s << ", " << op.lanes; + } + s << ")"; return s; } @@ -2343,7 +2633,8 @@ template struct IsInt { struct pattern_tag {}; A a; - int bits, lanes; + int bits; + int lanes; constexpr static uint32_t binds = bindings::mask; @@ -2388,7 +2679,8 @@ template struct IsUInt { struct pattern_tag {}; A a; - int bits, lanes; + int bits; + int lanes; constexpr static uint32_t binds = bindings::mask; diff --git a/src/IRMutator.cpp b/src/IRMutator.cpp index 005937a17008..b1703a6cccd1 100644 --- a/src/IRMutator.cpp +++ b/src/IRMutator.cpp @@ -327,6 +327,14 @@ Expr IRMutator::visit(const Shuffle *op) { return Shuffle::make(new_vectors, op->indices); } +Expr IRMutator::visit(const VectorInstruction *op) { + auto [new_args, changed] = mutate_with_changes(op->args); + if (!changed) { + return op; + } + return VectorInstruction::make(op->type, op->op, new_args); +} + Expr IRMutator::visit(const VectorReduce *op) { Expr value = mutate(op->value); if (value.same_as(op->value)) { diff --git a/src/IRMutator.h b/src/IRMutator.h index c7a1984269d3..4729bb08344f 100644 --- a/src/IRMutator.h +++ b/src/IRMutator.h @@ -81,6 +81,7 @@ class IRMutator { virtual Expr visit(const Call *); virtual Expr visit(const Let *); virtual Expr visit(const Shuffle *); + virtual Expr visit(const VectorInstruction *); virtual Expr visit(const VectorReduce *); virtual Stmt visit(const LetStmt *); diff --git a/src/IRPrinter.cpp b/src/IRPrinter.cpp index 38f57e46649e..78d0e087d7cb 100644 --- a/src/IRPrinter.cpp +++ b/src/IRPrinter.cpp @@ -1073,6 +1073,16 @@ void IRPrinter::visit(const Shuffle *op) { } } +void IRPrinter::visit(const VectorInstruction *op) { + stream << "(" + << op->type + << ")vector_instruction(\"" + << op->get_instruction_name() + << "\", "; + print_list(op->args); + stream << ")"; +} + void IRPrinter::visit(const VectorReduce *op) { stream << "(" << op->type diff --git a/src/IRPrinter.h b/src/IRPrinter.h index 666235988cd7..e4e89efd5806 100644 --- a/src/IRPrinter.h +++ b/src/IRPrinter.h @@ -194,6 +194,7 @@ class IRPrinter : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorInstruction *) override; void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; diff --git a/src/IRVisitor.cpp b/src/IRVisitor.cpp index 7f9993987200..97c55d8075ac 100644 --- a/src/IRVisitor.cpp +++ b/src/IRVisitor.cpp @@ -257,6 +257,12 @@ void IRVisitor::visit(const Shuffle *op) { } } +void IRVisitor::visit(const VectorInstruction *op) { + for (const auto &arg : op->args) { + arg.accept(this); + } +} + void IRVisitor::visit(const VectorReduce *op) { op->value.accept(this); } @@ -515,6 +521,12 @@ void IRGraphVisitor::visit(const Shuffle *op) { } } +void IRGraphVisitor::visit(const VectorInstruction *op) { + for (const auto &arg : op->args) { + include(arg); + } +} + void IRGraphVisitor::visit(const VectorReduce *op) { include(op->value); } diff --git a/src/IRVisitor.h b/src/IRVisitor.h index 4e1650ff22be..5df16880dfed 100644 --- a/src/IRVisitor.h +++ b/src/IRVisitor.h @@ -71,6 +71,7 @@ class IRVisitor { virtual void visit(const IfThenElse *); virtual void visit(const Evaluate *); virtual void visit(const Shuffle *); + virtual void visit(const VectorInstruction *); virtual void visit(const VectorReduce *); virtual void visit(const Prefetch *); virtual void visit(const Fork *); @@ -142,6 +143,7 @@ class IRGraphVisitor : public IRVisitor { void visit(const IfThenElse *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorInstruction *) override; void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Acquire *) override; @@ -224,6 +226,8 @@ class VariadicVisitor { return ((T *)this)->visit((const Let *)node, std::forward(args)...); case IRNodeType::Shuffle: return ((T *)this)->visit((const Shuffle *)node, std::forward(args)...); + case IRNodeType::VectorInstruction: + return ((T *)this)->visit((const VectorInstruction *)node, std::forward(args)...); case IRNodeType::VectorReduce: return ((T *)this)->visit((const VectorReduce *)node, std::forward(args)...); // Explicitly list the Stmt types rather than using a @@ -286,6 +290,7 @@ class VariadicVisitor { case IRNodeType::Call: case IRNodeType::Let: case IRNodeType::Shuffle: + case IRNodeType::VectorInstruction: case IRNodeType::VectorReduce: internal_error << "Unreachable"; break; diff --git a/src/InstructionSelector.cpp b/src/InstructionSelector.cpp new file mode 100644 index 000000000000..072d07aa4cf8 --- /dev/null +++ b/src/InstructionSelector.cpp @@ -0,0 +1,34 @@ +#include "InstructionSelector.h" + +#include "CodeGen_Internal.h" +#include "IROperator.h" + +namespace Halide { +namespace Internal { + +InstructionSelector::InstructionSelector(const Target &t, const CodeGen_LLVM *c) + : target(t), codegen(c) { +} + +Expr InstructionSelector::visit(const Div *op) { + if (op->type.is_vector() && op->type.is_int_or_uint()) { + // Lower division here in order to do pattern-matching on intrinsics. + return mutate(lower_int_uint_div(op->a, op->b)); + } + return IRGraphMutator::visit(op); +} + +Expr InstructionSelector::visit(const Mod *op) { + if (op->type.is_vector() && op->type.is_int_or_uint()) { + // Lower mod here in order to do pattern-matching on intrinsics. + return mutate(lower_int_uint_mod(op->a, op->b)); + } + return IRGraphMutator::visit(op); +} + +Expr InstructionSelector::visit(const VectorReduce *op) { + return mutate(codegen->split_vector_reduce(op, Expr())); +} + +} // namespace Internal +} // namespace Halide diff --git a/src/InstructionSelector.h b/src/InstructionSelector.h new file mode 100644 index 000000000000..bc7b1541374a --- /dev/null +++ b/src/InstructionSelector.h @@ -0,0 +1,38 @@ +#ifndef HALIDE_INSTRUCTION_SELECTOR_H +#define HALIDE_INSTRUCTION_SELECTOR_H + +/** \file + * Defines a base class for VectorInstruction selection. + */ + +#include "CodeGen_LLVM.h" +#include "IR.h" +#include "IRMutator.h" +#include "Target.h" + +namespace Halide { +namespace Internal { + +/** A base class for vector instruction selection. + * The default implementation lowers int and uint + * div and mod, and splits VectorReduce nodes via + * CodeGen_LLVM::split_vector_reduce(). + */ +class InstructionSelector : public IRGraphMutator { +protected: + const Target ⌖ + const CodeGen_LLVM *codegen; + + using IRGraphMutator::visit; + Expr visit(const Div *) override; + Expr visit(const Mod *) override; + Expr visit(const VectorReduce *) override; + +public: + InstructionSelector(const Target &target, const CodeGen_LLVM *codegen); +}; + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/Lower.cpp b/src/Lower.cpp index 38ad867686e6..f25f209ecea4 100644 --- a/src/Lower.cpp +++ b/src/Lower.cpp @@ -75,6 +75,7 @@ #include "UnsafePromises.h" #include "VectorizeLoops.h" #include "WrapCalls.h" +#include "X86Optimize.h" namespace Halide { namespace Internal { diff --git a/src/ModulusRemainder.cpp b/src/ModulusRemainder.cpp index 34a598e4c7e3..fcce870a5a29 100644 --- a/src/ModulusRemainder.cpp +++ b/src/ModulusRemainder.cpp @@ -74,6 +74,7 @@ class ComputeModulusRemainder : public IRVisitor { void visit(const Free *) override; void visit(const Evaluate *) override; void visit(const Shuffle *) override; + void visit(const VectorInstruction *) override; void visit(const VectorReduce *) override; void visit(const Prefetch *) override; void visit(const Atomic *) override; @@ -213,6 +214,12 @@ void ComputeModulusRemainder::visit(const Shuffle *op) { result = ModulusRemainder{}; } +void ComputeModulusRemainder::visit(const VectorInstruction *op) { + internal_error << "modulus_remainder of VectorInstruction:\n" + << Expr(op) << "\n"; + result = ModulusRemainder{}; +} + void ComputeModulusRemainder::visit(const VectorReduce *op) { internal_assert(op->type.is_scalar()) << "modulus_remainder of vector\n"; result = ModulusRemainder{}; diff --git a/src/Monotonic.cpp b/src/Monotonic.cpp index cec309571aa8..e2718917b83e 100644 --- a/src/Monotonic.cpp +++ b/src/Monotonic.cpp @@ -535,6 +535,11 @@ class DerivativeBounds : public IRVisitor { result = ConstantInterval::single_point(0); } + void visit(const VectorInstruction *op) override { + // TODO(rootjalex): Should this be an error? + result = ConstantInterval::everything(); + } + void visit(const VectorReduce *op) override { op->value.accept(this); switch (op->op) { diff --git a/src/Simplify_Exprs.cpp b/src/Simplify_Exprs.cpp index 574754686cc6..0496fb4fc353 100644 --- a/src/Simplify_Exprs.cpp +++ b/src/Simplify_Exprs.cpp @@ -59,6 +59,11 @@ Expr Simplify::visit(const Broadcast *op, ExprInfo *bounds) { } } +Expr Simplify::visit(const VectorInstruction *op, ExprInfo *bounds) { + clear_bounds_info(bounds); + return op; +} + Expr Simplify::visit(const VectorReduce *op, ExprInfo *bounds) { Expr value = mutate(op->value, bounds); diff --git a/src/Simplify_Internal.h b/src/Simplify_Internal.h index a510e5c51f64..1b0258a1d150 100644 --- a/src/Simplify_Internal.h +++ b/src/Simplify_Internal.h @@ -333,6 +333,7 @@ class Simplify : public VariadicVisitor { Expr visit(const Load *op, ExprInfo *bounds); Expr visit(const Call *op, ExprInfo *bounds); Expr visit(const Shuffle *op, ExprInfo *bounds); + Expr visit(const VectorInstruction *op, ExprInfo *bounds); Expr visit(const VectorReduce *op, ExprInfo *bounds); Expr visit(const Let *op, ExprInfo *bounds); Stmt visit(const LetStmt *op); diff --git a/src/StmtToHtml.cpp b/src/StmtToHtml.cpp index 21bc74dd20ac..36db8155c525 100644 --- a/src/StmtToHtml.cpp +++ b/src/StmtToHtml.cpp @@ -712,6 +712,13 @@ class StmtToHtml : public IRVisitor { stream << close_span(); } + void visit(const VectorInstruction *op) override { + stream << open_span("VectorInstruction"); + stream << open_span("Type") << op->type << close_span(); + print_list(symbol("vector_instruction") + "(\"" + op->get_instruction_name() + "\"", op->args, ")"); + stream << close_span(); + } + void visit(const VectorReduce *op) override { stream << open_span("VectorReduce"); stream << open_span("Type") << op->type << close_span(); diff --git a/src/X86Optimize.cpp b/src/X86Optimize.cpp new file mode 100644 index 000000000000..ae5a4424679b --- /dev/null +++ b/src/X86Optimize.cpp @@ -0,0 +1,633 @@ +#include "X86Optimize.h" + +#include "CSE.h" +#include "FindIntrinsics.h" +#include "IR.h" +#include "IRMatch.h" +#include "IRMutator.h" +#include "IROperator.h" +#include "InstructionSelector.h" +#include "Simplify.h" + +namespace Halide { +namespace Internal { + +#if defined(WITH_X86) + +namespace { + +// i32(i16_a)*i32(i16_b) +/- i32(i16_c)*i32(i16_d) can be done by +// interleaving a, c, and b, d, and then using dot_product. +bool should_use_dot_product(const Expr &a, const Expr &b, std::vector &result) { + Type t = a.type(); + internal_assert(b.type() == t) << a << " and " << b << " don't match types\n"; + + if (!(t.is_int() && t.bits() == 32 && t.lanes() >= 4)) { + return false; + } + + const Call *ma = Call::as_intrinsic(a, {Call::widening_mul}); + const Call *mb = Call::as_intrinsic(b, {Call::widening_mul}); + // dot_product can't handle mixed type widening muls. + if (ma && ma->args[0].type() != ma->args[1].type()) { + return false; + } + if (mb && mb->args[0].type() != mb->args[1].type()) { + return false; + } + // If the operands are widening shifts, we might be able to treat these as + // multiplies. + const Call *sa = Call::as_intrinsic(a, {Call::widening_shift_left}); + const Call *sb = Call::as_intrinsic(b, {Call::widening_shift_left}); + if (sa && !is_const(sa->args[1])) { + sa = nullptr; + } + if (sb && !is_const(sb->args[1])) { + sb = nullptr; + } + if ((ma || sa) && (mb || sb)) { + Expr a0 = ma ? ma->args[0] : sa->args[0]; + Expr a1 = ma ? ma->args[1] : lossless_cast(sa->args[0].type(), simplify(make_const(sa->type, 1) << sa->args[1])); + Expr b0 = mb ? mb->args[0] : sb->args[0]; + Expr b1 = mb ? mb->args[1] : lossless_cast(sb->args[0].type(), simplify(make_const(sb->type, 1) << sb->args[1])); + if (a1.defined() && b1.defined()) { + std::vector args = {a0, a1, b0, b1}; + result.swap(args); + return true; + } + } + return false; +} + +/** A top-down code optimizer that replaces Halide IR with VectorInstructions specific to x86. */ +class Optimize_X86 : public InstructionSelector { +public: + /** Create an x86 code optimizer. Processor features can be + * enabled using the appropriate flags in the target struct. */ + Optimize_X86(const Target &target, const CodeGen_LLVM *codegen) + : InstructionSelector(target, codegen) { + } + + using IRGraphMutator::mutate; + Expr mutate(const Expr &e) override { + Expr expr = IRGraphMutator::mutate(e); + internal_assert(expr.type() == e.type()) << "(X86Optimize) Found type mismatch: " << e << " -> " << expr << "\n"; + return expr; + } + +protected: + bool should_peephole_optimize(const Type &type) { + // We only have peephole optimizations for vectors here. + // FIXME: should we only optimize vectors that are multiples of the native vector width? + // when we do, we fail simd_op_check tests on weird vector sizes. + return type.is_vector(); + } + + using IRGraphMutator::visit; + + /** Nodes for which we want to emit specific sse/avx intrinsics */ + Expr visit(const Add *op) override { + if (!should_peephole_optimize(op->type)) { + return IRGraphMutator::visit(op); + } + + std::vector matches; + // TODO(rootjalex): is it possible to rewrite should_use_dot_product + // as a series of rewrite-rules? lossless_cast is the hardest part. + const int lanes = op->type.lanes(); + + // FIXME: should we check for accumulating dot_products first? + // can there even be overlap between these? + auto rewrite = IRMatcher::rewriter(IRMatcher::add(op->a, op->b), op->type); + if ( + // Only AVX512_SapphireRapids has accumulating dot products. + target.has_feature(Target::AVX512_SapphireRapids) && + ((op->type.element_of() == Int(32)) || + (op->type.element_of() == Float(32))) && + + // Accumulating pmaddubsw + (rewrite( + x + h_add(cast(Int(32, lanes * 4), widening_mul(y, z)), lanes), + v_instr(VectorInstruction::dot_product, x, y, z), + is_uint(y, 8) && is_int(z, 8)) || + + rewrite( + x + h_add(cast(Int(32, lanes * 4), widening_mul(y, z)), lanes), + v_instr(VectorInstruction::dot_product, x, z, y), + is_int(y, 8) && is_uint(z, 8)) || + + rewrite( + h_add(cast(Int(32, lanes * 4), widening_mul(x, y)), lanes) + z, + v_instr(VectorInstruction::dot_product, z, x, y), + is_uint(x, 8) && is_int(y, 8)) || + + rewrite( + h_add(cast(Int(32, lanes * 4), widening_mul(x, y)), lanes) + z, + v_instr(VectorInstruction::dot_product, z, y, x), + is_int(x, 8) && is_uint(y, 8)) || + + // Accumulating pmaddwd. + rewrite( + x + h_add(widening_mul(y, z), lanes), + v_instr(VectorInstruction::dot_product, x, y, z), + is_int(y, 16, lanes * 2) && is_int(z, 16, lanes * 2)) || + + rewrite( + h_add(widening_mul(x, y), lanes) + z, + v_instr(VectorInstruction::dot_product, z, x, y), + is_int(x, 16, lanes * 2) && is_int(y, 16, lanes * 2)) || + + // Accumulating fp dot products. + // TODO(rootjalex): This would be more powerful with lossless_cast checking. + rewrite( + x + h_add(cast(Float(32, lanes * 4), y) * cast(Float(32, lanes * 4), z), lanes), + v_instr(VectorInstruction::dot_product, x, y, z), + is_bfloat(y, 16) && is_bfloat(z, 16)) || + + rewrite( + h_add(cast(Float(32, lanes * 4), x) * cast(Float(32, lanes * 4), y), lanes) + z, + v_instr(VectorInstruction::dot_product, z, x, y), + is_bfloat(x, 16) && is_bfloat(y, 16)) || + + false)) { + return mutate(rewrite.result); + } + + if ((op->type.lanes() % 4 == 0) && should_use_dot_product(op->a, op->b, matches)) { + Expr ac = Shuffle::make_interleave({matches[0], matches[2]}); + Expr bd = Shuffle::make_interleave({matches[1], matches[3]}); + // We have dot_products for every x86 arch (because SSE2 has it), + // so this is `always` safe (as long as the output type lanes has + // a factor of 4). + return mutate(VectorInstruction::make(op->type, VectorInstruction::dot_product, {ac, bd})); + } + + return IRGraphMutator::visit(op); + } + + Expr visit(const Sub *op) override { + if (!should_peephole_optimize(op->type)) { + return IRGraphMutator::visit(op); + } + + std::vector matches; + // TODO(rootjalex): same issue as the Add case, lossless_cast and + // lossless_negate are hard to use in rewrite rules. + + if ((op->type.lanes() % 4 == 0) && should_use_dot_product(op->a, op->b, matches)) { + // Negate one of the factors in the second expression + Expr negative_2 = lossless_negate(matches[2]); + Expr negative_3 = lossless_negate(matches[3]); + if (negative_2.defined() || negative_3.defined()) { + if (negative_2.defined()) { + matches[2] = negative_2; + } else { + matches[3] = negative_3; + } + Expr ac = Shuffle::make_interleave({matches[0], matches[2]}); + Expr bd = Shuffle::make_interleave({matches[1], matches[3]}); + // Always safe, see comment in Add case above. + return mutate(VectorInstruction::make(op->type, VectorInstruction::dot_product, {ac, bd})); + } + } + + return IRGraphMutator::visit(op); + } + + Expr visit(const Cast *op) override { + if (!should_peephole_optimize(op->type)) { + return IRGraphMutator::visit(op); + } + + const int lanes = op->type.lanes(); + + auto rewrite = IRMatcher::rewriter(IRMatcher::cast(op->type, op->value), op->type); + + if ( + // pmulhrs is supported via AVX2 and SSE41, so SSE41 is the LCD. + (target.has_feature(Target::SSE41) && + rewrite( + cast(Int(16, lanes), rounding_shift_right(widening_mul(x, y), 15)), + v_instr(VectorInstruction::pmulhrs, x, y), + is_int(x, 16) && is_int(y, 16))) || + + // f32_to_bf16 is supported only via Target::AVX512_SapphireRapids + (target.has_feature(Target::AVX512_SapphireRapids) && + rewrite( + cast(BFloat(16, lanes), x), + v_instr(VectorInstruction::f32_to_bf16, x), + is_float(x, 32))) || + + false) { + return mutate(rewrite.result); + } + + // TODO: should we handle CodeGen_X86's weird 8 -> 16 bit issue here? + + return IRGraphMutator::visit(op); + } + + Expr visit(const Call *op) override { + if (!should_peephole_optimize(op->type)) { + return IRGraphMutator::visit(op); + } + + // TODO(rootjalex): This optimization is hard to do via a rewrite-rule because of lossless_cast. + + // A 16-bit mul-shift-right of less than 16 can sometimes be rounded up to a + // full 16 to use pmulh(u)w by left-shifting one of the operands. This is + // handled here instead of in the lowering of mul_shift_right because it's + // unlikely to be a good idea on platforms other than x86, as it adds an + // extra shift in the fully-lowered case. + if ((op->type.element_of() == UInt(16) || + op->type.element_of() == Int(16)) && + op->is_intrinsic(Call::mul_shift_right)) { + internal_assert(op->args.size() == 3); + const uint64_t *shift = as_const_uint(op->args[2]); + if (shift && *shift < 16 && *shift >= 8) { + Type narrow = op->type.with_bits(8); + Expr narrow_a = lossless_cast(narrow, op->args[0]); + Expr narrow_b = narrow_a.defined() ? Expr() : lossless_cast(narrow, op->args[1]); + int shift_left = 16 - (int)(*shift); + if (narrow_a.defined()) { + return mutate(mul_shift_right(op->args[0] << shift_left, op->args[1], 16)); + } else if (narrow_b.defined()) { + return mutate(mul_shift_right(op->args[0], op->args[1] << shift_left, 16)); + } + } + } + + const int lanes = op->type.lanes(); + const int bits = op->type.bits(); + + auto rewrite = IRMatcher::rewriter(op, op->type); + using IRMatcher::typed; + + Type unsigned_type = op->type.with_code(halide_type_uint); + auto x_uint = cast(unsigned_type, x); + auto y_uint = cast(unsigned_type, y); + + if ( + // saturating_narrow is always supported (via SSE2) for: + // int32 -> int16, int16 -> int8, int16 -> uint8 + rewrite( + saturating_cast(Int(16, lanes), x), + v_instr(VectorInstruction::saturating_narrow, x), + is_int(x, 32)) || + + rewrite( + saturating_cast(Int(8, lanes), x), + v_instr(VectorInstruction::saturating_narrow, x), + is_int(x, 16)) || + + rewrite( + saturating_cast(UInt(8, lanes), x), + v_instr(VectorInstruction::saturating_narrow, x), + is_int(x, 16)) || + + // int32 -> uint16 is supported via SSE41 + (target.has_feature(Target::SSE41) && + rewrite( + saturating_cast(UInt(16, lanes), x), + v_instr(VectorInstruction::saturating_narrow, x), + is_int(x, 32))) || + + // We can redirect signed rounding halving add to unsigned rounding + // halving add by adding 128 / 32768 to the result if the sign of the + // args differs. + ((op->type.is_int() && bits <= 16) && + rewrite( + rounding_halving_add(x, y), + cast(op->type, rounding_halving_add(x_uint, y_uint) + + ((x_uint ^ y_uint) & (1 << (bits - 1)))))) || + + // On x86, there are many 3-instruction sequences to compute absd of + // unsigned integers. This one consists solely of instructions with + // throughput of 3 ops per cycle on Cannon Lake. + // + // Solution due to Wojciech Mula: + // http://0x80.pl/notesen/2018-03-11-sse-abs-unsigned.html + rewrite( + absd(x, y), + saturating_sub(x, y) | saturating_sub(y, x), + is_uint(x) && is_uint(y)) || + + // Current best way to lower absd on x86. + rewrite( + absd(x, y), + // Cast is a no-op reinterpret. + cast(op->type, max(x, y) - min(x, y)), + is_int(x) && is_int(y)) || + + // pmulh is always supported (via SSE2). + ((op->type.is_int_or_uint() && bits == 16) && + rewrite( + mul_shift_right(x, y, 16), + v_instr(VectorInstruction::pmulh, x, y))) || + + // saturating_pmulhrs is supported via SSE41 + ((target.has_feature(Target::SSE41) && + op->type.is_int() && bits == 16) && + rewrite( + rounding_mul_shift_right(x, y, 15), + // saturating_pmulhrs + select((x == typed(Int(16, lanes), -32768)) && (y == typed(Int(16, lanes), -32768)), + typed(Int(16, lanes), 32767), + v_instr(VectorInstruction::pmulhrs, x, y)))) || + + // int(8 | 16 | 32) -> uint is supported via SSE41 + // float32 is always supported (via SSE2). + (((target.has_feature(Target::SSE41) && op->type.is_int() && bits <= 32) || + (op->type.is_float() && bits == 32)) && + rewrite( + abs(x), + v_instr(VectorInstruction::abs, x))) || + + // saturating ops for 8 and 16 bits are always supported (via SSE2). + ((bits == 8 || bits == 16) && + (rewrite( + saturating_add(x, y), + v_instr(VectorInstruction::saturating_add, x, y)) || + rewrite( + saturating_sub(x, y), + v_instr(VectorInstruction::saturating_sub, x, y)))) || + + // pavg ops for 8 and 16 bits are always supported (via SSE2). + ((op->type.is_uint() && (bits == 8 || bits == 16)) && + rewrite( + rounding_halving_add(x, y), + v_instr(VectorInstruction::rounding_halving_add, x, y))) || + + // int16 -> int32 widening_mul has a (v)pmaddwd implementation. + // always supported (via SSE2). + ((op->type.is_int() && (bits == 32)) && + rewrite( + widening_mul(x, y), + v_instr(VectorInstruction::widening_mul, x, y), + is_int(x, 16) && is_int(y, 16))) || + + (target.has_feature(Target::AVX512_SapphireRapids) && + (op->type.is_int() && (bits == 32)) && + // SapphireRapids accumulating dot products. + (rewrite( + saturating_add(x, h_satadd(cast(Int(32, lanes * 4), widening_mul(y, z)), lanes)), + v_instr(VectorInstruction::saturating_dot_product, x, y, z), + is_uint(y, 8) && is_int(z, 8)) || + + rewrite( + saturating_add(x, h_satadd(cast(Int(32, lanes * 4), widening_mul(y, z)), lanes)), + v_instr(VectorInstruction::saturating_dot_product, x, z, y), + is_int(y, 8) && is_uint(z, 8)) || + + rewrite( + saturating_add(x, h_satadd(cast(Int(32, lanes * 2), widening_mul(y, z)), lanes)), + v_instr(VectorInstruction::saturating_dot_product, x, y, z), + is_uint(y, 8) && is_int(z, 8)) || + + rewrite( + saturating_add(x, h_satadd(cast(Int(32, lanes * 2), widening_mul(y, z)), lanes)), + v_instr(VectorInstruction::saturating_dot_product, x, z, y), + is_int(y, 8) && is_uint(z, 8)) || + + rewrite( + saturating_add(x, h_satadd(widening_mul(y, z), lanes)), + v_instr(VectorInstruction::saturating_dot_product, x, z, y), + is_int(y, 16, lanes * 2) && is_int(z, 16, lanes * 2)) || + + false)) || + + false) { + return mutate(rewrite.result); + } + + // Fixed-point intrinsics should be lowered here. + // This is safe because this mutator is top-down. + // FIXME: Should this be default behavior of the base InstructionSelector class? + if (op->is_intrinsic({ + Call::halving_add, + Call::halving_sub, + Call::mul_shift_right, + Call::rounding_halving_add, + Call::rounding_mul_shift_right, + Call::rounding_shift_left, + Call::rounding_shift_right, + Call::saturating_add, + Call::saturating_sub, + Call::sorted_avg, + Call::widening_add, + Call::widening_mul, + Call::widening_shift_left, + Call::widening_shift_right, + Call::widening_sub, + })) { + return mutate(lower_intrinsic(op)); + } + + return IRGraphMutator::visit(op); + } + + Expr visit(const VectorReduce *op) override { + // FIXME: We need to split up VectorReduce nodes in the same way that + // CodeGen_LLVM::codegen_vector_reduce does, in order to do all + // matching here. + if ((op->op != VectorReduce::Add && op->op != VectorReduce::SaturatingAdd) || + !should_peephole_optimize(op->type)) { + return InstructionSelector::visit(op); + } + + const int lanes = op->type.lanes(); + const int value_lanes = op->value.type().lanes(); + const int factor = value_lanes / lanes; + Expr value = op->value; + + // Useful constants for some of the below rules. + const Expr one_i16 = make_one(Int(16, value_lanes)); + const Expr one_i8 = make_one(Int(8, value_lanes)); + const Expr one_u8 = make_one(UInt(8, value_lanes)); + const Expr zero_i32 = make_zero(Int(32, lanes)); + const Expr zero_f32 = make_zero(Float(32, lanes)); + + switch (op->op) { + case VectorReduce::Add: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_add(value, lanes), op->type); + auto x_is_int_or_uint = is_int(x) || is_uint(x); + auto y_is_int_or_uint = is_int(y) || is_uint(y); + if ( + // 2-way dot-products, int16 -> int32 is always supported (via SSE2). + ((factor == 2) && + (rewrite( + h_add(cast(Int(32, value_lanes), widening_mul(x, y)), lanes), + v_instr(VectorInstruction::dot_product, cast(Int(16, value_lanes), x), cast(Int(16, value_lanes), y)), + x_is_int_or_uint && y_is_int_or_uint) || + + // Horizontal widening add via pmaddwd + rewrite( + h_add(cast(Int(32, value_lanes), x), lanes), + v_instr(VectorInstruction::dot_product, x, one_i16), + is_int(x, 16)) || + + (rewrite( + h_add(widening_mul(x, y), lanes), + v_instr(VectorInstruction::dot_product, x, y), + is_int(x, 16) && is_int(y, 16))) || + + // pmaddub supported via SSE41 + (target.has_feature(Target::SSE41) && + // Horizontal widening adds using 2-way saturating dot products. + (rewrite( + h_add(cast(UInt(16, value_lanes), x), lanes), + cast(UInt(16, lanes), typed(Int(16, lanes), v_instr(VectorInstruction::saturating_dot_product, x, one_i8))), + is_uint(x, 8)) || + + rewrite( + h_add(cast(Int(16, value_lanes), x), lanes), + v_instr(VectorInstruction::saturating_dot_product, x, one_i8), + is_uint(x, 8)) || + + rewrite( + h_add(cast(Int(16, value_lanes), x), lanes), + v_instr(VectorInstruction::saturating_dot_product, one_u8, x), + is_int(x, 8)) || + + // SSE41 and AVX2 support horizontal_add via phadd intrinsics. + rewrite( + h_add(x, lanes), + v_instr(VectorInstruction::horizontal_add, x), + is_int(x, 16, lanes * 2) || is_uint(x, 16, lanes * 2) || + is_int(x, 32, lanes * 2) || is_uint(x, 32, lanes * 2)) || + + false)) || + false)) || + + // We can use the AVX512_SapphireRapids accumulating dot products + // on pure VectorReduce nodes with 0 as the accumulator. + ((factor == 4) && + target.has_feature(Target::AVX512_SapphireRapids) && + ((op->type.element_of() == Int(32)) || + (op->type.element_of() == Float(32))) && + + // Accumulating pmaddubsw + (rewrite( + h_add(cast(Int(32, lanes * 4), widening_mul(x, y)), lanes), + v_instr(VectorInstruction::dot_product, zero_i32, x, y), + is_uint(x, 8) && is_int(y, 8)) || + + rewrite( + h_add(cast(Int(32, lanes * 4), widening_mul(x, y)), lanes), + v_instr(VectorInstruction::dot_product, zero_i32, y, x), + is_int(x, 8) && is_uint(y, 8)) || + + // Accumulating pmaddwd. + rewrite( + h_add(widening_mul(x, y), lanes), + v_instr(VectorInstruction::dot_product, zero_i32, x, y), + is_int(x, 16, lanes * 2) && is_int(y, 16, lanes * 2)) || + + // Accumulating fp dot products. + // TODO(rootjalex): This would be more powerful with lossless_cast checking. + rewrite( + h_add(cast(Float(32, lanes * 4), x) * cast(Float(32, lanes * 4), y), lanes), + v_instr(VectorInstruction::dot_product, zero_f32, x, y), + is_bfloat(x, 16) && is_bfloat(y, 16)) || + + false)) || + + // psadbw is always supported via SSE2. + ((factor == 8) && + (rewrite( + h_add(cast(UInt(64, value_lanes), absd(x, y)), lanes), + v_instr(VectorInstruction::sum_absd, x, y), + is_uint(x, 8) && is_uint(y, 8)) || + + // Rewrite non-native sum-of-absolute-difference variants to the native + // op. We support reducing to various types. We could consider supporting + // multiple reduction factors too, but in general we don't handle non-native + // reduction factors for VectorReduce nodes (yet?). + rewrite( + h_add(cast(UInt(16, value_lanes), absd(x, y)), lanes), + cast(UInt(16, lanes), typed(UInt(64, lanes), v_instr(VectorInstruction::sum_absd, x, y))), + is_uint(x, 8) && is_uint(y, 8)) || + + rewrite( + h_add(cast(UInt(32, value_lanes), absd(x, y)), lanes), + cast(UInt(32, lanes), typed(UInt(64, lanes), v_instr(VectorInstruction::sum_absd, x, y))), + is_uint(x, 8) && is_uint(y, 8)) || + + rewrite( + h_add(cast(Int(16, value_lanes), absd(x, y)), lanes), + cast(Int(16, lanes), typed(UInt(64, lanes), v_instr(VectorInstruction::sum_absd, x, y))), + is_uint(x, 8) && is_uint(y, 8)) || + + rewrite( + h_add(cast(Int(32, value_lanes), absd(x, y)), lanes), + cast(Int(32, lanes), typed(UInt(64, lanes), v_instr(VectorInstruction::sum_absd, x, y))), + is_uint(x, 8) && is_uint(y, 8)) || + + rewrite( + h_add(cast(Int(64, value_lanes), absd(x, y)), lanes), + cast(Int(64, lanes), typed(UInt(64, lanes), v_instr(VectorInstruction::sum_absd, x, y))), + is_uint(x, 8) && is_uint(y, 8)) || + + false))) { + return mutate(rewrite.result); + } + break; + } + case VectorReduce::SaturatingAdd: { + auto rewrite = IRMatcher::rewriter(IRMatcher::h_satadd(value, lanes), op->type); + if ( + // Saturating dot products are supported via SSE41 and AVX2. + ((factor == 2) && target.has_feature(Target::SSE41) && + (rewrite( + h_satadd(widening_mul(x, y), lanes), + v_instr(VectorInstruction::saturating_dot_product, x, y), + is_uint(x, 8) && is_int(y, 8)) || + + rewrite( + h_satadd(widening_mul(x, y), lanes), + v_instr(VectorInstruction::saturating_dot_product, y, x), + is_int(x, 8) && is_uint(y, 8)) || + + false))) { + return mutate(rewrite.result); + } + break; + } + default: + break; + } + + return InstructionSelector::visit(op); + } + +private: + IRMatcher::Wild<0> x; + IRMatcher::Wild<1> y; + IRMatcher::Wild<2> z; +}; + +} // namespace + +Stmt optimize_x86_instructions(const Stmt &s, const Target &target, const CodeGen_LLVM *codegen) { + Stmt stmt = Optimize_X86(target, codegen).mutate(s); + + // Some of the rules above can introduce repeated sub-terms, so run CSE again. + if (!stmt.same_as(s)) { + stmt = common_subexpression_elimination(stmt); + return stmt; + } else { + return s; + } +} + +#else // WITH_X86 + +Stmt optimize_x86_instructions(const Stmt &s, const Target &t, const CodeGen_LLVM *codegen) { + user_error << "x86 not enabled for this build of Halide.\n"; + return Stmt(); +} + +#endif // WITH_X86 + +} // namespace Internal +} // namespace Halide diff --git a/src/X86Optimize.h b/src/X86Optimize.h new file mode 100644 index 000000000000..9732a2dba545 --- /dev/null +++ b/src/X86Optimize.h @@ -0,0 +1,21 @@ +#ifndef HALIDE_IR_X86_OPTIMIZE_H +#define HALIDE_IR_X86_OPTIMIZE_H + +/** \file + * Tools for optimizing IR for x86. + */ + +#include "CodeGen_LLVM.h" +#include "Expr.h" +#include "Target.h" + +namespace Halide { +namespace Internal { + +/** Perform vector instruction selection, inserting VectorInstruction nodes. */ +Stmt optimize_x86_instructions(const Stmt &stmt, const Target &target, const CodeGen_LLVM *codegen); + +} // namespace Internal +} // namespace Halide + +#endif diff --git a/src/runtime/x86_avx2.ll b/src/runtime/x86_avx2.ll index d4d88be839c6..221d9560502d 100644 --- a/src/runtime/x86_avx2.ll +++ b/src/runtime/x86_avx2.ll @@ -52,16 +52,6 @@ define weak_odr <8 x i32> @abs_i32x8(<8 x i32> %arg) { ret <8 x i32> %3 } -define weak_odr <16 x i16> @saturating_pmulhrswx16(<16 x i16> %a, <16 x i16> %b) nounwind uwtable readnone alwaysinline { - %1 = tail call <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16> %a, <16 x i16> %b) - %2 = icmp eq <16 x i16> %a, - %3 = icmp eq <16 x i16> %b, - %4 = and <16 x i1> %2, %3 - %5 = select <16 x i1> %4, <16 x i16> , <16 x i16> %1 - ret <16 x i16> %5 -} -declare <16 x i16> @llvm.x86.avx2.pmul.hr.sw(<16 x i16>, <16 x i16>) nounwind readnone - define weak_odr <16 x i16> @hadd_pmadd_u8_avx2(<32 x i8> %a) nounwind alwaysinline { %1 = tail call <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8> %a, <32 x i8> ) ret <16 x i16> %1 @@ -72,3 +62,34 @@ define weak_odr <16 x i16> @hadd_pmadd_i8_avx2(<32 x i8> %a) nounwind alwaysinli ret <16 x i16> %1 } declare <16 x i16> @llvm.x86.avx2.pmadd.ub.sw(<32 x i8>, <32 x i8>) nounwind readnone + +define weak_odr <16 x i16> @phaddw_avx2(<32 x i16> %a) nounwind alwaysinline { + %1 = shufflevector <32 x i16> %a, <32 x i16> poison, <16 x i32> + %2 = shufflevector <32 x i16> %a, <32 x i16> poison, <16 x i32> + %3 = tail call <16 x i16> @llvm.x86.avx2.phadd.w(<16 x i16> %1, <16 x i16> %2) + ret <16 x i16> %3 + } + declare <16 x i16> @llvm.x86.avx2.phadd.w(<16 x i16>, <16 x i16>) nounwind readnone + + define weak_odr <8 x i32> @phaddd_avx2(<16 x i32> %a) nounwind alwaysinline { + %1 = shufflevector <16 x i32> %a, <16 x i32> poison, <8 x i32> + %2 = shufflevector <16 x i32> %a, <16 x i32> poison, <8 x i32> + %3 = tail call <8 x i32> @llvm.x86.avx2.phadd.d(<8 x i32> %1, <8 x i32> %2) + ret <8 x i32> %3 + } + declare <8 x i32> @llvm.x86.avx2.phadd.d(<8 x i32>, <8 x i32>) nounwind readnone + + define weak_odr <8 x i32> @hadd_pmadd_i16_avx2(<16 x i16> %a) nounwind alwaysinline { + %1 = tail call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %a, <16 x i16> ) + ret <8 x i32> %1 + } + + define weak_odr <8 x i32> @wmul_pmaddwd_avx2(<8 x i16> %a, <8 x i16> %b) nounwind alwaysinline { + %1 = zext <8 x i16> %a to <8 x i32> + %2 = zext <8 x i16> %b to <8 x i32> + %3 = bitcast <8 x i32> %1 to <16 x i16> + %4 = bitcast <8 x i32> %2 to <16 x i16> + %res = call <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16> %3, <16 x i16> %4) + ret <8 x i32> %res + } + declare <8 x i32> @llvm.x86.avx2.pmadd.wd(<16 x i16>, <16 x i16>) nounwind readnone diff --git a/src/runtime/x86_sse41.ll b/src/runtime/x86_sse41.ll index d181de3d67e8..6c7b2356de75 100644 --- a/src/runtime/x86_sse41.ll +++ b/src/runtime/x86_sse41.ll @@ -72,16 +72,6 @@ define weak_odr <4 x i32> @abs_i32x4(<4 x i32> %x) nounwind uwtable readnone alw ret <4 x i32> %3 } -define weak_odr <8 x i16> @saturating_pmulhrswx8(<8 x i16> %a, <8 x i16> %b) nounwind uwtable readnone alwaysinline { - %1 = tail call <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16> %a, <8 x i16> %b) - %2 = icmp eq <8 x i16> %a, - %3 = icmp eq <8 x i16> %b, - %4 = and <8 x i1> %2, %3 - %5 = select <8 x i1> %4, <8 x i16> , <8 x i16> %1 - ret <8 x i16> %5 -} -declare <8 x i16> @llvm.x86.ssse3.pmul.hr.sw.128(<8 x i16>, <8 x i16>) nounwind readnone - define weak_odr <8 x i16> @hadd_pmadd_u8_sse3(<16 x i8> %a) nounwind alwaysinline { %1 = tail call <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8> %a, <16 x i8> ) ret <8 x i16> %1 @@ -92,3 +82,19 @@ define weak_odr <8 x i16> @hadd_pmadd_i8_sse3(<16 x i8> %a) nounwind alwaysinlin ret <8 x i16> %1 } declare <8 x i16> @llvm.x86.ssse3.pmadd.ub.sw.128(<16 x i8>, <16 x i8>) nounwind readnone + +define weak_odr <8 x i16> @phaddw_sse3(<16 x i16> %a) nounwind alwaysinline { + %1 = shufflevector <16 x i16> %a, <16 x i16> poison, <8 x i32> + %2 = shufflevector <16 x i16> %a, <16 x i16> poison, <8 x i32> + %3 = tail call <8 x i16> @llvm.x86.ssse3.phadd.w.128(<8 x i16> %1, <8 x i16> %2) + ret <8 x i16> %3 + } + declare <8 x i16> @llvm.x86.ssse3.phadd.w.128(<8 x i16>, <8 x i16>) nounwind readnone + + define weak_odr <4 x i32> @phaddd_sse3(<8 x i32> %a) nounwind alwaysinline { + %1 = shufflevector <8 x i32> %a, <8 x i32> poison, <4 x i32> + %2 = shufflevector <8 x i32> %a, <8 x i32> poison, <4 x i32> + %3 = tail call <4 x i32> @llvm.x86.ssse3.phadd.d.128(<4 x i32> %1, <4 x i32> %2) + ret <4 x i32> %3 + } + declare <4 x i32> @llvm.x86.ssse3.phadd.d.128(<4 x i32>, <4 x i32>) nounwind readnone