pyhub_ai.mixins.chat의 소스 코드

import asyncio
import logging
from functools import cached_property
from typing import (
    AsyncIterator,
    Callable,
    Coroutine,
    Dict,
    List,
    Literal,
    Optional,
    Type,
)

from asgiref.sync import sync_to_async
from django.core.exceptions import ImproperlyConfigured
from django.core.files.base import File
from django.forms import Form
from django.http import QueryDict
from django.utils.datastructures import MultiValueDict

from pyhub_ai.agents.chat import ContentBlock
from pyhub_ai.blocks import (
    EventContentBlock,
    MessageBlock,
    MessageBlockRenderModeType,
    TextContentBlock,
)
from pyhub_ai.forms import MessageForm
from pyhub_ai.models import Conversation, UserType
from pyhub_ai.utils import amerge

logger = logging.getLogger(__name__)


class ChatMixin:
    form_class = MessageForm
    user_text_field_name = "user_text"
    user_images_field_name = "images"
    conversation_pk_url_kwarg = "conversation_pk"

    chat_messages_dom_id = "chat-messages"
    base64_field_name_postfix = "__base64"
    template_name = "pyhub_ai/_chat_message.html"

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        self.send_func: Optional[Callable[[str], Coroutine]] = None
        self.llm_response_queue = asyncio.Queue()
        self.llm_response_task: Optional[asyncio.Task] = None

    async def chat_setup(self, send_func: Callable[[str], Coroutine]):
        self.send_func = send_func
        self.llm_response_task = asyncio.create_task(self.run_llm_response_task())
        self.llm_response_task.add_done_callback(self.handle_llm_response_task_result)

    async def chat_message_put(self, text: Optional[str]) -> None:
        await self.llm_response_queue.put(text)

    async def chat_shutdown(self):
        # sub task, llm_response_stream task가 종료될 수 있도록 None을 Queue에 추가합니다.
        await self.llm_response_queue.put(None)

        # main task, llm_response_task를 취소 요청합니다.
        if self.llm_response_task and not self.llm_response_task.done():
            self.llm_response_task.cancel()
            try:
                await self.llm_response_task
            except asyncio.CancelledError:
                pass

    async def llm_response_stream(self) -> AsyncIterator[str]:
        """
        LLM response stream

        Returns:
            AsyncIterator[str]: 유저 음성 데이터 스트림
        """
        while True:
            item = await self.llm_response_queue.get()
            yield item
            if item is None:
                break

    async def run_llm_response_task(self) -> None:
        async for stream_key, data_raw in amerge(
            llm_response_stream=self.llm_response_stream(),
        ):
            if stream_key == "llm_response_stream":
                await self.send_func(data_raw)
            else:
                raise NotImplementedError(f"[run_llm_response_task] Not Implemented stream_key : {stream_key}")

    @staticmethod
    async def handle_llm_response_task_result(task: asyncio.Task) -> None:
        try:
            task.result()
        except asyncio.CancelledError:
            pass
        except Exception as e:
            logger.error(f"LLM Response Task failed with exception: %s", e)

    def get_template_name(self) -> str:
        return self.template_name

    def get_form_class(self) -> Type[Form]:
        return self.form_class

    @cached_property
    def query_params(self) -> QueryDict:
        if hasattr(self, "scope"):
            query_string = self.scope.get("query_string", b"")
            return QueryDict(query_string)
        elif hasattr(self, "request"):
            return self.request.GET
        return QueryDict()

    @cached_property
    def url_route_kwargs(self) -> Dict[str, str]:
        if hasattr(self, "scope"):
            return self.scope["url_route"]["kwargs"]
        elif hasattr(self, "kwargs"):
            return self.kwargs
        return {}

    async def get_user(self) -> Optional[UserType]:
        cache_key = "_cached_user"
        if not hasattr(self, cache_key):
            if hasattr(self, "scope"):
                try:
                    user = self.scope["user"]
                except KeyError:
                    raise ImproperlyConfigured(
                        "scope['user']에 접근할 수 없습니다. "
                        "channels.auth.AuthMiddlewareStack이 ASGI 애플리케이션에 올바르게 구성되어 있는지 확인하세요. "
                        "\n예시: application = ProtocolTypeRouter({"
                        "\n    'websocket': AuthMiddlewareStack(URLRouter(websocket_urlpatterns))"
                        "\n})"
                    )
            elif hasattr(self, "request"):
                user = self.request.user
            else:
                user = None

            if user is not None:
                is_authenticated = await sync_to_async(lambda: user.is_authenticated)()
                if not is_authenticated:
                    user = None

            setattr(self, cache_key, user)

        return getattr(self, cache_key, None)

[문서] def get_conversation_pk(self) -> Optional[str]: """웹소켓 요청 URL에서 추출한 대화방 식별자를 반환합니다. Returns: Optional[str]: 대화방 식별자 """ for candidate_key in ( self.conversation_pk_url_kwarg, "conversation_pk", "conversation_id", "pk", "id", ): if candidate_key in self.url_route_kwargs: return str(self.url_route_kwargs[candidate_key]) return None
[문서] async def aget_conversation(self) -> Optional[Conversation]: cache_key = "_cached_conversation" if not hasattr(self, cache_key): conversation_pk = self.get_conversation_pk() if conversation_pk: qs = Conversation.objects.filter(pk=conversation_pk) # noqa conversation = await qs.afirst() else: conversation = None setattr(self, cache_key, conversation) return getattr(self, cache_key, None)
[문서] async def can_accept(self) -> bool: """연결을 수락할 수 있는지 여부를 반환합니다. Returns: bool: 연결을 수락할 수 있는 경우 True, 그렇지 않은 경우 False. """ user = await self.get_user() if user and user.is_authenticated: return True return False
async def form_handler(self, data: QueryDict, files: MultiValueDict, add_none_to_queue: bool = False): form_cls = self.get_form_class() form = form_cls(data=data, files=files) if form.is_valid(): await self.form_valid(form) else: await self.form_invalid(form) if add_none_to_queue: await self.chat_message_put(None) async def form_valid(self, form: Form) -> None: """유효한 폼 데이터를 처리하고, 유저 입력에 대한 응답을 생성합니다. Args: form (Form): 유효한 폼 객체. """ user_text = form.cleaned_data[self.user_text_field_name] images: List[File] = form.cleaned_data[self.user_images_field_name] await self.make_response(user_text=user_text, images=images) async def form_invalid(self, form: Form) -> None: """유효하지 않은 폼 데이터를 처리하고, 에러 메시지를 렌더링합니다. Args: form (Form): 유효하지 않은 폼 객체. """ error_message: str = ", ".join((f"{field}: {', '.join(errors)}" for field, errors in form.errors.items())) content_block = TextContentBlock(role="error", value=error_message) await self.render_block(content_block) async def think(self, input_query: str, files: Optional[List[File]] = None) -> AsyncIterator[ContentBlock]: """에이전트를 통해 입력 쿼리에 대해 생각하고 결과를 비동기적으로 반환합니다. Args: input_query (str): 입력 쿼리. files (Optional[List[File]]): 파일 목록 Yields: AsyncIterator[ContentBlock]: 생성된 메시지 청크. """ yield TextContentBlock("") @cached_property def render_format(self) -> Literal["json", "htmx"]: """JSON 요청이 아니라면 HTMX 응답""" # Consumer에서 요청을 처리할 경우, scope에서 헤더 확인 if hasattr(self, "scope"): # 참고: htmx ws (v2.0.1) 확장을 통한 요청에서는 htmx 헤더가 하나도 없습니다. # WebSocket 요청인 경우 scope에서 헤더 확인 headers = dict(self.scope.get("headers", ())) if headers.get("hx-request") == "true": return "htmx" if "application/json" in headers.get("accept", ""): return "json" # Django View 에서 요청을 처리할 경우, request.headers 확인 elif hasattr(self, "request"): if self.request.headers.get("HX-Request") == "true": return "htmx" if "application/json" in self.request.headers.get("Accept", ""): return "json" return self.query_params.get("format", "htmx") async def render_block( self, content_block: Optional[ContentBlock] = None, mode: MessageBlockRenderModeType = "overwrite", ) -> MessageBlock: if content_block is None: content_block = EventContentBlock() message_block = MessageBlock( chat_messages_dom_id=self.chat_messages_dom_id, content_block=content_block, template_name=self.get_template_name(), send_func=self.chat_message_put, render_format=self.render_format, ) await message_block.render(mode) return message_block async def make_response( self, user_text: str, images: Optional[List[File]] = None, ) -> None: """사용자 입력에 대한 응답을 생성하고, 메시지 타입에 맞게 렌더링합니다. Args: user_text (str): 사용자 입력 텍스트 images (Optional[List[File]]): 첨부된 사진 파일 목록 """ thinking_block = await self.render_block(mode="thinking-start") current_message_block: Optional[MessageBlock] = None content_block: ContentBlock async for content_block in self.think(input_query=user_text, files=images): # 새 메시지 블록을 렌더링하거나, 기존 메시지 블록에 추가합니다. if current_message_block is None or content_block.id != current_message_block.content_block.id: current_message_block = await self.render_block(content_block) else: await current_message_block.append(content_block) # 사용량 블록이 있는 경우, 렌더링합니다. usage_block = content_block.get_usage_block() if usage_block: await self.render_block(usage_block) # 모든 응답 생성이 완료되면, "생각 중" 메시지를 삭제합니다. if thinking_block is not None: await thinking_block.thinking_end()