Skip to content

Commit 2c73f83

Browse files
Merge pull request from GHSA-7vrm-3jc8-5wwm
* add more tests for string comparison explicitly test the codepath with <= 32 bytes * refactor keccak256 helper a bit * fix bytestring equality existing bytestring equality checks do not check length equality or for dirty bytes.
1 parent 0807a60 commit 2c73f83

File tree

5 files changed

+159
-104
lines changed

5 files changed

+159
-104
lines changed

tests/parser/functions/test_slice.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,36 @@ def ret10_slice() -> Bytes[10]:
210210
assert c.ret10_slice() == b"A"
211211

212212

213+
def test_slice_equality(get_contract):
214+
# test for equality with dirty bytes
215+
code = """
216+
@external
217+
def assert_eq() -> bool:
218+
dirty_bytes: String[4] = "abcd"
219+
dirty_bytes = slice(dirty_bytes, 0, 3)
220+
clean_bytes: String[4] = "abc"
221+
return dirty_bytes == clean_bytes
222+
"""
223+
224+
c = get_contract(code)
225+
assert c.assert_eq()
226+
227+
228+
def test_slice_inequality(get_contract):
229+
# test for equality with dirty bytes
230+
code = """
231+
@external
232+
def assert_ne() -> bool:
233+
dirty_bytes: String[4] = "abcd"
234+
dirty_bytes = slice(dirty_bytes, 0, 3)
235+
clean_bytes: String[4] = "abcd"
236+
return dirty_bytes != clean_bytes
237+
"""
238+
239+
c = get_contract(code)
240+
assert c.assert_ne()
241+
242+
213243
def test_slice_convert(get_contract):
214244
# test slice of converting between bytes32 and Bytes
215245
code = """

tests/parser/types/test_string.py

Lines changed: 101 additions & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -168,120 +168,165 @@ def test(a: uint256, b: String[50] = "foo") -> Bytes[100]:
168168
assert c.test(12345, "bar")[-3:] == b"bar"
169169

170170

171-
def test_string_equality(get_contract_with_gas_estimation):
172-
code = """
173-
_compA: String[100]
174-
_compB: String[100]
171+
string_equality_tests = [
172+
(
173+
100,
174+
"The quick brown fox jumps over the lazy dog",
175+
"The quick brown fox jumps over the lazy hog",
176+
),
177+
# check <= 32 codepath
178+
(32, "abc", "abc\0"),
179+
(32, "abc", "abc\1"), # use a_init dirty bytes
180+
(32, "abc\2", "abc"), # use b_init dirty bytes
181+
(32, "", "\0"),
182+
(32, "", "\1"),
183+
(33, "", "\1"),
184+
(33, "", "\0"),
185+
]
186+
187+
188+
@pytest.mark.parametrize("len_,a,b", string_equality_tests)
189+
def test_string_equality(get_contract_with_gas_estimation, len_, a, b):
190+
# fixtures to initialize strings with dirty bytes
191+
a_init = "\\1" * len_
192+
b_init = "\\2" * len_
193+
string1 = a.encode("unicode_escape").decode("utf-8")
194+
string2 = b.encode("unicode_escape").decode("utf-8")
195+
code = f"""
196+
a: String[{len_}]
197+
b: String[{len_}]
175198
176199
@external
177200
def equal_true() -> bool:
178-
compA: String[100] = "The quick brown fox jumps over the lazy dog"
179-
compB: String[100] = "The quick brown fox jumps over the lazy dog"
180-
return compA == compB
201+
a: String[{len_}] = "{a_init}"
202+
b: String[{len_}] = "{b_init}"
203+
a = "{string1}"
204+
b = "{string1}"
205+
return a == b
181206
182207
@external
183208
def equal_false() -> bool:
184-
compA: String[100] = "The quick brown fox jumps over the lazy dog"
185-
compB: String[100] = "The quick brown fox jumps over the lazy hog"
186-
return compA == compB
209+
a: String[{len_}] = "{a_init}"
210+
b: String[{len_}] = "{b_init}"
211+
a = "{string1}"
212+
b = "{string2}"
213+
return a == b
187214
188215
@external
189216
def not_equal_true() -> bool:
190-
compA: String[100] = "The quick brown fox jumps over the lazy dog"
191-
compB: String[100] = "The quick brown fox jumps over the lazy hog"
192-
return compA != compB
217+
a: String[{len_}] = "{a_init}"
218+
b: String[{len_}] = "{b_init}"
219+
a = "{string1}"
220+
b = "{string2}"
221+
return a != b
193222
194223
@external
195224
def not_equal_false() -> bool:
196-
compA: String[100] = "The quick brown fox jumps over the lazy dog"
197-
compB: String[100] = "The quick brown fox jumps over the lazy dog"
198-
return compA != compB
225+
a: String[{len_}] = "{a_init}"
226+
b: String[{len_}] = "{b_init}"
227+
a = "{string1}"
228+
b = "{string1}"
229+
return a != b
199230
200231
@external
201232
def literal_equal_true() -> bool:
202-
return "The quick brown fox jumps over the lazy dog" == \
203-
"The quick brown fox jumps over the lazy dog"
233+
return "{string1}" == "{string1}"
204234
205235
@external
206236
def literal_equal_false() -> bool:
207-
return "The quick brown fox jumps over the lazy dog" == \
208-
"The quick brown fox jumps over the lazy hog"
237+
return "{string1}" == "{string2}"
209238
210239
@external
211240
def literal_not_equal_true() -> bool:
212-
return "The quick brown fox jumps over the lazy dog" != \
213-
"The quick brown fox jumps over the lazy hog"
241+
return "{string1}" != "{string2}"
214242
215243
@external
216244
def literal_not_equal_false() -> bool:
217-
return "The quick brown fox jumps over the lazy dog" != \
218-
"The quick brown fox jumps over the lazy dog"
245+
return "{string1}" != "{string1}"
219246
220247
@external
221248
def storage_equal_true() -> bool:
222-
self._compA = "The quick brown fox jumps over the lazy dog"
223-
self._compB = "The quick brown fox jumps over the lazy dog"
224-
return self._compA == self._compB
249+
self.a = "{a_init}"
250+
self.b = "{b_init}"
251+
self.a = "{string1}"
252+
self.b = "{string1}"
253+
return self.a == self.b
225254
226255
@external
227256
def storage_equal_false() -> bool:
228-
self._compA = "The quick brown fox jumps over the lazy dog"
229-
self._compB = "The quick brown fox jumps over the lazy hog"
230-
return self._compA == self._compB
257+
self.a = "{a_init}"
258+
self.b = "{b_init}"
259+
self.a = "{string1}"
260+
self.b = "{string2}"
261+
return self.a == self.b
231262
232263
@external
233264
def storage_not_equal_true() -> bool:
234-
self._compA = "The quick brown fox jumps over the lazy dog"
235-
self._compB = "The quick brown fox jumps over the lazy hog"
236-
return self._compA != self._compB
265+
self.a = "{a_init}"
266+
self.b = "{b_init}"
267+
self.a = "{string1}"
268+
self.b = "{string2}"
269+
return self.a != self.b
237270
238271
@external
239272
def storage_not_equal_false() -> bool:
240-
self._compA = "The quick brown fox jumps over the lazy dog"
241-
self._compB = "The quick brown fox jumps over the lazy dog"
242-
return self._compA != self._compB
273+
self.a = "{a_init}"
274+
self.b = "{b_init}"
275+
self.a = "{string1}"
276+
self.b = "{string1}"
277+
return self.a != self.b
243278
244279
@external
245-
def string_compare_equal(str1: String[100], str2: String[100]) -> bool:
280+
def string_compare_equal(str1: String[{len_}], str2: String[{len_}]) -> bool:
246281
return str1 == str2
247282
248283
@external
249-
def string_compare_not_equal(str1: String[100], str2: String[100]) -> bool:
284+
def string_compare_not_equal(str1: String[{len_}], str2: String[{len_}]) -> bool:
250285
return str1 != str2
251286
252287
@external
253-
def compare_passed_storage_equal(str: String[100]) -> bool:
254-
self._compA = "The quick brown fox jumps over the lazy dog"
255-
return self._compA == str
288+
def compare_passed_storage_equal(str_: String[{len_}]) -> bool:
289+
self.a = "{a_init}"
290+
self.a = "{string1}"
291+
return self.a == str_
256292
257293
@external
258-
def compare_passed_storage_not_equal(str: String[100]) -> bool:
259-
self._compA = "The quick brown fox jumps over the lazy dog"
260-
return self._compA != str
294+
def compare_passed_storage_not_equal(str_: String[{len_}]) -> bool:
295+
self.a = "{a_init}"
296+
self.a = "{string1}"
297+
return self.a != str_
261298
262299
@external
263300
def compare_var_storage_equal_true() -> bool:
264-
self._compA = "The quick brown fox jumps over the lazy dog"
265-
compB: String[100] = "The quick brown fox jumps over the lazy dog"
266-
return self._compA == compB
301+
self.a = "{a_init}"
302+
b: String[{len_}] = "{b_init}"
303+
self.a = "{string1}"
304+
b = "{string1}"
305+
return self.a == b
267306
268307
@external
269308
def compare_var_storage_equal_false() -> bool:
270-
self._compA = "The quick brown fox jumps over the lazy dog"
271-
compB: String[100] = "The quick brown fox jumps over the lazy hog"
272-
return self._compA == compB
309+
self.a = "{a_init}"
310+
b: String[{len_}] = "{b_init}"
311+
self.a = "{string1}"
312+
b = "{string2}"
313+
return self.a == b
273314
274315
@external
275316
def compare_var_storage_not_equal_true() -> bool:
276-
self._compA = "The quick brown fox jumps over the lazy dog"
277-
compB: String[100] = "The quick brown fox jumps over the lazy hog"
278-
return self._compA != compB
317+
self.a = "{a_init}"
318+
b: String[{len_}] = "{b_init}"
319+
self.a = "{string1}"
320+
b = "{string2}"
321+
return self.a != b
279322
280323
@external
281324
def compare_var_storage_not_equal_false() -> bool:
282-
self._compA = "The quick brown fox jumps over the lazy dog"
283-
compB: String[100] = "The quick brown fox jumps over the lazy dog"
284-
return self._compA != compB
325+
self.a = "{a_init}"
326+
b: String[{len_}] = "{b_init}"
327+
self.a = "{string1}"
328+
b = "{string1}"
329+
return self.a != b
285330
"""
286331

287332
c = get_contract_with_gas_estimation(code)
@@ -298,8 +343,6 @@ def compare_var_storage_not_equal_false() -> bool:
298343
assert c.storage_not_equal_true() is True
299344
assert c.storage_not_equal_false() is False
300345

301-
a = "The quick brown fox jumps over the lazy dog"
302-
b = "The quick brown fox jumps over the lazy hog"
303346
assert c.string_compare_equal(a, a) is True
304347
assert c.string_compare_equal(a, b) is False
305348
assert c.string_compare_not_equal(b, a) is True

vyper/codegen/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ def copy_bytes(dst, src, length, length_bound):
186186

187187
with src.cache_when_complex("src") as (b1, src), length.cache_when_complex(
188188
"copy_bytes_count"
189-
) as (b2, length,), dst.cache_when_complex("dst") as (b3, dst):
189+
) as (b2, length), dst.cache_when_complex("dst") as (b3, dst):
190190

191191
# fast code for common case where num bytes is small
192192
# TODO expand this for more cases where num words is less than ~8

vyper/codegen/expr.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,6 @@
55
from vyper.address_space import DATA, IMMUTABLES, MEMORY, STORAGE
66
from vyper.codegen import external_call, self_call
77
from vyper.codegen.core import (
8-
LOAD,
9-
bytes_data_ptr,
108
clamp_basetype,
119
ensure_in_memory,
1210
get_dyn_array_count,
@@ -801,30 +799,15 @@ def parse_Compare(self):
801799
left = Expr(self.expr.left, self.context).ir_node
802800
right = Expr(self.expr.right, self.context).ir_node
803801

804-
length_mismatch = left.typ.maxlen != right.typ.maxlen
805-
left_over_32 = left.typ.maxlen > 32
806-
right_over_32 = right.typ.maxlen > 32
807-
808-
if length_mismatch or left_over_32 or right_over_32:
809-
left_keccak = keccak256_helper(self.expr, left, self.context)
810-
right_keccak = keccak256_helper(self.expr, right, self.context)
811-
812-
if op == "eq" or op == "ne":
813-
return IRnode.from_list([op, left_keccak, right_keccak], typ="bool")
814-
815-
else:
816-
return
802+
left_keccak = keccak256_helper(self.expr, left, self.context)
803+
right_keccak = keccak256_helper(self.expr, right, self.context)
817804

805+
if op not in ("eq", "ne"):
806+
return # raises
818807
else:
819-
820-
def load_bytearray(side):
821-
return LOAD(bytes_data_ptr(side))
822-
823-
return IRnode.from_list(
824-
# CMC 2022-03-24 TODO investigate this.
825-
[op, load_bytearray(left), load_bytearray(right)],
826-
typ="bool",
827-
)
808+
# use hash even for Bytes[N<=32], because there could be dirty
809+
# bytes past the bytes data.
810+
return IRnode.from_list([op, left_keccak, right_keccak], typ="bool")
828811

829812
# Compare other types.
830813
elif is_numeric_type(left.typ) and is_numeric_type(right.typ):

vyper/codegen/keccak256_helper.py

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from math import ceil
22

3-
from vyper.codegen.core import ensure_in_memory
3+
from vyper.codegen.core import bytes_data_ptr, ensure_in_memory, get_bytearray_length
44
from vyper.codegen.ir_node import IRnode
55
from vyper.codegen.types import BaseType, ByteArrayLike, is_base_type
66
from vyper.exceptions import CompilerPanic
@@ -21,37 +21,36 @@ def _gas_bound(num_words):
2121
return SHA3_BASE + num_words * SHA3_PER_WORD
2222

2323

24-
def keccak256_helper(expr, ir_arg, context):
25-
sub = ir_arg # TODO get rid of useless variable
26-
_check_byteslike(sub.typ, expr)
24+
def keccak256_helper(expr, to_hash, context):
25+
_check_byteslike(to_hash.typ, expr)
2726

2827
# Can hash literals
2928
# TODO this is dead code.
30-
if isinstance(sub, bytes):
31-
return IRnode.from_list(bytes_to_int(keccak256(sub)), typ=BaseType("bytes32"))
29+
if isinstance(to_hash, bytes):
30+
return IRnode.from_list(bytes_to_int(keccak256(to_hash)), typ=BaseType("bytes32"))
3231

3332
# Can hash bytes32 objects
34-
if is_base_type(sub.typ, "bytes32"):
33+
if is_base_type(to_hash.typ, "bytes32"):
3534
return IRnode.from_list(
3635
[
3736
"seq",
38-
["mstore", MemoryPositions.FREE_VAR_SPACE, sub],
37+
["mstore", MemoryPositions.FREE_VAR_SPACE, to_hash],
3938
["sha3", MemoryPositions.FREE_VAR_SPACE, 32],
4039
],
4140
typ=BaseType("bytes32"),
4241
add_gas_estimate=_gas_bound(1),
4342
)
4443

45-
sub = ensure_in_memory(sub, context)
46-
47-
return IRnode.from_list(
48-
[
49-
"with",
50-
"_buf",
51-
sub,
52-
["sha3", ["add", "_buf", 32], ["mload", "_buf"]],
53-
],
54-
typ=BaseType("bytes32"),
55-
annotation="keccak256",
56-
add_gas_estimate=_gas_bound(ceil(sub.typ.maxlen / 32)),
57-
)
44+
to_hash = ensure_in_memory(to_hash, context)
45+
46+
with to_hash.cache_when_complex("buf") as (b1, to_hash):
47+
data = bytes_data_ptr(to_hash)
48+
len_ = get_bytearray_length(to_hash)
49+
return b1.resolve(
50+
IRnode.from_list(
51+
["sha3", data, len_],
52+
typ="bytes32",
53+
annotation="keccak256",
54+
add_gas_estimate=_gas_bound(ceil(to_hash.typ.maxlen / 32)),
55+
)
56+
)

0 commit comments

Comments
 (0)