-
Notifications
You must be signed in to change notification settings - Fork 2.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
With dataloader RSS memory consumed by HF datasets monotonically increases #4883
Comments
Are you sure there is a leak? How can I see it? You shared the script but not the output which you believe should indicate a leak. I modified your reproduction script to print only once per try as your original was printing too much info and you absolutely must add
Now running it:
It's normal that at the beginning there is a small growth in memory usage, but after 5 cycles it gets steady. |
Unless of course you're referring the memory growth during the first try. Is that what you're referring to? And since your ds is small it's hard to see the growth - could it be just because some records are longer and it needs to allocate more memory for those? Though while experimenting with this I have observed a peculiar thing, if I concatenate 2 datasets, I don't see any growth at all. But that's probably because the program allocated additional peak RSS memory to concatenate and then is re-using the memory I basically tried to see if I make the dataset much longer, I'd expect not to see any memory growth once the 780 records of the imdb ds have been processed once. |
It is hard to say if it is directly reproducible in this setup. Maybe it is specific to the images stored in the CM4 case which cause a memory leak. I am still running your script and seeing if I can reproduce that particular leak in this case. |
I was able to reproduce the leak with:
You need to adjust the DATASET_PATH record. which you get from
(I assume the hf folks have the perms) - it's a smallish dataset (10k) then you run:
you should be able to see the leak |
This issue has nothing to do with I then traced this leak to this single call: datasets/src/datasets/formatting/formatting.py Lines 138 to 140 in 08a7b38
I can make it leak much faster by modifying that code to repeat
@lhoestq - do you know what might be happening inside Probably next need to remove The problem already happens with I'm also trying to dig in with |
This appears to be the same issue I think: #4528 |
I went all the way back to Could it be that the leak is in some 3rd party component |
Also found this warning
perhaps something triggers this condition? I have no idea if it's related - this is just something that came up during my research. |
Does it crash with OOM at some point? If it doesn't, it isn't a leak, just agressive caching or a custom allocator that doesn't like to give memory back (not uncommon). #4528 looks like it hits a steady state. I believe the underlying arrow libs use a custom C allocator. Some of those are designed not to give back to OS, but keep heap memory for themselves to re-use (hitting up the OS involves more expensive mutex locks, contention, etc). The greedy behaviour can be undesirable though. There are likely flags to change the allocator behaviour, and one could likely build without any custom allocators (or use a different one). |
In the original setup where we noticed this problem, it was indeed ending in an OOM |
@rwightman in the plot I shared, the steady state comes from the |
Could this be related to this discussion about a potential memory leak in pyarrow: https://issues.apache.org/jira/browse/ARROW-11007 ? (Note: I've tried |
The Arrow team is pretty responsive at [email protected] if it can help
That would be ideal indeed. Would be happy to help on this, can you give me access to the bucket so I can try with your data ? |
I added you to the bucket @lhoestq |
It looks like an issue with memory mapping:
|
Here is a code to reproduce this issue using only PyArrow and a dummy arrow file: import psutil
import os
import gc
import pyarrow as pa
import time
ARROW_PATH = "tmp.arrow"
if not os.path.exists(ARROW_PATH):
arr = pa.array([b"a" * (200 * 1024)] * 1000) # ~200MB
table = pa.table({"a": arr})
with open(ARROW_PATH, "wb") as f:
writer = pa.RecordBatchStreamWriter(f, schema=table.schema)
writer.write_table(table)
writer.close()
def memory_mapped_arrow_table_from_file(filename: str) -> pa.Table:
memory_mapped_stream = pa.memory_map(filename)
opened_stream = pa.ipc.open_stream(memory_mapped_stream)
pa_table = opened_stream.read_all()
return pa_table
table = memory_mapped_arrow_table_from_file(ARROW_PATH)
arr = table[0]
mem_before = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
for idx, x in enumerate(arr):
if idx % 100 == 0:
gc.collect()
time.sleep(0.1)
mem_after = psutil.Process(os.getpid()).memory_info().rss / (1024 * 1024)
print(f"{idx:4d} {mem_after - mem_before:12.4f}MB") prints
Note that this example simply iterates over the |
@lhoestq that does indeed increase in memory, but if you iterate over array again after the first time, or re-open and remap the same file (repeat Are the pa_tables held on to anywhere after they are iterated in the real code? in my hack, if you do a bunch cut & paste and then change the arr name for each iter
it leaks, if all arr are the same name (so prev one gets cleaned up) it does not and goes back to 0, anything that could be holding onto a reference of an intermediary equivalent like arr in the real use case? |
Yes, we have already established here #4883 (comment) that when one iterates over the whole dataset multiple times, it consumes a bit more memory in the next few repetitions and then remains steady. Which means that when a new iterator is created over the same dataset, all the memory from the previous iterator is re-used. So the leak happens primarily when the iterator is "drained" the first time. which tells me that either a circular reference is created somewhere which only gets released when the iterator is destroyed, or there is some global variable that keeps piling up the memory and doesn't release it in time. Also I noticed some There are also some |
@stas00 my point was, I'm not convinced @lhoestq last example illustrates the leak, but rather the differences between memory mapping and in memory usage patterns. If you destroy arr, memory map impl goes back to 0 each iteration. The amount of memory that 'looks' like it is leaked in first pass differes quite a bit between memory mapped vs in memory, but the underlying issue likely a circular reference, or reference(s) which were not cleaned up that would impact either case, but likely much more visible with mmap. |
Thank you for clarifying, Ross. I think we agree that it's almost certain that the I wish there was a way on linux to tell the program to free no longer used memory at will. |
FWIW, I revisted some code I had in the works to use HF datasets w/ timm train & val scripts. There is no leak there across multipe epochs. It uses the defaults. It's worth noting that with imagenet |
NotesAfter reading many issues and trying many things here is the summary of my learning I'm now using @lhoestq repro case as it's pyarrow-isolated: #4883 (comment) 1. pyarrow memory backendsit has 3 backends, I tried them all with the same results
2. quick releaseThe
it doesn't make any difference in this case 3. actual memory allocationsthis is a useful tracer for PA memory allocators
it nicely reports memory allocations and releases when the arrow file is created the first time. but when we then try to do This summary also reports no allocations when the script run the second time (post file creation):
However it's easy to see by using (this is bolted on top of the original repro script)
gives:
the 3rd and 4th columns are the 5th column is the size of mmaped stream - fixed. the last 2 are the PA's malloc reports - you can see it's totally fixed and 0. So what gives? PA's memory allocator says nothing was allocated and we can see python doesn't allocate any memory either. As someone suggested in one of the PA issues that IPC/GRPC could be the issue. Any suggestions on how debug this one? The main issue is that one can't step through with a python debugger as Please see the next comment for a possible answer. ref-countI also traced reference counts and they are all fixed using either so it's not the python object Important related discussionshttps://issues.apache.org/jira/browse/ARROW-11007 - looks very similar to our issue |
There is no leak, just badly communicated linux RSS memory usage statsNext, lets revisit @rwightman's suggestion that there is actually no leak. After all - we are using mmap which will try to map the file to RAM as much as it can and then page out if there is no memory. i.e. MMAP is only fast if you have a lot of CPU RAM. So let's do it: Memory mapping OOM testWe first quickly start a cgroups-controlled shell which will instantly kill any program that consumes more than 1GB of memory:
Let's check that it indeed does so. Let's change @lhoestq's script to allocate a 10GB arrow file:
oops, that didn't work, as we tried to allocate 10GB when only 1GB is allowed. This is what we want! Let's do a sanity check - can we allocate 0.1GB?
Yes. So the limited shell does the right thing. It let's allocate Next let's go back to @lhoestq's script but with 10GB arrow file. we change his repro script #4883 (comment) to 50x larger file
we first have to run into a normal unlimited shell so that we don't get killed (as the script allocates 10GB) let's run the script now in the 1GB-limited shell while running a monitor:
so we have 2 sources of RSS info just in case.
But wait, it reported 10GB RSS both in So that means it never allocated 10GB otherwise it'd have been killed. Which tells us that there is no leak whatsoever and this is just a really difficult situation where MMAPPED memory is reported as part of RSS which it probably shouldn't. As now we have no way how to measure real memory usage. I also attached the script with all the different things I have tried in it, so it should be easy to turn them on/off if you want to reproduce any of my findings. just rename it to (I have to remember to exit that special mem-limited shell or else I won't be able to do anything serious there.) |
The original leak in the multi-modal code is very likely something else. But of course now it'd be very difficult to trace it using mmap. I think to debug we have to set |
To add to what @stas00 found, I'm gonna leave some links to where I believe the confusion came from in pyarrow's APIs, for future reference:
And where their example shows 0 RSS memory allocation, the way we used to measure RSS shows 39.6719MB allocated. Here's the script to reproduce:
gives:
Which again just proves that we uncorrectly measure RSS, in the case of MMAPPED memory |
@lhoestq, I have been working on a detailed article that shows that MMAP doesn't leak and it's mostly ready. I will share when it's ready. The issue is that we still need to be able to debug memory leaks by turning MMAP off. But, once I tried to show the user that using Here is the repro:
|
as I suggested on slack perhaps it was due to dataset records length variation, so with your help I wrote another repro with synthetic records which are all identical - which should remove my hypothese from the equation and we should expect 0 incremental growth as we iterate over the datasets. But alas this is not the case. There is a tiny but definite leak-like behavior. Here is the new repro:
and the run:
so I'm still not sure why we get this. As you can see I started skipping the first few iterations where memory isn't stable yet. As the actual diff is much larger if we count all iterations. What do you think? |
@stas00 my 2 cents from having looked at a LOT of memory leaks over the years, esp in Python, .3% memory increase over that many iterations of something is difficult to say with certainty it is a leak. Also, just looking at RSS makes it hard to analyze leaks. RSS can stay near constant while you are leaking. RSS is paged in mem, if you have a big leak your RSS might not increase much (leaked mem tends not to get used again so often paged out) while your virtual page allocation could be going through the roof... |
yes, that's true, but unless the leak is big, I'm yet to find another measurement tool. To prove your point here is a very simple IO in a loop program that also reads the same line all over again:
it has some other instrumentations to do mmap and accumulate data, but let's ignore that for now. Here it is running in a simple non-mmap IO:
as you can see even this super-simplistic program that just performs If you have a better tool for measurement other than RSS, I'm all ears. |
@stas00 if you aren't using memory maps, you should be able to clearly see the increase in the virtual mem for the process as well. Even then, it could still be challenging to determine if it's leak vs fragmentation due to problematic allocation patterns (not uncommon with Python). Using a better mem allocator like tcmalloc via LD_PRELOAD hooks could reduce impact of fragmentation across both Python and c libs. Not sure that plays nice with any allocator that arrow might use itself though. |
Thank you for these suggestions, Ross. The problem is that most of the time we use a bunch of non-python libs that are binded to python and so besides python, one has to deal with not-quite controllable allocation strategies by those other components as well. So it's a super-tricky world. Good suggestion on memory fragmentation, which could definitely be one of the sources for ever-growing RSS. pytorch's memory management utils are mostly quite excellent, and fragmentation is one of the main issues there. Projects like Deepspeed try to solve it by pre-allocating memory themselves and then managing it tightly to avoid fragmentation, which seems to work quite well. BTW, I'm not sure if you have seen this tool I developed some years back to automatically track and report CPU and GPU memory usage in Jupyter notebooks. https://github.com/stas00/ipyexperiments I found it to be quite useful for detecting memory leakage - of course it's the same RSS for CPU, but it's just automated where each cell reports the delta. One other tricky thing to measure is CPU peak memory which it provides. As often there are those temp leaks which lead to OOMs. |
OK, I ended up compiling the different stages of the research into an article, including a recommendation on how to remove the @lhoestq, please have a look at your convenience and let's discuss how we use that in the
of course I'm open to other options. (I of course need to proof-read it, it surely could use an editing pass, I only checked the numbers made sense, but it should be quite readable already) |
And I will paste the last section of the article here for posterity should the original disappear: Using synthetic MMAP-disabled dataset to debug memory leaksTherefore the easiest approach is to create a syntetic dataset of desired length with all records being the same. That way the data is no longer a factor in the memory usage patterns as it's always the same.
We run this program once to create the dataset, and then the second time to profile its memory usage:
This is much better. There are still tiny fluctuations due to Python and you can see in the code I skipped the first few iterations in the code while things are being set up. But otherwise now you can easily debug the rest of your code for any memory leaks since Of course, do not forget to flip |
Thanks ! Before thinking about the documentation, I'd like to make sure we can explain what happens in a correct and simple manner. When accessing memory mapped data, the corresponding pages are loaded one at a time in the main memory and increase RSS. When the pages are not used anymore, they are paged out otherwise it would fill up the physical RAM. However RSS doesn't decrease. If you read the entire file, then the RSS will end up bigger than the file, no matter how much physical RAM you have. Is this correct ? |
All but the last sentence I think. It won't be able to read into memory more data than there is actual free host memory. As those are real paged in data. If there is no free memory it can't read read it in. The cgroups limited shell just shows that that that paged in cached MMAPed data doesn't count towards process' actually used memory. In other words if your host has 8GB of RAM and 0 SWAP, and all of that RAM is free and your process is unlimited it'll only be able to page in 8 GB of MMAPed data. |
It's challenging to describe what happens in a simple manner because it's not so simple. It's system level behaviour that depends on the system wide configuration and all of the processes running. The OS virtual memory manager needs to manage the pages across all of the processes, incl the mmap'd files. Pages can stay in a while if there is little demand, or get paged out quickly if there is high demand (from possibly any of the other processes on the system). Random access patterns (ie need one entry here, one there, will require a new page in for each access, multiple pages if there is seq read-ahead enabled). If you have little system memory and are mapping large files, jumping around, you can easly end up thrashing (constantly paging in, and then ejecting, which results in disk IO for mmap'd files). So what you observe for a single process, what's resident, can change based on what other processes are doing as wel... Many databases and applications that do heavy IO on large files opt to handle IO and buffer mangement themselves vs leave it up to the OS VMM. While mmap'd files are usually good enough, there are some pretty confusing edge case behaviours. Just loading the buffers directly into memory will be more consistent behaviour wise. However, that is only practical if the maximum file size is limited to a reasonable value (ie data is sharded into 256MB-2GB chunks). |
Indeed, this is all very complex and there is no deterministic way or tracking real memory usage of a process as a whole, especially so when different unrelated components are intermixed - e.g. python and 3rd party c++ libraries bound to it. Each component's memory allocations can be traced separately, e.g. Instead of trying to explain all these complexities, I propose that all the lay user needs to know is this:
In my write up I have crafted a few set ups using code samples and cgroups-limited shells that demonstrate the above. So the above summary can be TLDR and the article can be the long story for those who care for an empirical demonstration. If you have about 10min please kindly skim through my write up and let me know if you found some holes / unclear parts or whether this clears things up after reading it. Thank you! |
I asked around and it seemed that my write up won't fit into the HF blog, so I published it on my blog @lhoestq, I am not sure if you still want to include any of my notes in the |
Ok :) I think the explanations in #4883 (comment) should be helpful for many users already. And maybe redirect to your blog for the details ? Then to debug how much Arrow data is physically in RAM we can also advise pa.total_allocated_bytes. What do you think ? |
That sounds like a good plan, @lhoestq! |
Wanted to add one method for dealing with issues in datasets/pyarrow that seem related to this thread (and i hope helps someone; I imagine they'll land here like i did). I was encoding image data on a fairly large dataset using datasets I came up with a simple way to reduce the pressure on datasets/pyarrow during mapping/index flattening while still getting some of the benefit of parallel processing; this method was able to encode and save a 229gb dataset on a system with much less memory than that. In my case the dataset starts as rows with file paths; during the
This method is much slower than vanilla mapping, but enabled me to get past all the weird pyarrow issues and successfully encode a dataset that was much larger than the available memory on the instance it was running on. |
plz give us the python code |
I don't have time to generalize or really explain each step here, but maybe this will help someone. I recently used this technique on a dataset with ~5,000,000 sentences when straightforward dataset.map() failed in a way I've become familiar with. Note that the chosen dataset_len = len(dataset)
step = None
for num in range(config.batch_size, dataset_len):
if dataset_len % num == 0:
step = num
break
if not step:
dataset = dataset.select(range(0, dataset_len - 1))
dataset_len = dataset_len - 1
for num in range(config.batch_size, dataset_len):
if dataset_len % num == 0:
step = num
break
if not step:
raise Exception(f"Could not find a step size for {dataset_len}")
batch_size = int(step/num_processes) if int(step/num_processes) >= 200 else 200
print(
f'Initiating encoding via concatenating datasets. num_processes={num_processes}, keep_in_memory={config.keep_in_memory}, dataset_len={dataset_len}, step={step}, batch_size={batch_size}')
for idx in range(0, dataset_len, step):
print(f"Operating on {idx} -> {idx+step} of {dataset_len}")
chunk = dataset.select(range(idx, idx+step))
encoded_chunk = chunk.map(
preprocess_data,
fn_kwargs=dict(tokenizer=tokenizer, labels=labels, label2id=label2id, config=config,
batch_size=batch_size),
keep_in_memory=config.keep_in_memory,
batched=True,
features=features,
remove_columns=dataset.column_names,
num_proc=num_processes,
batch_size=batch_size,
writer_batch_size=batch_size
)
print(f"Saving {idx} -> {idx+step} chunk to disk")
encoded_chunk.save_to_disk(f"{SAGEMAKER_CHUNK_DIR}/chunk_{idx}_{idx+step}")
print("Loading encoded chunks for concatenating")
encoded_chunks = []
parent_path, subdirs, parent_files = next(os.walk(SAGEMAKER_CHUNK_DIR))
for chunk in subdirs:
print(f"Loading chunk: {chunk}")
curr_path = os.path.join(parent_path, chunk)
encoded_chunks.append(load_from_disk(curr_path))
print(f"Concatenating {len(encoded_chunks)} datasets")
encoded_dataset = concatenate_datasets(encoded_chunks) |
I ended up splitting my
then run my function
interestingly enough list comprehension does not work here and hang too ! so go old/good fashion loop. then concat the result
|
Describe the bug
When the HF datasets is used in conjunction with PyTorch Dataloader, the RSS memory of the process keeps on increasing when it should stay constant.
Steps to reproduce the bug
Run and observe the output of this snippet which logs RSS memory.
Expected results
Memory should not increase after initial setup and loading of the dataset
Actual results
Memory continuously increases as can be seen in the log.
Environment info
datasets
version: 2.3.2The text was updated successfully, but these errors were encountered: