Back to index

nordugrid-arc-nox  1.1.0~rc6
saml_util.cpp
Go to the documentation of this file.
00001 #ifdef HAVE_CONFIG_H
00002 #include <config.h>
00003 #endif
00004 
00005 #ifdef WIN32 
00006 #define NOGDI
00007 #endif 
00008 
00009 #include <iostream>
00010 #include <fstream>
00011 #include <sstream>
00012 
00013 #include <glibmm/fileutils.h>
00014 #include <unistd.h>
00015 #include <cstring>
00016 #include <zlib.h>
00017 
00018 #include <libxml/uri.h>
00019 
00020 #include <xmlsec/base64.h>
00021 #include <xmlsec/errors.h>
00022 #include <xmlsec/xmltree.h>
00023 #include <xmlsec/xmldsig.h>
00024 #include <xmlsec/xmlenc.h>
00025 #include <xmlsec/templates.h>
00026 #include <xmlsec/crypto.h>
00027 #include <xmlsec/openssl/app.h>
00028 #include <xmlsec/openssl/crypto.h>
00029 
00030 #include <openssl/bio.h>
00031 #include <openssl/evp.h>
00032 #include <openssl/sha.h>
00033 #include <openssl/rand.h>
00034 #ifdef CHARSET_EBCDIC
00035 #include <openssl/ebcdic.h>
00036 #endif 
00037 
00038 #include "XmlSecUtils.h"
00039 #include "saml_util.h"
00040 
00041 namespace Arc {
00042 
00043   std::string SignQuery(std::string query, SignatureMethod sign_method, std::string& privkey_file) {
00044  
00045     std::string ret;
00046 
00047     BIO* key_bio = BIO_new_file(privkey_file.c_str(), "rb");
00048     if (key_bio == NULL) {
00049      std::cout<<"Failed to open private key file: "<<privkey_file<<std::endl;
00050      return ret;
00051     }
00052 
00053     /* Add SigAlg */
00054     char *t;
00055     std::string new_query; 
00056     switch (sign_method) {
00057       case RSA_SHA1:
00058         t = (char*)xmlURIEscapeStr(xmlSecHrefRsaSha1, NULL);
00059         new_query.append(query).append("&SigAlg=").append(t);
00060        xmlFree(t);
00061        break;
00062       case DSA_SHA1:
00063        t = (char*)xmlURIEscapeStr(xmlSecHrefDsaSha1, NULL);
00064         new_query.append(query).append("&SigAlg=").append(t);
00065        xmlFree(t);
00066        break;
00067     }
00068 
00069     /* Build buffer digest */
00070     if(new_query.empty()) return ret;
00071    
00072     xmlChar* md;
00073     md = (xmlChar*)(xmlMalloc(20));
00074     char* digest = (char*)SHA1((unsigned char*)(new_query.c_str()), new_query.size(), md);
00075 
00076 
00077     RSA *rsa = NULL;
00078     DSA *dsa = NULL;
00079     unsigned char *sigret = NULL;
00080     unsigned int siglen;
00081     char *b64_sigret = NULL, *e_b64_sigret = NULL;
00082     int status = 0;
00083 
00084     /* Calculate signature value */
00085     if (sign_method == RSA_SHA1) {
00086       rsa = PEM_read_bio_RSAPrivateKey(key_bio, NULL, NULL, NULL);
00087       if (rsa == NULL) {
00088         std::cerr<<"Failed to read rsa key from private key file"<<std::endl;
00089         BIO_free(key_bio); xmlFree(digest); return ret;
00090       }
00091       sigret = (unsigned char *)malloc (RSA_size(rsa));
00092       status = RSA_sign(NID_sha1, (unsigned char*)digest, 20, sigret, &siglen, rsa);
00093       RSA_free(rsa);
00094     } 
00095     else if (sign_method == DSA_SHA1) {
00096       dsa = PEM_read_bio_DSAPrivateKey(key_bio, NULL, NULL, NULL);
00097       if (dsa == NULL) {
00098         std::cerr<<"Failed to read dsa key from private key file"<<std::endl;
00099         BIO_free(key_bio); xmlFree(digest); return ret;
00100       }
00101       sigret = (unsigned char *)malloc (DSA_size(dsa));
00102       status = DSA_sign(NID_sha1, (unsigned char*)digest, 20, sigret, &siglen, dsa);
00103       DSA_free(dsa);
00104     }
00105     
00106     BIO_free(key_bio);
00107    
00108     if (status ==0) { free(sigret); xmlFree(digest); return ret; }
00109 
00110     /* Base64 encode the signature value */
00111     b64_sigret = (char*)xmlSecBase64Encode(sigret, siglen, 0);
00112     /* escape b64_sigret */
00113     e_b64_sigret = (char*)xmlURIEscapeStr((xmlChar*)b64_sigret, NULL);
00114 
00115     /* Add signature */
00116     switch (sign_method) {
00117       case RSA_SHA1:
00118        new_query.append("&Signature=").append(e_b64_sigret);
00119        break;
00120       case DSA_SHA1:
00121         new_query.append("&Signature=").append(e_b64_sigret);
00122        break;
00123     }
00124 
00125     xmlFree(digest);
00126     free(sigret);
00127     xmlFree(b64_sigret);
00128     xmlFree(e_b64_sigret);
00129 
00130     return new_query;
00131   }
00132 
00133   //bool VerifyQuery(const std::string query, const xmlSecKey *sender_public_key) {
00134   bool VerifyQuery(const std::string query, const std::string& sender_cert_str) {
00135 
00136     xmlSecKey* sender_public_key = NULL;
00137     sender_public_key =  get_key_from_certstr(sender_cert_str);
00138     if(sender_public_key == NULL) { 
00139       std::cerr<<"Failed to get public key from the certificate string"<<std::endl; 
00140       return false; 
00141     }
00142 
00143     /* split query, the signature MUST be the last param of the query,
00144      * there could be more params in the URL; but they wouldn't be
00145      * covered by the signature */
00146     size_t f;
00147     f = query.find("&Signature=");
00148     if(f == std::string::npos) { std::cerr<<"Failed to find signature in the query"<<std::endl; return false; }
00149 
00150     std::string str0 = query.substr(0,f);
00151     std::string str1 = query.substr(f+11);
00152 
00153     f = str0.find("&SigAlg=");
00154     if(f == std::string::npos) { std::cerr<<"Failed to find signature alg in the query"<<std::endl; return false; }
00155 
00156     std::string sig_alg = str0.substr(f+8);
00157 
00158     int key_size;
00159     RSA *rsa = NULL;
00160     DSA *dsa = NULL;
00161     char* usig_alg = NULL;
00162     usig_alg = xmlURIUnescapeString(sig_alg.c_str(), 0, NULL);
00163 
00164     if (strcmp(usig_alg, (char*)xmlSecHrefRsaSha1) == 0) {
00165       if (sender_public_key->value->id != xmlSecOpenSSLKeyDataRsaId) {
00166          xmlFree(usig_alg); return false;
00167       }
00168       rsa = xmlSecOpenSSLKeyDataRsaGetRsa(sender_public_key->value);
00169         if (rsa == NULL) {
00170           xmlFree(usig_alg); return false;
00171         }
00172         key_size = RSA_size(rsa);
00173     } 
00174     else if (strcmp(usig_alg, (char*)xmlSecHrefDsaSha1) == 0) {
00175       if (sender_public_key->value->id != xmlSecOpenSSLKeyDataDsaId) {
00176         xmlFree(usig_alg); return false;
00177       }
00178       dsa = xmlSecOpenSSLKeyDataDsaGetDsa(sender_public_key->value);
00179       if (dsa == NULL) {
00180         xmlFree(usig_alg); return false;
00181       }
00182       key_size = DSA_size(dsa);
00183     } 
00184     else {
00185       xmlFree(usig_alg);
00186       return false;
00187     }
00188 
00189     f = str1.find("&");
00190     std::string sig_str = str1.substr(0, f-1);   
00191 
00192     char *b64_signature = NULL;
00193     xmlSecByte *signature = NULL;
00194 
00195     /* get signature (unescape + base64 decode) */
00196     signature = (unsigned char*)(xmlMalloc(key_size+1));
00197     b64_signature = (char*)xmlURIUnescapeString(sig_str.c_str(), 0, NULL);
00198 
00199     xmlSecBase64Decode((xmlChar*)b64_signature, signature, key_size+1);
00200 
00201     /* compute signature digest */
00202     xmlChar* md;
00203     md = (xmlChar*)(xmlMalloc(20));
00204     char* digest = (char*)SHA1((unsigned char*)(str0.c_str()), str0.size(), md);
00205 
00206     if (digest == NULL) {
00207       xmlFree(b64_signature);
00208       xmlFree(signature);
00209       xmlFree(usig_alg);
00210       return false;
00211     }
00212 
00213     int status = 0;
00214 
00215     if (rsa) {
00216       status = RSA_verify(NID_sha1, (unsigned char*)digest, 20, signature, key_size, rsa);
00217     } 
00218     else if (dsa) {
00219       status = DSA_verify(NID_sha1, (unsigned char*)digest, 20, signature, key_size, dsa);
00220     }
00221 
00222     if (status == 0) {
00223       std::cout<<"Signature of the query is not valid"<<std::endl;
00224       xmlFree(b64_signature);
00225       xmlFree(signature);
00226       xmlFree(digest);
00227       xmlFree(usig_alg);
00228       return false;
00229     }
00230 
00231     xmlFree(b64_signature);
00232     xmlFree(signature);
00233     xmlFree(digest);
00234     xmlFree(usig_alg);
00235    
00236     return true;
00237   }
00238 
00239   std::string BuildDeflatedQuery(const XMLNode& node) {
00240     //deflated, b64'ed and url-escaped
00241     std::string encoding("utf-8");
00242     std::string query;
00243     node.GetXML(query, encoding);
00244 
00245     //std::ostringstream oss (std::ostringstream::out);
00246     //node.SaveToStream(oss);
00247     //query = oss.str();
00248    
00249     //XMLNode node1(query);
00250     //std::string query1;
00251     //node1.GetXML(query1, encoding);  
00252 
00253     //std::cout<<"Query:  "<<query<<std::endl;
00254 
00255     std::string deflated_str = DeflateData(query);
00256 
00257     std::string b64_str = Base64Encode(deflated_str);
00258 
00259     std::string escaped_str = URIEscape(b64_str);
00260 
00261     return escaped_str;
00262   }
00263 
00264   std::string Base64Encode(const std::string& data) {
00265     unsigned long len;
00266     xmlChar *b64_out = NULL;
00267     len = data.length();
00268     b64_out = xmlSecBase64Encode((xmlChar*)(data.c_str()), data.length(), 0);
00269     std::string ret;
00270     if(b64_out != NULL) {
00271       ret.append((char*)b64_out);
00272       xmlFree(b64_out);
00273     }
00274     return ret;
00275   }
00276 
00277   std::string Base64Decode(const std::string& data) {
00278     unsigned long len;
00279     xmlChar *out = NULL;
00280     len = data.length();
00281     out = (xmlChar*)(xmlMalloc(len*4));
00282     len = xmlSecBase64Decode((xmlChar*)(data.c_str()), out, len*4);
00283     std::string ret;
00284     if(out != NULL) {
00285       ret.append((char*)out, len);
00286       xmlFree(out);
00287     }
00288     return ret;
00289   }
00290 
00291   std::string URIEscape(const std::string& data) {
00292     xmlChar* out = xmlURIEscapeStr((xmlChar*)(data.c_str()), NULL);
00293     std::string ret;
00294     ret.append((char*)out);
00295     xmlFree(out);
00296     return ret;
00297   }
00298 
00299   std::string URIUnEscape(const std::string& data) {
00300     xmlChar* out = (xmlChar*)xmlURIUnescapeString(data.c_str(), 0,  NULL);
00301     std::string ret;
00302     ret.append((char*)out);
00303     xmlFree(out);
00304     return ret;
00305   }
00306 
00307   std::string DeflateData(const std::string& data) {
00308     unsigned long len;
00309     char *out;
00310     len = data.length();
00311     out = (char*)(malloc(len * 2));
00312     z_stream stream;
00313     stream.next_in = (Bytef*)(data.c_str());
00314     stream.avail_in = len;
00315     stream.next_out = (Bytef*)out;
00316     stream.avail_out = len * 2;
00317 
00318     stream.zalloc = NULL;
00319     stream.zfree = NULL;
00320     stream.opaque = NULL;
00321 
00322     int rc;
00323     /* -MAX_WBITS to disable zib headers */
00324     rc = deflateInit2(&stream, Z_DEFAULT_COMPRESSION, Z_DEFLATED, -MAX_WBITS, 5, 0);
00325     if (rc == Z_OK) {
00326       rc = deflate(&stream, Z_FINISH);
00327       if (rc != Z_STREAM_END) {
00328         deflateEnd(&stream);
00329         if (rc == Z_OK) {
00330           rc = Z_BUF_ERROR;
00331         }
00332       } 
00333       else {
00334         rc = deflateEnd(&stream);
00335       }
00336     }
00337     if (rc != Z_OK) {
00338       free(out);
00339       std::cerr<<"Failed to deflate the data"<<std::endl;
00340       return std::string();
00341     }
00342     std::string ret;
00343     ret.append(out,stream.total_out);
00344 
00345     free(out);
00346     return ret;
00347   }
00348 
00349   std::string InflateData(const std::string& data) {
00350     unsigned long len;
00351     char *out;
00352     len = data.length();
00353     out = (char*)(malloc(len * 10));
00354     z_stream stream;
00355 
00356     stream.zalloc = NULL;
00357     stream.zfree = NULL;
00358     stream.opaque = NULL;
00359 
00360     stream.avail_in = len;
00361     stream.next_in = (Bytef*)(data.c_str());
00362     stream.total_in = 0;
00363     stream.avail_out = len*10;
00364     stream.total_out = 0;
00365     stream.next_out = (Bytef*)out;
00366 
00367     int rc;
00368     rc = inflateInit2(&stream, -MAX_WBITS);
00369     if (rc != Z_OK) {
00370       std::cerr<<"inflateInit failed"<<std::endl;
00371       free(out);
00372       return std::string();
00373     }
00374 
00375     rc = inflate(&stream, Z_FINISH);
00376     if (rc != Z_STREAM_END) {
00377       std::cerr<<"Failed to inflate"<<std::endl;
00378       inflateEnd(&stream);
00379       free(out);
00380       return std::string();
00381     }
00382     out[stream.total_out] = 0;
00383     inflateEnd(&stream);
00384 
00385     std::string ret;
00386     ret.append(out);
00387 
00388     free(out);
00389     return ret;
00390   }
00391 
00392   static bool is_base64(const char *message) {
00393     const char *c;
00394     c = message;
00395     while (*c != 0 && (isalnum(*c) || *c == '+' || *c == '/' || *c == '\n' || *c == '\r')) c++;
00396     while (*c == '=' || *c == '\n' || *c == '\r') c++;
00397     if (*c == 0) return true;
00398     return false;
00399   }
00400 
00401   bool BuildNodefromMsg(const std::string msg, XMLNode& node) {
00402     bool b64 = false;
00403     char* str = (char*)(msg.c_str());
00404     if (is_base64(msg.c_str())) {   
00405       str = (char*)malloc(msg.length());
00406       int r = xmlSecBase64Decode((xmlChar*)(msg.c_str()), (xmlChar*)str, msg.length());
00407       if (r >= 0) b64 = true;
00408       else {
00409         free(str);
00410         str = (char*)(msg.c_str());
00411       }
00412     }
00413 
00414     if (strchr(str, '<')) {
00415       XMLNode nd(str);
00416       if(!nd) { std::cerr<<"Message format unknown"<<std::endl; free(str); return false; }
00417       if (b64) free(str);
00418       nd.New(node);
00419       return true;
00420     }
00421 
00422     //if (strchr(str, '&') || strchr(str, '='))
00423     if(b64) free(str);
00424     return false;
00425   }  
00426 
00427 } //namespace Arc