diff --git a/mypyc/irbuild/for_helpers.py b/mypyc/irbuild/for_helpers.py index 715f5432cd13..0ebdf597fdd7 100644 --- a/mypyc/irbuild/for_helpers.py +++ b/mypyc/irbuild/for_helpers.py @@ -7,7 +7,7 @@ from __future__ import annotations -from typing import Callable, ClassVar +from typing import Any, Callable, ClassVar from mypy.nodes import ( ARG_POS, @@ -16,12 +16,14 @@ DictionaryComprehension, Expression, GeneratorExpr, + IndexExpr, ListExpr, Lvalue, MemberExpr, NameExpr, RefExpr, SetExpr, + SliceExpr, StarExpr, StrExpr, TupleExpr, @@ -67,6 +69,7 @@ short_int_rprimitive, ) from mypyc.irbuild.builder import IRBuilder +from mypyc.irbuild.constant_fold import constant_fold_expr from mypyc.irbuild.prepare import GENERATOR_HELPER_NAME from mypyc.irbuild.targets import AssignmentTarget, AssignmentTargetTuple from mypyc.primitives.dict_ops import ( @@ -436,12 +439,33 @@ def make_for_loop_generator( rtyp = builder.node_type(expr) if is_sequence_rprimitive(rtyp): - # Special case "for x in ". - expr_reg = builder.accept(expr) + # Special case "for x in ". target_type = builder.get_sequence_type(expr) - for_list = ForSequence(builder, index, body_block, loop_exit, line, nested) - for_list.init(expr_reg, target_type, reverse=False) + + if isinstance(expr, IndexExpr) and isinstance(expr.index, SliceExpr): + # TODO: maybe we must not apply this optimization to list type specifically + # because the need to check length changes at each iteration? + + def constant_fold_or_none(expr: Expression | None) -> Any: + return None if expr is None else constant_fold_expr(builder, expr) + + start = constant_fold_or_none(expr.index.begin_index) + stop = constant_fold_or_none(expr.index.end_index) + step = constant_fold_or_none(expr.index.stride) + + if all(s is None or isinstance(s, int) for s in (start, stop, step)): + for_list.init( + builder.accept(expr.base), + target_type, + reverse=False, + start=start, + stop=stop, + step=step, + ) + return for_list + + for_list.init(builder.accept(expr), target_type, reverse=False) return for_list if is_dict_rprimitive(rtyp): @@ -821,13 +845,33 @@ class ForSequence(ForGenerator): length_reg: Value | AssignmentTarget | None def init( - self, expr_reg: Value, target_type: RType, reverse: bool, length: Value | None = None + self, + expr_reg: Value, + target_type: RType, + reverse: bool, + length: Value | None = None, + *, + start: int | None = None, + stop: int | None = None, + step: int | None = None, ) -> None: assert is_sequence_rprimitive(expr_reg.type), (expr_reg, expr_reg.type) builder = self.builder # Record a Value indicating the length of the sequence, if known at compile time. self.length = length self.reverse = reverse + + self.start = 0 if start is None else start + assert self.start >= 0, "implement me!" + + self.stop = -1 if stop is None else stop + assert self.stop == -1, "implement me!" + + self.step = 1 if step is None else step + assert self.step and self.step >= 1, "this should be unreachable for step None and step 0" + if reverse: + self.step *= -1 + # Define target to contain the expression, along with the index that will be used # for the for-loop. If we are inside of a generator function, spill these into the # environment class. @@ -835,13 +879,16 @@ def init( if is_immutable_rprimitive(expr_reg.type): # If the expression is an immutable type, we can load the length just once. self.length_reg = builder.maybe_spill(self.length or self.load_len(self.expr_target)) + # TODO: if stop != -1 implement a safety check and then set self.stop_reg + # gen_condition will need to read stop_reg if present else: # Otherwise, even if the length is known, we must recalculate the length # at every iteration for compatibility with python semantics. self.length_reg = None if not reverse: - index_reg: Value = Integer(0, c_pyssize_t_rprimitive) + index_reg: Value = Integer(self.start, c_pyssize_t_rprimitive) else: + # TODO implement start logic if self.length_reg is not None: len_val = builder.read(self.length_reg) else: @@ -854,6 +901,7 @@ def gen_condition(self) -> None: builder = self.builder line = self.line if self.reverse: + # TODO implement start stop step # If we are iterating in reverse order, we obviously need # to check that the index is still positive. Somewhat less # obviously we still need to check against the length, @@ -898,8 +946,7 @@ def gen_step(self) -> None: # Step to the next item. builder = self.builder line = self.line - step = 1 if not self.reverse else -1 - add = builder.builder.int_add(builder.read(self.index_target, line), step) + add = builder.builder.int_add(builder.read(self.index_target, line), self.step) builder.assign(self.index_target, add, line)