|
| 1 | +from dataclasses import dataclass |
| 2 | +from enum import Enum |
1 | 3 | from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple, Union |
2 | 4 |
|
3 | | -from redis.exceptions import RedisError, ResponseError |
| 5 | +from redis.exceptions import RedisError, ResponseError, IncorrectPolicyType |
4 | 6 | from redis.utils import str_if_bytes |
5 | 7 |
|
6 | 8 | if TYPE_CHECKING: |
7 | 9 | from redis.asyncio.cluster import ClusterNode |
8 | 10 |
|
| 11 | +class RequestPolicy(Enum): |
| 12 | + ALL_NODES = 'all_nodes' |
| 13 | + ALL_SHARDS = 'all_shards' |
| 14 | + MULTI_SHARD = 'multi_shard' |
| 15 | + SPECIAL = 'special' |
| 16 | + DEFAULT_KEYLESS = 'default_keyless' |
| 17 | + DEFAULT_KEYED = 'default_keyed' |
| 18 | + |
| 19 | +class ResponsePolicy(Enum): |
| 20 | + ONE_SUCCEEDED = 'one_succeeded' |
| 21 | + ALL_SUCCEEDED = 'all_succeeded' |
| 22 | + AGG_LOGICAL_AND = 'agg_logical_and' |
| 23 | + AGG_LOGICAL_OR = 'agg_logical_or' |
| 24 | + AGG_MIN = 'agg_min' |
| 25 | + AGG_MAX = 'agg_max' |
| 26 | + AGG_SUM = 'agg_sum' |
| 27 | + SPECIAL = 'special' |
| 28 | + DEFAULT_KEYLESS = 'default_keyless' |
| 29 | + DEFAULT_KEYED = 'default_keyed' |
| 30 | + |
| 31 | +class CommandPolicies: |
| 32 | + def __init__( |
| 33 | + self, |
| 34 | + request_policy: RequestPolicy = RequestPolicy.DEFAULT_KEYLESS, |
| 35 | + response_policy: ResponsePolicy = ResponsePolicy.DEFAULT_KEYLESS |
| 36 | + ): |
| 37 | + self.request_policy = request_policy |
| 38 | + self.response_policy = response_policy |
| 39 | + |
| 40 | +PolicyRecords = dict[str, dict[str, CommandPolicies]] |
9 | 41 |
|
10 | 42 | class AbstractCommandsParser: |
11 | 43 | def _get_pubsub_keys(self, *args): |
@@ -64,7 +96,8 @@ class CommandsParser(AbstractCommandsParser): |
64 | 96 |
|
65 | 97 | def __init__(self, redis_connection): |
66 | 98 | self.commands = {} |
67 | | - self.initialize(redis_connection) |
| 99 | + self.redis_connection = redis_connection |
| 100 | + self.initialize(self.redis_connection) |
68 | 101 |
|
69 | 102 | def initialize(self, r): |
70 | 103 | commands = r.command() |
@@ -169,6 +202,173 @@ def _get_moveable_keys(self, redis_conn, *args): |
169 | 202 | raise e |
170 | 203 | return keys |
171 | 204 |
|
| 205 | + def _is_keyless_command(self, command_name: str, subcommand_name: Optional[str]=None) -> bool: |
| 206 | + """ |
| 207 | + Determines whether a given command or subcommand is considered "keyless". |
| 208 | +
|
| 209 | + A keyless command does not operate on specific keys, which is determined based |
| 210 | + on the first key position in the command or subcommand details. If the command |
| 211 | + or subcommand's first key position is zero or negative, it is treated as keyless. |
| 212 | +
|
| 213 | + Parameters: |
| 214 | + command_name: str |
| 215 | + The name of the command to check. |
| 216 | + subcommand_name: Optional[str], default=None |
| 217 | + The name of the subcommand to check, if applicable. If not provided, |
| 218 | + the check is performed only on the command. |
| 219 | +
|
| 220 | + Returns: |
| 221 | + bool |
| 222 | + True if the specified command or subcommand is considered keyless, |
| 223 | + False otherwise. |
| 224 | +
|
| 225 | + Raises: |
| 226 | + ValueError |
| 227 | + If the specified subcommand is not found within the command or the |
| 228 | + specified command does not exist in the available commands. |
| 229 | + """ |
| 230 | + if subcommand_name: |
| 231 | + for subcommand in self.commands.get(command_name)['subcommands']: |
| 232 | + if str_if_bytes(subcommand[0]) == subcommand_name: |
| 233 | + parsed_subcmd = self.parse_subcommand(subcommand) |
| 234 | + return parsed_subcmd['first_key_pos'] <= 0 |
| 235 | + raise ValueError(f"Subcommand {subcommand_name} not found in command {command_name}") |
| 236 | + else: |
| 237 | + command_details = self.commands.get(command_name, None) |
| 238 | + if command_details is not None: |
| 239 | + return command_details['first_key_pos'] <= 0 |
| 240 | + |
| 241 | + raise ValueError(f"Command {command_name} not found in commands") |
| 242 | + |
| 243 | + def get_command_policies(self) -> PolicyRecords: |
| 244 | + """ |
| 245 | + Retrieve and process the command policies for all commands and subcommands. |
| 246 | +
|
| 247 | + This method traverses through commands and subcommands, extracting policy details |
| 248 | + from associated data structures and constructing a dictionary of commands with their |
| 249 | + associated policies. It supports nested data structures and handles both main commands |
| 250 | + and their subcommands. |
| 251 | +
|
| 252 | + Returns: |
| 253 | + PolicyRecords: A collection of commands and subcommands associated with their |
| 254 | + respective policies. |
| 255 | +
|
| 256 | + Raises: |
| 257 | + IncorrectPolicyType: If an invalid policy type is encountered during policy extraction. |
| 258 | + """ |
| 259 | + command_with_policies = {} |
| 260 | + |
| 261 | + def extract_policies(data, module_name, command_name): |
| 262 | + """ |
| 263 | + Recursively extract policies from nested data structures. |
| 264 | + |
| 265 | + Args: |
| 266 | + data: The data structure to search (can be list, dict, str, bytes, etc.) |
| 267 | + command_name: The command name to associate with found policies |
| 268 | + """ |
| 269 | + if isinstance(data, (str, bytes)): |
| 270 | + # Decode bytes to string if needed |
| 271 | + policy = str_if_bytes(data.decode()) |
| 272 | + |
| 273 | + # Check if this is a policy string |
| 274 | + if policy.startswith('request_policy') or policy.startswith('response_policy'): |
| 275 | + if policy.startswith('request_policy'): |
| 276 | + policy_type = policy.split(':')[1] |
| 277 | + |
| 278 | + try: |
| 279 | + command_with_policies[module_name][command_name].request_policy = RequestPolicy(policy_type) |
| 280 | + except ValueError: |
| 281 | + raise IncorrectPolicyType(f"Incorrect request policy type: {policy_type}") |
| 282 | + |
| 283 | + if policy.startswith('response_policy'): |
| 284 | + policy_type = policy.split(':')[1] |
| 285 | + |
| 286 | + try: |
| 287 | + command_with_policies[module_name][command_name].response_policy = ResponsePolicy(policy_type) |
| 288 | + except ValueError: |
| 289 | + raise IncorrectPolicyType(f"Incorrect response policy type: {policy_type}") |
| 290 | + |
| 291 | + elif isinstance(data, list): |
| 292 | + # For lists, recursively process each element |
| 293 | + for item in data: |
| 294 | + extract_policies(item, module_name, command_name) |
| 295 | + |
| 296 | + elif isinstance(data, dict): |
| 297 | + # For dictionaries, recursively process each value |
| 298 | + for value in data.values(): |
| 299 | + extract_policies(value, module_name, command_name) |
| 300 | + |
| 301 | + for command, details in self.commands.items(): |
| 302 | + # Check whether the command has keys |
| 303 | + is_keyless = self._is_keyless_command(command) |
| 304 | + |
| 305 | + if is_keyless: |
| 306 | + default_request_policy = RequestPolicy.DEFAULT_KEYLESS |
| 307 | + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS |
| 308 | + else: |
| 309 | + default_request_policy = RequestPolicy.DEFAULT_KEYED |
| 310 | + default_response_policy = ResponsePolicy.DEFAULT_KEYED |
| 311 | + |
| 312 | + # Check if it's a core or module command |
| 313 | + split_name = command.split('.') |
| 314 | + |
| 315 | + if len(split_name) > 1: |
| 316 | + module_name = split_name[0] |
| 317 | + command_name = split_name[1] |
| 318 | + else: |
| 319 | + module_name = 'core' |
| 320 | + command_name = split_name[0] |
| 321 | + |
| 322 | + # Create a CommandPolicies object with default policies on the new command. |
| 323 | + if command_with_policies.get(module_name, None) is None: |
| 324 | + command_with_policies[module_name] = {command_name: CommandPolicies( |
| 325 | + request_policy=default_request_policy, |
| 326 | + response_policy=default_response_policy |
| 327 | + )} |
| 328 | + else: |
| 329 | + command_with_policies[module_name][command_name] = CommandPolicies( |
| 330 | + request_policy=default_request_policy, |
| 331 | + response_policy=default_response_policy |
| 332 | + ) |
| 333 | + |
| 334 | + tips = details.get('tips') |
| 335 | + subcommands = details.get('subcommands') |
| 336 | + |
| 337 | + # Process tips for the main command |
| 338 | + if tips: |
| 339 | + extract_policies(tips, module_name, command_name) |
| 340 | + |
| 341 | + # Process subcommands |
| 342 | + if subcommands: |
| 343 | + for subcommand_details in subcommands: |
| 344 | + # Get the subcommand name (first element) |
| 345 | + subcmd_name = subcommand_details[0] |
| 346 | + if isinstance(subcmd_name, bytes): |
| 347 | + subcmd_name = subcmd_name.decode() |
| 348 | + |
| 349 | + # Check whether the subcommand has keys |
| 350 | + is_keyless = self._is_keyless_command(command, subcmd_name) |
| 351 | + |
| 352 | + if is_keyless: |
| 353 | + default_request_policy = RequestPolicy.DEFAULT_KEYLESS |
| 354 | + default_response_policy = ResponsePolicy.DEFAULT_KEYLESS |
| 355 | + else: |
| 356 | + default_request_policy = RequestPolicy.DEFAULT_KEYED |
| 357 | + default_response_policy = ResponsePolicy.DEFAULT_KEYED |
| 358 | + |
| 359 | + subcmd_name = subcmd_name.replace('|', ' ') |
| 360 | + |
| 361 | + # Create a CommandPolicies object with default policies on the new command. |
| 362 | + command_with_policies[module_name][subcmd_name] = CommandPolicies( |
| 363 | + request_policy=default_request_policy, |
| 364 | + response_policy=default_response_policy |
| 365 | + ) |
| 366 | + |
| 367 | + # Recursively extract policies from the rest of the subcommand details |
| 368 | + for subcommand_detail in subcommand_details[1:]: |
| 369 | + extract_policies(subcommand_detail, module_name, subcmd_name) |
| 370 | + |
| 371 | + return command_with_policies |
172 | 372 |
|
173 | 373 | class AsyncCommandsParser(AbstractCommandsParser): |
174 | 374 | """ |
|
0 commit comments