Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 2 additions & 5 deletions pandas/core/arrays/_arrow_string_mixins.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,15 +174,12 @@ def _str_replace(
or callable(repl)
or not case
or flags
or (
isinstance(repl, str)
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
)
or (isinstance(repl, str) and r"\g<" in repl)
):
raise NotImplementedError(
"replace is not supported with a re.Pattern, callable repl, "
"case=False, flags!=0, or when the replacement string contains "
"named group references (\\g<...>, \\d+)"
"named group references (\\g<...>)"
)

func = pc.replace_substring_regex if regex else pc.replace_substring
Expand Down
2 changes: 1 addition & 1 deletion pandas/core/arrays/string_arrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,7 @@ def _str_replace(
or ( # substitution contains a named group pattern
# https://docs.python.org/3/library/re.html
isinstance(repl, str)
and (r"\g<" in repl or re.search(r"\\\d", repl) is not None)
and r"\g<" in repl
)
):
return super()._str_replace(pat, repl, n, case, flags, regex)
Expand Down
71 changes: 69 additions & 2 deletions pandas/tests/strings/test_find_replace.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pandas as pd
from pandas import (
Series,
StringDtype,
_testing as tm,
)
from pandas.tests.strings import (
Expand Down Expand Up @@ -601,6 +602,10 @@ def test_replace_callable_raises(any_string_dtype, repl):
r"\g<three> \g<two> \g<one>",
["Three Two One", "Baz Bar Foo"],
),
(
r"\3 \2 \1",
["Three Two One", "Baz Bar Foo"],
),
(
r"\g<3> \g<2> \g<1>",
["Three Two One", "Baz Bar Foo"],
Expand All @@ -616,6 +621,7 @@ def test_replace_callable_raises(any_string_dtype, repl):
],
ids=[
"named_groups_full_swap",
"numbered_groups_no_g_full_swap",
"numbered_groups_full_swap",
"single_group_with_literal",
"mixed_group_reference_with_literal",
Expand All @@ -640,22 +646,83 @@ def test_replace_named_groups_regex_swap(
[
r"\g<20>",
r"\20",
r"\40",
r"\4",
],
)
@pytest.mark.parametrize("use_compile", [True, False])
def test_replace_named_groups_regex_swap_expected_fail(
any_string_dtype, repl, use_compile
any_string_dtype, repl, use_compile, request
):
# GH#57636
if (
not use_compile
and r"\g" not in repl
and isinstance(any_string_dtype, StringDtype)
and any_string_dtype.storage == "pyarrow"
):
# calls pyarrow method directly
if repl == r"\20":
mark = pytest.mark.xfail(reason="PyArrow interprets as group + literal")
request.applymarker(mark)

pa = pytest.importorskip("pyarrow")
error_type = pa.ArrowInvalid
error_msg = r"only has \d parenthesized subexpressions"
else:
error_type = re.error
error_msg = "invalid group reference"

pattern = r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)"
if use_compile:
pattern = re.compile(pattern)
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=any_string_dtype)

with pytest.raises(re.error, match="invalid group reference"):
with pytest.raises(error_type, match=error_msg):
ser.str.replace(pattern, repl, regex=True)


@pytest.mark.parametrize(
"pattern, repl",
[
(r"(\w+) (\w+) (\w+)", r"\20"),
(r"(?P<one>\w+) (?P<two>\w+) (?P<three>\w+)", r"\20"),
],
)
def test_pyarrow_ambiguous_group_references(pyarrow_string_dtype, pattern, repl):
# GH#62653
ser = Series(["One Two Three", "Foo Bar Baz"], dtype=pyarrow_string_dtype)

result = ser.str.replace(pattern, repl, regex=True)
expected = Series(["Two0", "Bar0"], dtype=pyarrow_string_dtype)
tm.assert_series_equal(result, expected)


@pytest.mark.parametrize(
"pattern, repl, expected_list",
[
(
r"\[(?P<one>\d+)\]",
r"(\1)",
["var.one(0)", "var.two(1)", "var.three(2)"],
),
(
r"\[(\d+)\]",
r"(\1)",
["var.one(0)", "var.two(1)", "var.three(2)"],
),
],
)
@td.skip_if_no("pyarrow")
def test_pyarrow_backend_group_replacement(pattern, repl, expected_list):
ser = Series(["var.one[0]", "var.two[1]", "var.three[2]"]).convert_dtypes(
dtype_backend="pyarrow"
)
result = ser.str.replace(pattern, repl, regex=True)
expected = Series(expected_list).convert_dtypes(dtype_backend="pyarrow")
tm.assert_series_equal(result, expected)


def test_replace_callable_named_groups(any_string_dtype):
# test regex named groups
ser = Series(["Foo Bar Baz", np.nan], dtype=any_string_dtype)
Expand Down
Loading