Commits (2)
......@@ -189,20 +189,29 @@ public:
virtual TypeInfoPtr Next() = 0;
};
class SymbolProvider;
typedef boost::shared_ptr<SymbolProvider> SymbolProviderPtr;
class SymbolEnumerator;
typedef boost::shared_ptr< SymbolEnumerator> SymbolEnumeratorPtr;
class SymbolProvider {
public:
virtual SymbolEnumeratorPtr getSymbolEnumerator(const std::wstring& mask = L"") = 0;
};
class SymbolEnumerator {
public:
virtual std::wstring Next() = 0;
};
typedef boost::shared_ptr<SymbolEnumerator> SymbolEnumeratorPtr;
TypeInfoProviderPtr getTypeInfoProviderFromSource( const std::wstring& source, const std::wstring& opts = L"" );
TypeInfoProviderPtr getTypeInfoProviderFromSource(const std::string& source, const std::string& opts = "");
TypeInfoProviderPtr getTypeInfoProviderFromPdb( const std::wstring& pdbFile, MEMOFFSET_64 loadBase = 0 );
TypeInfoProviderPtr getDefaultTypeInfoProvider();
SymbolEnumeratorPtr getSymbolEnumeratorFromSource(const std::wstring& source, const std::wstring& opts = L"");
SymbolEnumeratorPtr getSymbolEnumeratorFromSource(const std::string& source, const std::string& opts = "");
SymbolProviderPtr getSymbolProviderFromSource(const std::wstring& source, const std::wstring& opts = L"");
SymbolProviderPtr getSymbolProviderFromSource(const std::string& source, const std::string& opts = "");
///////////////////////////////////////////////////////////////////////////////
......
......@@ -907,7 +907,7 @@ TypeInfoProviderPtr getTypeInfoProviderFromSource(const std::string& source, c
///////////////////////////////////////////////////////////////////////////////
SymbolEnumeratorClang::SymbolEnumeratorClang(const std::string& sourceCode, const std::string& compileOptions)
SymbolProviderClang::SymbolProviderClang(const std::string& sourceCode, const std::string& compileOptions)
{
std::vector<std::unique_ptr<ASTUnit>> ASTs;
ASTBuilderAction Action(ASTs);
......@@ -963,26 +963,41 @@ SymbolEnumeratorClang::SymbolEnumeratorClang(const std::string& sourceCode, con
///////////////////////////////////////////////////////////////////////////////
SymbolEnumeratorPtr SymbolProviderClang::getSymbolEnumerator(const std::wstring& mask)
{
return SymbolEnumeratorPtr(new SymbolEnumeratorClang(mask, shared_from_this()));
}
///////////////////////////////////////////////////////////////////////////////
std::wstring SymbolEnumeratorClang::Next()
{
if (m_index < m_symbols.size())
return strToWStr(m_symbols[m_index++]);
const auto& symbols = m_symbolProvider->m_symbols;
while (m_index < symbols.size())
{
const auto& sym = symbols[m_index++];
if (m_mask.empty() || fnmatch(m_mask, sym))
return strToWStr(sym);
}
return L"";
return std::wstring();
}
///////////////////////////////////////////////////////////////////////////////
SymbolEnumeratorPtr getSymbolEnumeratorFromSource(const std::wstring& source, const std::wstring& opts)
SymbolProviderPtr getSymbolProviderFromSource(const std::wstring& source, const std::wstring& opts)
{
return SymbolEnumeratorPtr( new SymbolEnumeratorClang(wstrToStr(source), wstrToStr(opts) ) );
return SymbolProviderPtr( new SymbolProviderClang(wstrToStr(source), wstrToStr(opts) ) );
}
///////////////////////////////////////////////////////////////////////////////
SymbolEnumeratorPtr getSymbolEnumeratorFromSource(const std::string& source, const std::string& opts)
SymbolProviderPtr getSymbolProviderFromSource(const std::string& source, const std::string& opts)
{
return SymbolEnumeratorPtr(new SymbolEnumeratorClang(source, opts));
return SymbolProviderPtr(new SymbolProviderClang(source, opts));
}
///////////////////////////////////////////////////////////////////////////////
......
......@@ -400,12 +400,37 @@ private:
};
class SymbolEnumeratorClang : public SymbolEnumerator, public boost::enable_shared_from_this<SymbolEnumeratorClang>
class SymbolEnumeratorClang;
class SymbolProviderClang : public SymbolProvider, public boost::enable_shared_from_this< SymbolProviderClang>
{
public:
friend SymbolEnumeratorClang;
SymbolProviderClang(const std::string& sourceCode, const std::string& compileOptions);
private:
SymbolEnumeratorPtr getSymbolEnumerator(const std::wstring& mask = L"") override;
ClangASTSessionPtr m_astSession;
std::vector<std::string> m_symbols;
};
class SymbolEnumeratorClang : public SymbolEnumerator
{
public:
SymbolEnumeratorClang(const std::string& sourceCode, const std::string& compileOptions);
SymbolEnumeratorClang(const std::wstring& mask, const boost::shared_ptr<SymbolProviderClang>& clangProvider) :
m_symbolProvider(clangProvider),
m_index(0),
m_mask(wstrToStr(mask))
{}
private:
......@@ -413,9 +438,11 @@ private:
private:
size_t m_index = 0;
size_t m_index;
std::vector<std::string> m_symbols;
std::string m_mask;
boost::shared_ptr<SymbolProviderClang> m_symbolProvider;
};
}
......@@ -53,22 +53,6 @@ static SymbolSessionPtr createSession(
do {
hres = dataSource.CoCreateInstance(__uuidof(DiaSource), NULL, CLSCTX_INPROC_SERVER);
if ( S_OK == hres )
break;
hres = dataSource.CoCreateInstance(__uuidof(DiaSourceAlt), NULL, CLSCTX_INPROC_SERVER);
if (S_OK == hres)
break;
hres = dataSource.CoCreateInstance(MSDIA12_CLASSGUID, NULL, CLSCTX_INPROC_SERVER);
if (S_OK == hres)
break;
hres = dataSource.CoCreateInstance(MSDIA11_CLASSGUID, NULL, CLSCTX_INPROC_SERVER);
if (S_OK == hres)
break;
HMODULE hModule = NULL;
if ( !GetModuleHandleEx(
......@@ -93,23 +77,63 @@ static SymbolSessionPtr createSession(
if (S_OK == hres)
break;
pos = fileName.find_last_of(L'\\');
fileName.replace(pos, fileName.length() - pos, L"\\msdia120.dll");
hres = dataSource.CoCreateInstance(__uuidof(DiaSource), NULL, CLSCTX_INPROC_SERVER);
if ( S_OK == hres )
break;
hres = NoRegCoCreate(fileName.c_str(), MSDIA12_CLASSGUID, __uuidof(IDiaDataSource), (void**)&dataSource);
hres = dataSource.CoCreateInstance(__uuidof(DiaSourceAlt), NULL, CLSCTX_INPROC_SERVER);
if (S_OK == hres)
break;
pos = fileName.find_last_of(L'\\');
fileName.replace(pos, fileName.length() - pos, L"\\msdia110.dll");
hres = dataSource.CoCreateInstance(MSDIA12_CLASSGUID, NULL, CLSCTX_INPROC_SERVER);
if (S_OK == hres)
break;
hres = NoRegCoCreate(fileName.c_str(), MSDIA11_CLASSGUID, __uuidof(IDiaDataSource), (void**)&dataSource);
hres = dataSource.CoCreateInstance(MSDIA11_CLASSGUID, NULL, CLSCTX_INPROC_SERVER);
if (S_OK == hres)
break;
throw DiaException(L"Call ::CoCreateInstance", hres);
//HMODULE hModule = NULL;
//if ( !GetModuleHandleEx(
// GET_MODULE_HANDLE_EX_FLAG_FROM_ADDRESS | GET_MODULE_HANDLE_EX_FLAG_UNCHANGED_REFCOUNT,
// (LPCTSTR)createSession,
// &hModule) )
// throw DiaException(L"failed to load msdia library");
//DWORD fileNameSize = 0x1000;
//
//std::vector<wchar_t> fileNameBuffer(fileNameSize);
//GetModuleFileNameW(hModule, &fileNameBuffer[0], fileNameSize);
//std::wstring fileName(&fileNameBuffer[0], fileNameSize);
//size_t pos = fileName.find_last_of(L'\\');
//fileName.replace(pos, fileName.length() - pos, L"\\msdia140.dll");
//hres = NoRegCoCreate(fileName.c_str(), MSDIA14_CLASSGUID, __uuidof(IDiaDataSource), (void**)&dataSource);
//if (S_OK == hres)
// break;
//pos = fileName.find_last_of(L'\\');
//fileName.replace(pos, fileName.length() - pos, L"\\msdia120.dll");
//hres = NoRegCoCreate(fileName.c_str(), MSDIA12_CLASSGUID, __uuidof(IDiaDataSource), (void**)&dataSource);
//if (S_OK == hres)
// break;
//pos = fileName.find_last_of(L'\\');
//fileName.replace(pos, fileName.length() - pos, L"\\msdia110.dll");
//hres = NoRegCoCreate(fileName.c_str(), MSDIA11_CLASSGUID, __uuidof(IDiaDataSource), (void**)&dataSource);
//if (S_OK == hres)
// break;
//throw DiaException(L"Call ::CoCreateInstance", hres);
} while( FALSE);
......
......@@ -206,14 +206,7 @@ public:
m_globalSymbol( DiaSymbol::fromGlobalScope( globalScope, getScopeName( session, globalScope ) ) ),
m_session( session ),
m_symbolFileName( symbolFile )
{
CoInitialize(NULL);
}
~DiaSession()
{
CoUninitialize();
}
{}
virtual SymbolPtr getSymbolScope() {
return m_globalSymbol;
......
......@@ -83,7 +83,7 @@ DebugManager::~DebugManager()
client->SetOutputCallbacks(NULL);
CoUninitialize();
//CoUninitialize();
}
///////////////////////////////////////////////////////////////////////////////
......
......@@ -568,7 +568,7 @@ TEST_F(ClangTest, EnumFuncNames)
}; \
";
SymbolEnumeratorPtr symEnum = getSymbolEnumeratorFromSource(srcCode);
auto symEnum = getSymbolProviderFromSource(srcCode)->getSymbolEnumerator();
std::wstring symbol;
std::vector<std::wstring> symbols;
......@@ -613,8 +613,16 @@ TEST_F(ClangTest, Func)
EXPECT_EQ(L"Void(__cdecl)()", compileType(srcCode, L"func1")->getName());
EXPECT_EQ(L"Int4B(__cdecl)()", compileType(srcCode, L"func2<2>")->getName());
EXPECT_EQ(L"Char(__cdecl)()", compileType(srcCode, L"testns::func3")->getName());
EXPECT_EQ(L"Void(__cdecl testcls::)(Int4B)", compileType(srcCode, L"testcls::method")->getName());
EXPECT_EQ(L"Void(__cdecl testcls1<int>::)()", compileType(srcCode, L"testcls1<int>::method")->getName());
if (kdlib::is64bitSystem())
{
EXPECT_EQ(L"Void(__cdecl testcls::)(Int4B)", compileType(srcCode, L"testcls::method")->getName());
EXPECT_EQ(L"Void(__cdecl testcls1<int>::)()", compileType(srcCode, L"testcls1<int>::method")->getName());
}
else
{
EXPECT_EQ(L"Void(__thiscall testcls::)(Int4B)", compileType(srcCode, L"testcls::method")->getName());
EXPECT_EQ(L"Void(__thiscall testcls1<int>::)()", compileType(srcCode, L"testcls1<int>::method")->getName());
}
EXPECT_THROW(compileType(srcCode, L"func2"), TypeException);
EXPECT_THROW(compileType(srcCode, L"func3"), TypeException);
......