2222#include < executorch/runtime/platform/assert.h>
2323#include < executorch/runtime/platform/log.h>
2424
25+ #ifdef CUDA_AVAILABLE
26+ #include < executorch/extension/llm/sampler/cuda_sampler.h>
27+ #endif
28+
2529namespace executorch ::extension::asr {
2630namespace {
2731
@@ -110,19 +114,25 @@ Error AsrRunner::load() {
110114 ET_CHECK_OK_OR_RETURN_ERROR (module_->load_method (kDecoderMethodName ));
111115 decoder_method_loaded_ = true ;
112116#ifdef CUDA_AVAILABLE
113- executorch::runtime::BackendOptions<1 > backend_options;
114- // For decoder still copy output from GPU to CPU for sampling.
115- // TODO: change sampler to use a CUDA kernel to sample and then skip copying
116- // decoder output as well
117- ET_CHECK_OK_OR_RETURN_ERROR (backend_options.set_option (
118- " skip_copy_output_to_cpu_for_method" , kEncoderMethodName ));
119- const auto opt_err =
120- executorch::runtime::set_option (" CudaBackend" , backend_options.view ());
121- if (opt_err != ::executorch::runtime::Error::Ok) {
122- ET_LOG (
123- Error,
124- " Failed to set CUDA backend options: %d" ,
125- static_cast <int >(opt_err));
117+ {
118+ // Skip copying outputs to CPU for both encoder and decoder methods.
119+ // Encoder output stays on GPU for the decoder to consume directly.
120+ // Decoder logits stay on GPU for CUDA-based sampling (temperature=0).
121+ // For temperature != 0, we fall back to CPU sampling which will require
122+ // a copy, but that path is less common for ASR applications.
123+ std::string skip_methods =
124+ std::string (kEncoderMethodName ) + " ," + kDecoderMethodName ;
125+ executorch::runtime::BackendOptions<1 > backend_options;
126+ ET_CHECK_OK_OR_RETURN_ERROR (backend_options.set_option (
127+ " skip_copy_output_to_cpu_for_method" , skip_methods.c_str ()));
128+ const auto opt_err =
129+ executorch::runtime::set_option (" CudaBackend" , backend_options.view ());
130+ if (opt_err != ::executorch::runtime::Error::Ok) {
131+ ET_LOG (
132+ Error,
133+ " Failed to set CUDA backend options: %d" ,
134+ static_cast <int >(opt_err));
135+ }
126136 }
127137#endif
128138 ET_CHECK_OK_OR_RETURN_ERROR (load_tokenizer ());
@@ -266,6 +276,18 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
266276 decoder_inputs.emplace_back (decoder_input_ptr);
267277 decoder_inputs.emplace_back (encoder_output_ptr);
268278 decoder_inputs.emplace_back (cache_position_ptr);
279+
280+ #ifdef CUDA_AVAILABLE
281+ // Create CUDA sampler outside the loop to avoid memory allocation overhead.
282+ // Only used when temperature == 0 (argmax sampling).
283+ const bool use_cuda_sampler = (config.temperature == 0 .0f );
284+ std::optional<::executorch::extension::llm::CudaSampler> cuda_sampler;
285+ if (use_cuda_sampler) {
286+ cuda_sampler.emplace ();
287+ ET_LOG (Info, " Using CUDA sampler for argmax sampling" );
288+ }
289+ #endif
290+
269291 // Add some green coloring for the first generated token
270292 // token_callback("\033[1;32m");
271293 while (generated_tokens < config.max_new_tokens ) {
@@ -286,9 +308,34 @@ Result<std::vector<int64_t>> AsrRunner::transcribe(
286308 ET_CHECK_OR_RETURN_ERROR (
287309 vocab_size > 0 , Internal, " Decoder logits tensor is empty." );
288310
289- const int64_t next_token =
311+ int64_t next_token;
312+ #ifdef CUDA_AVAILABLE
313+ if (use_cuda_sampler && cuda_sampler.has_value ()) {
314+ // Use CUDA-based argmax sampling - logits are already on GPU
315+ next_token = static_cast <int64_t >(cuda_sampler->sample_argmax (
316+ logits_tensor.const_data_ptr (),
317+ static_cast <int >(vocab_size),
318+ logits_tensor.scalar_type ()));
319+ ET_CHECK_OR_RETURN_ERROR (
320+ next_token >= 0 ,
321+ Internal,
322+ " CUDA sampler failed to sample token" );
323+ } else {
324+ // Fall back to CPU sampling for temperature != 0
325+ // Note: This requires the logits to be copied to CPU, which happens
326+ // automatically when skip_copy_output_to_cpu_for_method doesn't include
327+ // the decoder method. Since we include decoder in the skip list, we need
328+ // to handle this case differently in the future if we want to support
329+ // temperature != 0 with CUDA.
330+ next_token =
331+ static_cast <int64_t >(::executorch::extension::llm::logits_to_token (
332+ logits_tensor, config.temperature ));
333+ }
334+ #else
335+ next_token =
290336 static_cast <int64_t >(::executorch::extension::llm::logits_to_token (
291337 logits_tensor, config.temperature ));
338+ #endif
292339
293340 if (!first_token_generated) {
294341 stats_.first_token_ms = ::executorch::extension::llm::time_in_ms ();
0 commit comments