agents
import "github.com/umbralcalc/stochadex/pkg/agents"Package agents provides decision-making agents that operate over a generic Environment[S, A] interface. The package is intended to host any agent built on the same environment framework. Currently it ships MCTS (UCT) as the only agent, with MAST as an optional rollout strategy on top.
Per-player terminal scores are []float64 in [0,1] (the established stochadex value convention). Codecs (encoder/decoder for S into the stochadex row’s []float64) are supplied by the caller as function fields on each partition. This package does not depend on any encoding protocol.
Key Features:
- Generic Environment[S, A] interface (Legal/Apply/Terminal/Actor/Players)
- MCTS decomposed as three partitions: tree (selection + backup), rollout (one playout per step), and apply (state advancer)
- Pluggable rollout functions (UniformRandomRollout, FromProgress, WinnerToTerminal)
- MAST as an optional rollout strategy: a learning rollout policy backed by a separate aggregation partition holding running per-action-key (count, sum) state
- Cycle-breaking by mixing within-step params_from_upstream with lag-1 state-history reads, so tree ↔︎ rollout and apply ↔︎ search both compose without deadlock
- One-shot helper RunMCTSSearch for ad-hoc searches outside a coordinator
Usage Patterns:
- Self-play: compose ApplyIteration + an embedded search sim (analysis.NewMCTSSelfPlayPartitions wires this up)
- Per-simulation telemetry: run MCTSTreeIteration + MCTSRolloutIteration directly without the apply layer, then read the tree’s row
- YAML / pkg/api use: ship non-generic façade types per environment family that bake in the type parameters, since generic types in YAML iteration strings are awkward
Index
- Constants
- Variables
- func MASTAggregationCountSlot(k int) int
- func MASTAggregationRowWidth(maxKeys int) int
- func MASTAggregationSumSlot(k int) int
- func MASTMeanForKey(row []float64, k int) (mean float64, count int)
- func MASTRolloutNumPathOffset(players int) int
- func MASTRolloutOkOffset(players int) int
- func MASTRolloutPathOffset(players int) int
- func MASTRolloutRowWidth(players, maxPath int) int
- func MASTRolloutScoresOffset(i int) int
- func MCTSRolloutRowWidth(players int) int
- func MCTSTreeRowHasLeafOffset(stateWidth int) int
- func MCTSTreeRowVisitsOffset(stateWidth int) int
- func MCTSTreeRowWidth(stateWidth, maxLegalActions int) int
- func MCTSTreeRowWinsOffset(stateWidth, maxLegalActions int) int
- func TTTEncode(s TTTState) []float64
- func TTTKey(a TTTAction) string
- func WinnerToTerminal(winner, players int, done bool) []float64
- type ApplyIteration
- type Environment
- type MASTAggregationIteration
- type MASTRolloutIteration
- type MCTSConfig
- type MCTSEdgeStat
- type MCTSRolloutFn
- type MCTSRolloutIteration
- type MCTSTree
- func NewMCTSTree[S any, A any](root S) *MCTSTree[S, A]
- func (t *MCTSTree[S, A]) AdvanceRoot(env Environment[S, A], legalIdx int)
- func (t *MCTSTree[S, A]) BackupScores(path []int, scores []float64)
- func (t *MCTSTree[S, A]) BackupVisits(path []int, scores []float64)
- func (t *MCTSTree[S, A]) NodeCount() int
- func (t *MCTSTree[S, A]) Reset(root S)
- func (t *MCTSTree[S, A]) Root() S
- func (t *MCTSTree[S, A]) RootBestLegalIdx() (int, bool)
- func (t *MCTSTree[S, A]) RootEdgeStats(legal []A) []MCTSEdgeStat[A]
- func (t *MCTSTree[S, A]) RootStatsByLegalIdx(maxLegalActions int) (visits, wins []float64)
- func (t *MCTSTree[S, A]) RunOne(env Environment[S, A], cfg *MCTSConfig[S, A], rng *rand.Rand)
- func (t *MCTSTree[S, A]) SelectLeaf(env Environment[S, A], cfg *MCTSConfig[S, A], rng *rand.Rand) (path []int, leafState S, leafIdx int, ok bool)
- type MCTSTreeIteration
- func (m *MCTSTreeIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)
- func (m *MCTSTreeIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64
- func (m *MCTSTreeIteration[S, A]) MCTSTree() *MCTSTree[S, A]
- type TTTAction
- type TTTGame
- type TTTState
Constants
Default hyperparameters used when MCTSConfig fields are zero.
const (
MCTSDefaultSimulations = 120
MCTSDefaultRolloutMaxSteps = 220
MCTSDefaultMaxTreeDepth = 14
MCTSDefaultExploration = 1.41
)Param key for direct (within-step) best-action input. The value should be a 1-element slice — the legal-action index. A negative value (e.g. the -1 sentinel) leaves the state unchanged.
const ApplyParamBestIdx = "best_legal_idx"Param key for state-history (lag-1) best-action input. The value (set via ParamsAsPartitions) is a 1-element slice containing the partition index of the upstream partition whose row[BestIdxSlot] holds the best-action index.
const ApplyParamBestIdxPartition = "best_idx_partition"MASTAggregationParamPartition is the params_as_partitions key used by downstream samplers to learn this partition’s index for state-history reads. The value is a 1-element slice containing the partition index.
const MASTAggregationParamPartition = "mast_aggregates_partition"MASTAggregationParamUpdates is the params_from_upstream key used to read the variable-length update batch from an upstream rollout partition.
const MASTAggregationParamUpdates = "mast_updates"MASTDefaultTau is the softmax temperature applied to MAST means when sampling. Smaller → more exploitative; larger → closer to uniform. 1.0 diverges from uniform within ~50 updates per key in practice.
const MASTDefaultTau = 1.0MASTSamplePrior is the score assigned to actions whose key has not yet been observed, so they retain selection probability early in search. 0.5 sits at the midpoint of the [0, 1] reward range — neither encouraged nor discouraged.
const MASTSamplePrior = 0.5Param key used by MCTSRolloutIteration.Iterate to read the leaf state from an upstream MCTSTreeIteration via params_from_upstream. The slice should be of length StateWidth followed by a single has_leaf flag (matching MCTSTreeIteration’s row layout: leaf_state then has_leaf). Use Indices on the NamedUpstreamConfig to slice out the leaf_state + has_leaf section of the tree’s row.
const MCTSRolloutParamLeaf = "leaf"Param key used by MCTSTreeIteration.Iterate to read rollout scores from an upstream MCTSRolloutIteration via params_from_upstream (within-step). The slice should be of length Players + 1 (P scores + ok flag) — the layout produced by MCTSRolloutIteration.
This mode creates a within-step dependency that breaks if the rollout partition also depends on the tree (the standard MCTS pipeline does). Use MCTSTreeParamRolloutScoresPartition for the lag-1 state-history mode instead when wiring the tree + rollout pipeline.
const MCTSTreeParamRolloutScores = "rollout_scores"Param key used by MCTSTreeIteration.Iterate to read rollout scores from an upstream MCTSRolloutIteration via params_as_partitions (state-history, lag-1). The value is a 1-element slice containing the rollout partition’s index; the tree reads stateHistories[idx].Values row 0 (= the previous step’s rollout output) at runtime.
This is the standard wiring used by NewMCTSSelfPlayPartitions: rollout reads tree’s leaf within-step (so the rollout sees the freshest leaf), and tree reads rollout’s scores lag-1 (so rollout doesn’t have to wait on tree for last step’s scores). The 1-step lag aligns correctly: at step N+1 tree backs up the path it selected at step N with scores from rollout at step N (which were for that very leaf).
State-history mode takes priority over within-step mode if both keys are present.
const MCTSTreeParamRolloutScoresPartition = "rollout_scores_partition"Param key used by MCTSTreeIteration.Iterate to read the current search root state. Whenever this param’s value differs from the cached root encoding, the tree is reset to the decoded new root. The slice should be of length StateWidth.
Set this via the embedded simulation run’s outer params_from_upstream (e.g. an outer apply partition piping its current game state into the inner sim via the “<innerName>/root_state” forwarding mechanism). When MCTSTreeIteration is used standalone with no outer pipeline, the param is absent and the tree retains the root set at Configure time.
const MCTSTreeParamRootState = "root_state"Row layout slot accessors. Use these to compute params_from_upstream indices when wiring downstream partitions to MCTSTreeIteration’s row.
const MCTSTreeRowBestRootIdx = 0const MCTSTreeRowLeafStateOffset = 1TTTWidth is the encoded row width: 9 cells + current player.
const TTTWidth = 10Variables
WinLines is the eight three-in-a-row patterns.
var WinLines = [...][3]int{
{0, 1, 2}, {3, 4, 5}, {6, 7, 8},
{0, 3, 6}, {1, 4, 7}, {2, 5, 8},
{0, 4, 8}, {2, 4, 6},
}func MASTAggregationCountSlot
func MASTAggregationCountSlot(k int) intMASTAggregationCountSlot returns the row offset of the count for key k.
func MASTAggregationRowWidth
func MASTAggregationRowWidth(maxKeys int) intMASTAggregationRowWidth returns the required state_width for an MASTAggregationIteration with the given key bound.
func MASTAggregationSumSlot
func MASTAggregationSumSlot(k int) intMASTAggregationSumSlot returns the row offset of the sum for key k.
func MASTMeanForKey
func MASTMeanForKey(row []float64, k int) (mean float64, count int)MASTMeanForKey reads the running mean reward for key k from a row in the MASTAggregationIteration’s layout. Returns (0, 0) when the key has not been observed. Used by samplers that have read the partition’s row via params_as_partitions.
func MASTRolloutNumPathOffset
func MASTRolloutNumPathOffset(players int) intMASTRolloutNumPathOffset returns the row offset of the num_path counter.
func MASTRolloutOkOffset
func MASTRolloutOkOffset(players int) intMASTRolloutOkOffset returns the row offset of the ok flag.
func MASTRolloutPathOffset
func MASTRolloutPathOffset(players int) intMASTRolloutPathOffset returns the row offset of the first (key_idx, reward) pair.
func MASTRolloutRowWidth
func MASTRolloutRowWidth(players, maxPath int) intMASTRolloutRowWidth returns the required state_width for an MASTRolloutIteration with the given player count and path bound.
func MASTRolloutScoresOffset
func MASTRolloutScoresOffset(i int) intMASTRolloutScoresOffset returns the row offset of score slot i.
func MCTSRolloutRowWidth
func MCTSRolloutRowWidth(players int) intMCTSRolloutRowWidth returns the required InitStateValues / StateWidth for a MCTSRolloutIteration with the given player count.
func MCTSTreeRowHasLeafOffset
func MCTSTreeRowHasLeafOffset(stateWidth int) intfunc MCTSTreeRowVisitsOffset
func MCTSTreeRowVisitsOffset(stateWidth int) intfunc MCTSTreeRowWidth
func MCTSTreeRowWidth(stateWidth, maxLegalActions int) intMCTSTreeRowWidth returns the required InitStateValues / StateWidth for a MCTSTreeIteration with the given encoded-state width and max legal-action count.
func MCTSTreeRowWinsOffset
func MCTSTreeRowWinsOffset(stateWidth, maxLegalActions int) intfunc TTTEncode
func TTTEncode(s TTTState) []float64TTTEncode produces the []float64 row representation of a TTTState. Done/Winner are derivable from the cells so they are not encoded.
func TTTKey
func TTTKey(a TTTAction) stringTTTKey is a stable string key for an action, useful for any aggregation or lookup that needs a string-typed action identifier.
func WinnerToTerminal
func WinnerToTerminal(winner, players int, done bool) []float64WinnerToTerminal builds an Environment.Terminal-compatible result from a binary winner. For envs whose native terminal predicate is “winner int, done bool” rather than per-player scores, embed this in your Environment implementation:
func (g *MyGame) Terminal(s State) ([]float64, bool) {
w, done := g.winnerOrDone(s)
return agents.WinnerToTerminal(w, g.Players(s), done), done
}type ApplyIteration
ApplyIteration advances the environment by one ply per stochadex outer step using a best-action signal supplied either via params_from_upstream (within-step) OR via params_as_partitions (lagged read of an upstream partition’s state-history row). The partition row is the encoded current game state; one outer step decodes the row, applies the chosen legal action, and writes the encoded post-move state back.
Row layout (width = StateWidth):
row[0 .. StateWidth-1] encoded current game state.Two read modes
Direct param mode (ApplyParamBestIdx): set ParamsFromUpstream[ApplyParamBestIdx] to read the best-action index directly within the same step. Used when the upstream partition is not in a self-referential cycle with apply.
State-history mode (ApplyParamBestIdxPartition + BestIdxSlot): set ParamsAsPartitions[ApplyParamBestIdxPartition] to the upstream partition’s name, and set BestIdxSlot to the offset of the best-action index within that partition’s row. Apply will read the PREVIOUS step’s row 0 of that partition each step. Used to break the apply ↔︎ search dependency cycle in NewMCTSSelfPlayPartitions: search reads apply within-step (so apply does not wait on search), and apply reads search lagged by one step (so search does not wait on apply). The 1-step lag aligns correctly because at step N apply applies the best_idx that search produced at step N-1 for apply’s state at step N-1 — which is the same state apply currently holds (apply only advances when it applies a move).
State-history mode takes priority if both are configured.
Warm fields (must be set before Configure):
- Env: the typed Environment[S, A] to advance.
- Decoder, Encoder: codec for the game state.
Optional warm field:
- BestIdxSlot: row offset of the best-action index within the upstream partition’s row when using state-history mode.
type ApplyIteration[S any, A any] struct {
Env Environment[S, A]
Decoder func([]float64) (S, error)
Encoder func(S) []float64
BestIdxSlot int
// contains filtered or unexported fields
}func (*ApplyIteration[S, A]) Configure
func (m *ApplyIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)Configure implements simulator.Iteration.
func (*ApplyIteration[S, A]) Iterate
func (m *ApplyIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64Iterate implements simulator.Iteration.
type Environment
Environment is the game/decision-process interface that MCTS searches over. Implementations are pure: Legal/Apply must not mutate s, and Apply must return a fresh value (the search clones aggressively).
Terminal returns the per-player [0,1] score vector and a done flag. The score vector lets the env represent draws (0.5/0.5), graded scoring (Catan-like), and standard winner-takes-all (one-hot) without faking a “winner” int. For binary games, see WinnerToTerminal in rollout.go.
Actor returns the index of the player whose decision creates the next edge from s; backups credit nodes by Actor at the parent. Players returns the total seat count and bounds the score vector length.
type Environment[S any, A any] interface {
Legal(s S) []A
Apply(s S, a A) (S, error)
Terminal(s S) (scores []float64, done bool)
Actor(s S) int
Players(s S) int
}type MASTAggregationIteration
MASTAggregationIteration is a stochadex iteration that maintains running (count, sum) pairs per action-key index for MAST (Move-Average Sampling Technique). Each step it reads a variable-length update batch from an upstream rollout partition via params_from_upstream and applies the (key_idx, reward) increments to its row.
Row layout (width = 2 * MaxKeys):
row[2*k] count for key k
row[2*k+1] sum for key kUse MASTAggregationRowWidth(K) to compute state_width.
Update batch format
The upstream batch is a single []float64 with the layout
[num_updates, key_idx_0, reward_0, key_idx_1, reward_1, ...]where num_updates is the number of valid (key_idx, reward) pairs in the slice. MASTRolloutIteration emits exactly this layout in its row’s path-suffix slots; wire MASTAggregationIteration’s params_from_upstream (key MASTAggregationParamUpdates) to those slots.
Out-of-range key indices and updates beyond the slice’s declared num_updates are silently dropped.
Read access
Downstream samplers (e.g. MASTRolloutIteration) read the aggregates via state-history mode (lag-1) using params_as_partitions. See MASTAggregationParamPartition for the canonical key.
type MASTAggregationIteration[A any] struct {
MaxKeys int
// contains filtered or unexported fields
}func (*MASTAggregationIteration[A]) Configure
func (m *MASTAggregationIteration[A]) Configure(partitionIndex int, settings *simulator.Settings)Configure implements simulator.Iteration.
func (*MASTAggregationIteration[A]) Iterate
func (m *MASTAggregationIteration[A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64Iterate implements simulator.Iteration.
type MASTRolloutIteration
MASTRolloutIteration is a stochadex iteration that runs one MAST-biased rollout per step. It reads the leaf state from an upstream MCTSTreeIteration (within-step), reads the running aggregates from a MASTAggregationIteration (lag-1, via params_as_partitions), runs a playout sampling each ply via softmax over the aggregates, and emits the per-player scores plus the (key_idx, reward) path that the MASTAggregationIteration will absorb on the next step.
Row layout (width = Players + 2 + 2 * MaxPath):
row[0 .. Players-1] per-player [0,1] scores
row[Players] ok flag (1 = scores valid)
row[Players+1] num_path (length of valid pairs)
row[Players+2 .. Players+1+2*MaxPath] (key_idx, reward) pairs,
padded with zerosThe (num_path, pairs…) suffix matches MASTAggregationParamUpdates so the downstream MASTAggregationIteration can read it as a single slice.
Warm fields (must be set before Configure):
- Env: the typed Environment[S, A] to roll out against.
- Cfg: must have RolloutMaxSteps; the rollout function in Cfg is ignored (this partition implements its own MAST-biased sampling).
- Decoder: decodes the leaf_state from the upstream-supplied slice.
- KeyToIdx: maps an action to a bounded index in [0, MaxKeys).
- MaxKeys: same K as the MASTAggregationIteration this is wired to.
- MaxPath: maximum number of (key, reward) updates emitted per rollout. Paths longer than this are truncated.
- Players: per-player score vector length.
- Tau: softmax temperature (MASTDefaultTau if <= 0).
- Progress: optional [0,1] per-player progress proxy used to score truncated rollouts.
type MASTRolloutIteration[S any, A any] struct {
Env Environment[S, A]
Cfg MCTSConfig[S, A]
Decoder func([]float64) (S, error)
KeyToIdx func(A) int
MaxKeys int
MaxPath int
Players int
Tau float64
Progress func(s S, player int) (float64, bool)
// contains filtered or unexported fields
}func (*MASTRolloutIteration[S, A]) Configure
func (m *MASTRolloutIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)Configure implements simulator.Iteration.
func (*MASTRolloutIteration[S, A]) Iterate
func (m *MASTRolloutIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64Iterate implements simulator.Iteration.
type MCTSConfig
MCTSConfig holds UCT hyperparameters and the rollout driver.
Rollout is the single graded rollout signature: it returns a per-player score vector in [0,1] and an ok flag (false signals “no signal” — caller counts the visit but skips the win credit). For binary winner games or games using a Progress proxy, build the rollout from helpers in rollout.go (UniformRandomRollout, OneHotFromWinner, FromProgress).
Progress is an optional per-state, per-player [0,1] value proxy used by the FromProgress rollout adapter to score truncated rollouts.
type MCTSConfig[S any, A any] struct {
Simulations int
Exploration float64
MaxTreeDepth int
RolloutMaxSteps int
Rollout MCTSRolloutFn[S, A]
Progress func(s S, player int) (float64, bool)
}func (*MCTSConfig[S, A]) ApplyDefaults
func (c *MCTSConfig[S, A]) ApplyDefaults()ApplyDefaults fills in zero-valued hyperparameters with the package defaults, mutating the receiver. Called once at the start of each search run; safe to call multiple times.
Exported so external packages building on MCTSConfig can share the defaults logic without duplicating it.
type MCTSEdgeStat
MCTSEdgeStat is per-action telemetry exposed at the root after a search. Useful for JSON reports or logging the search’s distribution over moves.
type MCTSEdgeStat[A any] struct {
Action A `json:"action"`
Visits int `json:"visits"`
MeanForActor float64 `json:"mean_for_actor"`
}func RunMCTSSearch
func RunMCTSSearch[S any, A any](env Environment[S, A], root S, cfg MCTSConfig[S, A], baseSeed uint64, sims int) (A, []MCTSEdgeStat[A], error)RunMCTSSearch runs sims UCT simulations from root and returns the best legal action plus per-edge stats. Independent of stochadex partitions — useful for one-shot “what’s the best move?” queries.
Defaults are filled in from the package constants. If cfg.Rollout is nil, UniformRandomRollout is used.
type MCTSRolloutFn
MCTSRolloutFn is the single rollout signature used by the search. It runs a stochastic playout from s for at most maxSteps actions and returns:
- scores: per-player [0,1] score vector (length Players(s)).
- ok: false signals “no signal” — the caller counts the simulation as a visit but does not credit any win, so UCB exploration still progresses.
- err: a hard failure (e.g. Apply rejected what it claimed was legal); the caller treats this like ok=false.
The seed argument is the rollout’s full RNG seed; implementations should be deterministic given the same seed.
type MCTSRolloutFn[S any, A any] func(env Environment[S, A], s S, maxSteps int, seed uint64) (scores []float64, ok bool, err error)func FromProgress
func FromProgress[S any, A any](inner MCTSRolloutFn[S, A], progress func(s S, player int) (float64, bool)) MCTSRolloutFn[S, A]FromProgress wraps an inner rollout so that truncated rollouts (ok=false from inner) are rescued by scoring the final state via the supplied progress function. progress returns a per-player [0,1] proxy of “how close is this player to winning” and an ok flag (false = no proxy available for that player; treated as 0).
All-equal progress vectors carry no comparative signal and are treated as no signal (ok=false from the wrapper) so the search relies on UCB exploration alone — see the docstring on MCTSTree.backupVisits for the stall-move failure mode this avoids.
FromProgress needs to know the final state of the inner rollout, so it re-runs the playout itself rather than wrapping inner. The inner rollout is therefore only consulted for early termination.
func UniformRandomRollout
func UniformRandomRollout[S any, A any]() MCTSRolloutFn[S, A]UniformRandomRollout returns a MCTSRolloutFn that plays uniformly random legal actions until either Terminal returns done or maxSteps is reached. On termination the env’s Terminal scores are returned. On truncation (maxSteps reached without termination) the rollout returns ok=false — compose with FromProgress if you have a progress proxy to score truncated rollouts.
type MCTSRolloutIteration
MCTSRolloutIteration runs one rollout per stochadex step. It reads the leaf state to roll out from via params_from_upstream (typically wired to a MCTSTreeIteration’s leaf_state slot via Indices) and outputs a per-player score vector plus an ok flag in its own row.
Row layout (width = Players + 1):
row[0 .. Players-1] per-player [0,1] scores from the rollout
row[Players] ok flag (1 if the rollout produced valid scores,
0 otherwise — the upstream MCTSTreeIteration uses
this to decide whether to apply backupScores or
backupVisits with no signal)Use MCTSRolloutRowWidth(P) to size InitStateValues / state_width.
Stateless across steps — each Iterate is one independent rollout. Swap in FromProgress, WinnerToTerminal, or any custom MCTSRolloutFn via Cfg.Rollout without touching the partition wiring.
Warm fields (must be set before Configure):
- Env: the typed Environment[S, A] to roll out against.
- Cfg: must have Rollout set (the rollout function to invoke).
- Decoder: decodes the leaf_state from the upstream-supplied []float64 back into the typed S.
- Players: P, the per-player score vector length.
type MCTSRolloutIteration[S any, A any] struct {
Env Environment[S, A]
Cfg MCTSConfig[S, A]
Decoder func([]float64) (S, error)
Players int
// contains filtered or unexported fields
}func (*MCTSRolloutIteration[S, A]) Configure
func (m *MCTSRolloutIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)Configure implements simulator.Iteration.
func (*MCTSRolloutIteration[S, A]) Iterate
func (m *MCTSRolloutIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64Iterate implements simulator.Iteration.
type MCTSTree
MCTSTree is the in-memory UCT search tree for a fixed root state. Methods are not safe for concurrent use; one MCTSTree per goroutine.
The MCTSTree owns no environment or config — those are passed in to RunOne per-simulation. This makes MCTSTree easy to embed in iterations that may want to tweak config between simulations (e.g. adaptive exploration).
type MCTSTree[S any, A any] struct {
// contains filtered or unexported fields
}func NewMCTSTree
func NewMCTSTree[S any, A any](root S) *MCTSTree[S, A]NewMCTSTree returns a MCTSTree with a single root node containing root.
func (*MCTSTree[S, A]) AdvanceRoot
func (t *MCTSTree[S, A]) AdvanceRoot(env Environment[S, A], legalIdx int)AdvanceRoot promotes the root’s child at the given legal index to be the new root, preserving its subtree (classic MCTS tree reuse). If that child was never expanded, the tree is rebuilt fresh from the resulting state.
env is needed to compute the post-move state if the subtree is missing.
func (*MCTSTree[S, A]) BackupScores
func (t *MCTSTree[S, A]) BackupScores(path []int, scores []float64)BackupScores credits each node along path with the score belonging to its actor (visits and wins both increment). Exported wrapper around the internal backupScores so iterations split across selection / rollout / backup partitions can apply the scores when they arrive.
func (*MCTSTree[S, A]) BackupVisits
func (t *MCTSTree[S, A]) BackupVisits(path []int, scores []float64)BackupVisits is the no-signal-tolerant variant: visits always increment, but wins are only credited when scores is non-nil. See backupVisits docs for the engine-stall reasoning.
func (*MCTSTree[S, A]) NodeCount
func (t *MCTSTree[S, A]) NodeCount() intNodeCount returns the number of nodes currently in the tree (including the root). Useful for telemetry and capacity tuning.
func (*MCTSTree[S, A]) Reset
func (t *MCTSTree[S, A]) Reset(root S)Reset replaces the entire tree with a fresh root at the given state. Use when no usable subtree exists (opening move, or after an external change to the root).
func (*MCTSTree[S, A]) Root
func (t *MCTSTree[S, A]) Root() SRoot returns the root state.
func (*MCTSTree[S, A]) RootBestLegalIdx
func (t *MCTSTree[S, A]) RootBestLegalIdx() (int, bool)RootBestLegalIdx returns the most-visited (then most-winning) child legal index. Ties are broken via reservoir sampling over equally-good children so the choice is not biased toward the first-listed action — important for engine-heavy games where the first listed legal action is often a stall (recycle, pass) and a deterministic first-tie pick would deadlock the agent. Reservoir randomness is seeded from the current tree shape so the result is reproducible without taking an external rng.
func (*MCTSTree[S, A]) RootEdgeStats
func (t *MCTSTree[S, A]) RootEdgeStats(legal []A) []MCTSEdgeStat[A]RootEdgeStats reports per-action visit counts and mean-for-actor for each expanded child of the root. legal must be the same slice ordering used by the env’s Legal(root); pass env.Legal(tree.Root()) at the call site.
func (*MCTSTree[S, A]) RootStatsByLegalIdx
func (t *MCTSTree[S, A]) RootStatsByLegalIdx(maxLegalActions int) (visits, wins []float64)RootStatsByLegalIdx returns per-legal-action visit counts and win sums at the root, padded with zeros up to maxLegalActions. Returns (visits, wins) each of length maxLegalActions. Used to expose root statistics in fixed-width row layouts.
func (*MCTSTree[S, A]) RunOne
func (t *MCTSTree[S, A]) RunOne(env Environment[S, A], cfg *MCTSConfig[S, A], rng *rand.Rand)RunOne does one UCT iteration: selection → expansion → rollout → backup. rng must be seeded by the caller; one call uses one RNG.
Calls cfg.applyDefaults() so a fresh MCTSConfig with only Rollout set works out of the box. The mutation is idempotent (only zero values are filled).
RunOne is the all-in-one path used by RunMCTSSearch and by callers who don’t need the selection / rollout / backup phases as separate stochadex partitions. For the decomposed pipeline use SelectLeaf + BackupScores.
func (*MCTSTree[S, A]) SelectLeaf
func (t *MCTSTree[S, A]) SelectLeaf(env Environment[S, A], cfg *MCTSConfig[S, A], rng *rand.Rand) (path []int, leafState S, leafIdx int, ok bool)SelectLeaf walks the tree from the root using UCB1 (with first-visit preference for unvisited children) until it reaches an unexpanded edge, then expands it by creating a new child node. Returns the path of node indices from the root’s child down to the new leaf, the leaf’s state, the leaf’s node index, and ok=true. Returns ok=false if the root is terminal, has no legal moves, MaxTreeDepth is reached, or env.Apply fails during expansion (the caller should treat this as “no leaf selected this step”).
SelectLeaf does NOT roll out and does NOT back up — it is the (selection + expansion) half of one MCTS iteration. Pair it with MCTSTree.BackupScores or MCTSTree.BackupVisits to apply the scores when they arrive.
Calls cfg.applyDefaults() so a fresh MCTSConfig works out of the box. The mutation is idempotent (only zero values are filled).
type MCTSTreeIteration
MCTSTreeIteration runs the (selection + expansion + backup) phase of a UCT MCTS search as a stochadex iteration. The tree itself lives on the struct (graph state, fundamentally not []float64-shaped); the partition row exposes a fixed-width summary that downstream partitions can consume via params_from_upstream:
row[MCTSTreeRowBestRootIdx] — most-visited root
legal-action index
after the most recent
update (-1 if not
yet decided)
row[MCTSTreeRowLeafStateOffset .. +StateWidth-1] — encoded state of the
leaf the search just
selected (input to a
rollout partition)
row[MCTSTreeRowHasLeafOffset(W)] — 1 if the leaf_state
slot is real, 0
otherwise (used by
rollout partitions to
short-circuit when
nothing has been
selected yet)
row[MCTSTreeRowVisitsOffset(W) .. +MaxLegalActions-1] — per-legal-action
root visit counts,
padded with zeros
row[MCTSTreeRowWinsOffset(W,K) .. +MaxLegalActions-1] — per-legal-action
root win sums,
padded with zerosUse MCTSTreeRowWidth(W, K) to compute the required state_width / init slice length.
Pipeline lag
MCTSTreeIteration is one half of a 2-step pipeline with a downstream rollout partition. The rollout partition reads (leaf_state, has_leaf) and outputs scores; MCTSTreeIteration then reads those scores via params_from_upstream (key MCTSTreeParamRolloutScores) and applies a backup to the path it selected two steps earlier. The 2-step lag is fundamental to expressing selection-then-backup as stochadex’s single-row dataflow.
In steady state each outer step does one selection + one backup, so the throughput is one MCTS iteration per stochadex step (after a 2-step fill).
Warm fields (must be set before Configure):
- Env: the typed Environment[S, A] to search.
- Cfg: UCT hyperparameters and the rollout function (the rollout fn here is only used by MCTSTree.RunOne when no upstream rollout is wired — left nil if a separate MCTSRolloutIteration supplies the scores).
- Decoder, Encoder: codec for the leaf state slots.
- MaxLegalActions: K, the maximum legal-action count at any node. Stats slots beyond the actual legal count are zero-padded.
- StateWidth: W, the width of one encoded state.
- Players: P, the per-player score vector length.
type MCTSTreeIteration[S any, A any] struct {
Env Environment[S, A]
Cfg MCTSConfig[S, A]
Decoder func([]float64) (S, error)
Encoder func(S) []float64
MaxLegalActions int
StateWidth int
Players int
// contains filtered or unexported fields
}func (*MCTSTreeIteration[S, A]) Configure
func (m *MCTSTreeIteration[S, A]) Configure(partitionIndex int, settings *simulator.Settings)Configure implements simulator.Iteration. Decodes the encoded root from is.InitStateValues[1 .. 1+StateWidth] and resets the tree. The first slot (MCTSTreeRowBestRootIdx) and the stats slots can be left zero in the init: best_root_idx is initialised to -1 to signal “not yet decided”.
func (*MCTSTreeIteration[S, A]) Iterate
func (m *MCTSTreeIteration[S, A]) Iterate(params *simulator.Params, partitionIndex int, stateHistories []*simulator.StateHistory, timestepsHistory *simulator.CumulativeTimestepsHistory) []float64Iterate implements simulator.Iteration.
func (*MCTSTreeIteration[S, A]) MCTSTree
func (m *MCTSTreeIteration[S, A]) MCTSTree() *MCTSTree[S, A]MCTSTree exposes the underlying search tree (typically for telemetry).
type TTTAction
TTTAction is a cell index 0..8.
type TTTAction inttype TTTGame
TTTGame implements Environment[TTTState, TTTAction]. Tic-tac-toe is the canonical fixture for testing this package and downstream consumers (e.g. pkg/analysis): small enough to be obvious, large enough that random play loses, and deterministic at endgame — from a “win in one” position MCTS must pick the winning move; from a “block in one” position it must block.
type TTTGame struct{}func (*TTTGame) Actor
func (g *TTTGame) Actor(s TTTState) intfunc (*TTTGame) Apply
func (g *TTTGame) Apply(s TTTState, a TTTAction) (TTTState, error)Apply places the current player’s mark on cell a and updates Done/Winner.
func (*TTTGame) Legal
func (g *TTTGame) Legal(s TTTState) []TTTActionLegal returns the empty cells; nil if the game is over.
func (*TTTGame) Players
func (g *TTTGame) Players(s TTTState) intfunc (*TTTGame) Terminal
func (g *TTTGame) Terminal(s TTTState) ([]float64, bool)Terminal returns the per-player [0,1] score vector and a done flag. Draw splits 0.5/0.5; a winner takes 1, the other 0.
type TTTState
TTTState is a tic-tac-toe position. cells holds 0=empty, 1=X, 2=O.
type TTTState struct {
Cells [9]int8
Current int // 0=X to move, 1=O to move
Done bool
Winner int // -1=draw or in-progress, 0=X won, 1=O won
}func TTTDecode
func TTTDecode(v []float64) (TTTState, error)TTTDecode rebuilds a TTTState from its row representation, recomputing Done/Winner so callers can spell positions declaratively.
func TTTFromGrid
func TTTFromGrid(grid [9]int8, currentPlayer int) TTTStateTTTFromGrid builds a TTTState from a literal cell grid plus the player to move. Done/Winner are derived. Useful for spelling test positions inline.
Generated by gomarkdoc