Use at inference a multi-task learning model shared in Huggingface hub

I train with success a multi-task bert model. My Bert model works by having a shared BERT-style encoder transformer, and two different task heads for each task. The two heads are a binary classification head (num_label =2) and a sentiment classification head (num_label = 5)

I try to share it on the hub and reload it after for inference. But i failed.

Here is the code :

class SequenceClassificationHead(nn.Module):
    def __init__(self, hidden_size, num_labels, dropout_p=0.1):

        super().__init__()
        self.num_labels = num_labels
        self.dropout = nn.Dropout(dropout_p)
        self.classifier = nn.Linear(hidden_size, num_labels)

        self._init_weights()

    def _init_weights(self):
        self.classifier.weight.data.normal_(mean=0.0, std=0.02)
        if self.classifier.bias is not None:
            self.classifier.bias.data.zero_()

    def forward(self, sequence_output, pooled_output, labels=None, **kwargs):
        pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)

        loss = None
        if labels is not None:
            
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(
                logits.view(-1, self.num_labels), labels.long().view(-1)
            )
class MultiTaskModel(BertPreTrainedModel):
    def __init__(self, checkpoint, tasks: List):
        super().__init__(PretrainedConfig())

        self.encoder = BertModel.from_pretrained(checkpoint)

        self.output_heads = nn.ModuleDict()
        for task in tasks:
            decoder = self._create_output_head(self.encoder.config.hidden_size, task)
            # ModuleDict requires keys to be strings
            self.output_heads[str(task.id)] = decoder

    @staticmethod
    def _create_output_head(encoder_hidden_size: int, task):
        if task.type == "seq_classification":
            return SequenceClassificationHead(encoder_hidden_size, task.num_labels)
        else:
            raise NotImplementedError()

    def forward(
        self,
        input_ids=None,
        attention_mask=None,
        token_type_ids=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        labels=None,
        task_ids=None,
        **kwargs,

        ):

        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            position_ids=position_ids,
            head_mask=head_mask,
            inputs_embeds=inputs_embeds,
        )

        sequence_output, pooled_output = outputs[:2]
        unique_task_ids_list = torch.unique(task_ids).tolist()

        loss_list = []
        logits = None
        for unique_task_id in unique_task_ids_list:

            task_id_filter = task_ids == unique_task_id
            logits, task_loss = self.output_heads[str(unique_task_id)].forward(
                sequence_output[task_id_filter],
                pooled_output[task_id_filter],
                labels=None if labels is None else labels[task_id_filter],
                attention_mask=attention_mask[task_id_filter],
            )

I train it with the trainer API and share it with the Trainer API. It works but when i want to use it for inference and load from the hub i have this message :

loading file vocab.txt from cache at /root/.cache/huggingface/hub/models–HCKLab–BiBert-MultiTask/snapshots/f3523728d3e144c0b7d262f6ff924cc174bc0d03/vocab.txt
loading file tokenizer.json from cache at /root/.cache/huggingface/hub/models–HCKLab–BiBert-MultiTask/snapshots/f3523728d3e144c0b7d262f6ff924cc174bc0d03/tokenizer.json
loading file added_tokens.json from cache at None
loading file special_tokens_map.json from cache at /root/.cache/huggingface/hub/models–HCKLab–BiBert-MultiTask/snapshots/f3523728d3e144c0b7d262f6ff924cc174bc0d03/special_tokens_map.json
loading file tokenizer_config.json from cache at /root/.cache/huggingface/hub/models–HCKLab–BiBert-MultiTask/snapshots/f3523728d3e144c0b7d262f6ff924cc174bc0d03/tokenizer_config.json
loading configuration file config.json from cache at /root/.cache/huggingface/hub/models–HCKLab–BiBert-MultiTask/snapshots/f3523728d3e144c0b7d262f6ff924cc174bc0d03/config.json
Model config BertConfig {
“architectures”: [
“MultiTaskModel”
],
“attention_probs_dropout_prob”: 0.1,
“classifier_dropout”: null,
“hidden_act”: “gelu”,
“hidden_dropout_prob”: 0.1,
“hidden_size”: 768,
“initializer_range”: 0.02,
“intermediate_size”: 3072,
“layer_norm_eps”: 1e-12,
“max_position_embeddings”: 512,
“model_type”: “bert”,
“num_attention_heads”: 12,
“num_hidden_layers”: 12,
“pad_token_id”: 0,
“position_embedding_type”: “absolute”,
“torch_dtype”: “float32”,
“transformers_version”: “4.22.1”,
“type_vocab_size”: 2,
“use_cache”: true,
“vocab_size”: 30522
}

loading weights file pytorch_model.bin from cache at /root/.cache/huggingface/hub/models–HCKLab–BiBert-MultiTask/snapshots/f3523728d3e144c0b7d262f6ff924cc174bc0d03/pytorch_model.bin
Some weights of the model checkpoint at HCKLab/BiBert-MultiTask were not used when initializing BertModel: [‘encoder.encoder.layer.4.output.LayerNorm.bias’, ‘encoder.encoder.layer.3.attention.self.key.weight’, ‘encoder.encoder.layer.1.attention.self.query.bias’, ‘encoder.encoder.layer.4.attention.self.query.bias’, ‘encoder.encoder.layer.5.output.LayerNorm.bias’, ‘encoder.encoder.layer.4.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.11.attention.output.dense.bias’, ‘encoder.encoder.layer.2.attention.self.query.bias’, ‘encoder.pooler.dense.weight’, ‘encoder.encoder.layer.6.intermediate.dense.weight’, ‘encoder.encoder.layer.1.attention.self.key.bias’, ‘encoder.encoder.layer.7.attention.output.dense.weight’, ‘encoder.encoder.layer.9.attention.output.LayerNorm.weight’, ‘encoder.embeddings.LayerNorm.bias’, ‘encoder.encoder.layer.8.intermediate.dense.bias’, ‘encoder.encoder.layer.4.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.4.attention.self.value.weight’, ‘encoder.encoder.layer.5.output.dense.bias’, ‘encoder.encoder.layer.2.output.LayerNorm.weight’, ‘encoder.encoder.layer.5.output.LayerNorm.weight’, ‘encoder.encoder.layer.6.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.7.output.dense.weight’, ‘encoder.encoder.layer.7.intermediate.dense.bias’, ‘encoder.encoder.layer.9.output.dense.bias’, ‘encoder.encoder.layer.4.output.dense.weight’, ‘encoder.encoder.layer.10.attention.self.key.weight’, ‘encoder.encoder.layer.11.output.dense.bias’, ‘encoder.embeddings.position_embeddings.weight’, ‘encoder.encoder.layer.1.attention.self.value.bias’, ‘encoder.encoder.layer.6.attention.self.value.weight’, ‘encoder.encoder.layer.10.attention.self.value.bias’, ‘encoder.encoder.layer.6.attention.output.dense.bias’, ‘encoder.encoder.layer.5.attention.self.query.weight’, ‘encoder.encoder.layer.11.attention.output.dense.weight’, ‘encoder.encoder.layer.0.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.0.attention.self.key.weight’, ‘encoder.encoder.layer.11.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.1.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.3.output.LayerNorm.bias’, ‘encoder.encoder.layer.0.intermediate.dense.weight’, ‘encoder.encoder.layer.8.attention.self.query.weight’, ‘encoder.encoder.layer.10.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.3.attention.output.dense.bias’, ‘encoder.encoder.layer.3.output.LayerNorm.weight’, ‘encoder.encoder.layer.10.attention.self.key.bias’, ‘encoder.encoder.layer.1.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.5.attention.self.key.weight’, ‘encoder.encoder.layer.7.attention.self.key.weight’, ‘encoder.encoder.layer.9.attention.self.key.bias’, ‘encoder.encoder.layer.6.attention.self.query.bias’, ‘encoder.encoder.layer.9.output.LayerNorm.bias’, ‘encoder.encoder.layer.10.attention.output.dense.weight’, ‘encoder.encoder.layer.1.output.LayerNorm.bias’, ‘encoder.encoder.layer.0.output.dense.bias’, ‘encoder.encoder.layer.11.attention.self.value.weight’, ‘encoder.encoder.layer.6.attention.self.query.weight’, ‘encoder.encoder.layer.2.attention.output.LayerNorm.bias’, ‘output_heads.0.classifier.bias’, ‘encoder.encoder.layer.10.output.dense.weight’, ‘encoder.encoder.layer.5.attention.self.query.bias’, ‘encoder.encoder.layer.8.attention.output.dense.weight’, ‘encoder.encoder.layer.8.intermediate.dense.weight’, ‘encoder.encoder.layer.1.intermediate.dense.weight’, ‘encoder.encoder.layer.7.attention.self.query.bias’, ‘encoder.embeddings.token_type_embeddings.weight’, ‘encoder.encoder.layer.5.intermediate.dense.weight’, ‘encoder.encoder.layer.4.attention.output.dense.weight’, ‘encoder.encoder.layer.9.intermediate.dense.weight’, ‘encoder.encoder.layer.7.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.10.attention.output.dense.bias’, ‘encoder.encoder.layer.3.output.dense.weight’, ‘encoder.encoder.layer.11.attention.self.query.weight’, ‘encoder.encoder.layer.6.attention.self.key.bias’, ‘encoder.encoder.layer.8.output.dense.weight’, ‘encoder.encoder.layer.0.attention.self.value.bias’, ‘encoder.encoder.layer.0.attention.self.query.weight’, ‘encoder.pooler.dense.bias’, ‘encoder.encoder.layer.8.output.LayerNorm.bias’, ‘encoder.encoder.layer.6.attention.output.dense.weight’, ‘encoder.encoder.layer.7.attention.self.value.bias’, ‘encoder.embeddings.position_ids’, ‘encoder.encoder.layer.10.attention.self.value.weight’, ‘encoder.encoder.layer.10.output.dense.bias’, ‘encoder.encoder.layer.7.attention.output.LayerNorm.bias’, ‘output_heads.0.classifier.weight’, ‘encoder.encoder.layer.8.output.LayerNorm.weight’, ‘encoder.encoder.layer.6.attention.self.key.weight’, ‘encoder.encoder.layer.0.intermediate.dense.bias’, ‘encoder.encoder.layer.2.attention.output.LayerNorm.weight’, ‘encoder.embeddings.word_embeddings.weight’, ‘encoder.encoder.layer.4.attention.self.key.bias’, ‘encoder.encoder.layer.6.output.dense.bias’, ‘encoder.encoder.layer.2.attention.self.value.bias’, ‘encoder.encoder.layer.5.attention.self.key.bias’, ‘encoder.encoder.layer.2.attention.self.key.weight’, ‘encoder.encoder.layer.5.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.11.attention.self.key.bias’, ‘encoder.encoder.layer.1.attention.self.key.weight’, ‘encoder.encoder.layer.0.output.LayerNorm.bias’, ‘encoder.encoder.layer.2.attention.self.value.weight’, ‘encoder.encoder.layer.2.intermediate.dense.weight’, ‘encoder.encoder.layer.4.attention.self.query.weight’, ‘encoder.encoder.layer.5.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.5.attention.output.dense.weight’, ‘encoder.encoder.layer.9.intermediate.dense.bias’, ‘encoder.encoder.layer.3.attention.self.value.weight’, ‘encoder.encoder.layer.11.output.LayerNorm.weight’, ‘encoder.encoder.layer.6.attention.self.value.bias’, ‘encoder.encoder.layer.7.attention.output.dense.bias’, ‘encoder.encoder.layer.7.attention.self.query.weight’, ‘encoder.encoder.layer.3.intermediate.dense.bias’, ‘encoder.encoder.layer.11.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.1.attention.output.dense.bias’, ‘encoder.encoder.layer.11.attention.self.query.bias’, ‘encoder.encoder.layer.5.attention.output.dense.bias’, ‘encoder.encoder.layer.8.attention.self.value.bias’, ‘encoder.encoder.layer.7.output.LayerNorm.weight’, ‘output_heads.1.classifier.weight’, ‘encoder.encoder.layer.2.intermediate.dense.bias’, ‘encoder.encoder.layer.10.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.9.attention.self.value.bias’, ‘encoder.encoder.layer.10.output.LayerNorm.weight’, ‘encoder.encoder.layer.10.output.LayerNorm.bias’, ‘encoder.encoder.layer.5.attention.self.value.bias’, ‘encoder.encoder.layer.9.attention.self.query.bias’, ‘encoder.encoder.layer.8.attention.self.query.bias’, ‘encoder.encoder.layer.11.output.dense.weight’, ‘output_heads.1.classifier.bias’, ‘encoder.encoder.layer.4.attention.output.dense.bias’, ‘encoder.encoder.layer.2.output.dense.weight’, ‘encoder.encoder.layer.1.output.LayerNorm.weight’, ‘encoder.encoder.layer.2.attention.output.dense.bias’, ‘encoder.encoder.layer.9.output.LayerNorm.weight’, ‘encoder.encoder.layer.2.output.dense.bias’, ‘encoder.encoder.layer.9.attention.output.dense.bias’, ‘encoder.encoder.layer.10.attention.self.query.bias’, ‘encoder.encoder.layer.7.intermediate.dense.weight’, ‘encoder.encoder.layer.0.attention.output.dense.bias’, ‘encoder.encoder.layer.11.attention.self.value.bias’, ‘encoder.encoder.layer.3.intermediate.dense.weight’, ‘encoder.encoder.layer.3.attention.self.query.bias’, ‘encoder.encoder.layer.8.attention.self.value.weight’, ‘encoder.encoder.layer.11.intermediate.dense.bias’, ‘encoder.encoder.layer.5.output.dense.weight’, ‘encoder.encoder.layer.2.output.LayerNorm.bias’, ‘encoder.encoder.layer.10.intermediate.dense.weight’, ‘encoder.encoder.layer.11.intermediate.dense.weight’, ‘encoder.encoder.layer.5.attention.self.value.weight’, ‘encoder.encoder.layer.9.attention.output.dense.weight’, ‘encoder.encoder.layer.2.attention.output.dense.weight’, ‘encoder.encoder.layer.6.output.dense.weight’, ‘encoder.encoder.layer.1.output.dense.bias’, ‘encoder.encoder.layer.3.attention.self.value.bias’, ‘encoder.encoder.layer.3.attention.output.dense.weight’, ‘encoder.encoder.layer.4.intermediate.dense.bias’, ‘encoder.encoder.layer.0.attention.self.value.weight’, ‘encoder.encoder.layer.9.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.7.attention.self.value.weight’, ‘encoder.encoder.layer.10.intermediate.dense.bias’, ‘encoder.encoder.layer.5.intermediate.dense.bias’, ‘encoder.encoder.layer.8.output.dense.bias’, ‘encoder.encoder.layer.3.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.4.output.dense.bias’, ‘encoder.encoder.layer.4.output.LayerNorm.weight’, ‘encoder.encoder.layer.8.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.0.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.4.intermediate.dense.weight’, ‘encoder.encoder.layer.6.output.LayerNorm.weight’, ‘encoder.encoder.layer.9.attention.self.key.weight’, ‘encoder.encoder.layer.3.output.dense.bias’, ‘encoder.encoder.layer.0.attention.output.dense.weight’, ‘encoder.encoder.layer.9.output.dense.weight’, ‘encoder.encoder.layer.0.output.LayerNorm.weight’, ‘encoder.encoder.layer.11.output.LayerNorm.bias’, ‘encoder.encoder.layer.3.attention.self.query.weight’, ‘encoder.encoder.layer.0.attention.self.query.bias’, ‘encoder.encoder.layer.0.attention.self.key.bias’, ‘encoder.encoder.layer.3.attention.self.key.bias’, ‘encoder.encoder.layer.1.attention.output.dense.weight’, ‘encoder.encoder.layer.7.output.dense.bias’, ‘encoder.encoder.layer.9.attention.self.query.weight’, ‘encoder.encoder.layer.8.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.10.attention.self.query.weight’, ‘encoder.encoder.layer.4.attention.self.value.bias’, ‘encoder.encoder.layer.3.attention.output.LayerNorm.bias’, ‘encoder.encoder.layer.8.attention.output.dense.bias’, ‘encoder.encoder.layer.7.attention.self.key.bias’, ‘encoder.encoder.layer.0.output.dense.weight’, ‘encoder.encoder.layer.11.attention.self.key.weight’, ‘encoder.encoder.layer.8.attention.self.key.bias’, ‘encoder.embeddings.LayerNorm.weight’, ‘encoder.encoder.layer.2.attention.self.query.weight’, ‘encoder.encoder.layer.6.output.LayerNorm.bias’, ‘encoder.encoder.layer.7.output.LayerNorm.bias’, ‘encoder.encoder.layer.2.attention.self.key.bias’, ‘encoder.encoder.layer.6.intermediate.dense.bias’, ‘encoder.encoder.layer.6.attention.output.LayerNorm.weight’, ‘encoder.encoder.layer.9.attention.self.value.weight’, ‘encoder.encoder.layer.1.intermediate.dense.bias’, ‘encoder.encoder.layer.1.attention.self.query.weight’, ‘encoder.encoder.layer.4.attention.self.key.weight’, ‘encoder.encoder.layer.1.output.dense.weight’, ‘encoder.encoder.layer.8.attention.self.key.weight’, ‘encoder.encoder.layer.1.attention.self.value.weight’]

    This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
    This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
    Some weights of BertModel were not initialized from the model checkpoint at HCKLab/BiBert-MultiTask and are newly initialized: [‘encoder.layer.0.intermediate.dense.weight’, ‘encoder.layer.7.attention.self.key.weight’, ‘encoder.layer.6.attention.self.query.weight’, ‘encoder.layer.5.output.LayerNorm.bias’, ‘encoder.layer.7.attention.output.LayerNorm.weight’, ‘encoder.layer.6.attention.output.LayerNorm.weight’, ‘encoder.layer.3.output.dense.weight’, ‘encoder.layer.2.attention.output.dense.bias’, ‘encoder.layer.0.attention.output.LayerNorm.weight’, ‘encoder.layer.1.attention.self.key.bias’, ‘encoder.layer.0.attention.output.dense.weight’, ‘encoder.layer.1.attention.self.value.bias’, ‘encoder.layer.4.attention.self.value.weight’, ‘encoder.layer.1.attention.output.dense.bias’, ‘encoder.layer.7.intermediate.dense.bias’, ‘encoder.layer.2.output.LayerNorm.bias’, ‘encoder.layer.8.intermediate.dense.bias’, ‘encoder.layer.0.output.dense.bias’, ‘encoder.layer.10.intermediate.dense.weight’, ‘encoder.layer.5.attention.self.query.bias’, ‘encoder.layer.2.attention.self.query.weight’, ‘encoder.layer.5.attention.self.query.weight’, ‘encoder.layer.0.intermediate.dense.bias’, ‘encoder.layer.8.intermediate.dense.weight’, ‘encoder.layer.10.output.dense.bias’, ‘encoder.layer.0.attention.self.key.weight’, ‘encoder.layer.5.attention.output.dense.bias’, ‘encoder.layer.5.output.LayerNorm.weight’, ‘encoder.layer.7.intermediate.dense.weight’, ‘encoder.layer.8.output.dense.bias’, ‘encoder.layer.9.attention.self.key.bias’, ‘encoder.layer.11.output.dense.weight’, ‘encoder.layer.9.attention.self.key.weight’, ‘embeddings.LayerNorm.bias’, ‘encoder.layer.6.intermediate.dense.weight’, ‘encoder.layer.7.attention.self.query.bias’, ‘encoder.layer.1.intermediate.dense.weight’, ‘encoder.layer.7.attention.self.key.bias’, ‘encoder.layer.11.attention.output.dense.bias’, ‘encoder.layer.4.output.LayerNorm.weight’, ‘encoder.layer.7.attention.output.dense.weight’, ‘encoder.layer.11.attention.output.LayerNorm.bias’, ‘encoder.layer.5.output.dense.bias’, ‘encoder.layer.3.attention.self.query.bias’, ‘encoder.layer.8.attention.self.key.bias’, ‘encoder.layer.11.attention.self.query.bias’, ‘encoder.layer.1.attention.output.LayerNorm.weight’, ‘encoder.layer.4.attention.output.LayerNorm.weight’, ‘pooler.dense.bias’, ‘encoder.layer.3.intermediate.dense.weight’, ‘encoder.layer.10.attention.self.query.bias’, ‘encoder.layer.8.output.LayerNorm.weight’, ‘encoder.layer.7.attention.output.LayerNorm.bias’, ‘encoder.layer.4.output.LayerNorm.bias’, ‘encoder.layer.3.attention.self.query.weight’, ‘encoder.layer.1.output.dense.weight’, ‘encoder.layer.4.output.dense.bias’, ‘encoder.layer.10.attention.self.value.bias’, ‘encoder.layer.4.attention.self.query.weight’, ‘encoder.layer.7.output.dense.weight’, ‘encoder.layer.2.attention.self.query.bias’, ‘encoder.layer.1.intermediate.dense.bias’, ‘encoder.layer.10.output.LayerNorm.weight’, ‘encoder.layer.2.attention.self.value.bias’, ‘encoder.layer.11.attention.self.key.bias’, ‘encoder.layer.4.attention.output.LayerNorm.bias’, ‘encoder.layer.8.attention.output.dense.bias’, ‘encoder.layer.2.attention.self.value.weight’, ‘encoder.layer.6.output.LayerNorm.bias’, ‘encoder.layer.8.attention.self.key.weight’, ‘encoder.layer.0.attention.self.query.weight’, ‘encoder.layer.6.attention.self.query.bias’, ‘encoder.layer.8.attention.self.query.weight’, ‘encoder.layer.4.attention.output.dense.weight’, ‘encoder.layer.6.output.dense.weight’, ‘encoder.layer.11.attention.output.LayerNorm.weight’, ‘encoder.layer.9.attention.output.LayerNorm.weight’, ‘encoder.layer.11.output.dense.bias’, ‘encoder.layer.1.output.LayerNorm.weight’, ‘encoder.layer.1.attention.output.dense.weight’, ‘encoder.layer.6.attention.self.value.bias’, ‘encoder.layer.7.attention.output.dense.bias’, ‘encoder.layer.8.attention.self.value.bias’, ‘encoder.layer.5.attention.self.value.bias’, ‘encoder.layer.3.intermediate.dense.bias’, ‘encoder.layer.11.intermediate.dense.bias’, ‘encoder.layer.9.attention.self.value.bias’, ‘encoder.layer.1.attention.self.key.weight’, ‘encoder.layer.9.attention.self.query.weight’, ‘encoder.layer.9.attention.self.value.weight’, ‘encoder.layer.4.attention.self.key.weight’, ‘embeddings.LayerNorm.weight’, ‘encoder.layer.3.attention.output.LayerNorm.bias’, ‘encoder.layer.2.attention.self.key.weight’, ‘encoder.layer.9.intermediate.dense.weight’, ‘encoder.layer.8.attention.output.LayerNorm.weight’, ‘encoder.layer.5.intermediate.dense.bias’, ‘embeddings.token_type_embeddings.weight’, ‘encoder.layer.7.output.LayerNorm.bias’, ‘encoder.layer.7.attention.self.value.bias’, ‘encoder.layer.9.attention.self.query.bias’, ‘encoder.layer.3.attention.self.key.weight’, ‘encoder.layer.3.attention.output.dense.bias’, ‘encoder.layer.0.output.dense.weight’, ‘encoder.layer.6.attention.self.key.bias’, ‘encoder.layer.4.intermediate.dense.weight’, ‘encoder.layer.8.attention.self.value.weight’, ‘encoder.layer.10.attention.self.key.bias’, ‘encoder.layer.7.attention.self.value.weight’, ‘encoder.layer.11.attention.self.value.weight’, ‘pooler.dense.weight’, ‘encoder.layer.8.attention.self.query.bias’, ‘encoder.layer.0.attention.self.key.bias’, ‘encoder.layer.9.output.dense.weight’, ‘encoder.layer.10.attention.output.LayerNorm.weight’, ‘encoder.layer.9.output.LayerNorm.bias’, ‘encoder.layer.2.intermediate.dense.weight’, ‘encoder.layer.10.attention.self.query.weight’, ‘encoder.layer.11.attention.self.value.bias’, ‘encoder.layer.0.attention.output.dense.bias’, ‘encoder.layer.1.attention.self.value.weight’, ‘encoder.layer.0.output.LayerNorm.bias’, ‘encoder.layer.6.attention.self.key.weight’, ‘encoder.layer.6.attention.output.LayerNorm.bias’, ‘encoder.layer.7.attention.self.query.weight’, ‘encoder.layer.6.attention.output.dense.bias’, ‘encoder.layer.5.attention.self.value.weight’, ‘encoder.layer.3.attention.self.value.weight’, ‘encoder.layer.5.output.dense.weight’, ‘encoder.layer.4.intermediate.dense.bias’, ‘encoder.layer.5.attention.output.LayerNorm.weight’, ‘encoder.layer.1.output.LayerNorm.bias’, ‘encoder.layer.7.output.LayerNorm.weight’, ‘encoder.layer.3.output.LayerNorm.weight’, ‘encoder.layer.5.attention.output.dense.weight’, ‘encoder.layer.11.attention.self.key.weight’, ‘encoder.layer.9.attention.output.dense.bias’, ‘encoder.layer.6.output.dense.bias’, ‘encoder.layer.2.output.dense.weight’, ‘encoder.layer.11.intermediate.dense.weight’, ‘encoder.layer.11.output.LayerNorm.weight’, ‘encoder.layer.1.attention.self.query.bias’, ‘encoder.layer.2.attention.output.dense.weight’, ‘encoder.layer.2.output.LayerNorm.weight’, ‘encoder.layer.0.attention.self.query.bias’, ‘encoder.layer.1.attention.output.LayerNorm.bias’, ‘encoder.layer.9.attention.output.dense.weight’, ‘encoder.layer.10.intermediate.dense.bias’, ‘encoder.layer.9.intermediate.dense.bias’, ‘embeddings.word_embeddings.weight’, ‘encoder.layer.0.attention.output.LayerNorm.bias’, ‘encoder.layer.6.intermediate.dense.bias’, ‘encoder.layer.8.output.LayerNorm.bias’, ‘encoder.layer.4.output.dense.weight’, ‘encoder.layer.10.output.dense.weight’, ‘encoder.layer.9.output.dense.bias’, ‘encoder.layer.10.attention.output.dense.weight’, ‘encoder.layer.6.attention.output.dense.weight’, ‘encoder.layer.4.attention.self.query.bias’, ‘encoder.layer.6.output.LayerNorm.weight’, ‘encoder.layer.11.attention.self.query.weight’, ‘encoder.layer.2.attention.output.LayerNorm.weight’, ‘encoder.layer.1.attention.self.query.weight’, ‘encoder.layer.3.attention.self.key.bias’, ‘encoder.layer.7.output.dense.bias’, ‘encoder.layer.0.output.LayerNorm.weight’, ‘encoder.layer.3.attention.output.LayerNorm.weight’, ‘encoder.layer.5.intermediate.dense.weight’, ‘encoder.layer.6.attention.self.value.weight’, ‘encoder.layer.8.attention.output.dense.weight’, ‘encoder.layer.11.attention.output.dense.weight’, ‘encoder.layer.10.attention.output.LayerNorm.bias’, ‘encoder.layer.3.attention.self.value.bias’, ‘encoder.layer.10.attention.self.key.weight’, ‘encoder.layer.4.attention.output.dense.bias’, ‘encoder.layer.4.attention.self.key.bias’, ‘encoder.layer.5.attention.output.LayerNorm.bias’, ‘encoder.layer.10.output.LayerNorm.bias’, ‘encoder.layer.2.attention.output.LayerNorm.bias’, ‘encoder.layer.0.attention.self.value.bias’, ‘embeddings.position_embeddings.weight’, ‘encoder.layer.2.intermediate.dense.bias’, ‘encoder.layer.9.attention.output.LayerNorm.bias’, ‘encoder.layer.10.attention.output.dense.bias’, ‘encoder.layer.8.output.dense.weight’, ‘encoder.layer.11.output.LayerNorm.bias’, ‘encoder.layer.2.attention.self.key.bias’, ‘encoder.layer.4.attention.self.value.bias’, ‘encoder.layer.5.attention.self.key.weight’, ‘encoder.layer.8.attention.output.LayerNorm.bias’, ‘encoder.layer.9.output.LayerNorm.weight’, ‘encoder.layer.10.attention.self.value.weight’, ‘encoder.layer.1.output.dense.bias’, ‘encoder.layer.3.output.dense.bias’, ‘encoder.layer.3.attention.output.dense.weight’, ‘encoder.layer.2.output.dense.bias’, ‘encoder.layer.3.output.LayerNorm.bias’, ‘encoder.layer.0.attention.self.value.weight’, ‘encoder.layer.5.attention.self.key.bias’]
    You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

To load it from the hub i do :

from transformers import BertModel
checkpoint =“HCKLab/BiBert-MultiTask”
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

model1 = MultiTaskModel(checkpoint, tasks).to(device)

What do i miss and why when i shared the model after training to the hub i dont save the weights?



Comments

Popular posts from this blog

Today Walkin 14th-Sept

Spring Elasticsearch Operations

Hibernate Search - Elasticsearch with JSON manipulation