1040 lines
36 KiB
Python
1040 lines
36 KiB
Python
# Simple attempt to follow obfuscated code. Mostly copied from SwitchOverride.java. Incomplete, please only use as reference.
|
|
#@author Satsuoni
|
|
#@category Deobfuscation
|
|
#@keybinding
|
|
#@menupath
|
|
#@toolbar
|
|
|
|
|
|
from binascii import hexlify
|
|
import logging
|
|
from ghidra.app.emulator import EmulatorHelper
|
|
from ghidra.util.task import ConsoleTaskMonitor
|
|
from ghidra.program.model.pcode import PcodeOp
|
|
from ghidra.program.model.symbol import *
|
|
from ghidra.program.model.pcode import JumpTable
|
|
from java.util import LinkedList, Arrays, ArrayList
|
|
from ghidra.app.cmd.function import CreateFunctionCmd
|
|
from ghidra.app.cmd.disassemble import DisassembleCommand
|
|
from ghidra.program.model.lang import Register
|
|
from ghidra.program.model.lang import OperandType
|
|
from ghidra.program.model.lang import RegisterManager
|
|
|
|
import sys
|
|
|
|
DATA_FILE="runData.json"
|
|
logger = logging.getLogger("")
|
|
logger.setLevel(logging.DEBUG)
|
|
handler1=logging.StreamHandler(sys.stdout)
|
|
# from StackOverflow
|
|
class WarnFormatter(logging.Formatter):
|
|
err_fmt = "ERROR: %(msg)s"
|
|
warn_fmt = "Warning: %(msg)s"
|
|
dbg_fmt = "DBG: %(module)s: %(lineno)d: %(msg)s"
|
|
info_fmt = "%(msg)s"
|
|
def __init__(self, fmt="%(levelno)s: %(msg)s"):
|
|
logging.Formatter.__init__(self, fmt)
|
|
|
|
def format(self, record):
|
|
# Save the original format configured by the user
|
|
# when the logger formatter was instantiated
|
|
format_orig = self._fmt
|
|
# Replace the original format with one customized by logging level
|
|
if record.levelno == logging.DEBUG:
|
|
self._fmt = WarnFormatter.dbg_fmt
|
|
elif record.levelno == logging.INFO:
|
|
self._fmt = WarnFormatter.info_fmt
|
|
elif record.levelno == logging.ERROR:
|
|
self._fmt = WarnFormatter.err_fmt
|
|
elif record.levelno == logging.WARNING:
|
|
self._fmt = WarnFormatter.warn_fmt
|
|
# Call the original formatter class to do the grunt work
|
|
result = logging.Formatter.format(self, record)
|
|
# Restore the original format configured by the user
|
|
self._fmt = format_orig
|
|
return result
|
|
|
|
formatter = WarnFormatter('%(message)s')
|
|
handler1.setFormatter(formatter)
|
|
handler2=logging.FileHandler("testlog.log", mode='w', encoding="utf-8")
|
|
handler2.setFormatter(formatter)
|
|
logger.addHandler(handler1)
|
|
logger.addHandler(handler2)
|
|
|
|
def getAddress(offset):
|
|
return currentProgram.getAddressFactory().getDefaultAddressSpace().getAddress(offset)
|
|
|
|
def getProgramRegisterList(currentProgram):
|
|
pc = currentProgram.getProgramContext()
|
|
return pc.registers
|
|
|
|
state = getState()
|
|
currentProgram = state.getCurrentProgram()
|
|
name = currentProgram.getName()
|
|
listing = currentProgram.getListing()
|
|
|
|
logger.info("Starting working on {}".format(name))
|
|
def getPossibleConstAddressFromInstruction(instr):
|
|
raw_pcode = instr.getPcode()
|
|
for code in raw_pcode:
|
|
if code.getOpcode()==PcodeOp.COPY:
|
|
inp=code.getInputs()[0]
|
|
if inp.size==8 and inp.isConstant():
|
|
return getAddress(inp.getOffset())
|
|
return None
|
|
|
|
def isReturn(instr):
|
|
if instr is None:
|
|
return False
|
|
raw_pcode = instr.getPcode()
|
|
for code in raw_pcode:
|
|
if code.getOpcode()==PcodeOp.RETURN:
|
|
return True
|
|
return False
|
|
|
|
def isComputedBranchInstruction( instr):
|
|
if instr is None:
|
|
return False
|
|
flowType = instr.getFlowType()
|
|
if flowType == RefType.COMPUTED_JUMP:
|
|
return True
|
|
if (flowType.isCall()):
|
|
#is it a callfixup?
|
|
referencesFrom = instr.getReferencesFrom()
|
|
for reference in referencesFrom:
|
|
if reference.getReferenceType().isCall():
|
|
func = currentProgram.getFunctionManager().getFunctionAt(reference.getToAddress())
|
|
if func is not None and func.getCallFixup() is not None:
|
|
return True
|
|
return False
|
|
|
|
def isCallInstruction(instr):
|
|
if instr is None:
|
|
return False
|
|
flowType = instr.getFlowType()
|
|
if flowType.isCall():
|
|
return True
|
|
return False
|
|
|
|
def isCallOther(instr):
|
|
if instr is None:
|
|
return False
|
|
raw_pcode = instr.getPcode()
|
|
for code in raw_pcode:
|
|
if code.getOpcode()==PcodeOp.CALLOTHER:
|
|
return True
|
|
return False
|
|
|
|
def getCondmoveInstruction(instr):
|
|
if instr is None:
|
|
return None
|
|
raw_pcode = instr.getPcode()
|
|
for entry in raw_pcode:
|
|
if entry.getOpcode()==PcodeOp.CBRANCH:
|
|
addr=entry.getInput(0)
|
|
if addr is None: continue
|
|
if addr.size==8 :
|
|
maybeaddr=getAddress(addr.getOffset())
|
|
nextInstr=instr.getNext()
|
|
if nextInstr is None:
|
|
return None
|
|
if nextInstr.getAddress()==maybeaddr:
|
|
return (instr.getAddress(),instr.getDefaultOperandRepresentation(0),
|
|
instr.getDefaultOperandRepresentation(1))
|
|
return None
|
|
|
|
def getPotentialJtableAccess(instr): #use mnemonic? movsx(d), add, jmp, returns None or (jtableReg,offsetReg,jumpAddr)
|
|
if instr is None:
|
|
return False
|
|
mnem=instr.getMnemonicString().lower()
|
|
if "movsx" in mnem:
|
|
nxt=instr.getNext()
|
|
if nxt is None: return None
|
|
nx=nxt.getMnemonicString().lower()
|
|
if not "add" in nx: return None
|
|
nxt=nxt.getNext()
|
|
if nxt is None: return None
|
|
if isComputedBranchInstruction(nxt):
|
|
objects=instr.getOpObjects(1)
|
|
if len(objects)!=3 or int(str(objects[2]),0)!=4:
|
|
return None
|
|
return (objects[0],objects[1],nxt.getAddress())
|
|
return None
|
|
|
|
def skip(emu):
|
|
executionAddress = emu.getExecutionAddress()
|
|
instr=getInstructionAt(executionAddress)
|
|
nextInstr=instr.getNext()
|
|
lval = int("0x{}".format(nextInstr.getAddress()), 16)
|
|
emu.writeRegister(emu.getPCRegister(), lval)
|
|
|
|
#there can be several cmove in one code step...
|
|
|
|
|
|
class Branch(object):
|
|
def __init__(self, addr, to, frm):
|
|
self.address=addr
|
|
self.true=frm
|
|
self.false=to
|
|
self.target=to
|
|
self.isTrueTaken=False
|
|
self.isFalseTaken=False
|
|
self.trueIndex=None
|
|
self.falseIndex=None
|
|
self.lastTaken=None
|
|
def take(self,pth,emu):
|
|
if emu.getExecutionAddress() !=self.address:
|
|
logger.warning("Trying to take branch {} at address {}".format(self.address,emu.getExecutionAddress()))
|
|
return
|
|
if pth:
|
|
val=emu.readRegister(self.true)
|
|
else:
|
|
val=emu.readRegister(self.false)
|
|
emu.writeRegister(self.target,val)
|
|
self.lastTaken=pth
|
|
skip(emu)
|
|
def registerOutput(self,output): #output is a tuple (type, value)
|
|
if self.lastTaken is None:
|
|
logger.warning("Trying to write output for branch {} that was not triggered".format(self.address))
|
|
return
|
|
if self.lastTaken:
|
|
if self.isTrueTaken:
|
|
if self.trueIndex[0]!=output[0] or self.trueIndex[1]!=output[1]:
|
|
logger.warning("Trying to overwrite true output for branch {} ".format(self.address))
|
|
return
|
|
else:
|
|
self.isTrueTaken=True
|
|
self.trueIndex=output
|
|
else:
|
|
if self.isFalseTaken:
|
|
if self.falseIndex[0]!=output[0] or self.falseIndex[1]!=output[1]:
|
|
logger.warning("Trying to overwrite false output for branch {} ".format(self.address))
|
|
return
|
|
else:
|
|
self.isFalseTaken=True
|
|
self.falseIndex=output
|
|
def hasUntakenPaths(self):
|
|
return not (self.isFalseTaken and self.isTrueTaken)
|
|
def getNextUntaken(self):
|
|
if not self.isFalseTaken: return False
|
|
if not self.isTrueTaken: return True
|
|
return None
|
|
def getExpectedOutput(self):
|
|
if self.lastTaken is None:
|
|
return None
|
|
if self.lastTaken:
|
|
return self.trueIndex
|
|
else:
|
|
return self.falseIndex
|
|
class BranchBox(object):
|
|
def __init__(self,index):
|
|
self.branches=[]
|
|
self.curbranch=None
|
|
self.trace=None
|
|
self.index=index # number in obfuscated jump
|
|
self.pathTaken=[]
|
|
self.pathToBox=[] #a path to retrace if you want to get to this box, I guess
|
|
logger.info("Registering new branch box at 0x{}".format(index))
|
|
def hasUntakenPaths(self):
|
|
for br in self.branches:
|
|
if br.hasUntakenPaths():
|
|
return True
|
|
return False
|
|
def registerBranch(self,condmove): # tuple..
|
|
#logger.info("Register branch {} {}".format(condmove,self.curbranch))
|
|
if self.curbranch is None:
|
|
if len(self.branches)==0:
|
|
br=Branch(condmove[0],condmove[1],condmove[2])
|
|
self.branches.append(br)
|
|
self.curbranch=self.branches[0]
|
|
return True
|
|
else:
|
|
self.curbranch=self.branches[0]
|
|
if self.curbranch.lastTaken is None:
|
|
if self.curbranch.address==condmove[0]: #goto next branch/ wait for taking
|
|
return True
|
|
else:
|
|
logger.warning("Trying to add branch for branch {} that was not triggered".format(self.curbranch.address))
|
|
return False
|
|
eo=self.curbranch.getExpectedOutput()
|
|
if eo is not None: #already have output, got there or fail
|
|
if eo[0]=="branch" and self.branches[eo[1]].address==condmove[0]:
|
|
self.curbranch=self.branches[eo[1]]
|
|
return True
|
|
else:
|
|
logger.warning("Failed following trace")
|
|
return False
|
|
#new branch
|
|
logger.info("new branch {} {} {}".format(self.index, self.curbranch.lastTaken,condmove[0]))
|
|
br=Branch(condmove[0],condmove[1],condmove[2])
|
|
self.curbranch.registerOutput(("branch",len(self.branches)))
|
|
self.branches.append(br)
|
|
self.curbranch=br
|
|
if len(self.branches)>1024:
|
|
logger.warning("Too many branches in one box")
|
|
return False
|
|
return True
|
|
def registerIndexOutput(self,index):
|
|
if self.curbranch is None:
|
|
logger.warning("Trying to put output to empty branchbox")
|
|
return False
|
|
if self.curbranch.lastTaken is None:
|
|
logger.warning("Trying to add branch for branch {} that was not triggered".format(self.curbranch.address))
|
|
return False
|
|
self.curbranch.registerOutput(("index",index))
|
|
return True
|
|
def reset(self):
|
|
self.trace=None
|
|
self.pathTaken=[]
|
|
self.curbranch=None
|
|
for br in self.branches:
|
|
br.lastTaken=None
|
|
if len(self.branches)>0:
|
|
self.curbranch=self.branches[0]
|
|
def loadTrace(self,trace):
|
|
self.trace=trace
|
|
def isTracing(self):
|
|
return self.trace is not None
|
|
def searchUntakenInTree(self, inindex):
|
|
dbranch=self.branches[inindex]
|
|
if dbranch.hasUntakenPaths(): return True
|
|
if dbranch.trueIndex[0]=="branch":
|
|
if self.searchUntakenInTree(dbranch.trueIndex[1]): return True
|
|
if dbranch.falseIndex[0]=="branch":
|
|
if self.searchUntakenInTree(dbranch.falseIndex[1]): return True
|
|
return False
|
|
def findNextUntakenStep(self):
|
|
pth=self.curbranch.getNextUntaken()
|
|
if pth is not None: return pth
|
|
#recursive search...
|
|
if self.curbranch.trueIndex[0]=="branch":
|
|
if self.searchUntakenInTree(self.curbranch.trueIndex[1]): return True
|
|
if self.curbranch.falseIndex[0]=="branch":
|
|
if self.searchUntakenInTree(self.curbranch.falseIndex[1]): return False
|
|
return None
|
|
def takeUntakenOrTrace(self,emu):
|
|
if len(self.branches)==0:
|
|
logger.warning("Branch box not initialized")
|
|
return
|
|
if self.curbranch is None:
|
|
self.curbranch=self.branches[0]
|
|
if self.curbranch.address!=emu.getExecutionAddress():
|
|
logger.warning("Branch box broken?")
|
|
return
|
|
if self.isTracing():
|
|
if len(self.trace)>0:
|
|
logger.info("Taking trace {}".format(self.trace[0]))
|
|
self.curbranch.take(self.trace[0],emu)
|
|
self.pathTaken.append(self.trace[0])
|
|
out=self.curbranch.getExpectedOutput()
|
|
self.trace=self.trace[1:]
|
|
if out is None:
|
|
logger.waring("Trace failed...")
|
|
return
|
|
#if out[0] is not "branch":
|
|
# self.curbranch=None
|
|
# return
|
|
#self.curbranch=self.branches[out[1]]
|
|
else:
|
|
logger.warning("Trace bound violation")
|
|
return
|
|
else:
|
|
pth=self.findNextUntakenStep()
|
|
if pth is None:
|
|
logger.warning("No untaken paths left for {}".format(self.index))
|
|
return
|
|
self.pathTaken.append(pth)
|
|
logger.info("Taking {}".format(pth))
|
|
self.curbranch.take(pth,emu)
|
|
|
|
|
|
class CPath(object):
|
|
def __init__(self):
|
|
self.path=[]
|
|
self.ending=None #path can end in branch, return, baddata, basic loop or infinite loop (iloop)
|
|
self.endJump=None
|
|
self.multiend=False
|
|
def contains(self,index):
|
|
return index in self.path
|
|
def addIndex(self,index):
|
|
if not self.contains(index):
|
|
self.path.append(index)
|
|
def estimateTableLen(table_addr,first_addr): # a rough estimate to be sure... the "switch guard" thing on top is usually lower
|
|
cnt=0
|
|
table=table_addr
|
|
while getInt(table)<0:
|
|
naddr=table_addr.add(getInt(table))
|
|
if naddr<first_addr or naddr>table_addr: break
|
|
cnt+=1
|
|
table=table.add(4)
|
|
return cnt
|
|
def tryGetSwitchGuard(instr): #also mnemonic based, so fragile
|
|
prev=instr.getPrevious()
|
|
if prev is None: return -1
|
|
prev=prev.getPrevious()
|
|
nx=prev.getMnemonicString().lower()
|
|
if nx == "cmp":
|
|
objects=prev.getOpObjects(1)
|
|
if len(objects)!=1: return -1
|
|
try:
|
|
return int(str(objects[0]),0)
|
|
except:
|
|
return -1
|
|
return -1
|
|
#bVar2 = (**(code **)(*plVar5 + 0x20))(plVar5);
|
|
def hexornot(hon):
|
|
try:
|
|
return "{:X}".format(hon)
|
|
except:
|
|
return str(hon)
|
|
class Operand(object):
|
|
def getVal(self,emu):
|
|
if self.type=="const":
|
|
return self.value
|
|
elif self.type=="register":
|
|
return emu.readRegister(self.value)
|
|
elif self.type=="stackvar":
|
|
return emu.readStackValue(int(self.value,0),self.length,False)
|
|
elif self.type=="memconst":
|
|
return 0 #need implementation?
|
|
def __hash__(self):
|
|
return hash((self.type, self.value))
|
|
def __eq__(self, other):
|
|
return (self.type, self.value) == (other.type, other.value)
|
|
def __ne__(self, other):
|
|
return not(self == other)
|
|
def __init__(self,instr,num):
|
|
self.type="unknown"
|
|
if instr is None:
|
|
return
|
|
if num>=instr.getNumOperands():
|
|
return
|
|
tp=instr.getOperandType(num)
|
|
objlist=instr.getOpObjects(num)
|
|
ln=len(objlist)
|
|
if ln==0:
|
|
return
|
|
if tp&OperandType.SCALAR: #const
|
|
self.type="const"
|
|
self.value=int(str(objlist[0]),16)
|
|
return
|
|
if tp&OperandType.REGISTER and len(objlist)==1: #pure register
|
|
self.type="register"
|
|
self.value=str(objlist[0])
|
|
return
|
|
if len(objlist)==2 and str(objlist[0])=="RSP":
|
|
self.type="stackvar"
|
|
self.value="0x{:x}".format(int(str(objlist[1]),0))
|
|
self.offset=int(str(objlist[1]),0)
|
|
self.len=1
|
|
oprep=instr.getDefaultOperandRepresentation(num)
|
|
if "xword" in oprep:
|
|
self.length=16
|
|
elif "qword" in oprep:
|
|
self.length=8
|
|
elif "dword" in oprep:
|
|
self.length=4
|
|
elif "word ptr" in oprep:
|
|
self.length=2
|
|
elif "byte ptr" in oprep:
|
|
self.length=1
|
|
return
|
|
if len(objlist)==2 and tp&OperandType.DYNAMIC and tp&OperandType.ADDRESS : #dword ptr [RCX + -0x4]
|
|
self.type="memconst"
|
|
self.value=str(objlist[1])
|
|
self.olist=objlist
|
|
return
|
|
if len(objlist)==2 and str(objlist[0])=="GS":
|
|
self.type="debugho"
|
|
self.value="__0x{:x}".format(int(str(objlist[1]),0))
|
|
return
|
|
if len(objlist)==3 and isinstance(objlist[0],Register): #not quite correct, but. uff
|
|
self.type="register"
|
|
self.value=str(objlist[0])
|
|
return
|
|
logger.warning("{} {:x} {} needs operand implementation {}".format(instr,tp,num,objlist))
|
|
|
|
|
|
class BasicDataGraph(object):
|
|
def __init__(self):
|
|
self.registers={}
|
|
self.variables={}
|
|
self.inputs=set([])
|
|
#logger.info(instr.getDefaultOperandRepresentation(0))
|
|
#logger.info(instr.getDefaultOperandRepresentation(1))
|
|
def maybeInput(self,inp):
|
|
if inp.type=="register" or inp.type=="stackvar":
|
|
if not inp.value in self.registers and not inp.value in self.variables:
|
|
if not inp.value in self.inputs:
|
|
self.inputs.add(inp.value)
|
|
def getDeps(self,inp):
|
|
if inp.type=="register":
|
|
if inp.value in self.registers:
|
|
return set(self.registers[inp.value])
|
|
if inp.type=="stackvar":
|
|
if inp.value in self.variables:
|
|
return set(self.variables[inp.value])
|
|
if inp.value in self.inputs:
|
|
return set([inp.value])
|
|
return set([])
|
|
def assign(self,outp,inp):
|
|
self.maybeInput(inp)
|
|
ival=self.getDeps(inp)
|
|
if outp.type=="register":
|
|
self.registers[outp.value]=ival
|
|
elif outp.type=="stackvar":
|
|
self.variables[outp.value]=ival
|
|
else:
|
|
log.warning("Weird assign to {}".format(outp.type))
|
|
def addDependency(self,recv,incoming):
|
|
if recv.type=="register":
|
|
if not recv.value in self.registers:
|
|
self.registers[recv.value]=set()
|
|
receiver=self.registers[recv.value]
|
|
elif recv.type=="stackvar":
|
|
if not recv.value in self.variables:
|
|
self.variables[recv.value]=set()
|
|
receiver=self.variables[recv.value]
|
|
else:
|
|
return # we don't care?
|
|
for inp in incoming: #add direct ,leave propagation for later
|
|
self.maybeInput(inp)
|
|
ival=self.getDeps(inp)
|
|
if inp.type=="register":
|
|
receiver.update(ival)
|
|
elif inp.type=="stackvar":
|
|
receiver.update(ival)
|
|
else:
|
|
continue #consts, etc should not matter
|
|
def scrambleRegisters(self): #we lose depedency over call
|
|
for reg in self.registers:
|
|
self.registers[reg]=set([])
|
|
def add(self,instr):
|
|
logger.info(instr)
|
|
mn=instr.getMnemonicString().lower()
|
|
if mn== "call":
|
|
self.scrambleRegisters()
|
|
elif mn=="mov" or "movzx" in mn: #assign
|
|
output=Operand(instr,0)
|
|
input=Operand(instr,1)
|
|
self.assign(output,input)
|
|
elif mn=="imul"or mn=="sub" or mn=="add" or mn=="and" or mn=="xor":
|
|
no=c_instr.getNumOperands()
|
|
output=Operand(instr,0)
|
|
inputs=[Operand(instr,1)]
|
|
if no==3:
|
|
inputs.append(Operand(instr,2))
|
|
self.addDependency(output,inputs)
|
|
else:
|
|
pass #maybe do not matter
|
|
def getByMnem(self,mnem):
|
|
root=set([])
|
|
if mnem in self.variables:
|
|
root=set(self.variables[mnem])
|
|
rgs=currentProgram.getRegister(str(mnem))
|
|
if rgs is not None:
|
|
rgs=rgs.getBaseRegister()
|
|
if str(rgs) in self.registers:
|
|
root=self.registers[str(rgs)]
|
|
else:
|
|
root=set([])
|
|
for r in rgs.getChildRegisters():
|
|
if str(r) in self.registers:
|
|
root.update(self.registers[str(r)])
|
|
if mnem in self.registers:
|
|
return set(self.registers[mnem])
|
|
|
|
return root
|
|
def linksTo(self,mnem):
|
|
logger.info("mnem")
|
|
logger.info(mnem)
|
|
root=self.getByMnem(mnem)
|
|
logger.info("added")
|
|
logger.info(self.variables)
|
|
logger.info(self.registers)
|
|
logger.info(self.inputs)
|
|
return root
|
|
|
|
class mchain(object):
|
|
def __init__(self,primary):
|
|
self.mainRef=primary
|
|
self.secref=set([])
|
|
def addSecondary(self,mnem):
|
|
self.secref.add(mnem)
|
|
def string(self,emu):
|
|
index=emu.readRegister(self.mainRef)
|
|
ihash=0
|
|
istr=""
|
|
for scr in self.secref:
|
|
reg=currentProgram.getRegister(str(scr))
|
|
if reg is not None:
|
|
ro=0
|
|
ref=emu.readRegister(reg)
|
|
else:
|
|
ro=int(scr,0)
|
|
ref=emu.readStackValue(ro,4,False)
|
|
istr=istr+"_{:x}_{:x}|".format(ro,ref)
|
|
ihash = (ref + (ihash << 6) + (ihash << 16) - ihash*ro)&0xffffffff
|
|
return "0x{:x}_{}".format(index,istr)
|
|
#readStackValue
|
|
import os
|
|
import json
|
|
#needs checksum detection! checksum has 3 stackvars: checksum, current addr and final addr
|
|
class ObfuscatedPath(object):
|
|
def __init__(self):
|
|
self.switchAddr=None # jmp RAX instruction/ switch that controls flow. should be first cbranch, usually
|
|
self.jtableBreakpoint=None # movsxd instruction
|
|
self.jtableRef=None
|
|
self.jindexRef=None #register that contains jump index at jtableBreakpoint
|
|
self.estimatedTableLen=-1
|
|
self.paths=[]
|
|
self.cpi=-1 #current path index
|
|
self.branchboxes={}
|
|
self.lastindex=-1
|
|
self.loadedBox=None
|
|
self.trace=None
|
|
self.curFullPath=[]
|
|
self.startpoint=state.getCurrentLocation()
|
|
self.cur_fun=getFunctionContaining(self.startpoint.address)
|
|
self.emu=None
|
|
self.monitor=ConsoleTaskMonitor()
|
|
self.loopAvoider=1000
|
|
self.loopAvoid=self.loopAvoider
|
|
self.mchl=1 #ok, memory chain does not really work XD Needs to check for hidden variables/ pathbox...
|
|
self.mch=None
|
|
self.instructionBlock=[]
|
|
self.maxflushes=10
|
|
self.curflushes=0
|
|
self.flushed=False
|
|
self.saveregs= ["RIP", "RAX","RBX", "RCX", "RDX", "RSI", "RDI", "RSP", "RBP", "R15","rflags","R8","R9"]
|
|
self.stackvars=set()
|
|
self.collectedData={}
|
|
self.entry="0x{}".format(self.cur_fun.getEntryPoint())
|
|
if os.path.isfile(DATA_FILE):
|
|
with open(DATA_FILE,"rb") as fl:
|
|
dat=fl.read()
|
|
if len(dat)>0:
|
|
try:
|
|
self.collectedData=json.loads(dat)
|
|
except Exception as e:
|
|
logger.warning("Error reading data file: {}".format(e))
|
|
self.collectedData={}
|
|
if not self.entry in self.collectedData :
|
|
self.collectedData[self.entry]={"step":0}
|
|
|
|
def initEmulator(self):
|
|
if self.emu is not None:
|
|
self.emu.dispose()
|
|
self.emu=EmulatorHelper(currentProgram)
|
|
mainFunctionEntryLong = int("0x{}".format(self.cur_fun.getEntryPoint()), 16)
|
|
self.emu.writeRegister(self.emu.getPCRegister(), mainFunctionEntryLong)
|
|
self.emu.writeRegister("RSP", 0x000000002FFF0000)
|
|
self.emu.writeRegister("RBP", 0x000000002FFF0000)
|
|
def indexInPaths(self, index):
|
|
for pth in self.paths:
|
|
if pth.contains(index): return True
|
|
return False
|
|
def loadBox(self):
|
|
self.loadedBox=None
|
|
si=self.mch.string(self.emu)
|
|
if si in self.branchboxes:
|
|
logger.info("Loaded box {}".format(si))
|
|
self.loadedBox=self.branchboxes[si]
|
|
self.loadedBox.reset()
|
|
def findUntakenBranch(self):
|
|
for bb in self.branchboxes:
|
|
if self.branchboxes[bb].hasUntakenPaths():
|
|
logger.info("Branch box at {} has untaken path, restarting".format(bb))
|
|
return self.branchboxes[bb]
|
|
return None
|
|
def restart(self):
|
|
self.initEmulator()
|
|
self.loadedBox=None
|
|
self.cpi=-1
|
|
self.curFullPath=[]
|
|
self.instructionBlock=[]
|
|
self.flushed=True
|
|
def run(self):
|
|
for a in range(170000):
|
|
if not self.process():
|
|
bb=self.findUntakenBranch()
|
|
if bb is None:
|
|
logger.info("No untaken paths left, stopping")
|
|
return
|
|
self.trace=bb.pathToBox
|
|
logger.info(bb.pathToBox)
|
|
self.restart()
|
|
def initNewPath(self,index):
|
|
ind=len(self.paths)
|
|
npath=CPath()
|
|
npath.addIndex(index)
|
|
self.paths.append(npath)
|
|
self.cpi=ind
|
|
def endPath(self,cause,jmp=None): #iloop, loop, branch, ret, baddata
|
|
if self.cpi==-1:
|
|
logger.warning("Trying to end path as {} when not on path".format(cause))
|
|
return
|
|
logger.info("Ending path {} ({}) as {}".format(self.cpi,self.paths[self.cpi].path,cause))
|
|
pth=self.paths[self.cpi]
|
|
pth.ending=cause
|
|
pth.endJump=jmp
|
|
self.cpi=-1
|
|
def runStraight(self,nl):
|
|
for a in range(nl):
|
|
executionAddress = self.emu.getExecutionAddress()
|
|
logger.info("Address: 0x{} ({})".format(executionAddress, getInstructionAt(executionAddress)))
|
|
ein=getInstructionAt(executionAddress)
|
|
if "{}".format(ein)=="CMP EAX,0x347":
|
|
reg_value = self.emu.readRegister("EAX")
|
|
logger.info("EAX: 0x{:x}".format(reg_value))
|
|
if not self.step(): return
|
|
def flush(self):
|
|
if self.curflushes>=self.maxflushes: return
|
|
self.trace=None
|
|
self.branchboxes={}
|
|
self.paths=[]
|
|
self.restart()
|
|
logger.info("Flushed branches, rebuilding, current secondary: {}".format(self.mch.secref))
|
|
self.curflushes+=1
|
|
def analyzeInstrBlock(self):
|
|
if len(self.instructionBlock)<10: return False
|
|
if self.curflushes>=self.maxflushes: raise NotImplementedError
|
|
if self.flushed:
|
|
self.flushed=False
|
|
return False
|
|
bdg=BasicDataGraph()
|
|
for instr in self.instructionBlock:
|
|
bdg.add(instr)
|
|
affects=bdg.linksTo(self.jindexRef)
|
|
logger.info(affects)
|
|
prelen=len(self.mch.secref)
|
|
for a in affects:
|
|
if currentProgram.getRegister(str(a)) is not None: continue # avoid registers for now
|
|
self.mch.addSecondary(a)
|
|
if len(self.mch.secref)!=prelen:
|
|
return True
|
|
#self.curflushes+=1
|
|
return False
|
|
def save(self):
|
|
with open(DATA_FILE,"wb") as fl:
|
|
fl.write(json.dumps(self.collectedData).encode("utf-8"))
|
|
def detectLongLoops(self):
|
|
marks={}
|
|
prev=None
|
|
maxLoop=10
|
|
window=[]
|
|
cloop=None
|
|
coffs=0
|
|
clc=0
|
|
cnt=0
|
|
loopBreaks=[]
|
|
loops=[]
|
|
for ind in self.collectedData[self.entry]["basicChain"]:
|
|
if cloop is not None:
|
|
if ind != cloop[coffs]:
|
|
#loop break
|
|
logger.warning("Lbreak {} {} ".format(ind,cloop))
|
|
loopBreaks.append(cnt-1)
|
|
if clc>=5:
|
|
loops.append((clc,cloop))
|
|
cloop=None
|
|
clc=0
|
|
coffs=0
|
|
else:
|
|
coffs+=1
|
|
if coffs>=len(cloop):
|
|
clc+=1
|
|
coffs=0
|
|
else:
|
|
if ind in window:
|
|
cloop=window[window.index(ind):]
|
|
while ind in cloop[1:]:
|
|
l=cloop[1:]
|
|
cloop=l[l.index(ind):]
|
|
clc=0
|
|
coffs=1
|
|
window.append(ind)
|
|
if len(window)>maxLoop:
|
|
window=window[1:]
|
|
cnt+=1
|
|
logger.info("Detected loops: {}".format(loops))
|
|
logger.info("Detected loop breaks: {}".format(loopBreaks))
|
|
self.collectedData[self.entry]["basicLoops"]=loops
|
|
self.collectedData[self.entry]["basicLoopBreaks"]=loopBreaks
|
|
def getState(self):
|
|
state={}
|
|
for reg in self.saveregs:
|
|
val=self.emu.readRegister(reg)
|
|
state[reg]=val
|
|
for v in self.stackvars:
|
|
val=self.emu.readStackValue(int(v.value,16),v.length,False)
|
|
state["{}_{}".format(v.value,v.length)]=val
|
|
return state
|
|
def writeState(self,state):
|
|
for reg in self.saveregs:
|
|
if reg in state and reg!="RIP":
|
|
self.emu.writeRegister(reg,state[reg])
|
|
for v in self.stackvars:
|
|
ref="{}_{}".format(v.value,v.length)
|
|
if ref in state:
|
|
self.emu.writeStackValue(int(v.value,16),v.length,state[ref])
|
|
def tryRun(self,nl):
|
|
step=self.collectedData[self.entry]["step"]
|
|
if step==0:
|
|
while self.jtableBreakpoint is None:
|
|
self.process() #replace with initCollect
|
|
self.flush()
|
|
self.collectedData[self.entry]["jtableBreakpoint"]="0x{}".format(self.jtableBreakpoint)
|
|
self.collectedData[self.entry]["jindexRef"]=str(self.jindexRef)
|
|
self.collectedData[self.entry]["estimatedTableLen"]=self.estimatedTableLen
|
|
self.collectedData[self.entry]["stackvars"]=[]
|
|
self.collectedData[self.entry]["callIndices"]=[]
|
|
for sv in self.stackvars:
|
|
self.collectedData[self.entry]["stackvars"].append([sv.type,sv.value,sv.length])
|
|
self.collectedData[self.entry]["step"]=1
|
|
step=1
|
|
self.save()
|
|
else:
|
|
eaddr = self.emu.getExecutionAddress()
|
|
ein=getInstructionAt(eaddr)
|
|
if self.jtableBreakpoint is None:
|
|
self.jtableBreakpoint=getAddress(int(self.collectedData[self.entry]["jtableBreakpoint"],16))
|
|
self.jindexRef=self.collectedData[self.entry]["jindexRef"]
|
|
self.estimatedTableLen=self.collectedData[self.entry]["estimatedTableLen"]
|
|
for sv in self.collectedData[self.entry]["stackvars"]:
|
|
op=Operand(None,0)
|
|
op.type=sv[0]
|
|
op.value=sv[1]
|
|
op.length=sv[2]
|
|
self.stackvars.add(op)
|
|
if step==1: #run till the end, marking indices
|
|
self.collectedData[self.entry]["basicChain"]=[]
|
|
while(True):
|
|
executionAddress = self.emu.getExecutionAddress()
|
|
ein=getInstructionAt(executionAddress)
|
|
if executionAddress==self.jtableBreakpoint:
|
|
#index=self.mch.string(self.emu)
|
|
index=self.emu.readRegister(self.jindexRef)
|
|
self.collectedData[self.entry]["basicChain"].append(index)
|
|
logger.info("At index: {:X}".format(index))
|
|
if index>self.estimatedTableLen:
|
|
logger.info("Hit switch guard")
|
|
break
|
|
if isReturn(ein):
|
|
logger.info("Got to return scessfully")
|
|
break
|
|
branch=getCondmoveInstruction(ein)
|
|
if branch is not None:
|
|
logger.info("Address: 0x{} ({})".format(executionAddress, ein))
|
|
if isCallInstruction(ein):
|
|
logger.info("Skipping call : 0x{} ({})".format(executionAddress, ein))
|
|
self.collectedData[self.entry]["callIndices"].append(index)
|
|
skip(self.emu)
|
|
elif isCallOther(ein):
|
|
logger.info("Skipping : 0x{} ({})".format(executionAddress, ein))
|
|
skip(self.emu)
|
|
else:
|
|
#logger.info("Address: 0x{} ({})".format(executionAddress, getInstructionAt(executionAddress)))
|
|
if not self.step(): break
|
|
self.detectLongLoops()
|
|
self.collectedData[self.entry]["step"]=2
|
|
step=2
|
|
self.save()
|
|
else:
|
|
pass
|
|
self.restart()
|
|
if step==2:
|
|
self.detectLongLoops()
|
|
lbr=self.collectedData[self.entry]["basicLoopBreaks"]
|
|
if len(lbr)>0:
|
|
cnt=0
|
|
while(True):
|
|
executionAddress = self.emu.getExecutionAddress()
|
|
ein=getInstructionAt(executionAddress)
|
|
if executionAddress==self.jtableBreakpoint:
|
|
index=self.emu.readRegister(self.jindexRef)
|
|
if cnt in lbr:
|
|
logger.info("At loop break {:x}".format(index))
|
|
state=self.getState()
|
|
if not "breakStates" in self.entry:
|
|
self.collectedData[self.entry]["breakStates"]={}
|
|
self.collectedData[self.entry]["breakStates"][cnt]={}
|
|
self.collectedData[self.entry]["breakStates"][cnt][index]=state
|
|
if index>self.estimatedTableLen:
|
|
logger.info("Hit switch guard")
|
|
break
|
|
cnt+=1
|
|
if isReturn(ein):
|
|
logger.info("Got to return scessfully")
|
|
break
|
|
if isCallInstruction(ein):
|
|
logger.info("Skipping call : 0x{} ({})".format(executionAddress, ein))
|
|
self.collectedData[self.entry]["callIndices"].append(index)
|
|
skip(self.emu)
|
|
elif isCallOther(ein):
|
|
logger.info("Skipping : 0x{} ({})".format(executionAddress, ein))
|
|
skip(self.emu)
|
|
else:
|
|
if not self.step(): break
|
|
self.collectedData[self.entry]["step"]=3
|
|
step=3
|
|
self.save()
|
|
if step==3: #need to track the indices, jump the loops and jigger the non-loop conditions
|
|
self.restart()
|
|
def process(self): #todo: needs to check for intersections... memory chain, maybe? Nope XD
|
|
eaddr = self.emu.getExecutionAddress()
|
|
ein=getInstructionAt(eaddr)
|
|
if self.jtableBreakpoint is None: #initialization of jump table
|
|
jt=getPotentialJtableAccess(ein)
|
|
# try to find local variables... poorly XD
|
|
if ein is not None:
|
|
for nop in range(ein.getNumOperands()):
|
|
op=Operand(ein,nop)
|
|
if op.type=="stackvar":
|
|
self.stackvars.add(op)
|
|
if jt is not None:
|
|
logger.info("Found assumed jump at 0x{}".format(eaddr))
|
|
self.jtableBreakpoint=eaddr
|
|
self.switchAddr=jt[2]
|
|
self.jtableRef=jt[0]
|
|
jtableAddr=getAddress(self.emu.readRegister(self.jtableRef))
|
|
self.jindexRef=jt[1]
|
|
self.mch=mchain(self.jindexRef)
|
|
self.estimatedTableLen=estimateTableLen(jtableAddr,self.cur_fun.getEntryPoint())
|
|
logger.info("Jump table length estimated to be {}".format(self.estimatedTableLen))
|
|
sGuard=tryGetSwitchGuard(ein)
|
|
if sGuard>0:
|
|
logger.info("Replacing jump table length with switch guard of {}".format(sGuard))
|
|
self.estimatedTableLen=sGuard
|
|
self.jtableBreakpoint=ein.getPrevious().getAddress()
|
|
lval = int("0x{}".format(self.jtableBreakpoint), 16)
|
|
self.emu.writeRegister(self.emu.getPCRegister(), lval)
|
|
logger.info("Moving breakpoint to switch guard at 0x{}".format(self.jtableBreakpoint))
|
|
return self.process()
|
|
return self.step()
|
|
if eaddr==self.switchAddr: # at switch
|
|
self.instructionBlock=[]
|
|
else:
|
|
self.instructionBlock.append(ein)
|
|
if eaddr==self.jtableBreakpoint:
|
|
if self.analyzeInstrBlock(): # see if there are any more hidden variables...
|
|
self.flush()
|
|
return True #to avoid second restart
|
|
self.loopAvoid=self.loopAvoider #did not hit infinite loop on the way
|
|
index=self.mch.string(self.emu)#self.emu.readRegister(self.jindexRef)
|
|
activePath=None
|
|
if self.cpi!=-1:
|
|
activePath=self.paths[self.cpi]
|
|
if self.loadedBox is not None:
|
|
if self.trace is None and self.cpi>=0: #degenerate 1-step loop is possible...
|
|
self.endPath("branch",self.lastindex)
|
|
self.loadedBox.registerIndexOutput(index)
|
|
self.curFullPath[-1].append(self.loadedBox.pathTaken)
|
|
if self.trace is None and self.indexInPaths(index):
|
|
if self.lastindex==index: #degenerate
|
|
npath=CPath()
|
|
npath.ending="loop"
|
|
npath.endJump=index
|
|
self.paths.append(npath)
|
|
else:
|
|
if self.cpi!=-1:
|
|
self.endPath("loop") # most loops end in branches? though not all
|
|
else:
|
|
if activePath is not None:
|
|
activePath.multiend=True
|
|
logger.info("Found potential obfuscated loop to 0x{:x}".format(index))
|
|
return False
|
|
if self.trace is None:
|
|
if self.cpi==-1:
|
|
self.initNewPath(index) #self.mch.string()
|
|
else:
|
|
self.paths[self.cpi].addIndex(index)
|
|
self.curFullPath.append([index])
|
|
logger.info("At index {}".format(index))
|
|
self.lastindex=index
|
|
self.loadBox()
|
|
if self.trace is not None:
|
|
tpt=self.trace[0]
|
|
self.trace=self.trace[1:]
|
|
if len(self.trace)==0:
|
|
self.trace=None
|
|
if tpt[0]!=index:
|
|
logger.warning("Trace error, index mismatch!")
|
|
return False
|
|
if len(tpt)>1:
|
|
if self.loadedBox is None:
|
|
logger.warning("Trace error, should be branch box here!")
|
|
return False
|
|
self.loadedBox.loadTrace(tpt[1])
|
|
if self.emu.readRegister(self.jindexRef)>self.estimatedTableLen:
|
|
logger.info("Hit switch guard, retracing")
|
|
self.endPath("iloop")
|
|
#self.restart()
|
|
return False
|
|
branch=getCondmoveInstruction(ein)
|
|
if branch is not None:
|
|
if self.loadedBox is None:
|
|
bbox=BranchBox(self.lastindex)
|
|
bbox.pathToBox=[list(k) for k in self.curFullPath] #deep copy
|
|
self.branchboxes[self.lastindex]=bbox
|
|
self.loadedBox=bbox
|
|
self.loadedBox.registerBranch(branch)
|
|
self.loadedBox.takeUntakenOrTrace(self.emu)
|
|
return True
|
|
if isCallInstruction(ein):
|
|
logger.info("Skipping call : 0x{} ({})".format(eaddr, ein))
|
|
skip(self.emu)
|
|
return True
|
|
self.loopAvoid-=1
|
|
if self.loopAvoid<0:
|
|
self.endPath("iloop")
|
|
return False
|
|
if isReturn(ein):
|
|
if self.cpi!=-1:
|
|
self.endPath("return")
|
|
return False
|
|
return self.step()
|
|
def step(self):
|
|
success = self.emu.step(self.monitor)
|
|
if (success == False):
|
|
lastError = self.emu.getLastError()
|
|
logger.error("Emulation Error: '{}'".format(lastError))
|
|
if self.cpi!=-1:
|
|
self.endPath("baddata")
|
|
return success
|
|
|
|
location=state.getCurrentLocation()
|
|
print(location.address)
|
|
c_instr=getInstructionAt(location.address)
|
|
print(c_instr)
|
|
print(c_instr.getNumOperands())
|
|
print(type(c_instr.getOpObjects(0)[0]))
|
|
print(c_instr.getOpObjects(1))
|
|
print(listing.getInstructionBefore(c_instr.address))
|
|
print(c_instr.getOperandType(0))
|
|
|
|
print("{:x}".format(c_instr.getOperandType(1)))
|
|
print("{:x}".format(c_instr.getOperandType(2)))
|
|
|
|
cur_fun=getFunctionContaining(location.address)
|
|
print(cur_fun)
|
|
|
|
print(currentProgram.getRegister("ECX"))
|
|
print(currentProgram.getRegister("EdCX"))
|
|
print(c_instr.getDefaultOperandRepresentation(1))
|
|
#emuHelper = EmulatorHelper(currentProgram)
|
|
mainFunctionEntryLong = int("0x{}".format(cur_fun.getEntryPoint()), 16)
|
|
#emuHelper.writeRegister(emuHelper.getPCRegister(), mainFunctionEntryLong)
|
|
registers = getProgramRegisterList(currentProgram)
|
|
reg_filter = [
|
|
"RIP", "RAX"]#, "RBX", "RCX", "RDX", "RSI", "RDI",
|
|
# "RSP", "RBP", "rflags"
|
|
|
|
op=ObfuscatedPath()
|
|
op.initEmulator()
|
|
op.tryRun(50000000)
|
|
#op.restart()
|
|
#op.runStraight(5000)
|
|
op.emu.dispose()
|
|
|