diff --git a/androscalpel/src/dex_id.rs b/androscalpel/src/dex_id.rs index cebc68b..21d3e73 100644 --- a/androscalpel/src/dex_id.rs +++ b/androscalpel/src/dex_id.rs @@ -7,6 +7,7 @@ use std::collections::HashSet; use std::hash::{Hash, Hasher}; use anyhow::{anyhow, bail, Context}; +use pyo3::class::basic::CompareOp; use pyo3::prelude::*; use crate::{scalar::*, DexString, DexValue, Result}; @@ -142,10 +143,6 @@ impl IdMethodType { self.parameters.clone() } - pub fn __eq__(&self, other: &Self) -> bool { - self == other - } - pub fn __hash__(&self) -> u64 { let mut hasher = DefaultHasher::new(); self.hash(&mut hasher); @@ -177,6 +174,10 @@ impl IdMethodType { protos.insert(self.clone()); protos } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.cmp(other)) + } } impl IdMethodType { @@ -570,10 +571,6 @@ impl IdType { hasher.finish() } - pub fn __eq__(&self, other: &Self) -> bool { - self == other - } - /// Return all strings referenced in the Id. pub fn get_all_strings(&self) -> HashSet { let mut strings = HashSet::new(); @@ -587,6 +584,10 @@ impl IdType { types.insert(self.clone()); types } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.cmp(other)) + } // TODO: TESTS } @@ -684,10 +685,6 @@ impl IdField { format!("IdField(\"{name}\", {ty}, {class})") } - pub fn __eq__(&self, other: &Self) -> bool { - self == other - } - pub fn __hash__(&self) -> u64 { let mut hasher = DefaultHasher::new(); self.hash(&mut hasher); @@ -717,6 +714,10 @@ impl IdField { fields.insert(self.clone()); fields } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.cmp(other)) + } } impl Ord for IdField { @@ -848,10 +849,6 @@ impl IdMethod { ) } - pub fn __eq__(&self, other: &Self) -> bool { - self == other - } - pub fn __hash__(&self) -> u64 { let mut hasher = DefaultHasher::new(); self.hash(&mut hasher); @@ -888,6 +885,10 @@ impl IdMethod { method_ids.insert(self.clone()); method_ids } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.cmp(other)) + } } impl Ord for IdMethod { @@ -906,7 +907,7 @@ impl PartialOrd for IdMethod { } #[pyclass] -#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize)] +#[derive(Debug, Clone, PartialEq, Eq, Deserialize, Serialize, Ord, PartialOrd)] pub struct IdEnum(pub IdField); #[pymethods] @@ -952,7 +953,7 @@ impl IdEnum { self.0.get_all_field_ids() } - pub fn __eq__(&self, other: &Self) -> bool { - self == other + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.cmp(other)) } } diff --git a/androscalpel/src/dex_string.rs b/androscalpel/src/dex_string.rs index 7115fb9..c86fe4c 100644 --- a/androscalpel/src/dex_string.rs +++ b/androscalpel/src/dex_string.rs @@ -181,16 +181,6 @@ impl DexString { self.into() } - fn __richcmp__(&self, other: &PyAny, op: CompareOp, py: Python<'_>) -> PyResult { - let other: Self = other - .extract() - .or(::extract(other).map(|string| string.into()))?; - match op { - CompareOp::Eq => Ok((self == &other).into_py(py)), // TODO: alpha order? - _ => Ok(py.NotImplemented()), - } - } - pub fn __repr__(&self) -> String { self.into() } @@ -207,4 +197,8 @@ impl DexString { strings.insert(self.clone()); strings } + + fn __richcmp__(&self, other: &Self, op: CompareOp) -> bool { + op.matches(self.0.cmp(&other.0)) + } } diff --git a/tests/test.py b/tests/test.py index e7a8caa..eb6a6f5 100644 --- a/tests/test.py +++ b/tests/test.py @@ -50,7 +50,7 @@ clazz = dyn_load_apk.classes[clazz_id] method = clazz.virtual_methods[method_id] code = method.code -logging.getLogger().setLevel(logging.WARNING) +# logging.getLogger().setLevel(logging.WARNING) def is_evasion_method(meth: IdMethod) -> bool: @@ -95,7 +95,7 @@ print(f"[+] Load bytecode in assets/classes") with z.ZipFile(DYN_LOAD_APK) as zipf: with zipf.open("assets/classes", "r") as dex_f: dex = dex_f.read() - dyn_load_apk.add_dex_file(dex) + # dyn_load_apk.add_dex_file(dex) mal_cls_in_apk = malicious_class_id in dyn_load_apk.classes if mal_cls_in_apk: @@ -112,60 +112,145 @@ args = [ IdType("Landroid/content/Context;"), IdType("Landroid/content/BroadcastReceiver;"), ] -malicious_class = dyn_load_apk.classes[malicious_class_id] +# malicious_class = dyn_load_apk.classes[malicious_class_id] +# +# potential_meth = [] +# for m_id in malicious_class.direct_methods: +# if m_id.name == name and m_id.proto.get_parameters() == args: +# potential_meth.append((m_id, "direct")) +# for m_id in malicious_class.virtual_methods: +# if m_id.name == name and m_id.proto.get_parameters() == args: +# potential_meth.append((m_id, "virtual")) +# +# print("[+] Potential methods:") +# for m_id, t in potential_meth: +# print(f" {m_id}({t})") +# +# m_id, t = potential_meth[0] +# +# new_insns = ( +# code.insns[:42] +# + [ +# ins.NewInstance(1, IdType("Lcom/example/ut_dyn_load/SmsReceiver;")), +# ins.InvokeVirtual(m_id, [1, 8, 1]), +# ] +# + code.insns[62:] +# ) +# +# print(f"[+] New code ") +# for i, inst in enumerate(new_insns): +# if i >= 42 and i < 44: +# print(f"{i:>03} {GREEN}{inst}{ENDC}") +# continue +# match inst: +# case ins.InvokeVirtual(args=args, method=method) if is_evasion_method(method): +# print(f"{i:>03} {RED}{inst}{ENDC}") +# case ins.SGetObject(to=to, field=field): +# print(f"{i:>03} {inst}") +# print( +# f" val: {GREEN}{dyn_load_apk.classes[field.class_].static_fields[field].value}{ENDC}" +# ) +# case inst: +# print(f"{i:>03} {inst}") +# +# new_code = Code(code.registers_size, code.ins_size, code.outs_size, new_insns) +# dyn_load_apk.set_method_code(method_id, code) -potential_meth = [] -for m_id in malicious_class.direct_methods: - if m_id.name == name and m_id.proto.get_parameters() == args: - potential_meth.append((m_id, "direct")) -for m_id in malicious_class.virtual_methods: - if m_id.name == name and m_id.proto.get_parameters() == args: - potential_meth.append((m_id, "virtual")) +NB_APK = 8666 -print("[+] Potential methods:") -for m_id, t in potential_meth: - print(f" {m_id}({t})") - -m_id, t = potential_meth[0] - -new_insns = ( - code.insns[:42] - + [ - ins.NewInstance(1, IdType("Lcom/example/ut_dyn_load/SmsReceiver;")), - ins.InvokeVirtual(m_id, [1, 8, 1]), - ] - + code.insns[62:] -) - -print(f"[+] New code ") -for i, inst in enumerate(new_insns): - if i >= 42 and i < 44: - print(f"{i:>03} {GREEN}{inst}{ENDC}") - continue - match inst: - case ins.InvokeVirtual(args=args, method=method) if is_evasion_method(method): - print(f"{i:>03} {RED}{inst}{ENDC}") - case ins.SGetObject(to=to, field=field): - print(f"{i:>03} {inst}") - print( - f" val: {GREEN}{dyn_load_apk.classes[field.class_].static_fields[field].value}{ENDC}" - ) - case inst: - print(f"{i:>03} {inst}") - -new_code = Code(code.registers_size, code.ins_size, code.outs_size, new_insns) -dyn_load_apk.set_method_code(method_id, code) +list_cls = list(dyn_load_apk.classes.keys()) +list_cls.sort() +print(f"[+] NB classes: {len(list_cls)}") +for cls in list_cls[NB_APK:]: + dyn_load_apk.remove_class(cls) print("[+] Recompile") dex_raw = dyn_load_apk.gen_raw_dex() print("[+] Repackage") -utils.replace_dex( - DYN_LOAD_APK, - DYN_LOAD_APK.parent - / (DYN_LOAD_APK.name.removesuffix(".apk") + "-instrumented.apk"), - dex_raw, - Path(__file__).parent.parent / "my-release-key.jks", - zipalign=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "zipalign", - apksigner=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "apksigner", -) +# utils.replace_dex( +# DYN_LOAD_APK, +# DYN_LOAD_APK.parent +# / (DYN_LOAD_APK.name.removesuffix(".apk") + "-instrumented.apk"), +# dex_raw, +# Path(__file__).parent.parent / "my-release-key.jks", +# zipalign=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "zipalign", +# apksigner=Path.home() / "Android" / "Sdk" / "build-tools" / "34.0.0" / "apksigner", +# ) + + +MAX_REQ = 3 + + +def cmp(a, b, req=0): + if req > MAX_REQ: + return + if type(a) == dict: + cmp_dict(a, b, req) + elif type(a) == list: + cmp_list(a, b, req) + else: + cmp_other(a, b, req) + + +def nice_bool(b) -> str: + if b: + return "\033[32mTrue\033[0m" + else: + return "\033[31mFalse\033[0m" + + +def cmp_other(a, b, req=0): + ident = " " * req + for f in dir(a): + if getattr(getattr(a, f), "__call__", None) is None and ( + len(f) < 2 or f[:2] != "__" + ): + eq = getattr(a, f) == getattr(b, f) + print(f"{f'{ident}{f}: ':<150}{nice_bool(eq)}") + if not eq: + if "descriptor" in dir(a): + global last_id + last_id = a.descriptor + cmp(getattr(a, f), getattr(b, f), req + 1) + + +def cmp_dict(a, b, req=0): + ident = " " * req + keys_a = set(a.keys()) + keys_b = set(b.keys()) + if keys_a != keys_b: + print(f"{ident}a.keys() != b.keys()") + tot = 0 + nb_failed = 0 + for key in keys_a & keys_b: + eq = a[key] == b[key] + tot += 1 + if not eq: + nb_failed += 1 + print(f"{f'{ident}{str(key)}: ':<150}{nice_bool(eq)}") + global last_id + last_id = key + cmp(a[key], b[key], req + 1) + print(f"\033[32m{tot-nb_failed}\033[0m + \033[31m{nb_failed}\033[0m = {tot}") + + +def cmp_list(a, b, req=0): + ident = " " * req + la = len(a) + lb = len(b) + if la != lb: + print(f"{ident}len(a) != len(b)") + for i in range(min(la, lb)): + eq = a[i] == b[i] + print(f"{f'{ident}{str(i)}: ':<150}{nice_bool(eq)}") + if not eq: + cmp(a[i], b[i], req + 1) + + +instrumented_apk = Apk() +instrumented_apk.add_dex_file(dex_raw[0]) + +cmp(instrumented_apk, dyn_load_apk) +for i in range(NB_APK - 5, NB_APK + 5): + print(f"{i}: {str(list_cls[i])}")