Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Having some problem related to dataset loading. Freezing/Process Crash/Memory Problem. #29

Open
radna0 opened this issue Dec 22, 2024 · 8 comments

Comments

@radna0
Copy link

radna0 commented Dec 22, 2024

@tdrussell These are the logs after DeepSpeed spawns the processes and its subprocess starts to load the dataset

Might be related to huggingface/datasets#4883 as well. The program then freezes. Also the program uses about 200GB of System Mem. from 30G -> 250GB.

Fetching 8 files: 100%|████████████████████████████████████████████████| 8/8 [00:00<00:00, 9238.56it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 6657.63it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 8674.88it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 7536.93it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 7345.54it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 4868.61it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 7163.63it/s]
Downloading shards: 100%|██████████████████████████████████████████████| 2/2 [00:00<00:00, 9167.88it/s]
Downloading shards: 100%|█████████████████████████████████████████████| 2/2 [00:00<00:00, 12427.57it/s]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:09<00:00,  4.78s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:05<00:00,  2.94s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:05<00:00,  2.97s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:06<00:00,  3.26s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:06<00:00,  3.20s/it]
using video_clip_mode=single_middle
Loading checkpoint shards:   0%|                                                 | 0/2 [00:00<?, ?it/s]Found 5171 images/videos in /home/kojoe/diffusion-pipe/data/data_0
caching metadata
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:03<00:00,  1.55s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:03<00:00,  1.86s/it]
Loading checkpoint shards: 100%|█████████████████████████████████████████| 2/2 [00:04<00:00,  2.47s/it]
Map (num_proc=8):   0%|                                       | 15/5171 [00:01<07:39, 11.21 examples/s]using video_clip_mode=single_middle
Map (num_proc=8):   1%|▏                                      | 31/5171 [00:03<06:14, 13.73 examples/s]Found 5171 images/videos in /home/kojoe/diffusion-pipe/data/data_0
caching metadata
Map (num_proc=8): 100%|█████████████████████████████████████| 5171/5171 [07:26<00:00, 11.59 examples/s]
https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f973710e22c,https://symbolize.stripped_domain/r/?trace=7f973704308f7f973704308f&map=&map=7f973710e22c,

7f973704308fhttps://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=&map=7f973710e22c,
*** SIGTERM received by PID 226779 (TID 226779) on cpu 1 from PID 226629; stack trace: ***
*** SIGTERM received by PID 226781 (TID 226781) on cpu 0 from PID 226629; stack trace: ***
7f973704308fhttps://symbolize.stripped_domain/r/?trace=*** SIGTERM received by PID 226833 (TID 226833) on cpu 11 from PID 226629; stack trace: ***
&map=7f973710e22c,
7f973704308f7f973710e22c,&map=7f973704308fhttps://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=*** SIGTERM received by PID 226915 (TID 226915) on cpu 2 from PID 226629; stack trace: ***
&map=

7f973710e22c,7f973710e22c,7f973704308f7f973704308f*** SIGTERM received by PID 226806 (TID 226806) on cpu 30 from PID 226629; stack trace: ***
&map=&map=

*** SIGTERM received by PID 226828 (TID 226828) on cpu 64 from PID 226629; stack trace: ***
*** SIGTERM received by PID 226862 (TID 226862) on cpu 6 from PID 226629; stack trace: ***
*** SIGTERM received by PID 226856 (TID 226856) on cpu 54 from PID 226629; stack trace: ***
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f95340c8a00,7f973704308f&map=
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
E1222 15:27:39.683511  226806 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f973710e22c,https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f95340c8a00,7f973710e22c,7f95340c8a00,7f973704308f7f95340c8a00,7f973704308f&map=7f973704308f&map=&map=


E1222 15:27:39.683680  226915 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:27:39.683675  226781 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:27:39.683674  226779 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308fhttps://symbolize.stripped_domain/r/?trace=&map=7f953414641d,
7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,E1222 15:27:39.684261  226806 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
7f95340c8c04,7f973704308f&map=E1222 15:27:39.684284  226806 process_state.cc:1163] RAW: Raising 11 signal with default behavior

E1222 15:27:39.684299  226915 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.684315  226915 process_state.cc:1163] RAW: Raising 11 signal with default behavior
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:27:39.684381  226781 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.684409  226781 process_state.cc:1163] RAW: Raising 11 signal with default behavior
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:27:39.684454  226779 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.684465  226779 process_state.cc:1163] RAW: Raising 11 signal with default behavior
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f95340c8a00,7f973704308f&map=
E1222 15:27:39.685581  226828 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f973710e22c,7f95340c8a00,7f95340c8a00,7f973704308f7f973704308f&map=&map=

E1222 15:27:39.686030  226856 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:27:39.686034  226862 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:27:39.686177  226828 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.686193  226828 process_state.cc:1163] RAW: Raising 11 signal with default behavior
https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7f953414641d,7f953414641d,7f973704308f,7f973704308f,7f9533f9bfcb,7f9533f9bfcb,7f9534124ebb,7f9534124ebb,7f9534123932,7f9534123932,7f9534123437,7f9534123437,7f953437c62d,7f953437c62d,7f95340c8c04,7f95340c8c04,7f973704308f7f973704308f&map=&map=

E1222 15:27:39.686635  226856 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.686634  226862 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.686652  226856 process_state.cc:1163] RAW: Raising 11 signal with default behavior
E1222 15:27:39.686658  226862 process_state.cc:1163] RAW: Raising 11 signal with default behavior
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f95340c8a00,7f973704308f&map=
E1222 15:27:39.687943  226833 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:27:39.688389  226833 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:27:39.688401  226833 process_state.cc:1163] RAW: Raising 11 signal with default behavior
Map: 100%|███████████████████████████████████████████████| 1330/1330 [00:00<00:00, 23894.66 examples/s]
Map: 100%|███████████████████████████████████████████████| 2489/2489 [00:00<00:00, 23854.85 examples/s]
Map: 100%|█████████████████████████████████████████████████| 794/794 [00:00<00:00, 23423.63 examples/s]
Map: 100%|█████████████████████████████████████████████████| 558/558 [00:00<00:00, 22097.38 examples/s]
caching latents: /home/kojoe/diffusion-pipe/data/data_0
caching latents: (1.78, 72)
caching latents: (1280, 704, 72)
Map (num_proc=8): 100%|█████████████████████████████████████| 5171/5171 [08:24<00:00, 10.26 examples/s]
https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f973704308f&map=https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=
https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f973710e22c,7f973710e22c,https://symbolize.stripped_domain/r/?trace=7f973704308f*** SIGTERM received by PID 230940 (TID 230940) on cpu 0 from PID 229839; stack trace: ***
7f973710e22c,&map=7f973704308fhttps://symbolize.stripped_domain/r/?trace=7f973704308f7f973704308f
7f973710e22c,&map=&map=7f973704308f&map=7f973710e22c,
&map=
*** SIGTERM received by PID 231730 (TID 231730) on cpu 14 from PID 229839; stack trace: ***
7f973704308f
*** SIGTERM received by PID 232669 (TID 232669) on cpu 59 from PID 229839; stack trace: ***
&map=*** SIGTERM received by PID 231636 (TID 231636) on cpu 28 from PID 229839; stack trace: ***

*** SIGTERM received by PID 232872 (TID 232872) on cpu 60 from PID 229839; stack trace: ***

*** SIGTERM received by PID 233638 (TID 233638) on cpu 27 from PID 229839; stack trace: ***
https://symbolize.stripped_domain/r/?trace=*** SIGTERM received by PID 232320 (TID 232320) on cpu 25 from PID 229839; stack trace: ***
7f973710e22c,7f973704308f&map=
*** SIGTERM received by PID 233480 (TID 233480) on cpu 26 from PID 229839; stack trace: ***
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f95340c8a01       1888  (unknown)
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f95340c8a01       1888  (unknown)
PC: @     0x7f973710e22c  (unknown)  read
PC: @     0x7f973710e22c  (unknown)  read
    @     0x7f9737043090  (unknown)  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f95340c8a01       1888  (unknown)
https://symbolize.stripped_domain/r/?trace=    @     0x7f95340c8a01       1888  (unknown)
7f973710e22c,https://symbolize.stripped_domain/r/?trace=    @     0x7f9737043090  (unknown)  (unknown)
7f95340c8a00,7f973710e22c,https://symbolize.stripped_domain/r/?trace=7f973704308f    @     0x7f9737043090  (unknown)  (unknown)
7f95340c8a00,7f973710e22c,&map=https://symbolize.stripped_domain/r/?trace=    @     0x7f9737043090  (unknown)  (unknown)
7f973704308f    @     0x7f9737043090  (unknown)  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
7f95340c8a00,
7f973710e22c,https://symbolize.stripped_domain/r/?trace=&map=https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7f973704308f7f973710e22c,7f95340c8a00,7f973710e22c,
7f973710e22c,&map=7f95340c8a00,
7f973704308f7f95340c8a00,7f95340c8a00,7f973704308f&map=7f973704308f7f973704308f&map=
E1222 15:28:42.883335  231636 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
&map=&map=
E1222 15:28:42.883355  232320 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:28:42.883365  232872 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.


E1222 15:28:42.883399  232669 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:28:42.883417  231730 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:28:42.883435  230940 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
E1222 15:28:42.883435  233638 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:28:42.883983  231636 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:28:42.883996  231636 process_state.cc:1163] RAW: Raising 11 signal with default behavior
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:28:42.884058  230940 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
https://symbolize.stripped_domain/r/?trace=E1222 15:28:42.884071  230940 process_state.cc:1163] RAW: Raising 11 signal with default behavior
7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,https://symbolize.stripped_domain/r/?trace=7f973704308f7f953414641d,&map=7f973704308f,
7f9533f9bfcb,https://symbolize.stripped_domain/r/?trace=https://symbolize.stripped_domain/r/?trace=7f9534124ebb,https://symbolize.stripped_domain/r/?trace=E1222 15:28:42.884126  232872 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
7f953414641d,7f953414641d,7f9534123932,7f953414641d,7f973704308f,7f973704308f,E1222 15:28:42.884160  232872 process_state.cc:1163] RAW: Raising 11 signal with default behavior
7f9534123437,7f973704308f,7f9533f9bfcb,7f9533f9bfcb,7f953437c62d,7f9533f9bfcb,7f9534124ebb,7f9534124ebb,7f95340c8c04,7f9534124ebb,7f9534123932,7f9534123932,7f973704308f7f9534123932,7f9534123437,7f9534123437,&map=7f9534123437,7f953437c62d,7f953437c62d,
7f953437c62d,7f95340c8c04,7f95340c8c04,7f95340c8c04,7f973704308f&map=E1222 15:28:42.884262  233638 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
7f973704308f
7f973704308fE1222 15:28:42.884293  233638 process_state.cc:1163] RAW: Raising 11 signal with default behavior
&map=&map=E1222 15:28:42.884301  231730 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!


E1222 15:28:42.884320  231730 process_state.cc:1163] RAW: Raising 11 signal with default behavior
E1222 15:28:42.884320  232669 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:28:42.884327  232320 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:28:42.884334  232669 process_state.cc:1163] RAW: Raising 11 signal with default behavior
E1222 15:28:42.884345  232320 process_state.cc:1163] RAW: Raising 11 signal with default behavior
    @     0x7f95340c8a01       1888  (unknown)
    @     0x7f9737043090  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f973710e22c,7f95340c8a00,7f973704308f&map=
E1222 15:28:42.884631  233480 coredump_hook.cc:247] RAW: Remote crash gathering disabled for SIGTERM.
https://symbolize.stripped_domain/r/?trace=7f953414641d,7f973704308f,7f9533f9bfcb,7f9534124ebb,7f9534123932,7f9534123437,7f953437c62d,7f95340c8c04,7f973704308f&map=
E1222 15:28:42.885086  233480 process_state.cc:1128] RAW: Signal 11 raised at PC: 0x7f953414641d while already in FailureSignalHandler!
E1222 15:28:42.885098  233480 process_state.cc:1163] RAW: Raising 11 signal with default behavior
Map: 100%|███████████████████████████████████████████████| 1330/1330 [00:00<00:00, 25913.86 examples/s]
Map: 100%|███████████████████████████████████████████████| 2489/2489 [00:00<00:00, 27830.83 examples/s]
Map: 100%|█████████████████████████████████████████████████| 794/794 [00:00<00:00, 26929.39 examples/s]
Map: 100%|█████████████████████████████████████████████████| 558/558 [00:00<00:00, 22436.53 examples/s]
caching latents: /home/kojoe/diffusion-pipe/data/data_0
caching latents: (1.78, 72)
caching latents: (1280, 704, 72)
Map (num_proc=8):   0%|                                     | 4/1330 [03:10<14:57:33, 40.61s/ examples]
@tdrussell
Copy link
Owner

caching latents: (1280, 704, 72)

Your dataset is relatively high resolution videos. This is going to use a lot more RAM than images just to load and preprocess raw video frames. In dataset.py, can you try changing NUM_PROC at the top to 1? This causes all the dataset.map() calls to not use parallelism and do everything in-process. It should use less memory because there's no longer 8 processes trying to load videos at the same time. But it will be much slower.

I've tried to optimize the memory usage of the video preprocessing, but maybe there's more that can be done. The relevant code is in models/base.py, PreprocessMediaFile class.

@radna0
Copy link
Author

radna0 commented Dec 22, 2024

@tdrussell I set the resolutions to train on to 960, does it not resize the videos accordingly as well? I tried NUM_PROC=1, it just hangs after reaching step 4 again. Also yes. much much slower. It uses about 150GB. No OOM/process crash, it just freezes.

caching latents: /home/kojoe/diffusion-pipe/data/data_0
caching latents: (1.78, 72)
caching latents: (1280, 704, 72)
Map:   0%|▏                                                  | 4/1330 [01:13<6:44:10, 18.29s/ examples]
Map:   0%|▏                                                | 4/1330 [11:51<38:33:33, 104.69s/ examples]

@tdrussell
Copy link
Owner

How many frames is your longest video? The code must load the entire raw video into memory, before extracting one or more video clips (depending on how you configured it) of the correct number of frames for the bucket. The raw video pixels use a lot of memory. E.g. even for the size bucket (1280, 704, 72), at float32 it will use nearly 1GB. If your video files on disk are much longer than this it would explain the high memory use.

@radna0
Copy link
Author

radna0 commented Dec 22, 2024

The videos are within 60-120 frames. Around 5100 videos, at 4GB disk storage. Is there not a batch size we can set to how many videos are being loaded? Or is it just 1 by default? Can we somehow have the videos resized, encoded onto latents, save the tensors and then categorize them by buckets? Is that not possible?

@tdrussell
Copy link
Owner

tdrussell commented Dec 22, 2024

Okay, then something unexpected is happening. With NUM_PROC=1 it will load only a single video at a time. If your input videos are a few seconds long at most, it might use a few GB worst case (since rearranging dimensions makes a copy, etc), but it should not use 100+ GB. Let me do some tests with video on my end.

EDIT: one more thing to confirm, are you keeping caching_batch_size=1? If it's >1 it will still load the input videos one at a time, but it would need to keep all the latents in memory for the whole batch.

@radna0
Copy link
Author

radna0 commented Dec 22, 2024

Yes @tdrussell , the caching_batch_size is 1. I can share my config here, I don't think I have modified much.

EDIT: the json within the dataset.toml stores the captions. I just create a .txt file based on that for the videos if it doesn't exist

# Output path for training runs. Each training run makes a new directory in here.
output_dir = '/data/diffusion_pipe_training_runs/ltx_video_test'

# Dataset config file.
dataset = 'examples/dataset.toml'
# You can have separate eval datasets. Give them a name for Tensorboard metrics.
# eval_datasets = [
#     {name = 'something', config = 'path/to/eval_dataset.toml'},
# ]

# training settings

# I usually set this to a really high value because I don't know how long I want to train.
epochs = 1000
# Batch size of a single forward/backward pass for one GPU.
micro_batch_size_per_gpu = 1
# Pipeline parallelism degree. A single instance of the model is divided across this many GPUs.
pipeline_stages = 1
# Number of micro-batches sent through the pipeline for each training step.
# If pipeline_stages > 1, a higher GAS means better GPU utilization due to smaller pipeline bubbles (where GPUs aren't overlapping computation).
gradient_accumulation_steps = 4
# Grad norm clipping.
gradient_clipping = 1.0
# Learning rate warmup.
warmup_steps = 100

# eval settings

eval_every_n_epochs = 1
eval_before_first_step = true
# Might want to set these lower for eval so that less images get dropped (eval dataset size is usually much smaller than training set).
# Each size bucket of images/videos is rounded down to the nearest multiple of the global batch size, so higher global batch size means
# more dropped images. Usually doesn't matter for training but the eval set is much smaller so it can matter.
eval_micro_batch_size_per_gpu = 1
eval_gradient_accumulation_steps = 1

# misc settings

# Probably want to set this a bit higher if you have a smaller dataset so you don't end up with a million saved models.
save_every_n_epochs = 5
# Can checkpoint the training state every n number of epochs or minutes. Set only one of these. You can resume from checkpoints using the --resume_from_checkpoint flag.
#checkpoint_every_n_epochs = 1
checkpoint_every_n_minutes = 120
# Always set to true unless you have a huge amount of VRAM.
activation_checkpointing = true
# Controls how Deepspeed decides how to divide layers across GPUs. Probably don't change this.
partition_method = 'parameters'
# dtype for saving the LoRA or model, if different from training dtype
save_dtype = 'bfloat16'
# Batch size for caching latents and text embeddings. Increasing can lead to higher GPU utilization during caching phase but uses more memory.
caching_batch_size = 1
# How often deepspeed logs to console.
steps_per_print = 1
# How to extract video clips for training from a single input video file.
# The video file is first assigned to one of the configured frame buckets, but then we must extract one or more clips of exactly the right
# number of frames for that bucket.
# single_beginning: one clip starting at the beginning of the video
# single_middle: one clip from the middle of the video (cutting off the start and end equally)
# multiple_overlapping: extract the minimum number of clips to cover the full range of the video. They might overlap some.
# default is single_middle
video_clip_mode = 'single_middle'

[model]
# flux, ltx-video, or hunyuan-video
type = 'ltx-video'
# # Path to Huggingface Diffusers directory for Flux
diffusers_path = '/dev/shm'
timestep_sample_method = 'logit_normal'
dtype = 'bfloat16'

[adapter]
type = 'lora'
rank = 32
# Dtype for the LoRA weights you are training.
dtype = 'bfloat16'
# You can initialize the lora weights from a previously trained lora.
#init_from_existing = '/data/diffusion_pipe_training_runs/something/epoch50'

[optimizer]
# AdamW from the optimi library is a good default since it automatically uses Kahan summation when training bfloat16 weights.
# Look at train.py for other options. You could also easily edit the file and add your own.
type = 'adamw_optimi'
lr = 2e-5
betas = [0.9, 0.99]
weight_decay = 0.01
eps = 1e-8
# Resolutions to train on, given as the side length of a square image. You can have multiple sizes here.
resolutions = [960]

# Enable aspect ratio bucketing. For the different AR buckets, the final size will be such that
# the areas match the resolutions you configured above.
enable_ar_bucket = true

# The aspect ratio and frame bucket settings may be specified for each [[directory]] entry as well.
# Directory-level settings will override top-level settings.

# Min and max aspect ratios, given as width/height ratio.
ar_buckets = [1.78, 2.35]

# Can manually specify ar_buckets instead of using the range-style config above.
# Each entry can be width/height ratio, or (width, height) pair. But you can't mix them, because of TOML.
# ar_buckets = [[512, 512], [448, 576]]
# ar_buckets = [1.0, 1.5]

# For video training, you need to configure frame buckets (similar to aspect ratio buckets). There will always
# be a frame bucket of 1 for images. Videos will be assigned to the first frame bucket that the video is greater than or equal to in length.
# But videos are never assigned to the image frame bucket (1); if the video is very short it would just be dropped.
frame_buckets = [1, 60, 72, 96, 120]


[[directory]]
# Path to directory of images/videos, and corresponding caption files. The caption files should match the media file name, but with a .txt extension.
# A missing caption file will log a warning, but then just train using an empty caption.
path = '/home/kojoe/diffusion-pipe/data/data_0'
json = '/home/kojoe/diffusion-pipe/data/train.json'
# The dataset will act like it is duplicated this many times.
num_repeats = 1


# You can list multiple directories.

# [[directory]]
# path = '/home/anon/data/images/something_else'
# json = '/home/anon/data/images/something_else.json'
# num_repeats = 1

@tdrussell
Copy link
Owner

I can't reproduce this. Using LTX-Video, with resolutions=[960] and frame_buckets=[1, 72], it successfully processes and caches the latents, with about 51GB peak RAM usage. This is even with leaving NUM_PROC=8 which means it's loading 8 videos in parallel.

Are you using the latest commit and latest version of the LTX-Video model from Huggingface? And what is your --num_gpus flag for deepspeed? Each process will have to load all the models (transformer, vae, text encoder) so that could be a cause as well.

@radna0
Copy link
Author

radna0 commented Dec 26, 2024

@tdrussell This issue is reproducible on systems using TPUs/XLA with a custom DeepSpeed Backend implementation. I guess it's still very experimental and I haven't make a PR to merge and review the changes fully. Maybe that's why it's having problems.

But I'm currently testing on another system with NVIDIA GPUs, using DeepSpeed Cuda Backend, and right now from NUM_PROC=1->8, everything works as expected. It uses about 50GB as you stated. But it's still very slow, and when increasing the NUM_PROC to 16, 32, 36. probably anything higher than 8 makes the dataset loading significantly longer. Up to minutes for just the first step, it then afterwards only utilized as much as NUM_PROC=8? Memory Usage is a bit higher at around 100-130GB, but in terms of speed, it's the same and even slower.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants