sampling | Modular

Python module

Token sampling algorithms for text generation pipelines.

Token sampling algorithms.

greedy_acceptance_sampler()

max.pipelines.lib.sampling.sampling.greedy_acceptance_sampler(device)

source

Builds a graph that implements strict greedy acceptance for MTP.

Draft tokens are accepted only when they match the argmax of the target logits at each position. Always produces a recovered token for every draft position and a bonus token from the final (+1) target position.

Parameters:

device (DeviceRef) – Device for the graph.

Returns:

A graph that takes draft tokens, target logits, and target logit offsets and outputs the first rejected index, target tokens for all draft positions, and a bonus token.

Return type:

Graph

rejection_sampler()

max.pipelines.lib.sampling.sampling.rejection_sampler(device, *, seed=0)

source

Builds a graph that implements speculative decoding rejection sampling.

Accepts or rejects draft tokens using target vs draft probabilities and resamples from the target distribution when rejected.

Parameters:

  • device (DeviceRef) – Device for the graph.
  • seed (int) – Random seed for sampling.

Returns:

A graph that takes draft tokens, draft logits, and target logits and outputs accepted tokens and metadata.

Return type:

Graph

rejection_sampler_with_residuals()

max.pipelines.lib.sampling.sampling.rejection_sampler_with_residuals(device, *, seed=0, debug=False)

source

Builds a rejection sampler with residual sampling for speculative decoding.

Computes acceptance ratios for draft tokens, finds first rejection, samples from residual distribution (target - draft), and generates bonus tokens.

Parameters:

Return type:

Graph

token_sampler()

max.pipelines.lib.sampling.sampling.token_sampler(sampling_config, device, return_logits=False)

source

Builds a sampling graph that samples tokens from logits.

Parameters:

  • sampling_config (SamplingConfig) – Sampling configuration (top-k, temperature, etc.).
  • device (DeviceRef) – Device for the graph inputs and ops.
  • return_logits (bool) – Whether the graph should expose logits as an output.

Returns:

A graph that takes logits (and optional penalty inputs) and outputs tokens.

Return type:

Graph

typical_acceptance_sampler()

max.pipelines.lib.sampling.sampling.typical_acceptance_sampler(device, *, seed=0)

source

Builds a target-only stochastic rejection sampler for speculative decoding.

Accepts draft tokens based on coin < p_target(draft_token) where p_target is computed after applying temperature, top-k, and top-p filtering. No draft probabilities are needed.

Parameters:

  • device (DeviceRef) – Device for the graph.
  • seed (int) – Random seed for sampling.

Returns:

A graph that takes draft tokens, target logits, target logit offsets, and sampling parameters, and outputs the first rejected index, recovered tokens, and a bonus token.

Return type:

Graph