Official implementation of the paper Reasoning with Memory: Adaptive Information Management for Retrieval-Augmented Generation.
Hieu Man, Ro-ee Tal, Abhishek Kumar, JJ Cho, Benjamin Hsu. University of Oregon, AWS AI.
State-Aware RAG is a modular multi-hop Retrieval-Augmented Generation (RAG) framework featuring two complementary search paradigms:
- MCTS (Monte Carlo Tree Search) planner for exploratory, branching multi-step reasoning
- CoT (Chain-of-Thought) linear planner for lightweight sequential reasoning
The system performs iterative (decompose → retrieve → extract → evaluate → synthesize) cycles while maintaining an explicit evolving reasoning state ("memory"). Trees and intermediate artifacts can be persisted, re-loaded, and post-analyzed.
Core design goals: transparency, reproducibility, pluggability (LLMs & retrievers), and scalability (multi-processing + concurrency + caching).
- Features
- Architecture Overview
- Quick Start
- Environment
- Minimal Run (Single Question)
- Dataset Inference
- Configuration System
- Components & Roles
- Search Modes (MCTS vs CoT)
- Caching & Performance Tips
- Evaluation & Metrics
- Directory Structure
- Extending the Framework
- Troubleshooting
- Roadmap / Ideas
- License & Citation
- Multi-hop reasoning with explicit search state (reasoning nodes + memory)
- Two planners: MCTS (branching) and CoT (linear)
- Modular role agents (Generator / Retriever / Extractor / Evaluator)
- Pluggable LLM backends via OpenAI-compatible or LiteLLM interface
- Online (HTTP API) and optional offline (FlashRAG) retrieval
- Structured output extraction with fallback text parsing
- Concurrency + multi-processing + deterministic caching
- Reasoning tree export & reload (resumable search)
- Evaluation suite: F1, Exact Match, Sub EM, Retrieval Recall, LLM Judge
- Hydra + YAML config for reproducible experiment control
High-level loop per question:
- Generator proposes sub-questions, answers, synthesis attempts, rephrasings, or self-corrections using prompt templates under
state_aware_rag/agents/prompts/. - Retriever fetches context passages (HTTP retriever API and/or local index).
- Extractor pulls fine-grained answer spans / structured fields from retrieved context.
- Evaluator scores candidate answers, ranks or synthesizes final answers, may perform majority vote or reasoning synthesis.
- Planner (MCTS or CoT) orchestrates iterative expansion until termination conditions (depth, rollouts, convergence).
- Final answer + reasoning chain saved to a Hugging Face dataset directory under
results/with optionalresult_trees/JSONL.
Key files:
inference.py– main entry (single question or dataset map)evaluate.py– metric computation and judged scoringstate_aware_rag/agents/– agent abstractions & role implementationsstate_aware_rag/planners/– search logic (MCTSandCoT)configs/infer/– Hydra configs for experimentscache/– per-model LLM call cache
Each node logs: node type, memory snapshot, confidence, content (sub-question, answer, synthesis, final answer), plus lineage. Trees can be reloaded to continue or analyze.
Python >= 3.10
Install (editable):
pip install -e .You will need:
- An OpenAI-compatible endpoint (e.g. local server or provider) OR models accessible via LiteLLM routing
- A retriever service URL (see
scripts/deploy_retriever_server.sh) or FlashRAG index - API keys exported (e.g.
OPENAI_API_KEY) if required by your LLM backend
Optional (FlashRAG offline retrieval):
pip install flashragImportant: State-Aware RAG does not run standalone. Before any inference or evaluation command below, you must have three services reachable (the default endpoints are configured in
configs/infer/base.yaml):
- an LLM server (OpenAI-compatible) on
localhost:30000— used by the extractor/evaluator-metric models,- an embedding server on
localhost:8000— used by the generator/evaluator and the retriever's encoder,- the retriever server on
localhost:5000.These require GPUs and model weights. See
docs/server_deployment.mdfor the full guide (LLM + embedding + retriever + evaluation). Minimal local stack:
# Start LLM
python -m sglang.launch_server --model-path Qwen/Qwen3-8B --host 0.0.0.0 --port 30000 --dtype bfloat16 &
# Start embedding server
python -m sglang.launch_server --model-path Qwen/Qwen3-Embedding-4B --is-embedding --host 0.0.0.0 --port 8000 --dtype bfloat16 &
# Start retriever
python -m state_aware_rag.servers.retriever --config configs/servers/retriever-Qwen3-4B-wiki-23.yaml --port 5000 --workers 2 --mmap_index &All agent settings (models, endpoints, generation params) are defined inline in
configs/infer/base.yaml, so a single-question run needs no extra overrides:
python -m inference \
mode=mcts \
question="Who founded the company that created the iPhone?"Hydra merges any overrides you pass; mode=cot switches planner, and
search.save_tree=true persists the reasoning tree. To point an agent at a
different model without editing base.yaml, either override a single field
(e.g. agents.generator.client_kwargs.model_name=...) or swap in a standalone
agent YAML, e.g. agents.generator=configs/generator.yaml (sample per-agent
configs are provided under configs/).
Example (2Wiki dev subset, MCTS):
python -m inference \
mode=mcts \
data.name=2wiki \
data.limit=32 \
num_proc=4 \
search.max_depth=6 \
search.num_rollouts=8 \
search.save_tree=trueResults saved under:
results/mcts/Generator_<model>/Extractor_<model>/Evaluator_<model>/<dataset>/
python -m evaluate \
mode=mcts \
data.name=2wiki \
data.metrics='["all"]'The judge model used for llm_judge is the agents.evaluator_metric block in
configs/infer/base.yaml. By default evaluate scores the results directory
produced by the matching inference run; pass to_eval_path=<dir> to score a
specific saved dataset instead.
Creates <results_path>_with_scores with metric annotations + evaluation_results.json.
Hydra config tree (simplified):
configs/infer/base.yaml
├─ mode: mcts | cot
├─ results_dir: results
├─ num_proc: <int>
├─ agents:
│ ├─ generator: (inline dict or path to YAML)
│ ├─ extractor:
│ ├─ evaluator:
│ └─ retriever:
├─ search:
│ ├─ max_depth
│ ├─ num_rollouts
│ ├─ exploration_weight
│ ├─ top_k
│ ├─ save_tree
│ └─ verbose
└─ data:
├─ name (2wiki | hotpotqa | musique | ...)
├─ split
└─ limit
Each agent YAML defines:
name: generator
client_kwargs:
client_type: openai|litellm
model_name: <model-id>
api_base: <endpoint-url>
generation_config:
temperature: 0.2
max_tokens: 512
use_cache: true
concurrency: 8Retriever config may include:
online_retrieval_config:
url: http://localhost:8000/search
timeout: 30
offline_retrieval_config:
index_path: data/wiki23-Qwen3-4B-Emb-Indexed/
top_k: 5- GeneratorAgent – calls LLM to produce sub-questions, candidate answers, rephrasings, self-corrections, synthesis segments.
- RetrievalAgent – queries HTTP retriever or offline index; returns passages + metadata.
- ExtractorAgent – extracts structured spans or normalized answers from context.
- EvaluatorAgent – scores candidates, performs majority vote or synthesizes final answer reasoning.
All wrap a shared caching layer (agents/llm_agents.py) keyed by (messages + params hash) → JSON.
MCTS (state_aware_rag/planners/MCTS):
- Expands reasoning tree via rollouts.
- Nodes: USER_QUESTION, REPHASED_QUESTION, SUB_QA, SYNTHESIS, FINAL_ANSWER, SELF_CORRECTED.
- Selection guided by UCT (exploration_weight).
- Optional tree persistence: JSONL per node with visits & reward.
CoT (state_aware_rag/planners/CoT):
- Linear chain expansion (no branching) – lower compute, faster baselines.
Tips:
- Enable per-model cache:
use_cache: truein agent configs – avoids repeat LLM costs. - Adjust
concurrency/num_workersto match endpoint QPS capacity; start conservative (e.g. 8–16). - Use
num_proc(datasets map) ≤ physical cores. - Trim search space: lower
max_depth/num_rolloutsfor quick iteration. - Clear outdated cache if prompt changes: delete
cache/<role>/<model_name>/.
Implemented in evaluate.py and state_aware_rag/utils/metrics.py.
Available metrics:
- F1 (token overlap)
- Exact Match (EM)
- Sub Exact Match (Sub-EM; partial multi-answer coverage)
- LLM Judge (configurable judging model)
inference.py # Main inference entry
evaluate.py # Metrics & dataset annotation
configs/ # Hydra configs
state_aware_rag/
agents/ # Agent abstractions & role logic
planners/ # MCTS & CoT planners
preprocess/ # Normalization utilities
utils/ # Metrics & helpers
results/ # Saved HF datasets of predictions
cache/ # LLM response caches
mcts_trees/ # (Optional) reasoning tree JSONLs
scripts/ # Helper shell scripts (deploy, eval)
New LLM backend:
- Implement a client in
agents/llm_agents.py(followLiteLLMClientpattern). - Add
client_typeselection logic.
Custom retriever:
- Adapt
agents/retriever_agents.pyexpectation: POST{query, top_k, return_score?, instruction?}→{retrieved_docs: [[{id, contents, url?}, ...], ...]}. - Update your retriever server or index loader.
New planner:
- Create directory under
planners/<Name>/with asearchfunction signature mirroring existing planners. - Wire selection in
inference.generate_answerbased onmode.
Additional node / role:
- Extend
NodeType&ReasoningNodelogic, ensure serialization fields updated.
| Symptom | Likely Cause | Fix |
|---|---|---|
Empty retrieval_docs |
Retriever URL wrong or no results | Verify API endpoint & corpus |
| Cache not updating | Prompt or params changed but same hash path reused | Manually clear cache/<role>/<model> |
| MCTS very slow | Large num_rollouts * max_depth |
Reduce both; enable verbose=false |
| OOM / rate errors | Concurrency too high | Lower num_workers & concurrency |
Logging: set LOGGING_LEVEL=INFO (default) or adjust per run.
- Retrieval-conditioned adaptive rollouts (stop early on convergence)
- Graph-based memory
- Tool invocation / function calling integration
- Multi-corpus hybrid retrieval (dense + sparse fusion)
Contributions via issues / PRs welcome. See CONTRIBUTING.md for guidelines.
See CONTRIBUTING for information on reporting security issues.
This project is licensed under the Creative Commons Attribution-NonCommercial 4.0 International (CC-BY-NC-4.0) License. See the LICENSE file for the full text. Some bundled third-party components retain their original licenses under finetune/rl/verl/.
If you use State-Aware RAG in academic or industrial work, please cite:
@article{man2025reasoning,
title = {Reasoning with Memory: Adaptive Information Management for Retrieval-Augmented Generation},
author = {Man, Hieu and Tal, Ro-ee and Kumar, Abhishek and Cho, JJ and Hsu, Benjamin},
year = {2025},
url = {https://github.com/amazon-science/state-aware-rag}
}Q: Can I resume a previous MCTS run?
A: Yes, set search.save_tree=true. If the JSONL already exists for question_id, it will reload instead of recomputing.
Q: How do I change the model for the Generator only?
A: Point agents.generator to a different YAML or override agents.generator.client_kwargs.model_name=....
Q: Do I need golden answers?
A: No. They are optional and only used when use_golden_answer is true (e.g., supervision / debugging modes).
Please open an issue for bugs or feature requests.
Happy reasoning!