You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
352 lines
11 KiB
352 lines
11 KiB
# -*- coding: utf-8 -*-
|
|
"""Generative Manim LangGraph Implementation.ipynb
|
|
|
|
Automatically generated by Colab.
|
|
|
|
Original file is located at
|
|
https://colab.research.google.com/drive/1YSO9TG2fJVVH4l7yTHE_V-v8VaDfpM2s
|
|
|
|
# Generative Manim LangGraph Implementation
|
|
|
|
Taking the example of [Code generation with flow](https://github.com/langchain-ai/langgraph/blob/main/examples/code_assistant/langgraph_code_assistant.ipynb?ref=blog.langchain.dev), we will implement a similar approach to generate code for Manim animations. So far, I think we would not need test validation, we can delay this step for later.
|
|
"""
|
|
|
|
"""## Extracting examples from Manim docs"""
|
|
# Load .env
|
|
from dotenv import load_dotenv
|
|
load_dotenv()
|
|
|
|
import os
|
|
from bs4 import BeautifulSoup as Soup
|
|
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
|
|
|
|
# Manim Examples docs
|
|
url = "https://docs.manim.community/en/stable/examples.html"
|
|
loader = RecursiveUrlLoader(
|
|
url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text
|
|
)
|
|
docs = loader.load()
|
|
|
|
# Sort the list based on the URLs and get the text
|
|
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
|
|
d_reversed = list(reversed(d_sorted))
|
|
concatenated_content = "\n\n\n --- \n\n\n".join(
|
|
[doc.page_content for doc in d_reversed]
|
|
)
|
|
|
|
"""## LLM Solution"""
|
|
|
|
from langchain_openai import ChatOpenAI
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
### OpenAI
|
|
|
|
# Manim code generation prompt
|
|
code_gen_prompt = ChatPromptTemplate.from_messages(
|
|
[("system","""You are a coding assistant with expertise in Manim, Graphical Animation Library. \n
|
|
Here is a full set of Manim documentation: \n ------- \n {context} \n ------- \n Answer the user
|
|
question based on the above provided documentation. Ensure any code you provide can be executed \n
|
|
with all required imports and variables defined. Structure your answer with a description of the code solution. \n
|
|
Then list the imports. And finally list the functioning code block. Here is the user question:"""),
|
|
("placeholder", "{messages}")]
|
|
)
|
|
|
|
class code(BaseModel):
|
|
"""Code output"""
|
|
prefix: str = Field(description="Description of the problem and approach")
|
|
imports: str = Field(description="Code block import statements")
|
|
code: str = Field(description="Code block not including import statements")
|
|
description = "Schema for code solutions to requests on code for Manim Animations."
|
|
|
|
|
|
import anthropic
|
|
|
|
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
|
|
anthropic.api_key = os.getenv('ANTHROPIC_API_KEY')
|
|
|
|
expt_llm = "gpt-4-0125-preview"
|
|
llm = ChatOpenAI(temperature=0, model=expt_llm, openai_api_key=OPENAI_API_KEY)
|
|
code_gen_chain = code_gen_prompt | llm.with_structured_output(code)
|
|
|
|
question = "Draw three red circles"
|
|
|
|
solution = code_gen_chain.invoke({"context":concatenated_content,"messages":[("user",question)]})
|
|
|
|
solution
|
|
|
|
"""Now let's try with Anthropic"""
|
|
|
|
from langchain_anthropic import ChatAnthropic
|
|
from langchain_core.prompts import ChatPromptTemplate
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
|
|
code_gen_prompt_claude = ChatPromptTemplate.from_messages(
|
|
[("system","""<instructions> You are a coding assistant with expertise in Manim, Graphical Animation Library. \n
|
|
Here is a full set of Manim documentation: \n ------- \n {context} \n ------- \n Answer the user request based on the \n
|
|
above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n
|
|
defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n
|
|
Invoke the code tool from Manim to structure the output correctly. </instructions> \n Here is the user request:""",),
|
|
("placeholder", "{messages}"),])
|
|
|
|
# Data model
|
|
class code(BaseModel):
|
|
"""Code output"""
|
|
|
|
prefix: str = Field(description="Description of the problem and approach")
|
|
imports: str = Field(description="Code block import statements")
|
|
code: str = Field(description="Code block not including import statements")
|
|
description = "Schema for code solutions to questions about Manim."
|
|
|
|
# LLM
|
|
# expt_llm = "claude-3-haiku-20240307"
|
|
expt_llm = "claude-3-opus-20240229"
|
|
llm = ChatAnthropic(
|
|
model=expt_llm,
|
|
default_headers={"anthropic-beta": "tools-2024-04-04"},
|
|
)
|
|
|
|
structured_llm_claude = llm.with_structured_output(code, include_raw=True)
|
|
|
|
code_chain_claude_raw = code_gen_prompt_claude | structured_llm_claude
|
|
|
|
def insert_errors(inputs):
|
|
"""Insert errors for tool parsing in the messages"""
|
|
|
|
# Get errors
|
|
error = inputs["error"]
|
|
messages = inputs["messages"]
|
|
messages += [
|
|
(
|
|
"assistant",
|
|
f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.",
|
|
)
|
|
]
|
|
return {
|
|
"messages": messages,
|
|
"context": inputs["context"],
|
|
}
|
|
|
|
# This will be run as a fallback chain
|
|
fallback_chain = insert_errors | code_chain_claude_raw
|
|
N = 3 # Max re-tries
|
|
code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(fallbacks=[fallback_chain] * N, exception_key="error")
|
|
|
|
def parse_output(solution):
|
|
"""When we add 'include_raw=True' to structured output,
|
|
it will return a dict w 'raw', 'parsed', 'parsing_error'. """
|
|
|
|
return solution['parsed']
|
|
|
|
# Wtih re-try to correct for failure to invoke tool
|
|
# TODO: Annoying errors w/ "user" vs "assistant"
|
|
# Roles must alternate between "user" and "assistant", but found multiple "user" roles in a row
|
|
code_gen_chain = code_gen_chain_re_try | parse_output
|
|
|
|
# No re-try
|
|
code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output
|
|
|
|
# Test
|
|
question = "Draw a red circle"
|
|
solution = code_gen_chain.invoke({"context":concatenated_content,"messages":[("user",question)]})
|
|
solution
|
|
|
|
"""## State
|
|
|
|
Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation.
|
|
"""
|
|
|
|
from typing import Dict, TypedDict, List
|
|
|
|
class GraphState(TypedDict):
|
|
"""
|
|
Represents the state of our graph.
|
|
|
|
Attributes:
|
|
error : Binary flag for control flow to indicate whether test error was tripped
|
|
messages : With user question, error messages, reasoning
|
|
generation : Code solution
|
|
iterations : Number of tries
|
|
"""
|
|
|
|
error : str
|
|
messages : List
|
|
generation : str
|
|
iterations : int
|
|
|
|
"""## Graph
|
|
|
|
Our graph lays out the logical flow shown in the figure above.
|
|
|
|
|
|
"""
|
|
|
|
from operator import itemgetter
|
|
from langchain.prompts import PromptTemplate
|
|
from langchain_core.pydantic_v1 import BaseModel, Field
|
|
from langchain_core.runnables import RunnablePassthrough
|
|
|
|
### Parameter
|
|
|
|
# Max tries
|
|
max_iterations = 3
|
|
# Reflect
|
|
# flag = 'reflect'
|
|
flag = 'do not reflect'
|
|
|
|
### Nodes
|
|
|
|
def generate(state: GraphState):
|
|
"""
|
|
Generate a code solution
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): New key added to state, generation
|
|
"""
|
|
|
|
print("---GENERATING CODE SOLUTION---")
|
|
|
|
# State
|
|
messages = state["messages"]
|
|
iterations = state["iterations"]
|
|
error = state["error"]
|
|
|
|
# We have been routed back to generation with an error
|
|
if error == "yes":
|
|
messages += [("user","Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:")]
|
|
|
|
# Solution
|
|
code_solution = code_gen_chain.invoke({"context": concatenated_content, "messages" : messages})
|
|
messages += [("assistant",f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}")]
|
|
|
|
# Increment
|
|
iterations = iterations + 1
|
|
return {"generation": code_solution, "messages": messages, "iterations": iterations}
|
|
|
|
def code_check(state: GraphState):
|
|
"""
|
|
Check code
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): New key added to state, error
|
|
"""
|
|
|
|
print("---CHECKING CODE---")
|
|
|
|
# State
|
|
messages = state["messages"]
|
|
code_solution = state["generation"]
|
|
iterations = state["iterations"]
|
|
|
|
# Get solution components
|
|
prefix = code_solution.prefix
|
|
imports = code_solution.imports
|
|
code = code_solution.code
|
|
|
|
# Check imports
|
|
try:
|
|
exec(imports)
|
|
except Exception as e:
|
|
print("---CODE IMPORT CHECK: FAILED---")
|
|
error_message = [("user", f"Your solution failed the import test: {e}")]
|
|
messages += error_message
|
|
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}
|
|
|
|
# Check execution
|
|
try:
|
|
exec(imports + "\n" + code)
|
|
except Exception as e:
|
|
print("---CODE BLOCK CHECK: FAILED---")
|
|
error_message = [("user", f"Your solution failed the code execution test: {e}")]
|
|
messages += error_message
|
|
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}
|
|
|
|
# No errors
|
|
print("---NO CODE TEST FAILURES---")
|
|
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "no"}
|
|
|
|
def reflect(state: GraphState):
|
|
"""
|
|
Reflect on errors
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
state (dict): New key added to state, generation
|
|
"""
|
|
|
|
print("---GENERATING CODE SOLUTION---")
|
|
|
|
# State
|
|
messages = state["messages"]
|
|
iterations = state["iterations"]
|
|
code_solution = state["generation"]
|
|
|
|
# Prompt reflection
|
|
reflection_message = [("user", """You tried to solve this problem and failed a unit test. Reflect on this failure
|
|
given the provided documentation. Write a few key suggestions based on the
|
|
documentation to avoid making this mistake again.""")]
|
|
|
|
# Add reflection
|
|
reflections = code_gen_chain.invoke({"context" : concatenated_content, "messages" : messages})
|
|
messages += [("assistant" , f"Here are reflections on the error: {reflections}")]
|
|
return {"generation": code_solution, "messages": messages, "iterations": iterations}
|
|
|
|
### Edges
|
|
|
|
def decide_to_finish(state: GraphState):
|
|
"""
|
|
Determines whether to finish.
|
|
|
|
Args:
|
|
state (dict): The current graph state
|
|
|
|
Returns:
|
|
str: Next node to call
|
|
"""
|
|
error = state["error"]
|
|
iterations = state["iterations"]
|
|
|
|
if error == "no" or iterations == max_iterations:
|
|
print("---DECISION: FINISH---")
|
|
return "end"
|
|
else:
|
|
print("---DECISION: RE-TRY SOLUTION---")
|
|
if flag == 'reflect':
|
|
return "reflect"
|
|
else:
|
|
return "generate"
|
|
|
|
from langgraph.graph import END, StateGraph
|
|
|
|
workflow = StateGraph(GraphState)
|
|
|
|
# Define the nodes
|
|
workflow.add_node("generate", generate) # generation solution
|
|
workflow.add_node("check_code", code_check) # check code
|
|
workflow.add_node("reflect", reflect) # reflect
|
|
|
|
# Build graph
|
|
workflow.set_entry_point("generate")
|
|
workflow.add_edge("generate", "check_code")
|
|
workflow.add_conditional_edges(
|
|
"check_code",
|
|
decide_to_finish,
|
|
{
|
|
"end": END,
|
|
"reflect": "reflect",
|
|
"generate": "generate",
|
|
},
|
|
)
|
|
workflow.add_edge("reflect", "generate")
|
|
app = workflow.compile()
|
|
|
|
question = "Draw a red circle"
|
|
app.invoke({"messages":[("user",question)],"iterations":0}) |