DistilBERT + Adapted HMCN-F
API
Configuration schema
The configuration for this model defines the following hyperparameters:
classifier_lr
: Classifier (Adapted HMCN-F) learning rate.lambda_h
: Hierarchical loss gain.dropout
: The adapted HMCN-F model’s dropout rate.global_hidden_sizes
: Hidden sizes for constructing the global layers (main flow).local_hidden_sizes
: Hidden sizes for constructing the local branches (local flow).global_weight
: Global weight to be forward-propagated.hidden_nonlinear
: Hidden nonlinear type which can be eitherrelu
ortanhh
.
Default tuning configuration
"db_ahmcnf": {
"display_name": "DistilBERT + Adapted HMCN-F",
"max_len": 64,
"train_minibatch_size": 4,
"val_test_minibatch_size": 4,
"range": {
"classifier_lr": [0.0005,0.002],
"lambda_h": [0.02, 0.08],
"dropout": [0.1, 0.7],
"global_hidden_sizes": [384, 384],
"local_hidden_sizes": [384, 384],
"global_weight": [0.2, 0.8],
"hidden_nonlinear": "relu"
},
"mode": {
"classifier_lr": "uniform",
"lambda_h": "uniform",
"dropout": "uniform",
"global_hidden_sizes": "fixed",
"local_hidden_sizes": "fixed",
"global_weight": "uniform",
"hidden_nonlinear": "fixed"
}
},
Checkpoint schema
config
: A copy of the configuration dictionary passed to this instance’s constructor, either explicitly, or byfrom_checkpoint
(extracted from a prior checkpoint).hierarchy
: A serialised dictionary of hierarchical metadata created byPerLevelHierarchy.to_dict()
.classifier_state_dict
: Weights of the classifier (Adapted HMCN-F).optimizer_state_dict
: Saved state of the optimiser that was used to train the model for that checkpoint.
Theory
This replaces the simple linear layer found in the DistilBERT + Linear model with an adapted version of the HMCN-F model [WCB18]. HMCN-F is a fully-connected neural network with two flows, a global and a local one, designed to maximise the learning capabilities regarding the hierarchical nature of the data. An example computation graph for the classifier is given below.
The global flow starts at the vectorised input and continues to a global prediction layer, which predicts classes for all levels at once, using information from the previous per-level global layers, each of which takes in a concatenation of the vectorised input and the previous layer’s ReLU output, then outputs a hidden vector of the same arbitrarily-chosen size which will again become half of the next layer’s input. The local flows take their inputs as residual outputs from the global flow’s per-level layers through a per-level transition layer, then to a per-level local classification layer. These classification layers each predict the class at their level using information transformed from the corresponding global flow’s level. The input of all local flow classification layers are then concatenated from left to right. For this reason, the global classification layer’s output scores are also ordered by hierarchical level.
The final output of the model is then a weighted average of the global output and concatenated local outputs with ratio \(\beta\):
where \(P_L\) is the concatenation of local outputs from the top to the bottom levels of the hierarchy DAG:
where ++ signifies 1D vector concatenation along their main axis. This ordering has a second benefit of allowing for very simple parsing: instead of having to train the model against a threshold that picks out the top \(|H|\) categories (for \(|H|\) levels) and achieve neither hierarchical compliance nor even the basic guarantee that each top-performing category belongs to a different level, we can simply slice the output vector into per-level segments, then pass each through an \(argmax\) function to determine the predictions.
We employ three loss functions. The first is a global loss function used to train the global flow, the second is the sum of losses for each of the local classification layers’ outputs, and the third is used on the final output \(P_F\) to directly train the model against hierarchical constraints. For the third loss function, we define a hierarchical violation as when the predicted score of a node is larger than its parent node’s score. All three loss functions revolve around the averaged binary cross-entropy loss (BCE) formula:
where \(mean_{n=1}^{|N|}()\) averages over the N examples in \(f(x)\), in case minibatch training is implemented. With this base BCE loss, we construct the aforementioned three loss functions, the global, local and hierarchical loss respectively:
We have implemented \(parent()\) as a lookup vector where each element contains the index of the parent node of the one at that element’s index. As we do not encode the root node, nodes immediately below it have their parents set to themselves, in effect zeroing out the formula which is in line with the expected behaviour (as technically they do not have to obey any hierarchical constraint at all). PyTorch’s slicing capabilities retains the order of indexing and make this possible without using un-accelerated loops, ensuring little to no overhead. The final loss function is then the sum of the three above functions:
where \(\lambda\) is the hierarchical loss scale factor.