Skip to content

Simplify codegen of mixed-type checked integer addition and subtraction #15878

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
158 changes: 118 additions & 40 deletions src/compiler/crystal/codegen/primitives.cr
Original file line number Diff line number Diff line change
Expand Up @@ -135,8 +135,8 @@ class Crystal::CodeGenVisitor
end

case op
when "+", "-", "*"
return codegen_binary_op_with_overflow(op, t1, t2, p1, p2)
when "+", "-" then return codegen_addsub_with_overflow(op, t1, t2, p1, p2)
when "*" then return codegen_mul_with_overflow(t1, t2, p1, p2)
end

tmax, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2)
Expand All @@ -156,48 +156,92 @@ class Crystal::CodeGenVisitor
end
end

def codegen_binary_op_with_overflow(op, t1, t2, p1, p2)
if op == "*"
if t1.unsigned? && t2.signed?
return codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
elsif t1.signed? && t2.unsigned?
return codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
end
def codegen_addsub_with_overflow(op, t1, t2, p1, p2)
if t1.signed? != t2.signed?
# Convert `p1` to the opposite signedness, while simultaneously applying
# a bias equal to half the integer range: add bias if `p1` is unsigned,
# subtract bias if `p1` is signed, which for two's complement is
# equivalent to a bitwise XOR on the sign bit. Thus, if `t1` is signed:
#
# ```
# p1_biased = (p1 ^ t1::MIN).to_unsigned!
# result_biased = ...(p1_biased, p2)
# result_biased.to_signed! ^ t1::MIN
# ```
#
# If `t1` is unsigned:
#
# ```
# bias = typeof(p1.to_signed!)::MIN
# p1_biased = p1.to_signed! ^ bias
# result_biased = ...(p1_biased, p2)
# (result_biased ^ bias).to_unsigned!
# ```
t1_biased = @program.int_type(!t1.signed?, t1.bytes)
sign_bit, _ = (t1.signed? ? t1 : t1_biased).range
bias = int(sign_bit, t1)
p1_biased = builder.xor(p1, bias)

# now the overflow criterion is identical to that of the respective
# same-signedness operation
result_biased = codegen_addsub_same_signedness_with_overflow(op, t1_biased, t2, p1_biased, p2)

# revert the bias
builder.xor(result_biased, bias)
else
codegen_addsub_same_signedness_with_overflow(op, t1, t2, p1, p2)
end
end

calc_signed = t1.signed? || t2.signed?
calc_width = {t1, t2}.max_of { |t| t.bytes * 8 + ((calc_signed && t.unsigned?) ? 1 : 0) }
calc_type = llvm_context.int(calc_width)
def codegen_addsub_same_signedness_with_overflow(op, t1, t2, p1, p2)
if t2.bytes > t1.bytes
if t2.signed?
# e.g. Int8+Int16
# t1.new(t2.new!(p1) &+ p2)
p1 = extend_int(t1, t2, p1)

e1 = t1.signed? ? builder.sext(p1, calc_type) : builder.zext(p1, calc_type)
e2 = t2.signed? ? builder.sext(p2, calc_type) : builder.zext(p2, calc_type)
# use unchecked arithmetic; the signed overflow here cannot result in
# any value that fits into `t1`'s range
result = op == "+" ? builder.add(p1, p2) : builder.sub(p1, p2)

llvm_op =
case {calc_signed, op}
when {false, "+"} then "uadd"
when {false, "-"} then "usub"
when {false, "*"} then "umul"
when {true, "+"} then "sadd"
when {true, "-"} then "ssub"
when {true, "*"} then "smul"
else raise "BUG: unknown overflow op"
# catch the overflow via truncation instead
codegen_convert(t2, t1, result, checked: true)
else
# e.g. UInt8+UInt16
# p1 + t1.new!(p2)
p2_trunc = trunc(p2, llvm_type(t1))
result, overflow = call_binary_overflow_fun op, t1, p1, p2_trunc
codegen_raise_overflow_cond overflow

# if `p2` is outside `t1`'s range, any addition or subtraction must
# overflow regardless of `p1`'s value
_, max = t1.range
p2_too_large = builder.icmp(LLVM::IntPredicate::UGT, p2, int(max, t2))
codegen_raise_overflow_cond p2_too_large

result
end
else
# e.g. Int8+Int8, UInt8+UInt8, Int16+Int8, UInt16+UInt8
if t2.bytes < t1.bytes
# p1 + t1.new!(p2)
p2 = extend_int(t2, t1, p2)
end

llvm_fun = binary_overflow_fun "llvm.#{llvm_op}.with.overflow.i#{calc_width}", calc_type
res_with_overflow = builder.call(llvm_fun.type, llvm_fun.func, [e1, e2])

result = extract_value res_with_overflow, 0
overflow = extract_value res_with_overflow, 1

if calc_width > t1.bytes * 8
result_trunc = trunc result, llvm_type(t1)
result_trunc_ext = t1.signed? ? builder.sext(result_trunc, calc_type) : builder.zext(result_trunc, calc_type)
overflow = or(overflow, builder.icmp LLVM::IntPredicate::NE, result, result_trunc_ext)
result, overflow = call_binary_overflow_fun op, t1, p1, p2
codegen_raise_overflow_cond overflow
result
end
end

codegen_raise_overflow_cond overflow

trunc result, llvm_type(t1)
def codegen_mul_with_overflow(t1, t2, p1, p2)
if t1.unsigned? && t2.signed?
codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
elsif t1.signed? && t2.unsigned?
codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
else
codegen_mul_same_signedness_with_overflow(t1, t2, p1, p2)
end
end

def codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
Expand All @@ -207,7 +251,7 @@ class Crystal::CodeGenVisitor
)
codegen_raise_overflow_cond overflow

codegen_binary_op_with_overflow("*", t1, @program.int_type(false, t2.bytes), p1, p2)
codegen_mul_same_signedness_with_overflow(t1, @program.int_type(false, t2.bytes), p1, p2)
end

def codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
Expand All @@ -218,7 +262,7 @@ class Crystal::CodeGenVisitor

# tmp is the abs value of the result
# there is overflow when |result| > max + (negative ? 1 : 0)
tmp = codegen_binary_op_with_overflow("*", u1, t2, abs, p2)
tmp = codegen_mul_same_signedness_with_overflow(u1, t2, abs, p2)
_, max = t1.range
max_result = builder.add(int(max, t1), builder.zext(negative, llvm_type(t1)))
overflow = codegen_binary_op_gt(u1, u1, tmp, max_result)
Expand All @@ -229,6 +273,22 @@ class Crystal::CodeGenVisitor
builder.select negative, minus_tmp, tmp
end

def codegen_mul_same_signedness_with_overflow(t1, t2, p1, p2)
tmax, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2)

result, overflow = call_binary_overflow_fun "*", tmax, p1, p2

if tmax.bytes > t1.bytes
result_trunc = trunc result, llvm_type(t1)
result_trunc_ext = extend_int(t1, tmax, result_trunc)
overflow = or(overflow, builder.icmp LLVM::IntPredicate::NE, result, result_trunc_ext)
end

codegen_raise_overflow_cond overflow

trunc result, llvm_type(t1)
end

def codegen_binary_extend_int(t1, t2, p1, p2)
if t1.normal_rank == t2.normal_rank
# Nothing to do
Expand Down Expand Up @@ -369,13 +429,31 @@ class Crystal::CodeGenVisitor
position_at_end op_normal
end

private def binary_overflow_fun(fun_name, llvm_operand_type)
fetch_typed_fun(@llvm_mod, fun_name) do
private def call_binary_overflow_fun(op, t, p1, p2)
llvm_op =
case {t.signed?, op}
when {false, "+"} then "uadd"
when {false, "-"} then "usub"
when {false, "*"} then "umul"
when {true, "+"} then "sadd"
when {true, "-"} then "ssub"
when {true, "*"} then "smul"
else raise "BUG: unknown overflow op"
end

fun_name = "llvm.#{llvm_op}.with.overflow.i#{t.bytes * 8}"

llvm_operand_type = llvm_type(t)
llvm_fun = fetch_typed_fun(@llvm_mod, fun_name) do
LLVM::Type.function(
[llvm_operand_type, llvm_operand_type],
@llvm_context.struct([llvm_operand_type, @llvm_context.int1]),
)
end
res_with_overflow = builder.call(llvm_fun.type, llvm_fun.func, [p1, p2])
result = extract_value res_with_overflow, 0
overflow = extract_value res_with_overflow, 1
{result, overflow}
end

private def llvm_expect_i1_fun
Expand Down
Loading