Pyrefly's tensor shape tracking is designed so most PyTorch coverage can be extended by editing stubs and tests, without changing Pyrefly's Rust internals. This page explains the main mechanisms and how to validate changes.
Most external contributions should be stub-only or example/test-only changes.
Kernel changes are possible, but they are a narrower workflow for changes to
Pyrefly's shape machinery or the shape_extensions runtime package.
Shape tracking uses three complementary mechanisms:
- Fixture stubs:
.pyifiles with shape-generic type signatures. These cover modules likenn.Linear,nn.Conv2d, and functions liketorch.mm. - Shape DSL functions: shape transforms written in a small Python subset in
tensor-shapes/pyrefly-torch-stubs/torch-stubs/_shapes.pyi, decorated with@shape_dsl_function, and attached to stubs with@uses_shape_dsl(...). These cover operations with computed shape logic likereshape,cat, pooling, convolution, and interpolation. - Special handlers: Pyrefly implementation logic for constructs that need
deeper type system integration, like
nn.Sequentialchaining,.shape,.size(),assert_shape, and decorator interpretation.
The first two mechanisms live in tensor-shapes/ and are the normal way to add
or improve shape coverage. Special handlers require Pyrefly implementation
changes and should be treated as kernel work.
tensor-shapes/pyrefly-torch-stubs/torch-stubs/
|-- __init__.pyi
|-- _shapes.pyi
|-- nn/
| |-- __init__.pyi # nn.Linear, nn.Conv2d, nn.LSTM, etc.
| `-- functional.pyi # F.relu, F.softmax, F.conv2d, etc.
|-- distributions/
| `-- ... # torch.distributions
`-- ...
The tensor-shape test runner passes tensor-shapes/ as a Pyrefly search path,
so these stubs override the normal torch stubs during validation.
A fixture stub provides a shape-generic type signature. For example,
nn.Linear:
class Linear[N, M](Module):
def __init__(
self,
in_features: Dim[N],
out_features: Dim[M],
bias: bool = True,
) -> None: ...
def forward[*Xs](self, input: Tensor[*Xs, N]) -> Tensor[*Xs, M]: ...The constructor captures input and output dimensions as type parameters. The
forward method uses those parameters plus a variadic *Xs for batch
dimensions.
- Identify the shape signature: input dimensions, output dimensions, and how they relate.
- Use
Dim[X]for parameters that determine tensor dimensions. Non-shape parameters likebiasanddropoutstay as their original types. - Write the method or function signature expressing the shape transform. Use
*Xsor*Bsfor batch dimensions that pass through unchanged. - Add the stub to the appropriate
.pyifile intensor-shapes/pyrefly-torch-stubs/torch-stubs. - Add or update focused tests under
tensor-shapes/pyrefly-torch-stubs/test/.
Suppose you want to add nn.GroupNorm, which preserves spatial dimensions:
class GroupNorm[NumGroups, NumChannels](Module):
def __init__(
self,
num_groups: Dim[NumGroups],
num_channels: Dim[NumChannels],
eps: float = 1e-5,
affine: bool = True,
) -> None: ...
def forward[*S](self, input: Tensor[*S]) -> Tensor[*S]: ...Since GroupNorm does not change shape, the forward signature is simply
Tensor[*S] -> Tensor[*S].
Use the DSL when a plain signature cannot express the output shape.
DSL functions live in:
tensor-shapes/pyrefly-torch-stubs/torch-stubs/_shapes.pyi
Stubs attach a DSL function with @uses_shape_dsl(...). For example, a stub may
declare a broad return type like Tensor and let the DSL refine the result shape
at call sites:
from shape_extensions import uses_shape_dsl
from torch._shapes import reshape_ir
@uses_shape_dsl(reshape_ir)
def reshape(self: Tensor, shape: tuple[int, ...]) -> Tensor: ...The DSL is intentionally small. It supports common shape computation patterns, including:
ShapedArray(shape=[...])to construct result shapesself.shapeand other shaped-array argument shapes- Lists, slices, indexing, and comprehensions
- Arithmetic such as
+,-,*,//,%, and** if/else- Helper calls to other
@shape_dsl_functionfunctions - DSL helpers from
shape_extensions.dsl, such asprod,sum,Unknown, andError
Keep DSL functions simple and algebraic. They are analyzed by Pyrefly; they are not normal runtime implementations of PyTorch operations.
@shape_dsl_function
def cat_ir(tensors: list[ShapedArray], dim: int = 0) -> ShapedArray:
first = tensors[0]
d = normalize_dim(len(first.shape), dim)
return ShapedArray(
shape=[
sum([t.shape[i] for t in tensors]) if i == d else dim_val
for i, dim_val in enumerate(first.shape)
]
)This sums shapes along the concatenation dimension and preserves all others.
- Write the shape transform in
tensor-shapes/pyrefly-torch-stubs/torch-stubs/_shapes.pyi. - Decorate it with
@shape_dsl_function. - Attach it to the public stub with
@uses_shape_dsl(...). - Add positive tests that assert the computed shape.
- Add negative tests with
# E:expectations if the DSL should reject invalid shapes or report shape errors.
tensor-shapes/pyrefly-torch-stubs/examples/
Each file is a fully annotated port of a real-world PyTorch model with
assert_type checkpoints and smoke tests.
- Choose a model from TorchBench or another source.
- Port it using the tutorials or the agent skill.
- Add
assert_typeorassert_shapecheckpoints after shape-changing operations. - Add smoke tests at the bottom of the file when runtime execution is useful.
- Run
verify_port.shto check for common quality issues.
This script checks a ported model for common issues:
tensor-shapes/skills/add-shape-types-to-torch-model/verify_port.sh tensor-shapes/pyrefly-torch-stubs/examples/<model>.pyIt reports:
| Metric | Description |
|---|---|
ig |
type: ignore count |
bs |
Bare Tensor in signatures |
bv |
Bare Tensor in variable annotations |
sh |
Shaped assert_type count |
ba |
Bare assert_type count |
sm |
Smoke test count |
For most contributions, the important validation is the tensor-shape Pyrefly runner. It checks the focused tests, negative expectations, jaxtyping examples, and the example corpus using the shape-aware stubs.
Build Pyrefly first, then run:
cargo build
python3 tensor-shapes/pyrefly-torch-stubs/run_pyrefly.pyIf your build uses a custom target directory, run_pyrefly.py respects
CARGO_TARGET_DIR. You can also pass the binary explicitly:
python3 tensor-shapes/pyrefly-torch-stubs/run_pyrefly.py --pyrefly /path/to/pyreflyRun a single suite while iterating:
python3 tensor-shapes/pyrefly-torch-stubs/run_pyrefly.py --suite torch-positive
python3 tensor-shapes/pyrefly-torch-stubs/run_pyrefly.py --suite torch-negative
python3 tensor-shapes/pyrefly-torch-stubs/run_pyrefly.py --suite torch-examplesUse --nocapture when you want the full Pyrefly output on success. By default,
the runner prints a compact PASS ... line and only dumps checker output on
failure.
In an internal Buck checkout, the equivalent static validation targets are:
buck test tensor-shapes/pyrefly-torch-stubs/test:tensor_shapes_all_test
buck test tensor-shapes/pyrefly-torch-stubs/test:tensor_shapes_error_test
buck test tensor-shapes/pyrefly-torch-stubs/test:tensor_shapes_jaxtyping_test
buck test tensor-shapes/pyrefly-torch-stubs/test:tensor_shapes_jaxtyping_error_test
buck test tensor-shapes/pyrefly-torch-stubs/examples:torch_examples_testThe project-level test.py runner keeps tensor-shape validation separate from
the default Pyrefly test loop. To run just these validations through test.py:
python3 test.py --no-fmt --no-lint --no-test --tensor-shapes --no-conformance --no-jsonschemaRuntime tests validate that the annotation helpers and runnable example models behave correctly in Python, not just in Pyrefly's static checker.
The tests live in:
tensor-shapes/pyrefly-torch-stubs/test/runtime_tests/
Run them from a Python 3.12+ virtualenv with torch installed:
python3.12 -m venv .tensor-shapes-venv
. .tensor-shapes-venv/bin/activate
python -m pip install --upgrade pip
python -m pip install torch
python tensor-shapes/pyrefly-torch-stubs/run_runtime_tests.pyRun one suite while iterating:
python tensor-shapes/pyrefly-torch-stubs/run_runtime_tests.py --suite annotation
python tensor-shapes/pyrefly-torch-stubs/run_runtime_tests.py --suite modelThe runtime runner sets up import paths for shape_extensions and the runnable
example modules. In an internal Buck checkout, the existing runtime targets are:
buck test tensor-shapes/pyrefly-torch-stubs/test:annotation_runtime_test
buck test tensor-shapes/pyrefly-torch-stubs/test:model_runtime_testMost contributors should not need this section. Use these tests when you change Pyrefly's tensor-shape kernel rather than only stubs or examples. Kernel changes include:
shape_extensionsprimitives or decoratorsassert_shapetype-checker behavior@shape_dsl_functionparsing, validation, or evaluation@uses_shape_dslhandling- special handlers in Pyrefly's Rust source
The focused Pyrefly unit tests live in:
pyrefly/lib/test/shape_dsl.rs
Run them with Cargo:
cargo test shape_dslIn an internal Buck checkout:
buck test pyrefly:pyrefly_library -- shape_dslKernel tests are intentionally much smaller than the stub/example suites. They cover the core primitives and invariants; the tensor-shape stub tests stress the DSL through realistic PyTorch signatures.
Before handing off changes, run formatting and linting:
./test.py --no-test --no-tensor-shapes --no-conformance --no-jsonschemaAlso run the relevant tensor-shape checks for the files you touched:
- Stub/test/example changes:
python3 tensor-shapes/pyrefly-torch-stubs/run_pyrefly.py - Runtime helper or runnable model changes:
python tensor-shapes/pyrefly-torch-stubs/run_runtime_tests.py - Kernel changes:
cargo test shape_dslor the Buck equivalent above