{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": []
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
}
},
"cells": [
{
"cell_type": "code",
"source": [
"class Value:\n",
"\n",
" def __init__(self, data, _children=(), _op='', label=''):\n",
" self.data = data\n",
" self.grad = 0.0\n",
" self._prev = set(_children)\n",
" self._op = _op\n",
" self.label = label\n",
"\n",
"\n",
" def __repr__(self): # This basically allows us to print nicer looking expressions for the final output\n",
" return f\"Value(data={self.data})\"\n",
"\n",
" def __add__(self, other):\n",
" out = Value(self.data + other.data, (self, other), '+')\n",
" return out\n",
"\n",
" def __mul__(self, other):\n",
" out = Value(self.data * other.data, (self, other), '*')\n",
" return out"
],
"metadata": {
"id": "jtRAdDVT6jf2"
},
"execution_count": 1,
"outputs": []
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "AIP2sPDm6Los",
"outputId": "f467edac-8c3a-4695-a651-5325a4ecea8f"
},
"outputs": [
{
"output_type": "execute_result",
"data": {
"text/plain": [
"Value(data=-8.0)"
]
},
"metadata": {},
"execution_count": 2
}
],
"source": [
"a = Value(2.0, label='a')\n",
"b = Value(-3.0, label='b')\n",
"c = Value(10.0, label='c')\n",
"e = a*b; e.label='e'\n",
"d= e + c; d.label='d'\n",
"f = Value(-2.0, label='f')\n",
"L = d*f; L.label='L'\n",
"L"
]
},
{
"cell_type": "code",
"source": [
"from graphviz import Digraph\n",
"\n",
"def trace(root):\n",
" #Builds a set of all nodes and edges in a graph\n",
" nodes, edges = set(), set()\n",
" def build(v):\n",
" if v not in nodes:\n",
" nodes.add(v)\n",
" for child in v._prev:\n",
" edges.add((child, v))\n",
" build(child)\n",
" build(root)\n",
" return nodes, edges\n",
"\n",
"def draw_dot(root):\n",
" dot = Digraph(format='svg', graph_attr={'rankdir': 'LR'}) #LR == Left to Right\n",
"\n",
" nodes, edges = trace(root)\n",
" for n in nodes:\n",
" uid = str(id(n))\n",
" #For any value in the graph, create a rectangular ('record') node for it\n",
" dot.node(name = uid, label = \"{ %s | data %.4f | grad %.4f }\" % ( n.label, n.data, n.grad), shape='record')\n",
" if n._op:\n",
" #If this value is a result of some operation, then create an op node for it\n",
" dot.node(name = uid + n._op, label=n._op)\n",
" #and connect this node to it\n",
" dot.edge(uid + n._op, uid)\n",
"\n",
" for n1, n2 in edges:\n",
" #Connect n1 to the node of n2\n",
" dot.edge(str(id(n1)), str(id(n2)) + n2._op)\n",
"\n",
" return dot"
],
"metadata": {
"id": "T0rN8d146jvF"
},
"execution_count": 3,
"outputs": []
},
{
"cell_type": "code",
"source": [
"L.grad = 1.0\n",
"f.grad = 4.0\n",
"d.grad = -2.0\n",
"c.grad = -2.0\n",
"e.grad = -2.0\n",
"a.grad = 6.0\n",
"b.grad = -4.0"
],
"metadata": {
"id": "3TCgz-n6DbzI"
},
"execution_count": 4,
"outputs": []
},
{
"cell_type": "code",
"source": [
"draw_dot(L)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/",
"height": 212
},
"id": "k7wjwrfo6nUl",
"outputId": "b915567c-ba7b-44ec-fd34-97f35d258fd4"
},
"execution_count": 5,
"outputs": [
{
"output_type": "execute_result",
"data": {
"image/svg+xml": "\n\n\n\n\n",
"text/plain": [
""
]
},
"metadata": {},
"execution_count": 5
}
]
},
{
"cell_type": "markdown",
"source": [
"--------------------"
],
"metadata": {
"id": "WqQ2p-U1eUnJ"
}
},
{
"cell_type": "markdown",
"source": [
"Now, we are going to try and nudge the leaf nodes (because those are usually what we have control over. In our example: a,b,c,f) slightly towards the gradient value, to nudge L towards a more positive direction."
],
"metadata": {
"id": "cpx9Me4LeVfx"
}
},
{
"cell_type": "code",
"source": [
"a.data += 0.01 * a.grad\n",
"b.data += 0.01 * b.grad\n",
"c.data += 0.01 * c.grad\n",
"f.data += 0.01 * f.grad\n",
"\n",
"e = a*b;\n",
"d= e + c;\n",
"L = d*f;\n",
"L\n",
"\n",
"print(L.data)"
],
"metadata": {
"colab": {
"base_uri": "https://localhost:8080/"
},
"id": "_edHadeReTBn",
"outputId": "3631e3e6-5bbf-4ceb-d1e1-45599a5a2736"
},
"execution_count": 7,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"-6.723584000000001\n"
]
}
]
},
{
"cell_type": "markdown",
"source": [
"Therefore the value of L was pushed to a more positive direction from -8.0 to -6.0"
],
"metadata": {
"id": "aXwYpQKKgYGg"
}
}
]
}