Skip to content

Commit f72c3ce

Browse files
committed
openaiv2: clarify streaming contract and test corner cases
1 parent 2bd265b commit f72c3ce

File tree

4 files changed

+679
-292
lines changed

4 files changed

+679
-292
lines changed

instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py

Lines changed: 13 additions & 292 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,28 @@
1313
# limitations under the License.
1414

1515

16-
import asyncio
17-
import inspect
18-
from collections.abc import AsyncIterator, Iterator
1916
from timeit import default_timer
20-
from typing import Any, Optional, cast
17+
from typing import Any, Optional
2118

22-
from opentelemetry._logs import Logger, LogRecord
23-
from opentelemetry.context import get_current
19+
from opentelemetry._logs import Logger
2420
from opentelemetry.semconv._incubating.attributes import (
2521
gen_ai_attributes as GenAIAttributes,
2622
)
2723
from opentelemetry.semconv._incubating.attributes import (
2824
server_attributes as ServerAttributes,
2925
)
3026
from opentelemetry.trace import Span, SpanKind, Tracer
31-
from opentelemetry.trace.propagation import set_span_in_context
3227

3328
from .instruments import Instruments
29+
from .streaming import (
30+
AsyncStreamWrapper,
31+
StreamWrapper,
32+
SyncStreamWrapper,
33+
)
34+
35+
# Re-export StreamWrapper for backwards compatibility
36+
__all__ = ["StreamWrapper"]
37+
3438
from .utils import (
3539
choice_to_event,
3640
get_llm_request_attributes,
@@ -68,7 +72,7 @@ def traced_method(wrapped, instance, args, kwargs):
6872
try:
6973
result = wrapped(*args, **kwargs)
7074
if is_streaming(kwargs):
71-
return StreamWrapper(result, span, logger, capture_content)
75+
return SyncStreamWrapper(result, span, logger, capture_content)
7276

7377
if span.is_recording():
7478
_set_response_attributes(
@@ -125,7 +129,7 @@ async def traced_method(wrapped, instance, args, kwargs):
125129
try:
126130
result = await wrapped(*args, **kwargs)
127131
if is_streaming(kwargs):
128-
return StreamWrapper(result, span, logger, capture_content)
132+
return AsyncStreamWrapper(result, span, logger, capture_content)
129133

130134
if span.is_recording():
131135
_set_response_attributes(
@@ -472,286 +476,3 @@ def _set_embeddings_response_attributes(
472476
result.usage.prompt_tokens,
473477
)
474478
# Don't set output tokens for embeddings as all tokens are input tokens
475-
476-
477-
class ToolCallBuffer:
478-
def __init__(self, index, tool_call_id, function_name):
479-
self.index = index
480-
self.function_name = function_name
481-
self.tool_call_id = tool_call_id
482-
self.arguments = []
483-
484-
def append_arguments(self, arguments):
485-
self.arguments.append(arguments)
486-
487-
488-
class ChoiceBuffer:
489-
def __init__(self, index):
490-
self.index = index
491-
self.finish_reason = None
492-
self.text_content = []
493-
self.tool_calls_buffers = []
494-
495-
def append_text_content(self, content):
496-
self.text_content.append(content)
497-
498-
def append_tool_call(self, tool_call):
499-
idx = tool_call.index
500-
# make sure we have enough tool call buffers
501-
for _ in range(len(self.tool_calls_buffers), idx + 1):
502-
self.tool_calls_buffers.append(None)
503-
504-
if not self.tool_calls_buffers[idx]:
505-
self.tool_calls_buffers[idx] = ToolCallBuffer(
506-
idx, tool_call.id, tool_call.function.name
507-
)
508-
self.tool_calls_buffers[idx].append_arguments(
509-
tool_call.function.arguments
510-
)
511-
512-
513-
class StreamWrapper:
514-
span: Span
515-
response_id: Optional[str] = None
516-
response_model: Optional[str] = None
517-
service_tier: Optional[str] = None
518-
finish_reasons: list = []
519-
prompt_tokens: Optional[int] = 0
520-
completion_tokens: Optional[int] = 0
521-
522-
def __init__(
523-
self,
524-
stream: Iterator[Any] | AsyncIterator[Any],
525-
span: Span,
526-
logger: Logger,
527-
capture_content: bool,
528-
):
529-
self.stream = stream
530-
self.span = span
531-
self.choice_buffers = []
532-
self._span_started = False
533-
self.capture_content = capture_content
534-
535-
self.logger = logger
536-
self.setup()
537-
538-
def setup(self):
539-
if not self._span_started:
540-
self._span_started = True
541-
542-
def cleanup(self):
543-
if self._span_started:
544-
if self.span.is_recording():
545-
if self.response_model:
546-
set_span_attribute(
547-
self.span,
548-
GenAIAttributes.GEN_AI_RESPONSE_MODEL,
549-
self.response_model,
550-
)
551-
552-
if self.response_id:
553-
set_span_attribute(
554-
self.span,
555-
GenAIAttributes.GEN_AI_RESPONSE_ID,
556-
self.response_id,
557-
)
558-
559-
set_span_attribute(
560-
self.span,
561-
GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS,
562-
self.prompt_tokens,
563-
)
564-
set_span_attribute(
565-
self.span,
566-
GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS,
567-
self.completion_tokens,
568-
)
569-
570-
set_span_attribute(
571-
self.span,
572-
GenAIAttributes.GEN_AI_OPENAI_RESPONSE_SERVICE_TIER,
573-
self.service_tier,
574-
)
575-
576-
set_span_attribute(
577-
self.span,
578-
GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS,
579-
self.finish_reasons,
580-
)
581-
582-
for idx, choice in enumerate(self.choice_buffers):
583-
message: dict[str, Any] = {"role": "assistant"}
584-
if self.capture_content and choice.text_content:
585-
message["content"] = "".join(choice.text_content)
586-
if choice.tool_calls_buffers:
587-
tool_calls = []
588-
for tool_call in choice.tool_calls_buffers:
589-
function = {"name": tool_call.function_name}
590-
if self.capture_content:
591-
function["arguments"] = "".join(
592-
tool_call.arguments
593-
)
594-
tool_call_dict = {
595-
"id": tool_call.tool_call_id,
596-
"type": "function",
597-
"function": function,
598-
}
599-
tool_calls.append(tool_call_dict)
600-
message["tool_calls"] = tool_calls
601-
602-
body = {
603-
"index": idx,
604-
"finish_reason": choice.finish_reason or "error",
605-
"message": message,
606-
}
607-
608-
event_attributes = {
609-
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.OPENAI.value
610-
}
611-
context = set_span_in_context(self.span, get_current())
612-
self.logger.emit(
613-
LogRecord(
614-
event_name="gen_ai.choice",
615-
attributes=event_attributes,
616-
body=body,
617-
context=context,
618-
)
619-
)
620-
621-
self.span.end()
622-
self._span_started = False
623-
624-
def __enter__(self):
625-
self.setup()
626-
return self
627-
628-
def __exit__(self, exc_type, exc_val, exc_tb):
629-
try:
630-
if exc_type is not None:
631-
handle_span_exception(self.span, exc_val)
632-
finally:
633-
self.cleanup()
634-
return False # Propagate the exception
635-
636-
async def __aenter__(self):
637-
self.setup()
638-
return self
639-
640-
async def __aexit__(self, exc_type, exc_val, exc_tb):
641-
try:
642-
if exc_type is not None:
643-
handle_span_exception(self.span, exc_val)
644-
finally:
645-
self.cleanup()
646-
return False # Propagate the exception
647-
648-
def close(self):
649-
try:
650-
close_fn = getattr(self.stream, "close", None)
651-
if not callable(close_fn):
652-
return
653-
654-
close_result = close_fn()
655-
if inspect.isawaitable(close_result):
656-
try:
657-
loop = asyncio.get_running_loop()
658-
except RuntimeError:
659-
asyncio.run(cast(Any, close_result))
660-
else:
661-
loop.create_task(cast(Any, close_result))
662-
finally:
663-
self.cleanup()
664-
665-
def __iter__(self):
666-
return self
667-
668-
def __aiter__(self):
669-
return self
670-
671-
def __next__(self):
672-
try:
673-
chunk = next(cast(Iterator[Any], self.stream))
674-
self.process_chunk(chunk)
675-
return chunk
676-
except StopIteration:
677-
self.cleanup()
678-
raise
679-
except Exception as error:
680-
handle_span_exception(self.span, error)
681-
self.cleanup()
682-
raise
683-
684-
async def __anext__(self):
685-
try:
686-
chunk = await anext(cast(AsyncIterator[Any], self.stream))
687-
self.process_chunk(chunk)
688-
return chunk
689-
except StopAsyncIteration:
690-
self.cleanup()
691-
raise
692-
except Exception as error:
693-
handle_span_exception(self.span, error)
694-
self.cleanup()
695-
raise
696-
697-
def set_response_model(self, chunk):
698-
if self.response_model:
699-
return
700-
701-
if getattr(chunk, "model", None):
702-
self.response_model = chunk.model
703-
704-
def set_response_id(self, chunk):
705-
if self.response_id:
706-
return
707-
708-
if getattr(chunk, "id", None):
709-
self.response_id = chunk.id
710-
711-
def set_response_service_tier(self, chunk):
712-
if self.service_tier:
713-
return
714-
715-
if getattr(chunk, "service_tier", None):
716-
self.service_tier = chunk.service_tier
717-
718-
def build_streaming_response(self, chunk):
719-
if getattr(chunk, "choices", None) is None:
720-
return
721-
722-
choices = chunk.choices
723-
for choice in choices:
724-
if not choice.delta:
725-
continue
726-
727-
# make sure we have enough choice buffers
728-
for idx in range(len(self.choice_buffers), choice.index + 1):
729-
self.choice_buffers.append(ChoiceBuffer(idx))
730-
731-
if choice.finish_reason:
732-
self.choice_buffers[
733-
choice.index
734-
].finish_reason = choice.finish_reason
735-
736-
if choice.delta.content is not None:
737-
self.choice_buffers[choice.index].append_text_content(
738-
choice.delta.content
739-
)
740-
741-
if choice.delta.tool_calls is not None:
742-
for tool_call in choice.delta.tool_calls:
743-
self.choice_buffers[choice.index].append_tool_call(
744-
tool_call
745-
)
746-
747-
def set_usage(self, chunk):
748-
if getattr(chunk, "usage", None):
749-
self.completion_tokens = chunk.usage.completion_tokens
750-
self.prompt_tokens = chunk.usage.prompt_tokens
751-
752-
def process_chunk(self, chunk):
753-
self.set_response_id(chunk)
754-
self.set_response_model(chunk)
755-
self.set_response_service_tier(chunk)
756-
self.build_streaming_response(chunk)
757-
self.set_usage(chunk)

0 commit comments

Comments
 (0)