Skip to content

Conversation

@cetagostini
Copy link

Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.

Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
@aseyboldt
Copy link
Member

Thanks, that looks great!
I think we probably should call mlx.compile on the final functions though?

Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
@cetagostini
Copy link
Author

cetagostini commented Oct 27, 2025

Thanks, that looks great! I think we probably should call mlx.compile on the final functions though?

Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt

@cetagostini
Copy link
Author

@aseyboldt solve the test issue to work only on macs with intel chips.

@cetagostini
Copy link
Author

@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig.

@cetagostini cetagostini requested review from aseyboldt and jessegrabowski and removed request for aseyboldt October 30, 2025 12:46
Copy link
Member

@aseyboldt aseyboldt left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?

updated.update(**updates)

# Convert to MLX arrays if using MLX backend (indicated by force_single_core)
if self._force_single_core:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants