title: pytorch 计算图理解
date: 2018-08-06 10:21:14
因为笔者在train的时候发现梯度流被阻断了,所以学习下pytorch 计算图的原理。
version 0.3
gdown.pl hosts.txt main.py sort.sh
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
| import torch from torch.autograd import Variable
x = torch.ones(1, requires_grad=True) y = torch.ones(1) z = x + y
w = torch.ones(1, requires_grad=True) total = w + z
total.backward() print(x.requires_grad, x.grad) print(y.requires_grad, y.grad) print(z.requires_grad, z.grad) print(w.requires_grad, w.grad)
|
非叶子结点无法访问梯度
因为只有叶子结点是Variable, 它们的值可以变
非叶子结点可以通过使用.retrain_grad() 来修改梯度
只能一次backward
先前向计算得到graph, backward后graph就被释放了
然后在前向计算的时候,非叶子结点的梯度就destroy了,所以用retrain_grad进行修复
1 2 3 4
| b = w1 * a c = w2 * a d = (w3 * b) + (w4 * c) L = f(d)
|




https://towardsdatascience.com/getting-started-with-pytorch-part-1-understanding-how-automatic-differentiation-works-5008282073ec