diff --git a/androscalpel/src/annotation.rs b/androscalpel/src/annotation.rs index 5fc273c..4e2bf82 100644 --- a/androscalpel/src/annotation.rs +++ b/androscalpel/src/annotation.rs @@ -1,7 +1,7 @@ //! Annotations (for class, fields, methods and parameters alike). +use std::collections::{HashMap, HashSet}; use pyo3::prelude::*; -use std::collections::HashMap; use crate::{dex_id::IdType, value::DexValue, DexString}; @@ -61,6 +61,11 @@ impl DexAnnotationItem { let annotation = self.annotation.__repr__(); format!("AnnotationItem(visibility: {visibility}, {annotation})") } + + /// Return all strings references in the annotation. + pub fn get_all_strings(&self) -> HashSet { + self.annotation.get_all_strings() + } } /// An annotation. @@ -97,4 +102,15 @@ impl DexAnnotation { elts += "}"; format!("Annotation({type_}, {elts})") } + + /// Return all strings references in the annotation. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.extend(self.type_.get_all_strings()); + for (name, value) in &self.elements { + strings.insert(name.clone()); + strings.extend(value.get_all_strings()); + } + strings + } } diff --git a/androscalpel/src/class.rs b/androscalpel/src/class.rs index 0fc2049..07ae4fa 100644 --- a/androscalpel/src/class.rs +++ b/androscalpel/src/class.rs @@ -1,6 +1,6 @@ //! Representation of a class. -use std::collections::HashMap; +use std::collections::{HashMap, HashSet}; use pyo3::prelude::*; @@ -48,7 +48,6 @@ pub struct Class { #[pyo3(get, set)] pub source_file: Option, - // TODO: hash map? /// The static fields #[pyo3(get, set)] pub static_fields: HashMap, @@ -68,7 +67,6 @@ pub struct Class { /// structutres) #[pyo3(get, set)] pub annotations: Vec, - // TODO: mix annotation data to fields / methods / class to make it more practicle } #[pymethods] @@ -127,4 +125,28 @@ impl Class { let name: String = (&self.descriptor.get_name()).into(); format!("Class({name})") } + + /// Return all strings references in the class. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.extend(self.descriptor.get_all_strings()); + if let Some(ty) = &self.superclass { + strings.extend(ty.get_all_strings()); + } + for ty in &self.interfaces { + strings.extend(ty.get_all_strings()); + } + if let Some(string) = &self.source_file { + strings.insert(string.clone()); + } + for (id, field) in &self.static_fields { + strings.extend(id.get_all_strings()); + strings.extend(field.get_all_strings()); + } + //pub instance_fields: HashMap, + //pub direct_methods: HashMap, + //pub virtual_methods: HashMap, + //pub annotations: Vec, + strings + } } diff --git a/androscalpel/src/dex_id.rs b/androscalpel/src/dex_id.rs index 60b1bf1..dd7bf8e 100644 --- a/androscalpel/src/dex_id.rs +++ b/androscalpel/src/dex_id.rs @@ -1,6 +1,7 @@ //! The class identifying dex structure. use std::collections::hash_map::DefaultHasher; +use std::collections::HashSet; use std::hash::{Hash, Hasher}; use anyhow::anyhow; @@ -82,6 +83,17 @@ impl IdMethodType { } shorty.into() } + + /// Return all strings references in the Id. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.insert(self.shorty.clone()); + strings.extend(self.return_type.get_all_strings()); + for ty in &self.parameters { + strings.extend(ty.get_all_strings()); + } + strings + } } /// A type. @@ -337,6 +349,13 @@ impl IdType { self == other } + /// Return all strings references in the Id. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.insert(self.0.clone()); + strings + } + // TODO: TESTS } @@ -394,6 +413,15 @@ impl IdField { self.hash(&mut hasher); hasher.finish() } + + /// Return all strings references in the Id. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.insert(self.name.clone()); + strings.extend(self.type_.get_all_strings()); + strings.extend(self.class_.get_all_strings()); + strings + } } /// The Id of a method. @@ -454,6 +482,15 @@ impl IdMethod { self.hash(&mut hasher); hasher.finish() } + + /// Return all strings references in the Id. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.insert(self.name.clone()); + strings.extend(self.proto.get_all_strings()); + strings.extend(self.class_.get_all_strings()); + strings + } } #[pyclass] @@ -477,4 +514,9 @@ impl IdEnum { pub fn __repr__(&self) -> String { format!("DexEnum({})", self.__str__()) } + + /// Return all strings references in the Id. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } diff --git a/androscalpel/src/dex_string.rs b/androscalpel/src/dex_string.rs index 28e8bb2..df0f720 100644 --- a/androscalpel/src/dex_string.rs +++ b/androscalpel/src/dex_string.rs @@ -1,4 +1,6 @@ +use std::cmp::{Ord, Ordering, PartialOrd}; use std::collections::hash_map::DefaultHasher; +use std::collections::HashSet; use std::hash::{Hash, Hasher}; use pyo3::class::basic::CompareOp; @@ -9,6 +11,24 @@ use pyo3::prelude::*; #[derive(Clone, PartialEq, Eq, Debug)] pub struct DexString(pub androscalpel_serializer::StringDataItem); +impl Ord for DexString { + fn cmp(&self, other: &Self) -> Ordering { + self.0 + .data + .cmp(&other.0.data) + .then(self.0.utf16_size.0.cmp(&other.0.utf16_size.0)) + } +} + +impl PartialOrd for DexString { + fn partial_cmp(&self, other: &Self) -> Option { + self.0 + .data + .partial_cmp(&other.0.data) + .map(|ord| ord.then(self.0.utf16_size.0.cmp(&other.0.utf16_size.0))) + } +} + impl From for androscalpel_serializer::StringDataItem { fn from(DexString(string): DexString) -> Self { string @@ -107,4 +127,11 @@ impl DexString { self.hash(&mut hasher); hasher.finish() } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + strings.insert(self.clone()); + strings + } } diff --git a/androscalpel/src/dex_writer.rs b/androscalpel/src/dex_writer.rs new file mode 100644 index 0000000..69ec526 --- /dev/null +++ b/androscalpel/src/dex_writer.rs @@ -0,0 +1,402 @@ +//! The structure that generate a .dex from classes. + +use std::collections::HashMap; +use std::io::{Cursor, Write}; + +use crate::Result; +use crate::*; +use androscalpel_serializer::*; + +#[derive(Debug, Clone)] +pub struct DexWriter { + header: HeaderItem, + strings: HashMap, + _types_ids: HashMap, + _proto_ids: HashMap, + _field_ids: HashMap, + _method_ids: HashMap, + // TODO: composite classes need a struct for storing link data + // class_defs: HashMap, + // call_site_ids: // TODO: parsing code insns + // method_handles: + // TODO: other structs in data: + // **map_list**, prbl generate on write + // values + // annotations + // +} + +impl Default for DexWriter { + fn default() -> Self { + Self { + header: HeaderItem { + magic: DexFileMagic { + version: [0x30, 0x33, 0x39], + }, // TODO: find a better default version + checksum: 0, + signature: [0u8; 20], + file_size: 0, + header_size: 0x70, + endian_tag: EndianConstant::EndianConstant, + link_size: 0, + link_off: 0, + map_off: 0, + string_ids_size: 0, + string_ids_off: 0, + type_ids_size: 0, // At most 0xffff + type_ids_off: 0, + proto_ids_size: 0, // At most 0xffff + proto_ids_off: 0, + field_ids_size: 0, + field_ids_off: 0, + method_ids_size: 0, + method_ids_off: 0, + class_defs_size: 0, + class_defs_off: 0, + data_size: 0, // Must be an even multiple of sizeof(uint) -> % 8 = 0 + data_off: 0, + }, + strings: HashMap::new(), + _types_ids: HashMap::new(), + _proto_ids: HashMap::new(), + _field_ids: HashMap::new(), + _method_ids: HashMap::new(), + } + } +} + +impl DexWriter { + pub fn new() -> Self { + Self::default() + } + + pub fn add_class(&mut self, class: &Class) -> Result<()> { + // TODO: check size max + for string in class.get_all_strings() { + self.strings.insert(string, 0); + } + Ok(()) + } + + pub fn gen_dex_file_to_vec(&mut self) -> Result> { + let mut output = Cursor::new(Vec::::new()); + self.write_dex_file(&mut output)?; + Ok(output.into_inner()) + } + + fn write_dex_file(&mut self, writer: &mut dyn Write) -> Result<()> { + let mut section_manager = SectionManager::default(); + section_manager.incr_section_size(Section::HeaderItem, 0x70); + // TODO: + // map_list: + // - [x] header_item + // - [x] string_id_item + // - [ ] type_id_item + // - [ ] proto_id_item + // - [ ] field_id_item + // - [ ] method_id_item + // - [ ] class_def_item + // - [ ] call_site_id_item + // - [ ] method_handle_item + // - [ ] map_list + // - [ ] type_list + // - [ ] annotation_set_ref_list + // - [ ] annotation_set_item + // - [ ] class_data_item + // - [ ] code_item + // - [ ] string_data_item + // - [ ] debug_info_item + // - [ ] annotation_item + // - [ ] encoded_array_item + // - [ ] annotations_directory_item + // - [ ] hiddenapi_class_data_item + // Use section_manager for seting the right size/offset afterward + + let mut string_ids_list: Vec = self.strings.keys().cloned().collect(); + string_ids_list.sort(); + for (idx, string) in string_ids_list.iter().enumerate() { + self.strings + .entry(string.clone()) + .and_modify(|val| *val = idx); + section_manager.add_elt(Section::StringIdItem, None); + section_manager.add_elt(Section::StringDataItem, Some(string.0.size())); + } + + // Populate map_list + let map_item_size = MapItem { + type_: MapItemType::HeaderItem, + unused: 0, + size: 0, + offset: 0, + } + .size(); + // Empty map has a size 4, then we add the size of a MapItem for each element + section_manager.add_elt(Section::MapList, Some(4)); + for section in Section::VARIANT_LIST { + if !section.is_data() && section_manager.get_nb_elt(*section) != 0 { + section_manager.incr_section_size(Section::MapList, map_item_size); + } + } + let mut map_list = MapList::default(); + for section in Section::VARIANT_LIST { + if !section.is_data() && section_manager.get_nb_elt(*section) != 0 { + map_list.list.push(MapItem { + type_: section.get_map_item_type(), + unused: 0, + size: section_manager.get_nb_elt(*section) as u32, + offset: section_manager.get_offset(*section), + }); + } + } + + // Link Header section: + self.header.map_off = section_manager.get_offset(Section::MapList); + self.header.string_ids_size = string_ids_list.len() as u32; + self.header.string_ids_off = section_manager.get_offset(Section::StringIdItem); + self.header.type_ids_size = 0; // TODO + self.header.type_ids_off = section_manager.get_offset(Section::TypeIdItem); + self.header.proto_ids_size = 0; // TODO + self.header.proto_ids_off = section_manager.get_offset(Section::ProtoIdItem); + self.header.field_ids_size = 0; // TODO + self.header.field_ids_off = section_manager.get_offset(Section::FieldIdItem); + self.header.method_ids_size = 0; // TODO + self.header.method_ids_off = section_manager.get_offset(Section::MethodIdItem); + self.header.class_defs_size = 0; // TODO + self.header.class_defs_off = section_manager.get_offset(Section::ClassDefItem); + self.header.data_size = section_manager.get_size(Section::Data); + self.header.data_off = section_manager.get_offset(Section::Data); + + // TODO: compute checksum, hash, ect + self.header.serialize(writer)?; + // StringIdItem section + let mut string_off = section_manager.get_offset(Section::StringDataItem); + for string in string_ids_list { + let str_id = StringIdItem { + string_data_off: string_off, + }; + str_id.serialize(writer)?; + string_off += string.0.size() as u32; + } + // TODO: TypeIdItem + // TODO: ProtoIdItem, + // TODO: FieldIdItem, + // TODO: MethodIdItem, + // TODO: ClassDefItem, + // TODO: CallSiteIdItem, + // TODO: MethodHandleItem, + // TODO: Data, + // MapList, + map_list.serialize(writer)?; + // TODO: TypeList, + // TODO: AnnotationSetRefList, + // TODO: AnnotationSetItem, + // TODO: ClassDataItem, + // TODO: CodeItem, + // TODO: StringDataItem, + + // TODO + + // TODO: DebugInfoItem, + // TODO: AnnotationItem, + // TODO: EncodedArrayItem, + // TODO: AnnotationsDirectoryItem, + // TODO: HiddenapiClassDataItem, + Ok(()) + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +enum Section { + HeaderItem, + StringIdItem, + TypeIdItem, + ProtoIdItem, + FieldIdItem, + MethodIdItem, + ClassDefItem, + CallSiteIdItem, + MethodHandleItem, + Data, + MapList, + TypeList, + AnnotationSetRefList, + AnnotationSetItem, + ClassDataItem, + CodeItem, + StringDataItem, + DebugInfoItem, + AnnotationItem, + EncodedArrayItem, + AnnotationsDirectoryItem, + HiddenapiClassDataItem, +} + +impl Section { + const VARIANT_LIST: &[Self] = &[ + Self::HeaderItem, + Self::StringIdItem, + Self::TypeIdItem, + Self::ProtoIdItem, + Self::FieldIdItem, + Self::MethodIdItem, + Self::ClassDefItem, + Self::CallSiteIdItem, + Self::MethodHandleItem, + Self::Data, + Self::MapList, + Self::TypeList, + Self::AnnotationSetRefList, + Self::AnnotationSetItem, + Self::ClassDataItem, + Self::CodeItem, + Self::StringDataItem, + Self::DebugInfoItem, + Self::AnnotationItem, + Self::EncodedArrayItem, + Self::AnnotationsDirectoryItem, + Self::HiddenapiClassDataItem, + ]; + + fn get_index(&self) -> usize { + match self { + Self::HeaderItem => 0, + Self::StringIdItem => 1, + Self::TypeIdItem => 2, + Self::ProtoIdItem => 3, + Self::FieldIdItem => 4, + Self::MethodIdItem => 5, + Self::ClassDefItem => 6, + Self::CallSiteIdItem => 7, + Self::MethodHandleItem => 8, + Self::Data => 9, + Self::MapList => 10, + Self::TypeList => 11, + Self::AnnotationSetRefList => 12, + Self::AnnotationSetItem => 13, + Self::ClassDataItem => 14, + Self::CodeItem => 15, + Self::StringDataItem => 16, + Self::DebugInfoItem => 17, + Self::AnnotationItem => 18, + Self::EncodedArrayItem => 19, + Self::AnnotationsDirectoryItem => 20, + Self::HiddenapiClassDataItem => 21, + } + } + + fn get_elt_size(&self, default_size: Option) -> usize { + let fixed_size = match self { + Self::HeaderItem => Some(0x70), + Self::StringIdItem => Some(4), + Self::TypeIdItem => Some(4), + Self::ProtoIdItem => Some(0xc), + Self::FieldIdItem => Some(8), + Self::MethodIdItem => Some(8), + Self::ClassDefItem => Some(0x20), + Self::CallSiteIdItem => Some(4), + Self::MethodHandleItem => Some(8), + Self::Data => panic!("Element cannot be inserted in data dirctly"), + Self::MapList => None, + Self::TypeList => None, + Self::AnnotationSetRefList => None, + Self::AnnotationSetItem => None, + Self::ClassDataItem => None, + Self::CodeItem => None, + Self::StringDataItem => None, + Self::DebugInfoItem => None, + Self::AnnotationItem => None, + Self::EncodedArrayItem => None, + Self::AnnotationsDirectoryItem => None, + Self::HiddenapiClassDataItem => None, + }; + if let (Some(fixed_size), Some(default_size)) = (fixed_size, default_size) { + if fixed_size == default_size { + default_size + } else { + panic!( + "Element in {:?} have a size of {}, not {}", + self, fixed_size, default_size + ) + } + } else { + fixed_size.or(default_size).expect(&format!( + "Element of {:?} don't have a fixed size, you need to provide one", + self + )) + } + } + + fn get_map_item_type(&self) -> MapItemType { + match self { + Self::HeaderItem => MapItemType::HeaderItem, + Self::StringIdItem => MapItemType::StringIdItem, + Self::TypeIdItem => MapItemType::TypeIdItem, + Self::ProtoIdItem => MapItemType::ProtoIdItem, + Self::FieldIdItem => MapItemType::FieldIdItem, + Self::MethodIdItem => MapItemType::MethodIdItem, + Self::ClassDefItem => MapItemType::ClassDefItem, + Self::CallSiteIdItem => MapItemType::CallSiteIdItem, + Self::MethodHandleItem => MapItemType::MethodHandleItem, + Self::Data => panic!("Data is not a MatItemType"), + Self::MapList => MapItemType::MapList, + Self::TypeList => MapItemType::TypeList, + Self::AnnotationSetRefList => MapItemType::AnnotationSetRefList, + Self::AnnotationSetItem => MapItemType::AnnotationSetItem, + Self::ClassDataItem => MapItemType::ClassDataItem, + Self::CodeItem => MapItemType::CodeItem, + Self::StringDataItem => MapItemType::StringDataItem, + Self::DebugInfoItem => MapItemType::DebugInfoItem, + Self::AnnotationItem => MapItemType::AnnotationItem, + Self::EncodedArrayItem => MapItemType::EncodedArrayItem, + Self::AnnotationsDirectoryItem => MapItemType::AnnotationsDirectoryItem, + Self::HiddenapiClassDataItem => MapItemType::HiddenapiClassDataItem, + } + } + + fn is_data(&self) -> bool { + match self { + Self::Data => true, + _ => false, + } + } +} + +#[derive(Debug, Default)] +struct SectionManager { + sizes: [u32; Self::NB_SECTION], + nb_elt: [usize; Self::NB_SECTION], +} + +impl SectionManager { + const NB_SECTION: usize = 22; + + fn add_elt(&mut self, section: Section, size: Option) { + if section.is_data() { + panic!("Cannot add element directly in section data"); + } + self.sizes[section.get_index()] += section.get_elt_size(size) as u32; + self.nb_elt[section.get_index()] += 1; + } + + fn incr_section_size(&mut self, section: Section, size: usize) { + self.sizes[section.get_index()] += size as u32; + } + + fn get_offset(&self, section: Section) -> u32 { + // TODO: check alignment + self.sizes[..section.get_index()].iter().sum() + } + + fn get_size(&self, section: Section) -> u32 { + // TODO: check alignment + if section.is_data() { + self.sizes[section.get_index()..].iter().sum() + } else { + self.sizes[section.get_index()] + } + } + + fn get_nb_elt(&self, section: Section) -> usize { + self.nb_elt[section.get_index()] + } +} diff --git a/androscalpel/src/field.rs b/androscalpel/src/field.rs index 5781138..3098a09 100644 --- a/androscalpel/src/field.rs +++ b/androscalpel/src/field.rs @@ -1,8 +1,10 @@ //! Representation of the fields of a class. +use std::collections::HashSet; + use pyo3::prelude::*; -use crate::{DexAnnotationItem, DexValue, IdField}; +use crate::{DexAnnotationItem, DexString, DexValue, IdField}; /// Represent a field. #[pyclass] @@ -116,4 +118,18 @@ impl Field { // TODO: check type match Ok(()) } + + /// Return all strings references in the field. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + + strings.extend(self.descriptor.get_all_strings()); + if let Some(val) = &self.value { + strings.extend(val.get_all_strings()); + } + for annot in &self.annotations { + strings.extend(annot.get_all_strings()); + } + strings + } } diff --git a/androscalpel/src/method_handle.rs b/androscalpel/src/method_handle.rs index cfc3a2c..ef86100 100644 --- a/androscalpel/src/method_handle.rs +++ b/androscalpel/src/method_handle.rs @@ -1,9 +1,12 @@ //! The structure use to reference a method invocation. +use std::collections::HashSet; + use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; use crate::dex_id::*; +use crate::DexString; /// The structure use to reference a method invocation. #[derive(Debug, Clone)] @@ -41,6 +44,11 @@ impl StaticPut { pub fn __repr__(&self) -> String { format!("StaticPut({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] #[derive(Debug, Clone, PartialEq, Eq)] @@ -64,6 +72,11 @@ impl StaticGet { pub fn __repr__(&self) -> String { format!("StaticGet({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] #[derive(Debug, Clone, PartialEq, Eq)] @@ -87,6 +100,11 @@ impl InstancePut { pub fn __repr__(&self) -> String { format!("InstancePut({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] #[derive(Debug, Clone, PartialEq, Eq)] @@ -110,6 +128,11 @@ impl InstanceGet { pub fn __repr__(&self) -> String { format!("InstanceGet({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] @@ -134,6 +157,11 @@ impl InvokeStatic { pub fn __repr__(&self) -> String { format!("InvokeStatic({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] @@ -158,6 +186,11 @@ impl InvokeInstance { pub fn __repr__(&self) -> String { format!("InvokeInstance({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] @@ -182,6 +215,11 @@ impl InvokeConstructor { pub fn __repr__(&self) -> String { format!("InvokeConstructor({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] @@ -206,6 +244,11 @@ impl InvokeDirect { pub fn __repr__(&self) -> String { format!("InvokeDirect({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } #[pyclass] @@ -230,6 +273,11 @@ impl InvokeInterface { pub fn __repr__(&self) -> String { format!("InvokeInterface({})", self.0.__str__()) } + + /// Return all strings references in the handle. + pub fn get_all_strings(&self) -> HashSet { + self.0.get_all_strings() + } } impl<'source> FromPyObject<'source> for MethodHandle { @@ -292,4 +340,19 @@ impl MethodHandle { Self::InvokeInterface(val) => val.__str__(), } } + + /// Return all strings references in the Handle. + pub fn get_all_strings(&self) -> HashSet { + match self { + Self::StaticPut(val) => val.get_all_strings(), + Self::StaticGet(val) => val.get_all_strings(), + Self::InstancePut(val) => val.get_all_strings(), + Self::InstanceGet(val) => val.get_all_strings(), + Self::InvokeStatic(val) => val.get_all_strings(), + Self::InvokeInstance(val) => val.get_all_strings(), + Self::InvokeConstructor(val) => val.get_all_strings(), + Self::InvokeDirect(val) => val.get_all_strings(), + Self::InvokeInterface(val) => val.get_all_strings(), + } + } } diff --git a/androscalpel/src/scalar.rs b/androscalpel/src/scalar.rs index db07b8b..0099832 100644 --- a/androscalpel/src/scalar.rs +++ b/androscalpel/src/scalar.rs @@ -1,9 +1,9 @@ //! The class identifying dex structure. -use crate::DexValue; -use pyo3::prelude::*; +use std::collections::HashSet; -// TODO: move DexString here +use crate::{DexString, DexValue}; +use pyo3::prelude::*; #[pyclass] #[derive(Debug, Clone, Copy, PartialEq, Eq)] @@ -26,6 +26,11 @@ impl DexByte { pub fn __repr__(&self) -> String { format!("DexByte({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -49,6 +54,11 @@ impl DexShort { pub fn __repr__(&self) -> String { format!("DexShort({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -72,6 +82,11 @@ impl DexChar { pub fn __repr__(&self) -> String { format!("DexChar({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -95,6 +110,11 @@ impl DexInt { pub fn __repr__(&self) -> String { format!("DexInt({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -118,6 +138,11 @@ impl DexLong { pub fn __repr__(&self) -> String { format!("DexLong({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -141,6 +166,11 @@ impl DexFloat { pub fn __repr__(&self) -> String { format!("DexFloat({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -164,6 +194,11 @@ impl DexDouble { pub fn __repr__(&self) -> String { format!("DexDouble({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } /* DexString is already define in lib.rs, TODO: move the version in lib.rs here @@ -208,6 +243,11 @@ impl DexNull { pub fn __repr__(&self) -> String { "DexNull".into() } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -231,6 +271,11 @@ impl DexBoolean { pub fn __repr__(&self) -> String { format!("DexBoolean({})", self.0) } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + HashSet::new() + } } #[pyclass] @@ -261,4 +306,13 @@ impl DexArray { pub fn __repr__(&self) -> String { "DexArray(...)".into() } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + let mut strings = HashSet::new(); + for val in &self.0 { + strings.extend(val.get_all_strings()); + } + strings + } } diff --git a/androscalpel/src/value.rs b/androscalpel/src/value.rs index aa0e969..9fc9d56 100644 --- a/androscalpel/src/value.rs +++ b/androscalpel/src/value.rs @@ -1,5 +1,7 @@ //! The class identifying dex structure. +use std::collections::HashSet; + use pyo3::exceptions::PyTypeError; use pyo3::prelude::*; @@ -98,6 +100,30 @@ impl DexValue { DexValue::Boolean(val) => val.__str__(), } } + + /// Return all strings references in the value. + pub fn get_all_strings(&self) -> HashSet { + match self { + DexValue::Byte(val) => val.get_all_strings(), + DexValue::Short(val) => val.get_all_strings(), + DexValue::Char(val) => val.get_all_strings(), + DexValue::Int(val) => val.get_all_strings(), + DexValue::Long(val) => val.get_all_strings(), + DexValue::Float(val) => val.get_all_strings(), + DexValue::Double(val) => val.get_all_strings(), + DexValue::MethodType(val) => val.get_all_strings(), + DexValue::MethodHandle(val) => val.get_all_strings(), + DexValue::String(val) => val.get_all_strings(), + DexValue::Type(val) => val.get_all_strings(), + DexValue::Field(val) => val.get_all_strings(), + DexValue::Method(val) => val.get_all_strings(), + DexValue::Enum(val) => val.get_all_strings(), + DexValue::Array(val) => val.get_all_strings(), + DexValue::Annotation(val) => val.get_all_strings(), + DexValue::Null(val) => val.get_all_strings(), + DexValue::Boolean(val) => val.get_all_strings(), + } + } } impl IntoPy for DexValue {