agents

import "github.com/umbralcalc/stochadex/pkg/agents"

Package agents provides decision-making agents that operate over a generic Environment[S, A] interface. The package is intended to host any agent built on the same environment framework. Currently it ships MCTS (UCT) as the only agent, with MAST as an optional rollout strategy on top.

Per-player terminal scores are []float64 in [0,1] (the established stochadex value convention). Codecs (encoder/decoder for S into the stochadex row’s []float64) are supplied by the caller as function fields on each partition. This package does not depend on any encoding protocol.

Key Features:

Usage Patterns:

Index

Constants

Default hyperparameters used when MCTSConfig fields are zero.

const (
    MCTSDefaultSimulations     = 120
    MCTSDefaultRolloutMaxSteps = 220
    MCTSDefaultMaxTreeDepth    = 14
    MCTSDefaultExploration     = 1.41
)

Param key for direct (within-step) best-action input. The value should be a 1-element slice — the legal-action index. A negative value (e.g. the -1 sentinel) leaves the state unchanged.

const ApplyParamBestIdx = "best_legal_idx"

Param key for state-history (lag-1) best-action input. The value (set via ParamsAsPartitions) is a 1-element slice containing the partition index of the upstream partition whose row[BestIdxSlot] holds the best-action index.

const ApplyParamBestIdxPartition = "best_idx_partition"

MASTAggregationParamPartition is the params_as_partitions key used by downstream samplers to learn this partition’s index for state-history reads. The value is a 1-element slice containing the partition index.

const MASTAggregationParamPartition = "mast_aggregates_partition"

MASTAggregationParamUpdates is the params_from_upstream key used to read the variable-length update batch from an upstream rollout partition.

const MASTAggregationParamUpdates = "mast_updates"

MASTDefaultTau is the softmax temperature applied to MAST means when sampling. Smaller → more exploitative; larger → closer to uniform. 1.0 diverges from uniform within ~50 updates per key in practice.

const MASTDefaultTau = 1.0

MASTSamplePrior is the score assigned to actions whose key has not yet been observed, so they retain selection probability early in search. 0.5 sits at the midpoint of the [0, 1] reward range — neither encouraged nor discouraged.

const MASTSamplePrior = 0.5

Param key used by MCTSRolloutIteration.Iterate to read the leaf state from an upstream MCTSTreeIteration via params_from_upstream. The slice should be of length StateWidth followed by a single has_leaf flag (matching MCTSTreeIteration’s row layout: leaf_state then has_leaf). Use Indices on the NamedUpstreamConfig to slice out the leaf_state + has_leaf section of the tree’s row.

const MCTSRolloutParamLeaf = "leaf"

Param key used by MCTSTreeIteration.Iterate to read rollout scores from an upstream MCTSRolloutIteration via params_from_upstream (within-step). The slice should be of length Players + 1 (P scores + ok flag) — the layout produced by MCTSRolloutIteration.

This mode creates a within-step dependency that breaks if the rollout partition also depends on the tree (the standard MCTS pipeline does). Use MCTSTreeParamRolloutScoresPartition for the lag-1 state-history mode instead when wiring the tree + rollout pipeline.

const MCTSTreeParamRolloutScores = "rollout_scores"

Param key used by MCTSTreeIteration.Iterate to read rollout scores from an upstream MCTSRolloutIteration via params_as_partitions (state-history, lag-1). The value is a 1-element slice containing the rollout partition’s index; the tree reads stateHistories[idx].Values row 0 (= the previous step’s rollout output) at runtime.

This is the standard wiring used by NewMCTSSelfPlayPartitions: rollout reads tree’s leaf within-step (so the rollout sees the freshest leaf), and tree reads rollout’s scores lag-1 (so rollout doesn’t have to wait on tree for last step’s scores). The 1-step lag aligns correctly: at step N+1 tree backs up the path it selected at step N with scores from rollout at step N (which were for that very leaf).

State-history mode takes priority over within-step mode if both keys are present.

const MCTSTreeParamRolloutScoresPartition = "rollout_scores_partition"

Param key used by MCTSTreeIteration.Iterate to read the current search root state. Whenever this param’s value differs from the cached root encoding, the tree is reset to the decoded new root. The slice should be of length StateWidth.

Set this via the embedded simulation run’s outer params_from_upstream (e.g. an outer apply partition piping its current game state into the inner sim via the “<innerName>/root_state” forwarding mechanism). When MCTSTreeIteration is used standalone with no outer pipeline, the param is absent and the tree retains the root set at Configure time.

const MCTSTreeParamRootState = "root_state"

Row layout slot accessors. Use these to compute params_from_upstream indices when wiring downstream partitions to MCTSTreeIteration’s row.

const MCTSTreeRowBestRootIdx = 0

const MCTSTreeRowLeafStateOffset = 1

TTTWidth is the encoded row width: 9 cells + current player.

const TTTWidth = 10

Variables

WinLines is the eight three-in-a-row patterns.

var WinLines = [...][3]int{
    {0, 1, 2}, {3, 4, 5}, {6, 7, 8},
    {0, 3, 6}, {1, 4, 7}, {2, 5, 8},
    {0, 4, 8}, {2, 4, 6},
}

func MASTAggregationCountSlot

func MASTAggregationCountSlot(k int) int

MASTAggregationCountSlot returns the row offset of the count for key k.

func MASTAggregationRowWidth

func MASTAggregationRowWidth(maxKeys int) int

MASTAggregationRowWidth returns the required state_width for an MASTAggregationIteration with the given key bound.

func MASTAggregationSumSlot

func MASTAggregationSumSlot(k int) int

MASTAggregationSumSlot returns the row offset of the sum for key k.

func MASTMeanForKey

func MASTMeanForKey(row []float64, k int) (mean float64, count int)

MASTMeanForKey reads the running mean reward for key k from a row in the MASTAggregationIteration’s layout. Returns (0, 0) when the key has not been observed. Used by samplers that have read the partition’s row via params_as_partitions.

func MASTRolloutNumPathOffset

func MASTRolloutNumPathOffset(players int) int

MASTRolloutNumPathOffset returns the row offset of the num_path counter.

func MASTRolloutOkOffset

func MASTRolloutOkOffset(players int) int

MASTRolloutOkOffset returns the row offset of the ok flag.

func MASTRolloutPathOffset

func MASTRolloutPathOffset(players int) int

MASTRolloutPathOffset returns the row offset of the first (key_idx, reward) pair.

func MASTRolloutRowWidth

func MASTRolloutRowWidth(players, maxPath int) int

MASTRolloutRowWidth returns the required state_width for an MASTRolloutIteration with the given player count and path bound.

func MASTRolloutScoresOffset

func MASTRolloutScoresOffset(i int) int

MASTRolloutScoresOffset returns the row offset of score slot i.

func MCTSRolloutRowWidth

func MCTSRolloutRowWidth(players int) int

MCTSRolloutRowWidth returns the required InitStateValues / StateWidth for a MCTSRolloutIteration with the given player count.

func MCTSTreeRowHasLeafOffset

func MCTSTreeRowHasLeafOffset(stateWidth int) int

func MCTSTreeRowVisitsOffset

func MCTSTreeRowVisitsOffset(stateWidth int) int

func MCTSTreeRowWidth

func MCTSTreeRowWidth(stateWidth, maxLegalActions int) int

MCTSTreeRowWidth returns the required InitStateValues / StateWidth for a MCTSTreeIteration with the given encoded-state width and max legal-action count.

func MCTSTreeRowWinsOffset

func MCTSTreeRowWinsOffset(stateWidth, maxLegalActions int) int

func TTTEncode

func TTTEncode(s TTTState) []float64

TTTEncode produces the []float64 row representation of a TTTState. Done/Winner are derivable from the cells so they are not encoded.

func TTTKey

func TTTKey(a TTTAction) string

TTTKey is a stable string key for an action, useful for any aggregation or lookup that needs a string-typed action identifier.

func WinnerToTerminal

func WinnerToTerminal(winner, players int, done bool) []float64

WinnerToTerminal builds an Environment.Terminal-compatible result from a binary winner. For envs whose native terminal predicate is “winner int, done bool” rather than per-player scores, embed this in your Environment implementation:

func (g *MyGame) Terminal(s State) ([]float64, bool) {
    w, done := g.winnerOrDone(s)
    return agents.WinnerToTerminal(w, g.Players(s), done), done
}

type ApplyIteration

ApplyIteration advances the environment by one ply per stochadex outer step using a best-action signal supplied either via params_from_upstream (within-step) OR via params_as_partitions (lagged read of an upstream partition’s state-history row). The partition row is the encoded current game state; one outer step decodes the row, applies the chosen legal action, and writes the encoded post-move state back.

Row layout (width = StateWidth):

row[0 .. StateWidth-1]   encoded current game state.

Two read modes

  1. Direct param mode (ApplyParamBestIdx): set ParamsFromUpstream[ApplyParamBestIdx] to read the best-action index directly within the same step. Used when the upstream partition is not in a self-referential cycle with apply.

  2. State-history mode (ApplyParamBestIdxPartition + BestIdxSlot): set ParamsAsPartitions[ApplyParamBestIdxPartition] to the upstream partition’s name, and set BestIdxSlot to the offset of the best-action index within that partition’s row. Apply will read the PREVIOUS step’s row 0 of that partition each step. Used to break the apply ↔︎ search dependency cycle in NewMCTSSelfPlayPartitions: search reads apply within-step (so apply does not wait on search), and apply reads search lagged by one step (so search does not wait on apply). The 1-step lag aligns correctly because at step N apply applies the best_idx that search produced at step N-1 for apply’s state at step N-1 — which is the same state apply currently holds (apply only advances when it applies a move).

State-history mode takes priority if both are configured.

Warm fields (must be set before Configure):

Optional warm field:

type ApplyIteration[S any, A any] struct {
    Env         Environment[S, A]
    Decoder     func([]float64) (S, error)
    Encoder     func(S) []float64
    BestIdxSlot int
    // contains filtered or unexported fields
}

func (*ApplyIteration[S, A]) Configure

func (m *ApplyIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)

Configure implements simulator.Iteration.

func (*ApplyIteration[S, A]) Iterate

func (m *ApplyIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64

Iterate implements simulator.Iteration.

type Environment

Environment is the game/decision-process interface that MCTS searches over. Implementations are pure: Legal/Apply must not mutate s, and Apply must return a fresh value (the search clones aggressively).

Terminal returns the per-player [0,1] score vector and a done flag. The score vector lets the env represent draws (0.5/0.5), graded scoring (Catan-like), and standard winner-takes-all (one-hot) without faking a “winner” int. For binary games, see WinnerToTerminal in rollout.go.

Actor returns the index of the player whose decision creates the next edge from s; backups credit nodes by Actor at the parent. Players returns the total seat count and bounds the score vector length.

type Environment[S any, A any] interface {
    Legal(s S) []A
    Apply(s S, a A) (S, error)
    Terminal(s S) (scores []float64, done bool)
    Actor(s S) int
    Players(s S) int
}

type MASTAggregationIteration

MASTAggregationIteration is a stochadex iteration that maintains running (count, sum) pairs per action-key index for MAST (Move-Average Sampling Technique). Each step it reads a variable-length update batch from an upstream rollout partition via params_from_upstream and applies the (key_idx, reward) increments to its row.

Row layout (width = 2 * MaxKeys):

row[2*k]     count for key k
row[2*k+1]   sum   for key k

Use MASTAggregationRowWidth(K) to compute state_width.

Update batch format

The upstream batch is a single []float64 with the layout

[num_updates, key_idx_0, reward_0, key_idx_1, reward_1, ...]

where num_updates is the number of valid (key_idx, reward) pairs in the slice. MASTRolloutIteration emits exactly this layout in its row’s path-suffix slots; wire MASTAggregationIteration’s params_from_upstream (key MASTAggregationParamUpdates) to those slots.

Out-of-range key indices and updates beyond the slice’s declared num_updates are silently dropped.

Read access

Downstream samplers (e.g. MASTRolloutIteration) read the aggregates via state-history mode (lag-1) using params_as_partitions. See MASTAggregationParamPartition for the canonical key.

type MASTAggregationIteration[A any] struct {
    MaxKeys int
    // contains filtered or unexported fields
}

func (*MASTAggregationIteration[A]) Configure

func (m *MASTAggregationIteration[A]) Configure(partitionIndex int, settings *simulator.Settings)

Configure implements simulator.Iteration.

func (*MASTAggregationIteration[A]) Iterate

func (m *MASTAggregationIteration[A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64

Iterate implements simulator.Iteration.

type MASTRolloutIteration

MASTRolloutIteration is a stochadex iteration that runs one MAST-biased rollout per step. It reads the leaf state from an upstream MCTSTreeIteration (within-step), reads the running aggregates from a MASTAggregationIteration (lag-1, via params_as_partitions), runs a playout sampling each ply via softmax over the aggregates, and emits the per-player scores plus the (key_idx, reward) path that the MASTAggregationIteration will absorb on the next step.

Row layout (width = Players + 2 + 2 * MaxPath):

row[0 .. Players-1]                           per-player [0,1] scores
row[Players]                                  ok flag (1 = scores valid)
row[Players+1]                                num_path (length of valid pairs)
row[Players+2 .. Players+1+2*MaxPath]         (key_idx, reward) pairs,
                                              padded with zeros

The (num_path, pairs…) suffix matches MASTAggregationParamUpdates so the downstream MASTAggregationIteration can read it as a single slice.

Warm fields (must be set before Configure):

type MASTRolloutIteration[S any, A any] struct {
    Env      Environment[S, A]
    Cfg      MCTSConfig[S, A]
    Decoder  func([]float64) (S, error)
    KeyToIdx func(A) int
    MaxKeys  int
    MaxPath  int
    Players  int
    Tau      float64
    Progress func(s S, player int) (float64, bool)
    // contains filtered or unexported fields
}

func (*MASTRolloutIteration[S, A]) Configure

func (m *MASTRolloutIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)

Configure implements simulator.Iteration.

func (*MASTRolloutIteration[S, A]) Iterate

func (m *MASTRolloutIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64

Iterate implements simulator.Iteration.

type MCTSConfig

MCTSConfig holds UCT hyperparameters and the rollout driver.

Rollout is the single graded rollout signature: it returns a per-player score vector in [0,1] and an ok flag (false signals “no signal” — caller counts the visit but skips the win credit). For binary winner games or games using a Progress proxy, build the rollout from helpers in rollout.go (UniformRandomRollout, OneHotFromWinner, FromProgress).

Progress is an optional per-state, per-player [0,1] value proxy used by the FromProgress rollout adapter to score truncated rollouts.

type MCTSConfig[S any, A any] struct {
    Simulations     int
    Exploration     float64
    MaxTreeDepth    int
    RolloutMaxSteps int
    Rollout         MCTSRolloutFn[S, A]
    Progress        func(s S, player int) (float64, bool)
}

func (*MCTSConfig[S, A]) ApplyDefaults

func (c *MCTSConfig[S, A]) ApplyDefaults()

ApplyDefaults fills in zero-valued hyperparameters with the package defaults, mutating the receiver. Called once at the start of each search run; safe to call multiple times.

Exported so external packages building on MCTSConfig can share the defaults logic without duplicating it.

type MCTSEdgeStat

MCTSEdgeStat is per-action telemetry exposed at the root after a search. Useful for JSON reports or logging the search’s distribution over moves.

type MCTSEdgeStat[A any] struct {
    Action       A       `json:"action"`
    Visits       int     `json:"visits"`
    MeanForActor float64 `json:"mean_for_actor"`
}

func RunMCTSSearch

func RunMCTSSearch[S any, A any](env Environment[S, A], root S, cfg MCTSConfig[S, A], baseSeed uint64, sims int) (A, []MCTSEdgeStat[A], error)

RunMCTSSearch runs sims UCT simulations from root and returns the best legal action plus per-edge stats. Independent of stochadex partitions — useful for one-shot “what’s the best move?” queries.

Defaults are filled in from the package constants. If cfg.Rollout is nil, UniformRandomRollout is used.

type MCTSRolloutFn

MCTSRolloutFn is the single rollout signature used by the search. It runs a stochastic playout from s for at most maxSteps actions and returns:

The seed argument is the rollout’s full RNG seed; implementations should be deterministic given the same seed.

type MCTSRolloutFn[S any, A any] func(env Environment[S, A], s S, maxSteps int, seed uint64) (scores []float64, ok bool, err error)

func FromProgress

func FromProgress[S any, A any](inner MCTSRolloutFn[S, A], progress func(s S, player int) (float64, bool)) MCTSRolloutFn[S, A]

FromProgress wraps an inner rollout so that truncated rollouts (ok=false from inner) are rescued by scoring the final state via the supplied progress function. progress returns a per-player [0,1] proxy of “how close is this player to winning” and an ok flag (false = no proxy available for that player; treated as 0).

All-equal progress vectors carry no comparative signal and are treated as no signal (ok=false from the wrapper) so the search relies on UCB exploration alone — see the docstring on MCTSTree.backupVisits for the stall-move failure mode this avoids.

FromProgress needs to know the final state of the inner rollout, so it re-runs the playout itself rather than wrapping inner. The inner rollout is therefore only consulted for early termination.

func UniformRandomRollout

func UniformRandomRollout[S any, A any]() MCTSRolloutFn[S, A]

UniformRandomRollout returns a MCTSRolloutFn that plays uniformly random legal actions until either Terminal returns done or maxSteps is reached. On termination the env’s Terminal scores are returned. On truncation (maxSteps reached without termination) the rollout returns ok=false — compose with FromProgress if you have a progress proxy to score truncated rollouts.

type MCTSRolloutIteration

MCTSRolloutIteration runs one rollout per stochadex step. It reads the leaf state to roll out from via params_from_upstream (typically wired to a MCTSTreeIteration’s leaf_state slot via Indices) and outputs a per-player score vector plus an ok flag in its own row.

Row layout (width = Players + 1):

row[0 .. Players-1]   per-player [0,1] scores from the rollout
row[Players]          ok flag (1 if the rollout produced valid scores,
                      0 otherwise — the upstream MCTSTreeIteration uses
                      this to decide whether to apply backupScores or
                      backupVisits with no signal)

Use MCTSRolloutRowWidth(P) to size InitStateValues / state_width.

Stateless across steps — each Iterate is one independent rollout. Swap in FromProgress, WinnerToTerminal, or any custom MCTSRolloutFn via Cfg.Rollout without touching the partition wiring.

Warm fields (must be set before Configure):

type MCTSRolloutIteration[S any, A any] struct {
    Env     Environment[S, A]
    Cfg     MCTSConfig[S, A]
    Decoder func([]float64) (S, error)
    Players int
    // contains filtered or unexported fields
}

func (*MCTSRolloutIteration[S, A]) Configure

func (m *MCTSRolloutIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)

Configure implements simulator.Iteration.

func (*MCTSRolloutIteration[S, A]) Iterate

func (m *MCTSRolloutIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64

Iterate implements simulator.Iteration.

type MCTSTree

MCTSTree is the in-memory UCT search tree for a fixed root state. Methods are not safe for concurrent use; one MCTSTree per goroutine.

The MCTSTree owns no environment or config — those are passed in to RunOne per-simulation. This makes MCTSTree easy to embed in iterations that may want to tweak config between simulations (e.g. adaptive exploration).

type MCTSTree[S any, A any] struct {
    // contains filtered or unexported fields
}

func NewMCTSTree

func NewMCTSTree[S any, A any](root S) *MCTSTree[S, A]

NewMCTSTree returns a MCTSTree with a single root node containing root.

func (*MCTSTree[S, A]) AdvanceRoot

func (t *MCTSTree[S, A]) AdvanceRoot(env Environment[S, A], legalIdx int)

AdvanceRoot promotes the root’s child at the given legal index to be the new root, preserving its subtree (classic MCTS tree reuse). If that child was never expanded, the tree is rebuilt fresh from the resulting state.

env is needed to compute the post-move state if the subtree is missing.

func (*MCTSTree[S, A]) BackupScores

func (t *MCTSTree[S, A]) BackupScores(path []int, scores []float64)

BackupScores credits each node along path with the score belonging to its actor (visits and wins both increment). Exported wrapper around the internal backupScores so iterations split across selection / rollout / backup partitions can apply the scores when they arrive.

func (*MCTSTree[S, A]) BackupVisits

func (t *MCTSTree[S, A]) BackupVisits(path []int, scores []float64)

BackupVisits is the no-signal-tolerant variant: visits always increment, but wins are only credited when scores is non-nil. See backupVisits docs for the engine-stall reasoning.

func (*MCTSTree[S, A]) NodeCount

func (t *MCTSTree[S, A]) NodeCount() int

NodeCount returns the number of nodes currently in the tree (including the root). Useful for telemetry and capacity tuning.

func (*MCTSTree[S, A]) Reset

func (t *MCTSTree[S, A]) Reset(root S)

Reset replaces the entire tree with a fresh root at the given state. Use when no usable subtree exists (opening move, or after an external change to the root).

func (*MCTSTree[S, A]) Root

func (t *MCTSTree[S, A]) Root() S

Root returns the root state.

func (*MCTSTree[S, A]) RootBestLegalIdx

func (t *MCTSTree[S, A]) RootBestLegalIdx() (int, bool)

RootBestLegalIdx returns the most-visited (then most-winning) child legal index. Ties are broken via reservoir sampling over equally-good children so the choice is not biased toward the first-listed action — important for engine-heavy games where the first listed legal action is often a stall (recycle, pass) and a deterministic first-tie pick would deadlock the agent. Reservoir randomness is seeded from the current tree shape so the result is reproducible without taking an external rng.

func (*MCTSTree[S, A]) RootEdgeStats

func (t *MCTSTree[S, A]) RootEdgeStats(legal []A) []MCTSEdgeStat[A]

RootEdgeStats reports per-action visit counts and mean-for-actor for each expanded child of the root. legal must be the same slice ordering used by the env’s Legal(root); pass env.Legal(tree.Root()) at the call site.

func (*MCTSTree[S, A]) RootStatsByLegalIdx

func (t *MCTSTree[S, A]) RootStatsByLegalIdx(maxLegalActions int) (visits, wins []float64)

RootStatsByLegalIdx returns per-legal-action visit counts and win sums at the root, padded with zeros up to maxLegalActions. Returns (visits, wins) each of length maxLegalActions. Used to expose root statistics in fixed-width row layouts.

func (*MCTSTree[S, A]) RunOne

func (t *MCTSTree[S, A]) RunOne(env Environment[S, A], cfg *MCTSConfig[S, A], rng *rand.Rand)

RunOne does one UCT iteration: selection → expansion → rollout → backup. rng must be seeded by the caller; one call uses one RNG.

Calls cfg.applyDefaults() so a fresh MCTSConfig with only Rollout set works out of the box. The mutation is idempotent (only zero values are filled).

RunOne is the all-in-one path used by RunMCTSSearch and by callers who don’t need the selection / rollout / backup phases as separate stochadex partitions. For the decomposed pipeline use SelectLeaf + BackupScores.

func (*MCTSTree[S, A]) SelectLeaf

func (t *MCTSTree[S, A]) SelectLeaf(env Environment[S, A], cfg *MCTSConfig[S, A], rng *rand.Rand) (path []int, leafState S, leafIdx int, ok bool)

SelectLeaf walks the tree from the root using UCB1 (with first-visit preference for unvisited children) until it reaches an unexpanded edge, then expands it by creating a new child node. Returns the path of node indices from the root’s child down to the new leaf, the leaf’s state, the leaf’s node index, and ok=true. Returns ok=false if the root is terminal, has no legal moves, MaxTreeDepth is reached, or env.Apply fails during expansion (the caller should treat this as “no leaf selected this step”).

SelectLeaf does NOT roll out and does NOT back up — it is the (selection + expansion) half of one MCTS iteration. Pair it with MCTSTree.BackupScores or MCTSTree.BackupVisits to apply the scores when they arrive.

Calls cfg.applyDefaults() so a fresh MCTSConfig works out of the box. The mutation is idempotent (only zero values are filled).

type MCTSTreeIteration

MCTSTreeIteration runs the (selection + expansion + backup) phase of a UCT MCTS search as a stochadex iteration. The tree itself lives on the struct (graph state, fundamentally not []float64-shaped); the partition row exposes a fixed-width summary that downstream partitions can consume via params_from_upstream:

row[MCTSTreeRowBestRootIdx]                        — most-visited root
                                                 legal-action index
                                                 after the most recent
                                                 update (-1 if not
                                                 yet decided)
row[MCTSTreeRowLeafStateOffset .. +StateWidth-1]   — encoded state of the
                                                 leaf the search just
                                                 selected (input to a
                                                 rollout partition)
row[MCTSTreeRowHasLeafOffset(W)]1 if the leaf_state
                                                 slot is real, 0
                                                 otherwise (used by
                                                 rollout partitions to
                                                 short-circuit when
                                                 nothing has been
                                                 selected yet)
row[MCTSTreeRowVisitsOffset(W) .. +MaxLegalActions-1] — per-legal-action
                                                    root visit counts,
                                                    padded with zeros
row[MCTSTreeRowWinsOffset(W,K) .. +MaxLegalActions-1] — per-legal-action
                                                    root win sums,
                                                    padded with zeros

Use MCTSTreeRowWidth(W, K) to compute the required state_width / init slice length.

Pipeline lag

MCTSTreeIteration is one half of a 2-step pipeline with a downstream rollout partition. The rollout partition reads (leaf_state, has_leaf) and outputs scores; MCTSTreeIteration then reads those scores via params_from_upstream (key MCTSTreeParamRolloutScores) and applies a backup to the path it selected two steps earlier. The 2-step lag is fundamental to expressing selection-then-backup as stochadex’s single-row dataflow.

In steady state each outer step does one selection + one backup, so the throughput is one MCTS iteration per stochadex step (after a 2-step fill).

Warm fields (must be set before Configure):

type MCTSTreeIteration[S any, A any] struct {
    Env             Environment[S, A]
    Cfg             MCTSConfig[S, A]
    Decoder         func([]float64) (S, error)
    Encoder         func(S) []float64
    MaxLegalActions int
    StateWidth      int
    Players         int
    // contains filtered or unexported fields
}

func (*MCTSTreeIteration[S, A]) Configure

func (m *MCTSTreeIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)

Configure implements simulator.Iteration. Decodes the encoded root from is.InitStateValues[1 .. 1+StateWidth] and resets the tree. The first slot (MCTSTreeRowBestRootIdx) and the stats slots can be left zero in the init: best_root_idx is initialised to -1 to signal “not yet decided”.

func (*MCTSTreeIteration[S, A]) Iterate

func (m *MCTSTreeIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64

Iterate implements simulator.Iteration.

func (*MCTSTreeIteration[S, A]) MCTSTree

func (m *MCTSTreeIteration[S, A]) MCTSTree() *MCTSTree[S, A]

MCTSTree exposes the underlying search tree (typically for telemetry).

type TTTAction

TTTAction is a cell index 0..8.

type TTTAction int

type TTTGame

TTTGame implements Environment[TTTState, TTTAction]. Tic-tac-toe is the canonical fixture for testing this package and downstream consumers (e.g. pkg/analysis): small enough to be obvious, large enough that random play loses, and deterministic at endgame — from a “win in one” position MCTS must pick the winning move; from a “block in one” position it must block.

type TTTGame struct{}

func (*TTTGame) Actor

func (g *TTTGame) Actor(s TTTState) int

func (*TTTGame) Apply

func (g *TTTGame) Apply(s TTTState, a TTTAction) (TTTState, error)

Apply places the current player’s mark on cell a and updates Done/Winner.

func (g *TTTGame) Legal(s TTTState) []TTTAction

Legal returns the empty cells; nil if the game is over.

func (*TTTGame) Players

func (g *TTTGame) Players(s TTTState) int

func (*TTTGame) Terminal

func (g *TTTGame) Terminal(s TTTState) ([]float64, bool)

Terminal returns the per-player [0,1] score vector and a done flag. Draw splits 0.5/0.5; a winner takes 1, the other 0.

type TTTState

TTTState is a tic-tac-toe position. cells holds 0=empty, 1=X, 2=O.

type TTTState struct {
    Cells   [9]int8
    Current int // 0=X to move, 1=O to move
    Done    bool
    Winner  int // -1=draw or in-progress, 0=X won, 1=O won
}

func TTTDecode

func TTTDecode(v []float64) (TTTState, error)

TTTDecode rebuilds a TTTState from its row representation, recomputing Done/Winner so callers can spell positions declaratively.

func TTTFromGrid

func TTTFromGrid(grid [9]int8, currentPlayer int) TTTState

TTTFromGrid builds a TTTState from a literal cell grid plus the player to move. Done/Winner are derived. Useful for spelling test positions inline.

Generated by gomarkdoc