Skip to content

Commit

Permalink
Merge pull request #4 from harubaru/main
Browse files Browse the repository at this point in the history
Add initial GPT-J support
  • Loading branch information
corolla-johnson authored Sep 27, 2021
2 parents aba081c + 3a52860 commit a25c72d
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
9 changes: 7 additions & 2 deletions mkultra/inference.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, TextGenerationPipeline
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPTJForCausalLM, TextGenerationPipeline
from mkultra.soft_prompt import SoftPrompt
import torch

EXTRA_ALLOWED_MODELS = [
"GPT2SoftPromptLM",
"GPTNeoSoftPromptLM"
"GPTNeoSoftPromptLM",
"GPTJSoftPromptLM"
]

for model in EXTRA_ALLOWED_MODELS:
Expand Down Expand Up @@ -85,5 +86,9 @@ def __init__(self, config):
super().__init__(config)

class GPTNeoSoftPromptLM(GPTSoftPromptMixin, GPTNeoForCausalLM):
def __init__(self, config):
super().__init__(config)

class GPTJSoftPromptLM(GPTSoftPromptMixin, GPTJForCausalLM):
def __init__(self, config):
super().__init__(config)
6 changes: 5 additions & 1 deletion mkultra/tuning.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM
from transformers import GPT2LMHeadModel, GPTNeoForCausalLM, GPTJForCausalLM
from mkultra.soft_prompt import SoftPrompt
import torch
import torch.nn as nn
Expand Down Expand Up @@ -119,5 +119,9 @@ def __init__(self, config):
super().__init__(config)

class GPTNeoPromptTuningLM(GPTPromptTuningMixin, GPTNeoForCausalLM):
def __init__(self, config):
super().__init__(config)

class GPTJPromptTuningLM(GPTPromptTuningMixin, GPTJForCausalLM):
def __init__(self, config):
super().__init__(config)

0 comments on commit a25c72d

Please sign in to comment.