neuralmonkey.trainers.rl_trainer module

Training objectives for reinforcement learning.

neuralmonkey.trainers.rl_trainer.rl_objective(decoder: neuralmonkey.decoders.decoder.Decoder, reward_function: Callable[[numpy.ndarray, numpy.ndarray], numpy.ndarray], subtract_baseline: bool = False, normalize: bool = False, temperature: float = 1.0, ce_smoothing: float = 0.0, alpha: float = 1.0, sample_size: int = 1) → neuralmonkey.trainers.generic_trainer.Objective

Construct RL objective for training with sentence-level feedback.

Depending on the options the objective corresponds to: 1) sample_size = 1, normalize = False, ce_smoothing = 0.0

Bandit objective (Eq. 2) described in ‘Bandit Structured Prediction for Neural Sequence-to-Sequence Learning’ (http://www.aclweb.org/anthology/P17-1138) It’s recommended to set subtract_baseline = True.
  1. sample_size > 1, normalize = True, ce_smoothing = 0.0
Minimum Risk Training as described in ‘Minimum Risk Training for Neural Machine Translation’ (http://www.aclweb.org/anthology/P16-1159) (Eq. 12).
  1. sample_size > 1, normalize = False, ce_smoothing = 0.0
The Google ‘Reinforce’ objective as proposed in ‘Google’s NMT System: Bridging the Gap between Human and Machine Translation’ (https://arxiv.org/pdf/1609.08144.pdf) (Eq. 8).
  1. sample_size > 1, normalize = False, ce_smoothing > 0.0
Google’s ‘Mixed’ objective in the above paper (Eq. 9), where ce_smoothing implements alpha.

Note that ‘alpha’ controls the sharpness of the normalized distribution, while ‘temperature’ controls the sharpness during sampling.

Parameters:
  • decoder – a recurrent decoder to sample from
  • reward_function – any evaluator object
  • subtract_baseline – avg reward is subtracted from obtained reward
  • normalize – the probabilities of the samples are re-normalized
  • sample_size – number of samples to obtain feedback for
  • ce_smoothing – add cross-entropy loss with this coefficient to loss
  • alpha – determines the shape of the normalized distribution
  • temperature – the softmax temperature for sampling
Returns:

Objective object to be used in generic trainer