neuralmonkey.trainers package


neuralmonkey.trainers.cross_entropy_trainer module

class neuralmonkey.trainers.cross_entropy_trainer.CrossEntropyTrainer(decoders: typing.List[typing.Any], decoder_weights: typing.Union[typing.List[typing.Union[tensorflow.python.framework.ops.Tensor, float, NoneType]], NoneType] = None, l1_weight=0.0, l2_weight=0.0, clip_norm=False, optimizer=None, global_step=None) → None

Bases: neuralmonkey.trainers.generic_trainer.GenericTrainer

neuralmonkey.trainers.cross_entropy_trainer.xent_objective(decoder, weight=None) → neuralmonkey.trainers.generic_trainer.Objective

Get XENT objective from decoder with cost.

neuralmonkey.trainers.generic_trainer module

class neuralmonkey.trainers.generic_trainer.GenericTrainer(objectives: typing.List[neuralmonkey.trainers.generic_trainer.Objective], l1_weight: float = 0.0, l2_weight: float = 0.0, clip_norm: typing.Union[float, NoneType] = None, optimizer=None, global_step=None) → None

Bases: object

get_executable(compute_losses=True, summaries=True) → neuralmonkey.runners.base_runner.Executable
class neuralmonkey.trainers.generic_trainer.Objective(name, decoder, loss, gradients, weight)

Bases: tuple


Alias for field number 1


Alias for field number 3


Alias for field number 2


Alias for field number 0


Alias for field number 4

class neuralmonkey.trainers.generic_trainer.TrainExecutable(all_coders, train_op, losses, scalar_summaries, histogram_summaries)

Bases: neuralmonkey.runners.base_runner.Executable

collect_results(results: typing.List[typing.Dict]) → None
next_to_execute() → typing.Tuple[typing.List[typing.Any], typing.Union[typing.Dict, typing.List], typing.Dict[tensorflow.python.framework.ops.Tensor, typing.Union[int, float, numpy.ndarray]]]

neuralmonkey.trainers.self_critical_objective module

Training objective for self-critical learning.

Self-critic learning is a modification of the REINFORCE algorithm that uses the reward of the train-time decoder output as a baseline in the update step.

For more details see:

neuralmonkey.trainers.self_critical_objective.reinforce_score(reward: tensorflow.python.framework.ops.Tensor, baseline: tensorflow.python.framework.ops.Tensor, decoded: tensorflow.python.framework.ops.Tensor, logits: tensorflow.python.framework.ops.Tensor) → tensorflow.python.framework.ops.Tensor

Cost function whose derivative is the REINFORCE equation.

This implements the primitive function to the central equation of the REINFORCE algorithm that estimates the gradients of the loss with respect to decoder logits.

It uses the fact that the second term of the product (the difference of the word distribution and one hot vector of the decoded word) is a derivative of negative log likelihood of the decoded word. The reward function and the baseline are however treated as a constant, so they influence the derivate only multiplicatively.

neuralmonkey.trainers.self_critical_objective.self_critical_objective(decoder: neuralmonkey.decoders.decoder.Decoder, reward_function: typing.Callable[[numpy.ndarray, numpy.ndarray], numpy.ndarray], weight: float = None) → neuralmonkey.trainers.generic_trainer.Objective

Self-critical objective.

  • decoder – A recurrent decoder.
  • reward_function – A reward function computing score in Python.
  • weight – Mixing weight for a trainer.

Objective object to be used in generic trainer.

neuralmonkey.trainers.self_critical_objective.sentence_bleu(references: numpy.ndarray, hypotheses: numpy.ndarray) → numpy.ndarray

Compute index-based sentence-level BLEU score.

Computes sentence level BLEU on indices outputed by the decoder, i.e. whatever the decoder uses as a unit is used a token in the BLEU computation, ignoring the tokens may be sub-word units.

neuralmonkey.trainers.self_critical_objective.sentence_gleu(references: numpy.ndarray, hypotheses: numpy.ndarray) → numpy.ndarray

Compute index-based GLEU score.

GLEU score is a sentence-level metric used in Google’s Neural MT as a reward in reinforcement learning ( It is a minimum of precision and recall on 1- to 4-grams.

It operates over the indices emitted by the decoder which are not necessarily tokens (could be characters or subword units).

Module contents