@@ -178,31 +178,53 @@ def test_groupby_reduce(
178178 assert_equal (expected_result , result )
179179
180180
181- def gen_array_by (size , func ):
182- by = np .ones (size [- 1 ])
183- rng = np .random .default_rng (12345 )
181+ def maybe_skip_cupy (array_module , func , engine ):
182+ if array_module is np :
183+ return
184+
185+ import cupy
186+
187+ assert array_module is cupy
188+
189+ if engine == "numba" :
190+ pytest .skip ()
191+
192+ if engine == "numpy" and ("prod" in func or "first" in func or "last" in func ):
193+ pytest .xfail ()
194+ elif engine == "flox" and not (
195+ "sum" in func or "mean" in func or "std" in func or "var" in func
196+ ):
197+ pytest .xfail ()
198+
199+
200+ def gen_array_by (size , func , array_module ):
201+ xp = array_module
202+ by = xp .ones (size [- 1 ])
203+ rng = xp .random .default_rng (12345 )
184204 array = rng .random (size )
185205 if "nan" in func and "nanarg" not in func :
186- array [[1 , 4 , 5 ], ...] = np .nan
206+ array [[1 , 4 , 5 ], ...] = xp .nan
187207 elif "nanarg" in func and len (size ) > 1 :
188- array [[1 , 4 , 5 ], 1 ] = np .nan
208+ array [[1 , 4 , 5 ], 1 ] = xp .nan
189209 if func in ["any" , "all" ]:
190210 array = array > 0.5
191211 return array , by
192212
193213
194- @pytest .mark .parametrize ("chunks" , [None , - 1 , 3 , 4 ])
195214@pytest .mark .parametrize ("nby" , [1 , 2 , 3 ])
196215@pytest .mark .parametrize ("size" , ((12 ,), (12 , 9 )))
197- @pytest .mark .parametrize ("add_nan_by " , [True , False ])
216+ @pytest .mark .parametrize ("chunks " , [None , - 1 , 3 , 4 ])
198217@pytest .mark .parametrize ("func" , ALL_FUNCS )
199- def test_groupby_reduce_all (nby , size , chunks , func , add_nan_by , engine ):
218+ @pytest .mark .parametrize ("add_nan_by" , [True , False ])
219+ def test_groupby_reduce_all (nby , size , chunks , func , add_nan_by , engine , array_module ):
200220 if chunks is not None and not has_dask :
201221 pytest .skip ()
202222 if "arg" in func and engine == "flox" :
203223 pytest .skip ()
204224
205- array , by = gen_array_by (size , func )
225+ maybe_skip_cupy (array_module , func , engine )
226+
227+ array , by = gen_array_by (size , func , array_module )
206228 if chunks :
207229 array = dask .array .from_array (array , chunks = chunks )
208230 by = (by ,) * nby
0 commit comments