1717
1818#include "distribution_metadata.h"
1919#include "prune_shard_list.h"
20+ #include "create_shards.h"
2021
2122#include <stddef.h>
2223
@@ -69,6 +70,50 @@ static List * BuildRestrictInfoList(List *qualList);
6970static Node * BuildBaseConstraint (Var * column );
7071static void UpdateConstraint (Node * baseConstraint , ShardInterval * shardInterval );
7172
73+ static HTAB * shardPlacementCache ;
74+
75+ typedef struct ShardPlacementEntryCacheEntry
76+ {
77+ Oid tableId ;
78+ List * * placements ;
79+ } ShardPlacementCacheEntry ;
80+
81+ #define MAX_DISTRIBUTED_TABLES 101
82+
83+ static List *
84+ LookupShardPlacementCache (Oid relationId , int shardHashCode )
85+ {
86+ ShardPlacementCacheEntry * entry = NULL ;
87+
88+ if (shardPlacementCache == NULL )
89+ {
90+ HASHCTL info ;
91+ int hashFlags = (HASH_ELEM | HASH_BLOBS | HASH_CONTEXT );
92+
93+ memset (& info , 0 , sizeof (info ));
94+ info .keysize = sizeof (Oid );
95+ info .entrysize = sizeof (ShardPlacementCacheEntry );
96+ info .hcxt = CacheMemoryContext ;
97+
98+ shardPlacementCache = hash_create ("pg_shard placement cache" , MAX_DISTRIBUTED_TABLES , & info , hashFlags );
99+ }
100+ entry = hash_search (shardPlacementCache , & relationId , HASH_FIND , NULL );
101+ return (entry != NULL ) ? entry -> placements [shardHashCode ] : NULL ;
102+ }
103+
104+ static void
105+ AddToShardPlacementCache (Oid relationId , int shardHashCode , int shardCount , List * shardPlacements )
106+ {
107+ MemoryContext oldContext = MemoryContextSwitchTo (CacheMemoryContext );
108+ bool found = false;
109+ ShardPlacementCacheEntry * entry = (ShardPlacementCacheEntry * )hash_search (shardPlacementCache , & relationId , HASH_ENTER , & found );
110+ if (!found )
111+ {
112+ entry -> placements = palloc0 (shardCount * sizeof (List * ));
113+ }
114+ entry -> placements [shardHashCode ] = list_copy (shardPlacements );
115+ MemoryContextSwitchTo (oldContext );
116+ }
72117
73118/*
74119 * PruneShardList prunes shards from given list based on the selection criteria,
@@ -81,7 +126,7 @@ PruneShardList(Oid relationId, List *whereClauseList, List *shardIntervalList)
81126 ListCell * shardIntervalCell = NULL ;
82127 List * restrictInfoList = NIL ;
83128 Node * baseConstraint = NULL ;
84-
129+ int shardHashCode = -1 ;
85130 Var * partitionColumn = PartitionColumn (relationId );
86131 char partitionMethod = PartitionType (relationId );
87132
@@ -97,10 +142,51 @@ PruneShardList(Oid relationId, List *whereClauseList, List *shardIntervalList)
97142
98143 case HASH_PARTITION_TYPE :
99144 {
100- Node * hashedNode = HashableClauseMutator ((Node * ) whereClauseList ,
101- partitionColumn );
145+ Node * hashedNode = NULL ;
146+ List * hashedClauseList = NULL ;
147+ if (whereClauseList && whereClauseList -> length == 1 )
148+ {
149+ Expr * predicate = (Expr * )lfirst (list_head (whereClauseList ));
150+
151+ if (IsA (predicate , OpExpr ))
152+ {
153+ OpExpr * operatorExpression = (OpExpr * ) predicate ;
154+ Oid leftHashFunction = InvalidOid ;
155+ Oid rightHashFunction = InvalidOid ;
156+ if (get_op_hash_functions (operatorExpression -> opno ,
157+ & leftHashFunction ,
158+ & rightHashFunction )
159+ && SimpleOpExpression (predicate )
160+ && OpExpressionContainsColumn (operatorExpression ,
161+ partitionColumn ))
162+ {
163+ Node * leftOperand = get_leftop (predicate );
164+ Node * rightOperand = get_rightop (predicate );
165+ Const * constant = (Const * )(IsA (rightOperand , Const ) ? rightOperand : leftOperand );
166+ TypeCacheEntry * typeEntry = lookup_type_cache (constant -> consttype , TYPECACHE_HASH_PROC_FINFO );
167+ FmgrInfo * hashFunction = & (typeEntry -> hash_proc_finfo );
168+ if (OidIsValid (hashFunction -> fn_oid ))
169+ {
170+ int hashedValue = DatumGetInt32 (FunctionCall1 (hashFunction , constant -> constvalue ));
171+ int shardCount = shardIntervalList -> length ;
172+ uint32 hashTokenIncrement = (uint32 )(HASH_TOKEN_COUNT / shardCount );
173+ shardHashCode = (int )((uint32 )(hashedValue - INT32_MIN ) / hashTokenIncrement );
174+ remainingShardList = LookupShardPlacementCache (relationId , shardHashCode );
175+ if (remainingShardList != NULL )
176+ {
177+ return remainingShardList ;
178+ }
179+ }
180+ }
181+ }
182+ }
183+ hashedNode = HashableClauseMutator ((Node * ) whereClauseList ,
184+ partitionColumn );
185+ hashedClauseList = (List * ) hashedNode ;
186+ restrictInfoList = BuildRestrictInfoList (hashedClauseList );
102187
103- List * hashedClauseList = (List * ) hashedNode ;
188+ /* override the partition column for hash partitioning */
189+ partitionColumn = MakeInt4Column ();
104190 restrictInfoList = BuildRestrictInfoList (hashedClauseList );
105191
106192 /* override the partition column for hash partitioning */
@@ -141,7 +227,10 @@ PruneShardList(Oid relationId, List *whereClauseList, List *shardIntervalList)
141227 remainingShardList = lappend (remainingShardList , & (shardInterval -> id ));
142228 }
143229 }
144-
230+ if (shardHashCode >= 0 )
231+ {
232+ AddToShardPlacementCache (relationId , shardHashCode , shardIntervalList -> length , remainingShardList );
233+ }
145234 return remainingShardList ;
146235}
147236
0 commit comments