Skip to content

Commit

Permalink
check signing my commit
Browse files Browse the repository at this point in the history
  • Loading branch information
adelmemariani committed Dec 17, 2024
1 parent 2df1d30 commit 741b129
Showing 1 changed file with 8 additions and 8 deletions.
16 changes: 8 additions & 8 deletions dicee/models/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,14 +136,6 @@ def __init__(self, args: dict):
self.kernel_size = None
self.num_of_output_channels = None
self.weight_decay = None

if self.args["loss_fn"] == "BCELoss":
self.loss = torch.nn.BCEWithLogitsLoss()
if self.args["loss_fn"] == "LRLoss":
self.loss = LabelRelaxationLoss()
else:
self.loss = torch.nn.BCEWithLogitsLoss()

self.selected_optimizer = None
self.normalizer_class = None
self.normalize_head_entity_embeddings = IdentityClass()
Expand All @@ -162,6 +154,14 @@ def __init__(self, args: dict):
self.byte_pair_encoding = self.args.get("byte_pair_encoding", False)
self.max_length_subword_tokens = self.args.get("max_length_subword_tokens", None)
self.block_size=self.args.get("block_size", None)

if self.args["loss_fn"] == "BCELoss":
self.loss = torch.nn.BCEWithLogitsLoss()
if self.args["loss_fn"] == "LRLoss":
self.loss = LabelRelaxationLoss()
else:
self.loss = torch.nn.BCEWithLogitsLoss()

if self.byte_pair_encoding and self.args['model'] != "BytE":
self.token_embeddings = torch.nn.Embedding(self.num_tokens, self.embedding_dim)
self.param_init(self.token_embeddings.weight.data)
Expand Down

0 comments on commit 741b129

Please sign in to comment.