{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "w1RJ9v2tJXPc" }, "source": [ "# Generative Manim LangGraph Implementation\n", "\n", "Taking the example of [Code generation with flow](https://github.com/langchain-ai/langgraph/blob/main/examples/code_assistant/langgraph_code_assistant.ipynb?ref=blog.langchain.dev), we will implement a similar approach to generate code for Manim animations. So far, I think we would not need test validation, we can delay this step for later." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "rzT1zS_AVMvv", "outputId": "c02b7e9c-7e99-4f7b-8a1a-032687fc89e2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Collecting langchain_community\n", " Downloading langchain_community-0.0.34-py3-none-any.whl (1.9 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.9/1.9 MB\u001b[0m \u001b[31m7.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting langchain-openai\n", " Downloading langchain_openai-0.1.3-py3-none-any.whl (33 kB)\n", "Collecting langchain-anthropic\n", " Downloading langchain_anthropic-0.1.11-py3-none-any.whl (16 kB)\n", "Collecting langchain\n", " Downloading langchain-0.1.16-py3-none-any.whl (817 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m817.7/817.7 kB\u001b[0m \u001b[31m6.1 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting langgraph\n", " Downloading langgraph-0.0.38-py3-none-any.whl (59 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m59.3/59.3 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting bs4\n", " Downloading bs4-0.0.2-py2.py3-none-any.whl (1.2 kB)\n", "Requirement already satisfied: PyYAML>=5.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (6.0.1)\n", "Requirement already satisfied: SQLAlchemy<3,>=1.4 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.0.29)\n", "Requirement already satisfied: aiohttp<4.0.0,>=3.8.3 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (3.9.5)\n", "Collecting dataclasses-json<0.7,>=0.5.7 (from langchain_community)\n", " Downloading dataclasses_json-0.6.4-py3-none-any.whl (28 kB)\n", "Collecting langchain-core<0.2.0,>=0.1.45 (from langchain_community)\n", " Downloading langchain_core-0.1.45-py3-none-any.whl (291 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m291.3/291.3 kB\u001b[0m \u001b[31m10.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting langsmith<0.2.0,>=0.1.0 (from langchain_community)\n", " Downloading langsmith-0.1.49-py3-none-any.whl (115 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m115.2/115.2 kB\u001b[0m \u001b[31m6.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: numpy<2,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (1.25.2)\n", "Requirement already satisfied: requests<3,>=2 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (2.31.0)\n", "Requirement already satisfied: tenacity<9.0.0,>=8.1.0 in /usr/local/lib/python3.10/dist-packages (from langchain_community) (8.2.3)\n", "Collecting openai<2.0.0,>=1.10.0 (from langchain-openai)\n", " Downloading openai-1.23.2-py3-none-any.whl (311 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m311.2/311.2 kB\u001b[0m \u001b[31m12.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting tiktoken<1,>=0.5.2 (from langchain-openai)\n", " Downloading tiktoken-0.6.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.8 MB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m1.8/1.8 MB\u001b[0m \u001b[31m18.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting anthropic<1,>=0.23.0 (from langchain-anthropic)\n", " Downloading anthropic-0.25.6-py3-none-any.whl (870 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m870.7/870.7 kB\u001b[0m \u001b[31m21.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: defusedxml<0.8.0,>=0.7.1 in /usr/local/lib/python3.10/dist-packages (from langchain-anthropic) (0.7.1)\n", "Requirement already satisfied: async-timeout<5.0.0,>=4.0.0 in /usr/local/lib/python3.10/dist-packages (from langchain) (4.0.3)\n", "Collecting jsonpatch<2.0,>=1.33 (from langchain)\n", " Downloading jsonpatch-1.33-py2.py3-none-any.whl (12 kB)\n", "Collecting langchain-text-splitters<0.1,>=0.0.1 (from langchain)\n", " Downloading langchain_text_splitters-0.0.1-py3-none-any.whl (21 kB)\n", "Requirement already satisfied: pydantic<3,>=1 in /usr/local/lib/python3.10/dist-packages (from langchain) (2.7.0)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.10/dist-packages (from bs4) (4.12.3)\n", "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.3.1)\n", "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (23.2.0)\n", "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.4.1)\n", "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (6.0.5)\n", "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp<4.0.0,>=3.8.3->langchain_community) (1.9.4)\n", "Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (3.7.1)\n", "Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.7.0)\n", "Collecting httpx<1,>=0.23.0 (from anthropic<1,>=0.23.0->langchain-anthropic)\n", " Downloading httpx-0.27.0-py3-none-any.whl (75 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m75.6/75.6 kB\u001b[0m \u001b[31m5.0 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (1.3.1)\n", "Requirement already satisfied: tokenizers>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (0.15.2)\n", "Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from anthropic<1,>=0.23.0->langchain-anthropic) (4.11.0)\n", "Collecting marshmallow<4.0.0,>=3.18.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)\n", " Downloading marshmallow-3.21.1-py3-none-any.whl (49 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m49.4/49.4 kB\u001b[0m \u001b[31m4.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting typing-inspect<1,>=0.4.0 (from dataclasses-json<0.7,>=0.5.7->langchain_community)\n", " Downloading typing_inspect-0.9.0-py3-none-any.whl (8.8 kB)\n", "Collecting jsonpointer>=1.9 (from jsonpatch<2.0,>=1.33->langchain)\n", " Downloading jsonpointer-2.4-py2.py3-none-any.whl (7.8 kB)\n", "Collecting packaging<24.0,>=23.2 (from langchain-core<0.2.0,>=0.1.45->langchain_community)\n", " Downloading packaging-23.2-py3-none-any.whl (53 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m53.0/53.0 kB\u001b[0m \u001b[31m4.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting orjson<4.0.0,>=3.9.14 (from langsmith<0.2.0,>=0.1.0->langchain_community)\n", " Downloading orjson-3.10.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (141 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m141.1/141.1 kB\u001b[0m \u001b[31m2.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: tqdm>4 in /usr/local/lib/python3.10/dist-packages (from openai<2.0.0,>=1.10.0->langchain-openai) (4.66.2)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (0.6.0)\n", "Requirement already satisfied: pydantic-core==2.18.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1->langchain) (2.18.1)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.3.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (3.7)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2.0.7)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2->langchain_community) (2024.2.2)\n", "Requirement already satisfied: greenlet!=0.4.17 in /usr/local/lib/python3.10/dist-packages (from SQLAlchemy<3,>=1.4->langchain_community) (3.0.3)\n", "Requirement already satisfied: regex>=2022.1.18 in /usr/local/lib/python3.10/dist-packages (from tiktoken<1,>=0.5.2->langchain-openai) (2023.12.25)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.10/dist-packages (from beautifulsoup4->bs4) (2.5)\n", "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->anthropic<1,>=0.23.0->langchain-anthropic) (1.2.0)\n", "Collecting httpcore==1.* (from httpx<1,>=0.23.0->anthropic<1,>=0.23.0->langchain-anthropic)\n", " Downloading httpcore-1.0.5-py3-none-any.whl (77 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m77.9/77.9 kB\u001b[0m \u001b[31m5.2 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hCollecting h11<0.15,>=0.13 (from httpcore==1.*->httpx<1,>=0.23.0->anthropic<1,>=0.23.0->langchain-anthropic)\n", " Downloading h11-0.14.0-py3-none-any.whl (58 kB)\n", "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m58.3/58.3 kB\u001b[0m \u001b[31m2.3 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", "\u001b[?25hRequirement already satisfied: huggingface_hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (0.20.3)\n", "Collecting mypy-extensions>=0.3.0 (from typing-inspect<1,>=0.4.0->dataclasses-json<0.7,>=0.5.7->langchain_community)\n", " Downloading mypy_extensions-1.0.0-py3-none-any.whl (4.7 kB)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (3.13.4)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic<1,>=0.23.0->langchain-anthropic) (2023.6.0)\n", "Installing collected packages: packaging, orjson, mypy-extensions, jsonpointer, h11, typing-inspect, tiktoken, marshmallow, jsonpatch, httpcore, bs4, langsmith, httpx, dataclasses-json, openai, langchain-core, anthropic, langgraph, langchain-text-splitters, langchain-openai, langchain_community, langchain-anthropic, langchain\n", " Attempting uninstall: packaging\n", " Found existing installation: packaging 24.0\n", " Uninstalling packaging-24.0:\n", " Successfully uninstalled packaging-24.0\n", "Successfully installed anthropic-0.25.6 bs4-0.0.2 dataclasses-json-0.6.4 h11-0.14.0 httpcore-1.0.5 httpx-0.27.0 jsonpatch-1.33 jsonpointer-2.4 langchain-0.1.16 langchain-anthropic-0.1.11 langchain-core-0.1.45 langchain-openai-0.1.3 langchain-text-splitters-0.0.1 langchain_community-0.0.34 langgraph-0.0.38 langsmith-0.1.49 marshmallow-3.21.1 mypy-extensions-1.0.0 openai-1.23.2 orjson-3.10.1 packaging-23.2 tiktoken-0.6.0 typing-inspect-0.9.0\n" ] } ], "source": [ "! pip install -U langchain_community langchain-openai langchain-anthropic langchain langgraph bs4" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "28d29y8r4K0o", "outputId": "927f2b8b-75eb-44c7-ba9b-2c20a5fd84b3" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: anthropic in /usr/local/lib/python3.10/dist-packages (0.25.6)\n", "Requirement already satisfied: anyio<5,>=3.5.0 in /usr/local/lib/python3.10/dist-packages (from anthropic) (3.7.1)\n", "Requirement already satisfied: distro<2,>=1.7.0 in /usr/lib/python3/dist-packages (from anthropic) (1.7.0)\n", "Requirement already satisfied: httpx<1,>=0.23.0 in /usr/local/lib/python3.10/dist-packages (from anthropic) (0.27.0)\n", "Requirement already satisfied: pydantic<3,>=1.9.0 in /usr/local/lib/python3.10/dist-packages (from anthropic) (2.7.0)\n", "Requirement already satisfied: sniffio in /usr/local/lib/python3.10/dist-packages (from anthropic) (1.3.1)\n", "Requirement already satisfied: tokenizers>=0.13.0 in /usr/local/lib/python3.10/dist-packages (from anthropic) (0.15.2)\n", "Requirement already satisfied: typing-extensions<5,>=4.7 in /usr/local/lib/python3.10/dist-packages (from anthropic) (4.11.0)\n", "Requirement already satisfied: idna>=2.8 in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->anthropic) (3.7)\n", "Requirement already satisfied: exceptiongroup in /usr/local/lib/python3.10/dist-packages (from anyio<5,>=3.5.0->anthropic) (1.2.0)\n", "Requirement already satisfied: certifi in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->anthropic) (2024.2.2)\n", "Requirement already satisfied: httpcore==1.* in /usr/local/lib/python3.10/dist-packages (from httpx<1,>=0.23.0->anthropic) (1.0.5)\n", "Requirement already satisfied: h11<0.15,>=0.13 in /usr/local/lib/python3.10/dist-packages (from httpcore==1.*->httpx<1,>=0.23.0->anthropic) (0.14.0)\n", "Requirement already satisfied: annotated-types>=0.4.0 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->anthropic) (0.6.0)\n", "Requirement already satisfied: pydantic-core==2.18.1 in /usr/local/lib/python3.10/dist-packages (from pydantic<3,>=1.9.0->anthropic) (2.18.1)\n", "Requirement already satisfied: huggingface_hub<1.0,>=0.16.4 in /usr/local/lib/python3.10/dist-packages (from tokenizers>=0.13.0->anthropic) (0.20.3)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (3.13.4)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (2023.6.0)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (2.31.0)\n", "Requirement already satisfied: tqdm>=4.42.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (4.66.2)\n", "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (6.0.1)\n", "Requirement already satisfied: packaging>=20.9 in /usr/local/lib/python3.10/dist-packages (from huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (23.2)\n", "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (3.3.2)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->huggingface_hub<1.0,>=0.16.4->tokenizers>=0.13.0->anthropic) (2.0.7)\n" ] } ], "source": [ "!pip install anthropic" ] }, { "cell_type": "markdown", "metadata": { "id": "-ujfHFxjZlt3" }, "source": [ "## Extracting examples from Manim docs" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "GkOreR2TJO2l" }, "outputs": [], "source": [ "from bs4 import BeautifulSoup as Soup\n", "from langchain_community.document_loaders.recursive_url_loader import RecursiveUrlLoader\n", "\n", "# Manim Examples docs\n", "url = \"https://docs.manim.community/en/stable/examples.html\"\n", "loader = RecursiveUrlLoader(\n", " url=url, max_depth=20, extractor=lambda x: Soup(x, \"html.parser\").text\n", ")\n", "docs = loader.load()\n", "\n", "# Sort the list based on the URLs and get the text\n", "d_sorted = sorted(docs, key=lambda x: x.metadata[\"source\"])\n", "d_reversed = list(reversed(d_sorted))\n", "concatenated_content = \"\\n\\n\\n --- \\n\\n\\n\".join(\n", " [doc.page_content for doc in d_reversed]\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "odSEgc2CZp8r" }, "source": [ "## LLM Solution" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DwM4uTgHZsiT" }, "outputs": [], "source": [ "from langchain_openai import ChatOpenAI\n", "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_core.pydantic_v1 import BaseModel, Field\n", "\n", "### OpenAI\n", "\n", "# Manim code generation prompt\n", "code_gen_prompt = ChatPromptTemplate.from_messages(\n", " [(\"system\",\"\"\"You are a coding assistant with expertise in Manim, Graphical Animation Library. \\n\n", " Here is a full set of Manim documentation: \\n ------- \\n {context} \\n ------- \\n Answer the user\n", " question based on the above provided documentation. Ensure any code you provide can be executed \\n\n", " with all required imports and variables defined. Structure your answer with a description of the code solution. \\n\n", " Then list the imports. And finally list the functioning code block. Here is the user question:\"\"\"),\n", " (\"placeholder\", \"{messages}\")]\n", ")" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "5cYxsBA8Z__F" }, "outputs": [], "source": [ "class code(BaseModel):\n", " \"\"\"Code output\"\"\"\n", " prefix: str = Field(description=\"Description of the problem and approach\")\n", " imports: str = Field(description=\"Code block import statements\")\n", " code: str = Field(description=\"Code block not including import statements\")\n", " description = \"Schema for code solutions to requests on code for Manim Animations.\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "1nlIIoDCaNJA" }, "outputs": [], "source": [ "from google.colab import userdata\n", "import anthropic\n", "\n", "OPENAI_API_KEY = userdata.get('OPENAI_API_KEY')\n", "anthropic.api_key = userdata.get('ANTHROPIC_API_KEY')\n", "\n", "expt_llm = \"gpt-4-0125-preview\"\n", "llm = ChatOpenAI(temperature=0, model=expt_llm, openai_api_key=OPENAI_API_KEY)\n", "code_gen_chain = code_gen_prompt | llm.with_structured_output(code)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "EBS6nNJ4bUSe" }, "outputs": [], "source": [ "question = \"Draw three red circles\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ewQImmNObYhB", "outputId": "66159a2f-91d7-434d-965c-eeea09644c24" }, "outputs": [ { "data": { "text/plain": [ "code(prefix='Draw three red circles using Manim', imports='from manim import *', code='class ThreeRedCircles(Scene):\\n def construct(self):\\n # Create three red circles\\n circle1 = Circle(color=RED).shift(LEFT)\\n circle2 = Circle(color=RED)\\n circle3 = Circle(color=RED).shift(RIGHT)\\n\\n # Add circles to the scene\\n self.add(circle1, circle2, circle3)\\n\\n# To render the scene, uncomment the following line and run the script\\n# scene = ThreeRedCircles()\\n# scene.render()', description='Schema for code solutions to requests on code for Manim Animations.')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "solution = code_gen_chain.invoke({\"context\":concatenated_content,\"messages\":[(\"user\",question)]})\n", "\n", "solution" ] }, { "cell_type": "markdown", "metadata": { "id": "umRuYuWczxoe" }, "source": [ "Now let's try with Anthropic" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8pIusQp6fmUF" }, "outputs": [], "source": [ "from langchain_anthropic import ChatAnthropic\n", "from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_core.pydantic_v1 import BaseModel, Field" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "jkb4MB04z3j1" }, "outputs": [], "source": [ "code_gen_prompt_claude = ChatPromptTemplate.from_messages(\n", " [(\"system\",\"\"\" You are a coding assistant with expertise in Manim, Graphical Animation Library. \\n\n", " Here is a full set of Manim documentation: \\n ------- \\n {context} \\n ------- \\n Answer the user request based on the \\n\n", " above provided documentation. Ensure any code you provide can be executed with all required imports and variables \\n\n", " defined. Structure your answer: 1) a prefix describing the code solution, 2) the imports, 3) the functioning code block. \\n\n", " Invoke the code tool from Manim to structure the output correctly. \\n Here is the user request:\"\"\",),\n", " (\"placeholder\", \"{messages}\"),])" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "aCA4IZuE0EmP", "outputId": "f7fe31fd-511d-4576-ca64-f3b5f5006b89" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/usr/local/lib/python3.10/dist-packages/langchain_core/_api/beta_decorator.py:87: LangChainBetaWarning: The function `with_structured_output` is in beta. It is actively being worked on, so the API may change.\n", " warn_beta(\n" ] } ], "source": [ "# Data model\n", "class code(BaseModel):\n", " \"\"\"Code output\"\"\"\n", "\n", " prefix: str = Field(description=\"Description of the problem and approach\")\n", " imports: str = Field(description=\"Code block import statements\")\n", " code: str = Field(description=\"Code block not including import statements\")\n", " description = \"Schema for code solutions to questions about Manim.\"\n", "\n", "# LLM\n", "# expt_llm = \"claude-3-haiku-20240307\"\n", "expt_llm = \"claude-3-opus-20240229\"\n", "llm = ChatAnthropic(\n", " model=expt_llm,\n", " default_headers={\"anthropic-beta\": \"tools-2024-04-04\"},\n", ")\n", "\n", "structured_llm_claude = llm.with_structured_output(code, include_raw=True)" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "DYr8DwNc0bQG" }, "outputs": [], "source": [ "# Optional: Check for errors in case tool use is flaky\n", "def check_claude_output(tool_output):\n", " \"\"\"Check for parse error or failure to call the tool\"\"\"\n", "\n", " # Error with parsing\n", " if tool_output[\"parsing_error\"]:\n", " # Report back output and parsing errors\n", " print(\"Parsing error!\")\n", " raw_output = str(code_output[\"raw\"].content)\n", " error = tool_output[\"parsing_error\"]\n", " raise ValueError(\n", " f\"Error parsing your output! Be sure to invoke the tool. Output: {raw_output}. \\n Parse error: {error}\"\n", " )\n", "\n", " # Tool was not invoked\n", " elif not tool_output[\"parsed\"]:\n", " print(\"Failed to invoke tool!\")\n", " raise ValueError(\n", " f\"You did not use the provided tool! Be sure to invoke the tool to structure the output.\"\n", " )\n", " return tool_output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Pvf-8bvd2z7x" }, "outputs": [], "source": [ "code_chain_claude_raw = code_gen_prompt_claude | structured_llm_claude | check_claude_output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "zzd6Y1dc32KQ" }, "outputs": [], "source": [ "def insert_errors(inputs):\n", " \"\"\"Insert errors for tool parsing in the messages\"\"\"\n", "\n", " # Get errors\n", " error = inputs[\"error\"]\n", " messages = inputs[\"messages\"]\n", " messages += [\n", " (\n", " \"assistant\",\n", " f\"Retry. You are required to fix the parsing errors: {error} \\n\\n You must invoke the provided tool.\",\n", " )\n", " ]\n", " return {\n", " \"messages\": messages,\n", " \"context\": inputs[\"context\"],\n", " }" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "pueLK_Kq34c5" }, "outputs": [], "source": [ "\n", "# This will be run as a fallback chain\n", "fallback_chain = insert_errors | code_chain_claude_raw\n", "N = 3 # Max re-tries\n", "code_gen_chain_re_try = code_chain_claude_raw.with_fallbacks(fallbacks=[fallback_chain] * N, exception_key=\"error\")\n", "\n", "def parse_output(solution):\n", " \"\"\"When we add 'include_raw=True' to structured output,\n", " it will return a dict w 'raw', 'parsed', 'parsing_error'. \"\"\"\n", "\n", " return solution['parsed']\n", "\n", "# Wtih re-try to correct for failure to invoke tool\n", "# TODO: Annoying errors w/ \"user\" vs \"assistant\"\n", "# Roles must alternate between \"user\" and \"assistant\", but found multiple \"user\" roles in a row\n", "code_gen_chain = code_gen_chain_re_try | parse_output\n", "\n", "# No re-try\n", "code_gen_chain = code_gen_prompt_claude | structured_llm_claude | parse_output" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "vtOZfTjx38Af", "outputId": "d639d87b-6468-4b65-e27f-7aff429ec000" }, "outputs": [ { "data": { "text/plain": [ "code(prefix='Draw a red circle using Manim', imports='from manim import *', code='class RedCircle(Scene):\\n def construct(self):\\n circle = Circle(radius=1, color=RED)\\n self.add(circle)\\n self.wait(2)\\n\\n# To run this scene, use the following command in your terminal:\\n# manim -p -ql .py RedCircle', description='Schema for code solutions to questions about Manim.')" ] }, "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ "# Test\n", "question = \"Draw a red circle\"\n", "solution = code_gen_chain.invoke({\"context\":concatenated_content,\"messages\":[(\"user\",question)]})\n", "solution" ] }, { "cell_type": "markdown", "metadata": { "id": "MOQQ03Qg-eni" }, "source": [ "## State\n", "\n", "Our state is a dict that will contain keys (errors, question, code generation) relevant to code generation." ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "UuOnK7Cv-gww" }, "outputs": [], "source": [ "from typing import Dict, TypedDict, List\n", "\n", "class GraphState(TypedDict):\n", " \"\"\"\n", " Represents the state of our graph.\n", "\n", " Attributes:\n", " error : Binary flag for control flow to indicate whether test error was tripped\n", " messages : With user question, error messages, reasoning\n", " generation : Code solution\n", " iterations : Number of tries\n", " \"\"\"\n", "\n", " error : str\n", " messages : List\n", " generation : str\n", " iterations : int" ] }, { "cell_type": "markdown", "metadata": { "id": "FkK7aYTjC9_I" }, "source": [ "## Graph\n", "\n", "Our graph lays out the logical flow shown in the figure above.\n", "\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "7gU5YRQpDC_K" }, "outputs": [], "source": [ "from operator import itemgetter\n", "from langchain.prompts import PromptTemplate\n", "from langchain_core.pydantic_v1 import BaseModel, Field\n", "from langchain_core.runnables import RunnablePassthrough\n", "\n", "### Parameter\n", "\n", "# Max tries\n", "max_iterations = 3\n", "# Reflect\n", "# flag = 'reflect'\n", "flag = 'do not reflect'\n", "\n", "### Nodes\n", "\n", "def generate(state: GraphState):\n", " \"\"\"\n", " Generate a code solution\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, generation\n", " \"\"\"\n", "\n", " print(\"---GENERATING CODE SOLUTION---\")\n", "\n", " # State\n", " messages = state[\"messages\"]\n", " iterations = state[\"iterations\"]\n", " error = state[\"error\"]\n", "\n", " # We have been routed back to generation with an error\n", " if error == \"yes\":\n", " messages += [(\"user\",\"Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:\")]\n", "\n", " # Solution\n", " code_solution = code_gen_chain.invoke({\"context\": concatenated_content, \"messages\" : messages})\n", " messages += [(\"assistant\",f\"{code_solution.prefix} \\n Imports: {code_solution.imports} \\n Code: {code_solution.code}\")]\n", "\n", " # Increment\n", " iterations = iterations + 1\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n", "\n", "def code_check(state: GraphState):\n", " \"\"\"\n", " Check code\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, error\n", " \"\"\"\n", "\n", " print(\"---CHECKING CODE---\")\n", "\n", " # State\n", " messages = state[\"messages\"]\n", " code_solution = state[\"generation\"]\n", " iterations = state[\"iterations\"]\n", "\n", " # Get solution components\n", " prefix = code_solution.prefix\n", " imports = code_solution.imports\n", " code = code_solution.code\n", "\n", " # Check imports\n", " try:\n", " exec(imports)\n", " except Exception as e:\n", " print(\"---CODE IMPORT CHECK: FAILED---\")\n", " error_message = [(\"user\", f\"Your solution failed the import test: {e}\")]\n", " messages += error_message\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations, \"error\": \"yes\"}\n", "\n", " # Check execution\n", " try:\n", " exec(imports + \"\\n\" + code)\n", " except Exception as e:\n", " print(\"---CODE BLOCK CHECK: FAILED---\")\n", " error_message = [(\"user\", f\"Your solution failed the code execution test: {e}\")]\n", " messages += error_message\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations, \"error\": \"yes\"}\n", "\n", " # No errors\n", " print(\"---NO CODE TEST FAILURES---\")\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations, \"error\": \"no\"}\n", "\n", "def reflect(state: GraphState):\n", " \"\"\"\n", " Reflect on errors\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " state (dict): New key added to state, generation\n", " \"\"\"\n", "\n", " print(\"---GENERATING CODE SOLUTION---\")\n", "\n", " # State\n", " messages = state[\"messages\"]\n", " iterations = state[\"iterations\"]\n", " code_solution = state[\"generation\"]\n", "\n", " # Prompt reflection\n", " reflection_message = [(\"user\", \"\"\"You tried to solve this problem and failed a unit test. Reflect on this failure\n", " given the provided documentation. Write a few key suggestions based on the\n", " documentation to avoid making this mistake again.\"\"\")]\n", "\n", " # Add reflection\n", " reflections = code_gen_chain.invoke({\"context\" : concatenated_content, \"messages\" : messages})\n", " messages += [(\"assistant\" , f\"Here are reflections on the error: {reflections}\")]\n", " return {\"generation\": code_solution, \"messages\": messages, \"iterations\": iterations}\n", "\n", "### Edges\n", "\n", "def decide_to_finish(state: GraphState):\n", " \"\"\"\n", " Determines whether to finish.\n", "\n", " Args:\n", " state (dict): The current graph state\n", "\n", " Returns:\n", " str: Next node to call\n", " \"\"\"\n", " error = state[\"error\"]\n", " iterations = state[\"iterations\"]\n", "\n", " if error == \"no\" or iterations == max_iterations:\n", " print(\"---DECISION: FINISH---\")\n", " return \"end\"\n", " else:\n", " print(\"---DECISION: RE-TRY SOLUTION---\")\n", " if flag == 'reflect':\n", " return \"reflect\"\n", " else:\n", " return \"generate\"" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Me7Tp-cSDszt" }, "outputs": [], "source": [ "from langgraph.graph import END, StateGraph\n", "\n", "workflow = StateGraph(GraphState)\n", "\n", "# Define the nodes\n", "workflow.add_node(\"generate\", generate) # generation solution\n", "workflow.add_node(\"check_code\", code_check) # check code\n", "workflow.add_node(\"reflect\", reflect) # reflect\n", "\n", "# Build graph\n", "workflow.set_entry_point(\"generate\")\n", "workflow.add_edge(\"generate\", \"check_code\")\n", "workflow.add_conditional_edges(\n", " \"check_code\",\n", " decide_to_finish,\n", " {\n", " \"end\": END,\n", " \"reflect\": \"reflect\",\n", " \"generate\": \"generate\",\n", " },\n", ")\n", "workflow.add_edge(\"reflect\", \"generate\")\n", "app = workflow.compile()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "EVhrtV03D6I1", "outputId": "51bba349-cf26-4758-d1c8-4d4b9f961501" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "---GENERATING CODE SOLUTION---\n", "---CHECKING CODE---\n", "---CODE IMPORT CHECK: FAILED---\n", "---DECISION: RE-TRY SOLUTION---\n", "---GENERATING CODE SOLUTION---\n", "---CHECKING CODE---\n", "---CODE IMPORT CHECK: FAILED---\n", "---DECISION: RE-TRY SOLUTION---\n", "---GENERATING CODE SOLUTION---\n", "---CHECKING CODE---\n", "---CODE IMPORT CHECK: FAILED---\n", "---DECISION: FINISH---\n" ] }, { "data": { "text/plain": [ "{'error': 'yes',\n", " 'messages': [('user', 'Draw a red circle'),\n", " ('assistant',\n", " 'Draw a red circle using Manim \\n Imports: from manim import * \\n Code: class RedCircle(Scene):\\n def construct(self):\\n circle = Circle(color=RED)\\n self.add(circle)\\n self.wait(2)'),\n", " ('user', \"Your solution failed the import test: No module named 'manim'\"),\n", " ('user',\n", " 'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n", " ('assistant',\n", " 'To draw a red circle using Manim, you can use the following code snippet. This example creates a simple scene with a red circle and displays it for a short duration. \\n Imports: from manim import * \\n Code: class RedCircle(Scene):\\n def construct(self):\\n circle = Circle(color=RED)\\n self.add(circle)\\n self.wait(2)'),\n", " ('user', \"Your solution failed the import test: No module named 'manim'\"),\n", " ('user',\n", " 'Now, try again. Invoke the code tool to structure the output with a prefix, imports, and code block:'),\n", " ('assistant',\n", " 'To draw a red circle using Manim, follow this structured approach: \\n Imports: from manim import * \\n Code: class RedCircle(Scene):\\n def construct(self):\\n circle = Circle(color=RED)\\n self.add(circle)\\n self.wait(2)'),\n", " ('user', \"Your solution failed the import test: No module named 'manim'\")],\n", " 'generation': code(prefix='To draw a red circle using Manim, follow this structured approach:', imports='from manim import *', code='class RedCircle(Scene):\\n def construct(self):\\n circle = Circle(color=RED)\\n self.add(circle)\\n self.wait(2)', description='Schema for code solutions to questions about Manim.'),\n", " 'iterations': 3}" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "question = \"Draw a red circle\"\n", "app.invoke({\"messages\":[(\"user\",question)],\"iterations\":0})" ] }, { "cell_type": "markdown", "metadata": { "id": "-pz9B4D0El1Q" }, "source": [] } ], "metadata": { "colab": { "provenance": [] }, "kernelspec": { "display_name": "Python 3", "name": "python3" }, "language_info": { "name": "python" } }, "nbformat": 4, "nbformat_minor": 0 }