5
\$\begingroup\$

I think I cooked again, but maybe not. I'll share with you the product of today's cooking:

@inject.autoparams()
def get_database(client: MongoClient):
    return client["cache"]

lock = threading.Lock()

def _mongocached(func, *args, **kwargs):
    invalid_args = [
        (i, arg) for i, arg in enumerate(args)
        if not isinstance(arg, Hashable) or getattr(type(arg), "__repr__") is object.__repr__
    ]
    
    invalid_kwargs = [
        (k, kwarg) for k, kwarg in kwargs.items()
        if not isinstance(kwarg, Hashable) or getattr(type(kwarg), "__repr__") is object.__repr__
    ]
    
    if invalid_args or invalid_kwargs:
        errors = []
        if invalid_args:
            errors.append("Positional arguments:")
            for index, arg in invalid_args:
                errors.append(f"    Index {index}: {arg} (type: {type(arg).__name__})")
        
        if invalid_kwargs:
            errors.append("Keyword arguments:")
            for key, kwarg in invalid_kwargs:
                errors.append(f"    Key '{key}': {kwarg} (type: {type(kwarg).__name__})")
        
        raise ValueError("The following arguments are not hashable:\n" + "\n".join(errors))

    combined_key = f"{func.__qualname__}{repr(args)}{repr(kwargs)}"
    cache_hash = hashlib.sha256(combined_key.encode()).hexdigest()

    database = get_database()
    cache_collection = database[f"service_{SERVICE_ID}"]

    with lock:
        cached_entry = cache_collection.find_one({"cache_hash": cache_hash})

        if cached_entry:
            entry = CacheEntry(**cached_entry)
            serialized_content = entry.content
            if entry.compressed:
                serialized_content = zlib.decompress(serialized_content)
            return pickle.loads(serialized_content)

        result = func(*args, **kwargs)
        serialized_result = pickle.dumps(result)
        is_compressed = len(serialized_result) > 16384
        if is_compressed:
            serialized_result = zlib.compress(serialized_result, level=7)

        new_cache_entry = CacheEntry(
            cache_hash=cache_hash,
            content=serialized_result,
            compressed=is_compressed
        )
        cache_collection.insert_one(new_cache_entry.model_dump())

    return result

def mongocached(func):
    @functools.wraps(func)
    def wrapper(*args, **kwargs):
        return _mongocached(func, *args, **kwargs)
    return wrapper

def mongocachedmethod(method):
    @functools.wraps(method)
    def wrapper(self, *args, **kwargs):
        return _mongocached(method, *args, **kwargs)
    return wrapper


What y'all think? Is it understandable? Could it be implemented better?

Perchance.

I know someone will say something about the use of repr as a way of hashing the arguments. Well, I know, I know... But are there any better alternatives for consistent caching among multiple application restarts? I think not, but who knows?

\$\endgroup\$
1
  • \$\begingroup\$ I find this to be not very readable. Consider adding a one-sentence """docstring""" when you write a function. And choose better identifiers than e.g. mongocachedmethod. Notice that an f-string treats {repr(args)} and {args!r} identically. \$\endgroup\$ Commented Nov 5, 2024 at 21:14

1 Answer 1

4
\$\begingroup\$

Personally I'd implement these changes:

1) Custom serialization protocol

def serialize_arg(arg):
    if isinstance(arg, (int, str, float, bool)):
        return repr(arg)
    elif isinstance(arg, (list, tuple)):
        return f"({type(arg).__name__})" + ",".join(serialize_arg(x) for x in arg)
    elif isinstance(arg, dict):
        return "{" + ",".join(f"{serialize_arg(k)}:{serialize_arg(v)}" for k,v in sorted(arg.items())) + "}"
    else:
        # Fall back to repr for other types
        return repr(arg)

This approach brings several advantages, like the explicit handling of nested structures, the type preservation in the cache key (type(arg).__name__) and the consistent ordering for dictionaries through sorted().

2) JSON Serialization

def make_cache_key(func, args, kwargs):
    try:
        key_parts = [func.__qualname__, json.dumps(args), json.dumps(kwargs, sort_keys=True)]
        return hashlib.sha256("|".join(key_parts).encode()).hexdigest()
    except TypeError:
        # Fall back to repr if JSON serialization fails
        return hashlib.sha256(f"{func.__qualname__}{repr(args)}{repr(kwargs)}".encode()).hexdigest()

This utility function leverages a standardized serialization format and sort_keys=True guarantees consistent ordering of dictionary keys. The "|" character (pipe) works as separator to prevent key collisions between different parts.

\$\endgroup\$

You must log in to answer this question.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.