Source code for qualia_core.utils.process.AsyncTeeProtocol

import asyncio
import logging
import os
import sys
from typing import BinaryIO, Optional

if sys.version_info >= (3, 12):
    from typing import override
else:
    from typing_extensions import override

logger = logging.getLogger(__name__)

[docs] class AsyncTeeProtocol(asyncio.SubprocessProtocol): transport: Optional[asyncio.SubprocessTransport] = None def __init__(self, done_future: asyncio.Future[tuple[int, dict[int, bytearray]]], files: Optional[dict[int, BinaryIO]] = None) -> None: super().__init__() self.done = done_future self.buffers: dict[int, bytearray] = {} self.files = files if files is not None else {}
[docs] @override def connection_made(self, transport: asyncio.BaseTransport) -> None: if isinstance(transport, asyncio.SubprocessTransport): self.transport = transport else: logger.error('Expected transport type asyncio.SubprocessTransport, got: %s', type(asyncio.SubprocessTransport))
[docs] @override def pipe_data_received(self, fd: int, data: bytes) -> None: _ = os.write(fd, data) if fd not in self.buffers: self.buffers[fd] = bytearray() self.buffers[fd].extend(data) if fd in self.files: _ = self.files[fd].write(data)
[docs] @override def process_exited(self) -> None: if self.transport is None: logger.error('Transport is None') else: return_code = self.transport.get_returncode() if return_code is None: logger.error('return_code is None') else: self.done.set_result((return_code, self.buffers))