Fix return values. Check text for qoutes.
authorfox <fox@cb58f699-b61c-0410-a6fe-9272a202ed29>
Thu, 18 Jan 2007 17:12:57 +0000 (17:12 +0000)
committerfox <fox@cb58f699-b61c-0410-a6fe-9272a202ed29>
Thu, 18 Jan 2007 17:12:57 +0000 (17:12 +0000)
git-svn-id: https://svn.middleware.georgetown.edu/cpp-sp/trunk@2135 cb58f699-b61c-0410-a6fe-9272a202ed29

odbc-store/odbc-store.cpp

index 8f61b06..fb365fd 100644 (file)
@@ -388,6 +388,32 @@ private:
         strftime(ret,32,"{ts '%Y-%m-%d %H:%M:%S'}",ptime);
     }
 
+    // make a string safe for SQL command
+    // result to be free'd only if it isn't the input
+    char *makeSafeSQL(const char *src)
+    {
+       int ns = 0;
+       int nc = 0;
+       char *s;
+    
+       // see if any conversion needed
+       for (s=(char*)src; *s; nc++,s++) if (*s=='\''||*s=='\\') ns++;
+       if (ns==0) return ((char*)src);
+    
+       char *safe = (char*) malloc(nc+2*ns+1);
+       for (s=safe; *src; src++) {
+           if (*src=='\''||*src=='\\') *s++ = '\\';
+           *s++ = (char)*src;
+       }
+       *s = '\0';
+       return (safe);
+    }
+
+    void freeSafeSQL(char *safe, const char *src)
+    {
+        if (safe!=src) free(safe);
+    }
+
 };
 
 // class constructor
@@ -425,9 +451,10 @@ ODBCStorageService::~ODBCStorageService()
     delete shutdown_wait;
 }
 
+
 // create 
 
-HRESULT ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
+void ODBCStorageService::createRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("createRow");
@@ -442,24 +469,25 @@ HRESULT ODBCStorageService::createRow(const char *table, const char* context, co
     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
 
     // Prepare and exectute insert statement.
-    string q  = string("INSERT ") + table + " VALUES ('" + context + "','" + key + "','" + value + "'," + timebuf + "')";
+    char *scontext = makeSafeSQL(context);
+    char *svalue = makeSafeSQL(value);
+    string q  = string("INSERT ") + table + " VALUES ('" + scontext + "','" + key + "','" + svalue + "'," + timebuf + "')";
+    freeSafeSQL(scontext, context)
+    freeSafeSQL(svalue, value)
     log->debug("SQL: %s", q.str());
 
-    HRESULT hr=NOERROR;
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         log->error("insert record failed (t=%s, c=%s, k=%s", table, context, key);
         log_error(hstmt, SQL_HANDLE_STMT);
-        hr=E_FAIL;
     }
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
-    return hr;
 }
 
 // read
 
-HRESULT ODBCStorageService::readRow(const char *table, const char* context, const char* key, string& pvalue, time_t& pexpiration, int maxsize)
+bool ODBCStorageService::readRow(const char *table, const char* context, const char* key, string& pvalue, time_t& pexpiration, int maxsize)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("readRow");
@@ -475,8 +503,10 @@ HRESULT ODBCStorageService::readRow(const char *table, const char* context, cons
     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
 
     // Prepare and exectute select statement.
+    char *scontext = makeSafeSQL(context);
     string q = string("SELECT expires,value FROM ") + table +
-               " WHERE context='" + context + "' AND key='" + key + "'";
+               " WHERE context='" + scontext + "' AND key='" + key + "'";
+    freeSafeSQL(scontext, context)
     log->debug("SQL: %s", q.str());
 
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
@@ -484,22 +514,22 @@ HRESULT ODBCStorageService::readRow(const char *table, const char* context, cons
         log->error("error searching for (t=%s, c=%s, k=%s)", table, context, key);
         log_error(hstmt, SQL_HANDLE_STMT);
         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
-        return E_FAIL;
+        return false;
     }
 
     // retrieve data 
     SQLBindCol(hstmt,1,SQL_C_TYPE_TIMESTAMP,&expires,0,NULL);
-    // SQLBindCol(hstmt,1,SQL_C_CHAR,value,sizeof(value),NULL);
 
     if ((sr=SQLFetch(hstmt)) == SQL_NO_DATA) {
         SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
-        return S_FALSE;
+        return false;
     }
 
     // expire time from bound col
     exp = timeFromTimestamp(expires);
     if (time(NULL)>ezp) {
         log->debug(".. expired");
+        SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
         return false;
     }
     if (pexpiration) pexpiration = exp;
@@ -522,7 +552,7 @@ HRESULT ODBCStorageService::readRow(const char *table, const char* context, cons
             log->error("error retriving value for (t=%s, c=%s, k=%s)", table, context, key);
             log_error(hstmt, SQL_HANDLE_STMT);
             SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
-            return E_FAIL;
+            return false;
         }
     }
     pvalue = string(tvalue);
@@ -530,18 +560,21 @@ HRESULT ODBCStorageService::readRow(const char *table, const char* context, cons
 
     log->debug(".. value found");
 
-    return sr;
+    SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
+    return true;
 }
 
 
 // update 
 
-HRESULT ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
+bool ODBCStorageService::updateRow(const char *table, const char* context, const char* key, const char* value, time_t expiration)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("updateRow");
 #endif
 
+    bool ret = true;
+
     char timebuf[32];
     timestampFromTime(expiration, timebuf);
 
@@ -555,53 +588,59 @@ HRESULT ODBCStorageService::updateRow(const char *table, const char* context, co
     string expstr = "";
     if (expiration) expstr = string(",expires = '") + timebuf + "' ";
 
-    string q  = string("UPDATE ") + table + " SET value='" + value + "'" + expstr + 
-               " WHERE context='" + context + "' AND key='" + key + "' AND expires > NOW()";
+    char *scontext = makeSafeSQL(context);
+    char *svalue = makeSafeSQL(value);
+    string q  = string("UPDATE ") + table + " SET value='" + svalue + "'" + expstr + 
+               " WHERE context='" + scontext + "' AND key='" + key + "' AND expires > NOW()";
+    freeSafeSQL(scontext, context)
+    freeSafeSQL(svalue, value)
     log->debug("SQL: %s", q.str());
 
-    HRESULT hr=NOERROR;
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.str().c_str(), SQL_NTS);
     if (!SQL_SUCCEEDED(sr)) {
         log->error("update record failed (t=%s, c=%s, k=%s", table, context, key);
         log_error(hstmt, SQL_HANDLE_STMT);
-        hr=E_FAIL;
+        ret = false;
     }
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
-    return hr;
+    return ret;
 }
 
 
 // delete
 
-HRESULT ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
+bool ODBCStorageService::deleteRow(const char *table, const char *context, const char* key)
 {
 #ifdef _DEBUG
     xmltooling::NDC ndc("deleteRow");
 #endif
 
+    bool ret = true;
+
     // Get statement handle.
     SQLHSTMT hstmt;
     ODBCConn conn(getHDBC());
     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
 
     // Prepare and execute delete statement.
-    string q = string("DELETE FROM ") + table + " WHERE context='" + context + "' AND key='" + key + "'";
+    char *scontext = makeSafeSQL(context);
+    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND key='" + key + "'";
+    freeSafeSQL(scontext, context)
     log->debug("SQL: %s", q.str());
 
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
  
-    HRESULT hr=NOERROR;
-    if (sr==SQL_NO_DATA)
-        hr=S_FALSE;
-    else if (!SQL_SUCCEEDED(sr)) {
+    if (sr==SQL_NO_DATA) {
+        ret = false;
+    } else if (!SQL_SUCCEEDED(sr)) {
         log->error("error deleting record (t=%s, c=%s, k=%s)", table, context, key);
         log_error(hstmt, SQL_HANDLE_STMT);
-        hr=E_FAIL;
+        ret = false;
     }
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
-    return hr;
+    return ret;
 }
 
 
@@ -668,7 +707,9 @@ void ODBCStorageService::reapRows(const char *table, const char* context)
     // Prepare and execute delete statement.
     string q;
     if (context) {
-        q = string("DELETE FROM ") + table + " WHERE context='" + context + "' AND expires<NOW()";
+        char *scontext = makeSafeSQL(context);
+        q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "' AND expires<NOW()";
+        freeSafeSQL(scontext, context)
     } else {
         q = string("DELETE FROM ") + table + " WHERE expires<NOW()";
     }
@@ -676,13 +717,9 @@ void ODBCStorageService::reapRows(const char *table, const char* context)
 
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
  
-    HRESULT hr=NOERROR;
-    if (sr==SQL_NO_DATA)
-        hr=S_FALSE;
-    else if (!SQL_SUCCEEDED(sr)) {
+    if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
         log->error("error expiring records (t=%s, c=%s)", table, context?context:"null");
         log_error(hstmt, SQL_HANDLE_STMT);
-        hr=E_FAIL;
     }
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);
@@ -704,18 +741,16 @@ void ODBCStorageService::deleteCtx(const char *table, const char* context)
     SQLAllocHandle(SQL_HANDLE_STMT,conn,&hstmt);
 
     // Prepare and execute delete statement.
-    string q = string("DELETE FROM ") + table + " WHERE context='" + context + "'";
+    char *scontext = makeSafeSQL(context);
+    string q = string("DELETE FROM ") + table + " WHERE context='" + scontext + "'";
+    freeSafeSQL(scontext, context)
     log->debug("SQL: %s", q.str());
 
     SQLRETURN sr=SQLExecDirect(hstmt, (SQLCHAR*)q.c_str(), SQL_NTS);
  
-    HRESULT hr=NOERROR;
-    if (sr==SQL_NO_DATA)
-        hr=S_FALSE;
-    else if (!SQL_SUCCEEDED(sr)) {
+    if ((sr!=SQL_NO_DATA) && !SQL_SUCCEEDED(sr)) {
         log->error("error deleting context (t=%s, c=%s)", table, context);
         log_error(hstmt, SQL_HANDLE_STMT);
-        hr=E_FAIL;
     }
 
     SQLFreeHandle(SQL_HANDLE_STMT,hstmt);