Source code for ls_mlkit.util.scheduler

import math
from enum import Enum
from typing import Any, Callable


[docs] def cosine_decay_with_warmup(value, current, total, warmup_steps=0): if current < warmup_steps: return value * current / warmup_steps else: return value * (1 + math.cos(math.pi * (current - warmup_steps) / (total - warmup_steps))) / 2
[docs] def linear_decay_with_warmup(value, current, total, warmup_steps=0): if current < warmup_steps: return value * current / warmup_steps else: return value * (1 - (current - warmup_steps) / (total - warmup_steps))
[docs] def constant_with_warmup(value, current, total, warmup_steps=0): if current < warmup_steps: return value * current / warmup_steps else: return value
[docs] def exponential_decay_with_warmup(value, current, total, warmup_steps=0, decay_rate=5.0): if current < warmup_steps: return value * current / warmup_steps else: progress = (current - warmup_steps) / (total - warmup_steps) return value * math.exp(-decay_rate * progress)
[docs] class SchedulerType(Enum): COSINE_DECAY_WITH_WARMUP = "cosine_decay_with_warmup" LINEAR_DECAY_WITH_WARMUP = "linear_decay_with_warmup" EXPONENTIAL_DECAY_WITH_WARMUP = "exponential_decay_with_warmup" CONSTANT_WITH_WARMUP = "constant_with_warmup"
FUNCTION_MAPPING = { SchedulerType.COSINE_DECAY_WITH_WARMUP: cosine_decay_with_warmup, SchedulerType.LINEAR_DECAY_WITH_WARMUP: linear_decay_with_warmup, SchedulerType.EXPONENTIAL_DECAY_WITH_WARMUP: exponential_decay_with_warmup, SchedulerType.CONSTANT_WITH_WARMUP: constant_with_warmup, }
[docs] class Scheduler: def __init__( self, info: dict[str, dict[str, Any]], total: int, ): self.info = info self.total = total self.current = 0 for key, value in self.info.items(): if value.get("value") is None: raise ValueError(f"value of {key} is not defined") if value.get("schedule") is None: raise ValueError(f"schedule of {key} is not defined") if value.get("warmup_steps") is None: assert ( value.get("warmup_ratio") is not None ), f"warmup_ratio of {key} must be provided if warmup_steps is not provided" value["warmup_steps"] = int(self.total * value["warmup_ratio"])
[docs] def step(self): """Step the scheduler""" self.current += 1 for key, value in self.info.items(): value["current_value"] = value["schedule"](value["value"], self.current, self.total, value["warmup_steps"])
[docs] def get(self, key=None): """Get the current value of the scheduler Args: key (str, optional): The key of the scheduler to get. If None, return the entire scheduler info. Defaults to None. Returns: dict[str, Any] or Any: The entire scheduler info or the value of the scheduler for the given key """ if key is None: return self.info else: return self.info[key]["current_value"]
[docs] class ObjectAttrsScheduler: def __init__( self, obj: object, attr_names: list[str], total: int, warmup_steps: int = None, warmup_ratio: float = 0, strategy: SchedulerType = SchedulerType.CONSTANT_WITH_WARMUP, setter_methods: dict[str, Callable] = None, getter_methods: dict[str, Callable] = None, ): self.obj = obj self.attr_names = attr_names self.strategy = strategy self.setter_methods = setter_methods or {} self.getter_methods = getter_methods or {} self.info = {} for attr_name in attr_names: # Get initial value getter_method = self.getter_methods.get(attr_name) if getter_method is not None: assert hasattr(obj, getter_method), f"{getter_method} is not a method of {obj}" initial_value = getattr(obj, getter_method)() else: assert hasattr(obj, attr_name), f"{attr_name} is not an attribute of {obj}" initial_value = getattr(obj, attr_name) # Validate setter if provided setter_method = self.setter_methods.get(attr_name) if setter_method is not None: assert hasattr(obj, setter_method), f"{setter_method} is not a method of {obj}" self.info.update( { attr_name: { "value": initial_value, "schedule": FUNCTION_MAPPING[strategy], "warmup_steps": warmup_steps, "warmup_ratio": warmup_ratio, } } ) self.scheduler = Scheduler(self.info, total)
[docs] def step(self): self.scheduler.step() for attr_name in self.attr_names: new_value = self.scheduler.get(attr_name) # Use custom setter if provided, otherwise use setattr setter_method = self.setter_methods.get(attr_name) if setter_method is not None: getattr(self.obj, setter_method)(new_value) else: setattr(self.obj, attr_name, new_value)
[docs] def get(self): result = {} for attr_name in self.attr_names: # Use custom getter if provided, otherwise use getattr getter_method = self.getter_methods.get(attr_name) if getter_method is not None: result[attr_name] = getattr(self.obj, getter_method)() else: result[attr_name] = getattr(self.obj, attr_name) return result
if __name__ == "__main__": import wandb wandb.init(project="scheduler-test") total = 100 warmup_ratio = 0.1 class Test: def __init__(self, value): self.x = value self.y = value self.z = value def set_x(self, value): self.x = value def get_x(self): return self.x test = Test(10) scheduler = ObjectAttrsScheduler( test, attr_names=["x", "y", "z"], total=total, warmup_ratio=warmup_ratio, strategy=SchedulerType.EXPONENTIAL_DECAY_WITH_WARMUP, setter_methods={"x": "set_x"}, getter_methods={"x": "get_x"}, ) for i in range(total): scheduler.step() wandb.log(scheduler.get(), step=i)