This commit is contained in:
Jean-Marie 'Histausse' Mineau 2024-02-26 22:35:55 +01:00
parent 4e1c36ad3c
commit 4222dc6354
Signed by: histausse
GPG key ID: B66AEEDA9B645AD2
3 changed files with 161 additions and 81 deletions

View file

@ -7,6 +7,7 @@ use std::collections::HashSet;
use std::hash::{Hash, Hasher}; use std::hash::{Hash, Hasher};
use anyhow::{anyhow, bail, Context}; use anyhow::{anyhow, bail, Context};
use pyo3::class::basic::CompareOp;
use pyo3::prelude::*; use pyo3::prelude::*;
use crate::{scalar::*, DexString, DexValue, Result}; use crate::{scalar::*, DexString, DexValue, Result};
@ -142,10 +143,6 @@ impl IdMethodType {
self.parameters.clone() self.parameters.clone()
} }
pub fn __eq__(&self, other: &Self) -> bool {
self == other
}
pub fn __hash__(&self) -> u64 { pub fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
self.hash(&mut hasher); self.hash(&mut hasher);
@ -177,6 +174,10 @@ impl IdMethodType {
protos.insert(self.clone()); protos.insert(self.clone());
protos protos
} }
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.cmp(other))
}
} }
impl IdMethodType { impl IdMethodType {
@ -570,10 +571,6 @@ impl IdType {
hasher.finish() hasher.finish()
} }
pub fn __eq__(&self, other: &Self) -> bool {
self == other
}
/// Return all strings referenced in the Id. /// Return all strings referenced in the Id.
pub fn get_all_strings(&self) -> HashSet<DexString> { pub fn get_all_strings(&self) -> HashSet<DexString> {
let mut strings = HashSet::new(); let mut strings = HashSet::new();
@ -587,6 +584,10 @@ impl IdType {
types.insert(self.clone()); types.insert(self.clone());
types types
} }
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.cmp(other))
}
// TODO: TESTS // TODO: TESTS
} }
@ -684,10 +685,6 @@ impl IdField {
format!("IdField(\"{name}\", {ty}, {class})") format!("IdField(\"{name}\", {ty}, {class})")
} }
pub fn __eq__(&self, other: &Self) -> bool {
self == other
}
pub fn __hash__(&self) -> u64 { pub fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
self.hash(&mut hasher); self.hash(&mut hasher);
@ -717,6 +714,10 @@ impl IdField {
fields.insert(self.clone()); fields.insert(self.clone());
fields fields
} }
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.cmp(other))
}
} }
impl Ord for IdField { impl Ord for IdField {
@ -848,10 +849,6 @@ impl IdMethod {
) )
} }
pub fn __eq__(&self, other: &Self) -> bool {
self == other
}
pub fn __hash__(&self) -> u64 { pub fn __hash__(&self) -> u64 {
let mut hasher = DefaultHasher::new(); let mut hasher = DefaultHasher::new();
self.hash(&mut hasher); self.hash(&mut hasher);
@ -888,6 +885,10 @@ impl IdMethod {
method_ids.insert(self.clone()); method_ids.insert(self.clone());
method_ids method_ids
} }
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.cmp(other))
}
} }
impl Ord for IdMethod { impl Ord for IdMethod {
@ -906,7 +907,7 @@ impl PartialOrd for IdMethod {
} }
#[pyclass] #[pyclass]
#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] #[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Ord, PartialOrd)]
pub struct IdEnum(pub IdField); pub struct IdEnum(pub IdField);
#[pymethods] #[pymethods]
@ -952,7 +953,7 @@ impl IdEnum {
self.0.get_all_field_ids() self.0.get_all_field_ids()
} }
pub fn __eq__(&self, other: &Self) -> bool { fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
self == other op.matches(self.cmp(other))
} }
} }

View file

@ -181,16 +181,6 @@ impl DexString {
self.into() self.into()
} }
fn __richcmp__(&self, other: &PyAny, op: CompareOp, py: Python<'_>) -> PyResult<PyObject> {
let other: Self = other
.extract()
.or(<String as FromPyObject>::extract(other).map(|string| string.into()))?;
match op {
CompareOp::Eq => Ok((self == &other).into_py(py)), // TODO: alpha order?
_ => Ok(py.NotImplemented()),
}
}
pub fn __repr__(&self) -> String { pub fn __repr__(&self) -> String {
self.into() self.into()
} }
@ -207,4 +197,8 @@ impl DexString {
strings.insert(self.clone()); strings.insert(self.clone());
strings strings
} }
fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool {
op.matches(self.0.cmp(&other.0))
}
} }

View file

@ -50,7 +50,7 @@ clazz = dyn_load_apk.classes[clazz_id]
method = clazz.virtual_methods[method_id] method = clazz.virtual_methods[method_id]
code = method.code code = method.code
logging.getLogger().setLevel(logging.WARNING) # logging.getLogger().setLevel(logging.WARNING)
def is_evasion_method(meth: IdMethod) -> bool: def is_evasion_method(meth: IdMethod) -> bool:
@ -95,7 +95,7 @@ print(f"[+] Load bytecode in assets/classes")
with z.ZipFile(DYN_LOAD_APK) as zipf: with z.ZipFile(DYN_LOAD_APK) as zipf:
with zipf.open("assets/classes", "r") as dex_f: with zipf.open("assets/classes", "r") as dex_f:
dex = dex_f.read() dex = dex_f.read()
dyn_load_apk.add_dex_file(dex) # dyn_load_apk.add_dex_file(dex)
mal_cls_in_apk = malicious_class_id in dyn_load_apk.classes mal_cls_in_apk = malicious_class_id in dyn_load_apk.classes
if mal_cls_in_apk: if mal_cls_in_apk:
@ -112,60 +112,145 @@ args = [
IdType("Landroid/content/Context;"), IdType("Landroid/content/Context;"),
IdType("Landroid/content/BroadcastReceiver;"), IdType("Landroid/content/BroadcastReceiver;"),
] ]
malicious_class = dyn_load_apk.classes[malicious_class_id] # malicious_class = dyn_load_apk.classes[malicious_class_id]
#
# potential_meth = []
# for m_id in malicious_class.direct_methods:
# if m_id.name == name and m_id.proto.get_parameters() == args:
# potential_meth.append((m_id, "direct"))
# for m_id in malicious_class.virtual_methods:
# if m_id.name == name and m_id.proto.get_parameters() == args:
# potential_meth.append((m_id, "virtual"))
#
# print("[+] Potential methods:")
# for m_id, t in potential_meth:
# print(f" {m_id}({t})")
#
# m_id, t = potential_meth[0]
#
# new_insns = (
# code.insns[:42]
# + [
# ins.NewInstance(1, IdType("Lcom/example/ut_dyn_load/SmsReceiver;")),
# ins.InvokeVirtual(m_id, [1, 8, 1]),
# ]
# + code.insns[62:]
# )
#
# print(f"[+] New code ")
# for i, inst in enumerate(new_insns):
# if i >= 42 and i < 44:
# print(f"{i:>03} {GREEN}{inst}{ENDC}")
# continue
# match inst:
# case ins.InvokeVirtual(args=args, method=method) if is_evasion_method(method):
# print(f"{i:>03} {RED}{inst}{ENDC}")
# case ins.SGetObject(to=to, field=field):
# print(f"{i:>03} {inst}")
# print(
# f" val: {GREEN}{dyn_load_apk.classes[field.class_].static_fields[field].value}{ENDC}"
# )
# case inst:
# print(f"{i:>03} {inst}")
#
# new_code = Code(code.registers_size, code.ins_size, code.outs_size, new_insns)
# dyn_load_apk.set_method_code(method_id, code)
potential_meth = [] NB_APK = 8666
for m_id in malicious_class.direct_methods:
if m_id.name == name and m_id.proto.get_parameters() == args:
potential_meth.append((m_id, "direct"))
for m_id in malicious_class.virtual_methods:
if m_id.name == name and m_id.proto.get_parameters() == args:
potential_meth.append((m_id, "virtual"))
print("[+] Potential methods:") list_cls = list(dyn_load_apk.classes.keys())
for m_id, t in potential_meth: list_cls.sort()
print(f" {m_id}({t})") print(f"[+] NB classes: {len(list_cls)}")
for cls in list_cls[NB_APK:]:
m_id, t = potential_meth[0] dyn_load_apk.remove_class(cls)
new_insns = (
code.insns[:42]
+ [
ins.NewInstance(1, IdType("Lcom/example/ut_dyn_load/SmsReceiver;")),
ins.InvokeVirtual(m_id, [1, 8, 1]),
]
+ code.insns[62:]
)
print(f"[+] New code ")
for i, inst in enumerate(new_insns):
if i >= 42 and i < 44:
print(f"{i:>03} {GREEN}{inst}{ENDC}")
continue
match inst:
case ins.InvokeVirtual(args=args, method=method) if is_evasion_method(method):
print(f"{i:>03} {RED}{inst}{ENDC}")
case ins.SGetObject(to=to, field=field):
print(f"{i:>03} {inst}")
print(
f" val: {GREEN}{dyn_load_apk.classes[field.class_].static_fields[field].value}{ENDC}"
)
case inst:
print(f"{i:>03} {inst}")
new_code = Code(code.registers_size, code.ins_size, code.outs_size, new_insns)
dyn_load_apk.set_method_code(method_id, code)
print("[+] Recompile") print("[+] Recompile")
dex_raw = dyn_load_apk.gen_raw_dex() dex_raw = dyn_load_apk.gen_raw_dex()
print("[+] Repackage") print("[+] Repackage")
utils.replace_dex( # utils.replace_dex(
DYN_LOAD_APK, # DYN_LOAD_APK,
DYN_LOAD_APK.parent # DYN_LOAD_APK.parent
/ (DYN_LOAD_APK.name.removesuffix(".apk") + "-instrumented.apk"), # / (DYN_LOAD_APK.name.removesuffix(".apk") + "-instrumented.apk"),
dex_raw, # dex_raw,
Path(__file__).parent.parent / "my-release-key.jks", # Path(__file__).parent.parent / "my-release-key.jks",
zipalign=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "zipalign", # zipalign=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "zipalign",
apksigner=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "apksigner", # apksigner=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "apksigner",
) # )
MAX_REQ = 3
def cmp(a, b, req=0):
if req > MAX_REQ:
return
if type(a) == dict:
cmp_dict(a, b, req)
elif type(a) == list:
cmp_list(a, b, req)
else:
cmp_other(a, b, req)
def nice_bool(b) -> str:
if b:
return "\033[32mTrue\033[0m"
else:
return "\033[31mFalse\033[0m"
def cmp_other(a, b, req=0):
ident = " " * req
for f in dir(a):
if getattr(getattr(a, f), "__call__", None) is None and (
len(f) < 2 or f[:2] != "__"
):
eq = getattr(a, f) == getattr(b, f)
print(f"{f'{ident}{f}: ':<150}{nice_bool(eq)}")
if not eq:
if "descriptor" in dir(a):
global last_id
last_id = a.descriptor
cmp(getattr(a, f), getattr(b, f), req + 1)
def cmp_dict(a, b, req=0):
ident = " " * req
keys_a = set(a.keys())
keys_b = set(b.keys())
if keys_a != keys_b:
print(f"{ident}a.keys() != b.keys()")
tot = 0
nb_failed = 0
for key in keys_a & keys_b:
eq = a[key] == b[key]
tot += 1
if not eq:
nb_failed += 1
print(f"{f'{ident}{str(key)}: ':<150}{nice_bool(eq)}")
global last_id
last_id = key
cmp(a[key], b[key], req + 1)
print(f"\033[32m{tot-nb_failed}\033[0m + \033[31m{nb_failed}\033[0m = {tot}")
def cmp_list(a, b, req=0):
ident = " " * req
la = len(a)
lb = len(b)
if la != lb:
print(f"{ident}len(a) != len(b)")
for i in range(min(la, lb)):
eq = a[i] == b[i]
print(f"{f'{ident}{str(i)}: ':<150}{nice_bool(eq)}")
if not eq:
cmp(a[i], b[i], req + 1)
instrumented_apk = Apk()
instrumented_apk.add_dex_file(dex_raw[0])
cmp(instrumented_apk, dyn_load_apk)
for i in range(NB_APK - 5, NB_APK + 5):
print(f"{i}: {str(list_cls[i])}")