Source code for climb.tool.impl.sub_agents

from typing import Any, Dict

from openai import AzureOpenAI, OpenAI

from climb.common import Session
from climb.engine.const import MODEL_MAX_MESSAGE_TOKENS


[docs] def create_llm_client( session: Session, additional_kwargs_required: Dict[str, Any], ) -> Any: if session.engine_name in ("openai_v1",): client = OpenAI(api_key=additional_kwargs_required["api_key"]) elif session.engine_name in ("azure_openai_v1",): client = AzureOpenAI( azure_endpoint=additional_kwargs_required["azure_endpoint"], api_version=additional_kwargs_required["api_version"], api_key=additional_kwargs_required["api_key"], ) else: raise ValueError(f"Unknown engine name: {session.engine_name}") return client
[docs] def get_llm_chat( client: Any, session: Session, additional_kwargs_required: Dict[str, Any], chat_kwargs: Dict, ) -> str: if session.engine_name in ("openai_v1",): model_type = additional_kwargs_required["engine_params"]["model_id"] out = client.chat.completions.create( model=model_type, max_tokens=MODEL_MAX_MESSAGE_TOKENS[model_type], temperature=additional_kwargs_required["engine_params"]["temperature"], # --- messages=chat_kwargs["messages"], stream=chat_kwargs["stream"], ) elif session.engine_name in ("azure_openai_v1",): model_type = additional_kwargs_required["azure_openai_config"].model out = client.chat.completions.create( model=additional_kwargs_required["azure_openai_config"].deployment_name, max_tokens=MODEL_MAX_MESSAGE_TOKENS[model_type], # --- messages=chat_kwargs["messages"], stream=chat_kwargs["stream"], ) else: raise ValueError(f"Unknown engine name: {session.engine_name}") out_text = out.choices[0].message.content # type: ignore return out_text