|
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | 15 |
|
16 | | -import asyncio |
17 | | -import inspect |
18 | | -from collections.abc import AsyncIterator, Iterator |
19 | 16 | from timeit import default_timer |
20 | | -from typing import Any, Optional, cast |
| 17 | +from typing import Any, Optional |
21 | 18 |
|
22 | | -from opentelemetry._logs import Logger, LogRecord |
23 | | -from opentelemetry.context import get_current |
| 19 | +from opentelemetry._logs import Logger |
24 | 20 | from opentelemetry.semconv._incubating.attributes import ( |
25 | 21 | gen_ai_attributes as GenAIAttributes, |
26 | 22 | ) |
27 | 23 | from opentelemetry.semconv._incubating.attributes import ( |
28 | 24 | server_attributes as ServerAttributes, |
29 | 25 | ) |
30 | 26 | from opentelemetry.trace import Span, SpanKind, Tracer |
31 | | -from opentelemetry.trace.propagation import set_span_in_context |
32 | 27 |
|
33 | 28 | from .instruments import Instruments |
| 29 | +from .streaming import ( |
| 30 | + AsyncStreamWrapper, |
| 31 | + StreamWrapper, |
| 32 | + SyncStreamWrapper, |
| 33 | +) |
| 34 | + |
| 35 | +# Re-export StreamWrapper for backwards compatibility |
| 36 | +__all__ = ["StreamWrapper"] |
| 37 | + |
34 | 38 | from .utils import ( |
35 | 39 | choice_to_event, |
36 | 40 | get_llm_request_attributes, |
@@ -68,7 +72,7 @@ def traced_method(wrapped, instance, args, kwargs): |
68 | 72 | try: |
69 | 73 | result = wrapped(*args, **kwargs) |
70 | 74 | if is_streaming(kwargs): |
71 | | - return StreamWrapper(result, span, logger, capture_content) |
| 75 | + return SyncStreamWrapper(result, span, logger, capture_content) |
72 | 76 |
|
73 | 77 | if span.is_recording(): |
74 | 78 | _set_response_attributes( |
@@ -125,7 +129,7 @@ async def traced_method(wrapped, instance, args, kwargs): |
125 | 129 | try: |
126 | 130 | result = await wrapped(*args, **kwargs) |
127 | 131 | if is_streaming(kwargs): |
128 | | - return StreamWrapper(result, span, logger, capture_content) |
| 132 | + return AsyncStreamWrapper(result, span, logger, capture_content) |
129 | 133 |
|
130 | 134 | if span.is_recording(): |
131 | 135 | _set_response_attributes( |
@@ -472,286 +476,3 @@ def _set_embeddings_response_attributes( |
472 | 476 | result.usage.prompt_tokens, |
473 | 477 | ) |
474 | 478 | # Don't set output tokens for embeddings as all tokens are input tokens |
475 | | - |
476 | | - |
477 | | -class ToolCallBuffer: |
478 | | - def __init__(self, index, tool_call_id, function_name): |
479 | | - self.index = index |
480 | | - self.function_name = function_name |
481 | | - self.tool_call_id = tool_call_id |
482 | | - self.arguments = [] |
483 | | - |
484 | | - def append_arguments(self, arguments): |
485 | | - self.arguments.append(arguments) |
486 | | - |
487 | | - |
488 | | -class ChoiceBuffer: |
489 | | - def __init__(self, index): |
490 | | - self.index = index |
491 | | - self.finish_reason = None |
492 | | - self.text_content = [] |
493 | | - self.tool_calls_buffers = [] |
494 | | - |
495 | | - def append_text_content(self, content): |
496 | | - self.text_content.append(content) |
497 | | - |
498 | | - def append_tool_call(self, tool_call): |
499 | | - idx = tool_call.index |
500 | | - # make sure we have enough tool call buffers |
501 | | - for _ in range(len(self.tool_calls_buffers), idx + 1): |
502 | | - self.tool_calls_buffers.append(None) |
503 | | - |
504 | | - if not self.tool_calls_buffers[idx]: |
505 | | - self.tool_calls_buffers[idx] = ToolCallBuffer( |
506 | | - idx, tool_call.id, tool_call.function.name |
507 | | - ) |
508 | | - self.tool_calls_buffers[idx].append_arguments( |
509 | | - tool_call.function.arguments |
510 | | - ) |
511 | | - |
512 | | - |
513 | | -class StreamWrapper: |
514 | | - span: Span |
515 | | - response_id: Optional[str] = None |
516 | | - response_model: Optional[str] = None |
517 | | - service_tier: Optional[str] = None |
518 | | - finish_reasons: list = [] |
519 | | - prompt_tokens: Optional[int] = 0 |
520 | | - completion_tokens: Optional[int] = 0 |
521 | | - |
522 | | - def __init__( |
523 | | - self, |
524 | | - stream: Iterator[Any] | AsyncIterator[Any], |
525 | | - span: Span, |
526 | | - logger: Logger, |
527 | | - capture_content: bool, |
528 | | - ): |
529 | | - self.stream = stream |
530 | | - self.span = span |
531 | | - self.choice_buffers = [] |
532 | | - self._span_started = False |
533 | | - self.capture_content = capture_content |
534 | | - |
535 | | - self.logger = logger |
536 | | - self.setup() |
537 | | - |
538 | | - def setup(self): |
539 | | - if not self._span_started: |
540 | | - self._span_started = True |
541 | | - |
542 | | - def cleanup(self): |
543 | | - if self._span_started: |
544 | | - if self.span.is_recording(): |
545 | | - if self.response_model: |
546 | | - set_span_attribute( |
547 | | - self.span, |
548 | | - GenAIAttributes.GEN_AI_RESPONSE_MODEL, |
549 | | - self.response_model, |
550 | | - ) |
551 | | - |
552 | | - if self.response_id: |
553 | | - set_span_attribute( |
554 | | - self.span, |
555 | | - GenAIAttributes.GEN_AI_RESPONSE_ID, |
556 | | - self.response_id, |
557 | | - ) |
558 | | - |
559 | | - set_span_attribute( |
560 | | - self.span, |
561 | | - GenAIAttributes.GEN_AI_USAGE_INPUT_TOKENS, |
562 | | - self.prompt_tokens, |
563 | | - ) |
564 | | - set_span_attribute( |
565 | | - self.span, |
566 | | - GenAIAttributes.GEN_AI_USAGE_OUTPUT_TOKENS, |
567 | | - self.completion_tokens, |
568 | | - ) |
569 | | - |
570 | | - set_span_attribute( |
571 | | - self.span, |
572 | | - GenAIAttributes.GEN_AI_OPENAI_RESPONSE_SERVICE_TIER, |
573 | | - self.service_tier, |
574 | | - ) |
575 | | - |
576 | | - set_span_attribute( |
577 | | - self.span, |
578 | | - GenAIAttributes.GEN_AI_RESPONSE_FINISH_REASONS, |
579 | | - self.finish_reasons, |
580 | | - ) |
581 | | - |
582 | | - for idx, choice in enumerate(self.choice_buffers): |
583 | | - message: dict[str, Any] = {"role": "assistant"} |
584 | | - if self.capture_content and choice.text_content: |
585 | | - message["content"] = "".join(choice.text_content) |
586 | | - if choice.tool_calls_buffers: |
587 | | - tool_calls = [] |
588 | | - for tool_call in choice.tool_calls_buffers: |
589 | | - function = {"name": tool_call.function_name} |
590 | | - if self.capture_content: |
591 | | - function["arguments"] = "".join( |
592 | | - tool_call.arguments |
593 | | - ) |
594 | | - tool_call_dict = { |
595 | | - "id": tool_call.tool_call_id, |
596 | | - "type": "function", |
597 | | - "function": function, |
598 | | - } |
599 | | - tool_calls.append(tool_call_dict) |
600 | | - message["tool_calls"] = tool_calls |
601 | | - |
602 | | - body = { |
603 | | - "index": idx, |
604 | | - "finish_reason": choice.finish_reason or "error", |
605 | | - "message": message, |
606 | | - } |
607 | | - |
608 | | - event_attributes = { |
609 | | - GenAIAttributes.GEN_AI_SYSTEM: GenAIAttributes.GenAiSystemValues.OPENAI.value |
610 | | - } |
611 | | - context = set_span_in_context(self.span, get_current()) |
612 | | - self.logger.emit( |
613 | | - LogRecord( |
614 | | - event_name="gen_ai.choice", |
615 | | - attributes=event_attributes, |
616 | | - body=body, |
617 | | - context=context, |
618 | | - ) |
619 | | - ) |
620 | | - |
621 | | - self.span.end() |
622 | | - self._span_started = False |
623 | | - |
624 | | - def __enter__(self): |
625 | | - self.setup() |
626 | | - return self |
627 | | - |
628 | | - def __exit__(self, exc_type, exc_val, exc_tb): |
629 | | - try: |
630 | | - if exc_type is not None: |
631 | | - handle_span_exception(self.span, exc_val) |
632 | | - finally: |
633 | | - self.cleanup() |
634 | | - return False # Propagate the exception |
635 | | - |
636 | | - async def __aenter__(self): |
637 | | - self.setup() |
638 | | - return self |
639 | | - |
640 | | - async def __aexit__(self, exc_type, exc_val, exc_tb): |
641 | | - try: |
642 | | - if exc_type is not None: |
643 | | - handle_span_exception(self.span, exc_val) |
644 | | - finally: |
645 | | - self.cleanup() |
646 | | - return False # Propagate the exception |
647 | | - |
648 | | - def close(self): |
649 | | - try: |
650 | | - close_fn = getattr(self.stream, "close", None) |
651 | | - if not callable(close_fn): |
652 | | - return |
653 | | - |
654 | | - close_result = close_fn() |
655 | | - if inspect.isawaitable(close_result): |
656 | | - try: |
657 | | - loop = asyncio.get_running_loop() |
658 | | - except RuntimeError: |
659 | | - asyncio.run(cast(Any, close_result)) |
660 | | - else: |
661 | | - loop.create_task(cast(Any, close_result)) |
662 | | - finally: |
663 | | - self.cleanup() |
664 | | - |
665 | | - def __iter__(self): |
666 | | - return self |
667 | | - |
668 | | - def __aiter__(self): |
669 | | - return self |
670 | | - |
671 | | - def __next__(self): |
672 | | - try: |
673 | | - chunk = next(cast(Iterator[Any], self.stream)) |
674 | | - self.process_chunk(chunk) |
675 | | - return chunk |
676 | | - except StopIteration: |
677 | | - self.cleanup() |
678 | | - raise |
679 | | - except Exception as error: |
680 | | - handle_span_exception(self.span, error) |
681 | | - self.cleanup() |
682 | | - raise |
683 | | - |
684 | | - async def __anext__(self): |
685 | | - try: |
686 | | - chunk = await anext(cast(AsyncIterator[Any], self.stream)) |
687 | | - self.process_chunk(chunk) |
688 | | - return chunk |
689 | | - except StopAsyncIteration: |
690 | | - self.cleanup() |
691 | | - raise |
692 | | - except Exception as error: |
693 | | - handle_span_exception(self.span, error) |
694 | | - self.cleanup() |
695 | | - raise |
696 | | - |
697 | | - def set_response_model(self, chunk): |
698 | | - if self.response_model: |
699 | | - return |
700 | | - |
701 | | - if getattr(chunk, "model", None): |
702 | | - self.response_model = chunk.model |
703 | | - |
704 | | - def set_response_id(self, chunk): |
705 | | - if self.response_id: |
706 | | - return |
707 | | - |
708 | | - if getattr(chunk, "id", None): |
709 | | - self.response_id = chunk.id |
710 | | - |
711 | | - def set_response_service_tier(self, chunk): |
712 | | - if self.service_tier: |
713 | | - return |
714 | | - |
715 | | - if getattr(chunk, "service_tier", None): |
716 | | - self.service_tier = chunk.service_tier |
717 | | - |
718 | | - def build_streaming_response(self, chunk): |
719 | | - if getattr(chunk, "choices", None) is None: |
720 | | - return |
721 | | - |
722 | | - choices = chunk.choices |
723 | | - for choice in choices: |
724 | | - if not choice.delta: |
725 | | - continue |
726 | | - |
727 | | - # make sure we have enough choice buffers |
728 | | - for idx in range(len(self.choice_buffers), choice.index + 1): |
729 | | - self.choice_buffers.append(ChoiceBuffer(idx)) |
730 | | - |
731 | | - if choice.finish_reason: |
732 | | - self.choice_buffers[ |
733 | | - choice.index |
734 | | - ].finish_reason = choice.finish_reason |
735 | | - |
736 | | - if choice.delta.content is not None: |
737 | | - self.choice_buffers[choice.index].append_text_content( |
738 | | - choice.delta.content |
739 | | - ) |
740 | | - |
741 | | - if choice.delta.tool_calls is not None: |
742 | | - for tool_call in choice.delta.tool_calls: |
743 | | - self.choice_buffers[choice.index].append_tool_call( |
744 | | - tool_call |
745 | | - ) |
746 | | - |
747 | | - def set_usage(self, chunk): |
748 | | - if getattr(chunk, "usage", None): |
749 | | - self.completion_tokens = chunk.usage.completion_tokens |
750 | | - self.prompt_tokens = chunk.usage.prompt_tokens |
751 | | - |
752 | | - def process_chunk(self, chunk): |
753 | | - self.set_response_id(chunk) |
754 | | - self.set_response_model(chunk) |
755 | | - self.set_response_service_tier(chunk) |
756 | | - self.build_streaming_response(chunk) |
757 | | - self.set_usage(chunk) |
0 commit comments