Skip to content

Cherrypick #3440 for release/2.7 #3485

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 27, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
241 changes: 241 additions & 0 deletions examples/dynamo/llama2_flashinfer_rmsnorm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,241 @@
from typing import Callable, Optional, Sequence, Union

import flashinfer
import torch
import torch_tensorrt
from torch.fx.passes.shape_prop import TensorMetadata
from torch_tensorrt.dynamo.lowering.passes._aten_lowering_pass import (
_aten_lowering_pass,
)
from torch_tensorrt.dynamo.lowering.passes.pass_utils import (
clean_up_graph_after_modifications,
)
from transformers import LlamaConfig, LlamaForCausalLM


@torch.library.custom_op("flashinfer::rmsnorm", mutates_args=()) # type: ignore[misc]
def flashinfer_rmsnorm(
input: torch.Tensor, weight: torch.Tensor, eps: float = 1e-6
) -> torch.Tensor:
return flashinfer.norm.rmsnorm(input, weight)


@torch.library.register_fake("flashinfer::rmsnorm")
def _(input: torch.Tensor, weight: torch.Tensor, b: float = 1e-6) -> torch.Tensor:
return input


torch_tensorrt.dynamo.conversion.plugins.custom_op(
"flashinfer::rmsnorm", supports_dynamic_shapes=True
)


@_aten_lowering_pass
def replace_rmsnorm(
gm: torch.fx.GraphModule, sample_inputs: Sequence[torch.Tensor]
) -> torch.fx.GraphModule:
for node in gm.graph.nodes:
if (
node.target == torch.ops.aten._to_copy.default
and node.kwargs.get("dtype") is torch.float32
and len(node.users) == 2
):
if (
list(node.users)[0].target == torch.ops.aten.pow.Tensor_Scalar
and list(node.users)[1].target == torch.ops.aten.mul.Tensor
):
pow_node = list(node.users)[0]
if (
len(pow_node.users) == 1
and list(pow_node.users)[0].target == torch.ops.aten.mean.dim
):
mean_node = list(pow_node.users)[0]
if (
len(mean_node.users) == 1
and list(mean_node.users)[0].target == torch.ops.aten.add.Tensor
):
add_node = list(mean_node.users)[0]
if (
len(add_node.users) == 1
and list(add_node.users)[0].target
== torch.ops.aten.sqrt.default
):
sqrt_node = list(add_node.users)[0]
if (
len(sqrt_node.users) == 1
and list(sqrt_node.users)[0].target
== torch.ops.aten.div.Tensor
):
div_node = list(sqrt_node.users)[0]
if list(div_node.users)[0] == list(node.users)[1]:
mul_node = list(div_node.users)[0]
copy_node = list(mul_node.users)[0]
weight_mul_node = list(copy_node.users)[0]

weight = weight_mul_node.args[0]

original_meta = weight_mul_node.meta.get(
"tensor_meta", {}
)
memory_format = original_meta.memory_format

with gm.graph.inserting_after(weight_mul_node):
b = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.sym_size.int,
args=(node.args[0], 0),
)
b.meta["tensor_meta"] = TensorMetadata(
shape=torch.Size([]),
dtype=torch.int64,
requires_grad=False,
stride=None,
memory_format=memory_format,
is_quantized=False,
qparams={},
)
s = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.sym_size.int,
args=(node.args[0], 1),
)
s.meta.update(b.meta)

d = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.sym_size.int,
args=(node.args[0], 2),
)
d.meta.update(b.meta)

with gm.graph.inserting_after(b):
new_first_dim = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.mul.Scalar,
args=(b, s),
)
new_first_dim.meta.update(b.meta)

with gm.graph.inserting_after(new_first_dim):
# with gm.graph.inserting_after(weight_mul_node):
reshape_node = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.reshape.default,
args=(node.args[0], [new_first_dim, d]),
)
b_val = original_meta.shape[0]
s_val = original_meta.shape[1]
d_val = original_meta.shape[2]

reshape_node.meta["tensor_meta"] = (
TensorMetadata(
shape=torch.Size(
[b_val * s_val, d_val]
),
dtype=original_meta.dtype,
requires_grad=True,
stride=None,
memory_format=memory_format,
is_quantized=False,
qparams={},
)
)

with gm.graph.inserting_after(reshape_node):
flashinfer_rmsnorm_node = gm.graph.create_node(
op="call_function",
target=torch.ops.flashinfer.rmsnorm.default,
args=(
reshape_node,
weight,
add_node.args[1],
),
)
flashinfer_rmsnorm_node.meta.update(
reshape_node.meta
)

with gm.graph.inserting_after(
flashinfer_rmsnorm_node
):
reshapback_node = gm.graph.create_node(
op="call_function",
target=torch.ops.aten.reshape.default,
args=(
flashinfer_rmsnorm_node,
[b, s, d],
),
)

weight_mul_node.replace_all_uses_with(
reshapback_node
)
reshapback_node.meta.update(weight_mul_node.meta)

modified_graph = True

gm.graph.erase_node(weight_mul_node)
gm.graph.erase_node(copy_node)
gm.graph.erase_node(mul_node)
gm.graph.erase_node(div_node)
gm.graph.erase_node(sqrt_node)
gm.graph.erase_node(add_node)
gm.graph.erase_node(mean_node)
gm.graph.erase_node(pow_node)
gm.graph.erase_node(node)

if modified_graph:
gm = clean_up_graph_after_modifications(gm)

return gm


# 1. Create a custom config with 1 layer
config = LlamaConfig(
vocab_size=32000,
hidden_size=4096, # LLaMA2-7B dimensions
intermediate_size=11008, # FFN hidden_dim = 4 * 4096 * 0.7 (SwiGLU scaling)
num_hidden_layers=1, # Only 1 decoder layer
num_attention_heads=32,
max_position_embeddings=4096,
use_cache=False, # Disable KV caching for export
)

# 2. Initialize model (random weights)
with torch.no_grad():
model = LlamaForCausalLM(config).eval().half()

# 3. Export with static shapes
input_ids = torch.randint(0, 32000, (1, 64)) # Static [batch=1, seq=64]
exported = torch.export.export(
model,
(input_ids,),
dynamic_shapes=None, # Fully static
)

# Test forward pass
input_ids = torch.randint(0, 32000, (1, 64))
output = model(input_ids)
print(output)

# Export validation

DEVICE = torch.device("cuda:0")

with torch_tensorrt.logging.errors():
trt_model = torch_tensorrt.dynamo.compile(
exported,
inputs=[input_ids],
enabled_precisions={torch.float32, torch.float16},
truncate_double=True,
device=DEVICE,
disable_tf32=True,
use_explicit_typing=False,
use_fp32_acc=True,
# debug=True,
)

input_ids = input_ids.to(DEVICE)

res = trt_model.forward(input_ids)
print(res)
21 changes: 15 additions & 6 deletions py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import itertools
import logging
from types import FunctionType
from typing import Any, Callable, Tuple
Expand Down Expand Up @@ -108,7 +109,6 @@ def generate_signature(

def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
shape_env = ShapeEnv()
fake_mode = FakeTensorMode(shape_env=shape_env)
syms_args = []
tensor_args = [elem for elem in args if isinstance(elem, trtp.TensorDesc)]

Expand All @@ -121,7 +121,7 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
]
syms_args.append(syms_arg)

with FakeTensorMode() as fake_mode:
with FakeTensorMode(shape_env=shape_env) as fake_mode:
fake_args = []
for syms_arg in syms_args:
fake_arg = torch.randn(syms_arg)
Expand All @@ -130,16 +130,25 @@ def _generic_plugin_desc(*args: Any, **kwargs: Any) -> Tuple[trtp.TensorDesc]:
output = torch_op(*fake_args, **kwargs)

# We assume that number of dimensions are the same in torch op
shape_calc_fns = [None] * args[0].ndim
for i in range(args[0].ndim):
input_node_expr = [syms_arg[i].node.expr for syms_arg in syms_args]
shape_calc_fns = [None] * output.ndim

for i in range(output.ndim):
input_node_expr = list(
itertools.chain.from_iterable(
[sym.node.expr for sym in syms_arg] for syms_arg in syms_args
)
)

shape_calc_fns[i] = lambdify(
tuple(input_node_expr), output.shape[i].node.expr, "math"
)

out_desc = tensor_args[0].like()
for i in range(out_desc.ndim):
input_shape_expr = [tensor_arg.shape_expr[i] for tensor_arg in tensor_args]
input_shape_expr = list(
itertools.chain.from_iterable(arg.shape_expr for arg in tensor_args)
)

if output.shape[i].node.expr is None:
raise ValueError(f"output.shape[{i}].node.expr cannot be None")
out_desc.shape_expr[i] = shape_calc_fns[i](*input_shape_expr) # type: ignore[misc]
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import logging
import uuid
from typing import Callable, Dict, Optional, Sequence, Tuple, Union

import numpy as np
Expand Down Expand Up @@ -47,11 +48,15 @@ def custom_kernel_converter(
kwargs: Dict[str, Argument],
name: str,
) -> Union[trt.ITensor, Sequence[trt.ITensor]]:

plugin = getattr(getattr(trtp.op, namespace), op_name)

tensor_inputs = plugin.input_tensor_names
tensor_args = args[0 : len(tensor_inputs)]

unique_id = uuid.uuid4()
itensor_args = [
get_trt_tensor(ctx, t, f"{t_name}")
get_trt_tensor(ctx, t, f"{t_name}_{unique_id}")
for (t, t_name) in zip(tensor_args, tensor_inputs)
]

Expand Down
Loading