neuralmonkey.trainers.rl_trainer module¶
Training objectives for reinforcement learning.
-
neuralmonkey.trainers.rl_trainer.
rl_objective
(decoder: neuralmonkey.decoders.decoder.Decoder, reward_function: Callable[[numpy.ndarray, numpy.ndarray], numpy.ndarray], subtract_baseline: bool = False, normalize: bool = False, temperature: float = 1.0, ce_smoothing: float = 0.0, alpha: float = 1.0, sample_size: int = 1) → neuralmonkey.trainers.generic_trainer.Objective¶ Construct RL objective for training with sentence-level feedback.
Depending on the options the objective corresponds to: 1) sample_size = 1, normalize = False, ce_smoothing = 0.0
Bandit objective (Eq. 2) described in ‘Bandit Structured Prediction for Neural Sequence-to-Sequence Learning’ (http://www.aclweb.org/anthology/P17-1138) It’s recommended to set subtract_baseline = True.- sample_size > 1, normalize = True, ce_smoothing = 0.0
Minimum Risk Training as described in ‘Minimum Risk Training for Neural Machine Translation’ (http://www.aclweb.org/anthology/P16-1159) (Eq. 12).- sample_size > 1, normalize = False, ce_smoothing = 0.0
The Google ‘Reinforce’ objective as proposed in ‘Google’s NMT System: Bridging the Gap between Human and Machine Translation’ (https://arxiv.org/pdf/1609.08144.pdf) (Eq. 8).- sample_size > 1, normalize = False, ce_smoothing > 0.0
Google’s ‘Mixed’ objective in the above paper (Eq. 9), where ce_smoothing implements alpha.Note that ‘alpha’ controls the sharpness of the normalized distribution, while ‘temperature’ controls the sharpness during sampling.
Parameters: - decoder – a recurrent decoder to sample from
- reward_function – any evaluator object
- subtract_baseline – avg reward is subtracted from obtained reward
- normalize – the probabilities of the samples are re-normalized
- sample_size – number of samples to obtain feedback for
- ce_smoothing – add cross-entropy loss with this coefficient to loss
- alpha – determines the shape of the normalized distribution
- temperature – the softmax temperature for sampling
Returns: Objective object to be used in generic trainer