Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions .github/workflows/build-test-linux-x86_64.yml
Original file line number Diff line number Diff line change
Expand Up @@ -459,6 +459,12 @@ jobs:
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_automatic_plugin_with_attrs.py
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_converters_test_results.xml automatic_plugin/test_flashinfer_rmsnorm.py
popd
pushd .
# cuda-python is an optional runtime dep for the torch_tensorrt.annotation QDP layer.
python -m pip install cuda-python
cd tests/py/annotation
python -m pytest -ra --junitxml=${RUNNER_TEST_RESULTS_DIR}/dynamo_annotation_test_results.xml .
popd
L2-torchscript-tests:
name: ${{ matrix.display-name }}
Expand Down
156 changes: 156 additions & 0 deletions docsrc/py_api/annotation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,156 @@
.. _torch_tensorrt_annotation_py:

torch_tensorrt.annotation
==========================

.. currentmodule:: torch_tensorrt.annotation

.. automodule:: torch_tensorrt.annotation

.. note::

This module is **experimental**. It requires ``cuda-python`` at runtime
and TensorRT ``>=10.7.0`` (and not ``10.14.x``) for Quick Deployable
Plugin (QDP) support. Install ``cuda-python`` with ``pip install
cuda-python``.

Overview
--------

The ``annotation`` module registers NVRTC-compiled CUDA C++ kernels as
TensorRT Quick Deployable Plugins with full Ahead-of-Time (AOT)
compilation support. It offers two entry points that trade
declarativeness for flexibility — start with :func:`auto_cuda_kernel_plugin` and
drop down to :func:`manual_cuda_kernel_plugin` only when your kernel falls outside
the declarative DSL:

.. list-table::
:header-rows: 1
:widths: 28 36 36

* - Entry point
- What you provide
- What you get for free
* - :func:`auto_cuda_kernel_plugin`
- A :class:`KernelSpec` dataclass (source, inputs, outputs, extras,
geometry)
- Meta / eager / AOT functions and the PyTorch schema — all derived
* - :func:`manual_cuda_kernel_plugin`
- ``aot_fn`` + ``eager_fn`` + a meta function decorated with the
one-shot decorator
- PyTorch op + TRT plugin + converter, registered together

For unary-pointwise kernels, :func:`pointwise_aot` and
:func:`pointwise_eager` produce the two callables so users can plug them
directly into :func:`manual_cuda_kernel_plugin`.

Declarative entry point
-----------------------

.. autofunction:: auto_cuda_kernel_plugin

KernelSpec DSL
^^^^^^^^^^^^^^

.. autoclass:: KernelSpec
:members:

.. autoclass:: InputDecl
:members:

.. autoclass:: ScalarInput
:members:

.. autoclass:: OutputDecl
:members:

Shape relations
"""""""""""""""

.. autoclass:: SameAs
:members:

.. autoclass:: ReduceDims
:members:

Extra scalar args
"""""""""""""""""

Extras are passed to the kernel between the input and output pointer
lists in :class:`KernelSpec` order.

.. autoclass:: Numel
:members:

.. autoclass:: DimSize
:members:

Launch geometry
"""""""""""""""

.. autoclass:: Elementwise
:members:

.. autoclass:: Reduction
:members:

.. autoclass:: Custom
:members:

One-shot hand-written entry point
---------------------------------

.. autofunction:: manual_cuda_kernel_plugin

Lower-level building blocks
---------------------------

.. autofunction:: cuda_python

.. autofunction:: custom_plugin

Spec class
^^^^^^^^^^

.. autoclass:: CudaPythonSpec
:members:

Pointwise helpers
-----------------

.. autofunction:: pointwise_aot

.. autofunction:: pointwise_eager

Kernel signature convention
---------------------------

All entry points assume the ``__global__`` kernel takes its arguments in
the fixed order::

(input_ptrs..., extras..., output_ptrs...)

Pointers are ``void*`` cast to the appropriate element type. Extras
follow the order declared in :attr:`KernelSpec.extras` for the
declarative path, or the order your ``aot_fn`` builds for the manual
path.

Error behavior
--------------

:func:`auto_cuda_kernel_plugin` validates the :class:`KernelSpec` at decorator
time and raises :class:`ValueError` for the common authoring mistakes:

- Empty or duplicate-named ``inputs`` / ``outputs``.
- ``ReduceDims(input_idx=N)`` or ``SameAs(input_idx=N)`` where ``N`` is
out of range.
- ``Numel`` / ``DimSize`` referencing a name that is not an input.
- ``dtype_from`` pointing at an unknown input.
- ``Elementwise(layout='flat')`` with a multi-dimensional block tuple.
- Invalid block sizes, ``block_size`` in :class:`Reduction`, or a
non-callable :attr:`Custom.fn`.

Shape-dependent errors — for example
``Elementwise(layout='nd', block=(16, 16))`` invoked against a 1-D
output — are raised at launch time in a clear ``ValueError`` because
the offending ranks are only known when concrete tensors arrive.
1 change: 1 addition & 0 deletions docsrc/py_api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ Core
dynamo
logging
runtime
annotation
../cli/torchtrtc
../indices/supported_ops

Expand Down
120 changes: 120 additions & 0 deletions examples/dynamo/auto_cuda_kernel_plugin_annotation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
"""
.. _auto_cuda_kernel_plugin_annotation:

Declarative Custom Kernel via ``torch_tensorrt.annotation.auto_cuda_kernel_plugin``
====================================================================================

``auto_cuda_kernel_plugin`` is the recommended entry point for exposing a
hand-written CUDA C++ kernel to both PyTorch eager and the Torch-TensorRT
compile path.

You describe the kernel with a :class:`KernelSpec` dataclass (inputs, outputs,
extras, launch geometry) and the framework derives the meta function, the
eager CUDA launch, the TensorRT AOT implementation, and the PyTorch op schema
— no hand-written ``aot_fn`` / ``eager_fn`` required.

Use ``auto_cuda_kernel_plugin`` whenever your kernel follows the standard
calling convention ``(input_ptrs..., extras..., output_ptrs...)`` and fits one
of the built-in geometries: ``Elementwise(layout="flat" | "nd")`` or
``Reduction(reduce_dims=...)``.

For shape-changing kernels or anything outside that envelope, drop down to
:func:`torch_tensorrt.annotation.manual_cuda_kernel_plugin` (see
``manual_cuda_kernel_plugin_annotation.py``).
"""

import sys

import torch

import torch_tensorrt

if not torch_tensorrt.ENABLED_FEATURES.qdp_plugin:
print(
"[auto_cuda_kernel_plugin_annotation] Skipping example: "
"torch_tensorrt.annotation requires TensorRT QDP plugin support."
)
sys.exit(0)

try:
import tensorrt.plugin # noqa: F401
except ImportError:
print(
"[auto_cuda_kernel_plugin_annotation] Skipping example: "
"tensorrt.plugin unavailable."
)
sys.exit(0)

try:
import cuda.core # noqa: F401
except ImportError:
try:
import cuda.core.experimental # noqa: F401
except ImportError:
print(
"[auto_cuda_kernel_plugin_annotation] Skipping example: cuda-python "
"is not installed. Install with `pip install cuda-python`."
)
sys.exit(0)

import torch_tensorrt.annotation as tta

# Calling convention expected by auto_cuda_kernel_plugin:
# (input_ptrs..., extras..., output_ptrs...)

CU_SIGMOID = """
extern "C" __global__ void my_sigmoid(
const float* __restrict__ x, int n, float* __restrict__ y) {
int i = blockIdx.x * blockDim.x + threadIdx.x;
if (i < n) y[i] = 1.0f / (1.0f + __expf(-x[i]));
}
"""


# SameAs(0) output has the same shape and dtype as input 0.
# Numel("x") pass x.numel() to the kernel as an int extra.
# Elementwise(flat) 1-D launch over the flattened output; any input rank works.

tta.auto_cuda_kernel_plugin(
"ann_ex::sigmoid",
tta.KernelSpec(
kernel_source=CU_SIGMOID,
kernel_name="my_sigmoid",
inputs=[tta.InputDecl("x")],
outputs=[tta.OutputDecl("y", shape=tta.SameAs(0))],
extras=[tta.Numel("x")],
geometry=tta.Elementwise(block=(256,), layout="flat"),
),
supports_dynamic_shapes=True,
)


class Model(torch.nn.Module):
def forward(self, x: torch.Tensor) -> torch.Tensor:
return torch.ops.ann_ex.sigmoid(x)


if __name__ == "__main__":
x = torch.randn(4, 128, device="cuda", dtype=torch.float32)
ref = torch.sigmoid(x)

model = Model().cuda().eval()
eager_out = model(x)
print(
"Eager result matches torch.sigmoid:", torch.allclose(eager_out, ref, atol=1e-4)
)

print("Compiling with Torch-TensorRT...")
trt_model = torch_tensorrt.compile(
model,
inputs=[x],
enabled_precisions={torch.float32},
min_block_size=1,
)

with torch.no_grad():
for _ in range(5):
out = trt_model(x)
assert torch.allclose(out, ref, atol=1e-2, rtol=1e-2), "Mismatch!"

print("TRT inference successful - results match torch.sigmoid")
Loading
Loading