diff --git a/examples/online/config/dmc/algo/aca.yaml b/examples/online/config/dmc/algo/aca.yaml new file mode 100644 index 0000000..d40491a --- /dev/null +++ b/examples/online/config/dmc/algo/aca.yaml @@ -0,0 +1,41 @@ +# @package _global_ + +algo: + name: aca + target_update_freq: 1 + feature_dim: 512 + rff_dim: 1024 + critic_hidden_dims: [512, 512] + reward_hidden_dims: [512, 512] + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + ctrl_coef: 1.0 + reward_coef: 1.0 + critic_coef: 1.0 + critic_activation: elu # not used + back_critic_grad: false + feature_lr: 0.0001 + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + feature_ema: 0.005 + clip_grad_norm: null + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + num_noises: 25 + linear: false + ranking: true + +norm_obs: true diff --git a/examples/online/config/dmc/algo/ctrl_qsm.yaml b/examples/online/config/dmc/algo/ctrl_qsm.yaml new file mode 100644 index 0000000..af85061 --- /dev/null +++ b/examples/online/config/dmc/algo/ctrl_qsm.yaml @@ -0,0 +1,49 @@ +# @package _global_ + +algo: + name: ctrl_qsm + actor_update_freq: 1 + target_update_freq: 1 + discount: 0.99 + ema: 0.005 + # critic_hidden_dims: [512, 512, 512] # not used + critic_activation: elu # not used + critic_ensemble_size: 2 + layer_norm: true + critic_lr: 0.0003 + clip_grad_norm: null + + # below are params specific to ctrl_td3 + feature_dim: 512 + feature_lr: 0.0001 + feature_ema: 0.005 + phi_hidden_dims: [512, 512] + mu_hidden_dims: [512, 512] + critic_hidden_dims: [512, ] + reward_hidden_dims: [512, ] + rff_dim: 1024 + ctrl_coef: 1.0 + reward_coef: 1.0 + back_critic_grad: false + critic_coef: 1.0 + + num_noises: 25 + linear: false + ranking: true + + num_samples: 10 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm + +norm_obs: true diff --git a/examples/online/config/dmc/algo/qsm.yaml b/examples/online/config/dmc/algo/qsm.yaml new file mode 100644 index 0000000..12b9624 --- /dev/null +++ b/examples/online/config/dmc/algo/qsm.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512, 512] + critic_activation: elu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/alac.yaml b/examples/online/config/mujoco/algo/alac.yaml new file mode 100644 index 0000000..94fe9e3 --- /dev/null +++ b/examples/online/config/mujoco/algo/alac.yaml @@ -0,0 +1,24 @@ +# @package _global_ + +algo: + name: alac + discount: 0.99 + num_samples: 10 + ema: 0.005 + ld: + resnet: false + activation: relu + ensemble_size: 2 + time_dim: 64 + hidden_dims: [512, 512] + cond_hidden_dims: [128, 128] + steps: 20 + step_size: 0.05 + noise_scale: 1.0 + noise_schedule: "none" + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + epsilon: 0.001 + lr: 0.0003 + clip_grad_norm: null diff --git a/examples/online/config/mujoco/algo/idem.yaml b/examples/online/config/mujoco/algo/idem.yaml new file mode 100644 index 0000000..f4a9b36 --- /dev/null +++ b/examples/online/config/mujoco/algo/idem.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: idem + critic_hidden_dims: [256, 256] + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + num_reverse_samples: 500 + ema: 0.005 + temp: 0.2 + diffusion: + time_dim: 64 + mlp_hidden_dims: [256, 256] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/algo/qsm.yaml b/examples/online/config/mujoco/algo/qsm.yaml new file mode 100644 index 0000000..c646df7 --- /dev/null +++ b/examples/online/config/mujoco/algo/qsm.yaml @@ -0,0 +1,23 @@ +# @package _global_ + +algo: + name: qsm + critic_hidden_dims: [512, 512] + critic_activation: relu + critic_lr: 0.0003 + discount: 0.99 + num_samples: 10 + ema: 0.005 + temp: 0.1 + diffusion: + time_dim: 64 + mlp_hidden_dims: [512, 512] + lr: 0.0003 + end_lr: null + lr_decay_steps: null + lr_decay_begin: null + steps: 20 + clip_sampler: true + x_min: -1.0 + x_max: 1.0 + solver: ddpm diff --git a/examples/online/config/mujoco/config.yaml b/examples/online/config/mujoco/config.yaml index f662225..7658820 100644 --- a/examples/online/config/mujoco/config.yaml +++ b/examples/online/config/mujoco/config.yaml @@ -27,7 +27,6 @@ random_frames: 5_000 eval_frames: 10_000 log_frames: 1_000 lap_reset_frames: 250 -eval_episodes: 10 log: dir: logs tag: debug diff --git a/examples/online/main_dmc_offpolicy.py b/examples/online/main_dmc_offpolicy.py index 0f8e59f..cbd1c71 100644 --- a/examples/online/main_dmc_offpolicy.py +++ b/examples/online/main_dmc_offpolicy.py @@ -26,7 +26,10 @@ "td7": TD7Agent, "sdac": SDACAgent, "dpmd": DPMDAgent, + "qsm": QSMAgent, "ctrl_td3": CtrlTD3Agent, + "ctrl_qsm": CtrlQSMAgent, + "aca": ACAAgent, } class OffPolicyTrainer(): diff --git a/examples/online/main_mujoco_offpolicy.py b/examples/online/main_mujoco_offpolicy.py index f408493..4432994 100644 --- a/examples/online/main_mujoco_offpolicy.py +++ b/examples/online/main_mujoco_offpolicy.py @@ -6,10 +6,10 @@ import jax import numpy as np import omegaconf -import wandb from omegaconf import OmegaConf from tqdm import tqdm +import wandb from flowrl.agent.online import * from flowrl.config.online.mujoco import Config from flowrl.dataset.buffer.state import ReplayBuffer @@ -25,6 +25,9 @@ "td7": TD7Agent, "sdac": SDACAgent, "dpmd": DPMDAgent, + "qsm": QSMAgent, + "idem": IDEMAgent, + "alac": ALACAgent, } class OffPolicyTrainer(): diff --git a/flowrl/agent/online/__init__.py b/flowrl/agent/online/__init__.py index 9041320..4bc4767 100644 --- a/flowrl/agent/online/__init__.py +++ b/flowrl/agent/online/__init__.py @@ -1,11 +1,15 @@ from ..base import BaseAgent -from .ctrl.ctrl import CtrlTD3Agent +from .alac.alac import ALACAgent +from .ctrl import * from .dpmd import DPMDAgent +from .idem import IDEMAgent from .ppo import PPOAgent +from .qsm import QSMAgent from .sac import SACAgent from .sdac import SDACAgent from .td3 import TD3Agent from .td7.td7 import TD7Agent +from .unirep import * __all__ = [ "BaseAgent", @@ -15,5 +19,10 @@ "SDACAgent", "DPMDAgent", "PPOAgent", + "QSMAgent", + "IDEMAgent", + "ALACAgent", "CtrlTD3Agent", + "CtrlQSMAgent", + "ACAAgent", ] diff --git a/flowrl/agent/online/alac/alac.py b/flowrl/agent/online/alac/alac.py new file mode 100644 index 0000000..45bbe3c --- /dev/null +++ b/flowrl/agent/online/alac/alac.py @@ -0,0 +1,339 @@ +from functools import partial +from typing import Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.alac.network import EnsembleEnergyNet +from flowrl.config.online.mujoco.algo.alac import ALACConfig +from flowrl.flow.langevin_dynamics import AnnealedLangevinDynamics +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.mlp import MLP, ResidualMLP +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples")) +def jit_sample_actions( + rng: PRNGKey, + actor: AnnealedLangevinDynamics, + ld: AnnealedLangevinDynamics, + obs, + training: bool, + num_samples: int +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, x_init_rng = jax.random.split(rng) + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + x_init = jax.random.normal(x_init_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, x_init, obs_repeat, training=training) + if num_samples == 1: + actions = actions[:, 0] + else: + qs = ld(actions, t=jnp.zeros((B, num_samples, 1), dtype=jnp.float32), condition=obs_repeat) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "ema")) +def jit_update_ld( + rng: PRNGKey, + ld: AnnealedLangevinDynamics, + ld_target: AnnealedLangevinDynamics, + actor: AnnealedLangevinDynamics, + batch: Batch, + discount: float, + ema: float, +) -> Tuple[PRNGKey, AnnealedLangevinDynamics, AnnealedLangevinDynamics, Metric]: + B, A = batch.action.shape[0], batch.action.shape[1] + feed_t = jnp.zeros((B, 1), dtype=jnp.float32) + + rng, next_xT_rng = jax.random.split(rng) + # next_action_init = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], ld.x_dim)) + next_action_init = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + # rng, next_action, history = ld_target.sample( + # rng, + # next_action_init, + # batch.next_obs, + # training=False, + # ) + rng, next_action, history = actor.sample( + rng, + next_action_init, + batch.next_obs, + training=False, + ) + q_target = ld_target(next_action, feed_t, batch.next_obs, training=False) + # q_target = ld_target(batch.next_obs, next_action, training=False) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target.min(axis=0) + + def ld_loss_fn(params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = ld.apply( + {"params": params}, + batch.action, + t=feed_t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + # q_pred = ld.apply( + # {"params": params}, + # batch.obs, + # batch.action, + # training=True, + # rngs={"dropout": dropout_rng}, + # ) + ld_loss = ((q_pred - q_target[jnp.newaxis, :])**2).mean() + return ld_loss, { + "loss/ld_loss": ld_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + # "misc/q_grad_l1": jnp.abs(history[1]).mean(), + } + + new_ld, ld_metrics = ld.apply_gradient(ld_loss_fn) + new_ld_target = ema_update(new_ld, ld_target, ema) + + # record energy + # num_checkpoints = 5 + # stepsize_checkpoint = ld.steps // num_checkpoints + # energy_history = history[2][jnp.arange(0, ld.steps, stepsize_checkpoint)] + # energy_history = energy_history.mean(axis=[-2, -1]) + # ld_metrics.update({ + # f"info/energy_step{i}": energy for i, energy in enumerate(energy_history) + # }) + + return rng, new_ld, new_ld_target, ld_metrics + + +@partial(jax.jit, static_argnames=()) +def jit_update_actor( + rng: PRNGKey, + actor: AnnealedLangevinDynamics, + critic_target: AnnealedLangevinDynamics, + batch: Batch, +) -> Tuple[PRNGKey, AnnealedLangevinDynamics, Metric]: + x0 = batch.action + rng, xt, t, eps = actor.add_noise(rng, x0) + # rng, t_rng, noise_rng = jax.random.split(rng, 3) + # t = jax.random.uniform(t_rng, (*x0.shape[:-1], 1), dtype=jnp.float32, minval=actor.t_diffusion[0], maxval=actor.t_diffusion[1]) + # eps = jax.random.normal(noise_rng, x0.shape, dtype=jnp.float32) + alpha, sigma = actor.noise_schedule_func(t) + xt = alpha * x0 + sigma * eps + + q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(a, None, condition=s).min(axis=0).mean())) + # q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) + q_grad = q_grad_fn(xt, batch.obs) + q_grad = alpha * q_grad - sigma * xt + eps_estimation = sigma * q_grad / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + xt, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class ALACAgent(BaseAgent): + """ + Annealed Langevin Dynamics Actor-Critic (ALAC) agent. + """ + name = "ALACAgent" + model_names = ["actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: ALACConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, ld_rng = jax.random.split(self.rng, 2) + + # define the critic + # from flowrl.module.critic import EnsembleCritic + # from flowrl.module.model import Model + # critic_def = EnsembleCritic( + # hidden_dims=cfg.ld.hidden_dims, + # activation=jax.nn.relu, + # layer_norm=False, + # dropout=None, + # ensemble_size=2, + # ) + # self.ld = Model.create( + # critic_def, + # ld_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + # optimizer=optax.adam(learning_rate=cfg.ld.lr), + # ) + # self.ld_target = Model.create( + # critic_def, + # ld_rng, + # inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + # ) + + mlp_impl = ResidualMLP if cfg.ld.resnet else MLP + activation = {"mish": mish, "relu": jax.nn.relu}[cfg.ld.activation] + energy_def = EnsembleEnergyNet( + mlp_impl=mlp_impl, + hidden_dims=cfg.ld.hidden_dims, + output_dim=1, + activation=activation, + layer_norm=False, + dropout=None, + ensemble_size=cfg.ld.ensemble_size, + # time_embedding=partial(LearnableFourierEmbedding, output_dim=cfg.ld.time_dim), + time_embedding=None, + # cond_embedding=partial(MLP, hidden_dims=cfg.ld.cond_hidden_dims, activation=activation), + cond_embedding=None, + ) + self.ld = AnnealedLangevinDynamics.create( + network=energy_def, + rng=ld_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=False, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule=cfg.ld.noise_schedule, + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + optimizer=optax.adam(learning_rate=cfg.ld.lr), + clip_grad_norm=cfg.ld.clip_grad_norm, + ) + self.ld_target = AnnealedLangevinDynamics.create( + network=energy_def, + rng=ld_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=False, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule=cfg.ld.noise_schedule, + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + ) + + # DEBUG define the actor + from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone + self.rng, actor_rng = jax.random.split(self.rng, 2) + time_embedding = partial[LearnableFourierEmbedding](LearnableFourierEmbedding, output_dim=cfg.ld.time_dim) + cond_embedding = partial(MLP, hidden_dims=[128, 128], activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.ld.hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + # self.actor = ContinuousDDPM.create( + # network=backbone_def, + # rng=actor_rng, + # inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + # x_dim=self.act_dim, + # steps=cfg.ld.steps, + # noise_schedule="cosine", + # noise_schedule_params={}, + # clip_sampler=cfg.ld.clip_sampler, + # x_min=cfg.ld.x_min, + # x_max=cfg.ld.x_max, + # t_schedule_n=1.0, + # optimizer=optax.adam(learning_rate=cfg.ld.lr), + # ) + self.actor = AnnealedLangevinDynamics.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.ones((1, 1)), jnp.ones((1, self.obs_dim))), + x_dim=self.act_dim, + grad_prediction=True, + steps=cfg.ld.steps, + step_size=cfg.ld.step_size, + noise_scale=cfg.ld.noise_scale, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.ld.clip_sampler, + x_min=cfg.ld.x_min, + x_max=cfg.ld.x_max, + t_schedule_n=1.0, + epsilon=cfg.ld.epsilon, + optimizer=optax.adam(learning_rate=cfg.ld.lr), + clip_grad_norm=cfg.ld.clip_grad_norm, + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.ld, self.ld_target, ld_metrics = jit_update_ld( + self.rng, + self.ld, + self.ld_target, + self.actor, + batch, + self.cfg.discount, + self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_actor( + self.rng, + self.actor, + self.ld_target, + batch, + ) + + self._n_training_steps += 1 + return {**ld_metrics, **actor_metrics} + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + # self.ld, + self.actor, + self.ld, + obs, + training=False, + num_samples=num_samples, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} diff --git a/flowrl/agent/online/alac/network.py b/flowrl/agent/online/alac/network.py new file mode 100644 index 0000000..6f9e05b --- /dev/null +++ b/flowrl/agent/online/alac/network.py @@ -0,0 +1,87 @@ +import flax.linen as nn +import jax.numpy as jnp + +from flowrl.functional.activation import mish +from flowrl.module.mlp import MLP +from flowrl.types import * + + +class EnergyNet(nn.Module): + mlp_impl: nn.Module + hidden_dims: Sequence[int] + output_dim: int = 1 + activation: Callable = nn.relu + layer_norm: bool = False + dropout: Optional[float] = None + cond_embedding: Optional[nn.Module] = None + time_embedding: Optional[nn.Module] = None + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + t: Optional[jnp.ndarray] = None, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + if condition is not None: + if self.cond_embedding is not None: + condition = self.cond_embedding()(condition, training=training) + else: + condition = condition + x = jnp.concatenate([x, condition], axis=-1) + if self.time_embedding is not None: + t_ff = self.time_embedding()(t) + t_ff = MLP( + hidden_dims=[t_ff.shape[-1], t_ff.shape[-1]], + activation=mish, + )(t_ff) + x = jnp.concatenate([x, t_ff], axis=-1) + x = self.mlp_impl( + hidden_dims=self.hidden_dims, + output_dim=self.output_dim, + activation=self.activation, + layer_norm=self.layer_norm, + dropout=self.dropout, + )(x, training) + return x + + +class EnsembleEnergyNet(nn.Module): + mlp_impl: nn.Module + hidden_dims: Sequence[int] + output_dim: int = 1 + activation: Callable = nn.relu + layer_norm: bool = False + dropout: Optional[float] = None + cond_embedding: Optional[nn.Module] = None + time_embedding: Optional[nn.Module] = None + ensemble_size: int = 2 + + @nn.compact + def __call__( + self, + x: jnp.ndarray, + t: Optional[jnp.ndarray] = None, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + ) -> jnp.ndarray: + vmap_energy_net = nn.vmap( + EnergyNet, + variable_axes={"params": 0}, + split_rngs={"params": True, "dropout": True}, + in_axes=None, + out_axes=0, + axis_size=self.ensemble_size + ) + x = vmap_energy_net( + mlp_impl=self.mlp_impl, + hidden_dims=self.hidden_dims, + output_dim=self.output_dim, + activation=self.activation, + layer_norm=self.layer_norm, + dropout=self.dropout, + cond_embedding=self.cond_embedding, + time_embedding=self.time_embedding, + )(x, t, condition, training) + return x diff --git a/flowrl/agent/online/ctrl/__init__.py b/flowrl/agent/online/ctrl/__init__.py new file mode 100644 index 0000000..602a563 --- /dev/null +++ b/flowrl/agent/online/ctrl/__init__.py @@ -0,0 +1,7 @@ +from .ctrl_qsm import CtrlQSMAgent +from .ctrl_td3 import CtrlTD3Agent + +__all__ = [ + "CtrlTD3Agent", + "CtrlQSMAgent", +] diff --git a/flowrl/agent/online/ctrl/ctrl_qsm.py b/flowrl/agent/online/ctrl/ctrl_qsm.py new file mode 100644 index 0000000..2960ae1 --- /dev/null +++ b/flowrl/agent/online/ctrl/ctrl_qsm.py @@ -0,0 +1,287 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce +from flowrl.agent.online.qsm import QSMAgent +from flowrl.config.online.mujoco.algo.ctrl.ctrl_qsm import CtrlQSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM +from flowrl.functional.ema import ema_update +from flowrl.module.model import Model +from flowrl.module.rff import RffEnsembleCritic +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + feature = nce_target(obs_repeat, actions, method="forward_phi") + qs = critic(feature) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver")) +def update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + nce_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float +) -> Tuple[PRNGKey, Model, Metric]: + rng, sample_rng = jax.random.split(rng) + next_xT = jax.random.normal(sample_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample( + rng, + next_xT, + batch.next_obs, + training=False, + solver=solver, + ) + next_feature = nce_target(batch.next_obs, next_action, method="forward_phi") + q_target = critic_target(next_feature).min(0) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target + + feature = nce_target(batch.obs, batch.action, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q_pred = critic.apply( + {"params": critic_params}, + feature, + rngs={"dropout": dropout_rng}, + ) + critic_loss = critic_coef * ((q_pred - q_target[jnp.newaxis, :])**2).sum(0).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q_pred.mean(), + "misc/reward": batch.reward.mean(), + } + + new_critic, metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, metrics + +@partial(jax.jit, static_argnames=("temp")) +def update_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, Model, Metric]: + + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + def get_q_value(action: jnp.ndarray, obs: jnp.ndarray) -> jnp.ndarray: + feature = nce_target(obs, action, method="forward_phi") + q = critic_target(feature) + return q.min(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class CtrlQSMAgent(QSMAgent): + """ + CTRL with Q Score Matching (QSM) agent. + """ + + name = "CtrlQSMAgent" + model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: CtrlQSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.ctrl_coef = cfg.ctrl_coef + self.critic_coef = cfg.critic_coef + + self.linear = cfg.linear + self.ranking = cfg.ranking + self.feature_dim = cfg.feature_dim + self.num_noises = cfg.num_noises + self.reward_coef = cfg.reward_coef + self.rff_dim = cfg.rff_dim + self.actor_update_freq = cfg.actor_update_freq + self.target_update_freq = cfg.target_update_freq + + + # sanity checks for the hyper-parameters + assert not self.linear, "linear mode is not supported yet" + + # networks + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + nce_def = FactorizedNCE( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + self.ranking, + ) + self.nce = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.nce_target = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + critic_def = RffEnsembleCritic( + feature_dim=self.feature_dim, + hidden_dims=cfg.critic_hidden_dims, + rff_dim=cfg.rff_dim, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim)),), + ) + + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) + + self.rng, self.critic, critic_metrics = update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.nce_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + + if self._n_training_steps % self.actor_update_freq == 0: + self.rng, self.actor, actor_metrics = update_actor( + self.rng, + self.actor, + self.nce_target, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.nce_target, + self.critic, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/ctrl/ctrl.py b/flowrl/agent/online/ctrl/ctrl_td3.py similarity index 99% rename from flowrl/agent/online/ctrl/ctrl.py rename to flowrl/agent/online/ctrl/ctrl_td3.py index 186c4d9..f379b46 100644 --- a/flowrl/agent/online/ctrl/ctrl.py +++ b/flowrl/agent/online/ctrl/ctrl_td3.py @@ -7,7 +7,7 @@ from flowrl.agent.online.ctrl.network import FactorizedNCE, update_factorized_nce from flowrl.agent.online.td3 import TD3Agent -from flowrl.config.online.mujoco.algo.ctrl_td3 import CtrlTD3Config +from flowrl.config.online.mujoco.algo.ctrl.ctrl_td3 import CtrlTD3Config from flowrl.functional.ema import ema_update from flowrl.module.actor import SquashedDeterministicActor from flowrl.module.mlp import MLP diff --git a/flowrl/agent/online/idem.py b/flowrl/agent/online/idem.py new file mode 100644 index 0000000..4fdfe41 --- /dev/null +++ b/flowrl/agent/online/idem.py @@ -0,0 +1,105 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.qsm import QSMAgent, jit_update_qsm_critic +from flowrl.config.online.mujoco.algo.idem import IDEMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + +jit_update_idem_critic = jit_update_qsm_critic + +@partial(jax.jit, static_argnames=("num_reverse_samples", "temp",)) +def jit_update_idem_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + critic_target: Model, + batch: Batch, + num_reverse_samples: int, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + obs_repeat = batch.obs[jnp.newaxis, ...].repeat(num_reverse_samples, axis=0) + + rng, tnormal_rng, clipped_rng = jax.random.split(rng, 3) + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + lower_bound = - 1.0 / alpha2 * at - alpha1 / alpha2 + upper_bound = - 1.0 / alpha2 * at + alpha1 / alpha2 + tnormal_noise = jax.random.truncated_normal(tnormal_rng, lower_bound, upper_bound, (num_reverse_samples, *at.shape)) + normal_noise = jax.random.normal(clipped_rng, (num_reverse_samples, *at.shape)) + normal_noise_clipped = jnp.clip(normal_noise, lower_bound, upper_bound) + eps_reverse = jnp.where(jnp.isnan(tnormal_noise), normal_noise_clipped, tnormal_noise) + a0_hat = 1 / alpha1 * at + alpha2 / alpha1 * eps_reverse + + q_value_and_grad_fn = jax.vmap( + jax.vmap( + jax.value_and_grad(lambda a, s: critic_target(s, a).min(axis=0).mean()), + ) + ) + q_value, q_grad = q_value_and_grad_fn(a0_hat, obs_repeat) + q_grad = q_grad / temp + weight = jax.nn.softmax(q_value / temp, axis=0) + eps_estimation = - (alpha2 / alpha1) * jnp.sum(weight[:, :, jnp.newaxis] * q_grad, axis=0) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/weights": weight.mean(), + "misc/weight_std": weight.std(0).mean(), + "misc/weight_max": weight.max(0).mean(), + "misc/weight_min": weight.min(0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class IDEMAgent(QSMAgent): + """ + Iterative Denoising Energy Matching (iDEM) Agent. + """ + name = "IDEMAgent" + model_names = ["actor", "critic", "critic_target"] + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.critic, self.critic_target, critic_metrics = jit_update_idem_critic( + self.rng, + self.actor, + self.critic, + self.critic_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + ema=self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_idem_actor( + self.rng, + self.actor, + self.critic_target, + batch, + num_reverse_samples=self.cfg.num_reverse_samples, + temp=self.cfg.temp, + ) + self._n_training_steps += 1 + return {**critic_metrics, **actor_metrics} diff --git a/flowrl/agent/online/qsm.py b/flowrl/agent/online/qsm.py new file mode 100644 index 0000000..0a4e0f6 --- /dev/null +++ b/flowrl/agent/online/qsm.py @@ -0,0 +1,246 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.config.online.mujoco.algo.qsm import QSMConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + qs = critic(obs_repeat, actions) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver", "ema")) +def jit_update_qsm_critic( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + critic_target: Model, + batch: Batch, + discount: float, + solver: str, + ema: float, +) -> Tuple[PRNGKey, Model, Model, Metric]: + rng, next_xT_rng = jax.random.split(rng) + next_xT = jax.random.normal(next_xT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_action, _ = actor.sample(rng, next_xT, batch.next_obs, training=False, solver=solver) + q_target = critic_target(batch.next_obs, next_action) + q_target = batch.reward + discount * (1 - batch.terminal) * q_target.min(axis=0) + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q = critic.apply( + {"params": critic_params}, + batch.obs, + batch.action, + training=True, + rngs={"dropout": dropout_rng}, + ) + critic_loss = ((q - q_target[jnp.newaxis, :])**2).mean() + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q_mean": q.mean(), + "misc/reward": batch.reward.mean(), + "misc/next_action_l1": jnp.abs(next_action).mean(), + } + + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + + new_critic_target = ema_update(new_critic, critic_target, ema) + return rng, new_critic, new_critic_target, critic_metrics + +@partial(jax.jit, static_argnames=("temp",)) +def jit_update_qsm_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha1, alpha2 = actor.noise_schedule_func(t) + + q_grad_fn = jax.vmap(jax.grad(lambda a, s: critic_target(s, a).min(axis=0).mean())) + q_grad = q_grad_fn(at, batch.obs) + eps_estimation = - alpha2 * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + "misc/eps_estimation_std": jnp.std(eps_estimation, axis=0).mean(), + } + + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class QSMAgent(BaseAgent): + """ + Q Score Matching (QSM) agent and beyond. + """ + name = "QSMAgent" + model_names = ["actor", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: QSMConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + self.rng, actor_rng, critic_rng = jax.random.split(self.rng, 3) + + # define the actor + time_embedding = partial(LearnableFourierEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + + if cfg.diffusion.lr_decay_steps is not None: + actor_lr = optax.linear_schedule( + init_value=cfg.diffusion.lr, + end_value=cfg.diffusion.end_lr, + transition_steps=cfg.diffusion.lr_decay_steps, + transition_begin=cfg.diffusion.lr_decay_begin, + ) + else: + actor_lr = cfg.diffusion.lr + + self.actor = ContinuousDDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + t_schedule_n=1.0, + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] + critic_def = EnsembleCritic( + hidden_dims=cfg.critic_hidden_dims, + activation=critic_activation, + layer_norm=False, + dropout=None, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.obs_dim)), jnp.ones((1, self.act_dim))), + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + self.rng, self.critic, self.critic_target, critic_metrics = jit_update_qsm_critic( + self.rng, + self.actor, + self.critic, + self.critic_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + ema=self.cfg.ema, + ) + self.rng, self.actor, actor_metrics = jit_update_qsm_actor( + self.rng, + self.actor, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + self._n_training_steps += 1 + return {**critic_metrics, **actor_metrics} + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} diff --git a/flowrl/agent/online/unirep/__init__.py b/flowrl/agent/online/unirep/__init__.py new file mode 100644 index 0000000..e789c67 --- /dev/null +++ b/flowrl/agent/online/unirep/__init__.py @@ -0,0 +1,5 @@ +from .aca import ACAAgent + +__all__ = [ + "ACAAgent", +] diff --git a/flowrl/agent/online/unirep/aca.py b/flowrl/agent/online/unirep/aca.py new file mode 100644 index 0000000..e5451ee --- /dev/null +++ b/flowrl/agent/online/unirep/aca.py @@ -0,0 +1,338 @@ +from functools import partial +from typing import Tuple + +import jax +import jax.numpy as jnp +import optax + +from flowrl.agent.base import BaseAgent +from flowrl.agent.online.unirep.network import FactorizedNCE, update_factorized_nce +from flowrl.config.online.mujoco.algo.unirep.aca import ACAConfig +from flowrl.flow.continuous_ddpm import ContinuousDDPM, ContinuousDDPMBackbone +from flowrl.functional.activation import mish +from flowrl.functional.ema import ema_update +from flowrl.module.critic import EnsembleCritic +from flowrl.module.mlp import MLP +from flowrl.module.model import Model +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import Batch, Metric, Param, PRNGKey + + +@partial(jax.jit, static_argnames=("training", "num_samples", "solver")) +def jit_sample_actions( + rng: PRNGKey, + actor: ContinuousDDPM, + critic: Model, + nce_target: Model, + obs: jnp.ndarray, + training: bool, + num_samples: int, + solver: str, +) -> Tuple[PRNGKey, jnp.ndarray]: + assert len(obs.shape) == 2 + B = obs.shape[0] + rng, xT_rng = jax.random.split(rng) + + # sample + obs_repeat = obs[..., jnp.newaxis, :].repeat(num_samples, axis=-2) + xT = jax.random.normal(xT_rng, (*obs_repeat.shape[:-1], actor.x_dim)) + rng, actions, _ = actor.sample(rng, xT, obs_repeat, training, solver) + if num_samples == 1: + actions = actions[:, 0] + else: + t0 = jnp.ones((obs_repeat.shape[0], num_samples, 1)) + f0 = nce_target(obs_repeat, actions, t0, method="forward_phi") + qs = critic(f0) + qs = qs.min(axis=0).reshape(B, num_samples) + best_idx = qs.argmax(axis=-1) + actions = actions.reshape(B, num_samples, -1)[jnp.arange(B), best_idx] + return rng, actions + +@partial(jax.jit, static_argnames=("discount", "solver", "critic_coef")) +def jit_update_critic( + rng: PRNGKey, + critic: Model, + critic_target: Model, + actor: ContinuousDDPM, + nce_target: Model, + batch: Batch, + discount: float, + solver: str, + critic_coef: float, +) -> Tuple[PRNGKey, Model, Metric]: + # q0 target + t0 = jnp.ones((batch.obs.shape[0], 1)) + rng, next_aT_rng = jax.random.split(rng) + next_aT = jax.random.normal(next_aT_rng, (*batch.next_obs.shape[:-1], actor.x_dim)) + rng, next_a0, _ = actor.sample(rng, next_aT, batch.next_obs, training=False, solver=solver) + next_f0 = nce_target(batch.next_obs, next_a0, t0, method="forward_phi") + q0_target = critic_target(next_f0) + q0_target = batch.reward + discount * (1 - batch.terminal) * q0_target.min(axis=0) + + # qt target + a0 = batch.action + f0 = nce_target(batch.obs, a0, t0, method="forward_phi") + qt_target = critic_target(f0) + + # features + rng, at, t, eps = actor.add_noise(rng, a0) + ft = nce_target(batch.obs, at, t, method="forward_phi") + + def critic_loss_fn(critic_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + q0_pred = critic.apply( + {"params": critic_params}, + f0, + training=True, + rngs={"dropout": dropout_rng}, + ) + qt_pred = critic.apply( + {"params": critic_params}, + ft, + training=True, + rngs={"dropout": dropout_rng}, + ) + critic_loss = ( + ((q0_pred - q0_target[jnp.newaxis, :])**2).mean() + + ((qt_pred - qt_target[jnp.newaxis, :])**2).mean() + ) + return critic_loss, { + "loss/critic_loss": critic_loss, + "misc/q0_mean": q0_pred.mean(), + "misc/qt_mean": qt_pred.mean(), + "misc/reward": batch.reward.mean(), + "misc/next_action_l1": jnp.abs(next_a0).mean(), + } + + new_critic, critic_metrics = critic.apply_gradient(critic_loss_fn) + return rng, new_critic, critic_metrics + +@partial(jax.jit, static_argnames=("temp",)) +def jit_update_actor( + rng: PRNGKey, + actor: ContinuousDDPM, + nce_target: Model, + critic_target: Model, + batch: Batch, + temp: float, +) -> Tuple[PRNGKey, ContinuousDDPM, Metric]: + a0 = batch.action + rng, at, t, eps = actor.add_noise(rng, a0) + alpha, sigma = actor.noise_schedule_func(t) + + def get_q_value(at: jnp.ndarray, obs: jnp.ndarray, t: jnp.ndarray) -> jnp.ndarray: + ft = nce_target(obs, at, t, method="forward_phi") + q = critic_target(ft) + return q.mean(axis=0).mean() + q_grad_fn = jax.vmap(jax.grad(get_q_value)) + q_grad = q_grad_fn(at, batch.obs, t) + eps_estimation = - sigma * q_grad / temp / (jnp.abs(q_grad).mean() + 1e-6) + + def actor_loss_fn(actor_params: Param, dropout_rng: PRNGKey) -> Tuple[jnp.ndarray, Metric]: + eps_pred = actor.apply( + {"params": actor_params}, + at, + t, + condition=batch.obs, + training=True, + rngs={"dropout": dropout_rng}, + ) + loss = ((eps_pred - eps_estimation) ** 2).mean() + return loss, { + "loss/actor_loss": loss, + "misc/eps_estimation_l1": jnp.abs(eps_estimation).mean(), + } + new_actor, actor_metrics = actor.apply_gradient(actor_loss_fn) + return rng, new_actor, actor_metrics + + +class ACAAgent(BaseAgent): + """ + ACA (Actor-Critic with Actor) agent. + """ + name = "ACAAgent" + model_names = ["nce", "nce_target", "actor", "actor_target", "critic", "critic_target"] + + def __init__(self, obs_dim: int, act_dim: int, cfg: ACAConfig, seed: int): + super().__init__(obs_dim, act_dim, cfg, seed) + self.cfg = cfg + + self.feature_dim = cfg.feature_dim + self.ranking = cfg.ranking + self.linear = cfg.linear + self.reward_coef = cfg.reward_coef + self.critic_coef = cfg.critic_coef + + self.rng, nce_rng, nce_init_rng, actor_rng, critic_rng = jax.random.split(self.rng, 5) + + # define the nce + nce_def = FactorizedNCE( + self.obs_dim, + self.act_dim, + self.feature_dim, + cfg.phi_hidden_dims, + cfg.mu_hidden_dims, + cfg.reward_hidden_dims, + cfg.rff_dim, + cfg.num_noises, + self.ranking, + ) + self.nce = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + optimizer=optax.adam(learning_rate=cfg.feature_lr), + clip_grad_norm=cfg.clip_grad_norm, + ) + self.nce_target = Model.create( + nce_def, + nce_rng, + inputs=( + nce_init_rng, + jnp.ones((1, self.obs_dim)), + jnp.ones((1, self.act_dim)), + jnp.ones((1, self.obs_dim)), + ), + ) + + # define the actor + time_embedding = partial(LearnableFourierEmbedding, output_dim=cfg.diffusion.time_dim) + cond_embedding = partial(MLP, hidden_dims=(128, 128), activation=mish) + noise_predictor = partial( + MLP, + hidden_dims=cfg.diffusion.mlp_hidden_dims, + output_dim=act_dim, + activation=mish, + layer_norm=False, + dropout=None, + ) + backbone_def = ContinuousDDPMBackbone( + noise_predictor=noise_predictor, + time_embedding=time_embedding, + cond_embedding=cond_embedding, + ) + + if cfg.diffusion.lr_decay_steps is not None: + actor_lr = optax.linear_schedule( + init_value=cfg.diffusion.lr, + end_value=cfg.diffusion.end_lr, + transition_steps=cfg.diffusion.lr_decay_steps, + transition_begin=cfg.diffusion.lr_decay_begin, + ) + else: + actor_lr = cfg.diffusion.lr + + self.actor = ContinuousDDPM.create( + network=backbone_def, + rng=actor_rng, + inputs=(jnp.ones((1, self.act_dim)), jnp.zeros((1, 1)), jnp.ones((1, self.obs_dim)), ), + x_dim=self.act_dim, + steps=cfg.diffusion.steps, + noise_schedule="cosine", + noise_schedule_params={}, + clip_sampler=cfg.diffusion.clip_sampler, + x_min=cfg.diffusion.x_min, + x_max=cfg.diffusion.x_max, + t_schedule_n=1.0, + optimizer=optax.adam(learning_rate=actor_lr), + ) + + # define the critic + critic_activation = { + "relu": jax.nn.relu, + "elu": jax.nn.elu, + }[cfg.critic_activation] + critic_def = EnsembleCritic( + hidden_dims=cfg.critic_hidden_dims, + activation=critic_activation, + layer_norm=True, + dropout=None, + ensemble_size=2, + ) + self.critic = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim))), + optimizer=optax.adam(learning_rate=cfg.critic_lr), + ) + self.critic_target = Model.create( + critic_def, + critic_rng, + inputs=(jnp.ones((1, self.feature_dim))), + ) + + # define tracking variables + self._n_training_steps = 0 + + def train_step(self, batch: Batch, step: int) -> Metric: + metrics = {} + + self.rng, self.nce, nce_metrics = update_factorized_nce( + self.rng, + self.nce, + batch, + self.ranking, + self.reward_coef, + ) + metrics.update(nce_metrics) + self.rng, self.critic, critic_metrics = jit_update_critic( + self.rng, + self.critic, + self.critic_target, + self.actor, + self.nce_target, + batch, + discount=self.cfg.discount, + solver=self.cfg.diffusion.solver, + critic_coef=self.critic_coef, + ) + metrics.update(critic_metrics) + self.rng, self.actor, actor_metrics = jit_update_actor( + self.rng, + self.actor, + self.nce_target, + self.critic_target, + batch, + temp=self.cfg.temp, + ) + metrics.update(actor_metrics) + + if self._n_training_steps % self.cfg.target_update_freq == 0: + self.sync_target() + + self._n_training_steps += 1 + return metrics + + def sample_actions( + self, + obs: jnp.ndarray, + deterministic: bool = True, + num_samples: int = 1, + ) -> Tuple[jnp.ndarray, Metric]: + # if deterministic is true, sample cfg.num_samples actions and select the best one + # if not, sample 1 action + if deterministic: + num_samples = self.cfg.num_samples + else: + num_samples = 1 + self.rng, action = jit_sample_actions( + self.rng, + self.actor, + self.critic, + self.nce_target, + obs, + training=False, + num_samples=num_samples, + solver=self.cfg.diffusion.solver, + ) + if not deterministic: + action = action + 0.1 * jax.random.normal(self.rng, action.shape) + return action, {} + + def sync_target(self): + self.critic_target = ema_update(self.critic, self.critic_target, self.cfg.ema) + self.nce_target = ema_update(self.nce, self.nce_target, self.cfg.feature_ema) diff --git a/flowrl/agent/online/unirep/network.py b/flowrl/agent/online/unirep/network.py new file mode 100644 index 0000000..a4bc468 --- /dev/null +++ b/flowrl/agent/online/unirep/network.py @@ -0,0 +1,208 @@ +from functools import partial + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax + +from flowrl.flow.continuous_ddpm import cosine_noise_schedule +from flowrl.flow.ddpm import get_noise_schedule +from flowrl.functional.activation import l2_normalize, mish +from flowrl.module.critic import Critic +from flowrl.module.mlp import ResidualMLP +from flowrl.module.model import Model +from flowrl.module.rff import RffReward +from flowrl.module.time_embedding import LearnableFourierEmbedding +from flowrl.types import * +from flowrl.types import Sequence + + +class FactorizedNCE(nn.Module): + obs_dim: int + action_dim: int + feature_dim: int + phi_hidden_dims: Sequence[int] + mu_hidden_dims: Sequence[int] + reward_hidden_dims: Sequence[int] + rff_dim: int = 0 + num_noises: int = 0 + ranking: bool = False + + def setup(self): + self.mlp_t = nn.Sequential( + [LearnableFourierEmbedding(128), nn.Dense(256), mish, nn.Dense(128)] + ) + self.mlp_phi = ResidualMLP( + self.phi_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + self.mlp_mu = ResidualMLP( + self.mu_hidden_dims, + self.feature_dim, + multiplier=1, + activation=mish, + layer_norm=True, + dropout=None, + ) + # self.reward = RffReward( + # self.feature_dim, + # self.reward_hidden_dims, + # rff_dim=self.rff_dim, + # ) + self.reward = Critic( + hidden_dims=self.reward_hidden_dims, + activation=nn.elu, + layer_norm=True, + dropout=None, + ) + if self.num_noises > 0: + self.use_noise_perturbation = True + self.noise_schedule_fn = cosine_noise_schedule + else: + self.use_noise_perturbation = False + self.N = max(self.num_noises, 1) + if not self.ranking: + self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) + else: + self.normalizer = self.param("normalizer", lambda key: jnp.zeros((self.N,), jnp.float32)) + + def forward_phi(self, s, at, t): + x = jnp.concat([s, at], axis=-1) + if t is not None: + t_ff = self.mlp_t(t) + x = jnp.concat([x, t_ff], axis=-1) + x = self.mlp_phi(x) + x = l2_normalize(x, group_size=None) + return x + + def forward_mu(self, sp): + sp = self.mlp_mu(sp) + return sp + + def forward_reward(self, x: jnp.ndarray): # for z_phi + return self.reward(x) + + def forward_logits( + self, + rng: PRNGKey, + s: jnp.ndarray, + a: jnp.ndarray, + sp: jnp.ndarray, + z_mu: jnp.ndarray | None=None + ): + B, D = sp.shape + rng, t_rng, eps_rng = jax.random.split(rng, 3) + if z_mu is None: + z_mu = self.forward_mu(sp) + if self.use_noise_perturbation: + s = jnp.broadcast_to(s, (self.N, B, s.shape[-1])) + a0 = jnp.broadcast_to(a, (self.N, B, a.shape[-1])) + t = jax.random.uniform(t_rng, (self.N,), dtype=jnp.float32) # check removing min val and max val is valid + t = jnp.repeat(t, B).reshape(self.N, B, 1) + eps = jax.random.normal(eps_rng, a0.shape) + alpha, sigma = self.noise_schedule_fn(t) + at = alpha * a0 + sigma * eps + else: + s = jnp.expand_dims(s, 0) + at = jnp.expand_dims(a, 0) + t = None + z_phi = self.forward_phi(s, at, t) + z_mu = jnp.broadcast_to(z_mu, (self.N, B, self.feature_dim)) + logits = jax.lax.batch_matmul(z_phi, jnp.swapaxes(z_mu, -1, -2)) + logits = logits / jnp.exp(self.normalizer[:, None, None]) + rewards = self.forward_reward(z_phi) + return logits, rewards + + def forward_normalizer(self): + return self.normalizer + + def __call__( + self, + rng: PRNGKey, + s, + a, + sp, + ): + logits, rewards = self.forward_logits(rng, s, a, sp) + _ = self.forward_normalizer() + + return logits, rewards + + +@partial(jax.jit, static_argnames=("ranking", "reward_coef")) +def update_factorized_nce( + rng: PRNGKey, + nce: Model, + batch: Batch, + ranking: bool, + reward_coef: float, +) -> Tuple[PRNGKey, Model, Metric]: + B = batch.obs.shape[0] + rng, logits_rng = jax.random.split(rng) + if ranking: + labels = jnp.arange(B) + else: + labels = jnp.eye(B) + + def loss_fn(nce_params: Param, dropout_rng: PRNGKey): + z_mu = nce.apply({"params": nce_params}, batch.next_obs, method="forward_mu") + logits, rewards = nce.apply( + {"params": nce_params}, + logits_rng, + batch.obs, + batch.action, + batch.next_obs, + z_mu, + method="forward_logits", + ) + + if ranking: + model_loss = optax.softmax_cross_entropy_with_integer_labels( + logits, jnp.broadcast_to(labels, (logits.shape[0], B)) + ).mean(axis=-1) + else: + normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + eff_logits = logits + normalizer[:, None, None] - jnp.log(B) + model_loss = optax.sigmoid_binary_cross_entropy(eff_logits, labels).mean([-2, -1]) + normalizer = nce.apply({"params": nce_params}, method="forward_normalizer") + rewards_target = jnp.broadcast_to(batch.reward, rewards.shape) + reward_loss = jnp.mean((rewards - rewards_target) ** 2) + + nce_loss = model_loss.mean() + reward_coef * reward_loss + 0.000 * (logits**2).mean() + + pos_logits = logits[ + jnp.arange(logits.shape[0])[..., jnp.newaxis], + jnp.arange(logits.shape[1]), + jnp.arange(logits.shape[2])[jnp.newaxis, ...].repeat(logits.shape[0], axis=0) + ] + pos_logits_per_noise = pos_logits.mean(axis=-1) + neg_logits = (logits.sum(axis=-1) - pos_logits) / (logits.shape[-1] - 1) + neg_logits_per_noise = neg_logits.mean(axis=-1) + metrics = { + "loss/nce_loss": nce_loss, + "loss/model_loss": model_loss.mean(), + "loss/reward_loss": reward_loss, + "misc/obs_mean": batch.obs.mean(), + "misc/obs_std": batch.obs.std(axis=0).mean(), + } + checkpoints = list(range(0, logits.shape[0], logits.shape[0]//5)) + [logits.shape[0]-1] + metrics.update({ + f"misc/positive_logits_{i}": pos_logits_per_noise[i].mean() for i in checkpoints + }) + metrics.update({ + f"misc/negative_logits_{i}": neg_logits_per_noise[i].mean() for i in checkpoints + }) + metrics.update({ + f"misc/logits_gap_{i}": (pos_logits_per_noise[i] - neg_logits_per_noise[i]).mean() for i in checkpoints + }) + metrics.update({ + f"misc/normalizer_{i}": jnp.exp(normalizer[i]) for i in checkpoints + }) + return nce_loss, metrics + + new_nce, metrics = nce.apply_gradient(loss_fn) + return rng, new_nce, metrics diff --git a/flowrl/config/online/mujoco/__init__.py b/flowrl/config/online/mujoco/__init__.py index e775d25..8a72576 100644 --- a/flowrl/config/online/mujoco/__init__.py +++ b/flowrl/config/online/mujoco/__init__.py @@ -1,12 +1,16 @@ from hydra.core.config_store import ConfigStore +from .algo.alac import ALACConfig from .algo.base import BaseAlgoConfig -from .algo.ctrl_td3 import CtrlTD3Config +from .algo.ctrl import * from .algo.dpmd import DPMDConfig +from .algo.idem import IDEMConfig +from .algo.qsm import QSMConfig from .algo.sac import SACConfig from .algo.sdac import SDACConfig from .algo.td3 import TD3Config from .algo.td7 import TD7Config +from .algo.unirep import * from .config import Config, LogConfig _DEF_SUFFIX = "_cfg_def" @@ -23,7 +27,12 @@ "td3": TD3Config, "td7": TD7Config, "dpmd": DPMDConfig, - "ctrl": CtrlTD3Config, + "qsm": QSMConfig, + "alac": ALACConfig, + "idem": IDEMConfig, + "ctrl_td3": CtrlTD3Config, + "ctrl_qsm": CtrlQSMConfig, + "aca": ACAConfig, } for name, cfg in _CONFIGS.items(): diff --git a/flowrl/config/online/mujoco/algo/alac.py b/flowrl/config/online/mujoco/algo/alac.py new file mode 100644 index 0000000..b07cc24 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/alac.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class ALACLangevinDynamicsConfig: + resnet: bool + activation: str + ensemble_size: int + time_dim: int + hidden_dims: List[int] + cond_hidden_dims: List[int] + steps: int + step_size: float + noise_scale: float + noise_schedule: str + clip_sampler: bool + x_min: float + x_max: float + epsilon: float + lr: float + clip_grad_norm: float | None + +@dataclass +class ALACConfig(BaseAlgoConfig): + name: str + discount: float + ema: float + num_samples: int + ld: ALACLangevinDynamicsConfig diff --git a/flowrl/config/online/mujoco/algo/ctrl/__init__.py b/flowrl/config/online/mujoco/algo/ctrl/__init__.py new file mode 100644 index 0000000..7ed1456 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrl/__init__.py @@ -0,0 +1,7 @@ +from .ctrl_qsm import CtrlQSMConfig +from .ctrl_td3 import CtrlTD3Config + +__all__ = [ + "CtrlTD3Config", + "CtrlQSMConfig", +] diff --git a/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py b/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py new file mode 100644 index 0000000..402aa92 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/ctrl/ctrl_qsm.py @@ -0,0 +1,41 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig +from ..qsm import QSMDiffusionConfig + + +@dataclass +class CtrlQSMConfig(BaseAlgoConfig): + name: str + actor_update_freq: int + target_update_freq: int + discount: float + ema: float + # critic_hidden_dims: List[int] + critic_activation: str # not used + critic_ensemble_size: int + layer_norm: bool + critic_lr: float + clip_grad_norm: float | None + + feature_dim: int + feature_lr: float + feature_ema: float + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + rff_dim: int + ctrl_coef: float + reward_coef: float + back_critic_grad: bool + critic_coef: float + + num_noises: int + linear: bool + ranking: bool + + num_samples: int + temp: float + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/ctrl_td3.py b/flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py similarity index 96% rename from flowrl/config/online/mujoco/algo/ctrl_td3.py rename to flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py index 667496a..374d820 100644 --- a/flowrl/config/online/mujoco/algo/ctrl_td3.py +++ b/flowrl/config/online/mujoco/algo/ctrl/ctrl_td3.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import List -from .base import BaseAlgoConfig +from ..base import BaseAlgoConfig @dataclass diff --git a/flowrl/config/online/mujoco/algo/idem.py b/flowrl/config/online/mujoco/algo/idem.py new file mode 100644 index 0000000..c46e7b6 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/idem.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class IDEMDiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class IDEMConfig(BaseAlgoConfig): + name: str + critic_hidden_dims: List[int] + critic_lr: float + discount: float + num_samples: int + num_reverse_samples: int + ema: float + temp: float + diffusion: IDEMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/qsm.py b/flowrl/config/online/mujoco/algo/qsm.py new file mode 100644 index 0000000..8c02f0f --- /dev/null +++ b/flowrl/config/online/mujoco/algo/qsm.py @@ -0,0 +1,32 @@ +from dataclasses import dataclass +from typing import List + +from .base import BaseAlgoConfig + + +@dataclass +class QSMDiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class QSMConfig(BaseAlgoConfig): + name: str + critic_hidden_dims: List[int] + critic_activation: str + critic_lr: float + discount: float + num_samples: int + ema: float + temp: float + diffusion: QSMDiffusionConfig diff --git a/flowrl/config/online/mujoco/algo/unirep/__init__.py b/flowrl/config/online/mujoco/algo/unirep/__init__.py new file mode 100644 index 0000000..dafaab0 --- /dev/null +++ b/flowrl/config/online/mujoco/algo/unirep/__init__.py @@ -0,0 +1,5 @@ +from .aca import ACAConfig + +__all__ = [ + "ACAConfig", +] diff --git a/flowrl/config/online/mujoco/algo/unirep/aca.py b/flowrl/config/online/mujoco/algo/unirep/aca.py new file mode 100644 index 0000000..3d68d4a --- /dev/null +++ b/flowrl/config/online/mujoco/algo/unirep/aca.py @@ -0,0 +1,49 @@ +from dataclasses import dataclass +from typing import List + +from ..base import BaseAlgoConfig + + +@dataclass +class ACADiffusionConfig: + time_dim: int + mlp_hidden_dims: List[int] + lr: float + end_lr: float + lr_decay_steps: int | None + lr_decay_begin: int + steps: int + clip_sampler: bool + x_min: float + x_max: float + solver: str + + +@dataclass +class ACAConfig(BaseAlgoConfig): + name: str + target_update_freq: int + feature_dim: int + rff_dim: int + critic_hidden_dims: List[int] + reward_hidden_dims: List[int] + phi_hidden_dims: List[int] + mu_hidden_dims: List[int] + ctrl_coef: float + reward_coef: float + critic_coef: float + critic_activation: str + back_critic_grad: bool + feature_lr: float + critic_lr: float + discount: float + num_samples: int + ema: float + feature_ema: float + clip_grad_norm: float | None + temp: float + diffusion: ACADiffusionConfig + + num_noises: int + linear: bool + ranking: bool diff --git a/flowrl/flow/langevin_dynamics.py b/flowrl/flow/langevin_dynamics.py new file mode 100644 index 0000000..81fe14b --- /dev/null +++ b/flowrl/flow/langevin_dynamics.py @@ -0,0 +1,293 @@ +from functools import partial +from typing import Callable, Optional, Sequence, Tuple + +import flax.linen as nn +import jax +import jax.numpy as jnp +import optax +from flax.struct import PyTreeNode, dataclass, field +from flax.training.train_state import TrainState + +from flowrl.flow.continuous_ddpm import cosine_noise_schedule, linear_noise_schedule +from flowrl.module.model import Model +from flowrl.types import * + +# ======= Langevin Dynamics Sampling ======= + +@dataclass +class LangevinDynamics(Model): + state: TrainState + dropout_rng: PRNGKey = field(pytree_node=True) + x_dim: int = field(pytree_node=False, default=None) + grad_prediction: bool = field(pytree_node=False, default=True) + steps: int = field(pytree_node=False, default=None) + step_size: float = field(pytree_node=False, default=None) + noise_scale: float = field(pytree_node=False, default=None) + clip_sampler: bool = field(pytree_node=False, default=None) + x_min: float = field(pytree_node=False, default=None) + x_max: float = field(pytree_node=False, default=None) + + @classmethod + def create( + cls, + network: nn.Module, + rng: PRNGKey, + inputs: Sequence[jnp.ndarray], + x_dim: int, + grad_prediction: bool = True, + steps: int = 100, + step_size: float = 0.01, + noise_scale: float = 1.0, + clip_sampler: bool = False, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + optimizer: Optional[optax.GradientTransformation] = None, + clip_grad_norm: float = None + ) -> 'LangevinDynamics': + ret = super().create(network, rng, inputs, optimizer, clip_grad_norm) + + return ret.replace( + x_dim=x_dim, + grad_prediction=grad_prediction, + steps=steps, + step_size=step_size, + noise_scale=noise_scale, + clip_sampler=clip_sampler, + x_min=x_min, + x_max=x_max, + ) + + @partial(jax.jit, static_argnames=("training")) + def compute_grad( + self, + x: jnp.ndarray, + i: int, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + params: Optional[Param] = None, + dropout_rng: Optional[PRNGKey] = None + ) -> jnp.ndarray: + original_shape = x.shape[:-1] + t = i * jnp.ones((*x.shape[:-1], 1), dtype=jnp.int32) + + x = x.reshape(-1, x.shape[-1]) + t = t.reshape(-1, 1) + condition = condition.reshape(-1, condition.shape[-1]) + if self.grad_prediction: + if training: + grad = self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ) + else: + grad = self(x, t, condition=condition, training=training) + energy = jnp.zeros_like((*x.shape[:-1], 1), dtype=jnp.float32) + else: + if training: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ).mean())) + else: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self(x, t, condition=condition, training=training).mean())) + energy, grad = energy_and_grad_fn(x, t, condition) + return grad.reshape(*original_shape, self.x_dim), energy.reshape(*original_shape, 1) + + @partial(jax.jit, static_argnames=("training", "steps","step_size","noise_scale")) + def sample( + self, + rng: PRNGKey, + x_init: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + steps: Optional[int] = None, + step_size: Optional[float] = None, + noise_scale: Optional[float] = None, + params: Optional[Param] = None, + ) -> Tuple[PRNGKey, jnp.ndarray, Optional[jnp.ndarray]]: + steps = steps or self.steps + step_size = step_size or self.step_size + noise_scale = noise_scale or self.noise_scale + + def fn(input_tuple, i): + rng_, xt = input_tuple + rng_, noise_rng, dropout_rng_ = jax.random.split(rng_, 3) + + grad, energy = self.compute_grad(xt, i, condition=condition, training=training, params=params, dropout_rng=dropout_rng_) + + xt_1 = xt + step_size * grad + if self.clip_sampler: + xt_1 = jnp.clip(xt_1, self.x_min, self.x_max) + noise = jax.random.normal(noise_rng, xt_1.shape, dtype=jnp.float32) + xt_1 += (i>1) * jnp.sqrt(2 * step_size * noise_scale) * noise + + return (rng_, xt_1), (xt, grad, energy) + + output, history = jax.lax.scan(fn, (rng, x_init), jnp.arange(steps, 0, -1), unroll=True) + rng, action = output + return rng, action, history + + +@dataclass +class AnnealedLangevinDynamics(LangevinDynamics): + state: TrainState + dropout_rng: PRNGKey = field(pytree_node=True) + x_dim: int = field(pytree_node=False, default=None) + grad_prediction: bool = field(pytree_node=False, default=True) + steps: int = field(pytree_node=False, default=None) + step_size: float = field(pytree_node=False, default=None) + noise_scale: float = field(pytree_node=False, default=None) + clip_sampler: bool = field(pytree_node=False, default=None) + x_min: float = field(pytree_node=False, default=None) + x_max: float = field(pytree_node=False, default=None) + t_schedule_n: float = field(pytree_node=False, default=None) + t_diffusion: Tuple[float, float] = field(pytree_node=False, default=None) + noise_schedule_func: Callable = field(pytree_node=False, default=None) + + @classmethod + def create( + cls, + network: nn.Module, + rng: PRNGKey, + inputs: Sequence[jnp.ndarray], + x_dim: int, + grad_prediction: bool, + steps: int, + step_size: float, + noise_scale: float, + noise_schedule: str, + noise_schedule_params: Optional[Dict]=None, + clip_sampler: bool = False, + x_min: Optional[float] = None, + x_max: Optional[float] = None, + t_schedule_n: float=1.0, + epsilon: float=0.001, + optimizer: Optional[optax.GradientTransformation]=None, + clip_grad_norm: float=None + ) -> 'AnnealedLangevinDynamics': + ret = super().create( + network, + rng, + inputs, + x_dim, + grad_prediction, + steps, + step_size, + noise_scale, + clip_sampler, + x_min, + x_max, + optimizer, + clip_grad_norm, + ) + + if noise_schedule_params is None: + noise_schedule_params = {} + if noise_schedule == "cosine": + t_diffusion = [epsilon, 0.9946] + else: + t_diffusion = [epsilon, 1.0] + if noise_schedule == "linear": + noise_schedule_func = partial(linear_noise_schedule, **noise_schedule_params) + elif noise_schedule == "cosine": + noise_schedule_func = partial(cosine_noise_schedule, **noise_schedule_params) + elif noise_schedule == "none": + noise_schedule_func = lambda t: (jnp.ones_like(t), jnp.zeros_like(t)) + else: + raise NotImplementedError(f"Unsupported noise schedule: {noise_schedule}") + + return ret.replace( + t_schedule_n=t_schedule_n, + t_diffusion=t_diffusion, + noise_schedule_func=noise_schedule_func, + ) + + @partial(jax.jit, static_argnames=("training")) + def compute_grad( + self, + x: jnp.ndarray, + t: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + params: Optional[Param] = None, + dropout_rng: Optional[PRNGKey] = None + ) -> jnp.ndarray: + original_shape = x.shape[:-1] + t = t * jnp.ones((*x.shape[:-1], 1), dtype=jnp.int32) + + x = x.reshape(-1, x.shape[-1]) + t = t.reshape(-1, 1) + condition = condition.reshape(-1, condition.shape[-1]) + if self.grad_prediction: + if training: + grad = self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ) + else: + grad = self(x, t, condition=condition, training=training) + energy = jnp.zeros((*x.shape[:-1], 1), dtype=jnp.float32) + else: + if training: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self.apply( + {"params": params}, x, t, condition=condition, training=training, rngs={"dropout": dropout_rng} + ).mean())) + else: + energy_and_grad_fn = jax.vmap(jax.value_and_grad(lambda x, t, condition: self(x, t, condition=condition, training=training).mean())) + energy, grad = energy_and_grad_fn(x, t, condition) + # alpha, sigma = self.noise_schedule_func(t) + # grad = alpha * grad - sigma * x + return grad.reshape(*original_shape, self.x_dim), energy.reshape(*original_shape, 1) + + def add_noise(self, rng: PRNGKey, x: jnp.ndarray) -> Tuple[PRNGKey, jnp.ndarray, jnp.ndarray, jnp.ndarray]: + rng, t_rng, noise_rng = jax.random.split(rng, 3) + t = jax.random.uniform(t_rng, (*x.shape[:-1], 1), dtype=jnp.float32, minval=self.t_diffusion[0], maxval=self.t_diffusion[1]) + alpha, sigma = self.noise_schedule_func(t) + eps = jax.random.normal(noise_rng, x.shape, dtype=jnp.float32) + xt = alpha * x + sigma * eps + return rng, xt, t, eps + + @partial(jax.jit, static_argnames=("training", "steps","step_size","noise_scale")) + def sample( + self, + rng: PRNGKey, + x_init: jnp.ndarray, + condition: Optional[jnp.ndarray] = None, + training: bool = False, + steps: Optional[int] = None, + step_size: Optional[float] = None, + noise_scale: Optional[float] = None, + params: Optional[Param] = None, + ) -> Tuple[PRNGKey, jnp.ndarray, Optional[jnp.ndarray]]: + steps = steps or self.steps + # step_size = step_size or self.step_size + # noise_scale = noise_scale or self.noise_scale + t_schedule_n = 1.0 + from flowrl.flow.continuous_ddpm import quad_t_schedule + ts = quad_t_schedule(steps, n=t_schedule_n, tmin=self.t_diffusion[0], tmax=self.t_diffusion[1]) + alpha_hats = self.noise_schedule_func(ts)[0] ** 2 + alphas = alpha_hats[1:] / alpha_hats[:-1] + alphas = jnp.concat([jnp.ones((1, )), alphas], axis=0) + betas = 1 - alphas + alpha1, alpha2 = self.noise_schedule_func(ts) + + t_proto = jnp.ones((*x_init.shape[:-1], 1), dtype=jnp.int32) + + def fn(input_tuple, i): + rng_, xt = input_tuple + rng_, dropout_rng_, key_ = jax.random.split(rng_, 3) + input_t = t_proto * ts[i] + + q_grad, energy = self.compute_grad(xt, ts[i], condition=condition, training=training, params=params, dropout_rng=dropout_rng_) + eps_theta = q_grad + + x0_hat = (xt - jnp.sqrt(1 - alpha_hats[i]) * eps_theta) / jnp.sqrt(alpha_hats[i]) + x0_hat = jnp.clip(x0_hat, self.x_min, self.x_max) if self.clip_sampler else x0_hat + + mean_coef1 = jnp.sqrt(alpha_hats[i-1]) * betas[i] / (1 - alpha_hats[i]) + mean_coef2 = jnp.sqrt(alphas[i]) * (1 - alpha_hats[i-1]) / (1 - alpha_hats[i]) + xt_1 = mean_coef1 * x0_hat + mean_coef2 * xt + xt_1 += (i>1) * jnp.sqrt(betas[i]) * jax.random.normal(key_, xt_1.shape) + + return (rng_, xt_1), (xt, eps_theta, energy) + + output, history = jax.lax.scan(fn, (rng, x_init), jnp.arange(steps, 0, -1), unroll=True) + rng, action = output + return rng, action, history diff --git a/scripts/dmc/aca.sh b/scripts/dmc/aca.sh new file mode 100644 index 0000000..c4ac7c0 --- /dev/null +++ b/scripts/dmc/aca.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=aca" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/ctrl_qsm.sh b/scripts/dmc/ctrl_qsm.sh new file mode 100644 index 0000000..34729d6 --- /dev/null +++ b/scripts/dmc/ctrl_qsm.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=ctrl_qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lambda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/dmc/qsm.sh b/scripts/dmc/qsm.sh new file mode 100644 index 0000000..fea0999 --- /dev/null +++ b/scripts/dmc/qsm.sh @@ -0,0 +1,68 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "acrobot-swingup" + "ball_in_cup-catch" + "cartpole-balance" + "cartpole-balance_sparse" + "cartpole-swingup" + "cartpole-swingup_sparse" + "cheetah-run" + "dog-run" + "dog-stand" + "dog-trot" + "dog-walk" + "finger-spin" + "finger-turn_easy" + "finger-turn_hard" + "fish-swim" + "hopper-hop" + "hopper-stand" + "humanoid-run" + "humanoid-stand" + "humanoid-walk" + "pendulum-swingup" + "quadruped-run" + "quadruped-walk" + "reacher-easy" + "reacher-hard" + "walker-run" + "walker-stand" + "walker-walk" +) + +SHARED_ARGS=( + "algo=qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_dmc_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/mujoco/alac.sh b/scripts/mujoco/alac.sh new file mode 100644 index 0000000..2232a3a --- /dev/null +++ b/scripts/mujoco/alac.sh @@ -0,0 +1,60 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "Ant-v5" + "HalfCheetah-v5" + # "Hopper-v5" + # "HumanoidStandup-v5" + "Humanoid-v5" + # "InvertedDoublePendulum-v5" + # "InvertedPendulum-v5" + # "Pusher-v5" + # "Reacher-v5" + # "Swimmer-v5" + "Walker2d-v5" +) + +SHARED_ARGS=( + "algo=alac" + # "algo.ld.step_size=0.1" + # "algo.ld.noise_scale=0.01" + # "algo.ld.steps=50" + # "log.tag=noise_none-stepsize0.1-noise0.01-steps50-no_last_noise" + "algo.ld.activation=relu" + "algo.ld.steps=20" + "algo.ld.noise_schedule=cosine" + "log.tag=use_ld_but_actually_diffusion-decay_q-temp0.5" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_mujoco_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi diff --git a/scripts/mujoco/qsm.sh b/scripts/mujoco/qsm.sh new file mode 100644 index 0000000..04023c8 --- /dev/null +++ b/scripts/mujoco/qsm.sh @@ -0,0 +1,53 @@ +# Specify which GPUs to use +GPUS=(0 1 2 3 4 5 6 7) # Modify this array to specify which GPUs to use +SEEDS=(0 1 2 3 4) +NUM_EACH_GPU=3 + +PARALLEL=$((NUM_EACH_GPU * ${#GPUS[@]})) + +TASKS=( + "Ant-v5" + "HalfCheetah-v5" + "Hopper-v5" + "HumanoidStandup-v5" + "Humanoid-v5" + "InvertedDoublePendulum-v5" + "InvertedPendulum-v5" + "Pusher-v5" + "Reacher-v5" + "Swimmer-v5" + "Walker2d-v5" +) + +SHARED_ARGS=( + "algo=qsm" + "log.tag=default" + "log.project=flow-rl" + "log.entity=lamda-rl" +) + + +run_task() { + task=$1 + seed=$2 + slot=$3 + num_gpus=${#GPUS[@]} + device_idx=$((slot % num_gpus)) + device=${GPUS[$device_idx]} + echo "Running $env $seed on GPU $device" + command="python3 examples/online/main_mujoco_offpolicy.py task=$task device=$device seed=$seed ${SHARED_ARGS[@]}" + if [ -n "$DRY_RUN" ]; then + echo $command + else + echo $command + $command + fi +} + + +. env_parallel.bash +if [ -n "$DRY_RUN" ]; then + env_parallel -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +else + env_parallel --bar --results log/parallel/$name -P${PARALLEL} run_task {1} {2} {%} ::: ${TASKS[@]} ::: ${SEEDS[@]} +fi