Skip to content

Commit f44082c

Browse files
committed
Integrate CUDA sampler into ASR runner and enable skip_copy for decoder
- cuda_backend.cpp: Support comma-separated method names in skip_copy_output_to_cpu_for_method backend option. - runner.cpp: Use CudaSampler for argmax when CUDA is available and temperature==0. Skip copy to CPU for both encoder and decoder methods. - CMakeLists.txt updates: Link against extension_llm_sampler_cuda library and include the sampler subdirectory. This optimization keeps decoder logits on GPU and performs argmax directly on GPU memory, avoiding unnecessary device-to-host copies in the decode loop. ghstack-source-id: e3e85c9 ghstack-comment-id: 3688378475 Pull-Request: #16388
1 parent f872869 commit f44082c

File tree

4 files changed

+111
-15
lines changed

4 files changed

+111
-15
lines changed

backends/cuda/runtime/cuda_backend.cpp

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,36 @@ class ET_EXPERIMENTAL CudaBackend final
8383
return false;
8484
}
8585
std::lock_guard<std::mutex> guard(skip_copy_method_mutex_);
86-
return method_name == skip_copy_method_;
86+
// Support comma-separated list of method names
87+
if (skip_copy_method_.empty()) {
88+
return false;
89+
}
90+
// Check if method_name matches any entry in the comma-separated list
91+
size_t start = 0;
92+
size_t end = skip_copy_method_.find(',');
93+
while (end != std::string::npos) {
94+
std::string entry = skip_copy_method_.substr(start, end - start);
95+
// Trim whitespace
96+
size_t entry_start = entry.find_first_not_of(" \t");
97+
size_t entry_end = entry.find_last_not_of(" \t");
98+
if (entry_start != std::string::npos) {
99+
entry = entry.substr(entry_start, entry_end - entry_start + 1);
100+
if (entry == method_name) {
101+
return true;
102+
}
103+
}
104+
start = end + 1;
105+
end = skip_copy_method_.find(',', start);
106+
}
107+
// Check last (or only) entry
108+
std::string entry = skip_copy_method_.substr(start);
109+
size_t entry_start = entry.find_first_not_of(" \t");
110+
size_t entry_end = entry.find_last_not_of(" \t");
111+
if (entry_start != std::string::npos) {
112+
entry = entry.substr(entry_start, entry_end - entry_start + 1);
113+
return entry == method_name;
114+
}
115+
return false;
87116
}
88117

89118
Error load_function_pointers_into_handle(

extension/asr/runner/CMakeLists.txt

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,15 @@ if(EXECUTORCH_BUILD_CUDA)
4242
find_package(CUDAToolkit QUIET)
4343
if(CUDAToolkit_FOUND)
4444
target_compile_definitions(extension_asr_runner PUBLIC CUDA_AVAILABLE)
45+
target_include_directories(
46+
extension_asr_runner PUBLIC ${CUDAToolkit_INCLUDE_DIRS}
47+
)
48+
# Link against the CUDA sampler library from extension/llm/sampler
49+
if(TARGET extension_llm_sampler_cuda)
50+
target_link_libraries(extension_asr_runner PUBLIC extension_llm_sampler_cuda)
51+
else()
52+
target_link_libraries(extension_asr_runner PUBLIC CUDA::cudart)
53+
endif()
4554
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE for ASR runner")
4655
else()
4756
message(

extension/asr/runner/runner.cpp

Lines changed: 61 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,10 @@
2222
#include <executorch/runtime/platform/assert.h>
2323
#include <executorch/runtime/platform/log.h>
2424

25+
#ifdef CUDA_AVAILABLE
26+
#include <executorch/extension/llm/sampler/cuda_sampler.h>
27+
#endif
28+
2529
namespace executorch::extension::asr {
2630
namespace {
2731

@@ -110,19 +114,25 @@ Error AsrRunner::load() {
110114
ET_CHECK_OK_OR_RETURN_ERROR(module_->load_method(kDecoderMethodName));
111115
decoder_method_loaded_ = true;
112116
#ifdef CUDA_AVAILABLE
113-
executorch::runtime::BackendOptions<1> backend_options;
114-
// For decoder still copy output from GPU to CPU for sampling.
115-
// TODO: change sampler to use a CUDA kernel to sample and then skip copying
116-
// decoder output as well
117-
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
118-
"skip_copy_output_to_cpu_for_method", kEncoderMethodName));
119-
const auto opt_err =
120-
executorch::runtime::set_option("CudaBackend", backend_options.view());
121-
if (opt_err != ::executorch::runtime::Error::Ok) {
122-
ET_LOG(
123-
Error,
124-
"Failed to set CUDA backend options: %d",
125-
static_cast<int>(opt_err));
117+
{
118+
// Skip copying outputs to CPU for both encoder and decoder methods.
119+
// Encoder output stays on GPU for the decoder to consume directly.
120+
// Decoder logits stay on GPU for CUDA-based sampling (temperature=0).
121+
// For temperature != 0, we fall back to CPU sampling which will require
122+
// a copy, but that path is less common for ASR applications.
123+
std::string skip_methods =
124+
std::string(kEncoderMethodName) + "," + kDecoderMethodName;
125+
executorch::runtime::BackendOptions<1> backend_options;
126+
ET_CHECK_OK_OR_RETURN_ERROR(backend_options.set_option(
127+
"skip_copy_output_to_cpu_for_method", skip_methods.c_str()));
128+
const auto opt_err =
129+
executorch::runtime::set_option("CudaBackend", backend_options.view());
130+
if (opt_err != ::executorch::runtime::Error::Ok) {
131+
ET_LOG(
132+
Error,
133+
"Failed to set CUDA backend options: %d",
134+
static_cast<int>(opt_err));
135+
}
126136
}
127137
#endif
128138
ET_CHECK_OK_OR_RETURN_ERROR(load_tokenizer());
@@ -266,6 +276,18 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
266276
decoder_inputs.emplace_back(decoder_input_ptr);
267277
decoder_inputs.emplace_back(encoder_output_ptr);
268278
decoder_inputs.emplace_back(cache_position_ptr);
279+
280+
#ifdef CUDA_AVAILABLE
281+
// Create CUDA sampler outside the loop to avoid memory allocation overhead.
282+
// Only used when temperature == 0 (argmax sampling).
283+
const bool use_cuda_sampler = (config.temperature == 0.0f);
284+
std::optional<::executorch::extension::llm::CudaSampler> cuda_sampler;
285+
if (use_cuda_sampler) {
286+
cuda_sampler.emplace();
287+
ET_LOG(Info, "Using CUDA sampler for argmax sampling");
288+
}
289+
#endif
290+
269291
// Add some green coloring for the first generated token
270292
// token_callback("\033[1;32m");
271293
while (generated_tokens < config.max_new_tokens) {
@@ -286,9 +308,34 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
286308
ET_CHECK_OR_RETURN_ERROR(
287309
vocab_size > 0, Internal, "Decoder logits tensor is empty.");
288310

289-
const int64_t next_token =
311+
int64_t next_token;
312+
#ifdef CUDA_AVAILABLE
313+
if (use_cuda_sampler && cuda_sampler.has_value()) {
314+
// Use CUDA-based argmax sampling - logits are already on GPU
315+
next_token = static_cast<int64_t>(cuda_sampler->sample_argmax(
316+
logits_tensor.const_data_ptr(),
317+
static_cast<int>(vocab_size),
318+
logits_tensor.scalar_type()));
319+
ET_CHECK_OR_RETURN_ERROR(
320+
next_token >= 0,
321+
Internal,
322+
"CUDA sampler failed to sample token");
323+
} else {
324+
// Fall back to CPU sampling for temperature != 0
325+
// Note: This requires the logits to be copied to CPU, which happens
326+
// automatically when skip_copy_output_to_cpu_for_method doesn't include
327+
// the decoder method. Since we include decoder in the skip list, we need
328+
// to handle this case differently in the future if we want to support
329+
// temperature != 0 with CUDA.
330+
next_token =
331+
static_cast<int64_t>(::executorch::extension::llm::logits_to_token(
332+
logits_tensor, config.temperature));
333+
}
334+
#else
335+
next_token =
290336
static_cast<int64_t>(::executorch::extension::llm::logits_to_token(
291337
logits_tensor, config.temperature));
338+
#endif
292339

293340
if (!first_token_generated) {
294341
stats_.first_token_ms = ::executorch::extension::llm::time_in_ms();

extension/llm/runner/CMakeLists.txt

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,17 @@ if(EXECUTORCH_BUILD_CUDA)
6666
target_compile_definitions(extension_llm_runner PUBLIC CUDA_AVAILABLE)
6767
target_link_libraries(extension_llm_runner PUBLIC CUDA::cudart)
6868
message(STATUS "CUDAToolkit found; defining CUDA_AVAILABLE")
69+
70+
# Build the CUDA sampler library
71+
if(NOT TARGET extension_llm_sampler_cuda)
72+
add_subdirectory(
73+
${EXECUTORCH_ROOT}/extension/llm/sampler
74+
${CMAKE_CURRENT_BINARY_DIR}/sampler
75+
)
76+
endif()
77+
if(TARGET extension_llm_sampler_cuda)
78+
target_link_libraries(extension_llm_runner PUBLIC extension_llm_sampler_cuda)
79+
endif()
6980
else()
7081
message(
7182
STATUS

0 commit comments

Comments
 (0)