Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
41 commits
Select commit Hold shift + click to select a range
749cddd
Clean up some pointless code
abadams Sep 13, 2022
23b54ae
Improve comment on Halide::round
abadams Sep 13, 2022
41663c3
Make Halide::round round to even as documented
abadams Sep 13, 2022
6f3b7d4
Explicitly set the rounding mode in the C backend
abadams Sep 13, 2022
053002a
Use rint on ptx, which is documented to round to even
abadams Sep 13, 2022
b3b3685
round to even on win32
abadams Sep 13, 2022
e6e2d84
the nvidia libdevice is buggy for doubles
abadams Sep 13, 2022
083e651
Add missing include to C output
abadams Sep 13, 2022
4195d5f
Fix rounding in opencl
abadams Sep 14, 2022
5d8e7d7
Don't test opencl with doubles if CLDoubles is not enabled
abadams Sep 14, 2022
e429084
Work around hexagon issue
abadams Sep 14, 2022
8f0e387
Don't try to emit roundeven on wasm
abadams Sep 14, 2022
e31746d
wasm doesn't support float16
abadams Sep 15, 2022
7da057b
Add vectorizable lowering for round on platforms without roundeven
abadams Sep 16, 2022
c615490
Use rint on metal for Halide::round
abadams Sep 16, 2022
7e0963f
Make round an intrinsic
abadams Sep 17, 2022
8cf0eb8
Constant-fold round in simplifier
abadams Sep 18, 2022
063a208
d3d12 fix
abadams Sep 18, 2022
b6b4ced
Bounds of Call::round
abadams Sep 18, 2022
aa29466
Teach the mullapudi cost model about round
abadams Sep 18, 2022
3907e98
Handle PureIntrinsics of const args in bounds
abadams Sep 18, 2022
8429b17
scatter, undef, and require aren't pure
abadams Sep 19, 2022
fdc4759
metal doesn't support doubles
abadams Sep 19, 2022
e35c1a7
More parens
abadams Sep 19, 2022
469b0da
Add missing return
abadams Sep 19, 2022
48a78bd
Add vector versions of rint for wasm
abadams Sep 19, 2022
de33792
Use nearbyint for wasm instead of rint
abadams Sep 19, 2022
a542e07
Merge branch 'main' into abadams/fix_round
steven-johnson Sep 19, 2022
c353b00
Merge branch 'main' into abadams/fix_round
steven-johnson Sep 20, 2022
464e47b
Merge branch 'main' into abadams/fix_round
steven-johnson Sep 21, 2022
c426703
revert change to mangling
abadams Sep 21, 2022
3e3944c
Merge branch 'abadams/fix_round' of https://github.com/halide/Halide …
abadams Sep 21, 2022
8d930a0
d3d12 doesn't like double input/output buffers
abadams Sep 22, 2022
7332345
Lower round on arm-32 not-linux
abadams Sep 22, 2022
3815d3f
Don't simplify lowering of round-to-nearest-ties-to-even in codegen
abadams Sep 23, 2022
4e771ca
Fix infinite loop in round lowering on arm-32-notlinux
abadams Sep 23, 2022
1a581fc
Take care to never revisit args in bounds call visitor
abadams Sep 23, 2022
3ad3841
Merge branch 'main' into abadams/fix_round
steven-johnson Sep 23, 2022
66bb8d4
Remove defunct comment
abadams Sep 23, 2022
b8d2e08
Merge branch 'abadams/fix_round' of https://github.com/halide/Halide …
steven-johnson Sep 24, 2022
5c063bb
Merge branch 'main' into abadams/fix_round
steven-johnson Sep 26, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 51 additions & 47 deletions src/Bounds.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1114,19 +1114,44 @@ class Bounds : public IRVisitor {
void visit(const Call *op) override {
TRACK_BOUNDS_INTERVAL;
TRACK_BOUNDS_INFO("name:", op->name);

// Tags are hints that don't affect the results of the expression,
// and can be very deeply nested in the case of strict_float. The
// bounds of this call are *always* exactly that of its first argument,
// so short circuit it here before checking for const_args. This is
// important because evaluating const_args for such a deeply nested case
// essentially becomes O(n^2) doing work that is unnecessary, making
// otherwise simple pipelines take several minutes to compile.
// so short circuit it here.
if (op->is_tag()) {
internal_assert(op->args.size() == 1);
op->args[0].accept(this);
return;
}

// For call nodes, we want to only evaluate the bounds of each arg once, but
// lazily because for many functions we don't need them at all. This class
// helps avoid accidentally revisiting nodes.
class LazyArgBounds {
const vector<Expr> &args;
Bounds *visitor;
vector<Interval> intervals;

public:
LazyArgBounds(const vector<Expr> &args, Bounds *visitor)
: args(args), visitor(visitor) {
}

const Interval &get(int i) {
if (intervals.empty()) {
intervals.resize(args.size(), Interval::nothing());
}
if (intervals[i].is_empty()) {
args[i].accept(visitor);
intervals[i] = visitor->interval;
}
return intervals[i];
}
};

LazyArgBounds arg_bounds(op->args, this);

Type t = op->type.element_of();

if (t.is_handle()) {
Expand All @@ -1136,22 +1161,19 @@ class Bounds : public IRVisitor {

if (!const_bound &&
(op->call_type == Call::PureExtern ||
op->call_type == Call::PureIntrinsic ||
op->call_type == Call::Image)) {

// If the args are const we can return the call of those args
// for pure functions. For other types of functions, the same
// call in two different places might produce different
// results (e.g. during the update step of a reduction), so we
// can't move around call nodes.
//
// Note: Only evaluate new_args if we know the call is a candidate;
// otherwise we can get n^2 evaluation time for deeply-nested
// Expr trees.

std::vector<Expr> new_args(op->args.size());
bool const_args = true;
for (size_t i = 0; i < op->args.size() && const_args; i++) {
op->args[i].accept(this);
const Interval &interval = arg_bounds.get(i);
if (interval.is_single_point()) {
new_args[i] = interval.min;
} else {
Expand All @@ -1168,8 +1190,7 @@ class Bounds : public IRVisitor {
}

if (op->is_intrinsic(Call::abs)) {
op->args[0].accept(this);
Interval a = interval;
Interval a = arg_bounds.get(0);
interval.min = make_zero(t);
if (a.is_bounded()) {
if (equal(a.min, a.max)) {
Expand All @@ -1193,17 +1214,8 @@ class Bounds : public IRVisitor {
} else {
// absd() for int types will always produce a uint result
internal_assert(t.is_uint());

Expr a = op->args[0];
Expr b = op->args[1];
internal_assert(a.type() == b.type());

a.accept(this);
Interval a_interval = interval;

b.accept(this);
Interval b_interval = interval;

Interval a_interval = arg_bounds.get(0);
Interval b_interval = arg_bounds.get(1);
if (a_interval.is_bounded() && b_interval.is_bounded()) {
interval.min = make_zero(t);
interval.max = max(absd(a_interval.max, b_interval.min), absd(a_interval.min, b_interval.max));
Expand All @@ -1221,11 +1233,9 @@ class Bounds : public IRVisitor {
op->is_intrinsic(Call::promise_clamped)) {
// Unlike an explicit clamp, we are also permitted to
// assume the upper bound is greater than the lower bound.
op->args[1].accept(this);
Interval lower = interval;
op->args[2].accept(this);
Interval upper = interval;
op->args[0].accept(this);
Interval lower = arg_bounds.get(1);
Interval upper = arg_bounds.get(2);
interval = arg_bounds.get(0);

if (op->is_intrinsic(Call::promise_clamped) &&
interval.is_single_point()) {
Expand Down Expand Up @@ -1257,11 +1267,9 @@ class Bounds : public IRVisitor {

interval.min = Interval::make_max(interval.min, lower.min);
interval.max = Interval::make_min(interval.max, upper.max);
} else if (Call::as_tag(op)) {
op->args[0].accept(this);
} else if (op->is_intrinsic(Call::return_second)) {
internal_assert(op->args.size() == 2);
op->args[1].accept(this);
interval = arg_bounds.get(1);
} else if (op->is_intrinsic(Call::if_then_else)) {
internal_assert(op->args.size() == 2 || op->args.size() == 3);
// Probably more conservative than necessary
Expand All @@ -1270,17 +1278,15 @@ class Bounds : public IRVisitor {
equivalent_select.accept(this);
} else if (op->is_intrinsic(Call::require)) {
internal_assert(op->args.size() == 3);
op->args[1].accept(this);
interval = arg_bounds.get(1);
} else if (op->is_intrinsic(Call::shift_left) ||
op->is_intrinsic(Call::shift_right) ||
op->is_intrinsic(Call::bitwise_xor) ||
op->is_intrinsic(Call::bitwise_and) ||
op->is_intrinsic(Call::bitwise_or)) {
Expr a = op->args[0], b = op->args[1];
a.accept(this);
Interval a_interval = interval;
b.accept(this);
Interval b_interval = interval;
Interval a_interval = arg_bounds.get(0);
Interval b_interval = arg_bounds.get(1);
if (a_interval.is_single_point(a) && b_interval.is_single_point(b)) {
interval = Interval::single_point(op);
} else if (a_interval.is_single_point() && b_interval.is_single_point()) {
Expand Down Expand Up @@ -1438,8 +1444,7 @@ class Bounds : public IRVisitor {
// In 2's complement bitwise not inverts the ordering of
// the space, without causing overflow (unlike negation),
// so bitwise not is monotonic decreasing.
op->args[0].accept(this);
Interval a_interval = interval;
Interval a_interval = arg_bounds.get(0);
if (a_interval.is_single_point(op->args[0])) {
interval = Interval::single_point(op);
} else if (a_interval.is_single_point()) {
Expand All @@ -1455,12 +1460,13 @@ class Bounds : public IRVisitor {
}
}
}
} else if (op->args.size() == 1 && interval.is_bounded() &&
(op->name == "ceil_f32" || op->name == "ceil_f64" ||
} else if (op->args.size() == 1 &&
(op->is_intrinsic(Call::round) ||
op->name == "ceil_f32" || op->name == "ceil_f64" ||
op->name == "floor_f32" || op->name == "floor_f64" ||
op->name == "round_f32" || op->name == "round_f64" ||
op->name == "exp_f32" || op->name == "exp_f64" ||
op->name == "log_f32" || op->name == "log_f64")) {
op->name == "log_f32" || op->name == "log_f64") &&
(interval = arg_bounds.get(0)).is_bounded()) {
// For monotonic, pure, single-argument functions, we can
// make two calls for the min and the max.
interval = Interval(
Expand Down Expand Up @@ -1490,21 +1496,19 @@ class Bounds : public IRVisitor {
interval = Interval(min, max);
} else if (op->is_intrinsic(Call::memoize_expr)) {
internal_assert(!op->args.empty());
op->args[0].accept(this);
interval = arg_bounds.get(0);
} else if (op->is_intrinsic(Call::scatter_gather)) {
// Take the union of the args
Interval result = Interval::nothing();
for (const Expr &e : op->args) {
e.accept(this);
result.include(interval);
for (size_t i = 0; i < op->args.size(); i++) {
result.include(arg_bounds.get(i));
}
interval = result;
} else if (op->is_intrinsic(Call::mux)) {
// Take the union of all args but the first
Interval result = Interval::nothing();
for (size_t i = 1; i < op->args.size(); i++) {
op->args[i].accept(this);
result.include(interval);
result.include(arg_bounds.get(i));
}
interval = result;
} else if (op->is_intrinsic(Call::widen_right_add)) {
Expand Down
22 changes: 21 additions & 1 deletion src/CodeGen_ARM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,13 @@ const ArmIntrinsic intrinsic_defs[] = {
{"llvm.sqrt", "llvm.sqrt", Float(32, 2), "sqrt_f32", {Float(32, 2)}, ArmIntrinsic::HalfWidth},
{"llvm.sqrt", "llvm.sqrt", Float(64, 2), "sqrt_f64", {Float(64, 2)}},

{"llvm.roundeven", "llvm.roundeven", Float(16, 8), "round", {Float(16, 8)}, ArmIntrinsic::RequireFp16},
{"llvm.roundeven", "llvm.roundeven", Float(32, 4), "round", {Float(32, 4)}},
{"llvm.roundeven", "llvm.roundeven", Float(64, 2), "round", {Float(64, 2)}},
{"llvm.roundeven.f16", "llvm.roundeven.f16", Float(16), "round", {Float(16)}, ArmIntrinsic::RequireFp16 | ArmIntrinsic::NoMangle},
{"llvm.roundeven.f32", "llvm.roundeven.f32", Float(32), "round", {Float(32)}, ArmIntrinsic::NoMangle},
{"llvm.roundeven.f64", "llvm.roundeven.f64", Float(64), "round", {Float(64)}, ArmIntrinsic::NoMangle},

// SABD, UABD - Absolute difference
{"vabds", "sabd", UInt(8, 8), "absd", {Int(8, 8), Int(8, 8)}, ArmIntrinsic::HalfWidth},
{"vabdu", "uabd", UInt(8, 8), "absd", {UInt(8, 8), UInt(8, 8)}, ArmIntrinsic::HalfWidth},
Expand Down Expand Up @@ -599,7 +606,6 @@ const std::set<string> float16_native_funcs = {
"is_finite_f16",
"is_inf_f16",
"is_nan_f16",
"round_f16",
"sqrt_f16",
"trunc_f16",
};
Expand Down Expand Up @@ -1139,6 +1145,20 @@ void CodeGen_ARM::visit(const Call *op) {
// We want these as left shifts with a negative b instead.
value = codegen(op->args[0] << simplify(-op->args[1]));
return;
} else if (op->is_intrinsic(Call::round)) {
// llvm's roundeven intrinsic reliably lowers to the correct
// instructions on aarch64, but despite having the same instruction
// available, it doesn't seem to work for arm-32.
if (target.bits == 64) {
value = call_overloaded_intrin(op->type, "round", op->args);
if (value) {
return;
}
} else if (target.os != Target::Linux) {
// Furthermore, roundevenf isn't always in the standard library on arm-32
value = codegen(lower_round_to_nearest_ties_to_even(op->args[0]));
return;
}
}

if (op->type.is_vector()) {
Expand Down
8 changes: 6 additions & 2 deletions src/CodeGen_C.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ const string headers = R"INLINE_CODE(
#include <stdio.h>
#include <string.h>
#include <type_traits>
#include <fenv.h>
)INLINE_CODE";

// We now add definitions of things in the runtime which are
Expand Down Expand Up @@ -109,7 +110,6 @@ inline float log_f32(float x) {return logf(x);}
inline float pow_f32(float x, float y) {return powf(x, y);}
inline float floor_f32(float x) {return floorf(x);}
inline float ceil_f32(float x) {return ceilf(x);}
inline float round_f32(float x) {return nearbyint(x);}

inline double sqrt_f64(double x) {return sqrt(x);}
inline double sin_f64(double x) {return sin(x);}
Expand All @@ -128,7 +128,6 @@ inline double log_f64(double x) {return log(x);}
inline double pow_f64(double x, double y) {return pow(x, y);}
inline double floor_f64(double x) {return floor(x);}
inline double ceil_f64(double x) {return ceil(x);}
inline double round_f64(double x) {return nearbyint(x);}

inline float nan_f32() {return NAN;}
inline float neg_inf_f32() {return -INFINITY;}
Expand Down Expand Up @@ -2383,6 +2382,11 @@ void CodeGen_C::visit(const Call *op) {
create_assertion(op->args[0], op->args[2]);
rhs << print_expr(op->args[1]);
}
} else if (op->is_intrinsic(Call::round)) {
// There's no way to get rounding with ties to nearest even that works
// in all contexts where someone might be compiling generated C++ code,
// so we just lower it into primitive operations.
rhs << print_expr(lower_round_to_nearest_ties_to_even(op->args[0]));
} else if (op->is_intrinsic(Call::abs)) {
internal_assert(op->args.size() == 1);
Expr a0 = op->args[0];
Expand Down
4 changes: 3 additions & 1 deletion src/CodeGen_D3D12Compute_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -378,6 +378,9 @@ void CodeGen_D3D12Compute_Dev::CodeGen_D3D12Compute_C::visit(const Call *op) {
// If we know pow(x, y) is called with x > 0, we can use HLSL's pow
// directly.
stream << "pow(" << print_expr(op->args[0]) << ", " << print_expr(op->args[1]) << ")";
} else if (op->is_intrinsic(Call::round)) {
// HLSL's round intrinsic has the correct semantics for our rounding.
print_assignment(op->type, "round(" + print_expr(op->args[0]) + ")");
} else {
CodeGen_GPU_C::visit(op);
}
Expand Down Expand Up @@ -1311,7 +1314,6 @@ void CodeGen_D3D12Compute_Dev::init_module() {
<< "#define abs_f32 abs \n"
<< "#define floor_f32 floor \n"
<< "#define ceil_f32 ceil \n"
<< "#define round_f32 round \n"
<< "#define trunc_f32 trunc \n"
// pow() in HLSL has the same semantics as C if
// x > 0. Otherwise, we need to emulate C
Expand Down
10 changes: 10 additions & 0 deletions src/CodeGen_GPU_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -199,5 +199,15 @@ void CodeGen_GPU_C::visit(const Shuffle *op) {
}
}

void CodeGen_GPU_C::visit(const Call *op) {
// In metal and opencl, "rint" is a polymorphic function that matches our
// rounding semantics. GLSL handles it separately using "roundEven".
if (op->is_intrinsic(Call::round)) {
print_assignment(op->type, "rint(" + print_expr(op->args[0]) + ")");
} else {
CodeGen_C::visit(op);
}
}

} // namespace Internal
} // namespace Halide
1 change: 1 addition & 0 deletions src/CodeGen_GPU_Dev.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ class CodeGen_GPU_C : public CodeGen_C {
protected:
using CodeGen_C::visit;
void visit(const Shuffle *op) override;
void visit(const Call *op) override;

VectorDeclarationStyle vector_declaration_style = VectorDeclarationStyle::CLikeSyntax;
};
Expand Down
29 changes: 29 additions & 0 deletions src/CodeGen_Internal.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -523,6 +523,35 @@ Expr lower_mux(const Call *mux) {
return equiv;
}

// An implementation of rounding to nearest integer with ties to even to use for
// Halide::round. Written to avoid all use of c standard library functions so
// that it's cleanly vectorizable and a safe fallback on all platforms.
Expr lower_round_to_nearest_ties_to_even(const Expr &x) {
Type bits_type = x.type().with_code(halide_type_uint);
Type int_type = x.type().with_code(halide_type_int);

// Make one half with the same sign as x
Expr sign_bit = reinterpret(bits_type, x) & (cast(bits_type, 1) << (x.type().bits() - 1));
Expr one_half = reinterpret(bits_type, cast(x.type(), 0.5f)) | sign_bit;
Expr just_under_one_half = reinterpret(x.type(), one_half - 1);
one_half = reinterpret(x.type(), one_half);
// Do the same for the constant one.
Expr one = reinterpret(bits_type, cast(x.type(), 1)) | sign_bit;
// Round to nearest, with ties going towards zero
Expr ix = cast(int_type, x + just_under_one_half);
Expr a = cast(x.type(), ix);
// Get the residual
Expr diff = a - x;
// Make a mask of all ones if the result is odd
Expr odd = -cast(bits_type, ix & 1);
// Make a mask of all ones if the result was a tie
Expr tie = select(diff == one_half, cast(bits_type, -1), cast(bits_type, 0));
// If it was a tie, and the result is odd, we should have rounded in the
// other direction.
Expr correction = reinterpret(x.type(), odd & tie & one);
return common_subexpression_elimination(a - correction);
}

bool get_md_bool(llvm::Metadata *value, bool &result) {
if (!value) {
return false;
Expand Down
4 changes: 4 additions & 0 deletions src/CodeGen_Internal.h
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ Expr lower_extract_bits(const Call *c);
Expr lower_concat_bits(const Call *c);
///@}

/** An vectorizable implementation of Halide::round that doesn't depend on any
* standard library being present. */
Expr lower_round_to_nearest_ties_to_even(const Expr &);

/** Given an llvm::Module, set llvm:TargetOptions information */
void get_target_options(const llvm::Module &module, llvm::TargetOptions &options);

Expand Down
2 changes: 2 additions & 0 deletions src/CodeGen_LLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2758,6 +2758,8 @@ void CodeGen_LLVM::visit(const Call *op) {

value = phi;
}
} else if (op->is_intrinsic(Call::round)) {
value = codegen(lower_round_to_nearest_ties_to_even(op->args[0]));
} else if (op->is_intrinsic(Call::require)) {
internal_assert(op->args.size() == 3);
Expr cond = op->args[0];
Expand Down
1 change: 0 additions & 1 deletion src/CodeGen_Metal_Dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -763,7 +763,6 @@ void CodeGen_Metal_Dev::init_module() {
<< "#define abs_f32 fabs\n"
<< "#define floor_f32 floor\n"
<< "#define ceil_f32 ceil\n"
<< "#define round_f32 round\n"
<< "#define trunc_f32 trunc\n"
<< "#define pow_f32 pow\n"
<< "#define asin_f32 asin\n"
Expand Down
Loading