FYI I would convert to float16 (2x smaller, almost no loss).
But best practical answer (keeps float32 precision at read time, shrinks disk by ~4x, gives millisecond random row fetch):
quantize to float8
Store as a single memory-mapped binary file
- one line gives you any subset of rows in ~1-2 ms with no servers no polars.
If you later want even smaller, combine light pca drop to 384 dims + float8 -> 60~GB total. thats the simplest, fastest route to both goals.
FP8 (float8) + memory-mapped.npy
import numpy as np
# --- WRITE ---
def write_fp8_memmap(batches, out_path, total_rows, dim, fp8='e4m3'):
dtype = np.float8_e4m3fn if fp8 == 'e4m3' else np.float8_e5m2
mm = np.lib.format.open_memmap(out_path, mode='w+', dtype=dtype, shape=(total_rows, dim))
i = 0
for x in batches: # x: (B, dim) float32 ndarray or torch->ndarray
b = x.shape[0]
mm[i:i+b] = x.astype(dtype, copy=False)
i += b
mm.flush()
# --- EOF WRITE ---
-
# --- READ ---
class FP8Memmap:
def __init__(self, path):
self.mm = np.load(path, mmap_mode='r')
def get_rows(self, ids):
return self.mm[ids].astype(np.float32, copy=False)
# --- EOF READ ---
-
# Example:
# write_fp8_memmap(iter_batches(), 'emb_fp8.npy', total_rows=160_000_000, dim=768)
# reader = FP8Memmap('emb_fp8.npy')
# subset = reader.get_rows([5, 42, 12345678])
2. Row-wise int8 + scale factors (works anywhere also 1 byte/value)
If FP8 is not available or want tighter control, quantize each row separately with a single scale value
# --- WRITE ---
def write_int8_rowwise(batches, data_path, scale_path, total_rows, dim, eps=1e-8):
q_mm = np.lib.format.open_memmap(data_path, mode='w+', dtype=np.int8, shape=(total_rows, dim))
scl_mm = np.lib.format.open_memmap(scale_path, mode='w+', dtype=np.float32, shape=(total_rows,))
i = 0
for x in batches:
b = x.shape[0]
m = np.max(np.abs(x), axis=1, keepdims=True) + eps
s = (m / 127.0).astype(np.float32)
q = np.clip(np.round(x / s), -127, 127).astype(np.int8)
q_mm[i:i+b] = q
scl_mm[i:i+b] = s[:,0]
i += b
q_mm.flush(); scl_mm.flush()
# --- EOF WRITE ---
-
# --- READ ---
class Int8RowwiseMemmap:
def __init__(self, data_path, scale_path):
self.q = np.load(data_path, mmap_mode='r')
self.scl = np.load(scale_path, mmap_mode='r')
def get_rows(self, ids):
q_rows = self.q[ids].astype(np.float32, copy=False)
s = self.scl[ids].astype(np.float32, copy=False)
return q_rows * s[:, None]
# --- EOF READ ---
-
# Example:
# write_int8_rowwise(iter_batches(), 'emb_i8.npy', 'emb_i8_scales.npy', 160_000_000, 768)
# reader = Int8RowwiseMemmap('emb_i8.npy', 'emb_i8_scales.npy')
# subset = reader.get_rows([10, 11, 12])