From 6a3d97d54dbfe8b936f489a3d32e5e3b971740ba Mon Sep 17 00:00:00 2001 From: Miquel Anglada Girotto <45081549+MiqG@users.noreply.github.com> Date: Thu, 5 Dec 2024 19:47:06 +0100 Subject: [PATCH] Added tokenizer argument in StripedHyenaModelForExtractingEmbeddings class --- training-model/modeling_hyena.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/training-model/modeling_hyena.py b/training-model/modeling_hyena.py index 8a615a8..e0641b3 100644 --- a/training-model/modeling_hyena.py +++ b/training-model/modeling_hyena.py @@ -219,10 +219,10 @@ class CausalLMEmbeddingOutput(ModelOutput): class StripedHyenaModelForExtractingEmbeddings(StripedHyenaPreTrainedModel): supports_gradient_checkpointing = True - def __init__(self, config, **kwargs): + def __init__(self, config, tokenizer, **kwargs): super().__init__(config, **kwargs) model_config = dotdict(config.to_dict()) - self.backbone = StripedHyenaForEmbeddings(model_config) + self.backbone = StripedHyenaForEmbeddings(model_config, tokenizer) self.backbone.gradient_checkpointing = False self.config = config self.post_init() @@ -297,4 +297,4 @@ def forward( return CausalLMEmbeddingOutput( loss=loss, - e_token_embs=e_token_embeddings) \ No newline at end of file + e_token_embs=e_token_embeddings)