Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Batch size unrecogonized #20

Open
mao-code opened this issue Mar 22, 2024 · 1 comment · May be fixed by #21
Open

Batch size unrecogonized #20

mao-code opened this issue Mar 22, 2024 · 1 comment · May be fixed by #21

Comments

@mao-code
Copy link

mao-code commented Mar 22, 2024

I am using the DDPO logic to fine-tuned my own model.
However, I found that the example reward function (LLaVA BERTScore) use a fixed batch size.

After seeing the source code in this repo and the TRL DDPOTrainer class, I found that this batch size may related to sample_batch_size.

I recommend to modify the batch size with the one in the config or leave some comments on it. By doing so, people who wants to design their reward function can have a more sensible guide.

Below is the example reward in this repo I mentioned above.

def llava_bertscore():
    """Submits images to LLaVA and computes a reward by comparing the responses to the prompts using BERTScore. See
    https://github.com/kvablack/LLaVA-server for server-side code.
    """
    import requests
    from requests.adapters import HTTPAdapter, Retry
    from io import BytesIO
    import pickle

    batch_size = 16 
    url = "http://127.0.0.1:8085"
    sess = requests.Session()
    retries = Retry(
        total=1000, backoff_factor=1, status_forcelist=[500], allowed_methods=False
    )
    sess.mount("http://", HTTPAdapter(max_retries=retries))

    def _fn(images, prompts, metadata):
        del metadata
        if isinstance(images, torch.Tensor):
            images = (images * 255).round().clamp(0, 255).to(torch.uint8).cpu().numpy()
            images = images.transpose(0, 2, 3, 1)  # NCHW -> NHWC

        images_batched = np.array_split(images, np.ceil(len(images) / batch_size))
        prompts_batched = np.array_split(prompts, np.ceil(len(prompts) / batch_size))
...

And this is the code which use compute_reward() in the DDPOTraner class in TRL Repo

def step(self, epoch: int, global_step: int):
        """
        Perform a single step of training.

        Args:
            epoch (int): The current epoch.
            global_step (int): The current global step.

        Side Effects:
            - Model weights are updated
            - Logs the statistics to the accelerator trackers.
            - If `self.image_samples_callback` is not None, it will be called with the prompt_image_pairs, global_step, and the accelerator tracker.

        Returns:
            global_step (int): The updated global step.

        """
        samples, prompt_image_data = self._generate_samples(
            iterations=self.config.sample_num_batches_per_epoch,
            batch_size=self.config.sample_batch_size,
        )

        # collate samples into dict where each entry has shape (num_batches_per_epoch * sample.batch_size, ...)
        samples = {k: torch.cat([s[k] for s in samples]) for k in samples[0].keys()}
        rewards, rewards_metadata = self.compute_rewards(
            prompt_image_data, is_async=self.config.async_reward_computation
        )
...
@mao-code mao-code linked a pull request Mar 22, 2024 that will close this issue
@mao-code
Copy link
Author

Or is it just the batch for LLaVA to generate responses? Cuz I saw that it finally added all rewards together to a list.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging a pull request may close this issue.

1 participant