Skip to content
This repository has been archived by the owner on Nov 21, 2022. It is now read-only.

Commit

Permalink
Remove kwargs
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanNaren committed Jun 23, 2022
1 parent f070bd0 commit d972675
Show file tree
Hide file tree
Showing 5 changed files with 2 additions and 25 deletions.
9 changes: 0 additions & 9 deletions lightning_transformers/core/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -212,12 +212,3 @@ def predict_dataloader(self) -> Optional[DataLoader]:
@property
def collate_fn(self) -> Optional[Callable]:
return None

@property
def model_data_kwargs(self) -> Dict:
"""Override to provide the model with additional kwargs.
This is useful to provide the number of classes/pixels to the model or any other data specific args
Returns: Dict of args
"""
return {}
1 change: 0 additions & 1 deletion lightning_transformers/core/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,6 @@ class TaskTransformer(pl.LightningModule):
pretrained_model_name_or_path: Huggingface model to use if backbone config not passed.
tokenizer: The pre-trained tokenizer.
pipeline_kwargs: Arguments required for the HuggingFace inference pipeline class.
**model_data_kwargs: Arguments passed from the data module to the class.
"""

def __init__(
Expand Down
5 changes: 0 additions & 5 deletions lightning_transformers/task/nlp/multiple_choice/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Dict

from transformers import default_data_collator

Expand All @@ -35,7 +34,3 @@ def collate_fn(self) -> callable:
@property
def num_classes(self) -> int:
raise NotImplementedError

@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}
6 changes: 1 addition & 5 deletions lightning_transformers/task/nlp/text_classification/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, List, Optional
from typing import Any, List, Optional

from datasets import ClassLabel, Dataset
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -53,10 +53,6 @@ def num_classes(self) -> int:
self.setup("fit")
return self.labels.num_classes

@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}

@staticmethod
def convert_to_features(
example_batch: Any, _, tokenizer: PreTrainedTokenizerBase, input_feature_fields: List[str], **tokenizer_kwargs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, Dict, Optional
from typing import Any, Optional

from datasets import ClassLabel, Dataset
from pytorch_lightning.utilities import rank_zero_warn
Expand Down Expand Up @@ -52,7 +52,3 @@ def num_classes(self) -> int:
rank_zero_warn("Labels has not been set, calling `setup('fit')`.")
self.setup("fit")
return self.labels.num_classes

@property
def model_data_kwargs(self) -> Dict[str, int]:
return {"num_labels": self.num_classes}

0 comments on commit d972675

Please sign in to comment.