Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 21 additions & 7 deletions RLTest/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ def split_by_semicolon(s):

def args_list_to_dict(args_list):
def dicty(args):
return dict((seq.split(' ')[0], seq) for seq in args)
return {seq.split(' ')[0].upper(): seq for seq in args}
return list(map(lambda args: dicty(args), args_list))

def join_lists(lists):
Expand All @@ -105,18 +105,32 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True):
# ['args ...', ...]: arg list for a single module
# [['arg', ...', ...], ...]: arg strings for multiple modules

# arg string is a string of words seperated by whitespace
# arg string can be seperated by semicolons into (logical) arg lists.
# arg string is a string of words separated by whitespace.
# arg string can be separated by semicolons into (logical) arg lists.
# semicolons can be escaped with a backslash.
# if no semicolons are present, the string is treated as space-separated key-value pairs,
# where each consecutive pair of words forms a 'KEY VALUE' arg.
# thus, 'K1 V1 K2 V2' becomes ['K1 V1', 'K2 V2']
# an odd number of words without semicolons is an error.
# for args with multiple values, semicolons are required:
# thus, 'K1 V1; K2 V2 V3' becomes ['K1 V1', 'K2 V2 V3']
# arg list is a list of arg strings.
# arg list starts with an arg name that can later be used for argument overriding.
# arg strings are transformed into arg lists (haveSeqs parameter controls this behavior):
# thus, 'num 1; names a b' becomes ['num 1', 'names a b']

if type(modulesArgs) == str:
# case # 'args ...': arg string for a single module
# transformed into [['arg', ...]]
modulesArgs = [split_by_semicolon(modulesArgs)]
parts = split_by_semicolon(modulesArgs)
if len(parts) == 1:
# No semicolons found - treat as space-separated key-value pairs
words = parts[0].split()
if len(words) % 2 != 0:
print(Colors.Bred(f"Error in args: odd number of words in key-value pairs: '{modulesArgs}'. "
f"Use semicolons to separate args with multiple values (e.g. 'KEY1 V1; KEY2 V2 V3')."))
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: f-string not needed here.

sys.exit(1)
if len(words) > 2:
parts = [f"{words[i]} {words[i + 1]}" for i in range(0, len(words), 2)]
modulesArgs = [parts]
elif type(modulesArgs) == list:
args = []
is_list = False
Expand Down Expand Up @@ -180,7 +194,7 @@ def fix_modulesArgs(modules, modulesArgs, defaultArgs=None, haveSeqs=True):
modules_args_dict = args_list_to_dict(modulesArgs)
for imod, args_list in enumerate(defaultArgs):
for arg in args_list:
name = arg.split(' ')[0]
name = arg.split(' ')[0].upper()
if name not in modules_args_dict[imod]:
modulesArgs[imod] += [arg]

Expand Down
110 changes: 110 additions & 0 deletions tests/unit/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
from unittest import TestCase

from RLTest.utils import fix_modulesArgs


class TestFixModulesArgs(TestCase):

# 1. Single key-value pair string
def test_single_key_value_pair(self):
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4')
self.assertEqual(result, [['WORKERS 4']])

# 2. Multiple key-value pairs without semicolons (new behavior)
def test_multiple_kv_pairs_no_semicolons(self):
result = fix_modulesArgs(['/mod.so'], '_FREE_RESOURCE_ON_THREAD FALSE TIMEOUT 80 WORKERS 4')
self.assertEqual(result, [['_FREE_RESOURCE_ON_THREAD FALSE', 'TIMEOUT 80', 'WORKERS 4']])

# 3. Semicolon-separated args (existing behavior)
def test_semicolon_separated_args(self):
result = fix_modulesArgs(['/mod.so'], 'KEY1 V1; KEY2 V2')
self.assertEqual(result, [['KEY1 V1', 'KEY2 V2']])

# 4a. Odd number of words without semicolons - should error
def test_odd_words_no_semicolons_exits(self):
with self.assertRaises(SystemExit):
fix_modulesArgs(['/mod.so'], 'FLAG TIMEOUT 80')

# 4b. Odd number of words with semicolons - valid, semicolons split first
def test_odd_words_with_semicolons_valid(self):
result = fix_modulesArgs(['/mod.so'], 'FLAG; TIMEOUT 80')
self.assertEqual(result, [['FLAG', 'TIMEOUT 80']])

# 5a. Space-separated string overrides matching defaults, non-matching defaults added
def test_space_separated_overrides_defaults(self):
defaults = [['WORKERS 8', 'TIMEOUT 60', 'EXTRA 1']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['EXTRA'], 'EXTRA 1')

# 5b. Semicolon-separated string overrides matching defaults
def test_semicolon_separated_overrides_defaults(self):
defaults = [['WORKERS 8', 'TIMEOUT 60', 'EXTRA 1']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4; TIMEOUT 80', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['EXTRA'], 'EXTRA 1')

# 5c. Space-separated explicit overrides some defaults, non-overlapping defaults are merged
def test_space_separated_partial_override_with_defaults(self):
defaults = [['_FREE_RESOURCE_ON_THREAD TRUE', 'TIMEOUT 100', 'WORKERS 8']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['_FREE_RESOURCE_ON_THREAD'], '_FREE_RESOURCE_ON_THREAD TRUE')

# 6. None input with defaults - deep copy of defaults
def test_none_uses_defaults(self):
defaults = [['WORKERS 8', 'TIMEOUT 60']]
result = fix_modulesArgs(['/mod.so'], None, defaults)
self.assertEqual(result, defaults)
# Verify it's a deep copy
result[0][0] = 'MODIFIED'
self.assertEqual(defaults[0][0], 'WORKERS 8')

# 7. List of strings with defaults - overlapping and non-overlapping keys
def test_list_of_strings_with_defaults(self):
defaults = [['K1 default1', 'K2 default2', 'K4 default4']]
result = fix_modulesArgs(['/mod.so'], ['K1 override1', 'K2 override2', 'K3 new3'], defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['K1'], 'K1 override1')
self.assertEqual(result_dict['K2'], 'K2 override2')
self.assertEqual(result_dict['K3'], 'K3 new3')
self.assertEqual(result_dict['K4'], 'K4 default4')

# 8. List of lists (multi-module) with defaults - overlapping and non-overlapping keys
def test_multi_module_with_defaults(self):
modules = ['/mod1.so', '/mod2.so']
explicit = [['K1 v1', 'K2 v2'], ['K3 v3']]
defaults = [['K1 d1', 'K5 d5'], ['K3 d3', 'K4 d4']]
result = fix_modulesArgs(modules, explicit, defaults)
# Module 1: K1 overridden, K5 added from defaults
dict1 = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(dict1['K1'], 'K1 v1')
self.assertEqual(dict1['K2'], 'K2 v2')
self.assertEqual(dict1['K5'], 'K5 d5')
# Module 2: K3 overridden, K4 added from defaults
dict2 = {arg.split(' ')[0]: arg for arg in result[1]}
self.assertEqual(dict2['K3'], 'K3 v3')
self.assertEqual(dict2['K4'], 'K4 d4')


# 9. Case-insensitive matching between explicit args and defaults (both directions)
def test_case_insensitive_override(self):
# Uppercase explicit overrides lowercase defaults
defaults = [['workers 8', 'timeout 60', 'EXTRA 1', 'MIxEd 7', 'lower true']]
result = fix_modulesArgs(['/mod.so'], 'WORKERS 4 TIMEOUT 80 miXed 0 LOWER false', defaults)
result_dict = {arg.split(' ')[0]: arg for arg in result[0]}
self.assertEqual(result_dict['WORKERS'], 'WORKERS 4')
self.assertEqual(result_dict['TIMEOUT'], 'TIMEOUT 80')
self.assertEqual(result_dict['EXTRA'], 'EXTRA 1')
self.assertEqual(result_dict['miXed'], 'miXed 0')
self.assertEqual(result_dict['LOWER'], 'LOWER false')
self.assertNotIn('workers', result_dict)
self.assertNotIn('timeout', result_dict)
self.assertNotIn('MIxEd', result_dict)
self.assertNotIn('lower', result_dict)
Loading