Gradients#

Collection of gradient Ivy functions.

ivy.adam_step(dcdw, mw, vw, step, /, *, beta1=0.9, beta2=0.999, epsilon=1e-07, out=None)[source]#

Compute adam step delta, given the derivatives of some cost c with respect to weights ws, using ADAM update. `[reference]

<https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam>`_

Parameters:
  • dcdw (Union[Array, NativeArray]) – Derivates of the cost c with respect to the weights ws, [dc/dw for w in ws].

  • mw (Union[Array, NativeArray]) – running average of the gradients

  • vw (Union[Array, NativeArray]) – running average of second moments of the gradients

  • step (Union[int, float]) – training step

  • beta1 (float, default: 0.9) – gradient forgetting factor (Default value = 0.9)

  • beta2 (float, default: 0.999) – second moment of gradient forgetting factor (Default value = 0.999)

  • epsilon (float, default: 1e-07) – divisor during adam update, preventing division by zero (Default value = 1e-7)

  • out (Optional[Array], default: None) – optional output array, for writing the effective grad of adam_step to. It must have a shape that the inputs broadcast to.

Return type:

Tuple[Array, Array, Array]

Returns:

ret – The adam step delta.

Examples

With ivy.Array inputs:

>>> dcdw = ivy.array([1, 2, 3])
>>> mw = ivy.ones(3)
>>> vw = ivy.ones(1)
>>> step = ivy.array(3)
>>> adam_step_delta = ivy.adam_step(dcdw, mw, vw, step)
>>> print(adam_step_delta)
(ivy.array([0.2020105 , 0.22187898, 0.24144873]),
ivy.array([0.99999998, 1.09999998, 1.19999998]),
ivy.array([1.00000001, 1.00300001, 1.00800001]))
>>> dcdw = ivy.array([[1., 4., -3.], [2., 3., 0.5]])
>>> mw = ivy.zeros((2,3))
>>> vw = ivy.zeros(3)
>>> step = ivy.array(1)
>>> beta1 = 0.86
>>> beta2 = 0.95
>>> epsilon = 1e-6
>>> adam_step_delta = ivy.adam_step(dcdw, mw, vw, step, beta1=beta1, beta2=beta2,
...                                 epsilon=epsilon)
>>> print(adam_step_delta)
(ivy.array([[ 1.,  1., -1.],
            [ 1.,  1.,  1.]]),
    ivy.array([[ 0.14,  0.56, -0.42],
               [ 0.28,  0.42,  0.07]]),
 ivy.array([[0.05  , 0.8   , 0.45  ],
            [0.2   , 0.45  , 0.0125]]))
>>> dcdw = ivy.array([0.1, -0.7, 2])
>>> mw = ivy.ones(1)
>>> vw = ivy.ones(1)
>>> step = ivy.array(3.6)
>>> out = ivy.zeros_like(dcdw)
>>> adam_step_delta = ivy.adam_step(dcdw, mw, vw, step, out=out)
>>> print(out)
ivy.array([0.17294501, 0.15770318, 0.20863818])

With one ivy.Container input:

>>> dcdw = ivy.Container(a=ivy.array([0., 1., 2.]),
...                      b=ivy.array([3., 4., 5.]))
>>> mw = ivy.array([1., 4., 9.])
>>> vw = ivy.array([0.,])
>>> step = ivy.array([3.4])
>>> beta1 = 0.87
>>> beta2 = 0.976
>>> epsilon = 1e-5
>>> adam_step_delta = ivy.adam_step(dcdw, mw, vw, step, beta1=beta1, beta2=beta2,
...                                 epsilon=epsilon)
>>> print(adam_step_delta)
({
    a: ivy.array([6.49e+04, 1.74e+01, 1.95e+01]),
    b: ivy.array([2.02, 4.82, 8.17])
}, {
    a: ivy.array([0.87, 3.61, 8.09]),
    b: ivy.array([1.26, 4., 8.48])
}, {
    a: ivy.array([0., 0.024, 0.096]),
    b: ivy.array([0.216, 0.384, 0.6])
})

With multiple ivy.Container inputs:

>>> dcdw = ivy.Container(a=ivy.array([0., 1., 2.]),
...                      b=ivy.array([3., 4., 5.]))
>>> mw = ivy.Container(a=ivy.array([0., 0., 0.]),
...                    b=ivy.array([0., 0., 0.]))
>>> vw = ivy.Container(a=ivy.array([0.,]),
...                    b=ivy.array([0.,]))
>>> step = ivy.array([3.4])
>>> beta1 = 0.87
>>> beta2 = 0.976
>>> epsilon = 1e-5
>>> adam_step_delta = ivy.adam_step(dcdw, mw, vw, step, beta1=beta1, beta2=beta2,
...                                 epsilon=epsilon)
>>> print(adam_step_delta)
({
    a: ivy.array([0., 0.626, 0.626]),
    b: ivy.array([0.626, 0.626, 0.626])
}, {
    a: ivy.array([0., 0.13, 0.26]),
    b: ivy.array([0.39, 0.52, 0.65])
}, {
    a: ivy.array([0., 0.024, 0.096]),
    b: ivy.array([0.216, 0.384, 0.6])
})
ivy.adam_update(w, dcdw, lr, mw_tm1, vw_tm1, step, /, *, beta1=0.9, beta2=0.999, epsilon=1e-07, stop_gradients=True, out=None)[source]#

Update weights ws of some function, given the derivatives of some cost c with respect to ws, using ADAM update. `[reference]

<https://en.wikipedia.org/wiki/Stochastic_gradient_descent#Adam>`_

Parameters:
  • w (Union[Array, NativeArray]) – Weights of the function to be updated.

  • dcdw (Union[Array, NativeArray]) – Derivates of the cost c with respect to the weights ws, [dc/dw for w in ws].

  • lr (Union[float, Array, NativeArray]) – Learning rate(s), the rate(s) at which the weights should be updated relative to the gradient.

  • mw_tm1 (Union[Array, NativeArray]) – running average of the gradients, from the previous time-step.

  • vw_tm1 (Union[Array, NativeArray]) – running average of second moments of the gradients, from the previous time-step.

  • step (int) – training step.

  • beta1 (float, default: 0.9) – gradient forgetting factor (Default value = 0.9).

  • beta2 (float, default: 0.999) – second moment of gradient forgetting factor (Default value = 0.999).

  • epsilon (float, default: 1e-07) – divisor during adam update, preventing division by zero (Default value = 1e-7).

  • stop_gradients (bool, default: True) – Whether to stop the gradients of the variables after each gradient step. Default is True.

  • out (Optional[Array], default: None) – optional output array, for writing the new function weights ws_new to. It must have a shape that the inputs broadcast to.

Return type:

Tuple[Array, Array, Array]

Returns:

ret – The new function weights ws_new, and also new mw and vw, following the adam updates.

Examples

With ivy.Array inputs:

>>> w = ivy.array([1., 2, 3])
>>> dcdw = ivy.array([0.5,0.2,0.1])
>>> lr = ivy.array(0.1)
>>> vw_tm1 = ivy.zeros(1)
>>> mw_tm1 = ivy.zeros(3)
>>> step = 1
>>> updated_weights = ivy.adam_update(w, dcdw, lr, mw_tm1, vw_tm1, step)
>>> print(updated_weights)
(ivy.array([0.90000075, 1.90000164, 2.9000032 ]),
ivy.array([0.05, 0.02, 0.01]),
ivy.array([2.50000012e-04, 4.00000063e-05, 1.00000016e-05]))
>>> w = ivy.array([[1., 2, 3],[4, 2, 4],[6, 4, 2]])
>>> dcdw = ivy.array([[0.1, 0.2, 0.3],[0.4, 0.5, 0.1],[0.1, 0.5, 0.3]])
>>> lr = ivy.array(0.1)
>>> mw_tm1 = ivy.zeros((3,3))
>>> vw_tm1 = ivy.zeros(3)
>>> step = 2
>>> beta1 = 0.9
>>> beta2 = 0.999
>>> epsilon = 1e-7
>>> out = ivy.zeros_like(w)
>>> stop_gradients = True
>>> updated_weights = ivy.adam_update(w, dcdw, lr, mw_tm1, vw_tm1, step,
...                               beta1=beta1, beta2=beta2,
...                               epsilon=epsilon, out=out,
...                               stop_gradients=stop_gradients)
>>> print(updated_weights)
(
ivy.array([[0.92558873, 1.92558754, 2.92558718],
           [3.92558694, 1.92558682, 3.92558861],
           [5.92558861, 3.92558694, 1.92558718]]),
ivy.array([[0.01, 0.02, 0.03],
           [0.04, 0.05, 0.01],
           [0.01, 0.05, 0.03]]),
ivy.array([[1.00000016e-05, 4.00000063e-05, 9.00000086e-05],
           [1.60000025e-04, 2.50000012e-04, 1.00000016e-05],
           [1.00000016e-05, 2.50000012e-04, 9.00000086e-05]])
)

With one ivy.Container input:

>>> w = ivy.Container(a=ivy.array([1., 2., 3.]), b=ivy.array([4., 5., 6.]))
>>> dcdw = ivy.array([0.5, 0.2, 0.4])
>>> mw_tm1 = ivy.array([0., 0., 0.])
>>> vw_tm1 = ivy.array([0.])
>>> lr = ivy.array(0.01)
>>> step = 2
>>> updated_weights = ivy.adam_update(w, dcdw, mw_tm1, vw_tm1, lr, step)
>>> print(updated_weights)
({
    a: ivy.array([1., 2., 3.]),
    b: ivy.array([4., 5., 6.])
}, ivy.array([0.05, 0.02, 0.04]), ivy.array([0.01024, 0.01003, 0.01015]))

With multiple ivy.Container inputs:

>>> x = ivy.Container(a=ivy.array([0., 1., 2.]),
...                   b=ivy.array([3., 4., 5.]))
>>> dcdw = ivy.Container(a=ivy.array([0.1,0.3,0.3]),
...                      b=ivy.array([0.3,0.2,0.2]))
>>> mw_tm1 = ivy.Container(a=ivy.array([0.,0.,0.]),
...                        b=ivy.array([0.,0.,0.]))
>>> vw_tm1 = ivy.Container(a=ivy.array([0.,]),
...                        b=ivy.array([0.,]))
>>> step = 3
>>> beta1 = 0.9
>>> beta2 = 0.999
>>> epsilon = 1e-7
>>> stop_gradients = False
>>> lr = ivy.array(0.001)
>>> updated_weights = ivy.adam_update(w, dcdw, lr, mw_tm1, vw_tm1, step,
...                               beta1=beta1,
...                               beta2=beta2, epsilon=epsilon,
...                               stop_gradients=stop_gradients)
>>> print(updated_weights)
({
    a: ivy.array([0.99936122, 1.99936116, 2.99936128]),
    b: ivy.array([3.99936128, 4.99936104, 5.99936104])
}, {
    a: ivy.array([0.01, 0.03, 0.03]),
    b: ivy.array([0.03, 0.02, 0.02])
}, {
    a: ivy.array([1.00000016e-05, 9.00000086e-05, 9.00000086e-05]),
    b: ivy.array([9.00000086e-05, 4.00000063e-05, 4.00000063e-05])
})
ivy.execute_with_gradients(func, xs, /, *, retain_grads=False, xs_grad_idxs=((0,),), ret_grad_idxs=((0,),))[source]#

Call function func with input of xs variables, and return the function result func_ret and the gradients of each output variable w.r.t each input variable,

Parameters:
  • func – Function for which we compute the gradients of the output with respect to xs input.

  • xs (Union[Array, NativeArray]) – Variables for which to compute the function gradients with respective to. This can be a single array or an arbitrary nest of arrays.

  • retain_grads (bool, default: False) – Whether to retain the gradients of the returned values. (Default value = False)

  • xs_grad_idxs (Sequence[Sequence[Union[str, int]]], default: ((0,),)) – Indices of the input arrays to compute gradients with respect to. If None, gradients are returned with respect to all input arrays. If xs is an ivy.Array or ivy.Container, the default value is None, otherwise the default value is [[0]].

  • ret_grad_idxs (Sequence[Sequence[Union[str, int]]], default: ((0,),)) – Indices of the returned arrays for which to return computed gradients. If None, gradients are returned for all returned arrays. If the returned object from the func is an ivy.Array or ivy.Container, the default value is None otherwise the default value is [[0]].

Return type:

Tuple[Array, Array]

Returns:

ret – the function result func_ret and a dictionary of gradients of each output variable w.r.t each input variable.

Examples

With ivy.Array input:

>>> x = ivy.array([[1, 4, 6], [2, 6, 9]])
>>> func = lambda x: ivy.mean(ivy.square(x))
>>> func_ret = ivy.execute_with_gradients(func, x, retain_grads=True)
>>> print(func_ret)
(ivy.array(29.), ivy.array([[0.33333334, 1.33333337, 2.        ],
   [0.66666669, 2.        , 3.        ]]))

With ivy.Container input:

>>> x = ivy.Container(a = ivy.array([1, 4, 6]),
...                   b = ivy.array([2, 6, 9]))
>>> func = lambda x: ivy.mean(ivy.square(x))
>>> func_ret = ivy.execute_with_gradients(func, x, retain_grads=True)
>>> print(func_ret)
({
a: ivy.array(17.666666),
b: ivy.array(40.333332)
},
{
a: {
    a: ivy.array([0.66666669, 2.66666675, 4.]),
    b: ivy.array([0., 0., 0.])
},
b: {
    a: ivy.array([0., 0., 0.]),
    b: ivy.array([1.33333337, 4., 6.])
}
})
ivy.grad(func, argnums=0)[source]#

Call function func, and return func’s gradients.

Parameters:
  • func (Callable) – Function for which we compute the gradients of the output with respect to xs input.

  • argnums (Union[int, Sequence[int]], default: 0) – Indices of the input arrays to compute gradients with respect to. Default is 0.

Return type:

Callable

Returns:

ret – the grad function

Examples

>>> x = ivy.array([[4.6, 2.1, 5], [2.8, 1.3, 6.2]])
>>> func = lambda x: ivy.mean(ivy.square(x))
>>> grad_fn = ivy.grad(func)
>>> grad = grad_fn(x)
>>> print(grad)
ivy.array([[1.53 , 0.7  , 1.67 ],
...        [0.933, 0.433, 2.07 ]])
ivy.gradient_descent_update(w, dcdw, lr, /, *, stop_gradients=True, out=None)[source]#

Update weights ws of some function, given the derivatives of some cost c with respect to ws, [dc/dw for w in ws].

Parameters:
  • w (Union[Array, NativeArray]) – Weights of the function to be updated.

  • dcdw (Union[Array, NativeArray]) – Derivates of the cost c with respect to the weights ws, [dc/dw for w in ws].

  • lr (Union[float, Array, NativeArray]) – Learning rate(s), the rate(s) at which the weights should be updated relative to the gradient.

  • stop_gradients (bool, default: True) – Whether to stop the gradients of the variables after each gradient step. Default is True.

  • out (Optional[Array], default: None) – optional output array, for writing the result to. It must have a shape that the inputs broadcast to.

Return type:

Array

Returns:

ret – The new weights, following the gradient descent updates.

Examples

With ivy.Array inputs:

>>> w = ivy.array([[1., 2, 3],
...                [4, 6, 1],
...                [1, 0, 7]])
>>> dcdw = ivy.array([[0.5, 0.2, 0.1],
...                   [0.3, 0.6, 0.4],
...                   [0.4, 0.7, 0.2]])
>>> lr = ivy.array(0.1)
>>> new_weights = ivy.gradient_descent_update(w, dcdw, lr, stop_gradients=True)
>>> print(new_weights)
ivy.array([[ 0.95,  1.98,  2.99],
...        [ 3.97,  5.94,  0.96],
...        [ 0.96, -0.07,  6.98]])
>>> w = ivy.array([1., 2., 3.])
>>> dcdw = ivy.array([0.5, 0.2, 0.1])
>>> lr = ivy.array(0.3)
>>> out = ivy.zeros_like(w)
>>> ivy.gradient_descent_update(w, dcdw, lr, out=out)
>>> print(out)
ivy.array([0.85, 1.94, 2.97])

With one ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([1., 2., 3.]),
...                   b=ivy.array([3.48, 5.72, 1.98]))
>>> dcdw = ivy.array([0.5, 0.2, 0.1])
>>> lr = ivy.array(0.3)
>>> w_new = ivy.gradient_descent_update(w, dcdw, lr)
>>> print(w_new)
{
    a: ivy.array([0.85, 1.94, 2.97]),
    b: ivy.array([3.33, 5.66, 1.95])
}

With multiple ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([1., 2., 3.]),
...                   b=ivy.array([3.48, 5.72, 1.98]))
>>> dcdw = ivy.Container(a=ivy.array([0.5, 0.2, 0.1]),
...                      b=ivy.array([2., 3.42, 1.69]))
>>> lr = ivy.array(0.3)
>>> w_new = ivy.gradient_descent_update(w, dcdw, lr)
>>> print(w_new)
{
    a: ivy.array([0.85, 1.94, 2.97]),
    b: ivy.array([2.88, 4.69, 1.47])
}
ivy.jac(func)[source]#

Call function func, and return func’s Jacobian partial derivatives.

Parameters:

func (Callable) – Function for which we compute the gradients of the output with respect to xs input.

Return type:

Callable

Returns:

ret – the Jacobian function

Examples

With ivy.Array input:

>>> x = ivy.array([[4.6, 2.1, 5], [2.8, 1.3, 6.2]])
>>> func = lambda x: ivy.mean(ivy.square(x))
>>> jac_fn = ivy.jac(func)
>>> jacobian = jac_fn(x)
>>> print(jacobian)
ivy.array([[1.53 , 0.7  , 1.67 ],
...        [0.933, 0.433, 2.07 ]])
ivy.lamb_update(w, dcdw, lr, mw_tm1, vw_tm1, step, /, *, beta1=0.9, beta2=0.999, epsilon=1e-07, max_trust_ratio=10, decay_lambda=0, stop_gradients=True, out=None)[source]#

Update weights ws of some function, given the derivatives of some cost c with respect to ws, [dc/dw for w in ws], by applying LAMB method.

Parameters:
  • w (Union[Array, NativeArray]) – Weights of the function to be updated.

  • dcdw (Union[Array, NativeArray]) – Derivates of the cost c with respect to the weights ws, [dc/dw for w in ws].

  • lr (Union[float, Array, NativeArray]) – Learning rate(s), the rate(s) at which the weights should be updated relative to the gradient.

  • mw_tm1 (Union[Array, NativeArray]) – running average of the gradients, from the previous time-step.

  • vw_tm1 (Union[Array, NativeArray]) – running average of second moments of the gradients, from the previous time-step.

  • step (int) – training step.

  • beta1 (float, default: 0.9) – gradient forgetting factor (Default value = 0.9).

  • beta2 (float, default: 0.999) – second moment of gradient forgetting factor (Default value = 0.999).

  • epsilon (float, default: 1e-07) – divisor during adam update, preventing division by zero (Default value = 1e-7).

  • max_trust_ratio (Union[int, float], default: 10) – The maximum value for the trust ratio. (Default value = 10)

  • decay_lambda (float, default: 0) – The factor used for weight decay. (Default value = 0).

  • stop_gradients (bool, default: True) – Whether to stop the gradients of the variables after each gradient step. Default is True.

  • out (Optional[Array], default: None) – optional output array, for writing the new function weights ws_new to. It must have a shape that the inputs broadcast to.

Return type:

Tuple[Array, Array, Array]

Returns:

ret – The new function weights ws_new, following the LAMB updates.

Examples

With ivy.Array inputs:

>>> w = ivy.array([1., 2, 3])
>>> dcdw = ivy.array([0.5,0.2,0.1])
>>> lr = ivy.array(0.1)
>>> vw_tm1 = ivy.zeros(1)
>>> mw_tm1 = ivy.zeros(3)
>>> step = ivy.array(1)
>>> new_weights = ivy.lamb_update(w, dcdw, lr, mw_tm1, vw_tm1, step)
>>> print(new_weights)
(ivy.array([0.784, 1.78 , 2.78 ]),
... ivy.array([0.05, 0.02, 0.01]),
... ivy.array([2.5e-04, 4.0e-05, 1.0e-05]))
>>> w = ivy.array([[1., 2, 3],[4, 6, 1],[1, 0, 7]])
>>> dcdw = ivy.array([[0.5, 0.2, 0.1],[0.3, 0.6, 0.4],[0.4, 0.7, 0.2]])
>>> lr = ivy.array(0.1)
>>> mw_tm1 = ivy.zeros((3,3))
>>> vw_tm1 = ivy.zeros(3)
>>> step = ivy.array(1)
>>> beta1 = 0.9
>>> beta2 = 0.999
>>> epsilon = 1e-7
>>> max_trust_ratio = 10
>>> decay_lambda = 0
>>> out = ivy.zeros_like(w)
>>> stop_gradients = True
>>> new_weights = ivy.lamb_update(w, dcdw, lr, mw_tm1, vw_tm1, step, beta1=beta1,
...                               beta2=beta2, epsilon=epsilon,
...                               max_trust_ratio=max_trust_ratio,
...                               decay_lambda=decay_lambda, out=out,
...                               stop_gradients=stop_gradients)
>>> print(out)
ivy.array([[ 0.639,  1.64 ,  2.64 ],
...        [ 3.64 ,  5.64 ,  0.639],
...        [ 0.639, -0.361,  6.64 ]])

With one ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([1., 2., 3.]), b=ivy.array([4., 5., 6.]))
>>> dcdw = ivy.array([3., 4., 5.])
>>> mw_tm1 = ivy.array([0., 0., 0.])
>>> vw_tm1 = ivy.array([0.])
>>> lr = ivy.array(1.)
>>> step = ivy.array([2])
>>> new_weights = ivy.lamb_update(w, dcdw, mw_tm1, vw_tm1, lr, step)
>>> print(new_weights)
({
    a: ivy.array([1., 2., 3.]),
    b: ivy.array([4., 5., 6.])
}, ivy.array([0.3, 0.4, 0.5]), ivy.array([1.01, 1.01, 1.02]))

With multiple ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([1.,3.,5.]),
...                   b=ivy.array([3.,4.,2.]))
>>> dcdw = ivy.Container(a=ivy.array([0.2,0.3,0.6]),
...                      b=ivy.array([0.6,0.4,0.7]))
>>> mw_tm1 = ivy.Container(a=ivy.array([0.,0.,0.]),
...                        b=ivy.array([0.,0.,0.]))
>>> vw_tm1 = ivy.Container(a=ivy.array([0.,]),
...                        b=ivy.array([0.,]))
>>> step = ivy.array([3.4])
>>> beta1 = 0.9
>>> beta2 = 0.999
>>> epsilon = 1e-7
>>> max_trust_ratio = 10
>>> decay_lambda = 0
>>> stop_gradients = True
>>> lr = ivy.array(0.5)
>>> new_weights = ivy.lamb_update(w, dcdw, lr, mw_tm1, vw_tm1, step, beta1=beta1,
...                               beta2=beta2, epsilon=epsilon,
...                               max_trust_ratio=max_trust_ratio,
...                               decay_lambda=decay_lambda,
...                               stop_gradients=stop_gradients)
>>> print(new_weights)
({
    a: ivy.array([-0.708, 1.29, 3.29]),
    b: ivy.array([1.45, 2.45, 0.445])
}, {
    a: ivy.array([0.02, 0.03, 0.06]),
    b: ivy.array([0.06, 0.04, 0.07])
}, {
    a: ivy.array([4.0e-05, 9.0e-05, 3.6e-04]),
    b: ivy.array([0.00036, 0.00016, 0.00049])
})
ivy.lars_update(w, dcdw, lr, /, *, decay_lambda=0, stop_gradients=True, out=None)[source]#

Update weights ws of some function, given the derivatives of some cost c with respect to ws, [dc/dw for w in ws], by applying Layerwise Adaptive Rate Scaling (LARS) method.

Parameters:
  • w (Union[Array, NativeArray]) – Weights of the function to be updated.

  • dcdw (Union[Array, NativeArray]) – Derivates of the cost c with respect to the weights ws, [dc/dw for w in ws].

  • lr (Union[float, Array, NativeArray]) – Learning rate, the rate at which the weights should be updated relative to the gradient.

  • decay_lambda (float, default: 0) – The factor used for weight decay. Default is zero.

  • stop_gradients (bool, default: True) – Whether to stop the gradients of the variables after each gradient step. Default is True.

  • out (Optional[Array], default: None) – optional output array, for writing the result to. It must have a shape that the inputs broadcast to.

Return type:

Array

Returns:

ret – The new function weights ws_new, following the LARS updates.

Examples

With ivy.Array inputs:

>>> w = ivy.array([[3., 1, 5],
...                [7, 2, 9]])
>>> dcdw = ivy.array([[0.3, 0.1, 0.2],
...                   [0.1, 0.2, 0.4]])
>>> lr = ivy.array(0.1)
>>> new_weights = ivy.lars_update(w, dcdw, lr)
>>> print(new_weights)
ivy.array([[2.34077978, 0.78025991, 4.56051969],
...        [6.78026009, 1.56051981, 8.12103939]])
>>> w = ivy.array([3., 1, 5])
>>> dcdw = ivy.array([0.3, 0.1, 0.2])
>>> lr = ivy.array(0.1)
>>> out = ivy.zeros_like(dcdw)
>>> ivy.lars_update(w, dcdw, lr, out=out)
>>> print(out)
ivy.array([2.52565837, 0.8418861 , 4.68377209])

With one ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([3.2, 2.6, 1.3]),
...                    b=ivy.array([1.4, 3.1, 5.1]))
>>> dcdw = ivy.array([0.2, 0.4, 0.1])
>>> lr = ivy.array(0.1)
>>> new_weights = ivy.lars_update(w, dcdw, lr)
>>> print(new_weights)
{
    a: ivy.array([3.01132035, 2.22264051, 1.2056601]),
    b: ivy.array([1.1324538, 2.56490755, 4.96622658])
}

With multiple ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([3.2, 2.6, 1.3]),
...                    b=ivy.array([1.4, 3.1, 5.1]))
>>> dcdw = ivy.Container(a=ivy.array([0.2, 0.4, 0.1]),
...                       b=ivy.array([0.3,0.1,0.2]))
>>> lr = ivy.array(0.1)
>>> new_weights = ivy.lars_update(w, dcdw, lr)
>>> print(new_weights)
{
    a: ivy.array([3.01132035, 2.22264051, 1.2056601]),
    b: ivy.array([0.90848625, 2.93616199, 4.77232409])
}
ivy.optimizer_update(w, effective_grad, lr, /, *, stop_gradients=True, out=None)[source]#

Update weights ws of some function, given the true or effective derivatives of some cost c with respect to ws, [dc/dw for w in ws].

Parameters:
  • w (Union[Array, NativeArray]) – Weights of the function to be updated.

  • effective_grad (Union[Array, NativeArray]) – Effective gradients of the cost c with respect to the weights ws, [dc/dw for w in ws].

  • lr (Union[float, Array, NativeArray]) – Learning rate(s), the rate(s) at which the weights should be updated relative to the gradient.

  • stop_gradients (bool, default: True) – Whether to stop the gradients of the variables after each gradient step. Default is True.

  • out (Optional[Array], default: None) – optional output array, for writing the result to. It must have a shape that the inputs broadcast to.

Return type:

Array

Returns:

ret – The new function weights ws_new, following the optimizer updates.

Examples

With ivy.Array inputs:

>>> w = ivy.array([1., 2., 3.])
>>> effective_grad = ivy.zeros(3)
>>> lr = 3e-4
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr)
>>> print(ws_new)
ivy.array([1., 2., 3.])
>>> w = ivy.array([1., 2., 3.])
>>> effective_grad = ivy.zeros(3)
>>> lr = 3e-4
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr,
...                               out=None, stop_gradients=True)
>>> print(ws_new)
ivy.array([1., 2., 3.])
>>> w = ivy.array([[1., 2.], [4., 5.]])
>>> out = ivy.zeros_like(w)
>>> effective_grad = ivy.array([[4., 5.], [7., 8.]])
>>> lr = ivy.array([3e-4, 1e-2])
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr, out=out)
>>> print(out)
ivy.array([[0.999, 1.95],
           [4., 4.92]])
>>> w = ivy.array([1., 2., 3.])
>>> out = ivy.zeros_like(w)
>>> effective_grad = ivy.array([4., 5., 6.])
>>> lr = ivy.array([3e-4])
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr,
...                               stop_gradients=False, out=out)
>>> print(out)
ivy.array([0.999, 2.   , 3.   ])

With one ivy.Container input:

>>> w = ivy.Container(a=ivy.array([0., 1., 2.]),
...                   b=ivy.array([3., 4., 5.]))
>>> effective_grad = ivy.array([0., 0., 0.])
>>> lr = 3e-4
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr)
>>> print(ws_new)
{
    a: ivy.array([0., 1., 2.]),
    b: ivy.array([3., 4., 5.])
}

With multiple ivy.Container inputs:

>>> w = ivy.Container(a=ivy.array([0., 1., 2.]),
...                   b=ivy.array([3., 4., 5.]))
>>> effective_grad = ivy.Container(a=ivy.array([0., 0., 0.]),
...                                b=ivy.array([0., 0., 0.]))
>>> lr = 3e-4
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr, out=w)
>>> print(w)
{
    a: ivy.array([0., 1., 2.]),
    b: ivy.array([3., 4., 5.])
}
>>> w = ivy.Container(a=ivy.array([0., 1., 2.]),
...                   b=ivy.array([3., 4., 5.]))
>>> effective_grad = ivy.Container(a=ivy.array([0., 0., 0.]),
...                                b=ivy.array([0., 0., 0.]))
>>> lr = ivy.array([3e-4])
>>> ws_new = ivy.optimizer_update(w, effective_grad, lr,
...                               stop_gradients=False)
>>> print(ws_new)
{
    a: ivy.array([0., 1., 2.]),
    b: ivy.array([3., 4., 5.])
}
ivy.stop_gradient(x, /, *, preserve_type=True, out=None)[source]#

Stop gradient computation.

Parameters:
  • x (Union[Array, NativeArray]) – Array for which to stop the gradient.

  • preserve_type (bool, default: True) – Whether to preserve gradient computation on ivy.Array instances. Default is True.

  • out (Optional[Array], default: None) – optional output array, for writing the result to. It must have a shape that the inputs broadcast to.

Return type:

Array

Returns:

  • ret – The same array x, but with no gradient information.

  • Both the description and the type hints above assumes an array input for simplicity,

  • but this function is nestable, and therefore also accepts ivy.Container

  • instances in place of any of the arguments.

Examples

With ivy.Array inputs:

>>> x = ivy.array([1., 2., 3.])
>>> y = ivy.stop_gradient(x, preserve_type=True)
>>> print(y)
ivy.array([1., 2., 3.])
>>> x = ivy.zeros((2, 3))
>>> ivy.stop_gradient(x, preserve_type=False, out=x)
>>> print(x)
ivy.array([[0., 0., 0.],
           [0., 0., 0.]])

With one ivy.Container inputs:

>>> x = ivy.Container(a=ivy.array([0., 1., 2.]),
...                   b=ivy.array([3., 4., 5.]))
>>> y = ivy.stop_gradient(x, preserve_type=False)
>>> print(y)
{
    a: ivy.array([0., 1., 2.]),
    b: ivy.array([3., 4., 5.])
}

With multiple ivy.Container inputs:

>>> x = ivy.Container(a=ivy.array([0., 1., 2.]),
...                   b=ivy.array([3., 4., 5.]))
>>> ivy.stop_gradient(x, preserve_type=True, out=x)
>>> print(x)
{
    a: ivy.array([0., 1., 2.]),
    b: ivy.array([3., 4., 5.])
}
ivy.value_and_grad(func)[source]#

Create a function that evaluates both func and the gradient of func.

Parameters:

func (Callable) – Function for which we compute the gradients of the output with respect to xs input.

Return type:

Callable

Returns:

ret – A function that returns both func and the gradient of func.

Examples

With ivy.Array input:

>>> x = ivy.array([[4.6, 2.1, 5], [2.8, 1.3, 6.2]])
>>> func = lambda x: ivy.mean(ivy.square(x))
>>> grad_fn = ivy.value_and_grad(func)
>>> value_grad = grad_fn(x)
>>> print(value_grad)
(ivy.array(16.42333412), ivy.array([[1.5333333 , 0.69999999, 1.66666675],
       [0.93333334, 0.43333334, 2.0666666 ]]))

This should have hopefully given you an overview of the gradients submodule, if you have any questions, please feel free to reach out on our discord in the gradients channel!