@@ -141,10 +141,14 @@ typedef struct
141141 Port * port ;
142142 bool channel_binding_in_use ;
143143
144+ /* State data depending on the hash type */
145+ pg_cryptohash_type hash_type ;
146+ int key_length ;
147+
144148 int iterations ;
145149 char * salt ; /* base64-encoded */
146- uint8 StoredKey [SCRAM_KEY_LEN ];
147- uint8 ServerKey [SCRAM_KEY_LEN ];
150+ uint8 StoredKey [SCRAM_MAX_KEY_LEN ];
151+ uint8 ServerKey [SCRAM_MAX_KEY_LEN ];
148152
149153 /* Fields of the first message from client */
150154 char cbind_flag ;
@@ -155,7 +159,7 @@ typedef struct
155159 /* Fields from the last message from client */
156160 char * client_final_message_without_proof ;
157161 char * client_final_nonce ;
158- char ClientProof [SCRAM_KEY_LEN ];
162+ char ClientProof [SCRAM_MAX_KEY_LEN ];
159163
160164 /* Fields generated in the server */
161165 char * server_first_message ;
@@ -177,12 +181,15 @@ static char *build_server_first_message(scram_state *state);
177181static char * build_server_final_message (scram_state * state );
178182static bool verify_client_proof (scram_state * state );
179183static bool verify_final_nonce (scram_state * state );
180- static void mock_scram_secret (const char * username , int * iterations ,
181- char * * salt , uint8 * stored_key , uint8 * server_key );
184+ static void mock_scram_secret (const char * username , pg_cryptohash_type * hash_type ,
185+ int * iterations , int * key_length , char * * salt ,
186+ uint8 * stored_key , uint8 * server_key );
182187static bool is_scram_printable (char * p );
183188static char * sanitize_char (char c );
184189static char * sanitize_str (const char * s );
185- static char * scram_mock_salt (const char * username );
190+ static char * scram_mock_salt (const char * username ,
191+ pg_cryptohash_type hash_type ,
192+ int key_length );
186193
187194/*
188195 * Get a list of SASL mechanisms that this module supports.
@@ -266,8 +273,11 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
266273
267274 if (password_type == PASSWORD_TYPE_SCRAM_SHA_256 )
268275 {
269- if (parse_scram_secret (shadow_pass , & state -> iterations , & state -> salt ,
270- state -> StoredKey , state -> ServerKey ))
276+ if (parse_scram_secret (shadow_pass , & state -> iterations ,
277+ & state -> hash_type , & state -> key_length ,
278+ & state -> salt ,
279+ state -> StoredKey ,
280+ state -> ServerKey ))
271281 got_secret = true;
272282 else
273283 {
@@ -310,8 +320,10 @@ scram_init(Port *port, const char *selected_mech, const char *shadow_pass)
310320 */
311321 if (!got_secret )
312322 {
313- mock_scram_secret (state -> port -> user_name , & state -> iterations ,
314- & state -> salt , state -> StoredKey , state -> ServerKey );
323+ mock_scram_secret (state -> port -> user_name , & state -> hash_type ,
324+ & state -> iterations , & state -> key_length ,
325+ & state -> salt ,
326+ state -> StoredKey , state -> ServerKey );
315327 state -> doomed = true;
316328 }
317329
@@ -482,7 +494,8 @@ pg_be_scram_build_secret(const char *password)
482494 (errcode (ERRCODE_INTERNAL_ERROR ),
483495 errmsg ("could not generate random salt" )));
484496
485- result = scram_build_secret (saltbuf , SCRAM_DEFAULT_SALT_LEN ,
497+ result = scram_build_secret (PG_SHA256 , SCRAM_SHA_256_KEY_LEN ,
498+ saltbuf , SCRAM_DEFAULT_SALT_LEN ,
486499 SCRAM_DEFAULT_ITERATIONS , password ,
487500 & errstr );
488501
@@ -505,16 +518,18 @@ scram_verify_plain_password(const char *username, const char *password,
505518 char * salt ;
506519 int saltlen ;
507520 int iterations ;
508- uint8 salted_password [SCRAM_KEY_LEN ];
509- uint8 stored_key [SCRAM_KEY_LEN ];
510- uint8 server_key [SCRAM_KEY_LEN ];
511- uint8 computed_key [SCRAM_KEY_LEN ];
521+ int key_length = 0 ;
522+ pg_cryptohash_type hash_type ;
523+ uint8 salted_password [SCRAM_MAX_KEY_LEN ];
524+ uint8 stored_key [SCRAM_MAX_KEY_LEN ];
525+ uint8 server_key [SCRAM_MAX_KEY_LEN ];
526+ uint8 computed_key [SCRAM_MAX_KEY_LEN ];
512527 char * prep_password ;
513528 pg_saslprep_rc rc ;
514529 const char * errstr = NULL ;
515530
516- if (!parse_scram_secret (secret , & iterations , & encoded_salt ,
517- stored_key , server_key ))
531+ if (!parse_scram_secret (secret , & iterations , & hash_type , & key_length ,
532+ & encoded_salt , stored_key , server_key ))
518533 {
519534 /*
520535 * The password looked like a SCRAM secret, but could not be parsed.
@@ -541,9 +556,11 @@ scram_verify_plain_password(const char *username, const char *password,
541556 password = prep_password ;
542557
543558 /* Compute Server Key based on the user-supplied plaintext password */
544- if (scram_SaltedPassword (password , salt , saltlen , iterations ,
559+ if (scram_SaltedPassword (password , hash_type , key_length ,
560+ salt , saltlen , iterations ,
545561 salted_password , & errstr ) < 0 ||
546- scram_ServerKey (salted_password , computed_key , & errstr ) < 0 )
562+ scram_ServerKey (salted_password , hash_type , key_length ,
563+ computed_key , & errstr ) < 0 )
547564 {
548565 elog (ERROR , "could not compute server key: %s" , errstr );
549566 }
@@ -555,7 +572,7 @@ scram_verify_plain_password(const char *username, const char *password,
555572 * Compare the secret's Server Key with the one computed from the
556573 * user-supplied password.
557574 */
558- return memcmp (computed_key , server_key , SCRAM_KEY_LEN ) == 0 ;
575+ return memcmp (computed_key , server_key , key_length ) == 0 ;
559576}
560577
561578
@@ -565,14 +582,15 @@ scram_verify_plain_password(const char *username, const char *password,
565582 * On success, the iteration count, salt, stored key, and server key are
566583 * extracted from the secret, and returned to the caller. For 'stored_key'
567584 * and 'server_key', the caller must pass pre-allocated buffers of size
568- * SCRAM_KEY_LEN . Salt is returned as a base64-encoded, null-terminated
585+ * SCRAM_MAX_KEY_LEN . Salt is returned as a base64-encoded, null-terminated
569586 * string. The buffer for the salt is palloc'd by this function.
570587 *
571588 * Returns true if the SCRAM secret has been parsed, and false otherwise.
572589 */
573590bool
574- parse_scram_secret (const char * secret , int * iterations , char * * salt ,
575- uint8 * stored_key , uint8 * server_key )
591+ parse_scram_secret (const char * secret , int * iterations ,
592+ pg_cryptohash_type * hash_type , int * key_length ,
593+ char * * salt , uint8 * stored_key , uint8 * server_key )
576594{
577595 char * v ;
578596 char * p ;
@@ -606,6 +624,8 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
606624 /* Parse the fields */
607625 if (strcmp (scheme_str , "SCRAM-SHA-256" ) != 0 )
608626 goto invalid_secret ;
627+ * hash_type = PG_SHA256 ;
628+ * key_length = SCRAM_SHA_256_KEY_LEN ;
609629
610630 errno = 0 ;
611631 * iterations = strtol (iterations_str , & p , 10 );
@@ -631,17 +651,17 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
631651 decoded_stored_buf = palloc (decoded_len );
632652 decoded_len = pg_b64_decode (storedkey_str , strlen (storedkey_str ),
633653 decoded_stored_buf , decoded_len );
634- if (decoded_len != SCRAM_KEY_LEN )
654+ if (decoded_len != * key_length )
635655 goto invalid_secret ;
636- memcpy (stored_key , decoded_stored_buf , SCRAM_KEY_LEN );
656+ memcpy (stored_key , decoded_stored_buf , * key_length );
637657
638658 decoded_len = pg_b64_dec_len (strlen (serverkey_str ));
639659 decoded_server_buf = palloc (decoded_len );
640660 decoded_len = pg_b64_decode (serverkey_str , strlen (serverkey_str ),
641661 decoded_server_buf , decoded_len );
642- if (decoded_len != SCRAM_KEY_LEN )
662+ if (decoded_len != * key_length )
643663 goto invalid_secret ;
644- memcpy (server_key , decoded_server_buf , SCRAM_KEY_LEN );
664+ memcpy (server_key , decoded_server_buf , * key_length );
645665
646666 return true;
647667
@@ -655,20 +675,25 @@ parse_scram_secret(const char *secret, int *iterations, char **salt,
655675 *
656676 * In a normal authentication, these are extracted from the secret
657677 * stored in the server. This function generates values that look
658- * realistic, for when there is no stored secret.
678+ * realistic, for when there is no stored secret, using SCRAM-SHA-256 .
659679 *
660680 * Like in parse_scram_secret(), for 'stored_key' and 'server_key', the
661- * caller must pass pre-allocated buffers of size SCRAM_KEY_LEN , and
681+ * caller must pass pre-allocated buffers of size SCRAM_MAX_KEY_LEN , and
662682 * the buffer for the salt is palloc'd by this function.
663683 */
664684static void
665- mock_scram_secret (const char * username , int * iterations , char * * salt ,
685+ mock_scram_secret (const char * username , pg_cryptohash_type * hash_type ,
686+ int * iterations , int * key_length , char * * salt ,
666687 uint8 * stored_key , uint8 * server_key )
667688{
668689 char * raw_salt ;
669690 char * encoded_salt ;
670691 int encoded_len ;
671692
693+ /* Enforce the use of SHA-256, which would be realistic enough */
694+ * hash_type = PG_SHA256 ;
695+ * key_length = SCRAM_SHA_256_KEY_LEN ;
696+
672697 /*
673698 * Generate deterministic salt.
674699 *
@@ -677,7 +702,7 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
677702 * as the salt generated for mock authentication uses the cluster's nonce
678703 * value.
679704 */
680- raw_salt = scram_mock_salt (username );
705+ raw_salt = scram_mock_salt (username , * hash_type , * key_length );
681706 if (raw_salt == NULL )
682707 elog (ERROR , "could not encode salt" );
683708
@@ -695,8 +720,8 @@ mock_scram_secret(const char *username, int *iterations, char **salt,
695720 * iterations = SCRAM_DEFAULT_ITERATIONS ;
696721
697722 /* StoredKey and ServerKey are not used in a doomed authentication */
698- memset (stored_key , 0 , SCRAM_KEY_LEN );
699- memset (server_key , 0 , SCRAM_KEY_LEN );
723+ memset (stored_key , 0 , SCRAM_MAX_KEY_LEN );
724+ memset (server_key , 0 , SCRAM_MAX_KEY_LEN );
700725}
701726
702727/*
@@ -1111,10 +1136,10 @@ verify_final_nonce(scram_state *state)
11111136static bool
11121137verify_client_proof (scram_state * state )
11131138{
1114- uint8 ClientSignature [SCRAM_KEY_LEN ];
1115- uint8 ClientKey [SCRAM_KEY_LEN ];
1116- uint8 client_StoredKey [SCRAM_KEY_LEN ];
1117- pg_hmac_ctx * ctx = pg_hmac_create (PG_SHA256 );
1139+ uint8 ClientSignature [SCRAM_MAX_KEY_LEN ];
1140+ uint8 ClientKey [SCRAM_MAX_KEY_LEN ];
1141+ uint8 client_StoredKey [SCRAM_MAX_KEY_LEN ];
1142+ pg_hmac_ctx * ctx = pg_hmac_create (state -> hash_type );
11181143 int i ;
11191144 const char * errstr = NULL ;
11201145
@@ -1123,7 +1148,7 @@ verify_client_proof(scram_state *state)
11231148 * here even when processing the calculations as this could involve a mock
11241149 * authentication.
11251150 */
1126- if (pg_hmac_init (ctx , state -> StoredKey , SCRAM_KEY_LEN ) < 0 ||
1151+ if (pg_hmac_init (ctx , state -> StoredKey , state -> key_length ) < 0 ||
11271152 pg_hmac_update (ctx ,
11281153 (uint8 * ) state -> client_first_message_bare ,
11291154 strlen (state -> client_first_message_bare )) < 0 ||
@@ -1135,7 +1160,7 @@ verify_client_proof(scram_state *state)
11351160 pg_hmac_update (ctx ,
11361161 (uint8 * ) state -> client_final_message_without_proof ,
11371162 strlen (state -> client_final_message_without_proof )) < 0 ||
1138- pg_hmac_final (ctx , ClientSignature , sizeof ( ClientSignature ) ) < 0 )
1163+ pg_hmac_final (ctx , ClientSignature , state -> key_length ) < 0 )
11391164 {
11401165 elog (ERROR , "could not calculate client signature: %s" ,
11411166 pg_hmac_error (ctx ));
@@ -1144,14 +1169,15 @@ verify_client_proof(scram_state *state)
11441169 pg_hmac_free (ctx );
11451170
11461171 /* Extract the ClientKey that the client calculated from the proof */
1147- for (i = 0 ; i < SCRAM_KEY_LEN ; i ++ )
1172+ for (i = 0 ; i < state -> key_length ; i ++ )
11481173 ClientKey [i ] = state -> ClientProof [i ] ^ ClientSignature [i ];
11491174
11501175 /* Hash it one more time, and compare with StoredKey */
1151- if (scram_H (ClientKey , SCRAM_KEY_LEN , client_StoredKey , & errstr ) < 0 )
1176+ if (scram_H (ClientKey , state -> hash_type , state -> key_length ,
1177+ client_StoredKey , & errstr ) < 0 )
11521178 elog (ERROR , "could not hash stored key: %s" , errstr );
11531179
1154- if (memcmp (client_StoredKey , state -> StoredKey , SCRAM_KEY_LEN ) != 0 )
1180+ if (memcmp (client_StoredKey , state -> StoredKey , state -> key_length ) != 0 )
11551181 return false;
11561182
11571183 return true;
@@ -1349,12 +1375,12 @@ read_client_final_message(scram_state *state, const char *input)
13491375 client_proof_len = pg_b64_dec_len (strlen (value ));
13501376 client_proof = palloc (client_proof_len );
13511377 if (pg_b64_decode (value , strlen (value ), client_proof ,
1352- client_proof_len ) != SCRAM_KEY_LEN )
1378+ client_proof_len ) != state -> key_length )
13531379 ereport (ERROR ,
13541380 (errcode (ERRCODE_PROTOCOL_VIOLATION ),
13551381 errmsg ("malformed SCRAM message" ),
13561382 errdetail ("Malformed proof in client-final-message." )));
1357- memcpy (state -> ClientProof , client_proof , SCRAM_KEY_LEN );
1383+ memcpy (state -> ClientProof , client_proof , state -> key_length );
13581384 pfree (client_proof );
13591385
13601386 if (* p != '\0' )
@@ -1374,13 +1400,13 @@ read_client_final_message(scram_state *state, const char *input)
13741400static char *
13751401build_server_final_message (scram_state * state )
13761402{
1377- uint8 ServerSignature [SCRAM_KEY_LEN ];
1403+ uint8 ServerSignature [SCRAM_MAX_KEY_LEN ];
13781404 char * server_signature_base64 ;
13791405 int siglen ;
1380- pg_hmac_ctx * ctx = pg_hmac_create (PG_SHA256 );
1406+ pg_hmac_ctx * ctx = pg_hmac_create (state -> hash_type );
13811407
13821408 /* calculate ServerSignature */
1383- if (pg_hmac_init (ctx , state -> ServerKey , SCRAM_KEY_LEN ) < 0 ||
1409+ if (pg_hmac_init (ctx , state -> ServerKey , state -> key_length ) < 0 ||
13841410 pg_hmac_update (ctx ,
13851411 (uint8 * ) state -> client_first_message_bare ,
13861412 strlen (state -> client_first_message_bare )) < 0 ||
@@ -1392,19 +1418,19 @@ build_server_final_message(scram_state *state)
13921418 pg_hmac_update (ctx ,
13931419 (uint8 * ) state -> client_final_message_without_proof ,
13941420 strlen (state -> client_final_message_without_proof )) < 0 ||
1395- pg_hmac_final (ctx , ServerSignature , sizeof ( ServerSignature ) ) < 0 )
1421+ pg_hmac_final (ctx , ServerSignature , state -> key_length ) < 0 )
13961422 {
13971423 elog (ERROR , "could not calculate server signature: %s" ,
13981424 pg_hmac_error (ctx ));
13991425 }
14001426
14011427 pg_hmac_free (ctx );
14021428
1403- siglen = pg_b64_enc_len (SCRAM_KEY_LEN );
1429+ siglen = pg_b64_enc_len (state -> key_length );
14041430 /* don't forget the zero-terminator */
14051431 server_signature_base64 = palloc (siglen + 1 );
14061432 siglen = pg_b64_encode ((const char * ) ServerSignature ,
1407- SCRAM_KEY_LEN , server_signature_base64 ,
1433+ state -> key_length , server_signature_base64 ,
14081434 siglen );
14091435 if (siglen < 0 )
14101436 elog (ERROR , "could not encode server signature" );
@@ -1431,10 +1457,11 @@ build_server_final_message(scram_state *state)
14311457 * pointer to a static buffer of size SCRAM_DEFAULT_SALT_LEN, or NULL.
14321458 */
14331459static char *
1434- scram_mock_salt (const char * username )
1460+ scram_mock_salt (const char * username , pg_cryptohash_type hash_type ,
1461+ int key_length )
14351462{
14361463 pg_cryptohash_ctx * ctx ;
1437- static uint8 sha_digest [PG_SHA256_DIGEST_LENGTH ];
1464+ static uint8 sha_digest [SCRAM_MAX_KEY_LEN ];
14381465 char * mock_auth_nonce = GetMockAuthenticationNonce ();
14391466
14401467 /*
@@ -1446,11 +1473,17 @@ scram_mock_salt(const char *username)
14461473 StaticAssertDecl (PG_SHA256_DIGEST_LENGTH >= SCRAM_DEFAULT_SALT_LEN ,
14471474 "salt length greater than SHA256 digest length" );
14481475
1449- ctx = pg_cryptohash_create (PG_SHA256 );
1476+ /*
1477+ * This may be worth refreshing if support for more hash methods is\
1478+ * added.
1479+ */
1480+ Assert (hash_type == PG_SHA256 );
1481+
1482+ ctx = pg_cryptohash_create (hash_type );
14501483 if (pg_cryptohash_init (ctx ) < 0 ||
14511484 pg_cryptohash_update (ctx , (uint8 * ) username , strlen (username )) < 0 ||
14521485 pg_cryptohash_update (ctx , (uint8 * ) mock_auth_nonce , MOCK_AUTH_NONCE_LEN ) < 0 ||
1453- pg_cryptohash_final (ctx , sha_digest , sizeof ( sha_digest ) ) < 0 )
1486+ pg_cryptohash_final (ctx , sha_digest , key_length ) < 0 )
14541487 {
14551488 pg_cryptohash_free (ctx );
14561489 return NULL ;
0 commit comments