@@ -551,22 +551,18 @@ static int derive_retained_key(int hmac, const char *hostnqn,
551551 return -1 ;
552552}
553553
554- static int gen_tls_identity (const char * hostnqn , const char * subsysnqn ,
555- int version , int hmac , char * identity ,
556- unsigned char * retained , size_t key_len )
554+ static int derive_psk_digest (const char * hostnqn , const char * subsysnqn ,
555+ int version , int hmac ,
556+ unsigned char * retained , size_t key_len ,
557+ char * digest , size_t digest_len )
557558{
558- if (version != 0 ) {
559- nvme_msg (NULL , LOG_ERR , "NVMe TLS 2.0 is not supported; "
560- "recompile with OpenSSL support.\n" );
561- errno = ENOTSUP ;
562- return -1 ;
563- }
564- sprintf (identity , "NVMe0R%02d %s %s" ,
565- hmac , hostnqn , subsysnqn );
566- return strlen (identity );
559+ nvme_msg (NULL , LOG_ERR , "NVMe TLS 2.0 is not supported; "
560+ "recompile with OpenSSL support.\n" );
561+ errno = ENOTSUP ;
562+ return -1 ;
567563}
568564
569- static int derive_tls_key (int hmac , const char * identity ,
565+ static int derive_tls_key (int version , int hmac , const char * context ,
570566 unsigned char * retained ,
571567 unsigned char * psk , size_t key_len )
572568{
@@ -662,7 +658,7 @@ static int derive_retained_key(int hmac, const char *hostnqn,
662658 return key_len ;
663659}
664660
665- static int derive_tls_key (int hmac , const char * identity ,
661+ static int derive_tls_key (int version , int hmac , const char * context ,
666662 unsigned char * retained ,
667663 unsigned char * psk , size_t key_len )
668664{
@@ -710,9 +706,20 @@ static int derive_tls_key(int hmac, const char *identity,
710706 errno = ENOKEY ;
711707 return -1 ;
712708 }
709+ if (version == 1 ) {
710+ char hash_str [4 ];
711+
712+ sprintf (hash_str , "%02d " , hmac );
713+ if (EVP_PKEY_CTX_add1_hkdf_info (ctx ,
714+ (const unsigned char * )hash_str ,
715+ strlen (hash_str )) <= 0 ) {
716+ errno = ENOKEY ;
717+ return -1 ;
718+ }
719+ }
713720 if (EVP_PKEY_CTX_add1_hkdf_info (ctx ,
714- (const unsigned char * )identity ,
715- strlen (identity )) <= 0 ) {
721+ (const unsigned char * )context ,
722+ strlen (context )) <= 0 ) {
716723 errno = ENOKEY ;
717724 return -1 ;
718725 }
@@ -792,28 +799,18 @@ int nvme_gen_dhchap_key(char *hostnqn, enum nvme_hmac_alg hmac,
792799 return 0 ;
793800}
794801
795- static int gen_tls_identity (const char * hostnqn , const char * subsysnqn ,
796- int version , int hmac , char * identity ,
797- unsigned char * retained , size_t key_len )
802+ static int derive_psk_digest (const char * hostnqn , const char * subsysnqn ,
803+ int version , int hmac ,
804+ unsigned char * retained , size_t key_len ,
805+ char * digest , size_t digest_len )
798806{
799807 static const char hmac_seed [] = "NVMe-over-Fabrics" ;
800808 size_t hmac_len ;
801809 const EVP_MD * md = select_hmac (hmac , & hmac_len );
802810 _cleanup_hmac_ctx_ HMAC_CTX * hmac_ctx = NULL ;
803811 _cleanup_free_ unsigned char * psk_ctx = NULL ;
804- _cleanup_free_ char * enc_ctx = NULL ;
805812 size_t len ;
806813
807- if (version == 0 ) {
808- sprintf (identity , "NVMe%01dR%02d %s %s" ,
809- version , hmac , hostnqn , subsysnqn );
810- return strlen (identity );
811- }
812- if (version > 1 ) {
813- errno = EINVAL ;
814- return -1 ;
815- }
816-
817814 hmac_ctx = HMAC_CTX_new ();
818815 if (!hmac_ctx ) {
819816 errno = ENOMEM ;
@@ -860,17 +857,19 @@ static int gen_tls_identity(const char *hostnqn, const char *subsysnqn,
860857 errno = ENOKEY ;
861858 return -1 ;
862859 }
863- enc_ctx = malloc (key_len * 2 );
864- memset (enc_ctx , 0 , key_len * 2 );
865- len = base64_encode (psk_ctx , key_len , enc_ctx );
860+ if (key_len * 2 > digest_len ) {
861+ errno = EINVAL ;
862+ return -1 ;
863+ }
864+ memset (digest , 0 , digest_len );
865+ len = base64_encode (psk_ctx , key_len , digest );
866866 if (len < 0 ) {
867867 errno = ENOKEY ;
868868 return len ;
869869 }
870- sprintf (identity , "NVMe%01dR%02d %s %s %s" ,
871- version , hmac , hostnqn , subsysnqn , enc_ctx );
872- return strlen (identity );
870+ return strlen (digest );
873871}
872+
874873#endif /* !CONFIG_OPENSSL_1 */
875874
876875#ifdef CONFIG_OPENSSL_3
@@ -965,9 +964,10 @@ int nvme_gen_dhchap_key(char *hostnqn, enum nvme_hmac_alg hmac,
965964 return 0 ;
966965}
967966
968- static int gen_tls_identity (const char * hostnqn , const char * subsysnqn ,
969- int version , int hmac , char * identity ,
970- unsigned char * retained , size_t key_len )
967+ static int derive_psk_digest (const char * hostnqn , const char * subsysnqn ,
968+ int version , int hmac ,
969+ unsigned char * retained , size_t key_len ,
970+ char * digest , size_t digest_len )
971971{
972972 static const char hmac_seed [] = "NVMe-over-Fabrics" ;
973973 size_t hmac_len ;
@@ -976,21 +976,10 @@ static int gen_tls_identity(const char *hostnqn, const char *subsysnqn,
976976 _cleanup_evp_mac_ctx_ EVP_MAC_CTX * mac_ctx = NULL ;
977977 _cleanup_evp_mac_ EVP_MAC * mac = NULL ;
978978 char * progq = NULL ;
979- char * digest = NULL ;
979+ char * dig = NULL ;
980980 _cleanup_free_ unsigned char * psk_ctx = NULL ;
981- _cleanup_free_ char * enc_ctx = NULL ;
982981 size_t len ;
983982
984- if (version == 0 ) {
985- sprintf (identity , "NVMe%01dR%02d %s %s" ,
986- version , hmac , hostnqn , subsysnqn );
987- return strlen (identity );
988- }
989- if (version > 1 ) {
990- errno = EINVAL ;
991- return -1 ;
992- }
993-
994983 lib_ctx = OSSL_LIB_CTX_new ();
995984 if (!lib_ctx ) {
996985 errno = ENOMEM ;
@@ -1009,19 +998,19 @@ static int gen_tls_identity(const char *hostnqn, const char *subsysnqn,
1009998 }
1010999 switch (hmac ) {
10111000 case NVME_HMAC_ALG_SHA2_256 :
1012- digest = OSSL_DIGEST_NAME_SHA2_256 ;
1001+ dig = OSSL_DIGEST_NAME_SHA2_256 ;
10131002 break ;
10141003 case NVME_HMAC_ALG_SHA2_384 :
1015- digest = OSSL_DIGEST_NAME_SHA2_384 ;
1004+ dig = OSSL_DIGEST_NAME_SHA2_384 ;
10161005 break ;
10171006 default :
10181007 errno = EINVAL ;
10191008 break ;
10201009 }
1021- if (!digest )
1010+ if (!dig )
10221011 return -1 ;
10231012 * p ++ = OSSL_PARAM_construct_utf8_string (OSSL_MAC_PARAM_DIGEST ,
1024- digest , 0 );
1013+ dig , 0 );
10251014 * p = OSSL_PARAM_construct_end ();
10261015
10271016 psk_ctx = malloc (key_len );
@@ -1065,25 +1054,47 @@ static int gen_tls_identity(const char *hostnqn, const char *subsysnqn,
10651054 errno = EMSGSIZE ;
10661055 return -1 ;
10671056 }
1068- enc_ctx = malloc (hmac_len * 2 );
1069- memset (enc_ctx , 0 , hmac_len * 2 );
1070- len = base64_encode (psk_ctx , hmac_len , enc_ctx );
1057+ if (hmac_len * 2 > digest_len ) {
1058+ errno = EINVAL ;
1059+ return -1 ;
1060+ }
1061+ memset (digest , 0 , digest_len );
1062+ len = base64_encode (psk_ctx , hmac_len , digest );
10711063 if (len < 0 ) {
10721064 errno = ENOKEY ;
10731065 return len ;
10741066 }
1067+ return strlen (digest );
1068+ }
1069+ #endif /* !CONFIG_OPENSSL_3 */
1070+
1071+ static int gen_tls_identity (const char * hostnqn , const char * subsysnqn ,
1072+ int version , int hmac , char * digest ,
1073+ char * identity )
1074+ {
1075+ if (version == 0 ) {
1076+ sprintf (identity , "NVMe%01dR%02d %s %s" ,
1077+ version , hmac , hostnqn , subsysnqn );
1078+ return strlen (identity );
1079+ }
1080+ if (version > 1 ) {
1081+ errno = EINVAL ;
1082+ return -1 ;
1083+ }
1084+
10751085 sprintf (identity , "NVMe%01dR%02d %s %s %s" ,
1076- version , hmac , hostnqn , subsysnqn , enc_ctx );
1086+ version , hmac , hostnqn , subsysnqn , digest );
10771087 return strlen (identity );
10781088}
1079- #endif /* !CONFIG_OPENSSL_3 */
10801089
10811090static int derive_nvme_keys (const char * hostnqn , const char * subsysnqn ,
10821091 char * identity , int version ,
10831092 int hmac , unsigned char * configured ,
10841093 unsigned char * psk , int key_len )
10851094{
10861095 _cleanup_free_ unsigned char * retained = NULL ;
1096+ _cleanup_free_ char * digest = NULL ;
1097+ char * context = identity ;
10871098 int ret = -1 ;
10881099
10891100 if (!hostnqn || !subsysnqn || !identity || !psk ) {
@@ -1099,11 +1110,28 @@ static int derive_nvme_keys(const char *hostnqn, const char *subsysnqn,
10991110 ret = derive_retained_key (hmac , hostnqn , configured , retained , key_len );
11001111 if (ret < 0 )
11011112 return ret ;
1113+
1114+ if (version == 1 ) {
1115+ size_t digest_len = 2 * key_len ;
1116+
1117+ digest = malloc (digest_len );
1118+ if (!digest ) {
1119+ errno = ENOMEM ;
1120+ return -1 ;
1121+ }
1122+ ret = derive_psk_digest (hostnqn , subsysnqn , version , hmac ,
1123+ retained , key_len ,
1124+ digest , digest_len );
1125+ if (ret )
1126+ return ret ;
1127+ context = digest ;
1128+ }
11021129 ret = gen_tls_identity (hostnqn , subsysnqn , version , hmac ,
1103- identity , retained , key_len );
1130+ digest , identity );
11041131 if (ret < 0 )
11051132 return ret ;
1106- return derive_tls_key (hmac , identity , retained , psk , key_len );
1133+ return derive_tls_key (version , hmac , context , retained ,
1134+ psk , key_len );
11071135}
11081136
11091137static size_t nvme_identity_len (int hmac , int version , const char * hostnqn ,
0 commit comments