Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 41 additions & 0 deletions examples/online/config/dmc/algo/aca.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# @package _global_

algo:
name: aca
target_update_freq: 1
feature_dim: 512
rff_dim: 1024
critic_hidden_dims: [512, 512]
reward_hidden_dims: [512, 512]
phi_hidden_dims: [512, 512]
mu_hidden_dims: [512, 512]
ctrl_coef: 1.0
reward_coef: 1.0
critic_coef: 1.0
critic_activation: elu # not used
back_critic_grad: false
feature_lr: 0.0001
critic_lr: 0.0003
discount: 0.99
num_samples: 10
ema: 0.005
feature_ema: 0.005
clip_grad_norm: null
temp: 0.1
diffusion:
time_dim: 64
mlp_hidden_dims: [512, 512, 512]
lr: 0.0003
end_lr: null
lr_decay_steps: null
lr_decay_begin: null
steps: 20
clip_sampler: true
x_min: -1.0
x_max: 1.0
solver: ddpm
num_noises: 25
linear: false
ranking: true

norm_obs: true
49 changes: 49 additions & 0 deletions examples/online/config/dmc/algo/ctrl_qsm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
# @package _global_

algo:
name: ctrl_qsm
actor_update_freq: 1
target_update_freq: 1
discount: 0.99
ema: 0.005
# critic_hidden_dims: [512, 512, 512] # not used
critic_activation: elu # not used
critic_ensemble_size: 2
layer_norm: true
critic_lr: 0.0003
clip_grad_norm: null

# below are params specific to ctrl_td3
feature_dim: 512
feature_lr: 0.0001
feature_ema: 0.005
phi_hidden_dims: [512, 512]
mu_hidden_dims: [512, 512]
critic_hidden_dims: [512, ]
reward_hidden_dims: [512, ]
rff_dim: 1024
ctrl_coef: 1.0
reward_coef: 1.0
back_critic_grad: false
critic_coef: 1.0

num_noises: 25
linear: false
ranking: true

num_samples: 10
temp: 0.1
diffusion:
time_dim: 64
mlp_hidden_dims: [512, 512, 512]
lr: 0.0003
end_lr: null
lr_decay_steps: null
lr_decay_begin: null
steps: 20
clip_sampler: true
x_min: -1.0
x_max: 1.0
solver: ddpm

norm_obs: true
23 changes: 23 additions & 0 deletions examples/online/config/dmc/algo/qsm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

algo:
name: qsm
critic_hidden_dims: [512, 512, 512]
critic_activation: elu
critic_lr: 0.0003
discount: 0.99
num_samples: 10
ema: 0.005
temp: 0.1
diffusion:
time_dim: 64
mlp_hidden_dims: [512, 512, 512]
lr: 0.0003
end_lr: null
lr_decay_steps: null
lr_decay_begin: null
steps: 20
clip_sampler: true
x_min: -1.0
x_max: 1.0
solver: ddpm
24 changes: 24 additions & 0 deletions examples/online/config/mujoco/algo/alac.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
# @package _global_

algo:
name: alac
discount: 0.99
num_samples: 10
ema: 0.005
ld:
resnet: false
activation: relu
ensemble_size: 2
time_dim: 64
hidden_dims: [512, 512]
cond_hidden_dims: [128, 128]
steps: 20
step_size: 0.05
noise_scale: 1.0
noise_schedule: "none"
clip_sampler: true
x_min: -1.0
x_max: 1.0
epsilon: 0.001
lr: 0.0003
clip_grad_norm: null
23 changes: 23 additions & 0 deletions examples/online/config/mujoco/algo/idem.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

algo:
name: idem
critic_hidden_dims: [256, 256]
critic_lr: 0.0003
discount: 0.99
num_samples: 10
num_reverse_samples: 500
ema: 0.005
temp: 0.2
diffusion:
time_dim: 64
mlp_hidden_dims: [256, 256]
lr: 0.0003
end_lr: null
lr_decay_steps: null
lr_decay_begin: null
steps: 20
clip_sampler: true
x_min: -1.0
x_max: 1.0
solver: ddpm
23 changes: 23 additions & 0 deletions examples/online/config/mujoco/algo/qsm.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# @package _global_

algo:
name: qsm
critic_hidden_dims: [512, 512]
critic_activation: relu
critic_lr: 0.0003
discount: 0.99
num_samples: 10
ema: 0.005
temp: 0.1
diffusion:
time_dim: 64
mlp_hidden_dims: [512, 512]
lr: 0.0003
end_lr: null
lr_decay_steps: null
lr_decay_begin: null
steps: 20
clip_sampler: true
x_min: -1.0
x_max: 1.0
solver: ddpm
1 change: 0 additions & 1 deletion examples/online/config/mujoco/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,6 @@ random_frames: 5_000
eval_frames: 10_000
log_frames: 1_000
lap_reset_frames: 250
eval_episodes: 10
log:
dir: logs
tag: debug
Expand Down
3 changes: 3 additions & 0 deletions examples/online/main_dmc_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,10 @@
"td7": TD7Agent,
"sdac": SDACAgent,
"dpmd": DPMDAgent,
"qsm": QSMAgent,
"ctrl_td3": CtrlTD3Agent,
"ctrl_qsm": CtrlQSMAgent,
"aca": ACAAgent,
}

class OffPolicyTrainer():
Expand Down
5 changes: 4 additions & 1 deletion examples/online/main_mujoco_offpolicy.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import jax
import numpy as np
import omegaconf
import wandb
from omegaconf import OmegaConf
from tqdm import tqdm

import wandb
from flowrl.agent.online import *
from flowrl.config.online.mujoco import Config
from flowrl.dataset.buffer.state import ReplayBuffer
Expand All @@ -25,6 +25,9 @@
"td7": TD7Agent,
"sdac": SDACAgent,
"dpmd": DPMDAgent,
"qsm": QSMAgent,
"idem": IDEMAgent,
"alac": ALACAgent,
}

class OffPolicyTrainer():
Expand Down
11 changes: 10 additions & 1 deletion flowrl/agent/online/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,15 @@
from ..base import BaseAgent
from .ctrl.ctrl import CtrlTD3Agent
from .alac.alac import ALACAgent
from .ctrl import *
from .dpmd import DPMDAgent
from .idem import IDEMAgent
from .ppo import PPOAgent
from .qsm import QSMAgent
from .sac import SACAgent
from .sdac import SDACAgent
from .td3 import TD3Agent
from .td7.td7 import TD7Agent
from .unirep import *

__all__ = [
"BaseAgent",
Expand All @@ -15,5 +19,10 @@
"SDACAgent",
"DPMDAgent",
"PPOAgent",
"QSMAgent",
"IDEMAgent",
"ALACAgent",
"CtrlTD3Agent",
"CtrlQSMAgent",
"ACAAgent",
]
Loading