Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speedup sample #7578

Draft
wants to merge 11 commits into
base: main
Choose a base branch
from
Draft

Speedup sample #7578

wants to merge 11 commits into from

Conversation

ricardoV94
Copy link
Member

@ricardoV94 ricardoV94 commented Nov 19, 2024

TODO

  • back_compat for grad_out in LogpDlogpFunc
  • faster model.logp() when everything is measurable, (just need to make Censored have its own logp)?

Major changes

  1. internal uses of logp_dlogp_function now work with raveled inputs. External use will issue a warning unless ravel_inputs is specified explicitly. Eventually it will only be possible to use ravel_inputs=True.
  2. Step samplers arguments besides vars must be passed by keyword
  3. RaveledVars point_map_info is now a 4-n tuple, with size introduced.
  4. assign_step_method does not call instantiate_steppers, but returns arguments needed for the latter.
  5. Allow passing compile_kwargs to pm.sample which is then forwarded to the step samplers functions

Enhancement

This PR speedups NUTS (and other step samplers), by:

  1. Avoiding many variable unravel and copies, by doing it inside PyTensor
  2. Avoiding copies when setting shared variables (borrow=True)
  3. Setting trust_input=True which can have a large overhead.
  4. Disabling GC collection for the C-backend function (related to Consider disabling PyTensor GC in sampling functions #7539)
  5. Using slots for faster attribute access in the Tree class (and smaller footprint)
  6. Inlining some functions and being more lazy when possible

This PR speedups sample by:

  1. Avoiding way too many pytensor function compilations (model.initial_point() and very silly trace.fn after slicing at the end. It's also silly to compile the same function for every trace. We should just copy it.
  2. Avoid initializing NUTS step sampler just for most of the times then immediately discard it and using the one inside init_nuts. This will also reduce the path towards external samplers with nutpie/numpyro as it avoids the costly and useless compilation of the logp_dlogp_function
  3. Using trust_input and avoiding deepcopies in the trace function by using pytensor.In(borrow=True) and pytensor.Out(borrow=True).

Further speedups should come for free from #7539, specially for the Numba backend.

Benchmark

In the example below, sampling time is now only 7x slower than nutpie (5s vs 0.7s), compared to 13.5x slower (9.45s vs 0.7s) before. This assuming the same number of logp evals, in fact nutpie tuning allows us to get out with half the evals! We can hopefully bring it over.

Full time until from pm.sample to getting a trace is roughly halved as well (7.5s vs 14.4s), although this gain is not proportional to the number of draws.

With compile_kwargs=(mode="NUMBA"), sampling time is only 3x slower (2.3s).

import time
import pymc as pm
import numpy as np
import nutpie
import pandas as pd

# Load the radon dataset
data = pd.read_csv(pm.get_data("radon.csv"))
data["log_radon"] = data["log_radon"].astype(np.float64)
county_idx, counties = pd.factorize(data.county)
coords = {"county": counties, "obs_id": np.arange(len(county_idx))}

# Create a simple hierarchical model for the radon dataset
with pm.Model(coords=coords, check_bounds=False) as model:
    intercept = pm.Normal("intercept", sigma=10)

    # County effects
    raw = pm.ZeroSumNormal("county_raw", dims="county")
    sd = pm.HalfNormal("county_sd")
    county_effect = pm.Deterministic("county_effect", raw * sd, dims="county")

    # Global floor effect
    floor_effect = pm.Normal("floor_effect", sigma=2)

    # County:floor interaction
    raw = pm.ZeroSumNormal("county_floor_raw", dims="county")
    sd = pm.HalfNormal("county_floor_sd")
    county_floor_effect = pm.Deterministic(
        "county_floor_effect", raw * sd, dims="county"
    )

    mu = (
        intercept
        + county_effect[county_idx]
        + floor_effect * data.floor.values
        + county_floor_effect[county_idx] * data.floor.values
    )

    sigma = pm.HalfNormal("sigma", sigma=1.5)
    pm.Normal(
        "log_radon", mu=mu, sigma=sigma, observed=data.log_radon.values, dims="obs_id"
    )

from pymc.model.transform.optimization import freeze_dims_and_data
model = freeze_dims_and_data(model)
compiled_model = nutpie.compile_pymc_model(model)

start = time.perf_counter()
# More draws to make up for the fact that nutpie tunes better
trace_pymc = nutpie.sample(compiled_model, chains=1, tune=500, draws=1500, progress_bar=False)
end = time.perf_counter()
print(end - start)
idata = pm.sample(
    model=model, 
    chains=1,
    tune=500, 
    draws=500, 
    progressbar=False, 
    compute_convergence_checks=False, 
    return_inferencedata=False,
    # compile_kwargs=dict(mode="NUMBA")
)
print(idata._report.t_sampling)

📚 Documentation preview 📚: https://pymc--7578.org.readthedocs.build/en/7578/

@ricardoV94 ricardoV94 changed the title WIP Speedup NUTS Speedup sample Nov 22, 2024
@ricardoV94 ricardoV94 force-pushed the speedup_nuts branch 4 times, most recently from d6f9e14 to 87fd299 Compare November 23, 2024 23:06
@ricardoV94 ricardoV94 added major Include in major changes release notes section enhancements samplers labels Nov 23, 2024
@ricardoV94 ricardoV94 force-pushed the speedup_nuts branch 2 times, most recently from 0f4972c to 874ae65 Compare November 24, 2024 00:44
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements maintenance major Include in major changes release notes section samplers
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants