7979 RandomPolicy ,
8080)
8181from torchrl .modules import Actor , OrnsteinUhlenbeckProcessModule , SafeModule
82- from torchrl .weight_update import SharedMemWeightSyncScheme
82+ from torchrl .weight_update import (
83+ MultiProcessWeightSyncScheme ,
84+ SharedMemWeightSyncScheme ,
85+ )
8386
8487if os .getenv ("PYTORCH_TEST_FBCODE" ):
8588 IS_FB = True
@@ -1485,12 +1488,12 @@ def env_fn(seed):
14851488
14861489 @pytest .mark .parametrize ("use_async" , [False , True ])
14871490 @pytest .mark .parametrize ("cudagraph" , [False , True ])
1491+ @pytest .mark .parametrize (
1492+ "weight_sync_scheme" ,
1493+ [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
1494+ )
14881495 @pytest .mark .skipif (not torch .cuda .is_available (), reason = "no cuda device found" )
1489- def test_update_weights (self , use_async , cudagraph ):
1490- from torchrl .weight_update .weight_sync_schemes import (
1491- MultiProcessWeightSyncScheme ,
1492- )
1493-
1496+ def test_update_weights (self , use_async , cudagraph , weight_sync_scheme ):
14941497 def create_env ():
14951498 return ContinuousActionVecMockEnv ()
14961499
@@ -1503,6 +1506,9 @@ def create_env():
15031506 collector_class = (
15041507 MultiSyncDataCollector if not use_async else MultiaSyncDataCollector
15051508 )
1509+ kwargs = {}
1510+ if weight_sync_scheme is not None :
1511+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
15061512 collector = collector_class (
15071513 [create_env ] * 3 ,
15081514 policy = policy ,
@@ -1511,7 +1517,7 @@ def create_env():
15111517 frames_per_batch = 20 ,
15121518 cat_results = "stack" ,
15131519 cudagraph_policy = cudagraph ,
1514- weight_sync_schemes = { "policy" : MultiProcessWeightSyncScheme ()} ,
1520+ ** kwargs ,
15151521 )
15161522 assert "policy" in collector ._weight_senders , collector ._weight_senders .keys ()
15171523 try :
@@ -2857,23 +2863,28 @@ def forward(self, td):
28572863 # ["cuda:0", "cuda"],
28582864 ],
28592865 )
2860- def test_param_sync (self , give_weights , collector , policy_device , env_device ):
2861- from torchrl .weight_update .weight_sync_schemes import (
2862- MultiProcessWeightSyncScheme ,
2863- )
2864-
2866+ @pytest .mark .parametrize (
2867+ "weight_sync_scheme" ,
2868+ [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
2869+ )
2870+ def test_param_sync (
2871+ self , give_weights , collector , policy_device , env_device , weight_sync_scheme
2872+ ):
28652873 policy = TestUpdateParams .Policy ().to (policy_device )
28662874
28672875 env = EnvCreator (lambda : TestUpdateParams .DummyEnv (device = env_device ))
28682876 device = env ().device
28692877 env = [env ]
2878+ kwargs = {}
2879+ if weight_sync_scheme is not None :
2880+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
28702881 col = collector (
28712882 env ,
28722883 policy ,
28732884 device = device ,
28742885 total_frames = 200 ,
28752886 frames_per_batch = 10 ,
2876- weight_sync_schemes = { "policy" : MultiProcessWeightSyncScheme ()} ,
2887+ ** kwargs ,
28772888 )
28782889 try :
28792890 for i , data in enumerate (col ):
@@ -2918,13 +2929,13 @@ def test_param_sync(self, give_weights, collector, policy_device, env_device):
29182929 # ["cuda:0", "cuda"],
29192930 ],
29202931 )
2932+ @pytest .mark .parametrize (
2933+ "weight_sync_scheme" ,
2934+ [None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ],
2935+ )
29212936 def test_param_sync_mixed_device (
2922- self , give_weights , collector , policy_device , env_device
2937+ self , give_weights , collector , policy_device , env_device , weight_sync_scheme
29232938 ):
2924- from torchrl .weight_update .weight_sync_schemes import (
2925- MultiProcessWeightSyncScheme ,
2926- )
2927-
29282939 with torch .device ("cpu" ):
29292940 policy = TestUpdateParams .Policy ()
29302941 policy .param = nn .Parameter (policy .param .data .to (policy_device ))
@@ -2933,13 +2944,16 @@ def test_param_sync_mixed_device(
29332944 env = EnvCreator (lambda : TestUpdateParams .DummyEnv (device = env_device ))
29342945 device = env ().device
29352946 env = [env ]
2947+ kwargs = {}
2948+ if weight_sync_scheme is not None :
2949+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
29362950 col = collector (
29372951 env ,
29382952 policy ,
29392953 device = device ,
29402954 total_frames = 200 ,
29412955 frames_per_batch = 10 ,
2942- weight_sync_schemes = { "policy" : MultiProcessWeightSyncScheme ()} ,
2956+ ** kwargs ,
29432957 )
29442958 try :
29452959 for i , data in enumerate (col ):
@@ -3851,7 +3865,7 @@ def test_weight_update(self, weight_updater):
38513865 if weight_updater == "scheme_shared" :
38523866 kwargs = {"weight_sync_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
38533867 elif weight_updater == "scheme_pipe" :
3854- kwargs = {"weight_sync_schemes" : {"policy" : SharedMemWeightSyncScheme ()}}
3868+ kwargs = {"weight_sync_schemes" : {"policy" : MultiProcessWeightSyncScheme ()}}
38553869 elif weight_updater == "weight_updater" :
38563870 kwargs = {"weight_updater" : self .MPSWeightUpdaterBase (policy_weights , 2 )}
38573871 else :
@@ -3870,14 +3884,16 @@ def test_weight_update(self, weight_updater):
38703884 ** kwargs ,
38713885 )
38723886
3873- collector .update_policy_weights_ ()
3887+ # When using policy_factory, must pass weights explicitly
3888+ collector .update_policy_weights_ (policy_weights )
38743889 try :
38753890 for i , data in enumerate (collector ):
38763891 if i == 2 :
38773892 assert (data ["action" ] != 0 ).any ()
38783893 # zero the policy
38793894 policy_weights .data .zero_ ()
3880- collector .update_policy_weights_ ()
3895+ # When using policy_factory, must pass weights explicitly
3896+ collector .update_policy_weights_ (policy_weights )
38813897 elif i == 3 :
38823898 assert (data ["action" ] == 0 ).all (), data ["action" ]
38833899 break
@@ -3973,11 +3989,11 @@ def test_start_multi(self, total_frames, cls):
39733989 @pytest .mark .parametrize (
39743990 "cls" , [SyncDataCollector , MultiaSyncDataCollector , MultiSyncDataCollector ]
39753991 )
3976- def test_start_update_policy ( self , total_frames , cls ):
3977- from torchrl . weight_update . weight_sync_schemes import (
3978- MultiProcessWeightSyncScheme ,
3979- )
3980-
3992+ @ pytest . mark . parametrize (
3993+ "weight_sync_scheme" ,
3994+ [ None , MultiProcessWeightSyncScheme , SharedMemWeightSyncScheme ] ,
3995+ )
3996+ def test_start_update_policy ( self , total_frames , cls , weight_sync_scheme ):
39813997 rb = ReplayBuffer (storage = LazyMemmapStorage (max_size = 1000 ))
39823998 env = CountingEnv ()
39833999 m = nn .Linear (env .observation_spec ["observation" ].shape [- 1 ], 1 )
@@ -3998,8 +4014,8 @@ def test_start_update_policy(self, total_frames, cls):
39984014
39994015 # Add weight sync schemes for multi-process collectors
40004016 kwargs = {}
4001- if cls != SyncDataCollector :
4002- kwargs ["weight_sync_schemes" ] = {"policy" : MultiProcessWeightSyncScheme ()}
4017+ if cls != SyncDataCollector and weight_sync_scheme is not None :
4018+ kwargs ["weight_sync_schemes" ] = {"policy" : weight_sync_scheme ()}
40034019
40044020 collector = cls (
40054021 env ,
0 commit comments