Skip to content

Commit cd78c4b

Browse files
authored
Simplify codegen of mixed-type checked integer addition and subtraction (#15878)
Code generation for mixed-type uses of overflow-checked primitive integer addition or subtraction is currently somewhat convoluted, in particular when the receiver is signed and the argument unsigned, or vice versa. The compiler would create an intermediate integer type that has one more bit than the operands, e.g. `i9` or `i33`; most LLVM targets do not like these types, and this leads to [some rather unsightly LLVM IR that is hard to optimize](https://godbolt.org/z/h7EP71W51). This PR is a complete overhaul of the way these additions and subtractions are emitted. There are now multiple specialized code paths depending on whether the two operand types have the same signedness or width. The following Crystal snippet illustrates how each code path could be expressed equivalently in terms of other primitive calls in native Crystal: ```crystal fun i8_add_u8(p1 : Int8, p2 : UInt8) : Int8 p1_biased = (p1 ^ Int8::MIN).to_u8! result = p1_biased + p2 # same-type, checked result.to_i8! ^ Int8::MIN end fun u16_add_i8(p1 : UInt16, p2 : Int8) : UInt16 p1_biased = p1.to_i16! ^ Int16::MIN result = i16_add_i8(p1_biased, p2) # checked, see below (result ^ Int16::MIN).to_u16! end fun i8_add_u16(p1 : Int8, p2 : UInt16) : Int8 p1_biased = (p1 ^ Int8::MIN).to_u8! result = u8_add_u16(p1_biased, p2) # checked, see below result.to_i8! ^ Int8::MIN end fun i8_add_i16(p1 : Int8, p2 : Int16) : Int8 p1_ext = p1.to_i16! result = p1_ext &+ p2 result.to_i8 # checked end # the actual optimal call sequence is slightly different, # probably due to some short-circuit evaluation issue fun u8_add_u16(p1 : UInt8, p2 : UInt16) : UInt8 p2_trunc = p2.to_u8 # checked p1 + p2_trunc # same-type, checked end fun i16_add_i8(p1 : Int16, p2 : Int8) : Int16 p2_ext = p2.to_i16! p1 + p2_ext # same-type, checked end ``` ([Before](https://godbolt.org/z/b5vdnscnK) and [after](https://godbolt.org/z/qa5avE9cW) on Compiler Explorer) The gist here is that mixed-signedness operations are transformed into same-signedness ones by applying a bias to the first operand and switching its signedness, using a bitwise XOR. For example, `-0x80_i8..0x7F_i8` maps linearly to `0x00_u8..0xFF_u8`, and vice-versa. The same-signedness arithmetic operation that follows will overflow if and only if the original operation does. The result is XOR'ed again afterwards, as the bias action is the inverse of itself. This is the trick that allows mixed-type addition or subtraction without resorting to arcane integer bit widths.
1 parent 5acf3d4 commit cd78c4b

File tree

1 file changed

+118
-40
lines changed

1 file changed

+118
-40
lines changed

src/compiler/crystal/codegen/primitives.cr

Lines changed: 118 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -135,8 +135,8 @@ class Crystal::CodeGenVisitor
135135
end
136136

137137
case op
138-
when "+", "-", "*"
139-
return codegen_binary_op_with_overflow(op, t1, t2, p1, p2)
138+
when "+", "-" then return codegen_addsub_with_overflow(op, t1, t2, p1, p2)
139+
when "*" then return codegen_mul_with_overflow(t1, t2, p1, p2)
140140
end
141141

142142
tmax, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2)
@@ -156,48 +156,92 @@ class Crystal::CodeGenVisitor
156156
end
157157
end
158158

159-
def codegen_binary_op_with_overflow(op, t1, t2, p1, p2)
160-
if op == "*"
161-
if t1.unsigned? && t2.signed?
162-
return codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
163-
elsif t1.signed? && t2.unsigned?
164-
return codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
165-
end
159+
def codegen_addsub_with_overflow(op, t1, t2, p1, p2)
160+
if t1.signed? != t2.signed?
161+
# Convert `p1` to the opposite signedness, while simultaneously applying
162+
# a bias equal to half the integer range: add bias if `p1` is unsigned,
163+
# subtract bias if `p1` is signed, which for two's complement is
164+
# equivalent to a bitwise XOR on the sign bit. Thus, if `t1` is signed:
165+
#
166+
# ```
167+
# p1_biased = (p1 ^ t1::MIN).to_unsigned!
168+
# result_biased = ...(p1_biased, p2)
169+
# result_biased.to_signed! ^ t1::MIN
170+
# ```
171+
#
172+
# If `t1` is unsigned:
173+
#
174+
# ```
175+
# bias = typeof(p1.to_signed!)::MIN
176+
# p1_biased = p1.to_signed! ^ bias
177+
# result_biased = ...(p1_biased, p2)
178+
# (result_biased ^ bias).to_unsigned!
179+
# ```
180+
t1_biased = @program.int_type(!t1.signed?, t1.bytes)
181+
sign_bit, _ = (t1.signed? ? t1 : t1_biased).range
182+
bias = int(sign_bit, t1)
183+
p1_biased = builder.xor(p1, bias)
184+
185+
# now the overflow criterion is identical to that of the respective
186+
# same-signedness operation
187+
result_biased = codegen_addsub_same_signedness_with_overflow(op, t1_biased, t2, p1_biased, p2)
188+
189+
# revert the bias
190+
builder.xor(result_biased, bias)
191+
else
192+
codegen_addsub_same_signedness_with_overflow(op, t1, t2, p1, p2)
166193
end
194+
end
167195

168-
calc_signed = t1.signed? || t2.signed?
169-
calc_width = {t1, t2}.max_of { |t| t.bytes * 8 + ((calc_signed && t.unsigned?) ? 1 : 0) }
170-
calc_type = llvm_context.int(calc_width)
196+
def codegen_addsub_same_signedness_with_overflow(op, t1, t2, p1, p2)
197+
if t2.bytes > t1.bytes
198+
if t2.signed?
199+
# e.g. Int8+Int16
200+
# t1.new(t2.new!(p1) &+ p2)
201+
p1 = extend_int(t1, t2, p1)
171202

172-
e1 = t1.signed? ? builder.sext(p1, calc_type) : builder.zext(p1, calc_type)
173-
e2 = t2.signed? ? builder.sext(p2, calc_type) : builder.zext(p2, calc_type)
203+
# use unchecked arithmetic; the signed overflow here cannot result in
204+
# any value that fits into `t1`'s range
205+
result = op == "+" ? builder.add(p1, p2) : builder.sub(p1, p2)
174206

175-
llvm_op =
176-
case {calc_signed, op}
177-
when {false, "+"} then "uadd"
178-
when {false, "-"} then "usub"
179-
when {false, "*"} then "umul"
180-
when {true, "+"} then "sadd"
181-
when {true, "-"} then "ssub"
182-
when {true, "*"} then "smul"
183-
else raise "BUG: unknown overflow op"
207+
# catch the overflow via truncation instead
208+
codegen_convert(t2, t1, result, checked: true)
209+
else
210+
# e.g. UInt8+UInt16
211+
# p1 + t1.new!(p2)
212+
p2_trunc = trunc(p2, llvm_type(t1))
213+
result, overflow = call_binary_overflow_fun op, t1, p1, p2_trunc
214+
codegen_raise_overflow_cond overflow
215+
216+
# if `p2` is outside `t1`'s range, any addition or subtraction must
217+
# overflow regardless of `p1`'s value
218+
_, max = t1.range
219+
p2_too_large = builder.icmp(LLVM::IntPredicate::UGT, p2, int(max, t2))
220+
codegen_raise_overflow_cond p2_too_large
221+
222+
result
223+
end
224+
else
225+
# e.g. Int8+Int8, UInt8+UInt8, Int16+Int8, UInt16+UInt8
226+
if t2.bytes < t1.bytes
227+
# p1 + t1.new!(p2)
228+
p2 = extend_int(t2, t1, p2)
184229
end
185230

186-
llvm_fun = binary_overflow_fun "llvm.#{llvm_op}.with.overflow.i#{calc_width}", calc_type
187-
res_with_overflow = builder.call(llvm_fun.type, llvm_fun.func, [e1, e2])
188-
189-
result = extract_value res_with_overflow, 0
190-
overflow = extract_value res_with_overflow, 1
191-
192-
if calc_width > t1.bytes * 8
193-
result_trunc = trunc result, llvm_type(t1)
194-
result_trunc_ext = t1.signed? ? builder.sext(result_trunc, calc_type) : builder.zext(result_trunc, calc_type)
195-
overflow = or(overflow, builder.icmp LLVM::IntPredicate::NE, result, result_trunc_ext)
231+
result, overflow = call_binary_overflow_fun op, t1, p1, p2
232+
codegen_raise_overflow_cond overflow
233+
result
196234
end
235+
end
197236

198-
codegen_raise_overflow_cond overflow
199-
200-
trunc result, llvm_type(t1)
237+
def codegen_mul_with_overflow(t1, t2, p1, p2)
238+
if t1.unsigned? && t2.signed?
239+
codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
240+
elsif t1.signed? && t2.unsigned?
241+
codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
242+
else
243+
codegen_mul_same_signedness_with_overflow(t1, t2, p1, p2)
244+
end
201245
end
202246

203247
def codegen_mul_unsigned_signed_with_overflow(t1, t2, p1, p2)
@@ -207,7 +251,7 @@ class Crystal::CodeGenVisitor
207251
)
208252
codegen_raise_overflow_cond overflow
209253

210-
codegen_binary_op_with_overflow("*", t1, @program.int_type(false, t2.bytes), p1, p2)
254+
codegen_mul_same_signedness_with_overflow(t1, @program.int_type(false, t2.bytes), p1, p2)
211255
end
212256

213257
def codegen_mul_signed_unsigned_with_overflow(t1, t2, p1, p2)
@@ -218,7 +262,7 @@ class Crystal::CodeGenVisitor
218262

219263
# tmp is the abs value of the result
220264
# there is overflow when |result| > max + (negative ? 1 : 0)
221-
tmp = codegen_binary_op_with_overflow("*", u1, t2, abs, p2)
265+
tmp = codegen_mul_same_signedness_with_overflow(u1, t2, abs, p2)
222266
_, max = t1.range
223267
max_result = builder.add(int(max, t1), builder.zext(negative, llvm_type(t1)))
224268
overflow = codegen_binary_op_gt(u1, u1, tmp, max_result)
@@ -229,6 +273,22 @@ class Crystal::CodeGenVisitor
229273
builder.select negative, minus_tmp, tmp
230274
end
231275

276+
def codegen_mul_same_signedness_with_overflow(t1, t2, p1, p2)
277+
tmax, p1, p2 = codegen_binary_extend_int(t1, t2, p1, p2)
278+
279+
result, overflow = call_binary_overflow_fun "*", tmax, p1, p2
280+
281+
if tmax.bytes > t1.bytes
282+
result_trunc = trunc result, llvm_type(t1)
283+
result_trunc_ext = extend_int(t1, tmax, result_trunc)
284+
overflow = or(overflow, builder.icmp LLVM::IntPredicate::NE, result, result_trunc_ext)
285+
end
286+
287+
codegen_raise_overflow_cond overflow
288+
289+
trunc result, llvm_type(t1)
290+
end
291+
232292
def codegen_binary_extend_int(t1, t2, p1, p2)
233293
if t1.normal_rank == t2.normal_rank
234294
# Nothing to do
@@ -369,13 +429,31 @@ class Crystal::CodeGenVisitor
369429
position_at_end op_normal
370430
end
371431

372-
private def binary_overflow_fun(fun_name, llvm_operand_type)
373-
fetch_typed_fun(@llvm_mod, fun_name) do
432+
private def call_binary_overflow_fun(op, t, p1, p2)
433+
llvm_op =
434+
case {t.signed?, op}
435+
when {false, "+"} then "uadd"
436+
when {false, "-"} then "usub"
437+
when {false, "*"} then "umul"
438+
when {true, "+"} then "sadd"
439+
when {true, "-"} then "ssub"
440+
when {true, "*"} then "smul"
441+
else raise "BUG: unknown overflow op"
442+
end
443+
444+
fun_name = "llvm.#{llvm_op}.with.overflow.i#{t.bytes * 8}"
445+
446+
llvm_operand_type = llvm_type(t)
447+
llvm_fun = fetch_typed_fun(@llvm_mod, fun_name) do
374448
LLVM::Type.function(
375449
[llvm_operand_type, llvm_operand_type],
376450
@llvm_context.struct([llvm_operand_type, @llvm_context.int1]),
377451
)
378452
end
453+
res_with_overflow = builder.call(llvm_fun.type, llvm_fun.func, [p1, p2])
454+
result = extract_value res_with_overflow, 0
455+
overflow = extract_value res_with_overflow, 1
456+
{result, overflow}
379457
end
380458

381459
private def llvm_expect_i1_fun

0 commit comments

Comments
 (0)