diff --git a/pandas/core/arrays/_arrow_string_mixins.py b/pandas/core/arrays/_arrow_string_mixins.py index c99ee7d02a226..1288b5815d0e6 100644 --- a/pandas/core/arrays/_arrow_string_mixins.py +++ b/pandas/core/arrays/_arrow_string_mixins.py @@ -174,15 +174,12 @@ def _str_replace( or callable(repl) or not case or flags - or ( - isinstance(repl, str) - and (r"\g<" in repl or re.search(r"\\\d", repl) is not None) - ) + or (isinstance(repl, str) and r"\g<" in repl) ): raise NotImplementedError( "replace is not supported with a re.Pattern, callable repl, " "case=False, flags!=0, or when the replacement string contains " - "named group references (\\g<...>, \\d+)" + "named group references (\\g<...>)" ) func = pc.replace_substring_regex if regex else pc.replace_substring diff --git a/pandas/core/arrays/string_arrow.py b/pandas/core/arrays/string_arrow.py index f9fd74cbd76b1..8c820f723668c 100644 --- a/pandas/core/arrays/string_arrow.py +++ b/pandas/core/arrays/string_arrow.py @@ -427,7 +427,7 @@ def _str_replace( or ( # substitution contains a named group pattern # https://docs.python.org/3/library/re.html isinstance(repl, str) - and (r"\g<" in repl or re.search(r"\\\d", repl) is not None) + and r"\g<" in repl ) ): return super()._str_replace(pat, repl, n, case, flags, regex) diff --git a/pandas/tests/strings/test_find_replace.py b/pandas/tests/strings/test_find_replace.py index 14c704f4dd1e5..2ad15e9c0c937 100644 --- a/pandas/tests/strings/test_find_replace.py +++ b/pandas/tests/strings/test_find_replace.py @@ -9,6 +9,7 @@ import pandas as pd from pandas import ( Series, + StringDtype, _testing as tm, ) from pandas.tests.strings import ( @@ -601,6 +602,10 @@ def test_replace_callable_raises(any_string_dtype, repl): r"\g \g \g", ["Three Two One", "Baz Bar Foo"], ), + ( + r"\3 \2 \1", + ["Three Two One", "Baz Bar Foo"], + ), ( r"\g<3> \g<2> \g<1>", ["Three Two One", "Baz Bar Foo"], @@ -616,6 +621,7 @@ def test_replace_callable_raises(any_string_dtype, repl): ], ids=[ "named_groups_full_swap", + "numbered_groups_no_g_full_swap", "numbered_groups_full_swap", "single_group_with_literal", "mixed_group_reference_with_literal", @@ -640,22 +646,83 @@ def test_replace_named_groups_regex_swap( [ r"\g<20>", r"\20", + r"\40", + r"\4", ], ) @pytest.mark.parametrize("use_compile", [True, False]) def test_replace_named_groups_regex_swap_expected_fail( - any_string_dtype, repl, use_compile + any_string_dtype, repl, use_compile, request ): # GH#57636 + if ( + not use_compile + and r"\g" not in repl + and isinstance(any_string_dtype, StringDtype) + and any_string_dtype.storage == "pyarrow" + ): + # calls pyarrow method directly + if repl == r"\20": + mark = pytest.mark.xfail(reason="PyArrow interprets as group + literal") + request.applymarker(mark) + + pa = pytest.importorskip("pyarrow") + error_type = pa.ArrowInvalid + error_msg = r"only has \d parenthesized subexpressions" + else: + error_type = re.error + error_msg = "invalid group reference" + pattern = r"(?P\w+) (?P\w+) (?P\w+)" if use_compile: pattern = re.compile(pattern) ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype) - with pytest.raises(re.error, match="invalid group reference"): + with pytest.raises(error_type, match=error_msg): ser.str.replace(pattern, repl, regex=True) +@pytest.mark.parametrize( + "pattern, repl", + [ + (r"(\w+) (\w+) (\w+)", r"\20"), + (r"(?P\w+) (?P\w+) (?P\w+)", r"\20"), + ], +) +def test_pyarrow_ambiguous_group_references(pyarrow_string_dtype, pattern, repl): + # GH#62653 + ser = Series(["One Two Three", "Foo Bar Baz"], dtype=pyarrow_string_dtype) + + result = ser.str.replace(pattern, repl, regex=True) + expected = Series(["Two0", "Bar0"], dtype=pyarrow_string_dtype) + tm.assert_series_equal(result, expected) + + +@pytest.mark.parametrize( + "pattern, repl, expected_list", + [ + ( + r"\[(?P\d+)\]", + r"(\1)", + ["var.one(0)", "var.two(1)", "var.three(2)"], + ), + ( + r"\[(\d+)\]", + r"(\1)", + ["var.one(0)", "var.two(1)", "var.three(2)"], + ), + ], +) +@td.skip_if_no("pyarrow") +def test_pyarrow_backend_group_replacement(pattern, repl, expected_list): + ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"]).convert_dtypes( + dtype_backend="pyarrow" + ) + result = ser.str.replace(pattern, repl, regex=True) + expected = Series(expected_list).convert_dtypes(dtype_backend="pyarrow") + tm.assert_series_equal(result, expected) + + def test_replace_callable_named_groups(any_string_dtype): # test regex named groups ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype)