fomaml_step

ivy.fomaml_step(batch, inner_cost_fn, outer_cost_fn, variables, inner_grad_steps, inner_learning_rate, /, *, inner_optimization_step=<function gradient_descent_update>, inner_batch_fn=None, outer_batch_fn=None, average_across_steps=False, batched=True, inner_v=None, keep_inner_v=True, outer_v=None, keep_outer_v=True, return_inner_v=False, num_tasks=None, stop_gradients=True)[source]

Perform step of first order MAML.

Parameters
  • batch (Container) – The input batch

  • inner_cost_fn (Callable) – callable for the inner loop cost function, receving task-specific sub-batch, inner vars and outer vars

  • outer_cost_fn (Callable) – callable for the outer loop cost function, receving task-specific sub-batch, inner vars and outer vars. If None, the cost from the inner loop will also be optimized in the outer loop.

  • variables (Container) – Variables to be optimized during the meta step

  • inner_grad_steps (int) – Number of gradient steps to perform during the inner loop.

  • inner_learning_rate (float) – The learning rate of the inner loop.

  • inner_optimization_step (Callable) – The function used for the inner loop optimization. (default: <function gradient_descent_update at 0x7fa231f7b940>) Default is ivy.gradient_descent_update.

  • inner_batch_fn (Optional[Callable]) – Function to apply to the task sub-batch, before passing to the inner_cost_fn. (default: None) Default is None.

  • outer_batch_fn (Optional[Callable]) – Function to apply to the task sub-batch, before passing to the outer_cost_fn. (default: None) Default is None.

  • average_across_steps (bool) – Whether to average the inner loop steps for the outer loop update. (default: False) Default is False.

  • batched (bool) – Whether to batch along the time dimension, and run the meta steps in batch. (default: True) Default is True.

  • inner_v (Optional[Container]) – Nested variable keys to be optimized during the inner loop, with same keys and (default: None) boolean values. (Default value = None)

  • keep_inner_v (bool) – If True, the key chains in inner_v will be kept, otherwise they will be removed. (default: True) Default is True.

  • outer_v (Optional[Container]) – Nested variable keys to be optimized during the inner loop, with same keys and (default: None) boolean values. (Default value = None)

  • keep_outer_v (bool) – If True, the key chains in inner_v will be kept, otherwise they will be removed. (default: True) Default is True.

  • return_inner_v (Union[str, bool]) – Either ‘first’, ‘all’, or False. ‘first’ means the variables for the first task (default: False) inner loop will also be returned. variables for all tasks will be returned with ‘all’. Default is False.

  • num_tasks (Optional[int]) – Number of unique tasks to inner-loop optimize for the meta step. Determined from (default: None) batch by default.

  • stop_gradients (bool) – Whether to stop the gradients of the cost. Default is True. (default: True)

Return type

Tuple[Array, Container, Any]

Returns

ret – The cost and the gradients with respect to the outer loop variables.

Device Support

Device

JAX

NumPy

TensorFlow

PyTorch

CPU

GPU


Supported Frameworks: