diff --git a/mypyc/doc/str_operations.rst b/mypyc/doc/str_operations.rst index 4a7aff00f2ad..2eebd2f6ab57 100644 --- a/mypyc/doc/str_operations.rst +++ b/mypyc/doc/str_operations.rst @@ -19,7 +19,7 @@ Operators --------- * Concatenation (``s1 + s2``) -* Indexing (``s[n]``) +* Indexing (``s[n]``; also ``ord(s[n])``, which avoids the temporary length-1 string) * Slicing (``s[n:m]``, ``s[n:]``, ``s[:m]``) * Comparisons (``==``, ``!=``) * Augmented assignment (``s1 += s2``) diff --git a/mypyc/ir/deps.py b/mypyc/ir/deps.py index 249b456e2c85..cee7263a8c92 100644 --- a/mypyc/ir/deps.py +++ b/mypyc/ir/deps.py @@ -52,3 +52,4 @@ def get_header(self) -> str: BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c") BYTES_WRITER_EXTRA_OPS: Final = SourceDep("byteswriter_extra_ops.c") BYTEARRAY_EXTRA_OPS: Final = SourceDep("bytearray_extra_ops.c") +STR_EXTRA_OPS: Final = SourceDep("str_extra_ops.c") diff --git a/mypyc/irbuild/specialize.py b/mypyc/irbuild/specialize.py index d75c9144dcda..d0934914dfe9 100644 --- a/mypyc/irbuild/specialize.py +++ b/mypyc/irbuild/specialize.py @@ -24,6 +24,7 @@ DictExpr, Expression, GeneratorExpr, + IndexExpr, IntExpr, ListExpr, MemberExpr, @@ -72,6 +73,8 @@ is_int_rprimitive, is_list_rprimitive, is_sequence_rprimitive, + is_str_rprimitive, + is_tagged, is_uint8_rprimitive, list_rprimitive, object_rprimitive, @@ -125,9 +128,12 @@ bytes_decode_latin1_strict, bytes_decode_utf8_strict, isinstance_str, + str_adjust_index_op, str_encode_ascii_strict, str_encode_latin1_strict, str_encode_utf8_strict, + str_get_item_unsafe_as_int_op, + str_range_check_op, ) from mypyc.primitives.tuple_ops import isinstance_tuple, new_tuple_set_item_op @@ -1126,9 +1132,33 @@ def translate_float(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Valu def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None: if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS: return None - arg = constant_fold_expr(builder, expr.args[0]) + arg_expr = expr.args[0] + arg = constant_fold_expr(builder, arg_expr) if isinstance(arg, (str, bytes)) and len(arg) == 1: return Integer(ord(arg)) + + # Check for ord(s[i]) where s is str and i is an integer + if isinstance(arg_expr, IndexExpr): + # Check base type + base_type = builder.node_type(arg_expr.base) + if is_str_rprimitive(base_type): + # Check index type + index_expr = arg_expr.index + index_type = builder.node_type(index_expr) + if is_tagged(index_type) or is_fixed_width_rtype(index_type): + # This is ord(s[i]) where s is str and i is an integer. + # Generate specialized inline code using the helper. + result = translate_getitem_with_bounds_check( + builder, + arg_expr.base, + [arg_expr.index], + expr, + str_adjust_index_op, + str_range_check_op, + str_get_item_unsafe_as_int_op, + ) + return result + return None diff --git a/mypyc/lib-rt/str_extra_ops.c b/mypyc/lib-rt/str_extra_ops.c new file mode 100644 index 000000000000..bb2adabd6d7b --- /dev/null +++ b/mypyc/lib-rt/str_extra_ops.c @@ -0,0 +1,4 @@ +#include "str_extra_ops.h" + +// All str extra ops are inline functions in str_extra_ops.h +// This file exists to satisfy the SourceDep requirements diff --git a/mypyc/lib-rt/str_extra_ops.h b/mypyc/lib-rt/str_extra_ops.h new file mode 100644 index 000000000000..82f92bf85d46 --- /dev/null +++ b/mypyc/lib-rt/str_extra_ops.h @@ -0,0 +1,29 @@ +#ifndef MYPYC_STR_EXTRA_OPS_H +#define MYPYC_STR_EXTRA_OPS_H + +#include +#include +#include "CPy.h" + +// Optimized str indexing for ord(s[i]) + +// If index is negative, convert to non-negative index (no range checking) +static inline int64_t CPyStr_AdjustIndex(PyObject *obj, int64_t index) { + if (index < 0) { + return index + PyUnicode_GET_LENGTH(obj); + } + return index; +} + +// Check if index is in valid range [0, len) +static inline bool CPyStr_RangeCheck(PyObject *obj, int64_t index) { + return index >= 0 && index < PyUnicode_GET_LENGTH(obj); +} + +// Get character at index as int (ord value) - no bounds checking, returns as CPyTagged +static inline CPyTagged CPyStr_GetItemUnsafeAsInt(PyObject *obj, int64_t index) { + int kind = PyUnicode_KIND(obj); + return PyUnicode_READ(kind, PyUnicode_DATA(obj), index) << 1; +} + +#endif diff --git a/mypyc/primitives/str_ops.py b/mypyc/primitives/str_ops.py index ceaf1cfe5dd2..f6d3f722dd7b 100644 --- a/mypyc/primitives/str_ops.py +++ b/mypyc/primitives/str_ops.py @@ -2,6 +2,7 @@ from __future__ import annotations +from mypyc.ir.deps import STR_EXTRA_OPS from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER from mypyc.ir.rtypes import ( RType, @@ -10,10 +11,12 @@ bytes_rprimitive, c_int_rprimitive, c_pyssize_t_rprimitive, + int64_rprimitive, int_rprimitive, list_rprimitive, object_rprimitive, pointer_rprimitive, + short_int_rprimitive, str_rprimitive, tuple_rprimitive, ) @@ -507,3 +510,35 @@ c_function_name="CPyStr_Ord", error_kind=ERR_MAGIC, ) + +# Optimized str indexing for ord(s[i]) + +# str index adjustment - convert negative index to positive +str_adjust_index_op = custom_primitive_op( + name="str_adjust_index", + arg_types=[str_rprimitive, int64_rprimitive], + return_type=int64_rprimitive, + c_function_name="CPyStr_AdjustIndex", + error_kind=ERR_NEVER, + dependencies=[STR_EXTRA_OPS], +) + +# str range check - check if index is in valid range +str_range_check_op = custom_primitive_op( + name="str_range_check", + arg_types=[str_rprimitive, int64_rprimitive], + return_type=bool_rprimitive, + c_function_name="CPyStr_RangeCheck", + error_kind=ERR_NEVER, + dependencies=[STR_EXTRA_OPS], +) + +# str.__getitem__() as int - get character at index as int (ord value) - no bounds checking +str_get_item_unsafe_as_int_op = custom_primitive_op( + name="str_get_item_unsafe_as_int", + arg_types=[str_rprimitive, int64_rprimitive], + return_type=short_int_rprimitive, + c_function_name="CPyStr_GetItemUnsafeAsInt", + error_kind=ERR_NEVER, + dependencies=[STR_EXTRA_OPS], +) diff --git a/mypyc/test-data/irbuild-str.test b/mypyc/test-data/irbuild-str.test index 881ddc3656ab..b199f0706ec6 100644 --- a/mypyc/test-data/irbuild-str.test +++ b/mypyc/test-data/irbuild-str.test @@ -507,6 +507,65 @@ L0: r6 = unbox(int, r5) return r6 +[case testOrdOfStrIndex_64bit] +from mypy_extensions import i64 +def ord_str_index(s: str, i: int) -> int: + return ord(s[i]) +def ord_str_index_i64(s: str, i: i64) -> int: + return ord(s[i]) +[typing fixtures/typing-full.pyi] +[out] +def ord_str_index(s, i): + s :: str + i :: int + r0 :: native_int + r1 :: bit + r2, r3 :: i64 + r4 :: ptr + r5 :: c_ptr + r6, r7 :: i64 + r8, r9 :: bool + r10 :: short_int +L0: + r0 = i & 1 + r1 = r0 == 0 + if r1 goto L1 else goto L2 :: bool +L1: + r2 = i >> 1 + r3 = r2 + goto L3 +L2: + r4 = i ^ 1 + r5 = r4 + r6 = CPyLong_AsInt64(r5) + r3 = r6 + keep_alive i +L3: + r7 = CPyStr_AdjustIndex(s, r3) + r8 = CPyStr_RangeCheck(s, r7) + if r8 goto L5 else goto L4 :: bool +L4: + r9 = raise IndexError('index out of range') + unreachable +L5: + r10 = CPyStr_GetItemUnsafeAsInt(s, r7) + return r10 +def ord_str_index_i64(s, i): + s :: str + i, r0 :: i64 + r1, r2 :: bool + r3 :: short_int +L0: + r0 = CPyStr_AdjustIndex(s, i) + r1 = CPyStr_RangeCheck(s, r0) + if r1 goto L2 else goto L1 :: bool +L1: + r2 = raise IndexError('index out of range') + unreachable +L2: + r3 = CPyStr_GetItemUnsafeAsInt(s, r0) + return r3 + [case testStrip] from typing import NewType, Union NewStr = NewType("NewStr", str) diff --git a/mypyc/test-data/run-strings.test b/mypyc/test-data/run-strings.test index 3c5a1f1d31e1..de976bbab78a 100644 --- a/mypyc/test-data/run-strings.test +++ b/mypyc/test-data/run-strings.test @@ -808,6 +808,7 @@ def test_chr() -> None: [case testOrd] from testutil import assertRaises +from mypy_extensions import i64, i32, i16 def test_ord() -> None: assert ord(' ') == 32 @@ -834,6 +835,77 @@ def test_ord() -> None: with assertRaises(TypeError): ord('') +def test_ord_str_index() -> None: + # ASCII + s1 = "hello" + assert ord(s1[0 + int()]) == 104 # 'h' + assert ord(s1[1 + int()]) == 101 # 'e' + assert ord(s1[4 + int()]) == 111 # 'o' + assert ord(s1[-1 + int()]) == 111 # 'o' + assert ord(s1[-5 + int()]) == 104 # 'h' + + # Latin-1 (8 bits) + s2 = "café" + assert ord(s2[0 + int()]) == 99 # 'c' + assert ord(s2[3 + int()]) == 233 # 'é' (U+00E9) + assert ord(s2[-1 + int()]) == 233 + + # 16-bit unicode + s3 = "你好" # Chinese + assert ord(s3[0 + int()]) == 20320 # '你' (U+4F60) + assert ord(s3[1 + int()]) == 22909 # '好' (U+597D) + assert ord(s3[-1 + int()]) == 22909 + assert ord(s3[-2 + int()]) == 20320 + + # 4-byte unicode + s5 = "a😀b" # Emoji between ASCII chars + assert ord(s5[0 + int()]) == 97 # 'a' + assert ord(s5[1 + int()]) == 128512 # '😀' (U+1F600) + assert ord(s5[2 + int()]) == 98 # 'b' + assert ord(s5[-1 + int()]) == 98 + assert ord(s5[-2 + int()]) == 128512 + assert ord(s5[-3 + int()]) == 97 + + with assertRaises(IndexError, "index out of range"): + ord(s1[5 + int()]) + with assertRaises(IndexError, "index out of range"): + ord(s1[100 + int()]) + with assertRaises(IndexError, "index out of range"): + ord(s1[-6 + int()]) + with assertRaises(IndexError, "index out of range"): + ord(s1[-100 + int()]) + + s_empty = "" + with assertRaises(IndexError, "index out of range"): + ord(s_empty[0 + int()]) + with assertRaises(IndexError, "index out of range"): + ord(s_empty[-1 + int()]) + +def test_ord_str_index_i64() -> None: + s = "hello" + + idx_i64: i64 = 2 + int() + assert ord(s[idx_i64]) == 108 # 'l' + + idx_i64_neg: i64 = -1 + int() + assert ord(s[idx_i64_neg]) == 111 # 'o' + + idx_overflow: i64 = 10 + int() + with assertRaises(IndexError, "index out of range"): + ord(s[idx_overflow]) + + idx_underflow: i64 = -10 + int() + with assertRaises(IndexError, "index out of range"): + ord(s[idx_underflow]) + +def test_ord_str_index_unicode_mix() -> None: + # Mix of 1-byte, 2-byte, 3-byte, and 4-byte characters + s = "a\u00e9\u4f60😀" # 'a', 'é', '你', '😀' + assert ord(s[0 + int()]) == 97 # 1-byte + assert ord(s[1 + int()]) == 233 # 2-byte + assert ord(s[2 + int()]) == 20320 # 3-byte + assert ord(s[3 + int()]) == 128512 # 4-byte + [case testDecode] from testutil import assertRaises diff --git a/mypyc/test/test_cheader.py b/mypyc/test/test_cheader.py index d21eefdb9bc5..82223a0c451f 100644 --- a/mypyc/test/test_cheader.py +++ b/mypyc/test/test_cheader.py @@ -8,7 +8,23 @@ import unittest from mypyc.ir.deps import SourceDep -from mypyc.primitives import registry +from mypyc.ir.ops import PrimitiveDescription +from mypyc.primitives import ( + bytearray_ops, + bytes_ops, + dict_ops, + exc_ops, + float_ops, + generic_ops, + int_ops, + list_ops, + misc_ops, + registry, + set_ops, + str_ops, + tuple_ops, + weakref_ops, +) class TestHeaderInclusion(unittest.TestCase): @@ -35,6 +51,26 @@ def check_name(name: str) -> None: for ops in values: all_ops.extend(ops) + for module in [ + bytes_ops, + str_ops, + dict_ops, + list_ops, + bytearray_ops, + generic_ops, + int_ops, + misc_ops, + tuple_ops, + exc_ops, + float_ops, + set_ops, + weakref_ops, + ]: + for name in dir(module): + val = getattr(module, name, None) + if isinstance(val, PrimitiveDescription): + all_ops.append(val) + # Find additional headers via extra C source file dependencies. for op in all_ops: if op.dependencies: