-
Notifications
You must be signed in to change notification settings - Fork 21
Add MLX backend support for Nutpie compilation #254
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add MLX backend support for Nutpie compilation #254
Conversation
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.
|
Thanks, that looks great! |
Bump MLX version requirement to >=0.29.0 in pyproject.toml for all relevant extras. In compile_pymc.py, JIT compile the logp function using mx.compile for improved performance, aligning with JAX backend behavior.
Good point, that simple addition brings between 5% to 20% more performance! @aseyboldt |
|
@aseyboldt solve the test issue to work only on macs with intel chips. |
|
@aseyboldt can you give me a hand? The test failing its strange. My local pass everythig. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That failure is annoying. For some reason the results seem to differ between different machines? I think we really should figure out what's going on here. Maybe it helps if we print the first couple of values in warmup_posterior to see if the initial values are already different, or if small differences accumulate?
| updated.update(**updates) | ||
|
|
||
| # Convert to MLX arrays if using MLX backend (indicated by force_single_core) | ||
| if self._force_single_core: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should not use that argument to detect mlx.
How about we add an attribute _convert_data_item or so to the dataclass, that contains a function that transforms data arrays? We could then also use that for jax.
Introduces MLX as a backend option in compile_pymc_model, allowing gradient computation via MLX or Pytensor. Updates dependency groups to include MLX, extends internal functions to handle MLX mode, and adds corresponding tests for deterministic sampling with MLX.