From 3a52860dd15236e6dbf8ae1ff513cbde59cdb816 Mon Sep 17 00:00:00 2001 From: haru Date: Fri, 24 Sep 2021 02:47:46 -0700 Subject: [PATCH] Add initial GPT-J support --- mkultra/inference.py | 9 +++++++-- mkultra/tuning.py | 6 +++++- 2 files changed, 12 insertions(+), 3 deletions(-) diff --git a/mkultra/inference.py b/mkultra/inference.py index 0f1b9ae..988a6d3 100644 --- a/mkultra/inference.py +++ b/mkultra/inference.py @@ -1,10 +1,11 @@ -from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, TextGenerationPipeline +from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPTJForCausalLM, TextGenerationPipeline from mkultra.soft_prompt import SoftPrompt import torch EXTRA_ALLOWED_MODELS = [ "GPT2SoftPromptLM", - "GPTNeoSoftPromptLM" + "GPTNeoSoftPromptLM", + "GPTJSoftPromptLM" ] for model in EXTRA_ALLOWED_MODELS: @@ -85,5 +86,9 @@ def __init__(self, config): super().__init__(config) class GPTNeoSoftPromptLM(GPTSoftPromptMixin, GPTNeoForCausalLM): + def __init__(self, config): + super().__init__(config) + +class GPTJSoftPromptLM(GPTSoftPromptMixin, GPTJForCausalLM): def __init__(self, config): super().__init__(config) \ No newline at end of file diff --git a/mkultra/tuning.py b/mkultra/tuning.py index 7031e03..dc079ac 100644 --- a/mkultra/tuning.py +++ b/mkultra/tuning.py @@ -1,4 +1,4 @@ -from transformers import GPT2LMHeadModel, GPTNeoForCausalLM +from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPTJForCausalLM from mkultra.soft_prompt import SoftPrompt import torch import torch.nn as nn @@ -119,5 +119,9 @@ def __init__(self, config): super().__init__(config) class GPTNeoPromptTuningLM(GPTPromptTuningMixin, GPTNeoForCausalLM): + def __init__(self, config): + super().__init__(config) + +class GPTJPromptTuningLM(GPTPromptTuningMixin, GPTJForCausalLM): def __init__(self, config): super().__init__(config) \ No newline at end of file