Commit 34a3e662 authored by Elif Ceylan's avatar Elif Ceylan
Browse files

GROUPWORK_create optim.py for optimizers

parent 3ced0fb3
# External
import torch
class SGD():
def __init__(self, model, batch_size, lr, gamma) -> None:
# class members
self.model = model
self.batch_size = batch_size
self.lr = lr
self.gamma = gamma
self.velocity = []
def step(self):
modules = self.model.mods
for module in modules:
p = module.param()
if (p[0][0] is not None):
# if velocities for w and b are not initialized
# we initialize them
if module.velocity_w is None:
module.velocity_w = ((self.lr/(1 - self.gamma))*module.grad_w)
if module.velocity_b is None:
module.velocity_b = ((self.lr/(1 - self.gamma))*module.grad_b)
# velocity updates
module.velocity_w = self.gamma * module.velocity_w + (self.lr/self.batch_size)*module.grad_w
module.velocity_b = self.gamma * module.velocity_b + (self.lr/self.batch_size)*module.grad_b
# module parameter updates
module.update_params(opt=True, w=(module.w - module.velocity_w), b=(module.b - module.velocity_b))
def zero_grad(self):
self.model.zero_grad()
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment