From 33b04bb32ab4cdce3e68f989927e926b09e42289 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Sat, 21 Jan 2023 02:02:56 +0800 Subject: [PATCH 1/2] pickle5 --- parl/remote/communication.py | 18 ++++-------------- parl/remote/job.py | 4 ++-- parl/remote/monitor.py | 4 ++-- parl/remote/tests/log_server_test.py | 4 ++-- 4 files changed, 10 insertions(+), 20 deletions(-) diff --git a/parl/remote/communication.py b/parl/remote/communication.py index 467f21a8e..c5bd93adb 100644 --- a/parl/remote/communication.py +++ b/parl/remote/communication.py @@ -15,6 +15,7 @@ import cloudpickle import subprocess import os +import pickle5 from parl.utils import SerializeError, DeserializeError __all__ = ['dumps_argument', 'loads_argument', 'dumps_return', 'loads_return'] @@ -35,21 +36,10 @@ def _deserialize_serializable(obj): val.__dict__.update(obj["data"]) return val - context = pyarrow.default_serialization_context() + buffer_list = [] - # support deserialize in another environment - context.set_pickle(cloudpickle.dumps, cloudpickle.loads) - - # support serialize and deserialize custom class - context.register_type( - object, - "object", - custom_serializer=_serialize_serializable, - custom_deserializer=_deserialize_serializable) - - # if pyarrow is installed, parl will use pyarrow to serialize/deserialize objects. - serialize = lambda data: pyarrow.serialize(data, context=context).to_buffer() - deserialize = lambda data: pyarrow.deserialize(data, context=context) + serialize = lambda data: pickle5.dumps(data, protocol=5, buffer_callback=buffer_list.append) + deserialize = lambda data: pickle5.loads(data, buffers=buffer_list) else: # if pyarrow is not installed, parl will use cloudpickle to serialize/deserialize objects. serialize = lambda data: cloudpickle.dumps(data) diff --git a/parl/remote/job.py b/parl/remote/job.py index 61e36dd4e..2ba3aaf69 100644 --- a/parl/remote/job.py +++ b/parl/remote/job.py @@ -22,7 +22,7 @@ import argparse import cloudpickle -import pickle +import pickle5 import psutil import re import sys @@ -240,7 +240,7 @@ def wait_for_files(self, reply_socket, job_address): message = reply_socket.recv_multipart() tag = message[0] if tag == remote_constants.SEND_FILE_TAG: - pyfiles = pickle.loads(message[1]) + pyfiles = pickle5.loads(message[1]) envdir = tempfile.mkdtemp() for empty_subfolder in pyfiles['empty_subfolders']: diff --git a/parl/remote/monitor.py b/parl/remote/monitor.py index 452888940..f25488b08 100644 --- a/parl/remote/monitor.py +++ b/parl/remote/monitor.py @@ -13,7 +13,7 @@ # limitations under the License. import argparse -import pickle +import pickle5 import random import time import zmq @@ -57,7 +57,7 @@ def run(self): self.socket.send_multipart([b'[MONITOR]']) msg = self.socket.recv_multipart() - status = pickle.loads(msg[1]) + status = pickle5.loads(msg[1]) data = {'workers': [], 'clients': []} total_vacant_cpus = 0 total_used_cpus = 0 diff --git a/parl/remote/tests/log_server_test.py b/parl/remote/tests/log_server_test.py index f9a81fedd..5c5e22919 100644 --- a/parl/remote/tests/log_server_test.py +++ b/parl/remote/tests/log_server_test.py @@ -15,7 +15,7 @@ import json import multiprocessing import os -import pickle +import pickle5 import subprocess import sys import tempfile @@ -86,7 +86,7 @@ def test_log_server(self): # Get status status = master._get_status() - client_jobs = pickle.loads(status).get('client_jobs') + client_jobs = pickle5.loads(status).get('client_jobs') self.assertIsNotNone(client_jobs) # Get job id From 3ecfad68d0f5cae6117110abac78157fc2131bb9 Mon Sep 17 00:00:00 2001 From: drryanhuang Date: Sat, 21 Jan 2023 10:02:17 +0800 Subject: [PATCH 2/2] update setup pickle5 --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 8165837e9..ec5cbd3b7 100644 --- a/setup.py +++ b/setup.py @@ -80,6 +80,7 @@ def find_version(*file_paths): "tensorboard", "flask>=1.0.4", "click", + "pickle5", "psutil>=5.6.2", "flask_cors", "requests",