Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_2_base.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ write_timing_metrics: True
save_config_to_gcs: False
log_period: 10000000000 # Flushes Tensorboard

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base'
unet_checkpoint: ''
revision: 'main'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ gcs_metrics: False
save_config_to_gcs: False
log_period: 100

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_flux_dev_multi_res.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ write_timing_metrics: True
save_config_to_gcs: False
log_period: 100

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev'
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_flux_schnell.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ write_timing_metrics: True
save_config_to_gcs: False
log_period: 100

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-schnell'
clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax'
t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax'
Expand Down
3 changes: 3 additions & 0 deletions src/maxdiffusion/configs/base_wan_14b.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ gcs_metrics: False
save_config_to_gcs: False
log_period: 100

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers'

# Overrides the transformer from pretrained_model_name_or_path
Expand Down Expand Up @@ -151,6 +153,7 @@ logical_axis_rules: [
['conv_batch', ['data','fsdp']],
['out_channels', 'tensor'],
['conv_out', 'fsdp'],
['conv_in', ''] # not sharded
]
data_sharding: [['data', 'fsdp', 'tensor']]

Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_xl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ gcs_metrics: False
save_config_to_gcs: False
log_period: 100

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
unet_checkpoint: ''
revision: 'refs/pr/95'
Expand Down
2 changes: 2 additions & 0 deletions src/maxdiffusion/configs/base_xl_lightning.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@ write_timing_metrics: True
save_config_to_gcs: False
log_period: 100

tmp_dir: '/tmp' # directory for downloading gs:// files

pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0'
unet_checkpoint: ''
revision: 'refs/pr/95'
Expand Down
24 changes: 13 additions & 11 deletions src/maxdiffusion/pipelines/wan/wan_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,9 @@ def create_sharded_logical_transformer(
):

def create_model(rngs: nnx.Rngs, wan_config: dict):
wan_transformer = WanModel(**wan_config, rngs=rngs)
return wan_transformer
with nn_partitioning.axis_rules(config.logical_axis_rules):
wan_transformer = WanModel(**wan_config, rngs=rngs)
return wan_transformer

# 1. Load config.
if restored_checkpoint:
Expand Down Expand Up @@ -204,15 +205,16 @@ def load_tokenizer(cls, config: HyperParameters):
def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters):

def create_model(rngs: nnx.Rngs, config: HyperParameters):
wan_vae = AutoencoderKLWan.from_config(
config.pretrained_model_name_or_path,
subfolder="vae",
rngs=rngs,
mesh=mesh,
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
)
return wan_vae
with nn_partitioning.axis_rules(config.logical_axis_rules):
wan_vae = AutoencoderKLWan.from_config(
config.pretrained_model_name_or_path,
subfolder="vae",
rngs=rngs,
mesh=mesh,
dtype=config.activations_dtype,
weights_dtype=config.weights_dtype,
)
return wan_vae

# 1. eval shape
p_model_factory = partial(create_model, config=config)
Expand Down
9 changes: 5 additions & 4 deletions src/maxdiffusion/pyconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,14 +196,15 @@ def user_init(raw_keys):

# Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path
raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"]
tmp_dir = raw_keys.get("tmp_dir", "/tmp")
if "gs://" in raw_keys["tokenizer_model_name_or_path"]:
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], tmp_dir)
if "gs://" in raw_keys["pretrained_model_name_or_path"]:
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp")
raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], tmp_dir)
if "gs://" in raw_keys["unet_checkpoint"]:
raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], "/tmp")
raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], tmp_dir)
if "gs://" in raw_keys["tokenizer_model_name_or_path"]:
raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"], "/tmp")
raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"], tmp_dir)
if "gs://" in raw_keys["dataset_name"]:
raw_keys["dataset_name"] = max_utils.download_blobs(raw_keys["dataset_name"], raw_keys["dataset_save_location"])
raw_keys["dataset_save_location"] = raw_keys["dataset_name"]
Expand Down
Loading