Commit aee1979f authored by Mira Arabi Haddad's avatar Mira Arabi Haddad
Browse files

Delete SGD.py

parent 20efc7bd
import torch
class SGD_Optim():
def __init__(self, model, batch_size, lr, gamma) -> None:
self.model = model
self.batch_size = batch_size
self.lr = lr
self.gamma = gamma
self.velocity = []
def step(self):
modules = self.model.mods
for module in modules:
p = module.param()
if (p[0][0] is not None):
if module.velocity_w is None:
module.velocity_w = ((self.lr/(1 - self.gamma))*module.grad_w)
if module.velocity_b is None:
module.velocity_b = ((self.lr/(1 - self.gamma))*module.grad_b)
module.velocity_w = self.gamma * module.velocity_w + (self.lr/self.batch_size)*module.grad_w
module.velocity_b = self.gamma * module.velocity_b + (self.lr/self.batch_size)*module.grad_b
module.update_params(opt=True, w=(module.w - module.velocity_w), b=(module.b - module.velocity_b))
def zero_grad(self):
self.model.zero_grad()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment