|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
16 | | -import asyncio |
17 | | -import inspect |
18 | | -from collections.abc import AsyncIterator, Iterator |
19 | 16 | from timeit import default_timer |
20 | | -from typing import Any, Optional, cast |
| 17 | +from typing import Any, Optional |
21 | 18 |
|
22 | | -from opentelemetry._logs import Logger, LogRecord |
23 | | -from opentelemetry.context import get_current |
| 19 | +from opentelemetry._logs import Logger |
24 | 20 | from opentelemetry.semconv._incubating.attributes import ( |
25 | 21 | gen_ai_attributes as GenAIAttributes, |
26 | 22 | ) |
27 | 23 | from opentelemetry.semconv._incubating.attributes import ( |
28 | 24 | server_attributes as ServerAttributes, |
29 | 25 | ) |
30 | 26 | from opentelemetry.trace import Span, SpanKind, Tracer |
31 | | -from opentelemetry.trace.propagation import set_span_in_context |
32 | 27 |
|
33 | 28 | from .instruments import Instruments |
| 29 | +from .streaming import ( |
| 30 | + AsyncStreamWrapper, |
| 31 | + StreamWrapper, |
| 32 | + SyncStreamWrapper, |
| 33 | +) |
| 34 | + |
| 35 | +# Re-export StreamWrapper for backwards compatibility |
| 36 | +__all__ = ["StreamWrapper"] |
| 37 | + |
34 | 38 | from .utils import ( |
35 | 39 | choice_to_event, |
36 | 40 | get_llm_request_attributes, |
@@ -73,9 +77,7 @@ def traced_method(wrapped, instance, args, kwargs): |
73 | 77 | else: |
74 | 78 | parsed_result = result |
75 | 79 | if is_streaming(kwargs): |
76 | | - return StreamWrapper( |
77 | | - parsed_result, span, logger, capture_content |
78 | | - ) |
| 80 | + return SyncStreamWrapper(parsed_result, span, logger, capture_content) |
79 | 81 |
|
80 | 82 | if span.is_recording(): |
81 | 83 | _set_response_attributes( |
@@ -137,9 +139,7 @@ async def traced_method(wrapped, instance, args, kwargs): |
137 | 139 | else: |
138 | 140 | parsed_result = result |
139 | 141 | if is_streaming(kwargs): |
140 | | - return StreamWrapper( |
141 | | - parsed_result, span, logger, capture_content |
142 | | - ) |
| 142 | + return AsyncStreamWrapper(parsed_result, span, logger, capture_content) |
143 | 143 |
|
144 | 144 | if span.is_recording(): |
145 | 145 | _set_response_attributes( |
@@ -487,290 +487,3 @@ def _set_embeddings_response_attributes( |
487 | 487 | result.usage.prompt_tokens, |
488 | 488 | ) |
489 | 489 | # Don't set output tokens for embeddings as all tokens are input tokens |
490 | | - |
491 | | - |
492 | | -class ToolCallBuffer: |
493 | | - def __init__(self, index, tool_call_id, function_name): |
494 | | - self.index = index |
495 | | - self.function_name = function_name |
496 | | - self.tool_call_id = tool_call_id |
497 | | - self.arguments = [] |
498 | | - |
499 | | - def append_arguments(self, arguments): |
500 | | - self.arguments.append(arguments) |
501 | | - |
502 | | - |
503 | | -class ChoiceBuffer: |
504 | | - def __init__(self, index): |
505 | | - self.index = index |
506 | | - self.finish_reason = None |
507 | | - self.text_content = [] |
508 | | - self.tool_calls_buffers = [] |
509 | | - |
510 | | - def append_text_content(self, content): |
511 | | - self.text_content.append(content) |
512 | | - |
513 | | - def append_tool_call(self, tool_call): |
514 | | - idx = tool_call.index |
515 | | - # make sure we have enough tool call buffers |
516 | | - for _ in range(len(self.tool_calls_buffers), idx + 1): |
517 | | - self.tool_calls_buffers.append(None) |
518 | | - |
519 | | - if not self.tool_calls_buffers[idx]: |
520 | | - self.tool_calls_buffers[idx] = ToolCallBuffer( |
521 | | - idx, tool_call.id, tool_call.function.name |
522 | | - ) |
523 | | - self.tool_calls_buffers[idx].append_arguments( |
524 | | - tool_call.function.arguments |
525 | | - ) |
526 | | - |
527 | | - |
528 | | -class StreamWrapper: |
529 | | - span: Span |
530 | | - response_id: Optional[str] = None |
531 | | - response_model: Optional[str] = None |
532 | | - service_tier: Optional[str] = None |
533 | | - finish_reasons: list = [] |
534 | | - prompt_tokens: Optional[int] = 0 |
535 | | - completion_tokens: Optional[int] = 0 |
536 | | - |
537 | | - def __init__( |
538 | | - self, |
539 | | - stream: Iterator[Any] | AsyncIterator[Any], |
540 | | - span: Span, |
541 | | - logger: Logger, |
542 | | - capture_content: bool, |
543 | | - ): |
544 | | - self.stream = stream |
545 | | - self.span = span |
546 | | - self.choice_buffers = [] |
547 | | - self._span_started = False |
548 | | - self.capture_content = capture_content |
549 | | - |
550 | | - self.logger = logger |
551 | | - self.setup() |
552 | | - |
553 | | - def setup(self): |
554 | | - if not self._span_started: |
555 | | - self._span_started = True |
556 | | - |
557 | | - def cleanup(self): |
558 | | - if self._span_started: |
559 | | - if self.span.is_recording(): |
560 | | - if self.response_model: |
561 | | - set_span_attribute( |
562 | | - self.span, |
563 | | - GenAIAttributes.GEN_AI_RESPONSE_MODEL, |
564 | | - self.response_model, |
565 | | - ) |
566 | | - |
567 | | - if self.response_id: |
568 | | - set_span_attribute( |
569 | | - self.span, |
570 | | - GenAIAttributes.GEN_AI_RESPONSE_ID, |
571 | | - self.response_id, |
572 | | - ) |
573 | | - |
574 | | - set_span_attribute( |
575 | | - self.span, |
576 | | - GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS, |
577 | | - self.prompt_tokens, |
578 | | - ) |
579 | | - set_span_attribute( |
580 | | - self.span, |
581 | | - GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, |
582 | | - self.completion_tokens, |
583 | | - ) |
584 | | - |
585 | | - set_span_attribute( |
586 | | - self.span, |
587 | | - GenAIAttributes.GEN_AI_OPENAI_RESPONSE_SERVICE_TIER, |
588 | | - self.service_tier, |
589 | | - ) |
590 | | - |
591 | | - set_span_attribute( |
592 | | - self.span, |
593 | | - GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS, |
594 | | - self.finish_reasons, |
595 | | - ) |
596 | | - |
597 | | - for idx, choice in enumerate(self.choice_buffers): |
598 | | - message: dict[str, Any] = {"role": "assistant"} |
599 | | - if self.capture_content and choice.text_content: |
600 | | - message["content"] = "".join(choice.text_content) |
601 | | - if choice.tool_calls_buffers: |
602 | | - tool_calls = [] |
603 | | - for tool_call in choice.tool_calls_buffers: |
604 | | - function = {"name": tool_call.function_name} |
605 | | - if self.capture_content: |
606 | | - function["arguments"] = "".join( |
607 | | - tool_call.arguments |
608 | | - ) |
609 | | - tool_call_dict = { |
610 | | - "id": tool_call.tool_call_id, |
611 | | - "type": "function", |
612 | | - "function": function, |
613 | | - } |
614 | | - tool_calls.append(tool_call_dict) |
615 | | - message["tool_calls"] = tool_calls |
616 | | - |
617 | | - body = { |
618 | | - "index": idx, |
619 | | - "finish_reason": choice.finish_reason or "error", |
620 | | - "message": message, |
621 | | - } |
622 | | - |
623 | | - event_attributes = { |
624 | | - GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.OPENAI.value |
625 | | - } |
626 | | - context = set_span_in_context(self.span, get_current()) |
627 | | - self.logger.emit( |
628 | | - LogRecord( |
629 | | - event_name="gen_ai.choice", |
630 | | - attributes=event_attributes, |
631 | | - body=body, |
632 | | - context=context, |
633 | | - ) |
634 | | - ) |
635 | | - |
636 | | - self.span.end() |
637 | | - self._span_started = False |
638 | | - |
639 | | - def __enter__(self): |
640 | | - self.setup() |
641 | | - return self |
642 | | - |
643 | | - def __exit__(self, exc_type, exc_val, exc_tb): |
644 | | - try: |
645 | | - if exc_type is not None: |
646 | | - handle_span_exception(self.span, exc_val) |
647 | | - finally: |
648 | | - self.cleanup() |
649 | | - return False # Propagate the exception |
650 | | - |
651 | | - async def __aenter__(self): |
652 | | - self.setup() |
653 | | - return self |
654 | | - |
655 | | - async def __aexit__(self, exc_type, exc_val, exc_tb): |
656 | | - try: |
657 | | - if exc_type is not None: |
658 | | - handle_span_exception(self.span, exc_val) |
659 | | - finally: |
660 | | - self.cleanup() |
661 | | - return False # Propagate the exception |
662 | | - |
663 | | - def close(self): |
664 | | - try: |
665 | | - close_fn = getattr(self.stream, "close", None) |
666 | | - if not callable(close_fn): |
667 | | - return |
668 | | - |
669 | | - close_result = close_fn() |
670 | | - if inspect.isawaitable(close_result): |
671 | | - try: |
672 | | - loop = asyncio.get_running_loop() |
673 | | - except RuntimeError: |
674 | | - asyncio.run(cast(Any, close_result)) |
675 | | - else: |
676 | | - loop.create_task(cast(Any, close_result)) |
677 | | - finally: |
678 | | - self.cleanup() |
679 | | - |
680 | | - def __iter__(self): |
681 | | - return self |
682 | | - |
683 | | - def __aiter__(self): |
684 | | - return self |
685 | | - |
686 | | - def __next__(self): |
687 | | - try: |
688 | | - chunk = next(cast(Iterator[Any], self.stream)) |
689 | | - self.process_chunk(chunk) |
690 | | - return chunk |
691 | | - except StopIteration: |
692 | | - self.cleanup() |
693 | | - raise |
694 | | - except Exception as error: |
695 | | - handle_span_exception(self.span, error) |
696 | | - self.cleanup() |
697 | | - raise |
698 | | - |
699 | | - async def __anext__(self): |
700 | | - try: |
701 | | - chunk = await anext(cast(AsyncIterator[Any], self.stream)) |
702 | | - self.process_chunk(chunk) |
703 | | - return chunk |
704 | | - except StopAsyncIteration: |
705 | | - self.cleanup() |
706 | | - raise |
707 | | - except Exception as error: |
708 | | - handle_span_exception(self.span, error) |
709 | | - self.cleanup() |
710 | | - raise |
711 | | - |
712 | | - def set_response_model(self, chunk): |
713 | | - if self.response_model: |
714 | | - return |
715 | | - |
716 | | - if getattr(chunk, "model", None): |
717 | | - self.response_model = chunk.model |
718 | | - |
719 | | - def set_response_id(self, chunk): |
720 | | - if self.response_id: |
721 | | - return |
722 | | - |
723 | | - if getattr(chunk, "id", None): |
724 | | - self.response_id = chunk.id |
725 | | - |
726 | | - def set_response_service_tier(self, chunk): |
727 | | - if self.service_tier: |
728 | | - return |
729 | | - |
730 | | - if getattr(chunk, "service_tier", None): |
731 | | - self.service_tier = chunk.service_tier |
732 | | - |
733 | | - def build_streaming_response(self, chunk): |
734 | | - if getattr(chunk, "choices", None) is None: |
735 | | - return |
736 | | - |
737 | | - choices = chunk.choices |
738 | | - for choice in choices: |
739 | | - if not choice.delta: |
740 | | - continue |
741 | | - |
742 | | - # make sure we have enough choice buffers |
743 | | - for idx in range(len(self.choice_buffers), choice.index + 1): |
744 | | - self.choice_buffers.append(ChoiceBuffer(idx)) |
745 | | - |
746 | | - if choice.finish_reason: |
747 | | - self.choice_buffers[ |
748 | | - choice.index |
749 | | - ].finish_reason = choice.finish_reason |
750 | | - |
751 | | - if choice.delta.content is not None: |
752 | | - self.choice_buffers[choice.index].append_text_content( |
753 | | - choice.delta.content |
754 | | - ) |
755 | | - |
756 | | - if choice.delta.tool_calls is not None: |
757 | | - for tool_call in choice.delta.tool_calls: |
758 | | - self.choice_buffers[choice.index].append_tool_call( |
759 | | - tool_call |
760 | | - ) |
761 | | - |
762 | | - def set_usage(self, chunk): |
763 | | - if getattr(chunk, "usage", None): |
764 | | - self.completion_tokens = chunk.usage.completion_tokens |
765 | | - self.prompt_tokens = chunk.usage.prompt_tokens |
766 | | - |
767 | | - def process_chunk(self, chunk): |
768 | | - self.set_response_id(chunk) |
769 | | - self.set_response_model(chunk) |
770 | | - self.set_response_service_tier(chunk) |
771 | | - self.build_streaming_response(chunk) |
772 | | - self.set_usage(chunk) |
773 | | - |
774 | | - def parse(self): |
775 | | - """Called when using with_raw_response with stream=True""" |
776 | | - return self |
0 commit comments