You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

352 lines
11 KiB

3 months ago
# -*- coding: utf-8 -*-
"""Generative Manim LangGraph Implementation.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1YSO9TG2fJVVH4l7yTHE_V-v8VaDfpM2s
# Generative Manim LangGraph Implementation
Taking the example of [Code generation with flow](https://github.com/langchain-ai/langgraph/blob/main/examples/code_assistant/langgraph_code_assistant.ipynb?ref=blog.langchain.dev), we will implement a similar approach to generate code for Manim animations. So far, I think we would not need test validation, we can delay this step for later.
"""
"""## Extracting examples from Manim docs"""
# Load .env
from dotenv import load_dotenv
load_dotenv()
import os
from bs4 import BeautifulSoup as Soup
from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader
# Manim Examples docs
url = "https://docs.manim.community/en/stable/examples.html"
loader = RecursiveUrlLoader(
url=url, max_depth=20, extractor=lambda x: Soup(x, "html.parser").text
)
docs = loader.load()
# Sort the list based on the URLs and get the text
d_sorted = sorted(docs, key=lambda x: x.metadata["source"])
d_reversed = list(reversed(d_sorted))
concatenated_content = "\n\n\n --- \n\n\n".join(
[doc.page_content for doc in d_reversed]
)
"""## LLM Solution"""
from langchain_openai import ChatOpenAI
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
### OpenAI
# Manim code generation prompt
code_gen_prompt = ChatPromptTemplate.from_messages(
[("system","""You are a coding assistant with expertise in Manim, Graphical Animation Library. \n
Here is a full set of Manim documentation: \n ------- \n {context} \n ------- \n Answer the user
question based on the above provided documentation. Ensure any code you provide can be executed \n
with all required imports and variables defined. Structure your answer with a description of the code solution. \n
Then list the imports. And finally list the functioning code block. Here is the user question:"""),
("placeholder", "{messages}")]
)
class code(BaseModel):
"""Code output"""
prefix: str = Field(description="Description of the problem and approach")
imports: str = Field(description="Code block import statements")
code: str = Field(description="Code block not including import statements")
description = "Schema for code solutions to requests on code for Manim Animations."
import anthropic
OPENAI_API_KEY = os.getenv('OPENAI_API_KEY')
anthropic.api_key = os.getenv('ANTHROPIC_API_KEY')
expt_llm = "gpt-4-0125-preview"
llm = ChatOpenAI(temperature=0, model=expt_llm, openai_api_key=OPENAI_API_KEY)
code_gen_chain = code_gen_prompt | llm.with_structured_output(code)
question = "Draw three red circles"
solution = code_gen_chain.invoke({"context":concatenated_content,"messages":[("user",question)]})
solution
"""Now let's try with Anthropic"""
from langchain_anthropic import ChatAnthropic
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
code_gen_prompt_claude = ChatPromptTemplate.from_messages(
[("system","""<instructions> You are a coding assistant with expertise in Manim, Graphical Animation Library. \n
Here is a full set of Manim documentation: \n ------- \n {context} \n ------- \n Answer the user request based on the \n
above provided documentation. Ensure any code you provide can be executed with all required imports and variables \n
defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \n
Invoke the code tool from Manim to structure the output correctly. </instructions> \n Here is the user request:""",),
("placeholder", "{messages}"),])
# Data model
class code(BaseModel):
"""Code output"""
prefix: str = Field(description="Description of the problem and approach")
imports: str = Field(description="Code block import statements")
code: str = Field(description="Code block not including import statements")
description = "Schema for code solutions to questions about Manim."
# LLM
# expt_llm = "claude-3-haiku-20240307"
expt_llm = "claude-3-opus-20240229"
llm = ChatAnthropic(
model=expt_llm,
default_headers={"anthropic-beta": "tools-2024-04-04"},
)
structured_llm_claude = llm.with_structured_output(code, include_raw=True)
code_chain_claude_raw = code_gen_prompt_claude | structured_llm_claude
def insert_errors(inputs):
"""Insert errors for tool parsing in the messages"""
# Get errors
error = inputs["error"]
messages = inputs["messages"]
messages += [
(
"assistant",
f"Retry. You are required to fix the parsing errors: {error} \n\n You must invoke the provided tool.",
)
]
return {
"messages": messages,
"context": inputs["context"],
}
# This will be run as a fallback chain
fallback_chain = insert_errors | code_chain_claude_raw
N = 3 # Max re-tries
code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(fallbacks=[fallback_chain] * N, exception_key="error")
def parse_output(solution):
"""When we add 'include_raw=True' to structured output,
it will return a dict w 'raw', 'parsed', 'parsing_error'. """
return solution['parsed']
# Wtih re-try to correct for failure to invoke tool
# TODO: Annoying errors w/ "user" vs "assistant"
# Roles must alternate between "user" and "assistant", but found multiple "user" roles in a row
code_gen_chain = code_gen_chain_re_try | parse_output
# No re-try
code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output
# Test
question = "Draw a red circle"
solution = code_gen_chain.invoke({"context":concatenated_content,"messages":[("user",question)]})
solution
"""## State
Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation.
"""
from typing import Dict, TypedDict, List
class GraphState(TypedDict):
"""
Represents the state of our graph.
Attributes:
error : Binary flag for control flow to indicate whether test error was tripped
messages : With user question, error messages, reasoning
generation : Code solution
iterations : Number of tries
"""
error : str
messages : List
generation : str
iterations : int
"""## Graph
Our graph lays out the logical flow shown in the figure above.
"""
from operator import itemgetter
from langchain.prompts import PromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_core.runnables import RunnablePassthrough
### Parameter
# Max tries
max_iterations = 3
# Reflect
# flag = 'reflect'
flag = 'do not reflect'
### Nodes
def generate(state: GraphState):
"""
Generate a code solution
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
# State
messages = state["messages"]
iterations = state["iterations"]
error = state["error"]
# We have been routed back to generation with an error
if error == "yes":
messages += [("user","Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:")]
# Solution
code_solution = code_gen_chain.invoke({"context": concatenated_content, "messages" : messages})
messages += [("assistant",f"{code_solution.prefix} \n Imports: {code_solution.imports} \n Code: {code_solution.code}")]
# Increment
iterations = iterations + 1
return {"generation": code_solution, "messages": messages, "iterations": iterations}
def code_check(state: GraphState):
"""
Check code
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, error
"""
print("---CHECKING CODE---")
# State
messages = state["messages"]
code_solution = state["generation"]
iterations = state["iterations"]
# Get solution components
prefix = code_solution.prefix
imports = code_solution.imports
code = code_solution.code
# Check imports
try:
exec(imports)
except Exception as e:
print("---CODE IMPORT CHECK: FAILED---")
error_message = [("user", f"Your solution failed the import test: {e}")]
messages += error_message
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}
# Check execution
try:
exec(imports + "\n" + code)
except Exception as e:
print("---CODE BLOCK CHECK: FAILED---")
error_message = [("user", f"Your solution failed the code execution test: {e}")]
messages += error_message
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "yes"}
# No errors
print("---NO CODE TEST FAILURES---")
return {"generation": code_solution, "messages": messages, "iterations": iterations, "error": "no"}
def reflect(state: GraphState):
"""
Reflect on errors
Args:
state (dict): The current graph state
Returns:
state (dict): New key added to state, generation
"""
print("---GENERATING CODE SOLUTION---")
# State
messages = state["messages"]
iterations = state["iterations"]
code_solution = state["generation"]
# Prompt reflection
reflection_message = [("user", """You tried to solve this problem and failed a unit test. Reflect on this failure
given the provided documentation. Write a few key suggestions based on the
documentation to avoid making this mistake again.""")]
# Add reflection
reflections = code_gen_chain.invoke({"context" : concatenated_content, "messages" : messages})
messages += [("assistant" , f"Here are reflections on the error: {reflections}")]
return {"generation": code_solution, "messages": messages, "iterations": iterations}
### Edges
def decide_to_finish(state: GraphState):
"""
Determines whether to finish.
Args:
state (dict): The current graph state
Returns:
str: Next node to call
"""
error = state["error"]
iterations = state["iterations"]
if error == "no" or iterations == max_iterations:
print("---DECISION: FINISH---")
return "end"
else:
print("---DECISION: RE-TRY SOLUTION---")
if flag == 'reflect':
return "reflect"
else:
return "generate"
from langgraph.graph import END, StateGraph
workflow = StateGraph(GraphState)
# Define the nodes
workflow.add_node("generate", generate) # generation solution
workflow.add_node("check_code", code_check) # check code
workflow.add_node("reflect", reflect) # reflect
# Build graph
workflow.set_entry_point("generate")
workflow.add_edge("generate", "check_code")
workflow.add_conditional_edges(
"check_code",
decide_to_finish,
{
"end": END,
"reflect": "reflect",
"generate": "generate",
},
)
workflow.add_edge("reflect", "generate")
app = workflow.compile()
question = "Draw a red circle"
app.invoke({"messages":[("user",question)],"iterations":0})