2

I have embeddings (about 160 Million) that I created with a BERT-based encoder model. Right now they are in a .pt format and takes about 500GB in the disk. I want 2 things:

  1. To save them in an efficient way if possible (minimum disk space as possible)

  2. It is very important to me that I will be able to have fast random access to the embeddings, meaning I can get a group of embeddings given a list of their ids (their id can be the row number).

Is there a recommended way to do that?


what I already have tried for the random access:

  • polars: I tried to save them in the parquet format and scan the df in a lazy way, then filter by row index to get only the rows (embedding) I want. this was very slow. if there is other, faster, way to do this in polars I will be happy to hear!

details:

  1. The embedding are in float32 and right now I want to keep this precision.
  2. The embedding dim is 768.

1 Answer 1

1

FYI I would convert to float16 (2x smaller, almost no loss).

But best practical answer (keeps float32 precision at read time, shrinks disk by ~4x, gives millisecond random row fetch):

  1. quantize to float8

    • 1 byte per value instead of 4 bytes -> 500gb -> 125gb

    • benchmarks (Naamán Huerga-Pérez et al.) shows smaller than 0.3 quality loss.

  2. Store as a single memory-mapped binary file

    • one line gives you any subset of rows in ~1-2 ms with no servers no polars.

If you later want even smaller, combine light pca drop to 384 dims + float8 -> 60~GB total. thats the simplest, fastest route to both goals.

  1. FP8 (float8) + memory-mapped.npy

    import numpy as np
    
    # --- WRITE ---
    def write_fp8_memmap(batches, out_path, total_rows, dim, fp8='e4m3'):
        dtype = np.float8_e4m3fn if fp8 == 'e4m3' else np.float8_e5m2
        mm = np.lib.format.open_memmap(out_path, mode='w+', dtype=dtype, shape=(total_rows, dim))
        i = 0
        for x in batches:  # x: (B, dim) float32 ndarray or torch->ndarray
            b = x.shape[0]
            mm[i:i+b] = x.astype(dtype, copy=False)
            i += b
        mm.flush()
    # --- EOF WRITE ---
    
  2. 
    # --- READ ---
    class FP8Memmap:
        def __init__(self, path):
            self.mm = np.load(path, mmap_mode='r')
        def get_rows(self, ids):
            return self.mm[ids].astype(np.float32, copy=False)
    # --- EOF READ ---
    
  3. 
    # Example:
    # write_fp8_memmap(iter_batches(), 'emb_fp8.npy', total_rows=160_000_000, dim=768)
    # reader = FP8Memmap('emb_fp8.npy')
    # subset = reader.get_rows([5, 42, 12345678])
    

    2. Row-wise int8 + scale factors (works anywhere also 1 byte/value)

    If FP8 is not available or want tighter control, quantize each row separately with a single scale value

    
    # --- WRITE ---
    def write_int8_rowwise(batches, data_path, scale_path, total_rows, dim, eps=1e-8):
        q_mm   = np.lib.format.open_memmap(data_path,  mode='w+', dtype=np.int8,    shape=(total_rows, dim))
        scl_mm = np.lib.format.open_memmap(scale_path, mode='w+', dtype=np.float32, shape=(total_rows,))
        i = 0
        for x in batches:
            b = x.shape[0]
            m = np.max(np.abs(x), axis=1, keepdims=True) + eps
            s = (m / 127.0).astype(np.float32)
            q = np.clip(np.round(x / s), -127, 127).astype(np.int8)
            q_mm[i:i+b]   = q
            scl_mm[i:i+b] = s[:,0]
            i += b
        q_mm.flush(); scl_mm.flush()
    # --- EOF WRITE ---
    
  4. 
    # --- READ ---
    class Int8RowwiseMemmap:
        def __init__(self, data_path, scale_path):
            self.q   = np.load(data_path,  mmap_mode='r')
            self.scl = np.load(scale_path, mmap_mode='r')
        def get_rows(self, ids):
            q_rows = self.q[ids].astype(np.float32, copy=False)
            s = self.scl[ids].astype(np.float32, copy=False)
            return q_rows * s[:, None]
    # --- EOF READ ---
    
  5. 
    # Example:
    # write_int8_rowwise(iter_batches(), 'emb_i8.npy', 'emb_i8_scales.npy', 160_000_000, 768)
    # reader = Int8RowwiseMemmap('emb_i8.npy', 'emb_i8_scales.npy')
    # subset = reader.get_rows([10, 11, 12])
    
Sign up to request clarification or add additional context in comments.

2 Comments

Thank you very much! Maybe do you have example code for best practice quantizing and saving to a memory-mapped binary file?
Ofc. Two practical options depending on whether your NumPy has native float8 support or not. If you are on numpy 2.0 or >, it has built-in FP8 types ( e4m3/e5m2) which is the cleanest way. 1 byte per value, keeps everything in a single file, and you can still instantly load any subset of rows. I will update my previous answer for the code you can check it.

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.