import tensorflow as tf
import numpy as np
[docs]def make_linear_var(
step, start, end, start_value, end_value, clip_min=None, clip_max=None, **kwargs
):
r"""
Linear from :math:`(a, \alpha)` to :math:`(b, \beta)`, i.e.
:math:`y = (\beta - \alpha)/(b - a) * (x - a) + \alpha`
Parameters
----------
step : tf.Tensor
:math:`x`
start : int
:math:`a`
end : int
:math:`b`
start_value : float
:math:`\alpha`
end_value : float
:math:`\beta`
clip_min : int
Minimal value returned.
clip_max : int
Maximum value returned.
Returns
-------
:math:`y` : tf.Tensor
"""
if clip_min is None:
clip_min = min(start_value, end_value)
if clip_max is None:
clip_max = max(start_value, end_value)
delta_value = end_value - start_value
delta_step = end - start
linear = (
delta_value / delta_step * (tf.cast(step, tf.float32) - start) + start_value
)
return tf.clip_by_value(linear, clip_min, clip_max)
[docs]def make_periodic_step(step, start_step: int, period_duration_in_steps: int, **kwargs):
"""
Returns step within the unit period cycle specified
Parameters
----------
step: tf.Tensor
step variable
start_step : int
an offset parameter specifying when the first period begins
period_duration_in_steps : int
period duration of step
Returns
-------
unit_step
step within unit cycle period
"""
step = tf.to_float(step)
unit_step = step - start_step
unit_step = tf.clip_by_value(unit_step, 0.0, unit_step)
unit_step = tf.mod(unit_step, period_duration_in_steps)
return unit_step
[docs]def make_exponential_var(step, start, end, start_value, end_value, decay, **kwargs):
r"""
Exponential from :math:`(a, \alpha)` to :math:`(b, \beta)` with decay
rate decay.
Parameters
----------
step : tf.Tensor
:math:`x`
start : int
:math:`a`
end : int
:math:`b`
start_value : float
:math:`\alpha`
end_value : float
:math:`\beta`
decay : int
Decay rate
Returns
-------
:math:`y` : tf.Tensor
"""
startstep = start
endstep = (np.log(end_value) - np.log(start_value)) / np.log(decay)
stepper = make_linear_var(step, start, end, startstep, endstep)
return tf.math.pow(decay, stepper) * start_value
[docs]def make_staircase_var(
step,
start,
start_value,
step_size,
stair_factor,
clip_min=0.0,
clip_max=1.0,
**kwargs
):
r"""
Parameters
----------
step : tf.Tensor
:math:`x`
start : int
:math:`a`
start_value : float
:math:`\alpha`
step_size: int
after how many steps the value should be changed
stair_factor: float
factor that the value is multiplied with at every 'step_size' steps
clip_min : int
Minimal value returned.
clip_max : int
Maximum value returned.
Returns
-------
:math:`y` : tf.Tensor
"""
stair_case = (
stair_factor ** ((tf.cast(step, tf.float32) - start) // step_size) * start_value
)
stair_case_clipped = tf.clip_by_value(stair_case, clip_min, clip_max)
return stair_case_clipped
[docs]def make_periodic_wrapper(step_function):
"""
A wrapper to wrap the step variable of a step function into a periodic step variable.
Parameters
----------
step_function: callable
the step function where to exchange the step variable with a periodic step variable
Returns
-------
a function with periodic steps
"""
def make_periodic_xxx_var(step, **kwargs):
new_step = make_periodic_step(step, **kwargs)
return step_function(new_step, **kwargs)
return make_periodic_xxx_var
[docs]def make_var(step, var_type, options):
r"""
Example
-------
usage within trainer
.. code-block:: python
grad_weight = make_var(step=self.global_step,
var_type=self.config["grad_weight"]["var_type"],
options=self.config["grad_weight"]["options"])
within yaml file
.. code-block:: yaml
grad_weight:
var_type: linear
options:
start: 50000
end: 60000
start_value: 0.0
end_value: 1.0
clip_min: 1.0e-6
clip_max: 1.0
Parameters
----------
step: tf.Tensor
scalar tensor variable
var_type: str
a string from ["linear", "exponential", "staircase"]
options: dict
keyword arguments passed to specific 'make_xxx_var' function
Returns
-------
:math:`y` : tf.Tensor
"""
switch = {
"linear": make_linear_var,
"exponential": make_exponential_var,
"staircase": make_staircase_var,
"periodic_linear": make_periodic_wrapper(make_linear_var),
"periodic_staircase": make_periodic_wrapper(make_staircase_var),
}
return switch[var_type](step=step, **options)
if __name__ == "__main__":
tf.enable_eager_execution()
import seaborn as sns
sns.set()
from matplotlib import pyplot as plt
import pprint
N = 10000
t = tf.range(0, N)
schedule_config = {
"var_type": "linear",
"options": {
"start": 2000,
"end": 3000,
"start_value": 0.0,
"end_value": 1.0,
"clip_min": 1.0e-6,
"clip_max": 1.0,
},
}
pp = pprint.PrettyPrinter(indent=4)
scheduled_variable = make_var(
t, schedule_config["var_type"], schedule_config["options"]
)
textstr = pp.pformat(schedule_config["options"])
sns.lineplot(t, scheduled_variable, label=textstr)
ax = plt.gca()
ax.legend(frameon=True, fontsize=16)
plt.show()
###
schedule_config = {
"var_type": "periodic_linear",
"options": {
"start": 0,
"end": 1000,
"start_value": 0.0,
"end_value": 1.0,
"clip_min": 1.0e-6,
"clip_max": 1.0,
"start_step": 4999,
"period_duration_in_steps": 2000,
},
}
pp = pprint.PrettyPrinter(indent=4)
scheduled_variable = make_var(
t, schedule_config["var_type"], schedule_config["options"]
)
textstr = pp.pformat(schedule_config["options"])
sns.lineplot(t, scheduled_variable, label=textstr)
ax = plt.gca()
ax.legend(frameon=True, fontsize=16, loc=1)
plt.show()
###
schedule_config = {
"var_type": "periodic_staircase",
"options": {
"start": 0,
"start_value": 1.0,
"step_size": 1000,
"stair_factor": 10,
"clip_min": 1.0,
"clip_max": 1.0e3,
"start_step": 1000,
"period_duration_in_steps": 4000,
},
}
pp = pprint.PrettyPrinter(indent=4)
scheduled_variable = make_var(
t, schedule_config["var_type"], schedule_config["options"]
)
textstr = pp.pformat(schedule_config["options"])
sns.lineplot(t, scheduled_variable, label=textstr)
ax = plt.gca()
ax.legend(frameon=True, fontsize=16, loc=1)
plt.show()
## TODO: make this nice with some annotations
###
schedule_config = {
"var_type": "staircase",
"options": {
"start": 0,
"start_value": 1.0,
"step_size": 1000,
"stair_factor": 10,
"clip_min": 1.0,
"clip_max": 1.0e3,
"start_step": 1000,
"period_duration_in_steps": 4000,
},
}
pp = pprint.PrettyPrinter(indent=4)
scheduled_variable = make_var(
t, schedule_config["var_type"], schedule_config["options"]
)
textstr = pp.pformat(schedule_config["options"])
sns.lineplot(t, scheduled_variable, label=textstr)
ax = plt.gca()
ax.legend(frameon=True, fontsize=16, loc=1)
plt.show()