# -*- coding: utf-8 -*- """Generative Manim LangGraph Implementation.ipynb Automatically generated by Colab. Original file is located at https://colab.research.google.com/drive/1YSO9TG2fJVVH4l7yTHE_V-v8VaDfpM2s # Generative Manim LangGraph Implementation Taking the example of [Code generation with flow](https://github.com/langchain-ai/langgraph/blob/main/examples/code_assistant/langgraph_code_assistant.ipynb?ref=blog.langchain.dev), we will implement a similar approach to generate code for Manim animations. So far, I think we would not need test validation, we can delay this step for later. """ """## Extracting examples from Manim docs""" # Load .env from dotenv import load_dotenv load_dotenv() import os from bs4 import BeautifulSoup as Soup from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader # Manim Examples docs url = "https://docs.manim.community/en/stable/examples.html" loader = RecursiveUrlLoader( url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text ) docs = loader.load() # Sort the list based on the URLs and get the text d_sorted = sorted(docs, key=lambda x: x.metadata["source"]) d_reversed = list(reversed(d_sorted)) concatenated_content = "\n\n\n --- \n\n\n".join( [doc.page_content for doc in d_reversed] ) """## LLM Solution""" from langchain_openai import ChatOpenAI from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field ### OpenAI # Manim code generation prompt code_gen_prompt = ChatPromptTemplate.from_messages( [("system","""You are a coding assistant with expertise in Manim, Graphical Animation Library. \n Here is a full set of Manim documentation: \n ------- \n {context} \n ------- \n Answer the user question based on the above provided documentation. Ensure any code you provide can be executed \n with all required imports and variables defined. Structure your answer with a description of the code solution. \n Then list the imports. And finally list the functioning code block. Here is the user question:"""), ("placeholder", "{messages}")] ) class code(BaseModel): """Code output""" prefix: str = Field(description="Description of the problem and approach") imports: str = Field(description="Code block import statements") code: str = Field(description="Code block not including import statements") description = "Schema for code solutions to requests on code for Manim Animations." import anthropic OPENAI_API_KEY = os.getenv('OPENAI_API_KEY') anthropic.api_key = os.getenv('ANTHROPIC_API_KEY') expt_llm = "gpt-4-0125-preview" llm = ChatOpenAI(temperature=0, model=expt_llm, openai_api_key=OPENAI_API_KEY) code_gen_chain = code_gen_prompt | llm.with_structured_output(code) question = "Draw three red circles" solution = code_gen_chain.invoke({"context":concatenated_content,"messages":[("user",question)]}) solution """Now let's try with Anthropic""" from langchain_anthropic import ChatAnthropic from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field code_gen_prompt_claude = ChatPromptTemplate.from_messages( [("system",""" You are a coding assistant with expertise in Manim, Graphical Animation Library. \n Here is a full set of Manim documentation: \n ------- \n {context} \n ------- \n Answer the user request based on the \n above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n Invoke the code tool from Manim to structure the output correctly. \n Here is the user request:""",), ("placeholder", "{messages}"),]) # Data model class code(BaseModel): """Code output""" prefix: str = Field(description="Description of the problem and approach") imports: str = Field(description="Code block import statements") code: str = Field(description="Code block not including import statements") description = "Schema for code solutions to questions about Manim." # LLM # expt_llm = "claude-3-haiku-20240307" expt_llm = "claude-3-opus-20240229" llm = ChatAnthropic( model=expt_llm, default_headers={"anthropic-beta": "tools-2024-04-04"}, ) structured_llm_claude = llm.with_structured_output(code, include_raw=True) code_chain_claude_raw = code_gen_prompt_claude | structured_llm_claude def insert_errors(inputs): """Insert errors for tool parsing in the messages""" # Get errors error = inputs["error"] messages = inputs["messages"] messages += [ ( "assistant", f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.", ) ] return { "messages": messages, "context": inputs["context"], } # This will be run as a fallback chain fallback_chain = insert_errors | code_chain_claude_raw N = 3 # Max re-tries code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(fallbacks=[fallback_chain] * N, exception_key="error") def parse_output(solution): """When we add 'include_raw=True' to structured output, it will return a dict w 'raw', 'parsed', 'parsing_error'. """ return solution['parsed'] # Wtih re-try to correct for failure to invoke tool # TODO: Annoying errors w/ "user" vs "assistant" # Roles must alternate between "user" and "assistant", but found multiple "user" roles in a row code_gen_chain = code_gen_chain_re_try | parse_output # No re-try code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output # Test question = "Draw a red circle" solution = code_gen_chain.invoke({"context":concatenated_content,"messages":[("user",question)]}) solution """## State Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation. """ from typing import Dict, TypedDict, List class GraphState(TypedDict): """ Represents the state of our graph. Attributes: error : Binary flag for control flow to indicate whether test error was tripped messages : With user question, error messages, reasoning generation : Code solution iterations : Number of tries """ error : str messages : List generation : str iterations : int """## Graph Our graph lays out the logical flow shown in the figure above. """ from operator import itemgetter from langchain.prompts import PromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.runnables import RunnablePassthrough ### Parameter # Max tries max_iterations = 3 # Reflect # flag = 'reflect' flag = 'do not reflect' ### Nodes def generate(state: GraphState): """ Generate a code solution Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---GENERATING CODE SOLUTION---") # State messages = state["messages"] iterations = state["iterations"] error = state["error"] # We have been routed back to generation with an error if error == "yes": messages += [("user","Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:")] # Solution code_solution = code_gen_chain.invoke({"context": concatenated_content, "messages" : messages}) messages += [("assistant",f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}")] # Increment iterations = iterations + 1 return {"generation": code_solution, "messages": messages, "iterations": iterations} def code_check(state: GraphState): """ Check code Args: state (dict): The current graph state Returns: state (dict): New key added to state, error """ print("---CHECKING CODE---") # State messages = state["messages"] code_solution = state["generation"] iterations = state["iterations"] # Get solution components prefix = code_solution.prefix imports = code_solution.imports code = code_solution.code # Check imports try: exec(imports) except Exception as e: print("---CODE IMPORT CHECK: FAILED---") error_message = [("user", f"Your solution failed the import test: {e}")] messages += error_message return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"} # Check execution try: exec(imports + "\n" + code) except Exception as e: print("---CODE BLOCK CHECK: FAILED---") error_message = [("user", f"Your solution failed the code execution test: {e}")] messages += error_message return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"} # No errors print("---NO CODE TEST FAILURES---") return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "no"} def reflect(state: GraphState): """ Reflect on errors Args: state (dict): The current graph state Returns: state (dict): New key added to state, generation """ print("---GENERATING CODE SOLUTION---") # State messages = state["messages"] iterations = state["iterations"] code_solution = state["generation"] # Prompt reflection reflection_message = [("user", """You tried to solve this problem and failed a unit test. Reflect on this failure given the provided documentation. Write a few key suggestions based on the documentation to avoid making this mistake again.""")] # Add reflection reflections = code_gen_chain.invoke({"context" : concatenated_content, "messages" : messages}) messages += [("assistant" , f"Here are reflections on the error: {reflections}")] return {"generation": code_solution, "messages": messages, "iterations": iterations} ### Edges def decide_to_finish(state: GraphState): """ Determines whether to finish. Args: state (dict): The current graph state Returns: str: Next node to call """ error = state["error"] iterations = state["iterations"] if error == "no" or iterations == max_iterations: print("---DECISION: FINISH---") return "end" else: print("---DECISION: RE-TRY SOLUTION---") if flag == 'reflect': return "reflect" else: return "generate" from langgraph.graph import END, StateGraph workflow = StateGraph(GraphState) # Define the nodes workflow.add_node("generate", generate) # generation solution workflow.add_node("check_code", code_check) # check code workflow.add_node("reflect", reflect) # reflect # Build graph workflow.set_entry_point("generate") workflow.add_edge("generate", "check_code") workflow.add_conditional_edges( "check_code", decide_to_finish, { "end": END, "reflect": "reflect", "generate": "generate", }, ) workflow.add_edge("reflect", "generate") app = workflow.compile() question = "Draw a red circle" app.invoke({"messages":[("user",question)],"iterations":0})