Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion mypyc/doc/str_operations.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Operators
---------

* Concatenation (``s1 + s2``)
* Indexing (``s[n]``)
* Indexing (``s[n]``; also ``ord(s[n])``, which avoids the temporary length-1 string)
* Slicing (``s[n:m]``, ``s[n:]``, ``s[:m]``)
* Comparisons (``==``, ``!=``)
* Augmented assignment (``s1 += s2``)
Expand Down
1 change: 1 addition & 0 deletions mypyc/ir/deps.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,4 @@ def get_header(self) -> str:
BYTES_EXTRA_OPS: Final = SourceDep("bytes_extra_ops.c")
BYTES_WRITER_EXTRA_OPS: Final = SourceDep("byteswriter_extra_ops.c")
BYTEARRAY_EXTRA_OPS: Final = SourceDep("bytearray_extra_ops.c")
STR_EXTRA_OPS: Final = SourceDep("str_extra_ops.c")
32 changes: 31 additions & 1 deletion mypyc/irbuild/specialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
DictExpr,
Expression,
GeneratorExpr,
IndexExpr,
IntExpr,
ListExpr,
MemberExpr,
Expand Down Expand Up @@ -72,6 +73,8 @@
is_int_rprimitive,
is_list_rprimitive,
is_sequence_rprimitive,
is_str_rprimitive,
is_tagged,
is_uint8_rprimitive,
list_rprimitive,
object_rprimitive,
Expand Down Expand Up @@ -125,9 +128,12 @@
bytes_decode_latin1_strict,
bytes_decode_utf8_strict,
isinstance_str,
str_adjust_index_op,
str_encode_ascii_strict,
str_encode_latin1_strict,
str_encode_utf8_strict,
str_get_item_unsafe_as_int_op,
str_range_check_op,
)
from mypyc.primitives.tuple_ops import isinstance_tuple, new_tuple_set_item_op

Expand Down Expand Up @@ -1126,9 +1132,33 @@ def translate_float(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Valu
def translate_ord(builder: IRBuilder, expr: CallExpr, callee: RefExpr) -> Value | None:
if len(expr.args) != 1 or expr.arg_kinds[0] != ARG_POS:
return None
arg = constant_fold_expr(builder, expr.args[0])
arg_expr = expr.args[0]
arg = constant_fold_expr(builder, arg_expr)
if isinstance(arg, (str, bytes)) and len(arg) == 1:
return Integer(ord(arg))

# Check for ord(s[i]) where s is str and i is an integer
if isinstance(arg_expr, IndexExpr):
# Check base type
base_type = builder.node_type(arg_expr.base)
if is_str_rprimitive(base_type):
# Check index type
index_expr = arg_expr.index
index_type = builder.node_type(index_expr)
if is_tagged(index_type) or is_fixed_width_rtype(index_type):
# This is ord(s[i]) where s is str and i is an integer.
# Generate specialized inline code using the helper.
result = translate_getitem_with_bounds_check(
builder,
arg_expr.base,
[arg_expr.index],
expr,
str_adjust_index_op,
str_range_check_op,
str_get_item_unsafe_as_int_op,
)
return result

return None


Expand Down
4 changes: 4 additions & 0 deletions mypyc/lib-rt/str_extra_ops.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#include "str_extra_ops.h"

// All str extra ops are inline functions in str_extra_ops.h
// This file exists to satisfy the SourceDep requirements
29 changes: 29 additions & 0 deletions mypyc/lib-rt/str_extra_ops.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#ifndef MYPYC_STR_EXTRA_OPS_H
#define MYPYC_STR_EXTRA_OPS_H

#include <Python.h>
#include <stdint.h>
#include "CPy.h"

// Optimized str indexing for ord(s[i])

// If index is negative, convert to non-negative index (no range checking)
static inline int64_t CPyStr_AdjustIndex(PyObject *obj, int64_t index) {
if (index < 0) {
return index + PyUnicode_GET_LENGTH(obj);
}
return index;
}

// Check if index is in valid range [0, len)
static inline bool CPyStr_RangeCheck(PyObject *obj, int64_t index) {
return index >= 0 && index < PyUnicode_GET_LENGTH(obj);
}

// Get character at index as int (ord value) - no bounds checking, returns as CPyTagged
static inline CPyTagged CPyStr_GetItemUnsafeAsInt(PyObject *obj, int64_t index) {
int kind = PyUnicode_KIND(obj);
return PyUnicode_READ(kind, PyUnicode_DATA(obj), index) << 1;
}

#endif
35 changes: 35 additions & 0 deletions mypyc/primitives/str_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

from mypyc.ir.deps import STR_EXTRA_OPS
from mypyc.ir.ops import ERR_MAGIC, ERR_NEVER
from mypyc.ir.rtypes import (
RType,
Expand All @@ -10,10 +11,12 @@
bytes_rprimitive,
c_int_rprimitive,
c_pyssize_t_rprimitive,
int64_rprimitive,
int_rprimitive,
list_rprimitive,
object_rprimitive,
pointer_rprimitive,
short_int_rprimitive,
str_rprimitive,
tuple_rprimitive,
)
Expand Down Expand Up @@ -507,3 +510,35 @@
c_function_name="CPyStr_Ord",
error_kind=ERR_MAGIC,
)

# Optimized str indexing for ord(s[i])

# str index adjustment - convert negative index to positive
str_adjust_index_op = custom_primitive_op(
name="str_adjust_index",
arg_types=[str_rprimitive, int64_rprimitive],
return_type=int64_rprimitive,
c_function_name="CPyStr_AdjustIndex",
error_kind=ERR_NEVER,
dependencies=[STR_EXTRA_OPS],
)

# str range check - check if index is in valid range
str_range_check_op = custom_primitive_op(
name="str_range_check",
arg_types=[str_rprimitive, int64_rprimitive],
return_type=bool_rprimitive,
c_function_name="CPyStr_RangeCheck",
error_kind=ERR_NEVER,
dependencies=[STR_EXTRA_OPS],
)

# str.__getitem__() as int - get character at index as int (ord value) - no bounds checking
str_get_item_unsafe_as_int_op = custom_primitive_op(
name="str_get_item_unsafe_as_int",
arg_types=[str_rprimitive, int64_rprimitive],
return_type=short_int_rprimitive,
c_function_name="CPyStr_GetItemUnsafeAsInt",
error_kind=ERR_NEVER,
dependencies=[STR_EXTRA_OPS],
)
59 changes: 59 additions & 0 deletions mypyc/test-data/irbuild-str.test
Original file line number Diff line number Diff line change
Expand Up @@ -507,6 +507,65 @@ L0:
r6 = unbox(int, r5)
return r6

[case testOrdOfStrIndex_64bit]
from mypy_extensions import i64
def ord_str_index(s: str, i: int) -> int:
return ord(s[i])
def ord_str_index_i64(s: str, i: i64) -> int:
return ord(s[i])
[typing fixtures/typing-full.pyi]
[out]
def ord_str_index(s, i):
s :: str
i :: int
r0 :: native_int
r1 :: bit
r2, r3 :: i64
r4 :: ptr
r5 :: c_ptr
r6, r7 :: i64
r8, r9 :: bool
r10 :: short_int
L0:
r0 = i & 1
r1 = r0 == 0
if r1 goto L1 else goto L2 :: bool
L1:
r2 = i >> 1
r3 = r2
goto L3
L2:
r4 = i ^ 1
r5 = r4
r6 = CPyLong_AsInt64(r5)
r3 = r6
keep_alive i
L3:
r7 = CPyStr_AdjustIndex(s, r3)
r8 = CPyStr_RangeCheck(s, r7)
if r8 goto L5 else goto L4 :: bool
L4:
r9 = raise IndexError('index out of range')
unreachable
L5:
r10 = CPyStr_GetItemUnsafeAsInt(s, r7)
return r10
def ord_str_index_i64(s, i):
s :: str
i, r0 :: i64
r1, r2 :: bool
r3 :: short_int
L0:
r0 = CPyStr_AdjustIndex(s, i)
r1 = CPyStr_RangeCheck(s, r0)
if r1 goto L2 else goto L1 :: bool
L1:
r2 = raise IndexError('index out of range')
unreachable
L2:
r3 = CPyStr_GetItemUnsafeAsInt(s, r0)
return r3

[case testStrip]
from typing import NewType, Union
NewStr = NewType("NewStr", str)
Expand Down
72 changes: 72 additions & 0 deletions mypyc/test-data/run-strings.test
Original file line number Diff line number Diff line change
Expand Up @@ -808,6 +808,7 @@ def test_chr() -> None:

[case testOrd]
from testutil import assertRaises
from mypy_extensions import i64, i32, i16

def test_ord() -> None:
assert ord(' ') == 32
Expand All @@ -834,6 +835,77 @@ def test_ord() -> None:
with assertRaises(TypeError):
ord('')

def test_ord_str_index() -> None:
# ASCII
s1 = "hello"
assert ord(s1[0 + int()]) == 104 # 'h'
assert ord(s1[1 + int()]) == 101 # 'e'
assert ord(s1[4 + int()]) == 111 # 'o'
assert ord(s1[-1 + int()]) == 111 # 'o'
assert ord(s1[-5 + int()]) == 104 # 'h'

# Latin-1 (8 bits)
s2 = "café"
assert ord(s2[0 + int()]) == 99 # 'c'
assert ord(s2[3 + int()]) == 233 # 'é' (U+00E9)
assert ord(s2[-1 + int()]) == 233

# 16-bit unicode
s3 = "你好" # Chinese
assert ord(s3[0 + int()]) == 20320 # '你' (U+4F60)
assert ord(s3[1 + int()]) == 22909 # '好' (U+597D)
assert ord(s3[-1 + int()]) == 22909
assert ord(s3[-2 + int()]) == 20320

# 4-byte unicode
s5 = "a😀b" # Emoji between ASCII chars
assert ord(s5[0 + int()]) == 97 # 'a'
assert ord(s5[1 + int()]) == 128512 # '😀' (U+1F600)
assert ord(s5[2 + int()]) == 98 # 'b'
assert ord(s5[-1 + int()]) == 98
assert ord(s5[-2 + int()]) == 128512
assert ord(s5[-3 + int()]) == 97

with assertRaises(IndexError, "index out of range"):
ord(s1[5 + int()])
with assertRaises(IndexError, "index out of range"):
ord(s1[100 + int()])
with assertRaises(IndexError, "index out of range"):
ord(s1[-6 + int()])
with assertRaises(IndexError, "index out of range"):
ord(s1[-100 + int()])

s_empty = ""
with assertRaises(IndexError, "index out of range"):
ord(s_empty[0 + int()])
with assertRaises(IndexError, "index out of range"):
ord(s_empty[-1 + int()])

def test_ord_str_index_i64() -> None:
s = "hello"

idx_i64: i64 = 2 + int()
assert ord(s[idx_i64]) == 108 # 'l'

idx_i64_neg: i64 = -1 + int()
assert ord(s[idx_i64_neg]) == 111 # 'o'

idx_overflow: i64 = 10 + int()
with assertRaises(IndexError, "index out of range"):
ord(s[idx_overflow])

idx_underflow: i64 = -10 + int()
with assertRaises(IndexError, "index out of range"):
ord(s[idx_underflow])

def test_ord_str_index_unicode_mix() -> None:
# Mix of 1-byte, 2-byte, 3-byte, and 4-byte characters
s = "a\u00e9\u4f60😀" # 'a', 'é', '你', '😀'
assert ord(s[0 + int()]) == 97 # 1-byte
assert ord(s[1 + int()]) == 233 # 2-byte
assert ord(s[2 + int()]) == 20320 # 3-byte
assert ord(s[3 + int()]) == 128512 # 4-byte

[case testDecode]
from testutil import assertRaises

Expand Down
38 changes: 37 additions & 1 deletion mypyc/test/test_cheader.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,23 @@
import unittest

from mypyc.ir.deps import SourceDep
from mypyc.primitives import registry
from mypyc.ir.ops import PrimitiveDescription
from mypyc.primitives import (
bytearray_ops,
bytes_ops,
dict_ops,
exc_ops,
float_ops,
generic_ops,
int_ops,
list_ops,
misc_ops,
registry,
set_ops,
str_ops,
tuple_ops,
weakref_ops,
)


class TestHeaderInclusion(unittest.TestCase):
Expand All @@ -35,6 +51,26 @@ def check_name(name: str) -> None:
for ops in values:
all_ops.extend(ops)

for module in [
bytes_ops,
str_ops,
dict_ops,
list_ops,
bytearray_ops,
generic_ops,
int_ops,
misc_ops,
tuple_ops,
exc_ops,
float_ops,
set_ops,
weakref_ops,
]:
for name in dir(module):
val = getattr(module, name, None)
if isinstance(val, PrimitiveDescription):
all_ops.append(val)

# Find additional headers via extra C source file dependencies.
for op in all_ops:
if op.dependencies:
Expand Down