diff --git a/src/maxdiffusion/configs/base_2_base.yml b/src/maxdiffusion/configs/base_2_base.yml index b535762e..fc344d42 100644 --- a/src/maxdiffusion/configs/base_2_base.yml +++ b/src/maxdiffusion/configs/base_2_base.yml @@ -27,6 +27,8 @@ write_timing_metrics: True save_config_to_gcs: False log_period: 10000000000 # Flushes Tensorboard +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'stabilityai/stable-diffusion-2-base' unet_checkpoint: '' revision: 'main' diff --git a/src/maxdiffusion/configs/base_flux_dev.yml b/src/maxdiffusion/configs/base_flux_dev.yml index a7ca1350..29bd32b8 100644 --- a/src/maxdiffusion/configs/base_flux_dev.yml +++ b/src/maxdiffusion/configs/base_flux_dev.yml @@ -27,6 +27,8 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' diff --git a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml index 0da843fd..c8480365 100644 --- a/src/maxdiffusion/configs/base_flux_dev_multi_res.yml +++ b/src/maxdiffusion/configs/base_flux_dev_multi_res.yml @@ -27,6 +27,8 @@ write_timing_metrics: True save_config_to_gcs: False log_period: 100 +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-dev' clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' diff --git a/src/maxdiffusion/configs/base_flux_schnell.yml b/src/maxdiffusion/configs/base_flux_schnell.yml index 300ec039..cd3a289b 100644 --- a/src/maxdiffusion/configs/base_flux_schnell.yml +++ b/src/maxdiffusion/configs/base_flux_schnell.yml @@ -27,6 +27,8 @@ write_timing_metrics: True save_config_to_gcs: False log_period: 100 +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'black-forest-labs/FLUX.1-schnell' clip_model_name_or_path: 'ariG23498/clip-vit-large-patch14-text-flax' t5xxl_model_name_or_path: 'ariG23498/t5-v1-1-xxl-flax' diff --git a/src/maxdiffusion/configs/base_wan_14b.yml b/src/maxdiffusion/configs/base_wan_14b.yml index 8149c829..2defc4ad 100644 --- a/src/maxdiffusion/configs/base_wan_14b.yml +++ b/src/maxdiffusion/configs/base_wan_14b.yml @@ -27,6 +27,8 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'Wan-AI/Wan2.1-T2V-14B-Diffusers' # Overrides the transformer from pretrained_model_name_or_path @@ -151,6 +153,7 @@ logical_axis_rules: [ ['conv_batch', ['data','fsdp']], ['out_channels', 'tensor'], ['conv_out', 'fsdp'], + ['conv_in', ''] # not sharded ] data_sharding: [['data', 'fsdp', 'tensor']] diff --git a/src/maxdiffusion/configs/base_xl.yml b/src/maxdiffusion/configs/base_xl.yml index aa07940e..2cfc0f54 100644 --- a/src/maxdiffusion/configs/base_xl.yml +++ b/src/maxdiffusion/configs/base_xl.yml @@ -27,6 +27,8 @@ gcs_metrics: False save_config_to_gcs: False log_period: 100 +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0' unet_checkpoint: '' revision: 'refs/pr/95' diff --git a/src/maxdiffusion/configs/base_xl_lightning.yml b/src/maxdiffusion/configs/base_xl_lightning.yml index ee2e59d5..28ff10e7 100644 --- a/src/maxdiffusion/configs/base_xl_lightning.yml +++ b/src/maxdiffusion/configs/base_xl_lightning.yml @@ -26,6 +26,8 @@ write_timing_metrics: True save_config_to_gcs: False log_period: 100 +tmp_dir: '/tmp' # directory for downloading gs:// files + pretrained_model_name_or_path: 'stabilityai/stable-diffusion-xl-base-1.0' unet_checkpoint: '' revision: 'refs/pr/95' diff --git a/src/maxdiffusion/pipelines/wan/wan_pipeline.py b/src/maxdiffusion/pipelines/wan/wan_pipeline.py index c78d8bae..a4d950a2 100644 --- a/src/maxdiffusion/pipelines/wan/wan_pipeline.py +++ b/src/maxdiffusion/pipelines/wan/wan_pipeline.py @@ -71,8 +71,9 @@ def create_sharded_logical_transformer( ): def create_model(rngs: nnx.Rngs, wan_config: dict): - wan_transformer = WanModel(**wan_config, rngs=rngs) - return wan_transformer + with nn_partitioning.axis_rules(config.logical_axis_rules): + wan_transformer = WanModel(**wan_config, rngs=rngs) + return wan_transformer # 1. Load config. if restored_checkpoint: @@ -204,15 +205,16 @@ def load_tokenizer(cls, config: HyperParameters): def load_vae(cls, devices_array: np.array, mesh: Mesh, rngs: nnx.Rngs, config: HyperParameters): def create_model(rngs: nnx.Rngs, config: HyperParameters): - wan_vae = AutoencoderKLWan.from_config( - config.pretrained_model_name_or_path, - subfolder="vae", - rngs=rngs, - mesh=mesh, - dtype=config.activations_dtype, - weights_dtype=config.weights_dtype, - ) - return wan_vae + with nn_partitioning.axis_rules(config.logical_axis_rules): + wan_vae = AutoencoderKLWan.from_config( + config.pretrained_model_name_or_path, + subfolder="vae", + rngs=rngs, + mesh=mesh, + dtype=config.activations_dtype, + weights_dtype=config.weights_dtype, + ) + return wan_vae # 1. eval shape p_model_factory = partial(create_model, config=config) diff --git a/src/maxdiffusion/pyconfig.py b/src/maxdiffusion/pyconfig.py index 3bb5bd13..a7b9eedc 100644 --- a/src/maxdiffusion/pyconfig.py +++ b/src/maxdiffusion/pyconfig.py @@ -196,14 +196,15 @@ def user_init(raw_keys): # Orbax doesn't save the tokenizer params, instead it loads them from the pretrained_model_name_or_path raw_keys["tokenizer_model_name_or_path"] = raw_keys["pretrained_model_name_or_path"] + tmp_dir = raw_keys.get("tmp_dir", "/tmp") if "gs://" in raw_keys["tokenizer_model_name_or_path"]: - raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") + raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], tmp_dir) if "gs://" in raw_keys["pretrained_model_name_or_path"]: - raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], "/tmp") + raw_keys["pretrained_model_name_or_path"] = max_utils.download_blobs(raw_keys["pretrained_model_name_or_path"], tmp_dir) if "gs://" in raw_keys["unet_checkpoint"]: - raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], "/tmp") + raw_keys["unet_checkpoint"] = max_utils.download_blobs(raw_keys["unet_checkpoint"], tmp_dir) if "gs://" in raw_keys["tokenizer_model_name_or_path"]: - raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"], "/tmp") + raw_keys["tokenizer_model_name_or_path"] = max_utils.download_blobs(raw_keys["tokenizer_model_name_or_path"], tmp_dir) if "gs://" in raw_keys["dataset_name"]: raw_keys["dataset_name"] = max_utils.download_blobs(raw_keys["dataset_name"], raw_keys["dataset_save_location"]) raw_keys["dataset_save_location"] = raw_keys["dataset_name"]