sampling | Modular
Python module
Token sampling algorithms for text generation pipelines.
Token sampling algorithms.
greedy_acceptance_sampler()
max.pipelines.lib.sampling.sampling.greedy_acceptance_sampler(device)
Builds a graph that implements strict greedy acceptance for MTP.
Draft tokens are accepted only when they match the argmax of the target logits at each position. Always produces a recovered token for every draft position and a bonus token from the final (+1) target position.
rejection_sampler()
max.pipelines.lib.sampling.sampling.rejection_sampler(device, *, seed=0)
Builds a graph that implements speculative decoding rejection sampling.
Accepts or rejects draft tokens using target vs draft probabilities and resamples from the target distribution when rejected.
rejection_sampler_with_residuals()
max.pipelines.lib.sampling.sampling.rejection_sampler_with_residuals(device, *, seed=0, debug=False)
Builds a rejection sampler with residual sampling for speculative decoding.
Computes acceptance ratios for draft tokens, finds first rejection, samples from residual distribution (target - draft), and generates bonus tokens.
token_sampler()
max.pipelines.lib.sampling.sampling.token_sampler(sampling_config, device, return_logits=False)
Builds a sampling graph that samples tokens from logits.
-
Parameters:
-
Returns:
-
A graph that takes logits (and optional penalty inputs) and outputs tokens.
-
Return type:
typical_acceptance_sampler()
max.pipelines.lib.sampling.sampling.typical_acceptance_sampler(device, *, seed=0)
Builds a target-only stochastic rejection sampler for speculative decoding.
Accepts draft tokens based on coin < p_target(draft_token) where
p_target is computed after applying temperature, top-k, and top-p
filtering. No draft probabilities are needed.