memsentry项目源码注释

memsentry github地址:https://github.com/vusec/memsentry

主要是三个文件BenchDomain.cpp,BenchDomainPost.cppMenSentry.cpp。分别为向某些特定地方添加memory access实现bench的目的,去掉之前的memory access来达到不修改源文件的目的,和mensentry pass主体。

BenchDomain.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
/*
* Benchmark performance of domain-based methods by inserting
* domain-switches at specified points. For benchmarking address-based methods,
* this pass is not needed.
*
* Usage:
* This pass inserts memory accesses (tagged as 'safe') at the specified points
* (e.g., at every call/ret). The MemSentry pass will then insert domain
* switches for these memory accesses. Finally, the MemSentryBenchDomainPost
* pass will remove the memory accesses, but leave the domain switches. This can
* thus be used to benchmark the performance of domain-based approaches at
* different frequencies of required switches.
* Passes should thus be used as follows:
* -memsentry-benchdomain -memsentry -memsentry-benchdomain-post
*
* -memsentry-benchdomain-points=[call-ret,icall,libfunc]
* -memsentry-benchdomain-libfunc-file=<file>
*
* call-ret: insert mem access before every call and return
* icall: insert mem access before every indirect call
* libfunc: insert mem access before every library function call from
* specified list.
*/

#define DEBUG_TYPE "memsentry-benchdomain"
#include "utils.h"

#include <set>
#include <fstream>

#include "types.h"
#include "memsentry-pass.h"

using namespace llvm;

enum points {
CALLRET,
ICALL,
LIBFUNC,
};

// 初始化名为Points的opt示例,用以在命令行中显示选项
cl::opt<points> Points("memsentry-benchdomain-points", // name
cl::desc("What points should be treated as safe-region accesses:"), // description
cl::values( // listed values
clEnumValN(CALLRET, "call-ret", "Every call and return"),
clEnumValN(ICALL, "icall", "Indirect calls"),
clEnumValN(LIBFUNC, "libfunc", "Library functions (syscalls), specify list with -memsentry-benchdomain-libfunc-file."),
clEnumValEnd), cl::init(CALLRET) // default value
);

// 初始化名为LibFuncFile的opt示例,存放文件路径。文件中每一行包含一个函数名,被直接调用时被pass instrument
static cl::opt<std::string> LibFuncFile("memsentry-benchdomain-libfunc-file",
cl::desc("Path to file containing (per line) functions that, when called, should be instrumented."));

// 继承ModulePass
struct MemSentryBenchDomain : public ModulePass {
public:
static char ID;
MemSentryBenchDomain() : ModulePass(ID) {}

virtual bool runOnModule(Module &M);

private:
std::set<std::string> libFuncSet;

void initLibFuncs();
void handleInst(Instruction *I);
bool shouldInstrCallRet(Instruction *I);
bool shouldInstrICall(Instruction *I);
bool shouldInstrLibFunc(Instruction *I);
};

/* 判断:
* 当前指令call或者invoke直接调用了一个"原文件中没有显式要求ignore的本地函数"
* 或者当前指令是一个return指令
* 则返回true
*/
bool MemSentryBenchDomain::shouldInstrCallRet(Instruction *I) {
// 判断是否为Call指令或者Invoke指令
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
/* Try to only instrument calls where we can also insert the
* corresponding switch at the return, to better simulate what a defense
* could do. */
// CallSite是CallInst和InvokeInst的一层封装,保留了两者的一些共有特征。
CallSite CS(I);
// 拿到这条指令调用的函数
Function *F = CS.getCalledFunction();
return !F || shouldInstrument(*F);
}

// return指令
if (isa<ReturnInst>(I))
return true;

return false;
}

// 当前指令为间接调用(即不通过函数指针?)
bool MemSentryBenchDomain::shouldInstrICall(Instruction *I) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
CallSite CS(I);
Function *F = CS.getCalledFunction();
return F == nullptr;
}

return false;
}

// 当前指令直接调用了一个FuncLib中的函数
bool MemSentryBenchDomain::shouldInstrLibFunc(Instruction *I) {
if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
CallSite CS(I);
Function *F = CS.getCalledFunction();

if (!F)
return false;

if (libFuncSet.find(F->getName()) != libFuncSet.end())
return true;
}

return false;
}

/*
* 处理每条指令
*/
void MemSentryBenchDomain::handleInst(Instruction *I) {
// 根据Points的类型,判断是否需要处理
switch (Points) {
case CALLRET: if (!shouldInstrCallRet(I)) return; break;
case ICALL: if (!shouldInstrICall(I)) return; break;
case LIBFUNC: if (!shouldInstrLibFunc(I)) return; break;
default: assert(0); break;
}

// 在I前面插入读取和储存指令
// IRBuilder的主要功能是构建一系列IR,具体的语法以后再学习
IRBuilder<> B(I);
Value *Val = B.getInt8(0);
Value *Ptr = Constant::getNullValue(B.getInt8PtrTy());
StoreInst *Dummy = B.CreateStore(Val, Ptr, true);
Dummy->setMetadata("memsentry.benchdomain.dummy", MDNode::get(I->getContext(), {}));

memsentry_saferegion_access(Dummy);
}

// 从LibFuncFile中按行载入需要instrument的函数名称的字符串
void MemSentryBenchDomain::initLibFuncs() {
std::ifstream input(LibFuncFile);
std::string line;
while (std::getline(input, line))
libFuncSet.insert(line);

for (std::string a : libFuncSet)
errs() << " F: " << a << "\n";
}

// 重载runOnModule,对每一个Module,
bool MemSentryBenchDomain::runOnModule(Module &M) {
// 如果命令行里选择了LIBFUNC,则需要加载LibFunc文件
if (Points == LIBFUNC)
initLibFuncs();

// 遍历Module中函数
for (Function &F : M) {
// 是否是原文件中没有显式要求ignore的本地函数
if (!shouldInstrument(F))
continue;
// 遍历函数中instruction
for (inst_iterator II = inst_begin(&F), E = inst_end(&F); II != E; ++II) {
// 通过dereference拿到
Instruction *I = &*II;
handleInst(I);
}
}

return true;
}

char MemSentryBenchDomain::ID = 0; // ID会按照pass地址赋值,所以值不重要
// 通过初试化实例X的方法,注册pass
static RegisterPass<MemSentryBenchDomain> X("memsentry-benchdomain",
"MemSentry benchmarking pass for domain-based methods");

BenchDomainPost.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
/*
* Should be used in conjuction with -memsentry-benchdomain, and should run
* after -memsentry.
* See BenchDomain.cpp for more details.
*/

#define DEBUG_TYPE "memsentry-benchdomain-post"
#include "utils.h"

#include "types.h"

using namespace llvm;

// 继承FunctionPass
struct MemSentryBenchDomainPost : public FunctionPass {
public:
static char ID;
MemSentryBenchDomainPost() : FunctionPass(ID) {}
virtual bool runOnFunction(Function &F);
};

// 删除之前benchdomain插入的新的instruction,只保留mensentry插入的instruction
bool MemSentryBenchDomainPost::runOnFunction(Function &F) {
// 如果是原文件中没有显式要求ignore的本地函数,就处理,否则返回false
if (!shouldInstrument(F))
return false;

SmallVector<Instruction *, 16> DummyInstructions;

// 遍历instruction
for (inst_iterator II = inst_begin(&F), E = inst_end(&F); II != E; ++II) {
Instruction *I = &*II;
// 拿到metadata,如果是benchdomain插入的,添加到smallvector中
MDNode *MD = I->getMetadata("memsentry.benchdomain.dummy");
if (MD)
DummyInstructions.push_back(I);
}

bool changed = false;
for (Instruction *I : DummyInstructions) {
changed = true;
// 删除指令
I->eraseFromParent();
}

// 返回值代表函数被修改过
return changed;
}

char MemSentryBenchDomainPost::ID = 0;
// 注册pass
static RegisterPass<MemSentryBenchDomainPost> X("memsentry-benchdomain-post", "MemSentry benchmarking pass for domain-based methods - cleanup pass");

MenSentry.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
/*
* MemSentry: protecting safe regions on commodity hardware.
*
* This pass applies a specified protection to a program. The points that are
* allowed access should be specified by a previous pass (e.g., BenchDomain or a
* defense). An alternative to this is to place all code accessing the safe
* region in its own section, and specify the name of this section to this pass.
*
* For address-based approaches, all read, all writes or both read and writes
* can be instrumented, as specified by the compile-time flag.
* For certain domain-based approaches the safe-region is pre-allocated, and the
* size of this region should be specified at compile time.
*
* -memsentry-prot-feature=[sfi*,mpx,vmfunc,crypt]
* -memsentry-whitelist-section=memsentry_functions
* -memsentry-rw=[r,w,rw*] # For sfi/mpx
* -memsentry-verify-external-call-args=true # For sfi/mpx
* -memsentry-max-region-size=4096 # For crypt
*/

#define DEBUG_TYPE "memsentry"
#include "utils.h"

#include "types.h"
#include "memsentry-pass.h"

using namespace llvm;

cl::opt<prot_method> ProtMethod("memsentry-prot-method",
cl::desc("Method of protecting safe region:"),
cl::values(
clEnumValN(SFI, "sfi", "Software fault isolation (pointer masking)"),
clEnumValN(MPX, "mpx", "Intel MPX (memory protection extensions)"),
clEnumValN(VMFUNC, "vmfunc", "VM-Functions (requires vmfunc-enabled hypervisor like MemSentry's Dune)"),
clEnumValN(MPK, "mpk", "Intel MPK (memory protection keys). Upcoming, implemented as simulation"),
clEnumValN(CRYPT, "crypt", "Encryption (using Intel AES-NI)"),
clEnumValEnd), cl::init(SFI));

cl::opt<readwrite> ReadWrite("memsentry-rw",
cl::desc("What type of memory accesses to protect when using address-based approaches:"),
cl::values(
clEnumValN(READWRITE, "rw", "Reads and writes"),
clEnumValN(READ, "r", "Reads only"),
clEnumValN(WRITE, "w", "Writes only"),
clEnumValEnd), cl::init(READWRITE));

// 指定白名单section
static cl::opt<std::string> WhitelistSection("memsentry-whitelist-section",
cl::desc("Functions in this section are allowed access to the safe region"),
cl::init("memsentry_functions"));

static cl::opt<bool> VerifyExternalCallArguments(
"memsentry-verify-external-call-args",
cl::desc("For address-based methods, add checks to all pointer-type "
"arguments to external functions (make sure uninstrumented "
"libraries cannot use invalid pointers."),
cl::init(true));

static cl::opt<unsigned> MaxRegionSize("memsentry-max-region-size",
cl::desc("For methods that need to pre-allocate the entire safe-region,"
" the maximum size that should be supported."),
cl::init(4096));

/*
* External function, used to mark instruction as safe from other passes.
*/
void memsentry_saferegion_access(Instruction *I) {
I->setMetadata(MemSentrySafeMDName, MDNode::get(I->getContext(), {}));
}

// 从module中查找名为name的函数
static Function *getHelperFunc(Module *M, std::string name) {
Function *F = M->getFunction(name);
if (!F) {
errs() << "Cannot find func '" << name << "'\n";
exit(1);
}
return F;
}

// 把module中名为name的全局变量的初始值设为value
void setGV(Module &M, StringRef name, size_t value) {
GlobalVariable* GV = M.getNamedGlobal(name);
if(!GV) {
errs() << "Error: no " << name << " global variable found\n";
exit(1);
}
Type *Ty = GV->getType()->getPointerElementType();
Constant *Val = ConstantInt::get(Ty, value);
GV->setInitializer(Val);
}

// 函数参数是否有指针类型
static bool hasPointerArg(Function *F) {
FunctionType *FT = F->getFunctionType();
for (unsigned i = 0, n = FT->getNumParams(); i < n; i++) {
Type *type = FT->getParamType(i);
if (type->isPointerTy())
return true;
}
return false;
}

// 返回是否直接调用了白名单section中的函数
bool callsIntoWhitelistedFunction(CallSite &CS) {
Function *F = CS.getCalledFunction();
if (!F) // Indirect call
return false;
return F->getSection() == WhitelistSection;
}


/* Determines whether an instruction should be allowed access to the safe region
* (i.e., a previous pass has marked it as such).
*/
bool isAllowedAccess(Instruction *I) {
MDNode *MD = I->getMetadata(MemSentrySafeMDName);
return MD != NULL;
}

class Protection;

struct MemSentryPass : public ModulePass {
public:
static char ID;
MemSentryPass() : ModulePass(ID) {}
virtual bool runOnModule(Module &M);

private:
Protection *prot;

void handleInst(Instruction *I);
};

class Protection {
protected:
Module *M;
InlineFunctionInfo *inliningInfo;
enum prot_method protMethod;
std::string protMethodStr;
public:
Protection(Module *M, enum prot_method protMethod) {
this->M = M;
this->inliningInfo = new InlineFunctionInfo();
this->protMethod = protMethod;
this->protMethodStr = prot_method_strings[protMethod];
}
virtual ~Protection() { }

virtual void handleLoadInst(LoadInst *LI) {
handleMemInst(LI);
}
virtual void handleStoreInst(StoreInst *SI) {
handleMemInst(SI);
}
virtual void handleLoadIntrinsic(MemTransferInst *MTI) {
handleMemInst(MTI);
}
virtual void handleStoreIntrinsic(MemIntrinsic *MI) {
handleMemInst(MI);
}
virtual void handleMemInst(Instruction *I) {
assert(0 && "Not implemented");
}

virtual void handleCallInst(CallSite &CS) = 0;


/*
* Inline calls to _memsentry_<protMethod>*. This is done afterwards,
* instead of immediately, so the optimizeBB function can more easily
* see (and optimize) region changes.
*/
void inlineHelperCalls(Function &F) {
bool has_changed;
do {
has_changed = false;
for (inst_iterator it = inst_begin(F), E = inst_end(F); it != E; ++it) {
Instruction *I = &(*it);
CallInst *CI = dyn_cast<CallInst>(I);
// 不是call指令
if (!CI)
continue;
Function *F = CI->getCalledFunction();
if (!F)
continue;
// 是memsentry插入的指令
if (F->getName().startswith("_memsentry_" + protMethodStr)) {
InlineFunction(CI, *inliningInfo);
has_changed = true;
break;
}
}
} while (has_changed);
}

// 将应该处理的函数中的_memsentry_<protMethod>*全部inline
// 此方法在domainprotection中被重载
virtual void postInstrumentation() {
for (Function &F : *M) {
if (!shouldInstrument(F, &WhitelistSection))
continue;
inlineHelperCalls(F);
}
}
};

// addressprotection即通过把地址空间分开,访问时通过pass对指针加mask的方式保护sensitive region
// 在static文件夹中inline实现了两个方法,分别是mpx和sfi
class AddressProtection : public Protection {
protected:
Function *checkFunc; // 具体mpx/sfi函数,其参数为一个指针

Value* verifyPtr(Value *ptrVal, Instruction *I) {
// 不处理
if (isa<Constant>(ptrVal)) {
//LOG_LINE("+ Ignoring constant " << *I);
return ptrVal;
}
//LOG_LINE("Masking " << *I);
// 使用IRBuilder在I前插入语句块
IRBuilder<> B(I);
// 把ptrVal转换成checkFunc的第一个参数的类型
Value *funcArg = B.CreateBitCast(ptrVal, checkFunc->getFunctionType()->getParamType(0));
// 即语句masked = checkFunc(funcArg)
Value *masked = B.CreateCall(checkFunc, { funcArg });
// 把masked转换成ptrVal的值
Value *casted = B.CreateBitCast(masked, ptrVal->getType());

// 这里有个问题,不需要把ptrVal赋值为casted吗?
// 最后ptrVal的值肯定被修改了(否则这个函数就没用了)
// 但是怎么修改的呢?
return casted;
}
public:
// 复用protection构造函数
// checkFunc从module中找_memsentry_xxx,如_memsentry_mpx
AddressProtection(Module *M, enum prot_method protMethod)
: Protection(M, protMethod) {
std::string checkFuncName = "_memsentry_" + protMethodStr;
checkFunc = getHelperFunc(M, checkFuncName);
}

// 下面四个函数很简单,将访问地址的相应指针通过调用verifyPtr加上mask
// 加上mask的指针就能够访问sensitive region了
virtual void handleLoadInst(LoadInst *LI) {
if (!isAllowedAccess(LI))
LI->setOperand(0, verifyPtr(LI->getOperand(0), LI));
}

virtual void handleStoreInst(StoreInst *SI) {
if (!isAllowedAccess(SI))
SI->setOperand(1, verifyPtr(SI->getOperand(1), SI));
}

virtual void handleLoadIntrinsic(MemTransferInst *MTI) {
if (!isAllowedAccess(MTI))
MTI->setSource(verifyPtr(MTI->getRawSource(), MTI));
}

virtual void handleStoreIntrinsic(MemIntrinsic *MI) {
if (!isAllowedAccess(MI))
MI->setDest(verifyPtr(MI->getRawDest(), MI));
}

/* Verify pointer args to external functions if flag is set. */
// 处理call/invoke函数
void handleCallInst(CallSite &CS) {
Function *F = CS.getCalledFunction();
// 用户指定不处理
if (!VerifyExternalCallArguments)
return;
// 调用了白名单中的函数
if (callsIntoWhitelistedFunction(CS))
return;
// 不处理内联asm
if (CS.isInlineAsm())
return;
// 简洁调用
if (!F)
return; /* Indirect call */
// 不处理内部函数
if (!F->isDeclaration() && !F->isDeclarationForLinker())
return; /* Not external */

// 是intrinsic函数而且函数参数有指针,则需要判断
// 只处理memcpy,memmove,memset,vastart,vacopy,vaend
if (F->isIntrinsic() && hasPointerArg(F)) {
switch (F->getIntrinsicID()) {
case Intrinsic::dbg_declare:
case Intrinsic::dbg_value:
case Intrinsic::lifetime_start:
case Intrinsic::lifetime_end:
case Intrinsic::invariant_start:
case Intrinsic::invariant_end:
case Intrinsic::eh_typeid_for:
case Intrinsic::eh_return_i32:
case Intrinsic::eh_return_i64:
case Intrinsic::eh_sjlj_functioncontext:
case Intrinsic::eh_sjlj_setjmp:
case Intrinsic::eh_sjlj_longjmp:
return; /* No masking */
case Intrinsic::memcpy:
case Intrinsic::memmove:
case Intrinsic::memset:
case Intrinsic::vastart:
case Intrinsic::vacopy:
case Intrinsic::vaend:
break; /* Continue with masking */
default:
errs() << "Unhandled intrinsic that takes pointer: " << *F << "\n";
break; /* Do mask to be sure. */
}
}

// 对所有参数中的指针执行verifyPtr进行mask
Instruction *I = CS.getInstruction();
for (unsigned i = 0, n = CS.getNumArgOperands(); i < n; i++) {
Value *Arg = CS.getArgOperand(i);
if (Arg->getType()->isPointerTy()){
Value *MaskedArg = verifyPtr(Arg, I);
CS.getInstruction()->setOperand(i, MaskedArg);
}
}
}

};

// domainprotection在调用是设置相关区域有效(论文是这么说的)
// 通用的接口需要包括begin和end两个函数,一般是用于控制相应区域的active程度,类似于switch开关的实现
class DomainProtection : public Protection {
protected:
Function *beginFunc, *endFunc;
std::string beginFuncName, endFuncName;

// 在instruction前后插入beginFunc和endFunc
void changeDomain(Instruction *I) {
CallInst *CIb = CallInst::Create(beginFunc);
CIb->insertBefore(I);
CallInst *CIe = CallInst::Create(endFunc);
CIe->insertAfter(I);
}

/*
* Optimizes a basicblock by merging regions which have no mem accesses
* or so in between, thus eliminating needless switching of regions.
* Returns true if the function needs to be called again: when a
* modification to a BB is made, it cannot continue iterating over that
* BB.
*/
// 像上面说的,如果两组beginFunc/endFunc之间没有meminst,那么去掉中间的endFunc和beginFunc
bool optimizeBB(BasicBlock &BB) {
bool inMap = false; // 是否在beginFunc和endFunc之间
bool noMemSinceUnmap = false;
Instruction *lastUnmap = NULL;

for (Instruction &II : BB) {
Instruction *I = &II;
LoadInst *LI = dyn_cast<LoadInst>(I);
StoreInst *SI = dyn_cast<StoreInst>(I);
MemIntrinsic *MI = dyn_cast<MemIntrinsic>(I);
CallInst *CI = dyn_cast<CallInst>(I);
if (LI || SI || MI)
noMemSinceUnmap = false;
else if (CI) {
Function *F = CI->getCalledFunction();
if (!F)
continue;
if (F->getName() == beginFuncName) {
assert(!inMap);
inMap = true;
// 之前有一个endFunc,存在lastUnmap中,并且没有LI/SI/MI
// 则把这个beginFunc和上一个endFunc去掉
// 然后直接返回
if (noMemSinceUnmap) {
lastUnmap->eraseFromParent();
I->eraseFromParent();
return true;
}
}
else if (F->getName() == endFuncName) {
assert(inMap);
inMap = false;
noMemSinceUnmap = true;
lastUnmap = I;
}
else {
noMemSinceUnmap = false;
}
}
}
(void)inMap; /* Silence compiler, assert doesn't count as use. */
return false;
}
public:
//
DomainProtection(Module *M, enum prot_method protMethod)
: Protection(M, protMethod) {
beginFuncName = "_memsentry_" + protMethodStr + "_begin";
endFuncName = "_memsentry_" + protMethodStr + "_end";
beginFunc = getHelperFunc(M, beginFuncName);
endFunc = getHelperFunc(M, endFuncName);
}

// 对domain来说load/store/intrinsic都是调用meminst
// 对于每一条需要修改的instruction,调用changedomain
virtual void handleMemInst(Instruction *I) {
if (isAllowedAccess(I))
changeDomain(I);
}

void handleCallInst(CallSite &CS) {
if (callsIntoWhitelistedFunction(CS))
changeDomain(CS.getInstruction());
}

//
virtual void postInstrumentation() {
// Optimize domain-based instrumentation by removing unnecessary
// switches back and forth.
for (Function &F : *M) {
if (!shouldInstrument(F, &WhitelistSection))
continue;

for (BasicBlock &BB : F) {
unsigned cnt = 0;
// 优化直到没有需要合并的domain switch
while (optimizeBB(BB)) cnt++;
//LOG_LINE("Optimized away " << cnt << " domain switches in " << F.getName());
}
inlineHelperCalls(F);
}
}
};


// 根据prot_method返回新建protection实例指针
static Protection* getProtectionInstance(Module *M, enum prot_method protMethod) {
switch(protMethod) {
case SFI:
case MPX:
return new AddressProtection(M, protMethod);

case VMFUNC:
case MPK:
case CRYPT:
return new DomainProtection(M, protMethod);

default:
assert(0 && "Not implemented!");
return NULL;
}
}

/*
* 处理四种指令:load、store、memintrinsic、call/invoke
* 具体处理指令的步骤交给protection实例prot完成
*/
void MemSentryPass::handleInst(Instruction *I) {
ifcast(LoadInst, LI, I) {
if (ReadWrite != WRITE)
prot->handleLoadInst(LI);
}
else ifcast(StoreInst, SI, I) {
if (ReadWrite != READ)
prot->handleStoreInst(SI);
}
else ifcast(MemIntrinsic, MI, I) {
MemTransferInst *MTI = dyn_cast<MemTransferInst>(MI);
if (MTI && ReadWrite != WRITE)
prot->handleLoadIntrinsic(MTI);
if (ReadWrite != READ)
prot->handleStoreIntrinsic(MI);
}
else if (isa<CallInst>(I) || isa<InvokeInst>(I)) {
CallSite CS(I);
prot->handleCallInst(CS);
}
}


// 重载的runOnModule方法
bool MemSentryPass::runOnModule(Module &M) {
LOG_LINE("Starting, ProtMethod=" << prot_method_strings[ProtMethod]);

// Fix up tracking variables so static lib knows compilation params
setGV(M, "_memsentry_prot_method", ProtMethod);
setGV(M, "_memsentry_max_region_size", MaxRegionSize);

// Get right instrumentation class (address-based or domain-based)
this->prot = getProtectionInstance(&M, ProtMethod);

// 遍历module中function
for (Function &F : M) {
// 在白名单则跳过
if (!shouldInstrument(F, &WhitelistSection))
continue;
LOG_LINE("Instrumenting " << F.getName());
for (inst_iterator II = inst_begin(&F), E = inst_end(&F); II != E; ++II) {
Instruction *I = &*II;
// 正式处理
handleInst(I);
}
}

// Optimize inserted instrumentation further if need be.
LOG_LINE("Optimizing...");
this->prot->postInstrumentation();

return true;
}

char MemSentryPass::ID = 0;
// 注册
static RegisterPass<MemSentryPass> X("memsentry", "MemSentry pass");