#include "pch.h"
template<class TInterfaceClass>
class CInternetSessionImpl :
public TInterfaceClass
{
public:
CInternetSessionImpl(IInternetSessionSink * pSink) : m_buffer(NULL),
m_hFile(NULL),
m_pszFileList(NULL),
m_pszFileListData(NULL),
m_nLastErrorCode(0),
m_nLastErrorCodeInThread(0),
m_cBytesRead(0),
m_bAbortDownload(false),
m_hFTPSession(NULL),
m_hFileConnection(NULL),
m_pUpdateSink(pSink)
{
m_szLastError[0] = 0;
m_szErrorInThread[0] = 0;
m_eventResumeDownload = CreateEvent(NULL, FALSE, FALSE, NULL);
m_eventKillDownload = CreateEvent(NULL, TRUE, FALSE, NULL);
m_eventProgress = CreateEvent(NULL, FALSE, FALSE, NULL);
m_eventDownloadTerminated = CreateEvent(NULL, TRUE, FALSE, NULL);
m_eventFileCompleted = CreateEvent(NULL, FALSE, FALSE, NULL);
debugf("Creating download thread.\n");
DWORD dum;
m_threadDownload = CreateThread(NULL, 0, DownloadThread, (void*)this, 0, &dum);
if (m_pUpdateSink && m_threadDownload == NULL)
debugf("Failed to create thread.\n");
}
virtual ~CInternetSessionImpl()
{
KillDownload();
Disconnect();
CloseHandle(m_eventResumeDownload);
CloseHandle(m_eventKillDownload);
CloseHandle(m_eventProgress);
CloseHandle(m_eventDownloadTerminated);
CloseHandle(m_eventFileCompleted);
CloseHandle(m_threadDownload);
if (m_pszFileList)
{
char * * psz = m_pszFileListData + 1; while (psz && *psz)
free(*(psz++));
delete[] m_pszFileListData;
}
}
virtual int GetFileListIncrement() = 0;
void SetSink(IInternetSessionSink * pUpdateSink)
{
m_pUpdateSink = pUpdateSink;
}
bool InitiateDownload(const char * const * pszFileList,
const char * szDestFolder,
bool bDisconnectWhenDone,
int nMaxBufferSize)
{
{
char * * psz = const_cast<char * *>(pszFileList);
int i = 0;
while (psz && *psz)
{
i++;
psz++;
}
m_pszFileListData = new char*[i+2];
m_pszFileListData[0] = (char*)-1; psz = const_cast<char * *>(pszFileList);
i = 1;
while (psz && *psz)
{
m_pszFileListData[i] = _strdup(*psz);
i++;
psz++;
}
m_pszFileListData[i] = NULL;
m_pszFileList = m_pszFileListData;
}
strcpy_s(m_szDestFolder, _MAX_PATH, szDestFolder);
if (m_szDestFolder[0] == '\0' || m_szDestFolder[strlen(m_szDestFolder)-1] != '\\')
{
strcat_s(m_szDestFolder, _MAX_PATH, "\\");
}
m_bAutoDisconnect = bDisconnectWhenDone;
m_nBufferSize = nMaxBufferSize;
if(m_buffer == NULL)
m_buffer = (char*)::VirtualAlloc(NULL, m_nBufferSize, MEM_COMMIT, PAGE_READWRITE);
assert(m_buffer);
m_cTotalBytesRead = 0;
ResetEvent(m_eventResumeDownload);
ResetEvent(m_eventKillDownload);
ResetEvent(m_eventProgress);
ResetEvent(m_eventDownloadTerminated);
ResetEvent(m_eventFileCompleted);
return true;
}
bool Disconnect()
{
FinishCurrentFile(false);
if (m_hInternetSession)
{
if (!InternetCloseHandle(m_hInternetSession))
{
DoError("Disconnect Failed");
return false;
}
m_hInternetSession = NULL;
}
if (m_buffer)
{
::VirtualFree((void*)m_buffer, 0, MEM_RELEASE);
m_buffer = NULL;
}
return true;
}
const char* GetDownloadPath()
{
return m_szDestFolder;
}
const char* GetLastErrorMessage()
{
if (m_szLastError[0] != '\0')
{
return m_szLastError;
}
else return NULL; }
void Abort(bool bAutoDisconnect)
{
KillDownload();
m_bAbortDownload = true;
if (bAutoDisconnect)
Disconnect();
}
bool ContinueDownload()
{
if (m_szLastError[0] != '\0') return false;
if (WaitForSingleObject(m_eventKillDownload, 0) == WAIT_OBJECT_0)
{
if (m_szErrorInThread[0] != 0)
{
SetLastError(m_nLastErrorCodeInThread);
DoError(m_szErrorInThread);
}
else
if (m_bAutoDisconnect)
Disconnect();
return false; }
if (*m_pszFileList != NULL && !m_bAbortDownload) {
if (WaitForSingleObject(m_eventProgress, 0) == WAIT_OBJECT_0)
{
if (m_cBytesRead == m_nBufferSize)
{
if (!FlushDownloadBuffer())
return false;
}
if(m_pUpdateSink)
{
m_pUpdateSink->OnProgress(m_cTotalBytesRead, *m_pszFileList, m_cCurrentFileBytesRead);
}
SetEvent(m_eventResumeDownload);
}
else if(WaitForSingleObject(m_eventFileCompleted, 0) == WAIT_OBJECT_0)
{
FinishCurrentFile(true);
SetEvent(m_eventResumeDownload);
}
return true; }
else {
if (m_pUpdateSink)
m_pUpdateSink->OnTransferFinished();
if (m_bAutoDisconnect)
Disconnect();
return false;
}
}
protected:
enum DOWNLOAD_RESULT
{
DOWNLOAD_ERROR,
DOWNLOAD_PROGRESS,
FILE_COMPLETED,
};
static DWORD WINAPI DownloadThread(LPVOID pThreadParameter)
{
CInternetSessionImpl * pSession = (CInternetSessionImpl *) pThreadParameter;
HANDLE pHandles[] = { pSession->m_eventKillDownload, pSession->m_eventResumeDownload };
while (WaitForMultipleObjects(2, pHandles, FALSE, INFINITE) != WAIT_OBJECT_0)
{
if (pSession->m_hFile == NULL)
{
if (!pSession->StartNextFile())
{
SetEvent(pSession->m_eventKillDownload);
break;
}
}
if (pSession->m_hFile != NULL)
{
DOWNLOAD_RESULT result = DOWNLOAD_ERROR; __try
{
result = pSession->DownloadFileBlock();
}
__except(1)
{
result = DOWNLOAD_ERROR;
}
if (result == DOWNLOAD_PROGRESS)
{
SetEvent(pSession->m_eventProgress);
}
else
if (result == FILE_COMPLETED)
{
SetEvent(pSession->m_eventFileCompleted);
}
else
if (result == DOWNLOAD_ERROR)
{
SetEvent(pSession->m_eventKillDownload);
break;
}
}
}
debugf("Download thread exiting...\n");
SetEvent(pSession->m_eventDownloadTerminated);
ExitThread(0);
return 0;
}
virtual bool StartNextFile()
{
m_cBytesRead = 0;
m_cCurrentFileBytesRead = 0;
++m_pszFileList;
unsigned cTries = 0;
if (*m_pszFileList)
{
while (!(m_hFileConnection = FtpOpenFile(m_hFTPSession, *m_pszFileList, GENERIC_READ, FTP_TRANSFER_TYPE_BINARY | INTERNET_FLAG_RELOAD, 0)))
{
cTries++;
debugf("Failed to open file via FTP for download, try #%d; error code: %d\n", cTries, GetLastError());
Sleep(500);
if(cTries >= 10) {
DoErrorInThread("Failed to open file (%s) for download.", *m_pszFileList);
return false;
}
}
return OpenDownloadFile();
}
return true;
}
bool FinishCurrentFile(bool bCompleted) {
if (m_hFileConnection)
{
if (!InternetCloseHandle(m_hFileConnection))
{
DoError("InternetCloseHandle() Failed for download file");
return false;
}
m_hFileConnection = NULL;
}
if (!FlushDownloadBuffer())
return false;
if (!CloseDownloadFile(bCompleted))
return false;
return true;
}
DOWNLOAD_RESULT DownloadFileBlock()
{
unsigned long cBytesAvail, cBytesJustRead;
if (!InternetQueryDataAvailable(m_hFileConnection, &cBytesAvail, 0, 0))
{
DoErrorInThread("InternetQueryDataAvailable() Failed.");
return DOWNLOAD_ERROR;
}
if (cBytesAvail == 0)
{
return FILE_COMPLETED;
}
unsigned long cBytesAttempted = min(cBytesAvail, m_nBufferSize-m_cBytesRead);
if (!InternetReadFile((void*)m_hFileConnection, (void*)(m_buffer+m_cBytesRead), cBytesAttempted, &cBytesJustRead))
{
DoErrorInThread("InternetReadFile() Failed.");
return DOWNLOAD_ERROR;
}
if (cBytesJustRead == 0) {
return FILE_COMPLETED;
}
m_cBytesRead += cBytesJustRead;
m_cCurrentFileBytesRead += cBytesJustRead;
m_cTotalBytesRead += cBytesJustRead;
return DOWNLOAD_PROGRESS;
}
bool OpenDownloadFile()
{
char szFilename[MAX_PATH+20];
strcpy(szFilename, m_szDestFolder);
strcat(szFilename, *m_pszFileList);
m_hFile = CreateFile(szFilename,
GENERIC_WRITE,
FILE_SHARE_READ,
NULL,
CREATE_ALWAYS,
FILE_ATTRIBUTE_TEMPORARY, NULL);
if (m_hFile == INVALID_HANDLE_VALUE)
{
DoErrorInThread("Failed create file (%s) on local drive.", szFilename);
return false;
}
return true;
}
bool FlushDownloadBuffer()
{
unsigned long cBytesWritten;
if (m_cBytesRead != 0)
{
if(m_pUpdateSink)
{
if(m_pUpdateSink->OnDataReceived((void*)m_buffer, m_cBytesRead) == false)
{
::InternetCloseHandle(m_hFileConnection);
m_hFileConnection = NULL;
::CloseHandle(m_hFile);
m_hFile = NULL;
}
}
if (!WriteFile(m_hFile, (void*)m_buffer, m_cBytesRead, &cBytesWritten, NULL))
{
DoError("Failed to write the file (%s) to local drive : ", *m_pszFileList);
return false;
}
m_cBytesRead = 0;
}
return true;
}
bool CloseDownloadFile(bool bCompleted)
{
if (m_hFile == NULL)
return true;
if (!::CloseHandle(m_hFile))
DoError("Failed to close file %s", *m_pszFileList);
m_hFile = NULL;
if(m_pUpdateSink && bCompleted)
{
if(!m_pUpdateSink->OnFileCompleted(*m_pszFileList))
{
m_pszFileList-= GetFileListIncrement();
m_cTotalBytesRead -= m_cCurrentFileBytesRead;
}
}
return true;
}
void FormatErrorMessage(char *szBuffer, DWORD dwErrorCode)
{
sprintf(szBuffer,"(%d) ", dwErrorCode);
FormatMessage(FORMAT_MESSAGE_FROM_SYSTEM |
FORMAT_MESSAGE_IGNORE_INSERTS,
NULL,
dwErrorCode,
MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
szBuffer + strlen(szBuffer),
128,
NULL
);
strcat(m_szLastError, " ");
unsigned long dummy, size = sizeof(m_szLastError) - strlen(szBuffer) - 2;
InternetGetLastResponseInfo(&dummy, szBuffer + strlen(szBuffer), &size);
}
void DoError(char * szFormat, ...)
{
if (m_szLastError[0] != 0) return;
m_nLastErrorCode = GetLastError();
char szMsg[sizeof(m_szLastError) - 50];
va_list pArg;
va_start(pArg, szFormat);
_vsnprintf_s(szMsg, sizeof(szMsg), sizeof(szMsg), szFormat, pArg);
va_end(pArg);
strcpy(m_szLastError, szMsg);
FormatErrorMessage(m_szLastError + strlen(m_szLastError), m_nLastErrorCode);
CloseDownloadFile(false);
SetLastError(m_nLastErrorCode); if(m_pUpdateSink)
m_pUpdateSink->OnError(m_szLastError);
}
void DoErrorInThread(char * szFormat, ...)
{
if (m_szErrorInThread[0] != 0) return;
m_nLastErrorCodeInThread = GetLastError();
char szMsg[sizeof(m_szErrorInThread) - 50];
va_list pArg;
va_start(pArg, szFormat);
_vsnprintf(szMsg, sizeof(szMsg), szFormat, pArg);
va_end(pArg);
strcpy(m_szErrorInThread, szMsg);
}
void KillDownload()
{
if (m_threadDownload)
{
SetEvent(m_eventKillDownload); int nAwaker = WaitForSingleObject(m_eventDownloadTerminated, 5000); if (nAwaker == WAIT_TIMEOUT)
{
TerminateThread(m_threadDownload, 0);
}
CloseHandle(m_threadDownload);
m_threadDownload = NULL;
}
}
protected:
char * * m_pszFileList;
char * * m_pszFileListData;
char m_szDestFolder[MAX_PATH];
volatile HINTERNET m_hInternetSession;
volatile HINTERNET m_hFTPSession;
volatile HINTERNET m_hFileConnection;
volatile HANDLE m_hFile;
HANDLE m_eventResumeDownload; HANDLE m_eventKillDownload; HANDLE m_eventDownloadTerminated; HANDLE m_eventProgress; HANDLE m_eventFileCompleted; HANDLE m_threadDownload;
volatile char * m_buffer;
volatile unsigned m_cBytesRead; volatile unsigned m_nBufferSize; char m_szLastError[1024];
char m_szErrorInThread[MAX_PATH+100];
int m_nLastErrorCode;
int m_nLastErrorCodeInThread;
bool m_bAutoDisconnect; bool m_bAbortDownload; volatile unsigned long m_cTotalBytesRead; volatile unsigned long m_cCurrentFileBytesRead; IInternetSessionSink * m_pUpdateSink;
};
class CFTPSessionImpl:
public CInternetSessionImpl<IFTPSession>
{
public:
CFTPSessionImpl(IFTPSessionUpdateSink * pSink) :
CInternetSessionImpl<IFTPSession>(pSink)
{
}
virtual ~CFTPSessionImpl()
{
}
virtual bool ConnectToSite(const char * szFTPSite, const char * szDirectory, const char * szUsername, const char * szPassword)
{
m_szLastError[0] = '\0';
m_hInternetSession = ::InternetOpen(
"Microsoft Internet Explorer", INTERNET_OPEN_TYPE_PROXY, "ftp-gw", NULL, 0); m_hFTPSession = ::InternetConnect(
m_hInternetSession, szFTPSite, INTERNET_INVALID_PORT_NUMBER, szUsername, szPassword, INTERNET_SERVICE_FTP, 0, (DWORD) this); if(m_hFTPSession== NULL)
{
DoError("Failed to log onto FTP site (%s) : ", szFTPSite);
return false;
}
if (!FtpSetCurrentDirectory(m_hFTPSession, szDirectory))
{
DoError("Failed to enter the proper FTP directory (%s) : ", szDirectory);
return false;
}
return true;
}
virtual int GetFileListIncrement()
{
return 1;
}
virtual bool InitiateDownload(const char * const * pszFileList,
const char * szDestFolder,
bool bDisconnectWhenDone = true,
int nMaxBufferSize = 1024*1024)
{
bool bRet = CInternetSessionImpl<IFTPSession>::InitiateDownload(pszFileList, szDestFolder, false, nMaxBufferSize);
if(m_pUpdateSink && *pszFileList != NULL)
m_pUpdateSink->OnProgress(0, *(m_pszFileList+1), 0);
SetEvent(m_eventResumeDownload); return bRet;
}
};
class CHTTPSessionImpl :
public CInternetSessionImpl<IHTTPSession>
{
public:
CHTTPSessionImpl(IHTTPSessionSink * pSink) :
CInternetSessionImpl<IHTTPSession>(pSink)
{
m_hInternetSession = ::InternetOpen(
"Microsoft Internet Explorer", INTERNET_OPEN_TYPE_PRECONFIG, NULL, NULL, 0); if (m_hInternetSession == NULL)
DoError("Failed to initialize HTTP stuff.");
}
virtual ~CHTTPSessionImpl()
{
}
bool ConstructionSuccess()
{
return m_hInternetSession != NULL;
}
virtual int GetFileListIncrement()
{
return 2;
}
virtual bool StartNextFile()
{
m_cBytesRead = 0;
m_cCurrentFileBytesRead = 0;
++m_pszFileList;
unsigned cTries = 0;
if (*m_pszFileList)
{
while (!(m_hFileConnection = InternetOpenUrl(m_hInternetSession, *m_pszFileList, NULL, 0, INTERNET_FLAG_RELOAD | INTERNET_FLAG_NO_CACHE_WRITE, 0)))
{
cTries++;
debugf("Failed to open URL(%s) for download, try #%d\n", *m_pszFileList, cTries);
Sleep(500);
if(cTries >= 5)
{
DoErrorInThread("Failed to open file for download.");
return false;
}
}
++m_pszFileList;
if (*m_pszFileList == NULL) {
DoErrorInThread("FileList has bad format");
return false;
}
return OpenDownloadFile();
}
return true;
}
virtual bool InitiateDownload(const char * const * pszFileList,
const char * szDestFolder,
int nMaxBufferSize = 1024*1024)
{
bool bRet = CInternetSessionImpl<IHTTPSession>::InitiateDownload(pszFileList, szDestFolder, false, nMaxBufferSize);
if(m_pUpdateSink && *(m_pszFileList+2) != NULL)
m_pUpdateSink->OnProgress(0, *(m_pszFileList+2), 0);
SetEvent(m_eventResumeDownload); return bRet;
}
};
IHTTPSession * CreateHTTPSession(IHTTPSessionSink * pUpdateSink )
{
CHTTPSessionImpl * pNew = new CHTTPSessionImpl(pUpdateSink);
if (pNew && pNew->ConstructionSuccess())
return pNew;
else
return NULL;
}
IFTPSession * CreateFTPSession(IFTPSessionUpdateSink * pUpdateSink )
{
CFTPSessionImpl * pNew = new CFTPSessionImpl(pUpdateSink);
return pNew;
}