from agno.db.sqlite import SqliteDb from src.agent.base import creat_agent from src.infra.logger import get_logger logger = get_logger(__name__) def create_planner_agent(model, base_kwargs={}, tools=None, session_id=None): planer_db = SqliteDb(db_file="tmp/planner.db") planer_session_state = {"task_list": None} planner_agent = creat_agent(name="planner", model=model, tools=tools, db=planer_db, session_state=planer_session_state, add_session_state_to_context=True, add_datetime_to_context=True, markdown=True, session_id=session_id, **base_kwargs) logger.info(f"🧑‍✈️ Planner agent created - {planner_agent.session_id}") return planner_agent if __name__ == "__main__": from agno.models.google import Gemini from agno.agent import RunEvent from src.infra.config import get_settings from src.agent.base import UserState, Location, get_context user_message = "I'm going to San Francisco for tourism tomorrow, please help me plan a one-day itinerary." setting = get_settings() main_model = Gemini( id="gemini-2.5-flash", thinking_budget=1024, api_key=setting.gemini_api_key) use_state = UserState(location=Location(lat=25.058903, lng=121.549131)) kwargs = { "additional_context": get_context(use_state), "timezone_identifier": use_state.utc_offset, } planer_agent = create_planner_agent(main_model, kwargs) def planner_stream_handle(stream_item): show = True response = "" for chuck in stream_item: if chuck.event == RunEvent.run_content: content = chuck.content response += chuck.content if show: if "@@@" in response: show = False content = content.split("@@@")[0] print(content) json_data = "{" + response.split("{", maxsplit=1)[-1] return json_data, response def planner_message(agent, message): stream = agent.run(f"help user to update the task_list, user's message: {message}", stream=True, stream_events=True) task_list, _response = planner_stream_handle(stream) agent.update_session_state( session_id=agent.session_id, session_state_updates={"task_list": task_list}, ) planner_message(planer_agent, user_message) print(planer_agent.get_session_state())