Skip to content

Commit 456aafc

Browse files
committed
openaiv2: clarify streaming contract and test corner cases
1 parent d87745d commit 456aafc

File tree

4 files changed

+679
-300
lines changed

4 files changed

+679
-300
lines changed

instrumentation-genai/opentelemetry-instrumentation-openai-v2/src/opentelemetry/instrumentation/openai_v2/patch.py

Lines changed: 13 additions & 300 deletions
Original file line numberDiff line numberDiff line change
@@ -13,24 +13,28 @@
1313
# limitations under the License.
1414

1515

16-
import asyncio
17-
import inspect
18-
from collections.abc import AsyncIterator, Iterator
1916
from timeit import default_timer
20-
from typing import Any, Optional, cast
17+
from typing import Any, Optional
2118

22-
from opentelemetry._logs import Logger, LogRecord
23-
from opentelemetry.context import get_current
19+
from opentelemetry._logs import Logger
2420
from opentelemetry.semconv._incubating.attributes import (
2521
gen_ai_attributes as GenAIAttributes,
2622
)
2723
from opentelemetry.semconv._incubating.attributes import (
2824
server_attributes as ServerAttributes,
2925
)
3026
from opentelemetry.trace import Span, SpanKind, Tracer
31-
from opentelemetry.trace.propagation import set_span_in_context
3227

3328
from .instruments import Instruments
29+
from .streaming import (
30+
AsyncStreamWrapper,
31+
StreamWrapper,
32+
SyncStreamWrapper,
33+
)
34+
35+
# Re-export StreamWrapper for backwards compatibility
36+
__all__ = ["StreamWrapper"]
37+
3438
from .utils import (
3539
choice_to_event,
3640
get_llm_request_attributes,
@@ -73,9 +77,7 @@ def traced_method(wrapped, instance, args, kwargs):
7377
else:
7478
parsed_result = result
7579
if is_streaming(kwargs):
76-
return StreamWrapper(
77-
parsed_result, span, logger, capture_content
78-
)
80+
return SyncStreamWrapper(parsed_result, span, logger, capture_content)
7981

8082
if span.is_recording():
8183
_set_response_attributes(
@@ -137,9 +139,7 @@ async def traced_method(wrapped, instance, args, kwargs):
137139
else:
138140
parsed_result = result
139141
if is_streaming(kwargs):
140-
return StreamWrapper(
141-
parsed_result, span, logger, capture_content
142-
)
142+
return AsyncStreamWrapper(parsed_result, span, logger, capture_content)
143143

144144
if span.is_recording():
145145
_set_response_attributes(
@@ -487,290 +487,3 @@ def _set_embeddings_response_attributes(
487487
result.usage.prompt_tokens,
488488
)
489489
# Don't set output tokens for embeddings as all tokens are input tokens
490-
491-
492-
class ToolCallBuffer:
493-
def __init__(self, index, tool_call_id, function_name):
494-
self.index = index
495-
self.function_name = function_name
496-
self.tool_call_id = tool_call_id
497-
self.arguments = []
498-
499-
def append_arguments(self, arguments):
500-
self.arguments.append(arguments)
501-
502-
503-
class ChoiceBuffer:
504-
def __init__(self, index):
505-
self.index = index
506-
self.finish_reason = None
507-
self.text_content = []
508-
self.tool_calls_buffers = []
509-
510-
def append_text_content(self, content):
511-
self.text_content.append(content)
512-
513-
def append_tool_call(self, tool_call):
514-
idx = tool_call.index
515-
# make sure we have enough tool call buffers
516-
for _ in range(len(self.tool_calls_buffers), idx + 1):
517-
self.tool_calls_buffers.append(None)
518-
519-
if not self.tool_calls_buffers[idx]:
520-
self.tool_calls_buffers[idx] = ToolCallBuffer(
521-
idx, tool_call.id, tool_call.function.name
522-
)
523-
self.tool_calls_buffers[idx].append_arguments(
524-
tool_call.function.arguments
525-
)
526-
527-
528-
class StreamWrapper:
529-
span: Span
530-
response_id: Optional[str] = None
531-
response_model: Optional[str] = None
532-
service_tier: Optional[str] = None
533-
finish_reasons: list = []
534-
prompt_tokens: Optional[int] = 0
535-
completion_tokens: Optional[int] = 0
536-
537-
def __init__(
538-
self,
539-
stream: Iterator[Any] | AsyncIterator[Any],
540-
span: Span,
541-
logger: Logger,
542-
capture_content: bool,
543-
):
544-
self.stream = stream
545-
self.span = span
546-
self.choice_buffers = []
547-
self._span_started = False
548-
self.capture_content = capture_content
549-
550-
self.logger = logger
551-
self.setup()
552-
553-
def setup(self):
554-
if not self._span_started:
555-
self._span_started = True
556-
557-
def cleanup(self):
558-
if self._span_started:
559-
if self.span.is_recording():
560-
if self.response_model:
561-
set_span_attribute(
562-
self.span,
563-
GenAIAttributes.GEN_AI_RESPONSE_MODEL,
564-
self.response_model,
565-
)
566-
567-
if self.response_id:
568-
set_span_attribute(
569-
self.span,
570-
GenAIAttributes.GEN_AI_RESPONSE_ID,
571-
self.response_id,
572-
)
573-
574-
set_span_attribute(
575-
self.span,
576-
GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS,
577-
self.prompt_tokens,
578-
)
579-
set_span_attribute(
580-
self.span,
581-
GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS,
582-
self.completion_tokens,
583-
)
584-
585-
set_span_attribute(
586-
self.span,
587-
GenAIAttributes.GEN_AI_OPENAI_RESPONSE_SERVICE_TIER,
588-
self.service_tier,
589-
)
590-
591-
set_span_attribute(
592-
self.span,
593-
GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS,
594-
self.finish_reasons,
595-
)
596-
597-
for idx, choice in enumerate(self.choice_buffers):
598-
message: dict[str, Any] = {"role": "assistant"}
599-
if self.capture_content and choice.text_content:
600-
message["content"] = "".join(choice.text_content)
601-
if choice.tool_calls_buffers:
602-
tool_calls = []
603-
for tool_call in choice.tool_calls_buffers:
604-
function = {"name": tool_call.function_name}
605-
if self.capture_content:
606-
function["arguments"] = "".join(
607-
tool_call.arguments
608-
)
609-
tool_call_dict = {
610-
"id": tool_call.tool_call_id,
611-
"type": "function",
612-
"function": function,
613-
}
614-
tool_calls.append(tool_call_dict)
615-
message["tool_calls"] = tool_calls
616-
617-
body = {
618-
"index": idx,
619-
"finish_reason": choice.finish_reason or "error",
620-
"message": message,
621-
}
622-
623-
event_attributes = {
624-
GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.OPENAI.value
625-
}
626-
context = set_span_in_context(self.span, get_current())
627-
self.logger.emit(
628-
LogRecord(
629-
event_name="gen_ai.choice",
630-
attributes=event_attributes,
631-
body=body,
632-
context=context,
633-
)
634-
)
635-
636-
self.span.end()
637-
self._span_started = False
638-
639-
def __enter__(self):
640-
self.setup()
641-
return self
642-
643-
def __exit__(self, exc_type, exc_val, exc_tb):
644-
try:
645-
if exc_type is not None:
646-
handle_span_exception(self.span, exc_val)
647-
finally:
648-
self.cleanup()
649-
return False # Propagate the exception
650-
651-
async def __aenter__(self):
652-
self.setup()
653-
return self
654-
655-
async def __aexit__(self, exc_type, exc_val, exc_tb):
656-
try:
657-
if exc_type is not None:
658-
handle_span_exception(self.span, exc_val)
659-
finally:
660-
self.cleanup()
661-
return False # Propagate the exception
662-
663-
def close(self):
664-
try:
665-
close_fn = getattr(self.stream, "close", None)
666-
if not callable(close_fn):
667-
return
668-
669-
close_result = close_fn()
670-
if inspect.isawaitable(close_result):
671-
try:
672-
loop = asyncio.get_running_loop()
673-
except RuntimeError:
674-
asyncio.run(cast(Any, close_result))
675-
else:
676-
loop.create_task(cast(Any, close_result))
677-
finally:
678-
self.cleanup()
679-
680-
def __iter__(self):
681-
return self
682-
683-
def __aiter__(self):
684-
return self
685-
686-
def __next__(self):
687-
try:
688-
chunk = next(cast(Iterator[Any], self.stream))
689-
self.process_chunk(chunk)
690-
return chunk
691-
except StopIteration:
692-
self.cleanup()
693-
raise
694-
except Exception as error:
695-
handle_span_exception(self.span, error)
696-
self.cleanup()
697-
raise
698-
699-
async def __anext__(self):
700-
try:
701-
chunk = await anext(cast(AsyncIterator[Any], self.stream))
702-
self.process_chunk(chunk)
703-
return chunk
704-
except StopAsyncIteration:
705-
self.cleanup()
706-
raise
707-
except Exception as error:
708-
handle_span_exception(self.span, error)
709-
self.cleanup()
710-
raise
711-
712-
def set_response_model(self, chunk):
713-
if self.response_model:
714-
return
715-
716-
if getattr(chunk, "model", None):
717-
self.response_model = chunk.model
718-
719-
def set_response_id(self, chunk):
720-
if self.response_id:
721-
return
722-
723-
if getattr(chunk, "id", None):
724-
self.response_id = chunk.id
725-
726-
def set_response_service_tier(self, chunk):
727-
if self.service_tier:
728-
return
729-
730-
if getattr(chunk, "service_tier", None):
731-
self.service_tier = chunk.service_tier
732-
733-
def build_streaming_response(self, chunk):
734-
if getattr(chunk, "choices", None) is None:
735-
return
736-
737-
choices = chunk.choices
738-
for choice in choices:
739-
if not choice.delta:
740-
continue
741-
742-
# make sure we have enough choice buffers
743-
for idx in range(len(self.choice_buffers), choice.index + 1):
744-
self.choice_buffers.append(ChoiceBuffer(idx))
745-
746-
if choice.finish_reason:
747-
self.choice_buffers[
748-
choice.index
749-
].finish_reason = choice.finish_reason
750-
751-
if choice.delta.content is not None:
752-
self.choice_buffers[choice.index].append_text_content(
753-
choice.delta.content
754-
)
755-
756-
if choice.delta.tool_calls is not None:
757-
for tool_call in choice.delta.tool_calls:
758-
self.choice_buffers[choice.index].append_tool_call(
759-
tool_call
760-
)
761-
762-
def set_usage(self, chunk):
763-
if getattr(chunk, "usage", None):
764-
self.completion_tokens = chunk.usage.completion_tokens
765-
self.prompt_tokens = chunk.usage.prompt_tokens
766-
767-
def process_chunk(self, chunk):
768-
self.set_response_id(chunk)
769-
self.set_response_model(chunk)
770-
self.set_response_service_tier(chunk)
771-
self.build_streaming_response(chunk)
772-
self.set_usage(chunk)
773-
774-
def parse(self):
775-
"""Called when using with_raw_response with stream=True"""
776-
return self

0 commit comments

Comments
 (0)