Skip to content

framework package

framework package contains utilities for building assistants.

AssistantToolkit

Bases: Toolkit

Source code in framework/toolkit.py
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
class AssistantToolkit(Toolkit):

    TOOL_NAME = "delegate_task"

    def __init__(self, assistants: list[str]):
        self.assitants = assistants

        assistant_enum = StrEnum("Assistant", self.assitants)

        class DelegateTask(BaseModel):
            """
            Delegate a task to an assistant.
            """

            assistant: assistant_enum
            task: str
            id: str | None = Field(
                default=None,
                description=(
                    "ID of the task. Used to track the task."
                    "Pass null if you want to start a new conversation. The tool will generate a new ID."
                    "Pass the same ID if you want to continue the conversation."
                ),
            )

        self._model = DelegateTask
        self._tools = [pydantic_function_tool(self._model, name=self.TOOL_NAME)]

    async def get_tools(self) -> list[ChatCompletionToolParam]:
        return self._tools

    async def handle_tool_calls(
        self, tool_calls: list[ChatCompletionMessageToolCall], context: ToolContext
    ) -> list[LiteLLMMessage]:
        ret = []
        for tool_call in tool_calls:
            if tool_call.function.name != self.TOOL_NAME:
                continue
            logger.info(f"Executing tool call: {tool_call}")
            instance = self._model.model_validate_json(tool_call.function.arguments)
            task_response = await self._complete_task(instance, context.caller)
            logger.debug(f"Task response: {task_response}")
            ret.append(
                LiteLLMMessage(
                    role="tool",  # type: ignore
                    content=task_response.model_dump_json(),
                    tool_call_id=tool_call.id,
                )
            )
        return ret

    async def _complete_task(self, tool_call, assistant_name: str) -> TaskResponse:
        from deps import registry
        from framework import LLMAssistant

        if tool_call.id:
            chat = Chat(state=ChatState.load_from_disk(tool_call.id))
        else:
            chat = Chat()

        # Run the assistant on the task's chat session
        assistant = registry.get_assistant(tool_call.assistant)
        chat.state.assistant = assistant.name
        chat.state.messages.append(Message(role="user", name=assistant_name, content=tool_call.task))
        await assistant.run(chat)

        task_analyzer = LLMAssistant(
            name="TaskAnalyzer",
            model="gpt-4.1",
            system_prompt="""
            Analyze the conversation and extract the task details and output.
            You must extract all details asked from the conversation because the user cannot see this conversation.
            For example if user is asking for email, you must extract the list of emails and output them.
            """,
            output_type=TaskAnalysis,
        )

        temp = Chat()
        temp.state.messages = chat.state.messages.copy()
        await task_analyzer.run(temp)

        analysis = TaskAnalysis.model_validate_json(temp.state.messages[-1].content)
        chat.state.save_to_disk()

        return TaskResponse(id=chat.state.id, analysis=analysis)

__init__(assistants)

Source code in framework/toolkit.py
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
def __init__(self, assistants: list[str]):
    self.assitants = assistants

    assistant_enum = StrEnum("Assistant", self.assitants)

    class DelegateTask(BaseModel):
        """
        Delegate a task to an assistant.
        """

        assistant: assistant_enum
        task: str
        id: str | None = Field(
            default=None,
            description=(
                "ID of the task. Used to track the task."
                "Pass null if you want to start a new conversation. The tool will generate a new ID."
                "Pass the same ID if you want to continue the conversation."
            ),
        )

    self._model = DelegateTask
    self._tools = [pydantic_function_tool(self._model, name=self.TOOL_NAME)]

FunctionToolkit

Bases: Toolkit

Manages the list of tools to be passed into completion reqeust.

Source code in framework/toolkit.py
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
class FunctionToolkit(Toolkit):
    """Manages the list of tools to be passed into completion reqeust."""

    def __init__(self, functions: list[Callable]) -> None:
        self.functions = {f.__name__: f for f in functions}
        self.models = {f.__name__: _function_to_pydantic_model(f) for f in functions}
        self.tools = [pydantic_function_tool(model) for model in self.models.values()]

    async def get_tools(self) -> list[ChatCompletionToolParam]:
        """Returns the list of tools to be passed into completion reqeust."""
        return self.tools

    async def handle_tool_calls(
        self, tool_calls: list[ChatCompletionMessageToolCall], context: ToolContext
    ) -> list[LiteLLMMessage]:
        """This is called each time a response is received from completion method."""
        logger.info("Number of tool calls: %s", len(tool_calls))
        messages = []
        for tool_call in tool_calls:
            function = tool_call.function
            assert isinstance(function.name, str)
            if function.name not in self.functions:
                continue
            logger.info("Tool call: %s(%s)", function.name, function.arguments)
            func = self.functions[function.name]
            model = self.models[function.name]
            instance = model.model_validate_json(tool_call.function.arguments)
            kwargs = {name: getattr(instance, name) for name in model.model_fields}

            # Fill in default values
            for param in signature(func).parameters.values():
                if kwargs[param.name] is None and param.default is not Parameter.empty:
                    kwargs[param.name] = param.default

            if asyncio.iscoroutinefunction(func):
                result = await func(**kwargs)
            else:
                result = func(**kwargs)

            logger.info("%s call result: %s", function.name, result)
            messages.append(
                LiteLLMMessage(
                    role="tool",  # type: ignore
                    tool_call_id=tool_call.id,
                    content=result if isinstance(result, str) else json.dumps(result),
                )
            )

        return messages

get_tools() async

Returns the list of tools to be passed into completion reqeust.

Source code in framework/toolkit.py
69
70
71
async def get_tools(self) -> list[ChatCompletionToolParam]:
    """Returns the list of tools to be passed into completion reqeust."""
    return self.tools

handle_tool_calls(tool_calls, context) async

This is called each time a response is received from completion method.

Source code in framework/toolkit.py
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
async def handle_tool_calls(
    self, tool_calls: list[ChatCompletionMessageToolCall], context: ToolContext
) -> list[LiteLLMMessage]:
    """This is called each time a response is received from completion method."""
    logger.info("Number of tool calls: %s", len(tool_calls))
    messages = []
    for tool_call in tool_calls:
        function = tool_call.function
        assert isinstance(function.name, str)
        if function.name not in self.functions:
            continue
        logger.info("Tool call: %s(%s)", function.name, function.arguments)
        func = self.functions[function.name]
        model = self.models[function.name]
        instance = model.model_validate_json(tool_call.function.arguments)
        kwargs = {name: getattr(instance, name) for name in model.model_fields}

        # Fill in default values
        for param in signature(func).parameters.values():
            if kwargs[param.name] is None and param.default is not Parameter.empty:
                kwargs[param.name] = param.default

        if asyncio.iscoroutinefunction(func):
            result = await func(**kwargs)
        else:
            result = func(**kwargs)

        logger.info("%s call result: %s", function.name, result)
        messages.append(
            LiteLLMMessage(
                role="tool",  # type: ignore
                tool_call_id=tool_call.id,
                content=result if isinstance(result, str) else json.dumps(result),
            )
        )

    return messages

LLMAssistant

Bases: Assistant

Provides an Assistant implementation with a given system prompt and toolkit.

Source code in framework/llm_assistant.py
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
class LLMAssistant(Assistant):
    """Provides an Assistant implementation with a given system prompt and toolkit."""

    def __init__(
        self,
        name: str,
        description: Optional[str] = None,
        model: str = DEFAULT_MODEL,
        system_prompt: Optional[str] = None,
        output_type: Optional[type[BaseModel]] = None,
        toolkit: Optional[Toolkit] = None,
        max_turns: int = 10,
    ):
        """
        Creates a new LLMAssistant.
        """
        self.name = name
        self.description = description
        self.model = model
        self.system_prompt = system_prompt
        self.output_type = output_type
        self.toolkit = toolkit
        self.max_turns = max_turns
        self.examples: list[tuple[str, BaseModel]] = []

    async def run(self, chat: Chat) -> None:
        logger.info("Running assistant %s", self.name)

        # These messages are sent to the LLM API, prefixed by the system prompt.
        messages = self._get_messages(chat)

        async def handle_tool_calls(message: LitellmMessage):
            assert self.toolkit
            assert message.tool_calls
            tool_messages = await self.toolkit.handle_tool_calls(message.tool_calls, ToolContext(caller=self.name))
            assert len(tool_messages) == len(message.tool_calls)
            for tool_message in tool_messages:
                messages.append(tool_message)
                reply = await chat.reply("tool", name=self.name)
                await reply.add_chunk(tool_message["tool_call_id"], field="tool_call_id")
                if tool_message.content:
                    await reply.add_chunk(tool_message.content)
                await reply.end()

        # We start by sending the first message.
        message = await self._complete(messages, chat)
        messages.append(message)

        # We keep continue hitting OpenAI API until there are no more tool calls.
        current_turn = 0
        while message.tool_calls:
            current_turn += 1
            if current_turn > self.max_turns:
                raise Exception(f"Max turns ({self.max_turns}) exceeded")

            await handle_tool_calls(message)

            # Send messages with tool calls.
            message = await self._complete(messages, chat)
            messages.append(message)

    async def _complete(self, messages: list[LitellmMessage], chat: Chat) -> LitellmMessage:
        # Replace invalid characters in assistant name
        for message in messages:
            if message.get("name"):
                message["name"] = re.sub(r"[^a-zA-Z0-9-]", "_", message["name"])

        logger.info("Completing chat")
        for message in messages:
            logger.debug(message)

        kwargs = {}
        if self.toolkit:
            tools = await self.toolkit.get_tools()
            if tools:
                kwargs["tools"] = tools
                kwargs["tool_choice"] = "auto"
                kwargs["parallel_tool_calls"] = False

        if self.output_type:
            kwargs["response_format"] = self.output_type

        response = await acompletion(
            model=self.model,
            messages=messages,
            stream=True,
            metadata={
                "existing_trace_id": langfuse_context.get_current_trace_id(),
                "parent_observation_id": langfuse_context.get_current_observation_id(),
            },
            **kwargs,
        )
        assert isinstance(response, CustomStreamWrapper)

        # We start by sending a begin_message event to the web client.
        # This will cause the web client to draw a new message box for the assistant.
        reply = await chat.reply("assistant", name=self.name)

        # We will aggregate delta messages and store them in this variable until we see a finish_reason.
        # This is the only way to get the full content of the message.
        # We'll return this value at the end of the function.
        builder = MessageBuilder()

        # We will return this value at the end of the function.
        message: Optional[LitellmMessage] = None

        # Do not break this loop. Otherwise, litellm will not be able to run callbacks.
        async for chunk in response:
            assert chunk.__class__.__name__ == "ModelResponseStream"
            assert len(chunk.choices) == 1
            choice = chunk.choices[0]
            events = builder.write(choice.delta)
            for event in events:
                await reply.add_chunk(event.chunk, field=event.name)

            if finish_reason := choice.finish_reason:
                message = builder.getvalue()
                if finish_reason not in ("stop", "tool_calls"):
                    raise NotImplementedError(f"finish_reason={finish_reason}")
                await reply.end()

        if not message:
            raise Exception("Stream ended unexpectedly")

        return message

    def _get_messages(self, chat: Chat) -> list[LitellmMessage]:
        messages: list[LitellmMessage] = []

        messages.append(
            LitellmMessage(
                role="system",  # type: ignore
                content=self._get_system_prompt(),
            )
        )

        for user_message, response in self.examples:
            messages.extend(
                [
                    LitellmMessage(
                        role="system",  # type: ignore
                        name="example_user",
                        content=user_message,
                    ),
                    LitellmMessage(
                        role="system",  # type: ignore
                        name="example_assistant",
                        content=response.model_dump_json(),
                    ),
                ]
            )

        messages.extend([message_to_litellm(message) for message in chat.state.messages])
        return messages

    def _get_system_prompt(self) -> str:
        prompt = self.system_prompt or ""
        if prompt:
            prompt += "\n\n"

        t = datetime.now().strftime("%A, %B %d, %Y at %I:%M:%S %p")
        o = time.strftime("%z")  # Timezone offset
        prompt += f"Today's date and time is {t} ({o})"

        return prompt

    def add_example(self, user_message: str, response: BaseModel):
        """Add an example to the prompt."""
        self.examples.append((user_message, response))

__init__(name, description=None, model=DEFAULT_MODEL, system_prompt=None, output_type=None, toolkit=None, max_turns=10)

Creates a new LLMAssistant.

Source code in framework/llm_assistant.py
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
def __init__(
    self,
    name: str,
    description: Optional[str] = None,
    model: str = DEFAULT_MODEL,
    system_prompt: Optional[str] = None,
    output_type: Optional[type[BaseModel]] = None,
    toolkit: Optional[Toolkit] = None,
    max_turns: int = 10,
):
    """
    Creates a new LLMAssistant.
    """
    self.name = name
    self.description = description
    self.model = model
    self.system_prompt = system_prompt
    self.output_type = output_type
    self.toolkit = toolkit
    self.max_turns = max_turns
    self.examples: list[tuple[str, BaseModel]] = []

add_example(user_message, response)

Add an example to the prompt.

Source code in framework/llm_assistant.py
197
198
199
def add_example(self, user_message: str, response: BaseModel):
    """Add an example to the prompt."""
    self.examples.append((user_message, response))

MultiToolkit

Bases: Toolkit

Combines multiple toolkits into one.

Source code in framework/toolkit.py
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
class MultiToolkit(Toolkit):
    """Combines multiple toolkits into one."""

    def __init__(self, toolkits: list[Toolkit]) -> None:
        self.toolkits = toolkits

    async def get_tools(self) -> list[ChatCompletionToolParam]:
        tools = []
        for toolkit in self.toolkits:
            tools.extend(await toolkit.get_tools())
        return tools

    async def handle_tool_calls(
        self, tool_calls: list[ChatCompletionMessageToolCall], context: ToolContext
    ) -> list[LiteLLMMessage]:
        messages = []
        for toolkit in self.toolkits:
            messages.extend(await toolkit.handle_tool_calls(tool_calls, context))
        return messages

Toolkit

Bases: ABC

Manages the list of tools to be passed into completion reqeust.

Source code in framework/toolkit.py
28
29
30
31
32
33
34
35
36
37
class Toolkit(ABC):
    """Manages the list of tools to be passed into completion reqeust."""

    @abstractmethod
    async def get_tools(self) -> list[ChatCompletionToolParam]: ...

    @abstractmethod
    async def handle_tool_calls(
        self, tool_calls: list[ChatCompletionMessageToolCall], context: ToolContext
    ) -> list[LiteLLMMessage]: ...