Python Ray : passing non-trivial objects to workers causes memory overflow

Topic: Memory overflow caused by small amount of data

Use-case: I have instances of objects that do some work on data. These instances should be passed to the workers along with the data. I'm testing it right now on a local machine (EC2 c6i.12xlarge ubuntu 18.04).

Problem: The instances of my objects cause a memory overflow despite that the data and instances are only couple MB in size. I found that when I use 3rd party libraries like nltk inside the instances, the memory grows quickly with the amount of cpus used. When I don't use those 3rd party libraries, everything is working as it should.

Expected behavior: The memory usage is not increasing linearly with the amount of cpus

Minimal Example: Below a minimal example, with the output below. When I pass only the data (10 MB in the example) without the object instance to the workers the memory overhead is negligible small. When I pass the instance only without data to a worker, the Memory overhead is almost scaling linear (1 cpu: 6 MB, 2 cpus:11 MB, 10 cpus: 60 MB) - so it seems some package information is passed to every cpu along with the object instance, which is fine. However, when I pass Data (10 MB) and object instances, the data is also copied multiple times ( 1 cpu: 20 MB, 10 cpu: 180 MB). When I want to run on 30-50 cpus on a single machine with data of couple GB, this causes a memory overflow.

Questions: How can I give instances of objects that depend on 3rd party libraries without the above behavior? Is there a best practice to handle small, global variables that is different than the approach to putting them in the object storage?

import nltk
import psutil
import ray


class DummyObject():

    def do_something(self):
        print(nltk.__version__)


@ray.remote
def dummy_fun(*args):
    pass


def create_data(target_size_mb=10):
    """
    Create some random data
    :param target_size_mb: 
    :return: 
    """
    # Create a list of random strings
    data_entries = 80000 * target_size_mb  # Number of rows
    size_per_entry = 100  # Byte size per entry
    length_string = size_per_entry - 49  # Length of a string that satisfies the byte size
    payload = ['a' * length_string for i in range(data_entries)]  # Create payload as specified
    return payload


def run_problem(payload=None, config=None):
    num_cpu = 1
    tasks = num_cpu

    # Init ray
    ray.init(num_cpus=num_cpu)

    # Put it in the object storage
    payload_id = ray.put(payload)
    config_id = ray.put(config)

    # Track memory in a naive way
    start_memory = psutil.virtual_memory()[3]

    # Create jobs
    result_id = [dummy_fun.remote(config_id, payload_id) for i in range(tasks)]

    # Run jobs
    result = ray.get(result_id)

    end_memory = psutil.virtual_memory()[3]
    print('Memory usage {} MB'.format((end_memory - start_memory) / 8 / 1000 / 1000))

    ray.shutdown()


print("Payload: None \t config: Dummy Object")
run_problem(payload=None, config=DummyObject)
print("-" * 100)

print("Payload: 10 MB \t config: None")
run_problem(payload=create_data(target_size_mb=10), config=None)
print("-" * 100)

print("Payload: 10 MB \t config: Dummy Object")
run_problem(payload=create_data(target_size_mb=10), config=DummyObject)
print("-" * 100)

Output:

Payload: None    config: Dummy Object
Memory usage 5.612544 MB
----------------------------------------------------------------------------------------------------
Payload: 10 MB   config: None
Memory usage 0.23705600000000002 MB
----------------------------------------------------------------------------------------------------
Payload: 10 MB   config: Dummy Object
Memory usage 20.628991999999997 MB
----------------------------------------------------------------------------------------------------

Process finished with exit code 0


Comments

Popular posts from this blog

Today Walkin 14th-Sept

Spring Elasticsearch Operations

Hibernate Search - Elasticsearch with JSON manipulation