Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 18 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,18 +41,33 @@ endif()

include_directories(${CMAKE_CURRENT_SOURCE_DIR}/src)

# NVIDIA and Iluvatar are parallel backends; only one GPU backend at a time.
if(WITH_NVIDIA AND WITH_ILUVATAR)
message(FATAL_ERROR "WITH_NVIDIA and WITH_ILUVATAR cannot both be ON. Build one GPU backend at a time.")
endif()

if(WITH_NVIDIA)
add_compile_definitions(WITH_NVIDIA=1)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
endif()

# Iluvatar: CUDA-compatible device, uses clang++ with -x ivcore (not nvcc).
# Reference: InfiniCore xmake/iluvatar.lua
if(WITH_ILUVATAR)
add_compile_definitions(WITH_ILUVATAR=1)
if(NOT WITH_NVIDIA)
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
set(ILUVATAR_ARCH "ivcore20" CACHE STRING "Iluvatar GPU architecture")
find_program(CLANGXX NAMES clang++)
if(CLANGXX)
set(CMAKE_CUDA_COMPILER "${CLANGXX}" CACHE STRING "Iluvatar CUDA compiler (clang++)")
else()
set(CMAKE_CUDA_COMPILER "clang++" CACHE STRING "Iluvatar CUDA compiler (clang++)")
endif()
set(CMAKE_CUDA_FLAGS "-x ivcore -std=c++17 --cuda-gpu-arch=${ILUVATAR_ARCH} -fPIC -Wno-error=unused-variable -Wno-error=unused-private-field -Wno-unused-variable" CACHE STRING "Iluvatar CUDA flags")
set(CMAKE_CUDA_SEPARABLE_COMPILATION OFF CACHE BOOL "Disable RDC for Iluvatar")
message(STATUS "Iluvatar: CUDA compiler ${CMAKE_CUDA_COMPILER}, arch ${ILUVATAR_ARCH}")
enable_language(CUDA)
find_package(CUDAToolkit REQUIRED)
endif()

if(WITH_METAX)
Expand Down
133 changes: 94 additions & 39 deletions scripts/generate_wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,27 +72,25 @@
"""


class _OperatorExtractor:
def __call__(self, op_name):
def _get_system_include_flags():
system_include_flags = []

for line in subprocess.getoutput(
"clang++ -E -x c++ -v /dev/null"
).splitlines():
if not line.startswith(" "):
continue

system_include_flags.append("-isystem")
system_include_flags.append(line.strip())
def _get_system_include_flags():
system_include_flags = []
for line in subprocess.getoutput(
"clang++ -E -x c++ -v /dev/null"
).splitlines():
if not line.startswith(" "):
continue
system_include_flags.append("-isystem")
system_include_flags.append(line.strip())
return system_include_flags

return system_include_flags

class _OperatorExtractor:
def __call__(self, op_name, base_stem=None):
system_include_flags = _get_system_include_flags()

index = clang.cindex.Index.create()
args = ("-std=c++17", "-x", "c++", "-I", "src") + tuple(system_include_flags)
translation_unit = index.parse(f"src/base/{op_name.lower()}.h", args=args)
header = f"src/base/{(base_stem or op_name.lower())}.h"
translation_unit = index.parse(header, args=args)

nodes = tuple(type(self)._find(translation_unit.cursor, op_name))

Expand All @@ -105,7 +103,8 @@ def _get_system_include_flags():
elif node.kind == CursorKind.CXX_METHOD and node.spelling == "operator()":
calls.append(node)

return _Operator(op_name, constructors, calls)
header_name = base_stem if base_stem is not None else op_name.lower()
return _Operator(op_name, constructors, calls, header_name=header_name)

@staticmethod
def _find(node, op_name):
Expand All @@ -117,12 +116,34 @@ def _find(node, op_name):


class _Operator:
def __init__(self, name, constructors, calls):
def __init__(self, name, constructors, calls, header_name=None):
self.name = name

self.constructors = constructors

self.calls = calls
self.header_name = header_name if header_name is not None else name.lower()


def _make_mock_node(params):
"""Create a mock node with get_arguments() for manual operator specs."""

class _Type:
def __init__(self, spelling):
self.spelling = spelling

class _Arg:
def __init__(self, type_spelling, name):
self.type = _Type(type_spelling)
self.spelling = name

class _MockNode:
def get_arguments(self):
return [_Arg(typ, name) for typ, name in params]

return _MockNode()


# Operators that fail libclang parse; provide manual spec for wrapper generation.
_MANUAL_OP_SPECS = {}


def _generate_pybind11(operator):
Expand All @@ -135,6 +156,8 @@ def _generate_params(node):
)
.replace("const Tensor", "py::object")
.replace("Tensor", "py::object")
.replace("std::optional<float>", "float")
.replace("std::optional<int>", "bool")
)

def _generate_arguments(node):
Expand Down Expand Up @@ -173,7 +196,8 @@ def _generate_call(op_name, call, method=True):
)
calls = "\n".join(_generate_call(operator.name, call) for call in operator.calls)
callers = "\n".join(
_generate_call(operator.name, call, method=False) for call in operator.calls
_generate_call(operator.header_name, call, method=False)
for call in operator.calls
)

return f"""#ifndef INFINI_OPS_BINDINGS_{op_name.upper()}_H_
Expand All @@ -182,7 +206,7 @@ def _generate_call(op_name, call, method=True):
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>

#include "base/{op_name.lower()}.h"
#include "base/{operator.header_name}.h"
#include "utils.h"

namespace py = pybind11;
Expand Down Expand Up @@ -213,7 +237,7 @@ def _generate_source(operator):

return f"""#include "../../handle.h"
#include "../../tensor.h"
#include "infiniop/ops/{operator.name.lower()}.h"
#include "infiniop/ops/{operator.header_name}.h"
{impl_includes}

static infini::ops::DataType DataTypeFromInfiniDType(
Expand Down Expand Up @@ -270,7 +294,7 @@ def _generate_header(operator):
return f"""#ifndef __INFINIOP_{operator.name.upper()}_API_H__
#define __INFINIOP_{operator.name.upper()}_API_H__

#include "base/{operator.name.lower()}.h"
#include "base/{operator.header_name}.h"

typedef struct infini::ops::Operator<infini::ops::{operator.name}> *infiniop{operator.name}Descriptor_t;

Expand Down Expand Up @@ -382,20 +406,21 @@ def _generate_tensor_caster(name, is_data=False):
def _get_all_ops(devices):
ops = {}

for file_path in _BASE_DIR.iterdir():
if not file_path.is_file():
for base_file in _BASE_DIR.iterdir():
if not base_file.is_file():
continue

op_name = "".join(word.capitalize() for word in file_path.stem.split("_"))

ops[op_name] = []
op_name = "".join(word.capitalize() for word in base_file.stem.split("_"))
impl_paths = []

for file_path in _SRC_DIR.rglob("*"):
if not file_path.is_file() or file_path.parent.parent.name not in devices:
for impl_path in _SRC_DIR.rglob("*"):
if not impl_path.is_file() or impl_path.parent.parent.name not in devices:
continue

if f"class Operator<{op_name}" in file_path.read_text():
ops[op_name].append(file_path)
if f"class Operator<{op_name}" in impl_path.read_text():
impl_paths.append(impl_path)

ops[op_name] = {"base_stem": base_file.stem, "impl_paths": impl_paths}

return ops

Expand Down Expand Up @@ -429,12 +454,37 @@ def _get_all_ops(devices):

(_BINDINGS_DIR / "utils.h").write_text(_UTILS_H_CONTENT)

for op_name, impl_paths in ops.items():
extractor = _OperatorExtractor()
operator = extractor(op_name)
valid_ops = {}
for op_name, op_data in ops.items():
base_stem = op_data.get("base_stem") if isinstance(op_data, dict) else None
impl_paths = (
op_data.get("impl_paths", op_data)
if isinstance(op_data, dict)
else op_data
)

operator = None
if op_name in _MANUAL_OP_SPECS:
spec = _MANUAL_OP_SPECS[op_name]
operator = _Operator(
op_name,
constructors=[_make_mock_node(spec["constructor"])],
calls=[_make_mock_node(spec["call"])],
header_name=spec.get("header"),
)
else:
extractor = _OperatorExtractor()
try:
operator = extractor(op_name, base_stem=base_stem)
except clang.cindex.TranslationUnitLoadError as e:
print(
f"Warning: Skipping {op_name} - failed to parse base header: {e}"
)
continue

valid_ops[op_name] = impl_paths
source_path = _GENERATED_SRC_DIR / op_name.lower()
header_name = f"{op_name.lower()}.h"
header_name = f"{operator.header_name}.h"
bind_func_name = f"Bind{op_name}"

(_BINDINGS_DIR / header_name).write_text(_generate_pybind11(operator))
Expand All @@ -451,15 +501,20 @@ def _get_all_ops(devices):

impl_includes = "\n".join(
f'#include "{impl_path}"'
for impl_paths in ops.values()
for impl_paths in valid_ops.values()
for impl_path in impl_paths
)
op_includes = "\n".join(f'#include "{header_path}"' for header_path in header_paths)
bind_func_calls = "\n".join(
f"{bind_func_name}(m);" for bind_func_name in bind_func_names
)

(_BINDINGS_DIR / "ops.cc").write_text(f"""#include <pybind11/pybind11.h>
has_cuda_impl = any(
str(p).endswith(".cu") for impls in valid_ops.values() for p in impls
)
ops_source = "ops.cu" if has_cuda_impl else "ops.cc"

(_BINDINGS_DIR / ops_source).write_text(f"""#include <pybind11/pybind11.h>

// clang-format off
{impl_includes}
Expand Down
8 changes: 7 additions & 1 deletion src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,11 @@ if(WITH_NVIDIA)
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver)

list(APPEND DEVICE_LIST "nvidia")
set_target_properties(infiniops PROPERTIES CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON)
endif()

# Iluvatar: CUDA-compatible device; -x ivcore and flags from top-level CMakeLists.txt
if(WITH_ILUVATAR)
set(ILUVATAR_PATTERNS
"cuda/*.cc"
Expand All @@ -65,6 +68,9 @@ if(WITH_ILUVATAR)
find_package(CUDAToolkit REQUIRED)
target_link_libraries(infiniops PUBLIC CUDA::cudart CUDA::cublas CUDA::cuda_driver)

set_target_properties(infiniops PROPERTIES CUDA_STANDARD 17
CUDA_STANDARD_REQUIRED ON)

list(APPEND DEVICE_LIST "iluvatar")
endif()

Expand Down Expand Up @@ -112,7 +118,7 @@ if(GENERATE_PYTHON_BINDINGS)
set(PYBIND11_SOURCES "${PROJECT_SOURCE_DIR}/generated/bindings/ops.cc")

# TODO: There might be a better solution.
if(WITH_NVIDIA)
if(WITH_NVIDIA OR WITH_ILUVATAR)
set_source_files_properties(${PYBIND11_SOURCES} PROPERTIES LANGUAGE CUDA)
endif()

Expand Down
15 changes: 15 additions & 0 deletions src/base/gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,12 @@ class Gemm : public Operator<Gemm> {
// TODO: Check constraints.
}

Gemm(const Tensor a, const Tensor b, float alpha, float beta, bool trans_a,
bool trans_b, Tensor c)
: Gemm{a, b, std::optional<float>(alpha), std::optional<float>(beta),
std::optional<int>(static_cast<int>(trans_a)),
std::optional<int>(static_cast<int>(trans_b)), c} {}

Gemm(const Tensor a, const Tensor b, Tensor c)
: Gemm{a, b, std::nullopt, std::nullopt, std::nullopt, std::nullopt, c} {}

Expand All @@ -44,6 +50,15 @@ class Gemm : public Operator<Gemm> {
std::optional<int> trans_a,
std::optional<int> trans_b, Tensor c) const = 0;

virtual void operator()(void* stream, const Tensor a, const Tensor b,
float alpha, float beta, bool trans_a, bool trans_b,
Tensor c) const {
return operator()(stream, a, b, std::optional<float>(alpha),
std::optional<float>(beta),
std::optional<int>(static_cast<int>(trans_a)),
std::optional<int>(static_cast<int>(trans_b)), c);
}

virtual void operator()(void* stream, const Tensor a, const Tensor b,
Tensor c) const {
return operator()(stream, a, b, std::nullopt, std::nullopt, std::nullopt,
Expand Down
58 changes: 58 additions & 0 deletions src/base/rms_norm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
#ifndef INFINI_OPS_BASE_RMS_NORM_H_
#define INFINI_OPS_BASE_RMS_NORM_H_

#include <cstddef>
#include <vector>

#include "operator.h"
#include "tensor.h"

namespace infini::ops {

class RmsNorm : public Operator<RmsNorm> {
public:
RmsNorm(const Tensor y, const Tensor x, const Tensor w, float epsilon)
: epsilon_{epsilon},
y_shape_{y.shape()},
x_shape_{x.shape()},
y_strides_{y.strides()},
x_strides_{x.strides()},
dim_{y.size(-1)},
ndim_{y.ndim()},
batch_size_{ndim_ == 2 ? y.size(-2) : y.size(-3)},
nhead_{ndim_ == 2 ? 1 : y.size(-2)} {}

RmsNorm(const Tensor y, const Tensor x, const Tensor w)
: RmsNorm{y, x, w, 1e-6f} {}

virtual void operator()(void* stream, Tensor y, const Tensor x,
const Tensor w, float epsilon) const = 0;

virtual void operator()(void* stream, Tensor y, const Tensor x,
const Tensor w) const {
return operator()(stream, y, x, w, epsilon_);
}

protected:
float epsilon_{1e-6f};

Tensor::Shape y_shape_;

Tensor::Shape x_shape_;

Tensor::Strides y_strides_;

Tensor::Strides x_strides_;

Tensor::Size dim_{0};

Tensor::Size ndim_{0};

Tensor::Size batch_size_{0};

Tensor::Size nhead_{1};
};

} // namespace infini::ops

#endif
2 changes: 2 additions & 0 deletions src/common/cuda/kernel_commons.h
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

#ifdef WITH_NVIDIA
#include <cuda_runtime.h>
#elif defined(WITH_ILUVATAR)
#include <cuda_runtime.h>
#elif WITH_METAX
#include <mcr/mc_runtime.h>
#endif
Expand Down
Loading