From 2e7f2bd2004ffe4ef130c67fb6cabbbe03388407 Mon Sep 17 00:00:00 2001 From: BinFlip Date: Sun, 3 May 2026 21:13:15 -0700 Subject: [PATCH 1/6] feat: increased lints for better data integrity and avoiding malicious input --- dotscope/src/analysis/callgraph/graph.rs | 4 +- dotscope/src/analysis/callgraph/resolution.rs | 5 +- dotscope/src/analysis/cfg/graph.rs | 12 +- dotscope/src/analysis/cfg/loops.rs | 92 +++-- dotscope/src/analysis/cfg/semantics.rs | 26 +- dotscope/src/analysis/dataflow/liveness.rs | 17 +- dotscope/src/analysis/dataflow/reaching.rs | 4 +- dotscope/src/analysis/dataflow/sccp.rs | 4 +- dotscope/src/analysis/dataflow/solver.rs | 87 ++-- dotscope/src/analysis/defuse.rs | 12 +- dotscope/src/analysis/ssa/block.rs | 62 ++- dotscope/src/analysis/ssa/builder.rs | 30 +- dotscope/src/analysis/ssa/cfg.rs | 16 +- dotscope/src/analysis/ssa/consts.rs | 8 +- dotscope/src/analysis/ssa/converter.rs | 104 +++-- dotscope/src/analysis/ssa/decompose.rs | 71 ++-- dotscope/src/analysis/ssa/evaluator.rs | 28 +- dotscope/src/analysis/ssa/exception.rs | 2 +- .../src/analysis/ssa/function/canonical.rs | 18 +- dotscope/src/analysis/ssa/function/mod.rs | 19 +- dotscope/src/analysis/ssa/function/queries.rs | 45 ++- dotscope/src/analysis/ssa/function/rebuild.rs | 170 +++++--- .../src/analysis/ssa/function/transforms.rs | 103 +++-- dotscope/src/analysis/ssa/liveness.rs | 9 +- dotscope/src/analysis/ssa/memory.rs | 40 +- dotscope/src/analysis/ssa/ops.rs | 6 +- dotscope/src/analysis/ssa/patterns.rs | 12 +- dotscope/src/analysis/ssa/phis.rs | 6 +- dotscope/src/analysis/ssa/stack.rs | 24 +- dotscope/src/analysis/ssa/symbolic/expr.rs | 6 +- dotscope/src/analysis/ssa/types.rs | 2 +- dotscope/src/analysis/ssa/value.rs | 65 +-- dotscope/src/analysis/ssa/variable.rs | 2 +- dotscope/src/analysis/ssa/verifier.rs | 7 +- dotscope/src/analysis/taint.rs | 4 +- dotscope/src/analysis/x86/cfg.rs | 81 ++-- dotscope/src/analysis/x86/decoder.rs | 92 +++-- dotscope/src/analysis/x86/ssa.rs | 73 ++-- dotscope/src/analysis/x86/types.rs | 2 +- dotscope/src/assembly/builder.rs | 65 ++- dotscope/src/assembly/decoder.rs | 278 ++++++++----- dotscope/src/assembly/encoder.rs | 147 +++++-- dotscope/src/assembly/instruction.rs | 10 +- dotscope/src/assembly/instructions.rs | 57 +-- dotscope/src/cilassembly/builders/method.rs | 5 +- .../src/cilassembly/builders/method_body.rs | 8 +- dotscope/src/cilassembly/changes/assembly.rs | 8 +- dotscope/src/cilassembly/changes/changeref.rs | 37 +- dotscope/src/cilassembly/changes/heap.rs | 15 +- .../src/cilassembly/cleanup/compaction.rs | 21 +- dotscope/src/cilassembly/cleanup/executor.rs | 8 +- dotscope/src/cilassembly/cleanup/orphans.rs | 32 +- .../src/cilassembly/cleanup/references.rs | 20 +- dotscope/src/cilassembly/cleanup/request.rs | 17 +- dotscope/src/cilassembly/cleanup/stats.rs | 27 +- dotscope/src/cilassembly/cleanup/utils.rs | 15 +- dotscope/src/cilassembly/mod.rs | 16 +- dotscope/src/cilassembly/modifications.rs | 43 +- dotscope/src/cilassembly/writer/context.rs | 26 +- dotscope/src/cilassembly/writer/fields.rs | 15 +- dotscope/src/cilassembly/writer/fixups.rs | 317 ++++++++++----- dotscope/src/cilassembly/writer/generator.rs | 346 +++++++++++----- .../src/cilassembly/writer/heaps/rowpatch.rs | 2 +- .../src/cilassembly/writer/heaps/streaming.rs | 212 +++++++--- dotscope/src/cilassembly/writer/methods.rs | 32 +- dotscope/src/cilassembly/writer/output.rs | 112 +++++- .../src/cilassembly/writer/relocations.rs | 49 ++- dotscope/src/cilassembly/writer/remapper.rs | 36 +- dotscope/src/cilassembly/writer/signatures.rs | 6 +- dotscope/src/cilassembly/writer/sizes.rs | 24 +- dotscope/src/compiler/codegen/coalescing.rs | 91 +++-- dotscope/src/compiler/codegen/mod.rs | 156 +++++--- dotscope/src/compiler/events.rs | 10 +- dotscope/src/compiler/passes/algebraic.rs | 4 +- dotscope/src/compiler/passes/blockmerge.rs | 34 +- dotscope/src/compiler/passes/constants/mod.rs | 210 +++++----- dotscope/src/compiler/passes/controlflow.rs | 40 +- dotscope/src/compiler/passes/copying.rs | 12 +- dotscope/src/compiler/passes/deadcode.rs | 67 ++-- dotscope/src/compiler/passes/gvn.rs | 4 +- dotscope/src/compiler/passes/inlining.rs | 88 +++-- dotscope/src/compiler/passes/licm.rs | 9 +- dotscope/src/compiler/passes/loopcanon.rs | 18 +- dotscope/src/compiler/passes/predicates.rs | 51 +-- dotscope/src/compiler/passes/proxy.rs | 21 +- dotscope/src/compiler/passes/ranges.rs | 20 +- dotscope/src/compiler/passes/reassociate.rs | 26 +- dotscope/src/compiler/passes/strength.rs | 29 +- dotscope/src/compiler/scheduler.rs | 70 +++- dotscope/src/compiler/summary.rs | 18 +- dotscope/src/deobfuscation/cleanup.rs | 48 +-- dotscope/src/deobfuscation/engine/analysis.rs | 6 +- dotscope/src/deobfuscation/engine/codegen.rs | 4 +- .../src/deobfuscation/engine/detection.rs | 8 +- dotscope/src/deobfuscation/engine/pipeline.rs | 25 +- .../src/deobfuscation/passes/antidebug.rs | 6 +- .../deobfuscation/passes/bitmono/strings.rs | 10 +- .../deobfuscation/passes/bitmono/unmanaged.rs | 15 +- .../src/deobfuscation/passes/decryption.rs | 41 +- .../src/deobfuscation/passes/delegates.rs | 11 +- .../deobfuscation/passes/jiejienet/arrays.rs | 29 +- .../deobfuscation/passes/jiejienet/typeofs.rs | 21 +- dotscope/src/deobfuscation/passes/native.rs | 20 +- .../passes/netreactor/resolver.rs | 5 +- .../passes/netreactor/rewrite.rs | 4 +- .../src/deobfuscation/passes/neutralize.rs | 6 +- .../src/deobfuscation/passes/opaquefields.rs | 7 +- .../src/deobfuscation/passes/reflection.rs | 71 ++-- .../passes/unflattening/detection.rs | 14 +- .../passes/unflattening/dispatcher.rs | 16 +- .../passes/unflattening/reconstruction.rs | 61 ++- .../passes/unflattening/tracer/context.rs | 32 +- .../passes/unflattening/tracer/engine.rs | 105 +++-- .../passes/unflattening/tracer/helpers.rs | 27 +- dotscope/src/deobfuscation/processcell.rs | 6 +- dotscope/src/deobfuscation/renamer/cascade.rs | 28 +- .../src/deobfuscation/renamer/features.rs | 25 +- dotscope/src/deobfuscation/renamer/mod.rs | 8 +- dotscope/src/deobfuscation/renamer/phases.rs | 61 +-- dotscope/src/deobfuscation/renamer/prompt.rs | 15 +- .../deobfuscation/renamer/providers/local.rs | 15 +- .../deobfuscation/renamer/providers/simple.rs | 20 +- .../src/deobfuscation/renamer/validate.rs | 18 +- dotscope/src/deobfuscation/statemachine.rs | 62 +-- .../deobfuscation/techniques/bitmono/calli.rs | 19 +- .../deobfuscation/techniques/bitmono/hooks.rs | 81 +++- .../deobfuscation/techniques/bitmono/junk.rs | 14 +- .../techniques/bitmono/renamer.rs | 4 +- .../techniques/bitmono/strings.rs | 10 +- .../techniques/bitmono/unmanaged.rs | 20 +- .../techniques/confuserex/constants.rs | 11 +- .../techniques/confuserex/helpers.rs | 25 +- .../techniques/confuserex/marker.rs | 10 +- .../techniques/confuserex/metadata.rs | 107 +++-- .../techniques/confuserex/natives.rs | 2 +- .../techniques/confuserex/proxy.rs | 37 +- .../techniques/confuserex/resources.rs | 6 +- .../techniques/confuserex/statemachine.rs | 67 ++-- .../techniques/confuserex/tamper.rs | 14 +- .../src/deobfuscation/techniques/detection.rs | 6 +- .../techniques/generic/constants.rs | 9 +- .../techniques/generic/delegates.rs | 34 +- .../techniques/generic/flattening.rs | 2 +- .../techniques/generic/handlers.rs | 4 +- .../techniques/generic/metadata.rs | 27 +- .../techniques/generic/strings.rs | 14 +- .../techniques/jiejienet/arrays.rs | 102 +++-- .../techniques/jiejienet/constants.rs | 4 +- .../techniques/jiejienet/resources.rs | 54 ++- .../techniques/jiejienet/strings.rs | 16 +- .../techniques/jiejienet/typeofs.rs | 5 +- .../techniques/netreactor/helpers.rs | 45 ++- .../techniques/netreactor/necrobit.rs | 187 ++++++--- .../techniques/netreactor/resources.rs | 52 ++- .../techniques/obfuscar/strings.rs | 7 +- .../src/deobfuscation/techniques/registry.rs | 33 +- dotscope/src/deobfuscation/template.rs | 24 +- dotscope/src/deobfuscation/utils.rs | 30 +- dotscope/src/deobfuscation/workqueue.rs | 11 +- dotscope/src/emulation/capture/context.rs | 20 +- dotscope/src/emulation/engine/callresolver.rs | 66 ++-- dotscope/src/emulation/engine/controller.rs | 51 +-- dotscope/src/emulation/engine/exhandler.rs | 2 +- .../emulation/engine/interpreter/handlers.rs | 165 ++++---- dotscope/src/emulation/engine/pointer.rs | 6 +- dotscope/src/emulation/engine/stats.rs | 2 +- .../src/emulation/engine/typeops/newobj.rs | 32 +- dotscope/src/emulation/exception/types.rs | 4 +- dotscope/src/emulation/filesystem.rs | 9 +- dotscope/src/emulation/loader/data.rs | 6 +- dotscope/src/emulation/loader/peloader.rs | 176 ++++++--- dotscope/src/emulation/memory/addressspace.rs | 142 ++++--- dotscope/src/emulation/memory/arguments.rs | 61 ++- dotscope/src/emulation/memory/heap/arrays.rs | 20 +- .../src/emulation/memory/heap/collections.rs | 4 +- dotscope/src/emulation/memory/heap/mod.rs | 56 ++- dotscope/src/emulation/memory/heap/streams.rs | 80 ++-- dotscope/src/emulation/memory/locals.rs | 56 ++- dotscope/src/emulation/memory/page.rs | 44 ++- dotscope/src/emulation/memory/region.rs | 76 ++-- dotscope/src/emulation/memory/stack.rs | 14 +- dotscope/src/emulation/memory/unmanaged.rs | 80 ++-- dotscope/src/emulation/process/builder.rs | 33 +- dotscope/src/emulation/process/execution.rs | 5 +- .../src/emulation/runtime/bcl/appdomain.rs | 14 +- .../emulation/runtime/bcl/collections/list.rs | 12 +- .../emulation/runtime/bcl/crypto/hashing.rs | 34 +- .../src/emulation/runtime/bcl/crypto/hmac.rs | 5 +- .../src/emulation/runtime/bcl/crypto/mod.rs | 7 +- .../emulation/runtime/bcl/crypto/symmetric.rs | 54 ++- .../emulation/runtime/bcl/interop/marshal.rs | 133 ++++--- .../emulation/runtime/bcl/io/binaryreader.rs | 67 ++-- .../emulation/runtime/bcl/io/binarywriter.rs | 16 +- .../emulation/runtime/bcl/io/compression.rs | 8 +- .../emulation/runtime/bcl/io/filestream.rs | 68 ++-- .../src/emulation/runtime/bcl/io/stream.rs | 22 +- .../runtime/bcl/reflection/helpers.rs | 2 +- .../runtime/bcl/reflection/members.rs | 4 +- .../runtime/bcl/reflection/methods.rs | 8 +- .../emulation/runtime/bcl/reflection/types.rs | 59 ++- dotscope/src/emulation/runtime/bcl/runtime.rs | 119 +++--- .../src/emulation/runtime/bcl/system/array.rs | 158 +++++--- .../runtime/bcl/system/bitconverter.rs | 87 ++-- .../emulation/runtime/bcl/system/convert.rs | 100 ++--- .../emulation/runtime/bcl/system/datetime.rs | 125 ++++-- .../runtime/bcl/system/environment.rs | 14 +- .../src/emulation/runtime/bcl/system/math.rs | 373 +++++++++--------- .../emulation/runtime/bcl/system/string.rs | 30 +- .../emulation/runtime/bcl/text/encoding.rs | 157 +++----- .../runtime/bcl/text/stringbuilder.rs | 8 +- dotscope/src/emulation/runtime/native.rs | 82 ++-- dotscope/src/emulation/runtime/state.rs | 9 +- dotscope/src/emulation/thread/scheduler.rs | 12 +- dotscope/src/emulation/thread/state.rs | 19 +- dotscope/src/emulation/thread/sync.rs | 10 +- dotscope/src/emulation/tracer/calltree.rs | 12 +- dotscope/src/emulation/value/emvalue.rs | 25 +- dotscope/src/emulation/value/ops/binary.rs | 58 ++- dotscope/src/error.rs | 2 +- dotscope/src/file/mod.rs | 88 +++-- dotscope/src/file/parser.rs | 136 ++++--- dotscope/src/file/pe.rs | 255 +++++------- dotscope/src/file/repair.rs | 280 +++++++++---- dotscope/src/formatting/assembly.rs | 3 +- dotscope/src/formatting/exceptions.rs | 6 +- dotscope/src/formatting/helpers.rs | 24 +- dotscope/src/formatting/method_body.rs | 20 +- dotscope/src/formatting/tokens.rs | 2 +- dotscope/src/lib.rs | 21 +- dotscope/src/metadata/cilassemblyview.rs | 15 +- .../src/metadata/customattributes/parser.rs | 61 +-- .../metadata/customdebuginformation/parser.rs | 6 +- dotscope/src/metadata/dependencies/graph.rs | 15 +- dotscope/src/metadata/exports/builder.rs | 12 +- dotscope/src/metadata/exports/container.rs | 5 +- dotscope/src/metadata/exports/native.rs | 168 ++++++-- dotscope/src/metadata/identity/assembly.rs | 16 +- .../src/metadata/identity/cryptographic.rs | 5 +- dotscope/src/metadata/imports/cil.rs | 10 +- dotscope/src/metadata/imports/container.rs | 4 +- dotscope/src/metadata/imports/native.rs | 309 ++++++++++----- dotscope/src/metadata/loader/graph.rs | 9 +- dotscope/src/metadata/loader/mod.rs | 20 +- dotscope/src/metadata/marshalling/encoder.rs | 4 +- dotscope/src/metadata/marshalling/parser.rs | 6 +- dotscope/src/metadata/method/body.rs | 136 +++++-- dotscope/src/metadata/method/exceptions.rs | 4 +- dotscope/src/metadata/method/iter.rs | 20 +- dotscope/src/metadata/method/mod.rs | 205 +++++----- dotscope/src/metadata/resources/encoder.rs | 86 +++- dotscope/src/metadata/resources/parser.rs | 125 ++++-- dotscope/src/metadata/resources/types.rs | 72 ++-- dotscope/src/metadata/root.rs | 113 ++++-- dotscope/src/metadata/security/encoder.rs | 24 +- .../src/metadata/security/permissionset.rs | 17 +- dotscope/src/metadata/sequencepoints.rs | 22 +- dotscope/src/metadata/signatures/parser.rs | 25 +- dotscope/src/metadata/streams/blob.rs | 25 +- dotscope/src/metadata/streams/guid.rs | 18 +- dotscope/src/metadata/streams/streamheader.rs | 41 +- dotscope/src/metadata/streams/strings.rs | 7 +- dotscope/src/metadata/streams/tablesheader.rs | 31 +- dotscope/src/metadata/streams/userstrings.rs | 44 ++- dotscope/src/metadata/tablefields.rs | 50 +-- dotscope/src/metadata/tables/assembly/raw.rs | 21 +- .../src/metadata/tables/assembly/reader.rs | 2 +- .../src/metadata/tables/assemblyos/reader.rs | 2 +- .../metadata/tables/assemblyprocessor/raw.rs | 2 +- .../tables/assemblyprocessor/reader.rs | 2 +- .../tables/assemblyref/assemblyrefhash.rs | 2 +- .../src/metadata/tables/assemblyref/raw.rs | 21 +- .../src/metadata/tables/assemblyref/reader.rs | 2 +- .../src/metadata/tables/assemblyrefos/raw.rs | 13 +- .../metadata/tables/assemblyrefos/reader.rs | 2 +- .../tables/assemblyrefprocessor/raw.rs | 5 +- .../tables/assemblyrefprocessor/reader.rs | 2 +- .../metadata/tables/classlayout/builder.rs | 2 +- .../src/metadata/tables/classlayout/raw.rs | 7 +- .../src/metadata/tables/classlayout/reader.rs | 2 +- dotscope/src/metadata/tables/constant/raw.rs | 11 +- .../src/metadata/tables/constant/reader.rs | 4 +- .../metadata/tables/customattribute/raw.rs | 7 +- .../metadata/tables/customattribute/reader.rs | 2 +- .../tables/customdebuginformation/raw.rs | 9 +- .../tables/customdebuginformation/reader.rs | 2 +- .../src/metadata/tables/declsecurity/raw.rs | 7 +- .../metadata/tables/declsecurity/reader.rs | 2 +- dotscope/src/metadata/tables/document/raw.rs | 9 +- .../src/metadata/tables/document/reader.rs | 2 +- dotscope/src/metadata/tables/enclog/raw.rs | 4 +- dotscope/src/metadata/tables/enclog/reader.rs | 2 +- dotscope/src/metadata/tables/encmap/raw.rs | 3 +- dotscope/src/metadata/tables/encmap/reader.rs | 2 +- dotscope/src/metadata/tables/event/raw.rs | 7 +- dotscope/src/metadata/tables/event/reader.rs | 2 +- dotscope/src/metadata/tables/eventmap/raw.rs | 16 +- .../src/metadata/tables/eventmap/reader.rs | 2 +- dotscope/src/metadata/tables/eventptr/raw.rs | 1 + .../src/metadata/tables/eventptr/reader.rs | 2 +- .../src/metadata/tables/exportedtype/raw.rs | 11 +- .../metadata/tables/exportedtype/reader.rs | 2 +- dotscope/src/metadata/tables/field/builder.rs | 2 +- dotscope/src/metadata/tables/field/raw.rs | 9 +- dotscope/src/metadata/tables/field/reader.rs | 2 +- .../src/metadata/tables/fieldlayout/raw.rs | 5 +- .../src/metadata/tables/fieldlayout/reader.rs | 2 +- .../src/metadata/tables/fieldmarshal/raw.rs | 5 +- .../metadata/tables/fieldmarshal/reader.rs | 2 +- dotscope/src/metadata/tables/fieldptr/raw.rs | 1 + .../src/metadata/tables/fieldptr/reader.rs | 2 +- dotscope/src/metadata/tables/fieldrva/raw.rs | 5 +- .../src/metadata/tables/fieldrva/reader.rs | 2 +- dotscope/src/metadata/tables/file/raw.rs | 7 +- dotscope/src/metadata/tables/file/reader.rs | 2 +- .../src/metadata/tables/genericparam/raw.rs | 9 +- .../metadata/tables/genericparam/reader.rs | 2 +- .../tables/genericparamconstraint/raw.rs | 5 +- .../tables/genericparamconstraint/reader.rs | 2 +- dotscope/src/metadata/tables/implmap/raw.rs | 9 +- .../src/metadata/tables/implmap/reader.rs | 2 +- .../src/metadata/tables/importscope/raw.rs | 5 +- .../src/metadata/tables/importscope/reader.rs | 2 +- .../src/metadata/tables/interfaceimpl/raw.rs | 5 +- .../metadata/tables/interfaceimpl/reader.rs | 2 +- .../src/metadata/tables/localconstant/raw.rs | 5 +- .../metadata/tables/localconstant/reader.rs | 2 +- .../src/metadata/tables/localscope/owned.rs | 2 +- .../src/metadata/tables/localscope/raw.rs | 106 +++-- .../src/metadata/tables/localscope/reader.rs | 2 +- .../src/metadata/tables/localvariable/raw.rs | 7 +- .../metadata/tables/localvariable/reader.rs | 2 +- .../metadata/tables/manifestresource/raw.rs | 64 ++- .../tables/manifestresource/reader.rs | 2 +- dotscope/src/metadata/tables/memberref/raw.rs | 17 +- .../src/metadata/tables/memberref/reader.rs | 2 +- .../tables/methoddebuginformation/raw.rs | 5 +- .../tables/methoddebuginformation/reader.rs | 2 +- dotscope/src/metadata/tables/methoddef/raw.rs | 24 +- .../src/metadata/tables/methoddef/reader.rs | 2 +- .../src/metadata/tables/methodimpl/raw.rs | 7 +- .../src/metadata/tables/methodimpl/reader.rs | 2 +- dotscope/src/metadata/tables/methodptr/raw.rs | 1 + .../src/metadata/tables/methodptr/reader.rs | 2 +- .../metadata/tables/methodsemantics/mod.rs | 2 +- .../metadata/tables/methodsemantics/raw.rs | 9 +- .../metadata/tables/methodsemantics/reader.rs | 2 +- .../src/metadata/tables/methodspec/builder.rs | 7 +- .../src/metadata/tables/methodspec/mod.rs | 2 +- .../src/metadata/tables/methodspec/raw.rs | 5 +- .../src/metadata/tables/methodspec/reader.rs | 2 +- .../src/metadata/tables/module/builder.rs | 2 +- dotscope/src/metadata/tables/module/mod.rs | 2 +- dotscope/src/metadata/tables/module/raw.rs | 11 +- dotscope/src/metadata/tables/moduleref/raw.rs | 1 + .../src/metadata/tables/moduleref/reader.rs | 2 +- .../src/metadata/tables/nestedclass/raw.rs | 5 +- .../src/metadata/tables/nestedclass/reader.rs | 2 +- dotscope/src/metadata/tables/param/raw.rs | 7 +- dotscope/src/metadata/tables/param/reader.rs | 2 +- dotscope/src/metadata/tables/paramptr/raw.rs | 1 + .../src/metadata/tables/paramptr/reader.rs | 2 +- dotscope/src/metadata/tables/property/raw.rs | 7 +- .../src/metadata/tables/property/reader.rs | 2 +- .../src/metadata/tables/propertymap/raw.rs | 18 +- .../src/metadata/tables/propertymap/reader.rs | 2 +- .../src/metadata/tables/propertyptr/raw.rs | 1 + .../src/metadata/tables/propertyptr/reader.rs | 2 +- .../src/metadata/tables/standalonesig/raw.rs | 5 +- .../metadata/tables/standalonesig/reader.rs | 2 +- .../metadata/tables/statemachinemethod/raw.rs | 5 +- .../tables/statemachinemethod/reader.rs | 2 +- dotscope/src/metadata/tables/typedef/raw.rs | 65 ++- .../src/metadata/tables/typedef/reader.rs | 2 +- dotscope/src/metadata/tables/typeref/raw.rs | 7 +- .../src/metadata/tables/typeref/reader.rs | 2 +- .../tables/types/common/codedindex.rs | 6 +- .../src/metadata/tables/types/common/info.rs | 61 ++- .../src/metadata/tables/types/read/iter.rs | 14 +- .../src/metadata/tables/types/read/table.rs | 6 +- dotscope/src/metadata/tables/typespec/raw.rs | 1 + .../src/metadata/tables/typespec/reader.rs | 2 +- dotscope/src/metadata/typesystem/base.rs | 23 +- dotscope/src/metadata/typesystem/builder.rs | 5 +- dotscope/src/metadata/typesystem/encoder.rs | 22 +- dotscope/src/metadata/typesystem/mod.rs | 4 +- .../src/metadata/typesystem/primitives.rs | 33 +- dotscope/src/metadata/typesystem/registry.rs | 3 +- dotscope/src/metadata/typesystem/resolver.rs | 25 +- dotscope/src/metadata/validation/result.rs | 10 +- .../metadata/validation/shared/references.rs | 6 +- .../src/metadata/validation/shared/schema.rs | 16 +- .../validators/owned/metadata/attribute.rs | 22 +- .../owned/relationships/ownership.rs | 2 +- .../validators/owned/system/assembly.rs | 16 +- .../validators/owned/system/security.rs | 24 +- .../validators/owned/types/circularity.rs | 2 +- .../validators/owned/types/inheritance.rs | 10 +- .../validators/raw/constraints/generic.rs | 7 +- .../validators/raw/constraints/layout.rs | 12 +- .../validators/raw/modification/integrity.rs | 55 ++- .../validators/raw/modification/operation.rs | 14 +- .../validators/raw/structure/heap.rs | 4 +- .../validators/raw/structure/signature.rs | 24 +- .../validators/raw/structure/table.rs | 5 +- dotscope/src/metadata/vtfixup.rs | 78 ++-- dotscope/src/project/loader.rs | 4 +- dotscope/src/project/result.rs | 4 +- dotscope/src/test/analysis/runner.rs | 10 +- dotscope/src/utils/alignment.rs | 5 +- dotscope/src/utils/base64.rs | 51 ++- dotscope/src/utils/bitset.rs | 33 +- dotscope/src/utils/crypto.rs | 72 +++- dotscope/src/utils/decompress.rs | 33 +- dotscope/src/utils/enums.rs | 2 +- dotscope/src/utils/graph/algorithms/cycles.rs | 34 +- .../src/utils/graph/algorithms/dominators.rs | 191 +++++---- dotscope/src/utils/graph/algorithms/scc.rs | 52 ++- .../src/utils/graph/algorithms/topological.rs | 10 +- .../src/utils/graph/algorithms/traversal.rs | 33 +- dotscope/src/utils/graph/directed.rs | 72 ++-- dotscope/src/utils/io.rs | 117 +++--- dotscope/src/utils/lebytes.rs | 10 +- dotscope/src/utils/synchronization.rs | 5 +- dotscope/src/utils/visitedmap.rs | 63 +-- 424 files changed, 8663 insertions(+), 5364 deletions(-) diff --git a/dotscope/src/analysis/callgraph/graph.rs b/dotscope/src/analysis/callgraph/graph.rs index 662db944..527cd0fc 100644 --- a/dotscope/src/analysis/callgraph/graph.rs +++ b/dotscope/src/analysis/callgraph/graph.rs @@ -118,7 +118,7 @@ impl CallGraph { let types = assembly.types(); let mut graph: DirectedGraph = - DirectedGraph::with_capacity(method_count, method_count * 4); + DirectedGraph::with_capacity(method_count, method_count.saturating_mul(4)); let mut token_to_node: HashMap = HashMap::with_capacity(method_count); // First pass: add all internal methods as nodes @@ -882,7 +882,7 @@ impl CallGraph { .count(); let internal_methods = self.nodes().filter(|n| !n.is_external_ref).count(); - let external_refs = self.graph.node_count() - internal_methods; + let external_refs = self.graph.node_count().saturating_sub(internal_methods); CallGraphStats { method_count: internal_methods, diff --git a/dotscope/src/analysis/callgraph/resolution.rs b/dotscope/src/analysis/callgraph/resolution.rs index f81e7204..3e95fa42 100644 --- a/dotscope/src/analysis/callgraph/resolution.rs +++ b/dotscope/src/analysis/callgraph/resolution.rs @@ -360,7 +360,10 @@ impl CallResolver { virtual_methods: self.virtual_dispatch_table.len(), polymorphic_methods, max_targets, - total_types: self.type_subtypes.len() + self.sealed_types.len(), + total_types: self + .type_subtypes + .len() + .saturating_add(self.sealed_types.len()), interface_types: self.interfaces.len(), sealed_types: self.sealed_types.len(), } diff --git a/dotscope/src/analysis/cfg/graph.rs b/dotscope/src/analysis/cfg/graph.rs index 2a7a9297..cc6ad6f3 100644 --- a/dotscope/src/analysis/cfg/graph.rs +++ b/dotscope/src/analysis/cfg/graph.rs @@ -112,7 +112,7 @@ impl ControlFlowGraph<'static> { let block_count = blocks.len(); let mut graph: DirectedGraph = - DirectedGraph::with_capacity(block_count, block_count * 2); + DirectedGraph::with_capacity(block_count, block_count.saturating_mul(2)); // First pass: add all blocks as nodes let node_ids: Vec = blocks @@ -144,7 +144,9 @@ impl ControlFlowGraph<'static> { ))); } - let target_node = node_ids[succ_idx]; + let target_node = *node_ids + .get(succ_idx) + .ok_or_else(|| GraphError(format!("missing node id {succ_idx}")))?; let edge_kind = Self::classify_edge(flow_type, idx, successors.len()); let edge = CfgEdge::new(succ_idx, edge_kind); @@ -153,7 +155,9 @@ impl ControlFlowGraph<'static> { } // Identify entry and exit blocks - let entry = node_ids[0]; // Method entry is always block 0 + let entry = *node_ids + .first() + .ok_or_else(|| GraphError("method has no entry block".into()))?; // Method entry is always block 0 let mut exits: Vec = Vec::new(); for &node_id in &node_ids { let block = graph.node(node_id).ok_or_else(|| { @@ -314,7 +318,7 @@ impl<'a> ControlFlowGraph<'a> { } Some(FlowType::Switch) => { // For switches: last successor is default, others are cases - if successor_index == successor_count - 1 && successor_count > 1 { + if successor_count > 1 && successor_index == successor_count.saturating_sub(1) { CfgEdgeKind::Switch { case_value: None } } else { // Switch case indices are bounded by the number of successors, diff --git a/dotscope/src/analysis/cfg/loops.rs b/dotscope/src/analysis/cfg/loops.rs index e8287292..116d3fea 100644 --- a/dotscope/src/analysis/cfg/loops.rs +++ b/dotscope/src/analysis/cfg/loops.rs @@ -216,7 +216,7 @@ impl LoopInfo { #[must_use] pub fn single_latch(&self) -> Option { if self.latches.len() == 1 { - Some(self.latches[0]) + self.latches.first().copied() } else { None } @@ -360,8 +360,10 @@ impl LoopInfo { // Classic induction variable: 1 init from outside, 1+ updates from inside if outside_ops.len() == 1 && !inside_ops.is_empty() { - let init_op = outside_ops[0]; - let update_op = inside_ops[0]; // Take first inside operand + let (Some(init_op), Some(update_op)) = (outside_ops.first(), inside_ops.first()) + else { + continue; + }; // Try to determine update kind by analyzing the defining instruction let (update_kind, stride) = @@ -464,15 +466,21 @@ impl LoopForest { // Update block-to-loop mapping for all blocks in this loop for block_idx in loop_info.body.iter() { - if block_idx < self.block_to_loop.len() { - // Only update if this is a more deeply nested loop - if let Some(existing_idx) = self.block_to_loop[block_idx] { - if self.loops[existing_idx].depth < loop_info.depth { - self.block_to_loop[block_idx] = Some(loop_idx); + let Some(slot) = self.block_to_loop.get_mut(block_idx) else { + continue; + }; + // Only update if this is a more deeply nested loop + match *slot { + Some(existing_idx) => { + if self + .loops + .get(existing_idx) + .is_some_and(|l| l.depth < loop_info.depth) + { + *slot = Some(loop_idx); } - } else { - self.block_to_loop[block_idx] = Some(loop_idx); } + None => *slot = Some(loop_idx), } } @@ -501,11 +509,8 @@ impl LoopForest { #[must_use] pub fn innermost_loop(&self, block: NodeId) -> Option<&LoopInfo> { let block_idx = block.index(); - if block_idx < self.block_to_loop.len() { - self.block_to_loop[block_idx].map(|idx| &self.loops[idx]) - } else { - None - } + let loop_idx = (*self.block_to_loop.get(block_idx)?)?; + self.loops.get(loop_idx) } /// Returns the loop with the given header. @@ -517,7 +522,8 @@ impl LoopForest { /// Returns the loop depth for a block (0 if not in any loop). #[must_use] pub fn loop_depth(&self, block: NodeId) -> usize { - self.innermost_loop(block).map_or(0, |l| l.depth + 1) + self.innermost_loop(block) + .map_or(0, |l| l.depth.saturating_add(1)) } /// Returns true if a block is in any loop. @@ -720,7 +726,7 @@ where // Preheader exists only if there's exactly one non-loop predecessor loop_info.preheader = if non_loop_preds.len() == 1 { - Some(non_loop_preds[0]) + non_loop_preds.first().copied() } else { None }; @@ -804,43 +810,69 @@ fn compute_nesting(loops: &mut [LoopInfo]) { // For each loop, find its parent (smallest enclosing loop) for i in 0..n { - let header = loops[i].header; + let Some(header) = loops.get(i).map(|l| l.header) else { + continue; + }; // Find all loops that contain this loop's header (except itself) let mut candidates: Vec = (0..n) - .filter(|&j| j != i && loops[j].body.contains(header.index())) + .filter(|&j| { + j != i + && loops + .get(j) + .is_some_and(|l| l.body.contains(header.index())) + }) .collect(); // Parent is the smallest containing loop if !candidates.is_empty() { - candidates.sort_by_key(|&j| loops[j].size()); - let parent_idx = candidates[0]; - loops[i].parent = Some(loops[parent_idx].header); + candidates.sort_by_key(|&j| loops.get(j).map_or(usize::MAX, LoopInfo::size)); + let parent_idx = match candidates.first().copied() { + Some(p) => p, + None => continue, + }; + let parent_header = match loops.get(parent_idx).map(|l| l.header) { + Some(h) => h, + None => continue, + }; + if let Some(loop_i) = loops.get_mut(i) { + loop_i.parent = Some(parent_header); + } } } // Compute children from parent relationships for i in 0..n { - if let Some(parent_header) = loops[i].parent { - if let Some(&parent_idx) = header_to_idx.get(&parent_header) { - loops[parent_idx].children.push(loops[i].header); + let parent_opt = loops.get(i).and_then(|l| l.parent); + let Some(parent_header) = parent_opt else { + continue; + }; + let header_i = match loops.get(i).map(|l| l.header) { + Some(h) => h, + None => continue, + }; + if let Some(&parent_idx) = header_to_idx.get(&parent_header) { + if let Some(parent) = loops.get_mut(parent_idx) { + parent.children.push(header_i); } } } // Compute depths from parent chain for i in 0..n { - let mut depth = 0; - let mut current = loops[i].parent; + let mut depth: usize = 0; + let mut current = loops.get(i).and_then(|l| l.parent); while let Some(parent_header) = current { - depth += 1; + depth = depth.saturating_add(1); if let Some(&parent_idx) = header_to_idx.get(&parent_header) { - current = loops[parent_idx].parent; + current = loops.get(parent_idx).and_then(|l| l.parent); } else { break; } } - loops[i].depth = depth; + if let Some(l) = loops.get_mut(i) { + l.depth = depth; + } } } diff --git a/dotscope/src/analysis/cfg/semantics.rs b/dotscope/src/analysis/cfg/semantics.rs index af0632a4..17efb301 100644 --- a/dotscope/src/analysis/cfg/semantics.rs +++ b/dotscope/src/analysis/cfg/semantics.rs @@ -248,7 +248,11 @@ impl<'a> SemanticAnalyzer<'a> { let semantics = self.compute_block_semantics(block_idx); self.block_cache.insert(block_idx, semantics); } - &self.block_cache[&block_idx] + // Just inserted above if missing; fall back to inserting a default if for any + // reason the entry is gone (cannot happen with current logic, but avoids panics). + self.block_cache + .entry(block_idx) + .or_insert_with(|| BlockSemantics::new(block_idx)) } /// Computes semantic information for a block. @@ -267,11 +271,11 @@ impl<'a> SemanticAnalyzer<'a> { } // Analyze instructions - let mut const_assignments = 0; - let mut add_sub_ops = 0; - let mut call_count = 0; - let mut store_count = 0; - let mut comparison_count = 0; + let mut const_assignments: usize = 0; + let mut add_sub_ops: usize = 0; + let mut call_count: usize = 0; + let mut store_count: usize = 0; + let mut comparison_count: usize = 0; let mut has_return = false; let mut has_branch = false; let mut has_switch = false; @@ -279,26 +283,26 @@ impl<'a> SemanticAnalyzer<'a> { for instr in block.instructions() { match instr.op() { SsaOp::Const { dest, .. } => { - const_assignments += 1; + const_assignments = const_assignments.saturating_add(1); semantics.initialized_vars.push(*dest); } SsaOp::Add { dest, .. } | SsaOp::Sub { dest, .. } => { - add_sub_ops += 1; + add_sub_ops = add_sub_ops.saturating_add(1); semantics.updated_vars.push(*dest); } SsaOp::Call { .. } | SsaOp::CallVirt { .. } | SsaOp::NewObj { .. } => { - call_count += 1; + call_count = call_count.saturating_add(1); semantics.has_side_effects = true; } SsaOp::StoreField { .. } | SsaOp::StoreStaticField { .. } | SsaOp::StoreElement { .. } | SsaOp::StoreIndirect { .. } => { - store_count += 1; + store_count = store_count.saturating_add(1); semantics.has_side_effects = true; } SsaOp::Clt { .. } | SsaOp::Cgt { .. } | SsaOp::Ceq { .. } => { - comparison_count += 1; + comparison_count = comparison_count.saturating_add(1); semantics.has_comparison = true; } SsaOp::Branch { .. } => { diff --git a/dotscope/src/analysis/dataflow/liveness.rs b/dotscope/src/analysis/dataflow/liveness.rs index ed22bcce..d5d7534d 100644 --- a/dotscope/src/analysis/dataflow/liveness.rs +++ b/dotscope/src/analysis/dataflow/liveness.rs @@ -122,10 +122,11 @@ impl LiveVariables { for phi in block.phi_nodes() { for op in phi.operands() { let pred = op.predecessor(); - if pred < use_sets.len() { - if let Some(var_idx) = ssa.var_index(op.value()) { - if !def_sets[pred].contains(var_idx) { - use_sets[pred].insert(var_idx); + if let Some(var_idx) = ssa.var_index(op.value()) { + let already_def = def_sets.get(pred).is_some_and(|s| s.contains(var_idx)); + if !already_def { + if let Some(slot) = use_sets.get_mut(pred) { + slot.insert(var_idx); } } } @@ -189,10 +190,14 @@ impl DataFlowAnalysis for LiveVariables { let mut result = output.live.clone(); // Remove definitions (OUT - DEF) - result.difference_with(&self.def_sets[block_id]); + if let Some(d) = self.def_sets.get(block_id) { + result.difference_with(d); + } // Add uses (USE ∪ ...) - result.union_with(&self.use_sets[block_id]); + if let Some(u) = self.use_sets.get(block_id) { + result.union_with(u); + } LivenessResult { live: result } } diff --git a/dotscope/src/analysis/dataflow/reaching.rs b/dotscope/src/analysis/dataflow/reaching.rs index 6564f87d..2e3205ba 100644 --- a/dotscope/src/analysis/dataflow/reaching.rs +++ b/dotscope/src/analysis/dataflow/reaching.rs @@ -140,7 +140,9 @@ impl DataFlowAnalysis for ReachingDefinitions { ) -> Self::Lattice { // OUT = GEN ∪ IN (no KILL in SSA since each variable is defined once) let mut result = input.defs.clone(); - result.union_with(&self.gen_sets[block_id]); + if let Some(gen) = self.gen_sets.get(block_id) { + result.union_with(gen); + } ReachingDefsResult { defs: result } } } diff --git a/dotscope/src/analysis/dataflow/sccp.rs b/dotscope/src/analysis/dataflow/sccp.rs index cb7db8a8..0a31566c 100644 --- a/dotscope/src/analysis/dataflow/sccp.rs +++ b/dotscope/src/analysis/dataflow/sccp.rs @@ -362,8 +362,8 @@ impl ConstantPropagation { ScalarValue::Constant(c) => { // Known switch value - use checked conversion to handle negative values if let Some(idx) = c.as_i32().and_then(|i| usize::try_from(i).ok()) { - if idx < targets.len() { - self.add_cfg_edge(block_id, targets[idx]); + if let Some(target) = targets.get(idx) { + self.add_cfg_edge(block_id, *target); } else { self.add_cfg_edge(block_id, *default); } diff --git a/dotscope/src/analysis/dataflow/solver.rs b/dotscope/src/analysis/dataflow/solver.rs index 307e6363..c3349c60 100644 --- a/dotscope/src/analysis/dataflow/solver.rs +++ b/dotscope/src/analysis/dataflow/solver.rs @@ -138,16 +138,16 @@ impl DataFlowSolver { Direction::Forward => { // Entry block gets boundary value let entry = cfg.entry().index(); - if entry < num_blocks { - self.in_states[entry] = boundary; + if let Some(slot) = self.in_states.get_mut(entry) { + *slot = boundary; } } Direction::Backward => { // Exit blocks get boundary value for exit in cfg.exits() { let idx = exit.index(); - if idx < num_blocks { - self.out_states[idx] = boundary.clone(); + if let Some(slot) = self.out_states.get_mut(idx) { + *slot = boundary.clone(); } } } @@ -161,9 +161,9 @@ impl DataFlowSolver { for node in order { let idx = node.index(); - if idx < num_blocks { + if let Some(slot) = self.in_worklist.get_mut(idx) { self.worklist.push_back(idx); - self.in_worklist[idx] = true; + *slot = true; } } } @@ -174,8 +174,10 @@ impl DataFlowSolver { A::Lattice: Clone, { while let Some(block_idx) = self.worklist.pop_front() { - self.in_worklist[block_idx] = false; - self.iterations += 1; + if let Some(slot) = self.in_worklist.get_mut(block_idx) { + *slot = false; + } + self.iterations = self.iterations.saturating_add(1); let changed = match A::DIRECTION { Direction::Forward => self.process_forward(block_idx, ssa, cfg), @@ -203,28 +205,35 @@ impl DataFlowSolver { { // Compute input by meeting all predecessor outputs let node = NodeId::new(block_idx); + let Some(current_in) = self.in_states.get(block_idx).cloned() else { + return false; + }; let mut input = if cfg.predecessors(node).next().is_none() { // Entry block or unreachable - keep current in_state - self.in_states[block_idx].clone() + current_in.clone() } else { // Meet all predecessor outputs let mut result: Option = None; for pred in cfg.predecessors(node) { - let pred_out = &self.out_states[pred.index()]; + let Some(pred_out) = self.out_states.get(pred.index()) else { + continue; + }; result = Some(match result { None => pred_out.clone(), Some(acc) => acc.meet(pred_out), }); } - result.unwrap_or_else(|| self.in_states[block_idx].clone()) + result.unwrap_or_else(|| current_in.clone()) }; // Special case: entry block keeps its boundary value if node == cfg.entry() { - input = self.in_states[block_idx].clone(); + input = current_in.clone(); } - self.in_states[block_idx] = input.clone(); + if let Some(slot) = self.in_states.get_mut(block_idx) { + *slot = input.clone(); + } // Apply transfer function let Some(block) = ssa.block(block_idx) else { @@ -233,8 +242,11 @@ impl DataFlowSolver { let output = self.analysis.transfer(block_idx, block, &input, ssa); // Check if output changed - let changed = output != self.out_states[block_idx]; - self.out_states[block_idx] = output; + let Some(out_slot) = self.out_states.get_mut(block_idx) else { + return false; + }; + let changed = output != *out_slot; + *out_slot = output; changed } @@ -253,28 +265,35 @@ impl DataFlowSolver { { // Compute output by meeting all successor inputs let node = NodeId::new(block_idx); + let Some(current_out) = self.out_states.get(block_idx).cloned() else { + return false; + }; let mut output = if cfg.successors(node).next().is_none() { // Exit block or dead end - keep current out_state - self.out_states[block_idx].clone() + current_out.clone() } else { // Meet all successor inputs let mut result: Option = None; for succ in cfg.successors(node) { - let succ_in = &self.in_states[succ.index()]; + let Some(succ_in) = self.in_states.get(succ.index()) else { + continue; + }; result = Some(match result { None => succ_in.clone(), Some(acc) => acc.meet(succ_in), }); } - result.unwrap_or_else(|| self.out_states[block_idx].clone()) + result.unwrap_or_else(|| current_out.clone()) }; // Special case: exit blocks keep their boundary value if cfg.exits().contains(&node) { - output = self.out_states[block_idx].clone(); + output = current_out.clone(); } - self.out_states[block_idx] = output.clone(); + if let Some(slot) = self.out_states.get_mut(block_idx) { + *slot = output.clone(); + } // Apply transfer function (backward: input = transfer(output)) let Some(block) = ssa.block(block_idx) else { @@ -283,8 +302,11 @@ impl DataFlowSolver { let input = self.analysis.transfer(block_idx, block, &output, ssa); // Check if input changed - let changed = input != self.in_states[block_idx]; - self.in_states[block_idx] = input; + let Some(in_slot) = self.in_states.get_mut(block_idx) else { + return false; + }; + let changed = input != *in_slot; + *in_slot = input; changed } @@ -293,25 +315,26 @@ impl DataFlowSolver { fn add_affected_to_worklist(&mut self, block_idx: usize, cfg: &C) { let node = NodeId::new(block_idx); + let enqueue = |idx: usize, list: &mut Vec, work: &mut VecDeque| { + if let Some(slot) = list.get_mut(idx) { + if !*slot { + work.push_back(idx); + *slot = true; + } + } + }; + match A::DIRECTION { Direction::Forward => { // Forward: successors are affected for succ in cfg.successors(node) { - let idx = succ.index(); - if idx < self.in_worklist.len() && !self.in_worklist[idx] { - self.worklist.push_back(idx); - self.in_worklist[idx] = true; - } + enqueue(succ.index(), &mut self.in_worklist, &mut self.worklist); } } Direction::Backward => { // Backward: predecessors are affected for pred in cfg.predecessors(node) { - let idx = pred.index(); - if idx < self.in_worklist.len() && !self.in_worklist[idx] { - self.worklist.push_back(idx); - self.in_worklist[idx] = true; - } + enqueue(pred.index(), &mut self.in_worklist, &mut self.worklist); } } } diff --git a/dotscope/src/analysis/defuse.rs b/dotscope/src/analysis/defuse.rs index 0893ef2b..1f2eb19a 100644 --- a/dotscope/src/analysis/defuse.rs +++ b/dotscope/src/analysis/defuse.rs @@ -163,7 +163,7 @@ impl DefUseIndex { let max_var_idx = ssa .variables() .iter() - .map(|v| v.id().index() + 1) + .map(|v| v.id().index().saturating_add(1)) .max() .unwrap_or(0); let bitset_capacity = max_var_idx.max(variable_count); @@ -591,9 +591,13 @@ impl DefUseIndex { /// The single use site, or `None` if the variable has zero or multiple uses. #[must_use] pub fn single_use_site(&self, var: SsaVarId) -> Option { - self.uses - .get(&var) - .and_then(|uses| if uses.len() == 1 { Some(uses[0]) } else { None }) + self.uses.get(&var).and_then(|uses| { + if uses.len() == 1 { + uses.first().copied() + } else { + None + } + }) } } diff --git a/dotscope/src/analysis/ssa/block.rs b/dotscope/src/analysis/ssa/block.rs index 0691b0e2..411db9a8 100644 --- a/dotscope/src/analysis/ssa/block.rs +++ b/dotscope/src/analysis/ssa/block.rs @@ -68,8 +68,8 @@ impl std::ops::Add for ReplaceResult { type Output = Self; fn add(self, rhs: Self) -> Self { Self { - replaced: self.replaced + rhs.replaced, - skipped: self.skipped + rhs.skipped, + replaced: self.replaced.saturating_add(rhs.replaced), + skipped: self.skipped.saturating_add(rhs.skipped), } } } @@ -375,8 +375,8 @@ impl SsaBlock { /// need to also replace PHI operands (like eliminating trivial PHIs), use /// `replace_uses_including_phis`. pub fn replace_uses(&mut self, old_var: SsaVarId, new_var: SsaVarId) -> ReplaceResult { - let mut replaced = 0; - let mut skipped = 0; + let mut replaced: usize = 0; + let mut skipped: usize = 0; for instr in &mut self.instructions { let op = instr.op_mut(); @@ -384,12 +384,12 @@ impl SsaBlock { if let Some(dest) = op.dest() { if dest == new_var { if op.uses().contains(&old_var) { - skipped += 1; + skipped = skipped.saturating_add(1); } continue; } } - replaced += op.replace_uses(old_var, new_var); + replaced = replaced.saturating_add(op.replace_uses(old_var, new_var)); } ReplaceResult { replaced, skipped } @@ -442,7 +442,7 @@ impl SsaBlock { for operand in phi.operands_mut() { if operand.value() == old_var { *operand = PhiOperand::new(new_var, operand.predecessor()); - result.replaced += 1; + result.replaced = result.replaced.saturating_add(1); } } } @@ -489,7 +489,7 @@ impl SsaBlock { } // That operation must be an unconditional control transfer - match self.instructions[0].op() { + match self.instructions.first()?.op() { SsaOp::Jump { target } | SsaOp::Leave { target } => Some(*target), _ => None, } @@ -572,7 +572,10 @@ impl SsaBlock { // Only for non-terminator instructions let mut def_index: HashMap = HashMap::new(); for &idx in &non_terminator_indices { - if let Some(dest) = self.instructions[idx].def() { + let Some(instr) = self.instructions.get(idx) else { + continue; + }; + if let Some(dest) = instr.def() { def_index.insert(dest, idx); } } @@ -584,7 +587,7 @@ impl SsaBlock { .iter() .map(|phi| phi.result().index()) .max() - .map_or(0, |m| m + 1); + .map_or(0, |m| m.saturating_add(1)); let mut phi_defs = BitSet::new(max_phi_var); for phi in &self.phi_nodes { phi_defs.insert(phi.result().index()); @@ -608,7 +611,9 @@ impl SsaBlock { let mut prev_side_effect_pos: Option = None; for (pos, &idx) in non_terminator_indices.iter().enumerate() { - let instr = &self.instructions[idx]; + let Some(instr) = self.instructions.get(idx) else { + continue; + }; // Add data dependencies (def-use chains) for used in &instr.uses() { @@ -621,8 +626,12 @@ impl SsaBlock { if dep_idx != idx { if let Some(&dep_pos) = idx_to_pos.get(&dep_idx) { // instruction at pos depends on instruction at dep_pos - deps[pos].insert(dep_pos); - rdeps[dep_pos].insert(pos); + if let Some(d) = deps.get_mut(pos) { + d.insert(dep_pos); + } + if let Some(r) = rdeps.get_mut(dep_pos) { + r.insert(pos); + } } } } @@ -634,8 +643,12 @@ impl SsaBlock { if !instr.op().is_pure() { if let Some(prev_pos) = prev_side_effect_pos { // This side-effecting instruction depends on the previous one - deps[pos].insert(prev_pos); - rdeps[prev_pos].insert(pos); + if let Some(d) = deps.get_mut(pos) { + d.insert(prev_pos); + } + if let Some(r) = rdeps.get_mut(prev_pos) { + r.insert(pos); + } } prev_side_effect_pos = Some(pos); } @@ -658,10 +671,15 @@ impl SsaBlock { sorted_positions.push(pos); // Reduce in_degree for dependents - for dep_pos in rdeps[pos].iter() { - in_degree[dep_pos] -= 1; - if in_degree[dep_pos] == 0 { - ready.push_back(dep_pos); + let Some(rd) = rdeps.get(pos) else { + continue; + }; + for dep_pos in rd.iter() { + if let Some(slot) = in_degree.get_mut(dep_pos) { + *slot = slot.saturating_sub(1); + if *slot == 0 { + ready.push_back(dep_pos); + } } } } @@ -678,8 +696,10 @@ impl SsaBlock { // First add non-terminator instructions in sorted order for pos in sorted_positions { - let original_idx = non_terminator_indices[pos]; - if let Some(instr) = temp[original_idx].take() { + let Some(&original_idx) = non_terminator_indices.get(pos) else { + continue; + }; + if let Some(instr) = temp.get_mut(original_idx).and_then(Option::take) { self.instructions.push(instr); } } diff --git a/dotscope/src/analysis/ssa/builder.rs b/dotscope/src/analysis/ssa/builder.rs index 92941557..61826e64 100644 --- a/dotscope/src/analysis/ssa/builder.rs +++ b/dotscope/src/analysis/ssa/builder.rs @@ -135,7 +135,7 @@ impl SsaFunctionBuilder { /// /// Uses `Phi` origin since stack temporaries are not real CIL locals. fn alloc_stack_var(&mut self) -> SsaVarId { - self.next_stack_slot += 1; + self.next_stack_slot = self.next_stack_slot.saturating_add(1); self.alloc_var_with_origin(VariableOrigin::Phi) } @@ -211,25 +211,33 @@ pub struct SsaFunctionContext<'a> { impl SsaFunctionContext<'_> { /// Gets the argument variable at the specified index and sets its type. /// - /// # Panics - /// - /// Panics if `index >= num_args`. + /// If `index >= num_args`, allocates a fresh stack variable instead. This + /// keeps construction non-panicking; misuse will surface during type + /// validation in [`SsaFunctionBuilder::build`]. #[must_use] pub fn arg(&mut self, index: usize, ty: SsaType) -> SsaVarId { - let id = self.builder.arg_vars[index]; - self.builder.variables[id.index()].3 = ty; + let Some(&id) = self.builder.arg_vars.get(index) else { + return self.builder.alloc_stack_var_typed(ty); + }; + if let Some(entry) = self.builder.variables.get_mut(id.index()) { + entry.3 = ty; + } id } /// Gets the local variable at the specified index and sets its type. /// - /// # Panics - /// - /// Panics if `index >= num_locals`. + /// If `index >= num_locals`, allocates a fresh stack variable instead. This + /// keeps construction non-panicking; misuse will surface during type + /// validation in [`SsaFunctionBuilder::build`]. #[must_use] pub fn local(&mut self, index: usize, ty: SsaType) -> SsaVarId { - let id = self.builder.local_vars[index]; - self.builder.variables[id.index()].3 = ty; + let Some(&id) = self.builder.local_vars.get(index) else { + return self.builder.alloc_stack_var_typed(ty); + }; + if let Some(entry) = self.builder.variables.get_mut(id.index()) { + entry.3 = ty; + } id } diff --git a/dotscope/src/analysis/ssa/cfg.rs b/dotscope/src/analysis/ssa/cfg.rs index 657da42f..a28aed05 100644 --- a/dotscope/src/analysis/ssa/cfg.rs +++ b/dotscope/src/analysis/ssa/cfg.rs @@ -119,9 +119,9 @@ impl<'a> SsaCfg<'a> { .unwrap_or_default(); for succ in block_succs { - if succ < block_count { + if let Some(slot) = predecessors.get_mut(succ) { block_succs_list.push(succ); - predecessors[succ].push(block_idx); + slot.push(block_idx); } } } @@ -139,10 +139,16 @@ impl<'a> SsaCfg<'a> { { if handler_start < block_count && try_start < block_count - && !predecessors[handler_start].contains(&try_start) + && !predecessors + .get(handler_start) + .is_some_and(|p| p.contains(&try_start)) { - successors[try_start].push(handler_start); - predecessors[handler_start].push(try_start); + if let Some(slot) = successors.get_mut(try_start) { + slot.push(handler_start); + } + if let Some(slot) = predecessors.get_mut(handler_start) { + slot.push(try_start); + } } } } diff --git a/dotscope/src/analysis/ssa/consts.rs b/dotscope/src/analysis/ssa/consts.rs index df6c3a8c..fb7fd1f3 100644 --- a/dotscope/src/analysis/ssa/consts.rs +++ b/dotscope/src/analysis/ssa/consts.rs @@ -185,11 +185,15 @@ impl<'a> ConstEvaluator<'a> { // Copy needs recursive evaluation that the shared helper cannot provide, // because it resolves a variable rather than performing arithmetic. if let SsaOp::Copy { src, .. } = op { - return self.evaluate_var_depth(*src, depth + 1); + return self.evaluate_var_depth(*src, depth.saturating_add(1)); } let ptr_size = self.pointer_size; - evaluate_const_op(op, |var| self.evaluate_var_depth(var, depth + 1), ptr_size) + evaluate_const_op( + op, + |var| self.evaluate_var_depth(var, depth.saturating_add(1)), + ptr_size, + ) } /// Returns all computed constants. diff --git a/dotscope/src/analysis/ssa/converter.rs b/dotscope/src/analysis/ssa/converter.rs index 9212a253..33858e0a 100644 --- a/dotscope/src/analysis/ssa/converter.rs +++ b/dotscope/src/analysis/ssa/converter.rs @@ -186,7 +186,9 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { /// Returns the rename group ID for a stack slot at the given depth. #[allow(clippy::cast_possible_truncation)] fn stack_group(&self, depth: usize) -> u32 { - self.num_args as u32 + self.num_locals as u32 + depth as u32 + (self.num_args as u32) + .saturating_add(self.num_locals as u32) + .saturating_add(depth as u32) } /// Returns the rename group ID for an argument. @@ -196,14 +198,14 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { /// Returns the rename group ID for a local. fn local_group(&self, idx: u16) -> u32 { - self.num_args as u32 + idx as u32 + (self.num_args as u32).saturating_add(idx as u32) } /// If a group ID corresponds to a stack slot, returns the slot index. fn stack_slot_from_group(&self, group: u32) -> Option { - let base = self.num_args as u32 + self.num_locals as u32; + let base = (self.num_args as u32).saturating_add(self.num_locals as u32); if group >= base { - Some((group - base) as usize) + Some(group.saturating_sub(base) as usize) } else { None } @@ -215,8 +217,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { *origin } else if (group as usize) < self.num_args { VariableOrigin::Argument(group as u16) - } else if (group as usize) < self.num_args + self.num_locals { - VariableOrigin::Local((group as usize - self.num_args) as u16) + } else if (group as usize) < self.num_args.saturating_add(self.num_locals) { + VariableOrigin::Local((group as usize).saturating_sub(self.num_args) as u16) } else { VariableOrigin::Phi } @@ -224,7 +226,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { /// Returns true if a group ID corresponds to a stack slot. fn is_stack_group(&self, group: u32) -> bool { - group >= self.num_args as u32 + self.num_locals as u32 + group >= (self.num_args as u32).saturating_add(self.num_locals as u32) } /// Returns the SSA type for a variable origin. @@ -617,7 +619,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { uses: BTreeMap::new(), version_stacks: BTreeMap::new(), next_version: BTreeMap::new(), - address_taken: BitSet::new(num_args + num_locals), + address_taken: BitSet::new(num_args.saturating_add(num_locals)), group_origins: BTreeMap::new(), load_origins: BTreeMap::new(), exit_stacks: BTreeMap::new(), @@ -827,7 +829,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { let pops = if is_newobj { param_count } else { - param_count + usize::from(has_this) + param_count.saturating_add(usize::from(has_this)) }; let pushes = if is_newobj { 1 @@ -838,7 +840,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { clippy::cast_possible_truncation, clippy::cast_possible_wrap )] - ((pushes as i32 - pops as i32) + ((pushes as i32) + .saturating_sub(pops as i32) .clamp(i32::from(i8::MIN), i32::from(i8::MAX)) as i8) }, @@ -851,13 +854,16 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { Self::resolve_calli_info(asm, token).map( |(param_count, has_this, has_return)| { // calli pops: param_count + has_this + 1 (function pointer) - let pops = param_count + usize::from(has_this) + 1; + let pops = param_count + .saturating_add(usize::from(has_this)) + .saturating_add(1); let pushes = usize::from(has_return); #[allow( clippy::cast_possible_truncation, clippy::cast_possible_wrap )] - ((pushes as i32 - pops as i32) + ((pushes as i32) + .saturating_sub(pops as i32) .clamp(i32::from(i8::MIN), i32::from(i8::MAX)) as i8) }, @@ -872,7 +878,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { if net_effect < 0 { depth = depth.saturating_sub(net_effect.unsigned_abs() as usize); } else { - depth += net_effect as usize; + depth = depth.saturating_add(net_effect as usize); } } @@ -977,7 +983,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { // Pass successors only for the last instruction (terminator) // Non-terminator instructions don't need successor information - let instr_successors = if instr_idx == instr_count - 1 { + let instr_successors = if instr_idx.saturating_add(1) == instr_count { successors.as_slice() } else { &[] @@ -1059,11 +1065,12 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { .is_some_and(|instr| instr.op().is_terminator()); if !has_terminator && successors.len() == 1 { - let fallthrough_target = successors[0]; - let jump_instr = SsaInstruction::synthetic(SsaOp::Jump { - target: fallthrough_target, - }); - block.add_instruction(jump_instr); + if let Some(&fallthrough_target) = successors.first() { + let jump_instr = SsaInstruction::synthetic(SsaOp::Jump { + target: fallthrough_target, + }); + block.add_instruction(jump_instr); + } } } @@ -1213,7 +1220,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { // newobj doesn't pop 'this' - it creates it param_count } else { - param_count + usize::from(has_this) + param_count.saturating_add(usize::from(has_this)) }; let pushes = if is_newobj { @@ -1276,7 +1283,9 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { Self::resolve_calli_info(assembly, token) { // calli pops: param_count + has_this args + 1 function pointer - let pops = param_count + usize::from(has_this) + 1; + let pops = param_count + .saturating_add(usize::from(has_this)) + .saturating_add(1); let pushes = usize::from(has_return); #[allow(clippy::cast_possible_truncation)] @@ -1508,8 +1517,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { group_origins.get(&group).copied().unwrap_or_else(|| { if (group as usize) < num_args { VariableOrigin::Argument(group as u16) - } else if (group as usize) < num_args + num_locals { - VariableOrigin::Local((group as usize - num_args) as u16) + } else if (group as usize) < num_args.saturating_add(num_locals) { + VariableOrigin::Local((group as usize).saturating_sub(num_args) as u16) } else { VariableOrigin::Phi } @@ -1687,7 +1696,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { .filter(|&g| g != u32::MAX) .max() .unwrap_or(0) as usize; - let group_capacity = max_phi_group + 1; + let group_capacity = max_phi_group.saturating_add(1); let mut stack_groups_needing_v0 = BitSet::new(group_capacity); for block in self.function.blocks() { for phi in block.phi_nodes() { @@ -1768,7 +1777,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { .entry(group) .or_default() .push((0, var_id)); - *scope_pushed.entry(group).or_insert(0) += 1; + let entry = scope_pushed.entry(group).or_insert(0); + *entry = entry.saturating_add(1); } } @@ -1973,7 +1983,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { if old_idx != new_idx { index_remap.insert(old_idx, new_idx); } - new_idx += 1; + new_idx = new_idx.saturating_add(1); } } @@ -2027,7 +2037,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { var_type: SsaType, ) -> SsaVarId { let version = *self.next_version.get(&group).unwrap_or(&0); - *self.next_version.entry(group).or_insert(0) += 1; + let next = self.next_version.entry(group).or_insert(0); + *next = next.saturating_add(1); let def_site = match instr_idx { Some(idx) => DefSite::instruction(block_idx, idx), @@ -2075,8 +2086,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { let slot_idx = slot as usize; // Get the slot value, falling back to TOS for depth mismatch - let stack_slot = if slot_idx < stack.len() { - &stack[slot_idx] + let stack_slot = if let Some(s) = stack.get(slot_idx) { + s } else if let Some(last) = stack.last() { last } else { @@ -2420,7 +2431,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { }; // Phi type starts Unknown, resolved by resolve_phi_types() after rename let new_var = self.new_def(origin, group, block_idx, None, SsaType::Unknown); - *pushed_counts.entry(group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(group).or_insert(0); + *entry = entry.saturating_add(1); + } if let Some(slot) = self.stack_slot_from_group(group) { if slot < slots_capacity { @@ -2466,7 +2480,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { let group = self.stack_group(0); let new_var = self.new_def(origin, group, block_idx, None, exception_type); - *pushed_counts.entry(group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(group).or_insert(0); + *entry = entry.saturating_add(1); + } rename_map.insert(placeholder, new_var); // Mark slot 0 as handled so the resolution loop below skips it slots_with_phis.insert(0); @@ -2508,7 +2525,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { .entry(group) .or_default() .push((0, resolved)); - *pushed_counts.entry(group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(group).or_insert(0); + *entry = entry.saturating_add(1); + } } else { self.try_map_from_predecessors(block_idx, slot, placeholder, rename_map); } @@ -2605,7 +2625,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { let var_type = self.type_for_origin(target_origin); let new_var = self.new_def(origin, dest_group, block_idx, Some(instr_idx), var_type); - *pushed_counts.entry(dest_group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(dest_group).or_insert(0); + *entry = entry.saturating_add(1); + } rename_map.insert(sim_var, new_var); @@ -2637,7 +2660,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { }; let var_type = self.type_for_origin(origin); let v = self.new_def(origin, group, block_idx, Some(instr_idx), var_type); - *pushed_counts.entry(group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(group).or_insert(0); + *entry = entry.saturating_add(1); + } v } else { // Use the stack depth position recorded during simulation, @@ -2649,7 +2675,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { let group = self.stack_group(slot); let var_type = self.infer_instruction_result_type(block_idx, instr_idx); let v = self.new_def(origin, group, block_idx, Some(instr_idx), var_type); - *pushed_counts.entry(group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(group).or_insert(0); + *entry = entry.saturating_add(1); + } v }; @@ -2676,7 +2705,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { Some(instr_idx), var_type, ); - *pushed_counts.entry(store_group).or_insert(0) += 1; + { + let entry = pushed_counts.entry(store_group).or_insert(0); + *entry = entry.saturating_add(1); + } } } } diff --git a/dotscope/src/analysis/ssa/decompose.rs b/dotscope/src/analysis/ssa/decompose.rs index 15b34522..19f9ad1b 100644 --- a/dotscope/src/analysis/ssa/decompose.rs +++ b/dotscope/src/analysis/ssa/decompose.rs @@ -619,34 +619,28 @@ fn decompose_standard_instruction( // brfalse.s, brfalse // For conditional branches, successors[0] is the branch target, successors[1] is fallthrough uses.first().and_then(|&condition| { - if successors.len() >= 2 { - // brfalse: jumps to target if false, falls through if true - // successors[0] = branch target (false path), successors[1] = fallthrough (true path) - Some(SsaOp::Branch { - condition, - true_target: successors[1], // fallthrough - false_target: successors[0], // branch target - }) - } else { - None - } + let target = successors.first().copied()?; + let fallthrough = successors.get(1).copied()?; + // brfalse: jumps to target if false, falls through if true + Some(SsaOp::Branch { + condition, + true_target: fallthrough, + false_target: target, + }) }) } 0x2D | 0x3A => { // brtrue.s, brtrue // For conditional branches, successors[0] is the branch target, successors[1] is fallthrough uses.first().and_then(|&condition| { - if successors.len() >= 2 { - // brtrue: jumps to target if true, falls through if false - // successors[0] = branch target (true path), successors[1] = fallthrough (false path) - Some(SsaOp::Branch { - condition, - true_target: successors[0], // branch target - false_target: successors[1], // fallthrough - }) - } else { - None - } + let target = successors.first().copied()?; + let fallthrough = successors.get(1).copied()?; + // brtrue: jumps to target if true, falls through if false + Some(SsaOp::Branch { + condition, + true_target: target, + false_target: fallthrough, + }) }) } @@ -698,8 +692,9 @@ fn decompose_standard_instruction( uses.first().and_then(|&value| { if successors.len() >= 2 { // Last successor is the default, rest are case targets - let default = *successors.last().unwrap_or(&0); - let targets: Vec = successors[..successors.len() - 1].to_vec(); + let last_idx = successors.len().checked_sub(1)?; + let default = successors.get(last_idx).copied().unwrap_or(0); + let targets: Vec = successors.get(..last_idx)?.to_vec(); Some(SsaOp::Switch { value, targets, @@ -804,7 +799,11 @@ fn decompose_standard_instruction( // calli if let Some(signature) = extract_signature_token(&instr.operand) { let (fptr, args) = if let Some(&fptr) = uses.last() { - let args = uses[..uses.len() - 1].to_vec(); + let last_idx = uses.len().saturating_sub(1); + let args = uses + .get(..last_idx) + .map(<[SsaVarId]>::to_vec) + .unwrap_or_default(); (fptr, args) } else { (SsaVarId::from_index(0), vec![]) @@ -1430,18 +1429,16 @@ fn comparison_branch( unsigned: bool, ) -> Option { if let (Some(&left), Some(&right)) = (uses.first(), uses.get(1)) { - if successors.len() >= 2 { - Some(SsaOp::BranchCmp { - left, - right, - cmp, - unsigned, - true_target: successors[0], - false_target: successors[1], - }) - } else { - None - } + let true_target = successors.first().copied()?; + let false_target = successors.get(1).copied()?; + Some(SsaOp::BranchCmp { + left, + right, + cmp, + unsigned, + true_target, + false_target, + }) } else { None } diff --git a/dotscope/src/analysis/ssa/evaluator.rs b/dotscope/src/analysis/ssa/evaluator.rs index 4d6ce744..95b40956 100644 --- a/dotscope/src/analysis/ssa/evaluator.rs +++ b/dotscope/src/analysis/ssa/evaluator.rs @@ -988,7 +988,9 @@ impl<'a> SsaEvaluator<'a> { pub fn evaluate_path(&mut self, path: &[usize]) { for (i, &block_idx) in path.iter().enumerate() { if i > 0 { - self.set_predecessor(Some(path[i - 1])); + if let Some(&prev) = path.get(i.saturating_sub(1)) { + self.set_predecessor(Some(prev)); + } } self.evaluate_block(block_idx); } @@ -1026,17 +1028,21 @@ impl<'a> SsaEvaluator<'a> { // Evaluate all loop blocks for (i, &block_idx) in loop_blocks.iter().enumerate() { if i > 0 { - self.set_predecessor(Some(loop_blocks[i - 1])); + if let Some(&prev) = loop_blocks.get(i.saturating_sub(1)) { + self.set_predecessor(Some(prev)); + } } else if loop_blocks.len() > 1 { // First block - predecessor is the last block (loop back edge) - self.set_predecessor(Some(loop_blocks[loop_blocks.len() - 1])); + if let Some(&last) = loop_blocks.last() { + self.set_predecessor(Some(last)); + } } self.evaluate_block(block_idx); } // Check if values changed if self.values_match(&snapshot) { - return iteration + 1; + return iteration.saturating_add(1); } } @@ -1090,7 +1096,9 @@ impl<'a> SsaEvaluator<'a> { for _ in 0..iterations { for (i, &block_idx) in loop_blocks.iter().enumerate() { if i > 0 { - self.set_predecessor(Some(loop_blocks[i - 1])); + if let Some(&prev) = loop_blocks.get(i.saturating_sub(1)) { + self.set_predecessor(Some(prev)); + } } self.evaluate_block(block_idx); } @@ -1588,7 +1596,9 @@ impl<'a> SsaEvaluator<'a> { // Recursively resolve operands first for operand in op.uses() { if !self.values.contains_key(&operand) { - if let Some(resolved) = self.resolve_with_trace(operand, max_depth - 1) { + if let Some(resolved) = + self.resolve_with_trace(operand, max_depth.saturating_sub(1)) + { self.values.insert(operand, resolved); } } @@ -1646,7 +1656,7 @@ impl<'a> SsaEvaluator<'a> { let Some(instr) = terminator else { // No terminator - fall through to next block if it exists - let next_idx = block_idx + 1; + let next_idx = block_idx.saturating_add(1); if next_idx < self.ssa.block_count() { return ControlFlow::Continue(next_idx); } @@ -1715,8 +1725,8 @@ impl<'a> SsaEvaluator<'a> { Some(v) => { #[allow(clippy::cast_possible_truncation)] let idx = v as usize; - if idx < targets.len() { - ControlFlow::Continue(targets[idx]) + if let Some(&target) = targets.get(idx) { + ControlFlow::Continue(target) } else { ControlFlow::Continue(*default) } diff --git a/dotscope/src/analysis/ssa/exception.rs b/dotscope/src/analysis/ssa/exception.rs index 52c9c911..33ce5582 100644 --- a/dotscope/src/analysis/ssa/exception.rs +++ b/dotscope/src/analysis/ssa/exception.rs @@ -72,7 +72,7 @@ pub struct SsaExceptionHandler { /// When an end-boundary block is removed during canonicalization, we need to find the /// next block that survived to preserve the boundary semantics. fn find_next_surviving(block_remap: &[Option], start: usize) -> Option { - block_remap[start..].iter().find_map(|entry| *entry) + block_remap.get(start..)?.iter().find_map(|entry| *entry) } impl SsaExceptionHandler { diff --git a/dotscope/src/analysis/ssa/function/canonical.rs b/dotscope/src/analysis/ssa/function/canonical.rs index 2d1b2a7e..0be15b08 100644 --- a/dotscope/src/analysis/ssa/function/canonical.rs +++ b/dotscope/src/analysis/ssa/function/canonical.rs @@ -121,7 +121,7 @@ impl SsaFunction { block_remap.push(None); // This block will be removed } else { block_remap.push(Some(new_index)); - new_index += 1; + new_index = new_index.saturating_add(1); } } @@ -131,15 +131,17 @@ impl SsaFunction { let mut redirect_map: BTreeMap = BTreeMap::new(); for old_index in 0..self.blocks.len() { - if block_remap[old_index].is_none() { + if matches!(block_remap.get(old_index), Some(None)) { // This block is being removed - find where it would jump to if let Some(target) = self.find_ultimate_target(old_index, &block_remap) { redirect_map.insert(old_index, target); } else { // Can't find a redirect target - we must keep this block. // Reassign it a new index. - block_remap[old_index] = Some(new_index); - new_index += 1; + if let Some(slot) = block_remap.get_mut(old_index) { + *slot = Some(new_index); + new_index = new_index.saturating_add(1); + } } } } @@ -182,7 +184,9 @@ impl SsaFunction { .filter_map(|&old_pred| block_remap.get(old_pred).and_then(|opt| *opt)) .collect(); - let block = &mut self.blocks[block_idx]; + let Some(block) = self.blocks.get_mut(block_idx) else { + continue; + }; for phi in block.phi_nodes_mut() { // Collect changes first (to avoid borrow issues) let mut changes: Vec<(usize, Option, Vec)> = Vec::new(); @@ -264,7 +268,7 @@ impl SsaFunction { // Phase 6: Remove empty blocks and compact block indices. let mut kept_blocks: Vec = Vec::with_capacity(new_index); for (old_index, block) in self.blocks.drain(..).enumerate() { - if block_remap[old_index].is_some() { + if matches!(block_remap.get(old_index), Some(Some(_))) { kept_blocks.push(block); } } @@ -355,7 +359,7 @@ impl SsaFunction { // In CIL semantics, empty blocks fall through to the next block. if terminator.is_none() && block.instructions().is_empty() { // Try to fall through to the next block - let next_block = current + 1; + let next_block = current.saturating_add(1); if next_block < self.blocks.len() { if let Some(Some(new_idx)) = block_remap.get(next_block) { // Next block exists in new layout diff --git a/dotscope/src/analysis/ssa/function/mod.rs b/dotscope/src/analysis/ssa/function/mod.rs index dc520161..8cbb1c39 100644 --- a/dotscope/src/analysis/ssa/function/mod.rs +++ b/dotscope/src/analysis/ssa/function/mod.rs @@ -418,7 +418,7 @@ impl SsaFunction { let from_vars = self .variables .iter() - .map(|v| v.id().index() + 1) + .map(|v| v.id().index().saturating_add(1)) .max() .unwrap_or(0); let from_blocks = self @@ -439,7 +439,7 @@ impl SsaFunction { phi_ids.chain(instr_ids) }) .max() - .map_or(0, |m| m + 1); + .map_or(0, |m| m.saturating_add(1)); from_vars.max(from_blocks).max(self.variables.len()) } @@ -583,7 +583,8 @@ impl SsaFunction { self.variables.push(var); // Extend rename_groups to keep it in sync (default u32::MAX = no group) if self.rename_groups.len() <= id.index() { - self.rename_groups.resize(id.index() + 1, u32::MAX); + self.rename_groups + .resize(id.index().saturating_add(1), u32::MAX); } id } @@ -656,8 +657,10 @@ impl SsaFunction { let old_id = var.id(); let new_id = SsaVarId::from_index(index); // Carry over the rename group from the old position - if old_id.index() < old_groups.len() { - new_groups[index] = old_groups[old_id.index()]; + if let Some(&old_group) = old_groups.get(old_id.index()) { + if let Some(slot) = new_groups.get_mut(index) { + *slot = old_group; + } } if old_id != new_id { remap.insert(old_id, new_id); @@ -812,9 +815,11 @@ impl SsaFunction { pub(crate) fn set_rename_group(&mut self, var_id: SsaVarId, group: u32) { let idx = var_id.index(); if idx >= self.rename_groups.len() { - self.rename_groups.resize(idx + 1, u32::MAX); + self.rename_groups.resize(idx.saturating_add(1), u32::MAX); + } + if let Some(slot) = self.rename_groups.get_mut(idx) { + *slot = group; } - self.rename_groups[idx] = group; } /// Rebuilds SSA form after CFG modifications (e.g., control flow unflattening). diff --git a/dotscope/src/analysis/ssa/function/queries.rs b/dotscope/src/analysis/ssa/function/queries.rs index c08897cc..fe203de4 100644 --- a/dotscope/src/analysis/ssa/function/queries.rs +++ b/dotscope/src/analysis/ssa/function/queries.rs @@ -639,17 +639,19 @@ impl SsaFunction { // Remainder (state % N) or bitwise AND (state & mask): trace left operand SsaOp::Rem { left, .. } | SsaOp::And { left, .. } => { - self.trace_to_phi_impl(*left, target_block, depth + 1) + self.trace_to_phi_impl(*left, target_block, depth.saturating_add(1)) } // XOR operation (e.g., state ^ key): try both operands SsaOp::Xor { left, right, .. } => { // Try left first - if let Some(phi) = self.trace_to_phi_impl(*left, target_block, depth + 1) { + if let Some(phi) = + self.trace_to_phi_impl(*left, target_block, depth.saturating_add(1)) + { return Some(phi); } // Then try right (XOR is commutative) - self.trace_to_phi_impl(*right, target_block, depth + 1) + self.trace_to_phi_impl(*right, target_block, depth.saturating_add(1)) } // Arithmetic operations (ConfuserEx uses mul/add/sub for state transformation) @@ -658,20 +660,24 @@ impl SsaFunction { | SsaOp::Add { left, right, .. } | SsaOp::Sub { left, right, .. } => { // Try left first (usually where the state variable is) - if let Some(phi) = self.trace_to_phi_impl(*left, target_block, depth + 1) { + if let Some(phi) = + self.trace_to_phi_impl(*left, target_block, depth.saturating_add(1)) + { return Some(phi); } // Then try right - self.trace_to_phi_impl(*right, target_block, depth + 1) + self.trace_to_phi_impl(*right, target_block, depth.saturating_add(1)) } // Shift operations: trace the value operand SsaOp::Shr { value, .. } | SsaOp::Shl { value, .. } => { - self.trace_to_phi_impl(*value, target_block, depth + 1) + self.trace_to_phi_impl(*value, target_block, depth.saturating_add(1)) } // Copy: trace through to source - SsaOp::Copy { src, .. } => self.trace_to_phi_impl(*src, target_block, depth + 1), + SsaOp::Copy { src, .. } => { + self.trace_to_phi_impl(*src, target_block, depth.saturating_add(1)) + } // For other operations (including constants), the variable cannot be traced to a PHI _ => None, @@ -799,8 +805,14 @@ impl SsaFunction { } // Determine block count from successor map keys - let block_count = successor_map.keys().copied().max().map_or(0, |m| m + 1); - let block_count = block_count.max(from + 1).max(to + 1); + let block_count = successor_map + .keys() + .copied() + .max() + .map_or(0, |m| m.saturating_add(1)); + let block_count = block_count + .max(from.saturating_add(1)) + .max(to.saturating_add(1)); let mut visited = BitSet::new(block_count); let mut worklist = vec![from]; @@ -910,14 +922,16 @@ impl SsaFunction { // Count phi node operands for phi in block.phi_nodes() { for operand in phi.operands() { - *counts.entry(operand.value()).or_insert(0) += 1; + let entry = counts.entry(operand.value()).or_insert(0_usize); + *entry = entry.saturating_add(1); } } // Count instruction operands for instr in block.instructions() { for var in instr.op().uses() { - *counts.entry(var).or_insert(0) += 1; + let entry = counts.entry(var).or_insert(0_usize); + *entry = entry.saturating_add(1); } } } @@ -1076,10 +1090,11 @@ impl SsaFunction { // If all returns are the same constant if constants_found.iter().all(Option::is_some) { - let first = &constants_found[0]; - if constants_found.iter().all(|c| c == first) { - if let Some(const_val) = first { - return ReturnInfo::Constant(const_val.clone()); + if let Some(first) = constants_found.first() { + if constants_found.iter().all(|c| c == first) { + if let Some(const_val) = first { + return ReturnInfo::Constant(const_val.clone()); + } } } } diff --git a/dotscope/src/analysis/ssa/function/rebuild.rs b/dotscope/src/analysis/ssa/function/rebuild.rs index d210e3b0..7153d28b 100644 --- a/dotscope/src/analysis/ssa/function/rebuild.rs +++ b/dotscope/src/analysis/ssa/function/rebuild.rs @@ -97,7 +97,7 @@ pub(crate) struct SsaRebuilder<'a> { impl<'a> SsaRebuilder<'a> { pub fn new(ssa: &'a mut SsaFunction) -> Self { - let next_group = ssa.num_args as u32 + ssa.num_locals as u32; + let next_group = (ssa.num_args as u32).saturating_add(ssa.num_locals as u32); let block_count = ssa.blocks.len(); Self { ssa, @@ -314,7 +314,8 @@ impl<'a> SsaRebuilder<'a> { .position(|i| i.is_terminator()) .unwrap_or(entry.instructions().len()); for (i, instr) in to_rescue.into_iter().enumerate() { - entry.instructions_mut().insert(term_idx + i, instr); + let pos = term_idx.saturating_add(i); + entry.instructions_mut().insert(pos, instr); } } @@ -345,8 +346,10 @@ impl<'a> SsaRebuilder<'a> { // Clear unreachable blocks for block_idx in 0..self.ssa.blocks.len() { if !reachable.contains(block_idx) { - self.ssa.blocks[block_idx].instructions_mut().clear(); - self.ssa.blocks[block_idx].phi_nodes_mut().clear(); + if let Some(b) = self.ssa.blocks.get_mut(block_idx) { + b.instructions_mut().clear(); + b.phi_nodes_mut().clear(); + } } } @@ -357,7 +360,9 @@ impl<'a> SsaRebuilder<'a> { if !reachable.contains(block_idx) { continue; } - let block = &mut self.ssa.blocks[block_idx]; + let Some(block) = self.ssa.blocks.get_mut(block_idx) else { + continue; + }; // Remove operands from unreachable predecessors for phi in block.phi_nodes_mut().iter_mut() { @@ -367,11 +372,11 @@ impl<'a> SsaRebuilder<'a> { // Inline trivial phis (0 or 1 unique operand value) block.phi_nodes_mut().retain(|phi| { let operands = phi.operands(); - if operands.is_empty() { + let Some(first_op) = operands.first() else { return false; // Remove empty phi - } + }; + let first = first_op.value(); // Check if all operands resolve to the same value - let first = operands[0].value(); if operands .iter() .all(|op| op.value() == first || op.value() == phi.result()) @@ -430,9 +435,21 @@ impl<'a> SsaRebuilder<'a> { let mut rank: Vec = vec![0; num_vars]; let find = |parent: &mut Vec, mut x: usize| -> usize { - while parent[x] != x { - parent[x] = parent[parent[x]]; // path halving - x = parent[x]; + // Bound the iterations to the union-find size to avoid infinite + // loops if the parent array is corrupted. + for _ in 0..parent.len() { + let Some(&p) = parent.get(x) else { + return x; + }; + if p == x { + return x; + } + // Path halving: parent[x] = parent[parent[x]]. + let pp = parent.get(p).copied().unwrap_or(p); + if let Some(slot) = parent.get_mut(x) { + *slot = pp; + } + x = pp; } x }; @@ -443,13 +460,23 @@ impl<'a> SsaRebuilder<'a> { if ra == rb { return; } - if rank[ra] < rank[rb] { - parent[ra] = rb; - } else if rank[ra] > rank[rb] { - parent[rb] = ra; + let rank_ra = rank.get(ra).copied().unwrap_or(0); + let rank_rb = rank.get(rb).copied().unwrap_or(0); + if rank_ra < rank_rb { + if let Some(slot) = parent.get_mut(ra) { + *slot = rb; + } + } else if rank_ra > rank_rb { + if let Some(slot) = parent.get_mut(rb) { + *slot = ra; + } } else { - parent[rb] = ra; - rank[ra] += 1; + if let Some(slot) = parent.get_mut(rb) { + *slot = ra; + } + if let Some(slot) = rank.get_mut(ra) { + *slot = slot.saturating_add(1); + } } }; @@ -520,7 +547,7 @@ impl<'a> SsaRebuilder<'a> { arg_local_reps.entry(group).or_insert(idx); } VariableOrigin::Local(li) => { - let group = num_args as u32 + li as u32; + let group = (num_args as u32).saturating_add(li as u32); arg_local_reps.entry(group).or_insert(idx); } _ => {} @@ -530,7 +557,7 @@ impl<'a> SsaRebuilder<'a> { for instr in block.instructions() { match instr.op() { SsaOp::LoadLocal { dest, local_index } => { - let group = num_args as u32 + *local_index as u32; + let group = (num_args as u32).saturating_add(*local_index as u32); if let (Some(&dest_idx), Some(&rep_idx)) = (var_to_idx.get(dest), arg_local_reps.get(&group)) { @@ -563,14 +590,14 @@ impl<'a> SsaRebuilder<'a> { // Only split groups that actually have disconnected components. let max_existing = self.ssa.rename_groups.iter().copied().max().unwrap_or(0); let mut next_new_group = if max_existing == u32::MAX { - num_args as u32 + self.ssa.num_locals as u32 + (num_args as u32).saturating_add(self.ssa.num_locals as u32) } else { - max_existing + 1 + max_existing.saturating_add(1) }; let mut updates: Vec<(SsaVarId, u32)> = Vec::new(); - let real_local_limit = num_args as u32 + self.ssa.num_locals as u32; + let real_local_limit = (num_args as u32).saturating_add(self.ssa.num_locals as u32); for (&original_group, members) in &group_members { if members.len() <= 1 { @@ -595,7 +622,9 @@ impl<'a> SsaRebuilder<'a> { let mut canonical_root: Option = None; for (&root, component_members) in &component_roots { for &idx in component_members { - let var = &self.ssa.variables[idx]; + let Some(var) = self.ssa.variables.get(idx) else { + continue; + }; match var.origin() { VariableOrigin::Argument(_) | VariableOrigin::Local(_) => { canonical_root = Some(root); @@ -612,8 +641,9 @@ impl<'a> SsaRebuilder<'a> { // If no canonical root found, pick the largest component. // Break ties deterministically by the smallest variable index // within each component to avoid nondeterministic grouping. - let canonical_root = canonical_root.unwrap_or_else(|| { - *component_roots + let canonical_root = match canonical_root { + Some(r) => r, + None => match component_roots .iter() .max_by(|(_, members_a), (_, members_b)| { members_a.len().cmp(&members_b.len()).then_with(|| { @@ -622,9 +652,13 @@ impl<'a> SsaRebuilder<'a> { min_b.cmp(&min_a) }) }) - .map(|(root, _)| root) - .unwrap() - }); + .map(|(root, _)| *root) + { + Some(r) => r, + // No components — group is empty, nothing to split. + None => continue, + }, + }; // Decide which components to keep vs. split. // @@ -648,7 +682,7 @@ impl<'a> SsaRebuilder<'a> { .into_iter() .flat_map(|members| members.iter()) .filter_map(|&idx| { - let var = &self.ssa.variables[idx]; + let var = self.ssa.variables.get(idx)?; match var.origin() { VariableOrigin::Argument(_) | VariableOrigin::Local(_) => { Some((var.origin(), var.var_type().clone())) @@ -668,7 +702,7 @@ impl<'a> SsaRebuilder<'a> { .into_iter() .flat_map(|members| members.iter()) .filter_map(|&idx| { - let t = self.ssa.variables[idx].var_type(); + let t = self.ssa.variables.get(idx)?.var_type(); if t.is_unknown() { None } else { @@ -688,7 +722,9 @@ impl<'a> SsaRebuilder<'a> { let keep = if has_canonical_origin { // Tier 1: origin + type matching component_members.iter().any(|&idx| { - let var = &self.ssa.variables[idx]; + let Some(var) = self.ssa.variables.get(idx) else { + return false; + }; match var.origin() { VariableOrigin::Argument(_) | VariableOrigin::Local(_) => { canonical_origin_types @@ -704,7 +740,7 @@ impl<'a> SsaRebuilder<'a> { let comp_type: Option = component_members .iter() .filter_map(|&idx| { - let t = self.ssa.variables[idx].var_type(); + let t = self.ssa.variables.get(idx)?.var_type(); if t.is_unknown() { None } else { @@ -726,10 +762,12 @@ impl<'a> SsaRebuilder<'a> { } let new_group = next_new_group; - next_new_group += 1; + next_new_group = next_new_group.saturating_add(1); for &idx in component_members { - let var_id = self.ssa.variables[idx].id(); - updates.push((var_id, new_group)); + let Some(var) = self.ssa.variables.get(idx) else { + continue; + }; + updates.push((var.id(), new_group)); } } } @@ -757,7 +795,7 @@ impl<'a> SsaRebuilder<'a> { .filter(|&g| g != u32::MAX) .max() .unwrap_or(next_new_group); - let mut next_split_group = max_group_so_far + 1; + let mut next_split_group = max_group_so_far.saturating_add(1); let mut split_updates: Vec<(SsaVarId, u32)> = Vec::new(); for block in &self.ssa.blocks { @@ -777,11 +815,11 @@ impl<'a> SsaRebuilder<'a> { if *already_seen { // Third+ definition — also split split_updates.push((dest, next_split_group)); - next_split_group += 1; + next_split_group = next_split_group.saturating_add(1); } else { // Second definition — split this one split_updates.push((dest, next_split_group)); - next_split_group += 1; + next_split_group = next_split_group.saturating_add(1); *already_seen = true; } } else { @@ -837,7 +875,7 @@ impl<'a> SsaRebuilder<'a> { .or_insert(VariableOrigin::Argument(i as u16)); } for i in 0..self.ssa.num_locals { - let group = self.ssa.num_args as u32 + i as u32; + let group = (self.ssa.num_args as u32).saturating_add(i as u32); self.group_origins .entry(group) .or_insert(VariableOrigin::Local(i as u16)); @@ -859,7 +897,7 @@ impl<'a> SsaRebuilder<'a> { // Update next_group to be above all existing groups let max_existing = self.ssa.rename_groups.iter().copied().max().unwrap_or(0); if max_existing != u32::MAX { - self.next_group = self.next_group.max(max_existing + 1); + self.next_group = self.next_group.max(max_existing.saturating_add(1)); } } @@ -871,7 +909,8 @@ impl<'a> SsaRebuilder<'a> { SsaOp::LoadLocal { dest, local_index } => { let dest_group = self.ssa.rename_group(*dest); if dest_group != u32::MAX && !self.group_types.contains_key(&dest_group) { - let local_group = self.ssa.num_args as u32 + *local_index as u32; + let local_group = + (self.ssa.num_args as u32).saturating_add(*local_index as u32); if let Some(local_type) = self.group_types.get(&local_group).cloned() { self.group_types.insert(dest_group, local_type); } @@ -996,7 +1035,7 @@ impl<'a> SsaRebuilder<'a> { .or_insert(VariableOrigin::Phi); if self.ssa.rename_group(var_id) == u32::MAX { let group = self.next_group; - self.next_group += 1; + self.next_group = self.next_group.saturating_add(1); self.ssa.set_rename_group(var_id, group); self.group_origins.insert(group, VariableOrigin::Phi); if let Some(inferred) = inferred_type { @@ -1023,8 +1062,10 @@ impl<'a> SsaRebuilder<'a> { // during rename. for block_idx in 0..self.ssa.blocks.len() { if !self.reachable.contains(block_idx) { - self.ssa.blocks[block_idx].instructions_mut().clear(); - self.ssa.blocks[block_idx].phi_nodes_mut().clear(); + if let Some(b) = self.ssa.blocks.get_mut(block_idx) { + b.instructions_mut().clear(); + b.phi_nodes_mut().clear(); + } } } @@ -1135,12 +1176,16 @@ impl<'a> SsaRebuilder<'a> { // Merge dominance frontiers for b in handler_reachable.iter() { - if b < local_df.len() { - if b >= self.dominance_frontiers.len() { - self.dominance_frontiers - .resize(b + 1, BitSet::new(block_count)); - } - self.dominance_frontiers[b].union_with(&local_df[b]); + let Some(local_b) = local_df.get(b) else { + continue; + }; + if b >= self.dominance_frontiers.len() { + let new_len = b.checked_add(1).unwrap_or(self.dominance_frontiers.len()); + self.dominance_frontiers + .resize(new_len, BitSet::new(block_count)); + } + if let Some(slot) = self.dominance_frontiers.get_mut(b) { + slot.union_with(local_b); } } } @@ -1156,7 +1201,7 @@ impl<'a> SsaRebuilder<'a> { // Only ORIGINAL .NET locals have default-initialization at entry. // `num_locals == original_num_locals` always now (no inflation). for i in 0..self.ssa.num_locals { - let group = self.ssa.num_args as u32 + i as u32; + let group = (self.ssa.num_args as u32).saturating_add(i as u32); self.defs.entry(group).or_default().insert(0); } @@ -1218,7 +1263,7 @@ impl<'a> SsaRebuilder<'a> { SsaOp::LoadLocal { dest, local_index } if consumed_vars.contains(dest.index()) => { - let group = self.ssa.num_args as u32 + *local_index as u32; + let group = (self.ssa.num_args as u32).saturating_add(*local_index as u32); use_sites .entry(group) .or_insert_with(|| BitSet::new(block_count)) @@ -1715,7 +1760,7 @@ impl<'a> SsaRebuilder<'a> { match origin { VariableOrigin::Argument(idx) => idx as u32, VariableOrigin::Local(idx) => { - ctx.num_args as u32 + idx as u32 + (ctx.num_args as u32).saturating_add(idx as u32) } VariableOrigin::Phi => u32::MAX, } @@ -1729,7 +1774,8 @@ impl<'a> SsaRebuilder<'a> { for (i, (group, origin, old_result)) in phi_info.iter().enumerate() { let version = *next_version.get(group).unwrap_or(&0); - *next_version.entry(*group).or_insert(0) += 1; + let entry = next_version.entry(*group).or_insert(0); + *entry = entry.saturating_add(1); let var_type = ctx .group_types @@ -1747,7 +1793,8 @@ impl<'a> SsaRebuilder<'a> { } version_stacks.entry(*group).or_default().push(new_var_id); - *pushed_counts.entry(*group).or_insert(0) += 1; + let pc = pushed_counts.entry(*group).or_insert(0); + *pc = pc.saturating_add(1); if *old_result != new_var_id { rename_map.insert(*old_result, new_var_id); @@ -1781,7 +1828,7 @@ impl<'a> SsaRebuilder<'a> { let load_target_group = match instr.op() { SsaOp::LoadArg { arg_index, .. } => Some(*arg_index as u32), SsaOp::LoadLocal { local_index, .. } => { - Some(ctx.num_args as u32 + *local_index as u32) + Some((ctx.num_args as u32).saturating_add(*local_index as u32)) } _ => None, }; @@ -1839,7 +1886,8 @@ impl<'a> SsaRebuilder<'a> { .entry(dest_group) .or_default() .push(reaching_def); - *pushed_counts.entry(dest_group).or_insert(0) += 1; + let pc = pushed_counts.entry(dest_group).or_insert(0); + *pc = pc.saturating_add(1); } // Convert to Nop since the value is the reaching definition if let Some(block) = ssa.block_mut(block_idx) { @@ -1856,7 +1904,8 @@ impl<'a> SsaRebuilder<'a> { if group != u32::MAX { if let Some(origin) = origin { let version = *next_version.get(&group).unwrap_or(&0); - *next_version.entry(group).or_insert(0) += 1; + let nv = next_version.entry(group).or_insert(0); + *nv = nv.saturating_add(1); // Use per-variable type first (preserves stack-derived local types), // fall back to per-group type @@ -1881,7 +1930,8 @@ impl<'a> SsaRebuilder<'a> { } version_stacks.entry(group).or_default().push(new_var_id); - *pushed_counts.entry(group).or_insert(0) += 1; + let pc = pushed_counts.entry(group).or_insert(0); + *pc = pc.saturating_add(1); if *old_dest != new_var_id { rename_map.insert(*old_dest, new_var_id); @@ -1926,7 +1976,7 @@ impl<'a> SsaRebuilder<'a> { match phi.origin() { VariableOrigin::Argument(idx) => idx as u32, VariableOrigin::Local(idx) => { - ctx.num_args as u32 + idx as u32 + (ctx.num_args as u32).saturating_add(idx as u32) } VariableOrigin::Phi => u32::MAX, } diff --git a/dotscope/src/analysis/ssa/function/transforms.rs b/dotscope/src/analysis/ssa/function/transforms.rs index d42694e8..9c382db8 100644 --- a/dotscope/src/analysis/ssa/function/transforms.rs +++ b/dotscope/src/analysis/ssa/function/transforms.rs @@ -102,7 +102,10 @@ impl SsaFunction { self.blocks .iter_mut() .map(|block| block.replace_uses(old_var, new_var)) - .fold(ReplaceResult::default(), |acc, r| acc + r) + .fold(ReplaceResult::default(), |acc, r| ReplaceResult { + replaced: acc.replaced.saturating_add(r.replaced), + skipped: acc.skipped.saturating_add(r.skipped), + }) } /// Replaces all uses of `old_var` with `new_var`, including in PHI operands. @@ -133,7 +136,10 @@ impl SsaFunction { self.blocks .iter_mut() .map(|block| block.replace_uses_including_phis(old_var, new_var)) - .fold(ReplaceResult::default(), |acc, r| acc + r) + .fold(ReplaceResult::default(), |acc, r| ReplaceResult { + replaced: acc.replaced.saturating_add(r.replaced), + skipped: acc.skipped.saturating_add(r.skipped), + }) } /// Replaces all uses of `old_var` with `new_var` within a specific block. @@ -173,7 +179,7 @@ impl SsaFunction { copies: &BTreeMap, ) -> CopyPropagationResult { let variable_count = self.var_id_capacity(); - let mut total_replaced = 0; + let mut total_replaced: usize = 0; let mut fully_propagated = BitSet::new(variable_count); let mut partially_propagated = BitSet::new(variable_count); @@ -190,7 +196,7 @@ impl SsaFunction { } else { partially_propagated.insert(dest.index()); } - total_replaced += result.replaced; + total_replaced = total_replaced.saturating_add(result.replaced); } } @@ -278,7 +284,7 @@ impl SsaFunction { } } - let mut pruned = 0; + let mut pruned: usize = 0; for block_idx in reachable.iter() { if let Some(block) = self.block_mut(block_idx) { @@ -314,7 +320,7 @@ impl SsaFunction { let mut keep_iter = to_keep.iter(); operands.retain(|_| *keep_iter.next().unwrap_or(&true)); - pruned += original_len - operands.len(); + pruned = pruned.saturating_add(original_len.saturating_sub(operands.len())); } } } @@ -339,7 +345,9 @@ impl SsaFunction { for use_var in instr.op().uses() { if let Some(var) = self.var_index(use_var) { let use_site = UseSite::instruction(block_idx, instr_idx); - self.variables[var].add_use(use_site); + if let Some(slot) = self.variables.get_mut(var) { + slot.add_use(use_site); + } } } } @@ -349,7 +357,9 @@ impl SsaFunction { for operand in phi.operands() { if let Some(var) = self.var_index(operand.value()) { let use_site = UseSite::phi_operand(block_idx, phi_idx); - self.variables[var].add_use(use_site); + if let Some(slot) = self.variables.get_mut(var) { + slot.add_use(use_site); + } } } } @@ -441,7 +451,7 @@ impl SsaFunction { /// /// The number of phis eliminated. pub fn eliminate_trivial_phis(&mut self, options: &TrivialPhiOptions) -> usize { - let mut total_eliminated = 0; + let mut total_eliminated: usize = 0; let block_count = self.blocks.len(); // Precompute reachability data if in reachable mode @@ -495,9 +505,11 @@ impl SsaFunction { .filter(|&v| v != result) .collect(); - if unique_sources.len() == 1 { - let source = *unique_sources.iter().next().unwrap(); - + if let Some(&source) = unique_sources + .iter() + .next() + .filter(|_| unique_sources.len() == 1) + { let is_self_ref = match (&var_def_block, options.reachable) { (Some(vdb), Some(reachable)) => self .would_create_self_reference_reachable( @@ -531,8 +543,11 @@ impl SsaFunction { .filter(|&v| v != result) .collect(); - if unique_reachable.len() == 1 { - let source = *unique_reachable.iter().next().unwrap(); + if let Some(&source) = unique_reachable + .iter() + .next() + .filter(|_| unique_reachable.len() == 1) + { let is_self_ref = match (&var_def_block, options.reachable) { (Some(vdb), Some(reachable)) => self .would_create_self_reference_reachable( @@ -597,7 +612,7 @@ impl SsaFunction { break; } - total_eliminated += trivial_set.count(); + total_eliminated = total_eliminated.saturating_add(trivial_set.count()); for block in &mut self.blocks { block.phi_nodes_mut().retain(|phi| { let idx = phi.result().index(); @@ -620,7 +635,7 @@ impl SsaFunction { for (result, _) in &trivial_phis { trivial_set.insert(result.index()); } - total_eliminated += trivial_set.count(); + total_eliminated = total_eliminated.saturating_add(trivial_set.count()); for block in &mut self.blocks { block.phi_nodes_mut().retain(|phi| { let idx = phi.result().index(); @@ -745,7 +760,7 @@ impl SsaFunction { let remap = self.reassign_dense_ids(); self.remap_var_ids_in_blocks(&remap); self.rebuild_origin_versions(); - original_count - self.variables.len() + original_count.saturating_sub(self.variables.len()) } /// Reassigns all variable IDs to dense contiguous indices (0..N-1) and @@ -772,8 +787,7 @@ impl SsaFunction { let mut remap: BTreeMap<(usize, usize), usize> = BTreeMap::new(); let mut nop_sites: BTreeSet<(usize, usize)> = BTreeSet::new(); - for block_idx in 0..self.blocks.len() { - let block = &mut self.blocks[block_idx]; + for (block_idx, block) in self.blocks.iter_mut().enumerate() { let instructions = block.instructions_mut(); if !instructions.iter().any(|i| matches!(i.op(), SsaOp::Nop)) { @@ -788,7 +802,7 @@ impl SsaFunction { if old_idx != new_idx { remap.insert((block_idx, old_idx), new_idx); } - new_idx += 1; + new_idx = new_idx.saturating_add(1); } } @@ -819,9 +833,11 @@ impl SsaFunction { for var in &mut self.variables { let site = var.def_site(); if let Some(instr_idx) = site.instruction { - if site.block >= block_instr_counts.len() - || instr_idx >= block_instr_counts[site.block] - { + let out_of_bounds = match block_instr_counts.get(site.block) { + Some(&count) => instr_idx >= count, + None => true, + }; + if out_of_bounds { var.set_def_site(DefSite::entry()); } } @@ -998,7 +1014,8 @@ impl SsaFunction { // Determine the actual range of local indices (may exceed num_locals // when stack-originated locals have indices >= original num_locals) - let max_local_idx = used_locals.iter().copied().max().unwrap_or(0) as usize + 1; + let max_local_idx = + (used_locals.iter().copied().max().unwrap_or(0) as usize).saturating_add(1); let remap_size = max_local_idx.max(self.num_locals); // If no optimization needed (all locals used or no locals), return identity mapping @@ -1015,7 +1032,9 @@ impl SsaFunction { for (new_idx, &old_idx) in sorted_used.iter().enumerate() { #[allow(clippy::cast_possible_truncation)] let new_idx_u16 = new_idx as u16; - remap[old_idx as usize] = Some(new_idx_u16); + if let Some(slot) = remap.get_mut(old_idx as usize) { + *slot = Some(new_idx_u16); + } } let new_num_locals = sorted_used.len(); @@ -1023,7 +1042,7 @@ impl SsaFunction { // Phase 3: Update all variable origins for var in &mut self.variables { if let VariableOrigin::Local(idx) = var.origin() { - if let Some(new_idx) = remap[idx as usize] { + if let Some(Some(new_idx)) = remap.get(idx as usize).copied() { var.set_origin(VariableOrigin::Local(new_idx)); } } @@ -1033,7 +1052,7 @@ impl SsaFunction { for block in &mut self.blocks { for phi in block.phi_nodes_mut() { if let VariableOrigin::Local(idx) = phi.origin() { - if let Some(new_idx) = remap[idx as usize] { + if let Some(Some(new_idx)) = remap.get(idx as usize).copied() { phi.set_origin(VariableOrigin::Local(new_idx)); } } @@ -1046,7 +1065,7 @@ impl SsaFunction { match instr.op_mut() { SsaOp::LoadLocal { local_index, .. } | SsaOp::LoadLocalAddr { local_index, .. } => { - if let Some(new_idx) = remap[*local_index as usize] { + if let Some(Some(new_idx)) = remap.get(*local_index as usize).copied() { *local_index = new_idx; } } @@ -1128,8 +1147,8 @@ impl SsaFunction { // First, populate with any provided temporary types (highest priority for temps) for (idx, typ) in temp_types { let idx = *idx as usize; - if idx < local_types.len() { - local_types[idx] = Some(typ.clone()); + if let Some(slot) = local_types.get_mut(idx) { + *slot = Some(typ.clone()); } } @@ -1137,10 +1156,12 @@ impl SsaFunction { for var in &self.variables { if let VariableOrigin::Local(idx) = var.origin() { let idx = idx as usize; - if idx < local_types.len() && local_types[idx].is_none() { - let var_type = var.var_type(); - if !var_type.is_unknown() { - local_types[idx] = Some(var_type.clone()); + if let Some(slot) = local_types.get_mut(idx) { + if slot.is_none() { + let var_type = var.var_type(); + if !var_type.is_unknown() { + *slot = Some(var_type.clone()); + } } } } @@ -1151,11 +1172,13 @@ impl SsaFunction { for phi in block.phi_nodes() { if let VariableOrigin::Local(idx) = phi.origin() { let idx = idx as usize; - if idx < local_types.len() && local_types[idx].is_none() { - if let Some(var) = self.variable(phi.result()) { - let var_type = var.var_type(); - if !var_type.is_unknown() { - local_types[idx] = Some(var_type.clone()); + if let Some(slot) = local_types.get_mut(idx) { + if slot.is_none() { + if let Some(var) = self.variable(phi.result()) { + let var_type = var.var_type(); + if !var_type.is_unknown() { + *slot = Some(var_type.clone()); + } } } } @@ -1262,7 +1285,7 @@ impl SsaFunction { } } - let needed = max_local_idx.map_or(0, |idx| idx as usize + 1); + let needed = max_local_idx.map_or(0, |idx| (idx as usize).saturating_add(1)); self.num_locals = needed.max(self.original_num_locals); } } diff --git a/dotscope/src/analysis/ssa/liveness.rs b/dotscope/src/analysis/ssa/liveness.rs index 6d1a9722..a5661af2 100644 --- a/dotscope/src/analysis/ssa/liveness.rs +++ b/dotscope/src/analysis/ssa/liveness.rs @@ -40,8 +40,8 @@ pub fn compute_live_in_blocks( let mut predecessors: Vec> = vec![Vec::new(); block_count]; for (block_idx, succs) in successors.iter().enumerate() { for &succ in succs { - if succ < block_count { - predecessors[succ].push(block_idx); + if let Some(preds) = predecessors.get_mut(succ) { + preds.push(block_idx); } } } @@ -78,7 +78,10 @@ pub fn compute_live_in_blocks( // a successor and the predecessor doesn't define it (or it's live-in to the // predecessor due to a direct use). while let Some(block_idx) = worklist.pop() { - for &pred in &predecessors[block_idx] { + let Some(preds) = predecessors.get(block_idx) else { + continue; + }; + for &pred in preds { // If predecessor defines the variable, liveness doesn't propagate further // (the definition satisfies the use). But the predecessor itself is NOT // live-in for this variable (the def originates here). diff --git a/dotscope/src/analysis/ssa/memory.rs b/dotscope/src/analysis/ssa/memory.rs index 5cfceb27..9861e7cd 100644 --- a/dotscope/src/analysis/ssa/memory.rs +++ b/dotscope/src/analysis/ssa/memory.rs @@ -501,7 +501,7 @@ impl MemorySsa { fn allocate_version(&mut self, location: &MemoryLocation) -> u32 { let version = self.next_version.entry(location.clone()).or_insert(0); let result = *version; - *version += 1; + *version = version.saturating_add(1); result } @@ -666,7 +666,10 @@ impl MemorySsa { continue; } - for frontier_block in frontiers[node_id.index()].iter() { + let Some(frontier_set) = frontiers.get(node_id.index()) else { + continue; + }; + for frontier_block in frontier_set.iter() { if phi_blocks.insert(frontier_block) { // Add phi at frontier let version = self.allocate_version(&location); @@ -720,16 +723,20 @@ impl MemorySsa { let mut worklist = vec![cfg.entry().index()]; while let Some(block_idx) = worklist.pop() { - if visited[block_idx] { - continue; + match visited.get(block_idx) { + Some(true) => continue, + None => continue, + Some(false) => {} + } + if let Some(slot) = visited.get_mut(block_idx) { + *slot = true; } - visited[block_idx] = true; self.rename_block(block_idx, ssa, cfg, &mut version_stacks); // Add dominated blocks to worklist for child in dom_tree.children(NodeId::new(block_idx)) { - if !visited[child.index()] { + if visited.get(child.index()).copied() == Some(false) { worklist.push(child.index()); } } @@ -969,12 +976,15 @@ pub fn analyze_alias(loc1: &MemoryLocation, loc2: &MemoryLocation) -> AliasResul mod tests { use super::*; - use crate::analysis::ssa::{FieldRef, SsaVarId}; + use crate::{ + analysis::ssa::{FieldRef, SsaVarId}, + metadata::token::Token, + }; #[test] fn test_memory_location_static_field_alias() { - let field1 = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); - let field2 = FieldRef::new(crate::metadata::token::Token::new(0x04000002)); + let field1 = FieldRef::new(Token::new(0x04000001)); + let field2 = FieldRef::new(Token::new(0x04000002)); let loc1 = MemoryLocation::StaticField(field1); let loc2 = MemoryLocation::StaticField(field1); @@ -987,7 +997,7 @@ mod tests { #[test] fn test_memory_location_instance_field_alias() { - let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); + let field = FieldRef::new(Token::new(0x04000001)); let obj1 = SsaVarId::from_index(0); let obj2 = SsaVarId::from_index(1); @@ -1030,7 +1040,7 @@ mod tests { #[test] fn test_memory_location_unknown_alias() { - let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); + let field = FieldRef::new(Token::new(0x04000001)); let loc1 = MemoryLocation::Unknown; let loc2 = MemoryLocation::StaticField(field); @@ -1040,7 +1050,7 @@ mod tests { #[test] fn test_alias_result() { - let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); + let field = FieldRef::new(Token::new(0x04000001)); let loc1 = MemoryLocation::StaticField(field); let loc2 = MemoryLocation::StaticField(field); @@ -1057,7 +1067,7 @@ mod tests { #[test] fn test_memory_state() { let mut state = MemoryState::new(); - let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); + let field = FieldRef::new(Token::new(0x04000001)); let loc = MemoryLocation::StaticField(field); let value = SsaVarId::from_index(0); @@ -1072,7 +1082,7 @@ mod tests { #[test] fn test_memory_phi() { - let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); + let field = FieldRef::new(Token::new(0x04000001)); let loc = MemoryLocation::StaticField(field); let mut phi = MemoryPhi::new(loc.clone(), 2); @@ -1088,7 +1098,7 @@ mod tests { #[test] fn test_memory_op() { - let field = FieldRef::new(crate::metadata::token::Token::new(0x04000001)); + let field = FieldRef::new(Token::new(0x04000001)); let loc = MemoryLocation::StaticField(field); let dest = SsaVarId::from_index(0); let value = SsaVarId::from_index(1); diff --git a/dotscope/src/analysis/ssa/ops.rs b/dotscope/src/analysis/ssa/ops.rs index f53c8e9d..4b04576f 100644 --- a/dotscope/src/analysis/ssa/ops.rs +++ b/dotscope/src/analysis/ssa/ops.rs @@ -1235,13 +1235,13 @@ impl SsaOp { /// /// The number of replacements made. pub fn replace_uses(&mut self, old_var: SsaVarId, new_var: SsaVarId) -> usize { - let mut count = 0; + let mut count: usize = 0; // Helper closure to replace a variable let mut replace = |var: &mut SsaVarId| { if *var == old_var { *var = new_var; - count += 1; + count = count.saturating_add(1); } }; @@ -2575,7 +2575,7 @@ impl SsaOp { // Indirect call pops args + function pointer // args.len() will never exceed u32 for CIL methods #[allow(clippy::cast_possible_truncation)] - let pops = args.len() as u32 + 1; + let pops = (args.len() as u32).saturating_add(1); let pushes = u32::from(dest.is_some()); (pops, pushes) } diff --git a/dotscope/src/analysis/ssa/patterns.rs b/dotscope/src/analysis/ssa/patterns.rs index 40cd7767..f6987868 100644 --- a/dotscope/src/analysis/ssa/patterns.rs +++ b/dotscope/src/analysis/ssa/patterns.rs @@ -159,8 +159,8 @@ impl<'a> PatternDetector<'a> { let block_count = self.ssa.block_count().max(1); let mut visited = BitSet::new(block_count); let mut queue = vec![from_block]; - let max_depth = 50; // Prevent infinite loops - let mut depth = 0; + let max_depth: u32 = 50; // Prevent infinite loops + let mut depth: u32 = 0; while !queue.is_empty() && depth < max_depth { let mut next_queue = Vec::new(); @@ -181,7 +181,7 @@ impl<'a> PatternDetector<'a> { } queue = next_queue; - depth += 1; + depth = depth.saturating_add(1); } false @@ -506,11 +506,7 @@ impl DispatcherPattern { /// Gets the target block for a specific case index. #[must_use] pub fn target_for_case(&self, case_idx: usize) -> usize { - if case_idx < self.targets.len() { - self.targets[case_idx] - } else { - self.default - } + self.targets.get(case_idx).copied().unwrap_or(self.default) } } diff --git a/dotscope/src/analysis/ssa/phis.rs b/dotscope/src/analysis/ssa/phis.rs index 4683ff4e..acc3fcc7 100644 --- a/dotscope/src/analysis/ssa/phis.rs +++ b/dotscope/src/analysis/ssa/phis.rs @@ -302,7 +302,7 @@ impl<'a> PhiAnalyzer<'a> { } // Get the first operand's constant value - let first_value = evaluator.evaluate_var(operands[0].value())?; + let first_value = evaluator.evaluate_var(operands.first()?.value())?; // Check that all other operands have the same value for operand in operands.iter().skip(1) { @@ -387,8 +387,8 @@ pub(crate) fn place_pruned_phis( while let Some(block_idx) = worklist.pop() { let node_id = NodeId::new(block_idx); - if node_id.index() < dominance_frontiers.len() { - for frontier_idx in dominance_frontiers[node_id.index()].iter() { + if let Some(frontier) = dominance_frontiers.get(node_id.index()) { + for frontier_idx in frontier.iter() { let is_reachable = reachable.is_none_or(|r| r.contains(frontier_idx)); if frontier_idx < block_count && is_reachable && phi_blocks.insert(frontier_idx) { diff --git a/dotscope/src/analysis/ssa/stack.rs b/dotscope/src/analysis/ssa/stack.rs index b56850e5..fb1ba933 100644 --- a/dotscope/src/analysis/ssa/stack.rs +++ b/dotscope/src/analysis/ssa/stack.rs @@ -259,18 +259,18 @@ impl StackSimulator { /// * `num_locals` - Number of local variables #[must_use] pub fn new(num_args: usize, num_locals: usize) -> Self { - let mut next_sim_id = SIMULATION_ID_OFFSET; + let mut next_sim_id: usize = SIMULATION_ID_OFFSET; let mut args = Vec::with_capacity(num_args); for _ in 0..num_args { args.push(VariableState::new(SsaVarId::from_index(next_sim_id))); - next_sim_id += 1; + next_sim_id = next_sim_id.saturating_add(1); } let mut locals = Vec::with_capacity(num_locals); for _ in 0..num_locals { locals.push(VariableState::new(SsaVarId::from_index(next_sim_id))); - next_sim_id += 1; + next_sim_id = next_sim_id.saturating_add(1); } Self { @@ -292,7 +292,7 @@ impl StackSimulator { /// These IDs are temporary placeholders replaced during the SSA rename phase. fn alloc_sim_id(&mut self) -> SsaVarId { let id = SsaVarId::from_index(self.next_sim_id); - self.next_sim_id += 1; + self.next_sim_id = self.next_sim_id.saturating_add(1); id } @@ -453,10 +453,10 @@ impl StackSimulator { let var = self.alloc_sim_id(); let depth = self.stack.len(); #[allow(clippy::cast_possible_truncation)] - let local_idx = self.num_locals as u16 + depth as u16; + let local_idx = (self.num_locals as u16).saturating_add(depth as u16); let origin = VariableOrigin::Local(local_idx); // Track max depth for total_stack_slots() - let depth_count = (depth + 1) as u32; + let depth_count = depth.saturating_add(1) as u32; if depth_count > self.next_stack_slot { self.next_stack_slot = depth_count; } @@ -583,8 +583,8 @@ impl StackSimulator { let new_var = self.alloc_sim_id(); - let state = &mut self.args[index]; - state.version += 1; + let state = self.args.get_mut(index)?; + state.version = state.version.saturating_add(1); state.current_var = new_var; // Return new_var as def to enable Copy op generation for constant propagation @@ -611,8 +611,8 @@ impl StackSimulator { } let new_var = self.alloc_sim_id(); - let state = &mut self.locals[index]; - state.version += 1; + let state = self.locals.get_mut(index)?; + state.version = state.version.saturating_add(1); state.current_var = new_var; // Return new_var as def to enable Copy op generation for constant propagation @@ -737,10 +737,10 @@ impl StackSimulator { } let mut result = Vec::with_capacity(count); - let start_idx = self.stack.len() - count; + let start_idx = self.stack.len().saturating_sub(count); for i in start_idx..self.stack.len() { - result.push(self.stack[i].var); + result.push(self.stack.get(i)?.var); } self.stack.truncate(start_idx); diff --git a/dotscope/src/analysis/ssa/symbolic/expr.rs b/dotscope/src/analysis/ssa/symbolic/expr.rs index 900179fc..816c7614 100644 --- a/dotscope/src/analysis/ssa/symbolic/expr.rs +++ b/dotscope/src/analysis/ssa/symbolic/expr.rs @@ -683,8 +683,10 @@ impl SymbolicExpr { pub fn depth(&self) -> usize { match self { Self::Constant(_) | Self::Variable(_) | Self::NamedVar(_) => 0, - Self::Unary { operand, .. } => 1 + operand.depth(), - Self::Binary { left, right, .. } => 1 + left.depth().max(right.depth()), + Self::Unary { operand, .. } => 1usize.saturating_add(operand.depth()), + Self::Binary { left, right, .. } => { + 1usize.saturating_add(left.depth().max(right.depth())) + } } } } diff --git a/dotscope/src/analysis/ssa/types.rs b/dotscope/src/analysis/ssa/types.rs index 671c03c1..ede4d449 100644 --- a/dotscope/src/analysis/ssa/types.rs +++ b/dotscope/src/analysis/ssa/types.rs @@ -1084,7 +1084,7 @@ impl<'a> TypeContext<'a> { .map_or(SsaType::Object, |dt| SsaType::Class(TypeRef::new(dt.token))); } // Adjust for 'this' offset - if let Some(param) = self.method.signature.params.get(idx - 1) { + if let Some(param) = self.method.signature.params.get(idx.saturating_sub(1)) { let ty = SsaType::from_type_signature(¶m.base, self.assembly); return self.validate_generic_params(ty); } diff --git a/dotscope/src/analysis/ssa/value.rs b/dotscope/src/analysis/ssa/value.rs index 2af37d05..93bc5c13 100644 --- a/dotscope/src/analysis/ssa/value.rs +++ b/dotscope/src/analysis/ssa/value.rs @@ -853,45 +853,44 @@ impl ConstValue { .map(|v| v.mask_native(ptr_size)) } - /// Attempts to divide two constants. + /// Attempts to divide two constants. Uses `checked_div`/`checked_rem` so + /// MIN/-1 overflows fold to `None` rather than wrapping silently. #[must_use] pub fn div(&self, other: &Self, ptr_size: PointerSize) -> Option { match (self, other) { - (Self::I8(a), Self::I8(b)) if *b != 0 => Some(Self::I8(a.wrapping_div(*b))), - (Self::I16(a), Self::I16(b)) if *b != 0 => Some(Self::I16(a.wrapping_div(*b))), - (Self::I32(a), Self::I32(b)) if *b != 0 => Some(Self::I32(a.wrapping_div(*b))), - (Self::I64(a), Self::I64(b)) if *b != 0 => Some(Self::I64(a.wrapping_div(*b))), - (Self::U8(a), Self::U8(b)) if *b != 0 => Some(Self::U8(a / b)), - (Self::U16(a), Self::U16(b)) if *b != 0 => Some(Self::U16(a / b)), - (Self::U32(a), Self::U32(b)) if *b != 0 => Some(Self::U32(a / b)), - (Self::U64(a), Self::U64(b)) if *b != 0 => Some(Self::U64(a / b)), - (Self::NativeInt(a), Self::NativeInt(b)) if *b != 0 => { - Some(Self::NativeInt(a.wrapping_div(*b))) - } - (Self::NativeUInt(a), Self::NativeUInt(b)) if *b != 0 => Some(Self::NativeUInt(a / b)), - (Self::F32(a), Self::F32(b)) => Some(Self::F32(a / b)), // Float div by zero is inf + (Self::I8(a), Self::I8(b)) => a.checked_div(*b).map(Self::I8), + (Self::I16(a), Self::I16(b)) => a.checked_div(*b).map(Self::I16), + (Self::I32(a), Self::I32(b)) => a.checked_div(*b).map(Self::I32), + (Self::I64(a), Self::I64(b)) => a.checked_div(*b).map(Self::I64), + (Self::U8(a), Self::U8(b)) => a.checked_div(*b).map(Self::U8), + (Self::U16(a), Self::U16(b)) => a.checked_div(*b).map(Self::U16), + (Self::U32(a), Self::U32(b)) => a.checked_div(*b).map(Self::U32), + (Self::U64(a), Self::U64(b)) => a.checked_div(*b).map(Self::U64), + (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_div(*b).map(Self::NativeInt), + (Self::NativeUInt(a), Self::NativeUInt(b)) => a.checked_div(*b).map(Self::NativeUInt), + // Float div by zero is inf — IEEE 754 has no panic, no overflow. + (Self::F32(a), Self::F32(b)) => Some(Self::F32(a / b)), (Self::F64(a), Self::F64(b)) => Some(Self::F64(a / b)), _ => None, } .map(|v| v.mask_native(ptr_size)) } - /// Attempts to compute remainder (modulo) of two constants. + /// Attempts to compute remainder (modulo) of two constants. Uses + /// `checked_rem` so MIN%-1 overflows fold to `None`. #[must_use] pub fn rem(&self, other: &Self, ptr_size: PointerSize) -> Option { match (self, other) { - (Self::I8(a), Self::I8(b)) if *b != 0 => Some(Self::I8(a.wrapping_rem(*b))), - (Self::I16(a), Self::I16(b)) if *b != 0 => Some(Self::I16(a.wrapping_rem(*b))), - (Self::I32(a), Self::I32(b)) if *b != 0 => Some(Self::I32(a.wrapping_rem(*b))), - (Self::I64(a), Self::I64(b)) if *b != 0 => Some(Self::I64(a.wrapping_rem(*b))), - (Self::U8(a), Self::U8(b)) if *b != 0 => Some(Self::U8(a % b)), - (Self::U16(a), Self::U16(b)) if *b != 0 => Some(Self::U16(a % b)), - (Self::U32(a), Self::U32(b)) if *b != 0 => Some(Self::U32(a % b)), - (Self::U64(a), Self::U64(b)) if *b != 0 => Some(Self::U64(a % b)), - (Self::NativeInt(a), Self::NativeInt(b)) if *b != 0 => { - Some(Self::NativeInt(a.wrapping_rem(*b))) - } - (Self::NativeUInt(a), Self::NativeUInt(b)) if *b != 0 => Some(Self::NativeUInt(a % b)), + (Self::I8(a), Self::I8(b)) => a.checked_rem(*b).map(Self::I8), + (Self::I16(a), Self::I16(b)) => a.checked_rem(*b).map(Self::I16), + (Self::I32(a), Self::I32(b)) => a.checked_rem(*b).map(Self::I32), + (Self::I64(a), Self::I64(b)) => a.checked_rem(*b).map(Self::I64), + (Self::U8(a), Self::U8(b)) => a.checked_rem(*b).map(Self::U8), + (Self::U16(a), Self::U16(b)) => a.checked_rem(*b).map(Self::U16), + (Self::U32(a), Self::U32(b)) => a.checked_rem(*b).map(Self::U32), + (Self::U64(a), Self::U64(b)) => a.checked_rem(*b).map(Self::U64), + (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_rem(*b).map(Self::NativeInt), + (Self::NativeUInt(a), Self::NativeUInt(b)) => a.checked_rem(*b).map(Self::NativeUInt), (Self::F32(a), Self::F32(b)) => Some(Self::F32(a % b)), (Self::F64(a), Self::F64(b)) => Some(Self::F64(a % b)), _ => None, @@ -1234,7 +1233,9 @@ impl fmt::Display for ConstValue { write!( f, "array[{}x{}]<0x{:08X}>", - data.len() / element_size.max(&1), + data.len() + .checked_div(*element_size.max(&1)) + .unwrap_or(data.len()), element_size, element_type_token.value() ) @@ -1264,7 +1265,7 @@ impl fmt::Display for ConstValue { /// /// # Errors /// -/// Returns [`Error::SsaError`] for non-numeric `ConstValue` variants +/// Returns [`crate::Error::SsaError`] for non-numeric `ConstValue` variants /// (`String`, `DecryptedString`, `Null`, `Type`, `MethodHandle`, `FieldHandle`) /// since these cannot be represented as immediate values. Handle these cases /// with pattern matching before conversion. @@ -1544,8 +1545,10 @@ impl ComputedValue { pub fn normalized(self) -> Self { if self.op.is_commutative() && self.operands.len() == 2 { let mut ops = self.operands; - if ops[0].index() > ops[1].index() { - ops.swap(0, 1); + if let (Some(a), Some(b)) = (ops.first(), ops.get(1)) { + if a.index() > b.index() { + ops.swap(0, 1); + } } Self { op: self.op, diff --git a/dotscope/src/analysis/ssa/variable.rs b/dotscope/src/analysis/ssa/variable.rs index f7dc0ce0..e80a46d1 100644 --- a/dotscope/src/analysis/ssa/variable.rs +++ b/dotscope/src/analysis/ssa/variable.rs @@ -240,7 +240,7 @@ impl FunctionVarAllocator { /// Allocates the next dense variable ID. pub fn alloc(&mut self) -> SsaVarId { let id = SsaVarId::from_index(self.next_id); - self.next_id += 1; + self.next_id = self.next_id.saturating_add(1); id } diff --git a/dotscope/src/analysis/ssa/verifier.rs b/dotscope/src/analysis/ssa/verifier.rs index 7b218f60..3e0e634c 100644 --- a/dotscope/src/analysis/ssa/verifier.rs +++ b/dotscope/src/analysis/ssa/verifier.rs @@ -371,7 +371,7 @@ impl<'a> SsaVerifier<'a> { .flat_map(|phi| phi.operands().iter().map(|op| op.predecessor())) .max() .unwrap_or(0); - let capacity = block_count.max(max_phi_pred + 1).max(1); + let capacity = block_count.max(max_phi_pred.saturating_add(1)).max(1); let mut preds = BitSet::new(capacity); for &p in pred_list { if p < capacity { @@ -463,8 +463,9 @@ impl<'a> SsaVerifier<'a> { .map(|v| v.id().index()) .max() .unwrap_or(0); - let capacity = (max_block_var + 1) - .max(max_reg_var + 1) + let capacity = max_block_var + .saturating_add(1) + .max(max_reg_var.saturating_add(1)) .max(variable_count) .max(1); let mut registered = BitSet::new(capacity); diff --git a/dotscope/src/analysis/taint.rs b/dotscope/src/analysis/taint.rs index 4fef5ebe..6138d07b 100644 --- a/dotscope/src/analysis/taint.rs +++ b/dotscope/src/analysis/taint.rs @@ -312,13 +312,13 @@ impl TaintAnalysis { /// /// * `ssa` - The SSA function to analyze. pub fn propagate(&mut self, ssa: &SsaFunction) { - let mut iterations = 0; + let mut iterations: usize = 0; loop { if iterations >= self.config.max_iterations { break; } - iterations += 1; + iterations = iterations.saturating_add(1); let mut changed = false; diff --git a/dotscope/src/analysis/x86/cfg.rs b/dotscope/src/analysis/x86/cfg.rs index c1dccfea..dc773bd1 100644 --- a/dotscope/src/analysis/x86/cfg.rs +++ b/dotscope/src/analysis/x86/cfg.rs @@ -112,30 +112,40 @@ impl X86Function { let mut leaders = BTreeSet::new(); // First instruction is always a leader - leaders.insert(instructions[0].offset); + if let Some(first) = instructions.first() { + leaders.insert(first.offset); + } // Find jump targets and fallthrough points for (i, instr) in instructions.iter().enumerate() { + let next_idx = i.saturating_add(1); + let next_offset = instructions.get(next_idx).map(|x| x.offset); match &instr.instruction { X86Instruction::Jmp { target } => { - // Target is a leader - leaders.insert(*target - base_address); + if let Some(rel) = target.checked_sub(base_address) { + leaders.insert(rel); + } } X86Instruction::Jcc { target, .. } => { - // Target is a leader - leaders.insert(*target - base_address); + if let Some(rel) = target.checked_sub(base_address) { + leaders.insert(rel); + } // Fallthrough is also a leader (instruction after this one) - if i + 1 < instructions.len() { - leaders.insert(instructions[i + 1].offset); + if let Some(off) = next_offset { + leaders.insert(off); } } // Instruction after call is a leader (call returns) - X86Instruction::Call { .. } if i + 1 < instructions.len() => { - leaders.insert(instructions[i + 1].offset); + X86Instruction::Call { .. } => { + if let Some(off) = next_offset { + leaders.insert(off); + } } // Instruction after ret is a leader (if any) - X86Instruction::Ret if i + 1 < instructions.len() => { - leaders.insert(instructions[i + 1].offset); + X86Instruction::Ret => { + if let Some(off) = next_offset { + leaders.insert(off); + } } _ => {} } @@ -160,7 +170,10 @@ impl X86Function { }; // Find end of this block (exclusive) - let end_offset = leader_list.get(block_idx + 1).copied().unwrap_or(u64::MAX); + let end_offset = leader_list + .get(block_idx.saturating_add(1)) + .copied() + .unwrap_or(u64::MAX); // Collect instructions for this block let mut block_instrs = Vec::new(); @@ -199,7 +212,7 @@ impl X86Function { // Step 5: Build the DirectedGraph let mut graph: DirectedGraph<'static, X86BasicBlock, X86EdgeKind> = - DirectedGraph::with_capacity(blocks.len(), blocks.len() * 2); + DirectedGraph::with_capacity(blocks.len(), blocks.len().saturating_mul(2)); // Add all blocks as nodes for block in blocks { @@ -324,23 +337,31 @@ impl X86Function { in_stack: &mut [bool], ) -> bool { let idx = node.index(); - visited[idx] = true; - in_stack[idx] = true; + if let Some(slot) = visited.get_mut(idx) { + *slot = true; + } + if let Some(slot) = in_stack.get_mut(idx) { + *slot = true; + } for succ in func.graph.successors(node) { let succ_idx = succ.index(); - if in_stack[succ_idx] { + if in_stack.get(succ_idx).copied().unwrap_or(false) { // This is a back edge (n -> succ where succ is on the stack) // For reducibility, succ must dominate node if !doms.dominates(succ, node) { return false; } - } else if !visited[succ_idx] && !dfs_check(succ, func, doms, visited, in_stack) { + } else if !visited.get(succ_idx).copied().unwrap_or(false) + && !dfs_check(succ, func, doms, visited, in_stack) + { return false; } } - in_stack[idx] = false; + if let Some(slot) = in_stack.get_mut(idx) { + *slot = false; + } true } @@ -403,22 +424,24 @@ fn compute_edges( if let Some(term) = block.terminator() { match term { X86Instruction::Jmp { target } => { - let target_offset = target - base_address; - if let Some(&block_idx) = offset_to_block.get(&target_offset) { - edges.push((block_idx, X86EdgeKind::Unconditional)); + if let Some(target_offset) = target.checked_sub(base_address) { + if let Some(&block_idx) = offset_to_block.get(&target_offset) { + edges.push((block_idx, X86EdgeKind::Unconditional)); + } } } X86Instruction::Jcc { target, condition } => { // Conditional jump: two edges // 1. Target (condition true) - let target_offset = target - base_address; - if let Some(&block_idx) = offset_to_block.get(&target_offset) { - edges.push(( - block_idx, - X86EdgeKind::ConditionalTrue { - condition: *condition, - }, - )); + if let Some(target_offset) = target.checked_sub(base_address) { + if let Some(&block_idx) = offset_to_block.get(&target_offset) { + edges.push(( + block_idx, + X86EdgeKind::ConditionalTrue { + condition: *condition, + }, + )); + } } // 2. Fallthrough (condition false) if let Some(&block_idx) = offset_to_block.get(&block.end_offset) { diff --git a/dotscope/src/analysis/x86/decoder.rs b/dotscope/src/analysis/x86/decoder.rs index a5a2354f..8dbc3a97 100644 --- a/dotscope/src/analysis/x86/decoder.rs +++ b/dotscope/src/analysis/x86/decoder.rs @@ -31,7 +31,7 @@ use std::collections::VecDeque; /// /// # Errors /// -/// Returns [`Error::X86Error`] if: +/// Returns [`crate::Error::X86Error`] if: /// - `bytes` is empty /// - `bitness` is not 32 or 64 /// - An invalid instruction is encountered @@ -55,7 +55,10 @@ pub fn x86_decode_all( let mut instructions = Vec::new(); for instr in &mut decoder { - let offset = instr.ip() - base_address; + let offset = instr + .ip() + .checked_sub(base_address) + .ok_or_else(|| Error::X86Error("Instruction IP below base address".into()))?; let length = instr.len(); // Check for invalid instruction @@ -150,7 +153,7 @@ pub struct X86TraversalDecodeResult { /// /// # Errors /// -/// Returns [`Error::X86Error`] if: +/// Returns [`crate::Error::X86Error`] if: /// - `bytes` is empty /// - `bitness` is not 32 or 64 /// @@ -172,7 +175,9 @@ pub fn x86_decode_traversal( } let code_start = base_address; - let code_end = base_address + bytes.len() as u64; + let code_end = base_address + .checked_add(bytes.len() as u64) + .ok_or_else(|| Error::X86Error("Code region end address overflow".into()))?; // Worklist of offsets to decode (relative to base_address) let mut worklist: VecDeque = VecDeque::new(); @@ -185,8 +190,11 @@ pub fn x86_decode_traversal( let mut has_indirect = false; // Start from entry point - worklist.push_back(base_address + entry_offset); - visited.insert(base_address + entry_offset); + let entry_addr = base_address + .checked_add(entry_offset) + .ok_or_else(|| Error::X86Error("Entry address overflow".into()))?; + worklist.push_back(entry_addr); + visited.insert(entry_addr); while let Some(addr) = worklist.pop_front() { // Check if address is within bounds @@ -195,8 +203,10 @@ pub fn x86_decode_traversal( } #[allow(clippy::cast_possible_truncation)] - let offset_in_bytes = (addr - base_address) as usize; - let remaining_bytes = &bytes[offset_in_bytes..]; + let offset_in_bytes = addr.saturating_sub(base_address) as usize; + let Some(remaining_bytes) = bytes.get(offset_in_bytes..) else { + continue; + }; if remaining_bytes.is_empty() { continue; @@ -211,16 +221,16 @@ pub fn x86_decode_traversal( continue; } - let offset = addr - base_address; + let offset = addr.saturating_sub(base_address); let length = instr.len(); // Check if we've already decoded an instruction that overlaps // (this can happen with certain obfuscation tricks) let overlaps = instructions.iter().any(|existing| { let existing_start = existing.offset; - let existing_end = existing.offset + existing.length as u64; + let existing_end = existing.offset.saturating_add(existing.length as u64); let new_start = offset; - let new_end = offset + length as u64; + let new_end = offset.saturating_add(length as u64); // Check for overlap new_start < existing_end && new_end > existing_start }); @@ -239,7 +249,7 @@ pub fn x86_decode_traversal( }; // Determine successors based on instruction type - let next_addr = addr + length as u64; + let next_addr = addr.saturating_add(length as u64); match &converted { X86Instruction::Ret => { @@ -355,7 +365,7 @@ pub fn x86_native_body_size(bytes: &[u8], is_64bit: bool) -> usize { Ok(result) => result .instructions .iter() - .map(|instr| instr.offset as usize + instr.length) + .map(|instr| (instr.offset as usize).saturating_add(instr.length)) .max() .unwrap_or(0), Err(_) => 0, @@ -380,7 +390,7 @@ pub fn x86_native_body_size(bytes: &[u8], is_64bit: bool) -> usize { /// /// # Errors /// -/// Returns [`Error::X86Error`] if: +/// Returns [`crate::Error::X86Error`] if: /// - `bytes` is empty /// - `bitness` is not 32 or 64 /// - `offset` is beyond the end of `bytes` @@ -408,13 +418,13 @@ pub fn x86_decode_single( ))); } - let remaining = &bytes[offset_in_bytes..]; - let mut decoder = Decoder::with_ip( - bitness, - remaining, - base_address + offset, - DecoderOptions::NONE, - ); + let remaining = bytes + .get(offset_in_bytes..) + .ok_or_else(|| Error::X86Error(format!("Invalid instruction at offset 0x{offset:x}")))?; + let ip = base_address + .checked_add(offset) + .ok_or_else(|| Error::X86Error("Instruction IP overflow".into()))?; + let mut decoder = Decoder::with_ip(bitness, remaining, ip, DecoderOptions::NONE); if let Some(instr) = decoder.iter().next() { if instr.is_invalid() { @@ -439,7 +449,10 @@ pub fn x86_decode_single( /// Convert an iced-x86 instruction to our simplified representation. fn convert_instruction(instr: &Instruction, base_address: u64) -> Result { - let offset = instr.ip() - base_address; + let offset = instr + .ip() + .checked_sub(base_address) + .ok_or_else(|| Error::X86Error("Instruction IP below base address".into()))?; match instr.mnemonic() { // Data movement @@ -1081,7 +1094,7 @@ pub fn x86_detect_prologue(bytes: &[u8], bitness: u32) -> X86PrologueInfo { }; } - if bytes.len() >= 20 && bytes[..20] == DYNCIPHER_PROLOGUE { + if bytes.get(..20) == Some(&DYNCIPHER_PROLOGUE[..]) { return X86PrologueInfo { kind: X86PrologueKind::DynCipher, size: 20, @@ -1089,9 +1102,11 @@ pub fn x86_detect_prologue(bytes: &[u8], bitness: u32) -> X86PrologueInfo { }; } + let starts_with = + |pat: &[u8]| -> bool { bytes.get(..pat.len()).is_some_and(|prefix| prefix == pat) }; + // Standard 32-bit prologue: push ebp; mov ebp, esp (MSVC) - if bitness == 32 && bytes.len() >= 3 && bytes[0] == 0x55 && bytes[1] == 0x8B && bytes[2] == 0xEC - { + if bitness == 32 && starts_with(&[0x55, 0x8B, 0xEC]) { return X86PrologueInfo { kind: X86PrologueKind::Standard32, size: 3, @@ -1100,8 +1115,7 @@ pub fn x86_detect_prologue(bytes: &[u8], bitness: u32) -> X86PrologueInfo { } // Standard 32-bit prologue: push ebp; mov ebp, esp (GCC) - if bitness == 32 && bytes.len() >= 3 && bytes[0] == 0x55 && bytes[1] == 0x89 && bytes[2] == 0xE5 - { + if bitness == 32 && starts_with(&[0x55, 0x89, 0xE5]) { return X86PrologueInfo { kind: X86PrologueKind::Standard32, size: 3, @@ -1110,13 +1124,7 @@ pub fn x86_detect_prologue(bytes: &[u8], bitness: u32) -> X86PrologueInfo { } // Standard 64-bit prologue: push rbp; mov rbp, rsp - if bitness == 64 - && bytes.len() >= 4 - && bytes[0] == 0x55 - && bytes[1] == 0x48 - && bytes[2] == 0x89 - && bytes[3] == 0xE5 - { + if bitness == 64 && starts_with(&[0x55, 0x48, 0x89, 0xE5]) { return X86PrologueInfo { kind: X86PrologueKind::Standard64, size: 4, @@ -1132,7 +1140,7 @@ pub fn x86_detect_prologue(bytes: &[u8], bitness: u32) -> X86PrologueInfo { }; for pattern in patterns { - if bytes.len() >= pattern.len() && bytes[..pattern.len()] == **pattern { + if starts_with(pattern) { return X86PrologueInfo { kind: X86PrologueKind::StackFrame { is_64bit: bitness == 64, @@ -1168,31 +1176,35 @@ pub fn x86_detect_epilogue(instructions: &[X86DecodedInstruction]) -> Option Option { let idx = reg.base_index() as usize; - if idx < MAX_REGISTERS { - self.registers[idx] - } else { - None - } + self.registers.get(idx).copied().flatten() } /// Sets the SSA variable for a register. fn set(&mut self, reg: X86Register, var: SsaVarId) { let idx = reg.base_index() as usize; - if idx < MAX_REGISTERS { - self.registers[idx] = Some(var); + if let Some(slot) = self.registers.get_mut(idx) { + *slot = Some(var); } } @@ -203,7 +199,7 @@ impl<'a> X86ToSsaTranslator<'a> { self.analyze_definitions(); // Step 2: Place phi nodes using dominance frontiers - self.place_phi_nodes(); + self.place_phi_nodes()?; // Step 3: Translate blocks in dominator tree order let ssa_blocks = self.translate_blocks()?; @@ -237,7 +233,9 @@ impl<'a> X86ToSsaTranslator<'a> { if let Some(block) = self.func.block(block_idx) { for instr in &block.instructions { if let Some(reg_idx) = get_defined_register(&instr.instruction) { - self.reg_def_blocks[reg_idx].insert(block_idx); + if let Some(set) = self.reg_def_blocks.get_mut(reg_idx) { + set.insert(block_idx); + } } } } @@ -245,16 +243,26 @@ impl<'a> X86ToSsaTranslator<'a> { } /// Places phi nodes at dominance frontiers. - fn place_phi_nodes(&mut self) { + fn place_phi_nodes(&mut self) -> Result<()> { let doms = self.func.dominators(); let block_count = self.func.block_count(); let bitness = self.func.bitness; - let register_count = self.block_exit_states[0].register_count(); + let register_count = self + .block_exit_states + .first() + .ok_or_else(|| Error::SsaError("place_phi_nodes: block_exit_states is empty".into()))? + .register_count(); // For each register that is defined somewhere for reg_idx in 0..register_count { // Clone the def_blocks to avoid borrow issues - let def_blocks: FxHashSet = self.reg_def_blocks[reg_idx].clone(); + let def_blocks: FxHashSet = self + .reg_def_blocks + .get(reg_idx) + .ok_or_else(|| { + Error::SsaError("place_phi_nodes: reg_def_blocks index out of bounds".into()) + })? + .clone(); if def_blocks.is_empty() { continue; } @@ -295,6 +303,7 @@ impl<'a> X86ToSsaTranslator<'a> { } } } + Ok(()) } /// Translates all blocks to SSA form. @@ -393,7 +402,12 @@ impl<'a> X86ToSsaTranslator<'a> { } // Save exit state for phi operand computation - self.block_exit_states[block_idx] = reg_state; + let slot = self.block_exit_states.get_mut(block_idx).ok_or_else(|| { + Error::SsaError(format!( + "translate_block: block_exit_states out of bounds at {block_idx}" + )) + })?; + *slot = reg_state; Ok(ssa_block) } @@ -421,12 +435,18 @@ impl<'a> X86ToSsaTranslator<'a> { let preds: Vec<_> = self.func.predecessors(node).collect(); if preds.len() == 1 { - let pred_idx = preds[0].index(); - // Copy predecessor's exit state for registers without phi nodes - for reg_idx in 0..state.register_count() { - if !self.phi_placement.has(block_idx, reg_idx) { - if let Some(var) = self.block_exit_states[pred_idx].registers[reg_idx] { - state.registers[reg_idx] = Some(var); + if let Some(pred) = preds.first() { + let pred_idx = pred.index(); + // Copy predecessor's exit state for registers without phi nodes + if let Some(pred_state) = self.block_exit_states.get(pred_idx) { + for reg_idx in 0..state.register_count() { + if !self.phi_placement.has(block_idx, reg_idx) { + if let Some(Some(var)) = pred_state.registers.get(reg_idx) { + if let Some(slot) = state.registers.get_mut(reg_idx) { + *slot = Some(*var); + } + } + } } } } @@ -449,8 +469,10 @@ impl<'a> X86ToSsaTranslator<'a> { // Add operand from each predecessor for pred in self.func.predecessors(node) { let pred_idx = pred.index(); - if let Some(var) = self.block_exit_states[pred_idx].registers[reg_idx] { - phi.add_operand(PhiOperand::new(var, pred_idx)); + if let Some(pred_state) = self.block_exit_states.get(pred_idx) { + if let Some(Some(var)) = pred_state.registers.get(reg_idx) { + phi.add_operand(PhiOperand::new(*var, pred_idx)); + } } } } @@ -1355,7 +1377,7 @@ impl<'a> X86ToSsaTranslator<'a> { X86Instruction::Jcc { condition, target } => { let target_block = self.find_block_for_address(*target)?; - let fallthrough_block = block_idx + 1; // Assumes sequential layout + let fallthrough_block = block_idx.saturating_add(1); // Assumes sequential layout // Get comparison operands from flags if let Some((cmp, left, right, unsigned)) = flags.get_branch_operands(*condition) { @@ -2311,7 +2333,12 @@ impl<'a> X86ToSsaTranslator<'a> { /// Finds the block index for a given address. fn find_block_for_address(&self, addr: u64) -> Result { - let offset = addr - self.func.base_address; + let offset = addr.checked_sub(self.func.base_address).ok_or_else(|| { + Error::X86Error(format!( + "Address 0x{addr:x} is below base 0x{:x}", + self.func.base_address + )) + })?; for node_id in self.func.node_ids() { let idx = node_id.index(); diff --git a/dotscope/src/analysis/x86/types.rs b/dotscope/src/analysis/x86/types.rs index d08a1a5b..4876d144 100644 --- a/dotscope/src/analysis/x86/types.rs +++ b/dotscope/src/analysis/x86/types.rs @@ -963,7 +963,7 @@ impl X86DecodedInstruction { #[inline] #[must_use] pub fn end_offset(&self) -> u64 { - self.offset + self.length as u64 + self.offset.saturating_add(self.length as u64) } } diff --git a/dotscope/src/assembly/builder.rs b/dotscope/src/assembly/builder.rs index b48b4f42..4e4daf17 100644 --- a/dotscope/src/assembly/builder.rs +++ b/dotscope/src/assembly/builder.rs @@ -339,9 +339,13 @@ impl InstructionAssembler { handlers.push(ExceptionHandler { flags, try_offset: *try_start, - try_length: try_end - try_start, + try_length: try_end + .checked_sub(*try_start) + .ok_or_else(|| malformed_error!("try region length underflow"))?, handler_offset: *handler_start, - handler_length: handler_end - handler_start, + handler_length: handler_end + .checked_sub(*handler_start) + .ok_or_else(|| malformed_error!("handler region length underflow"))?, handler: class_token, filter_offset, }); @@ -493,20 +497,19 @@ impl InstructionAssembler { /// by indexing into the static `INSTRUCTIONS` / `INSTRUCTIONS_FE` tables. fn lookup_mnemonic(opcode: u16) -> Result<&'static str> { if opcode < u16::from(INSTRUCTIONS_MAX) { - let entry = &INSTRUCTIONS[opcode as usize]; + let entry = INSTRUCTIONS + .get(opcode as usize) + .ok_or_else(|| malformed_error!("opcode 0x{:04X} out of range", opcode))?; if entry.instr.is_empty() { return Err(malformed_error!("unused opcode slot 0x{:02X}", opcode)); } Ok(entry.instr) } else if opcode >= 0xFE00 { let sub = (opcode & 0xFF) as usize; - if sub >= usize::from(INSTRUCTIONS_FE_MAX) { - return Err(malformed_error!( - "extended opcode 0xFE{:02X} out of range", - sub - )); - } - let entry = &INSTRUCTIONS_FE[sub]; + let entry = INSTRUCTIONS_FE + .get(sub) + .filter(|_| sub < usize::from(INSTRUCTIONS_FE_MAX)) + .ok_or_else(|| malformed_error!("extended opcode 0xFE{:02X} out of range", sub))?; if entry.instr.is_empty() { return Err(malformed_error!("unused extended opcode 0xFE{:02X}", sub)); } @@ -522,13 +525,19 @@ impl InstructionAssembler { #[must_use] pub fn is_branch_opcode(opcode: u16) -> bool { let flow = if opcode < u16::from(INSTRUCTIONS_MAX) { - INSTRUCTIONS[opcode as usize].flow + let Some(entry) = INSTRUCTIONS.get(opcode as usize) else { + return false; + }; + entry.flow } else if opcode >= 0xFE00 { let sub = (opcode & 0xFF) as usize; if sub >= usize::from(INSTRUCTIONS_FE_MAX) { return false; } - INSTRUCTIONS_FE[sub].flow + let Some(entry) = INSTRUCTIONS_FE.get(sub) else { + return false; + }; + entry.flow } else { return false; }; @@ -544,13 +553,19 @@ impl InstructionAssembler { #[must_use] pub fn is_token_opcode(opcode: u16) -> bool { let op_type = if opcode < u16::from(INSTRUCTIONS_MAX) { - INSTRUCTIONS[opcode as usize].op_type + let Some(entry) = INSTRUCTIONS.get(opcode as usize) else { + return false; + }; + entry.op_type } else if opcode >= 0xFE00 { let sub = (opcode & 0xFF) as usize; if sub >= usize::from(INSTRUCTIONS_FE_MAX) { return false; } - INSTRUCTIONS_FE[sub].op_type + let Some(entry) = INSTRUCTIONS_FE.get(sub) else { + return false; + }; + entry.op_type } else { return false; }; @@ -589,7 +604,7 @@ impl InstructionAssembler { /// Generate a unique internal label name. fn generate_label(&mut self, prefix: &str) -> String { - self.label_counter += 1; + self.label_counter = self.label_counter.saturating_add(1); format!("__{}_{}", prefix, self.label_counter) } @@ -798,7 +813,9 @@ impl InstructionAssembler { // Update handler (mutable borrow) if let Some(try_block) = self.try_blocks.get_mut(try_name) { - try_block.handlers[handler_idx].end_label = Some(end_label); + if let Some(h) = try_block.handlers.get_mut(handler_idx) { + h.end_label = Some(end_label); + } } Ok(self) @@ -909,7 +926,9 @@ impl InstructionAssembler { // Update handler (mutable borrow) if let Some(try_block) = self.try_blocks.get_mut(try_name) { - try_block.handlers[handler_idx].end_label = Some(end_label); + if let Some(h) = try_block.handlers.get_mut(handler_idx) { + h.end_label = Some(end_label); + } } Ok(self) @@ -1019,7 +1038,9 @@ impl InstructionAssembler { // Update handler (mutable borrow) if let Some(try_block) = self.try_blocks.get_mut(try_name) { - try_block.handlers[handler_idx].end_label = Some(end_label); + if let Some(h) = try_block.handlers.get_mut(handler_idx) { + h.end_label = Some(end_label); + } } Ok(self) @@ -1149,7 +1170,9 @@ impl InstructionAssembler { // Update handler (mutable borrow) if let Some(try_block) = self.try_blocks.get_mut(try_name) { - try_block.handlers[handler_idx].start_label = start_label; + if let Some(h) = try_block.handlers.get_mut(handler_idx) { + h.start_label = start_label; + } } Ok(self) @@ -1194,7 +1217,9 @@ impl InstructionAssembler { // Update handler (mutable borrow) if let Some(try_block) = self.try_blocks.get_mut(try_name) { - try_block.handlers[handler_idx].end_label = Some(end_label); + if let Some(h) = try_block.handlers.get_mut(handler_idx) { + h.end_label = Some(end_label); + } } Ok(self) diff --git a/dotscope/src/assembly/decoder.rs b/dotscope/src/assembly/decoder.rs index 4256134f..6b26d3e8 100644 --- a/dotscope/src/assembly/decoder.rs +++ b/dotscope/src/assembly/decoder.rs @@ -253,56 +253,47 @@ impl<'a> Decoder<'a> { // Create blocks for exception handler entry points // These must be created explicitly as they may not be reachable via normal control flow if let Some(exceptions) = self.exceptions { + // Collect entry-point candidates first (handler/filter/try) to avoid + // borrow conflicts between iterating `self.exceptions` and mutating self. + let rva_base = self.rva_start as u64; + let mut candidates: Vec<(u64, u32)> = + Vec::with_capacity(exceptions.len().saturating_mul(3)); for handler in exceptions { - // Handler entry block (catch/finally/fault) - let handler_rva = self.rva_start as u64 + u64::from(handler.handler_offset); - if !entry_points.contains(&handler_rva) { - let handler_offset = self.offset_start + handler.handler_offset as usize; - if handler_offset < self.parser.len() && !self.visited.get(handler_offset) { - self.blocks.push(BasicBlock::new( - self.blocks.len(), - handler_rva, - handler_offset, - )); - entry_points.insert(handler_rva); - } - } - - // Filter entry block (for filter handlers) + let handler_rva = rva_base + .checked_add(u64::from(handler.handler_offset)) + .ok_or(out_of_bounds_error!())?; + candidates.push((handler_rva, handler.handler_offset)); if handler.filter_offset > 0 { - let filter_rva = self.rva_start as u64 + u64::from(handler.filter_offset); - if !entry_points.contains(&filter_rva) { - let filter_offset = self.offset_start + handler.filter_offset as usize; - if filter_offset < self.parser.len() && !self.visited.get(filter_offset) { - self.blocks.push(BasicBlock::new( - self.blocks.len(), - filter_rva, - filter_offset, - )); - entry_points.insert(filter_rva); - } - } + let filter_rva = rva_base + .checked_add(u64::from(handler.filter_offset)) + .ok_or(out_of_bounds_error!())?; + candidates.push((filter_rva, handler.filter_offset)); } + let try_rva = rva_base + .checked_add(u64::from(handler.try_offset)) + .ok_or(out_of_bounds_error!())?; + candidates.push((try_rva, handler.try_offset)); + } - // Try region entry block - // This must be created explicitly when try_offset > 0, otherwise the - // block starting at method entry will stop at this entry point but there - // will be no block to continue decoding the try region content. - let try_rva = self.rva_start as u64 + u64::from(handler.try_offset); - if !entry_points.contains(&try_rva) { - let try_offset = self.offset_start + handler.try_offset as usize; - if try_offset < self.parser.len() && !self.visited.get(try_offset) { - self.blocks - .push(BasicBlock::new(self.blocks.len(), try_rva, try_offset)); - entry_points.insert(try_rva); - } + for (entry_rva, entry_offset_u32) in candidates { + if entry_points.contains(&entry_rva) { + continue; + } + let entry_offset = self + .offset_start + .checked_add(entry_offset_u32 as usize) + .ok_or(out_of_bounds_error!())?; + if entry_offset < self.parser.len() && !self.visited.get(entry_offset) { + self.blocks + .push(BasicBlock::new(self.blocks.len(), entry_rva, entry_offset)); + entry_points.insert(entry_rva); } } } while self.block_id < self.blocks.len() { self.decode_single_block(&mut entry_points)?; - self.block_id += 1; + self.block_id = self.block_id.checked_add(1).ok_or(out_of_bounds_error!())?; } self.blocks.retain(|b| !b.instructions.is_empty()); @@ -312,9 +303,9 @@ impl<'a> Decoder<'a> { block.id = idx; } - self.process_exception_handlers(); + self.process_exception_handlers()?; self.wire_control_flow_edges(); - self.wire_exception_edges(); + self.wire_exception_edges()?; Ok(()) } @@ -354,25 +345,30 @@ impl<'a> Decoder<'a> { fn decode_single_block(&mut self, entry_points: &mut HashSet) -> Result<()> { let block_id = self.block_id; - if self.blocks[block_id].offset > self.parser.len() { + let (block_offset, block_rva) = { + let block = self.blocks.get(block_id).ok_or(out_of_bounds_error!())?; + (block.offset, block.rva) + }; + + if block_offset > self.parser.len() { return Err(out_of_bounds_error!()); } - if self.visited.get(self.blocks[block_id].offset) { + if self.visited.get(block_offset) { return Ok(()); } - self.parser.seek(self.blocks[block_id].offset)?; + self.parser.seek(block_offset)?; - let mut current_offset = self.blocks[block_id].offset; - let mut current_rva = self.blocks[block_id].rva; + let mut current_offset = block_offset; + let mut current_rva = block_rva; loop { if current_offset >= self.parser.len() { break; } - if current_rva != self.blocks[block_id].rva && entry_points.contains(¤t_rva) { + if current_rva != block_rva && entry_points.contains(¤t_rva) { // We've reached the start of another block - stop here break; } @@ -387,8 +383,17 @@ impl<'a> Decoder<'a> { self.visited.set_range(current_offset, true, instr_size); - self.blocks[block_id].size += instr_size; - self.blocks[block_id].instructions.push(instruction.clone()); + { + let block = self + .blocks + .get_mut(block_id) + .ok_or(out_of_bounds_error!())?; + block.size = block + .size + .checked_add(instr_size) + .ok_or_else(|| malformed_error!("block size overflow"))?; + block.instructions.push(instruction.clone()); + } match instruction.flow_type { FlowType::ConditionalBranch => { @@ -396,7 +401,9 @@ impl<'a> Decoder<'a> { self.add_entry_point(target_rva, entry_points); } - let fall_through_rva = current_rva + instruction.size; + let fall_through_rva = current_rva + .checked_add(instruction.size) + .ok_or_else(|| malformed_error!("fall-through RVA overflow"))?; self.add_entry_point(fall_through_rva, entry_points); break; @@ -413,7 +420,9 @@ impl<'a> Decoder<'a> { self.add_entry_point(target_rva, entry_points); } // Add fall-through as entry point for the default case - let fall_through_rva = current_rva + instruction.size; + let fall_through_rva = current_rva + .checked_add(instruction.size) + .ok_or_else(|| malformed_error!("fall-through RVA overflow"))?; self.add_entry_point(fall_through_rva, entry_points); break; } @@ -425,8 +434,12 @@ impl<'a> Decoder<'a> { } } - current_offset += instr_size; - current_rva += instruction.size; + current_offset = current_offset + .checked_add(instr_size) + .ok_or_else(|| malformed_error!("instruction offset overflow"))?; + current_rva = current_rva + .checked_add(instruction.size) + .ok_or_else(|| malformed_error!("instruction RVA overflow"))?; } Ok(()) @@ -467,10 +480,15 @@ impl<'a> Decoder<'a> { return; } - let Ok(relative_offset) = usize::try_from(rva - self.rva_start as u64) else { + let Some(delta) = rva.checked_sub(self.rva_start as u64) else { + return; + }; + let Ok(relative_offset) = usize::try_from(delta) else { return; // RVA delta too large for this platform }; - let offset = self.offset_start + relative_offset; + let Some(offset) = self.offset_start.checked_add(relative_offset) else { + return; + }; if offset >= self.parser.len() { return; } @@ -508,7 +526,7 @@ impl<'a> Decoder<'a> { return None; } - let block_end_rva = block.rva + block.size as u64; + let block_end_rva = block.rva.checked_add(block.size as u64)?; if rva > block.rva && rva < block_end_rva { for (instr_idx, instr) in block.instructions.iter().enumerate() { if instr.rva == rva { @@ -557,18 +575,22 @@ impl<'a> Decoder<'a> { // Create new block with instructions from split point onwards let mut new_block = BasicBlock::new(self.blocks.len(), rva, offset); - new_block.instructions = self.blocks[block_idx].instructions[split_instr_idx..].to_vec(); + let Some(orig) = self.blocks.get(block_idx) else { + return; + }; + let Some(tail) = orig.instructions.get(split_instr_idx..) else { + return; + }; + new_block.instructions = tail.to_vec(); new_block.size = Self::compute_instructions_size(&new_block.instructions); - new_block - .exceptions - .clone_from(&self.blocks[block_idx].exceptions); + new_block.exceptions.clone_from(&orig.exceptions); // Truncate the original block - self.blocks[block_idx] - .instructions - .truncate(split_instr_idx); - self.blocks[block_idx].size = - Self::compute_instructions_size(&self.blocks[block_idx].instructions); + let Some(orig_mut) = self.blocks.get_mut(block_idx) else { + return; + }; + orig_mut.instructions.truncate(split_instr_idx); + orig_mut.size = Self::compute_instructions_size(&orig_mut.instructions); self.blocks.push(new_block); } @@ -618,9 +640,9 @@ impl<'a> Decoder<'a> { /// This method should be called after all blocks have been decoded and before /// control flow edges are wired, as the exception associations may affect /// control flow analysis. - fn process_exception_handlers(&mut self) { + fn process_exception_handlers(&mut self) -> Result<()> { let Some(exceptions) = self.exceptions else { - return; + return Ok(()); }; // Build a map from RVA to block index for handler entry detection @@ -636,8 +658,12 @@ impl<'a> Decoder<'a> { let base_rva = self.rva_start as u64; for (handler_idx, handler) in exceptions.iter().enumerate() { - let try_start = base_rva + u64::from(handler.try_offset); - let try_end = try_start + u64::from(handler.try_length); + let try_start = base_rva + .checked_add(u64::from(handler.try_offset)) + .ok_or_else(|| malformed_error!("try_offset RVA overflow"))?; + let try_end = try_start + .checked_add(u64::from(handler.try_length)) + .ok_or_else(|| malformed_error!("try region end RVA overflow"))?; // Mark blocks in the try region for block in &mut self.blocks { @@ -647,24 +673,31 @@ impl<'a> Decoder<'a> { } // Mark handler entry block - let handler_rva = base_rva + u64::from(handler.handler_offset); + let handler_rva = base_rva + .checked_add(u64::from(handler.handler_offset)) + .ok_or_else(|| malformed_error!("handler_offset RVA overflow"))?; if let Some(&handler_block_idx) = rva_to_block.get(&handler_rva) { - self.blocks[handler_block_idx].handler_entry = - Some(HandlerEntryInfo::new(handler_idx, handler.flags)); + if let Some(b) = self.blocks.get_mut(handler_block_idx) { + b.handler_entry = Some(HandlerEntryInfo::new(handler_idx, handler.flags)); + } } // Mark filter entry block (for filter handlers) if handler.flags == ExceptionHandlerFlags::FILTER && handler.filter_offset > 0 { - let filter_rva = base_rva + u64::from(handler.filter_offset); + let filter_rva = base_rva + .checked_add(u64::from(handler.filter_offset)) + .ok_or_else(|| malformed_error!("filter_offset RVA overflow"))?; if let Some(&filter_block_idx) = rva_to_block.get(&filter_rva) { - // Filter blocks also receive the exception object - self.blocks[filter_block_idx].handler_entry = Some(HandlerEntryInfo::new( - handler_idx, - ExceptionHandlerFlags::FILTER, - )); + if let Some(b) = self.blocks.get_mut(filter_block_idx) { + b.handler_entry = Some(HandlerEntryInfo::new( + handler_idx, + ExceptionHandlerFlags::FILTER, + )); + } } } } + Ok(()) } /// Wire exception edges from protected blocks to their handler blocks. @@ -679,9 +712,9 @@ impl<'a> Decoder<'a> { /// instruction in a protected region can potentially transfer control to /// the handler. For analysis purposes, we model this as an edge from the /// block to the handler entry block. - fn wire_exception_edges(&mut self) { + fn wire_exception_edges(&mut self) -> Result<()> { let Some(exceptions) = self.exceptions else { - return; + return Ok(()); }; // Build a map from RVA to block index @@ -698,13 +731,19 @@ impl<'a> Decoder<'a> { // For each handler, wire edges from protected blocks to handler blocks for handler in exceptions { - let handler_rva = base_rva + u64::from(handler.handler_offset); + let handler_rva = base_rva + .checked_add(u64::from(handler.handler_offset)) + .ok_or_else(|| malformed_error!("handler_offset RVA overflow"))?; let Some(&handler_block_idx) = rva_to_block.get(&handler_rva) else { continue; }; - let try_start = base_rva + u64::from(handler.try_offset); - let try_end = try_start + u64::from(handler.try_length); + let try_start = base_rva + .checked_add(u64::from(handler.try_offset)) + .ok_or_else(|| malformed_error!("try_offset RVA overflow"))?; + let try_end = try_start + .checked_add(u64::from(handler.try_length)) + .ok_or_else(|| malformed_error!("try region end RVA overflow"))?; for block in &mut self.blocks { if block.rva >= try_start && block.rva < try_end { @@ -715,6 +754,7 @@ impl<'a> Decoder<'a> { } } } + Ok(()) } /// Wire up control flow edges (successors/predecessors) between blocks. @@ -749,11 +789,13 @@ impl<'a> Decoder<'a> { for block_idx in 0..self.blocks.len() { let successors = self.compute_block_successors(block_idx, &rva_to_block); - self.blocks[block_idx].successors.clone_from(&successors); + if let Some(b) = self.blocks.get_mut(block_idx) { + b.successors.clone_from(&successors); + } for &succ_idx in &successors { - if succ_idx < self.blocks.len() { - self.blocks[succ_idx].predecessors.push(block_idx); + if let Some(b) = self.blocks.get_mut(succ_idx) { + b.predecessors.push(block_idx); } } } @@ -791,10 +833,13 @@ impl<'a> Decoder<'a> { block_idx: usize, rva_to_block: &HashMap, ) -> Vec { - let block = &self.blocks[block_idx]; + let Some(block) = self.blocks.get(block_idx) else { + return vec![]; + }; let Some(last_instr) = block.instructions.last() else { return vec![]; }; + let fall_through_rva = block.rva.checked_add(block.size as u64); match last_instr.flow_type { FlowType::Return | FlowType::Throw => { @@ -821,9 +866,10 @@ impl<'a> Decoder<'a> { } // Add fall-through target (instruction immediately after this block) - let fall_through_rva = block.rva + block.size as u64; - if let Some(&fall_through_idx) = rva_to_block.get(&fall_through_rva) { - successors.push(fall_through_idx); + if let Some(rva) = fall_through_rva { + if let Some(&fall_through_idx) = rva_to_block.get(&rva) { + successors.push(fall_through_idx); + } } successors @@ -837,9 +883,10 @@ impl<'a> Decoder<'a> { .collect(); // Add fall-through as the default (last successor) - let fall_through_rva = block.rva + block.size as u64; - if let Some(&fall_through_idx) = rva_to_block.get(&fall_through_rva) { - successors.push(fall_through_idx); + if let Some(rva) = fall_through_rva { + if let Some(&fall_through_idx) = rva_to_block.get(&rva) { + successors.push(fall_through_idx); + } } successors @@ -858,10 +905,9 @@ impl<'a> Decoder<'a> { } FlowType::Sequential | FlowType::Call => { // Fall through to next block - let fall_through_rva = block.rva + block.size as u64; - rva_to_block - .get(&fall_through_rva) - .map(|&idx| vec![idx]) + fall_through_rva + .and_then(|rva| rva_to_block.get(&rva).copied()) + .map(|idx| vec![idx]) .unwrap_or_default() } } @@ -949,10 +995,13 @@ pub(crate) fn decode_method( } let mut parser = Parser::new(file.data()); + let rva_start = rva + .checked_add(body.size_header) + .ok_or_else(|| malformed_error!("rva + size_header overflow"))?; let mut decoder = Decoder::new( &mut parser, code_start, - rva + body.size_header, + rva_start, Some(&body.exception_handlers), shared_visited, )?; @@ -1036,9 +1085,9 @@ pub fn decode_blocks( let effective_data = if let Some(size) = max_size { let end_offset = offset.saturating_add(size).min(data.len()); - &data[offset..end_offset] + data.get(offset..end_offset).ok_or(out_of_bounds_error!())? } else { - &data[offset..] + data.get(offset..).ok_or(out_of_bounds_error!())? }; let mut parser = Parser::new(effective_data); @@ -1123,7 +1172,14 @@ pub fn decode_stream(parser: &mut Parser, rva: u64) -> Result> instructions.push(instruction); - current_rva += (parser.pos() - current_offset) as u64; + let consumed = parser + .pos() + .checked_sub(current_offset) + .ok_or_else(|| malformed_error!("parser position regressed during decode"))? + as u64; + current_rva = current_rva + .checked_add(consumed) + .ok_or_else(|| malformed_error!("instruction stream RVA overflow"))?; } Ok(instructions) @@ -1251,7 +1307,9 @@ pub fn decode_instruction(parser: &mut Parser, rva: u64) -> Result Operand::Switch(targets) } }; - let size = parser.pos() as u64 - offset; + let size = (parser.pos() as u64) + .checked_sub(offset) + .ok_or_else(|| malformed_error!("instruction size underflow"))?; let mut instruction = Instruction { rva, @@ -1265,9 +1323,11 @@ pub fn decode_instruction(parser: &mut Parser, rva: u64) -> Result stack_behavior: StackBehavior { pops: cil_instruction.stack_pops, pushes: cil_instruction.stack_pushes, - // Allow wrapping cast - stack effects can legitimately be negative + // Stack effects are bounded: per ECMA-335 individual instructions push/pop + // a small number of stack slots, so this fits in i8 with `wrapping_sub`. #[allow(clippy::cast_possible_wrap)] - net_effect: cil_instruction.stack_pushes as i8 - cil_instruction.stack_pops as i8, + net_effect: (cil_instruction.stack_pushes as i8) + .wrapping_sub(cil_instruction.stack_pops as i8), }, branch_targets: Vec::new(), operand, @@ -1278,7 +1338,9 @@ pub fn decode_instruction(parser: &mut Parser, rva: u64) -> Result // All branch-type instructions have their target computed from an immediate offset // This includes leave/leave.s which exit protected regions to a specific target if let Operand::Immediate(value) = instruction.operand { - let next_instruction_rva = rva + instruction.size; + let next_instruction_rva = rva + .checked_add(instruction.size) + .ok_or_else(|| malformed_error!("branch instruction RVA overflow"))?; let branch_offset = >::into(value); instruction .branch_targets @@ -1287,7 +1349,9 @@ pub fn decode_instruction(parser: &mut Parser, rva: u64) -> Result } FlowType::Switch => { if let Operand::Switch(targets) = &instruction.operand { - let next_instruction_rva = rva + instruction.size; + let next_instruction_rva = rva + .checked_add(instruction.size) + .ok_or_else(|| malformed_error!("switch instruction RVA overflow"))?; for &target in targets { // Sign-extend i32 offset to i64 for proper signed arithmetic let offset = i64::from(target); diff --git a/dotscope/src/assembly/encoder.rs b/dotscope/src/assembly/encoder.rs index d7465474..42a8af99 100644 --- a/dotscope/src/assembly/encoder.rs +++ b/dotscope/src/assembly/encoder.rs @@ -83,21 +83,23 @@ fn get_mnemonic_lookup( MNEMONIC_TO_OPCODE.get_or_init(|| { let mut map = HashMap::new(); - // Single-byte instructions (0x00 to 0xE0) + // Single-byte instructions (0x00 to 0xE0). The arrays are bounded to + // u8 opcodes by construction; any out-of-range entry is silently + // skipped (cannot be encoded with a 1-byte opcode anyway). for (opcode, instr) in INSTRUCTIONS.iter().enumerate() { if !instr.instr.is_empty() { - let opcode_u8 = u8::try_from(opcode) - .unwrap_or_else(|_| panic!("Opcode {opcode} exceeds u8 range")); - map.insert(instr.instr, (opcode_u8, 0, instr)); + if let Ok(opcode_u8) = u8::try_from(opcode) { + map.insert(instr.instr, (opcode_u8, 0, instr)); + } } } // Extended instructions (0xFE prefix) for (opcode, instr) in INSTRUCTIONS_FE.iter().enumerate() { if !instr.instr.is_empty() { - let opcode_u8 = u8::try_from(opcode) - .unwrap_or_else(|_| panic!("Opcode {opcode} exceeds u8 range")); - map.insert(instr.instr, (opcode_u8, 0xFE, instr)); + if let Ok(opcode_u8) = u8::try_from(opcode) { + map.insert(instr.instr, (opcode_u8, 0xFE, instr)); + } } } @@ -1014,8 +1016,12 @@ impl InstructionEncoder { // Single-byte terminators: ret, throw, endfinally, jmp matches!(last_byte, 0x2A | 0x7A | 0xDC | 0x27) // Two-byte terminator: rethrow (0xFE 0x1A) - || (self.bytecode.len() >= 2 - && self.bytecode[self.bytecode.len() - 2] == 0xFE + || (self + .bytecode + .len() + .checked_sub(2) + .and_then(|i| self.bytecode.get(i).copied()) + == Some(0xFE) && last_byte == 0x1A) } else { false @@ -1110,14 +1116,19 @@ impl InstructionEncoder { .ok_or_else(|| Error::UndefinedLabel(fixup.label.clone()))?; // Calculate relative offset from end of branch instruction to label - let next_instruction_pos = fixup.fixup_position + fixup.offset_size as usize; + let next_instruction_pos = fixup + .fixup_position + .checked_add(fixup.offset_size as usize) + .ok_or_else(|| malformed_error!("Branch instruction end position overflow"))?; let label_pos_i32 = i32::try_from(*label_position) .map_err(|_| malformed_error!("Label position exceeds i32 range"))?; let next_instr_pos_i32 = i32::try_from(next_instruction_pos) .map_err(|_| malformed_error!("Instruction position exceeds i32 range"))?; - let offset = label_pos_i32 - next_instr_pos_i32; + let offset = label_pos_i32 + .checked_sub(next_instr_pos_i32) + .ok_or_else(|| malformed_error!("Branch offset overflow"))?; self.write_branch_offset(offset, fixup)?; } @@ -1138,12 +1149,26 @@ impl InstructionEncoder { .map_err(|_| malformed_error!("Label position exceeds i32 range"))?; // Switch offsets are relative to the end of the switch instruction - let offset = label_pos_i32 - instruction_end_i32; + let offset = label_pos_i32 + .checked_sub(instruction_end_i32) + .ok_or_else(|| malformed_error!("Switch offset overflow"))?; // Write the 4-byte offset at the correct position - let target_pos = switch_fixup.fixup_position + i * 4; + let target_pos = switch_fixup + .fixup_position + .checked_add( + i.checked_mul(4) + .ok_or_else(|| malformed_error!("Switch label index overflow"))?, + ) + .ok_or_else(|| malformed_error!("Switch target position overflow"))?; + let target_end = target_pos + .checked_add(4) + .ok_or_else(|| malformed_error!("Switch target end overflow"))?; let offset_bytes = offset.to_le_bytes(); - self.bytecode[target_pos..target_pos + 4].copy_from_slice(&offset_bytes); + self.bytecode + .get_mut(target_pos..target_end) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&offset_bytes); } } @@ -1191,7 +1216,10 @@ impl InstructionEncoder { // Calculate offset as if we used short form (1 byte offset instead of 4) // Short form: instruction is 2 bytes (opcode + 1-byte offset) // The offset is relative to the end of the instruction - let short_form_end = fixup.instruction_position + 2; // opcode + 1-byte offset + let short_form_end = fixup + .instruction_position + .checked_add(2) + .ok_or_else(|| malformed_error!("Short-form end overflow"))?; let label_pos_i32 = i32::try_from(*label_position) .map_err(|_| malformed_error!("Label position exceeds i32 range"))?; @@ -1200,7 +1228,9 @@ impl InstructionEncoder { // Calculate what the offset would be with short form // Note: We need to account for the 3-byte savings in positions after this branch - let offset = label_pos_i32 - short_end_i32; + let offset = label_pos_i32 + .checked_sub(short_end_i32) + .ok_or_else(|| malformed_error!("Short-form offset overflow"))?; // Check if offset fits in signed byte (-128 to +127) if (-128..=127).contains(&offset) { @@ -1226,10 +1256,17 @@ impl InstructionEncoder { let mut cumulative = 0i32; for &idx in shrinkable { - let fixup = &self.fixups[idx]; + let Some(fixup) = self.fixups.get(idx) else { + continue; + }; // The adjustment takes effect after this instruction - let instr_end = fixup.fixup_position + fixup.offset_size as usize; - cumulative -= 3; // Shrinking saves 3 bytes + let instr_end = fixup + .fixup_position + .checked_add(fixup.offset_size as usize) + .ok_or_else(|| malformed_error!("Branch end position overflow"))?; + cumulative = cumulative + .checked_sub(3) + .ok_or_else(|| malformed_error!("Branch shrinking adjustment overflow"))?; adjustments.push((instr_end, cumulative)); } @@ -1249,7 +1286,7 @@ impl InstructionEncoder { clippy::cast_possible_wrap, clippy::cast_sign_loss )] - let adjusted = (pos as i32 + adj).max(0) as usize; + let adjusted = (pos as i32).saturating_add(adj).max(0) as usize; adjusted }; @@ -1259,14 +1296,24 @@ impl InstructionEncoder { // Sort fixups by position for sequential processing let mut sorted_indices: Vec = (0..self.fixups.len()).collect(); - sorted_indices.sort_by_key(|&i| self.fixups[i].instruction_position); + sorted_indices.sort_by_key(|&i| { + self.fixups + .get(i) + .map_or(usize::MAX, |f| f.instruction_position) + }); for &idx in &sorted_indices { - let fixup = &self.fixups[idx]; + let Some(fixup) = self.fixups.get(idx) else { + continue; + }; // Copy bytes up to this instruction if src_pos < fixup.instruction_position { - new_bytecode.extend_from_slice(&self.bytecode[src_pos..fixup.instruction_position]); + let chunk = self + .bytecode + .get(src_pos..fixup.instruction_position) + .ok_or(out_of_bounds_error!())?; + new_bytecode.extend_from_slice(chunk); } src_pos = fixup.instruction_position; @@ -1286,18 +1333,29 @@ impl InstructionEncoder { new_bytecode.push(0); // Placeholder for offset // Skip the original instruction (opcode + 4-byte offset) - src_pos = fixup.fixup_position + 4; + src_pos = fixup + .fixup_position + .checked_add(4) + .ok_or_else(|| malformed_error!("Branch fixup position overflow"))?; } else { // Copy original instruction - let instr_end = fixup.fixup_position + fixup.offset_size as usize; - new_bytecode.extend_from_slice(&self.bytecode[src_pos..instr_end]); + let instr_end = fixup + .fixup_position + .checked_add(fixup.offset_size as usize) + .ok_or_else(|| malformed_error!("Branch end position overflow"))?; + let chunk = self + .bytecode + .get(src_pos..instr_end) + .ok_or(out_of_bounds_error!())?; + new_bytecode.extend_from_slice(chunk); src_pos = instr_end; } } // Copy remaining bytes if src_pos < self.bytecode.len() { - new_bytecode.extend_from_slice(&self.bytecode[src_pos..]); + let chunk = self.bytecode.get(src_pos..).ok_or(out_of_bounds_error!())?; + new_bytecode.extend_from_slice(chunk); } // Update labels @@ -1315,7 +1373,9 @@ impl InstructionEncoder { if shrinkable_set.contains(&idx) { // This branch was shrunk fixup.instruction_position = new_instr_pos; - fixup.fixup_position = new_instr_pos + 1; // opcode + offset position + fixup.fixup_position = new_instr_pos + .checked_add(1) + .ok_or_else(|| malformed_error!("Shrunk branch fixup position overflow"))?; fixup.offset_size = 1; fixup.short_form_mnemonic = None; // Already optimized } else { @@ -1485,7 +1545,11 @@ impl InstructionEncoder { } let offset_i8 = i8::try_from(offset) .map_err(|_| malformed_error!("Branch offset exceeds i8 range"))?; - self.bytecode[fixup.fixup_position] = offset_i8.to_le_bytes()[0]; + let slot = self + .bytecode + .get_mut(fixup.fixup_position) + .ok_or(out_of_bounds_error!())?; + *slot = offset_i8.to_le_bytes()[0]; } 2 => { if offset < i32::from(i16::MIN) || offset > i32::from(i16::MAX) { @@ -1496,12 +1560,24 @@ impl InstructionEncoder { let offset_i16 = i16::try_from(offset) .map_err(|_| malformed_error!("Branch offset exceeds i16 range"))?; let bytes = offset_i16.to_le_bytes(); - self.bytecode[fixup.fixup_position..fixup.fixup_position + 2] + let end = fixup + .fixup_position + .checked_add(2) + .ok_or_else(|| malformed_error!("Branch fixup end overflow"))?; + self.bytecode + .get_mut(fixup.fixup_position..end) + .ok_or(out_of_bounds_error!())? .copy_from_slice(&bytes); } 4 => { let bytes = offset.to_le_bytes(); - self.bytecode[fixup.fixup_position..fixup.fixup_position + 4] + let end = fixup + .fixup_position + .checked_add(4) + .ok_or_else(|| malformed_error!("Branch fixup end overflow"))?; + self.bytecode + .get_mut(fixup.fixup_position..end) + .ok_or(out_of_bounds_error!())? .copy_from_slice(&bytes); } _ => { @@ -1528,9 +1604,12 @@ impl InstructionEncoder { /// /// Returns an error if stack underflow would occur (negative stack depth). fn update_stack_depth(&mut self, pops: u8, pushes: u8) -> Result<()> { - // Apply stack effect - let net_effect = i16::from(pushes) - i16::from(pops); - self.current_stack_depth += net_effect; + // Apply stack effect; both pushes/pops are u8 so this fits in i16. + let net_effect = i16::from(pushes).wrapping_sub(i16::from(pops)); + self.current_stack_depth = self + .current_stack_depth + .checked_add(net_effect) + .ok_or_else(|| malformed_error!("Stack depth overflow"))?; // Check for stack underflow - but only in reachable code. // In unreachable code, the stack depth is meaningless, so we don't error. diff --git a/dotscope/src/assembly/instruction.rs b/dotscope/src/assembly/instruction.rs index 75bd9f9a..901c6e75 100644 --- a/dotscope/src/assembly/instruction.rs +++ b/dotscope/src/assembly/instruction.rs @@ -264,7 +264,7 @@ impl From for u64 { /// # Thread Safety /// /// [`Operand`] is [`std::marker::Send`] and [`std::marker::Sync`] as all variants contain thread-safe types. -/// This includes primitives, [`crate::assembly::instruction::Immediate`], [`crate::metadata::token::Token`], and [`std::vec::Vec`]. +/// This includes primitives, [`crate::assembly::instruction::Immediate`], [`Token`], and [`std::vec::Vec`]. #[derive(Debug, Clone)] pub enum Operand { /// No operand present @@ -892,7 +892,7 @@ impl fmt::Debug for Instruction { write!(f, "0x{item:08X}")?; // Limit output for very large switch tables if i >= 5 && items.len() > 6 { - write!(f, ", ...{} more", items.len() - 6)?; + write!(f, ", ...{} more", items.len().saturating_sub(6))?; break; } } @@ -926,7 +926,11 @@ impl fmt::Debug for Instruction { write!(f, "0x{target:08X}")?; // Limit output for instructions with many targets if i >= 3 && self.branch_targets.len() > 4 { - write!(f, ", ...{} more", self.branch_targets.len() - 4)?; + write!( + f, + ", ...{} more", + self.branch_targets.len().saturating_sub(4) + )?; break; } } diff --git a/dotscope/src/assembly/instructions.rs b/dotscope/src/assembly/instructions.rs index b9e58d77..3d4f168a 100644 --- a/dotscope/src/assembly/instructions.rs +++ b/dotscope/src/assembly/instructions.rs @@ -2515,25 +2515,21 @@ pub const INSTRUCTIONS_FE: [CilInstruction; INSTRUCTIONS_FE_MAX as usize] = [ /// ``` #[must_use] pub fn il_instruction_size(il_bytes: &[u8], offset: usize) -> usize { - if offset >= il_bytes.len() { + let Some(&opcode) = il_bytes.get(offset) else { return 0; - } - - let opcode = il_bytes[offset]; + }; // Handle two-byte opcodes (0xFE prefix) if opcode == 0xFE { - if offset + 1 >= il_bytes.len() { + let Some(next) = offset.checked_add(1) else { + return 1; + }; + let Some(&second_byte) = il_bytes.get(next) else { return 1; // Incomplete two-byte opcode - } - - let second_byte = il_bytes[offset + 1]; - if (second_byte as usize) < INSTRUCTIONS_FE.len() { - let operand_size = INSTRUCTIONS_FE[second_byte as usize] - .op_type - .size() - .unwrap_or(0); - return 2 + operand_size; // 2 bytes for opcode prefix + second byte + }; + if let Some(instr) = INSTRUCTIONS_FE.get(second_byte as usize) { + let operand_size = instr.op_type.size().unwrap_or(0); + return 2usize.saturating_add(operand_size); } return 2; // Unknown 0xFE opcode, assume no operand } @@ -2544,9 +2540,9 @@ pub fn il_instruction_size(il_bytes: &[u8], offset: usize) -> usize { } // Single-byte opcodes - if (opcode as usize) < INSTRUCTIONS.len() { - let operand_size = INSTRUCTIONS[opcode as usize].op_type.size().unwrap_or(0); - return 1 + operand_size; + if let Some(instr) = INSTRUCTIONS.get(opcode as usize) { + let operand_size = instr.op_type.size().unwrap_or(0); + return 1usize.saturating_add(operand_size); } 1 // Unknown opcode, assume 1 byte @@ -2567,16 +2563,23 @@ pub fn il_instruction_size(il_bytes: &[u8], offset: usize) -> usize { #[inline] #[must_use] pub fn switch_instruction_size(il_bytes: &[u8], offset: usize) -> usize { - if offset + 5 > il_bytes.len() { + let Some(start) = offset.checked_add(1) else { + return 1; + }; + let Some(end) = offset.checked_add(5) else { + return 1; + }; + let Some(slice) = il_bytes.get(start..end) else { return 1; // Malformed, just return opcode size - } - - let count = u32::from_le_bytes([ - il_bytes[offset + 1], - il_bytes[offset + 2], - il_bytes[offset + 3], - il_bytes[offset + 4], - ]) as usize; + }; + let Ok(arr) = <[u8; 4]>::try_from(slice) else { + return 1; + }; + let count = u32::from_le_bytes(arr) as usize; - 1 + 4 + (count * 4) // opcode + count + targets + // 1 (opcode) + 4 (count) + count*4 (targets) + count + .checked_mul(4) + .and_then(|t| t.checked_add(5)) + .unwrap_or(usize::MAX) } diff --git a/dotscope/src/cilassembly/builders/method.rs b/dotscope/src/cilassembly/builders/method.rs index 9a99564d..4ef774d1 100644 --- a/dotscope/src/cilassembly/builders/method.rs +++ b/dotscope/src/cilassembly/builders/method.rs @@ -844,7 +844,10 @@ impl MethodBuilder { let param_start_index = assembly.next_rid(TableId::Param)?; for (sequence, (name, _param_type)) in parameters.iter().enumerate() { - let param_sequence = u32::try_from(sequence + 1) + let one_based = sequence + .checked_add(1) + .ok_or_else(|| malformed_error!("Parameter sequence overflow"))?; + let param_sequence = u32::try_from(one_based) .map_err(|_| malformed_error!("Parameter sequence exceeds u32 range"))?; // Parameters start at sequence 1 ParamBuilder::new() diff --git a/dotscope/src/cilassembly/builders/method_body.rs b/dotscope/src/cilassembly/builders/method_body.rs index 8b74f012..d59ba6bd 100644 --- a/dotscope/src/cilassembly/builders/method_body.rs +++ b/dotscope/src/cilassembly/builders/method_body.rs @@ -106,8 +106,12 @@ fn resolve_labeled_exception_handler( ))); } - let try_length = try_end_offset - try_start_offset; - let handler_length = handler_end_offset - handler_start_offset; + let try_length = try_end_offset + .checked_sub(try_start_offset) + .ok_or_else(|| Error::ModificationInvalid("try region length underflow".into()))?; + let handler_length = handler_end_offset + .checked_sub(handler_start_offset) + .ok_or_else(|| Error::ModificationInvalid("handler region length underflow".into()))?; // Resolve filter offset for FILTER handlers let filter_offset = if let Some(filter_label) = &labeled_handler.filter_start_label { diff --git a/dotscope/src/cilassembly/changes/assembly.rs b/dotscope/src/cilassembly/changes/assembly.rs index f994d1ed..fe8b640d 100644 --- a/dotscope/src/cilassembly/changes/assembly.rs +++ b/dotscope/src/cilassembly/changes/assembly.rs @@ -461,7 +461,7 @@ impl AssemblyChanges { self.method_bodies.insert(placeholder_rva, body_bytes); // Increment to next placeholder (simple sequential allocation) - self.next_method_placeholder += 1; + self.next_method_placeholder = self.next_method_placeholder.saturating_add(1); placeholder_rva } @@ -518,7 +518,7 @@ impl AssemblyChanges { let size = u32::try_from(body.len()) .map_err(|_| malformed_error!("Method body size exceeds u32 range"))?; // Align each method body to 4-byte boundary - Ok((size + 3) & !3) + Ok(size.saturating_add(3) & !3u32) }) .sum() } @@ -623,7 +623,7 @@ impl AssemblyChanges { self.field_data.insert(placeholder_rva, data); // Increment to next placeholder (simple sequential allocation) - self.next_field_placeholder += 1; + self.next_field_placeholder = self.next_field_placeholder.saturating_add(1); placeholder_rva } @@ -656,7 +656,7 @@ impl AssemblyChanges { let size = u32::try_from(data.len()) .map_err(|_| malformed_error!("Field data size exceeds u32 range"))?; // Align each entry to 4-byte boundary (same as method bodies) - Ok((size + 3) & !3) + Ok(size.saturating_add(3) & !3u32) }) .sum() } diff --git a/dotscope/src/cilassembly/changes/changeref.rs b/dotscope/src/cilassembly/changes/changeref.rs index 71ecf765..0fb28d31 100644 --- a/dotscope/src/cilassembly/changes/changeref.rs +++ b/dotscope/src/cilassembly/changes/changeref.rs @@ -49,7 +49,10 @@ use std::sync::{ Arc, }; -use crate::metadata::{tables::TableId, token::Token}; +use crate::{ + metadata::{tables::TableId, token::Token}, + Error, Result, +}; // Re-export hash functions from utils for backwards compatibility pub use crate::utils::{hash_blob, hash_guid, hash_string}; @@ -222,21 +225,24 @@ impl ChangeRef { /// /// * `token` - The existing metadata token /// - /// # Panics - /// - /// Panics if the token's table type is not recognized. - #[must_use] - pub fn from_token(token: Token) -> Self { - let table_id = - TableId::from_token_type(token.table()).expect("Token has unrecognized table type"); - Self { + /// # Errors + /// + /// Returns an error if the token's table type is not recognized. + pub fn from_token(token: Token) -> Result { + let table_id = TableId::from_token_type(token.table()).ok_or_else(|| { + Error::Other(format!( + "Token has unrecognized table type: 0x{:02X}", + token.table() + )) + })?; + Ok(Self { id: NEXT_CHANGE_ID.fetch_add(1, Ordering::Relaxed), kind: ChangeRefKind::TableRow(table_id), content_hash: 0, resolved: AtomicBool::new(true), resolved_offset: AtomicU32::new(UNRESOLVED), resolved_token: AtomicU32::new(token.value()), - } + }) } /// Returns the unique ID of this change reference. @@ -522,9 +528,12 @@ impl ChangeRef { } /// Creates a reference from an existing token, wrapped in Arc. - #[must_use] - pub fn existing_token(token: Token) -> ChangeRefRc { - Arc::new(Self::from_token(token)) + /// + /// # Errors + /// + /// Returns an error if the token's table type is not recognized. + pub fn existing_token(token: Token) -> Result { + Ok(Arc::new(Self::from_token(token)?)) } } @@ -561,7 +570,7 @@ mod tests { #[test] fn test_changeref_from_existing_token() { let token = Token::new(0x0200_0001); // TypeDef row 1 - let ref1 = ChangeRef::existing_token(token); + let ref1 = ChangeRef::existing_token(token).unwrap(); assert!(ref1.is_resolved()); assert_eq!(ref1.token(), Some(token)); } diff --git a/dotscope/src/cilassembly/changes/heap.rs b/dotscope/src/cilassembly/changes/heap.rs index 7458bd70..742e76ee 100644 --- a/dotscope/src/cilassembly/changes/heap.rs +++ b/dotscope/src/cilassembly/changes/heap.rs @@ -275,7 +275,10 @@ impl HeapChanges { /// /// Each string contributes: UTF-8 byte length + 1 null terminator pub fn binary_string_heap_size(&self) -> usize { - self.appended.iter().map(|(s, _)| s.len() + 1).sum() + self.appended + .iter() + .map(|(s, _)| s.len().saturating_add(1)) + .sum() } /// Calculates the binary size for #US heap additions. @@ -285,12 +288,12 @@ impl HeapChanges { self.appended .iter() .map(|(s, _)| { - let utf16_bytes = s.encode_utf16().count() * 2; - let total_length = utf16_bytes + 1; + let utf16_bytes = s.encode_utf16().count().saturating_mul(2); + let total_length = utf16_bytes.saturating_add(1); // compressed_uint_size returns at most 4, so cast is always safe #[allow(clippy::cast_possible_truncation)] let compressed_length_size = compressed_uint_size(total_length) as usize; - compressed_length_size + total_length + compressed_length_size.saturating_add(total_length) }) .sum() } @@ -325,7 +328,7 @@ impl HeapChanges> { // compressed_uint_size returns at most 4, so cast is always safe #[allow(clippy::cast_possible_truncation)] let compressed_length_size = compressed_uint_size(length) as usize; - compressed_length_size + length + compressed_length_size.saturating_add(length) }) .sum() } @@ -353,7 +356,7 @@ impl HeapChanges<[u8; 16]> { /// /// Each GUID contributes exactly 16 bytes. pub fn binary_guid_heap_size(&self) -> usize { - self.appended.len() * 16 + self.appended.len().saturating_mul(16) } } diff --git a/dotscope/src/cilassembly/cleanup/compaction.rs b/dotscope/src/cilassembly/cleanup/compaction.rs index 52033210..1a57ba2d 100644 --- a/dotscope/src/cilassembly/cleanup/compaction.rs +++ b/dotscope/src/cilassembly/cleanup/compaction.rs @@ -79,7 +79,9 @@ impl CompactionStats { /// This is the sum of removed strings, blobs, and GUIDs. #[must_use] pub fn total_removed(&self) -> usize { - self.strings + self.blobs + self.guids + self.strings + .saturating_add(self.blobs) + .saturating_add(self.guids) } } @@ -140,7 +142,9 @@ pub(crate) fn mark_unreferenced_heap_entries( // Calculate the byte range of this string entry // Safe: .NET heap offsets always fit in u32 #[allow(clippy::cast_possible_truncation)] - let str_end = offset_u32 + content.len() as u32 + 1; // +1 for null terminator + let content_len = content.len() as u32; + // +1 for null terminator. Saturate on overflow (unreachable in practice). + let str_end = offset_u32.saturating_add(content_len).saturating_add(1); // Check if ANY referenced offset falls within this string's range let has_reference = ref_strings @@ -202,17 +206,17 @@ pub(crate) fn mark_unreferenced_heap_entries( // Mark unreferenced entries for removal for offset in unreferenced_strings { assembly.string_remove(offset)?; - stats.strings += 1; + stats.strings = stats.strings.saturating_add(1); } for offset in unreferenced_blobs { assembly.blob_remove(offset)?; - stats.blobs += 1; + stats.blobs = stats.blobs.saturating_add(1); } for index in unreferenced_guids { assembly.guid_remove(index)?; - stats.guids += 1; + stats.guids = stats.guids.saturating_add(1); } Ok(stats) @@ -419,7 +423,7 @@ fn scan_table_data_owned_rows( // Safe: .NET heap offsets always fit in u32 #[allow(clippy::cast_possible_truncation)] - let rid = (idx + 1) as u32; + let rid = idx.saturating_add(1) as u32; let mut offset = 0; if row_data .row_write(&mut row_buffer, &mut offset, rid, table_info) @@ -442,7 +446,10 @@ fn extract_heap_refs_from_row( ref_guids: &mut HashSet, ) { for field in heap_fields { - if field.offset + field.size > row_buffer.len() { + let Some(field_end) = field.offset.checked_add(field.size) else { + continue; + }; + if field_end > row_buffer.len() { continue; } diff --git a/dotscope/src/cilassembly/cleanup/executor.rs b/dotscope/src/cilassembly/cleanup/executor.rs index d37e1796..b537997e 100644 --- a/dotscope/src/cilassembly/cleanup/executor.rs +++ b/dotscope/src/cilassembly/cleanup/executor.rs @@ -372,7 +372,7 @@ pub fn execute_cleanup( for token in dead_methods.iter().rev() { if try_remove(assembly, TableId::MethodDef, token.row()) { removed_methods.insert(*token); - method_count += 1; + method_count = method_count.saturating_add(1); } } stats.add(TableId::MethodDef, method_count); @@ -382,7 +382,7 @@ pub fn execute_cleanup( for token in dead_fields.iter().rev() { if try_remove(assembly, TableId::Field, token.row()) { removed_fields.insert(*token); - field_count += 1; + field_count = field_count.saturating_add(1); } } stats.add(TableId::Field, field_count); @@ -646,11 +646,11 @@ fn remove_empty_types( }; // Remove empty types (in reverse RID order) - let mut removed = 0; + let mut removed: usize = 0; let mut removed_tokens = HashSet::new(); for rid in empty_types.into_iter().rev() { if try_remove(assembly, TableId::TypeDef, rid) { - removed += 1; + removed = removed.saturating_add(1); removed_tokens.insert(Token::from_parts(TableId::TypeDef, rid)); } } diff --git a/dotscope/src/cilassembly/cleanup/orphans.rs b/dotscope/src/cilassembly/cleanup/orphans.rs index fd17befb..225a979f 100644 --- a/dotscope/src/cilassembly/cleanup/orphans.rs +++ b/dotscope/src/cilassembly/cleanup/orphans.rs @@ -149,10 +149,10 @@ where }; // Second pass: remove in reverse order (mutable borrow) - let mut removed_count = 0; + let mut removed_count: usize = 0; for rid in orphan_rids.into_iter().rev() { if try_remove(assembly, T::TABLE_ID, rid) { - removed_count += 1; + removed_count = removed_count.saturating_add(1); } } @@ -203,10 +203,10 @@ pub(super) fn remove_orphan_params(assembly: &mut CilAssembly, ctx: &DeletionCon orphan_params.sort_unstable_by(|a, b| b.cmp(a)); orphan_params.dedup(); - let mut removed_count = 0; + let mut removed_count: usize = 0; for rid in orphan_params { if try_remove(assembly, TableId::Param, rid) { - removed_count += 1; + removed_count = removed_count.saturating_add(1); } } @@ -389,10 +389,10 @@ pub(super) fn remove_orphan_genericparam( let removed_rids: HashSet = orphan_rids.iter().copied().collect(); // Second pass: remove - let mut removed_count = 0; + let mut removed_count: usize = 0; for rid in orphan_rids.into_iter().rev() { if try_remove(assembly, TableId::GenericParam, rid) { - removed_count += 1; + removed_count = removed_count.saturating_add(1); } } @@ -499,11 +499,11 @@ pub(super) fn remove_orphan_events( orphan_events.sort_unstable_by(|a, b| b.cmp(a)); orphan_events.dedup(); - let mut removed = 0; + let mut removed: usize = 0; let mut deleted_tokens = HashSet::new(); for rid in orphan_events { if try_remove(assembly, TableId::Event, rid) { - removed += 1; + removed = removed.saturating_add(1); deleted_tokens.insert(Token::from_parts(TableId::Event, rid)); } } @@ -560,11 +560,11 @@ pub(super) fn remove_orphan_properties( orphan_properties.sort_unstable_by(|a, b| b.cmp(a)); orphan_properties.dedup(); - let mut removed = 0; + let mut removed: usize = 0; let mut deleted_tokens = HashSet::new(); for rid in orphan_properties { if try_remove(assembly, TableId::Property, rid) { - removed += 1; + removed = removed.saturating_add(1); deleted_tokens.insert(Token::from_parts(TableId::Property, rid)); } } @@ -590,10 +590,10 @@ pub(super) fn remove_orphan_standalonesigs( let alive = collect_referenced_standalonesig_rids(assembly); - let mut removed = 0; + let mut removed: usize = 0; for &rid in candidates.iter().rev() { if !alive.contains(&rid) && try_remove(assembly, TableId::StandAloneSig, rid) { - removed += 1; + removed = removed.saturating_add(1); } } @@ -742,10 +742,10 @@ pub(super) fn remove_orphan_exportedtypes(assembly: &mut CilAssembly) -> (usize, }; let mut deleted_rids = HashSet::new(); - let mut removed = 0; + let mut removed: usize = 0; for rid in orphan_rids.into_iter().rev() { if try_remove(assembly, TableId::ExportedType, rid) { - removed += 1; + removed = removed.saturating_add(1); deleted_rids.insert(rid); } } @@ -791,10 +791,10 @@ pub(super) fn remove_orphan_manifestresources(assembly: &mut CilAssembly) -> (us }; let mut deleted_rids = HashSet::new(); - let mut removed = 0; + let mut removed: usize = 0; for rid in orphan_rids.into_iter().rev() { if try_remove(assembly, TableId::ManifestResource, rid) { - removed += 1; + removed = removed.saturating_add(1); deleted_rids.insert(rid); } } diff --git a/dotscope/src/cilassembly/cleanup/references.rs b/dotscope/src/cilassembly/cleanup/references.rs index 940f6e95..89da2a78 100644 --- a/dotscope/src/cilassembly/cleanup/references.rs +++ b/dotscope/src/cilassembly/cleanup/references.rs @@ -228,7 +228,9 @@ fn get_effective_method_rva(assembly: &CilAssembly, rid: u32, original_rva: u32) } TableModifications::Replaced(rows) => { // Full replacement - find the row by index (RID - 1) - if let Some(TableDataOwned::MethodDef(row)) = rows.get((rid - 1) as usize) { + if let Some(TableDataOwned::MethodDef(row)) = + rows.get(rid.saturating_sub(1) as usize) + { return row.rva; } } @@ -255,13 +257,19 @@ fn scan_method_body_bytes(data: &[u8], base_rva: usize, referenced: &mut HashSet // Get the code bytes (after the header) let code_start = body.size_header; - let code_end = code_start + body.size_code; + let Some(code_end) = code_start.checked_add(body.size_code) else { + return; + }; if code_end > data.len() { return; } - let code_data = &data[code_start..code_end]; - let code_rva = base_rva + body.size_header; + let Some(code_data) = data.get(code_start..code_end) else { + return; + }; + let Some(code_rva) = base_rva.checked_add(body.size_header) else { + return; + }; // Helper: extract token operands from decoded blocks let collect = |blocks: &[BasicBlock], out: &mut HashSet| { @@ -287,7 +295,9 @@ fn scan_method_body_bytes(data: &[u8], base_rva: usize, referenced: &mut HashSet for handler in &body.exception_handlers { let h_offset = handler.handler_offset as usize; if h_offset < code_data.len() { - let h_rva = code_rva + h_offset; + let Some(h_rva) = code_rva.checked_add(h_offset) else { + continue; + }; if let Ok(blocks) = decode_blocks(code_data, h_offset, h_rva, None) { collect(&blocks, referenced); } diff --git a/dotscope/src/cilassembly/cleanup/request.rs b/dotscope/src/cilassembly/cleanup/request.rs index 15f55164..0a41a0ab 100644 --- a/dotscope/src/cilassembly/cleanup/request.rs +++ b/dotscope/src/cilassembly/cleanup/request.rs @@ -502,14 +502,15 @@ impl CleanupRequest { /// Returns the total count of items to delete. #[must_use] pub fn deletion_count(&self) -> usize { - self.types.len() - + self.methods.len() - + self.methodspecs.len() - + self.fields.len() - + self.attributes.len() - + self.assemblyrefs.len() - + self.modulerefs.len() - + self.manifest_resources.len() + self.types + .len() + .saturating_add(self.methods.len()) + .saturating_add(self.methodspecs.len()) + .saturating_add(self.fields.len()) + .saturating_add(self.attributes.len()) + .saturating_add(self.assemblyrefs.len()) + .saturating_add(self.modulerefs.len()) + .saturating_add(self.manifest_resources.len()) } /// Checks if a specific token is marked for deletion. diff --git a/dotscope/src/cilassembly/cleanup/stats.rs b/dotscope/src/cilassembly/cleanup/stats.rs index 042a03e7..3bc4a39b 100644 --- a/dotscope/src/cilassembly/cleanup/stats.rs +++ b/dotscope/src/cilassembly/cleanup/stats.rs @@ -43,7 +43,8 @@ impl CleanupStats { /// Adds `count` to the removal counter for the given table. pub fn add(&mut self, table: TableId, count: usize) { if count > 0 { - *self.removals.entry(table).or_insert(0) += count; + let entry = self.removals.entry(table).or_insert(0); + *entry = entry.saturating_add(count); } } @@ -69,7 +70,9 @@ impl CleanupStats { /// This is the sum of compacted strings, blobs, and GUIDs. #[must_use] pub fn heap_entries_compacted(&self) -> usize { - self.strings_compacted + self.blobs_compacted + self.guids_compacted + self.strings_compacted + .saturating_add(self.blobs_compacted) + .saturating_add(self.guids_compacted) } /// Returns the count of primary items removed (types, methods, fields). @@ -78,7 +81,9 @@ impl CleanupStats { /// as opposed to orphaned entries that were removed as a consequence. #[must_use] pub fn primary_removed(&self) -> usize { - self.get(TableId::TypeDef) + self.get(TableId::MethodDef) + self.get(TableId::Field) + self.get(TableId::TypeDef) + .saturating_add(self.get(TableId::MethodDef)) + .saturating_add(self.get(TableId::Field)) } /// Returns the count of orphaned metadata entries removed. @@ -87,7 +92,9 @@ impl CleanupStats { /// primary deletions were applied (e.g., parameters of deleted methods). #[must_use] pub fn orphans_removed(&self) -> usize { - self.total_removed() - self.primary_removed() - self.get(TableId::CustomAttribute) + self.total_removed() + .saturating_sub(self.primary_removed()) + .saturating_sub(self.get(TableId::CustomAttribute)) } /// Merges stats from another cleanup operation into this one. @@ -98,10 +105,14 @@ impl CleanupStats { for (&table, &count) in &other.removals { self.add(table, count); } - self.sections_excluded += other.sections_excluded; - self.strings_compacted += other.strings_compacted; - self.blobs_compacted += other.blobs_compacted; - self.guids_compacted += other.guids_compacted; + self.sections_excluded = self + .sections_excluded + .saturating_add(other.sections_excluded); + self.strings_compacted = self + .strings_compacted + .saturating_add(other.strings_compacted); + self.blobs_compacted = self.blobs_compacted.saturating_add(other.blobs_compacted); + self.guids_compacted = self.guids_compacted.saturating_add(other.guids_compacted); } } diff --git a/dotscope/src/cilassembly/cleanup/utils.rs b/dotscope/src/cilassembly/cleanup/utils.rs index c025b358..d8ad452d 100644 --- a/dotscope/src/cilassembly/cleanup/utils.rs +++ b/dotscope/src/cilassembly/cleanup/utils.rs @@ -30,11 +30,12 @@ pub(super) fn list_range( child_count: u32, get_list_start: impl Fn(u32) -> Option, ) -> std::ops::Range { - let start = get_list_start(owner_rid).unwrap_or(child_count + 1); + let sentinel = child_count.saturating_add(1); + let start = get_list_start(owner_rid).unwrap_or(sentinel); let end = if owner_rid < owner_count { - get_list_start(owner_rid + 1).unwrap_or(child_count + 1) + get_list_start(owner_rid.saturating_add(1)).unwrap_or(sentinel) } else { - child_count + 1 + sentinel }; start..end } @@ -59,11 +60,11 @@ pub(super) fn remove_candidates_not_alive( candidates: &BTreeSet, alive: &HashSet, ) -> (usize, HashSet) { - let mut removed = 0; + let mut removed: usize = 0; let mut deleted_rids = HashSet::new(); for &rid in candidates.iter().rev() { if !alive.contains(&rid) && try_remove(assembly, table_id, rid) { - removed += 1; + removed = removed.saturating_add(1); deleted_rids.insert(rid); } } @@ -89,8 +90,8 @@ pub(crate) fn with_method_body( let Ok(offset) = file.rva_to_offset(effective_rva as usize) else { return; }; - if offset < original_data.len() { - callback(&original_data[offset..], effective_rva as usize); + if let Some(slice) = original_data.get(offset..) { + callback(slice, effective_rva as usize); } } } diff --git a/dotscope/src/cilassembly/mod.rs b/dotscope/src/cilassembly/mod.rs index 8d99a932..3b52191f 100644 --- a/dotscope/src/cilassembly/mod.rs +++ b/dotscope/src/cilassembly/mod.rs @@ -146,7 +146,7 @@ use crate::{ tables::{AssemblyRefRaw, CodedIndex, CodedIndexType, TableDataOwned, TableId}, token::Token, }, - CilObject, Result, ValidationConfig, + CilObject, Error, Result, ValidationConfig, }; mod builders; @@ -553,7 +553,7 @@ impl CilAssembly { // Add the original resource size since new resources are appended after original ones // This ensures ManifestResource.offset_field points to the correct location let original_size = self.view.cor20header().resource_size; - new_offset + original_size + new_offset.saturating_add(original_size) } /// Updates an existing string in the string heap at the specified index. @@ -914,7 +914,7 @@ impl CilAssembly { .changes .table_changes .entry(table_id) - .or_insert_with(|| TableModifications::new_sparse(original_count + 1)); + .or_insert_with(|| TableModifications::new_sparse(original_count.saturating_add(1))); let operation = Operation::Update(rid, new_row); let table_operation = TableOperation::new(operation); @@ -944,7 +944,7 @@ impl CilAssembly { .changes .table_changes .entry(table_id) - .or_insert_with(|| TableModifications::new_sparse(original_count + 1)); + .or_insert_with(|| TableModifications::new_sparse(original_count.saturating_add(1))); let operation = Operation::Delete(rid); let table_operation = TableOperation::new(operation); @@ -977,7 +977,7 @@ impl CilAssembly { .changes .table_changes .entry(table_id) - .or_insert_with(|| TableModifications::new_sparse(original_count + 1)); + .or_insert_with(|| TableModifications::new_sparse(original_count.saturating_add(1))); let new_rid = table_changes.next_rid()?; let operation = Operation::Insert(new_rid, row); @@ -1232,7 +1232,9 @@ impl CilAssembly { modifications.next_rid() } else { // No modifications yet - next RID is original count + 1 - Ok(self.original_table_row_count(table_id) + 1) + self.original_table_row_count(table_id) + .checked_add(1) + .ok_or_else(|| Error::LayoutFailed("table row count overflow".to_string())) } } @@ -1704,7 +1706,7 @@ impl CilAssembly { // Convert 0-based index to 1-based RID return Some(CodedIndex::new( TableId::AssemblyRef, - u32::try_from(index + 1).unwrap_or(u32::MAX), + u32::try_from(index.saturating_add(1)).unwrap_or(u32::MAX), CodedIndexType::Implementation, )); } diff --git a/dotscope/src/cilassembly/modifications.rs b/dotscope/src/cilassembly/modifications.rs index afc8dc23..60bc8dd8 100644 --- a/dotscope/src/cilassembly/modifications.rs +++ b/dotscope/src/cilassembly/modifications.rs @@ -291,28 +291,34 @@ impl TableModifications { // binary_search_by_key returns Ok(index) if found, which would insert BEFORE // existing entries with the same timestamp. We need to insert AFTER all // entries with the same timestamp to maintain insertion order. - let insert_pos = match operations - .binary_search_by_key(&op.timestamp, |o| o.timestamp) - { - Ok(mut pos) => { - // Found an entry with the same timestamp - scan forward to find the - // end of all entries with this timestamp (FIFO ordering) - while pos < operations.len() && operations[pos].timestamp == op.timestamp { - pos += 1; + let insert_pos = + match operations.binary_search_by_key(&op.timestamp, |o| o.timestamp) { + Ok(mut pos) => { + // Found an entry with the same timestamp - scan forward to find the + // end of all entries with this timestamp (FIFO ordering) + while operations + .get(pos) + .is_some_and(|o| o.timestamp == op.timestamp) + { + pos = pos.saturating_add(1); + } + pos } - pos - } - Err(pos) => pos, // Not found - insert at the natural position - }; + Err(pos) => pos, // Not found - insert at the natural position + }; operations.insert(insert_pos, op); // Update auxiliary data structures - let inserted_op = &operations[insert_pos]; + let inserted_op = operations.get(insert_pos).ok_or_else(|| { + crate::malformed_error!("inserted operation index out of bounds") + })?; match &inserted_op.operation { Operation::Insert(rid, _) => { inserted_rows.insert(*rid); if *rid >= *next_rid { - *next_rid = *rid + 1; + *next_rid = rid + .checked_add(1) + .ok_or_else(|| crate::malformed_error!("next_rid overflows u32"))?; } } Operation::Delete(rid) => { @@ -570,12 +576,17 @@ impl TableModifications { Self::Sparse { next_rid, .. } => Ok(*next_rid), Self::Replaced(rows) => { let len = u32::try_from(rows.len()).map_err(|_| { - crate::Error::LayoutFailed(format!( + Error::LayoutFailed(format!( "Table row count {} exceeds maximum u32 value", rows.len() )) })?; - Ok(len + 1) + len.checked_add(1).ok_or_else(|| { + Error::LayoutFailed(format!( + "Table row count {} + 1 exceeds u32::MAX", + rows.len() + )) + }) } } } diff --git a/dotscope/src/cilassembly/writer/context.rs b/dotscope/src/cilassembly/writer/context.rs index 475e84ef..159e0fa9 100644 --- a/dotscope/src/cilassembly/writer/context.rs +++ b/dotscope/src/cilassembly/writer/context.rs @@ -523,7 +523,7 @@ impl<'a> WriteContext<'a> { /// /// * `amount` - The number of bytes to advance pub fn advance(&mut self, amount: u64) { - self.position += amount; + self.position = self.position.saturating_add(amount); if self.position > self.bytes_written { self.bytes_written = self.position; } @@ -538,9 +538,12 @@ impl<'a> WriteContext<'a> { /// /// * `alignment` - The alignment boundary (must be a power of 2) pub fn align_to(&mut self, alignment: u64) { - let remainder = self.position % alignment; + let Some(remainder) = self.position.checked_rem(alignment) else { + return; + }; if remainder != 0 { - self.position += alignment - remainder; + let padding = alignment.saturating_sub(remainder); + self.position = self.position.saturating_add(padding); } } @@ -575,9 +578,13 @@ impl<'a> WriteContext<'a> { /// /// Returns an error if writing the padding bytes fails. pub fn align_to_with_padding(&mut self, alignment: u64) -> Result<()> { - let remainder = self.position % alignment; + let Some(remainder) = self.position.checked_rem(alignment) else { + return Ok(()); + }; if remainder != 0 { - let padding = alignment - remainder; + let padding = alignment.checked_sub(remainder).ok_or_else(|| { + Error::LayoutFailed("Alignment underflow computing padding".to_string()) + })?; // Safety: padding is always < alignment, and alignment is typically 4, 8, or 512 // so this will never exceed usize range let padding_usize = usize::try_from(padding).map_err(|_| { @@ -637,8 +644,11 @@ impl<'a> WriteContext<'a> { /// Returns an error if writing fails. pub fn write_at(&mut self, offset: u64, data: &[u8]) -> Result<()> { self.output.write_at(offset, data)?; - if offset + data.len() as u64 > self.bytes_written { - self.bytes_written = offset + data.len() as u64; + let end = offset + .checked_add(data.len() as u64) + .ok_or_else(|| Error::LayoutFailed("write_at end offset overflow".to_string()))?; + if end > self.bytes_written { + self.bytes_written = end; } Ok(()) } @@ -710,7 +720,7 @@ impl<'a> WriteContext<'a> { if offset >= self.text_section_offset { // In practice, PE files have sections well under 4GB, so this conversion is safe. // If the offset difference somehow exceeds u32, we saturate to avoid panic. - let diff = offset - self.text_section_offset; + let diff = offset.saturating_sub(self.text_section_offset); let diff_u32 = u32::try_from(diff).unwrap_or(u32::MAX); self.text_section_rva.saturating_add(diff_u32) } else { diff --git a/dotscope/src/cilassembly/writer/fields.rs b/dotscope/src/cilassembly/writer/fields.rs index 19ab918e..ae4b5535 100644 --- a/dotscope/src/cilassembly/writer/fields.rs +++ b/dotscope/src/cilassembly/writer/fields.rs @@ -351,11 +351,13 @@ pub fn write_field_data(ctx: &mut WriteContext) -> Result<()> { /// * `buffer` - The raw FieldRVA row bytes /// * `field_data_rva_map` - Mapping from old/placeholder RVAs to actual RVAs pub fn resolve_field_data_rva(buffer: &mut [u8], field_data_rva_map: &HashMap) { - if buffer.len() < 4 { + let Some(rva_bytes) = buffer.get(..4) else { return; - } + }; - let rva = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); + let mut bytes = [0u8; 4]; + bytes.copy_from_slice(rva_bytes); + let rva = u32::from_le_bytes(bytes); // RVA 0 means no data if rva == 0 { @@ -365,10 +367,9 @@ pub fn resolve_field_data_rva(buffer: &mut [u8], field_data_rva_map: &HashMap Result<()> { // Fix e_lfanew in DOS header (offset 0x3C points to PE signature) - let pe_sig_offset = u32::try_from(ctx.pe_signature_offset).map_err(|_| { - crate::Error::LayoutFailed("PE signature offset exceeds u32 range".to_string()) - })?; - ctx.write_u32_at(ctx.dos_header_offset + 0x3C, pe_sig_offset)?; + let pe_sig_offset = u32::try_from(ctx.pe_signature_offset) + .map_err(|_| Error::LayoutFailed("PE signature offset exceeds u32 range".to_string()))?; + let lfanew_pos = ctx + .dos_header_offset + .checked_add(0x3C) + .ok_or_else(|| Error::LayoutFailed("DOS header offset overflow".to_string()))?; + ctx.write_u32_at(lfanew_pos, pe_sig_offset)?; // Fix header fields fixup_optional_header(ctx)?; @@ -133,7 +136,7 @@ pub fn fixup_optional_header(ctx: &mut WriteContext) -> Result<()> { ctx.text_section_size, u64::from(ctx.file_alignment), )) - .map_err(|_| crate::Error::LayoutFailed("Text file size exceeds u32 range".to_string()))?; + .map_err(|_| Error::LayoutFailed("Text file size exceeds u32 range".to_string()))?; // Calculate total image size from all active sections let mut end_rva: u32 = 0; @@ -153,24 +156,30 @@ pub fn fixup_optional_header(ctx: &mut WriteContext) -> Result<()> { u64::from(end_rva), u64::from(ctx.section_alignment), )) - .map_err(|_| crate::Error::LayoutFailed("Image size exceeds u32 range".to_string()))?; + .map_err(|_| Error::LayoutFailed("Image size exceeds u32 range".to_string()))?; + + let oh = ctx.optional_header_offset; + let oh_field = |off: u64| -> Result { + oh.checked_add(off) + .ok_or_else(|| Error::LayoutFailed("Optional header offset overflow".to_string())) + }; // SizeOfCode at offset 4 (after magic field) - ctx.write_u32_at(ctx.optional_header_offset + 4, text_file_size)?; + ctx.write_u32_at(oh_field(4)?, text_file_size)?; // AddressOfEntryPoint at offset 16 // This is the RVA of the native entry point stub that jumps to _CorExeMain/_CorDllMain if let Some(entry_rva) = ctx.native_entry_rva { - ctx.write_u32_at(ctx.optional_header_offset + 16, entry_rva)?; + ctx.write_u32_at(oh_field(16)?, entry_rva)?; } // SizeOfImage at offset 56 - ctx.write_u32_at(ctx.optional_header_offset + 56, image_size)?; + ctx.write_u32_at(oh_field(56)?, image_size)?; // SizeOfHeaders at offset 60 let headers_size = u32::try_from(ctx.text_section_offset) - .map_err(|_| crate::Error::LayoutFailed("Headers size exceeds u32 range".to_string()))?; - ctx.write_u32_at(ctx.optional_header_offset + 60, headers_size)?; + .map_err(|_| Error::LayoutFailed("Headers size exceeds u32 range".to_string()))?; + ctx.write_u32_at(oh_field(60)?, headers_size)?; Ok(()) } @@ -209,14 +218,14 @@ pub fn fixup_section_table(ctx: &mut WriteContext) -> Result<()> { u64::from(ctx.file_alignment), )) .map_err(|_| { - crate::Error::LayoutFailed(format!( + Error::LayoutFailed(format!( "Section {} file size exceeds u32 range", section.name )) })?; let offset_u32 = u32::try_from(data_offset).map_err(|_| { - crate::Error::LayoutFailed(format!("Section {} offset exceeds u32 range", section.name)) + Error::LayoutFailed(format!("Section {} offset exceeds u32 range", section.name)) })?; // Build section header @@ -225,28 +234,59 @@ pub fn fixup_section_table(ctx: &mut WriteContext) -> Result<()> { // Name (8 bytes, null-padded) let name_bytes = section.name.as_bytes(); let copy_len = std::cmp::min(name_bytes.len(), 8); - header[..copy_len].copy_from_slice(&name_bytes[..copy_len]); + let name_src = name_bytes.get(..copy_len).ok_or(out_of_bounds_error!())?; + header + .get_mut(..copy_len) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(name_src); // VirtualSize - header[8..12].copy_from_slice(&data_size.to_le_bytes()); + header + .get_mut(8..12) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&data_size.to_le_bytes()); // VirtualAddress - header[12..16].copy_from_slice(&rva.to_le_bytes()); + header + .get_mut(12..16) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&rva.to_le_bytes()); // SizeOfRawData - header[16..20].copy_from_slice(&file_size.to_le_bytes()); + header + .get_mut(16..20) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&file_size.to_le_bytes()); // PointerToRawData - header[20..24].copy_from_slice(&offset_u32.to_le_bytes()); + header + .get_mut(20..24) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&offset_u32.to_le_bytes()); // COFF relocation and line number fields are always 0 for PE executables. // These are legacy fields from COFF object files used during linking. // PointerToRelocations - header[24..28].copy_from_slice(&0u32.to_le_bytes()); + header + .get_mut(24..28) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&0u32.to_le_bytes()); // PointerToLinenumbers (deprecated) - header[28..32].copy_from_slice(&0u32.to_le_bytes()); + header + .get_mut(28..32) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&0u32.to_le_bytes()); // NumberOfRelocations - header[32..34].copy_from_slice(&0u16.to_le_bytes()); + header + .get_mut(32..34) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&0u16.to_le_bytes()); // NumberOfLinenumbers (deprecated) - header[34..36].copy_from_slice(&0u16.to_le_bytes()); + header + .get_mut(34..36) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&0u16.to_le_bytes()); // Characteristics - header[36..40].copy_from_slice(§ion.characteristics.to_le_bytes()); + header + .get_mut(36..40) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(§ion.characteristics.to_le_bytes()); section_headers.push(header); } @@ -255,20 +295,36 @@ pub fn fixup_section_table(ctx: &mut WriteContext) -> Result<()> { let mut offset = ctx.section_table_offset; for header in §ion_headers { ctx.write_at(offset, header)?; - offset += SectionTable::SIZE as u64; + offset = offset + .checked_add(SectionTable::SIZE as u64) + .ok_or_else(|| Error::LayoutFailed("Section table offset overflow".to_string()))?; } // Zero any remaining space (from removed sections) - let original_table_size = ctx.sections.len() * SectionTable::SIZE; - let new_table_size = section_headers.len() * SectionTable::SIZE; + let original_table_size = ctx + .sections + .len() + .checked_mul(SectionTable::SIZE) + .ok_or_else(|| Error::LayoutFailed("Section table size overflow".to_string()))?; + let new_table_size = section_headers + .len() + .checked_mul(SectionTable::SIZE) + .ok_or_else(|| Error::LayoutFailed("Section table size overflow".to_string()))?; if new_table_size < original_table_size { - let zeros = vec![0u8; original_table_size - new_table_size]; + let zero_len = original_table_size + .checked_sub(new_table_size) + .ok_or_else(|| Error::LayoutFailed("Section table size underflow".to_string()))?; + let zeros = vec![0u8; zero_len]; ctx.write_at(offset, &zeros)?; } // Update section count in COFF header let new_count = u16::try_from(section_headers.len()).unwrap_or(0); - ctx.write_u16_at(ctx.coff_header_offset + 2, new_count)?; + let coff_count_pos = ctx + .coff_header_offset + .checked_add(2) + .ok_or_else(|| Error::LayoutFailed("COFF header offset overflow".to_string()))?; + ctx.write_u16_at(coff_count_pos, new_count)?; Ok(()) } @@ -291,10 +347,17 @@ pub fn fixup_cor20_header(ctx: &mut WriteContext) -> Result<()> { // MetaData directory (offset 8-15) let metadata_rva = ctx.offset_to_rva(ctx.metadata_offset); let metadata_size = u32::try_from(ctx.metadata_size) - .map_err(|_| crate::Error::LayoutFailed("Metadata size exceeds u32 range".to_string()))?; + .map_err(|_| Error::LayoutFailed("Metadata size exceeds u32 range".to_string()))?; + + let cor20 = ctx.cor20_header_offset; + let cor20_field = |off: u64| -> Result { + cor20 + .checked_add(off) + .ok_or_else(|| Error::LayoutFailed("COR20 header offset overflow".to_string())) + }; - ctx.write_u32_at(ctx.cor20_header_offset + 8, metadata_rva)?; // MetaData RVA - ctx.write_u32_at(ctx.cor20_header_offset + 12, metadata_size)?; // MetaData Size + ctx.write_u32_at(cor20_field(8)?, metadata_rva)?; // MetaData RVA + ctx.write_u32_at(cor20_field(12)?, metadata_size)?; // MetaData Size // Resources directory (offset 24-31) // COR20 header layout: @@ -311,18 +374,17 @@ pub fn fixup_cor20_header(ctx: &mut WriteContext) -> Result<()> { // Entry point token (offset 20-23) - may need remapping if methods were deleted if ctx.entry_point_token != 0 && !ctx.token_remapping.is_empty() { if let Some(&new_token) = ctx.token_remapping.get(&ctx.entry_point_token) { - ctx.write_u32_at(ctx.cor20_header_offset + 20, new_token)?; + ctx.write_u32_at(cor20_field(20)?, new_token)?; } } if ctx.resource_data_size > 0 { let resource_rva = ctx.offset_to_rva(ctx.resource_data_offset); - let resource_size = u32::try_from(ctx.resource_data_size).map_err(|_| { - crate::Error::LayoutFailed("Resource size exceeds u32 range".to_string()) - })?; + let resource_size = u32::try_from(ctx.resource_data_size) + .map_err(|_| Error::LayoutFailed("Resource size exceeds u32 range".to_string()))?; - ctx.write_u32_at(ctx.cor20_header_offset + 24, resource_rva)?; // Resources RVA - ctx.write_u32_at(ctx.cor20_header_offset + 28, resource_size)?; // Resources Size + ctx.write_u32_at(cor20_field(24)?, resource_rva)?; // Resources RVA + ctx.write_u32_at(cor20_field(28)?, resource_size)?; // Resources Size } Ok(()) @@ -357,8 +419,25 @@ pub fn fixup_cor20_header(ctx: &mut WriteContext) -> Result<()> { /// Debug (index 6) and Certificate (index 4) directories are zeroed during the /// write phase because they become invalid after assembly modification. pub fn fixup_data_directories(ctx: &mut WriteContext) -> Result<()> { - let dd_offset = if ctx.is_pe32_plus { 112 } else { 96 }; - let dd_base = ctx.optional_header_offset + dd_offset; + let dd_offset: u64 = if ctx.is_pe32_plus { 112 } else { 96 }; + let dd_base = ctx + .optional_header_offset + .checked_add(dd_offset) + .ok_or_else(|| Error::LayoutFailed("Data directory offset overflow".to_string()))?; + + // Each directory entry is 8 bytes: index N → offset N*8 (RVA), N*8+4 (Size) + let dd_entry = |index: u64| -> Result<(u64, u64)> { + let rva_off = index + .checked_mul(8) + .and_then(|v| dd_base.checked_add(v)) + .ok_or_else(|| { + Error::LayoutFailed("Data directory entry offset overflow".to_string()) + })?; + let size_off = rva_off.checked_add(4).ok_or_else(|| { + Error::LayoutFailed("Data directory entry offset overflow".to_string()) + })?; + Ok((rva_off, size_off)) + }; // IAT (index 12) and CLR Runtime Header (index 14) // When the assembly has native imports (IAT was written), the layout is: @@ -366,82 +445,91 @@ pub fn fixup_data_directories(ctx: &mut WriteContext) -> Result<()> { // When no native imports exist (.NET Core PE32+ without mscoree.dll): // .text start → COR20 header → ... let has_iat = ctx.iat_size > 0; + let (iat_rva_off, iat_size_off) = dd_entry(12)?; + let (clr_rva_off, clr_size_off) = dd_entry(14)?; if has_iat { let iat_rva = ctx.text_section_rva; let iat_size = u32::try_from(ctx.iat_size).unwrap_or(8); - ctx.write_u32_at(dd_base + 12 * 8, iat_rva)?; - ctx.write_u32_at(dd_base + 12 * 8 + 4, iat_size)?; + ctx.write_u32_at(iat_rva_off, iat_rva)?; + ctx.write_u32_at(iat_size_off, iat_size)?; // CLR header sits immediately after IAT - let clr_rva = ctx.text_section_rva + iat_size; - ctx.write_u32_at(dd_base + 14 * 8, clr_rva)?; - ctx.write_u32_at(dd_base + 14 * 8 + 4, COR20_HEADER_SIZE)?; + let clr_rva = ctx + .text_section_rva + .checked_add(iat_size) + .ok_or_else(|| Error::LayoutFailed("CLR header RVA overflow".to_string()))?; + ctx.write_u32_at(clr_rva_off, clr_rva)?; + ctx.write_u32_at(clr_size_off, COR20_HEADER_SIZE)?; } else { // No IAT - zero the IAT data directory - ctx.write_u32_at(dd_base + 12 * 8, 0)?; - ctx.write_u32_at(dd_base + 12 * 8 + 4, 0)?; + ctx.write_u32_at(iat_rva_off, 0)?; + ctx.write_u32_at(iat_size_off, 0)?; // CLR header sits at the very start of .text section let clr_rva = ctx.text_section_rva; - ctx.write_u32_at(dd_base + 14 * 8, clr_rva)?; - ctx.write_u32_at(dd_base + 14 * 8 + 4, COR20_HEADER_SIZE)?; + ctx.write_u32_at(clr_rva_off, clr_rva)?; + ctx.write_u32_at(clr_size_off, COR20_HEADER_SIZE)?; } // Import Table (index 1) + let (imp_rva_off, imp_size_off) = dd_entry(1)?; if let (Some(rva), Some(size)) = (ctx.import_data_rva, ctx.import_data_size) { - ctx.write_u32_at(dd_base + 8, rva)?; - ctx.write_u32_at(dd_base + 8 + 4, size)?; + ctx.write_u32_at(imp_rva_off, rva)?; + ctx.write_u32_at(imp_size_off, size)?; } else { // No import table - zero the directory entry - ctx.write_u32_at(dd_base + 8, 0)?; - ctx.write_u32_at(dd_base + 8 + 4, 0)?; + ctx.write_u32_at(imp_rva_off, 0)?; + ctx.write_u32_at(imp_size_off, 0)?; } // Export Table (index 0) + let (exp_rva_off, exp_size_off) = dd_entry(0)?; if let (Some(rva), Some(size)) = (ctx.export_data_rva, ctx.export_data_size) { - ctx.write_u32_at(dd_base, rva)?; - ctx.write_u32_at(dd_base + 4, size)?; + ctx.write_u32_at(exp_rva_off, rva)?; + ctx.write_u32_at(exp_size_off, size)?; } // Resource Table (index 2) - find .rsrc section or embedded PE resources + let (rsrc_rva_off, rsrc_size_off) = dd_entry(2)?; let rsrc_section = ctx .sections .iter() .find(|s| s.name.starts_with(".rsrc") && !s.removed); if let Some(section) = rsrc_section { if let (Some(rva), Some(size)) = (section.rva, section.data_size) { - ctx.write_u32_at(dd_base + 2 * 8, rva)?; - ctx.write_u32_at(dd_base + 2 * 8 + 4, size)?; + ctx.write_u32_at(rsrc_rva_off, rva)?; + ctx.write_u32_at(rsrc_size_off, size)?; } } else if ctx.pe_resource_size > 0 { // Resources were embedded in .text and carried over let rva = ctx.offset_to_rva(ctx.pe_resource_offset); - ctx.write_u32_at(dd_base + 2 * 8, rva)?; - ctx.write_u32_at(dd_base + 2 * 8 + 4, ctx.pe_resource_size)?; + ctx.write_u32_at(rsrc_rva_off, rva)?; + ctx.write_u32_at(rsrc_size_off, ctx.pe_resource_size)?; } else { // No resources at all - zero the directory entry - ctx.write_u32_at(dd_base + 2 * 8, 0)?; - ctx.write_u32_at(dd_base + 2 * 8 + 4, 0)?; + ctx.write_u32_at(rsrc_rva_off, 0)?; + ctx.write_u32_at(rsrc_size_off, 0)?; } // Base Relocation Table (index 5) - find .reloc section in sections vector + let (reloc_rva_off, reloc_size_off) = dd_entry(5)?; let reloc_section = ctx .sections .iter() .find(|s| s.name.starts_with(".reloc") && !s.removed); if let Some(section) = reloc_section { if let (Some(rva), Some(size)) = (section.rva, section.data_size) { - ctx.write_u32_at(dd_base + 5 * 8, rva)?; - ctx.write_u32_at(dd_base + 5 * 8 + 4, size)?; + ctx.write_u32_at(reloc_rva_off, rva)?; + ctx.write_u32_at(reloc_size_off, size)?; } else { // Section exists but no data written - zero the directory - ctx.write_u32_at(dd_base + 5 * 8, 0)?; - ctx.write_u32_at(dd_base + 5 * 8 + 4, 0)?; + ctx.write_u32_at(reloc_rva_off, 0)?; + ctx.write_u32_at(reloc_size_off, 0)?; } } else { // No reloc section or it was removed - zero out the data directory entry - ctx.write_u32_at(dd_base + 5 * 8, 0)?; - ctx.write_u32_at(dd_base + 5 * 8 + 4, 0)?; + ctx.write_u32_at(reloc_rva_off, 0)?; + ctx.write_u32_at(reloc_size_off, 0)?; } Ok(()) @@ -517,23 +605,34 @@ pub fn fixup_metadata_stream_headers( for (stream_offset, stream_size, name) in &streams { // Calculate offset relative to metadata root - let relative_offset = - u32::try_from(*stream_offset - metadata_root_offset).map_err(|_| { - crate::Error::LayoutFailed("Stream relative offset exceeds u32 range".to_string()) - })?; + let rel = stream_offset + .checked_sub(metadata_root_offset) + .ok_or_else(|| Error::LayoutFailed("Stream offset before metadata root".to_string()))?; + let relative_offset = u32::try_from(rel).map_err(|_| { + Error::LayoutFailed("Stream relative offset exceeds u32 range".to_string()) + })?; let aligned_size = u32::try_from(align_to(*stream_size, 4)).map_err(|_| { - crate::Error::LayoutFailed("Stream aligned size exceeds u32 range".to_string()) + Error::LayoutFailed("Stream aligned size exceeds u32 range".to_string()) })?; // Write offset ctx.write_u32_at(offset, relative_offset)?; // Write size - ctx.write_u32_at(offset + 4, aligned_size)?; + let size_off = offset + .checked_add(4) + .ok_or_else(|| Error::LayoutFailed("Stream header offset overflow".to_string()))?; + ctx.write_u32_at(size_off, aligned_size)?; // Advance past this stream header (offset + size + name with alignment) - let name_with_null = name.len() + 1; + let name_with_null = name + .len() + .checked_add(1) + .ok_or_else(|| Error::LayoutFailed("Stream name length overflow".to_string()))?; let aligned_name = align_to(name_with_null as u64, 4); - offset += 8 + aligned_name; + offset = aligned_name + .checked_add(8) + .and_then(|v| offset.checked_add(v)) + .ok_or_else(|| Error::LayoutFailed("Stream header offset overflow".to_string()))?; } Ok(()) @@ -593,7 +692,10 @@ pub fn zero_stripped_data_regions(ctx: &mut WriteContext) -> Result<()> { // any modification. if let Some((cert_offset, cert_size)) = ctx.original_certificate_dir { let cert_offset_u64 = u64::from(cert_offset); - if cert_offset_u64 + u64::from(cert_size) <= ctx.bytes_written { + let cert_end = cert_offset_u64 + .checked_add(u64::from(cert_size)) + .ok_or_else(|| Error::LayoutFailed("Certificate region offset overflow".to_string()))?; + if cert_end <= ctx.bytes_written { let zeros = vec![0u8; cert_size as usize]; ctx.write_at(cert_offset_u64, &zeros)?; } @@ -627,16 +729,25 @@ pub fn fixup_coff_characteristics(ctx: &mut WriteContext) -> Result<()> { const IMAGE_FILE_RELOCS_STRIPPED: u16 = 0x0001; // Read current characteristics (offset +18 from COFF header) - let chars_offset = ctx.coff_header_offset + 18; + let chars_offset = ctx + .coff_header_offset + .checked_add(18) + .ok_or_else(|| Error::LayoutFailed("COFF header offset overflow".to_string()))?; // Read the current value from the output let current_bytes = ctx.output.as_slice(); #[allow(clippy::cast_possible_truncation)] let chars_offset_usize = chars_offset as usize; - if chars_offset_usize + 2 <= current_bytes.len() { + let chars_end = chars_offset_usize.checked_add(2).ok_or_else(|| { + Error::LayoutFailed("COFF characteristics offset overflow".to_string()) + })?; + if chars_end <= current_bytes.len() { + let chars_slice = current_bytes + .get(chars_offset_usize..chars_end) + .ok_or(out_of_bounds_error!())?; let current = u16::from_le_bytes([ - current_bytes[chars_offset_usize], - current_bytes[chars_offset_usize + 1], + *chars_slice.first().ok_or(out_of_bounds_error!())?, + *chars_slice.get(1).ok_or(out_of_bounds_error!())?, ]); // Set the RELOCS_STRIPPED flag @@ -674,9 +785,12 @@ pub fn fixup_coff_characteristics(ctx: &mut WriteContext) -> Result<()> { /// mmap size. This ensures the checksum matches what the final truncated /// file will contain. pub fn fixup_checksum(ctx: &mut WriteContext) -> Result<()> { - let checksum_offset = ctx.optional_header_offset + 64; + let checksum_offset = ctx + .optional_header_offset + .checked_add(64) + .ok_or_else(|| Error::LayoutFailed("Checksum offset overflow".to_string()))?; let actual_size = usize::try_from(ctx.bytes_written) - .map_err(|_| crate::Error::LayoutFailed("File size exceeds usize range".to_string()))?; + .map_err(|_| Error::LayoutFailed("File size exceeds usize range".to_string()))?; let checksum = calculate_pe_checksum(&ctx.output, checksum_offset, actual_size); ctx.write_u32_at(checksum_offset, checksum)?; @@ -715,39 +829,48 @@ fn calculate_pe_checksum(output: &Output, checksum_offset: u64, actual_size: usi let file_size = actual_size.min(data.len()); // Don't exceed mmap bounds // Safe: checksum_offset is a small PE header offset that always fits in usize let checksum_offset_usize = usize::try_from(checksum_offset).unwrap_or(usize::MAX); + let checksum_end_usize = checksum_offset_usize.saturating_add(4); let mut sum: u64 = 0; - // Process 16-bit words directly from the memory-mapped file - let mut i = 0; - while i + 1 < file_size { + // Process 16-bit words directly from the memory-mapped file. + // The bounds (file_size) are clamped to the slice length above, so all + // index/saturating arithmetic stays within `data`. + let mut i: usize = 0; + while i.saturating_add(1) < file_size { // Skip the checksum field (4 bytes = 2 words) - if i >= checksum_offset_usize && i < checksum_offset_usize + 4 { - i += 2; + if i >= checksum_offset_usize && i < checksum_end_usize { + i = i.saturating_add(2); continue; } - let word = u16::from_le_bytes([data[i], data[i + 1]]); - sum += u64::from(word); - i += 2; + let word = match data + .get(i..i.saturating_add(2)) + .and_then(|s| s.try_into().ok()) + { + Some(arr) => u16::from_le_bytes(arr), + None => break, + }; + sum = sum.saturating_add(u64::from(word)); + i = i.saturating_add(2); } // Handle odd byte at the end of file (if any) - pad with zero - if i < file_size { - // Only include if not in checksum field - if i < checksum_offset_usize || i >= checksum_offset_usize + 4 { - sum += u64::from(data[i]); + if i < file_size && (i < checksum_offset_usize || i >= checksum_end_usize) { + if let Some(&byte) = data.get(i) { + sum = sum.saturating_add(u64::from(byte)); } } // Fold the sum to 16 bits (add carry to low 16 bits) while sum > 0xFFFF { - sum = (sum & 0xFFFF) + (sum >> 16); + sum = (sum & 0xFFFF).saturating_add(sum >> 16); } - // Add file size - safe: sum is folded to fit in u16, and file_size fits in u32 on all platforms + // Add file size - the PE checksum spec defines `(sum & 0xFFFF) + file_size`; + // wrap is part of the algorithm for files near 4 GiB. #[allow(clippy::cast_possible_truncation)] - let checksum = (sum as u32) + (file_size as u32); + let checksum = (sum as u32).wrapping_add(file_size as u32); checksum } diff --git a/dotscope/src/cilassembly/writer/generator.rs b/dotscope/src/cilassembly/writer/generator.rs index aa141380..d58bc025 100644 --- a/dotscope/src/cilassembly/writer/generator.rs +++ b/dotscope/src/cilassembly/writer/generator.rs @@ -435,7 +435,10 @@ impl<'a> PeGenerator<'a> { ctx.align_to_4(); ctx.method_bodies_offset = ctx.pos(); self.write_method_bodies(&mut ctx, changes)?; - ctx.method_bodies_size = ctx.pos() - ctx.method_bodies_offset; + ctx.method_bodies_size = ctx + .pos() + .checked_sub(ctx.method_bodies_offset) + .ok_or_else(|| Error::LayoutFailed("Method bodies size underflow".to_string()))?; // Write field initialization data (FieldRVA entries) write_field_data(&mut ctx)?; @@ -448,7 +451,10 @@ impl<'a> PeGenerator<'a> { ctx.align_to_4(); ctx.metadata_offset = ctx.pos(); self.write_metadata(&mut ctx, changes)?; - ctx.metadata_size = ctx.pos() - ctx.metadata_offset; + ctx.metadata_size = ctx + .pos() + .checked_sub(ctx.metadata_offset) + .ok_or_else(|| Error::LayoutFailed("Metadata size underflow".to_string()))?; // Write import/export data (if present) if self.needs_native_imports() { @@ -461,7 +467,10 @@ impl<'a> PeGenerator<'a> { self.write_embedded_pe_resources(&mut ctx)?; // Calculate .text section size and update sections vector - ctx.text_section_size = ctx.pos() - ctx.text_section_offset; + ctx.text_section_size = ctx + .pos() + .checked_sub(ctx.text_section_offset) + .ok_or_else(|| Error::LayoutFailed(".text section size underflow".to_string()))?; let text_size_u32 = u32::try_from(ctx.text_section_size).unwrap_or(u32::MAX); if let Some(idx) = ctx.find_section_index(".text") { ctx.update_section( @@ -523,11 +532,15 @@ impl<'a> PeGenerator<'a> { // Add 20% buffer for safety let estimated = original_size - + method_bodies_expansion - + field_data_expansion - + heap_expansion - + fieldrva_expansion; - let with_buffer = (estimated * 120) / 100; + .checked_add(method_bodies_expansion) + .and_then(|v| v.checked_add(field_data_expansion)) + .and_then(|v| v.checked_add(heap_expansion)) + .and_then(|v| v.checked_add(fieldrva_expansion)) + .ok_or_else(|| Error::LayoutFailed("File size estimate overflow".to_string()))?; + let with_buffer = estimated + .checked_mul(120) + .map(|v| v / 100) + .ok_or_else(|| Error::LayoutFailed("File size estimate overflow".to_string()))?; // Align to file alignment Ok(align_to(with_buffer, u64::from(FILE_ALIGNMENT_DEFAULT))) @@ -565,20 +578,24 @@ impl<'a> PeGenerator<'a> { // Add space for appended strings for (data, _) in changes.string_heap_changes.appended_iter() { - expansion += data.len() as u64 + 1; // +1 for null terminator + // +1 for null terminator. Saturate on absurd estimate; final emission validates. + expansion = expansion.saturating_add((data.len() as u64).saturating_add(1)); } // Add space for appended blobs for (data, _) in changes.blob_heap_changes.appended_iter() { - expansion += data.len() as u64 + 5; // +5 for max compressed length prefix + // +5 for max compressed length prefix + expansion = expansion.saturating_add((data.len() as u64).saturating_add(5)); } // Add space for appended GUIDs - expansion += (changes.guid_heap_changes.appended_iter().count() * 16) as u64; + let guid_count = changes.guid_heap_changes.appended_iter().count() as u64; + expansion = expansion.saturating_add(guid_count.saturating_mul(16)); // Add space for appended user strings for (data, _) in changes.userstring_heap_changes.appended_iter() { - expansion += data.len() as u64 + 5; // +5 for max compressed length prefix + // +5 for max compressed length prefix + expansion = expansion.saturating_add((data.len() as u64).saturating_add(5)); } expansion @@ -806,7 +823,7 @@ impl<'a> PeGenerator<'a> { let imports = self.build_import_list(ctx)?; // Calculate IAT size - let iat_size = imports.iat_byte_size(ctx.is_pe32_plus); + let iat_size = imports.iat_byte_size(ctx.is_pe32_plus)?; if iat_size == 0 { // No imports - write minimal placeholder (shouldn't happen for .NET) @@ -1104,8 +1121,13 @@ impl<'a> PeGenerator<'a> { } else { // CIL method - parse, remap tokens, and rebuild let offset = file.rva_to_offset(original_rva as usize)?; - let available_data = - file.data_slice(offset, file.data().len() - offset)?; + let remaining = + file.data().len().checked_sub(offset).ok_or_else(|| { + Error::LayoutFailed( + "Method offset exceeds file size".to_string(), + ) + })?; + let available_data = file.data_slice(offset, remaining)?; let rebuilt = rebuild_method_body( available_data, @@ -1115,7 +1137,7 @@ impl<'a> PeGenerator<'a> { )?; // Fat method headers require 4-byte alignment (ECMA-335 §II.25.4.2) - let rebuilt_is_fat = !rebuilt.is_empty() && (rebuilt[0] & 0x3) == 0x3; + let rebuilt_is_fat = rebuilt.first().is_some_and(|&b| (b & 0x3) == 0x3); if rebuilt_is_fat { ctx.align_to_4_with_padding()?; } @@ -1142,7 +1164,7 @@ impl<'a> PeGenerator<'a> { )?; // Fat method headers require 4-byte alignment (ECMA-335 §II.25.4.2) - let rebuilt_is_fat = !rebuilt.is_empty() && (rebuilt[0] & 0x3) == 0x3; + let rebuilt_is_fat = rebuilt.first().is_some_and(|&b| (b & 0x3) == 0x3); if rebuilt_is_fat { ctx.align_to_4_with_padding()?; } @@ -1224,18 +1246,32 @@ impl<'a> PeGenerator<'a> { let entry_start = old_offset as usize; // Read the 4-byte length prefix - if entry_start + 4 > original_data.len() { + let len_end = match entry_start.checked_add(4) { + Some(v) => v, + None => continue, + }; + if len_end > original_data.len() { continue; } - let data_len = u32::from_le_bytes([ - original_data[entry_start], - original_data[entry_start + 1], - original_data[entry_start + 2], - original_data[entry_start + 3], - ]) as usize; - let entry_total = 4 + data_len; - - if entry_start + entry_total > original_data.len() { + let len_slice = match original_data.get(entry_start..len_end) { + Some(s) => s, + None => continue, + }; + let len_arr: [u8; 4] = match len_slice.try_into() { + Ok(a) => a, + Err(_) => continue, + }; + let data_len = u32::from_le_bytes(len_arr) as usize; + let entry_total = match data_len.checked_add(4) { + Some(v) => v, + None => continue, + }; + + let entry_end = match entry_start.checked_add(entry_total) { + Some(v) => v, + None => continue, + }; + if entry_end > original_data.len() { continue; } @@ -1245,9 +1281,16 @@ impl<'a> PeGenerator<'a> { } // Write surviving resource and track offset remapping - ctx.write(&original_data[entry_start..entry_start + entry_total])?; + let entry_data = match original_data.get(entry_start..entry_end) { + Some(s) => s, + None => continue, + }; + ctx.write(entry_data)?; ctx.resource_offset_remap.insert(old_offset, new_offset); - new_offset += entry_total as u32; + new_offset = + new_offset.checked_add(entry_total as u32).ok_or_else(|| { + Error::LayoutFailed("Resource offset overflow".to_string()) + })?; } } } else { @@ -1270,7 +1313,10 @@ impl<'a> PeGenerator<'a> { } // Calculate total size - ctx.resource_data_size = ctx.pos() - ctx.resource_data_offset; + ctx.resource_data_size = ctx + .pos() + .checked_sub(ctx.resource_data_offset) + .ok_or_else(|| Error::LayoutFailed("Resource data size underflow".to_string()))?; Ok(()) } @@ -1339,7 +1385,10 @@ impl<'a> PeGenerator<'a> { fixup_metadata_stream_headers(ctx, metadata_root_offset, stream_headers_offset)?; // Update metadata size - ctx.metadata_size = ctx.pos() - metadata_root_offset; + ctx.metadata_size = ctx + .pos() + .checked_sub(metadata_root_offset) + .ok_or_else(|| Error::LayoutFailed("Metadata size underflow".to_string()))?; // Note: Table ChangeRefs are resolved earlier (before method bodies) // to support methods with local variable signatures @@ -1377,7 +1426,12 @@ impl<'a> PeGenerator<'a> { }) .collect(); - let version_padded_len = (root.version.len() + 3) & !3; + let version_padded_len = root + .version + .len() + .checked_add(3) + .ok_or_else(|| Error::LayoutFailed("Version length overflow".to_string()))? + & !3; let version_len_u32 = u32::try_from(version_padded_len).map_err(|_| { Error::LayoutFailed(format!( "Version length {version_padded_len} exceeds u32 range" @@ -1397,8 +1451,13 @@ impl<'a> PeGenerator<'a> { // Calculate where stream headers will start (after fixed root header) // sig(4) + major(2) + minor(2) + reserved(4) + length(4) + version(padded) + flags(2) + count(2) - let fixed_header_size = 4 + 2 + 2 + 4 + 4 + version_padded_len + 2 + 2; - let stream_headers_offset = ctx.pos() + fixed_header_size as u64; + let fixed_header_size = 20usize + .checked_add(version_padded_len) + .ok_or_else(|| Error::LayoutFailed("Header size overflow".to_string()))?; + let stream_headers_offset = ctx + .pos() + .checked_add(fixed_header_size as u64) + .ok_or_else(|| Error::LayoutFailed("Stream headers offset overflow".to_string()))?; // Write the full root header using its write_to method modified_root.write_to(ctx)?; @@ -1475,8 +1534,14 @@ impl<'a> PeGenerator<'a> { ); // Calculate where table data starts (after tables stream header) - let header_size = 24 + (valid.count_ones() as usize * 4); - let mut table_data_offset = ctx.tables_stream_offset + header_size as u64; + let header_size = (valid.count_ones() as usize) + .checked_mul(4) + .and_then(|v| v.checked_add(24)) + .ok_or_else(|| Error::LayoutFailed("Tables header size overflow".to_string()))?; + let mut table_data_offset = ctx + .tables_stream_offset + .checked_add(header_size as u64) + .ok_or_else(|| Error::LayoutFailed("Tables data offset overflow".to_string()))?; // Clone remapping to avoid borrow issues let strings_remap = ctx.heap_remapping.strings.clone(); @@ -1496,7 +1561,12 @@ impl<'a> PeGenerator<'a> { // Patch each row in the output table for output_idx in 0..output_row_count as usize { - let row_offset = table_data_offset + (output_idx as u64 * row_size as u64); + let row_off_within = (output_idx as u64) + .checked_mul(row_size as u64) + .ok_or_else(|| Error::LayoutFailed("Row offset overflow".to_string()))?; + let row_offset = table_data_offset + .checked_add(row_off_within) + .ok_or_else(|| Error::LayoutFailed("Row offset overflow".to_string()))?; let mut row_buffer = vec![0u8; row_size]; ctx.output.read_at(row_offset, &mut row_buffer)?; @@ -1529,7 +1599,12 @@ impl<'a> PeGenerator<'a> { ctx.write_at(row_offset, &row_buffer)?; } - table_data_offset += u64::from(output_row_count) * (row_size as u64); + let table_bytes = u64::from(output_row_count) + .checked_mul(row_size as u64) + .ok_or_else(|| Error::LayoutFailed("Table size overflow".to_string()))?; + table_data_offset = table_data_offset + .checked_add(table_bytes) + .ok_or_else(|| Error::LayoutFailed("Table data offset overflow".to_string()))?; } Ok(()) @@ -1557,23 +1632,28 @@ impl<'a> PeGenerator<'a> { let heap_fields = get_heap_fields(table_id, table_info); for field in heap_fields { - if field.offset + field.size > row_buffer.len() { + let Some(field_end) = field.offset.checked_add(field.size) else { + continue; + }; + if field_end > row_buffer.len() { continue; } + let Some(field_slice) = row_buffer.get(field.offset..field_end) else { + continue; + }; + // Read the field value let value = if field.size == 4 { - u32::from_le_bytes([ - row_buffer[field.offset], - row_buffer[field.offset + 1], - row_buffer[field.offset + 2], - row_buffer[field.offset + 3], - ]) + let Ok(arr) = <[u8; 4]>::try_from(field_slice) else { + continue; + }; + u32::from_le_bytes(arr) } else { - u32::from(u16::from_le_bytes([ - row_buffer[field.offset], - row_buffer[field.offset + 1], - ])) + let Ok(arr) = <[u8; 2]>::try_from(field_slice) else { + continue; + }; + u32::from(u16::from_le_bytes(arr)) }; // Check if it's a placeholder @@ -1581,17 +1661,19 @@ impl<'a> PeGenerator<'a> { if let Some(change_ref) = changes.lookup_by_placeholder(value) { if let Some(resolved) = change_ref.offset() { // Write the resolved value back + let Some(field_slice_mut) = row_buffer.get_mut(field.offset..field_end) + else { + continue; + }; if field.size == 4 { - row_buffer[field.offset..field.offset + 4] - .copy_from_slice(&resolved.to_le_bytes()); + field_slice_mut.copy_from_slice(&resolved.to_le_bytes()); } else { // Truncate to u16 - this is safe because heap offsets in small // metadata files fit in u16. Overflow would indicate a corrupted state. #[allow(clippy::cast_possible_truncation)] let small_value = u16::try_from(resolved).unwrap_or((resolved & 0xFFFF) as u16); - row_buffer[field.offset..field.offset + 2] - .copy_from_slice(&small_value.to_le_bytes()); + field_slice_mut.copy_from_slice(&small_value.to_le_bytes()); } } } @@ -1612,19 +1694,22 @@ impl<'a> PeGenerator<'a> { let root = view.metadata_root(); // Base header: signature (4) + major (2) + minor (2) + reserved (4) + version_length (4) - let base_size = 16; + let base_size: usize = 16; // Version string aligned to 4 bytes let version_len = root.version.len(); // Safe cast: version_len is a string length which is always small - let aligned_version = - usize::try_from(align_to(version_len as u64, 4)).unwrap_or(version_len + 4); + let aligned_version = usize::try_from(align_to(version_len as u64, 4)) + .unwrap_or_else(|_| version_len.saturating_add(4)); // Flags (2) + stream count (2) - let flags_and_count = 4; + let flags_and_count: usize = 4; // Stream headers: each is offset (4) + size (4) + name (variable, 4-byte aligned) // Estimate 5 streams max with ~12 bytes each for names - let stream_headers = 5 * (8 + 12); + let stream_headers: usize = 5 * (8 + 12); - base_size + aligned_version + flags_and_count + stream_headers + base_size + .saturating_add(aligned_version) + .saturating_add(flags_and_count) + .saturating_add(stream_headers) } /// Writes all heaps using streaming writers. @@ -1876,7 +1961,10 @@ impl<'a> PeGenerator<'a> { self.write_table_data(ctx, table_id, &output_table_info, changes, &remapper)?; } - ctx.tables_stream_size = ctx.pos() - ctx.tables_stream_offset; + ctx.tables_stream_size = ctx + .pos() + .checked_sub(ctx.tables_stream_offset) + .ok_or_else(|| Error::LayoutFailed("Tables stream size underflow".to_string()))?; Ok(()) } @@ -1936,7 +2024,12 @@ impl<'a> PeGenerator<'a> { "Table {table_id:?} deleted count {deleted_count} exceeds u32::MAX" )) })?; - Ok(original_count + added - deleted) + let total = original_count + .checked_add(added) + .ok_or_else(|| Error::LayoutFailed("Row count overflow".to_string()))?; + total + .checked_sub(deleted) + .ok_or_else(|| Error::LayoutFailed("Row count underflow".to_string())) } } } @@ -1991,8 +2084,11 @@ impl<'a> PeGenerator<'a> { if let Some(TableModifications::Replaced(rows)) = table_mod { let mut buffer = vec![0u8; output_row_size]; for (idx, row) in rows.iter().enumerate() { - let rid = u32::try_from(idx + 1).map_err(|_| { - Error::LayoutFailed(format!("Row index {} exceeds u32 range", idx + 1)) + let rid_idx = idx + .checked_add(1) + .ok_or_else(|| Error::LayoutFailed("Row index overflow".to_string()))?; + let rid = u32::try_from(rid_idx).map_err(|_| { + Error::LayoutFailed(format!("Row index {rid_idx} exceeds u32 range")) })?; let mut resolved_row = row.clone(); @@ -2162,10 +2258,14 @@ impl<'a> PeGenerator<'a> { } else if table_id == TableId::ManifestResource && !resource_offset_remap.is_empty() { // ManifestResource row layout: offset_field (4 bytes) is first. // Remap the offset to account for resource data compaction. - if buffer.len() >= 4 { - let old_offset = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); - if let Some(&new_offset) = resource_offset_remap.get(&old_offset) { - buffer[..4].copy_from_slice(&new_offset.to_le_bytes()); + if let Some(rva_slice) = buffer.get(..4) { + if let Ok(arr) = <[u8; 4]>::try_from(rva_slice) { + let old_offset = u32::from_le_bytes(arr); + if let Some(&new_offset) = resource_offset_remap.get(&old_offset) { + if let Some(out) = buffer.get_mut(..4) { + out.copy_from_slice(&new_offset.to_le_bytes()); + } + } } } } @@ -2189,11 +2289,13 @@ impl<'a> PeGenerator<'a> { method_body_rva_map: &HashMap, original_rva_delta: i32, ) { - if buffer.len() < 4 { + let Some(rva_slice) = buffer.get(..4) else { return; - } - - let rva = u32::from_le_bytes([buffer[0], buffer[1], buffer[2], buffer[3]]); + }; + let Ok(rva_arr) = <[u8; 4]>::try_from(rva_slice) else { + return; + }; + let rva = u32::from_le_bytes(rva_arr); // RVA 0 means abstract/extern method with no body if rva == 0 { @@ -2206,18 +2308,21 @@ impl<'a> PeGenerator<'a> { mapped_rva } else if rva < 0xF000_0000 { // Original RVA not in map - apply delta as fallback - // (used when method bodies region was copied as a whole) - (rva.cast_signed() + original_rva_delta).cast_unsigned() + // (used when method bodies region was copied as a whole). + // If the delta overflows i32, leave the row unchanged rather than + // writing a corrupt RVA. + match rva.cast_signed().checked_add(original_rva_delta) { + Some(v) => v.cast_unsigned(), + None => return, + } } else { // Unmapped placeholder - keep as is (shouldn't happen in valid code) rva }; - let new_bytes = new_rva.to_le_bytes(); - buffer[0] = new_bytes[0]; - buffer[1] = new_bytes[1]; - buffer[2] = new_bytes[2]; - buffer[3] = new_bytes[3]; + if let Some(out) = buffer.get_mut(..4) { + out.copy_from_slice(&new_rva.to_le_bytes()); + } } /// Builds StandAloneSig deduplication mapping. @@ -2278,7 +2383,7 @@ impl<'a> PeGenerator<'a> { if deleted_rids.contains(&sig.rid) || ctx.standalonesig_skip.contains(&sig.rid) { continue; } - output_rid += 1; + output_rid = output_rid.saturating_add(1); rid_to_output.insert(sig.rid, output_rid); } @@ -2302,7 +2407,7 @@ impl<'a> PeGenerator<'a> { inserts.sort(); for rid in inserts { - output_rid += 1; + output_rid = output_rid.saturating_add(1); rid_to_output.insert(rid, output_rid); } } @@ -2459,7 +2564,9 @@ impl<'a> PeGenerator<'a> { // followed by 8-byte absolute address // But for .NET, the stub is simpler: jmp qword ptr [IAT] // The offset from RIP (after instruction) to IAT - let stub_end_rva = entry_rva + 6; // instruction is 6 bytes + let stub_end_rva = entry_rva + .checked_add(6) + .ok_or_else(|| Error::LayoutFailed("Entry stub RVA overflow".to_string()))?; // instruction is 6 bytes let rel_offset = iat_rva.wrapping_sub(stub_end_rva); let stub: [u8; 6] = [ 0xff, @@ -2473,7 +2580,9 @@ impl<'a> PeGenerator<'a> { } else { // PE32 (x86): Use absolute addressing // ff 25 xx xx xx xx = jmp dword ptr [VA] - let iat_va = image_base + u64::from(iat_rva); + let iat_va = image_base + .checked_add(u64::from(iat_rva)) + .ok_or_else(|| Error::LayoutFailed("IAT VA overflow".to_string()))?; let stub: [u8; 6] = [ 0xff, 0x25, // jmp dword ptr [abs] @@ -2621,7 +2730,10 @@ impl<'a> PeGenerator<'a> { return Ok(()); // Can't resolve, skip }; - let Some(data) = file.data().get(offset..offset + res_size as usize) else { + let Some(end) = offset.checked_add(res_size as usize) else { + return Ok(()); // Overflow, skip + }; + let Some(data) = file.data().get(offset..end) else { return Ok(()); // Out of bounds, skip }; @@ -2659,7 +2771,9 @@ impl<'a> PeGenerator<'a> { let file = view.file(); // Track current end RVA for calculating next section's RVA - let mut current_end_rva = u64::from(ctx.text_section_rva) + ctx.text_section_size; + let mut current_end_rva = u64::from(ctx.text_section_rva) + .checked_add(ctx.text_section_size) + .ok_or_else(|| Error::LayoutFailed(".text end RVA overflow".to_string()))?; // Get original section info for reloc processing let original_text_rva = file @@ -2676,7 +2790,12 @@ impl<'a> PeGenerator<'a> { // Iterate through sections in order for section_idx in 0..ctx.sections.len() { - let section_name = ctx.sections[section_idx].name.clone(); + let section_name = ctx + .sections + .get(section_idx) + .ok_or(out_of_bounds_error!())? + .name + .clone(); // Skip .text - already handled if section_name.starts_with(".text") { @@ -2711,7 +2830,11 @@ impl<'a> PeGenerator<'a> { if data_size > 0 { ctx.update_section(section_idx, data_offset, section_rva, data_size); - current_end_rva = u64::from(section_rva) + u64::from(data_size); + current_end_rva = u64::from(section_rva) + .checked_add(u64::from(data_size)) + .ok_or_else(|| { + Error::LayoutFailed("Section end RVA overflow".to_string()) + })?; } } else if section_name.starts_with(".reloc") { // Write reloc section with filtering @@ -2725,7 +2848,11 @@ impl<'a> PeGenerator<'a> { if let Some(data_size) = result { ctx.update_section(section_idx, data_offset, section_rva, data_size); - current_end_rva = u64::from(section_rva) + u64::from(data_size); + current_end_rva = u64::from(section_rva) + .checked_add(u64::from(data_size)) + .ok_or_else(|| { + Error::LayoutFailed("Section end RVA overflow".to_string()) + })?; } else { // Reloc section was filtered out entirely ctx.mark_section_removed(section_idx); @@ -2737,7 +2864,11 @@ impl<'a> PeGenerator<'a> { if data_size > 0 { ctx.update_section(section_idx, data_offset, section_rva, data_size); - current_end_rva = u64::from(section_rva) + u64::from(data_size); + current_end_rva = u64::from(section_rva) + .checked_add(u64::from(data_size)) + .ok_or_else(|| { + Error::LayoutFailed("Section end RVA overflow".to_string()) + })?; } } @@ -2762,10 +2893,16 @@ impl<'a> PeGenerator<'a> { let view = self.assembly.view(); let file = view.file(); - let Some(data) = file.data().get( - section.pointer_to_raw_data as usize - ..(section.pointer_to_raw_data + section.size_of_raw_data) as usize, - ) else { + let Some(end) = section + .pointer_to_raw_data + .checked_add(section.size_of_raw_data) + else { + return Ok(0); + }; + let Some(data) = file + .data() + .get(section.pointer_to_raw_data as usize..end as usize) + else { return Ok(0); }; @@ -2801,10 +2938,13 @@ impl<'a> PeGenerator<'a> { let file = view.file(); // Get original reloc data if present - let existing_data = file.data().get( - section.pointer_to_raw_data as usize - ..(section.pointer_to_raw_data + section.size_of_raw_data) as usize, - ); + let existing_data = section + .pointer_to_raw_data + .checked_add(section.size_of_raw_data) + .and_then(|end| { + file.data() + .get(section.pointer_to_raw_data as usize..end as usize) + }); // Build relocation configuration let config = RelocationConfig { @@ -2846,10 +2986,16 @@ impl<'a> PeGenerator<'a> { return Ok(0); } - let Some(data) = file.data().get( - section.pointer_to_raw_data as usize - ..(section.pointer_to_raw_data + section.size_of_raw_data) as usize, - ) else { + let Some(end) = section + .pointer_to_raw_data + .checked_add(section.size_of_raw_data) + else { + return Ok(0); + }; + let Some(data) = file + .data() + .get(section.pointer_to_raw_data as usize..end as usize) + else { return Ok(0); }; diff --git a/dotscope/src/cilassembly/writer/heaps/rowpatch.rs b/dotscope/src/cilassembly/writer/heaps/rowpatch.rs index 3aa79bb5..0fff5b16 100644 --- a/dotscope/src/cilassembly/writer/heaps/rowpatch.rs +++ b/dotscope/src/cilassembly/writer/heaps/rowpatch.rs @@ -127,7 +127,7 @@ pub fn patch_row_heap_refs( /// * `size` - Size of the field (2 or 4 bytes) /// * `remap` - Mapping from old heap offsets/indices to new values fn patch_heap_field(row_data: &mut [u8], offset: usize, size: usize, remap: &HashMap) { - if offset + size > row_data.len() { + if offset.saturating_add(size) > row_data.len() { return; } diff --git a/dotscope/src/cilassembly/writer/heaps/streaming.rs b/dotscope/src/cilassembly/writer/heaps/streaming.rs index e6c4e5d7..e8687d99 100644 --- a/dotscope/src/cilassembly/writer/heaps/streaming.rs +++ b/dotscope/src/cilassembly/writer/heaps/streaming.rs @@ -153,13 +153,16 @@ fn emit_orphaned_substrings( continue; } - let delta = (ref_offset - old_offset_u32) as usize; + let delta = ref_offset + .checked_sub(old_offset_u32) + .ok_or_else(|| malformed_error!("Substring delta underflow"))? + as usize; if delta >= original_bytes.len() { continue; } // Extract the null-terminated substring from the original string bytes - let sub_bytes = &original_bytes[delta..]; + let sub_bytes = original_bytes.get(delta..).ok_or(out_of_bounds_error!())?; let sub_str = std::str::from_utf8(sub_bytes).unwrap_or(""); if sub_str.is_empty() { result.remapping.insert(ref_offset, 0); @@ -176,12 +179,21 @@ fn emit_orphaned_substrings( Error::LayoutFailed(format!("Heap position {} exceeds u32 range", *pos)) })?; + let sub_pos = start_offset + .checked_add(*pos) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + let sub_end = sub_pos + .checked_add(sub_bytes.len() as u64) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; if let Some(out) = output.as_mut() { - out.write_at(start_offset + *pos, sub_bytes)?; - out.write_at(start_offset + *pos + sub_bytes.len() as u64, &[0u8])?; + out.write_at(sub_pos, sub_bytes)?; + out.write_at(sub_end, &[0u8])?; } - *pos += sub_bytes.len() as u64 + 1; + *pos = pos + .checked_add(sub_bytes.len() as u64) + .and_then(|p| p.checked_add(1)) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; dedup_map.insert(sub_hash, new_sub_offset); result.remapping.insert(ref_offset, new_sub_offset); } @@ -241,7 +253,10 @@ fn process_strings_heap( // This must use the ORIGINAL length, not the modified length, because // other tables may reference substring offsets within the original range. let original_bytes = original_str.as_bytes(); - let original_end = old_offset_u32 + to_u32(original_bytes.len())? + 1; // +1 for null + let original_end = old_offset_u32 + .checked_add(to_u32(original_bytes.len())?) + .and_then(|v| v.checked_add(1)) + .ok_or_else(|| malformed_error!("String range exceeds u32"))?; // +1 for null if changes.is_removed(old_offset_u32) { // Even though this entry is removed, emit any referenced substrings @@ -289,15 +304,25 @@ fn process_strings_heap( Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")) })?; let str_bytes = final_str.as_bytes(); - let entry_size = str_bytes.len() as u64 + 1; // +1 for null terminator + let entry_size = (str_bytes.len() as u64) + .checked_add(1) + .ok_or_else(|| malformed_error!("Heap entry size overflow"))?; // +1 for null terminator // Write if in write mode + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + let null_pos = write_pos + .checked_add(str_bytes.len() as u64) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; if let Some(out) = output.as_mut() { - out.write_at(start_offset + pos, str_bytes)?; - out.write_at(start_offset + pos + str_bytes.len() as u64, &[0u8])?; + out.write_at(write_pos, str_bytes)?; + out.write_at(null_pos, &[0u8])?; } - pos += entry_size; + pos = pos + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; dedup_map.insert(content_hash, new_offset); // Add primary offset remapping if changed @@ -324,8 +349,12 @@ fn process_strings_heap( // Unmodified: substrings are at the same delta in the new entry for &ref_offset in referenced_offsets { if ref_offset > old_offset_u32 && ref_offset < original_end { - let substring_delta = ref_offset - old_offset_u32; - let new_substring_offset = new_offset + substring_delta; + let substring_delta = ref_offset + .checked_sub(old_offset_u32) + .ok_or_else(|| malformed_error!("Substring delta underflow"))?; + let new_substring_offset = new_offset + .checked_add(substring_delta) + .ok_or_else(|| malformed_error!("Substring offset overflow"))?; result.remapping.insert(ref_offset, new_substring_offset); } } @@ -358,20 +387,32 @@ fn process_strings_heap( // PE and metadata heap sizes are limited to 4GB, so this cast is safe #[allow(clippy::cast_possible_truncation)] while result.remapping.contains_key(&(pos as u32)) { - pos += 1; + pos = pos + .checked_add(1) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; } let new_offset = u32::try_from(pos) .map_err(|_| Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")))?; let str_bytes = final_str.as_bytes(); - let entry_size = str_bytes.len() as u64 + 1; - + let entry_size = (str_bytes.len() as u64) + .checked_add(1) + .ok_or_else(|| malformed_error!("Heap entry size overflow"))?; + + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + let null_pos = write_pos + .checked_add(str_bytes.len() as u64) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; if let Some(out) = output.as_mut() { - out.write_at(start_offset + pos, str_bytes)?; - out.write_at(start_offset + pos + str_bytes.len() as u64, &[0u8])?; + out.write_at(write_pos, str_bytes)?; + out.write_at(null_pos, &[0u8])?; } - pos += entry_size; + pos = pos + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; dedup_map.insert(content_hash, new_offset); change_ref.resolve_to_offset(new_offset); } @@ -515,11 +556,16 @@ fn process_blob_heap( Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")) })?; + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; if let Some(out) = output.as_mut() { // Write compressed length of 0 - out.write_at(start_offset + pos, &[0u8])?; + out.write_at(write_pos, &[0u8])?; } - pos += 1; // Empty blob is just 1 byte (length 0) + pos = pos + .checked_add(1) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; // Empty blob is just 1 byte (length 0) // Only add to remapping if the offset actually changed if old_offset_u32 != new_offset { @@ -539,8 +585,13 @@ fn process_blob_heap( Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")) })?; let len_size = compressed_uint_size(final_blob.len()); - let entry_size = len_size + final_blob.len() as u64; + let entry_size = len_size + .checked_add(final_blob.len() as u64) + .ok_or_else(|| malformed_error!("Blob entry size overflow"))?; + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; if let Some(out) = output.as_mut() { let blob_len_u32 = u32::try_from(final_blob.len()).map_err(|_| { Error::LayoutFailed(format!( @@ -550,12 +601,16 @@ fn process_blob_heap( })?; let mut len_bytes = Vec::with_capacity(4); write_compressed_uint(blob_len_u32, &mut len_bytes); - let write_pos = start_offset + pos; + let data_pos = write_pos + .checked_add(len_bytes.len() as u64) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; out.write_at(write_pos, &len_bytes)?; - out.write_at(write_pos + len_bytes.len() as u64, final_blob)?; + out.write_at(data_pos, final_blob)?; } - pos += entry_size; + pos = pos + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; dedup_map.insert(content_hash, new_offset); // Only add to remapping if the offset actually changed if old_offset_u32 != new_offset { @@ -606,14 +661,21 @@ fn process_blob_heap( // PE and metadata heap sizes are limited to 4GB, so this cast is safe #[allow(clippy::cast_possible_truncation)] while result.remapping.contains_key(&(pos as u32)) { - pos += 1; + pos = pos + .checked_add(1) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; } let new_offset = u32::try_from(pos) .map_err(|_| Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")))?; let len_size = compressed_uint_size(final_blob.len()); - let entry_size = len_size + final_blob.len() as u64; + let entry_size = len_size + .checked_add(final_blob.len() as u64) + .ok_or_else(|| malformed_error!("Blob entry size overflow"))?; + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; if let Some(out) = output.as_mut() { let blob_len_u32 = u32::try_from(final_blob.len()).map_err(|_| { Error::LayoutFailed(format!( @@ -623,11 +685,16 @@ fn process_blob_heap( })?; let mut len_bytes = Vec::with_capacity(4); write_compressed_uint(blob_len_u32, &mut len_bytes); - out.write_at(start_offset + pos, &len_bytes)?; - out.write_at(start_offset + pos + len_bytes.len() as u64, final_blob)?; + let data_pos = write_pos + .checked_add(len_bytes.len() as u64) + .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + out.write_at(write_pos, &len_bytes)?; + out.write_at(data_pos, final_blob)?; } - pos += entry_size; + pos = pos + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("Heap position overflow"))?; dedup_map.insert(content_hash, new_offset); change_ref.resolve_to_offset(new_offset); } @@ -726,7 +793,10 @@ fn process_guid_heap( })?; // For GUIDs, changes use byte offset = (index - 1) * 16 - let byte_offset = (old_index_u32.saturating_sub(1)) * 16; + let byte_offset = old_index_u32 + .saturating_sub(1) + .checked_mul(16) + .ok_or_else(|| malformed_error!("GUID byte offset overflow"))?; // Check if deleted if changes.is_removed(byte_offset) { @@ -749,17 +819,24 @@ fn process_guid_heap( } // Write if in write mode + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("GUID heap write offset overflow"))?; if let Some(out) = output.as_mut() { - out.write_at(start_offset + pos, &final_guid)?; + out.write_at(write_pos, &final_guid)?; } - pos += 16; + pos = pos + .checked_add(16) + .ok_or_else(|| malformed_error!("GUID heap position overflow"))?; dedup_map.insert(final_guid, current_index); // Only add to remapping if the index actually changed if old_index_u32 != current_index { result.remapping.insert(old_index_u32, current_index); } - current_index += 1; + current_index = current_index + .checked_add(1) + .ok_or_else(|| malformed_error!("GUID index overflow"))?; } } @@ -778,14 +855,21 @@ fn process_guid_heap( continue; } + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("GUID heap write offset overflow"))?; if let Some(out) = output.as_mut() { - out.write_at(start_offset + pos, final_guid)?; + out.write_at(write_pos, final_guid)?; } - pos += 16; + pos = pos + .checked_add(16) + .ok_or_else(|| malformed_error!("GUID heap position overflow"))?; dedup_map.insert(*final_guid, current_index); change_ref.resolve_to_offset(current_index); - current_index += 1; + current_index = current_index + .checked_add(1) + .ok_or_else(|| malformed_error!("GUID index overflow"))?; } result.bytes_written = pos; @@ -910,14 +994,19 @@ fn process_userstring_heap( let new_offset = u32::try_from(pos).map_err(|_| { Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")) })?; - let entry_size = userstring_entry_size(final_str); + let entry_size = userstring_entry_size(final_str)?; // Write if in write mode + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; if let Some(out) = output.as_mut() { - write_userstring_entry(out, start_offset + pos, final_str)?; + write_userstring_entry(out, write_pos, final_str)?; } - pos += entry_size; + pos = pos + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("UserString heap position overflow"))?; dedup_map.insert(content_hash, new_offset); // Only add to remapping if the offset actually changed if old_offset_u32 != new_offset { @@ -945,13 +1034,18 @@ fn process_userstring_heap( let new_offset = u32::try_from(pos) .map_err(|_| Error::LayoutFailed(format!("Heap position {pos} exceeds u32 range")))?; - let entry_size = userstring_entry_size(final_str); + let entry_size = userstring_entry_size(final_str)?; + let write_pos = start_offset + .checked_add(pos) + .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; if let Some(out) = output.as_mut() { - write_userstring_entry(out, start_offset + pos, final_str)?; + write_userstring_entry(out, write_pos, final_str)?; } - pos += entry_size; + pos = pos + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("UserString heap position overflow"))?; dedup_map.insert(content_hash, new_offset); change_ref.resolve_to_offset(new_offset); } @@ -961,10 +1055,18 @@ fn process_userstring_heap( } /// Calculates the size of a userstring entry without writing. -fn userstring_entry_size(s: &str) -> u64 { - let utf16_len = s.encode_utf16().count() * 2; - let total_len = utf16_len + 1; // +1 for terminal byte - compressed_uint_size(total_len) + total_len as u64 +fn userstring_entry_size(s: &str) -> Result { + let utf16_len = s + .encode_utf16() + .count() + .checked_mul(2) + .ok_or_else(|| malformed_error!("UserString UTF-16 length overflow"))?; + let total_len = utf16_len + .checked_add(1) + .ok_or_else(|| malformed_error!("UserString total length overflow"))?; // +1 for terminal byte + compressed_uint_size(total_len) + .checked_add(total_len as u64) + .ok_or_else(|| malformed_error!("UserString entry size overflow")) } /// Writes a single user string entry to output. @@ -972,7 +1074,10 @@ fn userstring_entry_size(s: &str) -> u64 { /// Format: compressed_length + UTF-16LE bytes + terminal byte fn write_userstring_entry(output: &mut Output, pos: u64, s: &str) -> Result<()> { let utf16_bytes: Vec = s.encode_utf16().flat_map(u16::to_le_bytes).collect(); - let total_len = utf16_bytes.len() + 1; + let total_len = utf16_bytes + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("UserString total length overflow"))?; // Write compressed length let total_len_u32 = u32::try_from(total_len).map_err(|_| { @@ -983,14 +1088,17 @@ fn write_userstring_entry(output: &mut Output, pos: u64, s: &str) -> Result<()> output.write_at(pos, &len_bytes)?; // Write UTF-16LE bytes - output.write_at(pos + len_bytes.len() as u64, &utf16_bytes)?; + let utf16_pos = pos + .checked_add(len_bytes.len() as u64) + .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; + output.write_at(utf16_pos, &utf16_bytes)?; // Write terminal byte (0x01 if any byte has high bit set, 0x00 otherwise) let terminal = u8::from(utf16_bytes.iter().any(|&b| b & 0x80 != 0)); - output.write_at( - pos + len_bytes.len() as u64 + utf16_bytes.len() as u64, - &[terminal], - )?; + let terminal_pos = utf16_pos + .checked_add(utf16_bytes.len() as u64) + .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; + output.write_at(terminal_pos, &[terminal])?; Ok(()) } diff --git a/dotscope/src/cilassembly/writer/methods.rs b/dotscope/src/cilassembly/writer/methods.rs index d24c211b..22d42c3a 100644 --- a/dotscope/src/cilassembly/writer/methods.rs +++ b/dotscope/src/cilassembly/writer/methods.rs @@ -44,7 +44,7 @@ use crate::{ method::{ExceptionHandlerFlags, MethodBody}, tables::TableId, }, - Parser, Result, + Error, Parser, Result, }; /// UserString heap table ID (0x70) - used in ldstr tokens. @@ -80,7 +80,7 @@ pub fn write_method_body( ) -> Result { // Fat headers require 4-byte alignment let start_pos = if body.is_fat { - align_to_4(offset) + offset.saturating_add(3) & !3u64 } else { offset }; @@ -92,12 +92,6 @@ pub fn write_method_body( Ok(bytes_written) } -/// Aligns a position to a 4-byte boundary. -#[inline] -fn align_to_4(pos: u64) -> u64 { - (pos + 3) & !3 -} - /// Remaps IL tokens in place using decode-and-patch approach. /// /// This function handles all token transformations in IL bytecode: @@ -177,10 +171,15 @@ pub fn remap_il_tokens( // Token operand is the last 4 bytes of the instruction // Safe: CIL method body offsets fit in usize #[allow(clippy::cast_possible_truncation)] - let token_offset = instr.offset as usize + instr.size as usize - 4; - if token_offset + 4 <= il_bytes.len() { - il_bytes[token_offset..token_offset + 4] - .copy_from_slice(&new_value.to_le_bytes()); + let instr_offset = instr.offset as usize; + #[allow(clippy::cast_possible_truncation)] + let instr_size = instr.size as usize; + if let Some(instr_end) = instr_offset.checked_add(instr_size) { + if let Some(token_offset) = instr_end.checked_sub(4) { + if let Some(dest) = il_bytes.get_mut(token_offset..instr_end) { + dest.copy_from_slice(&new_value.to_le_bytes()); + } + } } } } @@ -231,7 +230,14 @@ pub fn rebuild_method_body( }; // Extract IL code from the body data - let mut il_code = body_data[body.size_header..body.size_header + body.size_code].to_vec(); + let il_end = body + .size_header + .checked_add(body.size_code) + .ok_or_else(|| Error::LayoutFailed("Method body IL end offset overflow".to_string()))?; + let mut il_code = body_data + .get(body.size_header..il_end) + .ok_or_else(|| Error::LayoutFailed("Method body IL slice out of bounds".to_string()))? + .to_vec(); // Remap IL tokens (metadata tokens, UserString references, placeholders) remap_il_tokens(&mut il_code, token_map, userstring_map, changes)?; diff --git a/dotscope/src/cilassembly/writer/output.rs b/dotscope/src/cilassembly/writer/output.rs index bb3cbadd..39d7ab95 100644 --- a/dotscope/src/cilassembly/writer/output.rs +++ b/dotscope/src/cilassembly/writer/output.rs @@ -273,7 +273,9 @@ impl Output { ))); } - Ok(&mut self.as_mut_slice()[start..end]) + self.as_mut_slice().get_mut(start..end).ok_or_else(|| { + Error::LayoutFailed(format!("Range {start}..{end} exceeds buffer of size {len}")) + }) } /// Gets a mutable slice starting at the given offset with the specified size. @@ -287,7 +289,11 @@ impl Output { /// # Errors /// Returns [`crate::Error::MmapFailed`] if the range is invalid or exceeds buffer bounds. pub fn get_mut_slice(&mut self, start: usize, size: usize) -> Result<&mut [u8]> { - let end = start + size; + let end = start.checked_add(size).ok_or_else(|| { + Error::LayoutFailed(format!( + "Slice end overflows usize: start={start}, size={size}" + )) + })?; let len = self.size(); if end > len { return Err(Error::MmapFailed(format!( @@ -311,7 +317,13 @@ impl Output { let start = usize::try_from(offset).map_err(|_| { Error::MmapFailed(format!("Offset {offset} too large for target architecture")) })?; - let end = start + data.len(); + let end = start.checked_add(data.len()).ok_or_else(|| { + Error::LayoutFailed(format!( + "Write end overflows usize: offset={}, len={}", + offset, + data.len() + )) + })?; let len = self.size(); if end > len { @@ -323,7 +335,12 @@ impl Output { ))); } - self.as_mut_slice()[start..end].copy_from_slice(data); + let dest = self.as_mut_slice().get_mut(start..end).ok_or_else(|| { + Error::LayoutFailed(format!( + "Slice {start}..{end} out of bounds for buffer size {len}" + )) + })?; + dest.copy_from_slice(data); Ok(()) } @@ -342,7 +359,13 @@ impl Output { let start = usize::try_from(offset).map_err(|_| { Error::MmapFailed(format!("Offset {offset} too large for target architecture")) })?; - let end = start + buffer.len(); + let end = start.checked_add(buffer.len()).ok_or_else(|| { + Error::LayoutFailed(format!( + "Read end overflows usize: offset={}, len={}", + offset, + buffer.len() + )) + })?; let len = self.size(); if end > len { @@ -354,7 +377,12 @@ impl Output { ))); } - buffer.copy_from_slice(&self.as_slice()[start..end]); + let src = self.as_slice().get(start..end).ok_or_else(|| { + Error::LayoutFailed(format!( + "Slice {start}..{end} out of bounds for buffer size {len}" + )) + })?; + buffer.copy_from_slice(src); Ok(()) } @@ -387,8 +415,16 @@ impl Output { Error::MmapFailed(format!("Size {size} too large for target architecture")) })?; - let source_end = source_start + copy_size; - let target_end = target_start + copy_size; + let source_end = source_start.checked_add(copy_size).ok_or_else(|| { + Error::LayoutFailed(format!( + "Source range end overflows usize: start={source_start}, size={copy_size}" + )) + })?; + let target_end = target_start.checked_add(copy_size).ok_or_else(|| { + Error::LayoutFailed(format!( + "Target range end overflows usize: start={target_start}, size={copy_size}" + )) + })?; let len = self.size(); // Validate bounds @@ -456,7 +492,10 @@ impl Output { ))); } - self.as_mut_slice()[index] = byte; + let cell = self.as_mut_slice().get_mut(index).ok_or_else(|| { + Error::LayoutFailed(format!("Index {index} out of bounds for buffer size {len}")) + })?; + *cell = byte; Ok(()) } @@ -514,7 +553,13 @@ impl Output { write_compressed_uint(value, &mut buffer); self.write_at(offset, &buffer)?; - Ok(offset + buffer.len() as u64) + offset.checked_add(buffer.len() as u64).ok_or_else(|| { + Error::LayoutFailed(format!( + "Offset {} + buffer length {} overflows u64", + offset, + buffer.len() + )) + }) } /// Writes data with automatic 4-byte alignment padding. @@ -535,10 +580,19 @@ impl Output { pub fn write_aligned_data(&mut self, offset: u64, data: &[u8]) -> Result { // Write the data self.write_at(offset, data)?; - let data_end = offset + data.len() as u64; + let data_end = offset.checked_add(data.len() as u64).ok_or_else(|| { + Error::LayoutFailed(format!( + "Offset {} + data length {} overflows u64", + offset, + data.len() + )) + })?; - // Calculate padding needed for 4-byte alignment - let padding_needed = (4 - (data.len() % 4)) % 4; + // Calculate padding needed for 4-byte alignment. + // `data.len() % 4` is in [0, 3]; subtraction from 4 yields [1, 4]; final + // `% 4` brings it back to [0, 3]. Use checked operations to satisfy lints. + let rem = data.len() % 4; + let padding_needed = 4usize.saturating_sub(rem) % 4; if padding_needed > 0 { // Fill padding with 0xFF bytes to prevent creation of valid heap entries @@ -553,7 +607,11 @@ impl Output { padding_slice.fill(0xFF); } - Ok(data_end + padding_needed as u64) + data_end.checked_add(padding_needed as u64).ok_or_else(|| { + Error::LayoutFailed(format!( + "Aligned end overflows u64: data_end={data_end}, padding={padding_needed}" + )) + }) } /// Writes data and returns the next position for sequential writing. @@ -573,7 +631,13 @@ impl Output { pub fn write_and_advance(&mut self, position: &mut usize, data: &[u8]) -> Result<()> { let slice = self.get_mut_slice(*position, data.len())?; slice.copy_from_slice(data); - *position += data.len(); + *position = position.checked_add(data.len()).ok_or_else(|| { + Error::LayoutFailed(format!( + "Position {} + data length {} overflows usize", + *position, + data.len() + )) + })?; Ok(()) } @@ -613,8 +677,14 @@ impl Output { /// # Errors /// Returns [`crate::Error::MmapFailed`] if the padding would exceed file bounds. pub fn add_heap_padding(&mut self, current_pos: usize, heap_start: usize) -> Result<()> { - let bytes_written = current_pos - heap_start; - let padding_needed = (4 - (bytes_written % 4)) % 4; + let bytes_written = current_pos.checked_sub(heap_start).ok_or_else(|| { + Error::LayoutFailed(format!( + "Heap padding called with current_pos={current_pos} < heap_start={heap_start}" + )) + })?; + // `bytes_written % 4` is in [0, 3]; final modulo brings the result to [0, 3]. + let rem = bytes_written % 4; + let padding_needed = 4usize.saturating_sub(rem) % 4; if padding_needed > 0 { self.fill_region(current_pos as u64, padding_needed, 0xFF)?; @@ -919,7 +989,13 @@ impl std::io::Write for OutputWriter<'_> { self.output .write_at(self.position, buf) .map_err(|e| std::io::Error::other(e.to_string()))?; - self.position += buf.len() as u64; + self.position = self.position.checked_add(buf.len() as u64).ok_or_else(|| { + std::io::Error::other(format!( + "Position {} + buffer length {} overflows u64", + self.position, + buf.len() + )) + })?; Ok(buf.len()) } diff --git a/dotscope/src/cilassembly/writer/relocations.rs b/dotscope/src/cilassembly/writer/relocations.rs index a9655021..58122cd8 100644 --- a/dotscope/src/cilassembly/writer/relocations.rs +++ b/dotscope/src/cilassembly/writer/relocations.rs @@ -266,36 +266,45 @@ fn generate_single_reloc_block(rva: u32, reloc_type: u16) -> Vec { /// Filtered relocation data with .text blocks removed. fn filter_relocation_blocks(reloc_data: &[u8], text_rva_range: (u32, u32)) -> Vec { let mut result = Vec::new(); - let mut offset = 0; + let mut offset: usize = 0; let (text_start, text_end) = text_rva_range; - while offset + 8 <= reloc_data.len() { - // Read block header - let block_va = u32::from_le_bytes([ - reloc_data[offset], - reloc_data[offset + 1], - reloc_data[offset + 2], - reloc_data[offset + 3], - ]); - let block_size = u32::from_le_bytes([ - reloc_data[offset + 4], - reloc_data[offset + 5], - reloc_data[offset + 6], - reloc_data[offset + 7], - ]) as usize; - - // Validate block size - if block_size < 8 || offset + block_size > reloc_data.len() { + while let Some(header) = reloc_data.get(offset..).and_then(|s| s.get(..8)) { + // SAFETY: `header` is exactly 8 bytes from the `get(..8)` above; the + // sub-slices below are statically known to fit. + let Some(va_bytes) = header.get(0..4) else { + break; + }; + let Some(size_bytes) = header.get(4..8) else { + break; + }; + let Ok(va_arr) = <[u8; 4]>::try_from(va_bytes) else { + break; + }; + let Ok(size_arr) = <[u8; 4]>::try_from(size_bytes) else { + break; + }; + let block_va = u32::from_le_bytes(va_arr); + let block_size = u32::from_le_bytes(size_arr) as usize; + + // Validate block size and end offset + if block_size < 8 { break; } + let Some(block_end) = offset.checked_add(block_size) else { + break; + }; + let Some(block) = reloc_data.get(offset..block_end) else { + break; + }; // Keep blocks that don't point to .text let points_to_text = block_va >= text_start && block_va < text_end; if !points_to_text { - result.extend_from_slice(&reloc_data[offset..offset + block_size]); + result.extend_from_slice(block); } - offset += block_size; + offset = block_end; } result diff --git a/dotscope/src/cilassembly/writer/remapper.rs b/dotscope/src/cilassembly/writer/remapper.rs index e57444fb..58a6a32a 100644 --- a/dotscope/src/cilassembly/writer/remapper.rs +++ b/dotscope/src/cilassembly/writer/remapper.rs @@ -49,18 +49,22 @@ use std::collections::{HashMap, HashSet}; use crate::{ cilassembly::{AssemblyChanges, TableModifications}, - metadata::tables::{ - AssemblyOsRaw, AssemblyProcessorRaw, AssemblyRaw, AssemblyRefOsRaw, - AssemblyRefProcessorRaw, AssemblyRefRaw, ClassLayoutRaw, CodedIndex, ConstantRaw, - CustomAttributeRaw, CustomDebugInformationRaw, DeclSecurityRaw, DocumentRaw, EncLogRaw, - EncMapRaw, EventMapRaw, EventPtrRaw, EventRaw, ExportedTypeRaw, FieldLayoutRaw, - FieldMarshalRaw, FieldPtrRaw, FieldRaw, FieldRvaRaw, FileRaw, GenericParamConstraintRaw, - GenericParamRaw, ImplMapRaw, ImportScopeRaw, InterfaceImplRaw, LocalConstantRaw, - LocalScopeRaw, LocalVariableRaw, ManifestResourceRaw, MemberRefRaw, - MethodDebugInformationRaw, MethodDefRaw, MethodImplRaw, MethodPtrRaw, MethodSemanticsRaw, - MethodSpecRaw, ModuleRaw, ModuleRefRaw, NestedClassRaw, ParamPtrRaw, ParamRaw, - PropertyMapRaw, PropertyPtrRaw, PropertyRaw, StandAloneSigRaw, StateMachineMethodRaw, - TableDataOwned, TableId, TypeDefRaw, TypeRefRaw, TypeSpecRaw, + metadata::{ + tables::{ + AssemblyOsRaw, AssemblyProcessorRaw, AssemblyRaw, AssemblyRefOsRaw, + AssemblyRefProcessorRaw, AssemblyRefRaw, ClassLayoutRaw, CodedIndex, ConstantRaw, + CustomAttributeRaw, CustomDebugInformationRaw, DeclSecurityRaw, DocumentRaw, EncLogRaw, + EncMapRaw, EventMapRaw, EventPtrRaw, EventRaw, ExportedTypeRaw, FieldLayoutRaw, + FieldMarshalRaw, FieldPtrRaw, FieldRaw, FieldRvaRaw, FileRaw, + GenericParamConstraintRaw, GenericParamRaw, ImplMapRaw, ImportScopeRaw, + InterfaceImplRaw, LocalConstantRaw, LocalScopeRaw, LocalVariableRaw, + ManifestResourceRaw, MemberRefRaw, MethodDebugInformationRaw, MethodDefRaw, + MethodImplRaw, MethodPtrRaw, MethodSemanticsRaw, MethodSpecRaw, ModuleRaw, + ModuleRefRaw, NestedClassRaw, ParamPtrRaw, ParamRaw, PropertyMapRaw, PropertyPtrRaw, + PropertyRaw, StandAloneSigRaw, StateMachineMethodRaw, TableDataOwned, TableId, + TypeDefRaw, TypeRefRaw, TypeSpecRaw, + }, + token::Token, }, }; @@ -211,7 +215,7 @@ impl RidRemapper { if old_rid != new_rid { remap.insert(old_rid, new_rid); } - new_rid += 1; + new_rid = new_rid.saturating_add(1); } // Second pass: for deleted rows, find the "continuation" RID @@ -235,7 +239,7 @@ impl RidRemapper { // In .NET, tables like MethodDef use param_list/field_list/method_list values // that can point one past the end of the table to indicate "no items". // When rows are deleted, this continuation value must also be remapped. - let old_continuation = original_count + 1; + let old_continuation = original_count.saturating_add(1); if old_continuation != final_new_rid { remap.insert(old_continuation, final_new_rid); } @@ -415,7 +419,7 @@ impl RidRemapper { if let Some(&new_rid) = table_remap.get(&coded_index.row) { coded_index.row = new_rid; // Update the token to reflect the new row - coded_index.token = crate::metadata::token::Token::from_parts(table_id, new_rid); + coded_index.token = Token::from_parts(table_id, new_rid); } } } @@ -1087,7 +1091,7 @@ mod tests { let mut typedef = TypeDefRaw { rid: 1, - token: crate::metadata::token::Token::new(0x02000001), + token: Token::new(0x02000001), offset: 0, flags: 0, type_name: 0, diff --git a/dotscope/src/cilassembly/writer/signatures.rs b/dotscope/src/cilassembly/writer/signatures.rs index 106713f5..892b301d 100644 --- a/dotscope/src/cilassembly/writer/signatures.rs +++ b/dotscope/src/cilassembly/writer/signatures.rs @@ -33,7 +33,7 @@ use crate::{ }, token::Token, }, - Result, + Error, Result, }; /// Remaps TypeDef, TypeRef, and TypeSpec tokens in a signature blob. @@ -70,7 +70,9 @@ pub fn remap_signature_tokens( } // Determine signature type from header byte - let header = signature[0]; + let header = *signature + .first() + .ok_or_else(|| Error::LayoutFailed("Cannot remap signature: empty buffer".to_string()))?; // Check for signature type markers if header == SIGNATURE_HEADER::LOCAL_SIG { diff --git a/dotscope/src/cilassembly/writer/sizes.rs b/dotscope/src/cilassembly/writer/sizes.rs index 18b1f44b..d266c290 100644 --- a/dotscope/src/cilassembly/writer/sizes.rs +++ b/dotscope/src/cilassembly/writer/sizes.rs @@ -48,21 +48,35 @@ pub fn calculate_table_stream_expansion(assembly: &CilAssembly) -> Result { "Table {table_id:?} insert count {insert_count_raw} exceeds u32::MAX" )) })?; - let new_count = original_count + insert_count; + let new_count = original_count.checked_add(insert_count).ok_or_else(|| { + Error::LayoutFailed(format!( + "Table {table_id:?} new row count overflows u32" + )) + })?; (new_count, insert_count) } }; - let expansion_bytes = u64::from(additional_rows) * u64::from(row_size); - total_expansion += expansion_bytes; + let expansion_bytes = u64::from(additional_rows) + .checked_mul(u64::from(row_size)) + .ok_or_else(|| { + Error::LayoutFailed(format!("Table {table_id:?} expansion bytes overflow u64")) + })?; + total_expansion = total_expansion + .checked_add(expansion_bytes) + .ok_or_else(|| { + Error::LayoutFailed("total table expansion overflows u64".to_string()) + })?; if original_count == 0 && new_count > 0 { - header_expansion += 4; + header_expansion = header_expansion.saturating_add(4u64); } } } - Ok(total_expansion + header_expansion) + total_expansion + .checked_add(header_expansion) + .ok_or_else(|| Error::LayoutFailed("table expansion total overflows u64".to_string())) } #[cfg(test)] diff --git a/dotscope/src/compiler/codegen/coalescing.rs b/dotscope/src/compiler/codegen/coalescing.rs index 7b3cf338..4626207e 100644 --- a/dotscope/src/compiler/codegen/coalescing.rs +++ b/dotscope/src/compiler/codegen/coalescing.rs @@ -122,7 +122,7 @@ impl LiveInterval { fn new(pos: usize) -> Self { Self { start: pos, - end: pos + 1, + end: pos.saturating_add(1), } } @@ -206,8 +206,10 @@ impl LocalCoalescer { if let Some(live_out) = results.out_state(block_id) { let live_vars: Vec = live_out.variables().collect(); for (i, &var1) in live_vars.iter().enumerate() { - for &var2 in &live_vars[i + 1..] { - edges.push((var1, var2)); + if let Some(rest) = live_vars.get(i.saturating_add(1)..) { + for &var2 in rest { + edges.push((var1, var2)); + } } } } @@ -216,8 +218,10 @@ impl LocalCoalescer { if let Some(live_in) = results.in_state(block_id) { let live_vars: Vec = live_in.variables().collect(); for (i, &var1) in live_vars.iter().enumerate() { - for &var2 in &live_vars[i + 1..] { - edges.push((var1, var2)); + if let Some(rest) = live_vars.get(i.saturating_add(1)..) { + for &var2 in rest { + edges.push((var1, var2)); + } } } } @@ -255,8 +259,10 @@ impl LocalCoalescer { // All operands from the same predecessor interfere with each other for (_, operands) in operands_by_pred { for (i, &var1) in operands.iter().enumerate() { - for &var2 in &operands[i + 1..] { - edges.push((var1, var2)); + if let Some(rest) = operands.get(i.saturating_add(1)..) { + for &var2 in rest { + edges.push((var1, var2)); + } } } } @@ -397,11 +403,13 @@ impl LocalCoalescer { // Expire old intervals - return their slots to free pool // BUT don't return reserved Local-origin slots to the pool - while let Some(Reverse((end, _slot))) = active.peek() { - if *end > interval.start { + while let Some(&Reverse((end, _))) = active.peek() { + if end > interval.start { break; } - let Reverse((_, slot)) = active.pop().unwrap(); + let Some(Reverse((_, slot))) = active.pop() else { + break; + }; // Only add to free pool if not a reserved slot if !reserved_slots.contains(slot as usize) { let ty = slot_type.get(&slot).cloned().unwrap_or(SsaType::Unknown); @@ -424,7 +432,7 @@ impl LocalCoalescer { let slot = slot.unwrap_or_else(|| { let s = next_local; - next_local += 1; + next_local = next_local.saturating_add(1); s }); @@ -464,7 +472,7 @@ impl LocalCoalescer { let mut idx = 0usize; for block_id in 0..ssa.block_count() { if let Some(block) = ssa.block(block_id) { - idx += block.instructions().len(); + idx = idx.saturating_add(block.instructions().len()); } block_end_idx.push(idx); } @@ -493,13 +501,12 @@ impl LocalCoalescer { // position would leave a gap where the local slot can be reused. for operand in phi.operands() { let pred = operand.predecessor(); - let pred_end = if pred < block_end_idx.len() { - block_end_idx[pred] - } else { - instr_idx + 1 - }; + let pred_end = block_end_idx + .get(pred) + .copied() + .unwrap_or_else(|| instr_idx.saturating_add(1)); // Use the later of: PHI position or predecessor end - let use_point = pred_end.max(instr_idx + 1); + let use_point = pred_end.max(instr_idx.saturating_add(1)); intervals .entry(operand.value()) .or_insert_with(|| LiveInterval::new(instr_idx)) @@ -514,7 +521,7 @@ impl LocalCoalescer { intervals .entry(use_var) .or_insert_with(|| LiveInterval::new(instr_idx)) - .extend_end(instr_idx + 1); + .extend_end(instr_idx.saturating_add(1)); } // Definitions start the interval @@ -525,7 +532,7 @@ impl LocalCoalescer { .extend_start(instr_idx); } - instr_idx += 1; + instr_idx = instr_idx.saturating_add(1); } } @@ -674,15 +681,21 @@ impl LocalCoalescer { for (subtree_id, &root_idx) in roots.iter().enumerate() { let mut stack = vec![root_idx]; while let Some(idx) = stack.pop() { - if instr_to_subtree[idx].is_some() { + let Some(slot) = instr_to_subtree.get_mut(idx) else { + continue; + }; + if slot.is_some() { continue; } - instr_to_subtree[idx] = Some(subtree_id); + *slot = Some(subtree_id); // Follow dependency edges: operands defined in this block - for use_var in instructions[idx].uses() { + let Some(instr) = instructions.get(idx) else { + continue; + }; + for use_var in instr.uses() { if let Some(&dep_idx) = def_map.get(&use_var) { - if instr_to_subtree[dep_idx].is_none() { + if matches!(instr_to_subtree.get(dep_idx), Some(None)) { stack.push(dep_idx); } } @@ -694,17 +707,27 @@ impl LocalCoalescer { let num_subtrees = roots.len(); let mut subtree_vars: Vec> = vec![Vec::new(); num_subtrees]; for (idx, instr) in instructions.iter().enumerate() { - if let (Some(dest), Some(subtree_id)) = (instr.def(), instr_to_subtree[idx]) { - subtree_vars[subtree_id].push(dest); + if let (Some(dest), Some(Some(subtree_id))) = + (instr.def(), instr_to_subtree.get(idx).copied()) + { + if let Some(bucket) = subtree_vars.get_mut(subtree_id) { + bucket.push(dest); + } } } // Add interference edges between all variable pairs from different subtrees let mut edges = Vec::new(); for i in 0..num_subtrees { - for j in (i + 1)..num_subtrees { - for &var_a in &subtree_vars[i] { - for &var_b in &subtree_vars[j] { + let Some(vars_i) = subtree_vars.get(i) else { + continue; + }; + for j in i.saturating_add(1)..num_subtrees { + let Some(vars_j) = subtree_vars.get(j) else { + continue; + }; + for &var_a in vars_i { + for &var_b in vars_j { edges.push((var_a, var_b)); } } @@ -739,7 +762,9 @@ impl LocalCoalescer { // First, collect which Local-origin variables are actually USED. // Phi-origin variables go through graph coloring allocation separately. - let slot_capacity = var_count.max(self.coalescable_vars.len() + 1).max(64); + let slot_capacity = var_count + .max(self.coalescable_vars.len().saturating_add(1)) + .max(64); let mut used_local_var_ids = BitSet::new(slot_capacity); for block in ssa.blocks() { for phi in block.phi_nodes() { @@ -857,7 +882,7 @@ impl LocalCoalescer { .ok_or_else(|| Error::CodegenFailed("Should always find a valid slot".into()))?; var_to_local.insert(var, slot); - next_local = next_local.max(slot + 1); + next_local = next_local.max(slot.saturating_add(1)); } Ok(LocalAllocation { @@ -905,7 +930,7 @@ fn pre_assign_locals(ssa: &SsaFunction, used_local_vars: &[(SsaVarId, u16)]) -> for &(var_id, original_idx) in used_local_vars { let new_slot = *original_to_new.entry(original_idx).or_insert_with(|| { let slot = next_local; - next_local += 1; + next_local = next_local.saturating_add(1); slot }); var_to_local.insert(var_id, new_slot); @@ -931,7 +956,7 @@ fn pre_assign_locals(ssa: &SsaFunction, used_local_vars: &[(SsaVarId, u16)]) -> for original_idx in sorted_load_refs { original_to_new.entry(original_idx).or_insert_with(|| { let slot = next_local; - next_local += 1; + next_local = next_local.saturating_add(1); reserved_slots.insert(slot as usize); slot }); diff --git a/dotscope/src/compiler/codegen/mod.rs b/dotscope/src/compiler/codegen/mod.rs index 66e0b06e..96ffa824 100644 --- a/dotscope/src/compiler/codegen/mod.rs +++ b/dotscope/src/compiler/codegen/mod.rs @@ -549,7 +549,7 @@ impl SsaCodeGenerator { None } }) - .unwrap_or(eh.try_offset + eh.try_length); + .unwrap_or(eh.try_offset.saturating_add(eh.try_length)); let handler_end = eh .handler_end_block @@ -563,7 +563,7 @@ impl SsaCodeGenerator { None } }) - .unwrap_or(eh.handler_offset + eh.handler_length); + .unwrap_or(eh.handler_offset.saturating_add(eh.handler_length)); let filter_offset = if eh.flags == ExceptionHandlerFlags::FILTER { eh.filter_start_block @@ -794,7 +794,7 @@ impl SsaCodeGenerator { let pos_before = encoder.current_position(); self.block_offsets.insert(block.id(), pos_before); - let next_block_idx = block_ids.get(idx + 1).copied(); + let next_block_idx = block_ids.get(idx.saturating_add(1)).copied(); self.generate_block(&mut encoder, ssa, block, block.id(), next_block_idx)?; } @@ -818,7 +818,12 @@ impl SsaCodeGenerator { 0 } else { // We need max_index + 1 because local indices are 0-based - self.used_locals.iter().max().copied().unwrap_or(0) + 1 + self.used_locals + .iter() + .max() + .copied() + .unwrap_or(0) + .saturating_add(1) }; Ok((bytecode, max_stack, num_locals)) } @@ -1658,11 +1663,11 @@ impl SsaCodeGenerator { // Use the coalescer's slot if available, otherwise allocate fresh let local_idx = if let Some(&slot) = coalesced_slots.get(&phi_result) { // Bump next_local past any coalesced slot to avoid conflicts - self.next_local = self.next_local.max(slot + 1); + self.next_local = self.next_local.max(slot.saturating_add(1)); slot } else { let idx = self.next_local; - self.next_local += 1; + self.next_local = self.next_local.saturating_add(1); idx }; self.var_storage @@ -1781,11 +1786,11 @@ impl SsaCodeGenerator { // Use the coalescer's slot if available, otherwise allocate fresh let local_idx = if let Some(&slot) = coalesced_slots.get(&var_id) { // Bump next_local past any coalesced slot to avoid conflicts - self.next_local = self.next_local.max(slot + 1); + self.next_local = self.next_local.max(slot.saturating_add(1)); slot } else { let idx = self.next_local; - self.next_local += 1; + self.next_local = self.next_local.saturating_add(1); idx }; self.var_storage @@ -1923,7 +1928,8 @@ impl SsaCodeGenerator { for block in ssa.blocks() { for instr in block.instructions() { for var in instr.op().uses() { - *use_counts.entry(var).or_insert(0) += 1; + let entry = use_counts.entry(var).or_insert(0); + *entry = entry.saturating_add(1); } } } @@ -1978,7 +1984,8 @@ impl SsaCodeGenerator { for phi in block.phi_nodes() { if live_phis.contains(&phi.result()) { for operand in phi.operands() { - *use_counts.entry(operand.value()).or_insert(0) += 1; + let entry = use_counts.entry(operand.value()).or_insert(0); + *entry = entry.saturating_add(1); } } } @@ -2093,7 +2100,7 @@ impl SsaCodeGenerator { // values onto the stack, burying our value and making it inaccessible. // A more sophisticated analysis could check if intervening instructions // consume their values before our use, but that's complex and error-prone. - if u_idx != d_idx + 1 { + if u_idx != d_idx.saturating_add(1) { return false; } @@ -2102,7 +2109,10 @@ impl SsaCodeGenerator { // Condition 6: The use itself should not be a terminator operand // (terminators trigger spilling, so the value would be stored anyway) - if Self::is_terminator(instrs[u_idx].op()) { + let Some(use_instr) = instrs.get(u_idx) else { + return false; + }; + if Self::is_terminator(use_instr.op()) { return false; } @@ -2155,7 +2165,7 @@ impl SsaCodeGenerator { // Default: allocate a new local let local = self.next_local; - self.next_local += 1; + self.next_local = self.next_local.saturating_add(1); let storage = VarStorage::Local(local); self.var_storage.insert(var, storage); @@ -2210,7 +2220,7 @@ impl SsaCodeGenerator { // Allocate a new local slot let local = self.next_local; - self.next_local += 1; + self.next_local = self.next_local.saturating_add(1); let storage = VarStorage::Local(local); self.var_storage.insert(var, storage); @@ -2413,9 +2423,9 @@ impl SsaCodeGenerator { // If all phi operands trace to the same storage, return that storage if let Some((_, phi)) = ssa.find_phi_defining(var) { let operands = phi.operands(); - if !operands.is_empty() { + if let Some(first_op) = operands.first() { // Trace the first operand to get a reference storage - let first_storage = self.trace_to_storage(operands[0].value(), ssa, visited)?; + let first_storage = self.trace_to_storage(first_op.value(), ssa, visited)?; // Check if all other operands trace to the same storage for operand in operands.iter().skip(1) { @@ -2679,7 +2689,9 @@ impl SsaCodeGenerator { if !defined_in_block.contains(value) { // This Pop consumes a value not defined in this block // (the exception object), so clear its operands - operands_cache[idx].clear(); + if let Some(ops_at_idx) = operands_cache.get_mut(idx) { + ops_at_idx.clear(); + } has_exception_pop = true; } } @@ -2800,7 +2812,8 @@ impl SsaCodeGenerator { let mut in_progress: BTreeSet = BTreeSet::new(); // Track which LoadOperand (idx, operand_idx) pairs are already scheduled let mut scheduled_load: BTreeSet<(usize, usize)> = BTreeSet::new(); - let mut work_stack: Vec = Vec::with_capacity(ctx.ops.len() * 3); + let mut work_stack: Vec = + Vec::with_capacity(ctx.ops.len().saturating_mul(3)); // Preserve original instruction order as much as possible. // For malware analysis and understanding original intent, the order matters. @@ -2822,9 +2835,13 @@ impl SsaCodeGenerator { let mut copy_roots: Vec = Vec::new(); let mut other_roots: Vec = Vec::new(); for &root_idx in roots { - if Self::is_terminator(ctx.ops[root_idx]) { + let op = ctx + .ops + .get(root_idx) + .ok_or_else(|| Error::CodegenFailed("root index out of bounds".to_string()))?; + if Self::is_terminator(op) { terminator_roots.push(root_idx); - } else if matches!(ctx.ops[root_idx], SsaOp::Copy { .. }) { + } else if matches!(op, SsaOp::Copy { .. }) { copy_roots.push(root_idx); } else { other_roots.push(root_idx); @@ -2875,7 +2892,9 @@ impl SsaCodeGenerator { // Only spill when the first operand does NOT match the stack // top, because loading it from storage would bury the tracked // stack values. - let operands = &ctx.operands_cache[idx]; + let operands = ctx.operands_cache.get(idx).ok_or_else(|| { + Error::CodegenFailed("operands_cache index out of bounds".to_string()) + })?; let first_operand_on_stack = operands.first().is_some_and(|first_op| { self.stack_vars.last().is_some_and(|top| *top == *first_op) }); @@ -2903,14 +2922,20 @@ impl SsaCodeGenerator { CodeGenWorkItem::LoadOperand(idx, operand_idx) => { scheduled_load.remove(&(idx, operand_idx)); - let operand = ctx.operands_cache[idx][operand_idx]; + let operand = *ctx + .operands_cache + .get(idx) + .and_then(|ops| ops.get(operand_idx)) + .ok_or_else(|| { + Error::CodegenFailed("operands_cache index out of bounds".to_string()) + })?; if let Some(&dep_idx) = ctx.def_map.get(&operand) { // Operand is defined in this block if generated.contains(&dep_idx) { // Already generated - load from storage or stack self.load_var(encoder, ctx.ssa, operand)?; - } else if let SsaOp::Copy { src, .. } = ctx.ops[dep_idx] { + } else if let Some(SsaOp::Copy { src, .. }) = ctx.ops.get(dep_idx) { // Operand is defined by a Copy that hasn't been generated yet. // This can happen with circular dependencies through Copy chains // (e.g., Add needs Copy result, Copy needs Add result). @@ -2926,12 +2951,14 @@ impl SsaCodeGenerator { // This is a cyclic dependency between non-Copy operations, // which indicates invalid SSA. All operations within a block // should have a valid topological order. + let cur_op = ctx.ops.get(idx); + let dep_op = ctx.ops.get(dep_idx); return Err(Error::Deobfuscation(format!( "Cyclic dependency detected in block {}: \ op {:?} needs {:?} (defined by {:?}), \ but that definition is already being processed. \ This indicates invalid SSA form.", - ctx.current_block_idx, ctx.ops[idx], operand, ctx.ops[dep_idx] + ctx.current_block_idx, cur_op, operand, dep_op ))); } @@ -2963,12 +2990,16 @@ impl SsaCodeGenerator { continue; // Already done } + let cur_op = *ctx.ops.get(idx).ok_or_else(|| { + Error::CodegenFailed("ops index out of bounds".to_string()) + })?; + // Generate the operation (operands should be on stack) self.generate_op_core( encoder, ctx.ssa, ctx.current_block_idx, - ctx.ops[idx], + cur_op, next_block_idx, )?; generated.insert(idx); @@ -2977,9 +3008,9 @@ impl SsaCodeGenerator { // and whether it's used outside this block (cross-block use). // Note: Skip Copy instructions - they handle their own storage // in generate_op_core and don't leave a result on the stack. - let is_copy = matches!(ctx.ops[idx], SsaOp::Copy { .. }); + let is_copy = matches!(cur_op, SsaOp::Copy { .. }); if !is_copy { - if let Some(dest) = ctx.ops[idx].dest() { + if let Some(dest) = cur_op.dest() { let uses = use_counts.get(&dest).copied().unwrap_or(0); // Check if this value is used outside the current block. // If so, it must be stored because stack values don't @@ -2998,8 +3029,9 @@ impl SsaCodeGenerator { let next_loads_dest = work_stack.last().is_some_and(|item| match item { CodeGenWorkItem::LoadOperand(consumer_idx, op_idx) => ctx - .operands_cache[*consumer_idx] - .get(*op_idx) + .operands_cache + .get(*consumer_idx) + .and_then(|ops| ops.get(*op_idx)) .is_some_and(|op| *op == dest), _ => false, }); @@ -3039,16 +3071,22 @@ impl SsaCodeGenerator { // 3. The terminator (e.g., Jump) may emit phi stores that depend on the Copy results for &root_idx in &terminator_roots { if !generated.contains(&root_idx) { + let root_operands = ctx.operands_cache.get(root_idx).ok_or_else(|| { + Error::CodegenFailed("operands_cache index out of bounds".to_string()) + })?; + let root_op = *ctx.ops.get(root_idx).ok_or_else(|| { + Error::CodegenFailed("ops index out of bounds".to_string()) + })?; // Load any operands the terminator needs. // Always use load_var which handles buried values correctly. - for &operand in &ctx.operands_cache[root_idx] { + for &operand in root_operands { self.load_var(encoder, ctx.ssa, operand)?; } self.generate_op_core( encoder, ctx.ssa, ctx.current_block_idx, - ctx.ops[root_idx], + root_op, next_block_idx, )?; generated.insert(root_idx); @@ -3103,7 +3141,10 @@ impl SsaCodeGenerator { if generated.contains(©_idx) { continue; } - if let SsaOp::Copy { dest, src } = ctx.ops[copy_idx] { + let Some(&op) = ctx.ops.get(copy_idx) else { + continue; + }; + if let SsaOp::Copy { dest, src } = op { let dest_storage = self.get_or_allocate_storage(ctx.ssa, *dest)?; copies.push(CopyInfo { src_var: *src, @@ -3227,7 +3268,9 @@ impl SsaCodeGenerator { return Ok(()); } - let operands = &ctx.operands_cache[idx]; + let operands = ctx.operands_cache.get(idx).ok_or_else(|| { + Error::CodegenFailed("operands_cache index out of bounds".to_string()) + })?; for &operand in operands { if let Some(&dep_idx) = ctx.def_map.get(&operand) { @@ -3236,8 +3279,12 @@ impl SsaCodeGenerator { encoder, ctx, dep_idx, generated, visiting, use_counts, )?; + let dep_op = *ctx.ops.get(dep_idx).ok_or_else(|| { + Error::CodegenFailed("ops index out of bounds".to_string()) + })?; + // Skip Copy ops - they'll be handled in the parallel copy phase - if matches!(ctx.ops[dep_idx], SsaOp::Copy { .. }) { + if matches!(dep_op, SsaOp::Copy { .. }) { continue; } @@ -3248,17 +3295,24 @@ impl SsaCodeGenerator { self.spill_stack(encoder, ctx.ssa)?; } + let dep_operands = ctx.operands_cache.get(dep_idx).ok_or_else(|| { + Error::CodegenFailed("operands_cache index out of bounds".to_string()) + })?; + // Generate operands for this dependency - for &dep_operand in &ctx.operands_cache[dep_idx] { + for &dep_operand in dep_operands { if let Some(&dep_dep_idx) = ctx.def_map.get(&dep_operand) { + let dep_dep_op = ctx.ops.get(dep_dep_idx).copied(); if generated.contains(&dep_dep_idx) { // Operand's defining op is generated, load it self.load_var(encoder, ctx.ssa, dep_operand)?; - } else if matches!(ctx.ops[dep_dep_idx], SsaOp::Copy { .. }) { + } else if matches!(dep_dep_op, Some(SsaOp::Copy { .. })) { // Operand is a Copy result that hasn't been generated. // Load the Copy's source instead - it should be available. - let SsaOp::Copy { src, .. } = ctx.ops[dep_dep_idx] else { - unreachable!() + let Some(SsaOp::Copy { src, .. }) = dep_dep_op else { + return Err(Error::CodegenFailed( + "Copy op pattern mismatch".to_string(), + )); }; self.load_var(encoder, ctx.ssa, *src)?; } else { @@ -3273,17 +3327,11 @@ impl SsaCodeGenerator { } } - self.generate_op_core( - encoder, - ctx.ssa, - ctx.current_block_idx, - ctx.ops[dep_idx], - None, - )?; + self.generate_op_core(encoder, ctx.ssa, ctx.current_block_idx, dep_op, None)?; generated.insert(dep_idx); // Handle storage for the result - if let Some(dest) = ctx.ops[dep_idx].dest() { + if let Some(dest) = dep_op.dest() { let uses = use_counts.get(&dest).copied().unwrap_or(0); let used_outside_block = self.cross_block_uses.contains(&dest); @@ -3666,7 +3714,7 @@ impl SsaCodeGenerator { .. } => { // num_args: function pointer + the method arguments - let num_args = u8::try_from(args.len() + 1).unwrap_or(u8::MAX); + let num_args = u8::try_from(args.len().saturating_add(1)).unwrap_or(u8::MAX); let has_result = dest.is_some(); encoder.emit_call( "calli", @@ -4406,7 +4454,7 @@ impl SsaCodeGenerator { } => { let elem_size = element_size.max(&1); #[allow(clippy::cast_possible_truncation)] - let num_elements = data.len() / elem_size; + let num_elements = data.len().checked_div(*elem_size).unwrap_or(0); emitter::emit_ldc_i4(encoder, num_elements as i32)?; encoder.emit_instruction("newarr", Some(Operand::Token(*element_type_token)))?; @@ -4617,7 +4665,7 @@ impl SsaCodeGenerator { // No pooled slot available - allocate a fresh one let slot = self.next_local; - self.next_local += 1; + self.next_local = self.next_local.saturating_add(1); self.local_types.insert(slot, var_type.clone()); slot } @@ -4916,7 +4964,7 @@ impl SsaCodeGenerator { // Use intermediate labels for targets that need phi stores let mut switch_labels: Vec = Vec::with_capacity(targets.len()); for (i, &target) in targets.iter().enumerate() { - if needs_intermediate[i] { + if needs_intermediate.get(i).copied().unwrap_or(false) { switch_labels.push(format!("phi_switch_{current_block_idx}_{i}")); } else { switch_labels.push( @@ -4947,7 +4995,7 @@ impl SsaCodeGenerator { // Emit intermediate blocks for targets that need phi stores for (i, &target) in targets.iter().enumerate() { - if needs_intermediate[i] { + if needs_intermediate.get(i).copied().unwrap_or(false) { let intermediate_label = format!("phi_switch_{current_block_idx}_{i}"); encoder.define_label(&intermediate_label)?; self.emit_phi_stores_for_successor(encoder, ssa, current_block_idx, target)?; @@ -4974,7 +5022,7 @@ impl SsaCodeGenerator { // Emit intermediate blocks for targets that need phi stores for (i, &target) in targets.iter().enumerate() { - if needs_intermediate[i] { + if needs_intermediate.get(i).copied().unwrap_or(false) { let intermediate_label = format!("phi_switch_{current_block_idx}_{i}"); encoder.define_label(&intermediate_label)?; self.emit_phi_stores_for_successor(encoder, ssa, current_block_idx, target)?; @@ -5142,7 +5190,7 @@ impl SsaCodeGenerator { // Find copies whose destination is not read by any other pending copy for i in 0..pending.len() { - let Some((_, dst, _)) = pending[i].as_ref() else { + let Some((_, dst, _)) = pending.get(i).and_then(|p| p.as_ref()) else { continue; }; let dst = *dst; @@ -5161,7 +5209,7 @@ impl SsaCodeGenerator { if !dst_is_read { // Safe to schedule this copy - if let Some(copy) = pending[i].take() { + if let Some(copy) = pending.get_mut(i).and_then(Option::take) { ordered_copies.push(copy); made_progress = true; } diff --git a/dotscope/src/compiler/events.rs b/dotscope/src/compiler/events.rs index 7a6e040a..10d5eb98 100644 --- a/dotscope/src/compiler/events.rs +++ b/dotscope/src/compiler/events.rs @@ -386,9 +386,10 @@ impl EventLog { /// Counts events grouped by kind. #[must_use] pub fn count_by_kind(&self) -> HashMap { - let mut counts = HashMap::new(); + let mut counts: HashMap = HashMap::new(); for (_, event) in &self.events { - *counts.entry(event.kind).or_insert(0) += 1; + let entry = counts.entry(event.kind).or_insert(0); + *entry = entry.saturating_add(1); } counts } @@ -399,10 +400,11 @@ impl EventLog { /// iterating the entire log. #[must_use] pub fn count_by_kind_since(&self, offset: usize) -> HashMap { - let mut counts = HashMap::new(); + let mut counts: HashMap = HashMap::new(); for (idx, event) in &self.events { if idx >= offset { - *counts.entry(event.kind).or_insert(0) += 1; + let entry = counts.entry(event.kind).or_insert(0); + *entry = entry.saturating_add(1); } } counts diff --git a/dotscope/src/compiler/passes/algebraic.rs b/dotscope/src/compiler/passes/algebraic.rs index 9a46663f..095e27a4 100644 --- a/dotscope/src/compiler/passes/algebraic.rs +++ b/dotscope/src/compiler/passes/algebraic.rs @@ -139,7 +139,9 @@ impl AlgebraicSimplificationPass { ) { for candidate in candidates { if let Some(block) = ssa.block_mut(candidate.block_idx) { - let instr = &mut block.instructions_mut()[candidate.instr_idx]; + let Some(instr) = block.instructions_mut().get_mut(candidate.instr_idx) else { + continue; + }; let new_op = match candidate.simplification { Simplification::Constant(value) => SsaOp::Const { dest: candidate.dest, diff --git a/dotscope/src/compiler/passes/blockmerge.rs b/dotscope/src/compiler/passes/blockmerge.rs index b1d21a8e..a1c4f248 100644 --- a/dotscope/src/compiler/passes/blockmerge.rs +++ b/dotscope/src/compiler/passes/blockmerge.rs @@ -121,7 +121,7 @@ impl BlockMergingPass { // Maps: (trampoline, ultimate_target) → [predecessor blocks that were redirected] let mut redirected_preds: BTreeMap<(usize, usize), Vec> = BTreeMap::new(); - let mut redirected = 0; + let mut redirected: usize = 0; // Update all branch targets in all blocks for block_idx in 0..ssa.block_count() { @@ -150,7 +150,7 @@ impl BlockMergingPass { .message(format!( "redirected through trampoline: {old_targets:?} -> {new_targets:?}" )); - redirected += 1; + redirected = redirected.saturating_add(1); } } } @@ -216,7 +216,7 @@ impl BlockMergingPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut cleared = 0; + let mut cleared: usize = 0; for &block_idx in trampolines.keys() { if let Some(block) = ssa.block_mut(block_idx) { @@ -226,7 +226,7 @@ impl BlockMergingPass { .record(EventKind::BlockRemoved) .at(method_token, block_idx) .message(format!("cleared trampoline block B{block_idx}")); - cleared += 1; + cleared = cleared.saturating_add(1); } } } @@ -250,7 +250,7 @@ impl BlockMergingPass { Self::redirect_to_ultimate_targets(ssa, &trampolines, method_token, changes); let cleared = Self::clear_trampolines(ssa, &trampolines, method_token, changes); - redirected + cleared + redirected.saturating_add(cleared) } /// Merges blocks connected by a single edge. @@ -275,7 +275,7 @@ impl BlockMergingPass { changes: &mut EventLog, max_iterations: usize, ) -> usize { - let mut merged = 0; + let mut merged: usize = 0; // Collect exception handler boundary blocks. // @@ -312,7 +312,7 @@ impl BlockMergingPass { // Iterate until fixed point. for _ in 0..max_iterations { - let mut iteration_merges = 0; + let mut iteration_merges: usize = 0; // Build predecessor counts for all blocks. let block_count = ssa.block_count(); @@ -326,13 +326,19 @@ impl BlockMergingPass { .unwrap_or_default(); for succ in successors { if succ < block_count { - pred_counts[succ] += 1; - pred_of[succ] = Some(idx); + if let Some(c) = pred_counts.get_mut(succ) { + *c = c.saturating_add(1); + } + if let Some(p) = pred_of.get_mut(succ) { + *p = Some(idx); + } } } } // Entry block has an implicit edge. - pred_counts[0] += 1; + if let Some(c) = pred_counts.get_mut(0) { + *c = c.saturating_add(1); + } // Find mergeable pairs: A -> B where A's terminator is Jump(B), // B has exactly 1 predecessor, and neither is a handler boundary. @@ -349,7 +355,7 @@ impl BlockMergingPass { if b_idx >= block_count || b_idx == a_idx { continue; } - if pred_counts[b_idx] != 1 { + if pred_counts.get(b_idx).copied().unwrap_or(0) != 1 { continue; } if no_merge_from.contains(a_idx) || no_merge_into.contains(b_idx) { @@ -443,10 +449,10 @@ impl BlockMergingPass { .at(method_token, b_idx) .message(format!("coalesced B{b_idx} into B{a_idx}")); - iteration_merges += 1; + iteration_merges = iteration_merges.saturating_add(1); } - merged += iteration_merges; + merged = merged.saturating_add(iteration_merges); if iteration_merges == 0 { break; } @@ -483,7 +489,7 @@ impl BlockMergingPass { let preds = ssa.block_predecessors(target); let target_has_phis = ssa.block(target).is_none_or(|b| !b.phi_nodes().is_empty()); - if preds.len() == 1 && preds[0] == 0 && !target_has_phis { + if preds.len() == 1 && preds.first().copied() == Some(0) && !target_has_phis { // Safe to inline: B_target's only external predecessor is B0 and it // has no phis. Move B_target's instructions into B0. let target_instrs = ssa diff --git a/dotscope/src/compiler/passes/constants/mod.rs b/dotscope/src/compiler/passes/constants/mod.rs index 59504568..77ad63ad 100644 --- a/dotscope/src/compiler/passes/constants/mod.rs +++ b/dotscope/src/compiler/passes/constants/mod.rs @@ -307,7 +307,9 @@ impl ConstantPropagationPass { // Apply the transformations for (block_idx, instr_idx, result) in transformations { if let Some(block) = ssa.block_mut(block_idx) { - let instr = &mut block.instructions_mut()[instr_idx]; + let Some(instr) = block.instructions_mut().get_mut(instr_idx) else { + continue; + }; let old_op_str = format!("{}", instr.op()); match result { @@ -413,18 +415,19 @@ impl ConstantPropagationPass { constants.insert(dest, value.clone()); if let Some(block) = ssa.block_mut(block_idx) { - let instr = &mut block.instructions_mut()[instr_idx]; - let old_op_str = format!("{}", instr.op()); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + let old_op_str = format!("{}", instr.op()); - instr.set_op(SsaOp::Const { - dest, - value: value.clone(), - }); + instr.set_op(SsaOp::Const { + dest, + value: value.clone(), + }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("{old_op_str} → {value} (conv)")); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(format!("{old_op_str} → {value} (conv)")); + } } } } @@ -589,23 +592,24 @@ impl ConstantPropagationPass { reason, } => { if let Some(block) = ssa.block_mut(block_idx) { - let instr = &mut block.instructions_mut()[instr_idx]; - let old_op_str = format!("{}", instr.op()); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + let old_op_str = format!("{}", instr.op()); - instr.set_op(SsaOp::Conv { - dest, - operand: new_operand, - target: target.clone(), - overflow_check: false, - unsigned, - }); + instr.set_op(SsaOp::Conv { + dest, + operand: new_operand, + target: target.clone(), + overflow_check: false, + unsigned, + }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!( - "{old_op_str} → conv.{target} {new_operand} ({reason})" - )); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(format!( + "{old_op_str} → conv.{target} {new_operand} ({reason})" + )); + } } } ConvTransform::ReplaceWithCopy { @@ -616,15 +620,16 @@ impl ConstantPropagationPass { reason, } => { if let Some(block) = ssa.block_mut(block_idx) { - let instr = &mut block.instructions_mut()[instr_idx]; - let old_op_str = format!("{}", instr.op()); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + let old_op_str = format!("{}", instr.op()); - instr.set_op(SsaOp::Copy { dest, src }); + instr.set_op(SsaOp::Copy { dest, src }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("{old_op_str} → copy {src} ({reason})")); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(format!("{old_op_str} → copy {src} ({reason})")); + } } } } @@ -763,18 +768,19 @@ impl ConstantPropagationPass { constants.insert(dest, value.clone()); if let Some(block) = ssa.block_mut(block_idx) { - let instr = &mut block.instructions_mut()[instr_idx]; - let old_op_str = format!("{}", instr.op()); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + let old_op_str = format!("{}", instr.op()); - instr.set_op(SsaOp::Const { - dest, - value: value.clone(), - }); + instr.set_op(SsaOp::Const { + dest, + value: value.clone(), + }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("{old_op_str} → {value} (ovf)")); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(format!("{old_op_str} → {value} (ovf)")); + } } } } @@ -940,7 +946,10 @@ impl ConstantPropagationPass { instr.set_op(SsaOp::Const { dest, value }); changes .record(EventKind::ConstantFolded) - .at(method_token, block_idx * 1000 + instr_idx) + .at( + method_token, + block_idx.saturating_mul(1000).saturating_add(instr_idx), + ) .message("folded pure call with constant arguments"); } } @@ -996,11 +1005,9 @@ impl ConstantPropagationPass { if r != 0 { #[allow(clippy::cast_sign_loss)] let result = if *unsigned { - ((l as u64) % (r as u64)) as i64 - } else if l != i64::MIN || r != -1 { - l % r + (l as u64).checked_rem(r as u64).unwrap_or(0) as i64 } else { - 0 + l.checked_rem(r).unwrap_or(0) }; #[allow(clippy::cast_possible_truncation)] let value = ConstValue::I32(result as i32); @@ -1018,7 +1025,10 @@ impl ConstantPropagationPass { instr.set_op(SsaOp::Const { dest, value }); changes .record(EventKind::ConstantFolded) - .at(method_token, block_idx * 1000 + instr_idx) + .at( + method_token, + block_idx.saturating_mul(1000).saturating_add(instr_idx), + ) .message("folded arithmetic with constant operands"); } } @@ -1111,18 +1121,19 @@ impl ConstantPropagationPass { constants.insert(dest, value.clone()); if let Some(block) = ssa.block_mut(block_idx) { - let instr = &mut block.instructions_mut()[instr_idx]; - let old_op_str = format!("{}", instr.op()); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + let old_op_str = format!("{}", instr.op()); - instr.set_op(SsaOp::Const { - dest, - value: value.clone(), - }); + instr.set_op(SsaOp::Const { + dest, + value: value.clone(), + }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("{old_op_str} → {value} (string fold)")); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(format!("{old_op_str} → {value} (string fold)")); + } } } } @@ -1208,50 +1219,41 @@ impl ConstantPropagationPass { Some((dest, ConstValue::DecryptedString(result))) } StringFoldOp::SubstringFrom => { - let this_str = constants.get(&args[0])?.as_string_content(assembly)?; + let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; // Bail on non-ASCII: .NET uses UTF-16 char indices, Rust uses bytes. if !this_str.is_ascii() { return None; } - let start = constants.get(&args[1])?.as_i32()? as usize; - if start > this_str.len() { - return None; - } - Some(( - dest, - ConstValue::DecryptedString(this_str[start..].to_string()), - )) + let start = constants.get(args.get(1)?)?.as_i32()? as usize; + let tail = this_str.get(start..)?; + Some((dest, ConstValue::DecryptedString(tail.to_string()))) } StringFoldOp::SubstringRange => { - let this_str = constants.get(&args[0])?.as_string_content(assembly)?; + let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; if !this_str.is_ascii() { return None; } - let start = constants.get(&args[1])?.as_i32()? as usize; - let len = constants.get(&args[2])?.as_i32()? as usize; - if start.saturating_add(len) > this_str.len() { - return None; - } - Some(( - dest, - ConstValue::DecryptedString(this_str[start..start + len].to_string()), - )) + let start = constants.get(args.get(1)?)?.as_i32()? as usize; + let len = constants.get(args.get(2)?)?.as_i32()? as usize; + let end = start.checked_add(len)?; + let slice = this_str.get(start..end)?; + Some((dest, ConstValue::DecryptedString(slice.to_string()))) } StringFoldOp::Replace => { - let this_str = constants.get(&args[0])?.as_string_content(assembly)?; - let old = constants.get(&args[1])?.as_string_content(assembly)?; - let new = constants.get(&args[2])?.as_string_content(assembly)?; + let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; + let old = constants.get(args.get(1)?)?.as_string_content(assembly)?; + let new = constants.get(args.get(2)?)?.as_string_content(assembly)?; Some(( dest, ConstValue::DecryptedString(this_str.replace(&old, &new)), )) } StringFoldOp::ToLower => { - let this_str = constants.get(&args[0])?.as_string_content(assembly)?; + let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; Some((dest, ConstValue::DecryptedString(this_str.to_lowercase()))) } StringFoldOp::ToUpper => { - let this_str = constants.get(&args[0])?.as_string_content(assembly)?; + let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; Some((dest, ConstValue::DecryptedString(this_str.to_uppercase()))) } } @@ -1348,12 +1350,14 @@ impl ConstantPropagationPass { definitions.insert(dest, (block_idx, instr_idx)); } for use_var in instr.op().uses() { - *use_counts.entry(use_var).or_default() += 1; + let slot = use_counts.entry(use_var).or_default(); + *slot = slot.saturating_add(1); } } for phi in ssa.all_phi_nodes() { for operand in phi.operands() { - *use_counts.entry(operand.value()).or_default() += 1; + let slot = use_counts.entry(operand.value()).or_default(); + *slot = slot.saturating_add(1); } } @@ -1428,7 +1432,10 @@ impl ConstantPropagationPass { break; }; - let inner = match def_block_ref.instructions()[def_instr].op() { + let Some(def_instr_ref) = def_block_ref.instructions().get(def_instr) else { + break; + }; + let inner = match def_instr_ref.op() { SsaOp::Neg { dest: d, operand: inner, @@ -1486,22 +1493,25 @@ impl ConstantPropagationPass { // Odd chain: one operation remains, rewrite outermost to use innermost_operand let (b, i) = t.outermost_location; if let Some(block) = ssa.block_mut(b) { - let instr = &mut block.instructions_mut()[i]; - if t.is_neg { - instr.set_op(SsaOp::Neg { - dest: t.outermost_dest, - operand: t.innermost_operand, - }); - } else { - instr.set_op(SsaOp::Not { - dest: t.outermost_dest, - operand: t.innermost_operand, - }); + if let Some(instr) = block.instructions_mut().get_mut(i) { + if t.is_neg { + instr.set_op(SsaOp::Neg { + dest: t.outermost_dest, + operand: t.innermost_operand, + }); + } else { + instr.set_op(SsaOp::Not { + dest: t.outermost_dest, + operand: t.innermost_operand, + }); + } } } // Nop all except the outermost - for &(b, i) in &t.instructions_to_nop[1..] { - ssa.remove_instruction(b, i); + if let Some(rest) = t.instructions_to_nop.get(1..) { + for &(b, i) in rest { + ssa.remove_instruction(b, i); + } } changes .record(EventKind::ConstantFolded) @@ -1541,7 +1551,7 @@ impl ConstantPropagationPass { } // Check for back-edges from blocks with higher indices. - for bi in (block_idx + 1)..ssa.block_count() { + for bi in block_idx.saturating_add(1)..ssa.block_count() { if let Some(block) = ssa.block(bi) { if let Some(op) = block.terminator_op() { let targets_block = match op { diff --git a/dotscope/src/compiler/passes/controlflow.rs b/dotscope/src/compiler/passes/controlflow.rs index 4ee9655d..45f882da 100644 --- a/dotscope/src/compiler/passes/controlflow.rs +++ b/dotscope/src/compiler/passes/controlflow.rs @@ -133,8 +133,12 @@ impl ControlFlowSimplificationPass { .copied() .filter(|&t| t != block_idx) .collect(); - if !non_self.is_empty() && non_self.iter().all(|t| *t == non_self[0]) { - Some((block_idx, non_self[0])) + if let Some(&first) = non_self.first() { + if non_self.iter().all(|t| *t == first) { + Some((block_idx, first)) + } else { + None + } } else { None } @@ -172,7 +176,7 @@ impl ControlFlowSimplificationPass { .map(|&t| (t, resolve_chain(trampolines, t))) .collect(); - let mut threaded_count = 0; + let mut threaded_count: usize = 0; for block_idx in 0..ssa.block_count() { if let Some(block) = ssa.block_mut(block_idx) { @@ -194,7 +198,7 @@ impl ControlFlowSimplificationPass { .record(EventKind::ControlFlowRestructured) .at(method_token, block_idx) .message(format!("jump threaded: {old_targets:?} -> {new_targets:?}")); - threaded_count += 1; + threaded_count = threaded_count.saturating_add(1); } } } @@ -223,7 +227,7 @@ impl ControlFlowSimplificationPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut simplified_count = 0; + let mut simplified_count: usize = 0; for &(block_idx, target) in same_target_branches { if let Some(block) = ssa.block_mut(block_idx) { @@ -235,7 +239,7 @@ impl ControlFlowSimplificationPass { .message(format!( "branch to same target simplified: B{block_idx} branch -> jump B{target}" )); - simplified_count += 1; + simplified_count = simplified_count.saturating_add(1); } } } @@ -261,7 +265,7 @@ impl ControlFlowSimplificationPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut removed_count = 0; + let mut removed_count: usize = 0; for &(block_idx, start_idx) in dead_tails { if let Some(block) = ssa.block_mut(block_idx) { @@ -269,7 +273,7 @@ impl ControlFlowSimplificationPass { let to_remove = instr_count.saturating_sub(start_idx); for _ in 0..to_remove { block.instructions_mut().pop(); - removed_count += 1; + removed_count = removed_count.saturating_add(1); } if to_remove > 0 { changes @@ -297,29 +301,39 @@ impl ControlFlowSimplificationPass { /// /// The total number of changes made during this iteration. fn run_iteration(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize { - let mut total_changes = 0; + let mut total_changes: usize = 0; // Step 1: Find and apply jump threading (don't skip entry block) let trampolines = ssa.find_trampoline_blocks(false); if !trampolines.is_empty() { - total_changes += Self::apply_jump_threading(ssa, &trampolines, method_token, changes); + total_changes = total_changes.saturating_add(Self::apply_jump_threading( + ssa, + &trampolines, + method_token, + changes, + )); } // Step 2: Simplify branches to same target (also resolves through trampolines) let same_target_branches = Self::find_same_target_branches(ssa, &trampolines); if !same_target_branches.is_empty() { - total_changes += Self::simplify_same_target_branches( + total_changes = total_changes.saturating_add(Self::simplify_same_target_branches( ssa, &same_target_branches, method_token, changes, - ); + )); } // Step 3: Remove dead tails let dead_tails = find_dead_tails(ssa); if !dead_tails.is_empty() { - total_changes += Self::remove_dead_tails(ssa, &dead_tails, method_token, changes); + total_changes = total_changes.saturating_add(Self::remove_dead_tails( + ssa, + &dead_tails, + method_token, + changes, + )); } total_changes diff --git a/dotscope/src/compiler/passes/copying.rs b/dotscope/src/compiler/passes/copying.rs index 2c0b1f18..15d407b7 100644 --- a/dotscope/src/compiler/passes/copying.rs +++ b/dotscope/src/compiler/passes/copying.rs @@ -93,7 +93,7 @@ impl CopyPropagationPass { /// `rebuild_ssa` calls then assign the entry value (null) to all uses /// of Local(0), corrupting the data flow. fn protect_sole_local_defs(ssa: &SsaFunction, copies: &mut BTreeMap) { - let real_local_limit = (ssa.num_args() + ssa.num_locals()) as u32; + let real_local_limit = ssa.num_args().saturating_add(ssa.num_locals()) as u32; // Count instruction-based definitions per local/argument group. let mut group_def_count: BTreeMap = BTreeMap::new(); @@ -102,7 +102,8 @@ impl CopyPropagationPass { if let Some(dest) = instr.op().dest() { let group = ssa.rename_group(dest); if group < real_local_limit { - *group_def_count.entry(group).or_insert(0) += 1; + let counter = group_def_count.entry(group).or_insert(0); + *counter = counter.saturating_add(1); } } } @@ -112,7 +113,10 @@ impl CopyPropagationPass { // These are the groups at risk: after copy-prop eliminates the bridging // Copy, rebuild_ssa's rename may not be able to reconstruct the correct // reaching definition through phi chains. - let group_bound = ssa.num_locals() + ssa.num_args() + 1; + let group_bound = ssa + .num_locals() + .saturating_add(ssa.num_args()) + .saturating_add(1); let mut groups_in_phis = BitSet::new(group_bound); for block in ssa.blocks() { for phi in block.phi_nodes() { @@ -138,7 +142,7 @@ impl CopyPropagationPass { for block in ssa.blocks() { for instr in block.instructions() { if let SsaOp::LoadLocalAddr { local_index, .. } = instr.op() { - let group = ssa.num_args() as u32 + *local_index as u32; + let group = (ssa.num_args() as u32).saturating_add(*local_index as u32); if group < real_local_limit { address_taken_groups.insert(group as usize); } diff --git a/dotscope/src/compiler/passes/deadcode.rs b/dotscope/src/compiler/passes/deadcode.rs index f2882472..f0793ed0 100644 --- a/dotscope/src/compiler/passes/deadcode.rs +++ b/dotscope/src/compiler/passes/deadcode.rs @@ -60,10 +60,11 @@ pub fn find_dead_tails(ssa: &SsaFunction) -> Vec<(usize, usize)> { ssa.iter_blocks() .filter_map(|(block_idx, block)| { // Find first terminator + let last_idx = block.instruction_count().checked_sub(1)?; for (instr_idx, instr) in block.instructions().iter().enumerate() { - if instr.op().is_terminator() && instr_idx < block.instruction_count() - 1 { + if instr.op().is_terminator() && instr_idx < last_idx { // There are instructions after the terminator - return Some((block_idx, instr_idx + 1)); + return Some((block_idx, instr_idx.saturating_add(1))); } } None @@ -565,9 +566,10 @@ impl DeadCodeEliminationPass { format!("dead {}", instr.mnemonic()) }; instr.set_op(SsaOp::Nop); + let location = block_idx.saturating_mul(1000).saturating_add(instr_idx); changes .record(EventKind::InstructionRemoved) - .at(method_token, block_idx * 1000 + instr_idx) + .at(method_token, location) .message(message); } } @@ -647,7 +649,7 @@ impl DeadCodeEliminationPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut simplified = 0; + let mut simplified: usize = 0; // Process in reverse order by phi_idx within each block let mut by_block: BTreeMap)>> = BTreeMap::new(); @@ -670,7 +672,7 @@ impl DeadCodeEliminationPass { .record(EventKind::PhiSimplified) .at(method_token, block_idx) .message(format!("replaced with {replacement_var}")); - simplified += 1; + simplified = simplified.saturating_add(1); } } else { // All self-references - just remove the phi @@ -679,7 +681,7 @@ impl DeadCodeEliminationPass { .record(EventKind::PhiSimplified) .at(method_token, block_idx) .message("removed self-referential phi"); - simplified += 1; + simplified = simplified.saturating_add(1); } } } @@ -710,7 +712,7 @@ impl DeadCodeEliminationPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut cleared = 0; + let mut cleared: usize = 0; let total_blocks = ssa.block_count(); for block_idx in 0..total_blocks { @@ -722,7 +724,7 @@ impl DeadCodeEliminationPass { .record(EventKind::BlockRemoved) .at(method_token, block_idx) .message(format!("removed unreachable block {block_idx}")); - cleared += 1; + cleared = cleared.saturating_add(1); } } } @@ -806,7 +808,7 @@ impl DeadCodeEliminationPass { by_block.entry(block_idx).or_default().push(instr_idx); } - let mut removed = 0; + let mut removed: usize = 0; for (block_idx, mut indices) in by_block { // Sort in reverse order to remove from end first (preserves indices) @@ -822,11 +824,12 @@ impl DeadCodeEliminationPass { .map_or("unknown", SsaInstruction::mnemonic); block.instructions_mut().remove(instr_idx); + let location = block_idx.saturating_mul(1000).saturating_add(instr_idx); changes .record(EventKind::InstructionRemoved) - .at(method_token, block_idx * 1000 + instr_idx) + .at(method_token, location) .message(format!("removed op-less instruction: {mnemonic}")); - removed += 1; + removed = removed.saturating_add(1); } } } @@ -857,7 +860,7 @@ impl DeadCodeEliminationPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut removed = 0; + let mut removed: usize = 0; for block_idx in reachable.iter() { if let Some(block) = ssa.block_mut(block_idx) { @@ -866,14 +869,14 @@ impl DeadCodeEliminationPass { .instructions_mut() .retain(|instr| !matches!(instr.op(), SsaOp::Nop)); let new_len = block.instructions().len(); - let nops_removed = original_len - new_len; + let nops_removed = original_len.saturating_sub(new_len); if nops_removed > 0 { changes .record(EventKind::InstructionRemoved) .at(method_token, block_idx) .message(format!("removed {nops_removed} Nop instructions")); - removed += nops_removed; + removed = removed.saturating_add(nops_removed); } } } @@ -908,30 +911,50 @@ impl DeadCodeEliminationPass { /// The total number of changes made during this iteration. Zero indicates /// the algorithm has reached a fixed point. fn run_iteration(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize { - let mut total_changes = 0; + let mut total_changes: usize = 0; // Step 1: Find reachable blocks let reachable = Self::find_reachable_blocks(ssa); // Step 2: Clear unreachable blocks - total_changes += Self::clear_unreachable_blocks(ssa, &reachable, method_token, changes); + total_changes = total_changes.saturating_add(Self::clear_unreachable_blocks( + ssa, + &reachable, + method_token, + changes, + )); // Step 3: Remove op-less instructions (stack simulation artifacts like ldloc/ldarg // that weren't decomposed to SSA operations) let opless = Self::find_opless_instructions(ssa, &reachable); - total_changes += Self::remove_opless_instructions(ssa, &opless, method_token, changes); + total_changes = total_changes.saturating_add(Self::remove_opless_instructions( + ssa, + &opless, + method_token, + changes, + )); // Step 4: Remove Nop instructions (simplifies CFG for block merging) - total_changes += Self::remove_nop_instructions(ssa, &reachable, method_token, changes); + total_changes = total_changes.saturating_add(Self::remove_nop_instructions( + ssa, + &reachable, + method_token, + changes, + )); // Step 5: Prune phi operands from unreachable predecessors - total_changes += ssa.prune_phi_operands(&reachable); + total_changes = total_changes.saturating_add(ssa.prune_phi_operands(&reachable)); let reachable_set: BTreeSet = reachable.iter().collect(); // Step 6: Find and simplify trivial phis (doesn't need liveness) // Trivial phis are identified purely by structure (all operands same or self-referential) let trivial_phis = PhiAnalyzer::new(ssa).find_all_trivial(&reachable_set); - total_changes += Self::simplify_trivial_phis(ssa, &trivial_phis, method_token, changes); + total_changes = total_changes.saturating_add(Self::simplify_trivial_phis( + ssa, + &trivial_phis, + method_token, + changes, + )); // Step 7: Recompute reachability after phi simplification let reachable = Self::find_reachable_blocks(ssa); @@ -956,13 +979,13 @@ impl DeadCodeEliminationPass { } Self::remove_phis(ssa, &dead_phis, method_token, changes); - total_changes += dead_phis.len(); + total_changes = total_changes.saturating_add(dead_phis.len()); // Step 10: Find and remove dead definitions (pure ops with unused results) let dead_defs = Self::find_dead_definitions(ssa, &reachable, &live, &dead_phi_results); let c10 = dead_defs.len(); Self::remove_instructions(ssa, &dead_defs, method_token, changes); - total_changes += c10; + total_changes = total_changes.saturating_add(c10); // Step 10b: Clean up Nops created by remove_instructions (which replaces // dead instructions with Nop to preserve indices). Without this, the next diff --git a/dotscope/src/compiler/passes/gvn.rs b/dotscope/src/compiler/passes/gvn.rs index 54fc6ddf..5830c855 100644 --- a/dotscope/src/compiler/passes/gvn.rs +++ b/dotscope/src/compiler/passes/gvn.rs @@ -164,7 +164,7 @@ impl GlobalValueNumberingPass { // including phi operands, then nop-out the dead instruction. // This prevents ping-ponging with DCE: without nop-out, DCE would find // the dead instruction as "new work" on the next normalization iteration. - let mut total_replaced = 0; + let mut total_replaced: usize = 0; for (redundant_var, original_var, block_idx, instr_idx) in &redundant { let result = ssa.replace_uses_including_phis(*redundant_var, *original_var); if result.replaced > 0 { @@ -175,7 +175,7 @@ impl GlobalValueNumberingPass { "GVN: {redundant_var} → {original_var} ({} uses)", result.replaced )); - total_replaced += result.replaced; + total_replaced = total_replaced.saturating_add(result.replaced); } // Nop-out the redundant instruction so rebuild_ssa's strip_nops // removes it. This avoids leaving dead instructions for DCE to find. diff --git a/dotscope/src/compiler/passes/inlining.rs b/dotscope/src/compiler/passes/inlining.rs index e65a1be2..a864af61 100644 --- a/dotscope/src/compiler/passes/inlining.rs +++ b/dotscope/src/compiler/passes/inlining.rs @@ -405,27 +405,29 @@ impl<'a> InliningContext<'a> { ReturnInfo::Constant(value) => { if let Some(dest_var) = dest { if let Some(block) = self.caller_ssa.block_mut(call_block_idx) { - let instr = &mut block.instructions_mut()[call_instr_idx]; - instr.set_op(SsaOp::Const { - dest: dest_var, - value: value.clone(), - }); - self.changes - .record(EventKind::MethodInlined) - .at(self.caller_token, call_instr_idx) - .message(format!("inlined constant {callee_token:?}")); - return true; + if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) { + instr.set_op(SsaOp::Const { + dest: dest_var, + value: value.clone(), + }); + self.changes + .record(EventKind::MethodInlined) + .at(self.caller_token, call_instr_idx) + .message(format!("inlined constant {callee_token:?}")); + return true; + } } } else { // Void destination but constant return - just remove the call if let Some(block) = self.caller_ssa.block_mut(call_block_idx) { - let instr = &mut block.instructions_mut()[call_instr_idx]; - instr.set_op(SsaOp::Nop); - self.changes - .record(EventKind::MethodInlined) - .at(self.caller_token, call_instr_idx) - .message(format!("eliminated pure call {callee_token:?}")); - return true; + if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) { + instr.set_op(SsaOp::Nop); + self.changes + .record(EventKind::MethodInlined) + .at(self.caller_token, call_instr_idx) + .message(format!("eliminated pure call {callee_token:?}")); + return true; + } } } } @@ -433,29 +435,31 @@ impl<'a> InliningContext<'a> { if let Some(dest_var) = dest { if let Some(&src_var) = args.get(param_idx) { if let Some(block) = self.caller_ssa.block_mut(call_block_idx) { - let instr = &mut block.instructions_mut()[call_instr_idx]; - instr.set_op(SsaOp::Copy { - dest: dest_var, - src: src_var, - }); - self.changes - .record(EventKind::MethodInlined) - .at(self.caller_token, call_instr_idx) - .message(format!("inlined passthrough {callee_token:?}")); - return true; + if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) { + instr.set_op(SsaOp::Copy { + dest: dest_var, + src: src_var, + }); + self.changes + .record(EventKind::MethodInlined) + .at(self.caller_token, call_instr_idx) + .message(format!("inlined passthrough {callee_token:?}")); + return true; + } } } } } ReturnInfo::Void => { if let Some(block) = self.caller_ssa.block_mut(call_block_idx) { - let instr = &mut block.instructions_mut()[call_instr_idx]; - instr.set_op(SsaOp::Nop); - self.changes - .record(EventKind::MethodInlined) - .at(self.caller_token, call_instr_idx) - .message(format!("eliminated void call {callee_token:?}")); - return true; + if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) { + instr.set_op(SsaOp::Nop); + self.changes + .record(EventKind::MethodInlined) + .at(self.caller_token, call_instr_idx) + .message(format!("eliminated void call {callee_token:?}")); + return true; + } } } ReturnInfo::PureComputation | ReturnInfo::Dynamic | ReturnInfo::Unknown => { @@ -549,16 +553,22 @@ impl<'a> InliningContext<'a> { }; // Replace call instruction with first inlined op (or Nop if empty) - if let Some(first_op) = inlined_ops.first().cloned() { - block.instructions_mut()[call_instr_idx].set_op(first_op); + let first_op = inlined_ops.first().cloned(); + if let Some(instr) = block.instructions_mut().get_mut(call_instr_idx) { + if let Some(op) = first_op { + instr.set_op(op); + } else { + instr.set_op(SsaOp::Nop); + } } else { - block.instructions_mut()[call_instr_idx].set_op(SsaOp::Nop); + return false; } // Insert remaining inlined ops let instructions = block.instructions_mut(); + let base = call_instr_idx.saturating_add(1); for (i, op) in inlined_ops.into_iter().skip(1).enumerate() { - instructions.insert(call_instr_idx + 1 + i, SsaInstruction::synthetic(op)); + instructions.insert(base.saturating_add(i), SsaInstruction::synthetic(op)); } // Handle return value @@ -569,7 +579,7 @@ impl<'a> InliningContext<'a> { let Some(block) = self.caller_ssa.block_mut(call_block_idx) else { return false; }; - let insert_pos = call_instr_idx + 1; + let insert_pos = call_instr_idx.saturating_add(1); block.instructions_mut().insert( insert_pos, SsaInstruction::synthetic(SsaOp::Copy { diff --git a/dotscope/src/compiler/passes/licm.rs b/dotscope/src/compiler/passes/licm.rs index e5da92bc..2e7a7b65 100644 --- a/dotscope/src/compiler/passes/licm.rs +++ b/dotscope/src/compiler/passes/licm.rs @@ -105,7 +105,7 @@ impl SsaPass for LicmPass { return Ok(false); } - let mut total_hoisted = 0; + let mut total_hoisted: usize = 0; // Process loops from innermost to outermost. // This naturally propagates hoists through nesting levels: inner hoists @@ -221,7 +221,8 @@ impl SsaPass for LicmPass { // Count hoistable instructions per block let mut hoist_count_per_block: HashMap = HashMap::new(); for (block_idx, _) in &hoistable { - *hoist_count_per_block.entry(*block_idx).or_insert(0) += 1; + let entry = hoist_count_per_block.entry(*block_idx).or_insert(0); + *entry = entry.saturating_add(1); } // Find blocks that would become trampolines @@ -299,7 +300,7 @@ impl SsaPass for LicmPass { if let Some(preheader_block) = ssa.block_mut(preheader.index()) { let new_instr = SsaInstruction::synthetic(op.clone()); let instrs = preheader_block.instructions_mut(); - instrs.insert(insert_base + i, new_instr); + instrs.insert(insert_base.saturating_add(i), new_instr); } // Remove from original location (replace with Nop) @@ -309,7 +310,7 @@ impl SsaPass for LicmPass { } } - total_hoisted += 1; + total_hoisted = total_hoisted.saturating_add(1); } // Update phi operands at successor blocks. When all non-terminator diff --git a/dotscope/src/compiler/passes/loopcanon.rs b/dotscope/src/compiler/passes/loopcanon.rs index 2c554908..568835b6 100644 --- a/dotscope/src/compiler/passes/loopcanon.rs +++ b/dotscope/src/compiler/passes/loopcanon.rs @@ -107,7 +107,7 @@ impl LoopCanonicalizationPass { method_token: Token, changes: &mut EventLog, ) -> usize { - let mut total_modified = 0; + let mut total_modified: usize = 0; // We need to iterate until no more changes because inserting blocks // can affect loop structure @@ -118,7 +118,7 @@ impl LoopCanonicalizationPass { break; } - let mut modified_this_iteration = 0; + let mut modified_this_iteration: usize = 0; // Process loops from innermost to outermost to avoid invalidating // parent loop analysis when modifying inner loops @@ -134,7 +134,7 @@ impl LoopCanonicalizationPass { method_token, changes, ); - modified_this_iteration += 1; + modified_this_iteration = modified_this_iteration.saturating_add(1); // After inserting a preheader, we need to re-analyze loops break; } @@ -143,13 +143,13 @@ impl LoopCanonicalizationPass { // Check if this loop needs latch unification if !loop_info.has_single_latch() && loop_info.latches.len() > 1 { Self::unify_latches(ssa, loop_info, method_token, changes); - modified_this_iteration += 1; + modified_this_iteration = modified_this_iteration.saturating_add(1); // After unifying latches, we need to re-analyze loops break; } } - total_modified += modified_this_iteration; + total_modified = total_modified.saturating_add(modified_this_iteration); if modified_this_iteration == 0 { break; @@ -296,9 +296,9 @@ impl LoopCanonicalizationPass { // For non-loop values: if there was a phi created in preheader, // reference that phi's result; otherwise reference the single value if !non_loop_values.is_empty() { - if non_loop_values.len() == 1 { + if let [single] = non_loop_values.as_slice() { // Single non-loop predecessor: just update the predecessor - operands.push(PhiOperand::new(non_loop_values[0].value(), preheader_idx)); + operands.push(PhiOperand::new(single.value(), preheader_idx)); } else if let Some(&preheader_var) = preheader_phi_map.get(&origin) { // Multiple non-loop predecessors: use the phi we created in preheader operands.push(PhiOperand::new(preheader_var, preheader_idx)); @@ -369,9 +369,9 @@ impl LoopCanonicalizationPass { } latch_phi_vars.insert(*origin, new_var); unified_latch.phi_nodes_mut().push(latch_phi); - } else if latch_operands.len() == 1 { + } else if let [single] = latch_operands.as_slice() { // Single latch operand - just remember its value - latch_phi_vars.insert(*origin, latch_operands[0].value()); + latch_phi_vars.insert(*origin, single.value()); } } diff --git a/dotscope/src/compiler/passes/predicates.rs b/dotscope/src/compiler/passes/predicates.rs index 771c090e..e17253e6 100644 --- a/dotscope/src/compiler/passes/predicates.rs +++ b/dotscope/src/compiler/passes/predicates.rs @@ -515,7 +515,8 @@ impl OpaquePredicatePass { // Nested analysis if let Some(left_op) = left_def { - let left_result = Self::analyze_predicate_with_cache(left_op, cache, depth + 1); + let left_result = + Self::analyze_predicate_with_cache(left_op, cache, depth.saturating_add(1)); if left_result != PredicateResult::Unknown { if let Some(r) = right_def { if Self::is_one_constant(r) { @@ -1591,12 +1592,12 @@ impl OpaquePredicatePass { for block in ssa.blocks() { for phi in block.phi_nodes() { let operands: Vec<_> = phi.operands().iter().collect(); - if operands.is_empty() { + let Some(first_operand) = operands.first() else { continue; - } + }; // Check if all operands come from the same constant - let first_val = operands[0].value(); + let first_val = first_operand.value(); let mut all_same_const = true; let mut const_value = None; @@ -1813,14 +1814,16 @@ impl SsaPass for OpaquePredicatePass { } else { ConstValue::False }; - block.instructions_mut()[instr_idx].set_op(SsaOp::Const { - dest, - value: const_value, - }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("opaque predicate → {value}")); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + instr.set_op(SsaOp::Const { + dest, + value: const_value, + }); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(format!("opaque predicate → {value}")); + } } } @@ -1829,18 +1832,22 @@ impl SsaPass for OpaquePredicatePass { if let Some(block) = ssa.block_mut(block_idx) { match simplification { ComparisonSimplification::SimplerOp { new_op, reason } => { - block.instructions_mut()[instr_idx].set_op(new_op); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(reason); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + instr.set_op(new_op); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(reason); + } } ComparisonSimplification::Copy { dest, src, reason } => { - block.instructions_mut()[instr_idx].set_op(SsaOp::Copy { dest, src }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(reason); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + instr.set_op(SsaOp::Copy { dest, src }); + changes + .record(EventKind::ConstantFolded) + .at(method_token, instr_idx) + .message(reason); + } } } } diff --git a/dotscope/src/compiler/passes/proxy.rs b/dotscope/src/compiler/passes/proxy.rs index acee3ffa..1ad53b3b 100644 --- a/dotscope/src/compiler/passes/proxy.rs +++ b/dotscope/src/compiler/passes/proxy.rs @@ -166,20 +166,20 @@ impl ProxyDevirtualizationPass { // Find the call instruction (Call, CallVirt, or NewObj) let mut call_info: Option<(&MethodRef, &[SsaVarId], Option, ForwardKind)> = None; - let mut call_count = 0; + let mut call_count: usize = 0; for instr in instructions { match instr.op() { SsaOp::Call { method, args, dest } => { - call_count += 1; + call_count = call_count.saturating_add(1); call_info = Some((method, args, *dest, ForwardKind::Call)); } SsaOp::CallVirt { method, args, dest } => { - call_count += 1; + call_count = call_count.saturating_add(1); call_info = Some((method, args, *dest, ForwardKind::CallVirt)); } SsaOp::NewObj { ctor, args, dest } => { - call_count += 1; + call_count = call_count.saturating_add(1); call_info = Some((ctor, args, Some(*dest), ForwardKind::NewObj)); } // These are allowed in proxy methods @@ -648,10 +648,19 @@ impl ProxyDevirtualizationPass { // Insert const instructions before the call site let instrs = block.instructions_mut(); for (i, const_op) in const_ops.into_iter().enumerate() { - instrs.insert(call_instr_idx + i, SsaInstruction::synthetic(const_op)); + let Some(insert_idx) = call_instr_idx.checked_add(i) else { + return false; + }; + instrs.insert(insert_idx, SsaInstruction::synthetic(const_op)); } // The call instruction shifted by the number of inserted consts - instrs[call_instr_idx + num_consts].set_op(new_op); + let Some(call_idx) = call_instr_idx.checked_add(num_consts) else { + return false; + }; + let Some(call_instr) = instrs.get_mut(call_idx) else { + return false; + }; + call_instr.set_op(new_op); changes .record(EventKind::MethodInlined) .at(caller_token, call_instr_idx) diff --git a/dotscope/src/compiler/passes/ranges.rs b/dotscope/src/compiler/passes/ranges.rs index 18b29be3..df385cd3 100644 --- a/dotscope/src/compiler/passes/ranges.rs +++ b/dotscope/src/compiler/passes/ranges.rs @@ -164,10 +164,12 @@ impl SsaPass for ValueRangePropagationPass { } else { ConstValue::False }; - block.instructions_mut()[instr_idx].set_op(SsaOp::Const { - dest, - value: const_value, - }); + if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { + instr.set_op(SsaOp::Const { + dest, + value: const_value, + }); + } changes .record(EventKind::ConstantFolded) .at(method_token, instr_idx) @@ -367,10 +369,10 @@ impl RangeAnalysis { // Iteration limit to prevent infinite loops with widening ranges. // In practice, analysis should converge quickly for most methods. // If we hit this limit, we still have valid (possibly imprecise) results. - let mut iterations = 0; + let mut iterations: usize = 0; loop { - iterations += 1; + iterations = iterations.saturating_add(1); if iterations > self.max_iterations { // Hit iteration limit - return with current results. // This can happen with unbounded widening in loops. @@ -516,8 +518,8 @@ impl RangeAnalysis { if let Some(idx) = range.as_constant().and_then(|i| usize::try_from(i).ok()) { // Known switch value - if idx < targets.len() { - self.add_cfg_edge(block_id, targets[idx]); + if let Some(&target) = targets.get(idx) { + self.add_cfg_edge(block_id, target); } else { self.add_cfg_edge(block_id, *default); } @@ -665,7 +667,7 @@ impl RangeAnalysis { // Positive divisor: result in [0, n-1] if dividend is non-negative let l = self.get_range(*left); if l.is_always_non_negative() { - return ValueRange::bounded(0, n - 1); + return ValueRange::bounded(0, n.saturating_sub(1)); } } } diff --git a/dotscope/src/compiler/passes/reassociate.rs b/dotscope/src/compiler/passes/reassociate.rs index e7687285..c565d628 100644 --- a/dotscope/src/compiler/passes/reassociate.rs +++ b/dotscope/src/compiler/passes/reassociate.rs @@ -382,22 +382,24 @@ impl ReassociationPass { } // Update the inner operation to just use base_var and the combined constant - let inner_instr = &mut block.instructions_mut()[candidate.inner_instr]; - inner_instr.set_op(Self::make_op( - candidate.op_kind, - candidate.inner_dest, - candidate.base_var, - candidate.const1_var, - )); + if let Some(inner_instr) = block.instructions_mut().get_mut(candidate.inner_instr) { + inner_instr.set_op(Self::make_op( + candidate.op_kind, + candidate.inner_dest, + candidate.base_var, + candidate.const1_var, + )); + } } // Replace the outer operation with a Copy from the inner result if let Some(block) = ssa.block_mut(candidate.block_idx) { - let outer_instr = &mut block.instructions_mut()[candidate.instr_idx]; - outer_instr.set_op(SsaOp::Copy { - dest: candidate.dest, - src: candidate.inner_dest, - }); + if let Some(outer_instr) = block.instructions_mut().get_mut(candidate.instr_idx) { + outer_instr.set_op(SsaOp::Copy { + dest: candidate.dest, + src: candidate.inner_dest, + }); + } } modified.insert((candidate.inner_block, candidate.inner_instr)); diff --git a/dotscope/src/compiler/passes/strength.rs b/dotscope/src/compiler/passes/strength.rs index f5105649..9bf68145 100644 --- a/dotscope/src/compiler/passes/strength.rs +++ b/dotscope/src/compiler/passes/strength.rs @@ -194,7 +194,7 @@ impl<'a> ReductionChecker<'a> { }; let value = const_value.as_i64()?; let _exponent = is_power_of_two(value)?; - let mask = value - 1; // 2^n - 1 + let mask = value.checked_sub(1)?; // 2^n - 1 let uses = self.index.use_count(divisor_var); if uses != 1 || self.used_constants.contains(divisor_var.index()) { @@ -363,21 +363,26 @@ impl StrengthReductionPass { for candidate in candidates { // First, update the constant definition if let Some(block) = ssa.block_mut(candidate.const_block) { - let const_instr = &mut block.instructions_mut()[candidate.const_instr]; - const_instr.set_op(SsaOp::Const { - dest: candidate.const_var, - value: candidate.new_const_value, - }); + if let Some(const_instr) = block.instructions_mut().get_mut(candidate.const_instr) { + const_instr.set_op(SsaOp::Const { + dest: candidate.const_var, + value: candidate.new_const_value, + }); + } } // Then, update the operation if let Some(block) = ssa.block_mut(candidate.location.block_idx) { - let instr = &mut block.instructions_mut()[candidate.location.instr_idx]; - instr.set_op(candidate.new_op); - changes - .record(EventKind::StrengthReduced) - .at(method_token, candidate.location.instr_idx) - .message(&candidate.description); + if let Some(instr) = block + .instructions_mut() + .get_mut(candidate.location.instr_idx) + { + instr.set_op(candidate.new_op); + changes + .record(EventKind::StrengthReduced) + .at(method_token, candidate.location.instr_idx) + .message(&candidate.description); + } } } } diff --git a/dotscope/src/compiler/scheduler.rs b/dotscope/src/compiler/scheduler.rs index b0f20519..2fd3ef97 100644 --- a/dotscope/src/compiler/scheduler.rs +++ b/dotscope/src/compiler/scheduler.rs @@ -174,7 +174,7 @@ impl PassScheduler { /// /// # Errors /// - /// Returns [`Error::SsaError`] if a cycle is detected in the capability + /// Returns [`crate::Error::SsaError`] if a cycle is detected in the capability /// dependencies, including the names of the passes involved in the cycle. fn compute_layer_assignment(&self) -> Result> { let n = self.passes.len(); @@ -203,7 +203,9 @@ impl PassScheduler { if let Some(provider_indices) = providers.get(&cap) { for &j in provider_indices { if j != i { - deps[i].push(j); + if let Some(slot) = deps.get_mut(i) { + slot.push(j); + } let _ = graph.add_edge(j, i, ()); } } @@ -214,7 +216,10 @@ impl PassScheduler { // Validate the DAG is acyclic via topological sort if graph.topological_sort().is_none() { if let Some(cycle) = graph.find_any_cycle() { - let names: Vec<&str> = cycle.iter().map(|&i| self.passes[i].0.name()).collect(); + let names: Vec<&str> = cycle + .iter() + .filter_map(|&i| self.passes.get(i).map(|p| p.0.name())) + .collect(); return Err(Error::SsaError(format!( "Cycle detected in pass capability dependencies: {}", names.join(" → ") @@ -232,9 +237,17 @@ impl PassScheduler { while changed { changed = false; for i in 0..n { - for &dep in &deps[i] { - if layer[i] <= layer[dep] { - layer[i] = layer[dep] + 1; + let dep_list = match deps.get(i) { + Some(d) => d.clone(), + None => continue, + }; + for dep in dep_list { + let layer_i = layer.get(i).copied().unwrap_or(0); + let layer_dep = layer.get(dep).copied().unwrap_or(0); + if layer_i <= layer_dep { + if let Some(slot) = layer.get_mut(i) { + *slot = layer_dep.saturating_add(1); + } changed = true; } } @@ -247,14 +260,15 @@ impl PassScheduler { debug!( "Capability scheduling: {} passes across {} layers", n, - max_layer + 1 + max_layer.saturating_add(1) ); for (i, (pass, fallback)) in self.passes.iter().enumerate() { - if layer[i] != *fallback { + let layer_i = layer.get(i).copied().unwrap_or(*fallback); + if layer_i != *fallback { debug!( " pass '{}': layer {} (moved from fallback {})", pass.name(), - layer[i], + layer_i, fallback ); } @@ -500,7 +514,10 @@ impl PassScheduler { iteration_modified: Option<&DashSet>, ) -> Result { for &idx in indices { - all_passes[idx].0.initialize(ctx)?; + let pass_entry = all_passes.get_mut(idx).ok_or_else(|| { + Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) + })?; + pass_entry.0.initialize(ctx)?; } let dirty_set = state.map(|s| &s.method_dirty); @@ -509,14 +526,20 @@ impl PassScheduler { let any_changed = AtomicBool::new(false); for &idx in indices { - let pass = &all_passes[idx].0; + let pass_entry = all_passes.get(idx).ok_or_else(|| { + Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) + })?; + let pass = &pass_entry.0; if pass.is_global() && pass.run_global(ctx, assembly)? { any_changed.store(true, Ordering::Relaxed); } } for &idx in indices { - let pass = &all_passes[idx].0; + let pass_entry = all_passes.get(idx).ok_or_else(|| { + Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) + })?; + let pass = &pass_entry.0; if pass.is_global() { continue; } @@ -536,7 +559,10 @@ impl PassScheduler { } for &idx in indices { - all_passes[idx].0.finalize(ctx)?; + let pass_entry = all_passes.get_mut(idx).ok_or_else(|| { + Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) + })?; + pass_entry.0.finalize(ctx)?; } Ok(any_changed.load(Ordering::Relaxed)) @@ -712,21 +738,27 @@ impl PassScheduler { let layer_assignment = self.compute_layer_assignment()?; // Group pass indices by layer, then discard empty layers - let num_layers = layer_assignment.iter().copied().max().map_or(0, |m| m + 1); + let num_layers = layer_assignment + .iter() + .copied() + .max() + .map_or(0, |m| m.saturating_add(1)); let mut layer_indices: Vec> = vec![vec![]; num_layers]; for (i, &layer) in layer_assignment.iter().enumerate() { - layer_indices[layer].push(i); + if let Some(slot) = layer_indices.get_mut(layer) { + slot.push(i); + } } layer_indices.retain(|layer| !layer.is_empty()); - let mut stable_count = 0; - let mut iterations = 0; + let mut stable_count: usize = 0; + let mut iterations: usize = 0; let max_phase = self.max_phase_iterations; let max_iterations = self.max_iterations; let stable_iterations = self.stable_iterations; for iteration in 0..max_iterations { - iterations = iteration + 1; + iterations = iteration.saturating_add(1); debug!("Pipeline iteration {}/{}", iterations, max_iterations); // Track which methods are modified in this iteration so we can @@ -790,7 +822,7 @@ impl PassScheduler { if iteration_changed { stable_count = 0; } else { - stable_count += 1; + stable_count = stable_count.saturating_add(1); if stable_count >= stable_iterations { debug!("Pipeline stable after {} iterations", iterations); break; diff --git a/dotscope/src/compiler/summary.rs b/dotscope/src/compiler/summary.rs index fff94847..14e960d3 100644 --- a/dotscope/src/compiler/summary.rs +++ b/dotscope/src/compiler/summary.rs @@ -230,42 +230,42 @@ impl MethodSummary { // Return type is string if self.returns_string() { - score += 30; + score = score.saturating_add(30); } // Has int or byte[] parameter if self.has_integer_parameter() { - score += 15; + score = score.saturating_add(15); } if self.has_byte_array_parameter() { - score += 15; + score = score.saturating_add(15); } // Contains XOR operations if self.has_xor_operations { - score += 15; + score = score.saturating_add(15); } // Contains array access if self.has_array_access { - score += 10; + score = score.saturating_add(10); } // Contains encoding calls if self.has_encoding_calls { - score += 20; + score = score.saturating_add(20); } // Called with many distinct constant values if self.distinct_arg_values >= 5 { - score += 15; + score = score.saturating_add(15); } else if self.distinct_arg_values >= 2 { - score += 5; + score = score.saturating_add(5); } // Small method (decryptors tend to be compact) if self.instruction_count > 0 && self.instruction_count <= 100 { - score += 5; + score = score.saturating_add(5); } score.min(100) diff --git a/dotscope/src/deobfuscation/cleanup.rs b/dotscope/src/deobfuscation/cleanup.rs index f4eab5ce..72add578 100644 --- a/dotscope/src/deobfuscation/cleanup.rs +++ b/dotscope/src/deobfuscation/cleanup.rs @@ -89,7 +89,9 @@ pub(crate) fn build_cleanup_request( if !detections.is_detected(tech.id()) { continue; } - let detection = detections.get(tech.id()).unwrap(); + let Some(detection) = detections.get(tech.id()) else { + continue; + }; if let Some(tech_cleanup) = tech.cleanup(detection) { request.merge(&tech_cleanup); } @@ -363,7 +365,7 @@ fn repair_duplicate_assembly_rows(cil_assembly: &mut CilAssembly, ctx: &Analysis } // Remove duplicate rows (keep RID 1, remove RID 2..N) - let duplicates = row_count - 1; + let duplicates = row_count.saturating_sub(1); for rid in (2..=row_count).rev() { if let Err(e) = cil_assembly.table_row_remove(TableId::Assembly, rid) { log::warn!("Failed to remove duplicate Assembly row {rid}: {e}"); @@ -392,7 +394,7 @@ fn repair_duplicate_module_rows(cil_assembly: &mut CilAssembly, ctx: &AnalysisCo return; } - let duplicates = row_count - 1; + let duplicates = row_count.saturating_sub(1); for rid in (2..=row_count).rev() { if let Err(e) = cil_assembly.table_row_remove(TableId::Module, rid) { log::warn!("Failed to remove duplicate Module row {rid}: {e}"); @@ -588,17 +590,18 @@ fn repair_duplicate_typedef_rows( } else { // Check if the duplicate has methods or fields let method_start = row.method_list; + let next_rid = row.rid.saturating_add(1); let method_end = typedef_rows .iter() - .find(|t| t.rid == row.rid + 1) + .find(|t| t.rid == next_rid) .map(|t| t.method_list) - .unwrap_or(method_count + 1); + .unwrap_or_else(|| method_count.saturating_add(1)); let field_start = row.field_list; let field_end = typedef_rows .iter() - .find(|t| t.rid == row.rid + 1) + .find(|t| t.rid == next_rid) .map(|t| t.field_list) - .unwrap_or(field_count + 1); + .unwrap_or_else(|| field_count.saturating_add(1)); let has_methods = method_end > method_start; let has_fields = field_end > field_start; @@ -641,7 +644,7 @@ fn repair_duplicate_typedef_rows( } } - let total = removed_count + cleanup_count; + let total = removed_count.saturating_add(cleanup_count); log::info!( "Repaired TypeDef table: {removed_count} removed, {cleanup_count} scheduled for cleanup ({total} total duplicates, ECMA-335 §22.37)" ); @@ -672,16 +675,17 @@ fn repair_global_field_visibility(cil_assembly: &mut CilAssembly, ctx: &Analysis // Find type (always RID 1) and its field range let typedef_rows: Vec = typedefs.into_iter().collect(); - if typedef_rows.is_empty() { + let Some(module_type) = typedef_rows.first() else { return; - } - let module_type = &typedef_rows[0]; + }; let field_start = module_type.field_list; - let field_end = if typedef_rows.len() > 1 { - typedef_rows[1].field_list + let field_end = if let Some(next) = typedef_rows.get(1) { + next.field_list } else { // If there's only one type, the field end is the total field count + 1 - cil_assembly.original_table_row_count(TableId::Field) + 1 + cil_assembly + .original_table_row_count(TableId::Field) + .saturating_add(1) }; if field_start >= field_end { @@ -693,14 +697,13 @@ fn repair_global_field_visibility(cil_assembly: &mut CilAssembly, ctx: &Analysis return; }; let fields: Vec = fields_table.into_iter().collect(); - let mut repaired = 0; + let mut repaired: usize = 0; for rid in field_start..field_end { - let idx = (rid - 1) as usize; - if idx >= fields.len() { + let idx = rid.saturating_sub(1) as usize; + let Some(field) = fields.get(idx) else { break; - } - let field = &fields[idx]; + }; let access = field.flags & 0x0007; // FieldAccessMask per ECMA-335 §II.23.1.5 // Valid access for global fields: CompilerControlled(0), Private(1), or Public(6) @@ -712,7 +715,7 @@ fn repair_global_field_visibility(cil_assembly: &mut CilAssembly, ctx: &Analysis { log::warn!("Failed to repair field {rid} visibility: {e}"); } else { - repaired += 1; + repaired = repaired.saturating_add(1); } } } @@ -779,8 +782,9 @@ fn sweep_dead_module_methods( let callers = callers_of.get(&method.token); let has_callers = callers.is_some_and(|c| !c.is_empty()); // A method is dead if it has known callers and ALL of them are deleted. - let all_callers_deleted = - has_callers && callers.unwrap().iter().all(|c| deleted_methods.contains(c)); + let all_callers_deleted = callers.is_some_and(|c| { + !c.is_empty() && c.iter().all(|caller| deleted_methods.contains(caller)) + }); // Methods with NO callers in the SSA graph are conservatively kept — // the SSA call graph only covers successfully-converted methods, so diff --git a/dotscope/src/deobfuscation/engine/analysis.rs b/dotscope/src/deobfuscation/engine/analysis.rs index 5c53d2c6..d567d9d5 100644 --- a/dotscope/src/deobfuscation/engine/analysis.rs +++ b/dotscope/src/deobfuscation/engine/analysis.rs @@ -298,8 +298,8 @@ impl DeobfuscationEngine { if is_impure { for &var in &uses { if let Some(param_idx) = ssa.is_parameter_variable(var) { - if param_idx < param_count { - pure_only[param_idx] = false; + if let Some(slot) = pure_only.get_mut(param_idx) { + *slot = false; } } } @@ -309,7 +309,7 @@ impl DeobfuscationEngine { // Finalize pure_usage_only for (i, summary) in summaries.iter_mut().enumerate() { - summary.pure_usage_only = pure_only[i] && summary.is_used; + summary.pure_usage_only = pure_only.get(i).copied().unwrap_or(false) && summary.is_used; } summaries diff --git a/dotscope/src/deobfuscation/engine/codegen.rs b/dotscope/src/deobfuscation/engine/codegen.rs index 6c6bcaac..f3497d91 100644 --- a/dotscope/src/deobfuscation/engine/codegen.rs +++ b/dotscope/src/deobfuscation/engine/codegen.rs @@ -51,7 +51,7 @@ impl DeobfuscationEngine { // Generate code for each processed method let mut codegen = SsaCodeGenerator::new(); - let mut methods_updated = 0; + let mut methods_updated: usize = 0; let mut old_sas_tokens = Vec::new(); for entry in ctx.processed_methods.iter() { @@ -149,7 +149,7 @@ impl DeobfuscationEngine { ctx.events .record(EventKind::CodeRegenerated) .method(method_token); - methods_updated += 1; + methods_updated = methods_updated.saturating_add(1); } // Finalize array types: creates the parent diff --git a/dotscope/src/deobfuscation/engine/detection.rs b/dotscope/src/deobfuscation/engine/detection.rs index 6afc8f75..ae443482 100644 --- a/dotscope/src/deobfuscation/engine/detection.rs +++ b/dotscope/src/deobfuscation/engine/detection.rs @@ -38,7 +38,9 @@ impl DeobfuscationEngine { if ctx.initialized_techniques.contains(tech.id()) { continue; } - let detection = detections.get(tech.id()).unwrap(); + let Some(detection) = detections.get(tech.id()) else { + continue; + }; tech.initialize(ctx, assembly, detection, detections); ctx.initialized_techniques.insert(tech.id()); } @@ -73,7 +75,9 @@ impl DeobfuscationEngine { let Some(phase) = tech.ssa_phase() else { continue; }; - let detection = detections.get(tech.id()).unwrap(); + let Some(detection) = detections.get(tech.id()) else { + continue; + }; // Initialize if not already done if !ctx.initialized_techniques.contains(tech.id()) { diff --git a/dotscope/src/deobfuscation/engine/pipeline.rs b/dotscope/src/deobfuscation/engine/pipeline.rs index b138866c..1d554736 100644 --- a/dotscope/src/deobfuscation/engine/pipeline.rs +++ b/dotscope/src/deobfuscation/engine/pipeline.rs @@ -104,7 +104,7 @@ impl<'a> PipelineRun<'a> { if pipeline_iteration > 0 { info!( "Pipeline iteration {}: ByteTransform requested, re-running byte transforms", - pipeline_iteration + 1 + pipeline_iteration.saturating_add(1) ); assembly = self.run_byte_transforms(assembly)?; self.record_detections(); @@ -174,7 +174,9 @@ impl<'a> PipelineRun<'a> { if self.detections.is_transformed(tech.id()) { continue; } - let detection = self.detections.get(tech.id()).unwrap(); + let Some(detection) = self.detections.get(tech.id()) else { + continue; + }; let tech_start = Instant::now(); let Some(transform_result) = tech.byte_transform(&mut working, detection, &self.detections) @@ -308,8 +310,8 @@ impl<'a> PipelineRun<'a> { } info!( "[technique] SSA re-detected (outer {}, round {}): {}", - outer_iteration + 1, - *detection_round + 1, + outer_iteration.saturating_add(1), + detection_round.saturating_add(1), tech.name() ); self.detections.merge(tech.id(), ssa_det); @@ -320,13 +322,13 @@ impl<'a> PipelineRun<'a> { return Ok(false); } - *detection_round += 1; + *detection_round = detection_round.saturating_add(1); self.record_detections(); let passes_before = scheduler.pass_count(); self.engine .initialize_and_create_passes(ctx, assembly_arc, &self.detections, scheduler); - let passes_added = scheduler.pass_count() - passes_before; + let passes_added = scheduler.pass_count().saturating_sub(passes_before); self.engine.configure_no_inline(ctx); if passes_added > 0 { @@ -479,7 +481,7 @@ impl<'a> PipelineRun<'a> { let round_iterations = scheduler.run_pipeline(ctx, assembly_arc, Some(&ctx.processing_state))?; - self.iterations += round_iterations; + self.iterations = self.iterations.saturating_add(round_iterations); let has_pending_work_items = !ctx.work_queue.is_empty(); let mut detection_added_work = false; @@ -501,7 +503,7 @@ impl<'a> PipelineRun<'a> { if !has_work { debug!( "SSA fixpoint reached after {} iteration(s)", - outer_iteration + 1 + outer_iteration.saturating_add(1) ); break; } @@ -650,8 +652,11 @@ impl<'a> PipelineRun<'a> { } if neutralized { - self.iterations += - scheduler.run_pipeline(ctx, assembly_arc, Some(&ctx.processing_state))?; + self.iterations = self.iterations.saturating_add(scheduler.run_pipeline( + ctx, + assembly_arc, + Some(&ctx.processing_state), + )?); } Ok(()) diff --git a/dotscope/src/deobfuscation/passes/antidebug.rs b/dotscope/src/deobfuscation/passes/antidebug.rs index d468d7de..857dbc55 100644 --- a/dotscope/src/deobfuscation/passes/antidebug.rs +++ b/dotscope/src/deobfuscation/passes/antidebug.rs @@ -262,11 +262,11 @@ impl SsaPass for SentinelTaintRemovalPass { .find(|&&(bi, ii, _)| bi == block_idx && ii == instr_idx) { instr.set_op(SsaOp::Jump { target }); - neutralized += 1; + neutralized = neutralized.saturating_add(1); } } else { instr.set_op(SsaOp::Nop); - neutralized += 1; + neutralized = neutralized.saturating_add(1); } } } @@ -279,7 +279,7 @@ impl SsaPass for SentinelTaintRemovalPass { if let Some(block) = ssa.block_mut(block_idx) { if phi_idx < block.phi_nodes().len() { block.phi_nodes_mut().remove(phi_idx); - neutralized += 1; + neutralized = neutralized.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/passes/bitmono/strings.rs b/dotscope/src/deobfuscation/passes/bitmono/strings.rs index 9d1442ae..65850736 100644 --- a/dotscope/src/deobfuscation/passes/bitmono/strings.rs +++ b/dotscope/src/deobfuscation/passes/bitmono/strings.rs @@ -363,9 +363,13 @@ fn find_decryption_sites( let mut all_found = true; for (arg_idx, arg_var) in args.iter().enumerate() { - if let Some(&(blk, idx, token)) = ldsfld_index.get(arg_var) { - ldsfld_locations[arg_idx] = (blk, idx); - field_tokens[arg_idx] = token; + if let (Some(&(blk, idx, token)), Some(loc), Some(tok)) = ( + ldsfld_index.get(arg_var), + ldsfld_locations.get_mut(arg_idx), + field_tokens.get_mut(arg_idx), + ) { + *loc = (blk, idx); + *tok = token; } else { all_found = false; break; diff --git a/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs b/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs index e75f3783..1903ccd0 100644 --- a/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs +++ b/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs @@ -184,15 +184,14 @@ fn find_unmanaged_string_sites( // traces) should not prevent finding the actual string constructor. // Limit search distance to avoid O(n²) behaviour across many call sites. const MAX_SEARCH_DISTANCE: usize = 20; - let search_end = (i + 1 + MAX_SEARCH_DISTANCE).min(instructions.len()); - for (j, next) in instructions - .iter() - .enumerate() - .skip(i + 1) - .take(search_end - (i + 1)) - { + let start = i.saturating_add(1); + let search_end = start + .saturating_add(MAX_SEARCH_DISTANCE) + .min(instructions.len()); + let take = search_end.saturating_sub(start); + for (j, next) in instructions.iter().enumerate().skip(start).take(take) { if let SsaOp::NewObj { dest, args, .. } = next.op() { - if args.len() == 1 && args[0] == call_dest { + if args.len() == 1 && args.first() == Some(&call_dest) { sites.push(UnmanagedStringSite { call_idx: i, newobj_idx: j, diff --git a/dotscope/src/deobfuscation/passes/decryption.rs b/dotscope/src/deobfuscation/passes/decryption.rs index 1ef9b31a..ae41f6bb 100644 --- a/dotscope/src/deobfuscation/passes/decryption.rs +++ b/dotscope/src/deobfuscation/passes/decryption.rs @@ -454,15 +454,15 @@ impl DecryptionPass { // where the return type is a type parameter. Try to extract the actual value. if let EmValue::ValueType { fields, .. } = em_value { // If the ValueType has a single field that's an ObjectRef, try to get string from it - if fields.len() == 1 { - if let EmValue::ObjectRef(href) = &fields[0] { + if let Some(first_field) = fields.first().filter(|_| fields.len() == 1) { + if let EmValue::ObjectRef(href) = first_field { if let Ok(s) = thread.heap().get_string(*href) { return Some(ConstValue::DecryptedString(s.to_string())); } } // Try primitive conversion on the single field (but not if it's Null) - if !matches!(fields[0], EmValue::Null) { - if let Some(cv) = fields[0].to_const_value() { + if !matches!(first_field, EmValue::Null) { + if let Some(cv) = first_field.to_const_value() { return Some(cv); } } @@ -634,7 +634,14 @@ impl DecryptionPass { } // Simulate the feeding update call itself - let feeding_update = &state_updates[call_site.feeding_update_idx]; + let Some(feeding_update) = state_updates.get(call_site.feeding_update_idx) else { + failures.push(( + call_site.decryptor, + location, + FailureReason::NonConstantArgs, + )); + continue; + }; #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let flag = resolver @@ -813,7 +820,9 @@ impl DecryptionPass { resolver: &mut ValueResolver<'_>, ) -> bool { for &idx in relevant_updates { - let update = &all_updates[idx]; + let Some(update) = all_updates.get(idx) else { + return false; + }; #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] let flag = resolver @@ -856,7 +865,13 @@ impl DecryptionPass { *changed = true; changes .record(EventKind::InstructionRemoved) - .at(method_token, update.block_idx * 1000 + update.instr_idx) + .at( + method_token, + update + .block_idx + .saturating_mul(1000) + .saturating_add(update.instr_idx), + ) .message(format!( "replaced state machine update call with const (result {:?})", update.dest @@ -885,7 +900,10 @@ impl DecryptionPass { *changed = true; changes .record(EventKind::InstructionRemoved) - .at(method_token, block_idx * 1000 + instr_idx) + .at( + method_token, + block_idx.saturating_mul(1000).saturating_add(instr_idx), + ) .message("removed state machine initialization call"); } } @@ -986,7 +1004,7 @@ impl SsaPass for DecryptionPass { continue; }; - let location = block_idx * 1000 + instr_idx; + let location = block_idx.saturating_mul(1000).saturating_add(instr_idx); if self.decryptors.is_already_decrypted(method_token, location) { continue; @@ -1060,7 +1078,10 @@ impl SsaPass for DecryptionPass { }; changes .record(event_kind) - .at(method_token, block_idx * 1000 + instr_idx) + .at( + method_token, + block_idx.saturating_mul(1000).saturating_add(instr_idx), + ) .message(format!("decrypted: {value}")); ctx.add_known_value(method_token, dest, value.clone()); diff --git a/dotscope/src/deobfuscation/passes/delegates.rs b/dotscope/src/deobfuscation/passes/delegates.rs index 94982651..bfaf36f5 100644 --- a/dotscope/src/deobfuscation/passes/delegates.rs +++ b/dotscope/src/deobfuscation/passes/delegates.rs @@ -430,7 +430,10 @@ impl SsaPass for DelegateProxyResolutionPass { } fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { - let remaining = self.affected_methods.len() - self.processed_methods.len(); + let remaining = self + .affected_methods + .len() + .saturating_sub(self.processed_methods.len()); if remaining > 0 { debug!( "Delegate proxy resolution: {} delegate types, {} remaining methods ({} already processed)", @@ -508,7 +511,11 @@ impl SsaPass for DelegateProxyResolutionPass { }; // Build the replacement: drop the last arg (delegate instance) - let new_args: Vec = args[..args.len() - 1].to_vec(); + let drop_to = args.len().saturating_sub(1); + let new_args: Vec = args + .get(..drop_to) + .map(<[SsaVarId]>::to_vec) + .unwrap_or_default(); let target_method = SsaMethodRef::new(target_entry.method_token); let new_op = if target_entry.is_virtual { diff --git a/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs b/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs index 598bccf1..16b69767 100644 --- a/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs +++ b/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs @@ -130,16 +130,15 @@ impl SsaPass for ArrayInitRestorationPass { } // Resolve the index argument to a constant - let Some(ConstValue::I32(index)) = constants.get(&args[0]) else { + let Some(arg0) = args.first() else { continue }; + let Some(ConstValue::I32(index)) = constants.get(arg0) else { continue; }; let index = *index as usize; - if index >= self.field_tokens.len() { + let Some(&field_token) = self.field_tokens.get(index) else { continue; - } - - let field_token = self.field_tokens[index]; + }; // Replace Call with Const(FieldHandle(...)) if let Some(dest) = dest { @@ -162,15 +161,17 @@ impl SsaPass for ArrayInitRestorationPass { if method_token == init_method && args.len() == 3 { // Replace Call(MyInitializeArray, array, handle, xorKey) // with Call(RuntimeHelpers.InitializeArray, array, handle) - replacements.push(( - block_idx, - instr_idx, - SsaOp::Call { - dest: *dest, - method: MethodRef::new(init_target), - args: vec![args[0], args[1]], - }, - )); + if let (Some(&a0), Some(&a1)) = (args.first(), args.get(1)) { + replacements.push(( + block_idx, + instr_idx, + SsaOp::Call { + dest: *dest, + method: MethodRef::new(init_target), + args: vec![a0, a1], + }, + )); + } } } } diff --git a/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs b/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs index b0095bd0..593c7faa 100644 --- a/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs +++ b/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs @@ -41,7 +41,7 @@ use crate::{ }, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Result, + CilObject, Error, Result, }; /// SSA pass that replaces JIEJIE.NET typeof container calls with @@ -124,22 +124,23 @@ impl SsaPass for TypeOfRestorationPass { } // Resolve the index argument to a constant - let Some(ConstValue::I32(index)) = constants.get(&args[0]) else { + let Some(arg0) = args.first() else { continue }; + let Some(ConstValue::I32(index)) = constants.get(arg0) else { continue; }; let index = *index as usize; - if index >= self.type_tokens.len() { + let Some(&type_token) = self.type_tokens.get(index) else { debug!( "jiejie-typeof: index {} out of bounds (max {}), skipping", index, self.type_tokens.len() ); continue; - } + }; if let Some(dest) = dest { - replacements.push((block_idx, instr_idx, self.type_tokens[index], *dest)); + replacements.push((block_idx, instr_idx, type_token, *dest)); } } } @@ -186,9 +187,15 @@ impl SsaPass for TypeOfRestorationPass { args: vec![handle_var], }); - ssa.blocks_mut()[*block_idx] + let block = ssa.blocks_mut().get_mut(*block_idx).ok_or_else(|| { + Error::Deobfuscation(format!( + "jiejie-typeof: block index {} out of bounds", + *block_idx + )) + })?; + block .instructions_mut() - .insert(instr_idx + 1, call_instr); + .insert(instr_idx.saturating_add(1), call_instr); ctx.events.record(EventKind::ValueResolved); } diff --git a/dotscope/src/deobfuscation/passes/native.rs b/dotscope/src/deobfuscation/passes/native.rs index 73f2d6d7..bd9643e4 100644 --- a/dotscope/src/deobfuscation/passes/native.rs +++ b/dotscope/src/deobfuscation/passes/native.rs @@ -202,10 +202,10 @@ impl NativeMethodConversionPass { for &token in &self.targets { match self.convert_method(assembly, file, token, bitness) { Ok(()) => { - stats.converted += 1; + stats.converted = stats.converted.saturating_add(1); } Err(e) => { - stats.failed += 1; + stats.failed = stats.failed.saturating_add(1); stats.failed_tokens.push(token); stats.errors.push(format!("0x{:08x}: {}", token.value(), e)); } @@ -252,14 +252,26 @@ impl NativeMethodConversionPass { // Step 2: Get x86 bytes from the RVA let offset = file.rva_to_offset(method_row.rva as usize)?; - let x86_bytes = &file.data()[offset..]; + let x86_bytes = file.data().get(offset..).ok_or_else(|| { + Error::X86Error(format!( + "Method 0x{:08x}: file offset {offset} out of bounds", + token.value() + )) + })?; // Step 3: Detect prologue and adjust bytes if needed let (decode_bytes, base_offset) = if self.skip_prologue { let prologue = x86_detect_prologue(x86_bytes, bitness); if prologue.kind == X86PrologueKind::DynCipher { // Skip the prologue - (&x86_bytes[prologue.size..], prologue.size as u64) + let rest = x86_bytes.get(prologue.size..).ok_or_else(|| { + Error::X86Error(format!( + "Method 0x{:08x}: prologue size {} exceeds bytes available", + token.value(), + prologue.size + )) + })?; + (rest, prologue.size as u64) } else { // No recognized prologue, decode from start (x86_bytes, 0u64) diff --git a/dotscope/src/deobfuscation/passes/netreactor/resolver.rs b/dotscope/src/deobfuscation/passes/netreactor/resolver.rs index d527104b..6218e608 100644 --- a/dotscope/src/deobfuscation/passes/netreactor/resolver.rs +++ b/dotscope/src/deobfuscation/passes/netreactor/resolver.rs @@ -85,7 +85,10 @@ impl SsaPass for TokenResolverPass { let Some(dest_var) = dest else { continue; }; - let Some(const_val) = constants.get(&args[0]) else { + let Some(arg0) = args.first() else { + continue; + }; + let Some(const_val) = constants.get(arg0) else { continue; }; let raw_token = match const_val { diff --git a/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs b/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs index 78e0cb41..96608edb 100644 --- a/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs +++ b/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs @@ -133,7 +133,7 @@ impl SsaPass for ResourceShimRewritePass { method: bcl_target, args: args.clone(), }; - shim_rewrites += 1; + shim_rewrites = shim_rewrites.saturating_add(1); Some(new) } SsaOp::Call { dest, method, args } @@ -141,7 +141,7 @@ impl SsaPass for ResourceShimRewritePass { && dest.is_none() && args.is_empty() => { - init_nops += 1; + init_nops = init_nops.saturating_add(1); Some(SsaOp::Nop) } _ => None, diff --git a/dotscope/src/deobfuscation/passes/neutralize.rs b/dotscope/src/deobfuscation/passes/neutralize.rs index bf0a2afb..05ecb887 100644 --- a/dotscope/src/deobfuscation/passes/neutralize.rs +++ b/dotscope/src/deobfuscation/passes/neutralize.rs @@ -259,7 +259,7 @@ impl<'a> NeutralizationPass<'a> { // Find blocks that can reach exit (for fallback target selection) let can_reach_exit = Self::find_blocks_reaching_exit(ssa); - let mut count = 0; + let mut count: usize = 0; // 1. Remove tainted PHI nodes // Collect PHIs to remove (block_idx, phi_idx) sorted in reverse order @@ -271,7 +271,7 @@ impl<'a> NeutralizationPass<'a> { if let Some(block) = ssa.block_mut(block_idx) { if phi_idx < block.phi_nodes().len() { block.phi_nodes_mut().remove(phi_idx); - count += 1; + count = count.saturating_add(1); } } } @@ -350,7 +350,7 @@ impl<'a> NeutralizationPass<'a> { InstrAction::Nop => instr.set_op(SsaOp::Nop), InstrAction::Jump(target) => instr.set_op(SsaOp::Jump { target }), } - count += 1; + count = count.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/passes/opaquefields.rs b/dotscope/src/deobfuscation/passes/opaquefields.rs index 0f31c565..09b10e97 100644 --- a/dotscope/src/deobfuscation/passes/opaquefields.rs +++ b/dotscope/src/deobfuscation/passes/opaquefields.rs @@ -311,7 +311,7 @@ impl<'a> FieldResolver<'a> { let mut current_val = static_val; for (i, &field_token) in field_chain.iter().enumerate() { - let is_last = i == field_chain.len() - 1; + let is_last = i == field_chain.len().saturating_sub(1); match ¤t_val { EmValue::ObjectRef(heap_ref) => { @@ -419,7 +419,10 @@ impl SsaPass for OpaqueFieldPredicatePass { } fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { - let remaining = self.affected_methods.len() - self.processed_methods.len(); + let remaining = self + .affected_methods + .len() + .saturating_sub(self.processed_methods.len()); if remaining > 0 { debug!( "Opaque field predicate pass: {} unique static fields in {} remaining methods ({} already processed)", diff --git a/dotscope/src/deobfuscation/passes/reflection.rs b/dotscope/src/deobfuscation/passes/reflection.rs index d4960e72..ffaf40aa 100644 --- a/dotscope/src/deobfuscation/passes/reflection.rs +++ b/dotscope/src/deobfuscation/passes/reflection.rs @@ -249,7 +249,7 @@ impl SsaPass for ReflectionDevirtualizationPass { ReflectionSite::FieldAccess { .. } => rewrite_field_access(ssa, site, ctx), }; if success { - count += 1; + count = count.saturating_add(1); } } @@ -281,7 +281,7 @@ impl SsaPass for ReflectionDevirtualizationPass { /// The number of confirmed P1 (ResolveMethod + GetFunctionPointer + calli) sites. pub fn count_resolve_method_calli_sites(ssa: &SsaFunction, assembly: &CilObject) -> usize { let tracer = ChainTracer { ssa, assembly }; - let mut count = 0; + let mut count: usize = 0; for (block_idx, block) in ssa.blocks().iter().enumerate() { for (i, instr) in block.instructions().iter().enumerate() { let SsaOp::CallIndirect { @@ -294,7 +294,7 @@ pub fn count_resolve_method_calli_sites(ssa: &SsaFunction, assembly: &CilObject) .trace_resolve_method_calli(block_idx, i, *dest, *fptr, args) .is_some() { - count += 1; + count = count.saturating_add(1); } } } @@ -421,7 +421,7 @@ impl<'a> ChainTracer<'a> { if args.is_empty() { return None; } - match self.ssa.get_definition(args[0])? { + match self.ssa.get_definition(*args.first()?)? { SsaOp::LoadToken { token, .. } => Some(token.0), _ => None, } @@ -441,7 +441,7 @@ impl<'a> ChainTracer<'a> { return None; }; if rm_args.len() >= 2 && is_method_named(self.assembly, method.token(), "ResolveMethod") { - Some((rm_args[0], rm_args[1])) + Some((*rm_args.first()?, *rm_args.get(1)?)) } else { None } @@ -472,8 +472,9 @@ impl<'a> ChainTracer<'a> { return None; } intermediates.extend(self.def_site(field_info_var)); - let field_token = self.extract_field_token(rf_args[1])?; - intermediates.extend(self.def_site(rf_args[1])); + let token_var = *rf_args.get(1)?; + let field_token = self.extract_field_token(token_var)?; + intermediates.extend(self.def_site(token_var)); Some((field_token, intermediates)) } @@ -486,20 +487,18 @@ impl<'a> ChainTracer<'a> { let Some(SsaOp::CallVirt { args, .. }) = self.ssa.get_definition(module_var) else { return; }; - if args.is_empty() { + let Some(&type_handle_var) = args.first() else { return; - } + }; intermediates.extend(self.def_site(module_var)); - let type_handle_var = args[0]; let Some(SsaOp::Call { args, .. }) = self.ssa.get_definition(type_handle_var) else { return; }; - if args.is_empty() { + let Some(&loadtoken_arg) = args.first() else { return; - } + }; intermediates.extend(self.def_site(type_handle_var)); - let loadtoken_arg = args[0]; if matches!( self.ssa.get_definition(loadtoken_arg), @@ -525,10 +524,10 @@ impl<'a> ChainTracer<'a> { // Fast path: fptr directly defined by Call GetFunctionPointer. if let Some(def) = self.ssa.get_definition(fptr) { if let SsaOp::Call { method, args, .. } = def { - if !args.is_empty() - && is_method_named(self.assembly, method.token(), "GetFunctionPointer") - { - return Some((fptr, method.token(), args[0])); + if let Some(&first) = args.first() { + if is_method_named(self.assembly, method.token(), "GetFunctionPointer") { + return Some((fptr, method.token(), first)); + } } } return None; @@ -539,10 +538,10 @@ impl<'a> ChainTracer<'a> { for operand in phi.operands() { let op_var = operand.value(); if let Some(SsaOp::Call { method, args, .. }) = self.ssa.get_definition(op_var) { - if !args.is_empty() - && is_method_named(self.assembly, method.token(), "GetFunctionPointer") - { - return Some((op_var, method.token(), args[0])); + if let Some(&first) = args.first() { + if is_method_named(self.assembly, method.token(), "GetFunctionPointer") { + return Some((op_var, method.token(), first)); + } } } } @@ -582,12 +581,12 @@ impl<'a> ChainTracer<'a> { continue; } if let Some(SsaOp::CallVirt { method, args, .. }) = self.ssa.get_definition(*src) { - if !args.is_empty() - && is_method_named(self.assembly, method.token(), "get_MethodHandle") - { - intermediates.extend(self.def_site(*src)); - intermediates.push((block_idx, instr_idx)); - return Some(args[0]); + if let Some(&first) = args.first() { + if is_method_named(self.assembly, method.token(), "get_MethodHandle") { + intermediates.extend(self.def_site(*src)); + intermediates.push((block_idx, instr_idx)); + return Some(first); + } } } } @@ -643,7 +642,7 @@ impl<'a> ChainTracer<'a> { && is_method_named(self.assembly, method.token(), "get_MethodHandle") => { intermediates.extend(self.def_site(getfp_arg)); - args[0] + *args.first()? } SsaOp::LoadLocalAddr { local_index, .. } => { let local_idx = *local_index; @@ -697,7 +696,7 @@ impl<'a> ChainTracer<'a> { args: &[SsaVarId], dest: Option, ) -> Option { - let (method_info_var, obj_var, array_var) = (args[0], args[1], args[2]); + let (method_info_var, obj_var, array_var) = (*args.first()?, *args.get(1)?, *args.get(2)?); let mut intermediates: Vec<(usize, usize)> = Vec::new(); let (module_var, token_const_var) = self.trace_resolve_method_call(method_info_var)?; @@ -743,7 +742,7 @@ impl<'a> ChainTracer<'a> { args: &[SsaVarId], dest: Option, ) -> Option { - let (method_info_var, obj_var, array_var) = (args[0], args[1], args[2]); + let (method_info_var, obj_var, array_var) = (*args.first()?, *args.get(1)?, *args.get(2)?); let mut intermediates: Vec<(usize, usize)> = Vec::new(); // method_info_var ← Call/CallVirt GetMethod(type, Const("name")) @@ -756,7 +755,7 @@ impl<'a> ChainTracer<'a> { method, args: a, .. } if a.len() >= 2 && is_method_named(self.assembly, method.token(), "GetMethod") => { intermediates.extend(self.def_site(method_info_var)); - (a[0], a[1]) + (*a.first()?, *a.get(1)?) } _ => return None, }; @@ -825,7 +824,7 @@ impl<'a> ChainTracer<'a> { args: &[SsaVarId], dest: Option, ) -> Option { - let type_var = args[0]; + let type_var = *args.first()?; let mut intermediates = Vec::new(); let type_token = self.trace_type_from_handle(type_var)?; @@ -860,7 +859,7 @@ impl<'a> ChainTracer<'a> { args: &[SsaVarId], dest: Option, ) -> Option { - let (field_info_var, obj_var) = (args[0], args[1]); + let (field_info_var, obj_var) = (*args.first()?, *args.get(1)?); let (field_token, intermediates) = self.trace_resolve_field(field_info_var)?; Some(ReflectionSite::FieldAccess { @@ -896,7 +895,7 @@ impl<'a> ChainTracer<'a> { idx: usize, args: &[SsaVarId], ) -> Option { - let (field_info_var, obj_var, value_var) = (args[0], args[1], args[2]); + let (field_info_var, obj_var, value_var) = (*args.first()?, *args.get(1)?, *args.get(2)?); let (field_token, intermediates) = self.trace_resolve_field(field_info_var)?; // Unwrap Box on the value if present @@ -1088,9 +1087,9 @@ fn unpack_object_array(ssa: &SsaFunction, array_var: SsaVarId) -> Option CffDetector<'a> { // Pre-compute the dominator tree so candidate analysis can be parallel let _ = self.get_dom_tree(); - let dom_tree = self.dom_tree.as_ref().unwrap(); + let Some(dom_tree) = self.dom_tree.as_ref() else { + return Vec::new(); + }; let mut patterns: Vec = candidates .into_par_iter() @@ -450,7 +452,7 @@ impl<'a> CffDetector<'a> { // Pre-compute dominator tree once let _ = self.get_dom_tree(); - let dom_tree = self.dom_tree.as_ref().unwrap(); + let dom_tree = self.dom_tree.as_ref()?; // Score each candidate and pick the best let mut best_pattern: Option = None; @@ -558,7 +560,7 @@ impl<'a> CffDetector<'a> { .instructions() .iter() .any(|i| i.op().successors().contains(&block_idx)); - let effective_preds = pred_count + usize::from(has_self_loop); + let effective_preds = pred_count.saturating_add(usize::from(has_self_loop)); return effective_preds >= 2; } @@ -823,7 +825,7 @@ impl<'a> CffDetector<'a> { } } SsaOp::Copy { src, .. } => { - stack.push((*src, depth - 1)); + stack.push((*src, depth.saturating_sub(1))); } _ => {} } @@ -846,7 +848,7 @@ impl<'a> CffDetector<'a> { // Push operands in reverse so pops happen in original order // — matches the recursive "return first constant" semantics. for op in phi.operands().iter().rev() { - stack.push((op.value(), depth - 1)); + stack.push((op.value(), depth.saturating_sub(1))); } break; } @@ -1140,7 +1142,7 @@ fn can_reach_dispatcher( } for succ in ssa.block_successors(block) { if !visited.contains(succ) { - queue.push_back((succ, depth + 1)); + queue.push_back((succ, depth.saturating_add(1))); } } } diff --git a/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs b/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs index a5f779a7..27a877dd 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs @@ -124,11 +124,7 @@ impl Dispatcher { // Cast to usize for indexing - transform result is always non-negative after modulo/and operations #[allow(clippy::cast_sign_loss)] let index = self.transform.apply(state) as usize; - if index < self.cases.len() { - self.cases[index] - } else { - self.default - } + self.cases.get(index).copied().unwrap_or(self.default) } /// Refreshes variable IDs against the current SSA. @@ -272,13 +268,13 @@ impl StateTransform { Self::Modulo(n) => { // Use unsigned modulo for consistency with CIL rem.un let u_state = state.cast_unsigned(); - (u_state % n).cast_signed() + u_state.checked_rem(*n).unwrap_or(0).cast_signed() } Self::XorModulo { xor_key, divisor } => { // ConfuserEx pattern: (state ^ key) % N let xored = state ^ xor_key; let u_xored = xored.cast_unsigned(); - (u_xored % divisor).cast_signed() + u_xored.checked_rem(*divisor).unwrap_or(0).cast_signed() } Self::And(mask) => state & (*mask).cast_signed(), Self::Shr(amount) => { @@ -432,11 +428,7 @@ impl DispatcherInfo { // Cast to usize for indexing - transform result is always non-negative after modulo/and operations #[allow(clippy::cast_sign_loss)] let index = transform.apply(case_value) as usize; - if index < cases.len() { - Some(cases[index]) - } else { - Some(*default) - } + Some(cases.get(index).copied().unwrap_or(*default)) } Self::IfElseChain { comparisons, diff --git a/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs b/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs index 52f3ef5b..2a70a47e 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs @@ -307,7 +307,21 @@ pub fn merge_patch_plans(plans: Vec) -> PatchPlan { } if plans.len() == 1 { - return plans.into_iter().next().unwrap(); + if let Some(only) = plans.into_iter().next() { + return only; + } + // Unreachable: we just checked len() == 1, but handle defensively. + return PatchPlan { + dispatcher_blocks: Vec::new(), + state_tainted: BitSet::new(0), + redirects: Vec::new(), + state_transition_sources: BTreeSet::new(), + clone_requests: BTreeMap::new(), + execution_order: Vec::new(), + branch_collapses: BTreeMap::new(), + state_transitions_removed: 0, + user_branches_preserved: 0, + }; } // Determine the tainted BitSet size (all plans share the same SSA, so same size) @@ -376,8 +390,12 @@ pub fn merge_patch_plans(plans: Vec) -> PatchPlan { merged.branch_collapses.entry(source).or_insert(target); } - merged.state_transitions_removed += plan.state_transitions_removed; - merged.user_branches_preserved += plan.user_branches_preserved; + merged.state_transitions_removed = merged + .state_transitions_removed + .saturating_add(plan.state_transitions_removed); + merged.user_branches_preserved = merged + .user_branches_preserved + .saturating_add(plan.user_branches_preserved); } merged @@ -533,12 +551,13 @@ fn extract_redirects_from_node( .iter() .position(|&b| b == last) .unwrap_or(node.blocks_visited.len()); - if end_idx > start_idx + 1 { + let interior_start = start_idx.saturating_add(1); + if end_idx > interior_start { let intermediate_blocks: std::collections::BTreeSet = node - .blocks_visited[start_idx + 1..end_idx] - .iter() - .copied() - .collect(); + .blocks_visited + .get(interior_start..end_idx) + .map(|s| s.iter().copied().collect()) + .unwrap_or_default(); for iwv in &node.instructions { if !intermediate_blocks.contains(&iwv.block_idx) { continue; @@ -563,7 +582,10 @@ fn extract_redirects_from_node( // first has a Branch/BranchCmp overflow check: collapse // it to Jump at the next visited block, then use `last` // for the dispatcher bypass redirect. - let next_in_path = node.blocks_visited.get(start_idx + 1).copied(); + let next_in_path = node + .blocks_visited + .get(start_idx.saturating_add(1)) + .copied(); if let Some(next) = next_in_path { collapse_first_branch = Some((first, next)); } @@ -582,7 +604,7 @@ fn extract_redirects_from_node( .position(|&b| b == pred) .and_then(|pos| { if pos > 0 { - Some(node.blocks_visited[pos - 1]) + node.blocks_visited.get(pos.saturating_sub(1)).copied() } else { external_predecessor } @@ -595,7 +617,8 @@ fn extract_redirects_from_node( // This block should redirect to target_block instead of dispatcher plan.add_redirect(pred, *target_block, predecessor_of_pred); plan.state_transition_sources.insert(pred); - plan.state_transitions_removed += 1; + plan.state_transitions_removed = + plan.state_transitions_removed.saturating_add(1); } else if let Some(ext_pred) = external_predecessor { // The sub-trace starts directly at the dispatcher (no preceding user blocks). // This happens when a user branch at method entry sends one path directly @@ -603,7 +626,8 @@ fn extract_redirects_from_node( // dispatcher-targeting edge to the actual target. plan.add_redirect(ext_pred, *target_block, None); plan.state_transition_sources.insert(ext_pred); - plan.state_transitions_removed += 1; + plan.state_transitions_removed = + plan.state_transitions_removed.saturating_add(1); } // If the first visited block is a BranchCmp overflow check, add @@ -630,7 +654,7 @@ fn extract_redirects_from_node( false_branch, .. } => { - plan.user_branches_preserved += 1; + plan.user_branches_preserved = plan.user_branches_preserved.saturating_add(1); // For user branches, the branch block is the predecessor of both sub-traces. // This is crucial for proper merge point detection when a block is both // an entry path target (from the branch) and a CFF case target. @@ -645,7 +669,7 @@ fn extract_redirects_from_node( default, .. } => { - plan.user_branches_preserved += 1; + plan.user_branches_preserved = plan.user_branches_preserved.saturating_add(1); // For user switches, the switch block is the predecessor of all // case sub-traces. The recursive version processed all cases in // order, then the default — to reproduce that order with a LIFO @@ -678,7 +702,8 @@ fn extract_redirects_from_node( if let Some(pred) = external_predecessor { plan.add_redirect(pred, *target_block, external_predecessor); plan.state_transition_sources.insert(pred); - plan.state_transitions_removed += 1; + plan.state_transitions_removed = + plan.state_transitions_removed.saturating_add(1); } plan.add_to_execution_order(*target_block); } @@ -780,7 +805,9 @@ pub fn apply_patch_plan(ssa: &mut SsaFunction, plan: &PatchPlan) -> Reconstructi // Only redirect Jump terminators. Branch/BranchCmp blocks are user // branches that must preserve both targets — cloning them with // set_target would collapse both branch arms to the same target. - let (_, first_target) = paths[0]; + let Some(&(_, first_target)) = paths.first() else { + continue; + }; let is_user_branch = ssa .block(*merge_block) .and_then(|b| b.terminator_op()) @@ -802,7 +829,7 @@ pub fn apply_patch_plan(ssa: &mut SsaFunction, plan: &PatchPlan) -> Reconstructi // For remaining paths, create clones for &(pred, target) in paths.iter().skip(1) { - let new_block_idx = ssa.block_count() + cloned_blocks.len(); + let new_block_idx = ssa.block_count().saturating_add(cloned_blocks.len()); // Track the clone mapping clone_map.insert(new_block_idx, *merge_block); diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs index 82ed95b7..089f72c2 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs @@ -172,8 +172,8 @@ impl<'a> TreeTraceContext<'a> { // Size the case visit counter to fit the dispatcher's switch targets // (+1 for default, which uses targets.len() as its index). - ctx.visited_case_counts = vec![0u8; dispatcher.targets.len() + 1]; - ctx.last_case_state = vec![None; dispatcher.targets.len() + 1]; + ctx.visited_case_counts = vec![0u8; dispatcher.targets.len().saturating_add(1)]; + ctx.last_case_state = vec![None; dispatcher.targets.len().saturating_add(1)]; ctx.dispatcher = Some(dispatcher); ctx } @@ -207,7 +207,7 @@ impl<'a> TreeTraceContext<'a> { /// Allocates and returns the next unique node ID. pub fn next_id(&mut self) -> usize { let id = self.next_node_id; - self.next_node_id += 1; + self.next_node_id = self.next_node_id.saturating_add(1); id } @@ -295,11 +295,11 @@ impl<'a> TreeTraceContext<'a> { /// allowing re-entry from different CFF case paths. fn visit_state(&self) -> i64 { self.current_state().unwrap_or_else(|| { - let count = if self.last_case_index < self.visited_case_counts.len() { - self.visited_case_counts[self.last_case_index] as i64 - } else { - 0 - }; + let count = self + .visited_case_counts + .get(self.last_case_index) + .copied() + .map_or(0, i64::from); (self.last_case_index as i64) .wrapping_mul(256) .wrapping_add(count) @@ -318,7 +318,7 @@ impl<'a> TreeTraceContext<'a> { /// Increments the visit counter and returns true if the budget is exceeded. pub fn check_visit_budget(&mut self) -> bool { - self.total_visits += 1; + self.total_visits = self.total_visits.saturating_add(1); self.total_visits > self.max_block_visits } @@ -330,9 +330,8 @@ impl<'a> TreeTraceContext<'a> { /// Records that the dispatcher dispatched to the given case index. /// Increments the visit count for the case and updates the last case index. pub fn record_case_dispatch(&mut self, case_idx: usize) { - if case_idx < self.visited_case_counts.len() { - self.visited_case_counts[case_idx] = - self.visited_case_counts[case_idx].saturating_add(1); + if let Some(slot) = self.visited_case_counts.get_mut(case_idx) { + *slot = slot.saturating_add(1); } self.last_case_index = case_idx; } @@ -355,8 +354,8 @@ impl<'a> TreeTraceContext<'a> { /// Records the state value used for this dispatch of `case_idx`. pub fn record_case_state(&mut self, case_idx: usize, state: i64) { - if case_idx < self.last_case_state.len() { - self.last_case_state[case_idx] = Some(state); + if let Some(slot) = self.last_case_state.get_mut(case_idx) { + *slot = Some(state); } } @@ -365,8 +364,9 @@ impl<'a> TreeTraceContext<'a> { /// to avoid false positives on small dispatchers. pub fn is_case_loop(&self, case_idx: usize, targets_len: usize) -> bool { let loop_threshold = (targets_len / 2).max(2) as u8; - case_idx < self.visited_case_counts.len() - && self.visited_case_counts[case_idx] >= loop_threshold + self.visited_case_counts + .get(case_idx) + .is_some_and(|count| *count >= loop_threshold) } /// Returns true when the tracer should follow one path instead of forking diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/engine.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/engine.rs index e286a658..6cb47388 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/engine.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/engine.rs @@ -121,7 +121,13 @@ pub fn trace_from_block( loop { let Some(item) = work_stack.pop() else { - return current_result.expect("trace work stack empty but no result"); + return current_result.unwrap_or_else(|| { + let mut node = TraceNode::new(0, block_idx); + node.set_terminator(TraceTerminator::Stopped { + reason: StopReason::UnknownControlFlow { block: block_idx }, + }); + node + }); }; match item { @@ -136,9 +142,15 @@ pub fn trace_from_block( to_state, target_block, } => { - let child = current_result - .take() - .expect("StateTransitionLink: missing child"); + let Some(child) = current_result.take() else { + parent_node.set_terminator(TraceTerminator::Stopped { + reason: StopReason::UnknownControlFlow { + block: target_block, + }, + }); + current_result = Some(parent_node); + continue; + }; parent_node.set_terminator(TraceTerminator::StateTransition { from_state, to_state, @@ -158,9 +170,14 @@ pub fn trace_from_block( case_counts_snapshot, is_expr_switch, } => { - let true_node = current_result - .take() - .expect("BranchFalseArm: missing true_node"); + let Some(true_node) = current_result.take() else { + let mut node = parent_node; + node.set_terminator(TraceTerminator::Stopped { + reason: StopReason::UnknownControlFlow { block: block_idx }, + }); + current_result = Some(node); + continue; + }; // Restore context to the state before the true arm. // For expression switches, preserve the visited_case_counts @@ -197,7 +214,7 @@ pub fn trace_from_block( }); work_stack.push(WorkItem::TraceBlock { block: false_target, - depth: depth + 1, + depth: depth.saturating_add(1), }); } @@ -208,9 +225,16 @@ pub fn trace_from_block( true_node, expr_switch_restore, } => { - let false_node = current_result - .take() - .expect("BranchCombine: missing false_node"); + let Some(false_node) = current_result.take() else { + if let Some(saved) = expr_switch_restore { + ctx.exit_expr_switch_false_arm(saved); + } + parent_node.set_terminator(TraceTerminator::Stopped { + reason: StopReason::UnknownControlFlow { block: block_idx }, + }); + current_result = Some(parent_node); + continue; + }; if let Some(saved) = expr_switch_restore { ctx.exit_expr_switch_false_arm(saved); @@ -240,18 +264,16 @@ pub fn trace_from_block( if let Some(prev_result) = current_result.take() { if next_case_index > 0 { #[allow(clippy::cast_possible_wrap)] - let case_value = (next_case_index - 1) as i64; + let case_value = next_case_index.saturating_sub(1) as i64; completed_cases.push((case_value, Box::new(prev_result))); } } - if next_case_index < targets.len() { + if let Some(&target) = targets.get(next_case_index) { // More cases to trace — restore and trace the next one ctx.restore(snapshot.clone_snapshot()); ctx.evaluator_mut().set_predecessor(Some(block_idx)); - let target = targets[next_case_index]; - work_stack.push(WorkItem::SwitchNextCase { parent_node, block_idx, @@ -261,7 +283,7 @@ pub fn trace_from_block( depth, snapshot, completed_cases, - next_case_index: next_case_index + 1, + next_case_index: next_case_index.saturating_add(1), }); work_stack.push(WorkItem::TraceBlock { block: target, @@ -291,9 +313,13 @@ pub fn trace_from_block( value, cases, } => { - let default_node = current_result - .take() - .expect("SwitchCombine: missing default_node"); + let Some(default_node) = current_result.take() else { + parent_node.set_terminator(TraceTerminator::Stopped { + reason: StopReason::UnknownControlFlow { block: block_idx }, + }); + current_result = Some(parent_node); + continue; + }; parent_node.set_terminator(TraceTerminator::UserSwitch { block: block_idx, value, @@ -369,7 +395,7 @@ fn trace_from_block_linear<'a>( }); work_stack.push(WorkItem::TraceBlock { block: true_target, - depth: depth + 1, + depth: depth.saturating_add(1), }); } ForkRequest::Switch { @@ -380,13 +406,16 @@ fn trace_from_block_linear<'a>( snapshot, is_foreign, } => { - let fork_depth = if is_foreign { depth } else { depth + 1 }; - if targets.is_empty() { + let fork_depth = if is_foreign { + depth + } else { + depth.saturating_add(1) + }; + let Some(&first_target) = targets.first() else { ctx.evaluator_mut().set_predecessor(Some(block_idx)); return leaf; - } + }; ctx.evaluator_mut().set_predecessor(Some(block_idx)); - let first_target = targets[0]; work_stack.push(WorkItem::SwitchNextCase { parent_node: leaf, block_idx, @@ -525,10 +554,15 @@ fn trace_from_block_inner<'a>( // Exempt the dispatcher block — it's intentionally revisited as it dispatches // to different case blocks based on the state variable. let is_dispatcher = ctx.is_dispatcher_block(current_block); + let visited_without_last = node + .blocks_visited + .split_last() + .map(|(_, rest)| rest) + .unwrap_or(&[]); if !is_dispatcher && current_block != block_idx && node.blocks_visited.len() > 1 - && node.blocks_visited[..node.blocks_visited.len() - 1].contains(¤t_block) + && visited_without_last.contains(¤t_block) { let state = ctx.current_state().unwrap_or(0); node.set_terminator(TraceTerminator::LoopBack { @@ -539,9 +573,14 @@ fn trace_from_block_inner<'a>( } // Handle dispatcher re-entry: clear stale values and fix predecessor + let visited_without_last_for_reentry = node + .blocks_visited + .split_last() + .map(|(_, rest)| rest) + .unwrap_or(&[]); let is_dispatcher_reentry = is_dispatcher && node.blocks_visited.len() > 1 - && node.blocks_visited[..node.blocks_visited.len() - 1].contains(¤t_block); + && visited_without_last_for_reentry.contains(¤t_block); if is_dispatcher_reentry { if let Some(block) = ssa.block(current_block) { for instr in block.instructions() { @@ -579,8 +618,12 @@ fn trace_from_block_inner<'a>( // Set predecessor for phi evaluation if node.blocks_visited.len() > 1 { - let prev = node.blocks_visited[node.blocks_visited.len() - 2]; - ctx.evaluator_mut().set_predecessor(Some(prev)); + if let Some(&prev) = node + .blocks_visited + .get(node.blocks_visited.len().saturating_sub(2)) + { + ctx.evaluator_mut().set_predecessor(Some(prev)); + } } // Bridge loop-carried phi operand values for dispatcher blocks @@ -1211,11 +1254,7 @@ fn handle_switch<'a>( #[allow(clippy::cast_possible_truncation)] let idx_usize = idx as usize; - let target = if idx_usize < targets.len() { - targets[idx_usize] - } else { - *default - }; + let target = targets.get(idx_usize).copied().unwrap_or(*default); let from_state = ctx.current_state().unwrap_or(0); diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs index 5af83759..1ccb4040 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs @@ -59,7 +59,8 @@ pub fn trace_exception_handlers(ctx: &mut TreeTraceContext<'_>) -> Vec) -> Vec Option { /// Computes statistics for a trace tree. pub fn compute_tree_stats(node: &TraceNode, stats: &mut TraceStats, depth: usize) { - stats.node_count += 1; + stats.node_count = stats.node_count.saturating_add(1); stats.max_depth = stats.max_depth.max(depth); + let next_depth = depth.saturating_add(1); match &node.terminator { TraceTerminator::Exit { .. } => { - stats.exit_count += 1; + stats.exit_count = stats.exit_count.saturating_add(1); } TraceTerminator::StateTransition { continues, .. } => { - stats.state_transition_count += 1; - compute_tree_stats(continues, stats, depth + 1); + stats.state_transition_count = stats.state_transition_count.saturating_add(1); + compute_tree_stats(continues, stats, next_depth); } TraceTerminator::UserBranch { true_branch, false_branch, .. } => { - stats.user_branch_count += 1; - compute_tree_stats(true_branch, stats, depth + 1); - compute_tree_stats(false_branch, stats, depth + 1); + stats.user_branch_count = stats.user_branch_count.saturating_add(1); + compute_tree_stats(true_branch, stats, next_depth); + compute_tree_stats(false_branch, stats, next_depth); } TraceTerminator::UserSwitch { cases, default, .. } => { - stats.user_branch_count += 1; + stats.user_branch_count = stats.user_branch_count.saturating_add(1); for (_, case_node) in cases { - compute_tree_stats(case_node, stats, depth + 1); + compute_tree_stats(case_node, stats, next_depth); } - compute_tree_stats(default, stats, depth + 1); + compute_tree_stats(default, stats, next_depth); } TraceTerminator::Stopped { .. } | TraceTerminator::LoopBack { .. } => {} TraceTerminator::PendingStateTransition { .. } => { diff --git a/dotscope/src/deobfuscation/processcell.rs b/dotscope/src/deobfuscation/processcell.rs index 88c9678d..972ddab4 100644 --- a/dotscope/src/deobfuscation/processcell.rs +++ b/dotscope/src/deobfuscation/processcell.rs @@ -78,7 +78,7 @@ impl ProcessCell { /// /// # Errors /// - /// Returns [`Error::LockError`] if the internal `RwLock` is poisoned. + /// Returns [`crate::Error::LockError`] if the internal `RwLock` is poisoned. pub fn ensure_initialized( &self, init_fn: F, @@ -122,7 +122,7 @@ impl ProcessCell { /// /// # Errors /// - /// Returns [`Error::LockError`] if the internal `RwLock` is poisoned. + /// Returns [`crate::Error::LockError`] if the internal `RwLock` is poisoned. pub fn take(&self) -> Result> { let mut guard = self .process @@ -138,7 +138,7 @@ impl ProcessCell { /// /// # Errors /// - /// Returns [`Error::LockError`] if the internal `RwLock` is poisoned. + /// Returns [`crate::Error::LockError`] if the internal `RwLock` is poisoned. pub fn clear(&self) -> Result<()> { let mut guard = self .process diff --git a/dotscope/src/deobfuscation/renamer/cascade.rs b/dotscope/src/deobfuscation/renamer/cascade.rs index 6ccd34b0..352836a9 100644 --- a/dotscope/src/deobfuscation/renamer/cascade.rs +++ b/dotscope/src/deobfuscation/renamer/cascade.rs @@ -833,8 +833,8 @@ impl<'a> CascadeRenamer<'a> { if let Some(method) = self.assembly.method(&method_token) { // param.sequence is 1-based (0 = return type), so index = sequence - 1 let sig_index = (param_sequence as usize).saturating_sub(1); - if sig_index < method.signature.params.len() { - context.dotnet_type = Some(method.signature.params[sig_index].to_string()); + if let Some(param) = method.signature.params.get(sig_index) { + context.dotnet_type = Some(param.to_string()); } } @@ -984,7 +984,7 @@ impl<'a> CascadeRenamer<'a> { self.reserve_name(scope_key, &candidate); return candidate; } - suffix += 1; + suffix = suffix.saturating_add(1); } } @@ -1025,13 +1025,15 @@ fn build_param_owner_map( } // End is next method's param_list or end of table + let next_method_rid = method_rid.saturating_add(1); + let param_end_default = param_row_count.saturating_add(1); let param_end = if method_rid < methoddef_table.row_count { methoddef_table - .get(method_rid + 1) + .get(next_method_rid) .map(|next| next.param_list) - .unwrap_or(param_row_count + 1) + .unwrap_or(param_end_default) } else { - param_row_count + 1 + param_end_default }; for param_rid in param_start..param_end { @@ -1062,13 +1064,15 @@ fn build_member_owner_map( continue; } + let next_type_rid = type_rid.saturating_add(1); + let end_default = member_row_count.saturating_add(1); let end = if type_rid < typedef_table.row_count { typedef_table - .get(type_rid + 1) + .get(next_type_rid) .map(|next| get_list_start(&next)) - .unwrap_or(member_row_count + 1) + .unwrap_or(end_default) } else { - member_row_count + 1 + end_default }; for member_rid in start..end { @@ -1099,12 +1103,10 @@ fn generate_phase_label_from_context( _prefix: &str, _suffix: &str, ) -> Option { - if !phase.call_targets.is_empty() { - // Use the first call target as a label - let first = &phase.call_targets[0]; + if let Some(first) = phase.call_targets.first() { // Extract just the method name part let label = if let Some(idx) = first.rfind("::") { - &first[idx + 2..] + first.get(idx.saturating_add(2)..).unwrap_or(first.as_str()) } else { first.as_str() }; diff --git a/dotscope/src/deobfuscation/renamer/features.rs b/dotscope/src/deobfuscation/renamer/features.rs index 4efb982e..571266ee 100644 --- a/dotscope/src/deobfuscation/renamer/features.rs +++ b/dotscope/src/deobfuscation/renamer/features.rs @@ -148,7 +148,7 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { match instr.op() { // Calls SsaOp::Call { .. } | SsaOp::CallVirt { .. } | SsaOp::CallIndirect { .. } => { - profile.calls += 1; + profile.calls = profile.calls.saturating_add(1); } // Strings @@ -156,7 +156,7 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { value: ConstValue::String(_) | ConstValue::DecryptedString(_), .. } => { - profile.strings += 1; + profile.strings = profile.strings.saturating_add(1); } // Field I/O @@ -166,7 +166,7 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { | SsaOp::StoreStaticField { .. } | SsaOp::LoadFieldAddr { .. } | SsaOp::LoadStaticFieldAddr { .. } => { - profile.field_io += 1; + profile.field_io = profile.field_io.saturating_add(1); } // Bitwise @@ -176,7 +176,7 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { | SsaOp::Not { .. } | SsaOp::Shl { .. } | SsaOp::Shr { .. } => { - profile.bitwise += 1; + profile.bitwise = profile.bitwise.saturating_add(1); } // Arithmetic @@ -189,7 +189,7 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { | SsaOp::Div { .. } | SsaOp::Rem { .. } | SsaOp::Neg { .. } => { - profile.arithmetic += 1; + profile.arithmetic = profile.arithmetic.saturating_add(1); } // Array @@ -198,7 +198,7 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { | SsaOp::StoreElement { .. } | SsaOp::LoadElementAddr { .. } | SsaOp::ArrayLength { .. } => { - profile.array += 1; + profile.array = profile.array.saturating_add(1); } // Comparison @@ -207,12 +207,12 @@ pub fn build_opcode_profile(ssa: &SsaFunction) -> OpcodeProfile { | SsaOp::Cgt { .. } | SsaOp::Branch { .. } | SsaOp::BranchCmp { .. } => { - profile.comparison += 1; + profile.comparison = profile.comparison.saturating_add(1); } // Conversion SsaOp::Conv { .. } => { - profile.conversion += 1; + profile.conversion = profile.conversion.saturating_add(1); } _ => {} @@ -312,8 +312,11 @@ pub fn collect_call_site_context( // the call site (e.g., format strings loaded before String.Format which // comes after the method call whose result is being formatted). let window_start = idx.saturating_sub(5); - let window_end = (idx + 6).min(all_instrs.len()); - for (_, _, nearby_instr) in &all_instrs[window_start..window_end] { + let window_end = idx.saturating_add(6).min(all_instrs.len()); + let Some(window) = all_instrs.get(window_start..window_end) else { + continue; + }; + for (_, _, nearby_instr) in window { if let SsaOp::Const { value, .. } = nearby_instr.op() { match value { ConstValue::DecryptedString(s) @@ -341,7 +344,7 @@ pub fn collect_call_site_context( // Check if the return value feeds into another call if let Some(dest_var) = dest { - for (_, _, later_instr) in all_instrs.iter().skip(idx + 1).take(5) { + for (_, _, later_instr) in all_instrs.iter().skip(idx.saturating_add(1)).take(5) { let (usage_token, usage_args) = match later_instr.op() { SsaOp::Call { method, args, .. } => (method.token(), args), SsaOp::CallVirt { method, args, .. } => (method.token(), args), diff --git a/dotscope/src/deobfuscation/renamer/mod.rs b/dotscope/src/deobfuscation/renamer/mod.rs index de446463..a886f788 100644 --- a/dotscope/src/deobfuscation/renamer/mod.rs +++ b/dotscope/src/deobfuscation/renamer/mod.rs @@ -153,7 +153,7 @@ pub fn renames_apply(cil_assembly: &mut CilAssembly, entries: Vec) return Ok(0); } - let mut renamed_count = 0; + let mut renamed_count: usize = 0; // Track which string offsets have already been renamed and to what name let mut renamed_offsets: HashMap = HashMap::new(); @@ -162,14 +162,14 @@ pub fn renames_apply(cil_assembly: &mut CilAssembly, entries: Vec) if let Some(existing_name) = renamed_offsets.get(&entry.string_index) { if *existing_name == entry.new_name { // Same name as first rename — string_update already covers this - renamed_count += 1; + renamed_count = renamed_count.saturating_add(1); continue; } // Different name at same offset — allocate new string and update the row let change_ref = cil_assembly.string_add(&entry.new_name)?; let placeholder = change_ref.placeholder(); update_row_name_field(cil_assembly, entry.table_id, entry.rid, placeholder)?; - renamed_count += 1; + renamed_count = renamed_count.saturating_add(1); } else { // First rename at this offset — modify in place if cil_assembly @@ -177,7 +177,7 @@ pub fn renames_apply(cil_assembly: &mut CilAssembly, entries: Vec) .is_ok() { renamed_offsets.insert(entry.string_index, entry.new_name.clone()); - renamed_count += 1; + renamed_count = renamed_count.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/renamer/phases.rs b/dotscope/src/deobfuscation/renamer/phases.rs index 898504de..2d79e518 100644 --- a/dotscope/src/deobfuscation/renamer/phases.rs +++ b/dotscope/src/deobfuscation/renamer/phases.rs @@ -499,11 +499,10 @@ fn build_phases_from_boundaries( let mut phases = Vec::new(); for (i, &start) in boundaries.iter().enumerate() { - let end = if i + 1 < boundaries.len() { - boundaries[i + 1] - } else { - block_count - }; + let end = boundaries + .get(i.saturating_add(1)) + .copied() + .unwrap_or(block_count); if start >= block_count { continue; @@ -595,24 +594,32 @@ fn detect_back_edges( in_stack: &mut [bool], boundaries: &mut HashSet, ) { - if block_idx >= ssa.blocks().len() { + let blocks = ssa.blocks(); + let Some(block) = blocks.get(block_idx) else { return; + }; + if let Some(slot) = visited.get_mut(block_idx) { + *slot = true; + } + if let Some(slot) = in_stack.get_mut(block_idx) { + *slot = true; } - visited[block_idx] = true; - in_stack[block_idx] = true; - for succ in ssa.blocks()[block_idx].successors() { + let successors = block.successors(); + for succ in successors { if succ < visited.len() { - if in_stack[succ] { + if in_stack.get(succ).copied().unwrap_or(false) { // Back-edge found: succ is a loop header boundaries.insert(succ); - } else if !visited[succ] { + } else if !visited.get(succ).copied().unwrap_or(true) { detect_back_edges(ssa, succ, visited, in_stack, boundaries); } } } - in_stack[block_idx] = false; + if let Some(slot) = in_stack.get_mut(block_idx) { + *slot = false; + } } /// Detects contiguous blocks with heavy bitwise/arithmetic and no calls. @@ -627,12 +634,12 @@ fn detect_back_edges( /// * `boundaries` - Set of boundary block indices to populate. fn detect_transform_boundaries(ssa: &SsaFunction, boundaries: &mut HashSet) { let blocks = ssa.blocks(); - let mut consecutive_transform = 0; + let mut consecutive_transform: u32 = 0; for (block_idx, block) in blocks.iter().enumerate() { let mut has_calls = false; - let mut bitwise_count = 0u32; - let mut arithmetic_count = 0u32; + let mut bitwise_count: u32 = 0; + let mut arithmetic_count: u32 = 0; for instr in block.instructions() { match instr.op() { @@ -645,25 +652,25 @@ fn detect_transform_boundaries(ssa: &SsaFunction, boundaries: &mut HashSet { - bitwise_count += 1; + bitwise_count = bitwise_count.saturating_add(1); } SsaOp::Add { .. } | SsaOp::Sub { .. } | SsaOp::Mul { .. } | SsaOp::Div { .. } | SsaOp::Rem { .. } => { - arithmetic_count += 1; + arithmetic_count = arithmetic_count.saturating_add(1); } _ => {} } } - if !has_calls && (bitwise_count + arithmetic_count) >= 3 { + if !has_calls && bitwise_count.saturating_add(arithmetic_count) >= 3 { if consecutive_transform == 0 { // Start of transform region boundaries.insert(block_idx); } - consecutive_transform += 1; + consecutive_transform = consecutive_transform.saturating_add(1); } else { if consecutive_transform >= 3 { // End of transform region, start new phase @@ -687,19 +694,19 @@ fn detect_transform_boundaries(ssa: &SsaFunction, boundaries: &mut HashSet { - profile.calls += 1; + profile.calls = profile.calls.saturating_add(1); } SsaOp::Const { value: ConstValue::String(_) | ConstValue::DecryptedString(_), .. } => { - profile.strings += 1; + profile.strings = profile.strings.saturating_add(1); } SsaOp::LoadField { .. } | SsaOp::StoreField { .. } | SsaOp::LoadStaticField { .. } | SsaOp::StoreStaticField { .. } => { - profile.field_io += 1; + profile.field_io = profile.field_io.saturating_add(1); } SsaOp::And { .. } | SsaOp::Or { .. } @@ -707,7 +714,7 @@ fn classify_op_into_profile(op: &SsaOp, profile: &mut OpcodeProfile) { | SsaOp::Not { .. } | SsaOp::Shl { .. } | SsaOp::Shr { .. } => { - profile.bitwise += 1; + profile.bitwise = profile.bitwise.saturating_add(1); } SsaOp::Add { .. } | SsaOp::AddOvf { .. } @@ -718,23 +725,23 @@ fn classify_op_into_profile(op: &SsaOp, profile: &mut OpcodeProfile) { | SsaOp::Div { .. } | SsaOp::Rem { .. } | SsaOp::Neg { .. } => { - profile.arithmetic += 1; + profile.arithmetic = profile.arithmetic.saturating_add(1); } SsaOp::NewArr { .. } | SsaOp::LoadElement { .. } | SsaOp::StoreElement { .. } | SsaOp::ArrayLength { .. } => { - profile.array += 1; + profile.array = profile.array.saturating_add(1); } SsaOp::Ceq { .. } | SsaOp::Clt { .. } | SsaOp::Cgt { .. } | SsaOp::Branch { .. } | SsaOp::BranchCmp { .. } => { - profile.comparison += 1; + profile.comparison = profile.comparison.saturating_add(1); } SsaOp::Conv { .. } => { - profile.conversion += 1; + profile.conversion = profile.conversion.saturating_add(1); } _ => {} } diff --git a/dotscope/src/deobfuscation/renamer/prompt.rs b/dotscope/src/deobfuscation/renamer/prompt.rs index 5336cb01..32238ca0 100644 --- a/dotscope/src/deobfuscation/renamer/prompt.rs +++ b/dotscope/src/deobfuscation/renamer/prompt.rs @@ -121,7 +121,11 @@ fn build_method_prompt(context: &RenameContext, max_phases: usize) -> (String, S // Large method: use phase narrative let phases = truncate_phases(&context.phase_narrative, max_phases); for (i, phase) in phases.iter().enumerate() { - prefix.push_str(&format!("// Phase {}: {}\n", i + 1, phase.label)); + prefix.push_str(&format!( + "// Phase {}: {}\n", + i.saturating_add(1), + phase.label + )); if !phase.call_targets.is_empty() { let calls = phase.call_targets.join(", "); prefix.push_str(&format!("// [calls: {calls}]\n")); @@ -462,8 +466,13 @@ fn truncate_phases(phases: &[PhaseInfo], max_phases: usize) -> Vec<&PhaseInfo> { let half = max_phases / 2; let mut result: Vec<&PhaseInfo> = Vec::new(); - result.extend(&phases[..half]); - result.extend(&phases[phases.len() - half..]); + if let Some(front) = phases.get(..half) { + result.extend(front); + } + let tail_start = phases.len().saturating_sub(half); + if let Some(back) = phases.get(tail_start..) { + result.extend(back); + } result } diff --git a/dotscope/src/deobfuscation/renamer/providers/local.rs b/dotscope/src/deobfuscation/renamer/providers/local.rs index 9532edab..e6393b29 100644 --- a/dotscope/src/deobfuscation/renamer/providers/local.rs +++ b/dotscope/src/deobfuscation/renamer/providers/local.rs @@ -28,7 +28,7 @@ use crate::{ deobfuscation::renamer::{ context::RenameContext, prompt, validate, RenameProvider, SmartRenameConfig, }, - Result, + Error, Result, }; /// System prompt for identifier naming via chat API. @@ -161,7 +161,7 @@ impl LocalProvider { let response = state .runtime .block_on(state.model.send_chat_request(request)) - .map_err(|e| crate::Error::Deobfuscation(format!("Model inference failed: {e}")))?; + .map_err(|e| Error::Deobfuscation(format!("Model inference failed: {e}")))?; if let Some(choice) = response.choices.first() { log::debug!( @@ -202,18 +202,17 @@ impl RenameProvider for LocalProvider { /// - The tokio runtime cannot be created fn initialize(&mut self) -> Result<()> { if !self.config.model_path.exists() { - return Err(crate::Error::Deobfuscation(format!( + return Err(Error::Deobfuscation(format!( "Smart rename model not found: {}", self.config.model_path.display() ))); } - let runtime = Runtime::new().map_err(|e| { - crate::Error::Deobfuscation(format!("Failed to create tokio runtime: {e}")) - })?; + let runtime = Runtime::new() + .map_err(|e| Error::Deobfuscation(format!("Failed to create tokio runtime: {e}")))?; let model_path = self.config.model_path.canonicalize().map_err(|e| { - crate::Error::Deobfuscation(format!( + Error::Deobfuscation(format!( "Failed to resolve model path {}: {e}", self.config.model_path.display() )) @@ -239,7 +238,7 @@ impl RenameProvider for LocalProvider { builder .build() .await - .map_err(|e| crate::Error::Deobfuscation(format!("Model load failed: {e}"))) + .map_err(|e| Error::Deobfuscation(format!("Model load failed: {e}"))) })?; log::info!( diff --git a/dotscope/src/deobfuscation/renamer/providers/simple.rs b/dotscope/src/deobfuscation/renamer/providers/simple.rs index 94984d7a..e6b953c9 100644 --- a/dotscope/src/deobfuscation/renamer/providers/simple.rs +++ b/dotscope/src/deobfuscation/renamer/providers/simple.rs @@ -37,7 +37,7 @@ impl SimpleNameGenerator { /// Names are uppercase base-26: `A`, `B`, ..., `Z`, `AA`, `AB`, ... pub fn next_type_name(&mut self) -> String { let name = Self::index_to_name(self.types); - self.types += 1; + self.types = self.types.saturating_add(1); name } @@ -46,7 +46,7 @@ impl SimpleNameGenerator { /// Names are lowercase base-26: `a`, `b`, ..., `z`, `aa`, `ab`, ... pub fn next_method_name(&mut self) -> String { let name = Self::index_to_name_lower(self.methods); - self.methods += 1; + self.methods = self.methods.saturating_add(1); name } @@ -56,7 +56,7 @@ impl SimpleNameGenerator { /// sequence: `f_a`, `f_b`, ..., `f_z`, `f_aa`, ... pub fn next_field_name(&mut self) -> String { let name = format!("f_{}", Self::index_to_name_lower(self.fields)); - self.fields += 1; + self.fields = self.fields.saturating_add(1); name } @@ -66,7 +66,7 @@ impl SimpleNameGenerator { /// sequence: `p_a`, `p_b`, ..., `p_z`, `p_aa`, ... pub fn next_param_name(&mut self) -> String { let name = format!("p_{}", Self::index_to_name_lower(self.params)); - self.params += 1; + self.params = self.params.saturating_add(1); name } @@ -77,13 +77,13 @@ impl SimpleNameGenerator { pub fn index_to_name(mut index: usize) -> String { let mut result = String::new(); loop { - let remainder = index % 26; + let remainder = index.checked_rem(26).unwrap_or(0); #[allow(clippy::cast_possible_truncation)] - result.insert(0, (b'A' + remainder as u8) as char); + result.insert(0, (b'A'.saturating_add(remainder as u8)) as char); if index < 26 { break; } - index = index / 26 - 1; + index = index.checked_div(26).unwrap_or(0).saturating_sub(1); } result } @@ -96,13 +96,13 @@ impl SimpleNameGenerator { pub fn index_to_name_lower(mut index: usize) -> String { let mut result = String::new(); loop { - let remainder = index % 26; + let remainder = index.checked_rem(26).unwrap_or(0); #[allow(clippy::cast_possible_truncation)] - result.insert(0, (b'a' + remainder as u8) as char); + result.insert(0, (b'a'.saturating_add(remainder as u8)) as char); if index < 26 { break; } - index = index / 26 - 1; + index = index.checked_div(26).unwrap_or(0).saturating_sub(1); } result } diff --git a/dotscope/src/deobfuscation/renamer/validate.rs b/dotscope/src/deobfuscation/renamer/validate.rs index 85caab98..7efcca05 100644 --- a/dotscope/src/deobfuscation/renamer/validate.rs +++ b/dotscope/src/deobfuscation/renamer/validate.rs @@ -88,7 +88,9 @@ pub fn is_valid_dotnet_identifier(name: &str) -> bool { let mut chars = name.chars(); // First character: letter or underscore - let first = chars.next().unwrap(); + let Some(first) = chars.next() else { + return false; + }; if !first.is_alphabetic() && first != '_' { return false; } @@ -208,7 +210,9 @@ pub fn to_pascal_case(name: &str) -> String { } let mut chars = name.chars(); - let first = chars.next().unwrap(); + let Some(first) = chars.next() else { + return String::new(); + }; if first.is_lowercase() { let upper: String = first.to_uppercase().collect(); format!("{upper}{}", chars.as_str()) @@ -234,7 +238,9 @@ pub fn to_camel_case(name: &str) -> String { } let mut chars = name.chars(); - let first = chars.next().unwrap(); + let Some(first) = chars.next() else { + return String::new(); + }; if first.is_uppercase() { let lower: String = first.to_lowercase().collect(); format!("{lower}{}", chars.as_str()) @@ -262,9 +268,9 @@ pub fn deconflict_names(proposed: &mut [String], existing: &[String]) { for name in proposed.iter_mut() { let count = used.entry(name.clone()).or_insert(0); - *count += 1; + *count = count.saturating_add(1); if *count > 1 { - let mut suffix = *count; + let mut suffix: usize = *count; loop { let candidate = format!("{name}_{suffix}"); if !used.contains_key(&candidate) { @@ -272,7 +278,7 @@ pub fn deconflict_names(proposed: &mut [String], existing: &[String]) { *name = candidate; break; } - suffix += 1; + suffix = suffix.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/statemachine.rs b/dotscope/src/deobfuscation/statemachine.rs index add1c275..cf4b6c60 100644 --- a/dotscope/src/deobfuscation/statemachine.rs +++ b/dotscope/src/deobfuscation/statemachine.rs @@ -107,7 +107,9 @@ impl StateMachineCallSite { /// Returns a location identifier for this call site. #[must_use] pub fn location(&self) -> usize { - self.block_idx * 1000 + self.instr_idx + self.block_idx + .saturating_mul(1000) + .saturating_add(self.instr_idx) } } @@ -310,8 +312,8 @@ pub trait StateMachineProvider: Send + Sync + std::fmt::Debug { return None; } - if seeds.len() == 1 { - return Some(seeds[0].2); + if let [single] = seeds { + return Some(single.2); } // Multiple seeds - find the one that dominates this decryptor @@ -501,7 +503,7 @@ impl StateSlotOperation { SsaOpKind::Rol => left.rotate_left((right & 63) as u32), SsaOpKind::Ror => left.rotate_right((right & 63) as u32), SsaOpKind::Not => !left, - SsaOpKind::Neg => (-left.cast_signed()).cast_unsigned(), + SsaOpKind::Neg => left.wrapping_neg(), } } @@ -522,7 +524,7 @@ impl StateSlotOperation { SsaOpKind::Rol => left.rotate_left(right & 31), SsaOpKind::Ror => left.rotate_right(right & 31), SsaOpKind::Not => !left, - SsaOpKind::Neg => (-left.cast_signed()).cast_unsigned(), + SsaOpKind::Neg => left.wrapping_neg(), } } } @@ -665,11 +667,8 @@ impl StateMachineSemantics { #[must_use] pub fn init_operation(&self, slot: usize) -> Option<&StateSlotOperation> { // If we have fewer init_ops than slots, cycle through them - if self.init_ops.is_empty() { - None - } else { - Some(&self.init_ops[slot % self.init_ops.len()]) - } + let idx = slot.checked_rem(self.init_ops.len())?; + self.init_ops.get(idx) } } @@ -783,20 +782,29 @@ impl StateMachineState { let is_explicit = (flag & (1 << self.semantics.explicit_flag_bit)) != 0; // Ensure slots exist - let update_slot = update_slot % self.slots.len().max(1); - let get_slot = get_slot % self.slots.len().max(1); + let len = self.slots.len(); + let (Some(update_slot), Some(get_slot)) = + (update_slot.checked_rem(len), get_slot.checked_rem(len)) + else { + return 0; + }; // Update the specified slot if is_explicit { // Explicit: set slot to value - self.slots[update_slot] = value; + if let Some(slot) = self.slots.get_mut(update_slot) { + *slot = value; + } } else if let Some(op) = self.semantics.slot_operation(update_slot) { // Incremental: apply operation - self.slots[update_slot] = op.apply(self.slots[update_slot], value); + let current = self.slots.get(update_slot).copied().unwrap_or(0); + if let Some(slot) = self.slots.get_mut(update_slot) { + *slot = op.apply(current, value); + } } // Return value from requested slot - self.slots[get_slot] + self.slots.get(get_slot).copied().unwrap_or(0) } /// Applies the Next operation with u32 values. @@ -810,20 +818,28 @@ impl StateMachineState { let is_explicit = (flag & (1 << self.semantics.explicit_flag_bit)) != 0; // Ensure slots exist - let update_slot = update_slot % self.slots.len().max(1); - let get_slot = get_slot % self.slots.len().max(1); + let len = self.slots.len(); + let (Some(update_slot), Some(get_slot)) = + (update_slot.checked_rem(len), get_slot.checked_rem(len)) + else { + return 0; + }; // Update the specified slot if is_explicit { - self.slots[update_slot] = u64::from(value); + if let Some(slot) = self.slots.get_mut(update_slot) { + *slot = u64::from(value); + } } else if let Some(op) = self.semantics.slot_operation(update_slot) { #[allow(clippy::cast_possible_truncation)] - let current = self.slots[update_slot] as u32; - self.slots[update_slot] = u64::from(op.apply_u32(current, value)); + let current = self.slots.get(update_slot).copied().unwrap_or(0) as u32; + if let Some(slot) = self.slots.get_mut(update_slot) { + *slot = u64::from(op.apply_u32(current, value)); + } } #[allow(clippy::cast_possible_truncation)] - let result = self.slots[get_slot] as u32; + let result = self.slots.get(get_slot).copied().unwrap_or(0) as u32; result } @@ -842,8 +858,8 @@ impl StateMachineState { /// Sets the value of a specific state slot. pub fn set(&mut self, slot: usize, value: u64) { - if slot < self.slots.len() { - self.slots[slot] = value; + if let Some(s) = self.slots.get_mut(slot) { + *s = value; } } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/calli.rs b/dotscope/src/deobfuscation/techniques/bitmono/calli.rs index 716b3a44..d27e6198 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/calli.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/calli.rs @@ -85,17 +85,20 @@ impl Technique for BitMonoCalli { let method = method_entry.value(); let instructions: Vec<_> = method.instructions().collect(); - let mut method_sites = 0usize; - let mut i = 0; + let mut method_sites: usize = 0; + let mut i: usize = 0; while i < instructions.len() { - if instructions[i].mnemonic == "calli" { + let Some(instr_at_i) = instructions.get(i) else { + break; + }; + if instr_at_i.mnemonic == "calli" { // Walk backwards up to 12 instructions looking for the // characteristic BitMono trampoline pattern: // ldtoken -> GetTypeFromHandle -> get_Module // -> ldc.i4 -> ResolveMethod -> get_MethodHandle // -> GetFunctionPointer -> calli let window_start = i.saturating_sub(12); - let window = &instructions[window_start..i]; + let window = instructions.get(window_start..i).unwrap_or(&[]); let has_ldtoken = window.iter().any(|instr| instr.mnemonic == "ldtoken"); let has_trampoline_api = window.iter().any(|instr| { @@ -108,15 +111,15 @@ impl Technique for BitMonoCalli { }); if has_ldtoken && has_trampoline_api { - method_sites += 1; + method_sites = method_sites.saturating_add(1); } } - i += 1; + i = i.saturating_add(1); } if method_sites > 0 { method_tokens.insert(method.token); - site_count += method_sites; + site_count = site_count.saturating_add(method_sites); } } @@ -143,7 +146,7 @@ impl Technique for BitMonoCalli { for entry in ctx.ssa_functions.iter() { let count = count_resolve_method_calli_sites(entry.value(), assembly); if count > 0 { - site_count += count; + site_count = site_count.saturating_add(count); method_tokens.insert(*entry.key()); } } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs b/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs index fae40f2d..0ccbf42a 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs @@ -47,7 +47,7 @@ use crate::{ token::Token, typesystem::wellknown, }, - CilObject, Result, + CilObject, Error, Result, }; /// Findings from BitMono DotNetHook detection. @@ -129,7 +129,7 @@ impl Technique for BitMonoHooks { } if has_jit_hook_setup && has_marshal_write { - hook_count += 1; + hook_count = hook_count.saturating_add(1); infrastructure_type = Some(cil_type.token); // Identify the RedirectStub method by signature: static void(int32, int32). @@ -332,7 +332,20 @@ impl Technique for BitMonoHooks { let raw_token = call_target.value(); if let Some(&real_target) = redirect_map.get(&raw_token) { // Token operand is the last 4 bytes of the instruction - patches.push((instr.offset + instr.size - 4, real_target)); + let operand_offset = instr + .offset + .checked_add(instr.size) + .and_then(|v| v.checked_sub(4)) + .ok_or_else(|| { + Error::Deobfuscation( + "DotNetHook: instruction offset overflow computing operand position".into(), + ) + }); + let operand_offset = match operand_offset { + Ok(v) => v, + Err(e) => return Some(Err(e)), + }; + patches.push((operand_offset, real_target)); } } continue; @@ -354,7 +367,15 @@ impl Technique for BitMonoHooks { if let Some(&corrected) = stale_correction_map.get(&u) { if corrected != u { // ldc.i4 is 5 bytes: 0x20 + i32 operand - patches.push((instr.offset + 1, corrected)); + let operand_offset = match instr.offset.checked_add(1) { + Some(v) => v, + None => { + return Some(Err(Error::Deobfuscation( + "DotNetHook: ldc.i4 offset overflow computing operand position".into(), + ))); + } + }; + patches.push((operand_offset, corrected)); } } } @@ -537,15 +558,22 @@ fn extract_hook_mappings( continue; } - let arg1 = instructions[i - 2] + let Some(prev2) = instructions.get(i.saturating_sub(2)) else { + continue; + }; + let Some(prev1) = instructions.get(i.saturating_sub(1)) else { + continue; + }; + + let arg1 = prev2 .mnemonic .starts_with("ldc.i4") - .then(|| instructions[i - 2].get_i32_operand()) + .then(|| prev2.get_i32_operand()) .flatten(); - let arg2 = instructions[i - 1] + let arg2 = prev1 .mnemonic .starts_with("ldc.i4") - .then(|| instructions[i - 1].get_i32_operand()) + .then(|| prev1.get_i32_operand()) .flatten(); if let (Some(a1), Some(a2)) = (arg1, arg2) { @@ -621,7 +649,7 @@ fn extract_hook_mappings( let offset = target_offset.unwrap_or(0); let total_methods = assembly.methods().iter().count() as u32; let original_count = if offset > 0 { - (total_methods as i64 - offset) as u32 + (total_methods as i64).saturating_sub(offset).max(0) as u32 } else { total_methods }; @@ -629,7 +657,11 @@ fn extract_hook_mappings( let mut stale_correction_map: HashMap = (1..=original_count) .filter_map(|r| { let stale = 0x0600_0000 | r; - let final_row = (r as i64 + offset) as u32; + let final_row_i64 = (r as i64).saturating_add(offset); + if final_row_i64 < 1 { + return None; + } + let final_row = final_row_i64 as u32; if final_row != r && final_row >= 1 && final_row <= total_methods { Some((stale, 0x0600_0000 | final_row)) } else { @@ -691,22 +723,30 @@ fn extract_hook_mappings( /// - 2 instructions: `ldc.*` or `ldnull`, then `ret` /// - 3 instructions: `ldc.*` + `conv.*` + `ret`, or `ldloca` + `initobj` + `ret` fn is_dummy_body(instructions: &[&crate::assembly::Instruction]) -> bool { - if instructions.is_empty() { + let Some(last) = instructions.last() else { return false; - } - let last = instructions.last().unwrap(); + }; if last.mnemonic != "ret" { return false; } match instructions.len() { 1 => true, // just ret (void or stack-underflow dummy) 2 => { - let m = instructions[0].mnemonic; + let Some(first) = instructions.first() else { + return false; + }; + let m = first.mnemonic; m.starts_with("ldc.") || m == "ldnull" } 3 => { - let m0 = instructions[0].mnemonic; - let m1 = instructions[1].mnemonic; + let Some(i0) = instructions.first() else { + return false; + }; + let Some(i1) = instructions.get(1) else { + return false; + }; + let m0 = i0.mnemonic; + let m1 = i1.mnemonic; // ldc.i4.0 + conv.i8 + ret (int64 return) // ldloca.s + initobj + ret (value type return — rare) (m0.starts_with("ldc.") && m1.starts_with("conv.")) @@ -779,10 +819,11 @@ fn compute_target_offset( }; let stale_row = (stale_target & 0x00FF_FFFF) as i64; - if candidates.len() == 1 { - let final_row = candidates[0].row() as i64; - let offset = final_row - stale_row; - *offset_votes.entry(offset).or_insert(0) += 1; + if let [single] = candidates.as_slice() { + let final_row = single.row() as i64; + let offset = final_row.saturating_sub(stale_row); + let entry = offset_votes.entry(offset).or_insert(0_usize); + *entry = entry.saturating_add(1); } } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/junk.rs b/dotscope/src/deobfuscation/techniques/bitmono/junk.rs index 971db86d..4bc18171 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/junk.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/junk.rs @@ -70,13 +70,13 @@ impl Technique for BitMonoJunk { // Check for br.s at method start with a small positive forward offset. // BitMethodDotnet inserts br.s that jumps over 1-10 bytes of junk. - if instructions[0].mnemonic == "br.s" { - let is_small_forward_jump = matches!( - instructions[0].operand, - Operand::Immediate(Immediate::Int8(1..=10)) - ); - if is_small_forward_jump { - junk_method_count += 1; + if let Some(first) = instructions.first() { + if first.mnemonic == "br.s" { + let is_small_forward_jump = + matches!(first.operand, Operand::Immediate(Immediate::Int8(1..=10))); + if is_small_forward_jump { + junk_method_count = junk_method_count.saturating_add(1); + } } } } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/renamer.rs b/dotscope/src/deobfuscation/techniques/bitmono/renamer.rs index c51282c5..a6659c3c 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/renamer.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/renamer.rs @@ -47,7 +47,7 @@ impl Technique for BitMonoRenamer { // Check type name for spaces (skip angle-bracket compiler-generated names) if cil_type.name.contains(' ') && !cil_type.name.starts_with('<') { - space_name_count += 1; + space_name_count = space_name_count.saturating_add(1); } // Check method names @@ -59,7 +59,7 @@ impl Technique for BitMonoRenamer { continue; }; if method.name.contains(' ') && !method.name.starts_with('<') { - space_name_count += 1; + space_name_count = space_name_count.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/strings.rs b/dotscope/src/deobfuscation/techniques/bitmono/strings.rs index 22125622..770edc99 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/strings.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/strings.rs @@ -190,7 +190,7 @@ impl Technique for BitMonoStrings { if !decryptor_set.contains(&call_token) { continue; } - call_site_count += 1; + call_site_count = call_site_count.saturating_add(1); // Trace LoadStaticField args to collect infrastructure field tokens for arg in args { @@ -433,7 +433,9 @@ fn extract_crypto_parameters(ssa: &SsaFunction, assembly: &CilObject) -> CryptoP if let Some(name) = resolve_type_name(assembly, ctor.token()) { if name.contains("Rfc2898DeriveBytes") && args.len() >= 3 { // 3rd arg (index 2) is the iteration count - if let Some(ConstValue::I32(iters)) = const_map.get(&args[2]) { + if let Some(ConstValue::I32(iters)) = + args.get(2).and_then(|a| const_map.get(a)) + { if *iters > 0 { params.iterations = *iters as u32; } @@ -449,7 +451,9 @@ fn extract_crypto_parameters(ssa: &SsaFunction, assembly: &CilObject) -> CryptoP if let Some(name) = assembly.resolve_method_name(method.token()) { if name == "GetBytes" && args.len() == 2 { // args[0] = this (Rfc2898DeriveBytes instance), args[1] = size - if let Some(ConstValue::I32(size)) = const_map.get(&args[1]) { + if let Some(ConstValue::I32(size)) = + args.get(1).and_then(|a| const_map.get(a)) + { if *size > 0 { get_bytes_sizes.push(*size as u32); } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs b/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs index 231281f1..0cca0eb6 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs @@ -221,11 +221,7 @@ fn extract_native_string( let offset = file.rva_to_offset(rva as usize).ok()?; let data = file.data(); - if offset >= data.len() { - return None; - } - - let native_bytes = &data[offset..]; + let native_bytes = data.get(offset..)?; // Use traversal-based disassembly to find where the code ends. // This follows control flow edges (including trampolines, call targets, @@ -235,17 +231,13 @@ fn extract_native_string( return None; } - let string_bytes = &native_bytes[prefix_len..]; + let string_bytes = native_bytes.get(prefix_len..)?; // Detect encoding: if the second byte is 0x00 and first byte is printable ASCII, // the data is likely UTF-16LE (e.g., "Hello" = 48 00 65 00 6C 00 ...). // This check must come before the ASCII attempt, which would find the 0x00 at // position 1 and incorrectly return just the first character. - let looks_like_utf16 = string_bytes.len() >= 4 - && string_bytes[0] != 0 - && string_bytes[1] == 0 - && string_bytes[2] != 0 - && string_bytes[3] == 0; + let looks_like_utf16 = matches!(string_bytes, [b0, 0, b2, 0, ..] if *b0 != 0 && *b2 != 0); if looks_like_utf16 { // Try UTF-16LE first (char* constructor) @@ -257,8 +249,10 @@ fn extract_native_string( // Try UTF-8/ASCII (sbyte* constructor) if let Some(null_pos) = string_bytes.iter().position(|&b| b == 0) { if null_pos > 0 { - if let Ok(s) = std::str::from_utf8(&string_bytes[..null_pos]) { - return Some(s.to_string()); + if let Some(slice) = string_bytes.get(..null_pos) { + if let Ok(s) = std::str::from_utf8(slice) { + return Some(s.to_string()); + } } } } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/constants.rs b/dotscope/src/deobfuscation/techniques/confuserex/constants.rs index bd5f63bd..6f3eb71e 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/constants.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/constants.rs @@ -265,17 +265,22 @@ impl Technique for ConfuserExConstants { let sig = &method.signature; // Check for string(int32) signature + let first_param_is_i4 = sig + .params + .first() + .is_some_and(|p| p.base == TypeSignature::I4); + let is_string_decryptor = sig.param_count_generic == 0 && sig.return_type.base == TypeSignature::String && sig.params.len() == 1 - && sig.params[0].base == TypeSignature::I4; + && first_param_is_i4; // Check for generic T(int32) signature (param_count_generic == 1, // return type is GenericParamMethod(0)) let is_generic_decryptor = sig.param_count_generic == 1 && matches!(sig.return_type.base, TypeSignature::GenericParamMethod(0)) && sig.params.len() == 1 - && sig.params[0].base == TypeSignature::I4; + && first_param_is_i4; if is_string_decryptor || is_generic_decryptor { decryptor_tokens.push(method.token); @@ -296,7 +301,7 @@ impl Technique for ConfuserExConstants { } if let Ok(offset) = file.rva_to_offset(row.rva as usize) { let data = file.data(); - if offset < data.len() && data[offset] == LZMA_MAGIC { + if data.get(offset).is_some_and(|b| *b == LZMA_MAGIC) { has_lzma_fieldrva = true; data_field_tokens.push(Token::from_parts(TableId::Field, row.field)); } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/helpers.rs b/dotscope/src/deobfuscation/techniques/confuserex/helpers.rs index 964c1264..41692170 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/helpers.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/helpers.rs @@ -81,18 +81,19 @@ pub(super) fn extract_method_body_at_rva(memory: &[u8], rva: u32) -> Option body_slice.len() { return None; } - let il_code = &body_slice[il_start..il_end]; + let il_code = body_slice.get(il_start..il_end)?; let mut output = Vec::new(); body.write_to(&mut output, il_code).ok()?; @@ -112,7 +113,7 @@ pub(super) fn extract_decrypted_field_data( virtual_image: &[u8], ) -> (Vec<(u32, u32, Vec)>, usize) { let mut fields = Vec::new(); - let mut failed_count = 0; + let mut failed_count: usize = 0; let Some(tables) = assembly.tables() else { return (fields, failed_count); @@ -128,17 +129,21 @@ pub(super) fn extract_decrypted_field_data( } let Some(field_size) = get_field_data_size(assembly, row.field) else { - failed_count += 1; + failed_count = failed_count.saturating_add(1); continue; }; let rva_usize = rva as usize; - if rva_usize + field_size > virtual_image.len() { - failed_count += 1; + let Some(end) = rva_usize.checked_add(field_size) else { + failed_count = failed_count.saturating_add(1); continue; - } + }; + let Some(slice) = virtual_image.get(rva_usize..end) else { + failed_count = failed_count.saturating_add(1); + continue; + }; - let data = virtual_image[rva_usize..rva_usize + field_size].to_vec(); + let data = slice.to_vec(); fields.push((row.rid, rva, data)); } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/marker.rs b/dotscope/src/deobfuscation/techniques/confuserex/marker.rs index 725b3293..c0e1fca6 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/marker.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/marker.rs @@ -198,19 +198,19 @@ fn extract_version_from_blob(assembly: &CilObject, row: &CustomAttributeRaw) -> } // Check prolog 0x0001 - if data[0] != 0x01 || data[1] != 0x00 { + if *data.first()? != 0x01 || *data.get(1)? != 0x00 { return None; } // Read packed string length - let (str_len, offset) = read_packed_len(&data[2..])?; - let start = 2 + offset; - let end = start + str_len; + let (str_len, offset) = read_packed_len(data.get(2..)?)?; + let start = 2usize.checked_add(offset)?; + let end = start.checked_add(str_len)?; if end > data.len() { return None; } - let s = std::str::from_utf8(&data[start..end]).ok()?; + let s = std::str::from_utf8(data.get(start..end)?).ok()?; if s.is_empty() { return None; } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/metadata.rs b/dotscope/src/deobfuscation/techniques/confuserex/metadata.rs index e996e3aa..3b0c9567 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/metadata.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/metadata.rs @@ -124,12 +124,16 @@ impl Technique for ConfuserExMetadata { if let Some(module_table) = tables.table::() { for row in module_table { if row.name == CONFUSEREX_MARKER || row.name as usize >= strings_size { - findings.invalid_entries += 1; - findings.patches.push(CxMetadataPatch { - offset: row.offset + 2, // name field offset within Module row - size: if strings_size > 0xFFFF { 4 } else { 2 }, - corrected: 0, - }); + findings.invalid_entries = findings.invalid_entries.saturating_add(1); + // Skip rows where the file offset would overflow when adding the + // name field offset within the Module row. + if let Some(name_offset) = row.offset.checked_add(2) { + findings.patches.push(CxMetadataPatch { + offset: name_offset, + size: if strings_size > 0xFFFF { 4 } else { 2 }, + corrected: 0, + }); + } } } } @@ -138,7 +142,7 @@ impl Technique for ConfuserExMetadata { if let Some(assembly_table) = tables.table::() { for row in assembly_table { if row.name == CONFUSEREX_MARKER || row.name as usize >= strings_size { - findings.invalid_entries += 1; + findings.invalid_entries = findings.invalid_entries.saturating_add(1); } } } @@ -147,7 +151,7 @@ impl Technique for ConfuserExMetadata { if let Some(declsec_table) = tables.table::() { for row in declsec_table { if row.action == CONFUSEREX_MARKER_16 || row.action > 0x000E { - findings.invalid_entries += 1; + findings.invalid_entries = findings.invalid_entries.saturating_add(1); } } } @@ -156,7 +160,7 @@ impl Technique for ConfuserExMetadata { if let Some(typeref_table) = tables.table::() { for row in typeref_table { if row.resolution_scope.tag == TableId::Module && row.resolution_scope.row == 0 { - findings.invalid_entries += 1; + findings.invalid_entries = findings.invalid_entries.saturating_add(1); } } } @@ -228,11 +232,11 @@ impl Technique for ConfuserExMetadata { for patch in &findings.patches { match patch.size { 2 => match assembly.write_le::(patch.offset, patch.corrected as u16) { - Ok(_) => patched += 1, + Ok(_) => patched = patched.saturating_add(1), Err(e) => return Some(Err(e)), }, 4 => match assembly.write_le::(patch.offset, patch.corrected) { - Ok(_) => patched += 1, + Ok(_) => patched = patched.saturating_add(1), Err(e) => return Some(Err(e)), }, _ => {} @@ -284,53 +288,88 @@ fn check_duplicate_streams(assembly: &CilObject) -> bool { // Parse stream count from metadata root header. // Metadata root layout: signature(4) + major(2) + minor(2) + reserved(4) + version_len(4) let header_base = metadata_offset; - if header_base + 16 > data.len() { + let Some(header_end) = header_base.checked_add(16) else { + return false; + }; + if header_end > data.len() { return false; } - let version_len = u32::from_le_bytes( - data[header_base + 12..header_base + 16] - .try_into() - .unwrap_or_default(), - ) as usize; + let Some(version_len_start) = header_base.checked_add(12) else { + return false; + }; + let Some(version_bytes) = data.get(version_len_start..header_end) else { + return false; + }; + let version_len = u32::from_le_bytes(version_bytes.try_into().unwrap_or_default()) as usize; // Align version length to 4 bytes - let aligned_len = (version_len + 3) & !3; - let flags_offset = header_base + 16 + aligned_len; + let aligned_len = match version_len.checked_add(3) { + Some(v) => v & !3, + None => return false, + }; + let Some(flags_offset) = header_end.checked_add(aligned_len) else { + return false; + }; - if flags_offset + 4 > data.len() { + let Some(streams_end) = flags_offset.checked_add(4) else { + return false; + }; + if streams_end > data.len() { return false; } // flags(2) + streams(2) - let stream_count = u16::from_le_bytes( - data[flags_offset + 2..flags_offset + 4] - .try_into() - .unwrap_or_default(), - ) as usize; + let Some(stream_count_start) = flags_offset.checked_add(2) else { + return false; + }; + let Some(stream_count_bytes) = data.get(stream_count_start..streams_end) else { + return false; + }; + let stream_count = + u16::from_le_bytes(stream_count_bytes.try_into().unwrap_or_default()) as usize; // Walk stream headers and check for duplicate names. let mut seen_names = std::collections::HashSet::new(); - let mut pos = flags_offset + 4; + let mut pos = streams_end; for _ in 0..stream_count { - if pos + 8 > data.len() { + let Some(after_header) = pos.checked_add(8) else { + break; + }; + if after_header > data.len() { break; } // offset(4) + size(4) + name (null-terminated, 4-byte aligned) - pos += 8; + pos = after_header; // Read stream name let name_start = pos; - while pos < data.len() && data[pos] != 0 { - pos += 1; + while let Some(&byte) = data.get(pos) { + if byte == 0 { + break; + } + let Some(next) = pos.checked_add(1) else { + return false; + }; + pos = next; } if pos >= data.len() { break; } - let name = std::str::from_utf8(&data[name_start..pos]).unwrap_or(""); - pos += 1; // skip null terminator - // Align to 4 bytes - pos = (pos + 3) & !3; + let name_bytes = match data.get(name_start..pos) { + Some(b) => b, + None => break, + }; + let name = std::str::from_utf8(name_bytes).unwrap_or(""); + pos = match pos.checked_add(1) { + Some(v) => v, + None => break, + }; + // Align to 4 bytes + pos = match pos.checked_add(3) { + Some(v) => v & !3, + None => break, + }; if !name.is_empty() && !seen_names.insert(name.to_string()) { return true; diff --git a/dotscope/src/deobfuscation/techniques/confuserex/natives.rs b/dotscope/src/deobfuscation/techniques/confuserex/natives.rs index 7049c2ad..b6e93438 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/natives.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/natives.rs @@ -191,7 +191,7 @@ impl Technique for ConfuserExNativeHelpers { log::warn!( "Converted {}/{} native x86 methods to CIL (failures: {})", stats.converted, - stats.converted + stats.failed, + stats.converted.saturating_add(stats.failed), stats.errors.join(", ") ); } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/proxy.rs b/dotscope/src/deobfuscation/techniques/confuserex/proxy.rs index 32629afc..15224478 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/proxy.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/proxy.rs @@ -173,13 +173,20 @@ fn is_mild_proxy(instructions: &[&Instruction]) -> bool { } // Last instruction must be ret. - let last = instructions.last().unwrap(); + let Some(last) = instructions.last() else { + return false; + }; if last.mnemonic != "ret" { return false; } // Second-to-last must be a non-virtual call with a token operand. - let call_instr = instructions[instructions.len() - 2]; + let Some(call_idx) = instructions.len().checked_sub(2) else { + return false; + }; + let Some(call_instr) = instructions.get(call_idx) else { + return false; + }; if call_instr.mnemonic != "call" || call_instr.flow_type != FlowType::Call { return false; } @@ -188,7 +195,10 @@ fn is_mild_proxy(instructions: &[&Instruction]) -> bool { } // All preceding instructions must be ldarg variants. - for instr in &instructions[..instructions.len() - 2] { + let Some(prefix) = instructions.get(..call_idx) else { + return false; + }; + for instr in prefix { if !instr.mnemonic.starts_with("ldarg") { return false; } @@ -206,18 +216,28 @@ fn is_strong_proxy(instructions: &[&Instruction], _assembly: &CilObject) -> bool } // First instruction must be ldsfld. - if instructions[0].mnemonic != "ldsfld" { + let Some(first) = instructions.first() else { + return false; + }; + if first.mnemonic != "ldsfld" { return false; } // Last instruction must be ret. - let last = instructions.last().unwrap(); + let Some(last) = instructions.last() else { + return false; + }; if last.mnemonic != "ret" { return false; } // Second-to-last must be callvirt (delegate dispatch). - let call_instr = instructions[instructions.len() - 2]; + let Some(call_idx) = instructions.len().checked_sub(2) else { + return false; + }; + let Some(call_instr) = instructions.get(call_idx) else { + return false; + }; if call_instr.mnemonic != "callvirt" || call_instr.flow_type != FlowType::Call { return false; } @@ -229,7 +249,10 @@ fn is_strong_proxy(instructions: &[&Instruction], _assembly: &CilObject) -> bool // (ldsfld + ldarg* + callvirt + ret) is specific enough to identify proxy stubs. // All instructions between ldsfld and callvirt must be ldarg variants. - for instr in &instructions[1..instructions.len() - 2] { + let Some(middle) = instructions.get(1..call_idx) else { + return false; + }; + for instr in middle { if !instr.mnemonic.starts_with("ldarg") { return false; } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/resources.rs b/dotscope/src/deobfuscation/techniques/confuserex/resources.rs index 1d48f37f..b2779972 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/resources.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/resources.rs @@ -248,7 +248,7 @@ impl Technique for ConfuserExResources { } // Insert each extracted resource. - let mut inserted_count = 0; + let mut inserted_count: usize = 0; for resource in &extracted_resources { let builder = ManifestResourceBuilder::new() .name(&resource.name) @@ -257,7 +257,7 @@ impl Technique for ConfuserExResources { match builder.build(&mut cil_assembly) { Ok(_) => { - inserted_count += 1; + inserted_count = inserted_count.saturating_add(1); log::info!( "Inserted resource: {} ({} bytes)", resource.name, @@ -342,7 +342,7 @@ fn try_emulate_resource_handler( let mut resources = Vec::new(); for captured_asm in process.capture().assemblies().iter() { let data = &captured_asm.data; - if data.len() < 2 || data[0] != b'M' || data[1] != b'Z' { + if data.first() != Some(&b'M') || data.get(1) != Some(&b'Z') { continue; } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs b/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs index 213952e1..171841f2 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs @@ -153,7 +153,9 @@ impl StateMachineProvider for ConfuserExStateMachine { SsaOp::Call { method, args, .. } if method.token() == init_method_token && args.len() >= 2 => { - let seed_var = args[1]; + let Some(&seed_var) = args.get(1) else { + continue; + }; if let Some(ConstValue::I32(seed)) = self.trace_to_constant(seed_var, ssa, ctx, method_token) { @@ -165,8 +167,11 @@ impl StateMachineProvider for ConfuserExStateMachine { SsaOp::NewObj { ctor, args, .. } if ctor.token() == init_method_token && args.len() == 1 => { + let Some(&first_arg) = args.first() else { + continue; + }; if let Some(ConstValue::I32(seed)) = - self.trace_to_constant(args[0], ssa, ctx, method_token) + self.trace_to_constant(first_arg, ssa, ctx, method_token) { #[allow(clippy::cast_sign_loss)] seeds.push((block_idx, instr_idx, seed as u32)); @@ -194,16 +199,16 @@ impl StateMachineProvider for ConfuserExStateMachine { { if method.token() == update_method_token { // CFGCtx.Next takes 3 args: &this, flag (byte), increment (uint) - if args.len() >= 3 { - if let Some(dest) = dest { - updates.push(StateUpdateCall { - block_idx, - instr_idx, - dest: *dest, - flag_var: args[1], - increment_var: args[2], - }); - } + if let (Some(&flag_var), Some(&increment_var), Some(dest)) = + (args.get(1), args.get(2), dest.as_ref()) + { + updates.push(StateUpdateCall { + block_idx, + instr_idx, + dest: *dest, + flag_var, + increment_var, + }); } } } @@ -256,7 +261,10 @@ impl StateMachineProvider for ConfuserExStateMachine { // Check if argument comes from XOR, possibly through a // conversion (conv.i4, conv.u4) that the SSA builder may insert. - let arg_def = ssa.get_definition(args[0]); + let Some(&first_arg) = args.first() else { + continue; + }; + let arg_def = ssa.get_definition(first_arg); let xor_def = match arg_def { Some(SsaOp::Xor { .. }) => arg_def, Some(SsaOp::Conv { operand, .. }) => ssa.get_definition(*operand), @@ -299,7 +307,9 @@ impl StateMachineProvider for ConfuserExStateMachine { cfg_info: &CfgInfo<'_>, seed_block: Option, ) -> Vec { - let feeding_update = &all_updates[call_site.feeding_update_idx]; + let Some(feeding_update) = all_updates.get(call_site.feeding_update_idx) else { + return Vec::new(); + }; let target_block = feeding_update.block_idx; if target_block >= cfg_info.node_count { @@ -315,7 +325,12 @@ impl StateMachineProvider for ConfuserExStateMachine { .push(idx); } for indices in updates_by_block.values_mut() { - indices.sort_by_key(|&idx| all_updates[idx].instr_idx); + indices.sort_by_key(|&idx| { + all_updates + .get(idx) + .map(|u| u.instr_idx) + .unwrap_or(usize::MAX) + }); } // Find a path from entry to the feeding update's block. @@ -356,8 +371,10 @@ impl StateMachineProvider for ConfuserExStateMachine { if block_idx == target_block { // Same block: only updates BEFORE the feeding update for &idx in update_indices { - if all_updates[idx].instr_idx < feeding_update.instr_idx { - relevant_updates.push(idx); + if let Some(update) = all_updates.get(idx) { + if update.instr_idx < feeding_update.instr_idx { + relevant_updates.push(idx); + } } } } else { @@ -367,7 +384,9 @@ impl StateMachineProvider for ConfuserExStateMachine { // Sort by path position (entry first), then instruction index relevant_updates.sort_by_key(|&idx| { - let update = &all_updates[idx]; + let Some(update) = all_updates.get(idx) else { + return (usize::MAX, usize::MAX); + }; let pos = block_position .get(&update.block_idx) .copied() @@ -405,10 +424,10 @@ fn find_path_to_block(cfg_info: &CfgInfo<'_>, target: usize) -> Vec { found = true; break; } - if block >= cfg_info.predecessors.len() { + let Some(preds) = cfg_info.predecessors.get(block) else { continue; - } - for &pred in &cfg_info.predecessors[block] { + }; + for &pred in preds { if let Entry::Vacant(e) = parent.entry(pred) { e.insert(block); queue.push_back(pred); @@ -905,11 +924,9 @@ pub fn find_call_sites(assembly: &CilObject, decryptor_tokens: &[Token]) -> Vec< } // Decryptors take a single int32 argument - if args.is_empty() { + let Some(&arg_var) = args.first() else { continue; - } - - let arg_var = args[0]; + }; // Use backward taint analysis to trace the argument's data flow match analyze_argument_dataflow(&ssa, arg_var) { diff --git a/dotscope/src/deobfuscation/techniques/confuserex/tamper.rs b/dotscope/src/deobfuscation/techniques/confuserex/tamper.rs index 01bf4b45..3204fa5a 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/tamper.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/tamper.rs @@ -195,7 +195,7 @@ impl Technique for ConfuserExAntiTamper { if let Some(text) = text_section { let text_rva_start = text.virtual_address as usize; - let text_rva_end = text_rva_start + text.virtual_size as usize; + let text_rva_end = text_rva_start.saturating_add(text.virtual_size as usize); for row in method_table { if row.rva == 0 { @@ -213,12 +213,12 @@ impl Technique for ConfuserExAntiTamper { // Method body outside .text section suggests encryption if method_rva < text_rva_start || method_rva >= text_rva_end { - encrypted_count += 1; + encrypted_count = encrypted_count.saturating_add(1); // Identify which section this RVA falls into for section in sections { let sec_start = section.virtual_address as usize; - let sec_end = sec_start + section.virtual_size as usize; + let sec_end = sec_start.saturating_add(section.virtual_size as usize); if method_rva >= sec_start && method_rva < sec_end { encrypted_section_names.insert(section.name.clone()); break; @@ -560,16 +560,16 @@ fn extract_decrypted_bodies( methods: &[Token], ) -> (Vec<(Token, Vec)>, usize) { let mut bodies = Vec::new(); - let mut failed_count = 0; + let mut failed_count: usize = 0; for &token in methods { let Some(rva) = helpers::get_method_rva(assembly, token) else { - failed_count += 1; + failed_count = failed_count.saturating_add(1); continue; }; if rva == 0 || rva as usize >= virtual_image.len() { - failed_count += 1; + failed_count = failed_count.saturating_add(1); continue; } @@ -578,7 +578,7 @@ fn extract_decrypted_bodies( bodies.push((token, body_bytes)); } None => { - failed_count += 1; + failed_count = failed_count.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/techniques/detection.rs b/dotscope/src/deobfuscation/techniques/detection.rs index 918791c6..9da3fabd 100644 --- a/dotscope/src/deobfuscation/techniques/detection.rs +++ b/dotscope/src/deobfuscation/techniques/detection.rs @@ -175,7 +175,7 @@ impl Detections { /// Inserts a detection result for a technique. pub fn insert(&mut self, id: impl Into, detection: Detection) { self.entries.insert(id.into(), detection); - self.generation += 1; + self.generation = self.generation.saturating_add(1); } /// Gets the detection result for a technique. @@ -241,7 +241,7 @@ impl Detections { if !detection.detected { return; } - self.generation += 1; + self.generation = self.generation.saturating_add(1); let id = id.into(); match self.entries.get_mut(&id) { Some(existing) if existing.detected => { @@ -265,7 +265,7 @@ impl Detections { /// Uses [`merge`](Self::merge) semantics for each entry: never downgrades /// an existing positive detection. pub fn merge_all(&mut self, other: Detections) { - self.generation += 1; + self.generation = self.generation.saturating_add(1); for (id, detection) in other.entries { if detection.detected { self.merge(id, detection); diff --git a/dotscope/src/deobfuscation/techniques/generic/constants.rs b/dotscope/src/deobfuscation/techniques/generic/constants.rs index cc5151c5..b49d2ef4 100644 --- a/dotscope/src/deobfuscation/techniques/generic/constants.rs +++ b/dotscope/src/deobfuscation/techniques/generic/constants.rs @@ -71,15 +71,16 @@ impl GenericConstants { } let param_count = method.signature.params.len(); + let first_param_base = method.signature.params.first().map(|p| &p.base); // int32(int32) — integer constant accessor let is_int_accessor = param_count == 1 - && matches!(method.signature.params[0].base, TypeSignature::I4) + && matches!(first_param_base, Some(TypeSignature::I4)) && matches!(method.signature.return_type.base, TypeSignature::I4); // object(int32) — generic constant accessor let is_obj_accessor = param_count == 1 - && matches!(method.signature.params[0].base, TypeSignature::I4) + && matches!(first_param_base, Some(TypeSignature::I4)) && matches!(method.signature.return_type.base, TypeSignature::Object); if !is_int_accessor && !is_obj_accessor { @@ -180,7 +181,7 @@ impl Technique for GenericConstants { // Direct match if let Some(c) = counts.get_mut(&token) { - *c += 1; + *c = c.saturating_add(1); continue; } @@ -191,7 +192,7 @@ impl Technique for GenericConstants { .or_insert_with(|| assembly.resolver().resolve_memberref_method(token)); if let Some(resolved_token) = resolved { if let Some(c) = counts.get_mut(resolved_token) { - *c += 1; + *c = c.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/techniques/generic/delegates.rs b/dotscope/src/deobfuscation/techniques/generic/delegates.rs index 48f505f9..92b549d7 100644 --- a/dotscope/src/deobfuscation/techniques/generic/delegates.rs +++ b/dotscope/src/deobfuscation/techniques/generic/delegates.rs @@ -120,8 +120,10 @@ fn traces_to_argument(ssa: &SsaFunction, var: SsaVarId) -> bool { if let Some((_block_idx, phi)) = ssa.find_phi_defining(current) { let operands = phi.operands(); if operands.len() == 1 { - current = operands[0].value(); - continue; + if let Some(first) = operands.first() { + current = first.value(); + continue; + } } return false; } @@ -231,7 +233,7 @@ impl Technique for GenericDelegateProxy { if !ty.is_delegate() { continue; } - delegate_count += 1; + delegate_count = delegate_count.saturating_add(1); // Find the singleton static field. Delegate types normally have no // static fields, so any static field is the singleton instance. @@ -243,7 +245,7 @@ impl Technique for GenericDelegateProxy { else { continue; }; - has_static_field += 1; + has_static_field = has_static_field.saturating_add(1); // Find the wrapper method: static (not .cctor/.ctor), whose SSA // shows a CallVirt(Invoke) pattern with the delegate from a parameter. @@ -252,22 +254,20 @@ impl Technique for GenericDelegateProxy { if !method.is_static() || method.is_cctor() || method.is_ctor() { return None; } - has_static_method += 1; + has_static_method = has_static_method.saturating_add(1); // Look up this method's SSA function from the analysis context - let ssa_ref = ctx.ssa_functions.get(&method.token); - if ssa_ref.is_none() { + let Some(ssa) = ctx.ssa_functions.get(&method.token) else { debug!( "Delegate detect: type {}.{} method 0x{:08X} ({}) has no SSA", ty.namespace, ty.name, method.token.value(), method.name ); return None; - } - has_ssa += 1; + }; + has_ssa = has_ssa.saturating_add(1); - let ssa = ssa_ref.unwrap(); if is_delegate_wrapper_ssa(ssa.value(), assembly) { - is_wrapper += 1; + is_wrapper = is_wrapper.saturating_add(1); Some(method.token) } else { debug!( @@ -315,14 +315,14 @@ impl Technique for GenericDelegateProxy { let calli_count = count_resolve_method_calli_sites(ssa, assembly); if calli_count > 0 { reflection_affected.insert(method_token); - reflection_site_count += calli_count; + reflection_site_count = reflection_site_count.saturating_add(calli_count); } // Reflection patterns: Call/CallVirt to reflection APIs (P2, P3, P5, P6) let api_count = count_reflection_api_calls(ssa, assembly); if api_count > 0 { reflection_affected.insert(method_token); - reflection_site_count += api_count; + reflection_site_count = reflection_site_count.saturating_add(api_count); } } @@ -425,7 +425,7 @@ impl Technique for GenericDelegateProxy { /// `FieldInfo.GetValue`, and `FieldInfo.SetValue` where the target comes from /// a traceable reflection chain. fn count_reflection_api_calls(ssa: &SsaFunction, assembly: &CilObject) -> usize { - let mut count = 0; + let mut count: usize = 0; for block in ssa.blocks() { for instr in block.instructions() { let (method_token, arg_count) = match instr.op() { @@ -445,7 +445,7 @@ fn count_reflection_api_calls(ssa: &SsaFunction, assembly: &CilObject) -> usize if is_method_named(assembly, method_token, "MethodBase") || is_method_named(assembly, method_token, "MethodInfo") { - count += 1; + count = count.saturating_add(1); continue; } } @@ -454,7 +454,7 @@ fn count_reflection_api_calls(ssa: &SsaFunction, assembly: &CilObject) -> usize if name.contains("CreateInstance") && is_method_named(assembly, method_token, "Activator") { - count += 1; + count = count.saturating_add(1); continue; } @@ -462,7 +462,7 @@ fn count_reflection_api_calls(ssa: &SsaFunction, assembly: &CilObject) -> usize if (name == "GetValue" || name == "SetValue") && is_method_named(assembly, method_token, "FieldInfo") { - count += 1; + count = count.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/techniques/generic/flattening.rs b/dotscope/src/deobfuscation/techniques/generic/flattening.rs index 1a93e898..03bfd0ff 100644 --- a/dotscope/src/deobfuscation/techniques/generic/flattening.rs +++ b/dotscope/src/deobfuscation/techniques/generic/flattening.rs @@ -110,7 +110,7 @@ impl Technique for GenericFlattening { .collect(); if !method_dispatchers.is_empty() { - total_dispatchers += method_dispatchers.len(); + total_dispatchers = total_dispatchers.saturating_add(method_dispatchers.len()); dispatchers_by_method.insert(method_token, method_dispatchers); } } diff --git a/dotscope/src/deobfuscation/techniques/generic/handlers.rs b/dotscope/src/deobfuscation/techniques/generic/handlers.rs index a0b0133d..12bb3034 100644 --- a/dotscope/src/deobfuscation/techniques/generic/handlers.rs +++ b/dotscope/src/deobfuscation/techniques/generic/handlers.rs @@ -84,7 +84,9 @@ impl Technique for GenericHandlers { continue; } - let body_data = &file.data()[offset..offset + available]; + let Some(body_data) = file.data().get(offset..offset.saturating_add(available)) else { + continue; + }; if let Ok((_, filtered)) = MethodBody::from_lenient(body_data) { if filtered > 0 { diff --git a/dotscope/src/deobfuscation/techniques/generic/metadata.rs b/dotscope/src/deobfuscation/techniques/generic/metadata.rs index 383627e8..9b6bf5f4 100644 --- a/dotscope/src/deobfuscation/techniques/generic/metadata.rs +++ b/dotscope/src/deobfuscation/techniques/generic/metadata.rs @@ -114,9 +114,12 @@ impl Technique for GenericMetadata { let name_index = row.name as usize; let is_sentinel = KNOWN_SENTINEL_VALUES.contains(&row.name); if name_index >= strings_size || is_sentinel { - findings.invalid_module_rows += 1; + findings.invalid_module_rows = findings.invalid_module_rows.saturating_add(1); + let Some(field_offset) = row.offset.checked_add(2) else { + continue; + }; findings.patches.push(MetadataPatch { - offset: row.offset + 2, // name field offset within Module row + offset: field_offset, // name field offset within Module row size: if strings_size > 0xFFFF { 4 } else { 2 }, original: row.name, corrected: 0, @@ -132,7 +135,8 @@ impl Technique for GenericMetadata { for row in assembly_table { let is_sentinel = KNOWN_SENTINEL_VALUES.contains(&row.name); if row.name as usize >= strings_size || is_sentinel { - findings.invalid_assembly_rows += 1; + findings.invalid_assembly_rows = + findings.invalid_assembly_rows.saturating_add(1); } } } @@ -144,7 +148,8 @@ impl Technique for GenericMetadata { for row in declsec_table { let is_sentinel = KNOWN_SENTINEL_VALUES_16.contains(&row.action); if row.action > 0x000E || is_sentinel { - findings.invalid_declsecurity_rows += 1; + findings.invalid_declsecurity_rows = + findings.invalid_declsecurity_rows.saturating_add(1); } } } @@ -155,14 +160,16 @@ impl Technique for GenericMetadata { // Resolution scope with tag Module but row 0 is suspicious // (valid Module is row 1) if row.resolution_scope.tag == TableId::Module && row.resolution_scope.row == 0 { - findings.invalid_typeref_scopes += 1; + findings.invalid_typeref_scopes = + findings.invalid_typeref_scopes.saturating_add(1); } } } - let total_invalid = findings.invalid_module_rows - + findings.invalid_assembly_rows - + findings.invalid_declsecurity_rows; + let total_invalid = findings + .invalid_module_rows + .saturating_add(findings.invalid_assembly_rows) + .saturating_add(findings.invalid_declsecurity_rows); if total_invalid > 0 { evidence.push(Evidence::MetadataPattern(format!( @@ -202,13 +209,13 @@ impl Technique for GenericMetadata { if let Err(e) = assembly.write_le::(patch.offset, patch.corrected as u16) { return Some(Err(e)); } - patched += 1; + patched = patched.saturating_add(1); } 4 => { if let Err(e) = assembly.write_le::(patch.offset, patch.corrected) { return Some(Err(e)); } - patched += 1; + patched = patched.saturating_add(1); } _ => {} } diff --git a/dotscope/src/deobfuscation/techniques/generic/strings.rs b/dotscope/src/deobfuscation/techniques/generic/strings.rs index 54366d3e..5f761eb8 100644 --- a/dotscope/src/deobfuscation/techniques/generic/strings.rs +++ b/dotscope/src/deobfuscation/techniques/generic/strings.rs @@ -87,8 +87,8 @@ impl GenericStrings { let matches_signature = match param_count { // string(int32), string(uint32), string(string) 1 => matches!( - method.signature.params[0].base, - TypeSignature::I4 | TypeSignature::U4 | TypeSignature::String + method.signature.params.first().map(|p| &p.base), + Some(TypeSignature::I4 | TypeSignature::U4 | TypeSignature::String) ), // string(int32, int32) — offset+length based 2 => method @@ -113,9 +113,9 @@ impl GenericStrings { /// unsupported signatures. fn default_warmup_args(params: &[SignatureParameter]) -> Option> { match params.len() { - 1 => match params[0].base { - TypeSignature::I4 | TypeSignature::U4 => Some(vec![EmValue::I32(0)]), - TypeSignature::String => Some(vec![EmValue::Null]), + 1 => match params.first().map(|p| &p.base) { + Some(TypeSignature::I4 | TypeSignature::U4) => Some(vec![EmValue::I32(0)]), + Some(TypeSignature::String) => Some(vec![EmValue::Null]), _ => None, }, 2 => { @@ -221,7 +221,7 @@ impl Technique for GenericStrings { // Direct match if let Some(c) = counts.get_mut(&token) { - *c += 1; + *c = c.saturating_add(1); continue; } @@ -232,7 +232,7 @@ impl Technique for GenericStrings { .or_insert_with(|| assembly.resolver().resolve_memberref_method(token)); if let Some(resolved_token) = resolved { if let Some(c) = counts.get_mut(resolved_token) { - *c += 1; + *c = c.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs b/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs index 2a17b213..6c4c8dbd 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs @@ -55,7 +55,7 @@ use crate::{ token::Token, typesystem::wellknown, }, - CilObject, Result, + CilObject, Error, Result, }; /// Findings from RuntimeFieldHandleContainer detection. @@ -137,7 +137,10 @@ impl Technique for JiejieNetArrays { // Accessor: static, single int32 param, returns value type (RuntimeFieldHandle) if method.is_static() && sig.params.len() == 1 - && matches!(sig.params[0].base, TypeSignature::I4) + && sig + .params + .first() + .is_some_and(|p| matches!(p.base, TypeSignature::I4)) && matches!(sig.return_type.base, TypeSignature::ValueType(_)) { accessor_token = Some(method.token); @@ -386,17 +389,20 @@ fn find_my_initialize_array(assembly: &CilObject) -> (Option, Option HashMa let offset_map: HashMap = instructions .iter() .enumerate() - .map(|(i, instr)| (instr.offset as u32 - base_offset, i)) + .map(|(i, instr)| ((instr.offset as u32).wrapping_sub(base_offset), i)) .collect(); let mut stack: Vec = Vec::new(); @@ -614,13 +620,16 @@ fn emulate_delta_chain_cctor(assembly: &CilObject, cctor_token: Token) -> HashMa let mut visited = 0u32; // safety counter while pc < instructions.len() { - visited += 1; + visited = visited.saturating_add(1); if visited > 1000 { break; // safety limit } - let instr = instructions[pc]; - pc += 1; + let Some(instr) = instructions.get(pc) else { + break; + }; + let instr = *instr; + pc = pc.saturating_add(1); match instr.mnemonic { "ldc.i8" => { @@ -669,10 +678,10 @@ fn emulate_delta_chain_cctor(assembly: &CilObject, cctor_token: Token) -> HashMa // Relative offset from end of instruction. // IL offset of this instr = file_offset - base_offset // End of instr = IL offset + instr size - let il_off = instr.offset as u32 - base_offset; - let instr_size = instructions - .get(pc) - .map_or(0, |next| next.offset as u32 - instr.offset as u32); + let il_off = (instr.offset as u32).wrapping_sub(base_offset); + let instr_size = instructions.get(pc).map_or(0, |next| { + (next.offset as u32).wrapping_sub(instr.offset as u32) + }); // For the last case (no next), use standard br sizes let size = if instr_size > 0 { instr_size @@ -681,14 +690,18 @@ fn emulate_delta_chain_cctor(assembly: &CilObject, cctor_token: Token) -> HashMa } else { 2 }; - Some((il_off + size).wrapping_add(*rel as u32)) + Some(il_off.wrapping_add(size).wrapping_add(*rel as u32)) } Operand::Immediate(Immediate::Int8(rel)) => { - let il_off = instr.offset as u32 - base_offset; - let instr_size = instructions - .get(pc) - .map_or(2, |next| next.offset as u32 - instr.offset as u32); - Some((il_off + instr_size).wrapping_add(*rel as i32 as u32)) + let il_off = (instr.offset as u32).wrapping_sub(base_offset); + let instr_size = instructions.get(pc).map_or(2, |next| { + (next.offset as u32).wrapping_sub(instr.offset as u32) + }); + Some( + il_off + .wrapping_add(instr_size) + .wrapping_add(*rel as i32 as u32), + ) } _ => None, }; @@ -717,7 +730,9 @@ fn find_preceding_i32_value( ) -> Option { // Search backward from pos-1, limited distance for j in (0..pos).rev() { - let instr = instructions[j]; + let Some(instr) = instructions.get(j) else { + break; + }; // Try direct ldc.i4* constant if let Some(val) = instr.get_ldc_i4_value() { return Some(val); @@ -731,7 +746,7 @@ fn find_preceding_i32_value( } } // Stop searching after a few instructions - if pos - j > 5 { + if pos.saturating_sub(j) > 5 { break; } } @@ -748,7 +763,9 @@ fn find_preceding_get_handle_index( ) -> Option { // Search backward for `call GetHandle` (the accessor) for j in (0..pos).rev() { - let instr = instructions[j]; + let Some(instr) = instructions.get(j) else { + break; + }; if instr.mnemonic == "call" { if let Operand::Token(t) = &instr.operand { if *t == accessor_token { @@ -758,7 +775,7 @@ fn find_preceding_get_handle_index( } } // Don't search too far back - if pos - j > 10 { + if pos.saturating_sub(j) > 10 { break; } } @@ -779,17 +796,17 @@ fn decrypt_field_rva_data_to_bytes( let tables = assembly .tables() - .ok_or_else(|| crate::Error::Other("No metadata tables available".to_string()))?; + .ok_or_else(|| Error::Other("No metadata tables available".to_string()))?; let fieldrva_table = tables .table::() - .ok_or_else(|| crate::Error::Other("No FieldRVA table found".to_string()))?; + .ok_or_else(|| Error::Other("No FieldRVA table found".to_string()))?; let field_rid = field_token.row(); let rva_entry = fieldrva_table .iter() .find(|row| row.field == field_rid) .ok_or_else(|| { - crate::Error::Other(format!( + Error::Other(format!( "No FieldRVA entry for field 0x{:08X}", field_token.value(), )) @@ -826,25 +843,24 @@ fn calculate_field_data_size(assembly: &CilObject, field_rid: u32) -> Result() - .ok_or_else(|| crate::Error::Other("No Field table".to_string()))?; + .ok_or_else(|| Error::Other("No Field table".to_string()))?; let field_row = field_table .iter() .find(|r| r.rid == field_rid) - .ok_or_else(|| crate::Error::Other(format!("Field {field_rid} not found")))?; + .ok_or_else(|| Error::Other(format!("Field {field_rid} not found")))?; let blobs = assembly .blob() - .ok_or_else(|| crate::Error::Other("No blob heap".to_string()))?; + .ok_or_else(|| Error::Other("No blob heap".to_string()))?; let sig_data = blobs .get(field_row.signature as usize) - .map_err(|_| crate::Error::Other(format!("Cannot read signature for field {field_rid}")))?; - let field_sig = parse_field_signature(sig_data).map_err(|e| { - crate::Error::Other(format!("Cannot parse field {field_rid} signature: {e}")) - })?; + .map_err(|_| Error::Other(format!("Cannot read signature for field {field_rid}")))?; + let field_sig = parse_field_signature(sig_data) + .map_err(|e| Error::Other(format!("Cannot parse field {field_rid} signature: {e}")))?; // Try primitive size first let ptr_size = crate::metadata::typesystem::PointerSize::from_pe(assembly.file().pe().is_64bit); @@ -866,7 +882,7 @@ fn calculate_field_data_size(assembly: &CilObject, field_rid: u32) -> Result = method.instructions().collect(); - let mut ldc_i8_count = 0; + let mut ldc_i8_count: usize = 0; let mut has_conv_i4 = false; let mut has_stsfld = false; let mut has_dup = false; for instr in &instructions { match instr.mnemonic { - "ldc.i8" => ldc_i8_count += 1, + "ldc.i8" => ldc_i8_count = ldc_i8_count.saturating_add(1), "conv.i4" => has_conv_i4 = true, "stsfld" => has_stsfld = true, "dup" => has_dup = true, diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs b/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs index e95e8cda..55f93003 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs @@ -266,7 +266,7 @@ impl Technique for JiejieNetResources { // Insert decrypted resources into the assembly via CilAssembly let mut cil_assembly = cilobject.into_assembly(); - let mut inserted_count = 0; + let mut inserted_count: usize = 0; for (name, data) in &decrypted_resources { let builder = ManifestResourceBuilder::new() .name(name) @@ -275,7 +275,7 @@ impl Technique for JiejieNetResources { match builder.build(&mut cil_assembly) { Ok(_) => { - inserted_count += 1; + inserted_count = inserted_count.saturating_add(1); log::info!( "JIEJIE.NET resources: inserted resource '{}' ({} bytes)", name, @@ -473,7 +473,10 @@ fn detect_resource_structure(assembly: &CilObject) -> Option // SMF_GetContent: static byte[](string) if sig.params.len() == 1 - && matches!(sig.params[0].base, TypeSignature::String) + && sig + .params + .first() + .is_some_and(|p| matches!(p.base, TypeSignature::String)) && matches!(sig.return_type.base, TypeSignature::SzArray(_)) { get_content_method = Some(method.token); @@ -560,7 +563,10 @@ fn extract_xor_key_from_stream(assembly: &CilObject, stream_type_token: Token) - } // Check first param is byte[] - let first_is_array = matches!(&sig.params[0].base, TypeSignature::SzArray(_)); + let first_is_array = sig + .params + .first() + .is_some_and(|p| matches!(p.base, TypeSignature::SzArray(_))); if !first_is_array { continue; } @@ -579,7 +585,9 @@ fn extract_xor_key_from_stream(assembly: &CilObject, stream_type_token: Token) - } for j in (0..i).rev() { - let prev = &instructions[j]; + let Some(prev) = instructions.get(j) else { + break; + }; if prev.mnemonic.starts_with("ldc.i4") { if let Operand::Immediate(imm) = &prev.operand { let key_value = match imm { @@ -593,7 +601,7 @@ fn extract_xor_key_from_stream(assembly: &CilObject, stream_type_token: Token) - } } // Don't search too far back - if i - j > 3 { + if i.saturating_sub(j) > 3 { break; } } @@ -649,12 +657,10 @@ fn extract_resource_entries_ssa(ssa: &SsaFunction, assembly: &CilObject) -> Vec< }; // The true branch (next block) should contain a Call to a static byte[]() method - let successor_idx = block_idx + 1; - if successor_idx >= blocks.len() { + let successor_idx = block_idx.saturating_add(1); + let Some(successor) = blocks.get(successor_idx) else { continue; - } - - let successor = &blocks[successor_idx]; + }; for instr in successor.instructions() { let method_token = match instr.op() { SsaOp::Call { method, .. } => method.token(), @@ -720,7 +726,7 @@ fn trace_to_string_const(ssa: &SsaFunction, var: SsaVarId, assembly: &CilObject) value: ConstValue::DecryptedString(s), .. } => Some(s.clone()), - SsaOp::Copy { src, .. } => trace_impl(ssa, *src, assembly, depth + 1), + SsaOp::Copy { src, .. } => trace_impl(ssa, *src, assembly, depth.saturating_add(1)), _ => None, } } @@ -810,8 +816,28 @@ fn extract_and_decrypt_resource( ))); } - let gzip_len = u32::from_le_bytes([raw_data[0], raw_data[1], raw_data[2], raw_data[3]]); - let payload = &raw_data[4..]; + let header: [u8; 4] = raw_data + .get(..4) + .ok_or_else(|| { + Error::Deobfuscation(format!( + "Resource data missing 4-byte length header for '{}'", + entry.name + )) + })? + .try_into() + .map_err(|_| { + Error::Deobfuscation(format!( + "Resource data missing 4-byte length header for '{}'", + entry.name + )) + })?; + let gzip_len = u32::from_le_bytes(header); + let payload = raw_data.get(4..).ok_or_else(|| { + Error::Deobfuscation(format!( + "Resource data has no payload after header for '{}'", + entry.name + )) + })?; let content = if gzip_len > 0 { let decompressed = decompress_gzip(payload).map_err(|e| { diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs b/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs index 5cfc5c2f..3944e515 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs @@ -118,10 +118,18 @@ impl Technique for JiejieNetStrings { let sig = &method.signature; // Check for dcsoft signature: string(byte[], int64) + let p0_is_szarray = sig + .params + .first() + .is_some_and(|p| matches!(p.base, TypeSignature::SzArray(_))); + let p1_is_i8 = sig + .params + .get(1) + .is_some_and(|p| matches!(p.base, TypeSignature::I8)); if matches!(sig.return_type.base, TypeSignature::String) && sig.params.len() == 2 - && matches!(sig.params[0].base, TypeSignature::SzArray(_)) - && matches!(sig.params[1].base, TypeSignature::I8) + && p0_is_szarray + && p1_is_i8 && method.is_static() { dcsoft_token = Some(method.token); @@ -357,11 +365,11 @@ fn is_byte_array_data_container(cil_type: &CilType) -> bool { } // Check nested types for ExplicitLayout (RVA-backed data storage) - let mut explicit_layout_count = 0; + let mut explicit_layout_count: usize = 0; for (_, nested_ref) in cil_type.nested_types.iter() { if let Some(nested) = nested_ref.upgrade() { if nested.flags.layout() == crate::metadata::tables::TypeAttributes::EXPLICIT_LAYOUT { - explicit_layout_count += 1; + explicit_layout_count = explicit_layout_count.saturating_add(1); } } } diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs b/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs index 3fc7a338..fd45ade4 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs @@ -110,7 +110,10 @@ impl Technique for JiejieNetTypeOf { // Accessor: static, single int32 param, returns a class (Type) if method.is_static() && sig.params.len() == 1 - && matches!(sig.params[0].base, TypeSignature::I4) + && sig + .params + .first() + .is_some_and(|p| matches!(p.base, TypeSignature::I4)) && matches!(sig.return_type.base, TypeSignature::Class(_)) { accessor_token = Some(method.token); diff --git a/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs b/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs index faa4b5a8..bd52b1e9 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs @@ -65,7 +65,7 @@ pub fn scan_stub_methods(assembly: &CilObject) -> StubScanResult { continue; } - total_il_methods += 1; + total_il_methods = total_il_methods.saturating_add(1); // Skip the entry point — it's never encrypted if entry_point_token.is_some_and(|ep| ep == method.token) { @@ -79,12 +79,16 @@ pub fn scan_stub_methods(assembly: &CilObject) -> StubScanResult { } // Check the exact 4-instruction stub pattern - if instrs[0].mnemonic != "nop" || instrs[1].mnemonic != "nop" || instrs[3].mnemonic != "ret" - { + let (Some(i0), Some(i1), Some(i2), Some(i3)) = + (instrs.first(), instrs.get(1), instrs.get(2), instrs.get(3)) + else { + continue; + }; + if i0.mnemonic != "nop" || i1.mnemonic != "nop" || i3.mnemonic != "ret" { continue; } - let kind = match instrs[2].mnemonic { + let kind = match i2.mnemonic { "nop" => StubKind::Void, "ldc.i4.0" => StubKind::Value, "ldnull" => StubKind::Reference, @@ -325,7 +329,12 @@ fn is_nr_private_impl_name(name: &str) -> bool { if !name.starts_with(PREFIX) || !name.ends_with('}') { return false; } - let body = &name[PREFIX.len()..name.len() - 1]; + let Some(end) = name.len().checked_sub(1) else { + return false; + }; + let Some(body) = name.get(PREFIX.len()..end) else { + return false; + }; // Canonical GUID layout: 8-4-4-4-12 hex chars separated by '-' let segments: Vec<&str> = body.split('-').collect(); @@ -356,7 +365,12 @@ fn is_nr_guid_module_name(name: &str) -> bool { if !name.starts_with(PREFIX) || !name.ends_with('}') { return false; } - let body = &name[PREFIX.len()..name.len() - 1]; + let Some(end) = name.len().checked_sub(1) else { + return false; + }; + let Some(body) = name.get(PREFIX.len()..end) else { + return false; + }; let segments: Vec<&str> = body.split('-').collect(); if segments.len() != 5 { @@ -548,8 +562,11 @@ pub fn find_nr_token_resolver(assembly: &CilObject) -> Option if method.signature.params.len() != 1 { continue; } + let Some(first_param) = method.signature.params.first() else { + continue; + }; if !matches!( - method.signature.params[0].base, + first_param.base, crate::metadata::signatures::TypeSignature::I4 ) { continue; @@ -602,15 +619,19 @@ fn classify_token_accessor_body(assembly: &CilObject, method_token: Token) -> Op if instrs.len() != 4 { return None; } - if instrs[0].mnemonic != "ldsflda" - || instrs[1].mnemonic != "ldarg.0" - || instrs[2].mnemonic != "call" - || instrs[3].mnemonic != "ret" + let i0 = instrs.first()?; + let i1 = instrs.get(1)?; + let i2 = instrs.get(2)?; + let i3 = instrs.get(3)?; + if i0.mnemonic != "ldsflda" + || i1.mnemonic != "ldarg.0" + || i2.mnemonic != "call" + || i3.mnemonic != "ret" { return None; } - let call_token = instrs[2].get_token_operand()?; + let call_token = i2.get_token_operand()?; let name = crate::deobfuscation::utils::resolve_qualified_method_name(assembly, call_token)?; if !name.contains("ModuleHandle") { return None; diff --git a/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs b/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs index ed7725a2..ecacce54 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs @@ -537,35 +537,67 @@ fn is_necrobit_data_array(data: &[u8]) -> bool { return false; } - let read_u32 = |off: usize| -> u32 { - u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]) + let read_u32 = |off: usize| -> Option { + let slice = data.get(off..off.checked_add(4)?)?; + let arr: [u8; 4] = slice.try_into().ok()?; + Some(u32::from_le_bytes(arr)) }; - let first_token = read_u32(0); + let Some(first_token) = read_u32(0) else { + return false; + }; if first_token >> 24 != 0x06 { return false; } - let group_count = read_u32(16) as usize; + let Some(group_count) = read_u32(16) else { + return false; + }; + let group_count = group_count as usize; if group_count > 0 && group_count <= 500 { // Variant A: group entries + method_count + per-method IL records - let header_end = 24 + group_count * 8; - if header_end + 4 > data.len() { + let Some(header_end) = group_count + .checked_mul(8) + .and_then(|n| 24_usize.checked_add(n)) + else { + return false; + }; + let Some(header_end_plus_4) = header_end.checked_add(4) else { + return false; + }; + if header_end_plus_4 > data.len() { return false; } - let method_count = read_u32(header_end) as usize; + let Some(method_count) = read_u32(header_end) else { + return false; + }; + let method_count = method_count as usize; if method_count == 0 || method_count > 5000 { return false; } - header_end + 4 + method_count * 12 <= data.len() + method_count * 8 + let Some(lhs) = method_count + .checked_mul(12) + .and_then(|n| header_end_plus_4.checked_add(n)) + else { + return false; + }; + let Some(rhs) = method_count + .checked_mul(8) + .and_then(|n| data.len().checked_add(n)) + else { + return false; + }; + lhs <= rhs } else if group_count == 0 { // Variant B: complete method bodies start at offset 24. // Validate by checking if the first entry has a valid method body header. if data.len() < 36 { return false; } - let body_header_byte = data[32]; // offset 24 + 8 (RVA + v2) + let Some(&body_header_byte) = data.get(32) else { + return false; + }; // offset 24 + 8 (RVA + v2) let is_fat = body_header_byte & 0x03 == 0x03; let is_tiny = body_header_byte & 0x03 == 0x02; is_fat || is_tiny @@ -601,7 +633,13 @@ fn find_variant_a_blob(process: &EmulationProcess) -> Option> { if !is_necrobit_data_array(&bytes) { continue; } - let group_count = u32::from_le_bytes([bytes[16], bytes[17], bytes[18], bytes[19]]) as usize; + let Some(gc_slice) = bytes.get(16..20) else { + continue; + }; + let Ok(gc_arr) = <[u8; 4]>::try_from(gc_slice) else { + continue; + }; + let group_count = u32::from_le_bytes(gc_arr) as usize; if group_count == 0 { continue; // variant B — wrong path } @@ -621,8 +659,15 @@ fn find_variant_a_blob(process: &EmulationProcess) -> Option> { /// [`MethodBody::from_raw`] (skips bounds-checking against the stub's /// 4-byte code_size) and [`MethodBody::write_to`] for serialization. fn parse_variant_a_blob(data: &[u8], assembly: &CilObject) -> Result)>> { - let read_u32 = |off: usize| -> u32 { - u32::from_le_bytes([data[off], data[off + 1], data[off + 2], data[off + 3]]) + let read_u32 = |off: usize| -> Result { + let end = off + .checked_add(4) + .ok_or_else(|| Error::Deobfuscation("offset overflow in blob read".into()))?; + let slice = data.get(off..end).ok_or(out_of_bounds_error!())?; + let arr: [u8; 4] = slice.try_into().map_err(|_| { + Error::Deobfuscation("blob read: 4-byte slice conversion failed".into()) + })?; + Ok(u32::from_le_bytes(arr)) }; if data.len() < 24 { @@ -631,9 +676,17 @@ fn parse_variant_a_blob(data: &[u8], assembly: &CilObject) -> Result Result = HashMap::new(); for i in 0..group_count { - let rva = read_u32(24 + i * 8); - let val = read_u32(24 + i * 8 + 4); + let entry_off = 24_usize + .checked_add(i.checked_mul(8).ok_or_else(|| { + Error::Deobfuscation("group entry offset overflow in variant-A blob".into()) + })?) + .ok_or_else(|| { + Error::Deobfuscation("group entry offset overflow in variant-A blob".into()) + })?; + let rva = read_u32(entry_off)?; + let val = read_u32(entry_off.checked_add(4).ok_or_else(|| { + Error::Deobfuscation("group entry val offset overflow in variant-A blob".into()) + })?)?; header_patches.insert(rva, val); } - let method_count = read_u32(header_end) as usize; - let mut offset = header_end + 4; + let method_count = read_u32(header_end)? as usize; + let mut offset = header_end + .checked_add(4) + .ok_or_else(|| Error::Deobfuscation("offset overflow after header_end".into()))?; // RVA → MethodDef token. Variant A's per-method records reference the IL // start address (RVA + 1 for tiny, RVA + 12 for fat), so a hit at @@ -666,41 +730,57 @@ fn parse_variant_a_blob(data: &[u8], assembly: &CilObject) -> Result data.len() { + let record_end = offset.checked_add(12).ok_or_else(|| { + Error::Deobfuscation("record offset overflow in variant-A blob".into()) + })?; + if record_end > data.len() { break; } - let il_start_rva = read_u32(offset); - let v2 = read_u32(offset + 4); - let il_size = read_u32(offset + 8) as usize; - offset += 12; + let il_start_rva = read_u32(offset)?; + let v2 = + read_u32(offset.checked_add(4).ok_or_else(|| { + Error::Deobfuscation("v2 offset overflow in variant-A blob".into()) + })?)?; + let il_size = read_u32(offset.checked_add(8).ok_or_else(|| { + Error::Deobfuscation("il_size offset overflow in variant-A blob".into()) + })?)? as usize; + offset = record_end; if il_start_rva == 0 && v2 == 0 && il_size == 0 { break; } - if offset + il_size > data.len() { + let il_end = offset + .checked_add(il_size) + .ok_or_else(|| Error::Deobfuscation("il_end overflow in variant-A blob".into()))?; + if il_end > data.len() { log::warn!( "NecroBit: truncated IL data at RVA 0x{il_start_rva:04X} (need {il_size}, have {})", - data.len() - offset + data.len().saturating_sub(offset) ); break; } - let il_bytes = &data[offset..offset + il_size]; - offset += il_size; - - let (method_token, is_fat) = if let Some(&token) = rva_to_token.get(&(il_start_rva - 1)) { - (token, false) - } else if let Some(&token) = rva_to_token.get(&(il_start_rva - 12)) { - (token, true) - } else { - log::warn!("NecroBit: no method found for IL start RVA 0x{il_start_rva:04X}"); - continue; - }; + let il_bytes = data.get(offset..il_end).ok_or(out_of_bounds_error!())?; + offset = il_end; + + let tiny_key = il_start_rva.checked_sub(1); + let fat_key = il_start_rva.checked_sub(12); + let (method_token, is_fat) = + if let Some(&token) = tiny_key.and_then(|k| rva_to_token.get(&k)) { + (token, false) + } else if let Some(&token) = fat_key.and_then(|k| rva_to_token.get(&k)) { + (token, true) + } else { + log::warn!("NecroBit: no method found for IL start RVA 0x{il_start_rva:04X}"); + continue; + }; let body_bytes = if is_fat { - let method_rva = il_start_rva - 12; + let method_rva = il_start_rva.checked_sub(12).ok_or_else(|| { + Error::Deobfuscation("method_rva underflow for fat method".into()) + })?; let (max_stack, is_init_local, local_var_sig_token) = - resolve_fat_header_metadata(method_rva, v2, &header_patches); + resolve_fat_header_metadata(method_rva, v2, &header_patches)?; let exception_handlers = read_on_disk_exception_handlers(assembly, method_rva); let new_body = MethodBody { @@ -714,7 +794,7 @@ fn parse_variant_a_blob(data: &[u8], assembly: &CilObject) -> Result Result, -) -> (usize, bool, u32) { +) -> Result<(usize, bool, u32)> { // StandAloneSig entries patch the LocalVarSig field (at header + 8). + let sig_key = method_rva.checked_add(8).ok_or_else(|| { + Error::Deobfuscation("method_rva + 8 overflow in fat header lookup".into()) + })?; let local_var_sig_token = header_patches - .get(&(method_rva + 8)) + .get(&sig_key) .copied() .filter(|v| (v >> 24) == 0x11) .unwrap_or(0); @@ -786,7 +869,7 @@ fn resolve_fat_header_metadata( }; let is_init_local = (flags_and_size & 0x0010) != 0; - (maxstack as usize, is_init_local, local_var_sig_token) + Ok((maxstack as usize, is_init_local, local_var_sig_token)) } /// Returns the structured exception handlers from the on-disk stub method @@ -847,7 +930,9 @@ fn extract_bodies_from_image( continue; }; - let addr = image_base + u64::from(rva); + let addr = image_base.checked_add(u64::from(rva)).ok_or_else(|| { + Error::Deobfuscation("addr overflow computing image_base + rva".into()) + })?; let available = image_size.saturating_sub(u64::from(rva)) as usize; if available == 0 { continue; @@ -856,7 +941,7 @@ fn extract_bodies_from_image( let buffer = match addr_space.read(addr, available) { Ok(b) => b, Err(_) => { - read_failures += 1; + read_failures = read_failures.saturating_add(1); continue; } }; @@ -864,7 +949,7 @@ fn extract_bodies_from_image( let body = match MethodBody::from(&buffer) { Ok(b) => b, Err(_) => { - parse_failures += 1; + parse_failures = parse_failures.saturating_add(1); continue; } }; @@ -875,15 +960,15 @@ fn extract_bodies_from_image( // here a variant-B fallback run on a variant-A binary would silently // overwrite real bodies with fat-wrapped stubs. if body.size_code == 4 { - still_stubs += 1; + still_stubs = still_stubs.saturating_add(1); continue; } let total_size = body.size(); - if total_size > buffer.len() { + let Some(body_slice) = buffer.get(..total_size) else { continue; - } - bodies.push((token, buffer[..total_size].to_vec())); + }; + bodies.push((token, body_slice.to_vec())); } if bodies.is_empty() && !stub_tokens.is_empty() { diff --git a/dotscope/src/deobfuscation/techniques/netreactor/resources.rs b/dotscope/src/deobfuscation/techniques/netreactor/resources.rs index b4870e0a..c6a2a339 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/resources.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/resources.rs @@ -475,7 +475,7 @@ impl Technique for NetReactorResources { .build(&mut cil_assembly) { Ok(_) => { - injected += 1; + injected = injected.saturating_add(1); events.record(EventKind::ResourceDecrypted).message(format!( "Injected NR-decrypted resource {:?} ({} bytes)", name, @@ -718,35 +718,43 @@ fn is_lazy_init_body( .filter(|i| !matches!(i.mnemonic, "nop" | "br" | "br.s")) .collect(); - if instrs.len() < 7 { + let (Some(i0), Some(i1), Some(i2), Some(i3), Some(i4), Some(i5), Some(i6)) = ( + instrs.first(), + instrs.get(1), + instrs.get(2), + instrs.get(3), + instrs.get(4), + instrs.get(5), + instrs.get(6), + ) else { return false; - } + }; - let Operand::Token(flag_load) = &instrs[0].operand else { + let Operand::Token(flag_load) = &i0.operand else { return false; }; - if instrs[0].mnemonic != "ldsfld" { + if i0.mnemonic != "ldsfld" { return false; } - if !matches!(instrs[1].mnemonic, "brtrue" | "brtrue.s") { + if !matches!(i1.mnemonic, "brtrue" | "brtrue.s") { return false; } - if !matches!(instrs[2].mnemonic, "ldc.i4.1") { + if !matches!(i2.mnemonic, "ldc.i4.1") { return false; } - if instrs[3].mnemonic != "stsfld" { + if i3.mnemonic != "stsfld" { return false; } - let Operand::Token(flag_store) = &instrs[3].operand else { + let Operand::Token(flag_store) = &i3.operand else { return false; }; if flag_load != flag_store { return false; } - if instrs[4].mnemonic != "newobj" { + if i4.mnemonic != "newobj" { return false; } - let Operand::Token(ctor_token) = &instrs[4].operand else { + let Operand::Token(ctor_token) = &i4.operand else { return false; }; // The newobj must target a .ctor on the resolver type. @@ -759,10 +767,10 @@ fn is_lazy_init_body( if declaring.token != cil_type.token { return false; } - if instrs[5].mnemonic != "pop" { + if i5.mnemonic != "pop" { return false; } - if instrs[6].mnemonic != "ret" { + if i6.mnemonic != "ret" { return false; } true @@ -850,19 +858,22 @@ fn classify_purely_injected_cctors(assembly: &CilObject, lazy_init_token: Token) .instructions() .filter(|i| !matches!(i.mnemonic, "nop" | "br" | "br.s")) .collect(); + let (Some(i0), Some(i1)) = (instrs.first(), instrs.get(1)) else { + continue; + }; if instrs.len() != 2 { continue; } - if instrs[0].mnemonic != "call" { + if i0.mnemonic != "call" { continue; } - let Operand::Token(t) = &instrs[0].operand else { + let Operand::Token(t) = &i0.operand else { continue; }; if *t != lazy_init_token { continue; } - if instrs[1].mnemonic != "ret" { + if i1.mnemonic != "ret" { continue; } out.push(method.token); @@ -900,7 +911,10 @@ fn find_get_manifest_resource_names_shims( if method.signature.params.len() != 1 { continue; } - if !matches!(method.signature.params[0].base, TypeSignature::Class(_)) { + let Some(first_param) = method.signature.params.first() else { + continue; + }; + if !matches!(first_param.base, TypeSignature::Class(_)) { continue; } if !method.has_body() { @@ -1029,7 +1043,9 @@ fn find_assembly_load_shim_methods(assembly: &CilObject, cil_type: &CilTypeRc) - if method.signature.params.len() != 1 { continue; } - let param = &method.signature.params[0]; + let Some(param) = method.signature.params.first() else { + continue; + }; let is_byte_array = match ¶m.base { TypeSignature::SzArray(inner) => matches!(*inner.base, TypeSignature::U1), _ => false, diff --git a/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs b/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs index 64c17e71..4dd6ee1b 100644 --- a/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs +++ b/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs @@ -287,8 +287,11 @@ fn extract_xor_key_from_cctor(assembly: &CilObject, cctor_token: Token) -> Optio let instructions: Vec<_> = method.instructions().collect(); for window in instructions.windows(3) { - if window[0].mnemonic == "xor" && window[2].mnemonic == "xor" { - if let Some(val) = window[1].get_i32_operand() { + let (Some(w0), Some(w1), Some(w2)) = (window.first(), window.get(1), window.get(2)) else { + continue; + }; + if w0.mnemonic == "xor" && w2.mnemonic == "xor" { + if let Some(val) = w1.get_i32_operand() { if (0..=255).contains(&val) { #[allow(clippy::cast_sign_loss, clippy::cast_possible_truncation)] return Some(val as u8); diff --git a/dotscope/src/deobfuscation/techniques/registry.rs b/dotscope/src/deobfuscation/techniques/registry.rs index 0278a9b1..e06e39c3 100644 --- a/dotscope/src/deobfuscation/techniques/registry.rs +++ b/dotscope/src/deobfuscation/techniques/registry.rs @@ -203,7 +203,11 @@ impl TechniqueRegistry { // Check cache: if the detection generation hasn't changed, reuse indices. if let Ok(cache) = self.sorted_cache.lock() { if cache.0 == gen { - return cache.1.iter().map(|&i| &*self.techniques[i]).collect(); + return cache + .1 + .iter() + .filter_map(|&i| self.techniques.get(i).map(|t| &**t)) + .collect(); } } @@ -211,7 +215,7 @@ impl TechniqueRegistry { let sorted_indices = self.compute_sorted_indices(detections); let result: Vec<&dyn Technique> = sorted_indices .iter() - .map(|&i| &*self.techniques[i]) + .filter_map(|&i| self.techniques.get(i).map(|t| &**t)) .collect(); if let Ok(mut cache) = self.sorted_cache.lock() { @@ -261,8 +265,12 @@ impl TechniqueRegistry { for &req_id in tech.requires() { if let Some(&req_idx) = id_to_idx.get(req_id) { // req_idx -> idx (req must come before this technique) - dependents[req_idx].push(idx); - in_degree[idx] += 1; + if let Some(deps) = dependents.get_mut(req_idx) { + deps.push(idx); + } + if let Some(d) = in_degree.get_mut(idx) { + *d = d.saturating_add(1); + } } // Missing dependency -> treat as satisfied (may be from a different phase) } @@ -278,11 +286,16 @@ impl TechniqueRegistry { let mut sorted: Vec = Vec::with_capacity(n); while let Some(idx) = queue.pop_front() { - sorted.push(eligible[idx].0); - for &dep_idx in &dependents[idx] { - in_degree[dep_idx] -= 1; - if in_degree[dep_idx] == 0 { - queue.push_back(dep_idx); + if let Some((orig_idx, _)) = eligible.get(idx) { + sorted.push(*orig_idx); + } + let deps_for_idx: Vec = dependents.get(idx).cloned().unwrap_or_default(); + for dep_idx in deps_for_idx { + if let Some(d) = in_degree.get_mut(dep_idx) { + *d = d.saturating_sub(1); + if *d == 0 { + queue.push_back(dep_idx); + } } } } @@ -291,7 +304,7 @@ impl TechniqueRegistry { if sorted.len() < n { log::warn!( "Technique dependency cycle detected: {} techniques could not be topologically sorted", - n - sorted.len() + n.saturating_sub(sorted.len()) ); let sorted_set: HashSet = sorted.iter().copied().collect(); for &(orig_idx, _) in &eligible { diff --git a/dotscope/src/deobfuscation/template.rs b/dotscope/src/deobfuscation/template.rs index 710fd449..3732cd3e 100644 --- a/dotscope/src/deobfuscation/template.rs +++ b/dotscope/src/deobfuscation/template.rs @@ -305,11 +305,11 @@ impl EmulationTemplatePool { // Multi-pass fork-based warmup for the targeted .cctors let mut completed: HashSet = HashSet::new(); let mut permanently_failed: HashSet = HashSet::new(); - let mut pass = 0; + let mut pass: u32 = 0; loop { - pass += 1; - let mut new_completions = 0; + pass = pass.saturating_add(1); + let mut new_completions: u32 = 0; for cctor in cctors { if completed.contains(cctor) || permanently_failed.contains(cctor) { @@ -328,7 +328,7 @@ impl EmulationTemplatePool { ); process = fork; completed.insert(*cctor); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } Ok(EmulationOutcome::UnhandledException { instructions, .. }) => { debug!( @@ -337,7 +337,7 @@ impl EmulationTemplatePool { ); process = fork; permanently_failed.insert(*cctor); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } Ok(EmulationOutcome::LimitReached { ref limit, .. }) => { debug!( @@ -346,7 +346,7 @@ impl EmulationTemplatePool { ); process = fork; permanently_failed.insert(*cctor); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } Ok(outcome) => { debug!( @@ -370,7 +370,7 @@ impl EmulationTemplatePool { if is_resource_limit { process = fork; permanently_failed.insert(*cctor); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } else if pass == 1 { debug!( "Targeted warmup: .cctor 0x{:08X} failed: {} (pass {})", @@ -442,7 +442,7 @@ impl EmulationTemplatePool { let mut permanently_failed = HashSet::new(); for pass in 1..=self.config.emulation.warmup_retry_passes { - let mut new_completions = 0; + let mut new_completions: u32 = 0; for (warmup_token, warmup_args) in &warmup_methods { if completed.contains(warmup_token) || permanently_failed.contains(warmup_token) { @@ -462,7 +462,7 @@ impl EmulationTemplatePool { ); *process = fork; completed.insert(*warmup_token); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } Ok(EmulationOutcome::UnhandledException { instructions, .. }) => { debug!( @@ -471,7 +471,7 @@ impl EmulationTemplatePool { ); *process = fork; permanently_failed.insert(*warmup_token); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } Ok(EmulationOutcome::LimitReached { ref limit, .. }) => { debug!( @@ -480,7 +480,7 @@ impl EmulationTemplatePool { ); *process = fork; permanently_failed.insert(*warmup_token); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } Ok(outcome) => { debug!( @@ -508,7 +508,7 @@ impl EmulationTemplatePool { ); *process = fork; permanently_failed.insert(*warmup_token); - new_completions += 1; + new_completions = new_completions.saturating_add(1); } else if pass == 1 { debug!( "Template warmup: 0x{:08X} failed: {} (pass {})", diff --git a/dotscope/src/deobfuscation/utils.rs b/dotscope/src/deobfuscation/utils.rs index 4c39db73..1da9336b 100644 --- a/dotscope/src/deobfuscation/utils.rs +++ b/dotscope/src/deobfuscation/utils.rs @@ -207,14 +207,14 @@ pub(crate) fn build_call_site_counts( if instr.mnemonic == "call" || instr.mnemonic == "callvirt" { if let Some(token) = instr.get_token_operand() { if let Some(count) = counts.get_mut(&token) { - *count += 1; + *count = count.saturating_add(1); } else if token.is_table(TableId::MemberRef) { let resolved = memberref_cache .entry(token) .or_insert_with(|| assembly.resolver().resolve_memberref_method(token)); if let Some(resolved_token) = resolved { if let Some(count) = counts.get_mut(resolved_token) { - *count += 1; + *count = count.saturating_add(1); } } } @@ -562,24 +562,32 @@ pub(crate) fn build_init_array_map(assembly: &CilObject) -> HashMap= instructions.len() { + let Some(next_idx) = i.checked_add(1) else { + continue; + }; + if i < 1 || next_idx >= instructions.len() { continue; } // Find ldtoken before the call (within 3 instructions back) let mut backing_field_token = None; for j in (0..i).rev() { - if instructions[j].mnemonic == "ldtoken" { - backing_field_token = instructions[j].get_token_operand(); + let Some(prev_instr) = instructions.get(j) else { + break; + }; + if prev_instr.mnemonic == "ldtoken" { + backing_field_token = prev_instr.get_token_operand(); break; } - if i - j > 3 { + if i.saturating_sub(j) > 3 { break; } } // Find stsfld after the call - let stsfld_instr = &instructions[i + 1]; + let Some(stsfld_instr) = instructions.get(next_idx) else { + continue; + }; if stsfld_instr.mnemonic != "stsfld" { continue; } @@ -605,8 +613,12 @@ pub(crate) fn is_guid_name(name: &str) -> bool { return false; } let bytes = name.as_bytes(); - if bytes[8] != b'-' || bytes[13] != b'-' || bytes[18] != b'-' || bytes[23] != b'-' { - return false; + let dash_positions = [8usize, 13, 18, 23]; + for &pos in &dash_positions { + match bytes.get(pos) { + Some(&b'-') => {} + _ => return false, + } } bytes.iter().enumerate().all(|(i, &b)| { if i == 8 || i == 13 || i == 18 || i == 23 { diff --git a/dotscope/src/deobfuscation/workqueue.rs b/dotscope/src/deobfuscation/workqueue.rs index 08d1be1d..fc15d0e3 100644 --- a/dotscope/src/deobfuscation/workqueue.rs +++ b/dotscope/src/deobfuscation/workqueue.rs @@ -202,11 +202,12 @@ impl WorkQueue { /// it counts as 1 if the flag is set. pub fn len(&self) -> usize { let assembly = usize::from(self.redetect_assembly.load(Ordering::Acquire)); - self.build_ssa.len() - + self.inject_ssa.len() - + self.redetect_methods.len() - + self.redetect_types.len() - + assembly + self.build_ssa + .len() + .saturating_add(self.inject_ssa.len()) + .saturating_add(self.redetect_methods.len()) + .saturating_add(self.redetect_types.len()) + .saturating_add(assembly) } } diff --git a/dotscope/src/emulation/capture/context.rs b/dotscope/src/emulation/capture/context.rs index 1256a5a2..7408db70 100644 --- a/dotscope/src/emulation/capture/context.rs +++ b/dotscope/src/emulation/capture/context.rs @@ -369,7 +369,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.assembly_bytes += len; + stats.assembly_bytes = stats.assembly_bytes.saturating_add(len); } } @@ -450,7 +450,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.string_count += 1; + stats.string_count = stats.string_count.saturating_add(1); } } @@ -489,7 +489,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.string_count += 1; + stats.string_count = stats.string_count.saturating_add(1); } } @@ -547,7 +547,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.buffer_bytes += len; + stats.buffer_bytes = stats.buffer_bytes.saturating_add(len); } } @@ -645,7 +645,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.file_op_count += 1; + stats.file_op_count = stats.file_op_count.saturating_add(1); } } @@ -694,7 +694,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.network_op_count += 1; + stats.network_op_count = stats.network_op_count.saturating_add(1); } } @@ -750,7 +750,7 @@ impl CaptureContext { let base = range.start; // Safe: capture memory ranges fit in usize #[allow(clippy::cast_possible_truncation)] - let size = (range.end - range.start) as usize; + let size = range.end.saturating_sub(range.start) as usize; address_space .read(base, size) .ok() @@ -778,7 +778,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.snapshot_count += 1; + stats.snapshot_count = stats.snapshot_count.saturating_add(1); } } @@ -812,7 +812,7 @@ impl CaptureContext { let base = range.start; // Safe: capture memory ranges fit in usize #[allow(clippy::cast_possible_truncation)] - let size = (range.end - range.start) as usize; + let size = range.end.saturating_sub(range.start) as usize; address_space .read(base, size) .ok() @@ -840,7 +840,7 @@ impl CaptureContext { } if let Ok(mut stats) = self.stats.write() { - stats.snapshot_count += 1; + stats.snapshot_count = stats.snapshot_count.saturating_add(1); } } diff --git a/dotscope/src/emulation/engine/callresolver.rs b/dotscope/src/emulation/engine/callresolver.rs index 4018c26c..d29e7727 100644 --- a/dotscope/src/emulation/engine/callresolver.rs +++ b/dotscope/src/emulation/engine/callresolver.rs @@ -484,7 +484,7 @@ impl CallResolver { let total_args = if is_static { param_types.len() } else { - param_types.len() + 1 + param_types.len().saturating_add(1) }; let arg_values = thread.pop_args(total_args)?; let expects_return = context.method_returns_value(method_token)?; @@ -581,7 +581,7 @@ impl CallResolver { if let Some(member_ref) = context.get_member_ref(method_token) { if let MemberRefSignature::Method(method_sig) = &member_ref.signature { let total_args = if method_sig.has_this { - method_sig.param_count as usize + 1 + (method_sig.param_count as usize).saturating_add(1) } else { method_sig.param_count as usize }; @@ -593,7 +593,9 @@ impl CallResolver { // hook lookup using the runtime type of `this`. if is_virtual && method_sig.has_this && total_args > 0 { let args = thread.peek_args(total_args)?; - let this_arg = &args[0]; + let Some((this_arg, rest_args)) = args.split_first() else { + return Err(out_of_bounds_error!()); + }; if let EmValue::ObjectRef(heap_ref) = this_arg { if let Ok(runtime_type_token) = thread.heap().get_type_token(*heap_ref) @@ -629,7 +631,7 @@ impl CallResolver { let (this_ref, method_args): ( Option<&EmValue>, &[EmValue], - ) = (Some(&args[0]), &args[1..]); + ) = (Some(this_arg), rest_args); let hook_context = HookContext::new( method_token, @@ -756,7 +758,7 @@ impl CallResolver { let param_count = method.signature.params.len(); let is_instance = !context.is_static_method(method_token)?; let total_args = if is_instance { - param_count + 1 + param_count.saturating_add(1) } else { param_count }; @@ -776,16 +778,25 @@ impl CallResolver { // ECMA-335 III.2.1: If the constraint type is a value type and it does NOT // override the method, box the value type and use the boxed ref as 'this'. if resolved == method_token && context.is_value_type(constraint_token) { - if let EmValue::ManagedPtr(ptr) = &arg_values[0] { - if let Ok(value) = typeops::deref_managed_ptr(address_space, thread, ptr) { + let ptr_opt = match arg_values.first() { + Some(EmValue::ManagedPtr(ptr)) => Some(ptr.clone()), + _ => None, + }; + if let Some(ptr) = ptr_opt { + if let Ok(value) = typeops::deref_managed_ptr(address_space, thread, &ptr) { let boxed = thread.heap_mut().alloc_boxed(constraint_token, value)?; - arg_values[0] = EmValue::ObjectRef(boxed); + if let Some(slot) = arg_values.first_mut() { + *slot = EmValue::ObjectRef(boxed); + } } } } resolved } else { - self.resolve_virtual_dispatch(context, thread, method_token, &arg_values[0]) + let Some(this_arg) = arg_values.first() else { + return Err(out_of_bounds_error!()); + }; + self.resolve_virtual_dispatch(context, thread, method_token, this_arg) } } else { method_token @@ -818,7 +829,7 @@ impl CallResolver { && is_instance && !arg_values.is_empty() { - if let EmValue::ObjectRef(href) = &arg_values[0] { + if let Some(EmValue::ObjectRef(href)) = arg_values.first() { match thread.heap().get(*href) { Ok(HeapObject::Delegate { invocation_list, .. @@ -912,12 +923,16 @@ impl CallResolver { }; if should_dispatch { - let delegate_args: Vec = arg_values[1..].to_vec(); + let delegate_args: Vec = + arg_values.get(1..).unwrap_or(&[]).to_vec(); // Set up multicast state if there are more entries if invocation_list.len() > 1 { thread.set_multicast_state(MulticastState { - remaining_entries: invocation_list[1..].to_vec(), + remaining_entries: invocation_list + .get(1..) + .unwrap_or(&[]) + .to_vec(), delegate_args: delegate_args.clone(), dispatch_depth: thread.call_depth(), }); @@ -935,12 +950,16 @@ impl CallResolver { } if target_token.is_table(TableId::MemberRef) { - let delegate_args: Vec = arg_values[1..].to_vec(); + let delegate_args: Vec = + arg_values.get(1..).unwrap_or(&[]).to_vec(); // Set up multicast state if there are more entries if invocation_list.len() > 1 { thread.set_multicast_state(MulticastState { - remaining_entries: invocation_list[1..].to_vec(), + remaining_entries: invocation_list + .get(1..) + .unwrap_or(&[]) + .to_vec(), delegate_args: delegate_args.clone(), dispatch_depth: thread.call_depth(), }); @@ -987,7 +1006,8 @@ impl CallResolver { } else { debug!( "Delegate Invoke on {:?}: this is {:?}, not an ObjectRef", - resolved_method_token, arg_values[0], + resolved_method_token, + arg_values.first(), ); } } @@ -1401,7 +1421,7 @@ impl CallResolver { info: &ResolvedMethodInfo, ) -> Result { let total_args = if info.has_this { - info.param_count + 1 + info.param_count.saturating_add(1) } else { info.param_count }; @@ -1417,12 +1437,14 @@ impl CallResolver { let return_type = context.get_return_type(method_token).ok().flatten(); // Split into this and method args - let (this_ref, method_args): (Option<&EmValue>, &[EmValue]) = - if info.has_this && !args.is_empty() { - (Some(&args[0]), &args[1..]) - } else { - (None, &args[..]) - }; + let (this_ref, method_args): (Option<&EmValue>, &[EmValue]) = if info.has_this { + match args.split_first() { + Some((first, rest)) => (Some(first), rest), + None => (None, args.as_slice()), + } + } else { + (None, args.as_slice()) + }; let hook_context = HookContext::new( method_token, diff --git a/dotscope/src/emulation/engine/controller.rs b/dotscope/src/emulation/engine/controller.rs index 91338e56..86750b92 100644 --- a/dotscope/src/emulation/engine/controller.rs +++ b/dotscope/src/emulation/engine/controller.rs @@ -67,7 +67,7 @@ use crate::{ token::Token, typesystem::CilFlavor, }, - CilObject, Result, + CilObject, Error, Result, }; /// Control flow directive returned by extracted handler methods. @@ -1331,7 +1331,7 @@ impl EmulationController { current_method: Token, err: crate::Error, ) -> Result { - let crate::Error::Emulation(ref emu_err) = err else { + let Error::Emulation(ref emu_err) = err else { return Err(err); }; if !emu_err.is_clr_exception() { @@ -2041,9 +2041,9 @@ impl EmulationController { // Save return info from the current frame before popping it. // The tail-called method must return to our CALLER, not to us. - let mut popped = thread - .pop_frame() - .expect("tail call requires an active frame"); + let mut popped = thread.pop_frame().ok_or(EmulationError::InternalError { + description: "tail call requires an active frame".to_string(), + })?; let return_method = popped.return_method(); let return_offset = popped.return_offset(); let caller_stack = popped.take_caller_stack(); @@ -2100,9 +2100,9 @@ impl EmulationController { } } - let mut popped = thread - .pop_frame() - .expect("tail call requires an active frame"); + let mut popped = thread.pop_frame().ok_or(EmulationError::InternalError { + description: "tail call requires an active frame".to_string(), + })?; let return_method = popped.return_method(); let return_offset = popped.return_offset(); let caller_stack = popped.take_caller_stack(); @@ -2252,7 +2252,7 @@ impl EmulationController { let param_count = method_sig.param_count as usize; let has_this = method_sig.has_this; let total_args = if has_this { - param_count + 1 + param_count.saturating_add(1) } else { param_count }; @@ -2460,7 +2460,7 @@ impl EmulationController { let param_count = target_method.signature.params.len(); let is_instance = !context.is_static_method(method)?; let total_args = if is_instance { - param_count + 1 + param_count.saturating_add(1) } else { param_count }; @@ -2720,18 +2720,25 @@ impl EmulationController { impl Default for EmulationController { fn default() -> Self { - let address_space = Arc::new(AddressSpace::new()); - let fake_objects = SharedFakeObjects::new(address_space.managed_heap()); - let context = Arc::new(ThreadContext::new( - address_space, - Arc::new(RwLock::new(RuntimeState::new())), - Arc::new(CaptureContext::default()), - Arc::new(EmulationConfig::default()), - None, - fake_objects, - Arc::new(VirtualFs::new()), - )); - Self::new(context, None).expect("default EmulationController creation should not fail") + // Loop with a fresh RwLock each iteration. `CallResolver::new` only fails + // on RwLock poison, which is impossible on a freshly-allocated lock; the + // loop is therefore guaranteed to terminate on the first iteration. + loop { + let address_space = Arc::new(AddressSpace::new()); + let fake_objects = SharedFakeObjects::new(address_space.managed_heap()); + let context = Arc::new(ThreadContext::new( + address_space, + Arc::new(RwLock::new(RuntimeState::new())), + Arc::new(CaptureContext::default()), + Arc::new(EmulationConfig::default()), + None, + fake_objects, + Arc::new(VirtualFs::new()), + )); + if let Ok(ctrl) = Self::new(context, None) { + return ctrl; + } + } } } diff --git a/dotscope/src/emulation/engine/exhandler.rs b/dotscope/src/emulation/engine/exhandler.rs index eb94b60c..3f944f4a 100644 --- a/dotscope/src/emulation/engine/exhandler.rs +++ b/dotscope/src/emulation/engine/exhandler.rs @@ -493,7 +493,7 @@ pub fn schedule_finally_blocks( // The last one scheduled will be popped first for (i, (handler_offset, _)) in finally_blocks.iter().enumerate() { // The last finally should have the actual leave target - let target = if i == finally_blocks.len() - 1 { + let target = if i == finally_blocks.len().saturating_sub(1) { Some(leave_target) } else { None diff --git a/dotscope/src/emulation/engine/interpreter/handlers.rs b/dotscope/src/emulation/engine/interpreter/handlers.rs index 564af843..ae5f8291 100644 --- a/dotscope/src/emulation/engine/interpreter/handlers.rs +++ b/dotscope/src/emulation/engine/interpreter/handlers.rs @@ -483,10 +483,8 @@ impl Interpreter { return Err(Self::invalid_operand(instruction, "switch branch targets")); } - if index < branch_targets.len() { - Ok(StepResult::Branch { - target: branch_targets[index], - }) + if let Some(&target) = branch_targets.get(index) { + Ok(StepResult::Branch { target }) } else { // Fall through if index is out of range Ok(StepResult::Continue) @@ -674,13 +672,13 @@ impl Interpreter { } HeapObject::Array { elements, .. } if elements.len() == 1 => { // Check if it's a single-element array containing an integer - if let Some(int_val) = elements[0].try_to_i64() { + if let Some(int_val) = elements.first().and_then(EmValue::try_to_i64) { return Some(int_val); } } HeapObject::MultiArray { elements, .. } if elements.len() == 1 => { // Check if it's a single-element multi-array - if let Some(int_val) = elements[0].try_to_i64() { + if let Some(int_val) = elements.first().and_then(EmValue::try_to_i64) { return Some(int_val); } } @@ -781,9 +779,9 @@ impl Interpreter { .into()); } - let array_idx = usize::try_from(idx).ok().filter(|&i| i < elements.len()); - match array_idx { - Some(i) => elements[i].clone(), + let array_idx = usize::try_from(idx).ok(); + match array_idx.and_then(|i| elements.get(i)) { + Some(v) => v.clone(), None => { return Err(EmulationError::ArrayIndexOutOfBounds { index: idx, @@ -1119,21 +1117,12 @@ impl Interpreter { let local_value = frame .and_then(|f| f.locals().get(usize::from(*idx)).ok()) .cloned(); - match local_value.as_ref() { - Some(EmValue::ObjectRef(href)) => load_from_href(thread, *href), - Some(EmValue::ValueType { .. }) | Some(EmValue::Void) => { - load_from_valuetype(thread, local_value.unwrap()) - } - Some(_) => { - let value = thread.deref_pointer(&ptr)?; - if matches!(value, EmValue::ValueType { .. } | EmValue::Void) { - load_from_valuetype(thread, value) - } else { - thread.push(value)?; - Ok(StepResult::Continue) - } + match local_value { + Some(EmValue::ObjectRef(href)) => load_from_href(thread, href), + Some(v @ (EmValue::ValueType { .. } | EmValue::Void)) => { + load_from_valuetype(thread, v) } - None => { + _ => { let value = thread.deref_pointer(&ptr)?; if matches!(value, EmValue::ValueType { .. } | EmValue::Void) { load_from_valuetype(thread, value) @@ -1151,21 +1140,12 @@ impl Interpreter { let arg_value = frame .and_then(|f| f.arguments().get(usize::from(*idx)).ok()) .cloned(); - match arg_value.as_ref() { - Some(EmValue::ObjectRef(href)) => load_from_href(thread, *href), - Some(EmValue::ValueType { .. }) | Some(EmValue::Void) => { - load_from_valuetype(thread, arg_value.unwrap()) + match arg_value { + Some(EmValue::ObjectRef(href)) => load_from_href(thread, href), + Some(v @ (EmValue::ValueType { .. } | EmValue::Void)) => { + load_from_valuetype(thread, v) } - Some(_) => { - let value = thread.deref_pointer(&ptr)?; - if matches!(value, EmValue::ValueType { .. } | EmValue::Void) { - load_from_valuetype(thread, value) - } else { - thread.push(value)?; - Ok(StepResult::Continue) - } - } - None => { + _ => { let value = thread.deref_pointer(&ptr)?; if matches!(value, EmValue::ValueType { .. } | EmValue::Void) { load_from_valuetype(thread, value) @@ -1288,7 +1268,9 @@ impl Interpreter { while fields.len() <= field_idx { fields.push(EmValue::I32(0)); // Default to 0 } - fields[field_idx] = value; + if let Some(slot) = fields.get_mut(field_idx) { + *slot = value; + } return Some(EmValue::ValueType { type_token, fields }); } } @@ -1323,19 +1305,16 @@ impl Interpreter { .or_else(|| thread.current_frame()); let local_value = frame.and_then(|f| f.locals().get(local_idx).ok()).cloned(); - match local_value.as_ref() { + match local_value { Some(EmValue::ObjectRef(href)) => { let heap = thread.heap_mut(); - heap.set_field(*href, field_token, value)?; + heap.set_field(href, field_token, value)?; Ok(StepResult::Continue) } - Some(EmValue::ValueType { .. }) => { - if let Some(updated) = store_into_valuetype( - thread, - local_value.unwrap(), - field_token, - value.clone(), - ) { + Some(vt @ EmValue::ValueType { .. }) => { + if let Some(updated) = + store_into_valuetype(thread, vt, field_token, value.clone()) + { // Write back to the owning frame thread .resolve_frame_mut(ptr.frame_depth) @@ -1364,19 +1343,16 @@ impl Interpreter { .or_else(|| thread.current_frame()); let arg_value = frame.and_then(|f| f.arguments().get(arg_idx).ok()).cloned(); - match arg_value.as_ref() { + match arg_value { Some(EmValue::ObjectRef(href)) => { let heap = thread.heap_mut(); - heap.set_field(*href, field_token, value)?; + heap.set_field(href, field_token, value)?; Ok(StepResult::Continue) } - Some(EmValue::ValueType { .. }) => { - if let Some(updated) = store_into_valuetype( - thread, - arg_value.unwrap(), - field_token, - value.clone(), - ) { + Some(vt @ EmValue::ValueType { .. }) => { + if let Some(updated) = + store_into_valuetype(thread, vt, field_token, value.clone()) + { thread .resolve_frame_mut(ptr.frame_depth) .ok_or_else(|| EmulationError::InternalError { @@ -1545,77 +1521,90 @@ impl Interpreter { _ => unreachable!(), }; + // Helper: read an exact-sized little-endian array from the address space. + // The address_space.read API returns the requested number of bytes on + // success; a length mismatch indicates an internal invariant violation. + let read_exact = |size: usize| -> Result> { + let bytes = address_space.read(ptr_addr, size)?; + if bytes.len() != size { + return Err(EmulationError::InternalError { + description: format!( + "address space read returned {} bytes, expected {}", + bytes.len(), + size + ), + } + .into()); + } + Ok(bytes) + }; + let read_array_2 = || -> Result<[u8; 2]> { + let bytes = read_exact(2)?; + <[u8; 2]>::try_from(bytes.as_slice()).map_err(|_| out_of_bounds_error!()) + }; + let read_array_4 = || -> Result<[u8; 4]> { + let bytes = read_exact(4)?; + <[u8; 4]>::try_from(bytes.as_slice()).map_err(|_| out_of_bounds_error!()) + }; + let read_array_8 = || -> Result<[u8; 8]> { + let bytes = read_exact(8)?; + <[u8; 8]>::try_from(bytes.as_slice()).map_err(|_| out_of_bounds_error!()) + }; + // Read from address space based on read_size and expected_type let value = match (expected_type, read_size) { // Small integer reads (1 or 2 bytes) that widen to I32 (&CilFlavor::I4, 1) => { - let bytes = address_space.read(ptr_addr, 1)?; + let bytes = read_exact(1)?; + let b0 = *bytes.first().ok_or(out_of_bounds_error!())?; let val = if signed { // Intentional wrap-around for sign extension from u8 to i8 - i32::from(bytes[0].cast_signed()) + i32::from(b0.cast_signed()) } else { - i32::from(bytes[0]) + i32::from(b0) }; EmValue::I32(val) } (&CilFlavor::I4, 2) => { - let bytes = address_space.read(ptr_addr, 2)?; + let arr = read_array_2()?; let val = if signed { - i32::from(i16::from_le_bytes([bytes[0], bytes[1]])) + i32::from(i16::from_le_bytes(arr)) } else { - i32::from(u16::from_le_bytes([bytes[0], bytes[1]])) + i32::from(u16::from_le_bytes(arr)) }; EmValue::I32(val) } (&CilFlavor::I4, _) => { - let bytes = address_space.read(ptr_addr, 4)?; - let val = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + let val = i32::from_le_bytes(read_array_4()?); EmValue::I32(val) } (&CilFlavor::I8, _) => { - let bytes = address_space.read(ptr_addr, 8)?; - let val = i64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - bytes[7], - ]); + let val = i64::from_le_bytes(read_array_8()?); EmValue::I64(val) } (&CilFlavor::I, _) => { // Native int is pointer-sized: 4 bytes on PE32, 8 bytes on PE32+ - let bytes = address_space.read(ptr_addr, read_size)?; let val = if read_size == 4 { - i64::from(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + i64::from(i32::from_le_bytes(read_array_4()?)) } else { - i64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], - bytes[6], bytes[7], - ]) + i64::from_le_bytes(read_array_8()?) }; EmValue::NativeInt(val) } (&CilFlavor::R4, _) => { - let bytes = address_space.read(ptr_addr, 4)?; - let val = f32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + let val = f32::from_le_bytes(read_array_4()?); EmValue::F32(val) } (&CilFlavor::R8, _) => { - let bytes = address_space.read(ptr_addr, 8)?; - let val = f64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], - bytes[7], - ]); + let val = f64::from_le_bytes(read_array_8()?); EmValue::F64(val) } (&CilFlavor::Object, _) => { // For object references, treat as pointer-sized value - let bytes = address_space.read(ptr_addr, read_size)?; let val = if read_size == 4 { - u64::from(u32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + u64::from(u32::from_le_bytes(read_array_4()?)) } else { - u64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], - bytes[6], bytes[7], - ]) + u64::from_le_bytes(read_array_8()?) }; EmValue::UnmanagedPtr(val) } diff --git a/dotscope/src/emulation/engine/pointer.rs b/dotscope/src/emulation/engine/pointer.rs index 4be66e2f..4c315a58 100644 --- a/dotscope/src/emulation/engine/pointer.rs +++ b/dotscope/src/emulation/engine/pointer.rs @@ -92,18 +92,18 @@ impl InstructionPointer { /// Returns the offset of the next instruction (fall-through address). #[must_use] pub fn next_offset(&self) -> u32 { - self.offset + self.current_size + self.offset.saturating_add(self.current_size) } /// Advances to the next instruction (sequential execution). pub fn advance(&mut self, instruction_size: u32) { - self.offset += instruction_size; + self.offset = self.offset.saturating_add(instruction_size); self.current_size = 0; } /// Advances using the stored current instruction size. pub fn advance_current(&mut self) { - self.offset += self.current_size; + self.offset = self.offset.saturating_add(self.current_size); self.current_size = 0; } diff --git a/dotscope/src/emulation/engine/stats.rs b/dotscope/src/emulation/engine/stats.rs index d617ed20..af0b9170 100644 --- a/dotscope/src/emulation/engine/stats.rs +++ b/dotscope/src/emulation/engine/stats.rs @@ -42,7 +42,7 @@ impl ExecutionStats { /// Increments the instruction counter. pub fn increment_instructions(&mut self) { - self.instructions_executed += 1; + self.instructions_executed = self.instructions_executed.saturating_add(1); } /// Returns the elapsed time since execution started. diff --git a/dotscope/src/emulation/engine/typeops/newobj.rs b/dotscope/src/emulation/engine/typeops/newobj.rs index 5b1f48b0..b32f3b9e 100644 --- a/dotscope/src/emulation/engine/typeops/newobj.rs +++ b/dotscope/src/emulation/engine/typeops/newobj.rs @@ -150,17 +150,17 @@ pub fn resolve_newobj( if !hook_handled && method.is_code_runtime() && arg_values.len() == 2 { if let Some(ref decl_type) = declaring_type { if is_delegate_type(decl_type) { - let target = match &arg_values[0] { - EmValue::ObjectRef(href) => Some(*href), - EmValue::Null => None, + let target = match arg_values.first() { + Some(EmValue::ObjectRef(href)) => Some(*href), + Some(EmValue::Null) => None, _ => None, }; - let method_token_value = match &arg_values[1] { - EmValue::UnmanagedPtr(ptr) => Some(Token::new(*ptr as u32)), - EmValue::I32(v) => Some(Token::new(*v as u32)), - EmValue::I64(v) => Some(Token::new(*v as u32)), - EmValue::NativeInt(v) => Some(Token::new(*v as u32)), + let method_token_value = match arg_values.get(1) { + Some(EmValue::UnmanagedPtr(ptr)) => Some(Token::new(*ptr as u32)), + Some(EmValue::I32(v)) => Some(Token::new(*v as u32)), + Some(EmValue::I64(v)) => Some(Token::new(*v as u32)), + Some(EmValue::NativeInt(v)) => Some(Token::new(*v as u32)), _ => None, }; @@ -289,17 +289,17 @@ fn resolve_newobj_memberref( if args.len() == 2 { if let Some(cil_type) = context.get_type(type_token) { if is_delegate_type(&cil_type) { - let target = match &args[0] { - EmValue::ObjectRef(href) => Some(*href), - EmValue::Null => None, + let target = match args.first() { + Some(EmValue::ObjectRef(href)) => Some(*href), + Some(EmValue::Null) => None, _ => None, }; - let method_token_value = match &args[1] { - EmValue::UnmanagedPtr(ptr) => Some(Token::new(*ptr as u32)), - EmValue::I32(v) => Some(Token::new(*v as u32)), - EmValue::I64(v) => Some(Token::new(*v as u32)), - EmValue::NativeInt(v) => Some(Token::new(*v as u32)), + let method_token_value = match args.get(1) { + Some(EmValue::UnmanagedPtr(ptr)) => Some(Token::new(*ptr as u32)), + Some(EmValue::I32(v)) => Some(Token::new(*v as u32)), + Some(EmValue::I64(v)) => Some(Token::new(*v as u32)), + Some(EmValue::NativeInt(v)) => Some(Token::new(*v as u32)), _ => None, }; diff --git a/dotscope/src/emulation/exception/types.rs b/dotscope/src/emulation/exception/types.rs index 33ec0078..5390304b 100644 --- a/dotscope/src/emulation/exception/types.rs +++ b/dotscope/src/emulation/exception/types.rs @@ -184,7 +184,7 @@ impl ExceptionClause { /// The ending IL offset of the protected region. #[must_use] pub fn try_end(&self) -> u32 { - self.try_offset() + self.try_length() + self.try_offset().saturating_add(self.try_length()) } /// Gets the IL offset where the handler block begins. @@ -226,7 +226,7 @@ impl ExceptionClause { /// The ending IL offset of the handler code. #[must_use] pub fn handler_end(&self) -> u32 { - self.handler_offset() + self.handler_length() + self.handler_offset().saturating_add(self.handler_length()) } /// Checks if an IL offset is within the try block. diff --git a/dotscope/src/emulation/filesystem.rs b/dotscope/src/emulation/filesystem.rs index 15c58a0b..499320e2 100644 --- a/dotscope/src/emulation/filesystem.rs +++ b/dotscope/src/emulation/filesystem.rs @@ -165,10 +165,13 @@ impl VirtualFs { let mut normalized = path.to_lowercase().replace('\\', "/"); // Strip leading drive letter (e.g., "c:/...") if normalized.len() >= 3 - && normalized.as_bytes()[0].is_ascii_alphabetic() - && &normalized[1..3] == ":/" + && normalized + .as_bytes() + .first() + .is_some_and(u8::is_ascii_alphabetic) + && normalized.get(1..3) == Some(":/") { - normalized = normalized[2..].to_string(); + normalized = normalized.get(2..).unwrap_or("").to_string(); } // Strip leading slash for consistency normalized.trim_start_matches('/').to_string() diff --git a/dotscope/src/emulation/loader/data.rs b/dotscope/src/emulation/loader/data.rs index f43bbd18..ef808730 100644 --- a/dotscope/src/emulation/loader/data.rs +++ b/dotscope/src/emulation/loader/data.rs @@ -50,7 +50,7 @@ use crate::{ emulation::memory::{AddressSpace, MemoryProtection, MemoryRegion}, - Result, + Error, Result, }; /// Information about a mapped data region in the emulation address space. @@ -378,8 +378,8 @@ impl DataLoader { address: u64, protection: MemoryProtection, ) -> Result { - let data = std::fs::read(path) - .map_err(|e| crate::Error::Other(format!("Failed to read file: {e}")))?; + let data = + std::fs::read(path).map_err(|e| Error::Other(format!("Failed to read file: {e}")))?; let label = path .file_name() diff --git a/dotscope/src/emulation/loader/peloader.rs b/dotscope/src/emulation/loader/peloader.rs index d59279b7..b8978f1f 100644 --- a/dotscope/src/emulation/loader/peloader.rs +++ b/dotscope/src/emulation/loader/peloader.rs @@ -319,7 +319,8 @@ impl LoadedImage { /// ``` #[must_use] pub fn entry_point_va(&self) -> Option { - self.entry_point.map(|rva| self.base_address + rva) + self.entry_point + .map(|rva| self.base_address.saturating_add(rva)) } /// Converts a Relative Virtual Address (RVA) to an absolute virtual address. @@ -344,7 +345,7 @@ impl LoadedImage { /// ``` #[must_use] pub fn rva_to_va(&self, rva: u32) -> u64 { - self.base_address + u64::from(rva) + self.base_address.saturating_add(u64::from(rva)) } /// Checks whether an RVA falls within the bounds of this image. @@ -392,7 +393,10 @@ impl LoadedImage { #[must_use] pub fn section_for_rva(&self, rva: u32) -> Option<&LoadedSection> { self.sections.iter().find(|s| { - rva >= s.virtual_address && rva < s.virtual_address + s.virtual_size.max(s.raw_size) + rva >= s.virtual_address + && rva + < s.virtual_address + .saturating_add(s.virtual_size.max(s.raw_size)) }) } @@ -692,7 +696,11 @@ impl PeLoader { .optional_header .map_or(0x200, |oh| oh.windows_fields.size_of_headers as usize); if headers_size <= pe_bytes.len() && headers_size <= image_data.len() { - image_data[..headers_size].copy_from_slice(&pe_bytes[..headers_size]); + let dst = image_data + .get_mut(..headers_size) + .ok_or(out_of_bounds_error!())?; + let src = pe_bytes.get(..headers_size).ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(src); } // Map each section @@ -730,11 +738,21 @@ impl PeLoader { let dest_offset = virtual_address as usize; let copy_size = raw_size.min(virtual_size) as usize; - if raw_offset + copy_size <= pe_bytes.len() - && dest_offset + copy_size <= image_data.len() - { - image_data[dest_offset..dest_offset + copy_size] - .copy_from_slice(&pe_bytes[raw_offset..raw_offset + copy_size]); + let raw_end = raw_offset + .checked_add(copy_size) + .ok_or_else(|| malformed_error!("PE section raw range overflow"))?; + let dest_end = dest_offset + .checked_add(copy_size) + .ok_or_else(|| malformed_error!("PE section virtual range overflow"))?; + + if raw_end <= pe_bytes.len() && dest_end <= image_data.len() { + let dst = image_data + .get_mut(dest_offset..dest_end) + .ok_or(out_of_bounds_error!())?; + let src = pe_bytes + .get(raw_offset..raw_end) + .ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(src); } let section_protection = if self.config.apply_permissions { @@ -767,7 +785,10 @@ impl PeLoader { } // Apply relocations if loading at a non-preferred base - let delta = base_address.cast_signed() - preferred_base.cast_signed(); + let delta = base_address + .cast_signed() + .checked_sub(preferred_base.cast_signed()) + .ok_or_else(|| malformed_error!("PE relocation delta overflow"))?; if delta != 0 && self.config.apply_relocations { Self::apply_relocations(&pe, &mut image_data, delta, is_64_bit)?; } @@ -906,7 +927,10 @@ impl PeLoader { let reloc_size = reloc_dir.size as usize; // Ensure relocation data is within bounds - if reloc_rva + reloc_size > image_data.len() { + let end = reloc_rva + .checked_add(reloc_size) + .ok_or_else(|| malformed_error!("PE relocation directory range overflow"))?; + if end > image_data.len() { return Err(Error::Other( "Relocation directory extends beyond image bounds".to_string(), )); @@ -914,77 +938,112 @@ impl PeLoader { // Process relocation blocks let mut offset = reloc_rva; - let end = reloc_rva + reloc_size; - while offset + 8 <= end { + while offset.checked_add(8).is_some_and(|next| next <= end) { // Read block header - let page_rva = u32::from_le_bytes([ - image_data[offset], - image_data[offset + 1], - image_data[offset + 2], - image_data[offset + 3], - ]) as usize; - - let block_size = u32::from_le_bytes([ - image_data[offset + 4], - image_data[offset + 5], - image_data[offset + 6], - image_data[offset + 7], - ]) as usize; + let header_end = offset + .checked_add(8) + .ok_or_else(|| malformed_error!("PE relocation block header overflow"))?; + let header_slice = image_data + .get(offset..header_end) + .ok_or(out_of_bounds_error!())?; + let page_rva_bytes = header_slice.get(0..4).ok_or(out_of_bounds_error!())?; + let block_size_bytes = header_slice.get(4..8).ok_or(out_of_bounds_error!())?; + let page_rva = u32::from_le_bytes( + page_rva_bytes + .try_into() + .map_err(|_| out_of_bounds_error!())?, + ) as usize; + let block_size = u32::from_le_bytes( + block_size_bytes + .try_into() + .map_err(|_| out_of_bounds_error!())?, + ) as usize; // Validate block size - if block_size < 8 || offset + block_size > end { + let block_end = match offset.checked_add(block_size) { + Some(v) => v, + None => break, + }; + if block_size < 8 || block_end > end { break; } // Process entries in this block - let entry_count = (block_size - 8) / 2; + let entry_count = (block_size.saturating_sub(8)) / 2; for i in 0..entry_count { - let entry_offset = offset + 8 + i * 2; - if entry_offset + 2 > image_data.len() { + let entry_offset = match offset + .checked_add(8) + .and_then(|base| i.checked_mul(2).and_then(|delta| base.checked_add(delta))) + { + Some(v) => v, + None => break, + }; + let entry_end = match entry_offset.checked_add(2) { + Some(v) => v, + None => break, + }; + if entry_end > image_data.len() { break; } + let entry_bytes = image_data + .get(entry_offset..entry_end) + .ok_or(out_of_bounds_error!())?; let entry = - u16::from_le_bytes([image_data[entry_offset], image_data[entry_offset + 1]]); + u16::from_le_bytes(entry_bytes.try_into().map_err(|_| out_of_bounds_error!())?); let reloc_type = entry >> 12; let reloc_offset = (entry & 0x0FFF) as usize; - let target_offset = page_rva + reloc_offset; + let target_offset = match page_rva.checked_add(reloc_offset) { + Some(v) => v, + None => break, + }; match reloc_type { // 32-bit fixup reloc_type::IMAGE_REL_BASED_HIGHLOW - if target_offset + 4 <= image_data.len() => + if target_offset + .checked_add(4) + .is_some_and(|t| t <= image_data.len()) => { - let value = u32::from_le_bytes([ - image_data[target_offset], - image_data[target_offset + 1], - image_data[target_offset + 2], - image_data[target_offset + 3], - ]); + let target_end = target_offset + .checked_add(4) + .ok_or_else(|| malformed_error!("PE reloc target overflow"))?; + let val_slice = image_data + .get(target_offset..target_end) + .ok_or(out_of_bounds_error!())?; + let value = u32::from_le_bytes( + val_slice.try_into().map_err(|_| out_of_bounds_error!())?, + ); #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let new_value = (i64::from(value) + delta) as u32; - image_data[target_offset..target_offset + 4] - .copy_from_slice(&new_value.to_le_bytes()); + let new_value = i64::from(value).wrapping_add(delta) as u32; + let dst = image_data + .get_mut(target_offset..target_end) + .ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(&new_value.to_le_bytes()); } // 64-bit fixup reloc_type::IMAGE_REL_BASED_DIR64 - if is_64_bit && target_offset + 8 <= image_data.len() => + if is_64_bit + && target_offset + .checked_add(8) + .is_some_and(|t| t <= image_data.len()) => { - let value = u64::from_le_bytes([ - image_data[target_offset], - image_data[target_offset + 1], - image_data[target_offset + 2], - image_data[target_offset + 3], - image_data[target_offset + 4], - image_data[target_offset + 5], - image_data[target_offset + 6], - image_data[target_offset + 7], - ]); - let new_value = (value.cast_signed() + delta).cast_unsigned(); - image_data[target_offset..target_offset + 8] - .copy_from_slice(&new_value.to_le_bytes()); + let target_end = target_offset + .checked_add(8) + .ok_or_else(|| malformed_error!("PE reloc target overflow"))?; + let val_slice = image_data + .get(target_offset..target_end) + .ok_or(out_of_bounds_error!())?; + let value = u64::from_le_bytes( + val_slice.try_into().map_err(|_| out_of_bounds_error!())?, + ); + let new_value = value.cast_signed().wrapping_add(delta).cast_unsigned(); + let dst = image_data + .get_mut(target_offset..target_end) + .ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(&new_value.to_le_bytes()); } _ => { // ABSOLUTE (padding), out-of-bounds, or unsupported relocation type - skip @@ -992,7 +1051,10 @@ impl PeLoader { } } - offset += block_size; + offset = match offset.checked_add(block_size) { + Some(v) => v, + None => break, + }; } Ok(()) diff --git a/dotscope/src/emulation/memory/addressspace.rs b/dotscope/src/emulation/memory/addressspace.rs index 7f1c06e8..2b4ee5b4 100644 --- a/dotscope/src/emulation/memory/addressspace.rs +++ b/dotscope/src/emulation/memory/addressspace.rs @@ -380,8 +380,14 @@ impl AddressSpace { /// /// Returns the new lock count (1 for first acquisition). pub fn monitor_enter(&self, object_id: u64) -> u32 { - let mut locks = self.monitor_locks.write().unwrap(); - let count = locks.get(&object_id).copied().unwrap_or(0) + 1; + let Ok(mut locks) = self.monitor_locks.write() else { + return 0; + }; + let count = locks + .get(&object_id) + .copied() + .unwrap_or(0) + .saturating_add(1); *locks = locks.update(object_id, count); count } @@ -392,10 +398,12 @@ impl AddressSpace { /// held and successfully released, `false` if Exit was called without a /// matching Enter (would throw `SynchronizationLockException` in .NET). pub fn monitor_exit(&self, object_id: u64) -> bool { - let mut locks = self.monitor_locks.write().unwrap(); + let Ok(mut locks) = self.monitor_locks.write() else { + return false; + }; match locks.get(&object_id).copied() { Some(count) if count > 1 => { - *locks = locks.update(object_id, count - 1); + *locks = locks.update(object_id, count.saturating_sub(1)); true } Some(1) => { @@ -467,7 +475,7 @@ impl AddressSpace { /// Returns an error if the mapping fails. pub fn map(&self, region: MemoryRegion) -> Result { let size = region.size(); - let aligned_size = (size + 0xFFF) & !0xFFF; // Page align + let aligned_size = size.saturating_add(0xFFF) & !0xFFF; // Page align // Find next available address let base = self @@ -611,7 +619,7 @@ impl AddressSpace { // Check pinned arrays first if let Ok(pins) = self.pinned_arrays.read() { for entry in pins.values() { - let end = entry.base_addr + entry.byte_length as u64; + let end = entry.base_addr.saturating_add(entry.byte_length as u64); if address >= entry.base_addr && address < end { return true; } @@ -727,14 +735,14 @@ impl AddressSpace { // Calculate end page let end_addr = address.saturating_add(size as u64); - let end_page = (end_addr + Self::PAGE_SIZE - 1) & !(Self::PAGE_SIZE - 1); + let end_page = end_addr.saturating_add(Self::PAGE_SIZE - 1) & !(Self::PAGE_SIZE - 1); // Update protection for all affected pages if let Ok(mut overrides) = self.protection_overrides.write() { let mut page = start_page; while page < end_page { overrides.insert(page, new_protection); - page += Self::PAGE_SIZE; + page = page.saturating_add(Self::PAGE_SIZE); } } @@ -827,7 +835,7 @@ impl AddressSpace { /// Used for pinned arrays where the backing store is the managed heap. /// The returned address is guaranteed not to conflict with other allocations. pub fn reserve_address_range(&self, size: usize) -> u64 { - let aligned_size = (size + 0xFFF) & !0xFFF; // Align to 4KB + let aligned_size = size.saturating_add(0xFFF) & !0xFFF; // Align to 4KB self.next_address .fetch_add(aligned_size as u64, Ordering::SeqCst) } @@ -855,7 +863,11 @@ impl AddressSpace { let entry = PinnedArrayEntry { array_ref, base_addr, - byte_length: element_size * element_count, + byte_length: element_size.checked_mul(element_count).ok_or_else(|| { + EmulationError::InternalError { + description: "pinned array byte length overflow".to_string(), + } + })?, element_size, }; let mut pins = self.pinned_arrays.write().map_err(|_| { @@ -881,8 +893,9 @@ impl AddressSpace { } for entry in pins.values() { - let end = entry.base_addr + entry.byte_length as u64; - if addr >= entry.base_addr && addr + len as u64 <= end { + let end = entry.base_addr.saturating_add(entry.byte_length as u64); + let read_end = addr.saturating_add(len as u64); + if addr >= entry.base_addr && read_end <= end { return Some(self.read_pinned_bytes(entry, addr, len)); } } @@ -896,14 +909,19 @@ impl AddressSpace { addr: u64, len: usize, ) -> Result> { - let byte_offset = (addr - entry.base_addr) as usize; + let byte_offset = + addr.checked_sub(entry.base_addr) + .ok_or_else(|| EmulationError::InvalidAddress { + address: addr, + reason: "pinned array address underflow".to_string(), + })? as usize; let heap = self.managed_heap(); let mut result = vec![0u8; len]; if entry.element_size == 1 { // Byte array fast path: each element is one byte for (i, slot) in result.iter_mut().enumerate().take(len) { - let elem_idx = byte_offset + i; + let elem_idx = byte_offset.saturating_add(i); match heap.get_array_element(entry.array_ref, elem_idx) { Ok(EmValue::I32(v)) => { #[allow(clippy::cast_sign_loss)] @@ -916,8 +934,9 @@ impl AddressSpace { } } else { // Multi-byte element path: deserialize elements to bytes - let start_elem = byte_offset / entry.element_size; - let end_elem = (byte_offset + len).div_ceil(entry.element_size); + let start_elem = byte_offset.checked_div(entry.element_size).unwrap_or(0); + let read_end_offset = byte_offset.saturating_add(len); + let end_elem = read_end_offset.div_ceil(entry.element_size); let mut elem_buf = vec![0u8; entry.element_size]; for elem_idx in start_elem..end_elem { @@ -927,11 +946,13 @@ impl AddressSpace { .unwrap_or(EmValue::I32(0)), &mut elem_buf, ); - let elem_byte_start = elem_idx * entry.element_size; + let elem_byte_start = elem_idx.saturating_mul(entry.element_size); for (j, &b) in elem_buf.iter().enumerate() { - let abs_byte = elem_byte_start + j; - if abs_byte >= byte_offset && abs_byte < byte_offset + len { - result[abs_byte - byte_offset] = b; + let abs_byte = elem_byte_start.saturating_add(j); + if abs_byte >= byte_offset && abs_byte < read_end_offset { + if let Some(slot) = result.get_mut(abs_byte.saturating_sub(byte_offset)) { + *slot = b; + } } } } @@ -953,8 +974,9 @@ impl AddressSpace { } for entry in pins.values() { - let end = entry.base_addr + entry.byte_length as u64; - if addr >= entry.base_addr && addr + data.len() as u64 <= end { + let end = entry.base_addr.saturating_add(entry.byte_length as u64); + let write_end = addr.saturating_add(data.len() as u64); + if addr >= entry.base_addr && write_end <= end { return Some(self.write_pinned_bytes(entry, addr, data)); } } @@ -963,22 +985,28 @@ impl AddressSpace { /// Writes bytes to a pinned array by delegating to the managed heap. fn write_pinned_bytes(&self, entry: &PinnedArrayEntry, addr: u64, data: &[u8]) -> Result<()> { - let byte_offset = (addr - entry.base_addr) as usize; + let byte_offset = + addr.checked_sub(entry.base_addr) + .ok_or_else(|| EmulationError::InvalidAddress { + address: addr, + reason: "pinned array address underflow".to_string(), + })? as usize; let heap = self.managed_heap(); if entry.element_size == 1 { // Byte array fast path: each byte is one element for (i, &byte) in data.iter().enumerate() { - let elem_idx = byte_offset + i; + let elem_idx = byte_offset.saturating_add(i); heap.set_array_element(entry.array_ref, elem_idx, EmValue::I32(i32::from(byte)))?; } } else { // Multi-byte element path: read-modify-write for partial elements - let start_elem = byte_offset / entry.element_size; - let end_elem = (byte_offset + data.len()).div_ceil(entry.element_size); + let start_elem = byte_offset.checked_div(entry.element_size).unwrap_or(0); + let write_end_offset = byte_offset.saturating_add(data.len()); + let end_elem = write_end_offset.div_ceil(entry.element_size); for elem_idx in start_elem..end_elem { - let elem_byte_start = elem_idx * entry.element_size; + let elem_byte_start = elem_idx.saturating_mul(entry.element_size); let mut elem_buf = vec![0u8; entry.element_size]; // Read existing element value @@ -991,9 +1019,11 @@ impl AddressSpace { // Overwrite the affected bytes for (j, byte) in elem_buf.iter_mut().enumerate() { - let abs_byte = elem_byte_start + j; - if abs_byte >= byte_offset && abs_byte < byte_offset + data.len() { - *byte = data[abs_byte - byte_offset]; + let abs_byte = elem_byte_start.saturating_add(j); + if abs_byte >= byte_offset && abs_byte < write_end_offset { + if let Some(&src) = data.get(abs_byte.saturating_sub(byte_offset)) { + *byte = src; + } } } @@ -1007,27 +1037,17 @@ impl AddressSpace { /// Serializes an `EmValue` to little-endian bytes. fn emvalue_to_bytes(value: &EmValue, buf: &mut [u8]) { - match value { - EmValue::I32(v) => { - let bytes = v.to_le_bytes(); - let copy_len = buf.len().min(4); - buf[..copy_len].copy_from_slice(&bytes[..copy_len]); - } - EmValue::I64(v) | EmValue::NativeInt(v) => { - let bytes = v.to_le_bytes(); - let copy_len = buf.len().min(8); - buf[..copy_len].copy_from_slice(&bytes[..copy_len]); - } - EmValue::F32(v) => { - let bytes = v.to_le_bytes(); - let copy_len = buf.len().min(4); - buf[..copy_len].copy_from_slice(&bytes[..copy_len]); - } - EmValue::F64(v) => { - let bytes = v.to_le_bytes(); - let copy_len = buf.len().min(8); - buf[..copy_len].copy_from_slice(&bytes[..copy_len]); + fn copy_le(buf: &mut [u8], bytes: &[u8]) { + let copy_len = buf.len().min(bytes.len()); + if let (Some(dst), Some(src)) = (buf.get_mut(..copy_len), bytes.get(..copy_len)) { + dst.copy_from_slice(src); } + } + match value { + EmValue::I32(v) => copy_le(buf, &v.to_le_bytes()), + EmValue::I64(v) | EmValue::NativeInt(v) => copy_le(buf, &v.to_le_bytes()), + EmValue::F32(v) => copy_le(buf, &v.to_le_bytes()), + EmValue::F64(v) => copy_le(buf, &v.to_le_bytes()), _ => buf.fill(0), } } @@ -1035,12 +1055,22 @@ impl AddressSpace { /// Deserializes little-endian bytes to an `EmValue`. fn bytes_to_emvalue(bytes: &[u8]) -> EmValue { match bytes.len() { - 1 => EmValue::I32(i32::from(bytes[0])), - 2 => EmValue::I32(i32::from(i16::from_le_bytes([bytes[0], bytes[1]]))), - 4 => EmValue::I32(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])), - 8 => EmValue::I64(i64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ])), + 1 => match bytes.first() { + Some(&b) => EmValue::I32(i32::from(b)), + None => EmValue::I32(0), + }, + 2 => match <[u8; 2]>::try_from(bytes) { + Ok(arr) => EmValue::I32(i32::from(i16::from_le_bytes(arr))), + Err(_) => EmValue::I32(0), + }, + 4 => match <[u8; 4]>::try_from(bytes) { + Ok(arr) => EmValue::I32(i32::from_le_bytes(arr)), + Err(_) => EmValue::I32(0), + }, + 8 => match <[u8; 8]>::try_from(bytes) { + Ok(arr) => EmValue::I64(i64::from_le_bytes(arr)), + Err(_) => EmValue::I32(0), + }, _ => EmValue::I32(0), } } diff --git a/dotscope/src/emulation/memory/arguments.rs b/dotscope/src/emulation/memory/arguments.rs index 0b94430b..e46ac001 100644 --- a/dotscope/src/emulation/memory/arguments.rs +++ b/dotscope/src/emulation/memory/arguments.rs @@ -132,14 +132,10 @@ impl ArgumentStorage { /// /// Returns [`EmulationError::ArgumentIndexOutOfBounds`] if index is out of bounds. pub fn get(&self, index: usize) -> Result<&EmValue> { - if index >= self.values.len() { - return Err(EmulationError::ArgumentIndexOutOfBounds { - index, - count: self.values.len(), - } - .into()); - } - Ok(&self.values[index]) + let count = self.values.len(); + self.values + .get(index) + .ok_or_else(|| EmulationError::ArgumentIndexOutOfBounds { index, count }.into()) } /// Gets a mutable reference to an argument. @@ -152,14 +148,10 @@ impl ArgumentStorage { /// /// Returns [`EmulationError::ArgumentIndexOutOfBounds`] if index is out of bounds. pub fn get_mut(&mut self, index: usize) -> Result<&mut EmValue> { - if index >= self.values.len() { - return Err(EmulationError::ArgumentIndexOutOfBounds { - index, - count: self.values.len(), - } - .into()); - } - Ok(&mut self.values[index]) + let count = self.values.len(); + self.values + .get_mut(index) + .ok_or_else(|| EmulationError::ArgumentIndexOutOfBounds { index, count }.into()) } /// Sets the value of an argument. @@ -175,19 +167,13 @@ impl ArgumentStorage { /// /// Returns error if index is out of bounds or type mismatches. pub fn set(&mut self, index: usize, value: EmValue) -> Result<()> { - if index >= self.values.len() { - return Err(EmulationError::ArgumentIndexOutOfBounds { - index, - count: self.values.len(), - } - .into()); - } - - // Type check (relaxed for compatible types and symbolic values) - let expected = &self.types[index]; + let count = self.values.len(); + let expected = self.types.get(index).ok_or_else(|| -> crate::Error { + EmulationError::ArgumentIndexOutOfBounds { index, count }.into() + })?; let found = value.cil_flavor(); - // Check type compatibility using CilFlavor's stack compatibility rules + // Type check (relaxed for compatible types and symbolic values) if !expected.is_stack_assignable_from(&found) && !value.is_symbolic() { return Err(EmulationError::ArgumentFlavorMismatch { index, @@ -197,7 +183,10 @@ impl ArgumentStorage { .into()); } - self.values[index] = value; + let slot = self.values.get_mut(index).ok_or_else(|| -> crate::Error { + EmulationError::ArgumentIndexOutOfBounds { index, count }.into() + })?; + *slot = value; Ok(()) } @@ -211,14 +200,10 @@ impl ArgumentStorage { /// /// Returns error if index is out of bounds. pub fn get_type(&self, index: usize) -> Result<&CilFlavor> { - if index >= self.types.len() { - return Err(EmulationError::ArgumentIndexOutOfBounds { - index, - count: self.types.len(), - } - .into()); - } - Ok(&self.types[index]) + let count = self.types.len(); + self.types + .get(index) + .ok_or_else(|| EmulationError::ArgumentIndexOutOfBounds { index, count }.into()) } /// Gets the `this` reference for instance methods. @@ -228,8 +213,8 @@ impl ArgumentStorage { /// `Some(&EmValue)` if this is an instance method, `None` otherwise. #[must_use] pub fn this(&self) -> Option<&EmValue> { - if self.has_this && !self.values.is_empty() { - Some(&self.values[0]) + if self.has_this { + self.values.first() } else { None } diff --git a/dotscope/src/emulation/memory/heap/arrays.rs b/dotscope/src/emulation/memory/heap/arrays.rs index e429e026..38d9e9b6 100644 --- a/dotscope/src/emulation/memory/heap/arrays.rs +++ b/dotscope/src/emulation/memory/heap/arrays.rs @@ -94,14 +94,14 @@ impl ManagedHeap { })?; match state.objects.get(&heap_ref.id()) { Some(HeapObject::Array { elements, .. }) => { - if index >= elements.len() { + if let Some(elem) = elements.get(index) { + Ok(elem.clone()) + } else { Err(EmulationError::ArrayIndexOutOfBounds { index: i64::try_from(index).unwrap_or(i64::MAX), length: elements.len(), } .into()) - } else { - Ok(elements[index].clone()) } } Some(other) => Err(EmulationError::HeapTypeMismatch { @@ -134,15 +134,16 @@ impl ManagedHeap { })?; match state.objects.get_mut(&heap_ref.id()) { Some(HeapObject::Array { elements, .. }) => { - if index >= elements.len() { + let len = elements.len(); + if let Some(slot) = elements.get_mut(index) { + *slot = value; + Ok(()) + } else { Err(EmulationError::ArrayIndexOutOfBounds { index: i64::try_from(index).unwrap_or(i64::MAX), - length: elements.len(), + length: len, } .into()) - } else { - elements[index] = value; - Ok(()) } } Some(other) => Err(EmulationError::HeapTypeMismatch { @@ -295,7 +296,8 @@ impl ManagedHeap { let element_size = element_type.element_size(ptr_size); match element_size { Some(element_size) => { - let mut bytes = Vec::with_capacity(elements.len() * element_size); + let mut bytes = + Vec::with_capacity(elements.len().saturating_mul(element_size)); let mut valid = true; for e in elements { match e { diff --git a/dotscope/src/emulation/memory/heap/collections.rs b/dotscope/src/emulation/memory/heap/collections.rs index 017c0b84..07b67ae6 100644 --- a/dotscope/src/emulation/memory/heap/collections.rs +++ b/dotscope/src/emulation/memory/heap/collections.rs @@ -404,8 +404,8 @@ impl ManagedHeap { description: "managed heap", })?; if let Some(HeapObject::List { elements }) = state.objects.get_mut(&heap_ref.id()) { - if index < elements.len() { - elements[index] = value; + if let Some(slot) = elements.get_mut(index) { + *slot = value; } } Ok(()) diff --git a/dotscope/src/emulation/memory/heap/mod.rs b/dotscope/src/emulation/memory/heap/mod.rs index aa5f831a..ceb9d3d9 100644 --- a/dotscope/src/emulation/memory/heap/mod.rs +++ b/dotscope/src/emulation/memory/heap/mod.rs @@ -623,11 +623,18 @@ impl HeapObject { /// object header overhead and data storage, but is not exact. #[must_use] pub fn estimated_size(&self) -> usize { + // All sizes are estimates for in-memory tracking; saturating arithmetic + // is correct here — overflow only matters for pathological inputs and + // saturated `usize::MAX` correctly signals "huge" for limit checks. match self { - HeapObject::String(s) => 24 + s.len() * 2, // Object header + UTF-16 - HeapObject::Array { elements, .. } => 24 + elements.len() * 8, - HeapObject::MultiArray { elements, .. } => 32 + elements.len() * 8, - HeapObject::Object { fields, .. } => 24 + fields.len() * 16, + HeapObject::String(s) => s.len().saturating_mul(2).saturating_add(24), // Object header + UTF-16 + HeapObject::Array { elements, .. } => { + elements.len().saturating_mul(8).saturating_add(24) + } + HeapObject::MultiArray { elements, .. } => { + elements.len().saturating_mul(8).saturating_add(32) + } + HeapObject::Object { fields, .. } => fields.len().saturating_mul(16).saturating_add(24), HeapObject::TypedReference { .. } | HeapObject::BoxedValue { .. } | HeapObject::CryptoAlgorithm { .. } @@ -638,28 +645,38 @@ impl HeapObject { | HeapObject::ReflectionParameter { .. } | HeapObject::DynamicMethod { .. } => 32, HeapObject::ILGenerator { .. } => 128, - HeapObject::CryptoTransform { key, iv, .. } => 48 + key.len() + iv.len(), + HeapObject::CryptoTransform { key, iv, .. } => { + 48usize.saturating_add(key.len()).saturating_add(iv.len()) + } HeapObject::Delegate { .. } => 48, HeapObject::Encoding { .. } => 24, - HeapObject::SymmetricAlgorithm { key, iv, .. } => { - 32 + key.as_ref().map_or(0, Vec::len) + iv.as_ref().map_or(0, Vec::len) + HeapObject::SymmetricAlgorithm { key, iv, .. } => 32usize + .saturating_add(key.as_ref().map_or(0, Vec::len)) + .saturating_add(iv.as_ref().map_or(0, Vec::len)), + HeapObject::Dictionary { entries } => { + entries.len().saturating_mul(32).saturating_add(48) + } + HeapObject::List { elements } => elements.len().saturating_mul(8).saturating_add(32), + HeapObject::StringBuilder { buffer, .. } => 32usize.saturating_add(buffer.len()), + HeapObject::Stack { elements } => elements.len().saturating_mul(8).saturating_add(32), + HeapObject::Queue { elements } => elements.len().saturating_mul(8).saturating_add(32), + HeapObject::HashSet { elements } => { + elements.len().saturating_mul(16).saturating_add(48) } - HeapObject::Dictionary { entries } => 48 + entries.len() * 32, - HeapObject::List { elements } => 32 + elements.len() * 8, - HeapObject::StringBuilder { buffer, .. } => 32 + buffer.len(), - HeapObject::Stack { elements } => 32 + elements.len() * 8, - HeapObject::Queue { elements } => 32 + elements.len() * 8, - HeapObject::HashSet { elements } => 48 + elements.len() * 16, - HeapObject::KeyDerivation { password, salt, .. } => 48 + password.len() + salt.len(), - HeapObject::Stream { data, .. } => 32 + data.len(), + HeapObject::KeyDerivation { password, salt, .. } => 48usize + .saturating_add(password.len()) + .saturating_add(salt.len()), + HeapObject::Stream { data, .. } => 32usize.saturating_add(data.len()), HeapObject::CryptoStream { transformed_data, write_buffer, .. - } => 64 + transformed_data.as_ref().map_or(0, Vec::len) + write_buffer.len(), + } => 64usize + .saturating_add(transformed_data.as_ref().map_or(0, Vec::len)) + .saturating_add(write_buffer.len()), HeapObject::CompressedStream { decompressed_data, .. - } => 48 + decompressed_data.as_ref().map_or(0, Vec::len), + } => 48usize.saturating_add(decompressed_data.as_ref().map_or(0, Vec::len)), } } } @@ -967,7 +984,8 @@ impl ManagedHeap { /// Creates a managed heap with default size (64MB). #[must_use] pub fn default_size() -> Self { - Self::new(64 * 1024 * 1024) + // 64 * 1024 * 1024 = 64 MiB + Self::new(64usize.saturating_mul(1024).saturating_mul(1024)) } /// Checks if allocation would exceed memory limit. @@ -976,7 +994,7 @@ impl ManagedHeap { /// allocation would exceed [`max_size`](Self::max_size). pub(crate) fn check_allocation(&self, size: usize) -> Result<()> { let current = self.current_size.load(Ordering::Relaxed); - if current + size > self.max_size { + if current.saturating_add(size) > self.max_size { return Err(EmulationError::HeapMemoryLimitExceeded { current, limit: self.max_size, diff --git a/dotscope/src/emulation/memory/heap/streams.rs b/dotscope/src/emulation/memory/heap/streams.rs index a126d6a6..fbf77104 100644 --- a/dotscope/src/emulation/memory/heap/streams.rs +++ b/dotscope/src/emulation/memory/heap/streams.rs @@ -78,9 +78,8 @@ impl ManagedHeap { description: "managed heap", })?; if let Some(HeapObject::Stream { data, position }) = state.objects.get_mut(&heap_ref.id()) { - if *position < data.len() { - let byte = data[*position]; - *position += 1; + if let Some(&byte) = data.get(*position) { + *position = position.saturating_add(1); return Ok(Some(byte)); } } @@ -102,11 +101,13 @@ impl ManagedHeap { description: "managed heap", })?; if let Some(HeapObject::Stream { data, position }) = state.objects.get_mut(&heap_ref.id()) { - if *position + N <= data.len() { - let mut buf = [0u8; N]; - buf.copy_from_slice(&data[*position..*position + N]); - *position += N; - return Ok(Some(buf)); + if let Some(end) = position.checked_add(N) { + if let Some(slice) = data.get(*position..end) { + let mut buf = [0u8; N]; + buf.copy_from_slice(slice); + *position = end; + return Ok(Some(buf)); + } } } Ok(None) @@ -130,8 +131,12 @@ impl ManagedHeap { if let Some(HeapObject::Stream { data, position }) = state.objects.get_mut(&heap_ref.id()) { let available = data.len().saturating_sub(*position); let to_read = count.min(available); - let bytes = data[*position..*position + to_read].to_vec(); - *position += to_read; + let end = position.saturating_add(to_read); + let bytes = data + .get(*position..end) + .map(<[u8]>::to_vec) + .unwrap_or_default(); + *position = end; return Ok(Some(bytes)); } Ok(None) @@ -263,19 +268,26 @@ impl ManagedHeap { let write_len = bytes.len(); // Ensure capacity - let required_len = *position + write_len; + let required_len = + position + .checked_add(write_len) + .ok_or(EmulationError::InternalError { + description: "stream write length overflow".into(), + })?; if data.len() < required_len { data.resize(required_len, 0); } // Copy bytes to the stream - data[*position..*position + write_len].copy_from_slice(bytes); + let dst = + data.get_mut(*position..required_len) + .ok_or(EmulationError::InternalError { + description: "stream write slice OOB".into(), + })?; + dst.copy_from_slice(bytes); // Advance position - *position += write_len; - - // Update size estimate - // (We don't track size changes precisely here, but that's acceptable) + *position = required_len; Ok(write_len) } else { @@ -349,10 +361,10 @@ impl ManagedHeap { // Update size tracking atomically if new_size >= old_size { self.current_size - .fetch_add(new_size - old_size, Ordering::Relaxed); + .fetch_add(new_size.saturating_sub(old_size), Ordering::Relaxed); } else { self.current_size - .fetch_sub(old_size - new_size, Ordering::Relaxed); + .fetch_sub(old_size.saturating_sub(new_size), Ordering::Relaxed); } Ok(true) } else { @@ -476,10 +488,10 @@ impl ManagedHeap { // Update size tracking atomically if new_size >= old_size { self.current_size - .fetch_add(new_size - old_size, Ordering::Relaxed); + .fetch_add(new_size.saturating_sub(old_size), Ordering::Relaxed); } else { self.current_size - .fetch_sub(old_size - new_size, Ordering::Relaxed); + .fetch_sub(old_size.saturating_sub(new_size), Ordering::Relaxed); } Ok(true) } else { @@ -560,10 +572,10 @@ impl ManagedHeap { // Update size tracking atomically if new_size >= old_size { self.current_size - .fetch_add(new_size - old_size, Ordering::Relaxed); + .fetch_add(new_size.saturating_sub(old_size), Ordering::Relaxed); } else { self.current_size - .fetch_sub(old_size - new_size, Ordering::Relaxed); + .fetch_sub(old_size.saturating_sub(new_size), Ordering::Relaxed); } Ok(true) } @@ -598,8 +610,12 @@ impl ManagedHeap { { let available = data.len().saturating_sub(*transformed_pos); let to_read = count.min(available); - let result = data[*transformed_pos..*transformed_pos + to_read].to_vec(); - *transformed_pos += to_read; + let end = transformed_pos.saturating_add(to_read); + let result = data + .get(*transformed_pos..end) + .map(<[u8]>::to_vec) + .unwrap_or_default(); + *transformed_pos = end; Ok(Some(result)) } else { Ok(None) @@ -743,10 +759,10 @@ impl ManagedHeap { // Update size tracking atomically if new_size >= old_size { self.current_size - .fetch_add(new_size - old_size, Ordering::Relaxed); + .fetch_add(new_size.saturating_sub(old_size), Ordering::Relaxed); } else { self.current_size - .fetch_sub(old_size - new_size, Ordering::Relaxed); + .fetch_sub(old_size.saturating_sub(new_size), Ordering::Relaxed); } Ok(true) } else { @@ -826,10 +842,10 @@ impl ManagedHeap { // Update size tracking atomically if data_len >= old_cached_size { self.current_size - .fetch_add(data_len - old_cached_size, Ordering::Relaxed); + .fetch_add(data_len.saturating_sub(old_cached_size), Ordering::Relaxed); } else { self.current_size - .fetch_sub(old_cached_size - data_len, Ordering::Relaxed); + .fetch_sub(old_cached_size.saturating_sub(data_len), Ordering::Relaxed); } Ok(()) } @@ -870,8 +886,12 @@ impl ManagedHeap { }) => { let available = data.len().saturating_sub(*read_position); let to_read = count.min(available); - let bytes = data[*read_position..*read_position + to_read].to_vec(); - *read_position += to_read; + let end = read_position.saturating_add(to_read); + let bytes = data + .get(*read_position..end) + .map(<[u8]>::to_vec) + .unwrap_or_default(); + *read_position = end; Some(bytes) } _ => None, diff --git a/dotscope/src/emulation/memory/locals.rs b/dotscope/src/emulation/memory/locals.rs index ec1c71dd..6e345dd6 100644 --- a/dotscope/src/emulation/memory/locals.rs +++ b/dotscope/src/emulation/memory/locals.rs @@ -115,14 +115,10 @@ impl LocalVariables { /// /// Returns [`EmulationError::LocalIndexOutOfBounds`] if index is invalid. pub fn get(&self, index: usize) -> Result<&EmValue> { - if index >= self.values.len() { - return Err(EmulationError::LocalIndexOutOfBounds { - index, - count: self.values.len(), - } - .into()); - } - Ok(&self.values[index]) + let count = self.values.len(); + self.values + .get(index) + .ok_or_else(|| EmulationError::LocalIndexOutOfBounds { index, count }.into()) } /// Gets a mutable reference to a local variable. @@ -135,14 +131,10 @@ impl LocalVariables { /// /// Returns [`EmulationError::LocalIndexOutOfBounds`] if index is invalid. pub fn get_mut(&mut self, index: usize) -> Result<&mut EmValue> { - if index >= self.values.len() { - return Err(EmulationError::LocalIndexOutOfBounds { - index, - count: self.values.len(), - } - .into()); - } - Ok(&mut self.values[index]) + let count = self.values.len(); + self.values + .get_mut(index) + .ok_or_else(|| EmulationError::LocalIndexOutOfBounds { index, count }.into()) } /// Sets the value of a local variable. @@ -167,13 +159,10 @@ impl LocalVariables { /// When a type mismatch is detected, the local's declared type is updated to /// match the stored value, ensuring subsequent loads work correctly. pub fn set(&mut self, index: usize, value: EmValue) -> Result<()> { - if index >= self.values.len() { - return Err(EmulationError::LocalIndexOutOfBounds { - index, - count: self.values.len(), - } - .into()); - } + let count = self.values.len(); + let slot = self.values.get_mut(index).ok_or_else(|| -> crate::Error { + EmulationError::LocalIndexOutOfBounds { index, count }.into() + })?; // Match .NET runtime behavior: accept all types for local stores. // If the stored value's type differs from the declared type, update the @@ -181,12 +170,15 @@ impl LocalVariables { // CFF obfuscation where different code paths store different types. if !value.is_symbolic() { let found = value.cil_flavor(); - if !self.types[index].is_stack_assignable_from(&found) { - self.types[index] = found; + let type_slot = self.types.get_mut(index).ok_or_else(|| -> crate::Error { + EmulationError::LocalIndexOutOfBounds { index, count }.into() + })?; + if !type_slot.is_stack_assignable_from(&found) { + *type_slot = found; } } - self.values[index] = value; + *slot = value; Ok(()) } @@ -200,14 +192,10 @@ impl LocalVariables { /// /// Returns error if index is invalid. pub fn get_type(&self, index: usize) -> Result<&CilFlavor> { - if index >= self.types.len() { - return Err(EmulationError::LocalIndexOutOfBounds { - index, - count: self.types.len(), - } - .into()); - } - Ok(&self.types[index]) + let count = self.types.len(); + self.types + .get(index) + .ok_or_else(|| EmulationError::LocalIndexOutOfBounds { index, count }.into()) } /// Returns the number of local variables. diff --git a/dotscope/src/emulation/memory/page.rs b/dotscope/src/emulation/memory/page.rs index a04d149d..0cb4583c 100644 --- a/dotscope/src/emulation/memory/page.rs +++ b/dotscope/src/emulation/memory/page.rs @@ -96,7 +96,9 @@ impl Page { pub fn from_slice(data: &[u8]) -> Self { let mut page_data = [0u8; PAGE_SIZE]; let copy_len = data.len().min(PAGE_SIZE); - page_data[..copy_len].copy_from_slice(&data[..copy_len]); + if let (Some(dst), Some(src)) = (page_data.get_mut(..copy_len), data.get(..copy_len)) { + dst.copy_from_slice(src); + } Self::new(page_data) } @@ -122,9 +124,15 @@ impl Page { .map_err(|_| EmulationError::LockPoisoned { description: "page local buffer", })?; - Ok(local - .as_ref() - .map_or(self.backing[offset], |data| data[offset])) + let byte = local.as_ref().map_or_else( + || self.backing.get(offset).copied(), + |data| data.get(offset).copied(), + ); + byte.ok_or(EmulationError::PageOutOfBounds { + offset, + size: 1, + page_size: PAGE_SIZE, + }) } /// Reads a range of bytes into the provided buffer. @@ -154,9 +162,15 @@ impl Page { .map_err(|_| EmulationError::LockPoisoned { description: "page local buffer", })?; - let src = local - .as_ref() - .map_or(&self.backing[offset..end], |data| &data[offset..end]); + let src = local.as_ref().map_or_else( + || self.backing.get(offset..end), + |data| data.get(offset..end), + ); + let src = src.ok_or(EmulationError::PageOutOfBounds { + offset, + size: buf.len(), + page_size: PAGE_SIZE, + })?; buf.copy_from_slice(src); Ok(()) } @@ -197,7 +211,12 @@ impl Page { description: "page local buffer", })?; let buf = local.get_or_insert_with(|| Box::new(*self.backing)); - buf[offset] = value; + let slot = buf.get_mut(offset).ok_or(EmulationError::PageOutOfBounds { + offset, + size: 1, + page_size: PAGE_SIZE, + })?; + *slot = value; Ok(()) } @@ -231,7 +250,14 @@ impl Page { description: "page local buffer", })?; let buf = local.get_or_insert_with(|| Box::new(*self.backing)); - buf[offset..end].copy_from_slice(data); + let dst = buf + .get_mut(offset..end) + .ok_or(EmulationError::PageOutOfBounds { + offset, + size: data.len(), + page_size: PAGE_SIZE, + })?; + dst.copy_from_slice(data); Ok(()) } diff --git a/dotscope/src/emulation/memory/region.rs b/dotscope/src/emulation/memory/region.rs index 9c35fce2..bd38e4e4 100644 --- a/dotscope/src/emulation/memory/region.rs +++ b/dotscope/src/emulation/memory/region.rs @@ -303,9 +303,9 @@ impl MemoryRegion { let mut pages = Vec::with_capacity(num_pages); for i in 0..num_pages { - let start = i * PAGE_SIZE; - let end = (start + PAGE_SIZE).min(data.len()); - let chunk = &data[start..end]; + let start = i.saturating_mul(PAGE_SIZE); + let end = start.saturating_add(PAGE_SIZE).min(data.len()); + let chunk = data.get(start..end).unwrap_or(&[]); pages.push(Page::from_slice(chunk)); } @@ -409,7 +409,7 @@ impl MemoryRegion { /// Returns the end address (exclusive) of this region. #[must_use] pub fn end(&self) -> u64 { - self.base + self.size as u64 + self.base.saturating_add(self.size as u64) } /// Returns `true` if the address falls within this region. @@ -421,7 +421,10 @@ impl MemoryRegion { /// Returns `true` if the entire address range falls within this region. #[must_use] pub fn contains_range(&self, address: u64, len: usize) -> bool { - address >= self.base && (address + len as u64) <= self.end() + let Some(range_end) = address.checked_add(len as u64) else { + return false; + }; + address >= self.base && range_end <= self.end() } /// Returns the default protection flags for this region. @@ -470,11 +473,10 @@ impl MemoryRegion { if let Some(ref sections) = self.sections { // Safe: offset within a memory region always fits in u32 #[allow(clippy::cast_possible_truncation)] - let rva = (address - self.base) as u32; + let rva = address.saturating_sub(self.base) as u32; for section in sections.iter() { - if rva >= section.virtual_address - && rva < section.virtual_address + section.virtual_size - { + let section_end = section.virtual_address.saturating_add(section.virtual_size); + if rva >= section.virtual_address && rva < section_end { return Ok(section.protection); } } @@ -507,32 +509,26 @@ impl MemoryRegion { // Safe: offset within a memory region always fits in usize #[allow(clippy::cast_possible_truncation)] - let offset = (address - self.base) as usize; + let offset = address.saturating_sub(self.base) as usize; let mut result = vec![0u8; len]; let mut bytes_read = 0; while bytes_read < len { - let current_offset = offset + bytes_read; + let current_offset = offset.checked_add(bytes_read)?; let page_index = current_offset / PAGE_SIZE; let page_offset = current_offset % PAGE_SIZE; - if page_index >= self.pages.len() { - return None; - } + let page = self.pages.get(page_index)?; - let bytes_in_page = (PAGE_SIZE - page_offset).min(len - bytes_read); - let page = &self.pages[page_index]; + let remaining = len.checked_sub(bytes_read)?; + let bytes_in_page = PAGE_SIZE.saturating_sub(page_offset).min(remaining); + let read_end = bytes_read.checked_add(bytes_in_page)?; + let dest = result.get_mut(bytes_read..read_end)?; - if page - .read( - page_offset, - &mut result[bytes_read..bytes_read + bytes_in_page], - ) - .is_err() - { + if page.read(page_offset, dest).is_err() { return None; } - bytes_read += bytes_in_page; + bytes_read = bytes_read.saturating_add(bytes_in_page); } Some(result) @@ -561,31 +557,35 @@ impl MemoryRegion { // Safe: offset within a memory region always fits in usize #[allow(clippy::cast_possible_truncation)] - let offset = (address - self.base) as usize; + let offset = address.saturating_sub(self.base) as usize; let mut bytes_written = 0; while bytes_written < bytes.len() { - let current_offset = offset + bytes_written; + let Some(current_offset) = offset.checked_add(bytes_written) else { + return false; + }; let page_index = current_offset / PAGE_SIZE; let page_offset = current_offset % PAGE_SIZE; - if page_index >= self.pages.len() { + let Some(page) = self.pages.get(page_index) else { return false; - } + }; - let bytes_in_page = (PAGE_SIZE - page_offset).min(bytes.len() - bytes_written); - let page = &self.pages[page_index]; + let Some(remaining) = bytes.len().checked_sub(bytes_written) else { + return false; + }; + let bytes_in_page = PAGE_SIZE.saturating_sub(page_offset).min(remaining); + let Some(write_end) = bytes_written.checked_add(bytes_in_page) else { + return false; + }; + let Some(src) = bytes.get(bytes_written..write_end) else { + return false; + }; - if page - .write( - page_offset, - &bytes[bytes_written..bytes_written + bytes_in_page], - ) - .is_err() - { + if page.write(page_offset, src).is_err() { return false; } - bytes_written += bytes_in_page; + bytes_written = bytes_written.saturating_add(bytes_in_page); } true diff --git a/dotscope/src/emulation/memory/stack.rs b/dotscope/src/emulation/memory/stack.rs index 33239f15..6b00b105 100644 --- a/dotscope/src/emulation/memory/stack.rs +++ b/dotscope/src/emulation/memory/stack.rs @@ -171,7 +171,15 @@ impl EvaluationStack { if depth >= self.values.len() { return Err(EmulationError::StackUnderflow.into()); } - Ok(&self.values[self.values.len() - 1 - depth]) + let idx = self + .values + .len() + .checked_sub(1) + .and_then(|n| n.checked_sub(depth)) + .ok_or(EmulationError::StackUnderflow)?; + self.values + .get(idx) + .ok_or_else(|| EmulationError::StackUnderflow.into()) } /// Pops a value and verifies it has the expected CIL flavor. @@ -338,7 +346,9 @@ impl EvaluationStack { if len < 2 { return Err(EmulationError::StackUnderflow.into()); } - self.values.swap(len - 1, len - 2); + let top = len.saturating_sub(1); + let next = len.saturating_sub(2); + self.values.swap(top, next); Ok(()) } diff --git a/dotscope/src/emulation/memory/unmanaged.rs b/dotscope/src/emulation/memory/unmanaged.rs index 6a8d91c8..2830968a 100644 --- a/dotscope/src/emulation/memory/unmanaged.rs +++ b/dotscope/src/emulation/memory/unmanaged.rs @@ -172,7 +172,7 @@ impl UnmanagedMemory { /// Returns [`EmulationError::HeapMemoryLimitExceeded`] if allocation would /// exceed the memory limit. pub fn alloc(&mut self, size: usize) -> Result { - if self.current_size + size > self.max_size { + if self.current_size.saturating_add(size) > self.max_size { return Err(EmulationError::HeapMemoryLimitExceeded { current: self.current_size, limit: self.max_size, @@ -181,12 +181,12 @@ impl UnmanagedMemory { } let address = self.next_address; - self.next_address += size as u64; + let next = self.next_address.saturating_add(size as u64); // Align next allocation to 16 bytes - self.next_address = (self.next_address + 15) & !15; + self.next_address = next.saturating_add(15) & !15; self.regions.insert(address, InternalRegion::new(size)); - self.current_size += size; + self.current_size = self.current_size.saturating_add(size); Ok(UnmanagedRef::new(address)) } @@ -234,9 +234,10 @@ impl UnmanagedMemory { // Search for a region that contains this address (slow path) for (&base, region) in &self.regions { - if region.valid && address >= base && address < base + region.size() as u64 { + let region_end = base.saturating_add(region.size() as u64); + if region.valid && address >= base && address < region_end { #[allow(clippy::cast_possible_truncation)] // Offset bounded by region size - let offset = (address - base) as usize; + let offset = address.saturating_sub(base) as usize; return Some((region, offset)); } } @@ -262,7 +263,8 @@ impl UnmanagedMemory { // Search for containing region if found_base.is_none() { for (&base, region) in &self.regions { - if region.valid && address >= base && address < base + region.size() as u64 { + let region_end = base.saturating_add(region.size() as u64); + if region.valid && address >= base && address < region_end { found_base = Some(base); break; } @@ -272,7 +274,7 @@ impl UnmanagedMemory { if let Some(base) = found_base { if let Some(region) = self.regions.get_mut(&base) { #[allow(clippy::cast_possible_truncation)] // Offset bounded by region size - let offset = (address - base) as usize; + let offset = address.saturating_sub(base) as usize; return Some((region, offset)); } } @@ -303,15 +305,21 @@ impl UnmanagedMemory { reason: "address not in any allocated region", })?; - if offset + size > region.size() { - return Err(EmulationError::InvalidPointer { + let end = offset + .checked_add(size) + .ok_or(EmulationError::InvalidPointer { + address, + reason: "read length overflows", + })?; + let slice = region + .data + .get(offset..end) + .ok_or(EmulationError::InvalidPointer { address, reason: "read would exceed region bounds", - } - .into()); - } + })?; - Ok(region.data[offset..offset + size].to_vec()) + Ok(slice.to_vec()) } /// Writes bytes to unmanaged memory. @@ -333,15 +341,21 @@ impl UnmanagedMemory { reason: "address not in any allocated region", })?; - if offset + data.len() > region.size() { - return Err(EmulationError::InvalidPointer { + let end = offset + .checked_add(data.len()) + .ok_or(EmulationError::InvalidPointer { + address, + reason: "write length overflows", + })?; + let dest = region + .data + .get_mut(offset..end) + .ok_or(EmulationError::InvalidPointer { address, reason: "write would exceed region bounds", - } - .into()); - } + })?; - region.data[offset..offset + data.len()].copy_from_slice(data); + dest.copy_from_slice(data); Ok(()) } @@ -398,15 +412,21 @@ impl UnmanagedMemory { reason: "address not in any allocated region", })?; - if offset + size > region.size() { - return Err(EmulationError::InvalidPointer { + let end = offset + .checked_add(size) + .ok_or(EmulationError::InvalidPointer { + address, + reason: "memset length overflows", + })?; + let dest = region + .data + .get_mut(offset..end) + .ok_or(EmulationError::InvalidPointer { address, reason: "memset would exceed region bounds", - } - .into()); - } + })?; - region.data[offset..offset + size].fill(value); + dest.fill(value); Ok(()) } @@ -486,7 +506,7 @@ impl UnmanagedMemory { } let size = data.len(); - if self.current_size + size > self.max_size { + if self.current_size.saturating_add(size) > self.max_size { return Err(EmulationError::HeapMemoryLimitExceeded { current: self.current_size, limit: self.max_size, @@ -497,12 +517,12 @@ impl UnmanagedMemory { let mut region = InternalRegion::new(size); region.data.copy_from_slice(data); self.regions.insert(address, region); - self.current_size += size; + self.current_size = self.current_size.saturating_add(size); // Update next_address if this allocation would conflict - let end_address = address + size as u64; + let end_address = address.saturating_add(size as u64); if end_address > self.next_address { - self.next_address = (end_address + 15) & !15; + self.next_address = end_address.saturating_add(15) & !15; } Ok(UnmanagedRef::new(address)) diff --git a/dotscope/src/emulation/process/builder.rs b/dotscope/src/emulation/process/builder.rs index 011279c4..f59f00b8 100644 --- a/dotscope/src/emulation/process/builder.rs +++ b/dotscope/src/emulation/process/builder.rs @@ -100,7 +100,7 @@ use crate::{ token::Token, typesystem::PointerSize, }, - CilObject, Result, + CilObject, Error, Result, }; /// Pre-populates static fields with data from the FieldRVA table. @@ -156,28 +156,35 @@ fn populate_fieldrva_statics(assembly: &CilObject, address_space: &AddressSpace) continue; }; - if file_offset + field_type_size > pe_data.len() { + let Some(end) = file_offset.checked_add(field_type_size) else { continue; - } - - // Read the bytes from the PE file - let data = &pe_data[file_offset..file_offset + field_type_size]; + }; + let Some(data) = pe_data.get(file_offset..end) else { + continue; + }; // Convert to EmValue based on size let value = match field_type_size { - 1 => EmValue::I32(i32::from(data[0].cast_signed())), + 1 => { + let Some(b) = data.first() else { continue }; + EmValue::I32(i32::from(b.cast_signed())) + } 2 => { - let bytes = [data[0], data[1]]; + let Ok(bytes) = <[u8; 2]>::try_from(data) else { + continue; + }; EmValue::I32(i32::from(i16::from_le_bytes(bytes))) } 4 => { - let bytes = [data[0], data[1], data[2], data[3]]; + let Ok(bytes) = <[u8; 4]>::try_from(data) else { + continue; + }; EmValue::I32(i32::from_le_bytes(bytes)) } 8 => { - let bytes = [ - data[0], data[1], data[2], data[3], data[4], data[5], data[6], data[7], - ]; + let Ok(bytes) = <[u8; 8]>::try_from(data) else { + continue; + }; EmValue::I64(i64::from_le_bytes(bytes)) } _ => continue, @@ -1232,7 +1239,7 @@ impl ProcessBuilder { let mut writer = if let Some(ref path) = config_arc.tracing.output_path { // File-based tracing - propagate errors to caller TraceWriter::new_file(path, context).map_err(|e| { - crate::Error::TracingError(format!( + Error::TracingError(format!( "Failed to create trace file {}: {e}", path.display() )) diff --git a/dotscope/src/emulation/process/execution.rs b/dotscope/src/emulation/process/execution.rs index c7d0f3cd..20435f98 100644 --- a/dotscope/src/emulation/process/execution.rs +++ b/dotscope/src/emulation/process/execution.rs @@ -376,7 +376,10 @@ impl EmulationProcess { /// /// Returns an error if the instruction limit is exceeded. pub fn increment_instructions(&self, count: u64) -> Result<()> { - let new_count = self.instruction_count.fetch_add(count, Ordering::Relaxed) + count; + let new_count = self + .instruction_count + .fetch_add(count, Ordering::Relaxed) + .saturating_add(count); if self.context.config.limits.max_instructions > 0 && new_count > self.context.config.limits.max_instructions diff --git a/dotscope/src/emulation/runtime/bcl/appdomain.rs b/dotscope/src/emulation/runtime/bcl/appdomain.rs index 2a6e1d53..d150c2f5 100644 --- a/dotscope/src/emulation/runtime/bcl/appdomain.rs +++ b/dotscope/src/emulation/runtime/bcl/appdomain.rs @@ -568,16 +568,16 @@ fn assembly_get_manifest_resource_names_pre( fn delegate_ctor_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { // .ctor(object target, native int methodPtr) — args[0] = target, args[1] = method pointer if ctx.args.len() == 2 { - let target = match &ctx.args[0] { - EmValue::ObjectRef(href) => Some(*href), + let target = match ctx.args.first() { + Some(EmValue::ObjectRef(href)) => Some(*href), _ => None, }; - let method_token = match &ctx.args[1] { - EmValue::UnmanagedPtr(ptr) => Some(Token::new(*ptr as u32)), - EmValue::I32(v) => Some(Token::new(*v as u32)), - EmValue::I64(v) => Some(Token::new(*v as u32)), - EmValue::NativeInt(v) => Some(Token::new(*v as u32)), + let method_token = match ctx.args.get(1) { + Some(EmValue::UnmanagedPtr(ptr)) => Some(Token::new(*ptr as u32)), + Some(EmValue::I32(v)) => Some(Token::new(*v as u32)), + Some(EmValue::I64(v)) => Some(Token::new(*v as u32)), + Some(EmValue::NativeInt(v)) => Some(Token::new(*v as u32)), _ => None, }; diff --git a/dotscope/src/emulation/runtime/bcl/collections/list.rs b/dotscope/src/emulation/runtime/bcl/collections/list.rs index 108f76a5..8335d068 100644 --- a/dotscope/src/emulation/runtime/bcl/collections/list.rs +++ b/dotscope/src/emulation/runtime/bcl/collections/list.rs @@ -545,7 +545,7 @@ fn enumerator_move_next_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) let current_pos = try_hook!(thread.heap().get_field(*enum_ref, pos_field)); if let (EmValue::ObjectRef(list_ref), EmValue::I32(pos)) = (list_href, current_pos) { - let new_pos = pos + 1; + let new_pos = pos.saturating_add(1i32); let count = try_hook!(thread.heap().list_count(list_ref)); try_hook!(thread @@ -608,7 +608,8 @@ fn list_copy_to_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreH .map_or(0, |v| v.max(0) as usize); let elements = try_hook!(thread.heap().list_to_vec(*list_ref)); for (i, elem) in elements.into_iter().enumerate() { - try_hook!(thread.heap().set_array_element(*arr_ref, offset + i, elem)); + let idx = offset.saturating_add(i); + try_hook!(thread.heap().set_array_element(*arr_ref, idx, elem)); } } } @@ -626,8 +627,11 @@ fn list_get_range_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr let elements = try_hook!(thread.heap().list_to_vec(*list_ref)); let start_idx = (*start).max(0) as usize; let cnt = (*count).max(0) as usize; - let end = (start_idx + cnt).min(elements.len()); - let sub = elements[start_idx.min(elements.len())..end].to_vec(); + let end = start_idx.saturating_add(cnt).min(elements.len()); + let begin = start_idx.min(elements.len()); + let sub = elements + .get(begin..end) + .map_or_else(Vec::new, <[_]>::to_vec); let new_list = try_hook!(thread.heap().alloc_list_with_elements(sub)); return PreHookResult::Bypass(Some(EmValue::ObjectRef(new_list))); } diff --git a/dotscope/src/emulation/runtime/bcl/crypto/hashing.rs b/dotscope/src/emulation/runtime/bcl/crypto/hashing.rs index f2d350ec..1a618491 100644 --- a/dotscope/src/emulation/runtime/bcl/crypto/hashing.rs +++ b/dotscope/src/emulation/runtime/bcl/crypto/hashing.rs @@ -189,7 +189,7 @@ fn md5_compute_hash_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> return PreHookResult::Bypass(Some(EmValue::Null)); } - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let hash = compute_md5(&bytes); match thread.heap().alloc_byte_array(&hash) { @@ -236,11 +236,11 @@ fn hash_algorithm_compute_hash_pre( (default_algo.into(), None) }; - if ctx.args.is_empty() { + let Some(first_arg) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - match &ctx.args[0] { + match first_arg { EmValue::ObjectRef(handle) => { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let hash = match &*hash_type { @@ -333,11 +333,11 @@ fn sha256_create_pre(_ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr /// - `inputStream`: Stream to read and hash (overload 3) #[cfg(feature = "legacy-crypto")] fn sha1_compute_hash_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(first_arg) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - match &ctx.args[0] { + match first_arg { EmValue::ObjectRef(handle) => { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let hash = compute_sha1(&bytes); @@ -370,11 +370,11 @@ fn sha1_compute_hash_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> /// - `count`: Number of bytes to hash (overload 2) /// - `inputStream`: Stream to read and hash (overload 3) fn sha256_compute_hash_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(first_arg) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - match &ctx.args[0] { + match first_arg { EmValue::ObjectRef(handle) => { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let hash = compute_sha256(&bytes); @@ -427,9 +427,9 @@ fn hash_algorithm_transform_block_pre( _ => input_bytes.len(), }; - let end = (offset + count).min(input_bytes.len()); - let slice = if offset < input_bytes.len() { - &input_bytes[offset..end] + let end = offset.saturating_add(count).min(input_bytes.len()); + let slice: &[u8] = if offset < input_bytes.len() { + input_bytes.get(offset..end).unwrap_or(&[]) } else { &[] }; @@ -446,7 +446,7 @@ fn hash_algorithm_transform_block_pre( for (i, &byte) in slice.iter().enumerate() { try_hook!(thread.heap_mut().set_array_element( *output_handle, - output_offset + i, + output_offset.saturating_add(i), EmValue::I32(i32::from(byte)), )); } @@ -503,9 +503,11 @@ fn hash_algorithm_transform_final_block_pre( _ => input_bytes.len(), }; - let end = (offset + count).min(input_bytes.len()); + let end = offset.saturating_add(count).min(input_bytes.len()); let slice = if offset < input_bytes.len() { - input_bytes[offset..end].to_vec() + input_bytes + .get(offset..end) + .map_or_else(Vec::new, <[u8]>::to_vec) } else { Vec::new() }; diff --git a/dotscope/src/emulation/runtime/bcl/crypto/hmac.rs b/dotscope/src/emulation/runtime/bcl/crypto/hmac.rs index 89d70b76..af8bc3ff 100644 --- a/dotscope/src/emulation/runtime/bcl/crypto/hmac.rs +++ b/dotscope/src/emulation/runtime/bcl/crypto/hmac.rs @@ -149,7 +149,10 @@ fn hmac_compute_hash_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> return PreHookResult::Bypass(Some(EmValue::Null)); } - match &ctx.args[0] { + let Some(arg0) = ctx.args.first() else { + return PreHookResult::Bypass(Some(EmValue::Null)); + }; + match arg0 { EmValue::ObjectRef(handle) => { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let key = hmac_key.as_deref().unwrap_or(&[]); diff --git a/dotscope/src/emulation/runtime/bcl/crypto/mod.rs b/dotscope/src/emulation/runtime/bcl/crypto/mod.rs index dd71a585..43dc4346 100644 --- a/dotscope/src/emulation/runtime/bcl/crypto/mod.rs +++ b/dotscope/src/emulation/runtime/bcl/crypto/mod.rs @@ -207,9 +207,10 @@ pub(crate) fn resolve_crypto_key_iv( pub(crate) fn extract_xml_element(xml: &str, tag: &str) -> Option { let open = format!("<{tag}>"); let close = format!(""); - let start = xml.find(&open)? + open.len(); - let end = xml[start..].find(&close)? + start; - Some(xml[start..end].trim().to_string()) + let start = xml.find(&open)?.saturating_add(open.len()); + let tail = xml.get(start..)?; + let end = tail.find(&close)?.saturating_add(start); + Some(xml.get(start..end)?.trim().to_string()) } #[cfg(test)] diff --git a/dotscope/src/emulation/runtime/bcl/crypto/symmetric.rs b/dotscope/src/emulation/runtime/bcl/crypto/symmetric.rs index e9e8a96f..a1e0bb0a 100644 --- a/dotscope/src/emulation/runtime/bcl/crypto/symmetric.rs +++ b/dotscope/src/emulation/runtime/bcl/crypto/symmetric.rs @@ -693,7 +693,7 @@ fn transform_block_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P return PreHookResult::Bypass(Some(EmValue::I32(0))); } - let EmValue::ObjectRef(input_handle) = &ctx.args[0] else { + let Some(EmValue::ObjectRef(input_handle)) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(0))); }; @@ -708,9 +708,15 @@ fn transform_block_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P _ => (0, input_bytes.len()), }; - let end = (offset + count).min(input_bytes.len()); + let end = offset + .checked_add(count) + .unwrap_or(input_bytes.len()) + .min(input_bytes.len()); let data = if offset < input_bytes.len() { - input_bytes[offset..end].to_vec() + let Some(slice) = input_bytes.get(offset..end) else { + return PreHookResult::Bypass(Some(EmValue::I32(0))); + }; + slice.to_vec() } else { return PreHookResult::Bypass(Some(EmValue::I32(0))); }; @@ -735,9 +741,12 @@ fn transform_block_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P _ => { // No transform — passthrough: copy input to output for (i, &byte) in data.iter().enumerate() { + let Some(idx) = output_offset.checked_add(i) else { + break; + }; try_hook!(thread.heap_mut().set_array_element( output_handle, - output_offset + i, + idx, EmValue::I32(i32::from(byte)), )); } @@ -786,9 +795,12 @@ fn transform_block_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P // Write result to output buffer for (i, &byte) in result_bytes.iter().enumerate() { + let Some(idx) = output_offset.checked_add(i) else { + break; + }; try_hook!(thread.heap_mut().set_array_element( output_handle, - output_offset + i, + idx, EmValue::I32(i32::from(byte)), )); } @@ -816,7 +828,7 @@ fn transform_final_block_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread return PreHookResult::Bypass(Some(EmValue::Null)); } - let EmValue::ObjectRef(input_handle) = &ctx.args[0] else { + let Some(EmValue::ObjectRef(input_handle)) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); }; @@ -830,10 +842,11 @@ fn transform_final_block_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread (Some(EmValue::I32(offset)), Some(EmValue::I32(count))) => { let offset = *offset as usize; let count = *count as usize; - if offset + count <= input_bytes.len() { - input_bytes[offset..offset + count].to_vec() - } else { - input_bytes.clone() + match offset.checked_add(count) { + Some(end) if end <= input_bytes.len() => input_bytes + .get(offset..end) + .map_or_else(|| input_bytes.clone(), <[u8]>::to_vec), + _ => input_bytes.clone(), } } _ => input_bytes.clone(), @@ -983,11 +996,7 @@ fn crypto_stream_read_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - }; // Use data from the underlying stream's current position, not from offset 0 - let effective_data = if underlying_pos < stream_data.len() { - &stream_data[underlying_pos..] - } else { - &[] - }; + let effective_data: &[u8] = stream_data.get(underlying_pos..).unwrap_or(&[]); let transformed_data = if let Some((algorithm, key, iv, is_encryptor, mode, padding)) = try_hook!(thread.heap().get_crypto_transform_info(transform_ref)) @@ -1038,9 +1047,12 @@ fn crypto_stream_read_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - }; for (i, &byte) in bytes.iter().enumerate() { + let Some(idx) = offset.checked_add(i) else { + break; + }; try_hook!(thread.heap_mut().set_array_element( buffer_ref, - offset + i, + idx, EmValue::I32(i32::from(byte)), )); } @@ -1090,9 +1102,15 @@ fn crypto_stream_write_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) return PreHookResult::Bypass(None); }; - let end = (offset + count).min(buffer_data.len()); + let end = offset + .checked_add(count) + .unwrap_or(buffer_data.len()) + .min(buffer_data.len()); let bytes_to_write = if offset < buffer_data.len() { - buffer_data[offset..end].to_vec() + let Some(slice) = buffer_data.get(offset..end) else { + return PreHookResult::Bypass(None); + }; + slice.to_vec() } else { return PreHookResult::Bypass(None); }; diff --git a/dotscope/src/emulation/runtime/bcl/interop/marshal.rs b/dotscope/src/emulation/runtime/bcl/interop/marshal.rs index 1500bab3..c92cb49b 100644 --- a/dotscope/src/emulation/runtime/bcl/interop/marshal.rs +++ b/dotscope/src/emulation/runtime/bcl/interop/marshal.rs @@ -413,12 +413,17 @@ fn marshal_get_hinstance_pre( /// - `startIndex`: Starting index in the array /// - `length`: Number of elements to copy fn marshal_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 4 { + let (Some(arg0), Some(arg1), Some(arg2), Some(arg3)) = ( + ctx.args.first(), + ctx.args.get(1), + ctx.args.get(2), + ctx.args.get(3), + ) else { return PreHookResult::Bypass(None); - } + }; // Check first arg type to determine which overload - let src_addr = match &ctx.args[0] { + let src_addr = match arg0 { EmValue::UnmanagedPtr(a) => Some(*a), EmValue::NativeInt(a) => Some((*a).cast_unsigned()), _ => None, @@ -426,24 +431,27 @@ fn marshal_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreH if let Some(src_addr) = src_addr { // Overload: Copy(IntPtr source, byte[] dest, int startIndex, int length) - let dst_ref = match &ctx.args[1] { + let dst_ref = match arg1 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let start_idx = match &ctx.args[2] { + let start_idx = match arg2 { EmValue::I32(v) => (*v).cast_unsigned() as usize, _ => return PreHookResult::Bypass(None), }; - let length = match &ctx.args[3] { + let length = match arg3 { EmValue::I32(v) => (*v).cast_unsigned() as usize, _ => return PreHookResult::Bypass(None), }; if let Ok(bytes) = thread.address_space().read(src_addr, length) { for (i, &byte) in bytes.iter().enumerate() { + let Some(idx) = start_idx.checked_add(i) else { + return PreHookResult::Bypass(None); + }; try_hook!(thread.heap_mut().set_array_element( dst_ref, - start_idx + i, + idx, EmValue::I32(i32::from(byte)), )); } @@ -452,29 +460,32 @@ fn marshal_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreH } // Overload: Copy(byte[] source, int startIndex, IntPtr dest, int length) - let EmValue::ObjectRef(src_ref) = &ctx.args[0] else { + let EmValue::ObjectRef(src_ref) = arg0 else { return PreHookResult::Bypass(None); }; - let start_idx = match &ctx.args[1] { + let start_idx = match arg1 { EmValue::I32(v) => (*v).cast_unsigned() as usize, _ => return PreHookResult::Bypass(None), }; - let dest_addr = match &ctx.args[2] { + let dest_addr = match arg2 { EmValue::UnmanagedPtr(a) => *a, EmValue::NativeInt(a) => (*a).cast_unsigned(), _ => return PreHookResult::Bypass(None), }; - let length = match &ctx.args[3] { + let length = match arg3 { EmValue::I32(v) => (*v).cast_unsigned() as usize, _ => return PreHookResult::Bypass(None), }; let mut bytes = Vec::with_capacity(length); for i in 0..length { + let Some(idx) = start_idx.checked_add(i) else { + return PreHookResult::Bypass(None); + }; #[allow(clippy::cast_possible_truncation)] let byte_val = thread .heap() - .get_array_element(*src_ref, start_idx + i) + .get_array_element(*src_ref, idx) .map(|elem| match elem { EmValue::I32(v) => v.cast_unsigned() as u8, _ => 0, @@ -508,7 +519,8 @@ fn marshal_read_byte_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> }; if let Ok(bytes) = thread.address_space().read(addr, 1) { - PreHookResult::Bypass(Some(EmValue::I32(i32::from(bytes[0])))) + let byte = bytes.first().copied().unwrap_or(0); + PreHookResult::Bypass(Some(EmValue::I32(i32::from(byte)))) } else { PreHookResult::Bypass(Some(EmValue::I32(0))) } @@ -537,7 +549,10 @@ fn marshal_read_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - }; if let Ok(bytes) = thread.address_space().read(addr, 4) { - let value = i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]]); + let Some(slice) = bytes.get(0..4).and_then(|s| <[u8; 4]>::try_from(s).ok()) else { + return PreHookResult::Bypass(Some(EmValue::I32(0))); + }; + let value = i32::from_le_bytes(slice); PreHookResult::Bypass(Some(EmValue::I32(value))) } else { PreHookResult::Bypass(Some(EmValue::I32(0))) @@ -561,18 +576,18 @@ fn marshal_read_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - /// - `val`: Byte value to write /// - `o`: Object in unmanaged memory to write to (overload 3) fn marshal_write_byte_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; - let addr = match &ctx.args[0] { + let addr = match arg0 { EmValue::UnmanagedPtr(a) => *a, EmValue::NativeInt(a) => (*a).cast_unsigned(), _ => return PreHookResult::Bypass(None), }; #[allow(clippy::cast_possible_truncation)] - let value = match &ctx.args[1] { + let value = match arg1 { EmValue::I32(v) => (*v).cast_unsigned() as u8, _ => return PreHookResult::Bypass(None), }; @@ -598,15 +613,15 @@ fn marshal_write_byte_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - /// - `val`: 32-bit value to write /// - `o`: Object in unmanaged memory to write to (overload 3) fn marshal_write_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; - let Some(addr) = resolve_address(&ctx.args[0], thread) else { + let Some(addr) = resolve_address(arg0, thread) else { return PreHookResult::Bypass(None); }; - let value = match &ctx.args[1] { + let value = match arg1 { EmValue::I32(v) => *v, _ => return PreHookResult::Bypass(None), }; @@ -637,9 +652,10 @@ fn marshal_read_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - let addr = (base_addr as i64).wrapping_add(offset).cast_unsigned(); if let Ok(bytes) = thread.address_space().read(addr, 8) { - let value = i64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ]); + let Some(slice) = bytes.get(0..8).and_then(|s| <[u8; 8]>::try_from(s).ok()) else { + return PreHookResult::Bypass(Some(EmValue::I64(0))); + }; + let value = i64::from_le_bytes(slice); PreHookResult::Bypass(Some(EmValue::I64(value))) } else { PreHookResult::Bypass(Some(EmValue::I64(0))) @@ -670,7 +686,10 @@ fn marshal_read_int16_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - let addr = (base_addr as i64).wrapping_add(offset).cast_unsigned(); if let Ok(bytes) = thread.address_space().read(addr, 2) { - let value = i16::from_le_bytes([bytes[0], bytes[1]]); + let Some(slice) = bytes.get(0..2).and_then(|s| <[u8; 2]>::try_from(s).ok()) else { + return PreHookResult::Bypass(Some(EmValue::I32(0))); + }; + let value = i16::from_le_bytes(slice); PreHookResult::Bypass(Some(EmValue::I32(i32::from(value)))) } else { PreHookResult::Bypass(Some(EmValue::I32(0))) @@ -703,11 +722,15 @@ fn marshal_read_intptr_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) let ptr_size = ctx.pointer_size.bytes(); if let Ok(bytes) = thread.address_space().read(addr, ptr_size) { let value = if ptr_size == 8 { - i64::from_le_bytes([ - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7], - ]) + let Some(slice) = bytes.get(0..8).and_then(|s| <[u8; 8]>::try_from(s).ok()) else { + return PreHookResult::Bypass(Some(EmValue::NativeInt(0))); + }; + i64::from_le_bytes(slice) } else { - i64::from(i32::from_le_bytes([bytes[0], bytes[1], bytes[2], bytes[3]])) + let Some(slice) = bytes.get(0..4).and_then(|s| <[u8; 4]>::try_from(s).ok()) else { + return PreHookResult::Bypass(Some(EmValue::NativeInt(0))); + }; + i64::from(i32::from_le_bytes(slice)) }; PreHookResult::Bypass(Some(EmValue::NativeInt(value))) } else { @@ -724,25 +747,25 @@ fn marshal_read_intptr_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) /// - `Marshal.WriteInt64(IntPtr, Int64) -> void` /// - `Marshal.WriteInt64(IntPtr, Int32, Int64) -> void` fn marshal_write_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; - let Some(addr) = resolve_address(&ctx.args[0], thread) else { + let Some(addr) = resolve_address(arg0, thread) else { return PreHookResult::Bypass(None); }; // Two-arg form: WriteInt64(IntPtr, Int64) // Three-arg form: WriteInt64(IntPtr, Int32 offset, Int64) - let (offset, value) = if ctx.args.len() >= 3 { - let ofs = ctx.args[1].as_i32().map(i64::from).unwrap_or(0); - let val = match &ctx.args[2] { + let (offset, value) = if let Some(arg2) = ctx.args.get(2) { + let ofs = arg1.as_i32().map(i64::from).unwrap_or(0); + let val = match arg2 { EmValue::I64(v) | EmValue::NativeInt(v) => *v, _ => return PreHookResult::Bypass(None), }; (ofs, val) } else { - let val = match &ctx.args[1] { + let val = match arg1 { EmValue::I64(v) | EmValue::NativeInt(v) => *v, _ => return PreHookResult::Bypass(None), }; @@ -763,24 +786,24 @@ fn marshal_write_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) /// - `Marshal.WriteInt16(IntPtr, Int16) -> void` /// - `Marshal.WriteInt16(IntPtr, Int32, Int16) -> void` fn marshal_write_int16_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; - let Some(addr) = resolve_address(&ctx.args[0], thread) else { + let Some(addr) = resolve_address(arg0, thread) else { return PreHookResult::Bypass(None); }; // 3-arg overload: WriteInt16(IntPtr ptr, Int32 offset, Int16 value) // 2-arg overload: WriteInt16(IntPtr ptr, Int16 value) - let (offset, value_arg) = if ctx.args.len() >= 3 { - let off = match &ctx.args[1] { + let (offset, value_arg) = if let Some(arg2) = ctx.args.get(2) { + let off = match arg1 { EmValue::I32(v) => i64::from(*v), _ => 0, }; - (off, &ctx.args[2]) + (off, arg2) } else { - (0i64, &ctx.args[1]) + (0i64, arg1) }; #[allow(clippy::cast_possible_truncation)] @@ -803,19 +826,19 @@ fn marshal_write_int16_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) /// - `Marshal.WriteIntPtr(IntPtr, IntPtr) -> void` /// - `Marshal.WriteIntPtr(IntPtr, Int32, IntPtr) -> void` fn marshal_write_intptr_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; - let Some(addr) = resolve_address(&ctx.args[0], thread) else { + let Some(addr) = resolve_address(arg0, thread) else { return PreHookResult::Bypass(None); }; // Two-arg form: WriteIntPtr(IntPtr, IntPtr) // Three-arg form: WriteIntPtr(IntPtr, Int32 offset, IntPtr) - let (offset, value) = if ctx.args.len() >= 3 { - let ofs = ctx.args[1].as_i32().map(i64::from).unwrap_or(0); - let val = match &ctx.args[2] { + let (offset, value) = if let Some(arg2) = ctx.args.get(2) { + let ofs = arg1.as_i32().map(i64::from).unwrap_or(0); + let val = match arg2 { EmValue::NativeInt(v) | EmValue::I64(v) => *v, EmValue::NativeUInt(v) | EmValue::UnmanagedPtr(v) => (*v).cast_signed(), EmValue::I32(v) => i64::from(*v), @@ -823,7 +846,7 @@ fn marshal_write_intptr_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) }; (ofs, val) } else { - let val = match &ctx.args[1] { + let val = match arg1 { EmValue::NativeInt(v) | EmValue::I64(v) => *v, EmValue::NativeUInt(v) | EmValue::UnmanagedPtr(v) => (*v).cast_signed(), EmValue::I32(v) => i64::from(*v), @@ -1176,17 +1199,17 @@ fn intptr_op_explicit_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) /// - `pointer`: The pointer to add to /// - `offset`: The offset to add fn intptr_add_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::NativeInt(0))); - } + }; - let ptr = match &ctx.args[0] { + let ptr = match arg0 { EmValue::NativeInt(v) => *v, EmValue::UnmanagedPtr(v) => (*v).cast_signed(), _ => return PreHookResult::Bypass(Some(EmValue::NativeInt(0))), }; - let offset = match &ctx.args[1] { + let offset = match arg1 { EmValue::I32(v) => i64::from(*v), EmValue::I64(v) => *v, _ => return PreHookResult::Bypass(Some(EmValue::NativeInt(ptr))), diff --git a/dotscope/src/emulation/runtime/bcl/io/binaryreader.rs b/dotscope/src/emulation/runtime/bcl/io/binaryreader.rs index b6b84ba8..dfef5a06 100644 --- a/dotscope/src/emulation/runtime/bcl/io/binaryreader.rs +++ b/dotscope/src/emulation/runtime/bcl/io/binaryreader.rs @@ -470,18 +470,15 @@ fn binary_reader_read_string_pre( let Some(result) = try_hook!(thread.heap().with_stream(stream_ref, |data, position| { // Read 7-bit encoded length (LEB128-style) let mut length: usize = 0; - let mut shift = 0; + let mut shift: u32 = 0; loop { - if *position >= data.len() { - return None; - } - let byte = data[*position]; - *position += 1; + let &byte = data.get(*position)?; + *position = position.saturating_add(1); length |= ((byte & 0x7F) as usize) << shift; if byte & 0x80 == 0 { break; } - shift += 7; + shift = shift.saturating_add(7); // Prevent infinite loop on malformed data if shift > 35 { break; @@ -489,12 +486,10 @@ fn binary_reader_read_string_pre( } // Read the string bytes — return None if not enough data - if *position + length > data.len() { - return None; - } - - let s = String::from_utf8_lossy(&data[*position..*position + length]).into_owned(); - *position += length; + let end = position.checked_add(length)?; + let slice = data.get(*position..end)?; + let s = String::from_utf8_lossy(slice).into_owned(); + *position = end; Some(s) })) else { return PreHookResult::throw_end_of_stream(); @@ -537,15 +532,17 @@ fn binary_reader_read_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - 0 => { let Some(result) = try_hook!(thread.heap().with_stream(stream_ref, |data, position| { - if *position >= data.len() { + let Some(remaining) = data.get(*position..) else { + return -1_i32; + }; + if remaining.is_empty() { return -1_i32; } // Read a single UTF-8 character - let remaining = &data[*position..]; let s = String::from_utf8_lossy(remaining); if let Some(ch) = s.chars().next() { - *position += ch.len_utf8(); + *position = position.saturating_add(ch.len_utf8()); #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] { ch as i32 @@ -576,9 +573,12 @@ fn binary_reader_read_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - }; for (i, &byte) in bytes.iter().enumerate() { + let Some(idx) = offset.checked_add(i) else { + break; + }; try_hook!(thread.heap_mut().set_array_element( buffer_ref, - offset + i, + idx, EmValue::I32(i32::from(byte)), )); } @@ -675,15 +675,15 @@ fn binary_reader_read_char_pre( }; let Some(result) = try_hook!(thread.heap().with_stream(stream_ref, |data, position| { - if *position >= data.len() { + let remaining = data.get(*position..)?; + if remaining.is_empty() { return None; } // Decode one UTF-8 character from the stream - let remaining = &data[*position..]; let s = String::from_utf8_lossy(remaining); if let Some(ch) = s.chars().next() { - *position += ch.len_utf8(); + *position = position.saturating_add(ch.len_utf8()); #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] Some(ch as i32) } else { @@ -735,14 +735,16 @@ fn binary_reader_read_chars_pre( // Decode `count` UTF-8 characters let mut chars = Vec::with_capacity(count); for _ in 0..count { - if *position >= data.len() { + let Some(remaining) = data.get(*position..) else { + break; + }; + if remaining.is_empty() { break; } - let remaining = &data[*position..]; let s = String::from_utf8_lossy(remaining); if let Some(ch) = s.chars().next() { chars.push(ch); - *position += ch.len_utf8(); + *position = position.saturating_add(ch.len_utf8()); } else { break; } @@ -1023,22 +1025,21 @@ fn binary_reader_read_7bit_encoded_int_pre( } let mut value: u32 = 0; - let mut shift = 0; + let mut shift: u32 = 0; loop { - if *position >= data.len() { - return Read7BitResult::EndOfStream; - } if shift > 35 { // .NET throws FormatException for too many bytes return Read7BitResult::FormatError; } - let byte = data[*position]; - *position += 1; + let Some(&byte) = data.get(*position) else { + return Read7BitResult::EndOfStream; + }; + *position = position.saturating_add(1); value |= u32::from(byte & 0x7F) << shift; if byte & 0x80 == 0 { break; } - shift += 7; + shift = shift.saturating_add(7); } Read7BitResult::Ok(value) })) else { @@ -1080,12 +1081,14 @@ fn binary_reader_peek_char_pre( // Peek without advancing position let Some(result) = try_hook!(thread.heap().with_stream(stream_ref, |data, position| { - if *position >= data.len() { + let Some(remaining) = data.get(*position..) else { + return -1_i32; + }; + if remaining.is_empty() { return -1_i32; } // Decode one UTF-8 character without advancing position - let remaining = &data[*position..]; let s = String::from_utf8_lossy(remaining); if let Some(ch) = s.chars().next() { #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] diff --git a/dotscope/src/emulation/runtime/bcl/io/binarywriter.rs b/dotscope/src/emulation/runtime/bcl/io/binarywriter.rs index e0c75c56..06d705ee 100644 --- a/dotscope/src/emulation/runtime/bcl/io/binarywriter.rs +++ b/dotscope/src/emulation/runtime/bcl/io/binarywriter.rs @@ -198,8 +198,8 @@ fn binary_writer_write_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) // Try to use param_types for precise overload dispatch if let Some(param_types) = ctx.param_types { - if param_types.len() == 1 { - match param_types[0] { + if let Some(first_type) = param_types.first().filter(|_| param_types.len() == 1) { + match first_type { // Write(Boolean) — 1 byte: 0 or 1 CilFlavor::Boolean => { let v = match ctx.args.first() { @@ -384,11 +384,11 @@ fn binary_writer_write_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) let offset = *offset as usize; let count = *count as usize; if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*arr_ref)) { - let end = (offset + count).min(bytes.len()); + let end = offset.saturating_add(count).min(bytes.len()); if offset < bytes.len() { - try_hook!(thread - .heap_mut() - .write_to_stream(stream_ref, &bytes[offset..end])); + if let Some(slice) = bytes.get(offset..end) { + try_hook!(thread.heap_mut().write_to_stream(stream_ref, slice)); + } } } } @@ -468,12 +468,12 @@ fn binary_writer_seek_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - let new_pos = match origin { 1 => { #[allow(clippy::cast_possible_wrap)] - let pos = current_pos as i64 + offset; + let pos = (current_pos as i64).saturating_add(offset); pos.max(0) as usize } 2 => { #[allow(clippy::cast_possible_wrap)] - let pos = length as i64 + offset; + let pos = (length as i64).saturating_add(offset); pos.max(0) as usize } _ => offset.max(0) as usize, diff --git a/dotscope/src/emulation/runtime/bcl/io/compression.rs b/dotscope/src/emulation/runtime/bcl/io/compression.rs index 4279c70c..ecbbdebd 100644 --- a/dotscope/src/emulation/runtime/bcl/io/compression.rs +++ b/dotscope/src/emulation/runtime/bcl/io/compression.rs @@ -199,11 +199,7 @@ fn compressed_stream_read_pre( }; // Use data from the underlying stream's current position, not from offset 0 - let effective_data = if underlying_pos < compressed_data.len() { - &compressed_data[underlying_pos..] - } else { - &[] - }; + let effective_data = compressed_data.get(underlying_pos..).unwrap_or(&[]); let decompressed = match compression_type { 0 => decompress_deflate(effective_data).ok(), @@ -224,7 +220,7 @@ fn compressed_stream_read_pre( for (i, &byte) in bytes.iter().enumerate() { try_hook!(thread.heap_mut().set_array_element( buffer_ref, - offset + i, + offset.saturating_add(i), EmValue::I32(i32::from(byte)), )); } diff --git a/dotscope/src/emulation/runtime/bcl/io/filestream.rs b/dotscope/src/emulation/runtime/bcl/io/filestream.rs index 8f9de37b..dbc78cde 100644 --- a/dotscope/src/emulation/runtime/bcl/io/filestream.rs +++ b/dotscope/src/emulation/runtime/bcl/io/filestream.rs @@ -307,16 +307,19 @@ fn extract_nth_string_arg( /// Extracts the filename portion from a path (after last `\` or `/`). fn path_filename(path: &str) -> &str { - path.rfind(['\\', '/']).map_or(path, |pos| &path[pos + 1..]) + match path.rfind(['\\', '/']) { + Some(pos) => path.get(pos.saturating_add(1)..).unwrap_or(path), + None => path, + } } /// Returns true if a path is rooted (starts with drive letter like `C:` or `\` or `/`). fn is_rooted(path: &str) -> bool { let bytes = path.as_bytes(); - if bytes.is_empty() { + let Some(&first) = bytes.first() else { return false; - } - bytes[0] == b'\\' || bytes[0] == b'/' || (bytes.len() >= 2 && bytes[1] == b':') + }; + first == b'\\' || first == b'/' || bytes.get(1).is_some_and(|&b| b == b':') } /// Allocates a string on the heap and returns a bypass result, or Null on error. @@ -684,14 +687,15 @@ fn path_get_path_root_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - None => return PreHookResult::Bypass(Some(EmValue::Null)), }; let bytes = path.as_bytes(); - let root = if bytes.len() >= 3 && bytes[1] == b':' && (bytes[2] == b'\\' || bytes[2] == b'/') { + let root = if bytes.get(1) == Some(&b':') && matches!(bytes.get(2), Some(&b'\\') | Some(&b'/')) + { // Drive letter root: "C:\" - &path[..3] - } else if bytes.len() >= 2 && bytes[1] == b':' { + path.get(..3).unwrap_or("") + } else if bytes.get(1) == Some(&b':') { // Drive letter without trailing sep: "C:" - &path[..2] - } else if !bytes.is_empty() && (bytes[0] == b'\\' || bytes[0] == b'/') { - &path[..1] + path.get(..2).unwrap_or("") + } else if matches!(bytes.first(), Some(&b'\\') | Some(&b'/')) { + path.get(..1).unwrap_or("") } else { "" }; @@ -883,11 +887,7 @@ fn streamreader_read_to_end_pre( // Try plain Stream first (most common path) if let Some(text) = try_hook!(thread.heap().with_stream(stream_ref, |data, position| { - let remaining = if *position < data.len() { - &data[*position..] - } else { - &[] - }; + let remaining: &[u8] = data.get(*position..).unwrap_or(&[]); let text = String::from_utf8_lossy(remaining).into_owned(); *position = data.len(); // Advance to end text @@ -909,11 +909,7 @@ fn streamreader_read_to_end_pre( let decrypted = if let Some((data, pos)) = try_hook!(thread.heap().get_crypto_stream_transformed(stream_ref)) { - if pos < data.len() { - data[pos..].to_vec() - } else { - vec![] - } + data.get(pos..).map(<[u8]>::to_vec).unwrap_or_default() } else { // No cached data — perform the crypto transform now let Some((stream_data, underlying_pos)) = @@ -922,11 +918,7 @@ fn streamreader_read_to_end_pre( return PreHookResult::Bypass(Some(EmValue::Null)); }; - let effective_data = if underlying_pos < stream_data.len() { - &stream_data[underlying_pos..] - } else { - &[] - }; + let effective_data: &[u8] = stream_data.get(underlying_pos..).unwrap_or(&[]); let transformed = if let Some((algorithm, key, iv, is_encryptor, cmode, padding)) = try_hook!(thread.heap().get_crypto_transform_info(transform_ref)) @@ -979,21 +971,25 @@ fn streamreader_read_line_pre( return None; // EOF } - let remaining = &data[*position..]; + let remaining: &[u8] = data.get(*position..)?; let (line_bytes, advance) = if let Some(nl_pos) = remaining.iter().position(|&b| b == b'\n') { - let line_end = if nl_pos > 0 && remaining[nl_pos - 1] == b'\r' { - nl_pos - 1 // Strip \r from \r\n - } else { - nl_pos - }; - (&remaining[..line_end], nl_pos + 1) // Skip past the \n + let line_end = + if nl_pos > 0 && remaining.get(nl_pos.saturating_sub(1)).copied() == Some(b'\r') { + nl_pos.saturating_sub(1) // Strip \r from \r\n + } else { + nl_pos + }; + ( + remaining.get(..line_end).unwrap_or(&[]), + nl_pos.saturating_add(1), + ) // Skip past the \n } else { // No newline found — return rest of stream (remaining, remaining.len()) }; - *position += advance; + *position = position.saturating_add(advance); Some(String::from_utf8_lossy(line_bytes).into_owned()) })) { Some(opt) => opt, @@ -1033,11 +1029,7 @@ fn streamreader_peek_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> // Zero-copy peek — read byte at current position without advancing let value = try_hook!(thread.heap().with_stream(stream_ref, |data, position| { - if *position < data.len() { - data[*position] as i32 - } else { - -1 - } + data.get(*position).map_or(-1, |&b| b as i32) })); PreHookResult::Bypass(Some(EmValue::I32(value.unwrap_or(-1)))) diff --git a/dotscope/src/emulation/runtime/bcl/io/stream.rs b/dotscope/src/emulation/runtime/bcl/io/stream.rs index 3755846a..0a29b53c 100644 --- a/dotscope/src/emulation/runtime/bcl/io/stream.rs +++ b/dotscope/src/emulation/runtime/bcl/io/stream.rs @@ -392,7 +392,7 @@ fn stream_read_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHo for (i, &byte) in bytes.iter().enumerate() { try_hook!(thread.heap_mut().set_array_element( buffer_ref, - offset + i, + offset.saturating_add(i), EmValue::I32(i32::from(byte)), )); } @@ -460,9 +460,9 @@ fn stream_write_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreH }; // Extract the slice to write - let end = (offset + count).min(buffer_data.len()); - let bytes_to_write = if offset < buffer_data.len() { - &buffer_data[offset..end] + let end = offset.saturating_add(count).min(buffer_data.len()); + let bytes_to_write: &[u8] = if offset < buffer_data.len() { + buffer_data.get(offset..end).unwrap_or(&[]) } else { &[] }; @@ -640,14 +640,14 @@ fn stream_seek_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHo // Current // Safe: stream positions fit in i64 #[allow(clippy::cast_possible_wrap)] - let pos = current_pos as i64 + offset; + let pos = (current_pos as i64).saturating_add(offset); pos.max(0) as usize } 2 => { // End // Safe: stream positions fit in i64 #[allow(clippy::cast_possible_wrap)] - let pos = length as i64 + offset; + let pos = (length as i64).saturating_add(offset); pos.max(0) as usize } // Begin (0) and default @@ -764,7 +764,7 @@ fn stream_copy_to_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr // original CopyTo behavior — read remaining bytes, advance position. if let Some((data, position)) = try_hook!(thread.heap().get_stream_data(src_ref)) { if position < data.len() { - let remaining = &data[position..]; + let remaining: &[u8] = data.get(position..).unwrap_or(&[]); try_hook!(thread.heap_mut().write_to_stream(dst_ref, remaining)); } let len = data.len(); @@ -787,8 +787,8 @@ fn stream_copy_to_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr else { return PreHookResult::Bypass(None); }; - let effective = if underlying_pos < compressed_data.len() { - &compressed_data[underlying_pos..] + let effective: &[u8] = if underlying_pos < compressed_data.len() { + compressed_data.get(underlying_pos..).unwrap_or(&[]) } else { &[] }; @@ -823,8 +823,8 @@ fn stream_copy_to_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr else { return PreHookResult::Bypass(None); }; - let effective = if underlying_pos < stream_data.len() { - &stream_data[underlying_pos..] + let effective: &[u8] = if underlying_pos < stream_data.len() { + stream_data.get(underlying_pos..).unwrap_or(&[]) } else { &[] }; diff --git a/dotscope/src/emulation/runtime/bcl/reflection/helpers.rs b/dotscope/src/emulation/runtime/bcl/reflection/helpers.rs index c479ca51..bd599d82 100644 --- a/dotscope/src/emulation/runtime/bcl/reflection/helpers.rs +++ b/dotscope/src/emulation/runtime/bcl/reflection/helpers.rs @@ -94,7 +94,7 @@ pub(crate) fn find_method_by_name(asm: &CilObject, type_token: Token, name: &str if let Some(method) = method_weak.upgrade() { if method.name == name { let param_count = method.signature.params.len(); - if best.is_none() || param_count < best.unwrap().1 { + if best.is_none_or(|(_, n)| param_count < n) { best = Some((method.token, param_count)); } } diff --git a/dotscope/src/emulation/runtime/bcl/reflection/members.rs b/dotscope/src/emulation/runtime/bcl/reflection/members.rs index b7f0b352..7a4d87dd 100644 --- a/dotscope/src/emulation/runtime/bcl/reflection/members.rs +++ b/dotscope/src/emulation/runtime/bcl/reflection/members.rs @@ -842,7 +842,7 @@ mod tests { runtime::hook::{HookContext, PreHookResult}, EmValue, }, - metadata::typesystem::PointerSize, + metadata::{token::Token, typesystem::PointerSize}, test::emulation::create_test_thread, }; @@ -851,7 +851,7 @@ mod tests { #[test] fn test_field_get_value_hook() { let ctx = HookContext::new( - crate::metadata::token::Token::new(0x0A000001), + Token::new(0x0A000001), "System.Reflection", "FieldInfo", "GetValue", diff --git a/dotscope/src/emulation/runtime/bcl/reflection/methods.rs b/dotscope/src/emulation/runtime/bcl/reflection/methods.rs index 9aced623..1fe18866 100644 --- a/dotscope/src/emulation/runtime/bcl/reflection/methods.rs +++ b/dotscope/src/emulation/runtime/bcl/reflection/methods.rs @@ -936,13 +936,13 @@ fn il_generator_emit_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> /// operand type from the instruction tables. fn build_immediate_operand(opcode: u16, value: i32) -> Option { let op_type = if opcode < u16::from(INSTRUCTIONS_MAX) { - INSTRUCTIONS[opcode as usize].op_type + INSTRUCTIONS.get(opcode as usize)?.op_type } else if opcode >= 0xFE00 { let sub = (opcode & 0xFF) as usize; if sub >= usize::from(INSTRUCTIONS_FE_MAX) { return None; } - INSTRUCTIONS_FE[sub].op_type + INSTRUCTIONS_FE.get(sub)?.op_type } else { return None; }; @@ -1332,7 +1332,7 @@ mod tests { runtime::hook::{HookContext, PreHookResult}, EmValue, }, - metadata::typesystem::PointerSize, + metadata::{token::Token, typesystem::PointerSize}, test::emulation::create_test_thread, }; @@ -1341,7 +1341,7 @@ mod tests { #[test] fn test_method_invoke_hook() { let ctx = HookContext::new( - crate::metadata::token::Token::new(0x0A000001), + Token::new(0x0A000001), "System.Reflection", "MethodBase", "Invoke", diff --git a/dotscope/src/emulation/runtime/bcl/reflection/types.rs b/dotscope/src/emulation/runtime/bcl/reflection/types.rs index 9134b95e..e21016be 100644 --- a/dotscope/src/emulation/runtime/bcl/reflection/types.rs +++ b/dotscope/src/emulation/runtime/bcl/reflection/types.rs @@ -2244,21 +2244,18 @@ fn delegate_combine_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> let args = ctx.args; // Extract the two delegate arguments - let (a_ref, b_ref) = match args.len() { - 2 => (args[0].as_object_ref(), args[1].as_object_ref()), - _ => return PreHookResult::Bypass(Some(EmValue::Null)), + let (Some(arg0), Some(arg1)) = (args.first(), args.get(1)) else { + return PreHookResult::Bypass(Some(EmValue::Null)); }; + let (a_ref_opt, b_ref_opt) = (arg0.as_object_ref(), arg1.as_object_ref()); // Null-safe: Combine(null, x) = x, Combine(x, null) = x - if a_ref.is_none() { - return PreHookResult::Bypass(Some(args[1].clone())); - } - if b_ref.is_none() { - return PreHookResult::Bypass(Some(args[0].clone())); - } - - let a_ref = a_ref.unwrap(); - let b_ref = b_ref.unwrap(); + let Some(a_ref) = a_ref_opt else { + return PreHookResult::Bypass(Some(arg1.clone())); + }; + let Some(b_ref) = b_ref_opt else { + return PreHookResult::Bypass(Some(arg0.clone())); + }; // Get invocation lists from both delegates let (type_token, mut entries) = match thread.heap().get(a_ref) { @@ -2266,14 +2263,14 @@ fn delegate_combine_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> type_token, invocation_list, }) => (type_token, invocation_list), - _ => return PreHookResult::Bypass(Some(args[0].clone())), + _ => return PreHookResult::Bypass(Some(arg0.clone())), }; let b_entries = match thread.heap().get(b_ref) { Ok(HeapObject::Delegate { invocation_list, .. }) => invocation_list, - _ => return PreHookResult::Bypass(Some(args[0].clone())), + _ => return PreHookResult::Bypass(Some(arg0.clone())), }; // Concatenate: a's entries followed by b's entries @@ -2295,29 +2292,29 @@ fn delegate_combine_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> fn delegate_remove_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { let args = ctx.args; + let (Some(arg0), Some(arg1)) = (args.first(), args.get(1)) else { + return PreHookResult::Bypass(Some(EmValue::Null)); + }; if args.len() != 2 { return PreHookResult::Bypass(Some(EmValue::Null)); } - let source_ref = args[0].as_object_ref(); - let value_ref = args[1].as_object_ref(); + let source_ref_opt = arg0.as_object_ref(); + let value_ref_opt = arg1.as_object_ref(); - if source_ref.is_none() { + let Some(source_ref) = source_ref_opt else { return PreHookResult::Bypass(Some(EmValue::Null)); - } - if value_ref.is_none() { - return PreHookResult::Bypass(Some(args[0].clone())); - } - - let source_ref = source_ref.unwrap(); - let value_ref = value_ref.unwrap(); + }; + let Some(value_ref) = value_ref_opt else { + return PreHookResult::Bypass(Some(arg0.clone())); + }; let (type_token, mut entries) = match thread.heap().get(source_ref) { Ok(HeapObject::Delegate { type_token, invocation_list, }) => (type_token, invocation_list), - _ => return PreHookResult::Bypass(Some(args[0].clone())), + _ => return PreHookResult::Bypass(Some(arg0.clone())), }; let remove_token = match thread.heap().get(value_ref) { @@ -2327,10 +2324,10 @@ fn delegate_remove_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P if let Some(entry) = invocation_list.last() { entry.method_token } else { - return PreHookResult::Bypass(Some(args[0].clone())); + return PreHookResult::Bypass(Some(arg0.clone())); } } - _ => return PreHookResult::Bypass(Some(args[0].clone())), + _ => return PreHookResult::Bypass(Some(arg0.clone())), }; // Remove last matching entry @@ -2666,7 +2663,7 @@ mod tests { runtime::hook::{HookContext, PreHookResult}, EmValue, }, - metadata::typesystem::PointerSize, + metadata::{token::Token, typesystem::PointerSize}, test::emulation::create_test_thread, }; @@ -2675,7 +2672,7 @@ mod tests { #[test] fn test_get_module_hook() { let ctx = HookContext::new( - crate::metadata::token::Token::new(0x0A000001), + Token::new(0x0A000001), "System", "Type", "get_Module", @@ -2695,7 +2692,7 @@ mod tests { fn test_get_type_from_handle_hook_with_arg() { let args = [EmValue::NativeInt(0x0200_0001)]; let ctx = HookContext::new( - crate::metadata::token::Token::new(0x0A000001), + Token::new(0x0A000001), "System", "Type", "GetTypeFromHandle", @@ -2715,7 +2712,7 @@ mod tests { #[test] fn test_get_type_from_handle_hook_no_arg_throws() { let ctx = HookContext::new( - crate::metadata::token::Token::new(0x0A000001), + Token::new(0x0A000001), "System", "Type", "GetTypeFromHandle", diff --git a/dotscope/src/emulation/runtime/bcl/runtime.rs b/dotscope/src/emulation/runtime/bcl/runtime.rs index fb737c89..a332a23b 100644 --- a/dotscope/src/emulation/runtime/bcl/runtime.rs +++ b/dotscope/src/emulation/runtime/bcl/runtime.rs @@ -339,19 +339,19 @@ fn runtime_helpers_initialize_array_pre( // args[0] = array (ObjectRef) // args[1] = field handle (NativeInt/I32 containing token value) - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; // Get the array reference - let array_ref = match &ctx.args[0] { + let array_ref = match arg0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; // Get the field token from the RuntimeFieldHandle #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let field_token = match &ctx.args[1] { + let field_token = match arg1 { EmValue::I32(v) => (*v).cast_unsigned(), EmValue::NativeInt(v) | EmValue::I64(v) => *v as u32, _ => return PreHookResult::Bypass(None), @@ -414,74 +414,75 @@ fn runtime_helpers_initialize_array_pre( return PreHookResult::Bypass(None); }; - // Calculate total bytes to read - let total_bytes = array_len * element_size; - let bytes_to_read = total_bytes.min(pe_data.len() - file_offset); + // Calculate total bytes to read (saturating: out-of-range simply truncates the copy) + let total_bytes = array_len.saturating_mul(element_size); + let remaining = pe_data.len().saturating_sub(file_offset); + let bytes_to_read = total_bytes.min(remaining); // Read the bytes from the PE file - let data = &pe_data[file_offset..file_offset + bytes_to_read]; + let Some(end) = file_offset.checked_add(bytes_to_read) else { + return PreHookResult::Bypass(None); + }; + let Some(data) = pe_data.get(file_offset..end) else { + return PreHookResult::Bypass(None); + }; // Set each element in the array based on element type for i in 0..array_len { - let byte_offset = i * element_size; - if byte_offset + element_size > data.len() { + let Some(byte_offset) = i.checked_mul(element_size) else { break; - } + }; + let Some(end_offset) = byte_offset.checked_add(element_size) else { + break; + }; + let Some(elem_bytes) = data.get(byte_offset..end_offset) else { + break; + }; let value = match element_type { CilFlavor::Boolean | CilFlavor::I1 | CilFlavor::U1 => { - EmValue::I32(i32::from(data[byte_offset])) + let Some(b) = elem_bytes.first() else { + break; + }; + EmValue::I32(i32::from(*b)) } CilFlavor::Char | CilFlavor::I2 | CilFlavor::U2 => { - let bytes = [data[byte_offset], data[byte_offset + 1]]; + let Ok(bytes) = <[u8; 2]>::try_from(elem_bytes) else { + break; + }; EmValue::I32(i32::from(i16::from_le_bytes(bytes))) } CilFlavor::I4 | CilFlavor::U4 => { - let bytes = [ - data[byte_offset], - data[byte_offset + 1], - data[byte_offset + 2], - data[byte_offset + 3], - ]; + let Ok(bytes) = <[u8; 4]>::try_from(elem_bytes) else { + break; + }; EmValue::I32(i32::from_le_bytes(bytes)) } CilFlavor::R4 => { - let bytes = [ - data[byte_offset], - data[byte_offset + 1], - data[byte_offset + 2], - data[byte_offset + 3], - ]; + let Ok(bytes) = <[u8; 4]>::try_from(elem_bytes) else { + break; + }; EmValue::F32(f32::from_le_bytes(bytes)) } CilFlavor::I8 | CilFlavor::U8 => { - let bytes = [ - data[byte_offset], - data[byte_offset + 1], - data[byte_offset + 2], - data[byte_offset + 3], - data[byte_offset + 4], - data[byte_offset + 5], - data[byte_offset + 6], - data[byte_offset + 7], - ]; + let Ok(bytes) = <[u8; 8]>::try_from(elem_bytes) else { + break; + }; EmValue::I64(i64::from_le_bytes(bytes)) } CilFlavor::R8 => { - let bytes = [ - data[byte_offset], - data[byte_offset + 1], - data[byte_offset + 2], - data[byte_offset + 3], - data[byte_offset + 4], - data[byte_offset + 5], - data[byte_offset + 6], - data[byte_offset + 7], - ]; + let Ok(bytes) = <[u8; 8]>::try_from(elem_bytes) else { + break; + }; EmValue::F64(f64::from_le_bytes(bytes)) } // For other types, treat as bytes - _ => EmValue::I32(i32::from(data[byte_offset])), + _ => { + let Some(b) = elem_bytes.first() else { + break; + }; + EmValue::I32(i32::from(*b)) + } }; try_hook!(thread.heap_mut().set_array_element(array_ref, i, value)); @@ -532,12 +533,12 @@ fn runtime_helpers_equals_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); // false - } + }; // Reference equality - let equal = match (&ctx.args[0], &ctx.args[1]) { + let equal = match (arg0, arg1) { (EmValue::ObjectRef(a), EmValue::ObjectRef(b)) => a.id() == b.id(), (EmValue::Null, EmValue::Null) => true, _ => false, @@ -636,18 +637,14 @@ fn runtime_helpers_prepare_noop_pre( fn object_equals_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { // Instance overload: this.Equals(other) if let Some(this) = ctx.this { - let other = if ctx.args.is_empty() { - &EmValue::Null - } else { - &ctx.args[0] - }; + let other = ctx.args.first().unwrap_or(&EmValue::Null); let equal = this.clr_equals(other); return PreHookResult::Bypass(Some(EmValue::I32(i32::from(equal)))); } // Static overload: Object.Equals(a, b) - if ctx.args.len() >= 2 { - let equal = ctx.args[0].clr_equals(&ctx.args[1]); + if let (Some(a), Some(b)) = (ctx.args.first(), ctx.args.get(1)) { + let equal = a.clr_equals(b); return PreHookResult::Bypass(Some(EmValue::I32(i32::from(equal)))); } @@ -821,11 +818,7 @@ fn valuetype_equals_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> return PreHookResult::Bypass(Some(EmValue::I32(0))); }; - let other = if ctx.args.is_empty() { - &EmValue::Null - } else { - &ctx.args[0] - }; + let other = ctx.args.first().unwrap_or(&EmValue::Null); // Fast path: direct EmValue comparison (unboxed value types on the stack) if this_val.clr_equals(other) { @@ -965,8 +958,8 @@ fn unsafe_byte_offset_pre(_ctx: &HookContext<'_>, _thread: &mut EmulationThread) /// /// Compares the two arguments for reference equality. fn unsafe_are_same_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() >= 2 { - let same = ctx.args[0].clr_equals(&ctx.args[1]); + if let (Some(a), Some(b)) = (ctx.args.first(), ctx.args.get(1)) { + let same = a.clr_equals(b); PreHookResult::Bypass(Some(EmValue::I32(i32::from(same)))) } else { PreHookResult::Bypass(Some(EmValue::I32(0))) diff --git a/dotscope/src/emulation/runtime/bcl/system/array.rs b/dotscope/src/emulation/runtime/bcl/system/array.rs index 92b4ec45..93c429ca 100644 --- a/dotscope/src/emulation/runtime/bcl/system/array.rs +++ b/dotscope/src/emulation/runtime/bcl/system/array.rs @@ -327,12 +327,12 @@ fn array_set_value_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P return PreHookResult::Bypass(None); }; - if ctx.args.len() < 2 { + let (Some(value_arg), Some(index_arg)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(None); - } + }; - let value = ctx.args[0].clone(); - let index = usize::try_from(&ctx.args[1]).unwrap_or(0); + let value = value_arg.clone(); + let index = usize::try_from(index_arg).unwrap_or(0); try_hook!(thread.heap_mut().set_array_element(*href, index, value)); PreHookResult::Bypass(None) @@ -372,9 +372,10 @@ fn array_get_dimension_length_pre( let length = match obj { HeapObject::Array { elements, .. } if dimension == 0 => elements.len(), - HeapObject::MultiArray { dimensions, .. } if dimension < dimensions.len() => { - dimensions[dimension] - } + HeapObject::MultiArray { dimensions, .. } => match dimensions.get(dimension) { + Some(d) => *d, + None => return PreHookResult::Bypass(Some(EmValue::I32(0))), + }, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; @@ -467,28 +468,41 @@ fn array_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHoo } let (src_ref, src_index, dst_ref, dst_index, length) = if ctx.args.len() >= 5 { - let src_ref = match &ctx.args[0] { + let (Some(a0), Some(a1), Some(a2), Some(a3), Some(a4)) = ( + ctx.args.first(), + ctx.args.get(1), + ctx.args.get(2), + ctx.args.get(3), + ctx.args.get(4), + ) else { + return PreHookResult::Bypass(None); + }; + let src_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let src_index = usize::try_from(&ctx.args[1]).unwrap_or(0); - let dst_ref = match &ctx.args[2] { + let src_index = usize::try_from(a1).unwrap_or(0); + let dst_ref = match a2 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let dst_index = usize::try_from(&ctx.args[3]).unwrap_or(0); - let length = usize::try_from(&ctx.args[4]).unwrap_or(0); + let dst_index = usize::try_from(a3).unwrap_or(0); + let length = usize::try_from(a4).unwrap_or(0); (src_ref, src_index, dst_ref, dst_index, length) } else { - let src_ref = match &ctx.args[0] { + let (Some(a0), Some(a1), Some(a2)) = (ctx.args.first(), ctx.args.get(1), ctx.args.get(2)) + else { + return PreHookResult::Bypass(None); + }; + let src_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let dst_ref = match &ctx.args[1] { + let dst_ref = match a1 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let length = usize::try_from(&ctx.args[2]).unwrap_or(0); + let length = usize::try_from(a2).unwrap_or(0); (src_ref, 0, dst_ref, 0, length) }; @@ -507,8 +521,13 @@ fn array_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHoo try_hook!(thread.heap_mut().with_object_mut(dst_ref, |obj| match obj { HeapObject::Array { elements, .. } => { for (i, elem) in src_elements.into_iter().enumerate() { - if dst_index + i < elements.len() { - elements[dst_index + i] = elem; + let Some(target_index) = dst_index.checked_add(i) else { + return Err(crate::malformed_error!( + "Array.Copy: destination index overflow" + )); + }; + if let Some(slot) = elements.get_mut(target_index) { + *slot = elem; } } Ok(()) @@ -535,17 +554,18 @@ fn array_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHoo /// /// None. Elements are set to the default value for their element type. fn array_clear_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 3 { + let (Some(a0), Some(a1), Some(a2)) = (ctx.args.first(), ctx.args.get(1), ctx.args.get(2)) + else { return PreHookResult::Bypass(None); - } + }; - let array_ref = match &ctx.args[0] { + let array_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let index = usize::try_from(&ctx.args[1]).unwrap_or(0); - let length = usize::try_from(&ctx.args[2]).unwrap_or(0); + let index = usize::try_from(a1).unwrap_or(0); + let length = usize::try_from(a2).unwrap_or(0); try_hook!(thread .heap_mut() @@ -555,8 +575,14 @@ fn array_clear_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHo element_type, } => { let default = EmValue::default_for_flavor(element_type); - for i in index..(index + length).min(elements.len()) { - elements[i] = default.clone(); + let end = index + .checked_add(length) + .ok_or_else(|| crate::malformed_error!("Array.Clear: index + length overflow"))? + .min(elements.len()); + for i in index..end { + if let Some(slot) = elements.get_mut(i) { + *slot = default.clone(); + } } Ok(()) } @@ -583,11 +609,11 @@ fn array_clear_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHo /// /// None. The array (or section) is reversed in-place. fn array_reverse_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(a0) = ctx.args.first() else { return PreHookResult::Bypass(None); - } + }; - let array_ref = match &ctx.args[0] { + let array_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; @@ -627,17 +653,15 @@ fn array_reverse_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pre /// /// The zero-based index of the first occurrence, or -1 if not found. fn array_index_of_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(a0), Some(search_value)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::I32(-1))); - } + }; - let array_ref = match &ctx.args[0] { + let array_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(Some(EmValue::I32(-1))), }; - let search_value = &ctx.args[1]; - let Ok(obj) = thread.heap().get(array_ref) else { return PreHookResult::Bypass(Some(EmValue::I32(-1))); }; @@ -675,24 +699,30 @@ fn array_index_of_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr /// copying from an `int[]` with `srcOffset=2` means starting at byte 2 within /// the first integer element, not at the third integer. fn buffer_block_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 5 { + let (Some(a0), Some(a1), Some(a2), Some(a3), Some(a4)) = ( + ctx.args.first(), + ctx.args.get(1), + ctx.args.get(2), + ctx.args.get(3), + ctx.args.get(4), + ) else { return PreHookResult::Bypass(None); - } + }; - let src_ref = match &ctx.args[0] { + let src_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let src_offset = usize::try_from(&ctx.args[1]).unwrap_or(0); + let src_offset = usize::try_from(a1).unwrap_or(0); - let dst_ref = match &ctx.args[2] { + let dst_ref = match a2 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(None), }; - let dst_offset = usize::try_from(&ctx.args[3]).unwrap_or(0); - let count = usize::try_from(&ctx.args[4]).unwrap_or(0); + let dst_offset = usize::try_from(a3).unwrap_or(0); + let count = usize::try_from(a4).unwrap_or(0); if count == 0 { return PreHookResult::Bypass(None); @@ -724,11 +754,16 @@ fn buffer_block_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> }; // Extract the byte range we need to copy - let end_offset = (src_offset + count).min(src_bytes.len()); + let Some(unbounded_end) = src_offset.checked_add(count) else { + return PreHookResult::Error("Buffer.BlockCopy: src_offset + count overflow".to_string()); + }; + let end_offset = unbounded_end.min(src_bytes.len()); if src_offset >= src_bytes.len() { return PreHookResult::Bypass(None); } - let bytes_to_copy: Vec = src_bytes[src_offset..end_offset].to_vec(); + let Some(bytes_to_copy) = src_bytes.get(src_offset..end_offset).map(<[u8]>::to_vec) else { + return PreHookResult::Bypass(None); + }; // Apply bytes to destination array try_hook!(thread.heap_mut().with_object_mut(dst_ref, |obj| match obj { @@ -739,7 +774,9 @@ fn buffer_block_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> let Some(dst_elem_size) = element_type.element_size(ptr_size) else { return Ok(()); }; - let dst_byte_len = elements.len() * dst_elem_size; + let dst_byte_len = elements.len().checked_mul(dst_elem_size).ok_or_else(|| { + crate::malformed_error!("Buffer.BlockCopy: destination byte length overflow") + })?; if dst_offset >= dst_byte_len { return Ok(()); @@ -752,14 +789,25 @@ fn buffer_block_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> .collect(); // Copy the bytes - let copy_len = bytes_to_copy.len().min(dst_byte_len - dst_offset); - dst_bytes[dst_offset..dst_offset + copy_len] - .copy_from_slice(&bytes_to_copy[..copy_len]); + let remaining = dst_byte_len.checked_sub(dst_offset).ok_or_else(|| { + crate::malformed_error!("Buffer.BlockCopy: dst_offset exceeds buffer") + })?; + let copy_len = bytes_to_copy.len().min(remaining); + let copy_end = dst_offset.checked_add(copy_len).ok_or_else(|| { + crate::malformed_error!("Buffer.BlockCopy: dst_offset + copy_len overflow") + })?; + let dst_slice = dst_bytes + .get_mut(dst_offset..copy_end) + .ok_or(crate::out_of_bounds_error!())?; + let src_slice = bytes_to_copy + .get(..copy_len) + .ok_or(crate::out_of_bounds_error!())?; + dst_slice.copy_from_slice(src_slice); // Convert bytes back to elements for (i, chunk) in dst_bytes.chunks(dst_elem_size).enumerate() { - if i < elements.len() { - elements[i] = EmValue::from_le_bytes(chunk, element_type, ptr_size); + if let Some(slot) = elements.get_mut(i) { + *slot = EmValue::from_le_bytes(chunk, element_type, ptr_size); } } @@ -785,11 +833,11 @@ fn buffer_block_copy_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> /// /// The total size of the array in bytes (element count * element size). fn buffer_byte_length_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(a0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; - let array_ref = match &ctx.args[0] { + let array_ref = match a0 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; @@ -810,7 +858,7 @@ fn buffer_byte_length_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - CilFlavor::I1 | CilFlavor::U1 | CilFlavor::Boolean => 1, _ => std::mem::size_of::(), }; - elements.len() * element_size + elements.len().saturating_mul(element_size) } _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; @@ -829,12 +877,12 @@ fn buffer_byte_length_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - /// The element type is resolved from the `Type` argument via the emulation thread's /// type token resolution, which uses the assembly's metadata for accurate mapping. fn array_create_instance_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(type_arg), Some(length_arg)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Continue; - } + }; // arg[1] = int length (or int[] for multi-dimensional, which we only handle for 1D) - let length = match &ctx.args[1] { + let length = match length_arg { EmValue::I32(n) => *n as usize, EmValue::I64(n) => *n as usize, EmValue::NativeInt(n) => *n as usize, @@ -843,7 +891,7 @@ fn array_create_instance_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread // arg[0] = Type (ObjectRef to ReflectionType on the heap) // Resolve the element type through the proper metadata path - let element_flavor = match &ctx.args[0] { + let element_flavor = match type_arg { EmValue::ObjectRef(href) => { match thread.heap().get_reflection_type_token(*href) { Ok(Some(type_token)) => { diff --git a/dotscope/src/emulation/runtime/bcl/system/bitconverter.rs b/dotscope/src/emulation/runtime/bcl/system/bitconverter.rs index 371d6636..d0a8d373 100644 --- a/dotscope/src/emulation/runtime/bcl/system/bitconverter.rs +++ b/dotscope/src/emulation/runtime/bcl/system/bitconverter.rs @@ -76,11 +76,11 @@ fn bitconverter_get_bytes_pre( ctx: &HookContext<'_>, thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - let bytes = match &ctx.args[0] { + let bytes = match arg { EmValue::I32(v) => LeBytes::from_4(v.to_le_bytes()), EmValue::I64(v) | EmValue::NativeInt(v) => LeBytes::from_8(v.to_le_bytes()), EmValue::F32(v) => LeBytes::from_4(v.to_le_bytes()), @@ -109,29 +109,29 @@ fn bitconverter_get_bytes_pre( /// - `value`: Byte array containing the bytes to convert /// - `startIndex`: Starting position within the array fn bitconverter_to_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; - let Some(bytes) = (match &ctx.args[0] { + let Some(bytes) = (match arg0 { EmValue::ObjectRef(handle) => try_hook!(thread.heap().get_byte_array(*handle)), _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); }; - let start_index = usize::try_from(&ctx.args[1]).unwrap_or(usize::MAX); - if start_index.saturating_add(4) > bytes.len() { + let start_index = usize::try_from(arg1).unwrap_or(usize::MAX); + let Some(end_index) = start_index.checked_add(4) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } - - let value = i32::from_le_bytes([ - bytes[start_index], - bytes[start_index + 1], - bytes[start_index + 2], - bytes[start_index + 3], - ]); + }; + let Some(slice) = bytes + .get(start_index..end_index) + .and_then(|s| <[u8; 4]>::try_from(s).ok()) + else { + return PreHookResult::Bypass(Some(EmValue::I32(0))); + }; + let value = i32::from_le_bytes(slice); PreHookResult::Bypass(Some(EmValue::I32(value))) } @@ -148,33 +148,29 @@ fn bitconverter_to_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread /// - `value`: Byte array containing the bytes to convert /// - `startIndex`: Starting position within the array fn bitconverter_to_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::I64(0))); - } + }; - let Some(bytes) = (match &ctx.args[0] { + let Some(bytes) = (match arg0 { EmValue::ObjectRef(handle) => try_hook!(thread.heap().get_byte_array(*handle)), _ => return PreHookResult::Bypass(Some(EmValue::I64(0))), }) else { return PreHookResult::Bypass(Some(EmValue::I64(0))); }; - let start_index = usize::try_from(&ctx.args[1]).unwrap_or(usize::MAX); - if start_index.saturating_add(8) > bytes.len() { + let start_index = usize::try_from(arg1).unwrap_or(usize::MAX); + let Some(end_index) = start_index.checked_add(8) else { return PreHookResult::Bypass(Some(EmValue::I64(0))); - } - - let value = i64::from_le_bytes([ - bytes[start_index], - bytes[start_index + 1], - bytes[start_index + 2], - bytes[start_index + 3], - bytes[start_index + 4], - bytes[start_index + 5], - bytes[start_index + 6], - bytes[start_index + 7], - ]); + }; + let Some(slice) = bytes + .get(start_index..end_index) + .and_then(|s| <[u8; 8]>::try_from(s).ok()) + else { + return PreHookResult::Bypass(Some(EmValue::I64(0))); + }; + let value = i64::from_le_bytes(slice); PreHookResult::Bypass(Some(EmValue::I64(value))) } @@ -194,28 +190,29 @@ fn bitconverter_to_uint32_pre( ctx: &HookContext<'_>, thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; - let Some(bytes) = (match &ctx.args[0] { + let Some(bytes) = (match arg0 { EmValue::ObjectRef(handle) => try_hook!(thread.heap().get_byte_array(*handle)), _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); }; - let start_index = usize::try_from(&ctx.args[1]).unwrap_or(usize::MAX); - if start_index.saturating_add(4) > bytes.len() { + let start_index = usize::try_from(arg1).unwrap_or(usize::MAX); + let Some(end_index) = start_index.checked_add(4) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; + let Some(slice) = bytes + .get(start_index..end_index) + .and_then(|s| <[u8; 4]>::try_from(s).ok()) + else { + return PreHookResult::Bypass(Some(EmValue::I32(0))); + }; - let value = u32::from_le_bytes([ - bytes[start_index], - bytes[start_index + 1], - bytes[start_index + 2], - bytes[start_index + 3], - ]); + let value = u32::from_le_bytes(slice); #[allow(clippy::cast_possible_wrap)] let signed_value = value as i32; diff --git a/dotscope/src/emulation/runtime/bcl/system/convert.rs b/dotscope/src/emulation/runtime/bcl/system/convert.rs index 6c6270dd..62ca8313 100644 --- a/dotscope/src/emulation/runtime/bcl/system/convert.rs +++ b/dotscope/src/emulation/runtime/bcl/system/convert.rs @@ -217,11 +217,11 @@ pub fn register(manager: &HookManager) -> Result<()> { /// /// A Base64-encoded string representation of the byte array. fn to_base64_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let encoded = base64_encode(&bytes); match thread.heap_mut().alloc_string(&encoded) { @@ -252,11 +252,11 @@ fn to_base64_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> /// /// A byte array containing the decoded data, or `null` if decoding fails. fn from_base64_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Ok(s) = thread.heap().get_string(*handle) { if let Some(decoded) = base64_decode(&s) { match thread.heap_mut().alloc_byte_array(&decoded) { @@ -298,10 +298,10 @@ fn from_base64_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - /// /// The converted unsigned 8-bit integer value. fn to_byte_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_u8.into())); - } - PreHookResult::Bypass(Some(ctx.args[0].to_u8_cil().into())) + }; + PreHookResult::Bypass(Some(arg0.to_u8_cil().into())) } /// Hook for `System.Convert.ToSByte` method. @@ -331,10 +331,10 @@ fn to_byte_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookR /// /// The converted signed 8-bit integer value. fn to_sbyte_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i8.into())); - } - PreHookResult::Bypass(Some(ctx.args[0].to_i8_cil().into())) + }; + PreHookResult::Bypass(Some(arg0.to_i8_cil().into())) } /// Hook for `System.Convert.ToInt16` method. @@ -364,10 +364,10 @@ fn to_sbyte_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHook /// /// The converted signed 16-bit integer value. fn to_int16_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i16.into())); - } - PreHookResult::Bypass(Some(ctx.args[0].to_i16_cil().into())) + }; + PreHookResult::Bypass(Some(arg0.to_i16_cil().into())) } /// Hook for `System.Convert.ToUInt16` method. @@ -397,10 +397,10 @@ fn to_int16_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHook /// /// The converted unsigned 16-bit integer value. fn to_uint16_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_u16.into())); - } - PreHookResult::Bypass(Some(ctx.args[0].to_u16_cil().into())) + }; + PreHookResult::Bypass(Some(arg0.to_u16_cil().into())) } /// Hook for `System.Convert.ToInt32` method. @@ -434,11 +434,11 @@ fn to_uint16_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHoo /// /// The converted signed 32-bit integer value. fn to_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i32.into())); - } + }; - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Ok(s) = thread.heap_mut().get_string(*handle) { if let Ok(n) = s.parse::() { return PreHookResult::Bypass(Some(n.into())); @@ -446,7 +446,7 @@ fn to_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookR } return PreHookResult::Bypass(Some(0_i32.into())); } - PreHookResult::Bypass(Some(ctx.args[0].to_i32_cil().into())) + PreHookResult::Bypass(Some(arg0.to_i32_cil().into())) } /// Hook for `System.Convert.ToUInt32` method. @@ -478,13 +478,13 @@ fn to_int32_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookR /// /// The converted unsigned 32-bit integer value. fn to_uint32_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i32.into())); - } + }; // Store as i32 with u32 bit pattern (CIL semantics) // Bit-cast from u32 to i32 preserves the bit pattern - wrapping is intentional #[allow(clippy::cast_possible_wrap)] - let value = ctx.args[0].to_u32_cil() as i32; + let value = arg0.to_u32_cil() as i32; PreHookResult::Bypass(Some(value.into())) } @@ -519,11 +519,11 @@ fn to_uint32_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHoo /// /// The converted signed 64-bit integer value. fn to_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i64.into())); - } + }; - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Ok(s) = thread.heap_mut().get_string(*handle) { if let Ok(n) = s.parse::() { return PreHookResult::Bypass(Some(n.into())); @@ -531,7 +531,7 @@ fn to_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookR } return PreHookResult::Bypass(Some(0_i64.into())); } - PreHookResult::Bypass(Some(ctx.args[0].to_i64_cil().into())) + PreHookResult::Bypass(Some(arg0.to_i64_cil().into())) } /// Hook for `System.Convert.ToUInt64` method. @@ -563,13 +563,13 @@ fn to_int64_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookR /// /// The converted unsigned 64-bit integer value. fn to_uint64_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i64.into())); - } + }; // Store as i64 with u64 bit pattern (CIL semantics) // Bit-cast from u64 to i64 preserves the bit pattern - wrapping is intentional #[allow(clippy::cast_possible_wrap)] - let value = ctx.args[0].to_u64_cil() as i64; + let value = arg0.to_u64_cil() as i64; PreHookResult::Bypass(Some(value.into())) } @@ -599,11 +599,11 @@ fn to_uint64_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHoo /// /// The converted 32-bit floating-point value. fn to_single_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0.0_f32.into())); - } + }; - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Ok(s) = thread.heap_mut().get_string(*handle) { if let Ok(f) = s.parse::() { return PreHookResult::Bypass(Some(f.into())); @@ -611,7 +611,7 @@ fn to_single_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHook } return PreHookResult::Bypass(Some(0.0_f32.into())); } - PreHookResult::Bypass(Some(ctx.args[0].to_f32_cil().into())) + PreHookResult::Bypass(Some(arg0.to_f32_cil().into())) } /// Hook for `System.Convert.ToDouble` method. @@ -640,11 +640,11 @@ fn to_single_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHook /// /// The converted 64-bit floating-point value. fn to_double_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0.0_f64.into())); - } + }; - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Ok(s) = thread.heap_mut().get_string(*handle) { if let Ok(f) = s.parse::() { return PreHookResult::Bypass(Some(f.into())); @@ -652,7 +652,7 @@ fn to_double_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHook } return PreHookResult::Bypass(Some(0.0_f64.into())); } - PreHookResult::Bypass(Some(ctx.args[0].to_f64_cil().into())) + PreHookResult::Bypass(Some(arg0.to_f64_cil().into())) } /// Hook for `System.Convert.ToChar` method. @@ -678,11 +678,11 @@ fn to_double_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHook /// /// The converted Unicode character (stored as `UInt16`). fn to_char_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_u16.into())); - } + }; // Char is stored as u16 in I32 (CIL semantics) - PreHookResult::Bypass(Some(ctx.args[0].to_u16_cil().into())) + PreHookResult::Bypass(Some(arg0.to_u16_cil().into())) } /// Hook for `System.Convert.ToBoolean` method. @@ -713,12 +713,12 @@ fn to_char_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookR /// /// `true` if the value is non-zero, "true", or "1"; otherwise `false`. fn to_boolean_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(false.into())); - } + }; // String parsing needs special handling - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let EmValue::ObjectRef(handle) = arg0 { if let Ok(s) = thread.heap().get_string(*handle) { return PreHookResult::Bypass(Some( (s.eq_ignore_ascii_case("true") || &*s == "1").into(), @@ -726,7 +726,7 @@ fn to_boolean_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHoo } return PreHookResult::Bypass(Some(false.into())); } - PreHookResult::Bypass(Some(ctx.args[0].to_bool_cil().into())) + PreHookResult::Bypass(Some(arg0.to_bool_cil().into())) } /// Hook for `System.Convert.ToString` method. @@ -759,14 +759,14 @@ fn to_boolean_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHoo /// /// The string representation of the value. fn to_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return match thread.heap_mut().alloc_string("") { Ok(handle) => PreHookResult::Bypass(Some(EmValue::ObjectRef(handle))), Err(e) => PreHookResult::Error(format!("heap allocation failed: {e}")), }; - } + }; - let s = match &ctx.args[0] { + let s = match arg0 { EmValue::I32(n) => n.to_string(), EmValue::I64(n) => n.to_string(), EmValue::F32(f) => f.to_string(), diff --git a/dotscope/src/emulation/runtime/bcl/system/datetime.rs b/dotscope/src/emulation/runtime/bcl/system/datetime.rs index 99654334..1776311f 100644 --- a/dotscope/src/emulation/runtime/bcl/system/datetime.rs +++ b/dotscope/src/emulation/runtime/bcl/system/datetime.rs @@ -295,8 +295,14 @@ fn extract_ticks_from_this(ctx: &HookContext<'_>, _thread: &EmulationThread) -> /// Converts a Unix timestamp (seconds since 1970-01-01 00:00:00 UTC) to .NET /// ticks (100-nanosecond intervals since 0001-01-01 00:00:00 UTC). +/// +/// Returns [`FALLBACK_TICKS`] if the multiplication or addition overflows +/// (would only happen for absurd timestamp values that don't represent real dates). fn unix_to_ticks(unix_timestamp: u32) -> i64 { - UNIX_EPOCH_TICKS + i64::from(unix_timestamp) * TICKS_PER_SECOND + i64::from(unix_timestamp) + .checked_mul(TICKS_PER_SECOND) + .and_then(|secs_ticks| UNIX_EPOCH_TICKS.checked_add(secs_ticks)) + .unwrap_or(FALLBACK_TICKS) } /// Returns the PE build timestamp as .NET ticks. @@ -324,13 +330,25 @@ fn is_leap_year(year: i32) -> bool { } /// Returns the number of days in the given month (1-based) of the given year. +/// +/// Returns 0 if `month` is out of the valid 1..=12 range (so callers using +/// this for validation will reject the date). fn days_in_month(year: i32, month: i32) -> i32 { let table = if is_leap_year(year) { &DAYS_TO_MONTH_366 } else { &DAYS_TO_MONTH_365 }; - (table[month as usize] - table[(month - 1) as usize]) as i32 + let Ok(m) = usize::try_from(month) else { + return 0; + }; + let Some(prev) = m.checked_sub(1) else { + return 0; + }; + let (Some(end), Some(start)) = (table.get(m), table.get(prev)) else { + return 0; + }; + end.saturating_sub(*start) as i32 } /// Computes .NET ticks from a Gregorian date (year, month, day). @@ -351,10 +369,25 @@ fn date_to_ticks(year: i32, month: i32, day: i32) -> i64 { } else { &DAYS_TO_MONTH_365 }; - let y = year - 1; - let total_days = - y * 365 + y / 4 - y / 100 + y / 400 + table[(month - 1) as usize] as i32 + day - 1; - i64::from(total_days) * TICKS_PER_DAY + // year is validated to 1..=9999, so y is 0..=9998 and all sub-products + // (y*365, y/4, y/100, y/400) are well within i32 range. + let y = year.saturating_sub(1); + let Ok(month_idx) = usize::try_from(month.saturating_sub(1)) else { + return 0; + }; + let Some(table_entry) = table.get(month_idx) else { + return 0; + }; + let table_days = *table_entry as i32; + let total_days = y + .saturating_mul(365) + .saturating_add(y / 4) + .saturating_sub(y / 100) + .saturating_add(y / 400) + .saturating_add(table_days) + .saturating_add(day) + .saturating_sub(1); + i64::from(total_days).saturating_mul(TICKS_PER_DAY) } /// Extracts the Gregorian (year, month, day) from a .NET ticks value. @@ -364,28 +397,36 @@ fn date_to_ticks(year: i32, month: i32, day: i32) -> i64 { /// Gregorian cycle (146,097 days per cycle). fn ticks_to_date(ticks: i64) -> (i32, i32, i32) { let ticks = ticks & 0x3FFF_FFFF_FFFF_FFFF; + // total_days for the supported DateTime range fits comfortably in i32. let total_days = (ticks / TICKS_PER_DAY) as i32; - // Compute year from total days using the 400-year cycle. + // Compute year from total days using the 400-year cycle. All intermediate + // values are bounded by total_days (<= ~3.65M for year 9999), well within + // i32 range, so saturating arithmetic is purely defensive. let y400 = total_days / 146_097; - let mut remaining = total_days - y400 * 146_097; + let mut remaining = total_days.saturating_sub(y400.saturating_mul(146_097)); let mut y100 = remaining / 36_524; if y100 == 4 { y100 = 3; } - remaining -= y100 * 36_524; + remaining = remaining.saturating_sub(y100.saturating_mul(36_524)); let y4 = remaining / 1_461; - remaining -= y4 * 1_461; + remaining = remaining.saturating_sub(y4.saturating_mul(1_461)); let mut y1 = remaining / 365; if y1 == 4 { y1 = 3; } - let year = y400 * 400 + y100 * 100 + y4 * 4 + y1 + 1; - remaining -= y1 * 365; + let year = y400 + .saturating_mul(400) + .saturating_add(y100.saturating_mul(100)) + .saturating_add(y4.saturating_mul(4)) + .saturating_add(y1) + .saturating_add(1); + remaining = remaining.saturating_sub(y1.saturating_mul(365)); let table = if is_leap_year(year) { &DAYS_TO_MONTH_366 @@ -393,11 +434,12 @@ fn ticks_to_date(ticks: i64) -> (i32, i32, i32) { &DAYS_TO_MONTH_365 }; - let mut month = 1; - while month < 12 && remaining >= table[month] as i32 { - month += 1; + let mut month: usize = 1; + while month < 12 && table.get(month).is_some_and(|&t| remaining >= t as i32) { + month = month.saturating_add(1); } - let day = remaining - table[month - 1] as i32 + 1; + let prev_month_days = table.get(month.saturating_sub(1)).copied().unwrap_or(0) as i32; + let day = remaining.saturating_sub(prev_month_days).saturating_add(1); (year, month as i32, day) } @@ -413,17 +455,17 @@ fn datetime_ctor_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> Pr // .ctor(long ticks) — `this` is arg[0] for value type ctors, ticks is arg[1] // or ticks is the only arg if this is passed separately 1 => { - if let Some(ticks) = extract_ticks(&ctx.args[0]) { + if let Some(ticks) = ctx.args.first().and_then(extract_ticks) { return PreHookResult::Bypass(Some(EmValue::I64(ticks))); } PreHookResult::Continue } 2 => { // Could be .ctor(this, ticks) or .ctor(ticks, DateTimeKind) - if let Some(ticks) = extract_ticks(&ctx.args[0]) { + if let Some(ticks) = ctx.args.first().and_then(extract_ticks) { return PreHookResult::Bypass(Some(EmValue::I64(ticks))); } - if let Some(ticks) = extract_ticks(&ctx.args[1]) { + if let Some(ticks) = ctx.args.get(1).and_then(extract_ticks) { return PreHookResult::Bypass(Some(EmValue::I64(ticks))); } PreHookResult::Continue @@ -433,15 +475,15 @@ fn datetime_ctor_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> Pr let (y, m, d) = if n >= 4 { // this, year, month, day ( - extract_i32(&ctx.args[1]), - extract_i32(&ctx.args[2]), - extract_i32(&ctx.args[3]), + ctx.args.get(1).and_then(extract_i32), + ctx.args.get(2).and_then(extract_i32), + ctx.args.get(3).and_then(extract_i32), ) } else { ( - extract_i32(&ctx.args[0]), - extract_i32(&ctx.args[1]), - extract_i32(&ctx.args[2]), + ctx.args.first().and_then(extract_i32), + ctx.args.get(1).and_then(extract_i32), + ctx.args.get(2).and_then(extract_i32), ) }; if let (Some(year), Some(month), Some(day)) = (y, m, d) { @@ -530,7 +572,10 @@ fn datetime_add_f64( }) .unwrap_or(0.0); let delta = (amount * ticks_per_unit as f64) as i64; - PreHookResult::Bypass(Some(EmValue::I64(this_ticks + delta))) + // .NET DateTime arithmetic throws OverflowException on tick overflow; we + // saturate here to keep the hook fast-path usable instead of aborting. + let result = this_ticks.saturating_add(delta); + PreHookResult::Bypass(Some(EmValue::I64(result))) } /// Hook for `DateTime.AddDays(double)`. @@ -557,7 +602,7 @@ fn datetime_add_seconds_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) fn datetime_add_ticks_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { let this_ticks = extract_ticks_from_this(ctx, thread).unwrap_or(0); let delta = ctx.args.last().and_then(extract_ticks).unwrap_or(0); - PreHookResult::Bypass(Some(EmValue::I64(this_ticks + delta))) + PreHookResult::Bypass(Some(EmValue::I64(this_ticks.saturating_add(delta)))) } /// Hook for `DateTime.op_Subtraction(DateTime, DateTime) -> TimeSpan`. @@ -568,13 +613,12 @@ fn datetime_op_subtraction_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.len() >= 2 { - let a = extract_ticks(&ctx.args[0]).unwrap_or(0); - let b = extract_ticks(&ctx.args[1]).unwrap_or(0); - PreHookResult::Bypass(Some(EmValue::I64(a - b))) - } else { - PreHookResult::Bypass(Some(EmValue::I64(0))) - } + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { + return PreHookResult::Bypass(Some(EmValue::I64(0))); + }; + let a = extract_ticks(arg0).unwrap_or(0); + let b = extract_ticks(arg1).unwrap_or(0); + PreHookResult::Bypass(Some(EmValue::I64(a.saturating_sub(b)))) } /// Shared implementation for DateTime comparison operators. @@ -582,13 +626,12 @@ fn datetime_op_subtraction_pre( /// Extracts ticks from both arguments, applies the given comparison function, /// and returns `I32(1)` for true or `I32(0)` for false. fn datetime_cmp(ctx: &HookContext<'_>, op: fn(i64, i64) -> bool) -> PreHookResult { - if ctx.args.len() >= 2 { - let a = extract_ticks(&ctx.args[0]).unwrap_or(0); - let b = extract_ticks(&ctx.args[1]).unwrap_or(0); - PreHookResult::Bypass(Some(EmValue::I32(i32::from(op(a, b))))) - } else { - PreHookResult::Bypass(Some(EmValue::I32(0))) - } + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { + return PreHookResult::Bypass(Some(EmValue::I32(0))); + }; + let a = extract_ticks(arg0).unwrap_or(0); + let b = extract_ticks(arg1).unwrap_or(0); + PreHookResult::Bypass(Some(EmValue::I32(i32::from(op(a, b))))) } /// Hook for `DateTime.op_GreaterThan`. diff --git a/dotscope/src/emulation/runtime/bcl/system/environment.rs b/dotscope/src/emulation/runtime/bcl/system/environment.rs index caada9c7..d2a6b177 100644 --- a/dotscope/src/emulation/runtime/bcl/system/environment.rs +++ b/dotscope/src/emulation/runtime/bcl/system/environment.rs @@ -279,7 +279,12 @@ fn get_tick_count_pre(_ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P let config = &thread.config().environment; let divisor = config.tick_count_divisor.max(1); // avoid division by zero #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] - let ticks = config.tick_count_base + (thread.instructions_executed() / divisor) as i64; + let ticks = config.tick_count_base.saturating_add( + thread + .instructions_executed() + .checked_div(divisor) + .unwrap_or(0) as i64, + ); PreHookResult::Bypass(Some(EmValue::I32(ticks as i32))) } @@ -297,7 +302,12 @@ fn get_tick_count64_pre(_ctx: &HookContext<'_>, thread: &mut EmulationThread) -> let config = &thread.config().environment; let divisor = config.tick_count_divisor.max(1); // avoid division by zero #[allow(clippy::cast_possible_wrap, clippy::cast_possible_truncation)] - let ticks = config.tick_count_base + (thread.instructions_executed() / divisor) as i64; + let ticks = config.tick_count_base.saturating_add( + thread + .instructions_executed() + .checked_div(divisor) + .unwrap_or(0) as i64, + ); PreHookResult::Bypass(Some(EmValue::I64(ticks))) } diff --git a/dotscope/src/emulation/runtime/bcl/system/math.rs b/dotscope/src/emulation/runtime/bcl/system/math.rs index 2ff0f282..fb0fcbdb 100644 --- a/dotscope/src/emulation/runtime/bcl/system/math.rs +++ b/dotscope/src/emulation/runtime/bcl/system/math.rs @@ -384,11 +384,11 @@ pub fn register(manager: &HookManager) -> Result<()> { /// /// - `value`: The number whose absolute value is to be found fn system_math_abs_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(0_i32.into())); - } + }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::I32(n) => n.abs().into(), EmValue::I64(n) => n.abs().into(), EmValue::F32(f) => f.abs().into(), @@ -422,16 +422,16 @@ fn system_math_abs_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// - `val1`: The first of two values to compare /// - `val2`: The second of two values to compare fn system_math_min_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } + }; - let result = match (&ctx.args[0], &ctx.args[1]) { + let result = match (arg0, arg1) { (EmValue::I32(a), EmValue::I32(b)) => (*a.min(b)).into(), (EmValue::I64(a), EmValue::I64(b)) => (*a.min(b)).into(), (EmValue::F32(a), EmValue::F32(b)) => a.min(*b).into(), (EmValue::F64(a), EmValue::F64(b)) => a.min(*b).into(), - _ => ctx.args[0].clone(), + _ => arg0.clone(), }; PreHookResult::Bypass(Some(result)) @@ -460,16 +460,16 @@ fn system_math_min_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// - `val1`: The first of two values to compare /// - `val2`: The second of two values to compare fn system_math_max_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } + }; - let result = match (&ctx.args[0], &ctx.args[1]) { + let result = match (arg0, arg1) { (EmValue::I32(a), EmValue::I32(b)) => EmValue::I32(*a.max(b)), (EmValue::I64(a), EmValue::I64(b)) => EmValue::I64(*a.max(b)), (EmValue::F32(a), EmValue::F32(b)) => EmValue::F32(a.max(*b)), (EmValue::F64(a), EmValue::F64(b)) => EmValue::F64(a.max(*b)), - _ => ctx.args[0].clone(), + _ => arg0.clone(), }; PreHookResult::Bypass(Some(result)) @@ -493,11 +493,11 @@ fn system_math_max_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `value`: A signed number fn system_math_sign_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::I32(n) => n.signum(), EmValue::I64(n) => n.signum() as i32, EmValue::F32(f) => { @@ -552,11 +552,12 @@ fn system_math_sign_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// - `min`: The lower bound of the result /// - `max`: The upper bound of the result fn system_math_clamp_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 3 { + let (Some(arg0), Some(arg1), Some(arg2)) = (ctx.args.first(), ctx.args.get(1), ctx.args.get(2)) + else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } + }; - let result = match (&ctx.args[0], &ctx.args[1], &ctx.args[2]) { + let result = match (arg0, arg1, arg2) { (EmValue::I32(val), EmValue::I32(min), EmValue::I32(max)) => { EmValue::I32(*val.max(min).min(max)) } @@ -569,7 +570,7 @@ fn system_math_clamp_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) - (EmValue::F64(val), EmValue::F64(min), EmValue::F64(max)) => { EmValue::F64(val.max(*min).min(*max)) } - _ => ctx.args[0].clone(), + _ => arg0.clone(), }; PreHookResult::Bypass(Some(result)) @@ -592,31 +593,29 @@ fn system_math_clamp_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) - /// - `b`: The divisor /// - `result` (out): The remainder (for legacy overloads) fn system_math_divrem_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; - let (quotient, remainder) = match (&ctx.args[0], &ctx.args[1]) { - (EmValue::I32(a), EmValue::I32(b)) => { - if *b == 0 { - (EmValue::I32(0), EmValue::I32(0)) - } else { - (EmValue::I32(a / b), EmValue::I32(a % b)) - } - } - (EmValue::I64(a), EmValue::I64(b)) => { - if *b == 0 { - (EmValue::I64(0), EmValue::I64(0)) - } else { - (EmValue::I64(a / b), EmValue::I64(a % b)) - } - } + // Use checked arithmetic: division by zero or i32::MIN/-1 overflow falls + // back to (0, 0). The hook intentionally degrades gracefully rather than + // raising DivideByZeroException/OverflowException, matching the existing + // lenient style used by other Math hooks here. + let (quotient, remainder) = match (arg0, arg1) { + (EmValue::I32(a), EmValue::I32(b)) => match (a.checked_div(*b), a.checked_rem(*b)) { + (Some(q), Some(r)) => (EmValue::I32(q), EmValue::I32(r)), + _ => (EmValue::I32(0), EmValue::I32(0)), + }, + (EmValue::I64(a), EmValue::I64(b)) => match (a.checked_div(*b), a.checked_rem(*b)) { + (Some(q), Some(r)) => (EmValue::I64(q), EmValue::I64(r)), + _ => (EmValue::I64(0), EmValue::I64(0)), + }, _ => (EmValue::I32(0), EmValue::I32(0)), }; // Store remainder through the out parameter if provided - if ctx.args.len() >= 3 { - if let Some(ptr) = ctx.args[2].as_managed_ptr() { + if let Some(arg2) = ctx.args.get(2) { + if let Some(ptr) = arg2.as_managed_ptr() { try_hook!(thread.store_through_pointer(ptr, remainder)); } } @@ -637,11 +636,11 @@ fn system_math_divrem_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - /// /// - `d`: A double-precision floating-point number fn system_math_floor_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::F32(f) => EmValue::F64(f64::from(f.floor())), EmValue::F64(f) => EmValue::F64(f.floor()), _ => EmValue::F64(0.0), @@ -663,11 +662,11 @@ fn system_math_floor_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) - /// /// - `d`: A double-precision floating-point number fn system_math_ceiling_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::F32(f) => EmValue::F64(f64::from(f.ceil())), EmValue::F64(f) => EmValue::F64(f.ceil()), _ => EmValue::F64(0.0), @@ -697,23 +696,19 @@ fn system_math_ceiling_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) /// - `digits`: The number of fractional digits in the return value (optional) /// - `mode`: Specification for how to round value if it is midway (optional) fn system_math_round_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let value = match &ctx.args[0] { + let value = match arg0 { EmValue::F32(f) => f64::from(*f), EmValue::F64(f) => *f, _ => return PreHookResult::Bypass(Some(EmValue::F64(0.0))), }; - let decimals = if ctx.args.len() > 1 { - match &ctx.args[1] { - EmValue::I32(n) => (*n).clamp(0, 15), - _ => 0, - } - } else { - 0 + let decimals = match ctx.args.get(1) { + Some(EmValue::I32(n)) => (*n).clamp(0, 15), + _ => 0, }; let multiplier = 10_f64.powi(decimals); @@ -735,11 +730,11 @@ fn system_math_round_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) - /// /// - `d`: A number to truncate fn system_math_truncate_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::F32(f) => EmValue::F64(f64::from(f.trunc())), EmValue::F64(f) => EmValue::F64(f.trunc()), _ => EmValue::F64(0.0), @@ -761,12 +756,12 @@ fn system_math_truncate_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread /// - `x`: A double-precision floating-point number to be raised to a power /// - `y`: A double-precision floating-point number that specifies a power fn system_math_pow_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let base = ctx.args[0].to_f64_cil(); - let exp = ctx.args[1].to_f64_cil(); + let base = arg0.to_f64_cil(); + let exp = arg1.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(base.powf(exp)))) } @@ -783,11 +778,11 @@ fn system_math_pow_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `d`: The number whose square root is to be found fn system_math_sqrt_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.sqrt()))) } @@ -805,14 +800,14 @@ fn system_math_sqrt_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// - `d`: The number whose logarithm is to be found /// - `newBase`: The base of the logarithm (optional, defaults to e) fn system_math_log_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(f64::NEG_INFINITY))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); - let result = if ctx.args.len() > 1 { - let base = ctx.args[1].to_f64_cil(); + let result = if let Some(arg1) = ctx.args.get(1) { + let base = arg1.to_f64_cil(); value.log(base) } else { value.ln() @@ -833,11 +828,11 @@ fn system_math_log_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `d`: A number whose logarithm is to be found fn system_math_log10_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(f64::NEG_INFINITY))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.log10()))) } @@ -853,11 +848,11 @@ fn system_math_log10_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) - /// /// - `d`: A number specifying a power fn system_math_exp_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(1.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.exp()))) } @@ -873,11 +868,11 @@ fn system_math_exp_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `a`: An angle, measured in radians fn system_math_sin_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.sin()))) } @@ -893,11 +888,11 @@ fn system_math_sin_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `d`: An angle, measured in radians fn system_math_cos_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(1.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.cos()))) } @@ -913,11 +908,11 @@ fn system_math_cos_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `a`: An angle, measured in radians fn system_math_tan_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.tan()))) } @@ -933,11 +928,11 @@ fn system_math_tan_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `d`: A number representing a sine, where d must be >= -1 and <= 1 fn system_math_asin_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.asin()))) } @@ -953,11 +948,11 @@ fn system_math_asin_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `d`: A number representing a cosine, where d must be >= -1 and <= 1 fn system_math_acos_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(std::f64::consts::FRAC_PI_2))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.acos()))) } @@ -973,11 +968,11 @@ fn system_math_acos_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// /// - `d`: A number representing a tangent fn system_math_atan_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.atan()))) } @@ -994,12 +989,12 @@ fn system_math_atan_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// - `y`: The y coordinate of a point /// - `x`: The x coordinate of a point fn system_math_atan2_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::F64(0.0))); - } + }; - let y = ctx.args[0].to_f64_cil(); - let x = ctx.args[1].to_f64_cil(); + let y = arg0.to_f64_cil(); + let x = arg1.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(y.atan2(x)))) } @@ -1026,11 +1021,11 @@ fn system_numerics_bitoperations_popcount_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } + }; - let count = match &ctx.args[0] { + let count = match arg0 { EmValue::I32(n) => (*n as u32).count_ones() as i32, EmValue::I64(n) => (*n as u64).count_ones() as i32, _ => 0, @@ -1061,11 +1056,11 @@ fn system_numerics_bitoperations_leadingzerocount_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(32))); - } + }; - let count = match &ctx.args[0] { + let count = match arg0 { EmValue::I32(n) => (*n as u32).leading_zeros() as i32, EmValue::I64(n) => (*n as u64).leading_zeros() as i32, _ => 32, @@ -1099,11 +1094,11 @@ fn system_numerics_bitoperations_trailingzerocount_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(32))); - } + }; - let count = match &ctx.args[0] { + let count = match arg0 { EmValue::I32(n) => (*n as u32).trailing_zeros() as i32, EmValue::I64(n) => (*n as u64).trailing_zeros() as i32, _ => 32, @@ -1131,19 +1126,19 @@ fn system_numerics_bitoperations_rotateleft_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } + }; - let shift = match &ctx.args[1] { + let shift = match arg1 { EmValue::I32(n) => *n as u32, - _ => return PreHookResult::Bypass(Some(ctx.args[0].clone())), + _ => return PreHookResult::Bypass(Some(arg0.clone())), }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::I32(n) => EmValue::I32((*n as u32).rotate_left(shift) as i32), EmValue::I64(n) => EmValue::I64((*n as u64).rotate_left(shift) as i64), - _ => ctx.args[0].clone(), + _ => arg0.clone(), }; PreHookResult::Bypass(Some(result)) @@ -1168,19 +1163,19 @@ fn system_numerics_bitoperations_rotateright_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } + }; - let shift = match &ctx.args[1] { + let shift = match arg1 { EmValue::I32(n) => *n as u32, - _ => return PreHookResult::Bypass(Some(ctx.args[0].clone())), + _ => return PreHookResult::Bypass(Some(arg0.clone())), }; - let result = match &ctx.args[0] { + let result = match arg0 { EmValue::I32(n) => EmValue::I32((*n as u32).rotate_right(shift) as i32), EmValue::I64(n) => EmValue::I64((*n as u64).rotate_right(shift) as i64), - _ => ctx.args[0].clone(), + _ => arg0.clone(), }; PreHookResult::Bypass(Some(result)) @@ -1198,108 +1193,108 @@ fn system_numerics_bitoperations_rotateright_pre( /// /// - `x`: The number whose base-2 logarithm is to be found fn system_math_log2_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F64(f64::NEG_INFINITY))); - } + }; - let value = ctx.args[0].to_f64_cil(); + let value = arg0.to_f64_cil(); PreHookResult::Bypass(Some(EmValue::F64(value.log2()))) } /// Hook for `System.MathF.Abs` — absolute value (single-precision). fn system_mathf_abs_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().abs()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().abs()))) } /// Hook for `System.MathF.Sin` — sine (single-precision). fn system_mathf_sin_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().sin()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().sin()))) } /// Hook for `System.MathF.Cos` — cosine (single-precision). fn system_mathf_cos_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(1.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().cos()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().cos()))) } /// Hook for `System.MathF.Sqrt` — square root (single-precision). fn system_mathf_sqrt_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().sqrt()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().sqrt()))) } /// Hook for `System.MathF.Floor` — round down (single-precision). fn system_mathf_floor_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().floor()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().floor()))) } /// Hook for `System.MathF.Ceiling` — round up (single-precision). fn system_mathf_ceiling_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().ceil()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().ceil()))) } /// Hook for `System.MathF.Round` — round to nearest (single-precision). fn system_mathf_round_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().round()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().round()))) } /// Hook for `System.MathF.Min` — minimum of two values (single-precision). fn system_mathf_min_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } - let a = ctx.args[0].to_f32_cil(); - let b = ctx.args[1].to_f32_cil(); + }; + let a = arg0.to_f32_cil(); + let b = arg1.to_f32_cil(); PreHookResult::Bypass(Some(EmValue::F32(a.min(b)))) } /// Hook for `System.MathF.Max` — maximum of two values (single-precision). fn system_mathf_max_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(ctx.args.first().cloned()); - } - let a = ctx.args[0].to_f32_cil(); - let b = ctx.args[1].to_f32_cil(); + }; + let a = arg0.to_f32_cil(); + let b = arg1.to_f32_cil(); PreHookResult::Bypass(Some(EmValue::F32(a.max(b)))) } /// Hook for `System.MathF.Pow` — power (single-precision). fn system_mathf_pow_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - let base = ctx.args[0].to_f32_cil(); - let exp = ctx.args[1].to_f32_cil(); + }; + let base = arg0.to_f32_cil(); + let exp = arg1.to_f32_cil(); PreHookResult::Bypass(Some(EmValue::F32(base.powf(exp)))) } /// Hook for `System.MathF.Log` — natural or custom base logarithm (single-precision). fn system_mathf_log_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(f32::NEG_INFINITY))); - } - let value = ctx.args[0].to_f32_cil(); - let result = if ctx.args.len() > 1 { - let base = ctx.args[1].to_f32_cil(); + }; + let value = arg0.to_f32_cil(); + let result = if let Some(arg1) = ctx.args.get(1) { + let base = arg1.to_f32_cil(); value.log(base) } else { value.ln() @@ -1309,68 +1304,68 @@ fn system_mathf_log_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> /// Hook for `System.MathF.Log2` — base-2 logarithm (single-precision). fn system_mathf_log2_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(f32::NEG_INFINITY))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().log2()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().log2()))) } /// Hook for `System.MathF.Log10` — base-10 logarithm (single-precision). fn system_mathf_log10_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(f32::NEG_INFINITY))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().log10()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().log10()))) } /// Hook for `System.MathF.Tan` — tangent (single-precision). fn system_mathf_tan_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().tan()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().tan()))) } /// Hook for `System.MathF.Asin` — inverse sine (single-precision). fn system_mathf_asin_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().asin()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().asin()))) } /// Hook for `System.MathF.Acos` — inverse cosine (single-precision). fn system_mathf_acos_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(std::f32::consts::FRAC_PI_2))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().acos()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().acos()))) } /// Hook for `System.MathF.Atan` — inverse tangent (single-precision). fn system_mathf_atan_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().atan()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().atan()))) } /// Hook for `System.MathF.Atan2` — two-argument arctangent (single-precision). fn system_mathf_atan2_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - let y = ctx.args[0].to_f32_cil(); - let x = ctx.args[1].to_f32_cil(); + }; + let y = arg0.to_f32_cil(); + let x = arg1.to_f32_cil(); PreHookResult::Bypass(Some(EmValue::F32(y.atan2(x)))) } /// Hook for `System.MathF.Exp` — e raised to a power (single-precision). fn system_mathf_exp_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(1.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().exp()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().exp()))) } /// Hook for `System.MathF.Truncate` — remove fractional part (single-precision). @@ -1378,19 +1373,19 @@ fn system_mathf_truncate_pre( ctx: &HookContext<'_>, _thread: &mut EmulationThread, ) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::F32(0.0))); - } - PreHookResult::Bypass(Some(EmValue::F32(ctx.args[0].to_f32_cil().trunc()))) + }; + PreHookResult::Bypass(Some(EmValue::F32(arg0.to_f32_cil().trunc()))) } /// Hook for `System.MathF.Sign` — returns -1, 0, or 1 (single-precision input). #[allow(clippy::cast_possible_truncation)] fn system_mathf_sign_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(arg0) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::I32(0))); - } - let f = ctx.args[0].to_f32_cil(); + }; + let f = arg0.to_f32_cil(); let result = if f.is_nan() { 0 } else if f > 0.0 { diff --git a/dotscope/src/emulation/runtime/bcl/system/string.rs b/dotscope/src/emulation/runtime/bcl/system/string.rs index e74208eb..abf15f95 100644 --- a/dotscope/src/emulation/runtime/bcl/system/string.rs +++ b/dotscope/src/emulation/runtime/bcl/system/string.rs @@ -275,8 +275,8 @@ fn string_substring_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> .and_then(|r| r.ok()) .unwrap_or(0); - let substring: String = if ctx.args.len() > 1 { - let length = usize::try_from(&ctx.args[1]).unwrap_or(0); + let substring: String = if let Some(len_arg) = ctx.args.get(1) { + let length = usize::try_from(len_arg).unwrap_or(0); s.chars().skip(start).take(length).collect() } else { s.chars().skip(start).collect() @@ -429,16 +429,16 @@ fn string_replace_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr return PreHookResult::Bypass(Some(EmValue::Null)); }; - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::ObjectRef(*href))); - } + }; let s = match thread.heap().get_string(*href) { Ok(s) => s.to_string(), Err(e) => return PreHookResult::Error(format!("heap allocation failed: {e}")), }; - let result = match (&ctx.args[0], &ctx.args[1]) { + let result = match (arg0, arg1) { (EmValue::Char(old), EmValue::Char(new)) => s.replace(*old, &new.to_string()), (EmValue::ObjectRef(old_ref), EmValue::ObjectRef(new_ref)) => { let old_str = thread @@ -734,7 +734,8 @@ fn string_pad_left_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P let result = if s.len() >= total_width { s } else { - let padding: String = std::iter::repeat_n(pad_char, total_width - s.len()).collect(); + let padding: String = + std::iter::repeat_n(pad_char, total_width.saturating_sub(s.len())).collect(); format!("{padding}{s}") }; @@ -790,7 +791,8 @@ fn string_pad_right_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> let result = if s.len() >= total_width { s } else { - let padding: String = std::iter::repeat_n(pad_char, total_width - s.len()).collect(); + let padding: String = + std::iter::repeat_n(pad_char, total_width.saturating_sub(s.len())).collect(); format!("{s}{padding}") }; @@ -847,11 +849,11 @@ fn string_is_null_or_empty_pre( /// /// A single string with elements separated by the separator fn string_join_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.len() < 2 { + let (Some(arg0), Some(arg1)) = (ctx.args.first(), ctx.args.get(1)) else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - let separator = if let EmValue::ObjectRef(href) = &ctx.args[0] { + let separator = if let EmValue::ObjectRef(href) = arg0 { thread .heap() .get_string(*href) @@ -861,7 +863,7 @@ fn string_join_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHo String::new() }; - let array_ref = match &ctx.args[1] { + let array_ref = match arg1 { EmValue::ObjectRef(r) => *r, _ => return PreHookResult::Bypass(Some(EmValue::Null)), }; @@ -909,11 +911,11 @@ fn string_join_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHo /// /// A formatted string with placeholders replaced by argument values fn string_format_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> PreHookResult { - if ctx.args.is_empty() { + let Some(first_arg) = ctx.args.first() else { return PreHookResult::Bypass(Some(EmValue::Null)); - } + }; - let format_str = match &ctx.args[0] { + let format_str = match first_arg { EmValue::ObjectRef(href) => match thread.heap().get_string(*href) { Ok(s) => s.to_string(), Err(e) => return PreHookResult::Error(format!("heap allocation failed: {e}")), diff --git a/dotscope/src/emulation/runtime/bcl/text/encoding.rs b/dotscope/src/emulation/runtime/bcl/text/encoding.rs index 41a3b203..1512348e 100644 --- a/dotscope/src/emulation/runtime/bcl/text/encoding.rs +++ b/dotscope/src/emulation/runtime/bcl/text/encoding.rs @@ -273,7 +273,7 @@ fn encoding_get_bytes_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - }) .unwrap_or(EncodingType::Utf8); - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Ok(s) = thread.heap().get_string(*handle) { let bytes = encode_string(&s, encoding_type); match thread.heap_mut().alloc_byte_array(&bytes) { @@ -336,29 +336,10 @@ fn encoding_get_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) .unwrap_or(EncodingType::Utf8); // Handle both GetString(byte[]) and GetString(byte[], int, int) overloads - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Some(all_bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { - let bytes = if ctx.args.len() >= 3 { - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let offset = match &ctx.args[1] { - EmValue::I32(o) => *o as usize, - _ => 0, - }; - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let count = match &ctx.args[2] { - EmValue::I32(c) => *c as usize, - _ => all_bytes.len(), - }; - if offset + count <= all_bytes.len() { - all_bytes[offset..offset + count].to_vec() - } else { - all_bytes - } - } else { - all_bytes - }; + let bytes = slice_bytes_with_offset_count(&all_bytes, ctx.args.get(1), ctx.args.get(2)) + .unwrap_or(all_bytes); let s = decode_bytes(&bytes, encoding_type); @@ -382,6 +363,33 @@ fn encoding_get_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) PreHookResult::Bypass(Some(EmValue::Null)) } +/// Applies a `(byte[], int offset, int count)` slice spec to `all_bytes`. +/// +/// Returns `Some(slice)` when both `offset_arg` and `count_arg` are present +/// and `offset.checked_add(count) <= all_bytes.len()`. Returns `None` when +/// the offset/count overload was not used (either argument missing) or when +/// the requested range is invalid — callers fall back to the full buffer in +/// both cases, matching the legacy behaviour. +fn slice_bytes_with_offset_count( + all_bytes: &[u8], + offset_arg: Option<&EmValue>, + count_arg: Option<&EmValue>, +) -> Option> { + let (offset_arg, count_arg) = (offset_arg?, count_arg?); + // Negative or non-`I32` offset/count is treated as "no slice", matching + // the legacy behaviour of falling back to the full buffer. + let offset = match offset_arg { + EmValue::I32(o) => usize::try_from(*o).ok()?, + _ => 0, + }; + let count = match count_arg { + EmValue::I32(c) => usize::try_from(*c).ok()?, + _ => all_bytes.len(), + }; + let end = offset.checked_add(count)?; + all_bytes.get(offset..end).map(<[u8]>::to_vec) +} + /// Hook for `System.Text.Encoding.GetByteCount` method. /// /// Inspects the `this` encoding instance to determine the actual encoding @@ -423,7 +431,7 @@ fn encoding_get_byte_count_pre( }) .unwrap_or(EncodingType::Utf8); - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Ok(s) = thread.heap().get_string(*handle) { let byte_len = encode_string(&s, encoding_type).len(); let len = i32::try_from(byte_len).unwrap_or(i32::MAX); @@ -472,7 +480,7 @@ fn encoding_get_char_count_pre( }) .unwrap_or(EncodingType::Utf8); - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Some(bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { let s = decode_bytes(&bytes, encoding_type); let count = i32::try_from(s.chars().count()).unwrap_or(i32::MAX); @@ -649,7 +657,7 @@ fn utf8_get_bytes_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pr return PreHookResult::Bypass(Some(EmValue::Null)); } - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Ok(s) = thread.heap().get_string(*handle) { let bytes: Vec = s.as_bytes().to_vec(); match thread.heap_mut().alloc_byte_array(&bytes) { @@ -687,29 +695,10 @@ fn utf8_get_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P } // Handle both GetString(byte[]) and GetString(byte[], int, int) overloads - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Some(all_bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { - let bytes = if ctx.args.len() >= 3 { - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let offset = match &ctx.args[1] { - EmValue::I32(o) => *o as usize, - _ => 0, - }; - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let count = match &ctx.args[2] { - EmValue::I32(c) => *c as usize, - _ => all_bytes.len(), - }; - if offset + count <= all_bytes.len() { - all_bytes[offset..offset + count].to_vec() - } else { - all_bytes - } - } else { - all_bytes - }; + let bytes = slice_bytes_with_offset_count(&all_bytes, ctx.args.get(1), ctx.args.get(2)) + .unwrap_or(all_bytes); let s = String::from_utf8_lossy(&bytes).into_owned(); @@ -754,7 +743,7 @@ fn ascii_get_bytes_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> P return PreHookResult::Bypass(Some(EmValue::Null)); } - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Ok(s) = thread.heap().get_string(*handle) { // ASCII encoding - replace non-ASCII with '?' let bytes: Vec = s @@ -797,29 +786,10 @@ fn ascii_get_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> } // Handle both GetString(byte[]) and GetString(byte[], int, int) overloads - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Some(all_bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { - let bytes = if ctx.args.len() >= 3 { - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let offset = match &ctx.args[1] { - EmValue::I32(o) => *o as usize, - _ => 0, - }; - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let count = match &ctx.args[2] { - EmValue::I32(c) => *c as usize, - _ => all_bytes.len(), - }; - if offset + count <= all_bytes.len() { - all_bytes[offset..offset + count].to_vec() - } else { - all_bytes - } - } else { - all_bytes - }; + let bytes = slice_bytes_with_offset_count(&all_bytes, ctx.args.get(1), ctx.args.get(2)) + .unwrap_or(all_bytes); // ASCII decoding - only keep valid ASCII let s: String = bytes @@ -867,7 +837,7 @@ fn unicode_get_bytes_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> return PreHookResult::Bypass(Some(EmValue::Null)); } - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Ok(s) = thread.heap().get_string(*handle) { // UTF-16 LE encoding let bytes: Vec = s.encode_utf16().flat_map(u16::to_le_bytes).collect(); @@ -906,38 +876,20 @@ fn unicode_get_string_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) - } // Handle both GetString(byte[]) and GetString(byte[], int, int) overloads - if let EmValue::ObjectRef(handle) = &ctx.args[0] { + if let Some(EmValue::ObjectRef(handle)) = ctx.args.first() { if let Some(all_bytes) = try_hook!(thread.heap().get_byte_array(*handle)) { - let bytes = if ctx.args.len() >= 3 { - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let offset = match &ctx.args[1] { - EmValue::I32(o) => *o as usize, - _ => 0, - }; - // Safe: value validated as non-negative - #[allow(clippy::cast_sign_loss)] - let count = match &ctx.args[2] { - EmValue::I32(c) => *c as usize, - _ => all_bytes.len(), - }; - if offset + count <= all_bytes.len() { - all_bytes[offset..offset + count].to_vec() - } else { - all_bytes - } - } else { - all_bytes - }; + let bytes = slice_bytes_with_offset_count(&all_bytes, ctx.args.get(1), ctx.args.get(2)) + .unwrap_or(all_bytes); // UTF-16 LE decoding - if bytes.len() % 2 != 0 { + if !bytes.len().is_multiple_of(2) { return PreHookResult::Bypass(Some(EmValue::Null)); } let u16s: Vec = bytes .chunks_exact(2) - .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .filter_map(|chunk| <[u8; 2]>::try_from(chunk).ok()) + .map(u16::from_le_bytes) .collect(); let s = String::from_utf16_lossy(&u16s); @@ -1011,7 +963,8 @@ fn decode_bytes(bytes: &[u8], encoding_type: EncodingType) -> String { } let u16s: Vec = bytes .chunks_exact(2) - .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) + .filter_map(|chunk| <[u8; 2]>::try_from(chunk).ok()) + .map(u16::from_le_bytes) .collect(); String::from_utf16_lossy(&u16s) } @@ -1021,7 +974,8 @@ fn decode_bytes(bytes: &[u8], encoding_type: EncodingType) -> String { } let u16s: Vec = bytes .chunks_exact(2) - .map(|chunk| u16::from_be_bytes([chunk[0], chunk[1]])) + .filter_map(|chunk| <[u8; 2]>::try_from(chunk).ok()) + .map(u16::from_be_bytes) .collect(); String::from_utf16_lossy(&u16s) } @@ -1031,10 +985,9 @@ fn decode_bytes(bytes: &[u8], encoding_type: EncodingType) -> String { } bytes .chunks_exact(4) - .filter_map(|chunk| { - let code = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]); - char::from_u32(code) - }) + .filter_map(|chunk| <[u8; 4]>::try_from(chunk).ok()) + .map(u32::from_le_bytes) + .filter_map(char::from_u32) .collect() } } diff --git a/dotscope/src/emulation/runtime/bcl/text/stringbuilder.rs b/dotscope/src/emulation/runtime/bcl/text/stringbuilder.rs index a2a0e658..3a69be27 100644 --- a/dotscope/src/emulation/runtime/bcl/text/stringbuilder.rs +++ b/dotscope/src/emulation/runtime/bcl/text/stringbuilder.rs @@ -295,7 +295,7 @@ fn stringbuilder_remove_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) .map_or(buffer.len(), |(p, _)| p); let byte_end = buffer .char_indices() - .nth(start_idx + len) + .nth(start_idx.saturating_add(len)) .map_or(buffer.len(), |(p, _)| p); buffer.drain(byte_start..byte_end); } @@ -399,7 +399,7 @@ fn stringbuilder_set_length_pre( buffer.chars().take(target).collect() } else { let mut s = buffer; - for _ in 0..(target - current) { + for _ in 0..target.saturating_sub(current) { s.push('\0'); } s @@ -441,8 +441,8 @@ fn stringbuilder_set_chars_pre( { let idx = (*index).max(0) as usize; let mut chars: Vec = buffer.chars().collect(); - if idx < chars.len() { - chars[idx] = *ch; + if let Some(slot) = chars.get_mut(idx) { + *slot = *ch; let new_buffer: String = chars.into_iter().collect(); try_hook!(write_sb(thread, *sb_ref, new_buffer, capacity)); } diff --git a/dotscope/src/emulation/runtime/native.rs b/dotscope/src/emulation/runtime/native.rs index 684972b3..156dc7b7 100644 --- a/dotscope/src/emulation/runtime/native.rs +++ b/dotscope/src/emulation/runtime/native.rs @@ -255,34 +255,41 @@ fn register_virtual_protect(manager: &HookManager) -> Result<()> { } // Validate lpAddress is not null - let lp_address = match &args[0] { - EmValue::UnmanagedPtr(addr) if *addr != 0 => *addr, - EmValue::NativeInt(addr) if *addr > 0 => (*addr).cast_unsigned(), - EmValue::NativeUInt(addr) if *addr > 0 => *addr, + let lp_address = match args.first() { + Some(EmValue::UnmanagedPtr(addr)) if *addr != 0 => *addr, + Some(EmValue::NativeInt(addr)) if *addr > 0 => (*addr).cast_unsigned(), + Some(EmValue::NativeUInt(addr)) if *addr > 0 => *addr, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; // Validate dwSize is non-zero #[allow(clippy::cast_possible_truncation)] - let dw_size = match &args[1] { - EmValue::I32(size) if *size > 0 => (*size).cast_unsigned() as usize, - EmValue::NativeInt(size) if *size > 0 => (*size).cast_unsigned() as usize, - EmValue::NativeUInt(size) if *size > 0 => *size as usize, + let dw_size = match args.get(1) { + Some(EmValue::I32(size)) if *size > 0 => (*size).cast_unsigned() as usize, + Some(EmValue::NativeInt(size)) if *size > 0 => (*size).cast_unsigned() as usize, + Some(EmValue::NativeUInt(size)) if *size > 0 => *size as usize, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; // Get the new protection value #[allow(clippy::cast_possible_truncation)] - let fl_new_protect = match &args[2] { - EmValue::I32(p) => (*p).cast_unsigned(), - EmValue::NativeInt(p) => (*p).cast_unsigned() as u32, - EmValue::NativeUInt(p) => *p as u32, + let fl_new_protect = match args.get(2) { + Some(EmValue::I32(p)) => (*p).cast_unsigned(), + Some(EmValue::NativeInt(p)) => (*p).cast_unsigned() as u32, + Some(EmValue::NativeUInt(p)) => *p as u32, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; // Check if address range is valid let space = thread.address_space(); - if !space.is_valid(lp_address) || !space.is_valid(lp_address + dw_size as u64 - 1) { + let last_addr = match (dw_size as u64) + .checked_sub(1) + .and_then(|s| lp_address.checked_add(s)) + { + Some(v) => v, + None => return PreHookResult::Bypass(Some(EmValue::I32(0))), + }; + if !space.is_valid(lp_address) || !space.is_valid(last_addr) { return PreHookResult::Bypass(Some(EmValue::I32(0))); } @@ -297,15 +304,18 @@ fn register_virtual_protect(manager: &HookManager) -> Result<()> { // Write old protection value to the out parameter let old_protect_value = EmValue::I32(old_protect_windows.cast_signed()); - match &args[3] { - EmValue::ManagedPtr(ptr) + match args.get(3) { + None => { + // missing arg - treat as null pointer (allowed) + } + Some(EmValue::ManagedPtr(ptr)) if thread .store_through_pointer(ptr, old_protect_value) .is_err() => { return PreHookResult::Bypass(Some(EmValue::I32(0))); } - EmValue::UnmanagedPtr(addr) if *addr != 0 => { + Some(EmValue::UnmanagedPtr(addr)) if *addr != 0 => { let old_protect_bytes = old_protect_windows.to_le_bytes(); if thread .address_space() @@ -315,7 +325,7 @@ fn register_virtual_protect(manager: &HookManager) -> Result<()> { return PreHookResult::Bypass(Some(EmValue::I32(0))); } } - EmValue::NativeInt(addr) if *addr > 0 => { + Some(EmValue::NativeInt(addr)) if *addr > 0 => { let old_protect_bytes = old_protect_windows.to_le_bytes(); if thread .address_space() @@ -325,7 +335,7 @@ fn register_virtual_protect(manager: &HookManager) -> Result<()> { return PreHookResult::Bypass(Some(EmValue::I32(0))); } } - EmValue::NativeUInt(addr) if *addr > 0 => { + Some(EmValue::NativeUInt(addr)) if *addr > 0 => { let old_protect_bytes = old_protect_windows.to_le_bytes(); if thread .address_space() @@ -404,16 +414,11 @@ fn register_virtual_free(manager: &HookManager) -> Result<()> { // Returns BOOL: 1 = success, 0 = failure let args = ctx.args; - // Validate we have enough arguments - if args.is_empty() { - return PreHookResult::Bypass(Some(EmValue::I32(0))); - } - // Validate lpAddress is not null - let lp_address = match &args[0] { - EmValue::UnmanagedPtr(addr) if *addr != 0 => *addr, - EmValue::NativeInt(addr) if *addr > 0 => (*addr).cast_unsigned(), - EmValue::NativeUInt(addr) if *addr > 0 => *addr, + let lp_address = match args.first() { + Some(EmValue::UnmanagedPtr(addr)) if *addr != 0 => *addr, + Some(EmValue::NativeInt(addr)) if *addr > 0 => (*addr).cast_unsigned(), + Some(EmValue::NativeUInt(addr)) if *addr > 0 => *addr, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; @@ -612,25 +617,20 @@ fn register_check_remote_debugger_present(manager: &HookManager) -> Result<()> { // Returns BOOL: 1 = success, 0 = failure let args = ctx.args; - // Validate we have enough arguments - if args.len() < 2 { - return PreHookResult::Bypass(Some(EmValue::I32(0))); - } - // Validate hProcess is a valid handle (-1 for current process, or non-null) - match &args[0] { - EmValue::NativeInt(-1) => {} // Current process pseudo-handle is valid - EmValue::NativeInt(h) if *h > 0 => {} - EmValue::NativeUInt(h) if *h > 0 => {} - EmValue::UnmanagedPtr(h) if *h > 0 => {} + match args.first() { + Some(EmValue::NativeInt(-1)) => {} // Current process pseudo-handle is valid + Some(EmValue::NativeInt(h)) if *h > 0 => {} + Some(EmValue::NativeUInt(h)) if *h > 0 => {} + Some(EmValue::UnmanagedPtr(h)) if *h > 0 => {} _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), } // Get output pointer address - must be valid - let output_addr = match &args[1] { - EmValue::UnmanagedPtr(addr) if *addr != 0 => *addr, - EmValue::NativeInt(addr) if *addr > 0 => (*addr).cast_unsigned(), - EmValue::NativeUInt(addr) if *addr > 0 => *addr, + let output_addr = match args.get(1) { + Some(EmValue::UnmanagedPtr(addr)) if *addr != 0 => *addr, + Some(EmValue::NativeInt(addr)) if *addr > 0 => (*addr).cast_unsigned(), + Some(EmValue::NativeUInt(addr)) if *addr > 0 => *addr, _ => return PreHookResult::Bypass(Some(EmValue::I32(0))), }; diff --git a/dotscope/src/emulation/runtime/state.rs b/dotscope/src/emulation/runtime/state.rs index 512b71b2..4c091492 100644 --- a/dotscope/src/emulation/runtime/state.rs +++ b/dotscope/src/emulation/runtime/state.rs @@ -213,13 +213,16 @@ impl RuntimeState { // Register default BCL hooks based on configuration if config.stubs.bcl_stubs { - bcl::register(&hooks).expect("BCL hook registration should not fail at startup"); + if let Err(e) = bcl::register(&hooks) { + log::error!("BCL hook registration failed at startup: {e}"); + } } // Register default native P/Invoke hooks based on configuration if config.stubs.pinvoke_stubs { - native::register(&hooks, &native_functions) - .expect("Native hook registration should not fail at startup"); + if let Err(e) = native::register(&hooks, &native_functions) { + log::error!("Native hook registration failed at startup: {e}"); + } } Self { diff --git a/dotscope/src/emulation/thread/scheduler.rs b/dotscope/src/emulation/thread/scheduler.rs index 5486c202..2eebce7f 100644 --- a/dotscope/src/emulation/thread/scheduler.rs +++ b/dotscope/src/emulation/thread/scheduler.rs @@ -306,7 +306,7 @@ impl ThreadScheduler { /// A unique thread ID that has not been used before in this scheduler. pub fn allocate_thread_id(&mut self) -> ThreadId { let id = ThreadId::new(self.next_thread_id); - self.next_thread_id += 1; + self.next_thread_id = self.next_thread_id.saturating_add(1); id } @@ -324,7 +324,9 @@ impl ThreadScheduler { /// in the ready queue. #[must_use] pub fn ready_count(&self) -> usize { - self.ready_queue.len() + usize::from(self.current.is_some()) + self.ready_queue + .len() + .saturating_add(usize::from(self.current.is_some())) } /// Returns the ID of the currently running thread, if any. @@ -392,7 +394,7 @@ impl ThreadScheduler { /// Enqueue a thread as ready to run. fn enqueue_ready(&mut self, id: ThreadId, priority: ThreadPriority) { let sequence = self.next_sequence; - self.next_sequence += 1; + self.next_sequence = self.next_sequence.saturating_add(1); self.ready_queue.push(ScheduleEntry { priority, thread_id: id, @@ -480,8 +482,8 @@ impl ThreadScheduler { /// } /// ``` pub fn record_instruction(&mut self) -> bool { - self.total_instructions += 1; - self.quantum_used += 1; + self.total_instructions = self.total_instructions.saturating_add(1); + self.quantum_used = self.quantum_used.saturating_add(1); if let Some(id) = self.current { if let Some(thread) = self.threads.get_mut(&id) { diff --git a/dotscope/src/emulation/thread/state.rs b/dotscope/src/emulation/thread/state.rs index d71a7209..6a9656e1 100644 --- a/dotscope/src/emulation/thread/state.rs +++ b/dotscope/src/emulation/thread/state.rs @@ -360,7 +360,7 @@ impl ThreadCallFrame { /// /// Typically called after executing an instruction to move to the next one. pub fn advance_ip(&mut self, delta: u32) { - self.instruction_offset += delta; + self.instruction_offset = self.instruction_offset.saturating_add(delta); } /// Returns whether the caller expects a return value. @@ -1027,7 +1027,8 @@ impl EmulationThread { // Arguments are pushed left-to-right, so arg0 is at depth (count-1) // and argN is at depth 0 (top of stack) for i in 0..count { - args.push(self.eval_stack.peek_at(count - 1 - i)?.clone()); + let depth = count.saturating_sub(1).saturating_sub(i); + args.push(self.eval_stack.peek_at(depth)?.clone()); } Ok(args) } @@ -1113,7 +1114,9 @@ impl EmulationThread { } }; - let total_size = length * elem_size; + let total_size = length + .checked_mul(elem_size) + .ok_or(EmulationError::ArithmeticOverflow)?; let base_addr = self .context .address_space @@ -1124,7 +1127,13 @@ impl EmulationThread { .address_space .register_pinned_array(base_addr, array, elem_size, length)?; - let element_addr = base_addr + (index * elem_size) as u64 + u64::from(offset); + let index_offset = index + .checked_mul(elem_size) + .ok_or(EmulationError::ArithmeticOverflow)?; + let element_addr = base_addr + .checked_add(index_offset as u64) + .and_then(|a| a.checked_add(u64::from(offset))) + .ok_or(EmulationError::ArithmeticOverflow)?; Ok(element_addr) } @@ -1149,7 +1158,7 @@ impl EmulationThread { /// /// Called by the interpreter after each instruction. pub fn increment_instructions(&mut self) { - self.instructions_executed += 1; + self.instructions_executed = self.instructions_executed.saturating_add(1); } /// Returns the method token of the currently executing method. diff --git a/dotscope/src/emulation/thread/sync.rs b/dotscope/src/emulation/thread/sync.rs index 3e168246..7423e22d 100644 --- a/dotscope/src/emulation/thread/sync.rs +++ b/dotscope/src/emulation/thread/sync.rs @@ -196,7 +196,7 @@ impl MonitorState { true } Some(owner) if owner == thread_id => { - self.recursion_count += 1; + self.recursion_count = self.recursion_count.saturating_add(1); true } _ => false, @@ -224,7 +224,7 @@ impl MonitorState { pub fn exit(&mut self, thread_id: ThreadId) -> Result { match self.owner { Some(owner) if owner == thread_id => { - self.recursion_count -= 1; + self.recursion_count = self.recursion_count.saturating_sub(1); if self.recursion_count == 0 { self.owner = None; Ok(true) @@ -290,7 +290,7 @@ impl MutexState { true } Some(owner) if owner == thread_id => { - self.recursion_count += 1; + self.recursion_count = self.recursion_count.saturating_add(1); true } _ => false, @@ -318,7 +318,7 @@ impl MutexState { pub fn release(&mut self, thread_id: ThreadId) -> Result { match self.owner { Some(owner) if owner == thread_id => { - self.recursion_count -= 1; + self.recursion_count = self.recursion_count.saturating_sub(1); if self.recursion_count == 0 { self.owner = None; Ok(true) @@ -415,7 +415,7 @@ impl SemaphoreState { /// was zero. pub fn try_acquire(&mut self) -> bool { if self.count > 0 { - self.count -= 1; + self.count = self.count.saturating_sub(1); true } else { false diff --git a/dotscope/src/emulation/tracer/calltree.rs b/dotscope/src/emulation/tracer/calltree.rs index 693ac4b4..72f19bea 100644 --- a/dotscope/src/emulation/tracer/calltree.rs +++ b/dotscope/src/emulation/tracer/calltree.rs @@ -165,7 +165,7 @@ impl CallTreeBuilder { } TraceEvent::Instruction { .. } => { if let Some(top) = state.stack.last_mut() { - top.instruction_count += 1; + top.instruction_count = top.instruction_count.saturating_add(1); } } TraceEvent::ExceptionThrow { @@ -266,12 +266,12 @@ impl CallTreeNode { /// Total instruction count including all descendants. #[must_use] pub fn total_instructions(&self) -> u64 { - self.instruction_count - + self - .children + self.instruction_count.saturating_add( + self.children .iter() .map(|c| c.total_instructions()) - .sum::() + .fold(0u64, u64::saturating_add), + ) } /// Converts the call tree to a JSON string. @@ -321,7 +321,7 @@ impl CallTreeNode { self.exceptions.len() )?; for child in &self.children { - child.fmt_indented(f, indent + 1)?; + child.fmt_indented(f, indent.saturating_add(1))?; } Ok(()) } diff --git a/dotscope/src/emulation/value/emvalue.rs b/dotscope/src/emulation/value/emvalue.rs index 89418e3f..e8032453 100644 --- a/dotscope/src/emulation/value/emvalue.rs +++ b/dotscope/src/emulation/value/emvalue.rs @@ -1575,12 +1575,16 @@ impl EmValue { EmValue::I32(i32::from(bytes.first().copied().unwrap_or(0))) } CilFlavor::I2 | CilFlavor::U2 | CilFlavor::Char => { - let arr: [u8; 2] = bytes[..2.min(bytes.len())].try_into().unwrap_or([0, 0]); + let arr: [u8; 2] = bytes + .get(..2.min(bytes.len())) + .and_then(|s| s.try_into().ok()) + .unwrap_or([0, 0]); EmValue::I32(i32::from(i16::from_le_bytes(arr))) } CilFlavor::I4 | CilFlavor::U4 | CilFlavor::R4 => { - let arr: [u8; 4] = bytes[..4.min(bytes.len())] - .try_into() + let arr: [u8; 4] = bytes + .get(..4.min(bytes.len())) + .and_then(|s| s.try_into().ok()) .unwrap_or([0, 0, 0, 0]); if matches!(flavor, CilFlavor::R4) { EmValue::F32(f32::from_le_bytes(arr)) @@ -1589,8 +1593,9 @@ impl EmValue { } } CilFlavor::I8 | CilFlavor::U8 | CilFlavor::R8 => { - let arr: [u8; 8] = bytes[..8.min(bytes.len())] - .try_into() + let arr: [u8; 8] = bytes + .get(..8.min(bytes.len())) + .and_then(|s| s.try_into().ok()) .unwrap_or([0, 0, 0, 0, 0, 0, 0, 0]); if matches!(flavor, CilFlavor::R8) { EmValue::F64(f64::from_le_bytes(arr)) @@ -1600,8 +1605,9 @@ impl EmValue { } CilFlavor::I | CilFlavor::U => match ptr_size { PointerSize::Bit32 => { - let arr: [u8; 4] = bytes[..4.min(bytes.len())] - .try_into() + let arr: [u8; 4] = bytes + .get(..4.min(bytes.len())) + .and_then(|s| s.try_into().ok()) .unwrap_or([0, 0, 0, 0]); if matches!(flavor, CilFlavor::I) { EmValue::NativeInt(i64::from(i32::from_le_bytes(arr))) @@ -1610,8 +1616,9 @@ impl EmValue { } } PointerSize::Bit64 => { - let arr: [u8; 8] = bytes[..8.min(bytes.len())] - .try_into() + let arr: [u8; 8] = bytes + .get(..8.min(bytes.len())) + .and_then(|s| s.try_into().ok()) .unwrap_or([0, 0, 0, 0, 0, 0, 0, 0]); if matches!(flavor, CilFlavor::I) { EmValue::NativeInt(i64::from_le_bytes(arr)) diff --git a/dotscope/src/emulation/value/ops/binary.rs b/dotscope/src/emulation/value/ops/binary.rs index a0eb34ec..e8ba3800 100644 --- a/dotscope/src/emulation/value/ops/binary.rs +++ b/dotscope/src/emulation/value/ops/binary.rs @@ -646,35 +646,50 @@ impl EmValue { match (self, other) { (EmValue::I32(a), EmValue::I32(b)) => { if unsigned { - Ok(EmValue::I32(((*a as u32) / (*b as u32)) as i32)) + let q = (*a as u32) + .checked_div(*b as u32) + .ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I32(q as i32)) } else { // Handle MIN / -1 overflow case if *a == i32::MIN && *b == -1 { Ok(EmValue::I32(i32::MIN)) // Wrapping behavior } else { - Ok(EmValue::I32(a / b)) + let q = a.checked_div(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I32(q)) } } } (EmValue::I64(a), EmValue::I64(b)) => { if unsigned { - Ok(EmValue::I64(((*a as u64) / (*b as u64)) as i64)) + let q = (*a as u64) + .checked_div(*b as u64) + .ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I64(q as i64)) } else if *a == i64::MIN && *b == -1 { Ok(EmValue::I64(i64::MIN)) } else { - Ok(EmValue::I64(a / b)) + let q = a.checked_div(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I64(q)) } } (EmValue::NativeInt(a), EmValue::NativeInt(b)) => { if unsigned { - Ok(EmValue::NativeInt(((*a as u64) / (*b as u64)) as i64)) + let q = (*a as u64) + .checked_div(*b as u64) + .ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::NativeInt(q as i64)) } else if *a == i64::MIN && *b == -1 { Ok(EmValue::NativeInt(i64::MIN)) } else { - Ok(EmValue::NativeInt(a / b)) + let q = a.checked_div(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::NativeInt(q)) } } - (EmValue::NativeUInt(a), EmValue::NativeUInt(b)) => Ok(EmValue::NativeUInt(a / b)), + (EmValue::NativeUInt(a), EmValue::NativeUInt(b)) => { + let q = a.checked_div(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::NativeUInt(q)) + } (EmValue::F32(a), EmValue::F32(b)) => Ok(EmValue::F32(a / b)), (EmValue::F64(a), EmValue::F64(b)) => Ok(EmValue::F64(a / b)), (a, b) => Err(EmulationError::InvalidOperationTypes { @@ -702,35 +717,50 @@ impl EmValue { match (self, other) { (EmValue::I32(a), EmValue::I32(b)) => { if unsigned { - Ok(EmValue::I32(((*a as u32) % (*b as u32)) as i32)) + let r = (*a as u32) + .checked_rem(*b as u32) + .ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I32(r as i32)) } else { // Handle MIN % -1 case (result is 0) if *a == i32::MIN && *b == -1 { Ok(EmValue::I32(0)) } else { - Ok(EmValue::I32(a % b)) + let r = a.checked_rem(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I32(r)) } } } (EmValue::I64(a), EmValue::I64(b)) => { if unsigned { - Ok(EmValue::I64(((*a as u64) % (*b as u64)) as i64)) + let r = (*a as u64) + .checked_rem(*b as u64) + .ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I64(r as i64)) } else if *a == i64::MIN && *b == -1 { Ok(EmValue::I64(0)) } else { - Ok(EmValue::I64(a % b)) + let r = a.checked_rem(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::I64(r)) } } (EmValue::NativeInt(a), EmValue::NativeInt(b)) => { if unsigned { - Ok(EmValue::NativeInt(((*a as u64) % (*b as u64)) as i64)) + let r = (*a as u64) + .checked_rem(*b as u64) + .ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::NativeInt(r as i64)) } else if *a == i64::MIN && *b == -1 { Ok(EmValue::NativeInt(0)) } else { - Ok(EmValue::NativeInt(a % b)) + let r = a.checked_rem(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::NativeInt(r)) } } - (EmValue::NativeUInt(a), EmValue::NativeUInt(b)) => Ok(EmValue::NativeUInt(a % b)), + (EmValue::NativeUInt(a), EmValue::NativeUInt(b)) => { + let r = a.checked_rem(*b).ok_or(EmulationError::DivisionByZero)?; + Ok(EmValue::NativeUInt(r)) + } (EmValue::F32(a), EmValue::F32(b)) => Ok(EmValue::F32(a % b)), (EmValue::F64(a), EmValue::F64(b)) => Ok(EmValue::F64(a % b)), (a, b) => Err(EmulationError::InvalidOperationTypes { diff --git a/dotscope/src/error.rs b/dotscope/src/error.rs index 1a6967a2..31bba1a0 100644 --- a/dotscope/src/error.rs +++ b/dotscope/src/error.rs @@ -306,7 +306,7 @@ pub enum Error { /// This error occurs when looking up a type by token that doesn't /// exist in the loaded metadata or type system registry. /// - /// The associated [`crate::metadata::token::Token`] identifies which type was not found. + /// The associated [`Token`] identifies which type was not found. #[error("Failed to find type in TypeSystem - {0}")] TypeNotFound(Token), diff --git a/dotscope/src/file/mod.rs b/dotscope/src/file/mod.rs index 6bf8772c..e98df9f3 100644 --- a/dotscope/src/file/mod.rs +++ b/dotscope/src/file/mod.rs @@ -879,10 +879,7 @@ impl File { pub fn data_slice(&self, offset: usize, len: usize) -> Result<&[u8]> { let base = self.data.data(); let end = offset.checked_add(len).ok_or(out_of_bounds_error!())?; - if end > base.len() { - return Err(out_of_bounds_error!()); - } - Ok(&base[offset..end]) + base.get(offset..end).ok_or(out_of_bounds_error!()) } /// Converts a virtual address (VA) to a file offset. @@ -918,11 +915,7 @@ impl File { /// ``` pub fn va_to_offset(&self, va: usize) -> Result { let ib = self.imagebase(); - if ib > va as u64 { - return Err(out_of_bounds_error!()); - } - - let rva_u64 = va as u64 - ib; + let rva_u64 = (va as u64).checked_sub(ib).ok_or(out_of_bounds_error!())?; let rva = usize::try_from(rva_u64) .map_err(|_| malformed_error!("RVA too large to fit in usize: {}", rva_u64))?; self.rva_to_offset(rva) @@ -974,9 +967,12 @@ impl File { let rva_u32 = u32::try_from(rva) .map_err(|_| malformed_error!("RVA too large to fit in u32: {}", rva))?; if section.virtual_address <= rva_u32 && section_max > rva_u32 { - return Ok( - (rva - section.virtual_address as usize) + section.pointer_to_raw_data as usize - ); + let delta = rva + .checked_sub(section.virtual_address as usize) + .ok_or_else(|| malformed_error!("RVA underflow vs section base"))?; + return delta + .checked_add(section.pointer_to_raw_data as usize) + .ok_or_else(|| malformed_error!("RVA-to-offset overflow")); } } @@ -1033,8 +1029,12 @@ impl File { let offset_u32 = u32::try_from(offset) .map_err(|_| malformed_error!("Offset too large to fit in u32: {}", offset))?; if section.pointer_to_raw_data <= offset_u32 && section_max > offset_u32 { - return Ok((offset - section.pointer_to_raw_data as usize) - + section.virtual_address as usize); + let delta = offset + .checked_sub(section.pointer_to_raw_data as usize) + .ok_or_else(|| malformed_error!("Offset underflow vs section base"))?; + return delta + .checked_add(section.virtual_address as usize) + .ok_or_else(|| malformed_error!("Offset-to-RVA overflow")); } } @@ -1089,12 +1089,13 @@ impl File { return false; }; - if clr_data.len() < 12 { + let Some(rva_bytes) = clr_data + .get(8..12) + .and_then(|s| <[u8; 4]>::try_from(s).ok()) + else { return false; - } - - let meta_data_rva = - u32::from_le_bytes([clr_data[8], clr_data[9], clr_data[10], clr_data[11]]); + }; + let meta_data_rva = u32::from_le_bytes(rva_bytes); if meta_data_rva == 0 { return false; // No metadata @@ -1105,7 +1106,10 @@ impl File { if current_section_name == section_name { let section_start = section.virtual_address; - let section_end = section.virtual_address + section.virtual_size; + let Some(section_end) = section.virtual_address.checked_add(section.virtual_size) + else { + return false; + }; return meta_data_rva >= section_start && meta_data_rva < section_end; } } @@ -1259,7 +1263,11 @@ impl File { } // PE offset is at offset 0x3C in DOS header - let pe_offset = u32::from_le_bytes([data[60], data[61], data[62], data[63]]); + let bytes = data + .get(60..64) + .and_then(|s| <[u8; 4]>::try_from(s).ok()) + .ok_or_else(|| LayoutFailed("DOS header truncated".to_string()))?; + let pe_offset = u32::from_le_bytes(bytes); Ok(u64::from(pe_offset)) } @@ -1280,24 +1288,40 @@ impl File { let pe_sig_offset = self.pe_signature_offset()?; let data = self.data(); - let coff_header_offset = pe_sig_offset + 4; // Skip PE signature + let coff_header_offset = pe_sig_offset + .checked_add(4) + .ok_or_else(|| LayoutFailed("COFF header offset overflow".to_string()))?; // Skip PE signature - #[allow(clippy::cast_possible_truncation)] - if data.len() < (coff_header_offset + 20) as usize { + let coff_end = coff_header_offset + .checked_add(20) + .ok_or_else(|| LayoutFailed("COFF header end overflow".to_string()))?; + let coff_end_usize = usize::try_from(coff_end) + .map_err(|_| LayoutFailed("COFF header end exceeds usize".to_string()))?; + if data.len() < coff_end_usize { return Err(LayoutFailed( "File too small to contain COFF header".to_string(), )); } // Optional header size is at offset 16 in COFF header - let opt_header_size_offset = coff_header_offset + 16; - #[allow(clippy::cast_possible_truncation)] - let opt_header_size = u16::from_le_bytes([ - data[opt_header_size_offset as usize], - data[opt_header_size_offset as usize + 1], - ]); - - Ok(4 + 20 + u64::from(opt_header_size)) // PE sig + COFF + Optional header + let opt_header_size_offset = coff_header_offset + .checked_add(16) + .ok_or_else(|| LayoutFailed("Optional header size offset overflow".to_string()))?; + let opt_offset_usize = usize::try_from(opt_header_size_offset) + .map_err(|_| LayoutFailed("Optional header offset exceeds usize".to_string()))?; + let opt_end_usize = opt_offset_usize + .checked_add(2) + .ok_or_else(|| LayoutFailed("Optional header offset+2 overflow".to_string()))?; + let opt_bytes = data + .get(opt_offset_usize..opt_end_usize) + .and_then(|s| <[u8; 2]>::try_from(s).ok()) + .ok_or_else(|| LayoutFailed("Optional header size truncated".to_string()))?; + let opt_header_size = u16::from_le_bytes(opt_bytes); + + // PE sig + COFF + Optional header + 24u64 + .checked_add(u64::from(opt_header_size)) + .ok_or_else(|| LayoutFailed("PE headers size overflow".to_string())) } /// Aligns an offset to this file's PE file alignment boundary. diff --git a/dotscope/src/file/parser.rs b/dotscope/src/file/parser.rs index 74598a96..e31949af 100644 --- a/dotscope/src/file/parser.rs +++ b/dotscope/src/file/parser.rs @@ -298,11 +298,14 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn advance_by(&mut self, step: usize) -> Result<()> { - if self.position + step > self.data.len() { + let new_pos = self + .position + .checked_add(step) + .ok_or(out_of_bounds_error!())?; + if new_pos > self.data.len() { return Err(out_of_bounds_error!()); } - - self.position += step; + self.position = new_pos; Ok(()) } @@ -360,10 +363,10 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn peek_byte(&self) -> Result { - if self.position >= self.data.len() { - return Err(out_of_bounds_error!()); - } - Ok(self.data[self.position]) + self.data + .get(self.position) + .copied() + .ok_or(out_of_bounds_error!()) } /// Peek at a value of type `T` in little-endian format without advancing the position. @@ -472,11 +475,22 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn align(&mut self, alignment: usize) -> Result<()> { - let padding = (alignment - (self.position % alignment)) % alignment; - if self.position + padding > self.data.len() { + let rem = self + .position + .checked_rem(alignment) + .ok_or(out_of_bounds_error!())?; + let padding = alignment + .wrapping_sub(rem) + .checked_rem(alignment) + .ok_or(out_of_bounds_error!())?; + let new_pos = self + .position + .checked_add(padding) + .ok_or(out_of_bounds_error!())?; + if new_pos > self.data.len() { return Err(out_of_bounds_error!()); } - self.position += padding; + self.position = new_pos; Ok(()) } @@ -611,8 +625,11 @@ impl<'a> Parser<'a> { result } else { #[allow(clippy::cast_possible_wrap)] - let result = -((unsigned >> 1) as i32 + 1); - result + let half = (unsigned >> 1) as i32; + // -(half + 1): reproduces the ECMA-335 negative encoding while + // matching the existing release-mode wrapping behaviour for + // out-of-range inputs (half == i32::MAX). + half.wrapping_add(1).wrapping_neg() }; Ok(signed) @@ -665,8 +682,15 @@ impl<'a> Parser<'a> { }; let table_index = compressed_token >> 2; + let token = table.checked_add(table_index).ok_or_else(|| { + malformed_error!( + "Compressed token index overflows table base: {} + {}", + table, + table_index + ) + })?; - Ok(Token::new(table + table_index)) + Ok(Token::new(token)) } /// Read a 7-bit encoded integer (used in .NET for variable-length encoding). @@ -697,18 +721,16 @@ impl<'a> Parser<'a> { /// ``` pub fn read_7bit_encoded_int(&mut self) -> Result { let mut value = 0u32; - let mut shift = 0; + let mut shift: u32 = 0; loop { - if self.position >= self.data.len() { - return Err(out_of_bounds_error!()); - } - - let byte = self.data[self.position]; - self.position += 1; + let byte = *self.data.get(self.position).ok_or(out_of_bounds_error!())?; + self.position = self.position.checked_add(1).ok_or(out_of_bounds_error!())?; value |= u32::from(byte & 0x7F) << shift; - shift += 7; + shift = shift + .checked_add(7) + .ok_or_else(|| malformed_error!("7-bit encoded integer overflow"))?; if (byte & 0x80) == 0 { break; @@ -755,17 +777,20 @@ impl<'a> Parser<'a> { let start = self.position; let mut end = start; - while end < self.data.len() && self.data[end] != 0 { - end += 1; + while let Some(&b) = self.data.get(end) { + if b == 0 { + break; + } + end = end.checked_add(1).ok_or(out_of_bounds_error!())?; } // Handle two cases: // 1. Found null terminator (end < data.len()): normal null-terminated string // 2. Reached end of data (end == data.len()): string without null terminator (valid case) - let string_data = &self.data[start..end]; + let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; if end < self.data.len() { - self.position = end + 1; + self.position = end.checked_add(1).ok_or(out_of_bounds_error!())?; } else { self.position = end; } @@ -804,19 +829,20 @@ impl<'a> Parser<'a> { /// ``` pub fn read_prefixed_string_utf8(&mut self) -> Result { let length = self.read_7bit_encoded_int()? as usize; + let end = self + .position + .checked_add(length) + .ok_or(out_of_bounds_error!())?; - if self.position + length > self.data.len() { - return Err(out_of_bounds_error!()); - } - - let string_data = &self.data[self.position..self.position + length]; - self.position += length; + let start = self.position; + let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + self.position = end; String::from_utf8(string_data.to_vec()).map_err(|e| { malformed_error!( "Invalid UTF-8 string at offset {}-{}: {}", - self.position - length, - self.position, + start, + end, e.utf8_error() ) }) @@ -858,19 +884,17 @@ impl<'a> Parser<'a> { /// ``` pub fn read_prefixed_string_utf8_ref(&mut self) -> Result<&'a str> { let length = self.read_7bit_encoded_int()? as usize; + let start = self.position; + let end = start.checked_add(length).ok_or(out_of_bounds_error!())?; - if self.position + length > self.data.len() { - return Err(out_of_bounds_error!()); - } - - let string_data = &self.data[self.position..self.position + length]; - self.position += length; + let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + self.position = end; std::str::from_utf8(string_data).map_err(|_| { malformed_error!( "Invalid UTF-8 string at position {} - {} - {:?}", - self.position - length, - self.position, + start, + end, string_data ) }) @@ -901,19 +925,17 @@ impl<'a> Parser<'a> { /// ``` pub fn read_compressed_string_utf8(&mut self) -> Result { let length = self.read_compressed_uint()? as usize; + let start = self.position; + let end = start.checked_add(length).ok_or(out_of_bounds_error!())?; - if self.position + length > self.data.len() { - return Err(out_of_bounds_error!()); - } - - let string_data = &self.data[self.position..self.position + length]; - self.position += length; + let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + self.position = end; String::from_utf8(string_data.to_vec()).map_err(|e| { malformed_error!( "Invalid UTF-8 compressed string at offset {}-{}: {}", - self.position - length, - self.position, + start, + end, e.utf8_error() ) }) @@ -1038,7 +1060,10 @@ impl<'a> Parser<'a> { /// ``` pub fn read_bytes(&mut self, length: usize) -> Result<&'a [u8]> { let end = self.calc_end_position(length)?; - let bytes = &self.data[self.position..end]; + let bytes = self + .data + .get(self.position..end) + .ok_or(out_of_bounds_error!())?; self.position = end; Ok(bytes) } @@ -1068,7 +1093,11 @@ impl<'a> Parser<'a> { /// ``` pub fn read_prefixed_string_utf16(&mut self) -> Result { let length = self.read_7bit_encoded_int()? as usize; - if self.position + length > self.data.len() { + let end = self + .position + .checked_add(length) + .ok_or(out_of_bounds_error!())?; + if end > self.data.len() { return Err(out_of_bounds_error!()); } @@ -1076,8 +1105,9 @@ impl<'a> Parser<'a> { return Err(malformed_error!("Invalid UTF-16 length - {}", length)); } - let mut utf16_chars: Vec = Vec::with_capacity(length / 2); - for _ in 0..length / 2 { + let char_count = length / 2; + let mut utf16_chars: Vec = Vec::with_capacity(char_count); + for _ in 0..char_count { let char = self.read_le::()?; utf16_chars.push(char); } diff --git a/dotscope/src/file/pe.rs b/dotscope/src/file/pe.rs index 4d1c12fd..dc485ff2 100644 --- a/dotscope/src/file/pe.rs +++ b/dotscope/src/file/pe.rs @@ -38,7 +38,10 @@ //! owned_pe.write_section_table(&mut buffer)?; //! ``` -use crate::{Error, Result}; +use crate::{ + utils::{read_le_at, write_le_at}, + Error, Result, +}; use std::collections::HashMap; use std::fmt; use std::io::Write; @@ -216,7 +219,7 @@ pub mod constants { pub const CLR_MAJOR_RUNTIME_VERSION: u16 = 2; } -use constants::{IMAGE_RESOURCE_DIRECTORY_SIZE, RESOURCE_DATA_ENTRY_SIZE, RESOURCE_ENTRY_SIZE}; +use constants::{IMAGE_RESOURCE_DIRECTORY_SIZE, RESOURCE_ENTRY_SIZE}; /// Owned PE file representation that doesn't require borrowing from source data. /// @@ -660,7 +663,8 @@ impl Pe { .as_ref() .map_or(0, |_| u64::from(self.coff_header.size_of_optional_header)); - 4 + CoffHeader::SIZE as u64 + optional_header_size + 4u64.saturating_add(CoffHeader::SIZE as u64) + .saturating_add(optional_header_size) } /// Calculates the total size of all file headers (DOS header + PE headers). @@ -672,7 +676,7 @@ impl Pe { /// Total size in bytes of DOS header + PE headers #[must_use] pub fn calculate_total_file_headers_size(&self) -> u64 { - DosHeader::SIZE as u64 + self.calculate_headers_size() + (DosHeader::SIZE as u64).saturating_add(self.calculate_headers_size()) } /// Calculates the total size of all current sections' raw data. @@ -1225,17 +1229,29 @@ impl OptionalHeader { /// * `is_pe32_plus` - `true` for PE32+ (64-bit), `false` for PE32 (32-bit) #[must_use] pub const fn size_for_format(is_pe32_plus: bool) -> usize { - StandardFields::SIZE_FOR_FORMAT[is_pe32_plus as usize] - + WindowsFields::SIZE_FOR_FORMAT[is_pe32_plus as usize] - + DataDirectories::SIZE + let std_size = if is_pe32_plus { + StandardFields::SIZE_PE32_PLUS + } else { + StandardFields::SIZE_PE32 + }; + let win_size = if is_pe32_plus { + WindowsFields::SIZE_PE32_PLUS + } else { + WindowsFields::SIZE_PE32 + }; + std_size + .saturating_add(win_size) + .saturating_add(DataDirectories::SIZE) } } impl StandardFields { - /// Size of standard fields: [PE32, PE32+] - /// - PE32: 28 bytes (includes base_of_data) - /// - PE32+: 24 bytes (no base_of_data) - pub const SIZE_FOR_FORMAT: [usize; 2] = [28, 24]; + /// Size of standard fields for PE32 (32-bit): 28 bytes (includes base_of_data). + pub const SIZE_PE32: usize = 28; + /// Size of standard fields for PE32+ (64-bit): 24 bytes (no base_of_data). + pub const SIZE_PE32_PLUS: usize = 24; + /// Size of standard fields, indexed by `is_pe32_plus as usize`. + pub const SIZE_FOR_FORMAT: [usize; 2] = [Self::SIZE_PE32, Self::SIZE_PE32_PLUS]; fn from_goblin(goblin_sf: &goblin::pe::optional_header::StandardFields) -> Result { Ok(Self { @@ -1296,10 +1312,12 @@ impl StandardFields { } impl WindowsFields { - /// Size of Windows-specific fields: [PE32, PE32+] - /// - PE32: 68 bytes (4-byte image_base and stack/heap sizes) - /// - PE32+: 88 bytes (8-byte image_base and stack/heap sizes) - pub const SIZE_FOR_FORMAT: [usize; 2] = [68, 88]; + /// Size of Windows-specific fields for PE32 (32-bit): 68 bytes (4-byte image_base and stack/heap sizes). + pub const SIZE_PE32: usize = 68; + /// Size of Windows-specific fields for PE32+ (64-bit): 88 bytes (8-byte image_base and stack/heap sizes). + pub const SIZE_PE32_PLUS: usize = 88; + /// Size of Windows-specific fields, indexed by `is_pe32_plus as usize`. + pub const SIZE_FOR_FORMAT: [usize; 2] = [Self::SIZE_PE32, Self::SIZE_PE32_PLUS]; fn from_goblin(goblin_wf: &goblin::pe::optional_header::WindowsFields) -> Self { Self { @@ -1605,7 +1623,7 @@ impl SectionTable { /// Total size in bytes for the section table #[must_use] pub fn calculate_table_size(section_count: usize) -> u64 { - (section_count * Self::SIZE) as u64 + section_count.saturating_mul(Self::SIZE) as u64 } /// Creates a SectionTable from layout information. @@ -1727,7 +1745,9 @@ impl SectionTable { let mut name_bytes = [0u8; 8]; let name_str = self.name.as_bytes(); let copy_len = std::cmp::min(name_str.len(), 8); - name_bytes[..copy_len].copy_from_slice(&name_str[..copy_len]); + if let (Some(dst), Some(src)) = (name_bytes.get_mut(..copy_len), name_str.get(..copy_len)) { + dst.copy_from_slice(src); + } writer.write_all(&name_bytes)?; writer.write_all(&self.virtual_size.to_le_bytes())?; @@ -1831,55 +1851,32 @@ pub struct ImageResourceDirectory { impl ImageResourceDirectory { /// Reads an `ImageResourceDirectory` from a byte slice at the given offset. pub fn read_from(data: &[u8], offset: usize) -> Result { - if offset + IMAGE_RESOURCE_DIRECTORY_SIZE > data.len() { - return Err(malformed_error!( - "Resource directory at offset {:#x} exceeds bounds", - offset - )); - } - + let mut pos = offset; Ok(Self { - characteristics: u32::from_le_bytes([ - data[offset], - data[offset + 1], - data[offset + 2], - data[offset + 3], - ]), - time_date_stamp: u32::from_le_bytes([ - data[offset + 4], - data[offset + 5], - data[offset + 6], - data[offset + 7], - ]), - major_version: u16::from_le_bytes([data[offset + 8], data[offset + 9]]), - minor_version: u16::from_le_bytes([data[offset + 10], data[offset + 11]]), - number_of_named_entries: u16::from_le_bytes([data[offset + 12], data[offset + 13]]), - number_of_id_entries: u16::from_le_bytes([data[offset + 14], data[offset + 15]]), + characteristics: read_le_at::(data, &mut pos)?, + time_date_stamp: read_le_at::(data, &mut pos)?, + major_version: read_le_at::(data, &mut pos)?, + minor_version: read_le_at::(data, &mut pos)?, + number_of_named_entries: read_le_at::(data, &mut pos)?, + number_of_id_entries: read_le_at::(data, &mut pos)?, }) } /// Returns the total number of entries (named + ID). #[inline] pub fn entry_count(&self) -> usize { - self.number_of_named_entries as usize + self.number_of_id_entries as usize + (self.number_of_named_entries as usize).saturating_add(self.number_of_id_entries as usize) } /// Writes this `ImageResourceDirectory` to a byte slice at the given offset. pub fn write_to(&self, data: &mut [u8], offset: usize) -> Result<()> { - if offset + IMAGE_RESOURCE_DIRECTORY_SIZE > data.len() { - return Err(malformed_error!( - "Resource directory at offset {:#x} exceeds bounds for write", - offset - )); - } - - data[offset..offset + 4].copy_from_slice(&self.characteristics.to_le_bytes()); - data[offset + 4..offset + 8].copy_from_slice(&self.time_date_stamp.to_le_bytes()); - data[offset + 8..offset + 10].copy_from_slice(&self.major_version.to_le_bytes()); - data[offset + 10..offset + 12].copy_from_slice(&self.minor_version.to_le_bytes()); - data[offset + 12..offset + 14].copy_from_slice(&self.number_of_named_entries.to_le_bytes()); - data[offset + 14..offset + 16].copy_from_slice(&self.number_of_id_entries.to_le_bytes()); - + let mut pos = offset; + write_le_at::(data, &mut pos, self.characteristics)?; + write_le_at::(data, &mut pos, self.time_date_stamp)?; + write_le_at::(data, &mut pos, self.major_version)?; + write_le_at::(data, &mut pos, self.minor_version)?; + write_le_at::(data, &mut pos, self.number_of_named_entries)?; + write_le_at::(data, &mut pos, self.number_of_id_entries)?; Ok(()) } } @@ -1905,26 +1902,10 @@ pub struct ResourceEntry { impl ResourceEntry { /// Reads a `ResourceEntry` from a byte slice at the given offset. pub fn read_from(data: &[u8], offset: usize) -> Result { - if offset + RESOURCE_ENTRY_SIZE > data.len() { - return Err(malformed_error!( - "Resource entry at offset {:#x} exceeds bounds", - offset - )); - } - + let mut pos = offset; Ok(Self { - name_or_id: u32::from_le_bytes([ - data[offset], - data[offset + 1], - data[offset + 2], - data[offset + 3], - ]), - offset_to_data_or_directory: u32::from_le_bytes([ - data[offset + 4], - data[offset + 5], - data[offset + 6], - data[offset + 7], - ]), + name_or_id: read_le_at::(data, &mut pos)?, + offset_to_data_or_directory: read_le_at::(data, &mut pos)?, }) } @@ -1944,17 +1925,9 @@ impl ResourceEntry { /// Writes this `ResourceEntry` to a byte slice at the given offset. pub fn write_to(self, data: &mut [u8], offset: usize) -> Result<()> { - if offset + RESOURCE_ENTRY_SIZE > data.len() { - return Err(malformed_error!( - "Resource entry at offset {:#x} exceeds bounds for write", - offset - )); - } - - data[offset..offset + 4].copy_from_slice(&self.name_or_id.to_le_bytes()); - data[offset + 4..offset + 8] - .copy_from_slice(&self.offset_to_data_or_directory.to_le_bytes()); - + let mut pos = offset; + write_le_at::(data, &mut pos, self.name_or_id)?; + write_le_at::(data, &mut pos, self.offset_to_data_or_directory)?; Ok(()) } } @@ -1986,60 +1959,22 @@ pub struct ResourceDataEntry { impl ResourceDataEntry { /// Reads a `ResourceDataEntry` from a byte slice at the given offset. pub fn read_from(data: &[u8], offset: usize) -> Result { - if offset + RESOURCE_DATA_ENTRY_SIZE > data.len() { - return Err(malformed_error!( - "Resource data entry at offset {:#x} exceeds bounds", - offset - )); - } - + let mut pos = offset; Ok(Self { - offset_to_data: u32::from_le_bytes([ - data[offset], - data[offset + 1], - data[offset + 2], - data[offset + 3], - ]), - size: u32::from_le_bytes([ - data[offset + 4], - data[offset + 5], - data[offset + 6], - data[offset + 7], - ]), - code_page: u32::from_le_bytes([ - data[offset + 8], - data[offset + 9], - data[offset + 10], - data[offset + 11], - ]), - reserved: u32::from_le_bytes([ - data[offset + 12], - data[offset + 13], - data[offset + 14], - data[offset + 15], - ]), + offset_to_data: read_le_at::(data, &mut pos)?, + size: read_le_at::(data, &mut pos)?, + code_page: read_le_at::(data, &mut pos)?, + reserved: read_le_at::(data, &mut pos)?, }) } /// Writes this `ResourceDataEntry` to a byte slice at the given offset. pub fn write_to(&self, data: &mut [u8], offset: usize) -> Result<()> { - if offset + RESOURCE_DATA_ENTRY_SIZE > data.len() { - return Err(malformed_error!( - "Resource data entry at offset {:#x} exceeds bounds", - offset - )); - } - - let rva_bytes = self.offset_to_data.to_le_bytes(); - let size_bytes = self.size.to_le_bytes(); - let code_page_bytes = self.code_page.to_le_bytes(); - let reserved_bytes = self.reserved.to_le_bytes(); - - data[offset..offset + 4].copy_from_slice(&rva_bytes); - data[offset + 4..offset + 8].copy_from_slice(&size_bytes); - data[offset + 8..offset + 12].copy_from_slice(&code_page_bytes); - data[offset + 12..offset + 16].copy_from_slice(&reserved_bytes); - + let mut pos = offset; + write_le_at::(data, &mut pos, self.offset_to_data)?; + write_le_at::(data, &mut pos, self.size)?; + write_le_at::(data, &mut pos, self.code_page)?; + write_le_at::(data, &mut pos, self.reserved)?; Ok(()) } } @@ -2079,7 +2014,11 @@ pub fn relocate_resource_section(data: &mut [u8], old_rva: u32, new_rva: u32) -> return Ok(()); // No relocation needed } - let delta = i64::from(new_rva) - i64::from(old_rva); + // i64 has plenty of headroom for u32 - u32; checked_sub is overkill but + // satisfies clippy's arithmetic_side_effects lint. + let delta = i64::from(new_rva) + .checked_sub(i64::from(old_rva)) + .ok_or_else(|| malformed_error!("Resource RVA delta overflow"))?; // Process the root directory at offset 0 relocate_resource_directory(data, 0, delta) @@ -2089,41 +2028,39 @@ pub fn relocate_resource_section(data: &mut [u8], old_rva: u32, new_rva: u32) -> fn relocate_resource_directory(data: &mut [u8], offset: usize, delta: i64) -> Result<()> { // Read the directory header let dir = ImageResourceDirectory::read_from(data, offset)?; - let entries_offset = offset + IMAGE_RESOURCE_DIRECTORY_SIZE; + let entries_offset = offset + .checked_add(IMAGE_RESOURCE_DIRECTORY_SIZE) + .ok_or_else(|| malformed_error!("Resource directory entries offset overflow"))?; // Process each entry for i in 0..dir.entry_count() { - let entry_offset = entries_offset + i * RESOURCE_ENTRY_SIZE; + let scaled = i + .checked_mul(RESOURCE_ENTRY_SIZE) + .ok_or_else(|| malformed_error!("Resource entry index overflow"))?; + let entry_offset = entries_offset + .checked_add(scaled) + .ok_or_else(|| malformed_error!("Resource entry offset overflow"))?; let entry = ResourceEntry::read_from(data, entry_offset)?; if entry.is_directory() { // Entry points to another directory - recurse relocate_resource_directory(data, entry.target_offset(), delta)?; } else { - // Entry points to a ResourceDataEntry - adjust the RVA in-place - // The RVA is the first 4 bytes of the ResourceDataEntry structure + // Entry points to a ResourceDataEntry - adjust the RVA in-place. + // The RVA is the first 4 bytes of the ResourceDataEntry structure. let data_entry_offset = entry.target_offset(); - if data_entry_offset + 4 > data.len() { - return Err(malformed_error!( - "Resource data entry at offset {:#x} exceeds bounds", - data_entry_offset - )); - } - let old_data_rva = u32::from_le_bytes([ - data[data_entry_offset], - data[data_entry_offset + 1], - data[data_entry_offset + 2], - data[data_entry_offset + 3], - ]); - let new_data_rva = u32::try_from(i64::from(old_data_rva) + delta).map_err(|_| { - malformed_error!( - "Resource RVA relocation overflow: old_rva={:#x}, delta={}", - old_data_rva, - delta - ) - })?; - data[data_entry_offset..data_entry_offset + 4] - .copy_from_slice(&new_data_rva.to_le_bytes()); + let mut pos = data_entry_offset; + let old_data_rva: u32 = read_le_at(data, &mut pos)?; + let new_data_rva = u32::try_from(i64::from(old_data_rva).saturating_add(delta)) + .map_err(|_| { + malformed_error!( + "Resource RVA relocation overflow: old_rva={:#x}, delta={}", + old_data_rva, + delta + ) + })?; + let mut pos = data_entry_offset; + write_le_at::(data, &mut pos, new_data_rva)?; } } diff --git a/dotscope/src/file/repair.rs b/dotscope/src/file/repair.rs index 97507280..66e0d380 100644 --- a/dotscope/src/file/repair.rs +++ b/dotscope/src/file/repair.rs @@ -116,13 +116,16 @@ pub fn repair_pe(bytes: &mut [u8]) -> RepairResult { let mut result = RepairResult::default(); // Bail fast: not a PE file or too small for a DOS header - if bytes.len() < 64 || bytes[0] != b'M' || bytes[1] != b'Z' { + if bytes.len() < 64 || bytes.first() != Some(&b'M') || bytes.get(1) != Some(&b'Z') { return result; } let Some(pe_off) = read_u32(bytes, 0x3C).map(|v| v as usize) else { return result; }; - if pe_off + 4 > bytes.len() { + let Some(end) = pe_off.checked_add(4) else { + return result; + }; + if end > bytes.len() { return result; } @@ -148,13 +151,16 @@ pub fn repair_pe_cow(cowfile: &CowFile) -> RepairResult { return result; } let base = cowfile.data(); - if base[0] != b'M' || base[1] != b'Z' { + if base.first() != Some(&b'M') || base.get(1) != Some(&b'Z') { return result; } let Some(pe_off) = cowfile.read_le::(0x3C).ok().map(|v| v as usize) else { return result; }; - if pe_off + 4 > base.len() { + let Some(end) = pe_off.checked_add(4) else { + return result; + }; + if end > base.len() { return result; } @@ -184,15 +190,20 @@ fn repair_pe_signature_cow(cowfile: &CowFile, pe_off: usize, result: &mut Repair } fn repair_pe_optional_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut RepairResult) { - let optional_header_offset = pe_off + 24; + let Some(optional_header_offset) = pe_off.checked_add(24) else { + return; + }; let Some(magic) = cowfile.read_le::(optional_header_offset).ok() else { return; }; let num_rva_sizes_offset = match magic { - PE32_MAGIC => optional_header_offset + 0x5C, - PE32PLUS_MAGIC => optional_header_offset + 0x6C, + PE32_MAGIC => optional_header_offset.checked_add(0x5C), + PE32PLUS_MAGIC => optional_header_offset.checked_add(0x6C), _ => return, }; + let Some(num_rva_sizes_offset) = num_rva_sizes_offset else { + return; + }; if let Ok(num_rva_sizes) = cowfile.read_le::(num_rva_sizes_offset) { if num_rva_sizes == BITMONO_INFLATED_DATA_DIRECTORY_COUNT { @@ -208,17 +219,26 @@ fn repair_pe_optional_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut } } - let data_dir_start = num_rva_sizes_offset + 4; - let clr_dir_offset = data_dir_start + CLR_DIRECTORY_INDEX * 8; + let Some(data_dir_start) = num_rva_sizes_offset.checked_add(4) else { + return; + }; + let Some(clr_offset_in_table) = CLR_DIRECTORY_INDEX.checked_mul(8) else { + return; + }; + let Some(clr_dir_offset) = data_dir_start.checked_add(clr_offset_in_table) else { + return; + }; + let Some(clr_size_offset) = clr_dir_offset.checked_add(4) else { + return; + }; let clr_rva = cowfile.read_le::(clr_dir_offset).ok(); - let clr_size = cowfile.read_le::(clr_dir_offset + 4).ok(); - if clr_rva.is_some_and(|r| r != 0) && clr_size == Some(0) { - let rva = clr_rva.unwrap(); + let clr_size = cowfile.read_le::(clr_size_offset).ok(); + if let (Some(rva), Some(0)) = (clr_rva.filter(|&r| r != 0), clr_size) { debug!( ".NET directory size repair: 0 → {} (RVA=0x{:X})", COR20_HEADER_SIZE, rva ); - let _ = cowfile.write_le::(clr_dir_offset + 4, COR20_HEADER_SIZE); + let _ = cowfile.write_le::(clr_size_offset, COR20_HEADER_SIZE); result.repairs.push(RepairAction::DotNetDirectorySize { original: 0, restored: COR20_HEADER_SIZE, @@ -227,17 +247,29 @@ fn repair_pe_optional_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut } fn repair_clr_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut RepairResult) { - let optional_header_offset = pe_off + 24; + let Some(optional_header_offset) = pe_off.checked_add(24) else { + return; + }; let Some(magic) = cowfile.read_le::(optional_header_offset).ok() else { return; }; let num_rva_sizes_offset = match magic { - PE32_MAGIC => optional_header_offset + 0x5C, - PE32PLUS_MAGIC => optional_header_offset + 0x6C, + PE32_MAGIC => optional_header_offset.checked_add(0x5C), + PE32PLUS_MAGIC => optional_header_offset.checked_add(0x6C), _ => return, }; - let data_dir_start = num_rva_sizes_offset + 4; - let clr_dir_offset = data_dir_start + CLR_DIRECTORY_INDEX * 8; + let Some(num_rva_sizes_offset) = num_rva_sizes_offset else { + return; + }; + let Some(data_dir_start) = num_rva_sizes_offset.checked_add(4) else { + return; + }; + let Some(clr_offset_in_table) = CLR_DIRECTORY_INDEX.checked_mul(8) else { + return; + }; + let Some(clr_dir_offset) = data_dir_start.checked_add(clr_offset_in_table) else { + return; + }; let Some(clr_rva) = cowfile.read_le::(clr_dir_offset).ok() else { return; }; @@ -250,6 +282,15 @@ fn repair_clr_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut RepairRe return; }; let clr_offset = clr_file_offset as usize; + let Some(version_offset) = clr_offset.checked_add(4) else { + return; + }; + let Some(metadata_rva_offset) = clr_offset.checked_add(8) else { + return; + }; + let Some(metadata_size_offset) = clr_offset.checked_add(12) else { + return; + }; if let Ok(0) = cowfile.read_le::(clr_offset) { debug!("CLR header size repair: 0 → {}", COR20_HEADER_SIZE); @@ -260,20 +301,20 @@ fn repair_clr_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut RepairRe }); } - if let Ok(0) = cowfile.read_le::(clr_offset + 4) { + if let Ok(0) = cowfile.read_le::(version_offset) { debug!( "CLR runtime version repair: 0 → {}", CLR_MAJOR_RUNTIME_VERSION ); - let _ = cowfile.write_le::(clr_offset + 4, CLR_MAJOR_RUNTIME_VERSION); + let _ = cowfile.write_le::(version_offset, CLR_MAJOR_RUNTIME_VERSION); result.repairs.push(RepairAction::ClrHeaderVersion { original_major: 0, restored_major: CLR_MAJOR_RUNTIME_VERSION, }); } - let metadata_rva = cowfile.read_le::(clr_offset + 8).ok(); - let metadata_size = cowfile.read_le::(clr_offset + 12).ok(); + let metadata_rva = cowfile.read_le::(metadata_rva_offset).ok(); + let metadata_size = cowfile.read_le::(metadata_size_offset).ok(); if metadata_rva == Some(0) || metadata_size == Some(0) { if let Some((found_rva, found_size)) = find_metadata_by_bsjb_scan(base, pe_off) { @@ -281,8 +322,8 @@ fn repair_clr_header_cow(cowfile: &CowFile, pe_off: usize, result: &mut RepairRe "CLR metadata repair via BSJB scan: RVA=0x{:X}, size=0x{:X}", found_rva, found_size ); - let _ = cowfile.write_le::(clr_offset + 8, found_rva); - let _ = cowfile.write_le::(clr_offset + 12, found_size); + let _ = cowfile.write_le::(metadata_rva_offset, found_rva); + let _ = cowfile.write_le::(metadata_size_offset, found_size); result.repairs.push(RepairAction::ClrMetadataRva { restored_rva: found_rva, restored_size: found_size, @@ -323,15 +364,20 @@ fn repair_pe_signature(bytes: &mut [u8], pe_off: usize, result: &mut RepairResul } fn repair_pe_optional_header(bytes: &mut [u8], pe_off: usize, result: &mut RepairResult) { - let optional_header_offset = pe_off + 24; + let Some(optional_header_offset) = pe_off.checked_add(24) else { + return; + }; let Some(magic) = read_u16(bytes, optional_header_offset) else { return; }; let num_rva_sizes_offset = match magic { - PE32_MAGIC => optional_header_offset + 0x5C, - PE32PLUS_MAGIC => optional_header_offset + 0x6C, + PE32_MAGIC => optional_header_offset.checked_add(0x5C), + PE32PLUS_MAGIC => optional_header_offset.checked_add(0x6C), _ => return, }; + let Some(num_rva_sizes_offset) = num_rva_sizes_offset else { + return; + }; if let Some(num_rva_sizes) = read_u32(bytes, num_rva_sizes_offset) { if num_rva_sizes == BITMONO_INFLATED_DATA_DIRECTORY_COUNT { @@ -348,17 +394,26 @@ fn repair_pe_optional_header(bytes: &mut [u8], pe_off: usize, result: &mut Repai } } - let data_dir_start = num_rva_sizes_offset + 4; - let clr_dir_offset = data_dir_start + CLR_DIRECTORY_INDEX * 8; + let Some(data_dir_start) = num_rva_sizes_offset.checked_add(4) else { + return; + }; + let Some(clr_offset_in_table) = CLR_DIRECTORY_INDEX.checked_mul(8) else { + return; + }; + let Some(clr_dir_offset) = data_dir_start.checked_add(clr_offset_in_table) else { + return; + }; + let Some(clr_size_offset) = clr_dir_offset.checked_add(4) else { + return; + }; let clr_rva = read_u32(bytes, clr_dir_offset); - let clr_size = read_u32(bytes, clr_dir_offset + 4); - if clr_rva.is_some_and(|r| r != 0) && clr_size == Some(0) { - let rva = clr_rva.unwrap(); + let clr_size = read_u32(bytes, clr_size_offset); + if let (Some(rva), Some(0)) = (clr_rva.filter(|&r| r != 0), clr_size) { debug!( ".NET directory size repair: 0 → {} (RVA=0x{:X})", COR20_HEADER_SIZE, rva ); - let mut off = clr_dir_offset + 4; + let mut off = clr_size_offset; let _ = write_le_at(bytes, &mut off, COR20_HEADER_SIZE); result.repairs.push(RepairAction::DotNetDirectorySize { original: 0, @@ -368,17 +423,29 @@ fn repair_pe_optional_header(bytes: &mut [u8], pe_off: usize, result: &mut Repai } fn repair_clr_header(bytes: &mut [u8], pe_off: usize, result: &mut RepairResult) { - let optional_header_offset = pe_off + 24; + let Some(optional_header_offset) = pe_off.checked_add(24) else { + return; + }; let Some(magic) = read_u16(bytes, optional_header_offset) else { return; }; let num_rva_sizes_offset = match magic { - PE32_MAGIC => optional_header_offset + 0x5C, - PE32PLUS_MAGIC => optional_header_offset + 0x6C, + PE32_MAGIC => optional_header_offset.checked_add(0x5C), + PE32PLUS_MAGIC => optional_header_offset.checked_add(0x6C), _ => return, }; - let data_dir_start = num_rva_sizes_offset + 4; - let clr_dir_offset = data_dir_start + CLR_DIRECTORY_INDEX * 8; + let Some(num_rva_sizes_offset) = num_rva_sizes_offset else { + return; + }; + let Some(data_dir_start) = num_rva_sizes_offset.checked_add(4) else { + return; + }; + let Some(clr_offset_in_table) = CLR_DIRECTORY_INDEX.checked_mul(8) else { + return; + }; + let Some(clr_dir_offset) = data_dir_start.checked_add(clr_offset_in_table) else { + return; + }; let Some(clr_rva) = read_u32(bytes, clr_dir_offset) else { return; }; @@ -390,6 +457,15 @@ fn repair_clr_header(bytes: &mut [u8], pe_off: usize, result: &mut RepairResult) return; }; let clr_offset = clr_file_offset as usize; + let Some(version_offset) = clr_offset.checked_add(4) else { + return; + }; + let Some(metadata_rva_offset) = clr_offset.checked_add(8) else { + return; + }; + let Some(metadata_size_offset) = clr_offset.checked_add(12) else { + return; + }; if let Some(0) = read_u32(bytes, clr_offset) { debug!("CLR header size repair: 0 → {}", COR20_HEADER_SIZE); @@ -401,12 +477,12 @@ fn repair_clr_header(bytes: &mut [u8], pe_off: usize, result: &mut RepairResult) }); } - if let Some(0) = read_u16(bytes, clr_offset + 4) { + if let Some(0) = read_u16(bytes, version_offset) { debug!( "CLR runtime version repair: 0 → {}", CLR_MAJOR_RUNTIME_VERSION ); - let mut off = clr_offset + 4; + let mut off = version_offset; let _ = write_le_at(bytes, &mut off, CLR_MAJOR_RUNTIME_VERSION); result.repairs.push(RepairAction::ClrHeaderVersion { original_major: 0, @@ -414,8 +490,8 @@ fn repair_clr_header(bytes: &mut [u8], pe_off: usize, result: &mut RepairResult) }); } - let metadata_rva = read_u32(bytes, clr_offset + 8); - let metadata_size = read_u32(bytes, clr_offset + 12); + let metadata_rva = read_u32(bytes, metadata_rva_offset); + let metadata_size = read_u32(bytes, metadata_size_offset); if metadata_rva == Some(0) || metadata_size == Some(0) { if let Some((found_rva, found_size)) = find_metadata_by_bsjb_scan(bytes, pe_off) { @@ -423,9 +499,9 @@ fn repair_clr_header(bytes: &mut [u8], pe_off: usize, result: &mut RepairResult) "CLR metadata repair via BSJB scan: RVA=0x{:X}, size=0x{:X}", found_rva, found_size ); - let mut off = clr_offset + 8; + let mut off = metadata_rva_offset; let _ = write_le_at(bytes, &mut off, found_rva); - let mut off = clr_offset + 12; + let mut off = metadata_size_offset; let _ = write_le_at(bytes, &mut off, found_size); result.repairs.push(RepairAction::ClrMetadataRva { restored_rva: found_rva, @@ -442,8 +518,10 @@ fn rva_to_file_offset(bytes: &[u8], pe_off: usize, rva: u32) -> Option { for section in §ions { let section_end = section.virtual_address.checked_add(section.virtual_size)?; if rva >= section.virtual_address && rva < section_end { - let offset_within_section = rva - section.virtual_address; - return Some(section.pointer_to_raw_data + offset_within_section); + let offset_within_section = rva.checked_sub(section.virtual_address)?; + return section + .pointer_to_raw_data + .checked_add(offset_within_section); } } @@ -459,8 +537,8 @@ fn file_offset_to_rva(bytes: &[u8], pe_off: usize, offset: u32) -> Option { .pointer_to_raw_data .checked_add(section.size_of_raw_data)?; if offset >= section.pointer_to_raw_data && offset < section_end { - let offset_within_section = offset - section.pointer_to_raw_data; - return Some(section.virtual_address + offset_within_section); + let offset_within_section = offset.checked_sub(section.pointer_to_raw_data)?; + return section.virtual_address.checked_add(offset_within_section); } } @@ -470,41 +548,48 @@ fn file_offset_to_rva(bytes: &[u8], pe_off: usize, offset: u32) -> Option { /// Parses the section table from raw PE bytes into [`SectionTable`] entries. fn parse_section_table(bytes: &[u8], pe_off: usize) -> Option> { // COFF header starts at pe_off + 4 - let coff_offset = pe_off + 4; + let coff_offset = pe_off.checked_add(4)?; // Number of sections at COFF offset + 2 - let num_sections = read_u16(bytes, coff_offset + 2)? as usize; + let num_sections = read_u16(bytes, coff_offset.checked_add(2)?)? as usize; // Size of optional header at COFF offset + 16 - let opt_header_size = read_u16(bytes, coff_offset + 16)? as usize; + let opt_header_size = read_u16(bytes, coff_offset.checked_add(16)?)? as usize; // Section table starts after PE sig (4) + COFF header (20) + optional header - let section_table_offset = pe_off + 4 + 20 + opt_header_size; + let section_table_offset = pe_off + .checked_add(4)? + .checked_add(20)? + .checked_add(opt_header_size)?; let mut sections = Vec::with_capacity(num_sections); for i in 0..num_sections { // Each section header is 40 bytes (SectionTable::SIZE) - let entry_offset = section_table_offset + i * SectionTable::SIZE; + let entry_offset = section_table_offset.checked_add(i.checked_mul(SectionTable::SIZE)?)?; + + let read_field_u32 = |delta: usize| read_u32(bytes, entry_offset.checked_add(delta)?); + let read_field_u16 = |delta: usize| read_u16(bytes, entry_offset.checked_add(delta)?); // Name at offset 0 (8 bytes) + let name_end = entry_offset.checked_add(8)?; let name = bytes - .get(entry_offset..entry_offset + 8) + .get(entry_offset..name_end) .and_then(|b| std::str::from_utf8(b).ok()) .map(|s| s.trim_end_matches('\0').to_string()) .unwrap_or_default(); sections.push(SectionTable { name, - virtual_size: read_u32(bytes, entry_offset + 8)?, - virtual_address: read_u32(bytes, entry_offset + 12)?, - size_of_raw_data: read_u32(bytes, entry_offset + 16)?, - pointer_to_raw_data: read_u32(bytes, entry_offset + 20)?, - pointer_to_relocations: read_u32(bytes, entry_offset + 24)?, - pointer_to_line_numbers: read_u32(bytes, entry_offset + 28)?, - number_of_relocations: read_u16(bytes, entry_offset + 32)?, - number_of_line_numbers: read_u16(bytes, entry_offset + 34)?, - characteristics: read_u32(bytes, entry_offset + 36)?, + virtual_size: read_field_u32(8)?, + virtual_address: read_field_u32(12)?, + size_of_raw_data: read_field_u32(16)?, + pointer_to_raw_data: read_field_u32(20)?, + pointer_to_relocations: read_field_u32(24)?, + pointer_to_line_numbers: read_field_u32(28)?, + number_of_relocations: read_field_u16(32)?, + number_of_line_numbers: read_field_u16(34)?, + characteristics: read_field_u32(36)?, }); } @@ -520,8 +605,12 @@ fn find_metadata_by_bsjb_scan(bytes: &[u8], pe_off: usize) -> Option<(u32, u32)> return None; } - for i in 0..len - 3 { - if bytes[i..i + 4] == bsjb_bytes { + for i in 0..len.saturating_sub(3) { + let end = i.checked_add(4)?; + let Some(window) = bytes.get(i..end) else { + break; + }; + if window == bsjb_bytes { let file_offset = i as u32; // Convert file offset to RVA @@ -552,32 +641,50 @@ fn find_metadata_by_bsjb_scan(bytes: &[u8], pe_off: usize) -> Option<(u32, u32)> /// /// Parses the stream headers to find the maximum extent of metadata. fn estimate_metadata_size(bytes: &[u8], bsjb_offset: usize) -> u32 { + const FALLBACK: u32 = 0x1000; // Minimum: at least the fixed header fields let base = bsjb_offset; // Read version length at offset 12 - let Some(version_length) = read_u32(bytes, base + 12) else { - return 0x1000; // Fallback estimate + let Some(version_offset) = base.checked_add(12) else { + return FALLBACK; + }; + let Some(version_length) = read_u32(bytes, version_offset) else { + return FALLBACK; }; // Version string follows (padded to 4-byte boundary) - let padded_version_len = ((version_length + 3) & !3) as usize; - let streams_header_offset = base + 16 + padded_version_len; + let padded_version_len = (version_length.saturating_add(3) & !3) as usize; + let Some(streams_header_offset) = base + .checked_add(16) + .and_then(|v| v.checked_add(padded_version_len)) + else { + return FALLBACK; + }; // Read number of streams - let Some(num_streams) = read_u16(bytes, streams_header_offset + 2) else { - return 0x1000; + let Some(num_streams_offset) = streams_header_offset.checked_add(2) else { + return FALLBACK; + }; + let Some(num_streams) = read_u16(bytes, num_streams_offset) else { + return FALLBACK; }; // Parse stream headers to find maximum extent let mut max_extent: u32 = 0; - let mut cursor = streams_header_offset + 4; // Skip flags (2) + num_streams (2) + // Skip flags (2) + num_streams (2) + let Some(mut cursor) = streams_header_offset.checked_add(4) else { + return FALLBACK; + }; for _ in 0..num_streams { let Some(stream_offset) = read_u32(bytes, cursor) else { break; }; - let Some(stream_size) = read_u32(bytes, cursor + 4) else { + let Some(size_cursor) = cursor.checked_add(4) else { + break; + }; + let Some(stream_size) = read_u32(bytes, size_cursor) else { break; }; @@ -587,21 +694,36 @@ fn estimate_metadata_size(bytes: &[u8], bsjb_offset: usize) -> u32 { } // Skip past offset (4) + size (4) + name (null-terminated, padded to 4) - cursor += 8; + let Some(after_size) = cursor.checked_add(8) else { + break; + }; + cursor = after_size; // Read stream name (scan for null terminator) - while cursor < bytes.len() && bytes[cursor] != 0 { - cursor += 1; + while let Some(&b) = bytes.get(cursor) { + if b == 0 { + break; + } + let Some(next) = cursor.checked_add(1) else { + return if max_extent > 0 { max_extent } else { FALLBACK }; + }; + cursor = next; } // Skip null terminator - cursor += 1; + let Some(after_null) = cursor.checked_add(1) else { + break; + }; + cursor = after_null; // Align to 4-byte boundary - cursor = (cursor + 3) & !3; + let Some(aligned) = cursor.checked_add(3) else { + break; + }; + cursor = aligned & !3; } if max_extent > 0 { max_extent } else { - 0x1000 // Fallback + FALLBACK } } diff --git a/dotscope/src/formatting/assembly.rs b/dotscope/src/formatting/assembly.rs index 07e7bb6c..f72e484f 100644 --- a/dotscope/src/formatting/assembly.rs +++ b/dotscope/src/formatting/assembly.rs @@ -325,7 +325,8 @@ pub(super) fn format_data_directives(w: &mut dyn Write, asm: &CilObject) -> io:: // Uninitialized (BSS) data lives beyond the section's raw data but // within its virtual size. let is_initialized = find_section_for_rva(sections, rva).is_none_or(|s| { - (rva as u64) < u64::from(s.virtual_address) + u64::from(s.size_of_raw_data) + (rva as u64) + < u64::from(s.virtual_address).saturating_add(u64::from(s.size_of_raw_data)) }); if is_initialized { diff --git a/dotscope/src/formatting/exceptions.rs b/dotscope/src/formatting/exceptions.rs index 568ca35a..ec526d04 100644 --- a/dotscope/src/formatting/exceptions.rs +++ b/dotscope/src/formatting/exceptions.rs @@ -77,7 +77,7 @@ impl ExceptionBlockLayout { // Group handlers by try region (try_offset, try_end) let mut try_groups: BTreeMap<(u32, u32), Vec<&ExceptionHandler>> = BTreeMap::new(); for handler in handlers { - let try_end = handler.try_offset + handler.try_length; + let try_end = handler.try_offset.saturating_add(handler.try_length); try_groups .entry((handler.try_offset, try_end)) .or_default() @@ -103,7 +103,9 @@ impl ExceptionBlockLayout { // For each handler in this try group for handler in group { - let handler_end = handler.handler_offset + handler.handler_length; + let handler_end = handler + .handler_offset + .saturating_add(handler.handler_length); let kind = handler_kind(handler, asm); // Filter block opens at filter_offset (before handler) diff --git a/dotscope/src/formatting/helpers.rs b/dotscope/src/formatting/helpers.rs index f2de8418..21b1f763 100644 --- a/dotscope/src/formatting/helpers.rs +++ b/dotscope/src/formatting/helpers.rs @@ -32,14 +32,18 @@ static ILASM_RESERVED: LazyLock> = LazyLock::new(|| { let mut set = HashSet::with_capacity(512); // All CIL instruction mnemonics from the opcode tables - for instr in &INSTRUCTIONS[..INSTRUCTIONS_MAX as usize] { - if !instr.instr.is_empty() { - set.insert(instr.instr); + if let Some(slice) = INSTRUCTIONS.get(..INSTRUCTIONS_MAX as usize) { + for instr in slice { + if !instr.instr.is_empty() { + set.insert(instr.instr); + } } } - for instr in &INSTRUCTIONS_FE[..INSTRUCTIONS_FE_MAX as usize] { - if !instr.instr.is_empty() { - set.insert(instr.instr); + if let Some(slice) = INSTRUCTIONS_FE.get(..INSTRUCTIONS_FE_MAX as usize) { + for instr in slice { + if !instr.instr.is_empty() { + set.insert(instr.instr); + } } } @@ -381,8 +385,12 @@ pub(super) fn assembly_scoped_name(cil_type: &CilType, asm: &CilObject) -> Strin /// Strip a trailing generic arity suffix (`` `N ``) from a type name. fn strip_arity(name: &str) -> &str { if let Some(pos) = name.rfind('`') { - if name[pos + 1..].chars().all(|c| c.is_ascii_digit()) { - return &name[..pos]; + if let Some(rest) = name.get(pos.saturating_add(1)..) { + if rest.chars().all(|c| c.is_ascii_digit()) { + if let Some(prefix) = name.get(..pos) { + return prefix; + } + } } } name diff --git a/dotscope/src/formatting/method_body.rs b/dotscope/src/formatting/method_body.rs index edd4c269..f7c2f8a8 100644 --- a/dotscope/src/formatting/method_body.rs +++ b/dotscope/src/formatting/method_body.rs @@ -51,7 +51,7 @@ pub(super) fn format_method_body( 0 }; - let code_start_rva = u64::from(method.rva.unwrap_or(0)) + header_size; + let code_start_rva = u64::from(method.rva.unwrap_or(0)).saturating_add(header_size); // Build exception block layout for interleaving let layout = method @@ -87,7 +87,7 @@ pub(super) fn format_method_body( | BlockEvent::HandlerOpen { .. } | BlockEvent::FilterOpen ) { - nesting_depth += 1; + nesting_depth = nesting_depth.saturating_add(1); } } } @@ -110,7 +110,7 @@ pub(super) fn format_method_body_raw( asm: &CilObject, ) -> io::Result<()> { let header_size = method.body.get().map_or(0, |b| b.size_header as u64); - let code_start_rva = u64::from(method.rva.unwrap_or(0)) + header_size; + let code_start_rva = u64::from(method.rva.unwrap_or(0)).saturating_add(header_size); for instruction in method.instructions() { format_instruction(opts, w, instruction, code_start_rva, asm, 2)?; } @@ -173,7 +173,7 @@ fn format_locals(w: &mut dyn Write, method: &Method, asm: &CilObject) -> io::Res let pinned = if local.is_pinned { " pinned" } else { "" }; let byref = if local.is_byref { "&" } else { "" }; - let comma = if i < count - 1 { "," } else { "" }; + let comma = if i.saturating_add(1) < count { "," } else { "" }; writeln!(w, " [{i}] {type_name}{pinned}{byref} V_{i}{comma}")?; } @@ -286,11 +286,15 @@ fn format_operand( } else { "???".to_string() }; - let suffix = if i == offsets.len() - 1 { ")" } else { "," }; + let suffix = if i.saturating_add(1) == offsets.len() { + ")" + } else { + "," + }; result.push_str(&format!(" {label}{suffix}\n")); } // Remove trailing newline — writeln in format_instruction adds one - result.truncate(result.len() - 1); + result.truncate(result.len().saturating_sub(1)); result } } @@ -326,8 +330,8 @@ fn format_instruction_bytes(w: &mut dyn Write, instruction: &Instruction) -> io: Operand::Target(_) => { // CIL instruction sizes are small (max ~10 bytes), truncation is safe #[allow(clippy::cast_possible_truncation)] - let operand_size = - instruction.size as usize - if instruction.prefix != 0 { 2 } else { 1 }; + let operand_size = (instruction.size as usize) + .saturating_sub(if instruction.prefix != 0 { 2 } else { 1 }); bytes.extend(std::iter::repeat_n(0x00, operand_size)); } Operand::Token(tok) => { diff --git a/dotscope/src/formatting/tokens.rs b/dotscope/src/formatting/tokens.rs index 84aea731..5e1fe8f3 100644 --- a/dotscope/src/formatting/tokens.rs +++ b/dotscope/src/formatting/tokens.rs @@ -100,7 +100,7 @@ pub(super) fn resolve_token(assembly: &CilObject, token: Token) -> Option CilAssemblyViewData<'a> { return Err(out_of_bounds_error!()); } - let cor20_header = Cor20Header::read(&data[clr_offset..clr_end])?; + let cor20_header = Cor20Header::read( + data.get(clr_offset..clr_end) + .ok_or(out_of_bounds_error!())?, + )?; debug!( "PE header: CLR {}.{}, metadata at RVA 0x{:X}", cor20_header.major_runtime_version, @@ -231,7 +234,9 @@ impl<'a> CilAssemblyViewData<'a> { return Err(out_of_bounds_error!()); } - let metadata_slice = &data[metadata_offset..metadata_end]; + let metadata_slice = data + .get(metadata_offset..metadata_end) + .ok_or(out_of_bounds_error!())?; let metadata_root = Root::read(metadata_slice)?; let stream_names: Vec<&str> = metadata_root .stream_headers @@ -266,7 +271,9 @@ impl<'a> CilAssemblyViewData<'a> { return Err(out_of_bounds_error!()); } - let stream_data = &metadata_slice[stream_offset..stream_end]; + let stream_data = metadata_slice + .get(stream_offset..stream_end) + .ok_or(out_of_bounds_error!())?; match stream.name.as_str() { "#~" | "#-" => { @@ -773,7 +780,7 @@ impl CilAssemblyView { // before returning the owned `file: Arc` — leaving count = 1. let heads = self.into_heads(); Arc::try_unwrap(heads.file).map_err(|_| { - crate::Error::Other( + Error::Other( "CilAssemblyView::into_file: Arc has unexpected extra owners".to_string(), ) }) diff --git a/dotscope/src/metadata/customattributes/parser.rs b/dotscope/src/metadata/customattributes/parser.rs index d5e77f37..83962918 100644 --- a/dotscope/src/metadata/customattributes/parser.rs +++ b/dotscope/src/metadata/customattributes/parser.rs @@ -499,29 +499,30 @@ impl<'a> CustomAttributeParser<'a> { let fixed_args = self.parse_fixed_arguments(params)?; // Parse named arguments using explicit type tags - let named_args = - if self.parser.has_more_data() && self.parser.len() >= self.parser.pos() + 2 { - let num_named = self.parser.read_le::()?; - if num_named > MAX_NAMED_ARGS { - return Err(malformed_error!( - "Custom attribute has too many named arguments: {} (max: {})", - num_named, - MAX_NAMED_ARGS - )); - } + let named_args = if self.parser.has_more_data() + && self.parser.len() >= self.parser.pos().saturating_add(2) + { + let num_named = self.parser.read_le::()?; + if num_named > MAX_NAMED_ARGS { + return Err(malformed_error!( + "Custom attribute has too many named arguments: {} (max: {})", + num_named, + MAX_NAMED_ARGS + )); + } - let mut args = Vec::with_capacity(num_named as usize); - for _ in 0..num_named { - if let Some(arg) = self.parse_named_argument()? { - args.push(arg); - } else { - break; - } + let mut args = Vec::with_capacity(num_named as usize); + for _ in 0..num_named { + if let Some(arg) = self.parse_named_argument()? { + args.push(arg); + } else { + break; } - args - } else { - vec![] - }; + } + args + } else { + vec![] + }; Ok(CustomAttributeValue { fixed_args, @@ -1083,7 +1084,7 @@ impl<'a> CustomAttributeParser<'a> { work_stack.push(WorkItem::ParseTag(type_tag)); while let Some(work) = work_stack.pop() { - if work_stack.len() + result_stack.len() > MAX_NESTING_DEPTH { + if work_stack.len().saturating_add(result_stack.len()) > MAX_NESTING_DEPTH { return Err(DepthLimitExceeded(MAX_NESTING_DEPTH)); } @@ -1218,7 +1219,7 @@ impl<'a> CustomAttributeParser<'a> { } // Elements are on stack in correct order (last parsed = last in array) - let start_idx = result_stack.len() - count_usize; + let start_idx = result_stack.len().saturating_sub(count_usize); let elements = result_stack.drain(start_idx..).collect(); result_stack.push(CustomAttributeArgument::Array(elements)); } @@ -1278,7 +1279,7 @@ impl<'a> CustomAttributeParser<'a> { let length = p.read_compressed_uint()?; // Check if we have enough data for the string - let remaining_data = p.len() - p.pos(); + let remaining_data = p.len().saturating_sub(p.pos()); if length as usize > remaining_data { return Err(malformed_error!("not enough data")); } @@ -1297,7 +1298,11 @@ impl<'a> CustomAttributeParser<'a> { } // Check if the bytes are valid UTF-8 - let string_bytes = &p.data()[p.pos()..p.pos() + length as usize]; + let start = p.pos(); + let end = start + .checked_add(length as usize) + .ok_or_else(|| malformed_error!("string length overflow"))?; + let string_bytes = p.data().get(start..end).ok_or(out_of_bounds_error!())?; let s = std::str::from_utf8(string_bytes) .map_err(|_| malformed_error!("invalid UTF-8"))?; let result = s.to_string(); @@ -1351,7 +1356,7 @@ impl<'a> CustomAttributeParser<'a> { // Not a null string, parse as normal compressed uint + data let length = self.parser.read_compressed_uint()?; - let available_data = self.parser.len() - self.parser.pos(); + let available_data = self.parser.len().saturating_sub(self.parser.pos()); if length == 0 { Ok(String::new()) @@ -1363,7 +1368,7 @@ impl<'a> CustomAttributeParser<'a> { String::from_utf8(bytes).map_err(|e| { malformed_error!( "Invalid UTF-8 in custom attribute string at position {}: {}", - self.parser.pos() - length as usize, + self.parser.pos().saturating_sub(length as usize), e.utf8_error() ) }) @@ -1372,7 +1377,7 @@ impl<'a> CustomAttributeParser<'a> { "String length {} exceeds available data {} (blob context: pos={}, len={}, first_byte=0x{:02X})", length, available_data, - self.parser.pos() - 1, // subtract 1 because we already read the length + self.parser.pos().saturating_sub(1), // subtract 1 because we already read the length self.parser.len(), first_byte )) diff --git a/dotscope/src/metadata/customdebuginformation/parser.rs b/dotscope/src/metadata/customdebuginformation/parser.rs index ea70d537..1592abe0 100644 --- a/dotscope/src/metadata/customdebuginformation/parser.rs +++ b/dotscope/src/metadata/customdebuginformation/parser.rs @@ -378,7 +378,11 @@ impl<'a> CustomDebugParser<'a> { return Vec::new(); } - self.parser.data()[pos..len].to_vec() + self.parser + .data() + .get(pos..len) + .map(<[u8]>::to_vec) + .unwrap_or_default() } } diff --git a/dotscope/src/metadata/dependencies/graph.rs b/dotscope/src/metadata/dependencies/graph.rs index 9a64443f..cf1fd2b9 100644 --- a/dotscope/src/metadata/dependencies/graph.rs +++ b/dotscope/src/metadata/dependencies/graph.rs @@ -399,9 +399,13 @@ impl AssemblyDependencyGraph { // Add reversed edges between SCCs (for loading order: dependencies first) for node_idx in 0..graph.node_count() { - let source_scc = node_to_scc[node_idx]; + let Some(&source_scc) = node_to_scc.get(node_idx) else { + continue; + }; for successor in graph.successors(NodeId::new(node_idx)) { - let target_scc = node_to_scc[successor.index()]; + let Some(&target_scc) = node_to_scc.get(successor.index()) else { + continue; + }; if source_scc != target_scc { // target_scc should come before source_scc in loading order // So we add edge: target_scc -> source_scc @@ -428,7 +432,10 @@ impl AssemblyDependencyGraph { let mut result = Vec::with_capacity(graph.node_count()); for scc_node_id in scc_order { let scc_idx = scc_node_id.index(); - for &node_id in &sccs[scc_idx] { + let Some(scc) = sccs.get(scc_idx) else { + continue; + }; + for &node_id in scc { if let Some(identity) = indexed_graph.get_key(node_id) { result.push(identity.clone()); } @@ -593,7 +600,7 @@ impl AssemblyDependencyGraph { 1 } else { // Different assemblies: count each new one - usize::from(source_is_new) + usize::from(target_is_new) + usize::from(source_is_new).saturating_add(usize::from(target_is_new)) }; if new_assemblies > 0 { diff --git a/dotscope/src/metadata/exports/builder.rs b/dotscope/src/metadata/exports/builder.rs index b5e7a7a6..0819a815 100644 --- a/dotscope/src/metadata/exports/builder.rs +++ b/dotscope/src/metadata/exports/builder.rs @@ -103,7 +103,7 @@ impl NativeExportsBuilder { pub fn add_function(mut self, name: impl Into, ordinal: u16, address: u32) -> Self { self.functions.push((name.into(), ordinal, address)); if ordinal >= self.next_ordinal { - self.next_ordinal = ordinal + 1; + self.next_ordinal = ordinal.saturating_add(1); } self } @@ -134,7 +134,7 @@ impl NativeExportsBuilder { pub fn add_function_auto(mut self, name: impl Into, address: u32) -> Self { let ordinal = self.next_ordinal; self.functions.push((name.into(), ordinal, address)); - self.next_ordinal += 1; + self.next_ordinal = self.next_ordinal.saturating_add(1); self } @@ -165,7 +165,7 @@ impl NativeExportsBuilder { pub fn add_function_by_ordinal(mut self, ordinal: u16, address: u32) -> Self { self.ordinal_functions.push((ordinal, address)); if ordinal >= self.next_ordinal { - self.next_ordinal = ordinal + 1; + self.next_ordinal = ordinal.saturating_add(1); } self } @@ -195,7 +195,7 @@ impl NativeExportsBuilder { pub fn add_function_by_ordinal_auto(mut self, address: u32) -> Self { let ordinal = self.next_ordinal; self.ordinal_functions.push((ordinal, address)); - self.next_ordinal += 1; + self.next_ordinal = self.next_ordinal.saturating_add(1); self } @@ -232,7 +232,7 @@ impl NativeExportsBuilder { ) -> Self { self.forwarders.push((name.into(), ordinal, target.into())); if ordinal >= self.next_ordinal { - self.next_ordinal = ordinal + 1; + self.next_ordinal = ordinal.saturating_add(1); } self } @@ -267,7 +267,7 @@ impl NativeExportsBuilder { ) -> Self { let ordinal = self.next_ordinal; self.forwarders.push((name.into(), ordinal, target.into())); - self.next_ordinal += 1; + self.next_ordinal = self.next_ordinal.saturating_add(1); self } diff --git a/dotscope/src/metadata/exports/container.rs b/dotscope/src/metadata/exports/container.rs index f09cbb0b..b2336f3f 100644 --- a/dotscope/src/metadata/exports/container.rs +++ b/dotscope/src/metadata/exports/container.rs @@ -416,7 +416,10 @@ impl UnifiedExportContainer { /// println!("Total exports: {}", container.total_count()); /// ``` pub fn total_count(&self) -> usize { - self.cil.len() + self.native.function_count() + self.native.forwarder_count() + self.cil + .len() + .saturating_add(self.native.function_count()) + .saturating_add(self.native.forwarder_count()) } /// Add a native function export. diff --git a/dotscope/src/metadata/exports/native.rs b/dotscope/src/metadata/exports/native.rs index 8ada8b13..e28fbd32 100644 --- a/dotscope/src/metadata/exports/native.rs +++ b/dotscope/src/metadata/exports/native.rs @@ -419,7 +419,9 @@ impl NativeExports { // Update next ordinal if ordinal >= self.next_ordinal { - self.next_ordinal = ordinal + 1; + self.next_ordinal = ordinal + .checked_add(1) + .ok_or_else(|| malformed_error!("Ordinal counter overflow"))?; } Ok(()) @@ -479,7 +481,9 @@ impl NativeExports { // Update next ordinal if ordinal >= self.next_ordinal { - self.next_ordinal = ordinal + 1; + self.next_ordinal = ordinal + .checked_add(1) + .ok_or_else(|| malformed_error!("Ordinal counter overflow"))?; } Ok(()) @@ -568,11 +572,18 @@ impl NativeExports { self.name_to_ordinal.insert(name.to_owned(), ordinal); } - self.directory.function_count = to_u32(self.functions.len() + self.forwarders.len())?; + self.directory.function_count = to_u32( + self.functions + .len() + .checked_add(self.forwarders.len()) + .ok_or_else(|| malformed_error!("Function/forwarder count overflow"))?, + )?; self.directory.name_count = to_u32(self.name_to_ordinal.len())?; if ordinal >= self.next_ordinal { - self.next_ordinal = ordinal + 1; + self.next_ordinal = ordinal + .checked_add(1) + .ok_or_else(|| malformed_error!("Ordinal counter overflow"))?; } Ok(()) @@ -610,7 +621,7 @@ impl NativeExports { /// ``` #[must_use] pub fn function_count(&self) -> usize { - self.functions.len() + self.forwarders.len() + self.functions.len().saturating_add(self.forwarders.len()) } /// Get the number of forwarder exports. @@ -918,34 +929,76 @@ impl NativeExports { // EAT must cover from base_ordinal to highest ordinal let eat_entry_count = if max_ordinal >= self.directory.base_ordinal { - u32::from(max_ordinal - self.directory.base_ordinal + 1) + let span = max_ordinal + .checked_sub(self.directory.base_ordinal) + .and_then(|v| v.checked_add(1)) + .ok_or_else(|| malformed_error!("EAT entry count overflow"))?; + u32::from(span) } else { 0 }; - let eat_size = eat_entry_count * 4; // 4 bytes per address - let name_table_size = self.directory.name_count * 4; // 4 bytes per name RVA - let ordinal_table_size = self.directory.name_count * 2; // 2 bytes per ordinal - - let eat_rva = base_rva + export_dir_size; - let name_table_rva = eat_rva + eat_size; - let ordinal_table_rva = name_table_rva + name_table_size; - let strings_rva = ordinal_table_rva + ordinal_table_size; + let eat_size = eat_entry_count + .checked_mul(4) + .ok_or_else(|| malformed_error!("EAT size overflow"))?; + let name_table_size = self + .directory + .name_count + .checked_mul(4) + .ok_or_else(|| malformed_error!("Name table size overflow"))?; + let ordinal_table_size = self + .directory + .name_count + .checked_mul(2) + .ok_or_else(|| malformed_error!("Ordinal table size overflow"))?; + + let eat_rva = base_rva + .checked_add(export_dir_size) + .ok_or_else(|| malformed_error!("EAT RVA overflow"))?; + let name_table_rva = eat_rva + .checked_add(eat_size) + .ok_or_else(|| malformed_error!("Name table RVA overflow"))?; + let ordinal_table_rva = name_table_rva + .checked_add(name_table_size) + .ok_or_else(|| malformed_error!("Ordinal table RVA overflow"))?; + let strings_rva = ordinal_table_rva + .checked_add(ordinal_table_size) + .ok_or_else(|| malformed_error!("Strings RVA overflow"))?; // Calculate total size needed for strings - let mut total_strings_size = self.directory.dll_name.len() + 1; // DLL name + null + let mut total_strings_size: usize = self + .directory + .dll_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("DLL name size overflow"))?; for name in self.name_to_ordinal.keys() { - total_strings_size += name.len() + 1; // name + null + let name_size = name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Name size overflow"))?; + total_strings_size = total_strings_size + .checked_add(name_size) + .ok_or_else(|| malformed_error!("Strings size overflow"))?; } for forwarder in self.forwarders.values() { - total_strings_size += forwarder.target.len() + 1; // target + null + let target_size = forwarder + .target + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Forwarder target size overflow"))?; + total_strings_size = total_strings_size + .checked_add(target_size) + .ok_or_else(|| malformed_error!("Strings size overflow"))?; } + let total_strings_size_u32 = to_u32(total_strings_size)?; let total_size = export_dir_size - + eat_size - + name_table_size - + ordinal_table_size - + to_u32(total_strings_size)?; + .checked_add(eat_size) + .and_then(|s| s.checked_add(name_table_size)) + .and_then(|s| s.checked_add(ordinal_table_size)) + .and_then(|s| s.checked_add(total_strings_size_u32)) + .ok_or_else(|| malformed_error!("Export table total size overflow"))?; let mut data = vec![0u8; total_size as usize]; let mut offset = 0; @@ -976,13 +1029,31 @@ impl NativeExports { // Calculate string offsets for forwarders let mut forwarder_string_offsets = HashMap::new(); - let mut current_forwarder_offset = self.directory.dll_name.len() + 1; // After DLL name + let mut current_forwarder_offset = self + .directory + .dll_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("DLL name size overflow"))?; // After DLL name for (name, _) in &named_exports { - current_forwarder_offset += name.len() + 1; // +1 for null terminator + let name_size = name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Name size overflow"))?; + current_forwarder_offset = current_forwarder_offset + .checked_add(name_size) + .ok_or_else(|| malformed_error!("Forwarder offset overflow"))?; } for forwarder in self.forwarders.values() { forwarder_string_offsets.insert(forwarder.ordinal, current_forwarder_offset); - current_forwarder_offset += forwarder.target.len() + 1; + let target_size = forwarder + .target + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Forwarder target size overflow"))?; + current_forwarder_offset = current_forwarder_offset + .checked_add(target_size) + .ok_or_else(|| malformed_error!("Forwarder offset overflow"))?; } // Write Export Address Table (EAT) @@ -995,36 +1066,65 @@ impl NativeExports { // Go back and populate known entries let mut temp_offset = eat_start_offset; for ordinal_index in 0..eat_entry_count { - #[allow(clippy::cast_possible_truncation)] - let ordinal = self.directory.base_ordinal + (ordinal_index as u16); + let ordinal_index_u16 = u16::try_from(ordinal_index) + .map_err(|_| malformed_error!("Ordinal index exceeds u16 range"))?; + let ordinal = self + .directory + .base_ordinal + .checked_add(ordinal_index_u16) + .ok_or_else(|| malformed_error!("Ordinal computation overflow"))?; + + let temp_end = temp_offset + .checked_add(4) + .ok_or_else(|| malformed_error!("EAT entry offset overflow"))?; if let Some(function) = self.functions.get(&ordinal) { // Regular function - write address - data[temp_offset..temp_offset + 4].copy_from_slice(&function.address.to_le_bytes()); + data.get_mut(temp_offset..temp_end) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&function.address.to_le_bytes()); } else if let Some(_forwarder) = self.forwarders.get(&ordinal) { // Forwarder - write RVA to forwarder string if let Some(&string_offset) = forwarder_string_offsets.get(&ordinal) { - let forwarder_rva = strings_rva + to_u32(string_offset)?; - data[temp_offset..temp_offset + 4] + let forwarder_rva = strings_rva + .checked_add(to_u32(string_offset)?) + .ok_or_else(|| malformed_error!("Forwarder RVA overflow"))?; + data.get_mut(temp_offset..temp_end) + .ok_or(out_of_bounds_error!())? .copy_from_slice(&forwarder_rva.to_le_bytes()); } } // Otherwise leave as 0 (no function at this ordinal) - temp_offset += 4; + temp_offset = temp_end; } // Write Export Name Table - let mut name_string_offset = self.directory.dll_name.len() + 1; // After DLL name + let mut name_string_offset = self + .directory + .dll_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("DLL name size overflow"))?; // After DLL name for (name, _) in &named_exports { - let name_rva = strings_rva + to_u32(name_string_offset)?; + let name_rva = strings_rva + .checked_add(to_u32(name_string_offset)?) + .ok_or_else(|| malformed_error!("Name RVA overflow"))?; write_le_at(&mut data, &mut offset, name_rva)?; - name_string_offset += name.len() + 1; // +1 for null terminator + let name_size = name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Name size overflow"))?; + name_string_offset = name_string_offset + .checked_add(name_size) + .ok_or_else(|| malformed_error!("Name offset overflow"))?; } // Write Export Ordinal Table for (_, ordinal) in &named_exports { - let adjusted_ordinal = ordinal - self.directory.base_ordinal; + let adjusted_ordinal = ordinal + .checked_sub(self.directory.base_ordinal) + .ok_or_else(|| malformed_error!("Ordinal adjustment underflow"))?; write_le_at(&mut data, &mut offset, adjusted_ordinal)?; } diff --git a/dotscope/src/metadata/identity/assembly.rs b/dotscope/src/metadata/identity/assembly.rs index affadd20..03e47f04 100644 --- a/dotscope/src/metadata/identity/assembly.rs +++ b/dotscope/src/metadata/identity/assembly.rs @@ -569,11 +569,10 @@ impl AssemblyIdentity { let parts: Vec<&str> = display_name.split(',').map(str::trim).collect(); - if parts.is_empty() { - return Err(malformed_error!("Empty assembly display name")); - } - - let name = parts[0].to_string(); + let name = parts + .first() + .ok_or_else(|| malformed_error!("Empty assembly display name"))? + .to_string(); if name.is_empty() { return Err(malformed_error!("Assembly name cannot be empty")); } @@ -650,7 +649,7 @@ impl AssemblyIdentity { pub fn display_name(&self) -> String { // Pre-allocate with estimated capacity to minimize reallocations // Typical format: "Name, Version=x.x.x.x, Culture=neutral, PublicKeyToken=xxxxxxxxxxxxxxxx" - let mut result = String::with_capacity(self.name.len() + 80); + let mut result = String::with_capacity(self.name.len().saturating_add(80)); result.push_str(&self.name); @@ -1045,7 +1044,10 @@ impl AssemblyVersion { let mut components = [0u16; 4]; for (i, part) in parts.iter().enumerate() { - components[i] = part + let slot = components + .get_mut(i) + .ok_or_else(|| malformed_error!("Invalid version format: {}", version_str))?; + *slot = part .parse::() .map_err(|_| malformed_error!("Invalid version component: {}", part))?; } diff --git a/dotscope/src/metadata/identity/cryptographic.rs b/dotscope/src/metadata/identity/cryptographic.rs index 64aef68c..6e0e2434 100644 --- a/dotscope/src/metadata/identity/cryptographic.rs +++ b/dotscope/src/metadata/identity/cryptographic.rs @@ -347,7 +347,10 @@ impl Identity { } }; // Token is the last 8 bytes of the hash as little-endian u64 - read_le::(&hash[hash.len() - 8..]) + let start = hash.len().saturating_sub(8); + read_le::(hash.get(start..).ok_or_else(|| { + malformed_error!("Hash output is too short to extract public key token") + })?) } } diff --git a/dotscope/src/metadata/imports/cil.rs b/dotscope/src/metadata/imports/cil.rs index f0f77e96..cfe45ce0 100644 --- a/dotscope/src/metadata/imports/cil.rs +++ b/dotscope/src/metadata/imports/cil.rs @@ -905,7 +905,7 @@ impl Imports { /// Get an iterator over all imports in the container. /// /// Returns an iterator that yields [`crossbeam_skiplist::map::Entry`] instances, - /// each containing a ([`crate::metadata::token::Token`], [`crate::metadata::imports::ImportRc`]) pair. + /// each containing a ([`Token`], [`crate::metadata::imports::ImportRc`]) pair. /// The iteration order is sorted by token value due to the skip map's ordering properties. /// /// # Iterator Properties @@ -993,8 +993,8 @@ impl Imports { /// This method is thread-safe and can be called concurrently from multiple threads. pub fn by_name(&self, name: &str) -> Option { if let Some(tokens) = self.by_name.get(name) { - if !tokens.is_empty() { - if let Some(token) = self.data.get(&tokens[0]) { + if let Some(first) = tokens.first() { + if let Some(token) = self.data.get(first) { return Some(token.value().clone()); } } @@ -1076,8 +1076,8 @@ impl Imports { /// ``` pub fn by_fullname(&self, name: &str) -> Option { if let Some(tokens) = self.by_fullname.get(name) { - if !tokens.is_empty() { - if let Some(token) = self.data.get(&tokens[0]) { + if let Some(first) = tokens.first() { + if let Some(token) = self.data.get(first) { return Some(token.value().clone()); } } diff --git a/dotscope/src/metadata/imports/container.rs b/dotscope/src/metadata/imports/container.rs index dee71458..aae8a08e 100644 --- a/dotscope/src/metadata/imports/container.rs +++ b/dotscope/src/metadata/imports/container.rs @@ -353,7 +353,9 @@ impl UnifiedImportContainer { /// println!("Total imports: {}", container.total_count()); /// ``` pub fn total_count(&self) -> usize { - self.cil.len() + self.native.total_function_count() + self.cil + .len() + .saturating_add(self.native.total_function_count()) } /// Add a native function import. diff --git a/dotscope/src/metadata/imports/native.rs b/dotscope/src/metadata/imports/native.rs index bea024b0..392b8fd8 100644 --- a/dotscope/src/metadata/imports/native.rs +++ b/dotscope/src/metadata/imports/native.rs @@ -106,7 +106,7 @@ use std::collections::HashMap; use crate::{ file::pe::Import, - utils::{write_le_at, write_string_at}, + utils::{to_u32, write_le_at, write_string_at}, Result, }; @@ -473,7 +473,7 @@ impl NativeImports { )); } - let iat_rva = self.allocate_iat_rva(); + let iat_rva = self.allocate_iat_rva()?; let function = Import { dll: dll_name.to_owned(), @@ -556,7 +556,7 @@ impl NativeImports { )); } - let iat_rva = self.allocate_iat_rva(); + let iat_rva = self.allocate_iat_rva()?; let descriptor = self .descriptors .get_mut(dll_name) @@ -772,57 +772,20 @@ impl NativeImports { let mut updated_entries = HashMap::new(); for (old_rva, mut entry) in self.iat_entries.drain() { - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let new_rva = if rva_delta >= 0 { - old_rva.checked_add(rva_delta as u32) - } else { - old_rva.checked_sub((-rva_delta) as u32) - }; - - match new_rva { - Some(rva) => { - entry.rva = rva; - updated_entries.insert(rva, entry); - } - None => { - return Err(malformed_error!("RVA delta would cause overflow")); - } - } + let new_rva = adjust_rva(old_rva, rva_delta)?; + entry.rva = new_rva; + updated_entries.insert(new_rva, entry); } self.iat_entries = updated_entries; for descriptor in self.descriptors.values_mut() { for function in &mut descriptor.functions { - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let new_rva = if rva_delta >= 0 { - function.rva.checked_add(rva_delta as u32) - } else { - function.rva.checked_sub((-rva_delta) as u32) - }; - - match new_rva { - Some(rva) => function.rva = rva, - None => { - return Err(malformed_error!("RVA delta would cause overflow")); - } - } + function.rva = adjust_rva(function.rva, rva_delta)?; } } - #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)] - let new_next_rva = if rva_delta >= 0 { - self.next_iat_rva.checked_add(rva_delta as u32) - } else { - self.next_iat_rva.checked_sub((-rva_delta) as u32) - }; - - match new_next_rva { - Some(rva) => self.next_iat_rva = rva, - None => { - return Err(malformed_error!("RVA delta would cause overflow")); - } - } + self.next_iat_rva = adjust_rva(self.next_iat_rva, rva_delta)?; Ok(()) } @@ -832,10 +795,16 @@ impl NativeImports { /// Returns the next available RVA for IAT allocation and increments /// the internal counter by the appropriate entry size (4 bytes for PE32, /// 8 bytes for PE32+). Used internally when adding new function imports. - fn allocate_iat_rva(&mut self) -> u32 { + /// + /// # Errors + /// Returns an error if the IAT RVA counter would overflow. + fn allocate_iat_rva(&mut self) -> Result { let rva = self.next_iat_rva; - self.next_iat_rva += self.iat_entry_size(); - rva + self.next_iat_rva = self + .next_iat_rva + .checked_add(self.iat_entry_size()) + .ok_or_else(|| malformed_error!("IAT RVA counter overflow"))?; + Ok(rva) } /// Calculate the total size of the Import Address Table (IAT) in bytes. @@ -848,17 +817,28 @@ impl NativeImports { /// /// # Returns /// Total IAT size in bytes. - #[must_use] - pub fn iat_byte_size(&self, is_pe32_plus: bool) -> usize { - let entry_size = if is_pe32_plus { 8 } else { 4 }; - let mut total_entries = 0; + /// + /// # Errors + /// Returns an error if the computed size would overflow `usize`. + pub fn iat_byte_size(&self, is_pe32_plus: bool) -> Result { + let entry_size: usize = if is_pe32_plus { 8 } else { 4 }; + let mut total_entries: usize = 0; for descriptor in self.descriptors.values() { // Each DLL needs: function entries + 1 null terminator - total_entries += descriptor.functions.len() + 1; + let dll_entries = descriptor + .functions + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("IAT entry count overflow"))?; + total_entries = total_entries + .checked_add(dll_entries) + .ok_or_else(|| malformed_error!("IAT entry count overflow"))?; } - total_entries * entry_size + total_entries + .checked_mul(entry_size) + .ok_or_else(|| malformed_error!("IAT byte size overflow")) } /// Build the IAT (Import Address Table) bytes for .NET PE generation. @@ -882,8 +862,8 @@ impl NativeImports { return Ok(Vec::new()); } - let entry_size = if is_pe32_plus { 8 } else { 4 }; - let mut iat_bytes = Vec::with_capacity(self.iat_byte_size(is_pe32_plus)); + let entry_size: usize = if is_pe32_plus { 8 } else { 4 }; + let mut iat_bytes = Vec::with_capacity(self.iat_byte_size(is_pe32_plus)?); // Sort descriptors for deterministic ordering (mscoree.dll should be first when building import list) let mut descriptors_sorted: Vec<_> = self.descriptors.values().collect(); @@ -891,28 +871,52 @@ impl NativeImports { // Calculate where strings will be in the import table // Layout: descriptors + null descriptor + ILT entries + strings - let descriptor_size = (self.descriptors.len() + 1) * 20; // +1 for null terminator - - let mut total_ilt_entries = 0; + let descriptor_count_with_null = self + .descriptors + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Descriptor count overflow"))?; + let descriptor_size = descriptor_count_with_null + .checked_mul(20) + .ok_or_else(|| malformed_error!("Descriptor table size overflow"))?; + + let mut total_ilt_entries: usize = 0; for desc in &descriptors_sorted { - total_ilt_entries += desc.functions.len() + 1; // +1 for null terminator per DLL + let dll_entries = desc + .functions + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("ILT entry count overflow"))?; + total_ilt_entries = total_ilt_entries + .checked_add(dll_entries) + .ok_or_else(|| malformed_error!("ILT entry count overflow"))?; } - let ilt_size = total_ilt_entries * entry_size; + let ilt_size = total_ilt_entries + .checked_mul(entry_size) + .ok_or_else(|| malformed_error!("ILT byte size overflow"))?; // Strings start after descriptors and ILT - // Safe: PE import table sizes always fit in u32 - #[allow(clippy::cast_possible_truncation)] - let strings_start_rva = import_table_rva + (descriptor_size + ilt_size) as u32; + let header_size = descriptor_size + .checked_add(ilt_size) + .ok_or_else(|| malformed_error!("Import table header size overflow"))?; + let strings_start_rva = import_table_rva + .checked_add(to_u32(header_size)?) + .ok_or_else(|| malformed_error!("Strings start RVA overflow"))?; // Calculate hint/name RVAs for each function let mut current_string_rva = strings_start_rva; // First pass: calculate DLL name RVAs (they come first in strings) let mut dll_name_end_rva = current_string_rva; - #[allow(clippy::cast_possible_truncation)] for desc in &descriptors_sorted { - // Safe: PE import table sizes always fit in u32 - dll_name_end_rva += (desc.dll_name.len() + 1) as u32; // +1 for null terminator + let dll_name_size = desc + .dll_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("DLL name size overflow"))?; + dll_name_end_rva = dll_name_end_rva + .checked_add(to_u32(dll_name_size)?) + .ok_or_else(|| malformed_error!("DLL name RVA overflow"))?; } // Function hint/names come after DLL names @@ -947,12 +951,18 @@ impl NativeImports { } // Advance string RVA for named imports - #[allow(clippy::cast_possible_truncation)] if let Some(function_name) = function.name.as_ref() { - current_string_rva += 2; // hint (2 bytes) - // Safe: PE import table sizes always fit in u32 - current_string_rva += (function_name.len() + 1) as u32; - // name + null + // hint (2 bytes) + name + null + let name_size = function_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Function name size overflow"))?; + let advance = name_size + .checked_add(2) + .ok_or_else(|| malformed_error!("Hint/name advance overflow"))?; + current_string_rva = current_string_rva + .checked_add(to_u32(advance)?) + .ok_or_else(|| malformed_error!("String RVA overflow"))?; } } @@ -994,74 +1004,119 @@ impl NativeImports { return Ok(Vec::new()); } - let entry_size = if is_pe32_plus { 8 } else { 4 }; + let entry_size: usize = if is_pe32_plus { 8 } else { 4 }; // Sort descriptors for deterministic ordering let mut descriptors_sorted: Vec<_> = self.descriptors.values().collect(); descriptors_sorted.sort_by_key(|d| d.dll_name.to_lowercase()); // Calculate layout sizes - let descriptor_table_size = (descriptors_sorted.len() + 1) * 20; // +1 for null terminator + let descriptor_count_with_null = descriptors_sorted + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Descriptor count overflow"))?; + let descriptor_table_size = descriptor_count_with_null + .checked_mul(20) + .ok_or_else(|| malformed_error!("Descriptor table size overflow"))?; // Calculate total ILT size - let mut total_ilt_entries = 0; + let mut total_ilt_entries: usize = 0; for desc in &descriptors_sorted { - total_ilt_entries += desc.functions.len() + 1; // +1 for null terminator per DLL + let dll_entries = desc + .functions + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("ILT entry count overflow"))?; + total_ilt_entries = total_ilt_entries + .checked_add(dll_entries) + .ok_or_else(|| malformed_error!("ILT entry count overflow"))?; } - let ilt_size = total_ilt_entries * entry_size; + let ilt_size = total_ilt_entries + .checked_mul(entry_size) + .ok_or_else(|| malformed_error!("ILT byte size overflow"))?; // Calculate total string size - let mut total_string_size = 0; + let mut total_string_size: usize = 0; for desc in &descriptors_sorted { - total_string_size += desc.dll_name.len() + 1; // DLL name + null + // DLL name + null + let dll_size = desc + .dll_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("DLL name size overflow"))?; + total_string_size = total_string_size + .checked_add(dll_size) + .ok_or_else(|| malformed_error!("String table size overflow"))?; for func in &desc.functions { if let Some(ref name) = func.name { - total_string_size += 2 + name.len() + 1; // hint + name + null + // hint + name + null + let name_size = name + .len() + .checked_add(3) + .ok_or_else(|| malformed_error!("Function name size overflow"))?; + total_string_size = total_string_size + .checked_add(name_size) + .ok_or_else(|| malformed_error!("String table size overflow"))?; } } } - // Allocate buffer - let total_size = descriptor_table_size + ilt_size + total_string_size + 16; // +16 for alignment padding + // Allocate buffer (+16 for alignment padding) + let total_size = descriptor_table_size + .checked_add(ilt_size) + .and_then(|s| s.checked_add(total_string_size)) + .and_then(|s| s.checked_add(16)) + .ok_or_else(|| malformed_error!("Import table total size overflow"))?; let mut data = vec![0u8; total_size]; let mut offset = 0; // Calculate RVAs - // Safe: PE import table sizes always fit in u32 - #[allow(clippy::cast_possible_truncation)] - let ilt_start_rva = table_rva + descriptor_table_size as u32; - // Safe: PE import table sizes always fit in u32 - #[allow(clippy::cast_possible_truncation)] - let strings_start_rva = ilt_start_rva + ilt_size as u32; + let ilt_start_rva = table_rva + .checked_add(to_u32(descriptor_table_size)?) + .ok_or_else(|| malformed_error!("ILT start RVA overflow"))?; + let strings_start_rva = ilt_start_rva + .checked_add(to_u32(ilt_size)?) + .ok_or_else(|| malformed_error!("Strings start RVA overflow"))?; // Build ILT offset map and string RVAs let mut ilt_rva = ilt_start_rva; - let mut iat_offset = 0u32; // Offset within IAT for each DLL + let mut iat_offset: u32 = 0; // Offset within IAT for each DLL // Pre-calculate DLL name RVAs let mut dll_name_rvas = Vec::with_capacity(descriptors_sorted.len()); let mut current_dll_name_rva = strings_start_rva; - #[allow(clippy::cast_possible_truncation)] for desc in &descriptors_sorted { dll_name_rvas.push(current_dll_name_rva); - // Safe: PE import table sizes always fit in u32 - current_dll_name_rva += (desc.dll_name.len() + 1) as u32; + let dll_size = desc + .dll_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("DLL name size overflow"))?; + current_dll_name_rva = current_dll_name_rva + .checked_add(to_u32(dll_size)?) + .ok_or_else(|| malformed_error!("DLL name RVA overflow"))?; } // Pre-calculate function name RVAs let mut current_func_name_rva = current_dll_name_rva; let mut func_name_rvas: Vec> = Vec::with_capacity(descriptors_sorted.len()); - #[allow(clippy::cast_possible_truncation)] for desc in &descriptors_sorted { let mut rvas = Vec::with_capacity(desc.functions.len()); for func in &desc.functions { if let Some(function_name) = func.name.as_ref() { rvas.push(u64::from(current_func_name_rva)); - current_func_name_rva += 2; // hint - // Safe: PE import table sizes always fit in u32 - current_func_name_rva += (function_name.len() + 1) as u32; - // name + null + // hint (2 bytes) + name + null + let name_size = function_name + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Function name size overflow"))?; + let advance = name_size + .checked_add(2) + .ok_or_else(|| malformed_error!("Hint/name advance overflow"))?; + current_func_name_rva = current_func_name_rva + .checked_add(to_u32(advance)?) + .ok_or_else(|| malformed_error!("Function name RVA overflow"))?; } else { rvas.push(0); // Will use ordinal } @@ -1070,10 +1125,13 @@ impl NativeImports { } // Write import descriptors - #[allow(clippy::cast_possible_truncation)] for (i, desc) in descriptors_sorted.iter().enumerate() { let desc_ilt_rva = ilt_rva; - let desc_iat_rva = iat_rva + iat_offset; + let desc_iat_rva = iat_rva + .checked_add(iat_offset) + .ok_or_else(|| malformed_error!("IAT RVA overflow"))?; + + let dll_name_rva = *dll_name_rvas.get(i).ok_or(out_of_bounds_error!())?; // OriginalFirstThunk (ILT RVA) write_le_at::(&mut data, &mut offset, desc_ilt_rva)?; @@ -1082,15 +1140,26 @@ impl NativeImports { // ForwarderChain write_le_at::(&mut data, &mut offset, 0)?; // Name (DLL name RVA) - write_le_at::(&mut data, &mut offset, dll_name_rvas[i])?; + write_le_at::(&mut data, &mut offset, dll_name_rva)?; // FirstThunk (IAT RVA - points to external IAT) write_le_at::(&mut data, &mut offset, desc_iat_rva)?; - // Update offsets for next descriptor - // Safe: PE import table sizes always fit in u32 - let entries_for_dll = desc.functions.len() + 1; // +1 for null terminator - ilt_rva += (entries_for_dll * entry_size) as u32; - iat_offset += (entries_for_dll * entry_size) as u32; + // Update offsets for next descriptor (+1 for null terminator) + let entries_for_dll = desc + .functions + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("ILT entry count overflow"))?; + let dll_size = entries_for_dll + .checked_mul(entry_size) + .ok_or_else(|| malformed_error!("ILT DLL size overflow"))?; + let dll_size_u32 = to_u32(dll_size)?; + ilt_rva = ilt_rva + .checked_add(dll_size_u32) + .ok_or_else(|| malformed_error!("ILT RVA overflow"))?; + iat_offset = iat_offset + .checked_add(dll_size_u32) + .ok_or_else(|| malformed_error!("IAT offset overflow"))?; } // Write null terminator descriptor @@ -1100,6 +1169,7 @@ impl NativeImports { // Write ILT entries for (i, desc) in descriptors_sorted.iter().enumerate() { + let dll_func_rvas = func_name_rvas.get(i).ok_or(out_of_bounds_error!())?; for (j, func) in desc.functions.iter().enumerate() { let ilt_value = if func.name.is_none() { // Ordinal import @@ -1114,7 +1184,7 @@ impl NativeImports { } } else { // Named import - use pre-calculated RVA - func_name_rvas[i][j] + *dll_func_rvas.get(j).ok_or(out_of_bounds_error!())? }; if is_pe32_plus { @@ -1151,10 +1221,12 @@ impl NativeImports { // Align to 4 bytes while offset % 4 != 0 { - if offset < data.len() { - data[offset] = 0; + if let Some(slot) = data.get_mut(offset) { + *slot = 0; } - offset += 1; + offset = offset + .checked_add(1) + .ok_or_else(|| malformed_error!("Alignment offset overflow"))?; } // Truncate to actual size @@ -1170,6 +1242,25 @@ impl Default for NativeImports { } } +/// Apply a signed delta to a u32 RVA, returning an error on overflow. +fn adjust_rva(rva: u32, delta: i64) -> Result { + if delta >= 0 { + let abs_delta = + u32::try_from(delta).map_err(|_| malformed_error!("RVA delta exceeds u32 range"))?; + rva.checked_add(abs_delta) + .ok_or_else(|| malformed_error!("RVA delta would cause overflow")) + } else { + // Negate without overflow even when delta == i64::MIN + let abs_delta_i64 = delta + .checked_neg() + .ok_or_else(|| malformed_error!("RVA delta magnitude overflow"))?; + let abs_delta = u32::try_from(abs_delta_i64) + .map_err(|_| malformed_error!("RVA delta exceeds u32 range"))?; + rva.checked_sub(abs_delta) + .ok_or_else(|| malformed_error!("RVA delta would cause overflow")) + } +} + #[cfg(test)] mod tests { use super::*; diff --git a/dotscope/src/metadata/loader/graph.rs b/dotscope/src/metadata/loader/graph.rs index 1fcffcd9..0502882a 100644 --- a/dotscope/src/metadata/loader/graph.rs +++ b/dotscope/src/metadata/loader/graph.rs @@ -107,7 +107,7 @@ impl<'a> LoaderGraph<'a> { let key = LoaderKey::Special { sequence: self.special_counter, }; - self.special_counter += 1; + self.special_counter = self.special_counter.saturating_add(1); key }; @@ -377,12 +377,15 @@ impl<'a> LoaderGraph<'a> { let _ = writeln!(result, "Level {level_idx}: ["); for loader in level { // Find the LoaderKey for this loader - let loader_key = self + let Some(loader_key) = self .loaders .iter() .find(|(_, &l)| std::ptr::eq(*loader, l)) .map(|(key, _)| key) - .expect("Loader not found in graph"); + else { + let _ = writeln!(result, " (depends on: ?)"); + continue; + }; let deps = self.dependencies.get(loader_key).map_or_else( || "None".to_string(), diff --git a/dotscope/src/metadata/loader/mod.rs b/dotscope/src/metadata/loader/mod.rs index 8b0f45e6..139dee79 100644 --- a/dotscope/src/metadata/loader/mod.rs +++ b/dotscope/src/metadata/loader/mod.rs @@ -173,11 +173,21 @@ use std::{sync::LazyLock, time::Instant}; /// /// These conditions indicate programming errors that should be caught during development. static EXECUTION_LEVELS: LazyLock>> = LazyLock::new(|| { - let graph = build_dependency_graph(&LOADERS) - .expect("Static loader dependency graph must be valid - check for missing loaders or circular dependencies"); - graph.topological_levels().expect( - "Static loader dependency graph must be acyclic - check loader dependencies for cycles", - ) + let graph = match build_dependency_graph(&LOADERS) { + Ok(g) => g, + Err(e) => { + log::error!( + "Static loader dependency graph is invalid - check for missing loaders or circular dependencies: {e}" + ); + return Vec::new(); + } + }; + graph.topological_levels().unwrap_or_else(|e| { + log::error!( + "Static loader dependency graph must be acyclic - check loader dependencies for cycles: {e}" + ); + Vec::new() + }) }); /// Trait for metadata table loaders. diff --git a/dotscope/src/metadata/marshalling/encoder.rs b/dotscope/src/metadata/marshalling/encoder.rs index d88dbacc..b30d1699 100644 --- a/dotscope/src/metadata/marshalling/encoder.rs +++ b/dotscope/src/metadata/marshalling/encoder.rs @@ -183,7 +183,7 @@ impl MarshallingEncoder { /// - The recursion depth exceeds [`MAX_RECURSION_DEPTH`] /// - A nested type fails to encode pub fn encode_native_type(&mut self, native_type: &NativeType) -> Result<()> { - self.depth += 1; + self.depth = self.depth.saturating_add(1); if self.depth >= MAX_RECURSION_DEPTH { return Err(RecursionLimit(MAX_RECURSION_DEPTH)); } @@ -324,7 +324,7 @@ impl MarshallingEncoder { } } - self.depth -= 1; + self.depth = self.depth.saturating_sub(1); Ok(()) } diff --git a/dotscope/src/metadata/marshalling/parser.rs b/dotscope/src/metadata/marshalling/parser.rs index 3985fa24..33fbac79 100644 --- a/dotscope/src/metadata/marshalling/parser.rs +++ b/dotscope/src/metadata/marshalling/parser.rs @@ -217,14 +217,14 @@ impl<'a> MarshallingParser<'a> { /// # Errors /// Returns an error if the native type cannot be parsed or recursion limit is exceeded pub fn parse_native_type(&mut self) -> Result { - self.depth += 1; + self.depth = self.depth.saturating_add(1); if self.depth >= MAX_RECURSION_DEPTH { - self.depth -= 1; + self.depth = self.depth.saturating_sub(1); return Err(RecursionLimit(MAX_RECURSION_DEPTH)); } let result = self.parse_native_type_inner(); - self.depth -= 1; + self.depth = self.depth.saturating_sub(1); result } diff --git a/dotscope/src/metadata/method/body.rs b/dotscope/src/metadata/method/body.rs index 097eaa6b..aefbacca 100644 --- a/dotscope/src/metadata/method/body.rs +++ b/dotscope/src/metadata/method/body.rs @@ -416,7 +416,9 @@ impl MethodBody { return Ok(false); } } else { - *filtered_count += 1; + *filtered_count = filtered_count + .checked_add(1) + .ok_or_else(|| malformed_error!("filtered_count overflow"))?; } } EhValidationMode::Raw => { @@ -440,7 +442,10 @@ impl MethodBody { match MethodBodyFlags::new(u16::from(first_byte & 0b_00000011_u8)) { MethodBodyFlags::TINY_FORMAT => { let size_code = (first_byte >> 2) as usize; - if size_code + 1 > data.len() { + let needed = size_code + .checked_add(1) + .ok_or_else(|| malformed_error!("tiny method body size overflow"))?; + if needed > data.len() { return Err(out_of_bounds_error!()); } @@ -465,15 +470,20 @@ impl MethodBody { let first_duo = read_le::(data)?; - let size_header = (first_duo >> 12) * 4; - let size_code = read_le::(&data[4..])?; - if data.len() < (size_code as usize + size_header as usize) { + let size_header = (first_duo >> 12).wrapping_mul(4); + let size_code = read_le::(data.get(4..).ok_or(out_of_bounds_error!())?)?; + let total = (size_code as usize) + .checked_add(size_header as usize) + .ok_or_else(|| malformed_error!("fat method body total size overflow"))?; + if data.len() < total { return Err(out_of_bounds_error!()); } - let local_var_sig_token = read_le::(&data[8..])?; + let local_var_sig_token = + read_le::(data.get(8..).ok_or(out_of_bounds_error!())?)?; let flags_header = MethodBodyFlags::new(first_duo & 0b_0000111111111111_u16); - let max_stack = read_le::(&data[2..])? as usize; + let max_stack = + read_le::(data.get(2..).ok_or(out_of_bounds_error!())?)? as usize; let is_init_local = flags_header.contains(MethodBodyFlags::INIT_LOCALS); @@ -483,34 +493,58 @@ impl MethodBody { let mut filtered_count: usize = 0; if flags_header.contains(MethodBodyFlags::MORE_SECTS) { // Set cursor to the end of the header + body, to process exception tables - let mut cursor = size_header as usize + size_code as usize; - cursor = (cursor + 3) & !3; - - while data.len() > (cursor + 4) { - let method_data_section_flags = - SectionFlags::new(read_le::(&data[cursor..])?); + let mut cursor = (size_header as usize) + .checked_add(size_code as usize) + .ok_or_else(|| malformed_error!("EH cursor overflow"))?; + cursor = cursor + .checked_add(3) + .ok_or_else(|| malformed_error!("EH alignment overflow"))? + & !3usize; + + loop { + let cursor_plus_4 = cursor + .checked_add(4) + .ok_or_else(|| malformed_error!("EH cursor+4 overflow"))?; + if data.len() <= cursor_plus_4 { + break; + } + let method_data_section_flags = SectionFlags::new(read_le::( + data.get(cursor..).ok_or(out_of_bounds_error!())?, + )?); if !method_data_section_flags.contains(SectionFlags::EHTABLE) { break; } if method_data_section_flags.contains(SectionFlags::FAT_FORMAT) { - let method_data_section_size = - read_le::(&data[cursor + 1..])? & 0x00FF_FFFF; + let cursor_plus_1 = cursor + .checked_add(1) + .ok_or_else(|| malformed_error!("EH cursor+1 overflow"))?; + let method_data_section_size = read_le::( + data.get(cursor_plus_1..).ok_or(out_of_bounds_error!())?, + )? & 0x00FF_FFFF; // ECMA-335 says DataSize includes the 4-byte section header, // but some tools emit DataSize without the header. Using // DataSize/24 handles both cases via integer division // (matches Mono runtime behavior in metadata.c:parse_section_data). let handler_count = method_data_section_size / 24; - let needed = 4 + handler_count as usize * 24; - if handler_count == 0 || data.len() < cursor + needed { + let needed = (handler_count as usize) + .checked_mul(24) + .and_then(|v| v.checked_add(4)) + .ok_or_else(|| malformed_error!("EH FAT section size overflow"))?; + let cursor_end = cursor + .checked_add(needed) + .ok_or_else(|| malformed_error!("EH FAT cursor+needed overflow"))?; + if handler_count == 0 || data.len() < cursor_end { break; } if !Self::check_handler_count(handler_count, eh_mode)? { break; } - cursor += 4; + cursor = cursor.checked_add(4).ok_or_else(|| { + malformed_error!("EH FAT cursor advance overflow") + })?; for handler_idx in 0..handler_count { let flags_u32 = read_le_at::(data, &mut cursor)?; @@ -549,8 +583,12 @@ impl MethodBody { } } } else { - let method_data_section_size = - u32::from(read_le::(&data[cursor + 1..])?); + let cursor_plus_1 = cursor + .checked_add(1) + .ok_or_else(|| malformed_error!("EH small cursor+1 overflow"))?; + let method_data_section_size = u32::from(read_le::( + data.get(cursor_plus_1..).ok_or(out_of_bounds_error!())?, + )?); // ECMA-335 says DataSize includes the 4-byte section header, // but some tools (e.g. AsmResolver used by BitMono) emit @@ -558,15 +596,25 @@ impl MethodBody { // cases via integer division (matches Mono runtime behavior // in metadata.c:parse_section_data). let handler_count = method_data_section_size / 12; - let needed = 4 + handler_count as usize * 12; - if handler_count == 0 || data.len() < cursor + needed { + let needed = (handler_count as usize) + .checked_mul(12) + .and_then(|v| v.checked_add(4)) + .ok_or_else(|| { + malformed_error!("EH small section size overflow") + })?; + let cursor_end = cursor.checked_add(needed).ok_or_else(|| { + malformed_error!("EH small cursor+needed overflow") + })?; + if handler_count == 0 || data.len() < cursor_end { break; } if !Self::check_handler_count(handler_count, eh_mode)? { break; } - cursor += 4; + cursor = cursor.checked_add(4).ok_or_else(|| { + malformed_error!("EH small cursor advance overflow") + })?; for handler_idx in 0..handler_count { let handler = ExceptionHandler { flags: ExceptionHandlerFlags::new(read_le_at::( @@ -638,14 +686,14 @@ impl MethodBody { /// The total serialized size in bytes. #[must_use] pub fn size(&self) -> usize { - let base_size = self.size_header + self.size_code; + let base_size = self.size_header.saturating_add(self.size_code); if self.exception_handlers.is_empty() { return base_size; } // Exception handlers require 4-byte alignment after the code - let aligned_base = (base_size + 3) & !3; + let aligned_base = base_size.saturating_add(3) & !3usize; // Determine if we need fat or small exception handler format let needs_fat_format = self.exception_handlers.iter().any(|h| { @@ -655,15 +703,14 @@ impl MethodBody { || h.handler_length > 0xFF }); - let section_size = if needs_fat_format { - // Fat format: 4-byte header + 24 bytes per handler - 4 + (self.exception_handlers.len() * 24) - } else { - // Small format: 4-byte header + 12 bytes per handler - 4 + (self.exception_handlers.len() * 12) - }; + let per_handler = if needs_fat_format { 24 } else { 12 }; + let section_size = self + .exception_handlers + .len() + .saturating_mul(per_handler) + .saturating_add(4); - aligned_base + section_size + aligned_base.saturating_add(section_size) } /// Writes the method body to a writer. @@ -738,7 +785,9 @@ impl MethodBody { })?; let header_byte = (code_size_u8 << 2) | 0x02; writer.write_all(&[header_byte])?; - bytes_written += 1; + bytes_written = bytes_written + .checked_add(1) + .ok_or_else(|| malformed_error!("method body byte count overflow"))?; } else { // Fat format: 12-byte header let code_size_u32 = u32::try_from(code_size) @@ -764,26 +813,35 @@ impl MethodBody { writer.write_all(&max_stack_u16.to_le_bytes())?; writer.write_all(&code_size_u32.to_le_bytes())?; writer.write_all(&self.local_var_sig_token.to_le_bytes())?; - bytes_written += 12; + bytes_written = bytes_written + .checked_add(12) + .ok_or_else(|| malformed_error!("method body byte count overflow"))?; } // Write IL code writer.write_all(il_code)?; - bytes_written += code_size as u64; + bytes_written = bytes_written + .checked_add(code_size as u64) + .ok_or_else(|| malformed_error!("method body byte count overflow"))?; // Write exception handlers if present if has_exceptions { // Align to 4-byte boundary - let padding = (4 - (bytes_written % 4)) % 4; + let rem = bytes_written.checked_rem(4).unwrap_or(0); + let padding = 4u64.wrapping_sub(rem).checked_rem(4).unwrap_or(0); if padding > 0 { writer.write_all(&vec![0u8; padding as usize])?; - bytes_written += padding; + bytes_written = bytes_written + .checked_add(padding) + .ok_or_else(|| malformed_error!("method body byte count overflow"))?; } // Use the shared exception handler encoding let exception_bytes = encode_exception_handlers(&self.exception_handlers)?; writer.write_all(&exception_bytes)?; - bytes_written += exception_bytes.len() as u64; + bytes_written = bytes_written + .checked_add(exception_bytes.len() as u64) + .ok_or_else(|| malformed_error!("method body byte count overflow"))?; } Ok(bytes_written) diff --git a/dotscope/src/metadata/method/exceptions.rs b/dotscope/src/metadata/method/exceptions.rs index 38ef80c8..8363f26d 100644 --- a/dotscope/src/metadata/method/exceptions.rs +++ b/dotscope/src/metadata/method/exceptions.rs @@ -735,7 +735,7 @@ pub fn encode_exception_handlers(handlers: &[ExceptionHandler]) -> Result Result Iterator for InstructionIterator<'a> { let block = self.blocks.get(self.current_block)?; if self.current_instruction < block.instructions.len() { - let instruction = &block.instructions[self.current_instruction]; - self.current_instruction += 1; + let instruction = block.instructions.get(self.current_instruction)?; + self.current_instruction = self.current_instruction.saturating_add(1); return Some(instruction); } - self.current_block += 1; + self.current_block = self.current_block.saturating_add(1); self.current_instruction = 0; } } @@ -324,16 +324,18 @@ impl<'a> Iterator for InstructionIterator<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` fn size_hint(&self) -> (usize, Option) { - let mut remaining = 0; + let mut remaining: usize = 0; for i in self.current_block..self.blocks.len() { if let Some(block) = self.blocks.get(i) { if i == self.current_block { - remaining += block - .instructions - .len() - .saturating_sub(self.current_instruction); + remaining = remaining.saturating_add( + block + .instructions + .len() + .saturating_sub(self.current_instruction), + ); } else { - remaining += block.instructions.len(); + remaining = remaining.saturating_add(block.instructions.len()); } } } diff --git a/dotscope/src/metadata/method/mod.rs b/dotscope/src/metadata/method/mod.rs index 4a1cd747..0aab95ea 100644 --- a/dotscope/src/metadata/method/mod.rs +++ b/dotscope/src/metadata/method/mod.rs @@ -259,18 +259,19 @@ impl MethodRef { self.weak_ref.upgrade() } - /// Get a strong reference to the method, panicking if the method has been dropped. + /// Get a strong reference to the method, returning an error if it has been dropped. /// - /// Use this when you're certain the method should still exist. This provides - /// a convenient way to access the method without handling the `Option` case. + /// Use this when you expect the method to still exist. The provided message is + /// included in the returned error if the underlying weak reference can no longer + /// be upgraded. /// /// # Arguments /// - /// * `msg` - Error message to display if the method has been dropped + /// * `msg` - Diagnostic message included in the returned error if the method has been dropped /// - /// # Panics + /// # Errors /// - /// Panics if the method has been dropped and the weak reference cannot be upgraded. + /// Returns [`crate::Error::Malformed`] if the underlying method has been dropped. /// /// # Examples /// @@ -282,16 +283,16 @@ impl MethodRef { /// if let Some(entry) = assembly.methods().iter().next() { /// let method = entry.value(); /// let method_ref = dotscope::metadata::method::MethodRef::new(&method); - /// - /// // Use expect when you're certain the method should exist - /// let method = method_ref.expect("Method should still be available"); + /// + /// let method = method_ref.expect("Method should still be available")?; /// println!("Accessed method: {}", method.name); /// } /// # Ok::<(), dotscope::Error>(()) /// ``` - #[must_use] - pub fn expect(&self, msg: &str) -> MethodRc { - self.weak_ref.upgrade().expect(msg) + pub fn expect(&self, msg: &str) -> Result { + self.weak_ref + .upgrade() + .ok_or_else(|| malformed_error!("{}", msg)) } /// Check if the referenced method is still alive. @@ -1166,7 +1167,11 @@ impl Method { )); } - let mut body = MethodBody::from(&file.data()[method_offset..])?; + let mut body = MethodBody::from( + file.data() + .get(method_offset..) + .ok_or(out_of_bounds_error!())?, + )?; if body.local_var_sig_token != 0 { self.parse_local_variables(&body, blobs, sigs, types)?; } @@ -1291,7 +1296,11 @@ impl Method { )?; } else { // Regular parameters (1-indexed in metadata) - let index = (parameter.sequence - 1) as usize; + let index = parameter + .sequence + .checked_sub(1) + .ok_or_else(|| malformed_error!("Parameter sequence underflow"))? + as usize; if let Some(param_signature) = self.signature.params.get(index) { parameter.apply_signature( param_signature, @@ -1450,7 +1459,11 @@ impl Method { })?; // Compute argument and local counts - let num_args = self.signature.params.len() + usize::from(self.signature.has_this); + let num_args = self + .signature + .params + .len() + .saturating_add(usize::from(self.signature.has_this)); let declared_locals = self.local_vars.count(); // If no locals are declared (e.g., Tiny Format method body), infer from IL usage. @@ -1476,68 +1489,73 @@ impl Method { // Block offsets are absolute file offsets. Get base offset from first block. let base_offset = blocks.first().map_or(0, |b| b.offset); - let ssa_handlers: Vec = body - .exception_handlers - .iter() - .enumerate() - .map(|(handler_idx, eh)| { - // Map offsets to block indices (add base_offset to convert relative to absolute) - let try_start_block = Self::find_block_at_offset( - blocks, - base_offset + eh.try_offset as usize, - ); - let try_end_block = Self::find_block_at_offset( - blocks, - base_offset + (eh.try_offset + eh.try_length) as usize, - ); - - // For handler blocks, use the handler_entry info from decoder if available - let handler_start_block = - Self::find_handler_entry_block(blocks, handler_idx).or_else(|| { - Self::find_block_at_offset( - blocks, - base_offset + eh.handler_offset as usize, - ) - }); - let handler_end_block = Self::find_block_at_offset( - blocks, - base_offset + (eh.handler_offset + eh.handler_length) as usize, - ); - - let filter_start_block = if eh.flags == ExceptionHandlerFlags::FILTER { - Self::find_block_at_offset( - blocks, - base_offset + eh.filter_offset as usize, - ) - } else { - None - }; - - // Get the class token for catch handlers - let class_token_or_filter = if eh.flags == ExceptionHandlerFlags::EXCEPTION - { - eh.handler.as_ref().map_or(0, |t| t.token.value()) - } else if eh.flags == ExceptionHandlerFlags::FILTER { - eh.filter_offset - } else { - 0 - }; - - SsaExceptionHandler { - flags: eh.flags, - try_offset: eh.try_offset, - try_length: eh.try_length, - handler_offset: eh.handler_offset, - handler_length: eh.handler_length, - class_token_or_filter, - try_start_block, - try_end_block, - handler_start_block, - handler_end_block, - filter_start_block, - } - }) - .collect(); + let mut ssa_handlers: Vec = + Vec::with_capacity(body.exception_handlers.len()); + for (handler_idx, eh) in body.exception_handlers.iter().enumerate() { + let try_offset_abs = base_offset + .checked_add(eh.try_offset as usize) + .ok_or_else(|| malformed_error!("Exception handler try_offset overflow"))?; + let try_end_rel = (eh.try_offset as usize) + .checked_add(eh.try_length as usize) + .ok_or_else(|| malformed_error!("Exception handler try region overflow"))?; + let try_end_abs = base_offset + .checked_add(try_end_rel) + .ok_or_else(|| malformed_error!("Exception handler try end overflow"))?; + let handler_offset_abs = base_offset + .checked_add(eh.handler_offset as usize) + .ok_or_else(|| { + malformed_error!("Exception handler handler_offset overflow") + })?; + let handler_end_rel = (eh.handler_offset as usize) + .checked_add(eh.handler_length as usize) + .ok_or_else(|| malformed_error!("Exception handler region overflow"))?; + let handler_end_abs = base_offset + .checked_add(handler_end_rel) + .ok_or_else(|| malformed_error!("Exception handler end overflow"))?; + + // Map offsets to block indices (add base_offset to convert relative to absolute) + let try_start_block = Self::find_block_at_offset(blocks, try_offset_abs); + let try_end_block = Self::find_block_at_offset(blocks, try_end_abs); + + // For handler blocks, use the handler_entry info from decoder if available + let handler_start_block = Self::find_handler_entry_block(blocks, handler_idx) + .or_else(|| Self::find_block_at_offset(blocks, handler_offset_abs)); + let handler_end_block = Self::find_block_at_offset(blocks, handler_end_abs); + + let filter_start_block = if eh.flags == ExceptionHandlerFlags::FILTER { + let filter_offset_abs = base_offset + .checked_add(eh.filter_offset as usize) + .ok_or_else(|| { + malformed_error!("Exception handler filter_offset overflow") + })?; + Self::find_block_at_offset(blocks, filter_offset_abs) + } else { + None + }; + + // Get the class token for catch handlers + let class_token_or_filter = if eh.flags == ExceptionHandlerFlags::EXCEPTION { + eh.handler.as_ref().map_or(0, |t| t.token.value()) + } else if eh.flags == ExceptionHandlerFlags::FILTER { + eh.filter_offset + } else { + 0 + }; + + ssa_handlers.push(SsaExceptionHandler { + flags: eh.flags, + try_offset: eh.try_offset, + try_length: eh.try_length, + handler_offset: eh.handler_offset, + handler_length: eh.handler_length, + class_token_or_filter, + try_start_block, + try_end_block, + handler_start_block, + handler_end_block, + filter_start_block, + }); + } ssa.set_exception_handlers(ssa_handlers); } @@ -1596,7 +1614,7 @@ impl Method { } } - max_index.map_or(0, |i| i + 1) + max_index.map_or(0, |i| i.saturating_add(1)) } /// Finds the block index that starts at or contains the given offset. @@ -1609,7 +1627,7 @@ impl Method { // If no exact match, find block containing the offset blocks .iter() - .position(|b| offset >= b.offset && offset < b.offset + b.size) + .position(|b| offset >= b.offset && offset < b.offset.saturating_add(b.size)) } /// Finds the block that is marked as an entry point for the given handler index. @@ -1735,7 +1753,12 @@ impl Method { // If fat format with exceptions, align to 4 bytes and add exception handlers if has_exceptions { // Align to 4-byte boundary - let padding = (4 - (result.len() % 4)) % 4; + let remainder = result.len() % 4; + let padding = if remainder == 0 { + 0 + } else { + 4usize.saturating_sub(remainder) + }; result.extend(std::iter::repeat_n(0u8, padding)); // Encode and append exception handlers @@ -1765,7 +1788,8 @@ impl Method { for block in blocks { for instruction in &block.instructions { // instruction.size is u64, safely convert to usize - code_size += usize::try_from(instruction.size).ok()?; + let size = usize::try_from(instruction.size).ok()?; + code_size = code_size.checked_add(size)?; } } @@ -1774,14 +1798,14 @@ impl Method { let has_locals = body.local_var_sig_token != 0; let needs_fat = code_size > 63 || body.max_stack > 8 || has_locals || has_exceptions; - let header_size = if needs_fat { 12 } else { 1 }; + let header_size: usize = if needs_fat { 12 } else { 1 }; - let mut total = header_size + code_size; + let mut total = header_size.checked_add(code_size)?; // Add exception handler section size if needed if has_exceptions { // Align to 4 bytes - total = (total + 3) & !3; + total = total.checked_add(3)? & !3; // Calculate exception section size let handler_count = body.exception_handlers.len(); @@ -1794,13 +1818,12 @@ impl Method { || h.handler_length > 0xFF }); - if needs_fat_exceptions { - // Fat format: 4-byte header + 24 bytes per handler - total += 4 + handler_count * 24; - } else { - // Small format: 4-byte header + 12 bytes per handler - total += 4 + handler_count * 12; - } + let per_handler = if needs_fat_exceptions { 24 } else { 12 }; + // Fat format: 4-byte header + 24 bytes per handler + // Small format: 4-byte header + 12 bytes per handler + let handlers_size = handler_count.checked_mul(per_handler)?; + let section_size = handlers_size.checked_add(4)?; + total = total.checked_add(section_size)?; } Some(total) diff --git a/dotscope/src/metadata/resources/encoder.rs b/dotscope/src/metadata/resources/encoder.rs index 254d828a..3abc2962 100644 --- a/dotscope/src/metadata/resources/encoder.rs +++ b/dotscope/src/metadata/resources/encoder.rs @@ -866,9 +866,19 @@ impl DotNetResourceEncoder { buffer.extend_from_slice(resource_set_type.as_bytes()); // Calculate header size and update placeholder - let header_size = buffer.len() - header_size_pos - 4; + let header_size = buffer + .len() + .checked_sub(header_size_pos) + .and_then(|v| v.checked_sub(4)) + .ok_or_else(|| malformed_error!("Resource header size underflow"))?; let header_size_bytes = to_u32(header_size)?.to_le_bytes(); - buffer[header_size_pos..header_size_pos + 4].copy_from_slice(&header_size_bytes); + let header_end = header_size_pos + .checked_add(4) + .ok_or_else(|| malformed_error!("Resource header position overflow"))?; + buffer + .get_mut(header_size_pos..header_end) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&header_size_bytes); // Runtime Resource Reader Header buffer.extend_from_slice(&self.version.to_le_bytes()); // RR version @@ -903,15 +913,26 @@ impl DotNetResourceEncoder { let mut name_section_layout = Vec::new(); let mut name_offset = 0u32; for (_, resource_index) in &name_hashes { - let (name, _) = &self.resources[*resource_index]; + let (name, _) = self + .resources + .get(*resource_index) + .ok_or(out_of_bounds_error!())?; let name_utf16: Vec = name.encode_utf16().collect(); - let byte_count = name_utf16.len() * 2; + let byte_count = name_utf16 + .len() + .checked_mul(2) + .ok_or_else(|| malformed_error!("Resource name byte count overflow"))?; #[allow(clippy::cast_possible_truncation)] // compressed_uint_size returns at most 4 let len_size = compressed_uint_size(byte_count) as u32; - let entry_size = len_size + to_u32(byte_count)? + 4; + let entry_size = len_size + .checked_add(to_u32(byte_count)?) + .and_then(|v| v.checked_add(4)) + .ok_or_else(|| malformed_error!("Resource entry size overflow"))?; name_section_layout.push(name_offset); - name_offset += entry_size; + name_offset = name_offset + .checked_add(entry_size) + .ok_or_else(|| malformed_error!("Resource name offset overflow"))?; } // Write position table (in sorted hash order) @@ -923,7 +944,10 @@ impl DotNetResourceEncoder { let mut data_offsets = Vec::new(); let mut data_offset = 0u32; for (_, resource_index) in &name_hashes { - let (_, resource_type) = &self.resources[*resource_index]; + let (_, resource_type) = self + .resources + .get(*resource_index) + .ok_or(out_of_bounds_error!())?; data_offsets.push(data_offset); @@ -936,7 +960,10 @@ impl DotNetResourceEncoder { }; let data_size = resource_type.data_size().ok_or(Error::NotSupported)?; - data_offset += type_code_size + data_size; + data_offset = data_offset + .checked_add(type_code_size) + .and_then(|v| v.checked_add(data_size)) + .ok_or_else(|| malformed_error!("Resource data offset overflow"))?; } // Reserve space for data section offset - we'll update it after writing the name section @@ -945,9 +972,15 @@ impl DotNetResourceEncoder { // Write resource names and data offsets (in sorted hash order) for (i, (_, resource_index)) in name_hashes.iter().enumerate() { - let (name, _) = &self.resources[*resource_index]; + let (name, _) = self + .resources + .get(*resource_index) + .ok_or(out_of_bounds_error!())?; let name_utf16: Vec = name.encode_utf16().collect(); - let byte_count = name_utf16.len() * 2; + let byte_count = name_utf16 + .len() + .checked_mul(2) + .ok_or_else(|| malformed_error!("Resource name byte count overflow"))?; // Write byte count, not character count write_compressed_uint(to_u32(byte_count)?, &mut buffer); @@ -956,7 +989,8 @@ impl DotNetResourceEncoder { buffer.extend_from_slice(&utf16_char.to_le_bytes()); } - buffer.extend_from_slice(&data_offsets[i].to_le_bytes()); + let data_offset_value = data_offsets.get(i).ok_or(out_of_bounds_error!())?; + buffer.extend_from_slice(&data_offset_value.to_le_bytes()); } // Calculate the actual data section offset following Microsoft's ResourceWriter exactly @@ -964,18 +998,35 @@ impl DotNetResourceEncoder { // Standard .NET convention: offset is relative to magic number position, requiring +4 adjustment in parser // For embedded resources, we need to be careful about the offset calculation // The offset should point to where the data actually starts in the file - let actual_data_section_offset = buffer.len() - 4; // -4 to account for size prefix + let actual_data_section_offset = buffer + .len() + .checked_sub(4) + .ok_or_else(|| malformed_error!("Resource data section offset underflow"))?; // -4 to account for size prefix let data_section_offset_value = to_u32(actual_data_section_offset)?.to_le_bytes(); - buffer[data_section_offset_pos..data_section_offset_pos + 4] + let data_section_end = data_section_offset_pos + .checked_add(4) + .ok_or_else(|| malformed_error!("Resource data section position overflow"))?; + buffer + .get_mut(data_section_offset_pos..data_section_end) + .ok_or(out_of_bounds_error!())? .copy_from_slice(&data_section_offset_value); // Write resource data (in sorted hash order) self.write_resource_data_sorted(&mut buffer, &name_hashes)?; // Update the size field at the beginning - let total_size = buffer.len() - 4; // Exclude the size field itself + let total_size = buffer + .len() + .checked_sub(4) + .ok_or_else(|| malformed_error!("Resource total size underflow"))?; // Exclude the size field itself let size_bytes = to_u32(total_size)?.to_le_bytes(); - buffer[size_placeholder_pos..size_placeholder_pos + 4].copy_from_slice(&size_bytes); + let size_end = size_placeholder_pos + .checked_add(4) + .ok_or_else(|| malformed_error!("Resource size placeholder position overflow"))?; + buffer + .get_mut(size_placeholder_pos..size_end) + .ok_or(out_of_bounds_error!())? + .copy_from_slice(&size_bytes); Ok(buffer) } @@ -1024,7 +1075,10 @@ impl DotNetResourceEncoder { name_hashes: &[(u32, usize)], ) -> Result<()> { for (_, resource_index) in name_hashes { - let (_, resource_type) = &self.resources[*resource_index]; + let (_, resource_type) = self + .resources + .get(*resource_index) + .ok_or(out_of_bounds_error!())?; // Use Microsoft's ResourceTypeCode enum values exactly let type_code = match resource_type { diff --git a/dotscope/src/metadata/resources/parser.rs b/dotscope/src/metadata/resources/parser.rs index 9cad18ea..ae828a05 100644 --- a/dotscope/src/metadata/resources/parser.rs +++ b/dotscope/src/metadata/resources/parser.rs @@ -527,7 +527,8 @@ impl Resource { if second_u32 == RESOURCE_MAGIC { // Embedded resource format: [size][magic][header...] let size = first_u32 as usize; - if size > (data.len() - 4) || size < 8 { + let max_size = data.len().saturating_sub(4); + if size > max_size || size < 8 { return Err(malformed_error!("Invalid embedded resource size: {}", size)); } Ok(true) @@ -640,7 +641,7 @@ impl Resource { } // Check for debug string in V2 debug builds ("***DEBUG***") - if res.rr_version == 2 && (data.len() - parser.pos()) >= 11 { + if res.rr_version == 2 && data.len().saturating_sub(parser.pos()) >= 11 { res.is_debug = Self::try_parse_debug_marker(parser); } @@ -745,18 +746,23 @@ impl Resource { /// /// The total number of padding bytes skipped (alignment + PAD patterns). fn skip_padding(parser: &mut Parser, data: &[u8]) -> Result { - let mut padding_count = 0; + let mut padding_count: usize = 0; // Standard 8-byte alignment let align_bytes = parser.pos() & 7; if align_bytes != 0 { - let padding_to_skip = 8 - align_bytes; - padding_count += padding_to_skip; + let padding_to_skip = 8usize.wrapping_sub(align_bytes); + padding_count = padding_count + .checked_add(padding_to_skip) + .ok_or_else(|| malformed_error!("padding count overflow"))?; parser.advance_by(padding_to_skip)?; } // Check for additional explicit PAD patterns (some .NET implementations add these) - padding_count += Self::skip_pad_patterns(parser, data)?; + let pad_pattern_count = Self::skip_pad_patterns(parser, data)?; + padding_count = padding_count + .checked_add(pad_pattern_count) + .ok_or_else(|| malformed_error!("padding count overflow"))?; Ok(padding_count) } @@ -784,11 +790,17 @@ impl Resource { /// /// The total number of PAD pattern bytes skipped. fn skip_pad_patterns(parser: &mut Parser, data: &[u8]) -> Result { - let mut padding_count = 0; + let mut padding_count: usize = 0; - while parser.pos() + 4 <= data.len() { + loop { let pos = parser.pos(); - let remaining = data.len() - pos; + let Some(end4) = pos.checked_add(4) else { + break; + }; + if end4 > data.len() { + break; + } + let remaining = data.len().saturating_sub(pos); // Need at least 3 bytes to check for "PAD" if remaining < 3 { @@ -796,16 +808,23 @@ impl Resource { } // Check for "PAD" pattern - if data[pos] == b'P' && data[pos + 1] == b'A' && data[pos + 2] == b'D' { + let p0 = data.get(pos).copied(); + let p1 = pos.checked_add(1).and_then(|i| data.get(i).copied()); + let p2 = pos.checked_add(2).and_then(|i| data.get(i).copied()); + if p0 == Some(b'P') && p1 == Some(b'A') && p2 == Some(b'D') { parser.advance_by(3)?; - padding_count += 3; + padding_count = padding_count + .checked_add(3) + .ok_or_else(|| malformed_error!("PAD pattern padding overflow"))?; // Check for additional padding byte after PAD ('P' or '\0') if parser.pos() < data.len() { - let next_byte = data[parser.pos()]; - if next_byte == b'P' || next_byte == 0 { + let next_byte = data.get(parser.pos()).copied(); + if next_byte == Some(b'P') || next_byte == Some(0) { parser.advance()?; - padding_count += 1; + padding_count = padding_count + .checked_add(1) + .ok_or_else(|| malformed_error!("PAD pattern padding overflow"))?; } } } else { @@ -942,7 +961,15 @@ impl Resource { let mut parser = Parser::new(data); for i in 0..count { - let name_pos = self.name_section_offset + self.name_positions[i] as usize; + let name_pos_offset = *self + .name_positions + .get(i) + .ok_or_else(|| malformed_error!("name_positions index {} out of bounds", i))? + as usize; + let name_pos = self + .name_section_offset + .checked_add(name_pos_offset) + .ok_or_else(|| malformed_error!("name position overflow"))?; parser.seek(name_pos)?; let name = parser.read_prefixed_string_utf16()?; @@ -950,10 +977,15 @@ impl Resource { let data_pos = if self.is_embedded_resource { // Embedded resources: offset calculated from magic number position, need +4 for size field - self.data_section_offset + type_offset as usize + 4 + self.data_section_offset + .checked_add(type_offset as usize) + .and_then(|v| v.checked_add(4)) + .ok_or_else(|| malformed_error!("embedded resource data position overflow"))? } else { // Standalone .resources files: use direct offset - self.data_section_offset + type_offset as usize + self.data_section_offset + .checked_add(type_offset as usize) + .ok_or_else(|| malformed_error!("standalone resource data position overflow"))? }; // Validate data position bounds @@ -972,8 +1004,7 @@ impl Resource { if type_index == u32::MAX { // -1 encoded as 7-bit represents null ResourceType::Null - } else if (type_index as usize) < self.type_names.len() { - let type_name = &self.type_names[type_index as usize]; + } else if let Some(type_name) = self.type_names.get(type_index as usize) { ResourceType::from_type_name(type_name, &mut parser)? } else { return Err(malformed_error!("Invalid type index: {}", type_index)); @@ -987,20 +1018,20 @@ impl Resource { // No type table - this file uses only primitive types (direct type codes) // Common in resource files that contain only strings/primitives ResourceType::from_type_byte(type_code, &mut parser)? + } else if let Some(type_name) = self.type_names.get(type_code as usize) { + ResourceType::from_type_name(type_name, &mut parser)? } else { - // Has type table - type code is an index into the type table - if (type_code as usize) < self.type_names.len() { - let type_name = &self.type_names[type_code as usize]; - ResourceType::from_type_name(type_name, &mut parser)? - } else { - return Err(malformed_error!("Invalid type index: {}", type_code)); - } + return Err(malformed_error!("Invalid type index: {}", type_code)); } }; + let name_hash = *self + .name_hashes + .get(i) + .ok_or_else(|| malformed_error!("name_hashes index {} out of bounds", i))?; let result = ResourceEntry { name: name.clone(), - name_hash: self.name_hashes[i], + name_hash, data: resource_data, }; @@ -1133,18 +1164,29 @@ impl Resource { let mut parser = Parser::new(data); for i in 0..count { - let name_pos = self.name_section_offset + self.name_positions[i] as usize; + let name_pos_offset = *self + .name_positions + .get(i) + .ok_or_else(|| malformed_error!("name_positions index {} out of bounds", i))? + as usize; + let name_pos = self + .name_section_offset + .checked_add(name_pos_offset) + .ok_or_else(|| malformed_error!("name position overflow"))?; parser.seek(name_pos)?; let name = parser.read_prefixed_string_utf16()?; let type_offset = parser.read_le::()?; let data_pos = if self.is_embedded_resource { - // Embedded resources: offset calculated from magic number position, need +4 for size field - self.data_section_offset + type_offset as usize + 4 + self.data_section_offset + .checked_add(type_offset as usize) + .and_then(|v| v.checked_add(4)) + .ok_or_else(|| malformed_error!("embedded resource data position overflow"))? } else { - // Standalone .resources files: use direct offset - self.data_section_offset + type_offset as usize + self.data_section_offset + .checked_add(type_offset as usize) + .ok_or_else(|| malformed_error!("standalone resource data position overflow"))? }; // Validate data position bounds @@ -1163,8 +1205,7 @@ impl Resource { if type_index == u32::MAX { // -1 encoded as 7-bit represents null ResourceTypeRef::Null - } else if (type_index as usize) < self.type_names.len() { - let type_name = &self.type_names[type_index as usize]; + } else if let Some(type_name) = self.type_names.get(type_index as usize) { ResourceTypeRef::from_type_name_ref(type_name, &mut parser, data)? } else { return Err(malformed_error!("Invalid type index: {}", type_index)); @@ -1177,20 +1218,20 @@ impl Resource { if self.type_names.is_empty() { // No type table - this file uses only primitive types (direct type codes) ResourceTypeRef::from_type_byte_ref(type_code, &mut parser, data)? + } else if let Some(type_name) = self.type_names.get(type_code as usize) { + ResourceTypeRef::from_type_name_ref(type_name, &mut parser, data)? } else { - // Has type table - type code is an index into the type table - if (type_code as usize) < self.type_names.len() { - let type_name = &self.type_names[type_code as usize]; - ResourceTypeRef::from_type_name_ref(type_name, &mut parser, data)? - } else { - return Err(malformed_error!("Invalid type index: {}", type_code)); - } + return Err(malformed_error!("Invalid type index: {}", type_code)); } }; + let name_hash = *self + .name_hashes + .get(i) + .ok_or_else(|| malformed_error!("name_hashes index {} out of bounds", i))?; let result = ResourceEntryRef { name: name.clone(), - name_hash: self.name_hashes[i], + name_hash, data: resource_data, }; diff --git a/dotscope/src/metadata/resources/types.rs b/dotscope/src/metadata/resources/types.rs index 176827bc..fb152d3e 100644 --- a/dotscope/src/metadata/resources/types.rs +++ b/dotscope/src/metadata/resources/types.rs @@ -457,7 +457,7 @@ impl ResourceType { let utf8_byte_count = s.len(); let utf8_size = u32::try_from(utf8_byte_count).ok()?; let prefix_size = u32::try_from(compressed_uint_size(utf8_size as usize)).ok()?; - Some(prefix_size + utf8_size) + prefix_size.checked_add(utf8_size) } ResourceType::Boolean(_) | ResourceType::Byte(_) | ResourceType::SByte(_) => Some(1), // Single byte ResourceType::Char(_) | ResourceType::Int16(_) | ResourceType::UInt16(_) => Some(2), // 2 bytes @@ -474,7 +474,7 @@ impl ResourceType { // Per .NET specification: ByteArray and Stream use fixed 4-byte LE length // Note: Type code is NOT included here - encoder adds type_code_size separately let data_size = u32::try_from(data.len()).ok()?; - Some(4 + data_size) + 4_u32.checked_add(data_size) } // Types without .NET equivalents ResourceType::Null | ResourceType::StartOfUserTypes => None, @@ -596,13 +596,19 @@ impl ResourceType { 0x20 => { let length = parser.read_le::()?; let start_pos = parser.pos(); - let end_pos = start_pos + length as usize; - - if end_pos > parser.data().len() { - return Err(out_of_bounds_error!()); - } - - let data = parser.data()[start_pos..end_pos].to_vec(); + let end_pos = start_pos.checked_add(length as usize).ok_or_else(|| { + malformed_error!( + "ByteArray resource end offset overflow: start={} length={}", + start_pos, + length + ) + })?; + + let data = parser + .data() + .get(start_pos..end_pos) + .ok_or(out_of_bounds_error!())? + .to_vec(); if end_pos < parser.data().len() { parser.seek(end_pos)?; } @@ -611,13 +617,19 @@ impl ResourceType { 0x21 => { let length = parser.read_le::()?; let start_pos = parser.pos(); - let end_pos = start_pos + length as usize; - - if end_pos > parser.data().len() { - return Err(out_of_bounds_error!()); - } - - let data = parser.data()[start_pos..end_pos].to_vec(); + let end_pos = start_pos.checked_add(length as usize).ok_or_else(|| { + malformed_error!( + "Stream resource end offset overflow: start={} length={}", + start_pos, + length + ) + })?; + + let data = parser + .data() + .get(start_pos..end_pos) + .ok_or(out_of_bounds_error!())? + .to_vec(); if end_pos < parser.data().len() { parser.seek(end_pos)?; } @@ -1032,33 +1044,41 @@ impl<'a> ResourceTypeRef<'a> { 0x20 => { let length = parser.read_le::()?; let start_pos = parser.pos(); - let end_pos = start_pos + length as usize; + let end_pos = start_pos.checked_add(length as usize).ok_or_else(|| { + malformed_error!( + "ByteArray resource end offset overflow: start={} length={}", + start_pos, + length + ) + })?; - if end_pos > data.len() { - return Err(out_of_bounds_error!()); - } + let slice = data.get(start_pos..end_pos).ok_or(out_of_bounds_error!())?; if end_pos < data.len() { parser.seek(end_pos)?; } - Ok(ResourceTypeRef::ByteArray(&data[start_pos..end_pos])) + Ok(ResourceTypeRef::ByteArray(slice)) } 0x21 => { let length = parser.read_le::()?; let start_pos = parser.pos(); - let end_pos = start_pos + length as usize; + let end_pos = start_pos.checked_add(length as usize).ok_or_else(|| { + malformed_error!( + "Stream resource end offset overflow: start={} length={}", + start_pos, + length + ) + })?; - if end_pos > data.len() { - return Err(out_of_bounds_error!()); - } + let slice = data.get(start_pos..end_pos).ok_or(out_of_bounds_error!())?; if end_pos < data.len() { parser.seek(end_pos)?; } // Stream uses same format as ByteArray, just different type code - Ok(ResourceTypeRef::Stream(&data[start_pos..end_pos])) + Ok(ResourceTypeRef::Stream(slice)) } 0x40..=0xFF => { // User types - these require a type table for resolution diff --git a/dotscope/src/metadata/root.rs b/dotscope/src/metadata/root.rs index b992f4c8..f902c6ed 100644 --- a/dotscope/src/metadata/root.rs +++ b/dotscope/src/metadata/root.rs @@ -384,10 +384,10 @@ impl Root { let mut version_string: String = String::with_capacity(version_string_length as usize); for counter in 0..version_string_length { - version_string.push(char::from(read_le_at::( - data, - &mut (usize::from(VERSION_STRING_OFFSET) + counter as usize), - )?)); + let mut pos = usize::from(VERSION_STRING_OFFSET) + .checked_add(counter as usize) + .ok_or_else(|| malformed_error!("version string offset overflow"))?; + version_string.push(char::from(read_le_at::(data, &mut pos)?)); } // Validate version string format and content @@ -415,15 +415,18 @@ impl Root { } // Stream count is located after: version_string + FLAGS_FIELD_SIZE - let mut stream_count_offset = - version_string.len() + usize::from(VERSION_STRING_OFFSET) + FLAGS_FIELD_SIZE; + let mut stream_count_offset = version_string + .len() + .checked_add(usize::from(VERSION_STRING_OFFSET)) + .and_then(|v| v.checked_add(FLAGS_FIELD_SIZE)) + .ok_or_else(|| malformed_error!("stream count offset overflow"))?; let stream_count = read_le_at::(data, &mut stream_count_offset)?; // Validate stream count: must have at least one stream, no more than MAX_STREAM_COUNT - if stream_count == 0 - || stream_count > MAX_STREAM_COUNT - || (stream_count as usize * MIN_STREAM_HEADER_SIZE) > data.len() - { + let stream_count_size = (stream_count as usize) + .checked_mul(MIN_STREAM_HEADER_SIZE) + .ok_or_else(|| malformed_error!("stream count size overflow"))?; + if stream_count == 0 || stream_count > MAX_STREAM_COUNT || stream_count_size > data.len() { return Err(malformed_error!( "Root: invalid stream count {} (must be 1-{}) [ECMA-335 §II.24.2.1]", stream_count, @@ -433,10 +436,12 @@ impl Root { let mut streams = Vec::with_capacity(stream_count as usize); // Stream directory starts after: version_string + FLAGS_FIELD_SIZE + STREAM_COUNT_FIELD_SIZE - let mut stream_offset = version_string.len() - + usize::from(VERSION_STRING_OFFSET) - + FLAGS_FIELD_SIZE - + STREAM_COUNT_FIELD_SIZE; + let mut stream_offset = version_string + .len() + .checked_add(usize::from(VERSION_STRING_OFFSET)) + .and_then(|v| v.checked_add(FLAGS_FIELD_SIZE)) + .and_then(|v| v.checked_add(STREAM_COUNT_FIELD_SIZE)) + .ok_or_else(|| malformed_error!("stream directory offset overflow"))?; let mut streams_seen = [false; MAX_STREAM_COUNT as usize]; for _i in 0..stream_count { @@ -444,7 +449,8 @@ impl Root { return Err(out_of_bounds_error!()); } - let new_stream = StreamHeader::from(&data[stream_offset..])?; + let stream_data = data.get(stream_offset..).ok_or(out_of_bounds_error!())?; + let new_stream = StreamHeader::from(stream_data)?; if new_stream.offset as usize > data.len() || new_stream.size as usize > data.len() || new_stream.name.len() > MAX_STREAM_NAME_LENGTH @@ -469,25 +475,44 @@ impl Root { } let stream_index = match new_stream.name.as_str() { - "#Strings" => 0, + "#Strings" => 0usize, "#US" => 1, "#Blob" => 2, "#GUID" => 3, "#~" => 4, "#-" => 5, - _ => unreachable!("StreamHeader::from() should have validated the name"), + _ => { + return Err(malformed_error!( + "Root: unrecognized stream name '{}' [ECMA-335 §II.24.2.2]", + new_stream.name + )) + } }; - if streams_seen[stream_index] { + if *streams_seen + .get(stream_index) + .ok_or(out_of_bounds_error!())? + { return Err(malformed_error!( "Root: duplicate stream '{}' found [ECMA-335 §II.24.2.2]", new_stream.name )); } - streams_seen[stream_index] = true; - - let name_aligned = ((new_stream.name.len() + 1) + 3) & !3; - stream_offset += STREAM_HEADER_FIXED_SIZE + name_aligned; + *streams_seen + .get_mut(stream_index) + .ok_or(out_of_bounds_error!())? = true; + + let name_aligned = new_stream + .name + .len() + .checked_add(1) + .and_then(|v| v.checked_add(3)) + .ok_or_else(|| malformed_error!("stream name alignment overflow"))? + & !3usize; + stream_offset = stream_offset + .checked_add(STREAM_HEADER_FIXED_SIZE) + .and_then(|v| v.checked_add(name_aligned)) + .ok_or_else(|| malformed_error!("stream offset overflow"))?; streams.push(new_stream); } @@ -498,17 +523,28 @@ impl Root { )); } + let flags_offset = usize::from(VERSION_STRING_OFFSET) + .checked_add(version_string.len()) + .ok_or_else(|| malformed_error!("flags offset overflow"))?; + Ok(Root { signature, - major_version: read_le::(&data[FIELD_OFFSET_MAJOR_VERSION..])?, - minor_version: read_le::(&data[FIELD_OFFSET_MINOR_VERSION..])?, - reserved: read_le::(&data[FIELD_OFFSET_RESERVED..])?, + major_version: read_le::( + data.get(FIELD_OFFSET_MAJOR_VERSION..) + .ok_or(out_of_bounds_error!())?, + )?, + minor_version: read_le::( + data.get(FIELD_OFFSET_MINOR_VERSION..) + .ok_or(out_of_bounds_error!())?, + )?, + reserved: read_le::( + data.get(FIELD_OFFSET_RESERVED..) + .ok_or(out_of_bounds_error!())?, + )?, length: u32::try_from(version_string.len()).map_err(|_| { malformed_error!("Root: version string length too large [ECMA-335 §II.24.2.1]") })?, - flags: read_le::( - &data[usize::from(VERSION_STRING_OFFSET) + version_string.len()..], - )?, + flags: read_le::(data.get(flags_offset..).ok_or(out_of_bounds_error!())?)?, stream_number: u16::try_from(streams.len()) .map_err(|_| malformed_error!("Root: too many streams [ECMA-335 §II.24.2.1]"))?, stream_headers: streams, @@ -599,7 +635,8 @@ impl Root { } for (i, &(start1, end1, name1)) in stream_ranges.iter().enumerate() { - for &(start2, end2, name2) in stream_ranges.iter().skip(i + 1) { + let skip = i.saturating_add(1); + for &(start2, end2, name2) in stream_ranges.iter().skip(skip) { if start1 < end2 && start2 < end1 { return Err(malformed_error!( "Stream '{}' ({}..{}) overlaps with stream '{}' ({}..{})", @@ -689,7 +726,11 @@ impl Root { // Version string length (padded to 4-byte boundary) let version_bytes = self.version.as_bytes(); - let padded_len = (version_bytes.len() + 3) & !3; + let padded_len = version_bytes + .len() + .checked_add(3) + .ok_or_else(|| malformed_error!("Version string padded length overflow"))? + & !3usize; let padded_len_u32 = u32::try_from(padded_len).map_err(|_| { malformed_error!( "Version string padded length {} exceeds u32 range", @@ -700,7 +741,7 @@ impl Root { // Version string + null padding to 4-byte boundary writer.write_all(version_bytes)?; - let padding = padded_len - version_bytes.len(); + let padding = padded_len.saturating_sub(version_bytes.len()); if padding > 0 { writer.write_all(&vec![0u8; padding])?; } @@ -755,19 +796,21 @@ impl Root { #[must_use] pub fn serialized_size(&self) -> usize { // Fixed fields: signature(4) + major(2) + minor(2) + reserved(4) + length(4) + flags(2) + stream_number(2) - let fixed_size = 20; + let fixed_size = 20usize; // Version string padded to 4-byte boundary - let version_padded = (self.version.len() + 3) & !3; + let version_padded = self.version.len().saturating_add(3) & !3usize; // Stream headers let streams_size: usize = self .stream_headers .iter() .map(crate::StreamHeader::serialized_size) - .sum(); + .fold(0usize, usize::saturating_add); - fixed_size + version_padded + streams_size + fixed_size + .saturating_add(version_padded) + .saturating_add(streams_size) } } diff --git a/dotscope/src/metadata/security/encoder.rs b/dotscope/src/metadata/security/encoder.rs index a6937dad..2b86fbc6 100644 --- a/dotscope/src/metadata/security/encoder.rs +++ b/dotscope/src/metadata/security/encoder.rs @@ -345,27 +345,27 @@ impl PermissionSetEncoder { if !string_table.contains_key(&permission.class_name) { string_table.insert(permission.class_name.clone(), next_index); string_list.push(permission.class_name.clone()); - next_index += 1; + next_index = next_index.saturating_add(1); } if !string_table.contains_key(&permission.assembly_name) { string_table.insert(permission.assembly_name.clone(), next_index); string_list.push(permission.assembly_name.clone()); - next_index += 1; + next_index = next_index.saturating_add(1); } for arg in &permission.named_arguments { if !string_table.contains_key(&arg.name) { string_table.insert(arg.name.clone(), next_index); string_list.push(arg.name.clone()); - next_index += 1; + next_index = next_index.saturating_add(1); } if let ArgumentValue::String(ref value) = arg.value { if !string_table.contains_key(value) { string_table.insert(value.clone(), next_index); string_list.push(value.clone()); - next_index += 1; + next_index = next_index.saturating_add(1); } } } @@ -380,15 +380,21 @@ impl PermissionSetEncoder { write_compressed_uint(to_u32(permissions.len())?, &mut self.buffer); for permission in permissions { - let class_name_index = string_table[&permission.class_name]; - let assembly_name_index = string_table[&permission.assembly_name]; + let class_name_index = *string_table + .get(&permission.class_name) + .ok_or_else(|| malformed_error!("Class name not in string table"))?; + let assembly_name_index = *string_table + .get(&permission.assembly_name) + .ok_or_else(|| malformed_error!("Assembly name not in string table"))?; write_compressed_uint(class_name_index, &mut self.buffer); write_compressed_uint(assembly_name_index, &mut self.buffer); write_compressed_uint(to_u32(permission.named_arguments.len())?, &mut self.buffer); for arg in &permission.named_arguments { - let name_index = string_table[&arg.name]; + let name_index = *string_table + .get(&arg.name) + .ok_or_else(|| malformed_error!("Argument name not in string table"))?; write_compressed_uint(name_index, &mut self.buffer); @@ -456,7 +462,9 @@ impl PermissionSetEncoder { self.buffer.extend_from_slice(&value.to_le_bytes()); } ArgumentValue::String(value) | ArgumentValue::Type(value) => { - let value_index = string_table[value]; + let value_index = *string_table + .get(value) + .ok_or_else(|| malformed_error!("String value not in string table"))?; write_compressed_uint(value_index, &mut self.buffer); } _ => { diff --git a/dotscope/src/metadata/security/permissionset.rs b/dotscope/src/metadata/security/permissionset.rs index 582483d5..e4b8a5e4 100644 --- a/dotscope/src/metadata/security/permissionset.rs +++ b/dotscope/src/metadata/security/permissionset.rs @@ -428,12 +428,12 @@ impl PermissionSet { /// # Errors /// Returns an error if the data is empty or malformed. pub fn new(data: &[u8]) -> Result { - if data.is_empty() { - return Err(malformed_error!("PermissionSet data is empty")); - } + let first = *data + .first() + .ok_or_else(|| malformed_error!("PermissionSet data is empty"))?; // Determine format from first byte - let (format, permissions) = match data[0] { + let (format, permissions) = match first { /* '.' - Binary format marker */ 0x2E => Self::parse_binary_format(data)?, /* '<' - XML format marker */ @@ -674,7 +674,7 @@ impl PermissionSet { } let xml_start = b" = class_name.split('.').collect(); - if parts.len() >= 2 { - format!("{}.{}", parts[0], parts[1]) + if let (Some(p0), Some(p1)) = (parts.first(), parts.get(1)) { + format!("{p0}.{p1}") } else { "mscorlib".to_string() // Default to mscorlib for unrecognized types } diff --git a/dotscope/src/metadata/sequencepoints.rs b/dotscope/src/metadata/sequencepoints.rs index bc1b2925..03dca0dd 100644 --- a/dotscope/src/metadata/sequencepoints.rs +++ b/dotscope/src/metadata/sequencepoints.rs @@ -282,7 +282,7 @@ impl SequencePoints { let il_offset_value = if is_first { point.il_offset } else { - point.il_offset - prev_il_offset + point.il_offset.saturating_sub(prev_il_offset) }; write_compressed_uint(il_offset_value, &mut buffer); @@ -291,7 +291,7 @@ impl SequencePoints { write_compressed_uint(point.start_line, &mut buffer); } else { #[allow(clippy::cast_possible_wrap)] - let delta = point.start_line as i32 - prev_start_line as i32; + let delta = (point.start_line as i32).wrapping_sub(prev_start_line as i32); write_compressed_int(delta, &mut buffer); } @@ -299,16 +299,16 @@ impl SequencePoints { if is_first { write_compressed_uint(u32::from(point.start_col), &mut buffer); } else { - let delta = i32::from(point.start_col) - i32::from(prev_start_col); + let delta = i32::from(point.start_col).wrapping_sub(i32::from(prev_start_col)); write_compressed_int(delta, &mut buffer); } // End Line Delta (unsigned delta from start line) - let end_line_delta = point.end_line - point.start_line; + let end_line_delta = point.end_line.saturating_sub(point.start_line); write_compressed_uint(end_line_delta, &mut buffer); // End Column Delta (unsigned delta from start column) - let end_col_delta = point.end_col - point.start_col; + let end_col_delta = point.end_col.saturating_sub(point.start_col); write_compressed_uint(u32::from(end_col_delta), &mut buffer); // Update previous values for next iteration @@ -348,7 +348,9 @@ pub fn parse_sequence_points(blob: &[u8]) -> Result { il_offset = if first { il_offset_delta } else { - il_offset + il_offset_delta + il_offset + .checked_add(il_offset_delta) + .ok_or_else(|| malformed_error!("IL offset overflow in sequence points"))? }; let start_line_delta = if first { @@ -385,8 +387,12 @@ pub fn parse_sequence_points(blob: &[u8]) -> Result { let end_line_delta = parser.read_compressed_uint()?; #[allow(clippy::cast_possible_truncation)] let end_col_delta = parser.read_compressed_uint()? as u16; - let end_line = start_line + end_line_delta; - let end_col = start_col + end_col_delta; + let end_line = start_line + .checked_add(end_line_delta) + .ok_or_else(|| malformed_error!("End line overflow in sequence points"))?; + let end_col = start_col + .checked_add(end_col_delta) + .ok_or_else(|| malformed_error!("End column overflow in sequence points"))?; let is_hidden = start_line == HIDDEN_SEQUENCE_POINT_LINE; points.push(SequencePoint { diff --git a/dotscope/src/metadata/signatures/parser.rs b/dotscope/src/metadata/signatures/parser.rs index d60e7ae8..44c16b18 100644 --- a/dotscope/src/metadata/signatures/parser.rs +++ b/dotscope/src/metadata/signatures/parser.rs @@ -463,7 +463,7 @@ impl<'a> SignatureParser<'a> { while let Some(work) = work_stack.pop() { // Check nesting depth limit - if work_stack.len() + result_stack.len() > MAX_NESTING_DEPTH { + if work_stack.len().saturating_add(result_stack.len()) > MAX_NESTING_DEPTH { return Err(DepthLimitExceeded(MAX_NESTING_DEPTH)); } @@ -632,13 +632,23 @@ impl<'a> SignatureParser<'a> { } ELEMENT_TYPE::CMOD_REQD => { // We consumed the CMOD_REQD byte, go back so parse_custom_mods can handle it - self.parser.seek(self.parser.pos() - 1)?; + let prev_pos = self + .parser + .pos() + .checked_sub(1) + .ok_or_else(|| malformed_error!("Parser position underflow"))?; + self.parser.seek(prev_pos)?; let modifiers = self.parse_custom_mods()?; result_stack.push(TypeSignature::ModifiedRequired(modifiers)); } ELEMENT_TYPE::CMOD_OPT => { // We consumed the CMOD_OPT byte, go back so parse_custom_mods can handle it - self.parser.seek(self.parser.pos() - 1)?; + let prev_pos = self + .parser + .pos() + .checked_sub(1) + .ok_or_else(|| malformed_error!("Parser position underflow"))?; + self.parser.seek(prev_pos)?; let modifiers = self.parse_custom_mods()?; result_stack.push(TypeSignature::ModifiedOptional(modifiers)); } @@ -682,7 +692,7 @@ impl<'a> SignatureParser<'a> { WorkItem::BuildGenericInst { arg_count } => { // Stack has: base_type, arg1, arg2, ..., argN (top is argN) // We need to pop argN...arg1 in reverse, then base_type - if result_stack.len() < (arg_count as usize + 1) { + if result_stack.len() < (arg_count as usize).saturating_add(1) { return Err(malformed_error!( "Insufficient types on stack for GENERICINST" )); @@ -801,7 +811,12 @@ impl<'a> SignatureParser<'a> { Ok(()) } ELEMENT_TYPE::CMOD_REQD | ELEMENT_TYPE::CMOD_OPT => { - self.parser.seek(self.parser.pos() - 1)?; + let prev_pos = self + .parser + .pos() + .checked_sub(1) + .ok_or_else(|| malformed_error!("Parser position underflow"))?; + self.parser.seek(prev_pos)?; let _ = self.parse_custom_mods()?; Ok(()) } diff --git a/dotscope/src/metadata/streams/blob.rs b/dotscope/src/metadata/streams/blob.rs index 1ecedde4..e2f387aa 100644 --- a/dotscope/src/metadata/streams/blob.rs +++ b/dotscope/src/metadata/streams/blob.rs @@ -352,11 +352,10 @@ impl<'a> Blob<'a> { /// - [`iter`](Self::iter): Iterate over all blobs in the heap /// - [ECMA-335 II.24.2.4](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Blob heap specification pub fn from(data: &'a [u8]) -> Result> { - if data.is_empty() || data[0] != 0 { - return Err(malformed_error!("Invalid memory for #Blob heap")); + match data.first() { + Some(0) => Ok(Blob { data }), + _ => Err(malformed_error!("Invalid memory for #Blob heap")), } - - Ok(Blob { data }) } /// Retrieves a blob from the heap by its offset. @@ -441,7 +440,7 @@ impl<'a> Blob<'a> { return Err(out_of_bounds_error!()); } - let mut parser = Parser::new(&self.data[index..]); + let mut parser = Parser::new(self.data.get(index..).ok_or(out_of_bounds_error!())?); let len = parser.read_compressed_uint()? as usize; let skip = parser.pos(); @@ -453,11 +452,9 @@ impl<'a> Blob<'a> { return Err(out_of_bounds_error!()); }; - if data_start > self.data.len() || data_end > self.data.len() { - return Err(out_of_bounds_error!()); - } - - Ok(&self.data[data_start..data_end]) + self.data + .get(data_start..data_end) + .ok_or(out_of_bounds_error!()) } /// Returns an iterator over all blobs in the heap. @@ -826,10 +823,14 @@ impl<'a> Iterator for BlobIterator<'a> { let start_position = self.position; match self.blob.get(self.position) { Ok(blob_data) => { - let mut parser = Parser::new(&self.blob.data[self.position..]); + let remainder = self.blob.data.get(self.position..)?; + let mut parser = Parser::new(remainder); if parser.read_compressed_uint().is_ok() { let length_bytes = parser.pos(); - self.position += length_bytes + blob_data.len(); + self.position = self + .position + .saturating_add(length_bytes) + .saturating_add(blob_data.len()); Some((start_position, blob_data)) } else { None diff --git a/dotscope/src/metadata/streams/guid.rs b/dotscope/src/metadata/streams/guid.rs index eb704d22..c625e21e 100644 --- a/dotscope/src/metadata/streams/guid.rs +++ b/dotscope/src/metadata/streams/guid.rs @@ -501,15 +501,21 @@ impl<'a> Guid<'a> { /// - [`uguid::Guid`]: The returned GUID type with formatting and comparison methods /// - [ECMA-335 II.24.2.5](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): GUID heap specification pub fn get(&self, index: usize) -> Result { - if index < 1 || (index - 1) * 16 + 16 > self.data.len() { + if index < 1 { return Err(out_of_bounds_error!()); } + let offset_start = index + .checked_sub(1) + .and_then(|i| i.checked_mul(16)) + .ok_or(out_of_bounds_error!())?; + let offset_end = offset_start.checked_add(16).ok_or(out_of_bounds_error!())?; - let offset_start = (index - 1) * 16; - let offset_end = offset_start + 16; - + let bytes = self + .data + .get(offset_start..offset_end) + .ok_or(out_of_bounds_error!())?; let mut buffer = [0u8; 16]; - buffer.copy_from_slice(&self.data[offset_start..offset_end]); + buffer.copy_from_slice(bytes); Ok(uguid::Guid::from_bytes(buffer)) } @@ -973,7 +979,7 @@ impl Iterator for GuidIterator<'_> { match self.guid.get(self.index) { Ok(guid) => { let current_index = self.index; - self.index += 1; + self.index = self.index.saturating_add(1); Some((current_index, guid)) } Err(_) => None, diff --git a/dotscope/src/metadata/streams/streamheader.rs b/dotscope/src/metadata/streams/streamheader.rs index b521db72..cd3ad89e 100644 --- a/dotscope/src/metadata/streams/streamheader.rs +++ b/dotscope/src/metadata/streams/streamheader.rs @@ -159,7 +159,7 @@ //! - **ECMA-335 II.24.2.2**: Stream header format and directory structure //! - **ECMA-335 II.24.2**: Complete metadata stream architecture overview -use crate::{utils::read_le, Result}; +use crate::{utils::read_le_at, Result}; use std::io::Write; /// ECMA-335 compliant stream header providing metadata stream location and identification. @@ -509,8 +509,9 @@ impl StreamHeader { return Err(out_of_bounds_error!()); } - let offset = read_le::(data)?; - let size = read_le::(&data[4..])?; + let mut cursor = 0_usize; + let offset = read_le_at::(data, &mut cursor)?; + let size = read_le_at::(data, &mut cursor)?; // ECMA-335 II.24.2.2 says stream sizes "shall be a multiple of 4", but many real-world // .NET tools (AsmResolver, Mono's writer, etc.) produce unaligned sizes. The .NET runtime @@ -533,9 +534,13 @@ impl StreamHeader { )); } + // After the 8-byte header, parse the name (max 32 chars or until end of data). + // `cursor` is currently 8 because we read two u32s. + let name_remaining = data.len().saturating_sub(cursor); + let max_name_bytes = std::cmp::min(32, name_remaining); let mut name = String::with_capacity(32); - for counter in 0..std::cmp::min(32, data.len() - 8) { - let name_char = read_le::(&data[8 + counter..])?; + for _ in 0..max_name_bytes { + let name_char = read_le_at::(data, &mut cursor)?; if name_char == 0 { break; } @@ -611,9 +616,22 @@ impl StreamHeader { // Pad to 4-byte boundary // Name length + 1 (null terminator), padded to multiple of 4 - let name_with_null = self.name.len() + 1; - let padded_len = (name_with_null + 3) & !3; - let padding = padded_len - name_with_null; + let name_with_null = self.name.len().checked_add(1).ok_or_else(|| { + malformed_error!("StreamHeader name length overflow: {}", self.name.len()) + })?; + let padded_len = name_with_null.checked_add(3).ok_or_else(|| { + malformed_error!( + "StreamHeader padded length overflow: {} + 3", + name_with_null + ) + })? & !3; + let padding = padded_len.checked_sub(name_with_null).ok_or_else(|| { + malformed_error!( + "StreamHeader padding underflow: padded={} name_with_null={}", + padded_len, + name_with_null + ) + })?; if padding > 0 { writer.write_all(&vec![0u8; padding])?; } @@ -653,9 +671,10 @@ impl StreamHeader { /// ``` #[must_use] pub fn serialized_size(&self) -> usize { - let name_with_null = self.name.len() + 1; - let padded_name = (name_with_null + 3) & !3; - 8 + padded_name // offset (4) + size (4) + padded name + let name_with_null = self.name.len().saturating_add(1); + let padded_name = name_with_null.saturating_add(3) & !3; + // offset (4) + size (4) + padded name + 8_usize.saturating_add(padded_name) } } diff --git a/dotscope/src/metadata/streams/strings.rs b/dotscope/src/metadata/streams/strings.rs index 5b358fd3..a9546f59 100644 --- a/dotscope/src/metadata/streams/strings.rs +++ b/dotscope/src/metadata/streams/strings.rs @@ -532,7 +532,7 @@ impl<'a> Strings<'a> { /// - [`Strings::iter`]: Sequential iteration over all heap strings /// - [ECMA-335 II.24.2.3](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Strings heap format specification pub fn from(data: &[u8]) -> Result> { - if data.is_empty() || data[0] != 0 { + if data.first() != Some(&0) { return Err(malformed_error!("Provided #String heap is empty")); } @@ -773,7 +773,8 @@ impl<'a> Strings<'a> { // If profiling shows this is a bottleneck, callers should: // 1. Cache results at the call site for repeated access // 2. Use the iterator for bulk processing (avoids repeated bounds checks) - match CStr::from_bytes_until_nul(&self.data[index..]) { + let slice = self.data.get(index..).ok_or(out_of_bounds_error!())?; + match CStr::from_bytes_until_nul(slice) { Ok(result) => match result.to_str() { Ok(result) => Ok(result), Err(_) => Err(malformed_error!("Invalid string at index - {}", index)), @@ -1295,7 +1296,7 @@ impl<'a> Iterator for StringsIterator<'a> { match self.strings.get(self.position) { Ok(string) => { // Move position past this string and its null terminator - self.position += string.len() + 1; + self.position = self.position.saturating_add(string.len()).saturating_add(1); Some((start_position, string)) } Err(_) => None, diff --git a/dotscope/src/metadata/streams/tablesheader.rs b/dotscope/src/metadata/streams/tablesheader.rs index 8ec3983a..2839e5fa 100644 --- a/dotscope/src/metadata/streams/tablesheader.rs +++ b/dotscope/src/metadata/streams/tablesheader.rs @@ -1012,34 +1012,41 @@ impl<'a> TablesHeader<'a> { return Err(out_of_bounds_error!()); } - let valid_bitvec = read_le::(&data[8..])?; + let valid_bitvec = read_le::(data.get(8..).ok_or(out_of_bounds_error!())?)?; if valid_bitvec == 0 { return Err(malformed_error!("No valid rows in any of the tables")); } + let tables_offset = 24usize + .checked_add((valid_bitvec.count_ones() as usize).saturating_mul(4)) + .ok_or_else(|| malformed_error!("Tables header offset overflow"))?; + let table_capacity = (TableId::CustomDebugInformation as usize).saturating_add(1); + let mut tables_header = TablesHeader { - major_version: read_le::(&data[4..])?, - minor_version: read_le::(&data[5..])?, + major_version: read_le::(data.get(4..).ok_or(out_of_bounds_error!())?)?, + minor_version: read_le::(data.get(5..).ok_or(out_of_bounds_error!())?)?, valid: valid_bitvec, - sorted: read_le::(&data[16..])?, + sorted: read_le::(data.get(16..).ok_or(out_of_bounds_error!())?)?, info: Arc::new(TableInfo::new(data, valid_bitvec)?), - tables_offset: (24 + valid_bitvec.count_ones() * 4) as usize, - tables: Vec::with_capacity(TableId::CustomDebugInformation as usize + 1), + tables_offset, + tables: Vec::with_capacity(table_capacity), }; // with_capacity has allocated the buffer, but we can't 'insert' elements, only push // to make the vector grow - as .insert doesn't adjust length, only push does. - tables_header - .tables - .resize_with(TableId::CustomDebugInformation as usize + 1, || None); + tables_header.tables.resize_with(table_capacity, || None); - let mut current_offset = tables_header.tables_offset as usize; + let mut current_offset = tables_header.tables_offset; for table_id in TableId::iter() { if current_offset > data.len() { return Err(out_of_bounds_error!()); } - tables_header.add_table(&data[current_offset..], table_id, &mut current_offset)?; + tables_header.add_table( + data.get(current_offset..).ok_or(out_of_bounds_error!())?, + table_id, + &mut current_offset, + )?; } Ok(tables_header) @@ -1546,7 +1553,7 @@ impl<'a> TablesHeader<'a> { pub fn header_size(&self) -> usize { // Fixed header: 24 bytes // Row counts: 4 bytes per present table - 24 + (self.valid.count_ones() as usize * 4) + 24usize.saturating_add((self.valid.count_ones() as usize).saturating_mul(4)) } } diff --git a/dotscope/src/metadata/streams/userstrings.rs b/dotscope/src/metadata/streams/userstrings.rs index a452b4e4..30363240 100644 --- a/dotscope/src/metadata/streams/userstrings.rs +++ b/dotscope/src/metadata/streams/userstrings.rs @@ -124,7 +124,7 @@ impl<'a> UserStrings<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn from(data: &'a [u8]) -> Result> { - if data.is_empty() || data[0] != 0 { + if data.first().copied() != Some(0) { return Err(out_of_bounds_error!()); } @@ -174,7 +174,9 @@ impl<'a> UserStrings<'a> { } let (total_bytes, compressed_length_size) = read_compressed_int_at(self.data, index)?; - let data_start = index + compressed_length_size; + let data_start = index + .checked_add(compressed_length_size) + .ok_or_else(|| malformed_error!("User string offset overflow at index {}", index))?; if total_bytes == 0 { return Err(malformed_error!( @@ -190,9 +192,13 @@ impl<'a> UserStrings<'a> { // Total bytes includes UTF-16 data + terminator byte (1 byte) // So actual UTF-16 data is total_bytes - 1 - let utf16_length = total_bytes - 1; + let utf16_length = total_bytes + .checked_sub(1) + .ok_or_else(|| malformed_error!("User string length underflow at index {}", index))?; - let total_data_end = data_start + total_bytes; + let total_data_end = data_start + .checked_add(total_bytes) + .ok_or_else(|| malformed_error!("User string end overflow at index {}", index))?; if total_data_end > self.data.len() { return Err(out_of_bounds_error!()); } @@ -201,8 +207,13 @@ impl<'a> UserStrings<'a> { return Err(malformed_error!("Invalid UTF-16 length at index {}", index)); } - let utf16_data_end = data_start + utf16_length; - let utf16_data = &self.data[data_start..utf16_data_end]; + let utf16_data_end = data_start + .checked_add(utf16_length) + .ok_or_else(|| malformed_error!("User string data end overflow at index {}", index))?; + let utf16_data = self + .data + .get(data_start..utf16_data_end) + .ok_or(out_of_bounds_error!())?; // Convert byte slice to u16 slice for UTF-16 string construction. // @@ -472,30 +483,33 @@ impl<'a> Iterator for UserStringsIterator<'a> { read_compressed_int(self.user_strings.data, &mut self.position) { // Reset position since read_compressed_int advanced it - self.position -= consumed; + self.position = self.position.saturating_sub(consumed); (length, consumed) } else { // Try to skip over bad data by advancing one byte and trying again - self.position += 1; - recovery_attempts += 1; + self.position = self.position.saturating_add(1); + recovery_attempts = recovery_attempts.saturating_add(1); continue; }; // Handle zero-length entries (invalid according to .NET spec, but may exist in malformed data) if total_bytes == 0 { - self.position += compressed_length_size; - recovery_attempts += 1; + self.position = self.position.saturating_add(compressed_length_size); + recovery_attempts = recovery_attempts.saturating_add(1); continue; } let Ok(string) = self.user_strings.get(start_position) else { - // Skip over the malformed entry - self.position += compressed_length_size + total_bytes; - recovery_attempts += 1; + // Skip over the malformed entry. On overflow, bail out cleanly. + let skip = compressed_length_size.checked_add(total_bytes)?; + let next = self.position.checked_add(skip)?; + self.position = next; + recovery_attempts = recovery_attempts.saturating_add(1); continue; }; - let new_position = self.position + compressed_length_size + total_bytes; + let skip = compressed_length_size.checked_add(total_bytes)?; + let new_position = self.position.checked_add(skip)?; self.position = new_position; return Some((start_position, string)); diff --git a/dotscope/src/metadata/tablefields.rs b/dotscope/src/metadata/tablefields.rs index d5a81f18..d497be12 100644 --- a/dotscope/src/metadata/tablefields.rs +++ b/dotscope/src/metadata/tablefields.rs @@ -119,19 +119,19 @@ pub fn get_heap_fields(table_id: TableId, table_info: &TableInfoRef) -> Vec Vec Vec Vec Vec Vec Vec { let coded_size = table_info.coded_index_bytes(CodedIndexType::HasConstant) as usize; fields.push(HeapFieldDescriptor { - offset: 2 + coded_size, + offset: 2usize.saturating_add(coded_size), size: blob_size, heap_type: HeapType::Blob, }); @@ -255,7 +255,7 @@ pub fn get_heap_fields(table_id: TableId, table_info: &TableInfoRef) -> Vec Vec { let coded_size = table_info.coded_index_bytes(CodedIndexType::HasDeclSecurity) as usize; fields.push(HeapFieldDescriptor { - offset: 2 + coded_size, + offset: 2usize.saturating_add(coded_size), size: blob_size, heap_type: HeapType::Blob, }); @@ -318,7 +318,7 @@ pub fn get_heap_fields(table_id: TableId, table_info: &TableInfoRef) -> Vec { let coded_size = table_info.coded_index_bytes(CodedIndexType::MemberForwarded) as usize; fields.push(HeapFieldDescriptor { - offset: 2 + coded_size, + offset: 2usize.saturating_add(coded_size), size: str_size, heap_type: HeapType::String, }); @@ -333,13 +333,13 @@ pub fn get_heap_fields(table_id: TableId, table_info: &TableInfoRef) -> Vec Vec Vec Vec Vec { let coded_size = table_info.coded_index_bytes(CodedIndexType::TypeOrMethodDef) as usize; fields.push(HeapFieldDescriptor { - offset: 4 + coded_size, + offset: 4usize.saturating_add(coded_size), size: str_size, heap_type: HeapType::String, }); @@ -452,19 +452,19 @@ pub fn get_heap_fields(table_id: TableId, table_info: &TableInfoRef) -> Vec Vec Vec u32 { u32::from( - /* hash_alg_id */ 4 + - /* major_version */ 2 + - /* minor_version */ 2 + - /* build_number */ 2 + - /* revision_number */ 2 + - /* flags */ 4 + - /* public_key */ sizes.blob_bytes() + - /* name */ sizes.str_bytes() + - /* culture */ sizes.str_bytes() + /* hash_alg_id */ 4u8 + /* major_version */ .saturating_add(2) + /* minor_version */ .saturating_add(2) + /* build_number */ .saturating_add(2) + /* revision_number */ .saturating_add(2) + /* flags */ .saturating_add(4) + /* public_key */ .saturating_add(sizes.blob_bytes()) + /* name */ .saturating_add(sizes.str_bytes()) + /* culture */ .saturating_add(sizes.str_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/assembly/reader.rs b/dotscope/src/metadata/tables/assembly/reader.rs index 4692f422..12f588e9 100644 --- a/dotscope/src/metadata/tables/assembly/reader.rs +++ b/dotscope/src/metadata/tables/assembly/reader.rs @@ -73,7 +73,7 @@ impl RowReadable for AssemblyRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(AssemblyRaw { rid, - token: Token::new(0x2000_0000 + rid), + token: Token::new(0x2000_0000u32.saturating_add(rid)), offset: *offset, hash_alg_id: read_le_at::(data, offset)?, major_version: u32::from(read_le_at::(data, offset)?), diff --git a/dotscope/src/metadata/tables/assemblyos/reader.rs b/dotscope/src/metadata/tables/assemblyos/reader.rs index 78edd0b2..11f2d736 100644 --- a/dotscope/src/metadata/tables/assemblyos/reader.rs +++ b/dotscope/src/metadata/tables/assemblyos/reader.rs @@ -69,7 +69,7 @@ impl RowReadable for AssemblyOsRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, _sizes: &TableInfoRef) -> Result { Ok(AssemblyOsRaw { rid, - token: Token::new(0x2200_0000 + rid), + token: Token::new(0x2200_0000u32.saturating_add(rid)), offset: *offset, os_platform_id: read_le_at::(data, offset)?, os_major_version: read_le_at::(data, offset)?, diff --git a/dotscope/src/metadata/tables/assemblyprocessor/raw.rs b/dotscope/src/metadata/tables/assemblyprocessor/raw.rs index 40666766..d8cbbcc1 100644 --- a/dotscope/src/metadata/tables/assemblyprocessor/raw.rs +++ b/dotscope/src/metadata/tables/assemblyprocessor/raw.rs @@ -72,7 +72,7 @@ use crate::{ /// /// All fields contain direct integer values rather than heap indexes: /// - No string heap references -/// - No blob heap references +/// - No blob heap references /// - All data is self-contained within the table row /// /// # Architecture Identifiers diff --git a/dotscope/src/metadata/tables/assemblyprocessor/reader.rs b/dotscope/src/metadata/tables/assemblyprocessor/reader.rs index a75ed19d..0dd82aae 100644 --- a/dotscope/src/metadata/tables/assemblyprocessor/reader.rs +++ b/dotscope/src/metadata/tables/assemblyprocessor/reader.rs @@ -67,7 +67,7 @@ impl RowReadable for AssemblyProcessorRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, _sizes: &TableInfoRef) -> Result { Ok(AssemblyProcessorRaw { rid, - token: Token::new(0x2100_0000 + rid), + token: Token::new(0x2100_0000u32.saturating_add(rid)), offset: *offset, processor: read_le_at::(data, offset)?, }) diff --git a/dotscope/src/metadata/tables/assemblyref/assemblyrefhash.rs b/dotscope/src/metadata/tables/assemblyref/assemblyrefhash.rs index a6e045d8..3c696136 100644 --- a/dotscope/src/metadata/tables/assemblyref/assemblyrefhash.rs +++ b/dotscope/src/metadata/tables/assemblyref/assemblyrefhash.rs @@ -83,7 +83,7 @@ use std::fmt::Write; /// # Performance /// Pre-allocates output string with exact capacity to avoid reallocations. fn bytes_to_hex(bytes: &[u8]) -> String { - let mut hex_string = String::with_capacity(bytes.len() * 2); + let mut hex_string = String::with_capacity(bytes.len().saturating_mul(2)); for byte in bytes { let _ = write!(&mut hex_string, "{byte:02x}"); } diff --git a/dotscope/src/metadata/tables/assemblyref/raw.rs b/dotscope/src/metadata/tables/assemblyref/raw.rs index 038c24cf..7254af53 100644 --- a/dotscope/src/metadata/tables/assemblyref/raw.rs +++ b/dotscope/src/metadata/tables/assemblyref/raw.rs @@ -21,7 +21,7 @@ //! //! The `AssemblyRef` table (0x23) contains zero or more rows with these fields: //! - **`MajorVersion`** (2 bytes): Major version number -//! - **`MinorVersion`** (2 bytes): Minor version number +//! - **`MinorVersion`** (2 bytes): Minor version number //! - **`BuildNumber`** (2 bytes): Build number //! - **`RevisionNumber`** (2 bytes): Revision number //! - **Flags** (4 bytes): Assembly flags bitmask @@ -240,15 +240,16 @@ impl TableRow for AssemblyRefRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* major_version */ 2 + - /* minor_version */ 2 + - /* build_number */ 2 + - /* revision_number */ 2 + - /* flags */ 4 + - /* public_key_or_token */ sizes.blob_bytes() + - /* name */ sizes.str_bytes() + - /* culture */ sizes.str_bytes() + - /* hash_value */ sizes.blob_bytes() + /* major_version */ 2u8 + /* minor_version */ .saturating_add(2) + /* build_number */ .saturating_add(2) + /* revision_number */ .saturating_add(2) + /* flags */ .saturating_add(4) + /* public_key_or_token */ .saturating_add(sizes.blob_bytes()) + /* name */ .saturating_add(sizes.str_bytes()) + /* culture */ .saturating_add(sizes.str_bytes()) + /* hash_value */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/assemblyref/reader.rs b/dotscope/src/metadata/tables/assemblyref/reader.rs index 460e5b16..16be8b20 100644 --- a/dotscope/src/metadata/tables/assemblyref/reader.rs +++ b/dotscope/src/metadata/tables/assemblyref/reader.rs @@ -74,7 +74,7 @@ impl RowReadable for AssemblyRefRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(AssemblyRefRaw { rid, - token: Token::new(0x2300_0000 + rid), + token: Token::new(0x2300_0000u32.saturating_add(rid)), offset: *offset, major_version: u32::from(read_le_at::(data, offset)?), minor_version: u32::from(read_le_at::(data, offset)?), diff --git a/dotscope/src/metadata/tables/assemblyrefos/raw.rs b/dotscope/src/metadata/tables/assemblyrefos/raw.rs index 499b79df..b732fc5d 100644 --- a/dotscope/src/metadata/tables/assemblyrefos/raw.rs +++ b/dotscope/src/metadata/tables/assemblyrefos/raw.rs @@ -129,7 +129,7 @@ pub struct AssemblyRefOsRaw { /// /// Specifies the target operating system family. Common values include: /// - 1: Windows 32-bit - /// - 2: Windows 64-bit + /// - 2: Windows 64-bit /// - Other values for various platform types pub os_platform_id: u32, @@ -229,7 +229,7 @@ impl AssemblyRefOsRaw { /// /// Returns [`crate::Error`] if: /// - The referenced `AssemblyRef` entry cannot be found in the provided map - /// - The assembly reference token is invalid or malformed + /// - The assembly reference token is invalid or malformed /// - The `AssemblyRef` table index is out of bounds /// /// # Thread Safety @@ -277,10 +277,11 @@ impl TableRow for AssemblyRefOsRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* os_platform_id */ 4 + - /* os_major_version */ 4 + - /* os_minor_version */ 4 + - /* assembly_ref */ sizes.table_index_bytes(TableId::AssemblyRef) + /* os_platform_id */ 4u8 + /* os_major_version */ .saturating_add(4) + /* os_minor_version */ .saturating_add(4) + /* assembly_ref */ .saturating_add(sizes.table_index_bytes(TableId::AssemblyRef)) + ) } } diff --git a/dotscope/src/metadata/tables/assemblyrefos/reader.rs b/dotscope/src/metadata/tables/assemblyrefos/reader.rs index 60a795a4..b1e95424 100644 --- a/dotscope/src/metadata/tables/assemblyrefos/reader.rs +++ b/dotscope/src/metadata/tables/assemblyrefos/reader.rs @@ -55,7 +55,7 @@ impl RowReadable for AssemblyRefOsRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(AssemblyRefOsRaw { rid, - token: Token::new(0x2500_0000 + rid), + token: Token::new(0x2500_0000u32.saturating_add(rid)), offset: *offset, os_platform_id: read_le_at::(data, offset)?, os_major_version: read_le_at::(data, offset)?, diff --git a/dotscope/src/metadata/tables/assemblyrefprocessor/raw.rs b/dotscope/src/metadata/tables/assemblyrefprocessor/raw.rs index 78decd9d..30a1d528 100644 --- a/dotscope/src/metadata/tables/assemblyrefprocessor/raw.rs +++ b/dotscope/src/metadata/tables/assemblyrefprocessor/raw.rs @@ -263,8 +263,9 @@ impl TableRow for AssemblyRefProcessorRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* processor */ 4 + - /* assembly_ref */ sizes.table_index_bytes(TableId::AssemblyRef) + /* processor */ 4u8 + /* assembly_ref */ .saturating_add(sizes.table_index_bytes(TableId::AssemblyRef)) + ) } } diff --git a/dotscope/src/metadata/tables/assemblyrefprocessor/reader.rs b/dotscope/src/metadata/tables/assemblyrefprocessor/reader.rs index 4de43497..f515691b 100644 --- a/dotscope/src/metadata/tables/assemblyrefprocessor/reader.rs +++ b/dotscope/src/metadata/tables/assemblyrefprocessor/reader.rs @@ -53,7 +53,7 @@ impl RowReadable for AssemblyRefProcessorRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(AssemblyRefProcessorRaw { rid, - token: Token::new(0x2400_0000 + rid), + token: Token::new(0x2400_0000u32.saturating_add(rid)), offset: *offset, processor: read_le_at::(data, offset)?, assembly_ref: read_le_at_dyn(data, offset, sizes.is_large(TableId::AssemblyRef))?, diff --git a/dotscope/src/metadata/tables/classlayout/builder.rs b/dotscope/src/metadata/tables/classlayout/builder.rs index b90aca95..d2f137dc 100644 --- a/dotscope/src/metadata/tables/classlayout/builder.rs +++ b/dotscope/src/metadata/tables/classlayout/builder.rs @@ -274,7 +274,7 @@ impl ClassLayoutBuilder { .parent .ok_or_else(|| Error::ModificationInvalid("Parent type is required".to_string()))?; - if packing_size != 0 && (packing_size & (packing_size - 1)) != 0 { + if packing_size != 0 && (packing_size & packing_size.saturating_sub(1)) != 0 { return Err(Error::ModificationInvalid(format!( "Packing size must be 0 or a power of 2, got {packing_size}" ))); diff --git a/dotscope/src/metadata/tables/classlayout/raw.rs b/dotscope/src/metadata/tables/classlayout/raw.rs index 16c721fc..f3d35cd5 100644 --- a/dotscope/src/metadata/tables/classlayout/raw.rs +++ b/dotscope/src/metadata/tables/classlayout/raw.rs @@ -280,9 +280,10 @@ impl TableRow for ClassLayoutRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* packing_size */ 2 + - /* class_size */ 4 + - /* parent */ sizes.table_index_bytes(TableId::TypeDef) + /* packing_size */ 2u8 + /* class_size */ .saturating_add(4) + /* parent */ .saturating_add(sizes.table_index_bytes(TableId::TypeDef)) + ) } } diff --git a/dotscope/src/metadata/tables/classlayout/reader.rs b/dotscope/src/metadata/tables/classlayout/reader.rs index 4ad0a5d7..0c40e439 100644 --- a/dotscope/src/metadata/tables/classlayout/reader.rs +++ b/dotscope/src/metadata/tables/classlayout/reader.rs @@ -19,7 +19,7 @@ impl RowReadable for ClassLayoutRaw { Ok(ClassLayoutRaw { rid, - token: Token::new(0x0F00_0000 + rid), + token: Token::new(0x0F00_0000u32.saturating_add(rid)), offset: offset_org, packing_size, class_size, diff --git a/dotscope/src/metadata/tables/constant/raw.rs b/dotscope/src/metadata/tables/constant/raw.rs index 75db30e4..0b3bcb60 100644 --- a/dotscope/src/metadata/tables/constant/raw.rs +++ b/dotscope/src/metadata/tables/constant/raw.rs @@ -10,7 +10,7 @@ //! The Constant table (0x0B) contains zero or more rows with these fields: //! - **Type** (1 byte): Element type of the constant (`ELEMENT_TYPE`_* enumeration) //! - **Padding** (1 byte): Reserved padding byte (must be zero) -//! - **Parent** (2/4 bytes): `HasConstant` coded index into Field, Property, or Param tables +//! - **Parent** (2/4 bytes): `HasConstant` coded index into Field, Property, or Param tables //! - **Value** (2/4 bytes): Blob heap index containing the constant's binary data //! //! # Reference @@ -241,10 +241,11 @@ impl TableRow for ConstantRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* base */ 1 + - /* padding */ 1 + - /* parent */ sizes.coded_index_bytes(CodedIndexType::HasConstant) + - /* value */ sizes.blob_bytes() + /* base */ 1u8 + /* padding */ .saturating_add(1) + /* parent */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::HasConstant)) + /* value */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/constant/reader.rs b/dotscope/src/metadata/tables/constant/reader.rs index 2b7e04a7..7511c883 100644 --- a/dotscope/src/metadata/tables/constant/reader.rs +++ b/dotscope/src/metadata/tables/constant/reader.rs @@ -14,11 +14,11 @@ impl RowReadable for ConstantRaw { let offset_org = *offset; let c_type = read_le_at::(data, offset)?; - *offset += 1; // Padding + *offset = offset.saturating_add(1); // Padding Ok(ConstantRaw { rid, - token: Token::new(0x0B00_0000 + rid), + token: Token::new(0x0B00_0000u32.saturating_add(rid)), offset: offset_org, base: c_type, parent: CodedIndex::read(data, offset, sizes, CodedIndexType::HasConstant)?, diff --git a/dotscope/src/metadata/tables/customattribute/raw.rs b/dotscope/src/metadata/tables/customattribute/raw.rs index de813dc1..a9c47906 100644 --- a/dotscope/src/metadata/tables/customattribute/raw.rs +++ b/dotscope/src/metadata/tables/customattribute/raw.rs @@ -333,9 +333,10 @@ impl TableRow for CustomAttributeRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* parent */ sizes.coded_index_bytes(CodedIndexType::HasCustomAttribute) + - /* constructor */ sizes.coded_index_bytes(CodedIndexType::CustomAttributeType) + - /* value */ sizes.blob_bytes() + /* parent */ sizes.coded_index_bytes(CodedIndexType::HasCustomAttribute) + /* constructor */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::CustomAttributeType)) + /* value */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/customattribute/reader.rs b/dotscope/src/metadata/tables/customattribute/reader.rs index 2c759191..2ac88fc6 100644 --- a/dotscope/src/metadata/tables/customattribute/reader.rs +++ b/dotscope/src/metadata/tables/customattribute/reader.rs @@ -15,7 +15,7 @@ impl RowReadable for CustomAttributeRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(CustomAttributeRaw { rid, - token: Token::new(0x0C00_0000 + rid), + token: Token::new(0x0C00_0000u32.saturating_add(rid)), offset: *offset, parent: CodedIndex::read(data, offset, sizes, CodedIndexType::HasCustomAttribute)?, constructor: CodedIndex::read( diff --git a/dotscope/src/metadata/tables/customdebuginformation/raw.rs b/dotscope/src/metadata/tables/customdebuginformation/raw.rs index 62089acc..061b5b0e 100644 --- a/dotscope/src/metadata/tables/customdebuginformation/raw.rs +++ b/dotscope/src/metadata/tables/customdebuginformation/raw.rs @@ -163,7 +163,7 @@ impl CustomDebugInformationRaw { /// /// Returns [`crate::Error`] if: /// - The GUID heap index is invalid or out of bounds - /// - The blob heap index is invalid or out of bounds + /// - The blob heap index is invalid or out of bounds /// - The blob data cannot be parsed for known debug info types pub fn to_owned( &self, @@ -206,9 +206,10 @@ impl TableRow for CustomDebugInformationRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* parent */ sizes.coded_index_bytes(CodedIndexType::HasCustomDebugInformation) + - /* kind */ sizes.guid_bytes() + - /* value */ sizes.blob_bytes() + /* parent */ sizes.coded_index_bytes(CodedIndexType::HasCustomDebugInformation) + /* kind */ .saturating_add(sizes.guid_bytes()) + /* value */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/customdebuginformation/reader.rs b/dotscope/src/metadata/tables/customdebuginformation/reader.rs index 4cf657f6..bb1689dc 100644 --- a/dotscope/src/metadata/tables/customdebuginformation/reader.rs +++ b/dotscope/src/metadata/tables/customdebuginformation/reader.rs @@ -44,7 +44,7 @@ impl RowReadable for CustomDebugInformationRaw { Ok(CustomDebugInformationRaw { rid, - token: Token::new(0x3700_0000 + rid), + token: Token::new(0x3700_0000u32.saturating_add(rid)), offset: offset_org, parent, kind, diff --git a/dotscope/src/metadata/tables/declsecurity/raw.rs b/dotscope/src/metadata/tables/declsecurity/raw.rs index cee00c38..9a3c07db 100644 --- a/dotscope/src/metadata/tables/declsecurity/raw.rs +++ b/dotscope/src/metadata/tables/declsecurity/raw.rs @@ -247,9 +247,10 @@ impl TableRow for DeclSecurityRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* action */ 2 + - /* parent */ sizes.coded_index_bytes(CodedIndexType::HasDeclSecurity) + - /* permission_set */ sizes.blob_bytes() + /* action */ 2u8 + /* parent */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::HasDeclSecurity)) + /* permission_set */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/declsecurity/reader.rs b/dotscope/src/metadata/tables/declsecurity/reader.rs index 24a83b0a..c7369823 100644 --- a/dotscope/src/metadata/tables/declsecurity/reader.rs +++ b/dotscope/src/metadata/tables/declsecurity/reader.rs @@ -34,7 +34,7 @@ impl RowReadable for DeclSecurityRaw { Ok(DeclSecurityRaw { rid, - token: Token::new(0x0E00_0000 + rid), + token: Token::new(0x0E00_0000u32.saturating_add(rid)), offset: offset_org, action, parent: CodedIndex::read(data, offset, sizes, CodedIndexType::HasDeclSecurity)?, diff --git a/dotscope/src/metadata/tables/document/raw.rs b/dotscope/src/metadata/tables/document/raw.rs index b7e3e8e9..788f79a3 100644 --- a/dotscope/src/metadata/tables/document/raw.rs +++ b/dotscope/src/metadata/tables/document/raw.rs @@ -186,10 +186,11 @@ impl TableRow for DocumentRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - sizes.blob_bytes() + // name - sizes.guid_bytes() + // hash_algorithm - sizes.blob_bytes() + // hash - sizes.guid_bytes() // language + /* name */ sizes.blob_bytes() + /* hash_algorithm */ .saturating_add(sizes.guid_bytes()) + /* hash */ .saturating_add(sizes.blob_bytes()) + /* language */ .saturating_add(sizes.guid_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/document/reader.rs b/dotscope/src/metadata/tables/document/reader.rs index 3dbc5e7f..e05d5f3a 100644 --- a/dotscope/src/metadata/tables/document/reader.rs +++ b/dotscope/src/metadata/tables/document/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for DocumentRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(DocumentRaw { rid, - token: Token::new(0x3000_0000 + rid), + token: Token::new(0x3000_0000u32.saturating_add(rid)), offset: *offset, name: read_le_at_dyn(data, offset, sizes.is_large_blob())?, hash_algorithm: read_le_at_dyn(data, offset, sizes.is_large_guid())?, diff --git a/dotscope/src/metadata/tables/enclog/raw.rs b/dotscope/src/metadata/tables/enclog/raw.rs index dfcd64a2..0c8bf20d 100644 --- a/dotscope/src/metadata/tables/enclog/raw.rs +++ b/dotscope/src/metadata/tables/enclog/raw.rs @@ -66,7 +66,7 @@ use crate::{ /// /// All fields contain direct integer values rather than heap indexes: /// - No string heap references -/// - No blob heap references +/// - No blob heap references /// - All data is self-contained within the table row /// /// # Reference @@ -101,7 +101,7 @@ pub struct EncLogRaw { /// /// 4-byte value specifying what type of Edit-and-Continue operation was performed: /// - 0: Create - New metadata item added during edit session - /// - 1: Update - Existing metadata item modified during edit session + /// - 1: Update - Existing metadata item modified during edit session /// - 2: Delete - Metadata item marked for deletion during edit session pub func_code: u32, } diff --git a/dotscope/src/metadata/tables/enclog/reader.rs b/dotscope/src/metadata/tables/enclog/reader.rs index 83a3bb1f..1eac56f1 100644 --- a/dotscope/src/metadata/tables/enclog/reader.rs +++ b/dotscope/src/metadata/tables/enclog/reader.rs @@ -27,7 +27,7 @@ impl RowReadable for EncLogRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, _sizes: &TableInfoRef) -> Result { Ok(EncLogRaw { rid, - token: Token::new(0x1E00_0000 + rid), + token: Token::new(0x1E00_0000u32.saturating_add(rid)), offset: *offset, token_value: read_le_at::(data, offset)?, func_code: read_le_at::(data, offset)?, diff --git a/dotscope/src/metadata/tables/encmap/raw.rs b/dotscope/src/metadata/tables/encmap/raw.rs index a1343ff7..2d8167af 100644 --- a/dotscope/src/metadata/tables/encmap/raw.rs +++ b/dotscope/src/metadata/tables/encmap/raw.rs @@ -157,7 +157,8 @@ impl TableRow for EncMapRaw { /// /// ## Returns /// Always returns 4 bytes for the fixed token field. + #[rustfmt::skip] fn row_size(_sizes: &TableInfoRef) -> u32 { - 4 // Token field (4 bytes) + 4 // token } } diff --git a/dotscope/src/metadata/tables/encmap/reader.rs b/dotscope/src/metadata/tables/encmap/reader.rs index 02e473b1..8998d2ec 100644 --- a/dotscope/src/metadata/tables/encmap/reader.rs +++ b/dotscope/src/metadata/tables/encmap/reader.rs @@ -30,7 +30,7 @@ impl RowReadable for EncMapRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, _sizes: &TableInfoRef) -> Result { Ok(EncMapRaw { rid, - token: Token::new(0x1F00_0000 + rid), + token: Token::new(0x1F00_0000u32.saturating_add(rid)), offset: *offset, original_token: Token::new(read_le_at::(data, offset)?), }) diff --git a/dotscope/src/metadata/tables/event/raw.rs b/dotscope/src/metadata/tables/event/raw.rs index 9ebd5f92..5ea7f50c 100644 --- a/dotscope/src/metadata/tables/event/raw.rs +++ b/dotscope/src/metadata/tables/event/raw.rs @@ -177,9 +177,10 @@ impl TableRow for EventRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 2 + - /* name */ sizes.str_bytes() + - /* event_type */ sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef) + /* flags */ 2u8 + /* name */ .saturating_add(sizes.str_bytes()) + /* event_type */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef)) + ) } } diff --git a/dotscope/src/metadata/tables/event/reader.rs b/dotscope/src/metadata/tables/event/reader.rs index f3a88e8a..c693208b 100644 --- a/dotscope/src/metadata/tables/event/reader.rs +++ b/dotscope/src/metadata/tables/event/reader.rs @@ -19,7 +19,7 @@ impl RowReadable for EventRaw { Ok(EventRaw { rid, - token: Token::new(0x1400_0000 + rid), + token: Token::new(0x1400_0000u32.saturating_add(rid)), offset: offset_org, flags, name, diff --git a/dotscope/src/metadata/tables/eventmap/raw.rs b/dotscope/src/metadata/tables/eventmap/raw.rs index 035169d7..ecabd5c2 100644 --- a/dotscope/src/metadata/tables/eventmap/raw.rs +++ b/dotscope/src/metadata/tables/eventmap/raw.rs @@ -133,10 +133,13 @@ impl EventMapRaw { return Ok(Arc::new(boxcar::Vec::new())); } - let next_row_id = self.rid + 1; + let next_row_id = self + .rid + .checked_add(1) + .ok_or_else(|| malformed_error!("EventMap rid overflow: {}", self.rid))?; let start = self.event_list as usize; let end = if next_row_id > map.row_count { - events.len() + 1 + events.len().saturating_add(1) } else { match map.get(next_row_id) { Some(next_row) => next_row.event_list as usize, @@ -149,11 +152,11 @@ impl EventMapRaw { } }; - if start > events.len() || end > (events.len() + 1) || end < start { + if start > events.len() || end > events.len().saturating_add(1) || end < start { return Ok(Arc::new(boxcar::Vec::new())); } - let event_list = Arc::new(boxcar::Vec::with_capacity(end - start)); + let event_list = Arc::new(boxcar::Vec::with_capacity(end.saturating_sub(start))); for counter in start..end { let actual_event_token = if event_ptr.is_empty() { let token_value = counter | 0x1400_0000; @@ -326,8 +329,9 @@ impl TableRow for EventMapRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* parent */ sizes.table_index_bytes(TableId::TypeDef) + - /* event_list */ sizes.table_index_bytes(TableId::Event) + /* parent */ sizes.table_index_bytes(TableId::TypeDef) + /* event_list */ .saturating_add(sizes.table_index_bytes(TableId::Event)) + ) } } diff --git a/dotscope/src/metadata/tables/eventmap/reader.rs b/dotscope/src/metadata/tables/eventmap/reader.rs index 6c31a64a..c8553dfc 100644 --- a/dotscope/src/metadata/tables/eventmap/reader.rs +++ b/dotscope/src/metadata/tables/eventmap/reader.rs @@ -41,7 +41,7 @@ impl RowReadable for EventMapRaw { Ok(EventMapRaw { rid, - token: Token::new(0x1200_0000 + rid), + token: Token::new(0x1200_0000u32.saturating_add(rid)), offset: offset_org, parent, event_list, diff --git a/dotscope/src/metadata/tables/eventptr/raw.rs b/dotscope/src/metadata/tables/eventptr/raw.rs index 0b058627..2cc4522e 100644 --- a/dotscope/src/metadata/tables/eventptr/raw.rs +++ b/dotscope/src/metadata/tables/eventptr/raw.rs @@ -157,6 +157,7 @@ impl TableRow for EventPtrRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* event */ sizes.table_index_bytes(TableId::Event) + ) } } diff --git a/dotscope/src/metadata/tables/eventptr/reader.rs b/dotscope/src/metadata/tables/eventptr/reader.rs index 463a2096..0bce0b0c 100644 --- a/dotscope/src/metadata/tables/eventptr/reader.rs +++ b/dotscope/src/metadata/tables/eventptr/reader.rs @@ -36,7 +36,7 @@ impl RowReadable for EventPtrRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(EventPtrRaw { rid, - token: Token::new(0x1300_0000 + rid), + token: Token::new(0x1300_0000u32.saturating_add(rid)), offset: *offset, event: read_le_at_dyn(data, offset, sizes.is_large(TableId::Event))?, }) diff --git a/dotscope/src/metadata/tables/exportedtype/raw.rs b/dotscope/src/metadata/tables/exportedtype/raw.rs index ce22afd1..eda1e8c6 100644 --- a/dotscope/src/metadata/tables/exportedtype/raw.rs +++ b/dotscope/src/metadata/tables/exportedtype/raw.rs @@ -269,11 +269,12 @@ impl TableRow for ExportedTypeRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 4 + - /* type_def_id */ 4 + - /* type_name */ sizes.str_bytes() + - /* type_namespace */ sizes.str_bytes() + - /* implementation */ sizes.coded_index_bytes(CodedIndexType::Implementation) + /* flags */ 4u8 + /* type_def_id */ .saturating_add(4) + /* type_name */ .saturating_add(sizes.str_bytes()) + /* type_namespace */ .saturating_add(sizes.str_bytes()) + /* implementation */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::Implementation)) + ) } } diff --git a/dotscope/src/metadata/tables/exportedtype/reader.rs b/dotscope/src/metadata/tables/exportedtype/reader.rs index 5b29c455..4ebd9aeb 100644 --- a/dotscope/src/metadata/tables/exportedtype/reader.rs +++ b/dotscope/src/metadata/tables/exportedtype/reader.rs @@ -37,7 +37,7 @@ impl RowReadable for ExportedTypeRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ExportedTypeRaw { rid, - token: Token::new(0x2700_0000 + rid), + token: Token::new(0x2700_0000u32.saturating_add(rid)), offset: *offset, flags: read_le_at::(data, offset)?, type_def_id: read_le_at::(data, offset)?, diff --git a/dotscope/src/metadata/tables/field/builder.rs b/dotscope/src/metadata/tables/field/builder.rs index 34c6d442..3243fa98 100644 --- a/dotscope/src/metadata/tables/field/builder.rs +++ b/dotscope/src/metadata/tables/field/builder.rs @@ -136,7 +136,7 @@ impl FieldBuilder { /// /// # Returns /// - /// A [`crate::metadata::token::Token`] representing the newly created field, or an error if + /// A [`Token`] representing the newly created field, or an error if /// validation fails or required fields are missing. /// /// # Errors diff --git a/dotscope/src/metadata/tables/field/raw.rs b/dotscope/src/metadata/tables/field/raw.rs index 82275979..f907de9c 100644 --- a/dotscope/src/metadata/tables/field/raw.rs +++ b/dotscope/src/metadata/tables/field/raw.rs @@ -72,7 +72,7 @@ pub struct FieldRaw { /// Common values: /// - `0x0001`: `CompilerControlled` /// - `0x0002`: Private - /// - `0x0007`: Public + /// - `0x0007`: Public /// - `0x0010`: Static /// - `0x0020`: Literal /// - `0x0080`: `HasFieldRVA` @@ -165,9 +165,10 @@ impl TableRow for FieldRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 2 + - /* name */ sizes.str_bytes() + - /* signature */ sizes.blob_bytes() + /* flags */ 2u8 + /* name */ .saturating_add(sizes.str_bytes()) + /* signature */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/field/reader.rs b/dotscope/src/metadata/tables/field/reader.rs index b764ae8c..a35fcf07 100644 --- a/dotscope/src/metadata/tables/field/reader.rs +++ b/dotscope/src/metadata/tables/field/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for FieldRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(FieldRaw { rid, - token: Token::new(0x0400_0000 + rid), + token: Token::new(0x0400_0000u32.saturating_add(rid)), offset: *offset, flags: u32::from(read_le_at::(data, offset)?), name: read_le_at_dyn(data, offset, sizes.is_large_str())?, diff --git a/dotscope/src/metadata/tables/fieldlayout/raw.rs b/dotscope/src/metadata/tables/fieldlayout/raw.rs index e6a119f9..bdd5b56e 100644 --- a/dotscope/src/metadata/tables/fieldlayout/raw.rs +++ b/dotscope/src/metadata/tables/fieldlayout/raw.rs @@ -174,8 +174,9 @@ impl TableRow for FieldLayoutRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* field_offset */ 4 + - /* field */ sizes.table_index_bytes(TableId::Field) + /* field_offset */ 4u8 + /* field */ .saturating_add(sizes.table_index_bytes(TableId::Field)) + ) } } diff --git a/dotscope/src/metadata/tables/fieldlayout/reader.rs b/dotscope/src/metadata/tables/fieldlayout/reader.rs index 4a53124e..d4bcf657 100644 --- a/dotscope/src/metadata/tables/fieldlayout/reader.rs +++ b/dotscope/src/metadata/tables/fieldlayout/reader.rs @@ -33,7 +33,7 @@ impl RowReadable for FieldLayoutRaw { Ok(FieldLayoutRaw { rid, - token: Token::new(0x1000_0000 + rid), + token: Token::new(0x1000_0000u32.saturating_add(rid)), offset: offset_org, field_offset, field, diff --git a/dotscope/src/metadata/tables/fieldmarshal/raw.rs b/dotscope/src/metadata/tables/fieldmarshal/raw.rs index 2a1daa5e..0961b20d 100644 --- a/dotscope/src/metadata/tables/fieldmarshal/raw.rs +++ b/dotscope/src/metadata/tables/fieldmarshal/raw.rs @@ -223,8 +223,9 @@ impl TableRow for FieldMarshalRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* parent */ sizes.coded_index_bytes(CodedIndexType::HasFieldMarshal) + - /* native_type */ sizes.blob_bytes() + /* parent */ sizes.coded_index_bytes(CodedIndexType::HasFieldMarshal) + /* native_type */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/fieldmarshal/reader.rs b/dotscope/src/metadata/tables/fieldmarshal/reader.rs index 826c3bf4..5e880294 100644 --- a/dotscope/src/metadata/tables/fieldmarshal/reader.rs +++ b/dotscope/src/metadata/tables/fieldmarshal/reader.rs @@ -32,7 +32,7 @@ impl RowReadable for FieldMarshalRaw { Ok(FieldMarshalRaw { rid, - token: Token::new(0x0D00_0000 + rid), + token: Token::new(0x0D00_0000u32.saturating_add(rid)), offset: offset_org, parent: CodedIndex::read(data, offset, sizes, CodedIndexType::HasFieldMarshal)?, native_type: read_le_at_dyn(data, offset, sizes.is_large_blob())?, diff --git a/dotscope/src/metadata/tables/fieldptr/raw.rs b/dotscope/src/metadata/tables/fieldptr/raw.rs index abfedd5d..2506e0bc 100644 --- a/dotscope/src/metadata/tables/fieldptr/raw.rs +++ b/dotscope/src/metadata/tables/fieldptr/raw.rs @@ -140,6 +140,7 @@ impl TableRow for FieldPtrRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* field */ sizes.table_index_bytes(TableId::Field) + ) } } diff --git a/dotscope/src/metadata/tables/fieldptr/reader.rs b/dotscope/src/metadata/tables/fieldptr/reader.rs index 115aefe7..064747b3 100644 --- a/dotscope/src/metadata/tables/fieldptr/reader.rs +++ b/dotscope/src/metadata/tables/fieldptr/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for FieldPtrRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(FieldPtrRaw { rid, - token: Token::new(0x0300_0000 + rid), + token: Token::new(0x0300_0000u32.saturating_add(rid)), offset: *offset, field: read_le_at_dyn(data, offset, sizes.is_large(TableId::Field))?, }) diff --git a/dotscope/src/metadata/tables/fieldrva/raw.rs b/dotscope/src/metadata/tables/fieldrva/raw.rs index 51dad537..003b09f6 100644 --- a/dotscope/src/metadata/tables/fieldrva/raw.rs +++ b/dotscope/src/metadata/tables/fieldrva/raw.rs @@ -165,8 +165,9 @@ impl TableRow for FieldRvaRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* rva */ 4 + - /* field */ sizes.table_index_bytes(TableId::Field) + /* rva */ 4u8 + /* field */ .saturating_add(sizes.table_index_bytes(TableId::Field)) + ) } } diff --git a/dotscope/src/metadata/tables/fieldrva/reader.rs b/dotscope/src/metadata/tables/fieldrva/reader.rs index 98b22731..974a11a2 100644 --- a/dotscope/src/metadata/tables/fieldrva/reader.rs +++ b/dotscope/src/metadata/tables/fieldrva/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for FieldRvaRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(FieldRvaRaw { rid, - token: Token::new(0x1D00_0000 + rid), + token: Token::new(0x1D00_0000u32.saturating_add(rid)), offset: *offset, rva: read_le_at::(data, offset)?, field: read_le_at_dyn(data, offset, sizes.is_large(TableId::Field))?, diff --git a/dotscope/src/metadata/tables/file/raw.rs b/dotscope/src/metadata/tables/file/raw.rs index bc714923..dda5cf0a 100644 --- a/dotscope/src/metadata/tables/file/raw.rs +++ b/dotscope/src/metadata/tables/file/raw.rs @@ -178,9 +178,10 @@ impl TableRow for FileRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 4 + - /* name */ sizes.str_bytes() + - /* hash_value */ sizes.blob_bytes() + /* flags */ 4u8 + /* name */ .saturating_add(sizes.str_bytes()) + /* hash_value */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/file/reader.rs b/dotscope/src/metadata/tables/file/reader.rs index 768bf7b6..d6886da7 100644 --- a/dotscope/src/metadata/tables/file/reader.rs +++ b/dotscope/src/metadata/tables/file/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for FileRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(FileRaw { rid, - token: Token::new(0x2600_0000 + rid), + token: Token::new(0x2600_0000u32.saturating_add(rid)), offset: *offset, flags: read_le_at::(data, offset)?, name: read_le_at_dyn(data, offset, sizes.is_large_str())?, diff --git a/dotscope/src/metadata/tables/genericparam/raw.rs b/dotscope/src/metadata/tables/genericparam/raw.rs index bd9fdd58..0bb0eac8 100644 --- a/dotscope/src/metadata/tables/genericparam/raw.rs +++ b/dotscope/src/metadata/tables/genericparam/raw.rs @@ -194,10 +194,11 @@ impl TableRow for GenericParamRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* number */ 2 + - /* flags */ 2 + - /* owner */ sizes.coded_index_bytes(CodedIndexType::TypeOrMethodDef) + - /* name */ sizes.str_bytes() + /* number */ 2u8 + /* flags */ .saturating_add(2) + /* owner */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::TypeOrMethodDef)) + /* name */ .saturating_add(sizes.str_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/genericparam/reader.rs b/dotscope/src/metadata/tables/genericparam/reader.rs index fbb0a167..59c355c1 100644 --- a/dotscope/src/metadata/tables/genericparam/reader.rs +++ b/dotscope/src/metadata/tables/genericparam/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for GenericParamRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(GenericParamRaw { rid, - token: Token::new(0x2A00_0000 + rid), + token: Token::new(0x2A00_0000u32.saturating_add(rid)), offset: *offset, number: u32::from(read_le_at::(data, offset)?), flags: u32::from(read_le_at::(data, offset)?), diff --git a/dotscope/src/metadata/tables/genericparamconstraint/raw.rs b/dotscope/src/metadata/tables/genericparamconstraint/raw.rs index ef55a649..244fe67c 100644 --- a/dotscope/src/metadata/tables/genericparamconstraint/raw.rs +++ b/dotscope/src/metadata/tables/genericparamconstraint/raw.rs @@ -235,8 +235,9 @@ impl TableRow for GenericParamConstraintRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* owner */ sizes.table_index_bytes(TableId::GenericParam) + - /* constraint */ sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef) + /* owner */ sizes.table_index_bytes(TableId::GenericParam) + /* constraint */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef)) + ) } } diff --git a/dotscope/src/metadata/tables/genericparamconstraint/reader.rs b/dotscope/src/metadata/tables/genericparamconstraint/reader.rs index a1d9cbef..c1eb35a1 100644 --- a/dotscope/src/metadata/tables/genericparamconstraint/reader.rs +++ b/dotscope/src/metadata/tables/genericparamconstraint/reader.rs @@ -16,7 +16,7 @@ impl RowReadable for GenericParamConstraintRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(GenericParamConstraintRaw { rid, - token: Token::new(0x2C00_0000 + rid), + token: Token::new(0x2C00_0000u32.saturating_add(rid)), offset: *offset, owner: read_le_at_dyn(data, offset, sizes.is_large(TableId::GenericParam))?, constraint: CodedIndex::read(data, offset, sizes, CodedIndexType::TypeDefOrRef)?, diff --git a/dotscope/src/metadata/tables/implmap/raw.rs b/dotscope/src/metadata/tables/implmap/raw.rs index 9786d7be..c7be38bd 100644 --- a/dotscope/src/metadata/tables/implmap/raw.rs +++ b/dotscope/src/metadata/tables/implmap/raw.rs @@ -245,10 +245,11 @@ impl TableRow for ImplMapRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* mapping_flags */ 2 + - /* member_forwarded */ sizes.coded_index_bytes(CodedIndexType::MemberForwarded) + - /* import_name */ sizes.str_bytes() + - /* import_scope */ sizes.table_index_bytes(TableId::ModuleRef) + /* mapping_flags */ 2u8 + /* member_forwarded */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::MemberForwarded)) + /* import_name */ .saturating_add(sizes.str_bytes()) + /* import_scope */ .saturating_add(sizes.table_index_bytes(TableId::ModuleRef)) + ) } } diff --git a/dotscope/src/metadata/tables/implmap/reader.rs b/dotscope/src/metadata/tables/implmap/reader.rs index b0125e26..cec637a6 100644 --- a/dotscope/src/metadata/tables/implmap/reader.rs +++ b/dotscope/src/metadata/tables/implmap/reader.rs @@ -28,7 +28,7 @@ impl RowReadable for ImplMapRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ImplMapRaw { rid, - token: Token::new(0x1C00_0000 + rid), + token: Token::new(0x1C00_0000u32.saturating_add(rid)), offset: *offset, mapping_flags: u32::from(read_le_at::(data, offset)?), member_forwarded: CodedIndex::read( diff --git a/dotscope/src/metadata/tables/importscope/raw.rs b/dotscope/src/metadata/tables/importscope/raw.rs index 4af2112b..34c9fb18 100644 --- a/dotscope/src/metadata/tables/importscope/raw.rs +++ b/dotscope/src/metadata/tables/importscope/raw.rs @@ -121,8 +121,9 @@ impl TableRow for ImportScopeRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* parent */ sizes.table_index_bytes(TableId::ImportScope) + - /* imports */ sizes.blob_bytes() + /* parent */ sizes.table_index_bytes(TableId::ImportScope) + /* imports */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/importscope/reader.rs b/dotscope/src/metadata/tables/importscope/reader.rs index 6743c1c1..28fae389 100644 --- a/dotscope/src/metadata/tables/importscope/reader.rs +++ b/dotscope/src/metadata/tables/importscope/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for ImportScopeRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ImportScopeRaw { rid, - token: Token::new(0x3500_0000 + rid), + token: Token::new(0x3500_0000u32.saturating_add(rid)), offset: *offset, parent: read_le_at_dyn(data, offset, sizes.is_large(TableId::ImportScope))?, imports: read_le_at_dyn(data, offset, sizes.is_large_blob())?, diff --git a/dotscope/src/metadata/tables/interfaceimpl/raw.rs b/dotscope/src/metadata/tables/interfaceimpl/raw.rs index d58ad542..5722641d 100644 --- a/dotscope/src/metadata/tables/interfaceimpl/raw.rs +++ b/dotscope/src/metadata/tables/interfaceimpl/raw.rs @@ -179,8 +179,9 @@ impl TableRow for InterfaceImplRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* class */ sizes.table_index_bytes(TableId::TypeDef) + - /* interface */ sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef) + /* class */ sizes.table_index_bytes(TableId::TypeDef) + /* interface */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef)) + ) } } diff --git a/dotscope/src/metadata/tables/interfaceimpl/reader.rs b/dotscope/src/metadata/tables/interfaceimpl/reader.rs index df2126a1..1a30f69d 100644 --- a/dotscope/src/metadata/tables/interfaceimpl/reader.rs +++ b/dotscope/src/metadata/tables/interfaceimpl/reader.rs @@ -30,7 +30,7 @@ impl RowReadable for InterfaceImplRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(InterfaceImplRaw { rid, - token: Token::new(0x0900_0000 + rid), + token: Token::new(0x0900_0000u32.saturating_add(rid)), offset: *offset, class: read_le_at_dyn(data, offset, sizes.is_large(TableId::TypeDef))?, interface: CodedIndex::read(data, offset, sizes, CodedIndexType::TypeDefOrRef)?, diff --git a/dotscope/src/metadata/tables/localconstant/raw.rs b/dotscope/src/metadata/tables/localconstant/raw.rs index 7e69999a..a0f0f4d2 100644 --- a/dotscope/src/metadata/tables/localconstant/raw.rs +++ b/dotscope/src/metadata/tables/localconstant/raw.rs @@ -130,8 +130,9 @@ impl TableRow for LocalConstantRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* name */ sizes.str_bytes() + - /* signature */ sizes.blob_bytes() + /* name */ sizes.str_bytes() + /* signature */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/localconstant/reader.rs b/dotscope/src/metadata/tables/localconstant/reader.rs index 937ffe1b..d23bbca1 100644 --- a/dotscope/src/metadata/tables/localconstant/reader.rs +++ b/dotscope/src/metadata/tables/localconstant/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for LocalConstantRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(LocalConstantRaw { rid, - token: Token::new(0x3400_0000 + rid), + token: Token::new(0x3400_0000u32.saturating_add(rid)), offset: *offset, name: read_le_at_dyn(data, offset, sizes.is_large_str())?, signature: read_le_at_dyn(data, offset, sizes.is_large_blob())?, diff --git a/dotscope/src/metadata/tables/localscope/owned.rs b/dotscope/src/metadata/tables/localscope/owned.rs index 1d6f4594..b5027c04 100644 --- a/dotscope/src/metadata/tables/localscope/owned.rs +++ b/dotscope/src/metadata/tables/localscope/owned.rs @@ -98,7 +98,7 @@ impl LocalScope { /// ``` #[must_use] pub fn end_offset(&self) -> u32 { - self.start_offset + self.length + self.start_offset.saturating_add(self.length) } /// Checks if this scope contains any local variables diff --git a/dotscope/src/metadata/tables/localscope/raw.rs b/dotscope/src/metadata/tables/localscope/raw.rs index 395cd392..57f7b130 100644 --- a/dotscope/src/metadata/tables/localscope/raw.rs +++ b/dotscope/src/metadata/tables/localscope/raw.rs @@ -29,7 +29,7 @@ use std::sync::Arc; /// /// Each `LocalScope` table entry consists of: /// - Method: Simple index into `MethodDef` table -/// - `ImportScope`: Simple index into `ImportScope` table +/// - `ImportScope`: Simple index into `ImportScope` table /// - `VariableList`: Simple index into `LocalVariable` table /// - `ConstantList`: Simple index into `LocalConstant` table /// - `StartOffset`: 4-byte unsigned integer (IL offset) @@ -112,50 +112,76 @@ impl LocalScopeRaw { constants: &LocalConstantMap, scope_table: &MetadataTable, ) -> Result { - let method_token = Token::new(0x0600_0000 + self.method); + let method_token = Token::new( + 0x0600_0000_u32 + .checked_add(self.method) + .ok_or_else(|| malformed_error!("Method token overflow: {}", self.method))?, + ); let method = methods .get(&method_token) .ok_or_else(|| malformed_error!("Invalid method index {} in LocalScope", self.method))? .value() .clone(); - let import_scope = if self.import_scope == 0 { - None - } else { - let import_token = Token::new(0x3500_0000 + self.import_scope); - Some( - import_scopes - .get(&import_token) - .ok_or_else(|| { - malformed_error!( - "Invalid import scope index {} in LocalScope", - self.import_scope - ) - })? - .value() - .clone(), + let import_scope = + if self.import_scope == 0 { + None + } else { + let import_token = + Token::new(0x3500_0000_u32.checked_add(self.import_scope).ok_or_else( + || malformed_error!("ImportScope token overflow: {}", self.import_scope), + )?); + Some( + import_scopes + .get(&import_token) + .ok_or_else(|| { + malformed_error!( + "Invalid import scope index {} in LocalScope", + self.import_scope + ) + })? + .value() + .clone(), + ) + }; + + let next_rid = self.rid.checked_add(1).ok_or_else(|| { + malformed_error!( + "LocalScope rid overflow when computing next rid: {}", + self.rid ) - }; + })?; let variables = if self.variable_list == 0 { Arc::new(boxcar::Vec::new()) } else { let start = self.variable_list; - #[allow(clippy::cast_possible_truncation)] - let end = if let Some(next_scope) = scope_table.get(self.rid + 1) { + let end = if let Some(next_scope) = scope_table.get(next_rid) { if next_scope.variable_list != 0 { next_scope.variable_list } else { - variables.len() as u32 + 1 + let len = u32::try_from(variables.len()).map_err(|_| { + malformed_error!("LocalVariable count exceeds u32: {}", variables.len()) + })?; + len.checked_add(1).ok_or_else(|| { + malformed_error!("LocalVariable end index overflow: {}", len) + })? } } else { - variables.len() as u32 + 1 + let len = u32::try_from(variables.len()).map_err(|_| { + malformed_error!("LocalVariable count exceeds u32: {}", variables.len()) + })?; + len.checked_add(1) + .ok_or_else(|| malformed_error!("LocalVariable end index overflow: {}", len))? }; let list = Arc::new(boxcar::Vec::new()); for i in start..end { - let var_token = Token::new(0x3300_0000 + i); + let var_token_value = 0x3300_0000_u32 + .checked_add(i) + .ok_or_else(|| malformed_error!("LocalVariable token overflow: {}", i))?; + let var_token = Token::new(var_token_value); if let Some(var_entry) = variables.get(&var_token) { list.push(var_entry.value().clone()); } @@ -168,20 +194,31 @@ impl LocalScopeRaw { } else { let start = self.constant_list; - #[allow(clippy::cast_possible_truncation)] - let end = if let Some(next_scope) = scope_table.get(self.rid + 1) { + let end = if let Some(next_scope) = scope_table.get(next_rid) { if next_scope.constant_list != 0 { next_scope.constant_list } else { - constants.len() as u32 + 1 + let len = u32::try_from(constants.len()).map_err(|_| { + malformed_error!("LocalConstant count exceeds u32: {}", constants.len()) + })?; + len.checked_add(1).ok_or_else(|| { + malformed_error!("LocalConstant end index overflow: {}", len) + })? } } else { - constants.len() as u32 + 1 + let len = u32::try_from(constants.len()).map_err(|_| { + malformed_error!("LocalConstant count exceeds u32: {}", constants.len()) + })?; + len.checked_add(1) + .ok_or_else(|| malformed_error!("LocalConstant end index overflow: {}", len))? }; let list = Arc::new(boxcar::Vec::new()); for i in start..end { - let const_token = Token::new(0x3400_0000 + i); + let const_token_value = 0x3400_0000_u32 + .checked_add(i) + .ok_or_else(|| malformed_error!("LocalConstant token overflow: {}", i))?; + let const_token = Token::new(const_token_value); if let Some(const_entry) = constants.get(&const_token) { list.push(const_entry.value().clone()); } @@ -220,12 +257,13 @@ impl TableRow for LocalScopeRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* method */ sizes.table_index_bytes(TableId::MethodDef) + - /* import_scope */ sizes.table_index_bytes(TableId::ImportScope) + - /* variable_list */ sizes.table_index_bytes(TableId::LocalVariable) + - /* constant_list */ sizes.table_index_bytes(TableId::LocalConstant) + - /* start_offset */ 4 + - /* length */ 4 + /* method */ sizes.table_index_bytes(TableId::MethodDef) + /* import_scope */ .saturating_add(sizes.table_index_bytes(TableId::ImportScope)) + /* variable_list */ .saturating_add(sizes.table_index_bytes(TableId::LocalVariable)) + /* constant_list */ .saturating_add(sizes.table_index_bytes(TableId::LocalConstant)) + /* start_offset */ .saturating_add(4) + /* length */ .saturating_add(4) + ) } } diff --git a/dotscope/src/metadata/tables/localscope/reader.rs b/dotscope/src/metadata/tables/localscope/reader.rs index d9264894..54c02f53 100644 --- a/dotscope/src/metadata/tables/localscope/reader.rs +++ b/dotscope/src/metadata/tables/localscope/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for LocalScopeRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(LocalScopeRaw { rid, - token: Token::new(0x3200_0000 + rid), + token: Token::new(0x3200_0000u32.saturating_add(rid)), offset: *offset, method: read_le_at_dyn(data, offset, sizes.is_large(TableId::MethodDef))?, import_scope: read_le_at_dyn(data, offset, sizes.is_large(TableId::ImportScope))?, diff --git a/dotscope/src/metadata/tables/localvariable/raw.rs b/dotscope/src/metadata/tables/localvariable/raw.rs index 0ea62557..1d8f7d4b 100644 --- a/dotscope/src/metadata/tables/localvariable/raw.rs +++ b/dotscope/src/metadata/tables/localvariable/raw.rs @@ -131,9 +131,10 @@ impl TableRow for LocalVariableRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - 2 + // attributes (always 2 bytes) - 2 + // index (always 2 bytes) - sizes.str_bytes() // name (strings heap index) + /* attributes */ 2u8 + /* index */ .saturating_add(2) + /* name */ .saturating_add(sizes.str_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/localvariable/reader.rs b/dotscope/src/metadata/tables/localvariable/reader.rs index 8bc9d186..41a40d5f 100644 --- a/dotscope/src/metadata/tables/localvariable/reader.rs +++ b/dotscope/src/metadata/tables/localvariable/reader.rs @@ -16,7 +16,7 @@ impl RowReadable for LocalVariableRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(LocalVariableRaw { rid, - token: Token::new(0x3300_0000 + rid), + token: Token::new(0x3300_0000u32.saturating_add(rid)), offset: *offset, attributes: read_le_at::(data, offset)?, index: read_le_at::(data, offset)?, diff --git a/dotscope/src/metadata/tables/manifestresource/raw.rs b/dotscope/src/metadata/tables/manifestresource/raw.rs index 2b982bb7..986df167 100644 --- a/dotscope/src/metadata/tables/manifestresource/raw.rs +++ b/dotscope/src/metadata/tables/manifestresource/raw.rs @@ -135,24 +135,63 @@ impl ManifestResourceRaw { // The resource format is: [4-byte length prefix][data bytes] // offset_field points to the length prefix let section_start = file.rva_to_offset(cor20.resource_rva as usize)?; - let length_prefix_offset = section_start + self.offset_field as usize; + let length_prefix_offset = section_start + .checked_add(self.offset_field as usize) + .ok_or_else(|| { + malformed_error!( + "ManifestResource length-prefix offset overflow: {} + {}", + section_start, + self.offset_field + ) + })?; // Read the 4-byte length prefix to get actual data size if let Ok(len_bytes) = file.data_slice(length_prefix_offset, 4) { - data_size = - u32::from_le_bytes([len_bytes[0], len_bytes[1], len_bytes[2], len_bytes[3]]) - as usize; + let b0 = *len_bytes.first().ok_or(out_of_bounds_error!())?; + let b1 = *len_bytes.get(1).ok_or(out_of_bounds_error!())?; + let b2 = *len_bytes.get(2).ok_or(out_of_bounds_error!())?; + let b3 = *len_bytes.get(3).ok_or(out_of_bounds_error!())?; + data_size = u32::from_le_bytes([b0, b1, b2, b3]) as usize; } else { // Fallback: calculate from offset differences (includes length prefix) - data_size = if let Some(next_res) = table.get(self.rid + 1) { - (next_res.offset_field as usize - self.offset_field as usize).saturating_sub(4) + let next_rid = self.rid.checked_add(1).ok_or_else(|| { + malformed_error!( + "ManifestResource rid overflow when computing next rid: {}", + self.rid + ) + })?; + data_size = if let Some(next_res) = table.get(next_rid) { + (next_res.offset_field as usize) + .checked_sub(self.offset_field as usize) + .ok_or_else(|| { + malformed_error!( + "ManifestResource offsets out of order: next={} current={}", + next_res.offset_field, + self.offset_field + ) + })? + .saturating_sub(4) } else { - (cor20.resource_size as usize - self.offset_field as usize).saturating_sub(4) + (cor20.resource_size as usize) + .checked_sub(self.offset_field as usize) + .ok_or_else(|| { + malformed_error!( + "ManifestResource offset {} exceeds resource section size {}", + self.offset_field, + cor20.resource_size + ) + })? + .saturating_sub(4) }; } // data_offset points to actual data (after the 4-byte length prefix) - data_offset = length_prefix_offset + 4; + data_offset = length_prefix_offset.checked_add(4).ok_or_else(|| { + malformed_error!( + "ManifestResource data offset overflow: {} + 4", + length_prefix_offset + ) + })?; None } else { let implementation = get_ref(&self.implementation); @@ -208,10 +247,11 @@ impl TableRow for ManifestResourceRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* offset_field */ 4 + - /* flags */ 4 + - /* name */ sizes.str_bytes() + - /* implementation */ sizes.coded_index_bytes(CodedIndexType::Implementation) + /* offset */ 4u8 + /* flags */ .saturating_add(4) + /* name */ .saturating_add(sizes.str_bytes()) + /* implementation */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::Implementation)) + ) } } diff --git a/dotscope/src/metadata/tables/manifestresource/reader.rs b/dotscope/src/metadata/tables/manifestresource/reader.rs index fe84d601..efb3dd79 100644 --- a/dotscope/src/metadata/tables/manifestresource/reader.rs +++ b/dotscope/src/metadata/tables/manifestresource/reader.rs @@ -15,7 +15,7 @@ impl RowReadable for ManifestResourceRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ManifestResourceRaw { rid, - token: Token::new(0x2800_0000 + rid), + token: Token::new(0x2800_0000u32.saturating_add(rid)), offset: *offset, offset_field: read_le_at::(data, offset)?, flags: read_le_at::(data, offset)?, diff --git a/dotscope/src/metadata/tables/memberref/raw.rs b/dotscope/src/metadata/tables/memberref/raw.rs index ce8ec0fa..7fa04f16 100644 --- a/dotscope/src/metadata/tables/memberref/raw.rs +++ b/dotscope/src/metadata/tables/memberref/raw.rs @@ -121,7 +121,9 @@ impl MemberRefRaw { method_sig: &SignatureMethod, _strings: &Strings, ) -> Arc> { - let params = Arc::new(boxcar::Vec::with_capacity(method_sig.params.len() + 1)); + let params = Arc::new(boxcar::Vec::with_capacity( + method_sig.params.len().saturating_add(1), + )); // Create return parameter (sequence 0) let return_param = Arc::new(Param { @@ -148,7 +150,7 @@ impl MemberRefRaw { offset: 0, flags: ParamAttributes::ZERO, #[allow(clippy::cast_possible_truncation)] - sequence: (index + 1) as u32, // Parameter sequence starts at 1 + sequence: index.saturating_add(1) as u32, // Parameter sequence starts at 1 name: None, // MemberRef parameters don't have names from metadata default: OnceLock::new(), marshal: OnceLock::new(), @@ -225,7 +227,7 @@ impl MemberRefRaw { return Err(malformed_error!("Invalid signature data")); } - let (signature, params) = if signature_data[0] == 0x6 { + let (signature, params) = if *signature_data.first().ok_or(out_of_bounds_error!())? == 0x6 { ( MemberRefSignature::Field(parse_field_signature(signature_data)?), Arc::new(boxcar::Vec::new()), @@ -245,7 +247,7 @@ impl MemberRefRaw { )?; } else { // Regular parameter - let index = (param.sequence - 1) as usize; + let index = param.sequence.saturating_sub(1) as usize; if let Some(param_signature) = method_sig.params.get(index) { param.apply_signature( param_signature, @@ -337,9 +339,10 @@ impl TableRow for MemberRefRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* class */ sizes.coded_index_bytes(CodedIndexType::MemberRefParent) + - /* name */ sizes.str_bytes() + - /* signature */ sizes.blob_bytes() + /* class */ sizes.coded_index_bytes(CodedIndexType::MemberRefParent) + /* name */ .saturating_add(sizes.str_bytes()) + /* signature */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/memberref/reader.rs b/dotscope/src/metadata/tables/memberref/reader.rs index 013afc61..5c5202a9 100644 --- a/dotscope/src/metadata/tables/memberref/reader.rs +++ b/dotscope/src/metadata/tables/memberref/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for MemberRefRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MemberRefRaw { rid, - token: Token::new(0x0A00_0000 + rid), + token: Token::new(0x0A00_0000u32.saturating_add(rid)), offset: *offset, class: CodedIndex::read(data, offset, sizes, CodedIndexType::MemberRefParent)?, name: read_le_at_dyn(data, offset, sizes.is_large_str())?, diff --git a/dotscope/src/metadata/tables/methoddebuginformation/raw.rs b/dotscope/src/metadata/tables/methoddebuginformation/raw.rs index 8e41aefe..a60760c4 100644 --- a/dotscope/src/metadata/tables/methoddebuginformation/raw.rs +++ b/dotscope/src/metadata/tables/methoddebuginformation/raw.rs @@ -222,8 +222,9 @@ impl TableRow for MethodDebugInformationRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - sizes.table_index_bytes(TableId::Document) + // document - sizes.blob_bytes() // sequence_points + /* document */ sizes.table_index_bytes(TableId::Document) + /* sequence_points */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/methoddebuginformation/reader.rs b/dotscope/src/metadata/tables/methoddebuginformation/reader.rs index 4cfdd395..19bde398 100644 --- a/dotscope/src/metadata/tables/methoddebuginformation/reader.rs +++ b/dotscope/src/metadata/tables/methoddebuginformation/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for MethodDebugInformationRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MethodDebugInformationRaw { rid, - token: Token::new(0x3100_0000 + rid), + token: Token::new(0x3100_0000u32.saturating_add(rid)), offset: *offset, document: read_le_at_dyn(data, offset, sizes.is_large(TableId::Document))?, sequence_points: read_le_at_dyn(data, offset, sizes.is_large_blob())?, diff --git a/dotscope/src/metadata/tables/methoddef/raw.rs b/dotscope/src/metadata/tables/methoddef/raw.rs index 901bd130..0e293916 100644 --- a/dotscope/src/metadata/tables/methoddef/raw.rs +++ b/dotscope/src/metadata/tables/methoddef/raw.rs @@ -246,10 +246,13 @@ impl MethodDefRaw { let type_params = if self.param_list == 0 || params_map.is_empty() { Arc::new(boxcar::Vec::new()) } else { - let next_row_id = self.rid + 1; + let next_row_id = self + .rid + .checked_add(1) + .ok_or_else(|| malformed_error!("MethodDef rid overflow: {}", self.rid))?; let start = self.param_list as usize; let end = if next_row_id > table.row_count { - params_map.len() + 1 + params_map.len().saturating_add(1) } else { match table.get(next_row_id) { Some(next_row) => next_row.param_list as usize, @@ -262,10 +265,10 @@ impl MethodDefRaw { } }; - if start > params_map.len() || end > (params_map.len() + 1) || end < start { + if start > params_map.len() || end > params_map.len().saturating_add(1) || end < start { Arc::new(boxcar::Vec::new()) } else { - let type_params = Arc::new(boxcar::Vec::with_capacity(end - start)); + let type_params = Arc::new(boxcar::Vec::with_capacity(end.saturating_sub(start))); for counter in start..end { let actual_param_token = if param_ptr_map.is_empty() { let token_value = u32::try_from(counter | 0x0800_0000).map_err(|_| { @@ -382,12 +385,13 @@ impl TableRow for MethodDefRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* rva */ 4 + - /* impl_flags */ 2 + - /* flags */ 2 + - /* name */ sizes.str_bytes() + - /* signature */ sizes.blob_bytes() + - /* param_list */ sizes.table_index_bytes(TableId::Param) + /* rva */ 4u8 + /* impl_flags */ .saturating_add(2) + /* flags */ .saturating_add(2) + /* name */ .saturating_add(sizes.str_bytes()) + /* signature */ .saturating_add(sizes.blob_bytes()) + /* param_list */ .saturating_add(sizes.table_index_bytes(TableId::Param)) + ) } } diff --git a/dotscope/src/metadata/tables/methoddef/reader.rs b/dotscope/src/metadata/tables/methoddef/reader.rs index 13ceaf22..1d50a68b 100644 --- a/dotscope/src/metadata/tables/methoddef/reader.rs +++ b/dotscope/src/metadata/tables/methoddef/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for MethodDefRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MethodDefRaw { rid, - token: Token::new(0x0600_0000 + rid), + token: Token::new(0x0600_0000u32.saturating_add(rid)), offset: *offset, rva: read_le_at::(data, offset)?, impl_flags: u32::from(read_le_at::(data, offset)?), diff --git a/dotscope/src/metadata/tables/methodimpl/raw.rs b/dotscope/src/metadata/tables/methodimpl/raw.rs index 0cf59669..939f36eb 100644 --- a/dotscope/src/metadata/tables/methodimpl/raw.rs +++ b/dotscope/src/metadata/tables/methodimpl/raw.rs @@ -302,9 +302,10 @@ impl TableRow for MethodImplRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* class */ sizes.table_index_bytes(TableId::TypeDef) + - /* method_body */ sizes.coded_index_bytes(CodedIndexType::MethodDefOrRef) + - /* method_declaration */ sizes.coded_index_bytes(CodedIndexType::MethodDefOrRef) + /* class */ sizes.table_index_bytes(TableId::TypeDef) + /* method_body */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::MethodDefOrRef)) + /* method_declaration */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::MethodDefOrRef)) + ) } } diff --git a/dotscope/src/metadata/tables/methodimpl/reader.rs b/dotscope/src/metadata/tables/methodimpl/reader.rs index 9d3a2688..9d328428 100644 --- a/dotscope/src/metadata/tables/methodimpl/reader.rs +++ b/dotscope/src/metadata/tables/methodimpl/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for MethodImplRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MethodImplRaw { rid, - token: Token::new(0x1900_0000 + rid), + token: Token::new(0x1900_0000u32.saturating_add(rid)), offset: *offset, class: read_le_at_dyn(data, offset, sizes.is_large(TableId::TypeDef))?, method_body: CodedIndex::read(data, offset, sizes, CodedIndexType::MethodDefOrRef)?, diff --git a/dotscope/src/metadata/tables/methodptr/raw.rs b/dotscope/src/metadata/tables/methodptr/raw.rs index 9a7af0c9..353b1f60 100644 --- a/dotscope/src/metadata/tables/methodptr/raw.rs +++ b/dotscope/src/metadata/tables/methodptr/raw.rs @@ -149,6 +149,7 @@ impl TableRow for MethodPtrRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* method */ sizes.table_index_bytes(TableId::MethodDef) + ) } } diff --git a/dotscope/src/metadata/tables/methodptr/reader.rs b/dotscope/src/metadata/tables/methodptr/reader.rs index 81aed724..118bdf42 100644 --- a/dotscope/src/metadata/tables/methodptr/reader.rs +++ b/dotscope/src/metadata/tables/methodptr/reader.rs @@ -13,7 +13,7 @@ impl RowReadable for MethodPtrRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MethodPtrRaw { rid, - token: Token::new(0x0500_0000 + rid), + token: Token::new(0x0500_0000u32.saturating_add(rid)), offset: *offset, method: read_le_at_dyn(data, offset, sizes.is_large(TableId::MethodDef))?, }) diff --git a/dotscope/src/metadata/tables/methodsemantics/mod.rs b/dotscope/src/metadata/tables/methodsemantics/mod.rs index 8271cde7..828a3fef 100644 --- a/dotscope/src/metadata/tables/methodsemantics/mod.rs +++ b/dotscope/src/metadata/tables/methodsemantics/mod.rs @@ -66,7 +66,7 @@ pub(crate) use loader::*; pub use owned::*; pub use raw::*; -/// Thread-safe map holding the mapping of [`crate::metadata::token::Token`] to parsed [`MethodSemantics`] entries. +/// Thread-safe map holding the mapping of [`Token`] to parsed [`MethodSemantics`] entries. /// /// This concurrent skip list provides efficient O(log n) access to method semantics entries /// by their metadata token, supporting multiple concurrent readers and writers. diff --git a/dotscope/src/metadata/tables/methodsemantics/raw.rs b/dotscope/src/metadata/tables/methodsemantics/raw.rs index 1988dfb9..91a5ea37 100644 --- a/dotscope/src/metadata/tables/methodsemantics/raw.rs +++ b/dotscope/src/metadata/tables/methodsemantics/raw.rs @@ -163,7 +163,7 @@ pub struct MethodSemanticsRaw { /// /// 2-byte value defining the method's semantic role using [`MethodSemanticsAttributes`]: /// - `SETTER` (0x0001) - Property setter method - /// - `GETTER` (0x0002) - Property getter method + /// - `GETTER` (0x0002) - Property getter method /// - `OTHER` (0x0004) - Other property/event method /// - `ADD_ON` (0x0008) - Event add method /// - `REMOVE_ON` (0x0010) - Event remove method @@ -369,9 +369,10 @@ impl TableRow for MethodSemanticsRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* semantics */ 2 + - /* method */ sizes.table_index_bytes(TableId::MethodDef) + - /* association */ sizes.coded_index_bytes(CodedIndexType::HasSemantics) + /* semantics */ 2u8 + /* method */ .saturating_add(sizes.table_index_bytes(TableId::MethodDef)) + /* association */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::HasSemantics)) + ) } } diff --git a/dotscope/src/metadata/tables/methodsemantics/reader.rs b/dotscope/src/metadata/tables/methodsemantics/reader.rs index 622f515e..4631ea5a 100644 --- a/dotscope/src/metadata/tables/methodsemantics/reader.rs +++ b/dotscope/src/metadata/tables/methodsemantics/reader.rs @@ -36,7 +36,7 @@ impl RowReadable for MethodSemanticsRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MethodSemanticsRaw { rid, - token: Token::new(0x1800_0000 + rid), + token: Token::new(0x1800_0000u32.saturating_add(rid)), offset: *offset, semantics: u32::from(read_le_at::(data, offset)?), method: read_le_at_dyn(data, offset, sizes.is_large(TableId::MethodDef))?, diff --git a/dotscope/src/metadata/tables/methodspec/builder.rs b/dotscope/src/metadata/tables/methodspec/builder.rs index c84b6015..93bf0119 100644 --- a/dotscope/src/metadata/tables/methodspec/builder.rs +++ b/dotscope/src/metadata/tables/methodspec/builder.rs @@ -317,7 +317,12 @@ impl MethodSpecBuilder { )); } - let arg_count = instantiation[0]; + let arg_count = *instantiation.first().ok_or_else(|| { + Error::ModificationInvalid( + "Instantiation signature must contain at least the generic argument count" + .to_string(), + ) + })?; if arg_count == 0 { return Err(Error::ModificationInvalid( "Generic argument count cannot be zero".to_string(), diff --git a/dotscope/src/metadata/tables/methodspec/mod.rs b/dotscope/src/metadata/tables/methodspec/mod.rs index 848513f8..c84eda49 100644 --- a/dotscope/src/metadata/tables/methodspec/mod.rs +++ b/dotscope/src/metadata/tables/methodspec/mod.rs @@ -72,7 +72,7 @@ pub(crate) use loader::*; pub use owned::*; pub use raw::*; -/// Thread-safe map holding the mapping of [`crate::metadata::token::Token`] to parsed [`MethodSpec`] entries. +/// Thread-safe map holding the mapping of [`Token`] to parsed [`MethodSpec`] entries. /// /// This concurrent skip list provides efficient O(log n) access to method specification entries /// by their metadata token, supporting multiple concurrent readers and writers. diff --git a/dotscope/src/metadata/tables/methodspec/raw.rs b/dotscope/src/metadata/tables/methodspec/raw.rs index 30533e32..680ae24c 100644 --- a/dotscope/src/metadata/tables/methodspec/raw.rs +++ b/dotscope/src/metadata/tables/methodspec/raw.rs @@ -202,8 +202,9 @@ impl TableRow for MethodSpecRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* method */ sizes.coded_index_bytes(CodedIndexType::MethodDefOrRef) + - /* instantiation */ sizes.blob_bytes() + /* method */ sizes.coded_index_bytes(CodedIndexType::MethodDefOrRef) + /* instantiation */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/methodspec/reader.rs b/dotscope/src/metadata/tables/methodspec/reader.rs index b52c2df9..f07444d1 100644 --- a/dotscope/src/metadata/tables/methodspec/reader.rs +++ b/dotscope/src/metadata/tables/methodspec/reader.rs @@ -33,7 +33,7 @@ impl RowReadable for MethodSpecRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(MethodSpecRaw { rid, - token: Token::new(0x2B00_0000 + rid), + token: Token::new(0x2B00_0000u32.saturating_add(rid)), offset: *offset, method: CodedIndex::read(data, offset, sizes, CodedIndexType::MethodDefOrRef)?, instantiation: read_le_at_dyn(data, offset, sizes.is_large_blob())?, diff --git a/dotscope/src/metadata/tables/module/builder.rs b/dotscope/src/metadata/tables/module/builder.rs index 434cb63e..f8b41a0f 100644 --- a/dotscope/src/metadata/tables/module/builder.rs +++ b/dotscope/src/metadata/tables/module/builder.rs @@ -312,7 +312,7 @@ impl ModuleBuilder { .name .ok_or_else(|| Error::ModificationInvalid("name field is required".to_string()))?; - let existing_count = assembly.next_rid(TableId::Module)? - 1; + let existing_count = assembly.next_rid(TableId::Module)?.saturating_sub(1); if existing_count > 0 { return Err(Error::ModificationInvalid( "Module table already contains an entry. Only one module per assembly is allowed." diff --git a/dotscope/src/metadata/tables/module/mod.rs b/dotscope/src/metadata/tables/module/mod.rs index 26644b48..7aa5229d 100644 --- a/dotscope/src/metadata/tables/module/mod.rs +++ b/dotscope/src/metadata/tables/module/mod.rs @@ -72,7 +72,7 @@ pub(crate) use loader::*; pub use owned::*; pub use raw::*; -/// Thread-safe map holding the mapping of [`crate::metadata::token::Token`] to parsed [`Module`] entries. +/// Thread-safe map holding the mapping of [`Token`] to parsed [`Module`] entries. /// /// This concurrent skip list provides efficient O(log n) access to module entries /// by their metadata token. Since the Module table contains only one entry, this diff --git a/dotscope/src/metadata/tables/module/raw.rs b/dotscope/src/metadata/tables/module/raw.rs index 91ce593c..1c70e30f 100644 --- a/dotscope/src/metadata/tables/module/raw.rs +++ b/dotscope/src/metadata/tables/module/raw.rs @@ -181,11 +181,12 @@ impl TableRow for ModuleRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* generation */ 2 + - /* name */ sizes.str_bytes() + - /* mvid */ sizes.guid_bytes() + - /* encid */ sizes.guid_bytes() + - /* encbaseid */ sizes.guid_bytes() + /* generation */ 2u8 + /* name */ .saturating_add(sizes.str_bytes()) + /* mvid */ .saturating_add(sizes.guid_bytes()) + /* encid */ .saturating_add(sizes.guid_bytes()) + /* encbaseid */ .saturating_add(sizes.guid_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/moduleref/raw.rs b/dotscope/src/metadata/tables/moduleref/raw.rs index ea8793e1..323a4983 100644 --- a/dotscope/src/metadata/tables/moduleref/raw.rs +++ b/dotscope/src/metadata/tables/moduleref/raw.rs @@ -142,6 +142,7 @@ impl TableRow for ModuleRefRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* name */ sizes.str_bytes() + ) } } diff --git a/dotscope/src/metadata/tables/moduleref/reader.rs b/dotscope/src/metadata/tables/moduleref/reader.rs index 736af8f1..dcf552de 100644 --- a/dotscope/src/metadata/tables/moduleref/reader.rs +++ b/dotscope/src/metadata/tables/moduleref/reader.rs @@ -82,7 +82,7 @@ impl RowReadable for ModuleRefRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ModuleRefRaw { rid, - token: Token::new(0x1A00_0000 + rid), + token: Token::new(0x1A00_0000u32.saturating_add(rid)), offset: *offset, name: read_le_at_dyn(data, offset, sizes.is_large_str())?, }) diff --git a/dotscope/src/metadata/tables/nestedclass/raw.rs b/dotscope/src/metadata/tables/nestedclass/raw.rs index 78eb363c..b225c0af 100644 --- a/dotscope/src/metadata/tables/nestedclass/raw.rs +++ b/dotscope/src/metadata/tables/nestedclass/raw.rs @@ -207,8 +207,9 @@ impl TableRow for NestedClassRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* nested_class */ sizes.table_index_bytes(TableId::TypeDef) + - /* enclosing_class */ sizes.table_index_bytes(TableId::TypeDef) + /* nested_class */ sizes.table_index_bytes(TableId::TypeDef) + /* enclosing_class */ .saturating_add(sizes.table_index_bytes(TableId::TypeDef)) + ) } } diff --git a/dotscope/src/metadata/tables/nestedclass/reader.rs b/dotscope/src/metadata/tables/nestedclass/reader.rs index edccec1f..625d8a59 100644 --- a/dotscope/src/metadata/tables/nestedclass/reader.rs +++ b/dotscope/src/metadata/tables/nestedclass/reader.rs @@ -83,7 +83,7 @@ impl RowReadable for NestedClassRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(NestedClassRaw { rid, - token: Token::new(0x2900_0000 + rid), + token: Token::new(0x2900_0000u32.saturating_add(rid)), offset: *offset, nested_class: read_le_at_dyn(data, offset, sizes.is_large(TableId::TypeDef))?, enclosing_class: read_le_at_dyn(data, offset, sizes.is_large(TableId::TypeDef))?, diff --git a/dotscope/src/metadata/tables/param/raw.rs b/dotscope/src/metadata/tables/param/raw.rs index 4ed9f931..a12b025b 100644 --- a/dotscope/src/metadata/tables/param/raw.rs +++ b/dotscope/src/metadata/tables/param/raw.rs @@ -181,9 +181,10 @@ impl TableRow for ParamRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 2 + - /* sequence */ 2 + - /* name */ sizes.str_bytes() + /* flags */ 2u8 + /* sequence */ .saturating_add(2) + /* name */ .saturating_add(sizes.str_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/param/reader.rs b/dotscope/src/metadata/tables/param/reader.rs index 20a77ca9..a2e93374 100644 --- a/dotscope/src/metadata/tables/param/reader.rs +++ b/dotscope/src/metadata/tables/param/reader.rs @@ -95,7 +95,7 @@ impl RowReadable for ParamRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ParamRaw { rid, - token: Token::new(0x0800_0000 + rid), + token: Token::new(0x0800_0000u32.saturating_add(rid)), offset: *offset, flags: u32::from(read_le_at::(data, offset)?), sequence: u32::from(read_le_at::(data, offset)?), diff --git a/dotscope/src/metadata/tables/paramptr/raw.rs b/dotscope/src/metadata/tables/paramptr/raw.rs index b31d33a9..805283e1 100644 --- a/dotscope/src/metadata/tables/paramptr/raw.rs +++ b/dotscope/src/metadata/tables/paramptr/raw.rs @@ -148,6 +148,7 @@ impl TableRow for ParamPtrRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* param */ sizes.table_index_bytes(TableId::Param) + ) } } diff --git a/dotscope/src/metadata/tables/paramptr/reader.rs b/dotscope/src/metadata/tables/paramptr/reader.rs index c640d948..795cf227 100644 --- a/dotscope/src/metadata/tables/paramptr/reader.rs +++ b/dotscope/src/metadata/tables/paramptr/reader.rs @@ -84,7 +84,7 @@ impl RowReadable for ParamPtrRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(ParamPtrRaw { rid, - token: Token::new(0x0700_0000 + rid), + token: Token::new(0x0700_0000u32.saturating_add(rid)), offset: *offset, param: read_le_at_dyn(data, offset, sizes.is_large(TableId::Param))?, }) diff --git a/dotscope/src/metadata/tables/property/raw.rs b/dotscope/src/metadata/tables/property/raw.rs index b9bb639e..8ec31291 100644 --- a/dotscope/src/metadata/tables/property/raw.rs +++ b/dotscope/src/metadata/tables/property/raw.rs @@ -167,9 +167,10 @@ impl TableRow for PropertyRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 2 + - /* name */ sizes.str_bytes() + - /* type_signature */ sizes.blob_bytes() + /* flags */ 2u8 + /* name */ .saturating_add(sizes.str_bytes()) + /* type_signature */ .saturating_add(sizes.blob_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/property/reader.rs b/dotscope/src/metadata/tables/property/reader.rs index 6d44f4a7..6a299100 100644 --- a/dotscope/src/metadata/tables/property/reader.rs +++ b/dotscope/src/metadata/tables/property/reader.rs @@ -87,7 +87,7 @@ impl RowReadable for PropertyRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(PropertyRaw { rid, - token: Token::new(0x1700_0000 + rid), + token: Token::new(0x1700_0000u32.saturating_add(rid)), offset: *offset, flags: u32::from(read_le_at::(data, offset)?), name: read_le_at_dyn(data, offset, sizes.is_large_str())?, diff --git a/dotscope/src/metadata/tables/propertymap/raw.rs b/dotscope/src/metadata/tables/propertymap/raw.rs index c4b32d27..11c935f8 100644 --- a/dotscope/src/metadata/tables/propertymap/raw.rs +++ b/dotscope/src/metadata/tables/propertymap/raw.rs @@ -102,10 +102,13 @@ impl PropertyMapRaw { return Ok(Arc::new(boxcar::Vec::new())); } - let next_row_id = self.rid + 1; + let next_row_id = self + .rid + .checked_add(1) + .ok_or_else(|| malformed_error!("PropertyMap rid overflow: {}", self.rid))?; let start = self.property_list as usize; let end = if next_row_id > map.row_count { - properties.len() + 1 + properties.len().saturating_add(1) } else { match map.get(next_row_id) { Some(next_row) => next_row.property_list as usize, @@ -118,11 +121,11 @@ impl PropertyMapRaw { } }; - if start > properties.len() || end > (properties.len() + 1) || end < start { + if start > properties.len() || end > properties.len().saturating_add(1) || end < start { return Ok(Arc::new(boxcar::Vec::new())); } - let property_list = Arc::new(boxcar::Vec::with_capacity(end - start)); + let property_list = Arc::new(boxcar::Vec::with_capacity(end.saturating_sub(start))); for counter in start..end { let actual_property_token = if property_ptr.is_empty() { let token_value = counter | 0x1700_0000; @@ -187,7 +190,7 @@ impl PropertyMapRaw { /// ## Arguments /// * `types` - The [`crate::metadata::typesystem::TypeRegistry`] for resolving parent types /// * `properties` - Map of all resolved `Property` entries for lookup - /// * `property_ptr` - Map of `PropertyPtr` entries for indirection resolution + /// * `property_ptr` - Map of `PropertyPtr` entries for indirection resolution /// * `map` - The `PropertyMap` table for determining property ranges /// /// ## Returns @@ -291,8 +294,9 @@ impl TableRow for PropertyMapRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* parent */ sizes.table_index_bytes(TableId::TypeDef) + - /* property_list */ sizes.table_index_bytes(TableId::Property) + /* parent */ sizes.table_index_bytes(TableId::TypeDef) + /* property_list */ .saturating_add(sizes.table_index_bytes(TableId::Property)) + ) } } diff --git a/dotscope/src/metadata/tables/propertymap/reader.rs b/dotscope/src/metadata/tables/propertymap/reader.rs index 30efa00c..e4aa06e9 100644 --- a/dotscope/src/metadata/tables/propertymap/reader.rs +++ b/dotscope/src/metadata/tables/propertymap/reader.rs @@ -91,7 +91,7 @@ impl RowReadable for PropertyMapRaw { Ok(PropertyMapRaw { rid, - token: Token::new(0x1500_0000 + rid), + token: Token::new(0x1500_0000u32.saturating_add(rid)), offset: offset_org, parent, property_list, diff --git a/dotscope/src/metadata/tables/propertyptr/raw.rs b/dotscope/src/metadata/tables/propertyptr/raw.rs index ad1d8057..fc26facf 100644 --- a/dotscope/src/metadata/tables/propertyptr/raw.rs +++ b/dotscope/src/metadata/tables/propertyptr/raw.rs @@ -123,6 +123,7 @@ impl TableRow for PropertyPtrRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* property */ sizes.table_index_bytes(TableId::Property) + ) } } diff --git a/dotscope/src/metadata/tables/propertyptr/reader.rs b/dotscope/src/metadata/tables/propertyptr/reader.rs index 4bb4e5d9..781a30d0 100644 --- a/dotscope/src/metadata/tables/propertyptr/reader.rs +++ b/dotscope/src/metadata/tables/propertyptr/reader.rs @@ -76,7 +76,7 @@ impl RowReadable for PropertyPtrRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(PropertyPtrRaw { rid, - token: Token::new(0x1600_0000 + rid), + token: Token::new(0x1600_0000u32.saturating_add(rid)), offset: *offset, property: read_le_at_dyn(data, offset, sizes.is_large(TableId::Property))?, }) diff --git a/dotscope/src/metadata/tables/standalonesig/raw.rs b/dotscope/src/metadata/tables/standalonesig/raw.rs index 9fd8b08b..9c64fba8 100644 --- a/dotscope/src/metadata/tables/standalonesig/raw.rs +++ b/dotscope/src/metadata/tables/standalonesig/raw.rs @@ -105,7 +105,9 @@ impl StandAloneSigRaw { )); } - let first_byte = sig_data[0]; + let first_byte = *sig_data.first().ok_or_else(|| { + malformed_error!("StandAloneSig blob is empty at index {}", self.signature) + })?; let parsed_signature = match first_byte { SIGNATURE_HEADER::LOCAL_SIG => { let locals = parse_local_var_signature(sig_data)?; @@ -179,6 +181,7 @@ impl TableRow for StandAloneSigRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* signature */ sizes.blob_bytes() + ) } } diff --git a/dotscope/src/metadata/tables/standalonesig/reader.rs b/dotscope/src/metadata/tables/standalonesig/reader.rs index 34ebc5d1..1e2b7dbd 100644 --- a/dotscope/src/metadata/tables/standalonesig/reader.rs +++ b/dotscope/src/metadata/tables/standalonesig/reader.rs @@ -81,7 +81,7 @@ impl RowReadable for StandAloneSigRaw { Ok(StandAloneSigRaw { rid, - token: Token::new(0x1100_0000 + rid), + token: Token::new(0x1100_0000u32.saturating_add(rid)), offset: offset_org, signature, }) diff --git a/dotscope/src/metadata/tables/statemachinemethod/raw.rs b/dotscope/src/metadata/tables/statemachinemethod/raw.rs index b593a00d..9f4d8c26 100644 --- a/dotscope/src/metadata/tables/statemachinemethod/raw.rs +++ b/dotscope/src/metadata/tables/statemachinemethod/raw.rs @@ -146,8 +146,9 @@ impl TableRow for StateMachineMethodRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - sizes.table_index_bytes(TableId::MethodDef) + // move_next_method (MethodDef table index) - sizes.table_index_bytes(TableId::MethodDef) // kickoff_method (MethodDef table index) + /* move_next_method */ sizes.table_index_bytes(TableId::MethodDef) + /* kickoff_method */ .saturating_add(sizes.table_index_bytes(TableId::MethodDef)) + ) } } diff --git a/dotscope/src/metadata/tables/statemachinemethod/reader.rs b/dotscope/src/metadata/tables/statemachinemethod/reader.rs index c791c9e8..7e87c304 100644 --- a/dotscope/src/metadata/tables/statemachinemethod/reader.rs +++ b/dotscope/src/metadata/tables/statemachinemethod/reader.rs @@ -53,7 +53,7 @@ impl RowReadable for StateMachineMethodRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(StateMachineMethodRaw { rid, - token: Token::new(0x3600_0000 + rid), + token: Token::new(0x3600_0000u32.saturating_add(rid)), offset: *offset, move_next_method: read_le_at_dyn(data, offset, sizes.is_large(TableId::MethodDef))?, kickoff_method: read_le_at_dyn(data, offset, sizes.is_large(TableId::MethodDef))?, diff --git a/dotscope/src/metadata/tables/typedef/raw.rs b/dotscope/src/metadata/tables/typedef/raw.rs index 51f0cda9..a6e2e957 100644 --- a/dotscope/src/metadata/tables/typedef/raw.rs +++ b/dotscope/src/metadata/tables/typedef/raw.rs @@ -114,7 +114,7 @@ impl TypeDefRaw { /// /// ## Arguments /// * `get_ref` - Closure to resolve coded indexes to type references - /// * `strings` - The #String heap for resolving names and namespaces + /// * `strings` - The #String heap for resolving names and namespaces /// * `fields` - Map of all processed Field entries indexed by token /// * `field_ptr` - Map of `FieldPtr` entries for indirection resolution /// * `methods` - Map of all processed Method entries indexed by token @@ -149,30 +149,53 @@ impl TypeDefRaw { where F: Fn(&CodedIndex) -> CilTypeReference, { - let (end_fields, end_methods) = if self.rid + 1 > defs.row_count { - (fields.len() + 1, methods.len() + 1) + let next_rid = self.rid.checked_add(1).ok_or_else(|| { + malformed_error!("TypeDef rid overflow when computing next rid: {}", self.rid) + })?; + + let (end_fields, end_methods) = if next_rid > defs.row_count { + let fields_end = fields + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Field count overflow: {}", fields.len()))?; + let methods_end = methods + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Method count overflow: {}", methods.len()))?; + (fields_end, methods_end) } else { - match defs.get(self.rid + 1) { + match defs.get(next_rid) { Some(next_row) => (next_row.field_list as usize, next_row.method_list as usize), None => { return Err(malformed_error!( "Failed to resolve fields_end from next row - {}", - self.rid + 1 + next_rid )) } } }; let start_fields = self.field_list as usize; + let fields_len_plus_one = fields + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Field count overflow: {}", fields.len()))?; let type_fields = if self.field_list == 0 || fields.is_empty() - || end_fields > fields.len() + 1 + || end_fields > fields_len_plus_one || start_fields > fields.len() || end_fields <= start_fields { Arc::new(boxcar::Vec::new()) } else { - let type_fields = Arc::new(boxcar::Vec::with_capacity(end_fields - start_fields)); + let capacity = end_fields.checked_sub(start_fields).ok_or_else(|| { + malformed_error!( + "TypeDef field range invalid: end {} < start {}", + end_fields, + start_fields + ) + })?; + let type_fields = Arc::new(boxcar::Vec::with_capacity(capacity)); for counter in start_fields..end_fields { let actual_field_token = if field_ptr.is_empty() { Token::new(u32::try_from(counter | 0x0400_0000).map_err(|_| { @@ -226,15 +249,26 @@ impl TypeDefRaw { }; let start_methods = self.method_list as usize; + let methods_len_plus_one = methods + .len() + .checked_add(1) + .ok_or_else(|| malformed_error!("Method count overflow: {}", methods.len()))?; let type_methods = if self.method_list == 0 || methods.is_empty() - || end_methods > methods.len() + 1 + || end_methods > methods_len_plus_one || start_methods > methods.len() || end_methods < start_methods { Arc::new(boxcar::Vec::new()) } else { - let type_methods = Arc::new(boxcar::Vec::with_capacity(end_methods - start_methods)); + let capacity = end_methods.checked_sub(start_methods).ok_or_else(|| { + malformed_error!( + "TypeDef method range invalid: end {} < start {}", + end_methods, + start_methods + ) + })?; + let type_methods = Arc::new(boxcar::Vec::with_capacity(capacity)); for counter in start_methods..end_methods { let actual_method_token = if method_ptr.is_empty() { Token::new(u32::try_from(counter | 0x0600_0000).map_err(|_| { @@ -393,12 +427,13 @@ impl TableRow for TypeDefRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* flags */ 4 + - /* type_name */ sizes.str_bytes() + - /* type_namespace */ sizes.str_bytes() + - /* extends */ sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef) + - /* field_list */ sizes.table_index_bytes(TableId::Field) + - /* method_list */ sizes.table_index_bytes(TableId::MethodDef) + /* flags */ 4u8 + /* type_name */ .saturating_add(sizes.str_bytes()) + /* type_namespace */ .saturating_add(sizes.str_bytes()) + /* extends */ .saturating_add(sizes.coded_index_bytes(CodedIndexType::TypeDefOrRef)) + /* field_list */ .saturating_add(sizes.table_index_bytes(TableId::Field)) + /* method_list */ .saturating_add(sizes.table_index_bytes(TableId::MethodDef)) + ) } } diff --git a/dotscope/src/metadata/tables/typedef/reader.rs b/dotscope/src/metadata/tables/typedef/reader.rs index e08f21cc..b1848cc6 100644 --- a/dotscope/src/metadata/tables/typedef/reader.rs +++ b/dotscope/src/metadata/tables/typedef/reader.rs @@ -81,7 +81,7 @@ impl RowReadable for TypeDefRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(TypeDefRaw { rid, - token: Token::new(0x0200_0000 + rid), + token: Token::new(0x0200_0000u32.saturating_add(rid)), offset: *offset, flags: read_le_at::(data, offset)?, type_name: read_le_at_dyn(data, offset, sizes.is_large_str())?, diff --git a/dotscope/src/metadata/tables/typeref/raw.rs b/dotscope/src/metadata/tables/typeref/raw.rs index 262ea065..88370c9c 100644 --- a/dotscope/src/metadata/tables/typeref/raw.rs +++ b/dotscope/src/metadata/tables/typeref/raw.rs @@ -176,9 +176,10 @@ impl TableRow for TypeRefRaw { #[rustfmt::skip] fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( - /* resolution_scope */ sizes.coded_index_bytes(CodedIndexType::ResolutionScope) + - /* type_namespace */ sizes.str_bytes() + - /* type_name */ sizes.str_bytes() + /* resolution_scope */ sizes.coded_index_bytes(CodedIndexType::ResolutionScope) + /* type_namespace */ .saturating_add(sizes.str_bytes()) + /* type_name */ .saturating_add(sizes.str_bytes()) + ) } } diff --git a/dotscope/src/metadata/tables/typeref/reader.rs b/dotscope/src/metadata/tables/typeref/reader.rs index 6ffb21b7..d5708eda 100644 --- a/dotscope/src/metadata/tables/typeref/reader.rs +++ b/dotscope/src/metadata/tables/typeref/reader.rs @@ -72,7 +72,7 @@ impl RowReadable for TypeRefRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(TypeRefRaw { rid, - token: Token::new(0x0100_0000 + rid), + token: Token::new(0x0100_0000u32.saturating_add(rid)), offset: *offset, resolution_scope: CodedIndex::read( data, diff --git a/dotscope/src/metadata/tables/types/common/codedindex.rs b/dotscope/src/metadata/tables/types/common/codedindex.rs index 9813d954..133372a4 100644 --- a/dotscope/src/metadata/tables/types/common/codedindex.rs +++ b/dotscope/src/metadata/tables/types/common/codedindex.rs @@ -451,7 +451,11 @@ impl CodedIndex { pub fn null(ci_type: CodedIndexType) -> CodedIndex { // The first table in the coded index type's table list has tag 0, // so (row=0, tag=0) encodes to 0 in the binary format - let first_table = ci_type.tables()[0]; + let first_table = ci_type + .tables() + .first() + .copied() + .unwrap_or(TableId::TypeDef); CodedIndex::new(first_table, 0, ci_type) } diff --git a/dotscope/src/metadata/tables/types/common/info.rs b/dotscope/src/metadata/tables/types/common/info.rs index f15c6c4b..2222bb3d 100644 --- a/dotscope/src/metadata/tables/types/common/info.rs +++ b/dotscope/src/metadata/tables/types/common/info.rs @@ -96,7 +96,7 @@ impl TableRowInfo { } else { let zeros = rows.leading_zeros(); // Safe: 32 - zeros is always <= 32, fits in u8 - (32 - zeros) as u8 + 32u8.saturating_sub(zeros as u8) }; Self { @@ -219,8 +219,10 @@ impl TableInfo { /// /// * [ECMA-335 Partition II, Section 24.2.6](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf) - #~ Stream pub fn new(data: &[u8], valid_bitvec: u64) -> Result { - let mut table_info = - vec![TableRowInfo::default(); TableId::CustomDebugInformation as usize + 1]; + let table_info_len = (TableId::CustomDebugInformation as usize) + .checked_add(1) + .ok_or_else(|| malformed_error!("Table info size overflow"))?; + let mut table_info = vec![TableRowInfo::default(); table_info_len]; let mut next_row_offset = 24; for table_id in TableId::iter() { @@ -239,10 +241,13 @@ impl TableInfo { continue; } - table_info[table_id as usize] = TableRowInfo::new(row_count); + let slot = table_info + .get_mut(table_id as usize) + .ok_or(out_of_bounds_error!())?; + *slot = TableRowInfo::new(row_count); } - let heap_size_flags = read_le::(&data[6..])?; + let heap_size_flags = read_le::(data.get(6..).ok_or(out_of_bounds_error!())?)?; let mut table_info = TableInfo { rows: table_info, coded_indexes: vec![0; CodedIndexType::COUNT], @@ -346,7 +351,7 @@ impl TableInfo { clippy::cast_precision_loss )] let tag_bits = (tables.len() as f32).log2().ceil() as u8; - let tag_mask = (1 << tag_bits) - 1; + let tag_mask = (1u32 << tag_bits).saturating_sub(1); let tag = value & tag_mask; let index = value >> tag_bits; @@ -355,7 +360,8 @@ impl TableInfo { return Err(out_of_bounds_error!()); } - Ok((tables[tag as usize], index)) + let table = *tables.get(tag as usize).ok_or(out_of_bounds_error!())?; + Ok((table, index)) } /// Encodes a table identifier and row index into a coded index value. @@ -456,7 +462,7 @@ impl TableInfo { /// `false` if 2-byte indices are sufficient. #[must_use] pub fn is_large(&self, id: TableId) -> bool { - self.rows[id as usize].is_large + self.rows.get(id as usize).is_some_and(|info| info.is_large) } /// Returns whether the #String heap uses large (4-byte) indices. @@ -554,7 +560,15 @@ impl TableInfo { /// A reference to the [`TableRowInfo`] for the specified table. #[must_use] pub fn get(&self, table: TableId) -> &TableRowInfo { - &self.rows[table as usize] + // The `rows` vector is sized to cover every TableId variant, so this lookup + // is always in range. Fall back to a static default to keep the lint happy + // and to avoid panics on malformed inputs. + static DEFAULT: TableRowInfo = TableRowInfo { + rows: 0, + bits: 1, + is_large: false, + }; + self.rows.get(table as usize).unwrap_or(&DEFAULT) } /// Returns the number of bits required to represent an index into a specific table. @@ -572,7 +586,7 @@ impl TableInfo { /// The number of bits required to represent table indices (1-32). #[must_use] pub fn table_index_bits(&self, table_id: TableId) -> u8 { - self.rows[table_id as usize].bits + self.rows.get(table_id as usize).map_or(1, |info| info.bits) } /// Returns the number of bytes required to represent an index into a specific table. @@ -589,7 +603,8 @@ impl TableInfo { /// Either `2` for small tables or `4` for large tables. #[must_use] pub fn table_index_bytes(&self, table_id: TableId) -> u8 { - if self.rows[table_id as usize].bits > 16 { + let bits = self.rows.get(table_id as usize).map_or(1, |info| info.bits); + if bits > 16 { 4 } else { 2 @@ -611,7 +626,10 @@ impl TableInfo { /// The number of bits required to represent coded indices of this type. #[must_use] pub fn coded_index_bits(&self, coded_index_type: CodedIndexType) -> u8 { - self.coded_indexes[coded_index_type as usize] + self.coded_indexes + .get(coded_index_type as usize) + .copied() + .unwrap_or(0) } /// Returns the cached byte size for a specific coded index type. @@ -628,7 +646,12 @@ impl TableInfo { /// Either `2` for coded indices that fit in 16 bits or `4` for larger coded indices. #[must_use] pub fn coded_index_bytes(&self, coded_index_type: CodedIndexType) -> u8 { - if self.coded_indexes[coded_index_type as usize] > 16 { + let bits = self + .coded_indexes + .get(coded_index_type as usize) + .copied() + .unwrap_or(0); + if bits > 16 { 4 } else { 2 @@ -663,7 +686,7 @@ impl TableInfo { clippy::cast_precision_loss )] let tag_bits = (tables.len() as f32).log2().ceil() as u8; - max_bits + tag_bits + max_bits.saturating_add(tag_bits) } /// Calculates and caches the bit sizes required for all coded index types. @@ -674,7 +697,9 @@ impl TableInfo { fn calculate_coded_index_bits(&mut self) { for coded_index in CodedIndexType::iter() { let size = self.calculate_coded_index_size(coded_index); - self.coded_indexes[coded_index as usize] = size; + if let Some(slot) = self.coded_indexes.get_mut(coded_index as usize) { + *slot = size; + } } } @@ -722,7 +747,7 @@ impl TableInfo { /// The number of rows in the specified table (0 if table is not present). #[must_use] pub fn row_count(&self, table_id: TableId) -> u32 { - self.rows[table_id as usize].rows + self.rows.get(table_id as usize).map_or(0, |info| info.rows) } /// Creates a new `TableInfo` with modified row counts for specified tables. @@ -759,7 +784,9 @@ impl TableInfo { ) -> TableInfo { let mut rows = self.rows.clone(); for (table_id, count) in new_counts { - rows[table_id as usize] = TableRowInfo::new(count); + if let Some(slot) = rows.get_mut(table_id as usize) { + *slot = TableRowInfo::new(count); + } } let mut new_info = TableInfo { diff --git a/dotscope/src/metadata/tables/types/read/iter.rs b/dotscope/src/metadata/tables/types/read/iter.rs index 988aef91..acbf745d 100644 --- a/dotscope/src/metadata/tables/types/read/iter.rs +++ b/dotscope/src/metadata/tables/types/read/iter.rs @@ -72,11 +72,11 @@ impl Iterator for TableIterator<'_, T> { match T::row_read( self.table.data, &mut self.current_offset, - self.current_row + 1, + self.current_row.saturating_add(1), &self.table.sizes, ) { Ok(row) => { - self.current_row += 1; + self.current_row = self.current_row.saturating_add(1); Some(row) } Err(_) => None, @@ -245,7 +245,7 @@ impl<'a, T: RowReadable + Send + Sync> rayon::iter::plumbing::Producer for Table fn split_at(self, index: usize) -> (Self, Self) { // Index represents table row positions which are expected to fit in u32 #[allow(clippy::cast_possible_truncation)] - let mid = self.range.start + index as u32; + let mid = self.range.start.saturating_add(index as u32); let left = TableProducer { table: self.table, range: self.range.start..mid, @@ -292,11 +292,11 @@ impl Iterator for TableProducerIterator<'_, T> { } let row_index = self.range.start; - self.range.start += 1; + self.range.start = self.range.start.saturating_add(1); // Get the row directly from the table // +1 because row indices start at 1 - self.table.get(row_index + 1) + self.table.get(row_index.saturating_add(1)) } fn size_hint(&self) -> (usize, Option) { @@ -314,10 +314,10 @@ impl DoubleEndedIterator for TableProducerIterator return None; } - self.range.end -= 1; + self.range.end = self.range.end.saturating_sub(1); // Get the row directly from the table // +1 because row indices start at 1 - self.table.get(self.range.end + 1) + self.table.get(self.range.end.saturating_add(1)) } } diff --git a/dotscope/src/metadata/tables/types/read/table.rs b/dotscope/src/metadata/tables/types/read/table.rs index 1b07eb04..eaa0920d 100644 --- a/dotscope/src/metadata/tables/types/read/table.rs +++ b/dotscope/src/metadata/tables/types/read/table.rs @@ -160,7 +160,7 @@ impl<'a, T: RowReadable> MetadataTable<'a, T> { /// The total size in bytes as a `u64` to accommodate large tables. #[must_use] pub fn size(&self) -> u64 { - u64::from(self.row_count) * u64::from(self.row_size) + u64::from(self.row_count).saturating_mul(u64::from(self.row_size)) } /// Retrieves a specific row by its 1-based index. @@ -185,7 +185,9 @@ impl<'a, T: RowReadable> MetadataTable<'a, T> { T::row_read( self.data, - &mut ((index as usize - 1) * self.row_size as usize), + &mut (index as usize) + .saturating_sub(1) + .saturating_mul(self.row_size as usize), index, &self.sizes, ) diff --git a/dotscope/src/metadata/tables/typespec/raw.rs b/dotscope/src/metadata/tables/typespec/raw.rs index 78bd64b5..87357c77 100644 --- a/dotscope/src/metadata/tables/typespec/raw.rs +++ b/dotscope/src/metadata/tables/typespec/raw.rs @@ -194,6 +194,7 @@ impl TableRow for TypeSpecRaw { fn row_size(sizes: &TableInfoRef) -> u32 { u32::from( /* signature */ sizes.blob_bytes() + ) } } diff --git a/dotscope/src/metadata/tables/typespec/reader.rs b/dotscope/src/metadata/tables/typespec/reader.rs index 2adbbe9e..0c6650eb 100644 --- a/dotscope/src/metadata/tables/typespec/reader.rs +++ b/dotscope/src/metadata/tables/typespec/reader.rs @@ -70,7 +70,7 @@ impl RowReadable for TypeSpecRaw { fn row_read(data: &[u8], offset: &mut usize, rid: u32, sizes: &TableInfoRef) -> Result { Ok(TypeSpecRaw { rid, - token: Token::new(0x1B00_0000 + rid), + token: Token::new(0x1B00_0000u32.saturating_add(rid)), offset: *offset, signature: read_le_at_dyn(data, offset, sizes.is_large_blob())?, }) diff --git a/dotscope/src/metadata/typesystem/base.rs b/dotscope/src/metadata/typesystem/base.rs index 4265d57c..08b48cc9 100644 --- a/dotscope/src/metadata/typesystem/base.rs +++ b/dotscope/src/metadata/typesystem/base.rs @@ -265,10 +265,13 @@ impl CilTypeRef { /// /// A strong reference to the type. /// - /// ## Panics + /// ## Returns /// - /// Panics if the referenced type has been dropped and the weak reference - /// cannot be upgraded to a strong reference. + /// Returns `Some` with a strong reference if the type is still alive, or + /// `None` if the type has been dropped (i.e. the weak reference cannot + /// be upgraded). The `msg` argument is kept for backwards compatibility + /// but no longer used for panicking; the crate forbids panics in + /// production code. /// /// ## Example /// @@ -280,14 +283,14 @@ impl CilTypeRef { /// # let my_type: Arc = unimplemented!(); /// let type_ref = CilTypeRef::new(&my_type); /// - /// // This will panic if my_type has been dropped - /// let strong_ref = type_ref.expect("Type should still be alive"); - /// println!("Type: {}", strong_ref.name); + /// if let Some(strong_ref) = type_ref.expect("Type should still be alive") { + /// println!("Type: {}", strong_ref.name); + /// } /// # } /// ``` #[must_use] - pub fn expect(&self, msg: &str) -> CilTypeRc { - self.weak_ref.upgrade().expect(msg) + pub fn expect(&self, _msg: &str) -> Option { + self.weak_ref.upgrade() } /// Checks if the referenced type is still alive. @@ -505,7 +508,7 @@ impl<'a> Iterator for CilTypeRefListIter<'a> { if self.index < self.list.count() { // boxcar::Vec returns Option from get(), and we need to handle it let result = self.list.get(self.index); - self.index += 1; + self.index = self.index.saturating_add(1); result } else { None @@ -1580,7 +1583,7 @@ impl CilFlavor { CilFlavor::I4 | CilFlavor::U4 | CilFlavor::R4 => Some(4), CilFlavor::I8 | CilFlavor::U8 | CilFlavor::R8 => Some(8), CilFlavor::I | CilFlavor::U => Some(ptr_size.bytes()), - CilFlavor::TypedRef { .. } => Some(ptr_size.bytes() * 2), + CilFlavor::TypedRef { .. } => Some(ptr_size.bytes().saturating_mul(2)), // Reference types are pointer-sized CilFlavor::Object | CilFlavor::String diff --git a/dotscope/src/metadata/typesystem/builder.rs b/dotscope/src/metadata/typesystem/builder.rs index a3764c92..87158b27 100644 --- a/dotscope/src/metadata/typesystem/builder.rs +++ b/dotscope/src/metadata/typesystem/builder.rs @@ -754,7 +754,7 @@ impl TypeBuilder { let dimension_part = if rank <= 1 { "[]".to_string() } else { - format!("[{}]", ",".repeat(rank as usize - 1)) + format!("[{}]", ",".repeat((rank as usize).saturating_sub(1))) }; let name = format!("{}{}", base_type.name, dimension_part); @@ -901,7 +901,8 @@ impl TypeBuilder { // Create a dummy method specification for the type argument let rid = u32::try_from(index) .map_err(|_| malformed_error!("Generic argument index too large"))? - + 1; + .checked_add(1) + .ok_or_else(|| malformed_error!("Generic argument rid overflow"))?; let token_value = 0x2B00_0000_u32 .checked_add(u32::try_from(index).map_err(|_| { diff --git a/dotscope/src/metadata/typesystem/encoder.rs b/dotscope/src/metadata/typesystem/encoder.rs index f5c240fc..86867627 100644 --- a/dotscope/src/metadata/typesystem/encoder.rs +++ b/dotscope/src/metadata/typesystem/encoder.rs @@ -224,19 +224,23 @@ impl TypeSignatureEncoder { // Reference and pointer types TypeSignature::ByRef(inner) => { buffer.push(0x10); // ELEMENT_TYPE_BYREF - Self::encode_type_signature_internal(inner, buffer, depth + 1)?; + Self::encode_type_signature_internal(inner, buffer, depth.saturating_add(1))?; } TypeSignature::Ptr(pointer) => { buffer.push(0x0F); // ELEMENT_TYPE_PTR // Encode custom modifiers Self::encode_custom_modifiers(&pointer.modifiers, buffer)?; - Self::encode_type_signature_internal(&pointer.base, buffer, depth + 1)?; + Self::encode_type_signature_internal( + &pointer.base, + buffer, + depth.saturating_add(1), + )?; } TypeSignature::Pinned(inner) => { buffer.push(0x45); // ELEMENT_TYPE_PINNED - Self::encode_type_signature_internal(inner, buffer, depth + 1)?; + Self::encode_type_signature_internal(inner, buffer, depth.saturating_add(1))?; } // Array types @@ -244,12 +248,12 @@ impl TypeSignatureEncoder { buffer.push(0x1D); // ELEMENT_TYPE_SZARRAY // Encode custom modifiers Self::encode_custom_modifiers(&array.modifiers, buffer)?; - Self::encode_type_signature_internal(&array.base, buffer, depth + 1)?; + Self::encode_type_signature_internal(&array.base, buffer, depth.saturating_add(1))?; } TypeSignature::Array(array) => { buffer.push(0x14); // ELEMENT_TYPE_ARRAY - Self::encode_type_signature_internal(&array.base, buffer, depth + 1)?; + Self::encode_type_signature_internal(&array.base, buffer, depth.saturating_add(1))?; write_compressed_uint(array.rank, buffer); // Collect sizes and lower bounds from dimensions @@ -296,7 +300,7 @@ impl TypeSignatureEncoder { // Generic type instantiation TypeSignature::GenericInst(base_type, type_args) => { buffer.push(0x15); // ELEMENT_TYPE_GENERICINST - Self::encode_type_signature_internal(base_type, buffer, depth + 1)?; + Self::encode_type_signature_internal(base_type, buffer, depth.saturating_add(1))?; write_compressed_uint( u32::try_from(type_args.len()).map_err(|_| { malformed_error!( @@ -307,7 +311,11 @@ impl TypeSignatureEncoder { buffer, ); for type_arg in type_args { - Self::encode_type_signature_internal(type_arg, buffer, depth + 1)?; + Self::encode_type_signature_internal( + type_arg, + buffer, + depth.saturating_add(1), + )?; } } diff --git a/dotscope/src/metadata/typesystem/mod.rs b/dotscope/src/metadata/typesystem/mod.rs index 15ff9b31..da7b628b 100644 --- a/dotscope/src/metadata/typesystem/mod.rs +++ b/dotscope/src/metadata/typesystem/mod.rs @@ -720,10 +720,10 @@ impl CilType { // Look up the inheritance chain without computing flavors (avoid infinite recursion) let mut current = base_type.base(); - let mut depth = 0; + let mut depth: usize = 0; while let Some(ancestor) = current { - depth += 1; + depth = depth.saturating_add(1); if depth > MAX_INHERITANCE_DEPTH { break; } diff --git a/dotscope/src/metadata/typesystem/primitives.rs b/dotscope/src/metadata/typesystem/primitives.rs index 2342a0bb..5aac1417 100644 --- a/dotscope/src/metadata/typesystem/primitives.rs +++ b/dotscope/src/metadata/typesystem/primitives.rs @@ -402,20 +402,17 @@ impl CilPrimitiveData { pub fn from_bytes(type_byte: u8, data: &[u8]) -> Result { match type_byte { ELEMENT_TYPE::BOOLEAN => { - if data.is_empty() { - Err(out_of_bounds_error!()) - } else { - Ok(CilPrimitiveData::Boolean(data[0] != 0)) - } + let b = data.first().ok_or(out_of_bounds_error!())?; + Ok(CilPrimitiveData::Boolean(*b != 0)) } ELEMENT_TYPE::CHAR => { - if data.len() < 2 { - Err(out_of_bounds_error!()) - } else { - let code = u16::from_le_bytes([data[0], data[1]]); - // .NET System.Char is a UTF-16 code unit, so any u16 value is valid - Ok(CilPrimitiveData::Char(code)) - } + let bytes = data.get(0..2).ok_or(out_of_bounds_error!())?; + let code = u16::from_le_bytes([ + *bytes.first().ok_or(out_of_bounds_error!())?, + *bytes.get(1).ok_or(out_of_bounds_error!())?, + ]); + // .NET System.Char is a UTF-16 code unit, so any u16 value is valid + Ok(CilPrimitiveData::Char(code)) } ELEMENT_TYPE::I1 => Ok(CilPrimitiveData::I1(read_le::(data)?)), ELEMENT_TYPE::U1 => Ok(CilPrimitiveData::U1(read_le::(data)?)), @@ -441,10 +438,12 @@ impl CilPrimitiveData { )); } - let utf16_chars: Vec = data - .chunks_exact(2) - .map(|chunk| u16::from_le_bytes([chunk[0], chunk[1]])) - .collect(); + let mut utf16_chars: Vec = Vec::with_capacity(data.len() / 2); + for chunk in data.chunks_exact(2) { + let b0 = *chunk.first().ok_or(out_of_bounds_error!())?; + let b1 = *chunk.get(1).ok_or(out_of_bounds_error!())?; + utf16_chars.push(u16::from_le_bytes([b0, b1])); + } match String::from_utf16(&utf16_chars) { Ok(utf_string) => Ok(CilPrimitiveData::String(utf_string)), @@ -1335,7 +1334,7 @@ impl CilPrimitive { CilPrimitiveData::I(value) => value.to_le_bytes().to_vec(), CilPrimitiveData::String(value) => { let utf16_chars: Vec = value.encode_utf16().collect(); - let mut bytes = Vec::with_capacity(utf16_chars.len() * 2); + let mut bytes = Vec::with_capacity(utf16_chars.len().saturating_mul(2)); for ch in utf16_chars { bytes.extend_from_slice(&ch.to_le_bytes()); } diff --git a/dotscope/src/metadata/typesystem/registry.rs b/dotscope/src/metadata/typesystem/registry.rs index 3395d3c6..79ea8745 100644 --- a/dotscope/src/metadata/typesystem/registry.rs +++ b/dotscope/src/metadata/typesystem/registry.rs @@ -1537,7 +1537,8 @@ impl TypeRegistry { for (index, arg_type) in generic_args.iter().enumerate() { let rid = u32::try_from(index) .map_err(|_| malformed_error!("Generic argument index too large"))? - + 1; + .checked_add(1) + .ok_or_else(|| malformed_error!("Generic argument rid overflow"))?; let token_value = 0x2B00_0000_u32 .checked_add( u32::try_from(index) diff --git a/dotscope/src/metadata/typesystem/resolver.rs b/dotscope/src/metadata/typesystem/resolver.rs index 57c7a6c9..bf689a6d 100644 --- a/dotscope/src/metadata/typesystem/resolver.rs +++ b/dotscope/src/metadata/typesystem/resolver.rs @@ -630,7 +630,7 @@ impl TypeResolver { TypeSignature::Array(array) => { let mut token_init = self.token_init.take(); - let element_type = self.resolve_with_depth(&array.base, depth + 1)?; + let element_type = self.resolve_with_depth(&array.base, depth.saturating_add(1))?; let array_flavor = CilFlavor::Array { element_type: Box::new(element_type.flavor().clone()), @@ -643,11 +643,8 @@ impl TypeResolver { let name = if array.rank == 1 { format!("{}[]", element_type.name) } else { - format!( - "{}[{}]", - element_type.name, - ",".repeat(array.rank as usize - 1) - ) + let commas = (array.rank as usize).saturating_sub(1); + format!("{}[{}]", element_type.name, ",".repeat(commas)) }; let array_type = self.registry.get_or_create_type(&CompleteTypeSpec { @@ -668,7 +665,8 @@ impl TypeResolver { TypeSignature::SzArray(szarray) => { let mut token_init = self.token_init.take(); - let element_type = self.resolve_with_depth(&szarray.base, depth + 1)?; + let element_type = + self.resolve_with_depth(&szarray.base, depth.saturating_add(1))?; let namespace = element_type.namespace.clone(); let name = format!("{}[]", element_type.name); @@ -709,7 +707,7 @@ impl TypeResolver { TypeSignature::Ptr(ptr) => { let mut token_init = self.token_init.take(); - let pointed_type = self.resolve_with_depth(&ptr.base, depth + 1)?; + let pointed_type = self.resolve_with_depth(&ptr.base, depth.saturating_add(1))?; let namespace = pointed_type.namespace.clone(); let name = format!("{}*", pointed_type.name); @@ -741,7 +739,7 @@ impl TypeResolver { TypeSignature::ByRef(type_sig) => { let mut token_init = self.token_init.take(); - let ref_type = self.resolve_with_depth(type_sig, depth + 1)?; + let ref_type = self.resolve_with_depth(type_sig, depth.saturating_add(1))?; let namespace = ref_type.namespace.clone(); let name = format!("{}&", ref_type.name); @@ -781,7 +779,7 @@ impl TypeResolver { TypeSignature::Pinned(type_sig) => { let mut token_init = self.token_init.take(); - let pinned_type = self.resolve_with_depth(type_sig, depth + 1)?; + let pinned_type = self.resolve_with_depth(type_sig, depth.saturating_add(1))?; let namespace = pinned_type.namespace.clone(); let name = format!("pinned {}", pinned_type.name); @@ -803,7 +801,7 @@ impl TypeResolver { TypeSignature::GenericInst(base_sig, type_args) => { let mut token_init = self.token_init.take(); - let base_type = self.resolve_with_depth(base_sig, depth + 1)?; + let base_type = self.resolve_with_depth(base_sig, depth.saturating_add(1))?; let namespace = base_type.namespace.clone(); let name = Self::format_generic_name(&base_type.name, type_args.len()); @@ -831,14 +829,15 @@ impl TypeResolver { let mut generic_args = Vec::with_capacity(type_args.len()); for arg_sig in type_args { - let arg_type = self.resolve_with_depth(arg_sig, depth + 1)?; + let arg_type = self.resolve_with_depth(arg_sig, depth.saturating_add(1))?; generic_args.push(arg_type); } for (index, arg_type) in generic_args.into_iter().enumerate() { let rid = u32::try_from(index) .map_err(|_| malformed_error!("Generic argument index too large"))? - + 1; + .checked_add(1) + .ok_or_else(|| malformed_error!("Generic argument rid overflow"))?; let token_value = 0x2B00_0000_u32 .checked_add(u32::try_from(index).map_err(|_| { diff --git a/dotscope/src/metadata/validation/result.rs b/dotscope/src/metadata/validation/result.rs index c55d707a..46330967 100644 --- a/dotscope/src/metadata/validation/result.rs +++ b/dotscope/src/metadata/validation/result.rs @@ -277,14 +277,14 @@ impl ValidationResult { #[must_use] pub fn combine(results: Vec) -> Self { let mut combined_outcomes = Vec::new(); - let mut total_validator_count = 0; + let mut total_validator_count: usize = 0; let mut total_duration = Duration::ZERO; let mut overall_success = true; for result in results { combined_outcomes.extend(result.outcomes); - total_validator_count += result.validator_count; - total_duration += result.duration; + total_validator_count = total_validator_count.saturating_add(result.validator_count); + total_duration = total_duration.saturating_add(result.duration); overall_success = overall_success && result.success; } @@ -617,13 +617,13 @@ impl TwoStageValidationResult { /// Sets the Stage 1 validation result. pub fn set_stage1_result(&mut self, result: ValidationResult) { - self.total_duration += result.duration(); + self.total_duration = self.total_duration.saturating_add(result.duration()); self.stage1_result = Some(result); } /// Sets the Stage 2 validation result. pub fn set_stage2_result(&mut self, result: ValidationResult) { - self.total_duration += result.duration(); + self.total_duration = self.total_duration.saturating_add(result.duration()); self.stage2_result = Some(result); } diff --git a/dotscope/src/metadata/validation/shared/references.rs b/dotscope/src/metadata/validation/shared/references.rs index 56def34e..000d0816 100644 --- a/dotscope/src/metadata/validation/shared/references.rs +++ b/dotscope/src/metadata/validation/shared/references.rs @@ -335,8 +335,10 @@ impl<'a> ReferenceValidator<'a> { .references_from(token) .map_or(0, std::collections::HashSet::len); - analysis.total_tokens += 1; - analysis.total_references += incoming_count + outgoing_count; + analysis.total_tokens = analysis.total_tokens.saturating_add(1); + analysis.total_references = analysis + .total_references + .saturating_add(incoming_count.saturating_add(outgoing_count)); if incoming_count == 0 { analysis.orphaned_tokens.insert(token); diff --git a/dotscope/src/metadata/validation/shared/schema.rs b/dotscope/src/metadata/validation/shared/schema.rs index 786b9681..9d6f7690 100644 --- a/dotscope/src/metadata/validation/shared/schema.rs +++ b/dotscope/src/metadata/validation/shared/schema.rs @@ -66,6 +66,7 @@ use crate::{ metadata::{ tables::TableId, + token::Token, validation::{ scanner::{HeapSizes, ReferenceScanner}, ScannerStatistics, @@ -334,19 +335,26 @@ impl<'a> SchemaValidator<'a> { // The exact decoding depends on the specific coded index type // This is a simplified validation - real implementation would decode properly let table_bits = allowed_tables.len().next_power_of_two().trailing_zeros(); - let table_index = coded_index & ((1 << table_bits) - 1); - let rid = coded_index >> table_bits; + let mask = 1u32.checked_shl(table_bits).unwrap_or(0).saturating_sub(1); + let table_index = coded_index & mask; + let rid = coded_index.checked_shr(table_bits).unwrap_or(0); // Validate table index is within allowed range if (table_index as usize) >= allowed_tables.len() { return Err(Error::InvalidToken { - token: crate::metadata::token::Token::new(coded_index), + token: Token::new(coded_index), message: format!("Table index {table_index} not in allowed range"), }); } // Validate RID for the decoded table - let table_id = allowed_tables[table_index as usize]; + let table_id = + *allowed_tables + .get(table_index as usize) + .ok_or_else(|| Error::InvalidToken { + token: Token::new(coded_index), + message: format!("Table index {table_index} not in allowed range"), + })?; self.validate_rid(table_id, rid) } diff --git a/dotscope/src/metadata/validation/validators/owned/metadata/attribute.rs b/dotscope/src/metadata/validation/validators/owned/metadata/attribute.rs index 38b8b4b3..d6f6e290 100644 --- a/dotscope/src/metadata/validation/validators/owned/metadata/attribute.rs +++ b/dotscope/src/metadata/validation/validators/owned/metadata/attribute.rs @@ -457,14 +457,16 @@ impl OwnedAttributeValidator { } if custom_attr.named_args.len() > 20 { - let mut similar_names = 0; + let mut similar_names: usize = 0; for i in 0..custom_attr.named_args.len() { - for j in (i + 1)..custom_attr.named_args.len() { - if Self::are_similar_names( - &custom_attr.named_args[i].name, - &custom_attr.named_args[j].name, - ) { - similar_names += 1; + for j in i.saturating_add(1)..custom_attr.named_args.len() { + let (Some(arg_i), Some(arg_j)) = + (custom_attr.named_args.get(i), custom_attr.named_args.get(j)) + else { + continue; + }; + if Self::are_similar_names(&arg_i.name, &arg_j.name) { + similar_names = similar_names.saturating_add(1); if similar_names > 5 { return true; } @@ -506,7 +508,7 @@ impl OwnedAttributeValidator { constructor: CilTypeReference::None, blob_index: 0, }; - if self.has_deep_array_nesting(&temp_attr, depth + 1) { + if self.has_deep_array_nesting(&temp_attr, depth.saturating_add(1)) { return true; } } @@ -536,10 +538,10 @@ impl OwnedAttributeValidator { return false; } - let mut differences = 0; + let mut differences: usize = 0; for (c1, c2) in name1.chars().zip(name2.chars()) { if c1 != c2 { - differences += 1; + differences = differences.saturating_add(1); if differences > 1 { return false; } diff --git a/dotscope/src/metadata/validation/validators/owned/relationships/ownership.rs b/dotscope/src/metadata/validation/validators/owned/relationships/ownership.rs index 835da630..b021782a 100644 --- a/dotscope/src/metadata/validation/validators/owned/relationships/ownership.rs +++ b/dotscope/src/metadata/validation/validators/owned/relationships/ownership.rs @@ -435,7 +435,7 @@ impl OwnedOwnershipValidator { self.validate_nested_type_circularity_deep( &nested_type, recursion_stack, - depth + 1, + depth.saturating_add(1), )?; } } diff --git a/dotscope/src/metadata/validation/validators/owned/system/assembly.rs b/dotscope/src/metadata/validation/validators/owned/system/assembly.rs index 37a5ff6d..494c2e4e 100644 --- a/dotscope/src/metadata/validation/validators/owned/system/assembly.rs +++ b/dotscope/src/metadata/validation/validators/owned/system/assembly.rs @@ -257,14 +257,20 @@ impl OwnedAssemblyValidator { match parts.len() { 1 => { // Language only (e.g., "en", "fr") - parts[0].len() == 2 && parts[0].chars().all(|c| c.is_ascii_lowercase()) + let Some(lang) = parts.first() else { + return false; + }; + lang.len() == 2 && lang.chars().all(|c| c.is_ascii_lowercase()) } 2 => { // Language-Country (e.g., "en-US", "fr-FR") - parts[0].len() == 2 - && parts[0].chars().all(|c| c.is_ascii_lowercase()) - && parts[1].len() == 2 - && parts[1].chars().all(|c| c.is_ascii_uppercase()) + let (Some(lang), Some(country)) = (parts.first(), parts.get(1)) else { + return false; + }; + lang.len() == 2 + && lang.chars().all(|c| c.is_ascii_lowercase()) + && country.len() == 2 + && country.chars().all(|c| c.is_ascii_uppercase()) } _ => false, } diff --git a/dotscope/src/metadata/validation/validators/owned/system/security.rs b/dotscope/src/metadata/validation/validators/owned/system/security.rs index 54a0a2f6..3af85f85 100644 --- a/dotscope/src/metadata/validation/validators/owned/system/security.rs +++ b/dotscope/src/metadata/validation/validators/owned/system/security.rs @@ -275,12 +275,24 @@ impl OwnedSecurityValidator { // Simplified extraction - real implementation would parse XML properly if let Some(start) = permission_set.find(&format!("<{action}")) { - if let Some(end) = permission_set[start..].find('>') { - let section = &permission_set[start..start + end]; - if let Some(class_start) = section.find("class=\"") { - if let Some(class_end) = section[class_start + 7..].find('"') { - let class_name = §ion[class_start + 7..class_start + 7 + class_end]; - permissions.push(class_name.to_string()); + if let Some(rest) = permission_set.get(start..) { + if let Some(end) = rest.find('>') { + let section_end = start.saturating_add(end); + if let Some(section) = permission_set.get(start..section_end) { + if let Some(class_start) = section.find("class=\"") { + let class_value_start = class_start.saturating_add(7); + if let Some(after_class) = section.get(class_value_start..) { + if let Some(class_end) = after_class.find('"') { + let class_value_end = + class_value_start.saturating_add(class_end); + if let Some(class_name) = + section.get(class_value_start..class_value_end) + { + permissions.push(class_name.to_string()); + } + } + } + } } } } diff --git a/dotscope/src/metadata/validation/validators/owned/types/circularity.rs b/dotscope/src/metadata/validation/validators/owned/types/circularity.rs index ee1c4b99..e561d0b4 100644 --- a/dotscope/src/metadata/validation/validators/owned/types/circularity.rs +++ b/dotscope/src/metadata/validation/validators/owned/types/circularity.rs @@ -243,7 +243,7 @@ impl OwnedTypeCircularityValidator { visited, visiting, context, - depth + 1, + depth.saturating_add(1), )?; } diff --git a/dotscope/src/metadata/validation/validators/owned/types/inheritance.rs b/dotscope/src/metadata/validation/validators/owned/types/inheritance.rs index f2ab7855..ca2004c6 100644 --- a/dotscope/src/metadata/validation/validators/owned/types/inheritance.rs +++ b/dotscope/src/metadata/validation/validators/owned/types/inheritance.rs @@ -224,7 +224,13 @@ impl OwnedInheritanceValidator { visiting.insert(type_ptr); if let Some(base_type) = type_entry.base() { - self.check_inheritance_cycles(&base_type, visited, visiting, context, depth + 1)?; + self.check_inheritance_cycles( + &base_type, + visited, + visiting, + context, + depth.saturating_add(1), + )?; } for (_, entry) in type_entry.interfaces.iter() { @@ -234,7 +240,7 @@ impl OwnedInheritanceValidator { visited, visiting, context, - depth + 1, + depth.saturating_add(1), )?; } } diff --git a/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs b/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs index c44e0a7f..e54a11cc 100644 --- a/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs +++ b/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs @@ -315,10 +315,9 @@ impl RawGenericConstraintValidator { if let Some(constraint_table) = tables.table::() { for constraint in constraint_table { let constraint_tables = constraint.constraint.ci_type.tables(); - let constraint_table_type = if constraint_tables.len() == 1 { - constraint_tables[0] - } else { - continue; + let constraint_table_type = match constraint_tables { + [single] if constraint_tables.len() == 1 => *single, + _ => continue, }; let constraint_row = constraint.constraint.row; diff --git a/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs b/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs index 47f99833..f18a1d7c 100644 --- a/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs +++ b/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs @@ -336,8 +336,12 @@ impl RawLayoutConstraintValidator { for (index, typedef_entry) in typedef_rows.iter().enumerate() { let start_field = typedef_entry.field_list; - let end_field = if index + 1 < typedef_rows.len() { - typedef_rows[index + 1].field_list + let next_index = index.saturating_add(1); + let end_field = if next_index < typedef_rows.len() { + typedef_rows + .get(next_index) + .ok_or(out_of_bounds_error!())? + .field_list } else { u32::MAX }; @@ -557,8 +561,8 @@ impl RawLayoutConstraintValidator { fields.sort_by_key(|f| f.field_offset); for window in fields.windows(2) { - let field1 = &window[0]; - let field2 = &window[1]; + let field1 = window.first().ok_or(out_of_bounds_error!())?; + let field2 = window.get(1).ok_or(out_of_bounds_error!())?; let gap = field2.field_offset.saturating_sub(field1.field_offset); if gap > 1_048_576 { diff --git a/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs b/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs index aa1cce16..f939a604 100644 --- a/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs +++ b/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs @@ -181,8 +181,8 @@ impl RawChangeIntegrityValidator { if let Some(&max_rid) = final_rids.iter().max() { let expected_min_count = - u32::try_from(final_rids.len() * 7 / 10).unwrap_or(0); - if max_rid > expected_min_count.max(1) * 2 { + u32::try_from(final_rids.len().saturating_mul(7) / 10).unwrap_or(0); + if max_rid > expected_min_count.max(1).saturating_mul(2) { return Err(malformed_error!( "Table {:?} integrity violation: RID sequence too sparse - max RID {} with only {} rows (>70% gaps)", table_id, @@ -410,8 +410,20 @@ impl RawChangeIntegrityValidator { for (table_id, modifications) in table_changes { if let TableModifications::Sparse { operations, .. } = modifications { for window in operations.windows(2) { - let curr_time = window[0].timestamp; - let next_time = window[1].timestamp; + let curr = window.first().ok_or_else(|| { + malformed_error!( + "Operation window missing first element for table {:?}", + table_id + ) + })?; + let next = window.get(1).ok_or_else(|| { + malformed_error!( + "Operation window missing second element for table {:?}", + table_id + ) + })?; + let curr_time = curr.timestamp; + let next_time = next.timestamp; if curr_time > next_time { return Err(malformed_error!( @@ -503,7 +515,11 @@ impl RawChangeIntegrityValidator { TableModifications::Replaced(rows) => { // For replaced tables, we have all TypeDef data for (i, row_data) in rows.iter().enumerate() { - let rid = u32::try_from(i + 1).map_err(|_| Error::ValidationRawFailed { + let rid_usize = i.checked_add(1).ok_or_else(|| Error::ValidationRawFailed { + validator: "integrity".to_string(), + message: "Table row index overflow when computing RID".to_string(), + })?; + let rid = u32::try_from(rid_usize).map_err(|_| Error::ValidationRawFailed { validator: "integrity".to_string(), message: "Table row index exceeds u32 range".to_string(), })?; @@ -521,14 +537,35 @@ impl RawChangeIntegrityValidator { // Validate each TypeDef's field range for i in 0..typedef_field_lists.len() { - let (typedef_rid, field_list_start) = typedef_field_lists[i]; + let (typedef_rid, field_list_start) = *typedef_field_lists.get(i).ok_or_else(|| { + malformed_error!( + "TypeDef field-list index {} out of bounds (len {})", + i, + typedef_field_lists.len() + ) + })?; // Determine the end of this type's field range - let field_list_end = if i + 1 < typedef_field_lists.len() { - typedef_field_lists[i + 1].1 // Next type's field_list + let next_index = i + .checked_add(1) + .ok_or_else(|| malformed_error!("TypeDef field-list index overflow at {}", i))?; + let field_list_end = if next_index < typedef_field_lists.len() { + typedef_field_lists + .get(next_index) + .ok_or_else(|| { + malformed_error!( + "TypeDef field-list next index {} out of bounds (len {})", + next_index, + typedef_field_lists.len() + ) + })? + .1 // Next type's field_list } else { // For the last type, use the maximum field RID + 1 - field_rids.iter().max().map_or(1, |max| max + 1) + field_rids + .iter() + .max() + .map_or(1, |max| max.saturating_add(1)) }; // Validate that all fields in this range exist diff --git a/dotscope/src/metadata/validation/validators/raw/modification/operation.rs b/dotscope/src/metadata/validation/validators/raw/modification/operation.rs index f1a8f980..92c9df58 100644 --- a/dotscope/src/metadata/validation/validators/raw/modification/operation.rs +++ b/dotscope/src/metadata/validation/validators/raw/modification/operation.rs @@ -190,7 +190,7 @@ impl RawOperationValidator { } // Validate RID allocation is sequential from next_rid - if *rid >= *next_rid + 1000 { + if *rid >= next_rid.saturating_add(1000) { return Err(malformed_error!( "Insert operation for table {:?} has RID {} too far ahead of next available RID {} - potential RID exhaustion", table_id, @@ -292,8 +292,8 @@ impl RawOperationValidator { } // Track multiple updates to the same RID (allowed with timestamp ordering) - let update_count = update_rids.entry(*rid).or_insert(0); - *update_count += 1; + let update_count: &mut u32 = update_rids.entry(*rid).or_insert(0); + *update_count = update_count.saturating_add(1); if *update_count > 10 { return Err(malformed_error!( @@ -430,12 +430,14 @@ impl RawOperationValidator { if let TableModifications::Sparse { operations, .. } = modifications { // Validate operations are chronologically ordered for window in operations.windows(2) { - if window[0].timestamp > window[1].timestamp { + let first = window.first().ok_or(out_of_bounds_error!())?; + let second = window.get(1).ok_or(out_of_bounds_error!())?; + if first.timestamp > second.timestamp { return Err(malformed_error!( "Operations for table {:?} are not chronologically ordered - timestamp {} > {}", table_id, - window[0].timestamp, - window[1].timestamp + first.timestamp, + second.timestamp )); } } diff --git a/dotscope/src/metadata/validation/validators/raw/structure/heap.rs b/dotscope/src/metadata/validation/validators/raw/structure/heap.rs index 2f1895bd..0e9ed73d 100644 --- a/dotscope/src/metadata/validation/validators/raw/structure/heap.rs +++ b/dotscope/src/metadata/validation/validators/raw/structure/heap.rs @@ -398,12 +398,12 @@ impl RawHeapValidator { /// - Individual GUID access fails unexpectedly fn validate_guid_heap_content(assembly_view: &CilAssemblyView) -> Result<()> { if let Some(guids) = assembly_view.guids() { - let mut guid_count = 0; + let mut guid_count: usize = 0; // Validate accessibility through iteration // Note: The GUID iterator returns (1-based index, GUID), not byte offsets for (one_based_index, guid_data) in guids.iter() { - guid_count += 1; + guid_count = guid_count.saturating_add(1); // Verify GUID data is properly accessible let guid_bytes = guid_data.to_bytes(); diff --git a/dotscope/src/metadata/validation/validators/raw/structure/signature.rs b/dotscope/src/metadata/validation/validators/raw/structure/signature.rs index 824f7550..bf252bae 100644 --- a/dotscope/src/metadata/validation/validators/raw/structure/signature.rs +++ b/dotscope/src/metadata/validation/validators/raw/structure/signature.rs @@ -203,21 +203,20 @@ impl RawSignatureValidator { message: format!("Signature blob index {blob_index} exceeds blob heap bounds"), })?; - if blob_data.is_empty() { + let Some((&calling_convention, rest)) = blob_data.split_first() else { return Err(Error::ValidationRawFailed { validator: "RawSignatureValidator".to_string(), message: format!("Signature blob at index {blob_index} is empty"), }); - } + }; - let calling_convention = blob_data[0]; Self::validate_calling_convention(calling_convention, expected_kind, blob_index)?; if matches!( expected_kind, SignatureKind::Method | SignatureKind::LocalVar | SignatureKind::Property ) { - if blob_data.len() < 2 { + if rest.is_empty() { return Err(Error::ValidationRawFailed { validator: "RawSignatureValidator".to_string(), message: format!( @@ -226,7 +225,7 @@ impl RawSignatureValidator { }); } - Self::validate_compressed_integer(&blob_data[1..], blob_index)?; + Self::validate_compressed_integer(rest, blob_index)?; } Self::validate_blob_bounds(blob_data, blob_index)?; @@ -332,14 +331,12 @@ impl RawSignatureValidator { /// /// Returns validation error if compressed integer encoding is malformed. fn validate_compressed_integer(data: &[u8], blob_index: u32) -> Result<()> { - if data.is_empty() { + let Some(&first_byte) = data.first() else { return Err(Error::ValidationRawFailed { validator: "RawSignatureValidator".to_string(), message: format!("Insufficient data for compressed integer in blob {blob_index}"), }); - } - - let first_byte = data[0]; + }; if (first_byte & 0x80) == 0 { // 1-byte encoding: 0bbbbbbb @@ -449,8 +446,7 @@ impl RawValidator for RawSignatureValidator { for method in table { if let Some(blob_heap) = assembly_view.blobs() { if let Ok(blob_data) = blob_heap.get(method.signature as usize) { - if !blob_data.is_empty() { - let calling_convention = blob_data[0]; + if let Some(&calling_convention) = blob_data.first() { let signature_kind = match calling_convention { 0x06 => SignatureKind::Field, _ => SignatureKind::Method, @@ -497,8 +493,7 @@ impl RawValidator for RawSignatureValidator { for standalone_sig in table { if let Some(blob_heap) = assembly_view.blobs() { if let Ok(blob_data) = blob_heap.get(standalone_sig.signature as usize) { - if !blob_data.is_empty() { - let calling_convention = blob_data[0]; + if let Some(&calling_convention) = blob_data.first() { let signature_kind = match calling_convention { 0x07 => SignatureKind::LocalVar, 0x06 => SignatureKind::Field, @@ -530,8 +525,7 @@ impl RawValidator for RawSignatureValidator { for member_ref in table { if let Some(blob_heap) = assembly_view.blobs() { if let Ok(blob_data) = blob_heap.get(member_ref.signature as usize) { - if !blob_data.is_empty() { - let calling_convention = blob_data[0]; + if let Some(&calling_convention) = blob_data.first() { let signature_kind = match calling_convention { 0x06 => SignatureKind::Field, _ => SignatureKind::Method, diff --git a/dotscope/src/metadata/validation/validators/raw/structure/table.rs b/dotscope/src/metadata/validation/validators/raw/structure/table.rs index 18be0039..d291acfe 100644 --- a/dotscope/src/metadata/validation/validators/raw/structure/table.rs +++ b/dotscope/src/metadata/validation/validators/raw/structure/table.rs @@ -246,7 +246,8 @@ impl RawTableValidator { (tables.table::(), tables.table::()) { for typedef_row in typedef_table { - if typedef_row.field_list != 0 && typedef_row.field_list > field_table.row_count + 1 + if typedef_row.field_list != 0 + && typedef_row.field_list > field_table.row_count.saturating_add(1) { return Err(malformed_error!( "TypeDef RID {} references field list starting at RID {} but Field table only has {} rows", @@ -263,7 +264,7 @@ impl RawTableValidator { { for typedef_row in typedef_table { if typedef_row.method_list != 0 - && typedef_row.method_list > method_table.row_count + 1 + && typedef_row.method_list > method_table.row_count.saturating_add(1) { return Err(malformed_error!( "TypeDef RID {} references method list starting at RID {} but MethodDef table only has {} rows", diff --git a/dotscope/src/metadata/vtfixup.rs b/dotscope/src/metadata/vtfixup.rs index 3e28ed63..d125eccd 100644 --- a/dotscope/src/metadata/vtfixup.rs +++ b/dotscope/src/metadata/vtfixup.rs @@ -115,49 +115,37 @@ pub fn parse(asm: &CilObject) -> Option { let mut entries = Vec::with_capacity(num_entries); for i in 0..num_entries { - let base = i * 8; - if base + 8 > data.len() { - break; - } - let entry_rva = - u32::from_le_bytes([data[base], data[base + 1], data[base + 2], data[base + 3]]); - let count = u16::from_le_bytes([data[base + 4], data[base + 5]]); - let flags = u16::from_le_bytes([data[base + 6], data[base + 7]]); + let base = i.checked_mul(8)?; + let end = base.checked_add(8)?; + let chunk = data.get(base..end)?; + let entry_rva = u32::from_le_bytes(chunk.get(0..4)?.try_into().ok()?); + let count = u16::from_le_bytes(chunk.get(4..6)?.try_into().ok()?); + let flags = u16::from_le_bytes(chunk.get(6..8)?.try_into().ok()?); let slot_size: usize = if flags & COR_VTABLE_64BIT != 0 { 8 } else { 4 }; // Read method tokens at the entry's RVA let mut tokens = Vec::with_capacity(count as usize); if let Ok(tok_offset) = file.rva_to_offset(entry_rva as usize) { - let tok_data_len = (count as usize) * slot_size; + let tok_data_len = (count as usize).saturating_mul(slot_size); if let Ok(tok_data) = file.data_slice(tok_offset, tok_data_len) { for j in 0..count as usize { - let slot_base = j * slot_size; + let Some(slot_base) = j.checked_mul(slot_size) else { + break; + }; let token = if slot_size == 8 { // 64-bit slot: read u64, truncate to u32 (high 32 bits are padding) - if slot_base + 8 <= tok_data.len() { - u64::from_le_bytes([ - tok_data[slot_base], - tok_data[slot_base + 1], - tok_data[slot_base + 2], - tok_data[slot_base + 3], - tok_data[slot_base + 4], - tok_data[slot_base + 5], - tok_data[slot_base + 6], - tok_data[slot_base + 7], - ]) as u32 - } else { - 0 - } - } else if slot_base + 4 <= tok_data.len() { - u32::from_le_bytes([ - tok_data[slot_base], - tok_data[slot_base + 1], - tok_data[slot_base + 2], - tok_data[slot_base + 3], - ]) + slot_base + .checked_add(8) + .and_then(|end| tok_data.get(slot_base..end)) + .and_then(|s| <[u8; 8]>::try_from(s).ok()) + .map_or(0u32, |b| u64::from_le_bytes(b) as u32) } else { - 0 + slot_base + .checked_add(4) + .and_then(|end| tok_data.get(slot_base..end)) + .and_then(|s| <[u8; 4]>::try_from(s).ok()) + .map_or(0u32, u32::from_le_bytes) }; tokens.push(token); } @@ -177,7 +165,12 @@ pub fn parse(asm: &CilObject) -> Option { for (i, entry) in entries.iter().enumerate() { for (j, &token) in entry.tokens.iter().enumerate() { if token != 0 { - vtentry_map.entry(token).or_default().push((i + 1, j + 1)); + let entry_idx = i.checked_add(1)?; + let slot_idx = j.checked_add(1)?; + vtentry_map + .entry(token) + .or_default() + .push((entry_idx, slot_idx)); } } } @@ -202,12 +195,19 @@ pub fn parse(asm: &CilObject) -> Option { .rva .saturating_add(u32::from(entry.count).saturating_mul(slot_size)); if addr >= entry.rva && addr < range_end { - let slot_offset = addr - entry.rva; - if slot_offset % slot_size == 0 { - let slot_idx = (slot_offset / slot_size) as usize; - if let Some(&token) = entry.tokens.get(slot_idx) { - if token != 0 { - export_map.insert(token, (func.ordinal, func.name.clone())); + let Some(slot_offset) = addr.checked_sub(entry.rva) else { + break; + }; + let Some(rem) = slot_offset.checked_rem(slot_size) else { + break; + }; + if rem == 0 { + if let Some(slot_div) = slot_offset.checked_div(slot_size) { + let slot_idx = slot_div as usize; + if let Some(&token) = entry.tokens.get(slot_idx) { + if token != 0 { + export_map.insert(token, (func.ordinal, func.name.clone())); + } } } } diff --git a/dotscope/src/project/loader.rs b/dotscope/src/project/loader.rs index 63a05160..2db87be4 100644 --- a/dotscope/src/project/loader.rs +++ b/dotscope/src/project/loader.rs @@ -256,7 +256,9 @@ impl ProjectLoader { info!( "Project loaded: {}/{} assemblies", result.success_count(), - result.success_count() + result.failure_count() + result + .success_count() + .saturating_add(result.failure_count()) ); Ok(result) } diff --git a/dotscope/src/project/result.rs b/dotscope/src/project/result.rs index 7004fe9e..c6621223 100644 --- a/dotscope/src/project/result.rs +++ b/dotscope/src/project/result.rs @@ -176,14 +176,14 @@ impl ProjectResult { if let Some(identity) = identity { self.loaded_assemblies.push(identity); } - self.loaded_count += 1; + self.loaded_count = self.loaded_count.saturating_add(1); } /// Record a failed assembly load. pub(crate) fn record_failure(&mut self, file_path: String, error_message: String) { self.failed_loads.push((file_path.clone(), error_message)); self.missing_dependencies.push(file_path); - self.failed_count += 1; + self.failed_count = self.failed_count.saturating_add(1); } /// Record a version mismatch between required and actual assembly. diff --git a/dotscope/src/test/analysis/runner.rs b/dotscope/src/test/analysis/runner.rs index 334d4b98..f1c7ea6d 100644 --- a/dotscope/src/test/analysis/runner.rs +++ b/dotscope/src/test/analysis/runner.rs @@ -20,7 +20,7 @@ use crate::{ }, mono::{compilation::compile_debug, Architecture, TestCapabilities}, }, - CilObject, Result, + CilObject, Error, Result, }; /// Path to Mono 4.8 framework assemblies for dependency resolution. @@ -124,13 +124,13 @@ impl AnalysisTestRunner { let capabilities = TestCapabilities::detect(); if !capabilities.can_test() { - return Err(crate::Error::Other( + return Err(Error::Other( "No C# compiler available for analysis tests".to_string(), )); } let temp_dir = tempfile::TempDir::new() - .map_err(|e| crate::Error::Other(format!("Failed to create temp dir: {}", e)))?; + .map_err(|e| Error::Other(format!("Failed to create temp dir: {}", e)))?; Ok(Self { capabilities, @@ -172,7 +172,7 @@ impl AnalysisTestRunner { )?; if !result.is_success() { - return Err(crate::Error::Other(format!( + return Err(Error::Other(format!( "Compilation failed: {}", result.error.unwrap_or_else(|| "Unknown error".to_string()) ))); @@ -212,7 +212,7 @@ impl AnalysisTestRunner { let assembly = project_result .project .get_primary() - .ok_or_else(|| crate::Error::Other("Failed to get primary assembly".to_string()))?; + .ok_or_else(|| Error::Other("Failed to get primary assembly".to_string()))?; self.assembly = Some(assembly.clone()); Ok(assembly) diff --git a/dotscope/src/utils/alignment.rs b/dotscope/src/utils/alignment.rs index 56875afb..eb654b24 100644 --- a/dotscope/src/utils/alignment.rs +++ b/dotscope/src/utils/alignment.rs @@ -59,7 +59,7 @@ /// ``` #[inline] pub fn align_to_4_bytes(value: u64) -> u64 { - (value + 3) & !3 + value.saturating_add(3) & !3 } /// Aligns a value to an arbitrary power-of-2 boundary for PE sections and memory layout. @@ -106,7 +106,8 @@ pub fn align_to(value: u64, alignment: u64) -> u64 { alignment.is_power_of_two(), "alignment must be a power of 2, got {alignment}" ); - (value + alignment - 1) & !(alignment - 1) + let mask = alignment.saturating_sub(1); + value.saturating_add(mask) & !mask } #[cfg(test)] diff --git a/dotscope/src/utils/base64.rs b/dotscope/src/utils/base64.rs index 057be063..6c3a1d43 100644 --- a/dotscope/src/utils/base64.rs +++ b/dotscope/src/utils/base64.rs @@ -30,31 +30,37 @@ pub fn base64_encode(data: &[u8]) -> String { const ALPHABET: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; let mut result = String::new(); - let mut i = 0; + let mut i: usize = 0; while i < data.len() { - let b0 = data[i]; - let b1 = data.get(i + 1).copied().unwrap_or(0); - let b2 = data.get(i + 2).copied().unwrap_or(0); + // Safe: loop condition guarantees i < data.len() + let b0 = data.get(i).copied().unwrap_or(0); + let b1 = data.get(i.saturating_add(1)).copied().unwrap_or(0); + let b2 = data.get(i.saturating_add(2)).copied().unwrap_or(0); let n = u32::from(b0) << 16 | u32::from(b1) << 8 | u32::from(b2); - result.push(char::from(ALPHABET[(n >> 18 & 0x3F) as usize])); - result.push(char::from(ALPHABET[(n >> 12 & 0x3F) as usize])); + // ALPHABET has 64 entries; mask `& 0x3F` keeps the index in [0, 63]. + let idx0 = ((n >> 18) & 0x3F) as usize; + let idx1 = ((n >> 12) & 0x3F) as usize; + result.push(char::from(ALPHABET.get(idx0).copied().unwrap_or(b'A'))); + result.push(char::from(ALPHABET.get(idx1).copied().unwrap_or(b'A'))); - if i + 1 < data.len() { - result.push(char::from(ALPHABET[(n >> 6 & 0x3F) as usize])); + if i.saturating_add(1) < data.len() { + let idx2 = ((n >> 6) & 0x3F) as usize; + result.push(char::from(ALPHABET.get(idx2).copied().unwrap_or(b'A'))); } else { result.push('='); } - if i + 2 < data.len() { - result.push(char::from(ALPHABET[(n & 0x3F) as usize])); + if i.saturating_add(2) < data.len() { + let idx3 = (n & 0x3F) as usize; + result.push(char::from(ALPHABET.get(idx3).copied().unwrap_or(b'A'))); } else { result.push('='); } - i += 3; + i = i.saturating_add(3); } result @@ -104,21 +110,26 @@ pub fn base64_decode(s: &str) -> Option> { return None; } - let mut i = 0; + let mut i: usize = 0; while i < bytes.len() { let mut n: u32 = 0; - let mut pad_count = 0; - - for j in 0..4 { - let b = bytes[i + j]; + let mut pad_count: u32 = 0; + + // Pre-computed shifts so we don't perform `18 - j * 6` arithmetic at runtime. + const SHIFTS: [u32; 4] = [18, 12, 6, 0]; + for (j, &shift) in SHIFTS.iter().enumerate() { + // Safe: i + j < bytes.len() because length is multiple of 4 + // and i is incremented in steps of 4 within `i < bytes.len()`. + let idx = i.checked_add(j)?; + let b = *bytes.get(idx)?; if b == b'=' { - pad_count += 1; + pad_count = pad_count.saturating_add(1); continue; } if b >= 128 { return None; } - let val = DECODE_TABLE[usize::from(b)]; + let val = *DECODE_TABLE.get(usize::from(b))?; if val < 0 { return None; } @@ -126,7 +137,7 @@ pub fn base64_decode(s: &str) -> Option> { // Sign loss is intentional: we've validated val >= 0 #[allow(clippy::cast_sign_loss)] let val_u32 = val as u32; - n |= val_u32 << (18 - j * 6); + n |= val_u32 << shift; } // Truncation to u8 is intentional - we're extracting individual bytes @@ -141,7 +152,7 @@ pub fn base64_decode(s: &str) -> Option> { } } - i += 4; + i = i.saturating_add(4); } Some(result) diff --git a/dotscope/src/utils/bitset.rs b/dotscope/src/utils/bitset.rs index 47e2c345..acc11218 100644 --- a/dotscope/src/utils/bitset.rs +++ b/dotscope/src/utils/bitset.rs @@ -61,7 +61,7 @@ impl BitSet { // Clear the excess bits in the last word if !capacity.is_multiple_of(64) { if let Some(last) = words.last_mut() { - *last = (1u64 << (capacity % 64)) - 1; + *last = (1u64 << (capacity % 64)).saturating_sub(1); } } @@ -95,8 +95,11 @@ impl BitSet { let word = index / 64; let bit = index % 64; let mask = 1u64 << bit; - let was_set = self.words[word] & mask != 0; - self.words[word] |= mask; + let Some(slot) = self.words.get_mut(word) else { + return false; + }; + let was_set = *slot & mask != 0; + *slot |= mask; !was_set } @@ -109,7 +112,9 @@ impl BitSet { assert!(index < self.len, "index out of bounds"); let word = index / 64; let bit = index % 64; - self.words[word] &= !(1u64 << bit); + if let Some(slot) = self.words.get_mut(word) { + *slot &= !(1u64 << bit); + } } /// Returns `true` if the bit at the given index is set. @@ -122,7 +127,9 @@ impl BitSet { assert!(index < self.len, "index out of bounds"); let word = index / 64; let bit = index % 64; - (self.words[word] & (1u64 << bit)) != 0 + self.words + .get(word) + .is_some_and(|w| (w & (1u64 << bit)) != 0) } /// Returns the number of bits set. @@ -146,7 +153,7 @@ impl BitSet { // Clear excess bits in last word if !self.len.is_multiple_of(64) { if let Some(last) = self.words.last_mut() { - *last = (1u64 << (self.len % 64)) - 1; + *last = (1u64 << (self.len % 64)).saturating_sub(1); } } } @@ -231,18 +238,22 @@ impl Iterator for BitSetIter<'_> { fn next(&mut self) -> Option { while self.word_idx < self.set.words.len() { - let word = self.set.words[self.word_idx]; + let word = *self.set.words.get(self.word_idx)?; while self.bit_idx < 64 { - let idx = self.word_idx * 64 + self.bit_idx; + let idx = self + .word_idx + .checked_mul(64) + .and_then(|v| v.checked_add(self.bit_idx))?; if idx >= self.set.len { return None; } - self.bit_idx += 1; - if (word & (1u64 << (self.bit_idx - 1))) != 0 { + let bit = self.bit_idx; + self.bit_idx = self.bit_idx.saturating_add(1); + if (word & (1u64 << bit)) != 0 { return Some(idx); } } - self.word_idx += 1; + self.word_idx = self.word_idx.saturating_add(1); self.bit_idx = 0; } None diff --git a/dotscope/src/utils/crypto.rs b/dotscope/src/utils/crypto.rs index eda04cad..07b4641d 100644 --- a/dotscope/src/utils/crypto.rs +++ b/dotscope/src/utils/crypto.rs @@ -343,12 +343,12 @@ pub fn derive_pbkdf1_key(password: &[u8], salt: &[u8], iterations: u32, key_len: if key_len <= 20 { // Simple case: just return the first key_len bytes - base_key[..key_len].to_vec() + base_key.get(..key_len).unwrap_or(&[]).to_vec() } else { // Extended output using .NET's proprietary extension // For each additional block: SHA1(counter_string || base_key) let mut result = base_key.to_vec(); - let mut counter = 1u32; + let mut counter: u32 = 1; while result.len() < key_len { let mut hasher = sha1::Sha1::new(); @@ -357,7 +357,7 @@ pub fn derive_pbkdf1_key(password: &[u8], salt: &[u8], iterations: u32, key_len: hasher.update(base_key); let block = hasher.finalize(); result.extend_from_slice(&block); - counter += 1; + counter = counter.saturating_add(1); } result.truncate(key_len); @@ -495,10 +495,18 @@ where return None; } if is_encryptor { - let cipher = E::new_from_slices(key, &iv[..block_size]).ok()?; - let padded_len = ((data.len() / block_size) + 1) * block_size; + let iv_slice = iv.get(..block_size)?; + let cipher = E::new_from_slices(key, iv_slice).ok()?; + // block_size is small (8 or 16); checked arithmetic guards against overflow + // for absurdly large `data.len()` from adversarial input. + let padded_len = data + .len() + .checked_div(block_size)? + .checked_add(1)? + .checked_mul(block_size)?; let mut buf = vec![0u8; padded_len]; - buf[..data.len()].copy_from_slice(data); + let buf_prefix = buf.get_mut(..data.len())?; + buf_prefix.copy_from_slice(data); let result = match padding { 1 => cipher .encrypt_padded::(&mut buf, data.len()) @@ -517,7 +525,8 @@ where if data.is_empty() { return None; } - let cipher = D::new_from_slices(key, &iv[..block_size]).ok()?; + let iv_slice = iv.get(..block_size)?; + let cipher = D::new_from_slices(key, iv_slice).ok()?; let mut buf = data.to_vec(); let result = match padding { 1 | 3 => cipher.decrypt_padded::(&mut buf).ok()?, @@ -541,9 +550,16 @@ where { if is_encryptor { let cipher = E::new_from_slice(key).ok()?; - let padded_len = ((data.len() / block_size) + 1) * block_size; + // block_size is small (8 or 16); checked arithmetic guards against overflow + // for absurdly large `data.len()` from adversarial input. + let padded_len = data + .len() + .checked_div(block_size)? + .checked_add(1)? + .checked_mul(block_size)?; let mut buf = vec![0u8; padded_len]; - buf[..data.len()].copy_from_slice(data); + let buf_prefix = buf.get_mut(..data.len())?; + buf_prefix.copy_from_slice(data); let result = match padding { 1 => cipher .encrypt_padded::(&mut buf, data.len()) @@ -679,17 +695,32 @@ pub fn verify_rsa_pkcs1v15( return false; } let mut em = vec![0u8; k]; - em[k - m_bytes.len()..].copy_from_slice(&m_bytes); + // k >= m_bytes.len() was just verified above; subtraction is safe. + let Some(em_offset) = k.checked_sub(m_bytes.len()) else { + return false; + }; + let Some(em_tail) = em.get_mut(em_offset..) else { + return false; + }; + em_tail.copy_from_slice(&m_bytes); // RFC 8017 §9.2: EMSA-PKCS1-v1.5 verification // Expected: 0x00 0x01 [0xFF padding] 0x00 [DigestInfo prefix] [hash] - let t_len = digest_prefix.len() + hash.len(); - if k < t_len + 11 { + let Some(t_len) = digest_prefix.len().checked_add(hash.len()) else { + return false; + }; + let Some(min_k) = t_len.checked_add(11) else { + return false; + }; + if k < min_k { return false; } // Build expected encoding and compare in constant position. - let ps_len = k - t_len - 3; + // k >= t_len + 11 implies k - t_len - 3 >= 8, so this subtraction is safe. + let Some(ps_len) = k.checked_sub(t_len).and_then(|v| v.checked_sub(3)) else { + return false; + }; let mut expected = Vec::with_capacity(k); expected.push(0x00); expected.push(0x01); @@ -755,7 +786,10 @@ pub fn derive_key_iv( salt: &[u8], params: &CryptoParameters, ) -> (Vec, Vec) { - let output_len = params.key_size + params.iv_size; + // Sizes come from configuration extracted via SSA from a decryptor body. + // Saturating addition avoids overflow on absurd values (in which case + // we degrade to empty key/iv rather than panicking). + let output_len = params.key_size.saturating_add(params.iv_size); let derived = derive_pbkdf2_key( password, salt, @@ -763,10 +797,12 @@ pub fn derive_key_iv( output_len, params.hash_algorithm, ); - ( - derived[..params.key_size].to_vec(), - derived[params.key_size..output_len].to_vec(), - ) + let key = derived.get(..params.key_size).unwrap_or(&[]).to_vec(); + let iv = derived + .get(params.key_size..output_len) + .unwrap_or(&[]) + .to_vec(); + (key, iv) } #[cfg(test)] diff --git a/dotscope/src/utils/decompress.rs b/dotscope/src/utils/decompress.rs index df9f0c3a..5c066131 100644 --- a/dotscope/src/utils/decompress.rs +++ b/dotscope/src/utils/decompress.rs @@ -72,7 +72,9 @@ pub fn is_confuserex_lzma(data: &[u8]) -> bool { // LZMA properties byte: encodes lc, lp, pb parameters // Valid range: 0-224 (9 * 5 * 5 - 1) // ConfuserEx typically uses default settings: lc=3, lp=0, pb=2 -> 0x5D - let props_byte = data[0]; + let Some(&props_byte) = data.first() else { + return false; + }; if props_byte > 224 { return false; } @@ -83,27 +85,34 @@ pub fn is_confuserex_lzma(data: &[u8]) -> bool { // Check dictionary size (bytes 1-4 of LZMA properties) // ConfuserEx typically uses 1MB dictionary (0x00100000) - let dict_size = u32::from_le_bytes([data[1], data[2], data[3], data[4]]); + let Some(dict_bytes) = data.get(1..5).and_then(|s| <[u8; 4]>::try_from(s).ok()) else { + return false; + }; + let dict_size = u32::from_le_bytes(dict_bytes); // Dictionary size must be reasonable: 1KB to 16MB // Too small or too large suggests this isn't LZMA - if !(1024..=16 * 1024 * 1024).contains(&dict_size) { + if !(1024u32..=16u32.saturating_mul(1024).saturating_mul(1024)).contains(&dict_size) { return false; } // Bytes 5-8: Uncompressed size (little-endian i32) // For ConfuserEx, this is typically a small positive number (the decrypted constants) - let uncompressed_size = i32::from_le_bytes([data[5], data[6], data[7], data[8]]); + let Some(size_bytes) = data.get(5..9).and_then(|s| <[u8; 4]>::try_from(s).ok()) else { + return false; + }; + let uncompressed_size = i32::from_le_bytes(size_bytes); // Uncompressed size should be positive and reasonable (< 10MB) // Negative or very large sizes indicate this isn't LZMA data - if uncompressed_size <= 0 || uncompressed_size > 10 * 1024 * 1024 { + if uncompressed_size <= 0 || uncompressed_size > 10i32.saturating_mul(1024).saturating_mul(1024) + { return false; } // Additional check: compressed data should be smaller than uncompressed // (otherwise why compress it?) - let compressed_data_len = data.len() - 9; // minus header + let compressed_data_len = data.len().saturating_sub(9); // minus header if compressed_data_len > uncompressed_size.cast_unsigned() as usize { // Compressed larger than uncompressed - suspicious return false; @@ -135,13 +144,17 @@ pub fn decompress_confuserex_lzma(data: &[u8]) -> DecompressResult> { } // Parse header - let props = &data[0..5]; - let uncompressed_size = i32::from_le_bytes([data[5], data[6], data[7], data[8]]); - let compressed = &data[9..]; + let props = data.get(0..5).ok_or(DecompressError::BufferTooSmall)?; + let size_bytes = data + .get(5..9) + .and_then(|s| <[u8; 4]>::try_from(s).ok()) + .ok_or(DecompressError::BufferTooSmall)?; + let uncompressed_size = i32::from_le_bytes(size_bytes); + let compressed = data.get(9..).ok_or(DecompressError::BufferTooSmall)?; // Build LZMA stream header for lzma-rs // lzma-rs expects: 5 bytes props + 8 bytes uncompressed size (little-endian u64) + compressed data - let mut lzma_stream = Vec::with_capacity(13 + compressed.len()); + let mut lzma_stream = Vec::with_capacity(compressed.len().saturating_add(13)); lzma_stream.extend_from_slice(props); // Convert i32 to u64 for lzma-rs format diff --git a/dotscope/src/utils/enums.rs b/dotscope/src/utils/enums.rs index 42fa1583..20d2002a 100644 --- a/dotscope/src/utils/enums.rs +++ b/dotscope/src/utils/enums.rs @@ -194,7 +194,7 @@ impl EnumUtils { return Self::check_enum_inheritance_with_registry_recursive( &resolved_base, registry, - depth + 1, + depth.saturating_add(1), ); } } diff --git a/dotscope/src/utils/graph/algorithms/cycles.rs b/dotscope/src/utils/graph/algorithms/cycles.rs index 242b7a3c..d64d9ad4 100644 --- a/dotscope/src/utils/graph/algorithms/cycles.rs +++ b/dotscope/src/utils/graph/algorithms/cycles.rs @@ -76,18 +76,22 @@ fn has_cycle_dfs( ) -> bool { let idx = node.index(); - if in_stack[idx] { + if in_stack.get(idx).copied().unwrap_or(false) { // Found a back edge - cycle detected return true; } - if visited[idx] { + if visited.get(idx).copied().unwrap_or(false) { // Already processed this node in a different path, no cycle here return false; } - visited[idx] = true; - in_stack[idx] = true; + if let Some(slot) = visited.get_mut(idx) { + *slot = true; + } + if let Some(slot) = in_stack.get_mut(idx) { + *slot = true; + } for successor in graph.successors(node) { if has_cycle_dfs(graph, successor, visited, in_stack) { @@ -95,7 +99,9 @@ fn has_cycle_dfs( } } - in_stack[idx] = false; + if let Some(slot) = in_stack.get_mut(idx) { + *slot = false; + } false } @@ -163,20 +169,24 @@ fn find_cycle_dfs( ) -> Option> { let idx = node.index(); - if in_stack[idx] { + if in_stack.get(idx).copied().unwrap_or(false) { // Found a back edge - extract the cycle let cycle_start_pos = path.iter().position(|&n| n == node)?; - let mut cycle: Vec = path[cycle_start_pos..].to_vec(); + let mut cycle: Vec = path.get(cycle_start_pos..)?.to_vec(); cycle.push(node); // Close the cycle return Some(cycle); } - if visited[idx] { + if visited.get(idx).copied().unwrap_or(false) { return None; } - visited[idx] = true; - in_stack[idx] = true; + if let Some(slot) = visited.get_mut(idx) { + *slot = true; + } + if let Some(slot) = in_stack.get_mut(idx) { + *slot = true; + } path.push(node); for successor in graph.successors(node) { @@ -186,7 +196,9 @@ fn find_cycle_dfs( } path.pop(); - in_stack[idx] = false; + if let Some(slot) = in_stack.get_mut(idx) { + *slot = false; + } None } diff --git a/dotscope/src/utils/graph/algorithms/dominators.rs b/dotscope/src/utils/graph/algorithms/dominators.rs index c5341d4b..619f0bf7 100644 --- a/dotscope/src/utils/graph/algorithms/dominators.rs +++ b/dotscope/src/utils/graph/algorithms/dominators.rs @@ -78,19 +78,16 @@ impl DominatorTree { self.entry } - /// Returns the immediate dominator of a node, or `None` for the entry node. + /// Returns the immediate dominator of a node, or `None` for the entry node + /// or for nodes whose index is out of bounds. /// /// The immediate dominator is the closest strict dominator of the node. - /// - /// # Panics - /// - /// Panics if the node index is out of bounds. #[inline] pub fn immediate_dominator(&self, node: NodeId) -> Option { if node == self.entry { None } else { - Some(self.idom[node.index()]) + self.idom.get(node.index()).copied() } } @@ -115,10 +112,9 @@ impl DominatorTree { let mut current = b; while current != self.entry { // Check for unreachable nodes (sentinel value) or out-of-bounds - if current.index() >= self.node_count { + let Some(&idom) = self.idom.get(current.index()) else { return false; - } - let idom = self.idom[current.index()]; + }; if idom == a { return true; } @@ -171,13 +167,21 @@ impl DominatorTree { /// Returns the depth of a node in the dominator tree. /// - /// The entry node has depth 0. + /// The entry node has depth 0. Returns 0 for nodes whose index is out of + /// bounds or that are unreachable from the entry. pub fn depth(&self, node: NodeId) -> usize { - let mut depth = 0; + let mut depth: usize = 0; let mut current = node; while current != self.entry { - current = self.idom[current.index()]; - depth += 1; + let Some(&idom) = self.idom.get(current.index()) else { + return depth; + }; + // Sentinel idom (== current) means unreachable: stop walking. + if idom == current { + return depth; + } + current = idom; + depth = depth.saturating_add(1); } depth } @@ -190,11 +194,7 @@ impl DominatorTree { /// /// O(1) — children are pre-computed during dominator tree construction. pub fn children(&self, node: NodeId) -> &[NodeId] { - if node.index() < self.children.len() { - &self.children[node.index()] - } else { - &[] - } + self.children.get(node.index()).map_or(&[], Vec::as_slice) } /// Returns the number of nodes in the dominator tree. @@ -220,7 +220,7 @@ impl Iterator for DominatorIterator<'_> { self.current = None; Some(current) } else { - self.current = Some(self.tree.idom[current.index()]); + self.current = self.tree.idom.get(current.index()).copied(); Some(current) } } @@ -312,14 +312,17 @@ where lt.compute(graph, &predecessors); // Build children list in a single O(V) pass from the idom array - let mut children = vec![Vec::new(); node_count]; + let mut children: Vec> = vec![Vec::new(); node_count]; for i in 0..node_count { let node = NodeId::new(i); - if node != entry { - let parent = lt.idom[i]; - if parent.index() < node_count { - children[parent.index()].push(node); - } + if node == entry { + continue; + } + let Some(parent) = lt.idom.get(i).copied() else { + continue; + }; + if let Some(slot) = children.get_mut(parent.index()) { + slot.push(node); } } @@ -346,11 +349,13 @@ where /// Returns a vector where `result[i]` contains all predecessors of node `i`. fn precompute_predecessors(graph: &G) -> Vec> { let n = graph.node_count(); - let mut preds = vec![Vec::new(); n]; + let mut preds: Vec> = vec![Vec::new(); n]; for i in 0..n { let v = NodeId::new(i); for succ in graph.successors(v) { - preds[succ.index()].push(v); + if let Some(slot) = preds.get_mut(succ.index()) { + slot.push(v); + } } } preds @@ -383,6 +388,29 @@ struct LengauerTarjan { } impl LengauerTarjan { + /// Sentinel NodeId representing "uninitialized" or "out-of-graph" values. + #[inline] + fn sentinel() -> NodeId { + NodeId::new(usize::MAX) + } + + #[inline] + fn get_node(slice: &[NodeId], i: usize) -> NodeId { + slice.get(i).copied().unwrap_or(Self::sentinel()) + } + + #[inline] + fn get_dfnum(&self, n: NodeId) -> usize { + self.dfnum.get(n.index()).copied().unwrap_or(0) + } + + #[inline] + fn set_node(slice: &mut [NodeId], i: usize, value: NodeId) { + if let Some(slot) = slice.get_mut(i) { + *slot = value; + } + } + fn new(n: usize, entry: NodeId) -> Self { let sentinel = NodeId::new(usize::MAX); Self { @@ -406,58 +434,69 @@ impl LengauerTarjan { // Process nodes in reverse DFS order (excluding entry) for i in (1..self.dfs_counter).rev() { - let w = self.vertex[i]; - let parent_w = self.parent[w.index()]; + let w = Self::get_node(&self.vertex, i); + let parent_w = Self::get_node(&self.parent, w.index()); // Phase 2: Compute semidominators // semi(w) = min { v : v -> w is a CFG edge and dfnum(v) < dfnum(w) } ∪ // { semi(u) : u -> w via tree edges where dfnum(u) > dfnum(w) } - for v in &predecessors[w.index()] { + let preds_w: &[NodeId] = predecessors.get(w.index()).map_or(&[], Vec::as_slice); + for v in preds_w { let v = *v; - if self.dfnum[v.index()] == 0 { + if self.get_dfnum(v) == 0 { // v is unreachable from entry, skip continue; } let u = self.eval(v); - if self.dfnum[self.semi[u.index()].index()] - < self.dfnum[self.semi[w.index()].index()] - { - self.semi[w.index()] = self.semi[u.index()]; + let semi_u = Self::get_node(&self.semi, u.index()); + let semi_w = Self::get_node(&self.semi, w.index()); + if self.get_dfnum(semi_u) < self.get_dfnum(semi_w) { + Self::set_node(&mut self.semi, w.index(), semi_u); } } // Add w to bucket of its semidominator - let semi_w = self.semi[w.index()]; - self.bucket[semi_w.index()].push(w); + let semi_w = Self::get_node(&self.semi, w.index()); + if let Some(bucket) = self.bucket.get_mut(semi_w.index()) { + bucket.push(w); + } // Link w into the forest self.link(parent_w, w); // Phase 3: Implicitly compute immediate dominators // Process bucket of parent(w) - let bucket = std::mem::take(&mut self.bucket[parent_w.index()]); + let bucket = self + .bucket + .get_mut(parent_w.index()) + .map_or_else(Vec::new, std::mem::take); for v in bucket { let u = self.eval(v); - if self.semi[u.index()] == self.semi[v.index()] { + let semi_u = Self::get_node(&self.semi, u.index()); + let semi_v = Self::get_node(&self.semi, v.index()); + if semi_u == semi_v { // idom(v) = semi(v) = parent(w) - self.idom[v.index()] = parent_w; + Self::set_node(&mut self.idom, v.index(), parent_w); } else { // idom(v) = idom(u) (will be computed later) - self.idom[v.index()] = u; + Self::set_node(&mut self.idom, v.index(), u); } } } // Phase 4: Explicitly compute immediate dominators for i in 1..self.dfs_counter { - let w = self.vertex[i]; - if self.idom[w.index()] != self.semi[w.index()] { - self.idom[w.index()] = self.idom[self.idom[w.index()].index()]; + let w = Self::get_node(&self.vertex, i); + let idom_w = Self::get_node(&self.idom, w.index()); + let semi_w = Self::get_node(&self.semi, w.index()); + if idom_w != semi_w { + let idom_idom = Self::get_node(&self.idom, idom_w.index()); + Self::set_node(&mut self.idom, w.index(), idom_idom); } } // Entry node dominates itself - self.idom[self.entry.index()] = self.entry; + Self::set_node(&mut self.idom, self.entry.index(), self.entry); } /// DFS traversal to assign DFS numbers and build DFS tree. @@ -471,17 +510,21 @@ impl LengauerTarjan { continue; } - if self.dfnum[idx] != 0 { + if self.dfnum.get(idx).copied().unwrap_or(0) != 0 { continue; } - self.dfs_counter += 1; - self.dfnum[idx] = self.dfs_counter; - self.vertex[self.dfs_counter - 1] = node; + self.dfs_counter = self.dfs_counter.saturating_add(1); + if let Some(slot) = self.dfnum.get_mut(idx) { + *slot = self.dfs_counter; + } + // dfs_counter is at least 1 here, so subtracting 1 is safe. + let vertex_idx = self.dfs_counter.saturating_sub(1); + Self::set_node(&mut self.vertex, vertex_idx, node); for succ in graph.successors(node) { - if self.dfnum[succ.index()] == 0 { - self.parent[succ.index()] = node; + if self.get_dfnum(succ) == 0 { + Self::set_node(&mut self.parent, succ.index(), node); stack.push((succ, false)); } } @@ -490,18 +533,18 @@ impl LengauerTarjan { /// Link v as a child of w in the spanning forest. fn link(&mut self, w: NodeId, v: NodeId) { - self.ancestor[v.index()] = w; + Self::set_node(&mut self.ancestor, v.index(), w); } /// Evaluate: find the node with minimum semidominator on the path to the root. fn eval(&mut self, v: NodeId) -> NodeId { - let sentinel = NodeId::new(usize::MAX); - if self.ancestor[v.index()] == sentinel { + let sentinel = Self::sentinel(); + if Self::get_node(&self.ancestor, v.index()) == sentinel { return v; } self.compress(v); - self.best[v.index()] + Self::get_node(&self.best, v.index()) } /// Path compression for the forest (iterative). @@ -511,32 +554,38 @@ impl LengauerTarjan { /// This avoids O(V) recursion depth that can overflow the stack on large /// CFF-obfuscated CFGs (500+ blocks). fn compress(&mut self, v: NodeId) { - let sentinel = NodeId::new(usize::MAX); + let sentinel = Self::sentinel(); // Phase 1: collect the path from v upward until we reach a node // whose ancestor is the forest root (ancestor == sentinel). let mut path = Vec::new(); let mut u = v; - while self.ancestor[self.ancestor[u.index()].index()] != sentinel { + loop { + let anc_u = Self::get_node(&self.ancestor, u.index()); + let anc_anc_u = Self::get_node(&self.ancestor, anc_u.index()); + if anc_anc_u == sentinel { + break; + } path.push(u); - u = self.ancestor[u.index()]; + u = anc_u; } // Phase 2: walk the path in reverse (top-down) to propagate best // values and flatten ancestor pointers — same semantics as the // recursive version's post-order updates. for &node in path.iter().rev() { - let ancestor_node = self.ancestor[node.index()]; - let best_ancestor = self.best[ancestor_node.index()]; - let best_node = self.best[node.index()]; - - if self.dfnum[self.semi[best_ancestor.index()].index()] - < self.dfnum[self.semi[best_node.index()].index()] - { - self.best[node.index()] = best_ancestor; + let ancestor_node = Self::get_node(&self.ancestor, node.index()); + let best_ancestor = Self::get_node(&self.best, ancestor_node.index()); + let best_node = Self::get_node(&self.best, node.index()); + + let semi_ba = Self::get_node(&self.semi, best_ancestor.index()); + let semi_bn = Self::get_node(&self.semi, best_node.index()); + if self.get_dfnum(semi_ba) < self.get_dfnum(semi_bn) { + Self::set_node(&mut self.best, node.index(), best_ancestor); } - self.ancestor[node.index()] = self.ancestor[ancestor_node.index()]; + let new_anc = Self::get_node(&self.ancestor, ancestor_node.index()); + Self::set_node(&mut self.ancestor, node.index(), new_anc); } } } @@ -614,7 +663,9 @@ where let mut runner = pred; // Guard against unreachable nodes (their index may be invalid/sentinel) while Some(runner) != idom_node && runner != dom_tree.entry() && runner.index() < n { - frontiers[runner.index()].insert(node.index()); + if let Some(slot) = frontiers.get_mut(runner.index()) { + slot.insert(node.index()); + } if let Some(idom) = dom_tree.immediate_dominator(runner) { // Check for sentinel value (unreachable node) if idom.index() >= n { @@ -627,7 +678,9 @@ where } // Also check entry if needed (guard against invalid index) if Some(runner) != idom_node && runner == dom_tree.entry() && runner.index() < n { - frontiers[runner.index()].insert(node.index()); + if let Some(slot) = frontiers.get_mut(runner.index()) { + slot.insert(node.index()); + } } } } diff --git a/dotscope/src/utils/graph/algorithms/scc.rs b/dotscope/src/utils/graph/algorithms/scc.rs index 1cfc8d58..ce3965fc 100644 --- a/dotscope/src/utils/graph/algorithms/scc.rs +++ b/dotscope/src/utils/graph/algorithms/scc.rs @@ -95,7 +95,7 @@ where // Run Tarjan's algorithm from each unvisited node for i in 0..node_count { let node = NodeId::new(i); - if state.index[i].is_none() { + if state.index.get(i).copied().flatten().is_none() { state.strongconnect(graph, node); } } @@ -135,35 +135,49 @@ impl TarjanState { let v_idx = v.index(); // Set the depth index for v - self.index[v_idx] = Some(self.current_index); - self.lowlink[v_idx] = self.current_index; - self.current_index += 1; + if let Some(slot) = self.index.get_mut(v_idx) { + *slot = Some(self.current_index); + } + if let Some(slot) = self.lowlink.get_mut(v_idx) { + *slot = self.current_index; + } + self.current_index = self.current_index.saturating_add(1); self.stack.push(v); - self.on_stack[v_idx] = true; + if let Some(slot) = self.on_stack.get_mut(v_idx) { + *slot = true; + } // Consider successors of v for w in graph.successors(v) { let w_idx = w.index(); - if self.index[w_idx].is_none() { + let w_index_visited = self.index.get(w_idx).copied().flatten(); + if w_index_visited.is_none() { // Successor w has not yet been visited; recurse self.strongconnect(graph, w); - self.lowlink[v_idx] = self.lowlink[v_idx].min(self.lowlink[w_idx]); - } else if self.on_stack[w_idx] { + let lw = self.lowlink.get(w_idx).copied().unwrap_or(usize::MAX); + if let Some(slot) = self.lowlink.get_mut(v_idx) { + *slot = (*slot).min(lw); + } + } else if self.on_stack.get(w_idx).copied().unwrap_or(false) { // Successor w is on stack and hence in the current SCC - // Note: index[w] is valid here because w has been visited - if let Some(idx) = self.index[w_idx] { - self.lowlink[v_idx] = self.lowlink[v_idx].min(idx); + if let Some(idx) = w_index_visited { + if let Some(slot) = self.lowlink.get_mut(v_idx) { + *slot = (*slot).min(idx); + } } } } // If v is a root node, pop the stack and generate an SCC - if let Some(idx) = self.index[v_idx] { - if self.lowlink[v_idx] == idx { + let v_index = self.index.get(v_idx).copied().flatten(); + if let Some(idx) = v_index { + if self.lowlink.get(v_idx).copied().unwrap_or(usize::MAX) == idx { let mut scc = Vec::new(); while let Some(w) = self.stack.pop() { - self.on_stack[w.index()] = false; + if let Some(slot) = self.on_stack.get_mut(w.index()) { + *slot = false; + } scc.push(w); if w == v { break; @@ -222,10 +236,12 @@ where let node_count = graph.node_count(); // Build mapping from node to SCC index - let mut node_to_scc = vec![0; node_count]; + let mut node_to_scc = vec![0usize; node_count]; for (scc_idx, scc) in sccs.iter().enumerate() { for &node in scc { - node_to_scc[node.index()] = scc_idx; + if let Some(slot) = node_to_scc.get_mut(node.index()) { + *slot = scc_idx; + } } } @@ -235,10 +251,10 @@ where for i in 0..node_count { let from_node = NodeId::new(i); - let from_scc = node_to_scc[i]; + let from_scc = node_to_scc.get(i).copied().unwrap_or(0); for to_node in graph.successors(from_node) { - let to_scc = node_to_scc[to_node.index()]; + let to_scc = node_to_scc.get(to_node.index()).copied().unwrap_or(0); if from_scc != to_scc && seen_edges.insert((from_scc, to_scc)) { edges.push((from_scc, to_scc)); diff --git a/dotscope/src/utils/graph/algorithms/topological.rs b/dotscope/src/utils/graph/algorithms/topological.rs index b08f3808..1e6cab3d 100644 --- a/dotscope/src/utils/graph/algorithms/topological.rs +++ b/dotscope/src/utils/graph/algorithms/topological.rs @@ -104,14 +104,15 @@ where let mut in_degree: Vec = vec![0; node_count]; for node in graph.node_ids() { for _ in graph.predecessors(node) { - in_degree[node.index()] += 1; + let slot = in_degree.get_mut(node.index())?; + *slot = slot.saturating_add(1); } } // Initialize queue with nodes having in-degree 0 let mut queue: VecDeque = VecDeque::new(); for node in graph.node_ids() { - if in_degree[node.index()] == 0 { + if *in_degree.get(node.index())? == 0 { queue.push_back(node); } } @@ -122,8 +123,9 @@ where result.push(node); for successor in graph.successors(node) { - in_degree[successor.index()] -= 1; - if in_degree[successor.index()] == 0 { + let slot = in_degree.get_mut(successor.index())?; + *slot = slot.saturating_sub(1); + if *slot == 0 { queue.push_back(successor); } } diff --git a/dotscope/src/utils/graph/algorithms/traversal.rs b/dotscope/src/utils/graph/algorithms/traversal.rs index 175e3321..d74915de 100644 --- a/dotscope/src/utils/graph/algorithms/traversal.rs +++ b/dotscope/src/utils/graph/algorithms/traversal.rs @@ -67,7 +67,9 @@ impl<'g, G: Successors> DfsIterator<'g, G> { } let mut visited = vec![false; node_count]; - visited[start.index()] = true; + if let Some(slot) = visited.get_mut(start.index()) { + *slot = true; + } DfsIterator { graph, @@ -90,9 +92,11 @@ impl Iterator for DfsIterator<'_, G> { // so that they are visited in the original order let successors: Vec = self.graph.successors(node).collect(); for &succ in successors.iter().rev() { - if !self.visited[succ.index()] { - self.visited[succ.index()] = true; - self.stack.push(succ); + if let Some(slot) = self.visited.get_mut(succ.index()) { + if !*slot { + *slot = true; + self.stack.push(succ); + } } } @@ -173,7 +177,9 @@ impl<'g, G: Successors> BfsIterator<'g, G> { } let mut visited = vec![false; node_count]; - visited[start.index()] = true; + if let Some(slot) = visited.get_mut(start.index()) { + *slot = true; + } let mut queue = VecDeque::new(); queue.push_back(start); @@ -197,9 +203,11 @@ impl Iterator for BfsIterator<'_, G> { // Enqueue unvisited successors for succ in self.graph.successors(node) { - if !self.visited[succ.index()] { - self.visited[succ.index()] = true; - self.queue.push_back(succ); + if let Some(slot) = self.visited.get_mut(succ.index()) { + if !*slot { + *slot = true; + self.queue.push_back(succ); + } } } @@ -312,10 +320,13 @@ pub fn postorder(graph: &G, start: NodeId) -> Vec { while let Some((node, state)) = stack.pop() { match state { State::Enter => { - if visited[node.index()] { + let Some(slot) = visited.get_mut(node.index()) else { + continue; + }; + if *slot { continue; } - visited[node.index()] = true; + *slot = true; // Push exit state for this node (will be processed after children) stack.push((node, State::Exit)); @@ -323,7 +334,7 @@ pub fn postorder(graph: &G, start: NodeId) -> Vec { // Push children in reverse order so they're processed in order let successors: Vec = graph.successors(node).collect(); for &succ in successors.iter().rev() { - if !visited[succ.index()] { + if let Some(false) = visited.get(succ.index()).copied() { stack.push((succ, State::Enter)); } } diff --git a/dotscope/src/utils/graph/directed.rs b/dotscope/src/utils/graph/directed.rs index 73343aac..60b93175 100644 --- a/dotscope/src/utils/graph/directed.rs +++ b/dotscope/src/utils/graph/directed.rs @@ -457,7 +457,7 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// /// # Errors /// - /// Returns [`Error::GraphError`] if either `source` or `target` node does not exist + /// Returns [`crate::Error::GraphError`] if either `source` or `target` node does not exist /// in the graph. pub fn add_edge(&mut self, source: NodeId, target: NodeId, data: E) -> Result { if source.index() >= self.nodes.len() { @@ -482,8 +482,22 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { data, }); - self.outgoing[source.index()].push(id); - self.incoming[target.index()].push(id); + self.outgoing + .get_mut(source.index()) + .ok_or_else(|| { + Error::GraphError(format!( + "outgoing adjacency missing for source node {source}" + )) + })? + .push(id); + self.incoming + .get_mut(target.index()) + .ok_or_else(|| { + Error::GraphError(format!( + "incoming adjacency missing for target node {target}" + )) + })? + .push(id); Ok(id) } @@ -636,9 +650,11 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// assert_eq!(successors.len(), 2); /// ``` pub fn successors(&self, node: NodeId) -> impl Iterator + '_ { - self.outgoing[node.index()] - .iter() - .map(|&edge_id| self.edges[edge_id.index()].target) + self.outgoing + .get(node.index()) + .into_iter() + .flatten() + .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.target)) } /// Returns an iterator over the predecessors of the given node. @@ -674,9 +690,11 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// assert_eq!(predecessors.len(), 2); /// ``` pub fn predecessors(&self, node: NodeId) -> impl Iterator + '_ { - self.incoming[node.index()] - .iter() - .map(|&edge_id| self.edges[edge_id.index()].source) + self.incoming + .get(node.index()) + .into_iter() + .flatten() + .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.source)) } /// Returns an iterator over outgoing edges from the given node. @@ -712,9 +730,11 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// } /// ``` pub fn outgoing_edges(&self, node: NodeId) -> impl Iterator + '_ { - self.outgoing[node.index()] - .iter() - .map(|&edge_id| (edge_id, &self.edges[edge_id.index()].data)) + self.outgoing + .get(node.index()) + .into_iter() + .flatten() + .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| (edge_id, &e.data))) } /// Returns an iterator over incoming edges to the given node. @@ -734,9 +754,11 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// /// Panics if `node` is not a valid node in the graph. pub fn incoming_edges(&self, node: NodeId) -> impl Iterator + '_ { - self.incoming[node.index()] - .iter() - .map(|&edge_id| (edge_id, &self.edges[edge_id.index()].data)) + self.incoming + .get(node.index()) + .into_iter() + .flatten() + .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| (edge_id, &e.data))) } /// Returns the out-degree (number of outgoing edges) of a node. @@ -771,7 +793,7 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// ``` #[must_use] pub fn out_degree(&self, node: NodeId) -> usize { - self.outgoing[node.index()].len() + self.outgoing.get(node.index()).map_or(0, Vec::len) } /// Returns the in-degree (number of incoming edges) of a node. @@ -806,7 +828,7 @@ impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { /// ``` #[must_use] pub fn in_degree(&self, node: NodeId) -> usize { - self.incoming[node.index()].len() + self.incoming.get(node.index()).map_or(0, Vec::len) } /// Returns `true` if the graph contains no nodes. @@ -946,18 +968,22 @@ impl GraphBase for DirectedGraph<'_, N, E> { // Implement the Successors trait impl Successors for DirectedGraph<'_, N, E> { fn successors(&self, node: NodeId) -> impl Iterator { - self.outgoing[node.index()] - .iter() - .map(|&edge_id| self.edges[edge_id.index()].target) + self.outgoing + .get(node.index()) + .into_iter() + .flatten() + .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.target)) } } // Implement the Predecessors trait impl Predecessors for DirectedGraph<'_, N, E> { fn predecessors(&self, node: NodeId) -> impl Iterator { - self.incoming[node.index()] - .iter() - .map(|&edge_id| self.edges[edge_id.index()].source) + self.incoming + .get(node.index()) + .into_iter() + .flatten() + .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.source)) } } diff --git a/dotscope/src/utils/io.rs b/dotscope/src/utils/io.rs index e48c0da5..081f4a1f 100644 --- a/dotscope/src/utils/io.rs +++ b/dotscope/src/utils/io.rs @@ -536,15 +536,13 @@ pub fn read_le(data: &[u8]) -> Result { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn read_le_at(data: &[u8], offset: &mut usize) -> Result { let type_len = std::mem::size_of::(); - if (type_len + *offset) > data.len() { - return Err(out_of_bounds_error!()); - } - - let Ok(read) = data[*offset..*offset + type_len].try_into() else { + let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; + let slice = data.get(*offset..end).ok_or(out_of_bounds_error!())?; + let Ok(read) = slice.try_into() else { return Err(out_of_bounds_error!()); }; - *offset += type_len; + *offset = end; Ok(T::from_le_bytes(read)) } @@ -671,15 +669,13 @@ pub fn read_be(data: &[u8]) -> Result { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn read_be_at(data: &[u8], offset: &mut usize) -> Result { let type_len = std::mem::size_of::(); - if (type_len + *offset) > data.len() { - return Err(out_of_bounds_error!()); - } - - let Ok(read) = data[*offset..*offset + type_len].try_into() else { + let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; + let slice = data.get(*offset..end).ok_or(out_of_bounds_error!())?; + let Ok(read) = slice.try_into() else { return Err(out_of_bounds_error!()); }; - *offset += type_len; + *offset = end; Ok(T::from_be_bytes(read)) } @@ -809,13 +805,11 @@ pub fn write_le(data: &mut [u8], value: T) -> Result<()> { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn write_le_at(data: &mut [u8], offset: &mut usize, value: T) -> Result<()> { let type_len = std::mem::size_of::(); - if (type_len + *offset) > data.len() { - return Err(out_of_bounds_error!()); - } - + let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; let bytes = value.to_le_bytes(); - data[*offset..*offset + type_len].copy_from_slice(bytes.as_ref()); - *offset += type_len; + let dst = data.get_mut(*offset..end).ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(bytes.as_ref()); + *offset = end; Ok(()) } @@ -952,13 +946,11 @@ pub fn write_be(data: &mut [u8], value: T) -> Result<()> { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn write_be_at(data: &mut [u8], offset: &mut usize, value: T) -> Result<()> { let type_len = std::mem::size_of::(); - if (type_len + *offset) > data.len() { - return Err(out_of_bounds_error!()); - } - + let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; let bytes = value.to_be_bytes(); - data[*offset..*offset + type_len].copy_from_slice(bytes.as_ref()); - *offset += type_len; + let dst = data.get_mut(*offset..end).ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(bytes.as_ref()); + *offset = end; Ok(()) } @@ -1094,9 +1086,12 @@ pub fn write_compressed_uint(value: u32, buffer: &mut Vec) { #[allow(clippy::cast_sign_loss)] pub fn write_compressed_int(value: i32, buffer: &mut Vec) { let unsigned_value = if value >= 0 { - (value as u32) << 1 + (value as u32).wrapping_shl(1) } else { - (((-value - 1) as u32) << 1) | 1 + // value is negative, so -value-1 fits in u32 without overflow + // (covers i32::MIN: -i32::MIN - 1 == i32::MAX, fits in u32) + let magnitude = value.wrapping_neg().wrapping_sub(1) as u32; + magnitude.wrapping_shl(1) | 1 }; write_compressed_uint(unsigned_value, buffer); } @@ -1201,7 +1196,9 @@ pub fn write_prefixed_string_utf8(value: &str, buffer: &mut Vec) { #[allow(clippy::cast_possible_truncation)] pub fn write_prefixed_string_utf16(value: &str, buffer: &mut Vec) { let utf16_chars: Vec = value.encode_utf16().collect(); - let byte_length = utf16_chars.len() * 2; + // saturating_mul: the byte length is then bounded; the 7-bit prefix is u32 + // and overflow only matters for >2GiB strings, which we cannot encode anyway. + let byte_length = utf16_chars.len().saturating_mul(2); write_7bit_encoded_int(byte_length as u32, buffer); @@ -1223,18 +1220,14 @@ pub fn write_prefixed_string_utf16(value: &str, buffer: &mut Vec) { /// only a null terminator. #[must_use] pub fn decode_utf16le(bytes: &[u8]) -> Option { - if bytes.len() < 2 { - return None; - } let mut utf16_chars = Vec::new(); - let mut i = 0; - while i + 1 < bytes.len() { - let ch = u16::from_le_bytes([bytes[i], bytes[i + 1]]); + for pair in bytes.chunks_exact(2) { + let arr: [u8; 2] = pair.try_into().ok()?; + let ch = u16::from_le_bytes(arr); if ch == 0 { break; } utf16_chars.push(ch); - i += 2; } if utf16_chars.is_empty() { return None; @@ -1278,20 +1271,19 @@ pub fn decode_utf16le(bytes: &[u8]) -> Option { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn write_string_at(data: &mut [u8], offset: &mut usize, value: &str) -> Result<()> { let string_bytes = value.as_bytes(); - let total_length = string_bytes.len() + 1; // +1 for null terminator + let after_str = offset + .checked_add(string_bytes.len()) + .ok_or(out_of_bounds_error!())?; + let after_null = after_str.checked_add(1).ok_or(out_of_bounds_error!())?; - // Check bounds - if *offset + total_length > data.len() { - return Err(out_of_bounds_error!()); - } + let dst = data + .get_mut(*offset..after_str) + .ok_or(out_of_bounds_error!())?; + dst.copy_from_slice(string_bytes); - // Write string bytes - data[*offset..*offset + string_bytes.len()].copy_from_slice(string_bytes); - *offset += string_bytes.len(); + *data.get_mut(after_str).ok_or(out_of_bounds_error!())? = 0; - // Write null terminator - data[*offset] = 0; - *offset += 1; + *offset = after_null; Ok(()) } @@ -1336,35 +1328,32 @@ pub fn write_string_at(data: &mut [u8], offset: &mut usize, value: &str) -> Resu /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn read_compressed_int(data: &[u8], offset: &mut usize) -> Result<(usize, usize)> { - if *offset >= data.len() { - return Err(out_of_bounds_error!()); - } - - let first_byte = data[*offset]; + let first_byte = *data.get(*offset).ok_or(out_of_bounds_error!())?; if first_byte & 0x80 == 0 { // Single byte: 0xxxxxxx - *offset += 1; + *offset = offset.checked_add(1).ok_or(out_of_bounds_error!())?; Ok((first_byte as usize, 1)) } else if first_byte & 0xC0 == 0x80 { // Two bytes: 10xxxxxx xxxxxxxx - if *offset + 1 >= data.len() { - return Err(out_of_bounds_error!()); - } - let second_byte = data[*offset + 1]; + let next = offset.checked_add(1).ok_or(out_of_bounds_error!())?; + let second_byte = *data.get(next).ok_or(out_of_bounds_error!())?; let value = (((first_byte & 0x3F) as usize) << 8) | (second_byte as usize); - *offset += 2; + *offset = offset.checked_add(2).ok_or(out_of_bounds_error!())?; Ok((value, 2)) } else { // Four bytes: 110xxxxx xxxxxxxx xxxxxxxx xxxxxxxx - if *offset + 3 >= data.len() { - return Err(out_of_bounds_error!()); - } + let o1 = offset.checked_add(1).ok_or(out_of_bounds_error!())?; + let o2 = offset.checked_add(2).ok_or(out_of_bounds_error!())?; + let o3 = offset.checked_add(3).ok_or(out_of_bounds_error!())?; + let b1 = *data.get(o1).ok_or(out_of_bounds_error!())?; + let b2 = *data.get(o2).ok_or(out_of_bounds_error!())?; + let b3 = *data.get(o3).ok_or(out_of_bounds_error!())?; let mut value = ((first_byte & 0x1F) as usize) << 24; - value |= (data[*offset + 1] as usize) << 16; - value |= (data[*offset + 2] as usize) << 8; - value |= data[*offset + 3] as usize; - *offset += 4; + value |= (b1 as usize) << 16; + value |= (b2 as usize) << 8; + value |= b3 as usize; + *offset = offset.checked_add(4).ok_or(out_of_bounds_error!())?; Ok((value, 4)) } } @@ -1416,7 +1405,7 @@ pub fn read_compressed_int_at(data: &[u8], offset: usize) -> Result<(usize, usiz /// (`0xFF`), empty input, or truncated data. #[must_use] pub fn read_packed_len(data: &[u8]) -> Option<(usize, usize)> { - if data.is_empty() || data[0] == 0xFF { + if *data.first()? == 0xFF { return None; } read_compressed_int_at(data, 0).ok() diff --git a/dotscope/src/utils/lebytes.rs b/dotscope/src/utils/lebytes.rs index 110e4971..88758706 100644 --- a/dotscope/src/utils/lebytes.rs +++ b/dotscope/src/utils/lebytes.rs @@ -83,7 +83,9 @@ impl LeBytes { /// Returns the bytes as a slice. #[inline] pub fn as_slice(&self) -> &[u8] { - &self.buf[..self.len as usize] + // SAFETY-equivalent: `self.len` is always <= 8 (clamped at construction). + let len = (self.len as usize).min(self.buf.len()); + self.buf.get(..len).unwrap_or(&[]) } /// Returns the number of bytes. @@ -152,8 +154,8 @@ impl Iterator for LeBytesIter { #[inline] fn next(&mut self) -> Option { if self.pos < self.bytes.len { - let b = self.bytes.buf[self.pos as usize]; - self.pos += 1; + let b = *self.bytes.buf.get(self.pos as usize)?; + self.pos = self.pos.saturating_add(1u8); Some(b) } else { None @@ -162,7 +164,7 @@ impl Iterator for LeBytesIter { #[inline] fn size_hint(&self) -> (usize, Option) { - let remaining = (self.bytes.len - self.pos) as usize; + let remaining = self.bytes.len.saturating_sub(self.pos) as usize; (remaining, Some(remaining)) } } diff --git a/dotscope/src/utils/synchronization.rs b/dotscope/src/utils/synchronization.rs index c706dfca..843ebc2c 100644 --- a/dotscope/src/utils/synchronization.rs +++ b/dotscope/src/utils/synchronization.rs @@ -170,7 +170,10 @@ impl FailFastBarrier { } } - let arrived_count = self.arrived.fetch_add(1, Ordering::AcqRel) + 1; + let arrived_count = self + .arrived + .fetch_add(1, Ordering::AcqRel) + .saturating_add(1); if arrived_count == self.count { // Last thread to arrive - wake everyone up diff --git a/dotscope/src/utils/visitedmap.rs b/dotscope/src/utils/visitedmap.rs index 29bae623..4c60a537 100644 --- a/dotscope/src/utils/visitedmap.rs +++ b/dotscope/src/utils/visitedmap.rs @@ -95,10 +95,11 @@ pub struct VisitedMap { data: Vec, /// Number of byte positions that can be tracked elements: usize, - /// Size of each bitfield element in bits - bitfield_size: usize, } +/// Size of each bitfield element in bits (always non-zero on supported targets). +const BITFIELD_SIZE: usize = std::mem::size_of::() * 8; + impl VisitedMap { /// Creates a new [`crate::utils::VisitedMap`] for tracking the specified number of bytes. /// @@ -121,19 +122,14 @@ impl VisitedMap { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn new(elements: usize) -> VisitedMap { - let bitfield_size = std::mem::size_of::() * 8; - let num_bitfields = elements.div_ceil(bitfield_size); + let num_bitfields = elements.div_ceil(BITFIELD_SIZE); let mut data = Vec::with_capacity(num_bitfields); for _ in 0..num_bitfields { data.push(AtomicUsize::new(0)); } - VisitedMap { - data, - elements, - bitfield_size, - } + VisitedMap { data, elements } } /// Returns the maximum number of elements this instance can track. @@ -210,8 +206,8 @@ impl VisitedMap { return false; } - if let Some(bitfield) = self.data.get(element / self.bitfield_size) { - let shift_amount = u32::try_from(element % self.bitfield_size).unwrap_or(0); + if let Some(bitfield) = self.data.get(element / BITFIELD_SIZE) { + let shift_amount = u32::try_from(element % BITFIELD_SIZE).unwrap_or(0); let current_value = bitfield.load(Ordering::Acquire); return (current_value.wrapping_shr(shift_amount) & 1_usize) != 0; } @@ -259,15 +255,18 @@ impl VisitedMap { let mut counter = 0; - while let Some(bitfield) = self.data.get((element + counter) / self.bitfield_size) { + while let Some(bitfield) = self + .data + .get(element.saturating_add(counter) / BITFIELD_SIZE) + { let current_value = bitfield.load(Ordering::Acquire); if current_value == usize::MAX { - counter += self.bitfield_size; + counter = counter.saturating_add(BITFIELD_SIZE); } else { let shift_amount = - u32::try_from((element + counter) % self.bitfield_size).unwrap_or(0); + u32::try_from(element.saturating_add(counter) % BITFIELD_SIZE).unwrap_or(0); if (current_value.wrapping_shr(shift_amount) & 1_usize) == 0 { - counter += 1; + counter = counter.saturating_add(1); } else { break; } @@ -313,17 +312,17 @@ impl VisitedMap { pub fn get_first(&self, visited: bool) -> usize { let mut counter = 0; - while let Some(bitfield) = self.data.get(counter / self.bitfield_size) { + while let Some(bitfield) = self.data.get(counter / BITFIELD_SIZE) { let current_value = bitfield.load(Ordering::Acquire); if visited { if current_value == usize::MAX { return counter; } else if current_value == 0 { - counter += self.bitfield_size; + counter = counter.saturating_add(BITFIELD_SIZE); } else { - let shift_amount = u32::try_from(counter % self.bitfield_size).unwrap_or(0); + let shift_amount = u32::try_from(counter % BITFIELD_SIZE).unwrap_or(0); if (current_value.wrapping_shr(shift_amount) & 1_usize) == 0 { - counter += 1; + counter = counter.saturating_add(1); } else { return counter; } @@ -331,13 +330,13 @@ impl VisitedMap { } else if current_value == 0 { return counter; } else if current_value == usize::MAX { - counter += self.bitfield_size; + counter = counter.saturating_add(BITFIELD_SIZE); } else if (current_value - .wrapping_shr(u32::try_from(counter % self.bitfield_size).unwrap_or(0)) + .wrapping_shr(u32::try_from(counter % BITFIELD_SIZE).unwrap_or(0)) & 1_usize) != 0 { - counter += 1; + counter = counter.saturating_add(1); } else { return counter; } @@ -411,28 +410,32 @@ impl VisitedMap { /// assert!(!visited_map.get(15)); /// ``` pub fn set_range(&self, element: usize, state: bool, len: usize) { - if element > self.elements || (element + len) > self.elements { + let Some(end) = element.checked_add(len) else { + debug_assert!(false, "Invalid element!"); + return; + }; + if element > self.elements || end > self.elements { debug_assert!(false, "Invalid element!"); return; } let mut counter = 0; while counter < len { - let current_pos = element + counter; - let bit_in_field = current_pos % self.bitfield_size; - let remaining = len - counter; + let current_pos = element.saturating_add(counter); + let bit_in_field = current_pos % BITFIELD_SIZE; + let remaining = len.saturating_sub(counter); - if let Some(bitfield) = self.data.get(current_pos / self.bitfield_size) { + if let Some(bitfield) = self.data.get(current_pos / BITFIELD_SIZE) { // Only use bulk set if: // 1. We're at a bitfield boundary (bit_in_field == 0), AND // 2. We have at least one full bitfield worth of bits remaining - if bit_in_field == 0 && remaining >= self.bitfield_size { + if bit_in_field == 0 && remaining >= BITFIELD_SIZE { if state { bitfield.store(usize::MAX, Ordering::Release); } else { bitfield.store(0, Ordering::Release); } - counter += self.bitfield_size; + counter = counter.saturating_add(BITFIELD_SIZE); } else { let shift_amount = u32::try_from(bit_in_field).unwrap_or(0); let bit_mask = 1_usize.wrapping_shl(shift_amount); @@ -443,7 +446,7 @@ impl VisitedMap { bitfield.fetch_and(!bit_mask, Ordering::AcqRel); } - counter += 1; + counter = counter.saturating_add(1); } } else { debug_assert!(false); From 8620c286f99fce45115480fe889245bdcbacbc7c Mon Sep 17 00:00:00 2001 From: BinFlip Date: Sat, 9 May 2026 07:00:48 -0700 Subject: [PATCH 2/6] feat: extracted ssa functionality into standalone 'analyssa' crate --- Cargo.lock | 72 +- dotscope/Cargo.toml | 15 + dotscope/benches/assembly.rs | 422 +-- dotscope/benches/cilassemblyview.rs | 23 +- dotscope/benches/cilobject.rs | 20 +- dotscope/benches/cor20.rs | 16 +- dotscope/benches/method_body.rs | 135 +- dotscope/benches/resources.rs | 20 +- dotscope/benches/security.rs | 31 +- dotscope/benches/signatures.rs | 87 +- dotscope/benches/streams.rs | 149 +- dotscope/examples/analysis.rs | 32 +- dotscope/examples/basic.rs | 13 +- dotscope/examples/comprehensive.rs | 47 +- dotscope/examples/decode_blocks.rs | 25 +- dotscope/examples/deobfuscate.rs | 8 +- dotscope/examples/disassembly.rs | 102 +- dotscope/examples/injectcode.rs | 19 +- dotscope/examples/lowlevel.rs | 52 +- dotscope/examples/metadata.rs | 107 +- dotscope/examples/method_analysis.rs | 57 +- dotscope/examples/modify.rs | 13 +- dotscope/examples/project_loader.rs | 36 +- dotscope/examples/raw_assembly_view.rs | 4 +- dotscope/examples/types.rs | 135 +- dotscope/src/analysis/algebraic.rs | 458 --- dotscope/src/analysis/callgraph/graph.rs | 12 +- dotscope/src/analysis/cfg/analyzer.rs | 330 -- dotscope/src/analysis/cfg/graph.rs | 17 +- dotscope/src/analysis/cfg/loops.rs | 936 ----- dotscope/src/analysis/cfg/mod.rs | 28 +- dotscope/src/analysis/cfg/semantics.rs | 15 +- dotscope/src/analysis/dataflow/framework.rs | 286 -- dotscope/src/analysis/dataflow/lattice.rs | 302 -- dotscope/src/analysis/dataflow/liveness.rs | 336 -- dotscope/src/analysis/dataflow/mod.rs | 104 +- dotscope/src/analysis/dataflow/reaching.rs | 274 -- dotscope/src/analysis/dataflow/sccp.rs | 767 ---- dotscope/src/analysis/dataflow/solver.rs | 403 --- dotscope/src/analysis/defuse.rs | 979 ----- dotscope/src/analysis/mod.rs | 40 +- dotscope/src/analysis/range.rs | 1285 ------- dotscope/src/analysis/ssa/block.rs | 1073 ------ dotscope/src/analysis/ssa/builder.rs | 65 +- dotscope/src/analysis/ssa/cfg.rs | 466 --- dotscope/src/analysis/ssa/constraints.rs | 251 -- dotscope/src/analysis/ssa/consts.rs | 520 --- dotscope/src/analysis/ssa/converter.rs | 42 +- dotscope/src/analysis/ssa/decompose.rs | 43 +- dotscope/src/analysis/ssa/evaluator.rs | 3098 ---------------- dotscope/src/analysis/ssa/exception.rs | 248 +- dotscope/src/analysis/ssa/function.rs | 313 ++ .../src/analysis/ssa/function/canonical.rs | 430 --- .../src/analysis/ssa/function/duplication.rs | 318 -- dotscope/src/analysis/ssa/function/mod.rs | 1610 --------- dotscope/src/analysis/ssa/function/queries.rs | 1225 ------- dotscope/src/analysis/ssa/function/rebuild.rs | 2118 ----------- dotscope/src/analysis/ssa/function/repair.rs | 192 - .../src/analysis/ssa/function/semantics.rs | 115 - .../src/analysis/ssa/function/transforms.rs | 1291 ------- dotscope/src/analysis/ssa/instruction.rs | 425 --- dotscope/src/analysis/ssa/liveness.rs | 241 -- dotscope/src/analysis/ssa/memory.rs | 1126 ------ dotscope/src/analysis/ssa/mod.rs | 128 +- dotscope/src/analysis/ssa/ops.rs | 3148 +---------------- dotscope/src/analysis/ssa/patterns.rs | 672 ---- dotscope/src/analysis/ssa/phi.rs | 435 --- dotscope/src/analysis/ssa/phis.rs | 833 ----- dotscope/src/analysis/ssa/resolver.rs | 106 +- .../src/analysis/ssa/symbolic/evaluator.rs | 369 -- dotscope/src/analysis/ssa/symbolic/expr.rs | 1134 ------ dotscope/src/analysis/ssa/symbolic/mod.rs | 30 +- dotscope/src/analysis/ssa/symbolic/ops.rs | 159 - dotscope/src/analysis/ssa/symbolic/solver.rs | 12 +- dotscope/src/analysis/ssa/target.rs | 559 +++ dotscope/src/analysis/ssa/types.rs | 10 +- dotscope/src/analysis/ssa/value.rs | 1815 +--------- dotscope/src/analysis/ssa/variable.rs | 713 ---- dotscope/src/analysis/ssa/verifier.rs | 1011 ------ dotscope/src/analysis/taint.rs | 729 +--- dotscope/src/analysis/x86/cfg.rs | 30 +- dotscope/src/analysis/x86/decoder.rs | 4 +- dotscope/src/analysis/x86/mod.rs | 14 +- dotscope/src/analysis/x86/ssa.rs | 88 +- dotscope/src/cilassembly/changes/heap.rs | 12 - dotscope/src/cilassembly/writer/fields.rs | 2 +- dotscope/src/cilassembly/writer/tables.rs | 8 - dotscope/src/compiler/codegen/coalescing.rs | 3 +- dotscope/src/compiler/codegen/mod.rs | 40 +- dotscope/src/compiler/codegen/tests.rs | 23 + dotscope/src/compiler/context.rs | 52 +- dotscope/src/compiler/events.rs | 903 ----- dotscope/src/compiler/host.rs | 181 + dotscope/src/compiler/mod.rs | 20 +- dotscope/src/compiler/pass.rs | 342 +- dotscope/src/compiler/passes/algebraic.rs | 310 -- dotscope/src/compiler/passes/blockmerge.rs | 908 ----- dotscope/src/compiler/passes/constants/mod.rs | 52 +- .../src/compiler/passes/constants/tests.rs | 110 +- dotscope/src/compiler/passes/controlflow.rs | 872 ----- dotscope/src/compiler/passes/copying.rs | 1046 +----- dotscope/src/compiler/passes/deadcode.rs | 1714 +-------- dotscope/src/compiler/passes/gvn.rs | 418 --- dotscope/src/compiler/passes/inlining.rs | 22 +- dotscope/src/compiler/passes/licm.rs | 704 ---- dotscope/src/compiler/passes/loopcanon.rs | 694 ---- dotscope/src/compiler/passes/mod.rs | 39 +- dotscope/src/compiler/passes/predicates.rs | 2308 ------------ dotscope/src/compiler/passes/proxy.rs | 46 +- dotscope/src/compiler/passes/ranges.rs | 1012 ------ dotscope/src/compiler/passes/reassociate.rs | 585 --- dotscope/src/compiler/passes/strength.rs | 406 +-- dotscope/src/compiler/passes/threading.rs | 545 --- dotscope/src/compiler/passes/utils.rs | 109 - dotscope/src/compiler/scheduler.rs | 1055 +----- dotscope/src/compiler/summary.rs | 13 +- dotscope/src/deobfuscation/context.rs | 14 +- dotscope/src/deobfuscation/engine/api.rs | 3 +- dotscope/src/deobfuscation/engine/pipeline.rs | 34 +- dotscope/src/deobfuscation/engine/tests.rs | 36 +- .../src/deobfuscation/passes/antidebug.rs | 26 +- .../deobfuscation/passes/bitmono/strings.rs | 17 +- .../deobfuscation/passes/bitmono/unmanaged.rs | 10 +- .../src/deobfuscation/passes/decryption.rs | 61 +- .../src/deobfuscation/passes/delegates.rs | 30 +- .../deobfuscation/passes/jiejienet/arrays.rs | 10 +- .../passes/jiejienet/resources.rs | 10 +- .../deobfuscation/passes/jiejienet/typeofs.rs | 20 +- .../passes/netreactor/resolver.rs | 10 +- .../passes/netreactor/rewrite.rs | 35 +- .../src/deobfuscation/passes/neutralize.rs | 18 +- .../src/deobfuscation/passes/opaquefields.rs | 34 +- .../src/deobfuscation/passes/reflection.rs | 25 +- .../src/deobfuscation/passes/staticfields.rs | 24 +- .../passes/unflattening/detection.rs | 22 +- .../deobfuscation/passes/unflattening/mod.rs | 39 +- .../passes/unflattening/reconstruction.rs | 5 +- .../passes/unflattening/tracer/context.rs | 10 +- .../passes/unflattening/tracer/helpers.rs | 2 +- .../passes/unflattening/tracer/types.rs | 3 +- dotscope/src/deobfuscation/renamer/cascade.rs | 42 +- dotscope/src/deobfuscation/renamer/config.rs | 5 +- dotscope/src/deobfuscation/renamer/mod.rs | 4 +- dotscope/src/deobfuscation/renamer/phases.rs | 29 +- dotscope/src/deobfuscation/renamer/prompt.rs | 13 +- dotscope/src/deobfuscation/statemachine.rs | 15 +- .../deobfuscation/techniques/bitmono/debug.rs | 8 +- .../techniques/bitmono/strings.rs | 16 +- .../techniques/bitmono/unmanaged.rs | 6 +- .../techniques/confuserex/debug.rs | 5 +- .../techniques/confuserex/statemachine.rs | 2 +- .../src/deobfuscation/techniques/detection.rs | 5 +- .../techniques/generic/delegates.rs | 8 +- .../techniques/generic/flattening.rs | 5 +- .../techniques/generic/opaquefields.rs | 6 +- .../techniques/jiejienet/arrays.rs | 15 +- .../techniques/jiejienet/constants.rs | 5 +- .../techniques/jiejienet/resources.rs | 6 +- .../techniques/jiejienet/strings.rs | 22 +- .../techniques/jiejienet/typeofs.rs | 5 +- dotscope/src/deobfuscation/techniques/mod.rs | 5 +- .../techniques/netreactor/antitamp.rs | 10 +- .../techniques/netreactor/resources.rs | 68 +- dotscope/src/emulation/filesystem.rs | 5 +- dotscope/src/emulation/memory/heap/mod.rs | 2 +- dotscope/src/emulation/process/builder.rs | 2 +- dotscope/src/emulation/runtime/appdomain.rs | 3 +- dotscope/src/emulation/runtime/bcl/runtime.rs | 1 + dotscope/src/emulation/value/emvalue.rs | 36 + dotscope/src/emulation/value/ops/binary.rs | 4 +- dotscope/src/error.rs | 16 +- dotscope/src/formatting/mod.rs | 5 +- dotscope/src/lib.rs | 16 +- dotscope/src/metadata/cilobject.rs | 6 +- dotscope/src/metadata/dependencies/graph.rs | 9 +- dotscope/src/metadata/loader/graph.rs | 17 +- dotscope/src/metadata/resolver.rs | 7 +- dotscope/src/metadata/typesystem/base.rs | 88 +- dotscope/src/metadata/vtfixup.rs | 14 +- dotscope/src/test/analysis/templates.rs | 14 +- dotscope/src/utils/bitset.rs | 384 -- dotscope/src/utils/graph/algorithms/cycles.rs | 418 --- .../src/utils/graph/algorithms/dominators.rs | 1095 ------ dotscope/src/utils/graph/algorithms/mod.rs | 99 - dotscope/src/utils/graph/algorithms/scc.rs | 664 ---- .../src/utils/graph/algorithms/topological.rs | 336 -- .../src/utils/graph/algorithms/traversal.rs | 748 ---- dotscope/src/utils/graph/directed.rs | 1381 -------- dotscope/src/utils/graph/edge.rs | 280 -- dotscope/src/utils/graph/indexed.rs | 420 --- dotscope/src/utils/graph/mod.rs | 122 - dotscope/src/utils/graph/node.rs | 273 -- dotscope/src/utils/graph/traits.rs | 335 -- dotscope/src/utils/mod.rs | 5 - dotscope/tests/bitmono.rs | 9 + dotscope/tests/common/compatibility.rs | 9 + dotscope/tests/common/framework.rs | 9 + dotscope/tests/common/verification.rs | 9 + dotscope/tests/confuserex.rs | 9 + dotscope/tests/crafted_1.rs | 3 + dotscope/tests/deobfuscation.rs | 9 + dotscope/tests/fuzzer.rs | 4 + dotscope/tests/jiejie.rs | 9 + dotscope/tests/modify_add.rs | 9 + dotscope/tests/modify_basic.rs | 9 + dotscope/tests/modify_heaps.rs | 9 + dotscope/tests/modify_impexp.rs | 9 + dotscope/tests/modify_roundtrips_crafted2.rs | 9 + dotscope/tests/modify_roundtrips_method.rs | 9 + dotscope/tests/modify_roundtrips_wbdll.rs | 9 + dotscope/tests/netreactor.rs | 9 + dotscope/tests/obfuscar.rs | 9 + dotscope/tests/roundtrip_asm.rs | 9 + dotscope/tests/ssa.rs | 13 +- 214 files changed, 4021 insertions(+), 53735 deletions(-) delete mode 100644 dotscope/src/analysis/algebraic.rs delete mode 100644 dotscope/src/analysis/cfg/analyzer.rs delete mode 100644 dotscope/src/analysis/cfg/loops.rs delete mode 100644 dotscope/src/analysis/dataflow/framework.rs delete mode 100644 dotscope/src/analysis/dataflow/lattice.rs delete mode 100644 dotscope/src/analysis/dataflow/liveness.rs delete mode 100644 dotscope/src/analysis/dataflow/reaching.rs delete mode 100644 dotscope/src/analysis/dataflow/sccp.rs delete mode 100644 dotscope/src/analysis/dataflow/solver.rs delete mode 100644 dotscope/src/analysis/defuse.rs delete mode 100644 dotscope/src/analysis/range.rs delete mode 100644 dotscope/src/analysis/ssa/block.rs delete mode 100644 dotscope/src/analysis/ssa/cfg.rs delete mode 100644 dotscope/src/analysis/ssa/constraints.rs delete mode 100644 dotscope/src/analysis/ssa/consts.rs delete mode 100644 dotscope/src/analysis/ssa/evaluator.rs create mode 100644 dotscope/src/analysis/ssa/function.rs delete mode 100644 dotscope/src/analysis/ssa/function/canonical.rs delete mode 100644 dotscope/src/analysis/ssa/function/duplication.rs delete mode 100644 dotscope/src/analysis/ssa/function/mod.rs delete mode 100644 dotscope/src/analysis/ssa/function/queries.rs delete mode 100644 dotscope/src/analysis/ssa/function/rebuild.rs delete mode 100644 dotscope/src/analysis/ssa/function/repair.rs delete mode 100644 dotscope/src/analysis/ssa/function/semantics.rs delete mode 100644 dotscope/src/analysis/ssa/function/transforms.rs delete mode 100644 dotscope/src/analysis/ssa/instruction.rs delete mode 100644 dotscope/src/analysis/ssa/liveness.rs delete mode 100644 dotscope/src/analysis/ssa/memory.rs delete mode 100644 dotscope/src/analysis/ssa/patterns.rs delete mode 100644 dotscope/src/analysis/ssa/phi.rs delete mode 100644 dotscope/src/analysis/ssa/phis.rs delete mode 100644 dotscope/src/analysis/ssa/symbolic/evaluator.rs delete mode 100644 dotscope/src/analysis/ssa/symbolic/expr.rs delete mode 100644 dotscope/src/analysis/ssa/symbolic/ops.rs create mode 100644 dotscope/src/analysis/ssa/target.rs delete mode 100644 dotscope/src/analysis/ssa/variable.rs delete mode 100644 dotscope/src/analysis/ssa/verifier.rs delete mode 100644 dotscope/src/compiler/events.rs create mode 100644 dotscope/src/compiler/host.rs delete mode 100644 dotscope/src/compiler/passes/algebraic.rs delete mode 100644 dotscope/src/compiler/passes/blockmerge.rs delete mode 100644 dotscope/src/compiler/passes/controlflow.rs delete mode 100644 dotscope/src/compiler/passes/gvn.rs delete mode 100644 dotscope/src/compiler/passes/licm.rs delete mode 100644 dotscope/src/compiler/passes/loopcanon.rs delete mode 100644 dotscope/src/compiler/passes/predicates.rs delete mode 100644 dotscope/src/compiler/passes/ranges.rs delete mode 100644 dotscope/src/compiler/passes/reassociate.rs delete mode 100644 dotscope/src/compiler/passes/threading.rs delete mode 100644 dotscope/src/compiler/passes/utils.rs delete mode 100644 dotscope/src/utils/bitset.rs delete mode 100644 dotscope/src/utils/graph/algorithms/cycles.rs delete mode 100644 dotscope/src/utils/graph/algorithms/dominators.rs delete mode 100644 dotscope/src/utils/graph/algorithms/mod.rs delete mode 100644 dotscope/src/utils/graph/algorithms/scc.rs delete mode 100644 dotscope/src/utils/graph/algorithms/topological.rs delete mode 100644 dotscope/src/utils/graph/algorithms/traversal.rs delete mode 100644 dotscope/src/utils/graph/directed.rs delete mode 100644 dotscope/src/utils/graph/edge.rs delete mode 100644 dotscope/src/utils/graph/indexed.rs delete mode 100644 dotscope/src/utils/graph/mod.rs delete mode 100644 dotscope/src/utils/graph/node.rs delete mode 100644 dotscope/src/utils/graph/traits.rs diff --git a/Cargo.lock b/Cargo.lock index 737b13df..82876d96 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -97,6 +97,19 @@ version = "0.2.21" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" +[[package]] +name = "analyssa" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dd57908e45c301501605d7a6cd7b6449f7d0020a95c57ec2c29180be8689c78d" +dependencies = [ + "boxcar", + "dashmap", + "log", + "rayon", + "thiserror 2.0.18", +] + [[package]] name = "android_system_properties" version = "0.1.5" @@ -1428,9 +1441,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.11.2" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4850db49bf08e663084f7fb5c87d202ef91a3907271aff24a94eb97ff039153c" +checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ "block-buffer 0.12.0", "const-oid", @@ -1502,6 +1515,7 @@ name = "dotscope" version = "0.7.0" dependencies = [ "aes", + "analyssa", "boxcar", "cbc", "clap", @@ -2320,9 +2334,9 @@ dependencies = [ [[package]] name = "h2" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2f44da3a8150a6703ed5d34e164b875fd14c2cdab9af1252a9a1020bde2bdc54" +checksum = "171fefbc92fe4a4de27e0698d6a5b392d6a0e333506bc49133760b3bcf948733" dependencies = [ "atomic-waker", "bytes", @@ -2447,7 +2461,7 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" dependencies = [ - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -3269,7 +3283,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "69b6441f590336821bb897fb28fc622898ccceb1d6cea3fde5ea86b090c4de98" dependencies = [ "cfg-if", - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -3692,9 +3706,9 @@ dependencies = [ [[package]] name = "no_std_io2" -version = "0.9.3" +version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b51ed7824b6e07d354605f4abb3d9d300350701299da96642ee084f5ce631550" +checksum = "418abd1b6d34fbf6cae440dc874771b0525a604428704c76e48b29a5e67b8003" dependencies = [ "memchr", ] @@ -4097,7 +4111,7 @@ version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "112d82ceb8c5bf524d9af484d4e4970c9fd5a0cc15ba14ad93dccd28873b0629" dependencies = [ - "digest 0.11.2", + "digest 0.11.3", "hmac", ] @@ -4307,18 +4321,18 @@ dependencies = [ [[package]] name = "profiling" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3eb8486b569e12e2c32ad3e204dbaba5e4b5b216e9367044f25f1dba42341773" +checksum = "3d595e54a326bc53c1c197b32d295e14b169e3cfeaa8dc82b529f947fba6bcf5" dependencies = [ "profiling-procmacros", ] [[package]] name = "profiling-procmacros" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "52717f9a02b6965224f95ca2a81e2e0c5c43baacd28ca057577988930b6c3d5b" +checksum = "4488a4a36b9a4ba6b9334a32a39971f77c1436ec82c38707bce707699cc3bbcb" dependencies = [ "quote", "syn 2.0.117", @@ -4384,9 +4398,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quick-xml" -version = "0.39.2" +version = "0.39.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "958f21e8e7ceb5a1aa7fa87fab28e7c75976e0bfe7e23ff069e0a260f894067d" +checksum = "721da970c312655cde9b4ffe0547f20a8494866a4af5ff51f18b7c633d0c870b" dependencies = [ "memchr", ] @@ -5347,9 +5361,9 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd5414fad8e6907dbdd5bc441a50ae8d6e26151a03b1de04d89a5576de61d01f" +checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" dependencies = [ "base64 0.22.1", "chrono", @@ -5366,9 +5380,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.18.0" +version = "3.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d3db8978e608f1fe7357e211969fd9abdcae80bac1ba7a3369bb7eb6b404eb65" +checksum = "cf2ebbe86054f9b45bc3881e865683ccfaccce97b9b4cb53f3039d67f355a334" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -5404,7 +5418,7 @@ checksum = "aacc4cc499359472b4abe1bf11d0b12e688af9a805fa5e3016f9a386dc2d0214" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -5426,7 +5440,7 @@ checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" dependencies = [ "cfg-if", "cpufeatures 0.3.0", - "digest 0.11.2", + "digest 0.11.3", ] [[package]] @@ -5521,9 +5535,9 @@ checksum = "e3a9fe34e3e7a50316060351f37187a3f546bce95496156754b601a5fa71b76e" [[package]] name = "siphasher" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b2aa850e253778c88a04c3d7323b043aeda9d3e30d5971937c1855769763678e" +checksum = "8ee5873ec9cce0195efcb7a4e9507a04cd49aec9c83d0389df45b1ef7ba2e649" [[package]] name = "slab" @@ -6177,9 +6191,9 @@ dependencies = [ [[package]] name = "tokio" -version = "1.52.1" +version = "1.52.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b67dee974fe86fd92cc45b7a95fdd2f99a36a6d7b0d431a231178d3d670bbcc6" +checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" dependencies = [ "bytes", "libc", @@ -6637,9 +6651,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" [[package]] name = "utoipa" -version = "5.4.0" +version = "5.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fcc29c80c21c31608227e0912b2d7fddba57ad76b606890627ba8ee7964e993" +checksum = "8bde15df68e80b16c7d16b9616e80770ad158988daa56a27dccd1e55558b0160" dependencies = [ "indexmap 2.14.0", "serde", @@ -6649,9 +6663,9 @@ dependencies = [ [[package]] name = "utoipa-gen" -version = "5.4.0" +version = "5.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d79d08d92ab8af4c5e8a6da20c47ae3f61a0f1dabc1997cdf2d082b757ca08b" +checksum = "6ba0b99ee52df3028635d93840c797102da61f8a7bb3cf751032455895b52ef8" dependencies = [ "proc-macro2", "quote", diff --git a/dotscope/Cargo.toml b/dotscope/Cargo.toml index 33ffe079..fac1f8bc 100644 --- a/dotscope/Cargo.toml +++ b/dotscope/Cargo.toml @@ -25,6 +25,20 @@ exclude = [ all-features = true rustdoc-args = ["--cfg", "docsrs"] +# dotscope is used in malware-analysis pipelines where every input byte +# is adversarial. These lints prevent the parser from panicking on +# malformed input. The `cfg_attr(test, allow(...))` escape hatch in +# `src/lib.rs` lets tests still use `unwrap`/`expect`/`panic` for terseness. +[lints.rust] +missing_docs = "deny" + +[lints.clippy] +unwrap_used = "deny" +expect_used = "deny" +panic = "deny" +arithmetic_side_effects = "deny" +indexing_slicing = "deny" + [dependencies] thiserror = "2.0.18" uguid = "2.2.1" @@ -55,6 +69,7 @@ hex = "0.4.3" num-bigint = { version = "0.4.6", optional = true } log = "0.4.29" flate2 = "1.1.9" +analyssa = "0.1.0" lzma-rs = "0.3.0" z3 = { version = "0.20.0", optional = true } iced-x86 = { version = "1.21.0", default-features = false, features = ["std", "decoder", "instr_info"], optional = true } diff --git a/dotscope/benches/assembly.rs b/dotscope/benches/assembly.rs index ce9f4149..bc415802 100644 --- a/dotscope/benches/assembly.rs +++ b/dotscope/benches/assembly.rs @@ -1,71 +1,157 @@ -#![allow(unused)] +//! Benchmarks for CIL assembly and disassembly. +//! +//! Exercises both the fluent [`InstructionAssembler`] API and the lower-level +//! [`InstructionEncoder`] / [`decode_stream`] paths, plus assemble-disassemble +//! roundtrips. + extern crate dotscope; use criterion::{criterion_group, criterion_main, Criterion}; -use dotscope::assembly::{ - decode_instruction, decode_stream, InstructionAssembler, InstructionEncoder, -}; +use dotscope::assembly::{decode_stream, InstructionAssembler, InstructionEncoder}; use dotscope::metadata::token::Token; -use dotscope::Parser; +use dotscope::Result; +use std::hint::black_box; + +fn assemble_simple() -> Result<( + Vec, + u16, + Vec, +)> { + let mut asm = InstructionAssembler::new(); + asm.ldarg_1()?.ldarg_2()?.add()?.ret()?; + asm.finish() +} + +fn assemble_complex() -> Result<( + Vec, + u16, + Vec, +)> { + let mut asm = InstructionAssembler::new(); + asm.ldc_i4_0()? + .stloc_0()? + .br("loop_condition")? + .label("loop_start")? + .ldarg_0()? + .ldloc_0()? + .ldarg_1()? + .stelem_i4()? + .ldloc_0()? + .ldc_i4_1()? + .add()? + .stloc_0()? + .label("loop_condition")? + .ldloc_0()? + .ldc_i4_const(10)? + .clt()? + .brtrue("loop_start")? + .ret()?; + asm.finish() +} + +fn assemble_object( + field_token: Token, + method_token: Token, +) -> Result<( + Vec, + u16, + Vec, +)> { + let mut asm = InstructionAssembler::new(); + asm.ldarg_0()? + .ldfld(field_token)? + .ldnull()? + .ceq()? + .brfalse("not_null")? + .ldarg_0()? + .newobj(method_token)? + .stfld(field_token)? + .label("not_null")? + .ldarg_0()? + .ldfld(field_token)? + .callvirt(method_token)? + .ret()?; + asm.finish() +} +fn assemble_with_optimizations() -> Result<( + Vec, + u16, + Vec, +)> { + let mut asm = InstructionAssembler::new(); + asm.ldc_i4_const(0)? + .ldc_i4_const(1)? + .ldc_i4_const(127)? + .ldc_i4_const(1000)? + .add()? + .add()? + .add()? + .ret()?; + asm.finish() +} + +fn assemble_manual_selection() -> Result<( + Vec, + u16, + Vec, +)> { + let mut asm = InstructionAssembler::new(); + asm.ldc_i4_0()? + .ldc_i4_1()? + .ldc_i4_s(127)? + .ldc_i4(1000)? + .add()? + .add()? + .add()? + .ret()?; + asm.finish() +} + +fn assemble_large_method() -> Result<( + Vec, + u16, + Vec, +)> { + let mut asm = InstructionAssembler::new(); + for i in 0i32..50 { + asm.ldarg_0()? + .ldc_i4_const(i)? + .ceq()? + .brtrue(&format!("case_{i}"))?; + } + asm.ldc_i4_m1()?.ret()?; + for i in 0i32..50 { + asm.label(&format!("case_{i}"))? + .ldc_i4_const(i.saturating_mul(2))? + .ret()?; + } + asm.finish() +} + +fn encode_simple_direct() -> Result<(Vec, u16, std::collections::HashMap)> { + let mut encoder = InstructionEncoder::new(); + encoder.emit_instruction("ldarg.1", None)?; + encoder.emit_instruction("ldarg.2", None)?; + encoder.emit_instruction("add", None)?; + encoder.emit_instruction("ret", None)?; + encoder.finalize() +} + +/// Benchmark CIL assembly and disassembly across simple, complex, and large +/// method shapes plus assemble-disassemble roundtrips. pub fn criterion_benchmark(c: &mut Criterion) { // Simple method: basic arithmetic c.bench_function("bench_assemble_simple_method", |b| { b.iter(|| { - let mut asm = InstructionAssembler::new(); - asm.ldarg_1() - .unwrap() - .ldarg_2() - .unwrap() - .add() - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap() + black_box(assemble_simple().ok()); }); }); // Complex method: loops, branches, and object operations c.bench_function("bench_assemble_complex_method", |b| { b.iter(|| { - let mut asm = InstructionAssembler::new(); - asm.ldc_i4_0() - .unwrap() // int i = 0 - .stloc_0() - .unwrap() - .br("loop_condition") - .unwrap() - .label("loop_start") - .unwrap() - .ldarg_0() - .unwrap() // Load array - .ldloc_0() - .unwrap() // Load index - .ldarg_1() - .unwrap() // Load value - .stelem_i4() - .unwrap() // array[i] = value - .ldloc_0() - .unwrap() // i++ - .ldc_i4_1() - .unwrap() - .add() - .unwrap() - .stloc_0() - .unwrap() - .label("loop_condition") - .unwrap() - .ldloc_0() - .unwrap() // if (i < 10) - .ldc_i4_const(10) - .unwrap() - .clt() - .unwrap() - .brtrue("loop_start") - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap() + black_box(assemble_complex().ok()); }); }); @@ -73,172 +159,41 @@ pub fn criterion_benchmark(c: &mut Criterion) { c.bench_function("bench_assemble_object_method", |b| { let field_token = Token::new(0x04000001); let method_token = Token::new(0x06000001); - let type_token = Token::new(0x02000001); + let _type_token = Token::new(0x02000001); b.iter(|| { - let mut asm = InstructionAssembler::new(); - asm.ldarg_0() - .unwrap() // this - .ldfld(field_token) - .unwrap() // Load field - .ldnull() - .unwrap() // Compare with null - .ceq() - .unwrap() - .brfalse("not_null") - .unwrap() - .ldarg_0() - .unwrap() // Create new object - .newobj(method_token) - .unwrap() - .stfld(field_token) - .unwrap() - .label("not_null") - .unwrap() - .ldarg_0() - .unwrap() // Return field value - .ldfld(field_token) - .unwrap() - .callvirt(method_token) - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap() + black_box(assemble_object(field_token, method_token).ok()); }); }); // Low-level encoder benchmark c.bench_function("bench_assemble_encoder_direct", |b| { b.iter(|| { - let mut encoder = InstructionEncoder::new(); - encoder.emit_instruction("ldarg.1", None).unwrap(); - encoder.emit_instruction("ldarg.2", None).unwrap(); - encoder.emit_instruction("add", None).unwrap(); - encoder.emit_instruction("ret", None).unwrap(); - encoder.finalize().unwrap().0 + black_box(encode_simple_direct().ok()); }); }); // Roundtrip benchmark: assemble then disassemble - let simple_bytecode = { - let mut asm = InstructionAssembler::new(); - asm.ldarg_1() - .unwrap() - .ldarg_2() - .unwrap() - .add() - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap().0 - }; + let simple_bytecode = assemble_simple().map(|(b, _, _)| b).unwrap_or_default(); + let complex_bytecode = assemble_complex().map(|(b, _, _)| b).unwrap_or_default(); c.bench_function("bench_roundtrip_simple", |b| { b.iter(|| { - // Assemble - let mut asm = InstructionAssembler::new(); - asm.ldarg_1() - .unwrap() - .ldarg_2() - .unwrap() - .add() - .unwrap() - .ret() - .unwrap(); - let (bytecode, _max_stack, _) = asm.finish().unwrap(); - - // Disassemble - let mut parser = dotscope::Parser::new(&bytecode); - decode_stream(&mut parser, 0x1000).unwrap() + let _: Result<_> = (|| { + let (bytecode, _max_stack, _) = assemble_simple()?; + let mut parser = dotscope::Parser::new(&bytecode); + decode_stream(&mut parser, 0x1000) + })(); }); }); - let complex_bytecode = { - let mut asm = InstructionAssembler::new(); - asm.ldc_i4_0() - .unwrap() - .stloc_0() - .unwrap() - .br("loop_condition") - .unwrap() - .label("loop_start") - .unwrap() - .ldarg_0() - .unwrap() - .ldloc_0() - .unwrap() - .ldarg_1() - .unwrap() - .stelem_i4() - .unwrap() - .ldloc_0() - .unwrap() - .ldc_i4_1() - .unwrap() - .add() - .unwrap() - .stloc_0() - .unwrap() - .label("loop_condition") - .unwrap() - .ldloc_0() - .unwrap() - .ldc_i4_const(10) - .unwrap() - .clt() - .unwrap() - .brtrue("loop_start") - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap().0 - }; - c.bench_function("bench_roundtrip_complex", |b| { b.iter(|| { - // Assemble - let mut asm = InstructionAssembler::new(); - asm.ldc_i4_0() - .unwrap() - .stloc_0() - .unwrap() - .br("loop_condition") - .unwrap() - .label("loop_start") - .unwrap() - .ldarg_0() - .unwrap() - .ldloc_0() - .unwrap() - .ldarg_1() - .unwrap() - .stelem_i4() - .unwrap() - .ldloc_0() - .unwrap() - .ldc_i4_1() - .unwrap() - .add() - .unwrap() - .stloc_0() - .unwrap() - .label("loop_condition") - .unwrap() - .ldloc_0() - .unwrap() - .ldc_i4_const(10) - .unwrap() - .clt() - .unwrap() - .brtrue("loop_start") - .unwrap() - .ret() - .unwrap(); - let (bytecode, _max_stack, _) = asm.finish().unwrap(); - - // Disassemble - let mut parser = dotscope::Parser::new(&bytecode); - decode_stream(&mut parser, 0x1000).unwrap() + let _: Result<_> = (|| { + let (bytecode, _max_stack, _) = assemble_complex()?; + let mut parser = dotscope::Parser::new(&bytecode); + decode_stream(&mut parser, 0x1000) + })(); }); }); @@ -246,93 +201,34 @@ pub fn criterion_benchmark(c: &mut Criterion) { c.bench_function("bench_disassemble_simple", |b| { b.iter(|| { let mut parser = dotscope::Parser::new(&simple_bytecode); - decode_stream(&mut parser, 0x1000).unwrap() + black_box(decode_stream(&mut parser, 0x1000).ok()); }); }); c.bench_function("bench_disassemble_complex", |b| { b.iter(|| { let mut parser = dotscope::Parser::new(&complex_bytecode); - decode_stream(&mut parser, 0x1000).unwrap() + black_box(decode_stream(&mut parser, 0x1000).ok()); }); }); // Optimization benchmark: compare ldc_i4_const vs manual selection c.bench_function("bench_assemble_with_optimizations", |b| { b.iter(|| { - let mut asm = InstructionAssembler::new(); - asm.ldc_i4_const(0) - .unwrap() // Should use ldc.i4.0 - .ldc_i4_const(1) - .unwrap() // Should use ldc.i4.1 - .ldc_i4_const(127) - .unwrap() // Should use ldc.i4.s - .ldc_i4_const(1000) - .unwrap() // Should use ldc.i4 - .add() - .unwrap() - .add() - .unwrap() - .add() - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap() + black_box(assemble_with_optimizations().ok()); }); }); c.bench_function("bench_assemble_manual_selection", |b| { b.iter(|| { - let mut asm = InstructionAssembler::new(); - asm.ldc_i4_0() - .unwrap() - .ldc_i4_1() - .unwrap() - .ldc_i4_s(127) - .unwrap() - .ldc_i4(1000) - .unwrap() - .add() - .unwrap() - .add() - .unwrap() - .add() - .unwrap() - .ret() - .unwrap(); - asm.finish().unwrap() + black_box(assemble_manual_selection().ok()); }); }); // Memory-intensive benchmark: large method with many labels c.bench_function("bench_assemble_large_method", |b| { b.iter(|| { - let mut asm = InstructionAssembler::new(); - - // Create a method with many branches and labels - for i in 0..50 { - asm.ldarg_0() - .unwrap() - .ldc_i4_const(i) - .unwrap() - .ceq() - .unwrap() - .brtrue(&format!("case_{i}")) - .unwrap(); - } - - asm.ldc_i4_m1().unwrap().ret().unwrap(); - - for i in 0..50 { - asm.label(&format!("case_{i}")) - .unwrap() - .ldc_i4_const(i * 2) - .unwrap() - .ret() - .unwrap(); - } - - asm.finish().unwrap() + black_box(assemble_large_method().ok()); }); }); } diff --git a/dotscope/benches/cilassemblyview.rs b/dotscope/benches/cilassemblyview.rs index 95579624..a25e52c0 100644 --- a/dotscope/benches/cilassemblyview.rs +++ b/dotscope/benches/cilassemblyview.rs @@ -1,28 +1,23 @@ -#![allow(unused)] +//! Benchmarks for [`CilAssemblyView`] loading with and without metadata validation. + extern crate dotscope; use criterion::{criterion_group, criterion_main, Criterion}; use dotscope::{CilAssemblyView, ValidationConfig}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; +/// Benchmark loading a `CilAssemblyView` with and without metadata validation. pub fn criterion_benchmark(c: &mut Criterion) { - // // Set rayon to use only 1 thread for this benchmark to profile - // rayon::ThreadPoolBuilder::new() - // .num_threads(1) - // .build_global() - // .unwrap(); - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WindowsBase.dll"); c.bench_function("bench_cilassemblyview", |b| { - b.iter({ || CilAssemblyView::from_path(&path).unwrap() }); + b.iter(|| { + let _ = CilAssemblyView::from_path(&path); + }); }); c.bench_function("bench_cilassemblyview_validation", |b| { - b.iter({ - || { - CilAssemblyView::from_path_with_validation(&path, ValidationConfig::strict()) - .unwrap() - } + b.iter(|| { + let _ = CilAssemblyView::from_path_with_validation(&path, ValidationConfig::strict()); }); }); } diff --git a/dotscope/benches/cilobject.rs b/dotscope/benches/cilobject.rs index 19db51e4..3aae6236 100644 --- a/dotscope/benches/cilobject.rs +++ b/dotscope/benches/cilobject.rs @@ -1,28 +1,24 @@ -#![allow(unused)] +//! Benchmarks for [`CilObject`] loading with and without metadata validation. + extern crate dotscope; use criterion::{criterion_group, criterion_main, Criterion}; use dotscope::{metadata::cilobject::CilObject, ValidationConfig}; -use std::path::{Path, PathBuf}; +use std::path::PathBuf; +/// Benchmark loading a `CilObject` with and without metadata validation. pub fn criterion_benchmark(c: &mut Criterion) { - // // Set rayon to use only 1 thread for this benchmark to profile - // rayon::ThreadPoolBuilder::new() - // .num_threads(1) - // .build_global() - // .unwrap(); - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/mono_4.8/mscorlib.dll"); c.bench_function("bench_cilobject", |b| { - b.iter({ - || CilObject::from_path_with_validation(&path, ValidationConfig::disabled()).unwrap() + b.iter(|| { + let _ = CilObject::from_path_with_validation(&path, ValidationConfig::disabled()); }); }); c.bench_function("bench_cilobject_validation", |b| { - b.iter({ - || CilObject::from_path_with_validation(&path, ValidationConfig::strict()).unwrap() + b.iter(|| { + let _ = CilObject::from_path_with_validation(&path, ValidationConfig::strict()); }); }); } diff --git a/dotscope/benches/cor20.rs b/dotscope/benches/cor20.rs index ad0a030c..338f1c2f 100644 --- a/dotscope/benches/cor20.rs +++ b/dotscope/benches/cor20.rs @@ -17,17 +17,25 @@ use std::{fs, hint::black_box, path::PathBuf}; fn bench_cor20_header_parse(c: &mut Criterion) { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_COR20_HEADER.bin"); - let data = fs::read(&path).expect("Failed to read COR20 header file"); + let Ok(data) = fs::read(&path) else { + eprintln!( + "Skipping cor20 benchmark: failed to read {}", + path.display() + ); + return; + }; let file_size = data.len(); - assert_eq!(file_size, 72, "COR20 header must be exactly 72 bytes"); + if file_size != 72 { + eprintln!("COR20 header must be exactly 72 bytes, got {file_size}; skipping"); + return; + } let mut group = c.benchmark_group("cor20_header"); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse", |b| { b.iter(|| { - let header = Cor20Header::read(black_box(&data)).unwrap(); - black_box(header) + black_box(Cor20Header::read(black_box(&data)).ok()); }); }); group.finish(); diff --git a/dotscope/benches/method_body.rs b/dotscope/benches/method_body.rs index f12a6abb..a182c25f 100644 --- a/dotscope/benches/method_body.rs +++ b/dotscope/benches/method_body.rs @@ -11,130 +11,91 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use dotscope::metadata::method::MethodBody; use std::{fs, hint::black_box, path::PathBuf}; -/// Benchmark parsing a tiny method header. -/// -/// Tiny headers are 1 byte and can represent methods up to 63 bytes of IL code. -/// This is the fastest path for simple methods. -fn bench_parse_method_tiny(c: &mut Criterion) { - let path = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_METHOD_TINY_0600032D.bin"); - - let data = fs::read(&path).expect("Failed to read tiny method file"); +/// Run a method-body parsing benchmark over a sample file. Skips with a +/// diagnostic message when the file cannot be read. +fn bench_method_file(c: &mut Criterion, group_name: &str, sample: &str) { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(sample); + + let Ok(data) = fs::read(&path) else { + eprintln!("Skipping {group_name}: failed to read {}", path.display()); + return; + }; let file_size = data.len(); - let mut group = c.benchmark_group("method_body_tiny"); + let mut group = c.benchmark_group(group_name); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse", |b| { b.iter(|| { - let body = MethodBody::from(black_box(&data)).unwrap(); - black_box(body) + black_box(MethodBody::from(black_box(&data)).ok()); }); }); group.finish(); } +/// Benchmark parsing a tiny method header. +/// +/// Tiny headers are 1 byte and can represent methods up to 63 bytes of IL code. +/// This is the fastest path for simple methods. +fn bench_parse_method_tiny(c: &mut Criterion) { + bench_method_file( + c, + "method_body_tiny", + "tests/samples/WB_METHOD_TINY_0600032D.bin", + ); +} + /// Benchmark parsing a fat method header. /// /// Fat headers are 12+ bytes and support complex methods with local variables, /// exception handlers, and large code sizes. fn bench_parse_method_fat(c: &mut Criterion) { - let path = - PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_METHOD_FAT_0600033E.bin"); - - let data = fs::read(&path).expect("Failed to read fat method file"); - let file_size = data.len(); - - let mut group = c.benchmark_group("method_body_fat"); - group.throughput(Throughput::Bytes(file_size as u64)); - group.bench_function("parse", |b| { - b.iter(|| { - let body = MethodBody::from(black_box(&data)).unwrap(); - black_box(body) - }); - }); - group.finish(); + bench_method_file( + c, + "method_body_fat", + "tests/samples/WB_METHOD_FAT_0600033E.bin", + ); } /// Benchmark parsing a method with a single exception handler. /// /// Tests the overhead of parsing exception handling sections. fn bench_parse_method_with_exception(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("tests/samples/WB_METHOD_FAT_EXCEPTION_06000341.bin"); - - let data = fs::read(&path).expect("Failed to read method with exception file"); - let file_size = data.len(); - - let mut group = c.benchmark_group("method_body_exception_single"); - group.throughput(Throughput::Bytes(file_size as u64)); - group.bench_function("parse", |b| { - b.iter(|| { - let body = MethodBody::from(black_box(&data)).unwrap(); - black_box(body) - }); - }); - group.finish(); + bench_method_file( + c, + "method_body_exception_single", + "tests/samples/WB_METHOD_FAT_EXCEPTION_06000341.bin", + ); } /// Benchmark parsing a method with local variables and exception handlers. /// /// Tests a more realistic complex method scenario. fn bench_parse_method_with_locals_and_exception(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("tests/samples/WB_METHOD_FAT_EXCEPTION_N1_2LOCALS_060001AA.bin"); - - let data = fs::read(&path).expect("Failed to read method with locals file"); - let file_size = data.len(); - - let mut group = c.benchmark_group("method_body_with_locals"); - group.throughput(Throughput::Bytes(file_size as u64)); - group.bench_function("parse", |b| { - b.iter(|| { - let body = MethodBody::from(black_box(&data)).unwrap(); - black_box(body) - }); - }); - group.finish(); + bench_method_file( + c, + "method_body_with_locals", + "tests/samples/WB_METHOD_FAT_EXCEPTION_N1_2LOCALS_060001AA.bin", + ); } /// Benchmark parsing a method with multiple exception handlers. /// /// Tests parsing of complex exception handling with multiple try/catch/finally blocks. fn bench_parse_method_multiple_exceptions(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("tests/samples/WB_METHOD_FAT_EXCEPTION_N2_06000421.bin"); - - let data = fs::read(&path).expect("Failed to read method with multiple exceptions file"); - let file_size = data.len(); - - let mut group = c.benchmark_group("method_body_exception_multiple"); - group.throughput(Throughput::Bytes(file_size as u64)); - group.bench_function("parse", |b| { - b.iter(|| { - let body = MethodBody::from(black_box(&data)).unwrap(); - black_box(body) - }); - }); - group.finish(); + bench_method_file( + c, + "method_body_exception_multiple", + "tests/samples/WB_METHOD_FAT_EXCEPTION_N2_06000421.bin", + ); } /// Benchmark parsing another complex method with nested exception handlers. fn bench_parse_method_complex_exceptions(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) - .join("tests/samples/WB_METHOD_FAT_EXCEPTION_N2_06000D54.bin"); - - let data = fs::read(&path).expect("Failed to read complex exception method file"); - let file_size = data.len(); - - let mut group = c.benchmark_group("method_body_exception_complex"); - group.throughput(Throughput::Bytes(file_size as u64)); - group.bench_function("parse", |b| { - b.iter(|| { - let body = MethodBody::from(black_box(&data)).unwrap(); - black_box(body) - }); - }); - group.finish(); + bench_method_file( + c, + "method_body_exception_complex", + "tests/samples/WB_METHOD_FAT_EXCEPTION_N2_06000D54.bin", + ); } criterion_group!( diff --git a/dotscope/benches/resources.rs b/dotscope/benches/resources.rs index c8bfd6a8..e8c57987 100644 --- a/dotscope/benches/resources.rs +++ b/dotscope/benches/resources.rs @@ -1,4 +1,5 @@ -#![allow(unused)] +//! Benchmarks for owned vs. zero-copy parsing of standalone `.resources` files. + extern crate dotscope; use criterion::{criterion_group, criterion_main, Criterion, Throughput}; @@ -13,8 +14,13 @@ fn bench_parse_resources_file(c: &mut Criterion) { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("tests/samples/WB_FxResources.WindowsBase.SR.resources.bin"); - // Load the standalone .resources file - let data = fs::read(&path).expect("Failed to read resources file"); + let Ok(data) = fs::read(&path) else { + eprintln!( + "Skipping resources benchmark: failed to read {}", + path.display() + ); + return; + }; let file_size = data.len(); println!( @@ -28,8 +34,8 @@ fn bench_parse_resources_file(c: &mut Criterion) { group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse_dotnet_resource", |b| { b.iter(|| { - let parsed = parse_dotnet_resource(black_box(&data)).unwrap(); - black_box(parsed) + let parsed = parse_dotnet_resource(black_box(&data)); + black_box(parsed.ok()); }); }); group.finish(); @@ -39,8 +45,8 @@ fn bench_parse_resources_file(c: &mut Criterion) { group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse_dotnet_resource_ref", |b| { b.iter(|| { - let parsed = parse_dotnet_resource_ref(black_box(&data)).unwrap(); - black_box(parsed) + let parsed = parse_dotnet_resource_ref(black_box(&data)); + black_box(parsed.ok()); }); }); group.finish(); diff --git a/dotscope/benches/security.rs b/dotscope/benches/security.rs index 9531f6da..c5831d27 100644 --- a/dotscope/benches/security.rs +++ b/dotscope/benches/security.rs @@ -17,15 +17,20 @@ fn bench_permission_set_parse(c: &mut Criterion) { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_DeclSecurity_1.bin"); - let data = fs::read(&path).expect("Failed to read security declaration file"); + let Ok(data) = fs::read(&path) else { + eprintln!( + "Skipping security benchmark: failed to read {}", + path.display() + ); + return; + }; let file_size = data.len(); let mut group = c.benchmark_group("permission_set"); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse_binary", |b| { b.iter(|| { - let perm_set = PermissionSet::new(black_box(&data)).unwrap(); - black_box(perm_set) + black_box(PermissionSet::new(black_box(&data)).ok()); }); }); group.finish(); @@ -49,8 +54,7 @@ fn bench_permission_set_minimal(c: &mut Criterion) { c.bench_function("permission_set_minimal", |b| { b.iter(|| { - let perm_set = PermissionSet::new(black_box(&data)).unwrap(); - black_box(perm_set) + black_box(PermissionSet::new(black_box(&data)).ok()); }); }); } @@ -64,8 +68,7 @@ fn bench_permission_set_xml_minimal(c: &mut Criterion) { c.bench_function("permission_set_xml_minimal", |b| { b.iter(|| { - let perm_set = PermissionSet::new(black_box(data)).unwrap(); - black_box(perm_set) + black_box(PermissionSet::new(black_box(data)).ok()); }); }); } @@ -78,8 +81,7 @@ fn bench_permission_set_xml_with_permission(c: &mut Criterion) { c.bench_function("permission_set_xml_with_permission", |b| { b.iter(|| { - let perm_set = PermissionSet::new(black_box(data)).unwrap(); - black_box(perm_set) + black_box(PermissionSet::new(black_box(data)).ok()); }); }); } @@ -89,15 +91,20 @@ fn bench_permission_set_repeated(c: &mut Criterion) { let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_DeclSecurity_1.bin"); - let data = fs::read(&path).expect("Failed to read security declaration file"); + let Ok(data) = fs::read(&path) else { + eprintln!( + "Skipping security benchmark: failed to read {}", + path.display() + ); + return; + }; let mut group = c.benchmark_group("permission_set"); group.throughput(Throughput::Elements(100)); group.bench_function("parse_100x", |b| { b.iter(|| { for _ in 0..100 { - let perm_set = PermissionSet::new(black_box(&data)).unwrap(); - black_box(perm_set); + black_box(PermissionSet::new(black_box(&data)).ok()); } }); }); diff --git a/dotscope/benches/signatures.rs b/dotscope/benches/signatures.rs index 469713d0..710d0383 100644 --- a/dotscope/benches/signatures.rs +++ b/dotscope/benches/signatures.rs @@ -25,8 +25,7 @@ fn bench_method_signature_void_no_params(c: &mut Criterion) { c.bench_function("sig_method_void_no_params", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -39,8 +38,7 @@ fn bench_method_signature_primitives(c: &mut Criterion) { c.bench_function("sig_method_primitives", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -53,8 +51,7 @@ fn bench_method_signature_instance(c: &mut Criterion) { c.bench_function("sig_method_instance", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -67,8 +64,7 @@ fn bench_method_signature_generic(c: &mut Criterion) { c.bench_function("sig_method_generic", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -81,8 +77,7 @@ fn bench_method_signature_multi_generic(c: &mut Criterion) { c.bench_function("sig_method_multi_generic", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -95,8 +90,7 @@ fn bench_method_signature_byref(c: &mut Criterion) { c.bench_function("sig_method_byref", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -109,8 +103,7 @@ fn bench_method_signature_array_return(c: &mut Criterion) { c.bench_function("sig_method_array_return", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -125,8 +118,7 @@ fn bench_method_signature_many_params(c: &mut Criterion) { c.bench_function("sig_method_many_params", |b| { b.iter(|| { - let sig = parse_method_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_signature(black_box(&signature)).ok()); }); }); } @@ -139,8 +131,7 @@ fn bench_field_signature_primitive(c: &mut Criterion) { c.bench_function("sig_field_primitive", |b| { b.iter(|| { - let sig = parse_field_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_field_signature(black_box(&signature)).ok()); }); }); } @@ -153,8 +144,7 @@ fn bench_field_signature_string(c: &mut Criterion) { c.bench_function("sig_field_string", |b| { b.iter(|| { - let sig = parse_field_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_field_signature(black_box(&signature)).ok()); }); }); } @@ -167,8 +157,7 @@ fn bench_field_signature_array(c: &mut Criterion) { c.bench_function("sig_field_array", |b| { b.iter(|| { - let sig = parse_field_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_field_signature(black_box(&signature)).ok()); }); }); } @@ -181,8 +170,7 @@ fn bench_field_signature_generic_param(c: &mut Criterion) { c.bench_function("sig_field_generic_param", |b| { b.iter(|| { - let sig = parse_field_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_field_signature(black_box(&signature)).ok()); }); }); } @@ -195,8 +183,7 @@ fn bench_field_signature_class(c: &mut Criterion) { c.bench_function("sig_field_class", |b| { b.iter(|| { - let sig = parse_field_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_field_signature(black_box(&signature)).ok()); }); }); } @@ -209,8 +196,7 @@ fn bench_property_signature_simple(c: &mut Criterion) { c.bench_function("sig_property_simple", |b| { b.iter(|| { - let sig = parse_property_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_property_signature(black_box(&signature)).ok()); }); }); } @@ -223,8 +209,7 @@ fn bench_property_signature_indexer(c: &mut Criterion) { c.bench_function("sig_property_indexer", |b| { b.iter(|| { - let sig = parse_property_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_property_signature(black_box(&signature)).ok()); }); }); } @@ -237,8 +222,7 @@ fn bench_property_signature_static(c: &mut Criterion) { c.bench_function("sig_property_static", |b| { b.iter(|| { - let sig = parse_property_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_property_signature(black_box(&signature)).ok()); }); }); } @@ -251,8 +235,7 @@ fn bench_local_var_signature_single(c: &mut Criterion) { c.bench_function("sig_localvar_single", |b| { b.iter(|| { - let sig = parse_local_var_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_local_var_signature(black_box(&signature)).ok()); }); }); } @@ -265,8 +248,7 @@ fn bench_local_var_signature_multiple(c: &mut Criterion) { c.bench_function("sig_localvar_multiple", |b| { b.iter(|| { - let sig = parse_local_var_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_local_var_signature(black_box(&signature)).ok()); }); }); } @@ -279,8 +261,7 @@ fn bench_local_var_signature_byref(c: &mut Criterion) { c.bench_function("sig_localvar_byref", |b| { b.iter(|| { - let sig = parse_local_var_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_local_var_signature(black_box(&signature)).ok()); }); }); } @@ -293,8 +274,7 @@ fn bench_local_var_signature_pinned(c: &mut Criterion) { c.bench_function("sig_localvar_pinned", |b| { b.iter(|| { - let sig = parse_local_var_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_local_var_signature(black_box(&signature)).ok()); }); }); } @@ -319,8 +299,7 @@ fn bench_local_var_signature_many(c: &mut Criterion) { c.bench_function("sig_localvar_many", |b| { b.iter(|| { - let sig = parse_local_var_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_local_var_signature(black_box(&signature)).ok()); }); }); } @@ -333,8 +312,7 @@ fn bench_type_spec_generic_simple(c: &mut Criterion) { c.bench_function("sig_typespec_generic_simple", |b| { b.iter(|| { - let sig = parse_type_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_type_spec_signature(black_box(&signature)).ok()); }); }); } @@ -347,8 +325,7 @@ fn bench_type_spec_generic_multi_arg(c: &mut Criterion) { c.bench_function("sig_typespec_generic_multi_arg", |b| { b.iter(|| { - let sig = parse_type_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_type_spec_signature(black_box(&signature)).ok()); }); }); } @@ -361,8 +338,7 @@ fn bench_type_spec_array(c: &mut Criterion) { c.bench_function("sig_typespec_array", |b| { b.iter(|| { - let sig = parse_type_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_type_spec_signature(black_box(&signature)).ok()); }); }); } @@ -375,8 +351,7 @@ fn bench_type_spec_pointer(c: &mut Criterion) { c.bench_function("sig_typespec_pointer", |b| { b.iter(|| { - let sig = parse_type_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_type_spec_signature(black_box(&signature)).ok()); }); }); } @@ -389,8 +364,7 @@ fn bench_type_spec_generic_param(c: &mut Criterion) { c.bench_function("sig_typespec_generic_param", |b| { b.iter(|| { - let sig = parse_type_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_type_spec_signature(black_box(&signature)).ok()); }); }); } @@ -403,8 +377,7 @@ fn bench_method_spec_single(c: &mut Criterion) { c.bench_function("sig_methodspec_single", |b| { b.iter(|| { - let sig = parse_method_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_spec_signature(black_box(&signature)).ok()); }); }); } @@ -417,8 +390,7 @@ fn bench_method_spec_multiple(c: &mut Criterion) { c.bench_function("sig_methodspec_multiple", |b| { b.iter(|| { - let sig = parse_method_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_spec_signature(black_box(&signature)).ok()); }); }); } @@ -431,8 +403,7 @@ fn bench_method_spec_nested_generic(c: &mut Criterion) { c.bench_function("sig_methodspec_nested_generic", |b| { b.iter(|| { - let sig = parse_method_spec_signature(black_box(&signature)).unwrap(); - black_box(sig) + black_box(parse_method_spec_signature(black_box(&signature)).ok()); }); }); } diff --git a/dotscope/benches/streams.rs b/dotscope/benches/streams.rs index 833f14e3..61870e57 100644 --- a/dotscope/benches/streams.rs +++ b/dotscope/benches/streams.rs @@ -14,19 +14,35 @@ use criterion::{criterion_group, criterion_main, Criterion, Throughput}; use dotscope::metadata::streams::{Blob, Guid, Strings, UserStrings}; use std::{fs, hint::black_box, path::PathBuf}; +/// Read a sample file from the workspace samples directory. Returns `None` +/// (with a stderr diagnostic) if the file is missing, so benchmarks can be +/// skipped without panicking. +fn read_sample(name: &str) -> Option> { + let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(name); + match fs::read(&path) { + Ok(data) => Some(data), + Err(err) => { + eprintln!( + "Skipping benchmark: failed to read {}: {err}", + path.display() + ); + None + } + } +} + /// Benchmark parsing the complete #Strings heap. fn bench_strings_heap_parse(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_STRINGS.bin"); - - let data = fs::read(&path).expect("Failed to read strings heap file"); + let Some(data) = read_sample("tests/samples/WB_STRINGS.bin") else { + return; + }; let file_size = data.len(); let mut group = c.benchmark_group("strings_heap"); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse", |b| { b.iter(|| { - let strings = Strings::from(black_box(&data)).unwrap(); - black_box(strings) + black_box(Strings::from(black_box(&data)).ok()); }); }); group.finish(); @@ -34,12 +50,14 @@ fn bench_strings_heap_parse(c: &mut Criterion) { /// Benchmark iterating over all strings in the heap. fn bench_strings_heap_iterate(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_STRINGS.bin"); - - let data = fs::read(&path).expect("Failed to read strings heap file"); - let strings = Strings::from(&data).expect("Failed to parse strings heap"); + let Some(data) = read_sample("tests/samples/WB_STRINGS.bin") else { + return; + }; + let Ok(strings) = Strings::from(&data) else { + eprintln!("Skipping iterate: failed to parse strings heap"); + return; + }; - // Count strings for throughput calculation let string_count = strings.iter().count(); let mut group = c.benchmark_group("strings_heap"); @@ -55,12 +73,14 @@ fn bench_strings_heap_iterate(c: &mut Criterion) { /// Benchmark random access to strings by offset. fn bench_strings_heap_random_access(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_STRINGS.bin"); + let Some(data) = read_sample("tests/samples/WB_STRINGS.bin") else { + return; + }; + let Ok(strings) = Strings::from(&data) else { + eprintln!("Skipping random_access: failed to parse strings heap"); + return; + }; - let data = fs::read(&path).expect("Failed to read strings heap file"); - let strings = Strings::from(&data).expect("Failed to parse strings heap"); - - // Collect valid offsets for random access testing let offsets: Vec = strings.iter().map(|(offset, _)| offset).collect(); let mut group = c.benchmark_group("strings_heap"); @@ -77,17 +97,16 @@ fn bench_strings_heap_random_access(c: &mut Criterion) { /// Benchmark parsing the complete #Blob heap. fn bench_blob_heap_parse(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_BLOB.bin"); - - let data = fs::read(&path).expect("Failed to read blob heap file"); + let Some(data) = read_sample("tests/samples/WB_BLOB.bin") else { + return; + }; let file_size = data.len(); let mut group = c.benchmark_group("blob_heap"); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse", |b| { b.iter(|| { - let blob = Blob::from(black_box(&data)).unwrap(); - black_box(blob) + black_box(Blob::from(black_box(&data)).ok()); }); }); group.finish(); @@ -95,12 +114,14 @@ fn bench_blob_heap_parse(c: &mut Criterion) { /// Benchmark iterating over all blobs in the heap. fn bench_blob_heap_iterate(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_BLOB.bin"); - - let data = fs::read(&path).expect("Failed to read blob heap file"); - let blob = Blob::from(&data).expect("Failed to parse blob heap"); + let Some(data) = read_sample("tests/samples/WB_BLOB.bin") else { + return; + }; + let Ok(blob) = Blob::from(&data) else { + eprintln!("Skipping iterate: failed to parse blob heap"); + return; + }; - // Count blobs for throughput calculation let blob_count = blob.iter().count(); let mut group = c.benchmark_group("blob_heap"); @@ -116,12 +137,14 @@ fn bench_blob_heap_iterate(c: &mut Criterion) { /// Benchmark random access to blobs by offset. fn bench_blob_heap_random_access(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_BLOB.bin"); + let Some(data) = read_sample("tests/samples/WB_BLOB.bin") else { + return; + }; + let Ok(blob) = Blob::from(&data) else { + eprintln!("Skipping random_access: failed to parse blob heap"); + return; + }; - let data = fs::read(&path).expect("Failed to read blob heap file"); - let blob = Blob::from(&data).expect("Failed to parse blob heap"); - - // Collect valid offsets for random access testing let offsets: Vec = blob.iter().map(|(offset, _)| offset).collect(); let mut group = c.benchmark_group("blob_heap"); @@ -138,17 +161,16 @@ fn bench_blob_heap_random_access(c: &mut Criterion) { /// Benchmark parsing the complete #US (User Strings) heap. fn bench_userstrings_heap_parse(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_US.bin"); - - let data = fs::read(&path).expect("Failed to read user strings heap file"); + let Some(data) = read_sample("tests/samples/WB_US.bin") else { + return; + }; let file_size = data.len(); let mut group = c.benchmark_group("userstrings_heap"); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse", |b| { b.iter(|| { - let us = UserStrings::from(black_box(&data)).unwrap(); - black_box(us) + black_box(UserStrings::from(black_box(&data)).ok()); }); }); group.finish(); @@ -156,12 +178,14 @@ fn bench_userstrings_heap_parse(c: &mut Criterion) { /// Benchmark iterating over all user strings in the heap. fn bench_userstrings_heap_iterate(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_US.bin"); + let Some(data) = read_sample("tests/samples/WB_US.bin") else { + return; + }; + let Ok(us) = UserStrings::from(&data) else { + eprintln!("Skipping iterate: failed to parse user strings heap"); + return; + }; - let data = fs::read(&path).expect("Failed to read user strings heap file"); - let us = UserStrings::from(&data).expect("Failed to parse user strings heap"); - - // Count strings for throughput calculation let string_count = us.iter().count(); let mut group = c.benchmark_group("userstrings_heap"); @@ -177,12 +201,14 @@ fn bench_userstrings_heap_iterate(c: &mut Criterion) { /// Benchmark random access to user strings by offset. fn bench_userstrings_heap_random_access(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_US.bin"); - - let data = fs::read(&path).expect("Failed to read user strings heap file"); - let us = UserStrings::from(&data).expect("Failed to parse user strings heap"); + let Some(data) = read_sample("tests/samples/WB_US.bin") else { + return; + }; + let Ok(us) = UserStrings::from(&data) else { + eprintln!("Skipping random_access: failed to parse user strings heap"); + return; + }; - // Collect valid offsets for random access testing let offsets: Vec = us.iter().map(|(offset, _)| offset).collect(); let mut group = c.benchmark_group("userstrings_heap"); @@ -199,17 +225,16 @@ fn bench_userstrings_heap_random_access(c: &mut Criterion) { /// Benchmark parsing the complete #GUID heap. fn bench_guid_heap_parse(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_GUID.bin"); - - let data = fs::read(&path).expect("Failed to read GUID heap file"); + let Some(data) = read_sample("tests/samples/WB_GUID.bin") else { + return; + }; let file_size = data.len(); let mut group = c.benchmark_group("guid_heap"); group.throughput(Throughput::Bytes(file_size as u64)); group.bench_function("parse", |b| { b.iter(|| { - let guid = Guid::from(black_box(&data)).unwrap(); - black_box(guid) + black_box(Guid::from(black_box(&data)).ok()); }); }); group.finish(); @@ -217,12 +242,14 @@ fn bench_guid_heap_parse(c: &mut Criterion) { /// Benchmark iterating over all GUIDs in the heap. fn bench_guid_heap_iterate(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_GUID.bin"); - - let data = fs::read(&path).expect("Failed to read GUID heap file"); - let guid = Guid::from(&data).expect("Failed to parse GUID heap"); + let Some(data) = read_sample("tests/samples/WB_GUID.bin") else { + return; + }; + let Ok(guid) = Guid::from(&data) else { + eprintln!("Skipping iterate: failed to parse GUID heap"); + return; + }; - // Count GUIDs for throughput calculation let guid_count = guid.iter().count(); let mut group = c.benchmark_group("guid_heap"); @@ -238,12 +265,14 @@ fn bench_guid_heap_iterate(c: &mut Criterion) { /// Benchmark random access to GUIDs by index (1-based as per ECMA-335). fn bench_guid_heap_random_access(c: &mut Criterion) { - let path = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("tests/samples/WB_GUID.bin"); - - let data = fs::read(&path).expect("Failed to read GUID heap file"); - let guid = Guid::from(&data).expect("Failed to parse GUID heap"); + let Some(data) = read_sample("tests/samples/WB_GUID.bin") else { + return; + }; + let Ok(guid) = Guid::from(&data) else { + eprintln!("Skipping random_access: failed to parse GUID heap"); + return; + }; - // Collect valid indices for random access testing (1-based) let indices: Vec = guid.iter().map(|(idx, _)| idx).collect(); let mut group = c.benchmark_group("guid_heap"); diff --git a/dotscope/examples/analysis.rs b/dotscope/examples/analysis.rs index 62f43e28..4a5eb5a2 100644 --- a/dotscope/examples/analysis.rs +++ b/dotscope/examples/analysis.rs @@ -209,22 +209,22 @@ fn display_diagnostics(diagnostics: &Arc) { } eprintln!("\n=== Loading Diagnostics ==="); - let mut error_count = 0; - let mut warning_count = 0; - let mut info_count = 0; + let mut error_count: u64 = 0; + let mut warning_count: u64 = 0; + let mut info_count: u64 = 0; for entry in diagnostics.iter() { let prefix = match entry.severity { DiagnosticSeverity::Error => { - error_count += 1; + error_count = error_count.saturating_add(1); "ERROR" } DiagnosticSeverity::Warning => { - warning_count += 1; + warning_count = warning_count.saturating_add(1); "WARNING" } DiagnosticSeverity::Info => { - info_count += 1; + info_count = info_count.saturating_add(1); "INFO" } }; @@ -261,7 +261,7 @@ fn list_methods(assembly: &CilObject) -> Result<(), Box> println!("\n=== Methods in Assembly ===\n"); let methods = assembly.methods(); - let mut count = 0; + let mut count: u64 = 0; for entry in methods { let method = entry.value(); @@ -269,7 +269,7 @@ fn list_methods(assembly: &CilObject) -> Result<(), Box> if rva > 0 { let type_name = get_method_type_name(assembly, method.token); println!(" 0x{:08X} {}::{}", rva, type_name, method.name); - count += 1; + count = count.saturating_add(1); } } } @@ -327,13 +327,16 @@ fn find_method_by_name( match matches.len() { 0 => Err(format!("No method found matching '{}'", name).into()), - 1 => Ok(matches.into_iter().next().unwrap().1), + 1 => match matches.into_iter().next() { + Some((_, method)) => Ok(method), + None => Err("Internal error: no matches after len() == 1".into()), + }, _ => { eprintln!("Multiple methods match '{}':", name); for (i, (full_name, method)) in matches.iter().enumerate() { eprintln!( " {}. {} (RVA: 0x{:08X})", - i + 1, + i.saturating_add(1), full_name, method.rva.unwrap_or(0) ); @@ -362,7 +365,8 @@ fn display_disasm( ) -> Result<(), Box> { let type_name = get_method_type_name(assembly, method.token); let rva = method.rva.unwrap_or(0); - let num_args = method.signature.param_count as usize + usize::from(method.signature.has_this); + let num_args = (method.signature.param_count as usize) + .saturating_add(usize::from(method.signature.has_this)); let num_locals = method.local_vars.count(); println!("\n{}", "=".repeat(80)); @@ -402,7 +406,8 @@ fn display_ssa_method( ) -> Result<(), Box> { let type_name = get_method_type_name(assembly, method.token); let rva = method.rva.unwrap_or(0); - let num_args = method.signature.param_count as usize + usize::from(method.signature.has_this); + let num_args = (method.signature.param_count as usize) + .saturating_add(usize::from(method.signature.has_this)); let num_locals = method.local_vars.count(); println!("\n{}", "=".repeat(80)); @@ -441,7 +446,8 @@ fn display_ssa_deobfuscated( }; let type_name = get_method_type_name(&assembly, method.token); let rva = method.rva.unwrap_or(0); - let num_args = method.signature.param_count as usize + usize::from(method.signature.has_this); + let num_args = (method.signature.param_count as usize) + .saturating_add(usize::from(method.signature.has_this)); let num_locals = method.local_vars.count(); println!("\n{}", "=".repeat(80)); diff --git a/dotscope/examples/basic.rs b/dotscope/examples/basic.rs index 62a0c8ac..24edf5fa 100644 --- a/dotscope/examples/basic.rs +++ b/dotscope/examples/basic.rs @@ -21,8 +21,9 @@ use std::{env, path::Path}; fn main() -> Result<()> { // Get the path from command line arguments or use a default let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("basic", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example demonstrates basic .NET assembly analysis patterns:"); eprintln!(" • Loading assemblies with error handling"); @@ -30,9 +31,9 @@ fn main() -> Result<()> { eprintln!(" • Iterating through methods safely"); eprintln!(" • Using the prelude for clean, consistent code"); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); // Load and analyze a .NET assembly using dotscope // This demonstrates the primary entry point for assembly analysis @@ -53,7 +54,7 @@ fn main() -> Result<()> { eprintln!(" • File is not in PE format"); eprintln!(); eprintln!("Try with a known good .NET assembly like:"); - eprintln!(" {} tests/samples/WindowsBase.dll", args[0]); + eprintln!(" {prog} tests/samples/WindowsBase.dll"); return Err(e); } }; @@ -71,7 +72,7 @@ fn main() -> Result<()> { let method = entry.value(); println!( "{}. Method: {} (Token: 0x{:08X})", - count + 1, + count.saturating_add(1), method.name, token.value() ); diff --git a/dotscope/examples/comprehensive.rs b/dotscope/examples/comprehensive.rs index 5c6c04df..7e11bccf 100644 --- a/dotscope/examples/comprehensive.rs +++ b/dotscope/examples/comprehensive.rs @@ -23,8 +23,9 @@ use std::{env, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("comprehensive", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example demonstrates advanced dotscope capabilities:"); eprintln!(" • Complete metadata analysis"); @@ -33,9 +34,9 @@ fn main() -> Result<()> { eprintln!(" • Import/export analysis"); eprintln!(" • Instruction-level analysis"); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); println!("🔍 Advanced analysis of: {}", path.display()); // Load assembly using the prelude's CilObject @@ -106,7 +107,8 @@ fn print_type_analysis(assembly: &CilObject) { &type_def.namespace }; - *namespaces.entry(namespace.to_string()).or_insert(0) += 1; + let entry = namespaces.entry(namespace.to_string()).or_insert(0u64); + *entry = entry.saturating_add(1); } println!(" Top namespaces:"); @@ -138,15 +140,15 @@ fn print_method_analysis(assembly: &CilObject) { for entry in methods.iter().take(10) { let method = entry.value(); - method_stats.total += 1; + method_stats.total = method_stats.total.saturating_add(1); // Check if it's a static method (simplified check) if method.name.starts_with("op_") || method.name == ".cctor" { - method_stats.static_methods += 1; + method_stats.static_methods = method_stats.static_methods.saturating_add(1); } if method.name.starts_with("get_") || method.name.starts_with("set_") { - method_stats.properties += 1; + method_stats.properties = method_stats.properties.saturating_add(1); } println!( @@ -156,7 +158,7 @@ fn print_method_analysis(assembly: &CilObject) { ); if let Some(body) = method.body.get() { - method_stats.with_body += 1; + method_stats.with_body = method_stats.with_body.saturating_add(1); println!( " - IL size: {} bytes, Max stack: {}", @@ -196,15 +198,15 @@ fn print_import_analysis(assembly: &CilObject) { println!(" Sample imports:"); // Now we can iterate over all imports! - let mut method_imports = 0; - let mut type_imports = 0; + let mut method_imports: u64 = 0; + let mut type_imports: u64 = 0; for entry in imports.cil().iter().take(10) { let (token, import) = (entry.key(), entry.value()); match &import.import { dotscope::metadata::imports::ImportType::Method(_) => { - method_imports += 1; + method_imports = method_imports.saturating_add(1u64); if import.namespace.is_empty() { println!( " Method: {} (Token: 0x{:08X})", @@ -221,7 +223,7 @@ fn print_import_analysis(assembly: &CilObject) { } } dotscope::metadata::imports::ImportType::Type(_) => { - type_imports += 1; + type_imports = type_imports.saturating_add(1u64); if import.namespace.is_empty() { println!(" Type: {} (Token: 0x{:08X})", import.name, token.value()); } else { @@ -237,7 +239,10 @@ fn print_import_analysis(assembly: &CilObject) { } if imports.total_count() > 10 { - println!(" ... and {} more imports", imports.total_count() - 10); + println!( + " ... and {} more imports", + imports.total_count().saturating_sub(10) + ); } println!(" Import summary:"); @@ -257,17 +262,17 @@ fn print_instruction_analysis(assembly: &CilObject) { // Find methods with IL to analyze let methods = assembly.methods(); - let mut instruction_count = 0; - let mut total_il_bytes = 0; - let mut methods_analyzed = 0; + let mut instruction_count: usize = 0; + let mut total_il_bytes: usize = 0; + let mut methods_analyzed: u64 = 0; for entry in methods.iter().take(5) { let method = entry.value(); if let Some(body) = method.body.get() { println!(" Analyzing method: {}", method.name); - total_il_bytes += body.size_code; - methods_analyzed += 1; + total_il_bytes = total_il_bytes.saturating_add(body.size_code); + methods_analyzed = methods_analyzed.saturating_add(1u64); // Access basic blocks - they are automatically decoded when method is loaded let blocks: Vec<_> = method.blocks().collect(); @@ -279,7 +284,7 @@ fn print_instruction_analysis(assembly: &CilObject) { // Show first few instructions from first block if let Some((_, first_block)) = blocks.first() { let inst_count = first_block.instructions.len(); - instruction_count += inst_count; + instruction_count = instruction_count.saturating_add(inst_count); println!(" - First block has {inst_count} instructions"); for (i, instruction) in first_block.instructions.iter().take(3).enumerate() { @@ -291,7 +296,7 @@ fn print_instruction_analysis(assembly: &CilObject) { if first_block.instructions.len() > 3 { println!( " ... and {} more instructions", - first_block.instructions.len() - 3 + first_block.instructions.len().saturating_sub(3) ); } } diff --git a/dotscope/examples/decode_blocks.rs b/dotscope/examples/decode_blocks.rs index b8ab39ae..1d735967 100644 --- a/dotscope/examples/decode_blocks.rs +++ b/dotscope/examples/decode_blocks.rs @@ -25,15 +25,14 @@ fn main() -> Result<()> { let code = [0x00, 0x2A]; // nop, ret let blocks = decode_blocks(&code, 0, 0x1000, None)?; println!("Number of basic blocks: {}", blocks.len()); - println!( - "Instructions in first block: {}", - blocks[0].instructions.len() - ); - for (i, instruction) in blocks[0].instructions.iter().enumerate() { - println!( - " {}: {} (RVA: 0x{:X})", - i, instruction.mnemonic, instruction.rva - ); + if let Some(first) = blocks.first() { + println!("Instructions in first block: {}", first.instructions.len()); + for (i, instruction) in first.instructions.iter().enumerate() { + println!( + " {}: {} (RVA: 0x{:X})", + i, instruction.mnemonic, instruction.rva + ); + } } // Example: Conditional branch @@ -68,9 +67,11 @@ fn main() -> Result<()> { ]; let blocks = decode_blocks(&code, 0, 0x3000, Some(2))?; println!("Number of basic blocks: {}", blocks.len()); - println!("Instructions in block: {}", blocks[0].instructions.len()); - for instruction in &blocks[0].instructions { - println!(" {} (RVA: 0x{:X})", instruction.mnemonic, instruction.rva); + if let Some(first) = blocks.first() { + println!("Instructions in block: {}", first.instructions.len()); + for instruction in &first.instructions { + println!(" {} (RVA: 0x{:X})", instruction.mnemonic, instruction.rva); + } } println!("\n✅ All examples completed successfully!"); diff --git a/dotscope/examples/deobfuscate.rs b/dotscope/examples/deobfuscate.rs index 095a8b95..8c0e3f38 100644 --- a/dotscope/examples/deobfuscate.rs +++ b/dotscope/examples/deobfuscate.rs @@ -128,7 +128,13 @@ fn main() -> Result<(), Box> { } if handle.is_finished() { - break handle.join().expect("Deobfuscation thread panicked"); + match handle.join() { + Ok(value) => break value, + Err(_) => { + eprintln!("Deobfuscation thread panicked"); + std::process::exit(1); + } + } } thread::sleep(std::time::Duration::from_millis(100)); diff --git a/dotscope/examples/disassembly.rs b/dotscope/examples/disassembly.rs index 7837e95c..531f9cbd 100644 --- a/dotscope/examples/disassembly.rs +++ b/dotscope/examples/disassembly.rs @@ -23,8 +23,9 @@ use std::{env, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("disassembly", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example demonstrates IL disassembly and method analysis:"); eprintln!(" • CIL instruction decoding with full operand support"); @@ -32,9 +33,9 @@ fn main() -> Result<()> { eprintln!(" • Exception handler examination"); eprintln!(" • Stack and local variable analysis"); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); println!("⚙️ IL Disassembly analysis of: {}", path.display()); let assembly = CilObject::from_path(path)?; @@ -66,37 +67,39 @@ fn print_method_body_analysis(assembly: &CilObject) { for entry in methods.iter().take(20) { let method = entry.value(); - stats.total_methods += 1; + stats.total_methods = stats.total_methods.saturating_add(1); if let Some(body) = method.body.get() { - stats.methods_with_body += 1; - stats.total_il_bytes += body.size_code; + stats.methods_with_body = stats.methods_with_body.saturating_add(1); + stats.total_il_bytes = stats.total_il_bytes.saturating_add(body.size_code); if body.max_stack > stats.max_stack_size { stats.max_stack_size = body.max_stack; } if body.local_var_sig_token != 0 { - stats.methods_with_locals += 1; + stats.methods_with_locals = stats.methods_with_locals.saturating_add(1); } if !body.exception_handlers.is_empty() { - stats.methods_with_exceptions += 1; - stats.total_exception_handlers += body.exception_handlers.len(); + stats.methods_with_exceptions = stats.methods_with_exceptions.saturating_add(1); + stats.total_exception_handlers = stats + .total_exception_handlers + .saturating_add(body.exception_handlers.len()); } // Analyze method characteristics if body.is_init_local { - stats.init_locals += 1; + stats.init_locals = stats.init_locals.saturating_add(1); } if body.size_code < 64 { - stats.tiny_methods += 1; + stats.tiny_methods = stats.tiny_methods.saturating_add(1); } else { - stats.fat_methods += 1; + stats.fat_methods = stats.fat_methods.saturating_add(1); } } else { - stats.abstract_or_extern += 1; + stats.abstract_or_extern = stats.abstract_or_extern.saturating_add(1); } } @@ -154,12 +157,12 @@ fn print_instruction_analysis(assembly: &CilObject) { ); // Display actual disassembled instructions from blocks - let mut total_instructions = 0; - let mut block_count = 0; + let mut total_instructions: usize = 0; + let mut block_count: u64 = 0; // Access blocks - blocks are automatically populated when method is loaded for (block_id, block) in method.blocks() { - block_count += 1; + block_count = block_count.saturating_add(1); if block_count <= 3 && !block.instructions.is_empty() { println!( " Block {} (RVA: 0x{:X}, {} instructions):", @@ -177,14 +180,17 @@ fn print_instruction_analysis(assembly: &CilObject) { ); // Update instruction statistics - instruction_stats.total_instructions += 1; + instruction_stats.total_instructions = + instruction_stats.total_instructions.saturating_add(1); match instruction.flow_type { dotscope::assembly::FlowType::ConditionalBranch | dotscope::assembly::FlowType::UnconditionalBranch => { - instruction_stats.branch_instructions += 1; + instruction_stats.branch_instructions = + instruction_stats.branch_instructions.saturating_add(1); } dotscope::assembly::FlowType::Call => { - instruction_stats.call_instructions += 1; + instruction_stats.call_instructions = + instruction_stats.call_instructions.saturating_add(1); } _ => {} } @@ -192,28 +198,34 @@ fn print_instruction_analysis(assembly: &CilObject) { if instruction.mnemonic.starts_with("ld") || instruction.mnemonic.starts_with("st") { - instruction_stats.load_store_instructions += 1; + instruction_stats.load_store_instructions = + instruction_stats.load_store_instructions.saturating_add(1); } } } if block.instructions.len() > 5 { println!( " ... ({} more instructions)", - block.instructions.len() - 5 + block.instructions.len().saturating_sub(5) ); } } - total_instructions += block.instructions.len(); + total_instructions = + total_instructions.saturating_add(block.instructions.len()); } println!(" Basic blocks: {block_count}"); if block_count > 3 { - println!(" ... ({} more blocks)", block_count - 3); + println!( + " ... ({} more blocks)", + block_count.saturating_sub(3) + ); } println!(" Total instructions: {total_instructions}"); - instruction_stats.methods_analyzed += 1; + instruction_stats.methods_analyzed = + instruction_stats.methods_analyzed.saturating_add(1); } if instruction_stats.methods_analyzed >= 3 { @@ -256,18 +268,35 @@ fn print_exception_analysis(assembly: &CilObject) { if let Some(body) = method.body.get() { if !body.exception_handlers.is_empty() { - exception_stats.methods_with_handlers += 1; + exception_stats.methods_with_handlers = + exception_stats.methods_with_handlers.saturating_add(1); for handler in &body.exception_handlers { - exception_stats.total_handlers += 1; + exception_stats.total_handlers = + exception_stats.total_handlers.saturating_add(1); // Analyze handler types based on flags match handler.flags { - ExceptionHandlerFlags::EXCEPTION => exception_stats.catch_handlers += 1, - ExceptionHandlerFlags::FILTER => exception_stats.filter_handlers += 1, - ExceptionHandlerFlags::FINALLY => exception_stats.finally_handlers += 1, - ExceptionHandlerFlags::FAULT => exception_stats.fault_handlers += 1, - _ => exception_stats.unknown_handlers += 1, + ExceptionHandlerFlags::EXCEPTION => { + exception_stats.catch_handlers = + exception_stats.catch_handlers.saturating_add(1) + } + ExceptionHandlerFlags::FILTER => { + exception_stats.filter_handlers = + exception_stats.filter_handlers.saturating_add(1) + } + ExceptionHandlerFlags::FINALLY => { + exception_stats.finally_handlers = + exception_stats.finally_handlers.saturating_add(1) + } + ExceptionHandlerFlags::FAULT => { + exception_stats.fault_handlers = + exception_stats.fault_handlers.saturating_add(1) + } + _ => { + exception_stats.unknown_handlers = + exception_stats.unknown_handlers.saturating_add(1) + } } // Track protected region sizes @@ -292,7 +321,7 @@ fn print_exception_analysis(assembly: &CilObject) { i, handler_type, handler.try_offset, - handler.try_offset + handler.try_length, + handler.try_offset.saturating_add(handler.try_length), handler.handler_offset ); } @@ -330,10 +359,10 @@ fn print_stack_analysis(assembly: &CilObject) { let method = entry.value(); if let Some(body) = method.body.get() { - local_stats.methods_analyzed += 1; + local_stats.methods_analyzed = local_stats.methods_analyzed.saturating_add(1); if body.local_var_sig_token != 0 { - local_stats.methods_with_locals += 1; + local_stats.methods_with_locals = local_stats.methods_with_locals.saturating_add(1); // In a real implementation, you would parse the local variable signature // to determine the exact types and count of local variables @@ -355,7 +384,8 @@ fn print_stack_analysis(assembly: &CilObject) { // Check init_locals flag if body.is_init_local { - local_stats.methods_with_init_locals += 1; + local_stats.methods_with_init_locals = + local_stats.methods_with_init_locals.saturating_add(1); } } } diff --git a/dotscope/examples/injectcode.rs b/dotscope/examples/injectcode.rs index 07ef1380..95007e05 100644 --- a/dotscope/examples/injectcode.rs +++ b/dotscope/examples/injectcode.rs @@ -34,8 +34,9 @@ use std::{env, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() != 3 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("injectcode", String::as_str); + let (Some(input_arg), Some(output_arg)) = (args.get(1), args.get(2)) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example demonstrates .NET assembly code injection:"); eprintln!(" - Finding or creating external assembly references"); @@ -46,7 +47,7 @@ fn main() -> Result<()> { eprintln!(" - Complete workflow with validation and PE generation"); eprintln!(); eprintln!("Example:"); - eprintln!(" {} input.dll injected.dll", args[0]); + eprintln!(" {prog} input.dll injected.dll"); eprintln!(); eprintln!("The injected method will be:"); eprintln!(" public static void PrintHelloWorld()"); @@ -54,10 +55,10 @@ fn main() -> Result<()> { eprintln!(" System.Console.WriteLine(\"Hello World from dotscope!\");"); eprintln!(" }}"); return Ok(()); - } + }; - let input_path = Path::new(&args[1]); - let output_path = Path::new(&args[2]); + let input_path = Path::new(input_arg); + let output_path = Path::new(output_arg); println!(".NET Assembly Code Injection Tool"); println!("Input: {}", input_path.display()); @@ -136,9 +137,9 @@ fn main() -> Result<()> { .build(&mut assembly)?; // Get placeholder token for use in IL instructions - let console_writeline_token = console_writeline_ref - .placeholder_token() - .expect("Console.WriteLine ChangeRef should be a table row"); + let console_writeline_token = console_writeline_ref.placeholder_token().ok_or_else(|| { + dotscope::Error::Other("Console.WriteLine ChangeRef should be a table row".to_string()) + })?; println!(" Created mscorlib reference"); println!(" Created Console.WriteLine reference"); println!(); diff --git a/dotscope/examples/lowlevel.rs b/dotscope/examples/lowlevel.rs index 822e942e..1bf4d680 100644 --- a/dotscope/examples/lowlevel.rs +++ b/dotscope/examples/lowlevel.rs @@ -23,8 +23,9 @@ use std::{env, fs, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("lowlevel", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example demonstrates low-level API usage for binary parsing:"); eprintln!(" • Direct PE structure parsing"); @@ -32,11 +33,11 @@ fn main() -> Result<()> { eprintln!(" • Working with byte buffers"); eprintln!(" • Understanding dotscope internals"); eprintln!(); - eprintln!("Recommended: {} tests/samples/WindowsBase.dll", args[0]); + eprintln!("Recommended: {prog} tests/samples/WindowsBase.dll"); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); println!("🔧 Low-level analysis of: {}", path.display()); // Step 1: Load the entire assembly into memory as Vec @@ -64,7 +65,9 @@ fn main() -> Result<()> { // Step 3: Parse CLR metadata using low-level Cor20Header struct println!("\n=== Step 3: Parsing CLR Header using Cor20Header ==="); - let (clr_rva, clr_size) = file.clr().expect("File should have CLR runtime header"); + let (clr_rva, clr_size) = file + .clr() + .ok_or_else(|| dotscope::Error::Other("File should have CLR runtime header".into()))?; println!("CLR Runtime Header: RVA=0x{clr_rva:08X}, Size={clr_size} bytes"); // Convert RVA to file offset and read CLR header @@ -127,8 +130,12 @@ fn main() -> Result<()> { .find(|stream| stream.name == "#Strings") { println!("\n--- #Strings Stream ---"); - let strings_data = &metadata_data[strings_stream.offset as usize - ..(strings_stream.offset + strings_stream.size) as usize]; + let start = strings_stream.offset as usize; + let end = start.saturating_add(strings_stream.size as usize); + let Some(strings_data) = metadata_data.get(start..end) else { + println!("#Strings stream extends past metadata buffer; skipping"); + return Ok(()); + }; match Strings::from(strings_data) { Ok(strings) => { @@ -158,8 +165,12 @@ fn main() -> Result<()> { .find(|stream| stream.name == "#Blob") { println!("\n--- #Blob Stream ---"); - let blob_data = &metadata_data - [blob_stream.offset as usize..(blob_stream.offset + blob_stream.size) as usize]; + let start = blob_stream.offset as usize; + let end = start.saturating_add(blob_stream.size as usize); + let Some(blob_data) = metadata_data.get(start..end) else { + println!("#Blob stream extends past metadata buffer; skipping"); + return Ok(()); + }; match Blob::from(blob_data) { Ok(blob) => { @@ -186,8 +197,12 @@ fn main() -> Result<()> { .find(|stream| stream.name == "#US") { println!("\n--- #US Stream (User Strings) ---"); - let us_data = - &metadata_data[us_stream.offset as usize..(us_stream.offset + us_stream.size) as usize]; + let start = us_stream.offset as usize; + let end = start.saturating_add(us_stream.size as usize); + let Some(us_data) = metadata_data.get(start..end) else { + println!("#US stream extends past metadata buffer; skipping"); + return Ok(()); + }; match UserStrings::from(us_data) { Ok(user_strings) => { @@ -214,8 +229,12 @@ fn main() -> Result<()> { .find(|stream| stream.name == "#~") { println!("\n--- #~ Stream (Metadata Tables) using TablesHeader struct ---"); - let tables_data = &metadata_data - [tables_stream.offset as usize..(tables_stream.offset + tables_stream.size) as usize]; + let start = tables_stream.offset as usize; + let end = start.saturating_add(tables_stream.size as usize); + let Some(tables_data) = metadata_data.get(start..end) else { + println!("#~ stream extends past metadata buffer; skipping"); + return Ok(()); + }; match TablesHeader::from(tables_data) { Ok(tables_header) => { @@ -240,7 +259,10 @@ fn main() -> Result<()> { println!(" {:?}: {} rows", summary.table_id, summary.row_count); } if summaries.len() > 10 { - println!(" ... and {} more tables", summaries.len() - 10); + println!( + " ... and {} more tables", + summaries.len().saturating_sub(10) + ); } } Err(e) => println!("Failed to parse TablesHeader: {e}"), diff --git a/dotscope/examples/metadata.rs b/dotscope/examples/metadata.rs index cb3ce22c..d0969693 100644 --- a/dotscope/examples/metadata.rs +++ b/dotscope/examples/metadata.rs @@ -24,8 +24,9 @@ use std::{collections::HashMap, env, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("metadata", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example explores metadata tables and streams in detail:"); eprintln!(" • Raw metadata table access and analysis"); @@ -33,9 +34,9 @@ fn main() -> Result<()> { eprintln!(" • Cross-table relationship analysis"); eprintln!(" • Assembly dependency tracking"); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); println!("🔬 Metadata exploration of: {}", path.display()); let assembly = CilObject::from_path(path)?; @@ -95,14 +96,14 @@ fn print_heap_analysis(assembly: &CilObject) { // String heap analysis with iterator demonstration if let Some(strings) = assembly.strings() { - let mut string_count = 0; - let mut total_length = 0; + let mut string_count: u64 = 0; + let mut total_length: usize = 0; let mut sample_strings = Vec::new(); println!(" String heap analysis:"); for (offset, string) in strings.iter().take(1000) { - string_count += 1; - total_length += string.len(); + string_count = string_count.saturating_add(1u64); + total_length = total_length.saturating_add(string.len()); // Collect interesting samples if sample_strings.len() < 5 && !string.is_empty() && string.len() > 3 { @@ -141,15 +142,15 @@ fn print_heap_analysis(assembly: &CilObject) { // Blob heap analysis with iterator demonstration if let Some(blob) = assembly.blob() { - let mut blob_count = 0; - let mut total_size = 0; + let mut blob_count: u64 = 0; + let mut total_size: usize = 0; let mut size_histogram: HashMap = HashMap::new(); println!(" Blob heap analysis:"); for (offset, blob_data) in blob.iter().take(500) { // Limit to avoid overwhelming output - blob_count += 1; - total_size += blob_data.len(); + blob_count = blob_count.saturating_add(1u64); + total_size = total_size.saturating_add(blob_data.len()); // Categorize by size let size_category = match blob_data.len() { @@ -159,7 +160,10 @@ fn print_heap_analysis(assembly: &CilObject) { 65..=256 => "large (65-256 bytes)", _ => "huge (>256 bytes)", }; - *size_histogram.entry(size_category.to_string()).or_insert(0) += 1; + let entry = size_histogram + .entry(size_category.to_string()) + .or_insert(0usize); + *entry = entry.saturating_add(1); // Show a sample of the first few blobs if blob_count <= 3 && !blob_data.is_empty() { @@ -195,13 +199,13 @@ fn print_heap_analysis(assembly: &CilObject) { // User strings heap analysis with iterator demonstration if let Some(user_strings) = assembly.userstrings() { - let mut string_count = 0; + let mut string_count: u64 = 0; let mut sample_user_strings = Vec::new(); println!(" User strings heap analysis:"); for (offset, string) in user_strings.iter().take(100) { // Limit for readability - string_count += 1; + string_count = string_count.saturating_add(1); // Collect interesting samples if sample_user_strings.len() < 3 { @@ -217,7 +221,7 @@ fn print_heap_analysis(assembly: &CilObject) { println!(" Sample user strings:"); for (offset, string) in sample_user_strings { let truncated = if string.len() > 50 { - format!("{}...", &string[..47]) + format!("{}...", string.get(..47).unwrap_or(string.as_str())) } else { string }; @@ -242,23 +246,25 @@ fn print_type_system_analysis(assembly: &CilObject) { type_def.namespace.clone() }; - *namespace_stats.entry(namespace).or_insert(0) += 1; + let entry = namespace_stats.entry(namespace).or_insert(0usize); + *entry = entry.saturating_add(1); - // Categorize by common patterns - if type_def.name.ends_with("Attribute") { - *type_kind_stats.entry("Attributes").or_insert(0) += 1; + let category = if type_def.name.ends_with("Attribute") { + "Attributes" } else if type_def.name.ends_with("Exception") { - *type_kind_stats.entry("Exceptions").or_insert(0) += 1; + "Exceptions" } else if type_def.name.ends_with("EventArgs") { - *type_kind_stats.entry("EventArgs").or_insert(0) += 1; + "EventArgs" } else if type_def.name.starts_with('I') && type_def.name.len() > 1 - && type_def.name.chars().nth(1).unwrap().is_uppercase() + && type_def.name.chars().nth(1).is_some_and(char::is_uppercase) { - *type_kind_stats.entry("Interfaces").or_insert(0) += 1; + "Interfaces" } else { - *type_kind_stats.entry("Classes").or_insert(0) += 1; - } + "Classes" + }; + let entry = type_kind_stats.entry(category).or_insert(0usize); + *entry = entry.saturating_add(1); } // Display namespace statistics @@ -282,7 +288,7 @@ fn print_custom_attributes_analysis(assembly: &CilObject) { // Show custom attributes from Types println!(" Custom attributes on Types:"); let types = assembly.types(); - let mut type_count = 0; + let mut type_count: u64 = 0; for entry in types.iter().take(20) { let type_def = entry.value(); let custom_attrs = &type_def.custom_attributes; @@ -292,11 +298,14 @@ fn print_custom_attributes_analysis(assembly: &CilObject) { if attr_count > 0 && type_count < 2 { println!(" Type: {}", type_def.fullname()); for (i, attr) in custom_attrs.iter().take(5) { - print_custom_attribute_info(i + 1, attr); + print_custom_attribute_info(i.saturating_add(1), attr); } - type_count += 1; + type_count = type_count.saturating_add(1u64); if attr_count > 5 { - println!(" ... and {} more attributes", attr_count - 5); + println!( + " ... and {} more attributes", + attr_count.saturating_sub(5) + ); } } } @@ -304,7 +313,7 @@ fn print_custom_attributes_analysis(assembly: &CilObject) { // Show custom attributes from Methods println!(" Custom attributes on Methods:"); let methods = assembly.methods(); - let mut method_count = 0; + let mut method_count: u64 = 0; for entry in methods.iter().take(50) { let method = entry.value(); let custom_attrs = &method.custom_attributes; @@ -314,18 +323,21 @@ fn print_custom_attributes_analysis(assembly: &CilObject) { if attr_count > 0 && method_count < 2 { println!(" Method: {}", method.name); for (i, attr) in custom_attrs.iter().take(5) { - print_custom_attribute_info(i + 1, attr); + print_custom_attribute_info(i.saturating_add(1), attr); } - method_count += 1; + method_count = method_count.saturating_add(1u64); if attr_count > 5 { - println!(" ... and {} more attributes", attr_count - 5); + println!( + " ... and {} more attributes", + attr_count.saturating_sub(5) + ); } } } // Show custom attributes from Events println!(" Custom attributes on Events:"); - let mut event_count = 0; + let mut event_count: u64 = 0; for entry in types.iter().take(20) { let type_def = entry.value(); let events = &type_def.events; @@ -337,11 +349,14 @@ fn print_custom_attributes_analysis(assembly: &CilObject) { if attr_count > 0 && event_count < 2 { println!(" Event: {}", event.name); for (i, attr) in custom_attrs.iter().take(5) { - print_custom_attribute_info(i + 1, attr); + print_custom_attribute_info(i.saturating_add(1), attr); } - event_count += 1; + event_count = event_count.saturating_add(1u64); if attr_count > 5 { - println!(" ... and {} more attributes", attr_count - 5); + println!( + " ... and {} more attributes", + attr_count.saturating_sub(5) + ); } } } @@ -426,7 +441,12 @@ fn print_dependency_analysis(assembly: &CilObject) { flag_descriptions.join(", ") }; - println!(" {}. {} v{}", i + 1, assembly_ref.name, version); + println!( + " {}. {} v{}", + i.saturating_add(1), + assembly_ref.name, + version + ); println!(" Culture: {culture}, Flags: {flags_str}"); // Show identifier information if available @@ -450,7 +470,10 @@ fn print_dependency_analysis(assembly: &CilObject) { } } if assembly_refs.len() > 10 { - println!(" ... and {} more", assembly_refs.len() - 10); + println!( + " ... and {} more", + assembly_refs.len().saturating_sub(10) + ); } } @@ -462,10 +485,10 @@ fn print_dependency_analysis(assembly: &CilObject) { println!(" Referenced modules:"); for (i, entry) in module_refs.iter().take(10).enumerate() { let module_ref = entry.value(); - println!(" {}. {}", i + 1, module_ref.name); + println!(" {}. {}", i.saturating_add(1), module_ref.name); } if module_refs.len() > 10 { - println!(" ... and {} more", module_refs.len() - 10); + println!(" ... and {} more", module_refs.len().saturating_sub(10)); } } diff --git a/dotscope/examples/method_analysis.rs b/dotscope/examples/method_analysis.rs index 80854bf2..a865febb 100644 --- a/dotscope/examples/method_analysis.rs +++ b/dotscope/examples/method_analysis.rs @@ -74,8 +74,9 @@ fn format_impl_options(options: &MethodImplOptions) -> String { fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("method_analysis", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example performs exhaustive analysis of a selected method:"); eprintln!(" • Complete method metadata examination"); @@ -85,9 +86,9 @@ fn main() -> Result<()> { eprintln!(); eprintln!("The example will automatically select a suitable method with IL code."); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); println!("🔍 Comprehensive Method Analysis of: {}", path.display()); let assembly = CilObject::from_path(path)?; @@ -339,11 +340,8 @@ fn print_method_parameters(method: &Method) { vararg.modifiers.count() ); for (j, modifier) in vararg.modifiers.iter() { - println!( - " [{}]: Token 0x{:08X}", - j, - modifier.token().unwrap().value() - ); + let token_value = modifier.token().map_or(0, |t| t.value()); + println!(" [{j}]: Token 0x{token_value:08X}"); } } } @@ -545,7 +543,7 @@ fn print_basic_block_analysis(method: &Method) { } if block_count > 10 { - println!(" ... ({} more blocks)", block_count - 10); + println!(" ... ({} more blocks)", block_count.saturating_sub(10)); } } @@ -565,17 +563,20 @@ fn print_instruction_stream_analysis(method: &Method) -> Result<()> { for (i, instruction) in method.instructions().enumerate() { // Count by mnemonic - *instruction_stats + let entry = instruction_stats .entry(instruction.mnemonic.to_string()) - .or_insert(0) += 1; + .or_insert(0u64); + *entry = entry.saturating_add(1); // Count by category let category_name = format!("{:?}", instruction.category); - *category_stats.entry(category_name).or_insert(0) += 1; + let entry = category_stats.entry(category_name).or_insert(0u64); + *entry = entry.saturating_add(1); // Count by flow type let flow_name = format!("{:?}", instruction.flow_type); - *flow_type_stats.entry(flow_name).or_insert(0) += 1; + let entry = flow_type_stats.entry(flow_name).or_insert(0u64); + *entry = entry.saturating_add(1); // Collect stack effects stack_effects.push(instruction.stack_behavior.net_effect); @@ -604,7 +605,10 @@ fn print_instruction_stream_analysis(method: &Method) -> Result<()> { operand_display ); } else if i == 15 { - println!(" ... ({} more instructions)", total_instructions - 15); + println!( + " ... ({} more instructions)", + total_instructions.saturating_sub(15) + ); } } @@ -655,7 +659,10 @@ fn print_instruction_stream_analysis(method: &Method) -> Result<()> { if sorted_targets.len() <= 5 { println!(" Targets: {sorted_targets:?}"); } else { - println!(" First 5 targets: {:?}...", &sorted_targets[0..5]); + println!( + " First 5 targets: {:?}...", + sorted_targets.get(0..5).unwrap_or(&[]) + ); } } @@ -672,17 +679,17 @@ fn print_control_flow_analysis(method: &Method) { return; } - let mut entry_blocks = 0; - let mut exit_blocks = 0; - let mut branch_blocks = 0; - let mut simple_blocks = 0; + let mut entry_blocks: u64 = 0; + let mut exit_blocks: u64 = 0; + let mut branch_blocks: u64 = 0; + let mut simple_blocks: u64 = 0; for (_, block) in method.blocks() { match (block.predecessors.len(), block.successors.len()) { - (0, _) => entry_blocks += 1, - (_, 0) => exit_blocks += 1, - (_, n) if n > 1 => branch_blocks += 1, - _ => simple_blocks += 1, + (0, _) => entry_blocks = entry_blocks.saturating_add(1u64), + (_, 0) => exit_blocks = exit_blocks.saturating_add(1u64), + (_, n) if n > 1 => branch_blocks = branch_blocks.saturating_add(1u64), + _ => simple_blocks = simple_blocks.saturating_add(1u64), } } @@ -697,7 +704,7 @@ fn print_control_flow_analysis(method: &Method) { .blocks() .map(|(_, block)| block.successors.len().saturating_sub(1)) .sum::() - + 1; + .saturating_add(1); println!("\n Complexity Metrics:"); println!(" Cyclomatic Complexity: {cyclomatic_complexity}"); diff --git a/dotscope/examples/modify.rs b/dotscope/examples/modify.rs index 0f3fc8f3..e1524cc8 100644 --- a/dotscope/examples/modify.rs +++ b/dotscope/examples/modify.rs @@ -32,8 +32,9 @@ use std::{env, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 3 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("modify", String::as_str); + let (Some(source_arg), Some(output_arg)) = (args.get(1), args.get(2)) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example demonstrates comprehensive .NET assembly modification:"); eprintln!(" - Adding strings, blobs, GUIDs, and user strings to heaps"); @@ -42,12 +43,12 @@ fn main() -> Result<()> { eprintln!(" - Validating changes and writing modified assembly"); eprintln!(); eprintln!("Example:"); - eprintln!(" {} input.dll modified.dll", args[0]); + eprintln!(" {prog} input.dll modified.dll"); return Ok(()); - } + }; - let source_path = Path::new(&args[1]); - let output_path = Path::new(&args[2]); + let source_path = Path::new(source_arg); + let output_path = Path::new(output_arg); println!(".NET Assembly Modification Tool"); println!("Source: {}", source_path.display()); diff --git a/dotscope/examples/project_loader.rs b/dotscope/examples/project_loader.rs index eb9b5d9f..bde1e1c5 100644 --- a/dotscope/examples/project_loader.rs +++ b/dotscope/examples/project_loader.rs @@ -30,11 +30,11 @@ use std::env; fn main() -> dotscope::Result<()> { let args: Vec = env::args().collect(); + let prog = args.first().map_or("project_loader", String::as_str); - if args.len() < 2 { + let Some(assembly_path) = args.get(1) else { eprintln!( - "Usage: {} [--search-path ] [--search-path ] ...", - args[0] + "Usage: {prog} [--search-path ] [--search-path ] ..." ); eprintln!(); eprintln!("Options:"); @@ -42,26 +42,20 @@ fn main() -> dotscope::Result<()> { eprintln!(" (can be specified multiple times)"); eprintln!(); eprintln!("Examples:"); - eprintln!(" {} MyApp.exe", args[0]); - eprintln!(" {} MyApp.exe --search-path /usr/lib/mono/4.5", args[0]); - eprintln!( - " {} tests/samples/crafted_2.exe --search-path tests/samples/mono_4.8", - args[0] - ); + eprintln!(" {prog} MyApp.exe"); + eprintln!(" {prog} MyApp.exe --search-path /usr/lib/mono/4.5"); + eprintln!(" {prog} tests/samples/crafted_2.exe --search-path tests/samples/mono_4.8"); std::process::exit(1); - } - - let assembly_path = &args[1]; + }; // Parse --search-path arguments let mut search_paths: Vec = Vec::new(); - let mut i = 2; - while i < args.len() { - if args[i] == "--search-path" && i + 1 < args.len() { - search_paths.push(args[i + 1].clone()); - i += 2; - } else { - i += 1; + let mut iter = args.iter().skip(2); + while let Some(arg) = iter.next() { + if arg == "--search-path" { + if let Some(path) = iter.next() { + search_paths.push(path.clone()); + } } } @@ -171,7 +165,7 @@ fn main() -> dotscope::Result<()> { println!(" - {} ({:?})", ciltype.fullname(), ciltype.flavor()); } if types.len() > 10 { - println!(" ... and {} more", types.len() - 10); + println!(" ... and {} more", types.len().saturating_sub(10)); } } } @@ -215,7 +209,7 @@ fn main() -> dotscope::Result<()> { println!(" - {} in {}", ciltype.fullname(), identity.name); } if definitions.len() > 5 { - println!(" ... and {} more", definitions.len() - 5); + println!(" ... and {} more", definitions.len().saturating_sub(5)); } } diff --git a/dotscope/examples/raw_assembly_view.rs b/dotscope/examples/raw_assembly_view.rs index 1b668078..668a5084 100644 --- a/dotscope/examples/raw_assembly_view.rs +++ b/dotscope/examples/raw_assembly_view.rs @@ -82,7 +82,7 @@ fn display_streams(view: &CilAssemblyView) { println!("{}", "-".repeat(40)); for (idx, stream) in view.streams().iter().enumerate() { - println!("{}. {} stream:", idx + 1, stream.name); + println!("{}. {} stream:", idx.saturating_add(1), stream.name); println!(" • Offset: 0x{:08X}", stream.offset); println!(" • Size: {} bytes", stream.size); @@ -191,7 +191,7 @@ fn demonstrate_blob_access(view: &CilAssemblyView) -> Result<()> { " • Offset: {} - Size: {} bytes - Data: {:02X?}...", offset, data.len(), - &data[..data.len().min(8)] + data.get(..data.len().min(8)).unwrap_or(&[]) ); } } else { diff --git a/dotscope/examples/types.rs b/dotscope/examples/types.rs index 9bfeaeac..49222c9c 100644 --- a/dotscope/examples/types.rs +++ b/dotscope/examples/types.rs @@ -23,8 +23,9 @@ use std::{collections::HashMap, env, path::Path}; fn main() -> Result<()> { let args: Vec = env::args().collect(); - if args.len() < 2 { - eprintln!("Usage: {} ", args[0]); + let prog = args.first().map_or("types", String::as_str); + let Some(path_arg) = args.get(1) else { + eprintln!("Usage: {prog} "); eprintln!(); eprintln!("This example explores the .NET type system in detail:"); eprintln!(" • Type categorization and analysis"); @@ -32,9 +33,9 @@ fn main() -> Result<()> { eprintln!(" • Inheritance hierarchy mapping"); eprintln!(" • Interface implementation tracking"); return Ok(()); - } + }; - let path = Path::new(&args[1]); + let path = Path::new(path_arg); println!("🏗️ Type system analysis of: {}", path.display()); let assembly = CilObject::from_path(path)?; @@ -67,65 +68,82 @@ fn print_type_registry_analysis(assembly: &CilObject) { // Analyze type definitions for type_def in &types.all_types() { - type_categories.total_types += 1; + type_categories.total_types = type_categories.total_types.saturating_add(1); // Categorize by visibility match type_def.flags.bits() & 0x07 { // TypeAttributes.VisibilityMask - 0 => type_categories.not_public += 1, // NotPublic - 1 => type_categories.public += 1, // Public - 2 => type_categories.nested_public += 1, // NestedPublic - 3 => type_categories.nested_private += 1, // NestedPrivate - 4 => type_categories.nested_family += 1, // NestedFamily - 5 => type_categories.nested_assembly += 1, // NestedAssembly - 6 => type_categories.nested_fam_and_assem += 1, // NestedFamANDAssem - 7 => type_categories.nested_fam_or_assem += 1, // NestedFamORAssem + 0 => type_categories.not_public = type_categories.not_public.saturating_add(1), // NotPublic + 1 => type_categories.public = type_categories.public.saturating_add(1), // Public + 2 => type_categories.nested_public = type_categories.nested_public.saturating_add(1), // NestedPublic + 3 => type_categories.nested_private = type_categories.nested_private.saturating_add(1), // NestedPrivate + 4 => type_categories.nested_family = type_categories.nested_family.saturating_add(1), // NestedFamily + 5 => { + type_categories.nested_assembly = type_categories.nested_assembly.saturating_add(1) + } // NestedAssembly + 6 => { + type_categories.nested_fam_and_assem = + type_categories.nested_fam_and_assem.saturating_add(1) + } // NestedFamANDAssem + 7 => { + type_categories.nested_fam_or_assem = + type_categories.nested_fam_or_assem.saturating_add(1) + } // NestedFamORAssem _ => {} } // Categorize by layout match type_def.flags.bits() & 0x18 { // TypeAttributes.LayoutMask - 0x00 => type_categories.auto_layout += 1, // AutoLayout - 0x08 => type_categories.sequential_layout += 1, // SequentialLayout - 0x10 => type_categories.explicit_layout += 1, // ExplicitLayout + 0x00 => type_categories.auto_layout = type_categories.auto_layout.saturating_add(1), // AutoLayout + 0x08 => { + type_categories.sequential_layout = + type_categories.sequential_layout.saturating_add(1) + } // SequentialLayout + 0x10 => { + type_categories.explicit_layout = type_categories.explicit_layout.saturating_add(1) + } // ExplicitLayout _ => {} } // Categorize by semantics match type_def.flags.bits() & 0x20 { // TypeAttributes.ClassSemanticsMask - 0x00 => type_categories.class_types += 1, // Class - 0x20 => type_categories.interface_types += 1, // Interface + 0x00 => type_categories.class_types = type_categories.class_types.saturating_add(1), // Class + 0x20 => { + type_categories.interface_types = type_categories.interface_types.saturating_add(1) + } // Interface _ => {} } // Check for special types if type_def.flags.bits() & 0x80 != 0 { // Abstract - type_categories.abstract_types += 1; + type_categories.abstract_types = type_categories.abstract_types.saturating_add(1); } if type_def.flags.bits() & 0x100 != 0 { // Sealed - type_categories.sealed_types += 1; + type_categories.sealed_types = type_categories.sealed_types.saturating_add(1); } if type_def.flags.bits() & 0x400 != 0 { // Serializable - type_categories.serializable_types += 1; + type_categories.serializable_types = + type_categories.serializable_types.saturating_add(1); } // Analyze by naming patterns if type_def.name.ends_with("Attribute") { - type_categories.attribute_types += 1; + type_categories.attribute_types = type_categories.attribute_types.saturating_add(1); } else if type_def.name.ends_with("Exception") { - type_categories.exception_types += 1; + type_categories.exception_types = type_categories.exception_types.saturating_add(1); } else if type_def.name.ends_with("EventArgs") { - type_categories.event_arg_types += 1; + type_categories.event_arg_types = type_categories.event_arg_types.saturating_add(1); } else if type_def.name.starts_with('I') && type_def.name.len() > 1 - && type_def.name.chars().nth(1).unwrap().is_uppercase() + && type_def.name.chars().nth(1).is_some_and(char::is_uppercase) { - type_categories.interface_named_types += 1; + type_categories.interface_named_types = + type_categories.interface_named_types.saturating_add(1); } } @@ -172,12 +190,19 @@ fn print_generic_analysis(assembly: &CilObject) { for type_def in types.all_types().iter().take(100) { // Look for generic type indicators if type_def.name.contains('`') { - generic_stats.generic_types += 1; + generic_stats.generic_types = generic_stats.generic_types.saturating_add(1); // Extract generic parameter count if let Some(backtick_pos) = type_def.name.rfind('`') { - if let Ok(param_count) = type_def.name[backtick_pos + 1..].parse::() { - generic_stats.total_type_parameters += param_count; + if let Ok(param_count) = type_def + .name + .get(backtick_pos.saturating_add(1)..) + .unwrap_or("") + .parse::() + { + generic_stats.total_type_parameters = generic_stats + .total_type_parameters + .saturating_add(param_count); if param_count > generic_stats.max_type_parameters { generic_stats.max_type_parameters = param_count; generic_stats.most_generic_type = type_def.name.clone(); @@ -198,7 +223,8 @@ fn print_generic_analysis(assembly: &CilObject) { // Simple heuristic: methods with generic naming patterns if method.name.contains('<') || method.name.contains("Generic") { - generic_stats.potentially_generic_methods += 1; + generic_stats.potentially_generic_methods = + generic_stats.potentially_generic_methods.saturating_add(1); } } @@ -230,16 +256,18 @@ fn print_inheritance_analysis(assembly: &CilObject) { let mut base_class_counts: HashMap = HashMap::new(); for type_def in types.all_types().iter().take(50) { - inheritance_stats.total_types += 1; + inheritance_stats.total_types = inheritance_stats.total_types.saturating_add(1); // Check if type has a base class (extends something) if let Some(base_type) = type_def.base() { - inheritance_stats.types_with_base_class += 1; + inheritance_stats.types_with_base_class = + inheritance_stats.types_with_base_class.saturating_add(1); let base_class_name = format!("{}:{}", base_type.namespace, base_type.name); - *base_class_counts.entry(base_class_name).or_insert(0) += 1; + let entry = base_class_counts.entry(base_class_name).or_insert(0u32); + *entry = entry.saturating_add(1); } else { - inheritance_stats.root_types += 1; + inheritance_stats.root_types = inheritance_stats.root_types.saturating_add(1); } // Check for interface implementations @@ -247,7 +275,9 @@ fn print_inheritance_analysis(assembly: &CilObject) { if type_def.flags.bits() & 0x20 == 0 { // Not an interface itself // This is a placeholder - real implementation would check InterfaceImpl table - inheritance_stats.types_implementing_interfaces += 1; + inheritance_stats.types_implementing_interfaces = inheritance_stats + .types_implementing_interfaces + .saturating_add(1); } } @@ -284,19 +314,21 @@ fn print_interface_analysis(assembly: &CilObject) { for type_def in types.all_types().iter().take(100) { if type_def.flags.bits() & 0x20 != 0 { // Interface flag - interface_stats.interface_count += 1; + interface_stats.interface_count = interface_stats.interface_count.saturating_add(1); interface_names.push(format!("{}.{}", type_def.namespace, type_def.name)); // Analyze interface naming patterns if type_def.name.starts_with('I') && type_def.name.len() > 1 - && type_def.name.chars().nth(1).unwrap().is_uppercase() + && type_def.name.chars().nth(1).is_some_and(char::is_uppercase) { - interface_stats.conventionally_named += 1; + interface_stats.conventionally_named = + interface_stats.conventionally_named.saturating_add(1); } if type_def.namespace.starts_with("System") { - interface_stats.system_interfaces += 1; + interface_stats.system_interfaces = + interface_stats.system_interfaces.saturating_add(1); } } } @@ -331,30 +363,37 @@ fn print_signature_analysis(assembly: &CilObject) { for entry in methods.iter().take(50) { let method = entry.value(); - signature_stats.methods_analyzed += 1; + signature_stats.methods_analyzed = signature_stats.methods_analyzed.saturating_add(1); // Analyze method naming patterns for signature complexity - let param_count = - method.name.matches(',').count() + if method.name.contains('(') { 1 } else { 0 }; + let param_count = method + .name + .matches(',') + .count() + .saturating_add(usize::from(method.name.contains('('))); if param_count > signature_stats.max_parameters { signature_stats.max_parameters = param_count; signature_stats.most_complex_method = method.name.clone(); } - signature_stats.total_parameters += param_count; + signature_stats.total_parameters = + signature_stats.total_parameters.saturating_add(param_count); // Check for special method types if method.name.starts_with("get_") || method.name.starts_with("set_") { - signature_stats.property_accessors += 1; + signature_stats.property_accessors = + signature_stats.property_accessors.saturating_add(1); } else if method.name.starts_with("add_") || method.name.starts_with("remove_") { - signature_stats.event_accessors += 1; + signature_stats.event_accessors = signature_stats.event_accessors.saturating_add(1); } else if method.name.starts_with("op_") { - signature_stats.operator_overloads += 1; + signature_stats.operator_overloads = + signature_stats.operator_overloads.saturating_add(1); } else if method.name == ".ctor" { - signature_stats.constructors += 1; + signature_stats.constructors = signature_stats.constructors.saturating_add(1); } else if method.name == ".cctor" { - signature_stats.static_constructors += 1; + signature_stats.static_constructors = + signature_stats.static_constructors.saturating_add(1); } } diff --git a/dotscope/src/analysis/algebraic.rs b/dotscope/src/analysis/algebraic.rs deleted file mode 100644 index f4140cca..00000000 --- a/dotscope/src/analysis/algebraic.rs +++ /dev/null @@ -1,458 +0,0 @@ -//! Algebraic identity simplification. -//! -//! This module provides shared logic for detecting algebraic identities -//! in SSA operations. It checks for patterns like: -//! -//! - `x xor x = 0` (self-cancelling) -//! - `x xor 0 = x` (identity element) -//! - `x * 0 = 0` (absorbing element) -//! - `x * 1 = x` (identity element) -//! - etc. -//! -//! # Usage -//! -//! ```rust,ignore -//! use dotscope::analysis::{simplify_op, SimplifyResult}; -//! -//! let constants = ssa.find_constants(); -//! match simplify_op(op, &constants) { -//! SimplifyResult::Constant(value) => { /* replace with constant */ } -//! SimplifyResult::Copy(var) => { /* replace with copy of var */ } -//! SimplifyResult::None => { /* no simplification */ } -//! } -//! ``` - -use std::collections::BTreeMap; - -use crate::analysis::{ConstValue, SsaOp, SsaVarId}; - -/// Result of checking an operation for algebraic simplification. -#[derive(Debug, Clone, PartialEq)] -pub enum SimplifyResult { - /// The operation simplifies to a constant value. - Constant(ConstValue), - /// The operation simplifies to copying another variable. - Copy(SsaVarId), - /// No simplification possible. - None, -} - -impl SimplifyResult { - /// Returns true if a simplification is possible. - #[must_use] - pub fn is_some(&self) -> bool { - !matches!(self, Self::None) - } - - /// Returns true if no simplification is possible. - #[must_use] - pub fn is_none(&self) -> bool { - matches!(self, Self::None) - } -} - -/// Check if an SSA operation can be algebraically simplified. -/// -/// This function checks for common algebraic identities that allow -/// an operation to be replaced with a simpler form (constant or copy). -#[must_use] -pub fn simplify_op(op: &SsaOp, constants: &BTreeMap) -> SimplifyResult { - match op { - // XOR: x ^ x = 0, x ^ 0 = x - SsaOp::Xor { left, right, .. } => { - if left == right { - return SimplifyResult::Constant(ConstValue::I32(0)); - } - if constants.get(right).is_some_and(ConstValue::is_zero) { - return SimplifyResult::Copy(*left); - } - if constants.get(left).is_some_and(ConstValue::is_zero) { - return SimplifyResult::Copy(*right); - } - SimplifyResult::None - } - - // OR: x | x = x, x | 0 = x, x | -1 = -1 - SsaOp::Or { left, right, .. } => { - if left == right { - return SimplifyResult::Copy(*left); - } - if let Some(c) = constants.get(right) { - if c.is_zero() { - return SimplifyResult::Copy(*left); - } - if c.is_all_ones() { - return SimplifyResult::Constant(c.clone()); - } - } - if let Some(c) = constants.get(left) { - if c.is_zero() { - return SimplifyResult::Copy(*right); - } - if c.is_all_ones() { - return SimplifyResult::Constant(c.clone()); - } - } - SimplifyResult::None - } - - // AND: x & x = x, x & 0 = 0, x & -1 = x - SsaOp::And { left, right, .. } => { - if left == right { - return SimplifyResult::Copy(*left); - } - if let Some(c) = constants.get(right) { - if c.is_zero() { - return SimplifyResult::Constant(c.zero_of_same_type()); - } - if c.is_all_ones() { - return SimplifyResult::Copy(*left); - } - } - if let Some(c) = constants.get(left) { - if c.is_zero() { - return SimplifyResult::Constant(c.zero_of_same_type()); - } - if c.is_all_ones() { - return SimplifyResult::Copy(*right); - } - } - SimplifyResult::None - } - - // ADD: x + 0 = x - SsaOp::Add { left, right, .. } => { - if constants.get(right).is_some_and(ConstValue::is_zero) { - return SimplifyResult::Copy(*left); - } - if constants.get(left).is_some_and(ConstValue::is_zero) { - return SimplifyResult::Copy(*right); - } - SimplifyResult::None - } - - // SUB: x - 0 = x, x - x = 0 - SsaOp::Sub { left, right, .. } => { - if left == right { - return SimplifyResult::Constant(ConstValue::I32(0)); - } - if constants.get(right).is_some_and(ConstValue::is_zero) { - return SimplifyResult::Copy(*left); - } - SimplifyResult::None - } - - // MUL: x * 0 = 0, x * 1 = x - SsaOp::Mul { left, right, .. } => { - if let Some(c) = constants.get(right) { - if c.is_zero() { - return SimplifyResult::Constant(c.clone()); - } - if c.is_one() { - return SimplifyResult::Copy(*left); - } - } - if let Some(c) = constants.get(left) { - if c.is_zero() { - return SimplifyResult::Constant(c.clone()); - } - if c.is_one() { - return SimplifyResult::Copy(*right); - } - } - SimplifyResult::None - } - - // DIV: x / 1 = x, 0 / x = 0 - SsaOp::Div { left, right, .. } => { - if constants.get(right).is_some_and(ConstValue::is_one) { - return SimplifyResult::Copy(*left); - } - if let Some(c) = constants.get(left) { - if c.is_zero() { - return SimplifyResult::Constant(c.clone()); - } - } - SimplifyResult::None - } - - // REM: 0 % x = 0, x % 1 = 0 - SsaOp::Rem { left, right, .. } => { - if let Some(c) = constants.get(left) { - if c.is_zero() { - return SimplifyResult::Constant(c.clone()); - } - } - if let Some(c) = constants.get(right) { - if c.is_one() { - return SimplifyResult::Constant(c.zero_of_same_type()); - } - } - SimplifyResult::None - } - - // SHL/SHR: x << 0 = x, x >> 0 = x - SsaOp::Shl { value, amount, .. } | SsaOp::Shr { value, amount, .. } => { - if constants.get(amount).is_some_and(ConstValue::is_zero) { - return SimplifyResult::Copy(*value); - } - SimplifyResult::None - } - - // Comparisons: x == x → true, x < x → false, x > x → false - SsaOp::Ceq { left, right, .. } => { - if left == right { - return SimplifyResult::Constant(ConstValue::I32(1)); - } - SimplifyResult::None - } - - SsaOp::Clt { left, right, .. } | SsaOp::Cgt { left, right, .. } => { - if left == right { - return SimplifyResult::Constant(ConstValue::I32(0)); - } - SimplifyResult::None - } - - _ => SimplifyResult::None, - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - - use super::*; - - fn make_constants(pairs: &[(SsaVarId, ConstValue)]) -> BTreeMap { - pairs.iter().cloned().collect() - } - - #[test] - fn test_xor_self_cancels() { - let v1 = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let op = SsaOp::Xor { - dest, - left: v1, - right: v1, - }; - assert_eq!( - simplify_op(&op, &BTreeMap::new()), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_xor_zero_identity() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Xor { - dest, - left: v1, - right: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(0))]); - assert_eq!(simplify_op(&op, &constants), SimplifyResult::Copy(v1)); - } - - #[test] - fn test_mul_zero_absorbs() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Mul { - dest, - left: v1, - right: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(0))]); - assert_eq!( - simplify_op(&op, &constants), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_mul_one_identity() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Mul { - dest, - left: v1, - right: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(1))]); - assert_eq!(simplify_op(&op, &constants), SimplifyResult::Copy(v1)); - } - - #[test] - fn test_add_zero_identity() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Add { - dest, - left: v1, - right: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(0))]); - assert_eq!(simplify_op(&op, &constants), SimplifyResult::Copy(v1)); - } - - #[test] - fn test_sub_self_cancels() { - let v1 = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let op = SsaOp::Sub { - dest, - left: v1, - right: v1, - }; - assert_eq!( - simplify_op(&op, &BTreeMap::new()), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_and_zero_absorbs() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::And { - dest, - left: v1, - right: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(0))]); - assert_eq!( - simplify_op(&op, &constants), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_or_zero_identity() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Or { - dest, - left: v1, - right: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(0))]); - assert_eq!(simplify_op(&op, &constants), SimplifyResult::Copy(v1)); - } - - #[test] - fn test_div_one_identity() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Div { - dest, - left: v1, - right: v2, - unsigned: false, - }; - let constants = make_constants(&[(v2, ConstValue::I32(1))]); - assert_eq!(simplify_op(&op, &constants), SimplifyResult::Copy(v1)); - } - - #[test] - fn test_shl_zero_identity() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Shl { - dest, - value: v1, - amount: v2, - }; - let constants = make_constants(&[(v2, ConstValue::I32(0))]); - assert_eq!(simplify_op(&op, &constants), SimplifyResult::Copy(v1)); - } - - #[test] - fn test_no_simplification() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Add { - dest, - left: v1, - right: v2, - }; - // No constants - no simplification - assert_eq!(simplify_op(&op, &BTreeMap::new()), SimplifyResult::None); - } - - #[test] - fn test_rem_one_zero() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Rem { - dest, - left: v1, - right: v2, - unsigned: false, - }; - let constants = make_constants(&[(v2, ConstValue::I32(1))]); - assert_eq!( - simplify_op(&op, &constants), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_ceq_self_true() { - let v1 = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let op = SsaOp::Ceq { - dest, - left: v1, - right: v1, - }; - assert_eq!( - simplify_op(&op, &BTreeMap::new()), - SimplifyResult::Constant(ConstValue::I32(1)) - ); - } - - #[test] - fn test_clt_self_false() { - let v1 = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let op = SsaOp::Clt { - dest, - left: v1, - right: v1, - unsigned: false, - }; - assert_eq!( - simplify_op(&op, &BTreeMap::new()), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_cgt_self_false() { - let v1 = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let op = SsaOp::Cgt { - dest, - left: v1, - right: v1, - unsigned: false, - }; - assert_eq!( - simplify_op(&op, &BTreeMap::new()), - SimplifyResult::Constant(ConstValue::I32(0)) - ); - } -} diff --git a/dotscope/src/analysis/callgraph/graph.rs b/dotscope/src/analysis/callgraph/graph.rs index 527cd0fc..03d0b900 100644 --- a/dotscope/src/analysis/callgraph/graph.rs +++ b/dotscope/src/analysis/callgraph/graph.rs @@ -21,15 +21,13 @@ use crate::{ token::Token, typesystem::{CilTypeReference, TypeRegistry}, }, - utils::{ - escape_dot, - graph::{ - algorithms::{self, strongly_connected_components}, - DirectedGraph, NodeId, - }, - }, + utils::escape_dot, CilObject, Result, }; +use analyssa::graph::{ + algorithms::{self, strongly_connected_components}, + DirectedGraph, NodeId, +}; /// Inter-procedural call graph for a .NET assembly. /// diff --git a/dotscope/src/analysis/cfg/analyzer.rs b/dotscope/src/analysis/cfg/analyzer.rs deleted file mode 100644 index 7ab69534..00000000 --- a/dotscope/src/analysis/cfg/analyzer.rs +++ /dev/null @@ -1,330 +0,0 @@ -//! Loop analyzer for computing comprehensive loop information from SSA. -//! -//! This module provides the [`LoopAnalyzer`] which computes full [`LoopInfo`] -//! structures from an SSA function, including preheaders, latches, exits, -//! and loop type classification. - -use crate::{ - analysis::{ - cfg::{detect_loops, LoopForest}, - SsaCfg, SsaFunction, - }, - utils::graph::{algorithms, RootedGraph}, -}; - -/// Analyzes loops in an SSA function. -/// -/// The analyzer computes: -/// - Natural loops using dominance-based back edge detection -/// - Preheader identification for each loop -/// - Latch (back edge source) identification -/// - Exit edge detection -/// - Loop type classification -/// - Loop nesting relationships -/// -/// This is a thin wrapper around the generic `detect_loops` function, -/// providing a convenient SSA-specific interface. -pub struct LoopAnalyzer<'a> { - cfg: SsaCfg<'a>, -} - -impl<'a> LoopAnalyzer<'a> { - /// Creates a new loop analyzer for the given SSA function. - #[must_use] - pub fn new(ssa: &'a SsaFunction) -> Self { - let cfg = SsaCfg::from_ssa(ssa); - Self { cfg } - } - - /// Analyzes all loops and returns a [`LoopForest`]. - /// - /// Uses the shared `detect_loops` function which implements dominance-based - /// back edge detection and computes preheaders, exits, loop types, and nesting. - #[must_use] - pub fn analyze(&self) -> LoopForest { - let dominators = algorithms::compute_dominators(&self.cfg, self.cfg.entry()); - detect_loops(&self.cfg, &dominators) - } -} - -/// Extension trait for SSA functions to easily access loop analysis. -pub trait SsaLoopAnalysis { - /// Analyzes loops in this function. - fn analyze_loops(&self) -> LoopForest; -} - -impl SsaLoopAnalysis for SsaFunction { - fn analyze_loops(&self) -> LoopForest { - LoopAnalyzer::new(self).analyze() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - analysis::{cfg::loops::LoopType, SsaFunctionBuilder, SsaVarId}, - utils::graph::NodeId, - }; - - #[test] - fn test_find_condition_in_body() { - // Create a simple loop with a condition inside: - // B0 (entry) -> B1 (header) - // B1: jump to B2 - // B2 (condition): branch cond, B3, B4 - // B3 (body): jump to B1 (back edge) - // B4 (exit): ret - - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry, jump to header - f.block(0, |b| b.jump(1)); - // B1: header (dispatcher-like), jump to condition - f.block(1, |b| b.jump(2)); - // B2: condition block with branch - f.block(2, |b| { - let cond = b.const_true(); - b.branch(cond, 3, 4); - }); - // B3: body, jump back to header (back edge to B1) - f.block(3, |b| b.jump(1)); - // B4: exit - f.block(4, |b| b.ret()); - }) - .unwrap(); - - let forest = ssa.analyze_loops(); - - assert_eq!(forest.len(), 1, "Should have one loop"); - - let loop_info = &forest.loops()[0]; - assert_eq!(loop_info.header, NodeId::new(1), "Header should be B1"); - - // The condition block should be B2 (the one with Branch) - let condition = loop_info.find_condition_in_body(&ssa); - assert_eq!( - condition, - Some(NodeId::new(2)), - "Condition block should be B2" - ); - } - - #[test] - fn test_find_all_conditions_in_body() { - // Create a loop with multiple conditional branches - // B0 -> B1 (header) - // B1: jump to B2 - // B2: branch cond1, B3, B4 - // B3: branch cond2, B5, B1 (early exit or continue) - // B4: branch cond3, B1, B5 (back edge or exit) - // B5: ret - - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.jump(2)); - f.block(2, |b| { - let cond = b.const_true(); - b.branch(cond, 3, 4); - }); - f.block(3, |b| { - let cond = b.const_true(); - b.branch(cond, 5, 1); - }); - f.block(4, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 5); - }); - f.block(5, |b| b.ret()); - }) - .unwrap(); - - let forest = ssa.analyze_loops(); - assert!(!forest.is_empty(), "Should have at least one loop"); - - let loop_info = &forest.loops()[0]; - let conditions = loop_info.find_all_conditions_in_body(&ssa); - - // Should find multiple condition blocks in the loop body - assert!( - !conditions.is_empty(), - "Should find at least one condition block" - ); - } - - #[test] - fn test_simple_while_loop() { - // Create a simple while loop: - // B0 (entry) -> B1 (header) - // B1: branch cond, B2, B3 - // B2 (body) -> B1 (back edge) - // B3 (exit) - - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry, jump to header - f.block(0, |b| b.jump(1)); - // B1: header with conditional branch - f.block(1, |b| { - let cond = b.const_true(); - b.branch(cond, 2, 3); - }); - // B2: body, jump back to header - f.block(2, |b| b.jump(1)); - // B3: exit - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let forest = ssa.analyze_loops(); - - assert_eq!(forest.len(), 1); - - let loop_info = &forest.loops()[0]; - assert_eq!(loop_info.header, NodeId::new(1)); - assert!(loop_info.contains(NodeId::new(1))); - assert!(loop_info.contains(NodeId::new(2))); - assert!(!loop_info.contains(NodeId::new(0))); - assert!(!loop_info.contains(NodeId::new(3))); - - // Should have single latch - assert!(loop_info.has_single_latch()); - assert_eq!(loop_info.single_latch(), Some(NodeId::new(2))); - - // Should have preheader (B0) - assert!(loop_info.has_preheader()); - assert_eq!(loop_info.preheader, Some(NodeId::new(0))); - - // Should be pre-tested (exit from header) - assert_eq!(loop_info.loop_type, LoopType::PreTested); - assert!(loop_info.is_canonical()); - } - - #[test] - fn test_do_while_loop() { - // Create a do-while loop: - // B0 (entry) -> B1 (header/body) - // B1: branch cond, B1, B2 (back edge is latch with exit) - // B2 (exit) - - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry, jump to header - f.block(0, |b| b.jump(1)); - // B1: header/body with conditional back edge - f.block(1, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); // back edge to 1, exit to 2 - }); - // B2: exit - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let forest = ssa.analyze_loops(); - - assert_eq!(forest.len(), 1); - - let loop_info = &forest.loops()[0]; - assert_eq!(loop_info.header, NodeId::new(1)); - - // Latch is the header itself (self-loop) - assert!(loop_info.has_single_latch()); - assert_eq!(loop_info.single_latch(), Some(NodeId::new(1))); - - // Exit is from latch, so this is post-tested - assert_eq!(loop_info.loop_type, LoopType::PostTested); - } - - #[test] - fn test_nested_loops() { - // Create nested loops: - // B0 -> B1 (outer header) - // B1 -> B2 (inner header) - // B2 -> B2 (inner back edge), B3 - // B3 -> B1 (outer back edge), B4 - // B4 (exit) - - let ssa = { - let mut cond_out = SsaVarId::from_index(0); - SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry - f.block(0, |b| b.jump(1)); - // B1: outer header - f.block(1, |b| b.jump(2)); - // B2: inner header with self-loop - f.block(2, |b| { - let c = b.const_true(); - cond_out = c; - b.branch(c, 2, 3); // inner back edge to 2, exit to 3 - }); - // B3: between inner and outer, branches back to outer header or exits - f.block(3, |b| b.branch(cond_out, 1, 4)); // outer back edge to 1, exit to 4 - // B4: exit - f.block(4, |b| b.ret()); - }) - .unwrap() - }; - - let forest = ssa.analyze_loops(); - - assert_eq!(forest.len(), 2); - - // Find inner and outer loops - let inner = forest.loop_for_header(NodeId::new(2)).unwrap(); - let outer = forest.loop_for_header(NodeId::new(1)).unwrap(); - - // Inner loop should be nested in outer - assert_eq!(inner.parent, Some(NodeId::new(1))); - assert!(outer.children.contains(&NodeId::new(2))); - - // Depths - assert_eq!(outer.depth, 0); - assert_eq!(inner.depth, 1); - - // Block 2 should be in inner loop - assert_eq!(forest.loop_depth(NodeId::new(2)), 2); - } - - #[test] - fn test_induction_variable_api() { - // Test that the induction variable detection API works correctly. - // We use an existing loop structure and verify the method can be called. - - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry, jump to header - f.block(0, |b| b.jump(1)); - // B1: header with conditional branch - f.block(1, |b| { - let cond = b.const_true(); - b.branch(cond, 2, 3); - }); - // B2: body, jump back to header - f.block(2, |b| b.jump(1)); - // B3: exit - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let forest = ssa.analyze_loops(); - - assert_eq!(forest.len(), 1, "Should have one loop"); - - let loop_info = &forest.loops()[0]; - - // Call the induction variable detection method - // (may return empty since our simple test loop has no phi nodes) - let induction_vars = loop_info.find_induction_vars(&ssa); - - // This simple loop has no phi nodes at the header, so no induction vars - // The test verifies the API works without panicking - assert!( - induction_vars.is_empty(), - "Simple loop without phi should have no induction vars" - ); - } -} diff --git a/dotscope/src/analysis/cfg/graph.rs b/dotscope/src/analysis/cfg/graph.rs index cc6ad6f3..61f5ea8e 100644 --- a/dotscope/src/analysis/cfg/graph.rs +++ b/dotscope/src/analysis/cfg/graph.rs @@ -5,17 +5,18 @@ use std::{fmt::Write, sync::OnceLock}; +use analyssa::{ + graph::{ + algorithms::{self, DominatorTree}, + DirectedGraph, EdgeId, GraphBase, NodeId, Predecessors, RootedGraph, Successors, + }, + BitSet, +}; + use crate::{ analysis::cfg::{detect_loops, CfgEdge, CfgEdgeKind, LoopForest, LoopInfo}, assembly::{BasicBlock, FlowType, Operand}, - utils::{ - escape_dot, - graph::{ - algorithms::{self, DominatorTree}, - DirectedGraph, EdgeId, GraphBase, NodeId, Predecessors, RootedGraph, Successors, - }, - BitSet, - }, + utils::escape_dot, Error::GraphError, Result, }; diff --git a/dotscope/src/analysis/cfg/loops.rs b/dotscope/src/analysis/cfg/loops.rs deleted file mode 100644 index 116d3fea..00000000 --- a/dotscope/src/analysis/cfg/loops.rs +++ /dev/null @@ -1,936 +0,0 @@ -//! Extended loop analysis infrastructure. -//! -//! This module provides comprehensive loop analysis beyond basic natural loop detection, -//! including preheader identification, latch detection, exit analysis, and loop -//! classification. -//! -//! # Loop Structure -//! -//! A well-formed loop has the following structure: -//! -//! ```text -//! [preheader] <- Single entry predecessor (optional, may need insertion) -//! | -//! v -//! [header] <------+ <- Single entry point, dominates all loop nodes -//! | | -//! v | -//! [body ...] | <- Loop body nodes -//! | | -//! v | -//! [latch] --------+ <- Back edge source(s) -//! | -//! v -//! [exit ...] <- Exit blocks (outside loop, have predecessor in loop) -//! ``` -//! -//! # Loop Types -//! -//! Loops are classified into: -//! - **Pre-tested** (while): Condition checked at header before body -//! - **Post-tested** (do-while): Condition checked at latch after body -//! - **Infinite**: No exit condition (or condition always true) -//! - **Complex**: Multiple back edges or irregular structure -//! -//! # Canonicalization -//! -//! Canonical loops have: -//! - Single preheader (non-loop predecessor to header) -//! - Single latch (single back edge to header) -//! - Dedicated exit blocks (exits have single predecessor) -//! -//! # Generic Loop Detection -//! -//! The [`detect_loops`] function can analyze any graph implementing the required -//! traits (`GraphBase`, `Successors`, `Predecessors`). This enables loop detection -//! on CIL CFGs, SSA CFGs, x86 CFGs, and any other graph structure. -//! -//! ```rust,ignore -//! use dotscope::analysis::cfg::detect_loops; -//! use dotscope::utils::graph::algorithms::compute_dominators; -//! -//! // Works with any graph implementing the traits -//! let dominators = compute_dominators(&graph, entry); -//! let forest = detect_loops(&graph, &dominators); -//! ``` - -use std::collections::HashMap; - -use crate::{ - analysis::{SsaFunction, SsaOp, SsaVarId}, - utils::{ - graph::{algorithms::DominatorTree, GraphBase, NodeId, Predecessors, Successors}, - BitSet, - }, -}; - -/// Classification of loop types based on structure. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum LoopType { - /// Pre-tested loop (while): exit condition at header. - /// ```text - /// while (cond) { body } - /// ``` - PreTested, - - /// Post-tested loop (do-while): exit condition at latch. - /// ```text - /// do { body } while (cond) - /// ``` - PostTested, - - /// Infinite loop: no exit edges from loop body. - /// ```text - /// while (true) { body } - /// ``` - Infinite, - - /// Complex loop: multiple latches, irregular exits, or irreducible. - Complex, -} - -/// Exit edge information for a loop. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct LoopExit { - /// The block inside the loop that branches out. - pub exiting_block: NodeId, - /// The block outside the loop that is the exit target. - pub exit_block: NodeId, -} - -/// Classification of induction variable update operations. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum InductionUpdateKind { - /// `i = i + stride` (increment) - Add, - /// `i = i - stride` (decrement) - Sub, - /// `i = i * stride` (scaling) - Mul, - /// Unknown or complex update pattern - Unknown, -} - -/// Represents an induction variable in a loop. -/// -/// An induction variable is a variable whose value changes by a fixed amount -/// on each iteration of a loop. Classic examples include loop counters (`i++`). -/// -/// # Structure -/// -/// An induction variable has: -/// - A phi node at the loop header that merges the initial and updated values -/// - An initial value from outside the loop (preheader) -/// - An updated value computed inside the loop (typically in the latch) -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct InductionVar { - /// The phi node result variable at the loop header. - pub phi_result: SsaVarId, - /// The initial value (from preheader or outside loop). - pub init_value: SsaVarId, - /// The block providing the initial value. - pub init_block: NodeId, - /// The updated value (from inside loop, typically latch). - pub update_value: SsaVarId, - /// The block providing the updated value. - pub update_block: NodeId, - /// The type of update operation. - pub update_kind: InductionUpdateKind, - /// The stride (constant value added/subtracted per iteration), if known. - pub stride: Option, -} - -/// Comprehensive loop information. -/// -/// This extends `NaturalLoop` with additional structural information needed -/// for loop canonicalization and optimization. -#[derive(Debug, Clone)] -pub struct LoopInfo { - /// The header block (single entry point, dominates all loop nodes). - pub header: NodeId, - - /// All blocks in the loop body (including header). - pub body: BitSet, - - /// Back edge sources (blocks that jump to the header from within the loop). - pub latches: Vec, - - /// Preheader block if one exists (single non-loop predecessor of header). - /// `None` if header has multiple non-loop predecessors or none. - pub preheader: Option, - - /// Exit edges from the loop. - pub exits: Vec, - - /// Loop nesting depth (0 = outermost). - pub depth: usize, - - /// Classification of the loop type. - pub loop_type: LoopType, - - /// Parent loop header, if this loop is nested. - pub parent: Option, - - /// Immediate child loop headers. - pub children: Vec, -} - -impl LoopInfo { - /// Creates a new `LoopInfo` with the given header. - #[must_use] - pub fn new(header: NodeId, node_count: usize) -> Self { - let mut body = BitSet::new(node_count); - body.insert(header.index()); - Self { - header, - body, - latches: Vec::new(), - preheader: None, - exits: Vec::new(), - depth: 0, - loop_type: LoopType::Complex, - parent: None, - children: Vec::new(), - } - } - - /// Returns true if this loop contains the given block. - #[must_use] - pub fn contains(&self, node: NodeId) -> bool { - self.body.contains(node.index()) - } - - /// Returns the number of blocks in the loop. - #[must_use] - pub fn size(&self) -> usize { - self.body.count() - } - - /// Returns true if the loop has a single latch (canonical form). - #[must_use] - pub fn has_single_latch(&self) -> bool { - self.latches.len() == 1 - } - - /// Returns the single latch if there is exactly one. - #[must_use] - pub fn single_latch(&self) -> Option { - if self.latches.len() == 1 { - self.latches.first().copied() - } else { - None - } - } - - /// Returns true if the loop has a preheader (canonical form). - #[must_use] - pub fn has_preheader(&self) -> bool { - self.preheader.is_some() - } - - /// Returns true if the loop is in canonical form. - /// - /// A canonical loop has: - /// - A single preheader - /// - A single latch - #[must_use] - pub fn is_canonical(&self) -> bool { - self.has_preheader() && self.has_single_latch() - } - - /// Returns true if this is an innermost loop (no children). - #[must_use] - pub fn is_innermost(&self) -> bool { - self.children.is_empty() - } - - /// Returns true if this is an outermost loop (no parent). - #[must_use] - pub fn is_outermost(&self) -> bool { - self.parent.is_none() - } - - /// Returns all exit blocks (blocks outside loop reachable from inside). - pub fn exit_blocks(&self) -> impl Iterator + '_ { - self.exits.iter().map(|e| e.exit_block) - } - - /// Returns all exiting blocks (blocks inside loop that branch out). - pub fn exiting_blocks(&self) -> impl Iterator + '_ { - self.exits.iter().map(|e| e.exiting_block) - } - - /// Returns the number of exits from this loop. - #[must_use] - pub fn exit_count(&self) -> usize { - self.exits.len() - } - - /// Returns true if the header is also an exiting block. - /// - /// This indicates a pre-tested loop (condition at entry). - #[must_use] - pub fn header_is_exiting(&self) -> bool { - self.exits.iter().any(|e| e.exiting_block == self.header) - } - - /// Returns true if a latch is also an exiting block. - /// - /// This indicates a post-tested loop (condition at end). - #[must_use] - pub fn latch_is_exiting(&self) -> bool { - self.exits - .iter() - .any(|e| self.latches.contains(&e.exiting_block)) - } - - /// Finds the condition block inside the loop body. - /// - /// For control-flow flattened code (e.g., ConfuserEx), the actual loop - /// condition is often inside a case block rather than at the dispatcher - /// header. This method searches for blocks with `Branch` instructions - /// within the loop body. - /// - /// # Returns - /// - /// - `Some(NodeId)` - The first block found with a conditional branch - /// - `None` - No conditional branch found in the loop body - #[must_use] - pub fn find_condition_in_body(&self, ssa: &SsaFunction) -> Option { - for block_idx in self.body.iter() { - if let Some(block) = ssa.block(block_idx) { - if matches!(block.terminator_op(), Some(SsaOp::Branch { .. })) { - return Some(NodeId::new(block_idx)); - } - } - } - None - } - - /// Finds all conditional blocks within the loop body. - /// - /// Unlike `find_condition_in_body`, this returns all blocks with - /// conditional branches, useful for complex loops with multiple exit points. - #[must_use] - pub fn find_all_conditions_in_body(&self, ssa: &SsaFunction) -> Vec { - self.body - .iter() - .filter(|&block_idx| { - ssa.block(block_idx) - .is_some_and(|b| matches!(b.terminator_op(), Some(SsaOp::Branch { .. }))) - }) - .map(NodeId::new) - .collect() - } - - /// Identifies induction variables in this loop. - /// - /// An induction variable is identified by finding phi nodes at the loop - /// header where: - /// - One operand comes from outside the loop (initial value) - /// - One operand comes from inside the loop (updated value) - /// - /// The method attempts to classify the update kind (add, sub, etc.) by - /// analyzing the instruction that produces the update value. - /// - /// # Returns - /// - /// A vector of [`InductionVar`] structures describing each induction variable. - #[must_use] - pub fn find_induction_vars(&self, ssa: &SsaFunction) -> Vec { - let mut induction_vars = Vec::new(); - - // Get phi nodes at the header - let Some(header_block) = ssa.block(self.header.index()) else { - return induction_vars; - }; - - for phi in header_block.phi_nodes() { - let operands = phi.operands(); - - // Need at least 2 operands (init + update) - if operands.len() < 2 { - continue; - } - - // Find operands from inside vs outside the loop - let (inside_ops, outside_ops): (Vec<&_>, Vec<&_>) = operands - .iter() - .partition(|op| self.body.contains(op.predecessor())); - - // Classic induction variable: 1 init from outside, 1+ updates from inside - if outside_ops.len() == 1 && !inside_ops.is_empty() { - let (Some(init_op), Some(update_op)) = (outside_ops.first(), inside_ops.first()) - else { - continue; - }; - - // Try to determine update kind by analyzing the defining instruction - let (update_kind, stride) = - Self::analyze_update_instruction(ssa, update_op.value(), phi.result()); - - induction_vars.push(InductionVar { - phi_result: phi.result(), - init_value: init_op.value(), - init_block: NodeId::new(init_op.predecessor()), - update_value: update_op.value(), - update_block: NodeId::new(update_op.predecessor()), - update_kind, - stride, - }); - } - } - - induction_vars - } - - /// Analyzes an instruction to determine if it's an induction update. - /// - /// Looks for patterns like `v = phi_result + const` or `v = phi_result - const`. - fn analyze_update_instruction( - ssa: &SsaFunction, - update_var: SsaVarId, - phi_result: SsaVarId, - ) -> (InductionUpdateKind, Option) { - // Find the instruction that defines update_var - let Some(var) = ssa.variable(update_var) else { - return (InductionUpdateKind::Unknown, None); - }; - let def_site = var.def_site(); - - if def_site.is_phi() { - return (InductionUpdateKind::Unknown, None); - } - - let Some(block) = ssa.block(def_site.block) else { - return (InductionUpdateKind::Unknown, None); - }; - - let Some(instr_idx) = def_site.instruction else { - return (InductionUpdateKind::Unknown, None); - }; - - let Some(instr) = block.instruction(instr_idx) else { - return (InductionUpdateKind::Unknown, None); - }; - - // Check for Add/Sub patterns - match instr.op() { - // Check if one operand is the phi result - SsaOp::Add { left, right, .. } if *left == phi_result || *right == phi_result => { - let other = if *left == phi_result { *right } else { *left }; - let stride = ssa.try_constant_value(other).and_then(|v| v.as_i64()); - return (InductionUpdateKind::Add, stride); - } - // For subtraction, left should be phi_result - SsaOp::Sub { left, right, .. } if *left == phi_result => { - let stride = ssa.try_constant_value(*right).and_then(|v| v.as_i64()); - return (InductionUpdateKind::Sub, stride); - } - SsaOp::Mul { left, right, .. } if *left == phi_result || *right == phi_result => { - let other = if *left == phi_result { *right } else { *left }; - let stride = ssa.try_constant_value(other).and_then(|v| v.as_i64()); - return (InductionUpdateKind::Mul, stride); - } - _ => {} - } - - (InductionUpdateKind::Unknown, None) - } -} - -/// Loop forest containing all loops in a function. -/// -/// Provides efficient queries for loop membership, nesting, and iteration. -#[derive(Debug, Clone)] -pub struct LoopForest { - /// All loops indexed by their header block. - loops: Vec, - /// Map from block to the innermost loop containing it. - block_to_loop: Vec>, -} - -impl LoopForest { - /// Creates an empty loop forest. - #[must_use] - pub fn new(block_count: usize) -> Self { - Self { - loops: Vec::new(), - block_to_loop: vec![None; block_count], - } - } - - /// Adds a loop to the forest. - pub fn add_loop(&mut self, loop_info: LoopInfo) { - let loop_idx = self.loops.len(); - - // Update block-to-loop mapping for all blocks in this loop - for block_idx in loop_info.body.iter() { - let Some(slot) = self.block_to_loop.get_mut(block_idx) else { - continue; - }; - // Only update if this is a more deeply nested loop - match *slot { - Some(existing_idx) => { - if self - .loops - .get(existing_idx) - .is_some_and(|l| l.depth < loop_info.depth) - { - *slot = Some(loop_idx); - } - } - None => *slot = Some(loop_idx), - } - } - - self.loops.push(loop_info); - } - - /// Returns all loops in the forest. - #[must_use] - pub fn loops(&self) -> &[LoopInfo] { - &self.loops - } - - /// Returns the number of loops. - #[must_use] - pub fn len(&self) -> usize { - self.loops.len() - } - - /// Returns true if there are no loops. - #[must_use] - pub fn is_empty(&self) -> bool { - self.loops.is_empty() - } - - /// Returns the innermost loop containing the given block. - #[must_use] - pub fn innermost_loop(&self, block: NodeId) -> Option<&LoopInfo> { - let block_idx = block.index(); - let loop_idx = (*self.block_to_loop.get(block_idx)?)?; - self.loops.get(loop_idx) - } - - /// Returns the loop with the given header. - #[must_use] - pub fn loop_for_header(&self, header: NodeId) -> Option<&LoopInfo> { - self.loops.iter().find(|l| l.header == header) - } - - /// Returns the loop depth for a block (0 if not in any loop). - #[must_use] - pub fn loop_depth(&self, block: NodeId) -> usize { - self.innermost_loop(block) - .map_or(0, |l| l.depth.saturating_add(1)) - } - - /// Returns true if a block is in any loop. - #[must_use] - pub fn is_in_loop(&self, block: NodeId) -> bool { - self.innermost_loop(block).is_some() - } - - /// Iterates over all loops in the forest. - pub fn iter(&self) -> impl Iterator { - self.loops.iter() - } - - /// Returns loops sorted by depth (outermost first). - #[must_use] - pub fn by_depth_ascending(&self) -> Vec<&LoopInfo> { - let mut sorted: Vec<_> = self.loops.iter().collect(); - sorted.sort_by_key(|l| l.depth); - sorted - } - - /// Returns loops sorted by depth (innermost first). - #[must_use] - pub fn by_depth_descending(&self) -> Vec<&LoopInfo> { - let mut sorted: Vec<_> = self.loops.iter().collect(); - sorted.sort_by_key(|l| std::cmp::Reverse(l.depth)); - sorted - } -} - -/// Detects all natural loops in a graph using dominance-based back edge detection. -/// -/// This is the primary entry point for loop detection. It works with any graph -/// implementing the required traits, enabling loop analysis on various graph types -/// (CIL CFGs, SSA CFGs, x86 CFGs, etc.). -/// -/// # Algorithm -/// -/// The detection algorithm: -/// 1. Finds back edges using dominance (n -> h where h dominates n) -/// 2. For each back edge, computes the natural loop body -/// 3. Computes preheaders, exits, and loop types -/// 4. Establishes nesting relationships -/// -/// # Arguments -/// -/// * `graph` - Any graph implementing `GraphBase + Successors + Predecessors` -/// * `dominators` - Pre-computed dominator tree for the graph -/// -/// # Returns -/// -/// A [`LoopForest`] containing all detected loops with their full analysis. -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::analysis::cfg::detect_loops; -/// use dotscope::utils::graph::algorithms::compute_dominators; -/// -/// let dominators = compute_dominators(&graph, graph.entry()); -/// let forest = detect_loops(&graph, &dominators); -/// -/// for loop_info in forest.loops() { -/// println!("Loop at {:?} with {} blocks", loop_info.header, loop_info.size()); -/// } -/// ``` -#[must_use] -pub fn detect_loops(graph: &G, dominators: &DominatorTree) -> LoopForest -where - G: GraphBase + Successors + Predecessors, -{ - let block_count = graph.node_count(); - let mut forest = LoopForest::new(block_count); - - // Collect loops by header - let mut loops_by_header: HashMap = HashMap::new(); - - // Find all back edges: edge (n -> h) where h dominates n - for node in graph.node_ids() { - for succ in graph.successors(node) { - // Check if successor dominates current node (back edge) - if dominators.dominates(succ, node) { - // Found back edge: node -> succ (succ is loop header) - let header = succ; - - let loop_info = loops_by_header - .entry(header) - .or_insert_with(|| LoopInfo::new(header, block_count)); - - loop_info.latches.push(node); - expand_loop_body(graph, loop_info, node); - } - } - } - - // Compute additional loop information for each loop - for loop_info in loops_by_header.values_mut() { - compute_preheader(graph, loop_info); - compute_exits(graph, loop_info); - loop_info.loop_type = classify_loop(loop_info); - } - - // Convert to Vec and compute nesting relationships - let mut loops: Vec = loops_by_header.into_values().collect(); - compute_nesting(&mut loops); - - // Sort by header for deterministic ordering - loops.sort_by_key(|l| l.header.index()); - - // Add all loops to forest - for loop_info in loops { - forest.add_loop(loop_info); - } - - forest -} - -/// Checks if a graph has any back edges (loops). -/// -/// This is a fast check that returns as soon as the first back edge is found, -/// without building the full loop forest. Use this when you only need to know -/// whether loops exist, not their detailed structure. -/// -/// # Arguments -/// -/// * `graph` - Any graph implementing `GraphBase + Successors` -/// * `dominators` - Pre-computed dominator tree for the graph -/// -/// # Returns -/// -/// `true` if at least one back edge exists, `false` otherwise. -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::analysis::cfg::has_back_edges; -/// use dotscope::utils::graph::algorithms::compute_dominators; -/// -/// let dominators = compute_dominators(&graph, graph.entry()); -/// if has_back_edges(&graph, &dominators) { -/// println!("Graph contains loops"); -/// } -/// ``` -#[must_use] -pub fn has_back_edges(graph: &G, dominators: &DominatorTree) -> bool -where - G: GraphBase + Successors, -{ - for node in graph.node_ids() { - for succ in graph.successors(node) { - if dominators.dominates(succ, node) { - return true; - } - } - } - false -} - -/// Expands the loop body to include all nodes that can reach the latch. -/// -/// Uses a worklist algorithm: starting from the latch, we add -/// predecessors that aren't the header until we've found all loop body nodes. -fn expand_loop_body(graph: &G, loop_info: &mut LoopInfo, latch: NodeId) -where - G: Predecessors, -{ - if loop_info.body.contains(latch.index()) { - return; - } - - let mut worklist = vec![latch]; - - while let Some(node) = worklist.pop() { - if loop_info.body.insert(node.index()) { - // Node wasn't in body yet, add its predecessors - for pred in graph.predecessors(node) { - if pred != loop_info.header && !loop_info.body.contains(pred.index()) { - worklist.push(pred); - } - } - } - } -} - -/// Identifies the preheader for a loop. -/// -/// A preheader is a single predecessor of the header that is outside the loop. -/// If the header has multiple non-loop predecessors, there is no preheader. -fn compute_preheader(graph: &G, loop_info: &mut LoopInfo) -where - G: Predecessors, -{ - let mut non_loop_preds: Vec = Vec::new(); - - for pred in graph.predecessors(loop_info.header) { - if !loop_info.body.contains(pred.index()) { - non_loop_preds.push(pred); - } - } - - // Preheader exists only if there's exactly one non-loop predecessor - loop_info.preheader = if non_loop_preds.len() == 1 { - non_loop_preds.first().copied() - } else { - None - }; -} - -/// Computes exit edges for a loop. -/// -/// An exit edge goes from a block inside the loop to a block outside the loop. -fn compute_exits(graph: &G, loop_info: &mut LoopInfo) -where - G: Successors, -{ - loop_info.exits.clear(); - - for body_block_idx in loop_info.body.iter() { - let body_block = NodeId::new(body_block_idx); - for succ in graph.successors(body_block) { - if !loop_info.body.contains(succ.index()) { - loop_info.exits.push(LoopExit { - exiting_block: body_block, - exit_block: succ, - }); - } - } - } -} - -/// Classifies the loop type based on structure. -fn classify_loop(loop_info: &LoopInfo) -> LoopType { - // Check for infinite loop (no exits) - if loop_info.exits.is_empty() { - return LoopType::Infinite; - } - - // Check for multiple latches (complex) - if loop_info.latches.len() > 1 { - return LoopType::Complex; - } - - // Get the single latch - let latch = loop_info.single_latch(); - - // Check if all exits are from the latch (post-tested / do-while loop) - if let Some(latch) = latch { - let latch_exits = loop_info - .exits - .iter() - .filter(|e| e.exiting_block == latch) - .count(); - - if latch_exits == loop_info.exits.len() && latch_exits > 0 { - return LoopType::PostTested; - } - } - - // Check if header is the only exiting block (pre-tested / while loop) - let header_exits = loop_info - .exits - .iter() - .filter(|e| e.exiting_block == loop_info.header) - .count(); - - if header_exits == loop_info.exits.len() && header_exits > 0 { - return LoopType::PreTested; - } - - // Mixed or irregular exit structure - LoopType::Complex -} - -/// Computes loop nesting relationships and depths. -fn compute_nesting(loops: &mut [LoopInfo]) { - let n = loops.len(); - - // Build header-to-index mapping - let header_to_idx: HashMap = loops - .iter() - .enumerate() - .map(|(i, l)| (l.header, i)) - .collect(); - - // For each loop, find its parent (smallest enclosing loop) - for i in 0..n { - let Some(header) = loops.get(i).map(|l| l.header) else { - continue; - }; - - // Find all loops that contain this loop's header (except itself) - let mut candidates: Vec = (0..n) - .filter(|&j| { - j != i - && loops - .get(j) - .is_some_and(|l| l.body.contains(header.index())) - }) - .collect(); - - // Parent is the smallest containing loop - if !candidates.is_empty() { - candidates.sort_by_key(|&j| loops.get(j).map_or(usize::MAX, LoopInfo::size)); - let parent_idx = match candidates.first().copied() { - Some(p) => p, - None => continue, - }; - let parent_header = match loops.get(parent_idx).map(|l| l.header) { - Some(h) => h, - None => continue, - }; - if let Some(loop_i) = loops.get_mut(i) { - loop_i.parent = Some(parent_header); - } - } - } - - // Compute children from parent relationships - for i in 0..n { - let parent_opt = loops.get(i).and_then(|l| l.parent); - let Some(parent_header) = parent_opt else { - continue; - }; - let header_i = match loops.get(i).map(|l| l.header) { - Some(h) => h, - None => continue, - }; - if let Some(&parent_idx) = header_to_idx.get(&parent_header) { - if let Some(parent) = loops.get_mut(parent_idx) { - parent.children.push(header_i); - } - } - } - - // Compute depths from parent chain - for i in 0..n { - let mut depth: usize = 0; - let mut current = loops.get(i).and_then(|l| l.parent); - while let Some(parent_header) = current { - depth = depth.saturating_add(1); - if let Some(&parent_idx) = header_to_idx.get(&parent_header) { - current = loops.get(parent_idx).and_then(|l| l.parent); - } else { - break; - } - } - if let Some(l) = loops.get_mut(i) { - l.depth = depth; - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_loop_info_creation() { - let header = NodeId::new(0); - let loop_info = LoopInfo::new(header, 10); - - assert_eq!(loop_info.header, header); - assert!(loop_info.contains(header)); - assert_eq!(loop_info.size(), 1); - assert!(!loop_info.has_single_latch()); - assert!(!loop_info.has_preheader()); - assert!(!loop_info.is_canonical()); - } - - #[test] - fn test_loop_info_canonical() { - let header = NodeId::new(1); - let mut loop_info = LoopInfo::new(header, 10); - - loop_info.preheader = Some(NodeId::new(0)); - loop_info.latches.push(NodeId::new(2)); - - assert!(loop_info.has_preheader()); - assert!(loop_info.has_single_latch()); - assert!(loop_info.is_canonical()); - } - - #[test] - fn test_loop_forest() { - let mut forest = LoopForest::new(10); - - let mut outer_loop = LoopInfo::new(NodeId::new(1), 10); - outer_loop.body.insert(2); - outer_loop.body.insert(3); - outer_loop.depth = 0; - - let mut inner_loop = LoopInfo::new(NodeId::new(2), 10); - inner_loop.body.insert(3); - inner_loop.depth = 1; - - forest.add_loop(outer_loop); - forest.add_loop(inner_loop); - - assert_eq!(forest.len(), 2); - - // Block 3 should be in inner loop (depth 1) - assert_eq!(forest.loop_depth(NodeId::new(3)), 2); - - // Block 1 should be in outer loop only - assert_eq!(forest.loop_depth(NodeId::new(1)), 1); - - // Block 0 should not be in any loop - assert_eq!(forest.loop_depth(NodeId::new(0)), 0); - } -} diff --git a/dotscope/src/analysis/cfg/mod.rs b/dotscope/src/analysis/cfg/mod.rs index e025b5c4..bd983df7 100644 --- a/dotscope/src/analysis/cfg/mod.rs +++ b/dotscope/src/analysis/cfg/mod.rs @@ -12,8 +12,11 @@ //! # Key Components //! //! - [`ControlFlowGraph`] - The main CFG structure wrapping basic blocks -//! - [`CfgEdge`] - Edge representation with control flow semantics -//! - [`CfgEdgeKind`] - Classification of edge types (unconditional, conditional, etc.) +//! - [`CfgEdge`] / [`CfgEdgeKind`] - Edges and their control-flow classification +//! - [`LoopAnalyzer`] / [`LoopForest`] / [`LoopInfo`] - Loop detection (re-exported +//! from `analyssa::analysis::loops` / `analyssa::analysis::loop_analyzer`) +//! - [`SemanticAnalyzer`] / [`BlockSemantics`] / [`LoopSemantics`] - Higher-level +//! block- and loop-role classification used by deobfuscation passes //! //! # Edge Types //! @@ -82,18 +85,23 @@ //! access after construction. The lazy-initialized dominator tree and loop info //! use [`std::sync::OnceLock`] for thread-safe initialization. -mod analyzer; mod edge; mod graph; -mod loops; mod semantics; -pub use analyzer::LoopAnalyzer; -#[cfg(feature = "compiler")] -pub use analyzer::SsaLoopAnalysis; pub use edge::{CfgEdge, CfgEdgeKind}; pub use graph::ControlFlowGraph; -#[cfg(feature = "x86")] -pub use loops::has_back_edges; -pub use loops::{detect_loops, InductionVar, LoopForest, LoopInfo}; pub use semantics::{BlockSemantics, LoopSemantics, SemanticAnalyzer}; + +// `LoopAnalyzer` and the extended-loop primitives live analyssa-side. CIL +// callers reach them through these aliases / re-exports. +use crate::analysis::ssa::CilTarget; + +#[cfg(feature = "compiler")] +pub use analyssa::analysis::loop_analyzer::SsaLoopAnalysis; +#[cfg(feature = "x86")] +pub use analyssa::analysis::loops::has_back_edges; +pub use analyssa::analysis::loops::{detect_loops, InductionVar, LoopForest, LoopInfo}; + +/// CIL-defaulted alias of [`analyssa::analysis::loop_analyzer::LoopAnalyzer`]. +pub type LoopAnalyzer<'a, T = CilTarget> = analyssa::analysis::loop_analyzer::LoopAnalyzer<'a, T>; diff --git a/dotscope/src/analysis/cfg/semantics.rs b/dotscope/src/analysis/cfg/semantics.rs index 17efb301..a92834cd 100644 --- a/dotscope/src/analysis/cfg/semantics.rs +++ b/dotscope/src/analysis/cfg/semantics.rs @@ -43,9 +43,10 @@ use std::collections::HashMap; -use crate::{ - analysis::{cfg::InductionVar, LoopInfo, SsaFunction, SsaOp, SsaVarId}, - utils::BitSet, +use analyssa::BitSet; + +use crate::analysis::{ + cfg::InductionVar, CilTarget, LoopInfo, SsaFunction, SsaOp, SsaVarId, Target, }; /// Semantic role of a basic block. @@ -216,18 +217,18 @@ impl LoopSemantics { } /// Analyzes semantic roles of blocks in an SSA function. -pub struct SemanticAnalyzer<'a> { - ssa: &'a SsaFunction, +pub struct SemanticAnalyzer<'a, T: Target = CilTarget> { + ssa: &'a SsaFunction, /// Cache of block semantics. block_cache: HashMap, /// Known dispatcher blocks. dispatcher_blocks: BitSet, } -impl<'a> SemanticAnalyzer<'a> { +impl<'a, T: Target> SemanticAnalyzer<'a, T> { /// Creates a new semantic analyzer for the given SSA function. #[must_use] - pub fn new(ssa: &'a SsaFunction) -> Self { + pub fn new(ssa: &'a SsaFunction) -> Self { Self { ssa, block_cache: HashMap::new(), diff --git a/dotscope/src/analysis/dataflow/framework.rs b/dotscope/src/analysis/dataflow/framework.rs deleted file mode 100644 index a36cf69f..00000000 --- a/dotscope/src/analysis/dataflow/framework.rs +++ /dev/null @@ -1,286 +0,0 @@ -//! Data flow analysis framework trait and direction. -//! -//! This module defines the core abstraction for data flow analyses. Any -//! specific analysis (reaching definitions, liveness, constant propagation) -//! implements the [`DataFlowAnalysis`] trait to work with the solver. - -use std::fmt::Debug; - -use crate::{ - analysis::{ - dataflow::lattice::MeetSemiLattice, ControlFlowGraph, SsaBlock, SsaCfg, SsaFunction, - }, - utils::graph::{NodeId, Predecessors, RootedGraph, Successors}, -}; - -/// Trait for control flow graphs usable with the dataflow solver. -/// -/// This trait abstracts over different CFG implementations, allowing the solver -/// to work with both [`ControlFlowGraph`] (CIL-level) and [`SsaCfg`] (SSA-level). -/// -/// [`ControlFlowGraph`]: crate::analysis::ControlFlowGraph -/// [`SsaCfg`]: crate::analysis::SsaCfg -pub trait DataFlowCfg: Predecessors + Successors { - /// Returns the entry node of the CFG. - fn entry(&self) -> NodeId; - - /// Returns the exit nodes of the CFG. - fn exits(&self) -> Vec; - - /// Returns nodes in postorder (for backward analysis). - fn postorder(&self) -> Vec; - - /// Returns nodes in reverse postorder (for forward analysis). - fn reverse_postorder(&self) -> Vec; -} - -/// Direction of data flow analysis. -/// -/// The direction determines how information propagates through the CFG -/// and which operation (meet or join) is used at control flow merge points. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Direction { - /// Information flows forward, from entry to exit. - /// - /// At join points (blocks with multiple predecessors), values from - /// all predecessors are combined using the meet operation. - /// - /// Examples: reaching definitions, available expressions, constant propagation. - Forward, - - /// Information flows backward, from exit to entry. - /// - /// At split points (blocks with multiple successors), values from - /// all successors are combined. - /// - /// Examples: live variables, very busy expressions. - Backward, -} - -/// A data flow analysis that can be run on SSA form. -/// -/// This trait defines the interface for a data flow analysis. Implementations -/// provide the transfer function and boundary conditions; the solver handles -/// iteration to a fixpoint. -/// -/// # Type Parameters -/// -/// * `L` - The lattice type representing abstract values at each program point -/// -/// # Direction -/// -/// The `DIRECTION` constant specifies whether this is a forward or backward -/// analysis. The solver uses this to determine iteration order and how to -/// combine values at control flow merge points. -/// -/// # Transfer Functions -/// -/// The core of any data flow analysis is the transfer function, which -/// describes how flowing through a basic block transforms the abstract state. -/// -/// For forward analyses: `out[B] = transfer(B, in[B])` -/// For backward analyses: `in[B] = transfer(B, out[B])` -/// -/// # Example -/// -/// ```rust,ignore -/// use dotscope::analysis::{DataFlowAnalysis, Direction, MeetSemiLattice}; -/// -/// struct MyAnalysis; -/// -/// impl DataFlowAnalysis for MyAnalysis { -/// type Lattice = MyLattice; -/// const DIRECTION: Direction = Direction::Forward; -/// -/// fn boundary(&self, _ssa: &SsaFunction) -> Self::Lattice { -/// MyLattice::initial_at_entry() -/// } -/// -/// fn initial(&self, _ssa: &SsaFunction) -> Self::Lattice { -/// MyLattice::top() -/// } -/// -/// fn transfer( -/// &self, -/// block_id: usize, -/// block: &SsaBlock, -/// input: &Self::Lattice, -/// ssa: &SsaFunction, -/// ) -> Self::Lattice { -/// // Compute the output state from the input state -/// // by applying the block's effects -/// todo!() -/// } -/// } -/// ``` -pub trait DataFlowAnalysis { - /// The lattice type for this analysis. - /// - /// This must implement `MeetSemiLattice` to support combining values - /// at control flow merge points. - type Lattice: MeetSemiLattice; - - /// The direction of this analysis. - const DIRECTION: Direction; - - /// Returns the initial value at the boundary of the function. - /// - /// For forward analyses, this is the value at function entry. - /// For backward analyses, this is the value at function exit(s). - /// - /// This often represents the "known" information at the boundary, - /// such as "all parameters are defined" for reaching definitions. - fn boundary(&self, ssa: &SsaFunction) -> Self::Lattice; - - /// Returns the initial value for interior blocks. - /// - /// This is the value used to initialize all non-boundary blocks - /// before iteration begins. For most analyses, this is the top - /// element of the lattice (no information). - fn initial(&self, ssa: &SsaFunction) -> Self::Lattice; - - /// Computes the transfer function for a basic block. - /// - /// Given the input state to a block, computes the output state - /// after flowing through the block. - /// - /// # Arguments - /// - /// * `block_id` - The index of the block being processed - /// * `block` - The SSA block - /// * `input` - The abstract state flowing into (forward) or out of (backward) the block - /// * `ssa` - The complete SSA function for context - /// - /// # Returns - /// - /// The abstract state after flowing through the block. - fn transfer( - &self, - block_id: usize, - block: &SsaBlock, - input: &Self::Lattice, - ssa: &SsaFunction, - ) -> Self::Lattice; - - /// Called when analysis is complete. - /// - /// This hook allows analyses to perform post-processing, such as - /// computing per-instruction results from block-level results. - /// - /// The default implementation does nothing. - fn finalize( - &mut self, - _in_states: &[Self::Lattice], - _out_states: &[Self::Lattice], - _ssa: &SsaFunction, - ) { - // Default: no post-processing - } -} - -/// Results of a data flow analysis. -/// -/// This provides access to the computed abstract values at block boundaries. -#[derive(Debug, Clone)] -pub struct AnalysisResults { - /// Input state for each block (before transfer function). - pub in_states: Vec, - /// Output state for each block (after transfer function). - pub out_states: Vec, -} - -impl AnalysisResults { - /// Creates new analysis results with the given states. - /// - /// # Arguments - /// - /// * `in_states` - The input states for each block - /// * `out_states` - The output states for each block - /// - /// # Returns - /// - /// A new [`AnalysisResults`] instance. - #[must_use] - pub fn new(in_states: Vec, out_states: Vec) -> Self { - Self { - in_states, - out_states, - } - } - - /// Returns the input state for a block. - /// - /// # Arguments - /// - /// * `block` - The block index to query - /// - /// # Returns - /// - /// The input state for the block, or `None` if the index is out of bounds. - #[must_use] - pub fn in_state(&self, block: usize) -> Option<&L> { - self.in_states.get(block) - } - - /// Returns the output state for a block. - /// - /// # Arguments - /// - /// * `block` - The block index to query - /// - /// # Returns - /// - /// The output state for the block, or `None` if the index is out of bounds. - #[must_use] - pub fn out_state(&self, block: usize) -> Option<&L> { - self.out_states.get(block) - } - - /// Returns the number of blocks. - /// - /// # Returns - /// - /// The total number of blocks in the analysis results. - #[must_use] - pub fn block_count(&self) -> usize { - self.in_states.len() - } -} - -// Implement DataFlowCfg for ControlFlowGraph -impl DataFlowCfg for ControlFlowGraph<'_> { - fn entry(&self) -> NodeId { - RootedGraph::entry(self) - } - - fn exits(&self) -> Vec { - self.exits().to_vec() - } - - fn postorder(&self) -> Vec { - self.postorder() - } - - fn reverse_postorder(&self) -> Vec { - self.reverse_postorder() - } -} - -// Implement DataFlowCfg for SsaCfg -impl DataFlowCfg for SsaCfg<'_> { - fn entry(&self) -> NodeId { - RootedGraph::entry(self) - } - - fn exits(&self) -> Vec { - self.exits() - } - - fn postorder(&self) -> Vec { - self.postorder() - } - - fn reverse_postorder(&self) -> Vec { - self.reverse_postorder() - } -} diff --git a/dotscope/src/analysis/dataflow/lattice.rs b/dotscope/src/analysis/dataflow/lattice.rs deleted file mode 100644 index 63850b5f..00000000 --- a/dotscope/src/analysis/dataflow/lattice.rs +++ /dev/null @@ -1,302 +0,0 @@ -//! Lattice traits for data flow analysis. -//! -//! A lattice is a mathematical structure that defines how abstract values -//! combine at control flow join points. This module provides the fundamental -//! traits that analysis domains must implement. -//! -//! # Lattice Theory Background -//! -//! For data flow analysis, we use lattices with the following properties: -//! -//! - **Partial Order**: Elements can be compared (≤) -//! - **Meet (∧)**: Greatest lower bound of two elements -//! - **Join (∨)**: Least upper bound of two elements -//! - **Top (⊤)**: Greatest element (no information) -//! - **Bottom (⊥)**: Least element (conflicting/all information) -//! -//! # Forward vs Backward Analysis -//! -//! - **Forward analysis** (e.g., reaching definitions): Uses meet at join points -//! - **Backward analysis** (e.g., liveness): Uses join at split points -//! -//! The solver automatically selects the appropriate operation based on -//! analysis direction. - -use std::fmt::Debug; - -use crate::utils::BitSet; - -/// A meet semi-lattice with a meet (greatest lower bound) operation. -/// -/// The meet operation combines information from multiple control flow paths. -/// It must satisfy: -/// -/// - **Idempotent**: `x.meet(x) = x` -/// - **Commutative**: `x.meet(y) = y.meet(x)` -/// - **Associative**: `x.meet(y.meet(z)) = (x.meet(y)).meet(z)` -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::analysis::MeetSemiLattice; -/// -/// impl MeetSemiLattice for ConstantLattice { -/// fn meet(&self, other: &Self) -> Self { -/// match (self, other) { -/// (Self::Top, x) | (x, Self::Top) => x.clone(), -/// (Self::Const(a), Self::Const(b)) if a == b => Self::Const(*a), -/// _ => Self::Bottom, -/// } -/// } -/// } -/// ``` -pub trait MeetSemiLattice: Clone + Debug + PartialEq { - /// Computes the meet (greatest lower bound) of two lattice elements. - /// - /// The meet represents combining information from two paths that merge. - #[must_use] - fn meet(&self, other: &Self) -> Self; - - /// Returns `true` if this is the bottom element. - /// - /// The bottom element represents "all information" or "conflict". - /// Once bottom is reached, further meets cannot change the value. - fn is_bottom(&self) -> bool; -} - -/// A join semi-lattice with a join (least upper bound) operation. -/// -/// The join operation combines information when paths split (for backward analysis) -/// or when we want to widen the approximation. -/// -/// It must satisfy: -/// -/// - **Idempotent**: `x.join(x) = x` -/// - **Commutative**: `x.join(y) = y.join(x)` -/// - **Associative**: `x.join(y.join(z)) = (x.join(y)).join(z)` -pub trait JoinSemiLattice: Clone + Debug + PartialEq { - /// Computes the join (least upper bound) of two lattice elements. - /// - /// The join represents the least specific value that covers both inputs. - #[must_use] - fn join(&self, other: &Self) -> Self; - - /// Returns `true` if this is the top element. - /// - /// The top element represents "no information" or "unknown". - /// It is the identity for meet: `x.meet(top) = x`. - fn is_top(&self) -> bool; -} - -/// A complete lattice with both meet and join operations. -/// -/// Most data flow analyses operate over complete lattices, which have -/// both a greatest and least element, plus meet and join operations. -/// -/// # Required Properties -/// -/// - All properties of `MeetSemiLattice` and `JoinSemiLattice` -/// - **Absorption**: `x.meet(x.join(y)) = x` and `x.join(x.meet(y)) = x` -pub trait Lattice: MeetSemiLattice + JoinSemiLattice { - /// Returns the top (⊤) element of the lattice. - /// - /// Top represents "no information" and is the identity for meet. - fn top() -> Self; - - /// Returns the bottom (⊥) element of the lattice. - /// - /// Bottom represents "all information" or "conflict". - fn bottom() -> Self; -} - -// Lattice trait implementations for BitSet (defined in crate::utils::bitset) - -impl MeetSemiLattice for BitSet { - /// Meet is union for reaching definitions (may analysis). - fn meet(&self, other: &Self) -> Self { - let mut result = self.clone(); - result.union_with(other); - result - } - - fn is_bottom(&self) -> bool { - // For may analysis, bottom is full set - self.count() == self.len() - } -} - -impl JoinSemiLattice for BitSet { - /// Join is intersection for reaching definitions. - fn join(&self, other: &Self) -> Self { - let mut result = self.clone(); - result.intersect_with(other); - result - } - - fn is_top(&self) -> bool { - self.is_empty() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_bitset_meet_union() { - let mut a = BitSet::new(10); - a.insert(1); - a.insert(3); - - let mut b = BitSet::new(10); - b.insert(2); - b.insert(3); - - let result = a.meet(&b); - - // Meet is union: {1, 3} ∪ {2, 3} = {1, 2, 3} - assert!(result.contains(1)); - assert!(result.contains(2)); - assert!(result.contains(3)); - assert!(!result.contains(0)); - assert!(!result.contains(4)); - } - - #[test] - fn test_bitset_meet_idempotent() { - let mut a = BitSet::new(10); - a.insert(1); - a.insert(5); - - let result = a.meet(&a); - - // Idempotent: x.meet(x) = x - assert_eq!(a, result); - } - - #[test] - fn test_bitset_meet_commutative() { - let mut a = BitSet::new(10); - a.insert(1); - a.insert(3); - - let mut b = BitSet::new(10); - b.insert(2); - b.insert(4); - - // Commutative: x.meet(y) = y.meet(x) - assert_eq!(a.meet(&b), b.meet(&a)); - } - - #[test] - fn test_bitset_meet_associative() { - let mut a = BitSet::new(10); - a.insert(1); - - let mut b = BitSet::new(10); - b.insert(2); - - let mut c = BitSet::new(10); - c.insert(3); - - // Associative: x.meet(y.meet(z)) = (x.meet(y)).meet(z) - let left = a.meet(&b.meet(&c)); - let right = a.meet(&b).meet(&c); - assert_eq!(left, right); - } - - #[test] - fn test_bitset_join_intersection() { - let mut a = BitSet::new(10); - a.insert(1); - a.insert(2); - a.insert(3); - - let mut b = BitSet::new(10); - b.insert(2); - b.insert(3); - b.insert(4); - - let result = a.join(&b); - - // Join is intersection: {1, 2, 3} ∩ {2, 3, 4} = {2, 3} - assert!(!result.contains(1)); - assert!(result.contains(2)); - assert!(result.contains(3)); - assert!(!result.contains(4)); - } - - #[test] - fn test_bitset_join_idempotent() { - let mut a = BitSet::new(10); - a.insert(1); - a.insert(5); - - let result = a.join(&a); - - // Idempotent: x.join(x) = x - assert_eq!(a, result); - } - - #[test] - fn test_bitset_join_commutative() { - let mut a = BitSet::new(10); - a.insert(1); - a.insert(3); - - let mut b = BitSet::new(10); - b.insert(2); - b.insert(3); - - // Commutative: x.join(y) = y.join(x) - assert_eq!(a.join(&b), b.join(&a)); - } - - #[test] - fn test_bitset_is_top_empty() { - let empty = BitSet::new(10); - assert!(empty.is_top()); - - let mut non_empty = BitSet::new(10); - non_empty.insert(0); - assert!(!non_empty.is_top()); - } - - #[test] - fn test_bitset_is_bottom_full() { - let full = BitSet::full(10); - assert!(full.is_bottom()); - - let mut partial = BitSet::new(10); - partial.insert(0); - assert!(!partial.is_bottom()); - } - - #[test] - fn test_bitset_meet_with_empty() { - let empty = BitSet::new(10); - - let mut a = BitSet::new(10); - a.insert(1); - a.insert(2); - - // Meet with empty (top) should give the other set - let result = a.meet(&empty); - assert!(result.contains(1)); - assert!(result.contains(2)); - assert_eq!(result.count(), 2); - } - - #[test] - fn test_bitset_join_with_empty() { - let empty = BitSet::new(10); - - let mut a = BitSet::new(10); - a.insert(1); - a.insert(2); - - // Join with empty (top) should give empty - let result = a.join(&empty); - assert!(result.is_empty()); - } -} diff --git a/dotscope/src/analysis/dataflow/liveness.rs b/dotscope/src/analysis/dataflow/liveness.rs deleted file mode 100644 index d5d7534d..00000000 --- a/dotscope/src/analysis/dataflow/liveness.rs +++ /dev/null @@ -1,336 +0,0 @@ -//! Live variable analysis. -//! -//! A variable is *live* at a program point if there exists a path from that -//! point to a use of the variable that doesn't pass through a definition of -//! the variable. In SSA form, since each variable is defined exactly once, -//! this simplifies to: a variable is live if it will be used on some path -//! from this point. -//! -//! # Uses -//! -//! Live variable analysis is essential for: -//! - **Dead code elimination**: If a definition's result is never live, it's dead -//! - **Register allocation**: Variables live at the same time need different registers -//! - **Debugging**: Determine which variables can be inspected at a breakpoint -//! -//! # Algorithm -//! -//! This is a backward data flow analysis: -//! -//! - `USE[B]` = variables used in B before any definition -//! - `DEF[B]` = variables defined in B -//! - `OUT[B]` = ∪{IN[S] | S is a successor of B} -//! - `IN[B]` = USE[B] ∪ (OUT[B] - DEF[B]) -//! -//! In SSA form, DEF[B] kills only the variables defined in B (which have -//! unique definitions anyway), so the analysis tracks which uses are live. - -use crate::{ - analysis::{ - dataflow::{ - framework::{DataFlowAnalysis, Direction}, - lattice::MeetSemiLattice, - }, - SsaBlock, SsaFunction, SsaVarId, - }, - utils::BitSet, -}; - -/// Live variable analysis. -/// -/// Computes which variables are live at each program point. -/// A variable is live if its value may be used on some path from -/// that point forward. -/// -/// # Example -/// -/// ```rust,ignore -/// use dotscope::analysis::{LiveVariables, DataFlowSolver}; -/// -/// let analysis = LiveVariables::new(&ssa); -/// let mut solver = DataFlowSolver::new(analysis); -/// let results = solver.solve(&ssa, &graph); -/// -/// // Check which variables are live at block exit -/// if let Some(live) = results.out_state(block_id) { -/// for var_id in live.variables() { -/// println!("Variable {} is live at exit of block {}", var_id, block_id); -/// } -/// } -/// ``` -pub struct LiveVariables { - /// Number of variables in the function. - num_vars: usize, - /// USE sets for each block (variables used before definition). - use_sets: Vec, - /// DEF sets for each block (variables defined). - def_sets: Vec, -} - -impl LiveVariables { - /// Creates a new live variables analysis for the given SSA function. - #[must_use] - pub fn new(ssa: &SsaFunction) -> Self { - let num_vars = ssa.variable_count(); - let num_blocks = ssa.block_count(); - - let mut use_sets = Vec::with_capacity(num_blocks); - let mut def_sets = Vec::with_capacity(num_blocks); - - // Phase 1: Initialize use/def sets without PHI operands - for block in ssa.blocks() { - let mut uses = BitSet::new(num_vars); - let mut defs = BitSet::new(num_vars); - - // Process phi nodes: they define variables - for phi in block.phi_nodes() { - if let Some(def_idx) = ssa.var_index(phi.result()) { - defs.insert(def_idx); - } - } - - // Process instructions in forward order - for instr in block.instructions() { - // Uses first (before def, since this is the "USE before DEF" set) - for &use_var in &instr.uses() { - if let Some(var_idx) = ssa.var_index(use_var) { - if !defs.contains(var_idx) { - uses.insert(var_idx); - } - } - } - - // Then definition - if let Some(def) = instr.def() { - if let Some(def_idx) = ssa.var_index(def) { - defs.insert(def_idx); - } - } - } - - use_sets.push(uses); - def_sets.push(defs); - } - - // Phase 2: Add PHI operand uses to their PREDECESSOR blocks. - // A PHI operand `v<-B_pred` means variable v is used at the END - // of B_pred (ECMA-335 / SSA semantics: phi copies happen at the - // predecessor's outgoing edge). Placing the use in the predecessor - // ensures backward dataflow propagates liveness from the predecessor - // back through all intermediate blocks to the definition. - for block in ssa.blocks() { - for phi in block.phi_nodes() { - for op in phi.operands() { - let pred = op.predecessor(); - if let Some(var_idx) = ssa.var_index(op.value()) { - let already_def = def_sets.get(pred).is_some_and(|s| s.contains(var_idx)); - if !already_def { - if let Some(slot) = use_sets.get_mut(pred) { - slot.insert(var_idx); - } - } - } - } - } - } - - Self { - num_vars, - use_sets, - def_sets, - } - } - - /// Returns the number of variables being tracked. - #[must_use] - pub const fn num_variables(&self) -> usize { - self.num_vars - } - - /// Returns the USE set for a block. - #[must_use] - pub fn use_set(&self, block: usize) -> Option<&BitSet> { - self.use_sets.get(block) - } - - /// Returns the DEF set for a block. - #[must_use] - pub fn def_set(&self, block: usize) -> Option<&BitSet> { - self.def_sets.get(block) - } -} - -impl DataFlowAnalysis for LiveVariables { - type Lattice = LivenessResult; - const DIRECTION: Direction = Direction::Backward; - - fn boundary(&self, _ssa: &SsaFunction) -> Self::Lattice { - // At function exit, no variables are live - // (unless we're tracking return values, which we could add) - LivenessResult { - live: BitSet::new(self.num_vars), - } - } - - fn initial(&self, _ssa: &SsaFunction) -> Self::Lattice { - // Initially, no variables are live - LivenessResult { - live: BitSet::new(self.num_vars), - } - } - - fn transfer( - &self, - block_id: usize, - _block: &SsaBlock, - output: &Self::Lattice, - _ssa: &SsaFunction, - ) -> Self::Lattice { - // For backward analysis: IN = USE ∪ (OUT - DEF) - let mut result = output.live.clone(); - - // Remove definitions (OUT - DEF) - if let Some(d) = self.def_sets.get(block_id) { - result.difference_with(d); - } - - // Add uses (USE ∪ ...) - if let Some(u) = self.use_sets.get(block_id) { - result.union_with(u); - } - - LivenessResult { live: result } - } -} - -/// Result of live variable analysis for a single program point. -#[derive(Debug, Clone, PartialEq)] -pub struct LivenessResult { - /// Bit vector of live variables (indexed by `SsaVarId`). - live: BitSet, -} - -impl LivenessResult { - /// Creates a new empty result. - #[must_use] - pub fn new(num_vars: usize) -> Self { - Self { - live: BitSet::new(num_vars), - } - } - - /// Returns `true` if the given variable is live at this point. - #[must_use] - pub fn is_live(&self, var: SsaVarId) -> bool { - let idx = var.index(); - idx < self.live.len() && self.live.contains(idx) - } - - /// Returns an iterator over all live variables. - pub fn variables(&self) -> impl Iterator + '_ { - self.live.iter().map(SsaVarId::from_index) - } - - /// Returns the number of live variables. - #[must_use] - pub fn count(&self) -> usize { - self.live.count() - } - - /// Returns `true` if no variables are live. - #[must_use] - pub fn is_empty(&self) -> bool { - self.live.is_empty() - } - - /// Marks a variable as live. - pub fn add(&mut self, var: SsaVarId) { - let idx = var.index(); - if idx < self.live.len() { - self.live.insert(idx); - } - } - - /// Marks a variable as not live. - pub fn remove(&mut self, var: SsaVarId) { - let idx = var.index(); - if idx < self.live.len() { - self.live.remove(idx); - } - } - - /// Returns the underlying bit set. - #[must_use] - pub const fn as_bitset(&self) -> &BitSet { - &self.live - } -} - -impl MeetSemiLattice for LivenessResult { - /// Meet is union (a variable is live if it's live on ANY successor path). - fn meet(&self, other: &Self) -> Self { - let mut result = self.live.clone(); - result.union_with(&other.live); - Self { live: result } - } - - fn is_bottom(&self) -> bool { - // Bottom is when all variables are live (full set) - self.live.count() == self.live.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_liveness_result() { - let mut result = LivenessResult::new(10); - assert!(result.is_empty()); - - result.add(SsaVarId::from_index(0)); - result.add(SsaVarId::from_index(5)); - - assert!(!result.is_empty()); - assert_eq!(result.count(), 2); - assert!(result.is_live(SsaVarId::from_index(0))); - assert!(result.is_live(SsaVarId::from_index(5))); - assert!(!result.is_live(SsaVarId::from_index(1))); - - result.remove(SsaVarId::from_index(0)); - assert!(!result.is_live(SsaVarId::from_index(0))); - assert_eq!(result.count(), 1); - } - - #[test] - fn test_liveness_meet() { - let mut a = LivenessResult::new(10); - let mut b = LivenessResult::new(10); - - a.add(SsaVarId::from_index(0)); - a.add(SsaVarId::from_index(1)); - b.add(SsaVarId::from_index(1)); - b.add(SsaVarId::from_index(2)); - - let result = a.meet(&b); - assert!(result.is_live(SsaVarId::from_index(0))); - assert!(result.is_live(SsaVarId::from_index(1))); - assert!(result.is_live(SsaVarId::from_index(2))); - assert_eq!(result.count(), 3); - } - - #[test] - fn test_liveness_iterator() { - let mut result = LivenessResult::new(100); - result.add(SsaVarId::from_index(5)); - result.add(SsaVarId::from_index(42)); - result.add(SsaVarId::from_index(99)); - - let vars: Vec<_> = result.variables().collect(); - assert_eq!(vars.len(), 3); - assert!(vars.contains(&SsaVarId::from_index(5))); - assert!(vars.contains(&SsaVarId::from_index(42))); - assert!(vars.contains(&SsaVarId::from_index(99))); - } -} diff --git a/dotscope/src/analysis/dataflow/mod.rs b/dotscope/src/analysis/dataflow/mod.rs index 213f0ebf..391aff7b 100644 --- a/dotscope/src/analysis/dataflow/mod.rs +++ b/dotscope/src/analysis/dataflow/mod.rs @@ -1,67 +1,53 @@ -//! Data flow analysis framework for SSA form. -//! -//! This module provides a generic framework for computing properties that -//! propagate along control flow edges. It supports both forward and backward -//! analyses using a worklist-based solver. -//! -//! # Architecture -//! -//! The framework is built around three core abstractions: -//! -//! - **Lattice**: Defines the domain of abstract values with meet/join operations -//! - **Analysis**: Specifies transfer functions and boundary conditions -//! - **Solver**: Iteratively computes fixpoints using a worklist algorithm -//! -//! # Analyses Provided -//! -//! - [`ReachingDefinitions`]: Tracks which definitions may reach each program point -//! - [`LiveVariables`]: Determines which variables are live at each program point -//! - [`ConstantPropagation`]: Sparse conditional constant propagation (SCCP) -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::analysis::{ConstantPropagation, DataFlowSolver, SsaFunction}; -//! -//! // Build SSA form -//! let ssa = SsaConverter::build(&graph, num_args, num_locals, resolver)?; -//! -//! // Run constant propagation -//! let analysis = ConstantPropagation::new(PointerSize::Bit64); -//! let mut solver = DataFlowSolver::new(analysis, &ssa, &graph); -//! solver.solve(); -//! -//! // Query results -//! for var in ssa.variables() { -//! if let Some(value) = solver.get_value(var.id()) { -//! println!("{}: {}", var.id(), value); -//! } -//! } -//! ``` -//! -//! # Thread Safety -//! -//! All types in this module are `Send` and `Sync`. - -mod framework; -mod lattice; -mod liveness; -mod reaching; -mod sccp; -mod solver; - -pub use framework::{AnalysisResults, DataFlowAnalysis, Direction}; -pub use liveness::{LiveVariables, LivenessResult}; -pub use reaching::ReachingDefinitions; -pub use sccp::{ConstantPropagation, ScalarValue, SccpResult}; -pub use solver::DataFlowSolver; +//! Re-export shim — dataflow framework + analyses live in +//! `analyssa::analysis::dataflow`. CIL-defaulted aliases preserve historical +//! `dotscope::analysis::*` API. + +use analyssa::graph::{NodeId, RootedGraph}; + +use crate::analysis::{ssa::CilTarget, ControlFlowGraph}; + +pub use analyssa::analysis::dataflow::{ + framework::{AnalysisResults, DataFlowAnalysis, DataFlowCfg, Direction}, + liveness::{LiveVariables, LivenessResult}, + reaching::ReachingDefinitions, +}; + +/// CIL-defaulted alias of [`analyssa::analysis::dataflow::sccp::ConstantPropagation`]. +pub type ConstantPropagation = + analyssa::analysis::dataflow::sccp::ConstantPropagation; +/// CIL-defaulted alias of [`analyssa::analysis::dataflow::sccp::ScalarValue`]. +pub type ScalarValue = analyssa::analysis::dataflow::sccp::ScalarValue; +/// CIL-defaulted alias of [`analyssa::analysis::dataflow::sccp::SccpResult`]. +pub type SccpResult = analyssa::analysis::dataflow::sccp::SccpResult; +/// CIL-defaulted alias of [`analyssa::analysis::dataflow::solver::DataFlowSolver`]. +pub type DataFlowSolver = + analyssa::analysis::dataflow::solver::DataFlowSolver; + +// `DataFlowCfg` impl for the CIL `ControlFlowGraph`. The SsaCfg impl lives +// analyssa-side. Both must be available so dataflow analyses run on either CFG. +impl DataFlowCfg for ControlFlowGraph<'_> { + fn entry(&self) -> NodeId { + RootedGraph::entry(self) + } + + fn exits(&self) -> Vec { + self.exits().to_vec() + } + fn postorder(&self) -> Vec { + self.postorder() + } + + fn reverse_postorder(&self) -> Vec { + self.reverse_postorder() + } +} #[cfg(test)] mod tests { use super::*; use crate::{ - analysis::{cfg::ControlFlowGraph, ssa::SsaConverter}, + analysis::{cfg::ControlFlowGraph, ssa::SsaConverter, SsaFunction}, assembly::{decode_blocks, InstructionAssembler}, metadata::typesystem::PointerSize, test::TestTypeProvider, @@ -80,7 +66,7 @@ mod tests { assembler: InstructionAssembler, num_args: usize, num_locals: usize, - ) -> (crate::analysis::SsaFunction, ControlFlowGraph<'static>) { + ) -> (SsaFunction, ControlFlowGraph<'static>) { let cfg = build_cfg(assembler); let ssa = SsaConverter::build( &cfg, diff --git a/dotscope/src/analysis/dataflow/reaching.rs b/dotscope/src/analysis/dataflow/reaching.rs deleted file mode 100644 index 2e3205ba..00000000 --- a/dotscope/src/analysis/dataflow/reaching.rs +++ /dev/null @@ -1,274 +0,0 @@ -//! Reaching definitions analysis. -//! -//! Reaching definitions computes, for each program point, which variable -//! definitions may reach that point without being killed by an intervening -//! definition of the same variable. -//! -//! # SSA Form -//! -//! In SSA form, each variable is defined exactly once, so reaching definitions -//! is simplified: a definition reaches a use if and only if there's a path -//! from the definition to the use. This is always true in well-formed SSA. -//! -//! However, this analysis is still useful for: -//! - Validating SSA construction -//! - Computing def-use chains -//! - Detecting dead definitions -//! -//! # Algorithm -//! -//! For each block B: -//! - `GEN[B]` = definitions created in B -//! - `KILL[B]` = definitions killed in B (in SSA: none, since each var is defined once) -//! - `IN[B]` = ∪{OUT[P] | P is a predecessor of B} -//! - `OUT[B]` = GEN[B] ∪ (IN[B] - KILL[B]) -//! -//! Since SSA has no kills, this simplifies to: -//! - `OUT[B]` = GEN[B] ∪ IN[B] - -use crate::{ - analysis::{ - dataflow::{ - framework::{DataFlowAnalysis, Direction}, - lattice::MeetSemiLattice, - }, - SsaBlock, SsaFunction, SsaVarId, - }, - utils::BitSet, -}; - -/// Reaching definitions analysis. -/// -/// Computes which variable definitions may reach each block. -/// -/// # Example -/// -/// ```rust,ignore -/// use dotscope::analysis::{ReachingDefinitions, DataFlowSolver}; -/// -/// let analysis = ReachingDefinitions::new(&ssa); -/// let mut solver = DataFlowSolver::new(analysis); -/// let results = solver.solve(&ssa, &graph); -/// -/// // Check which definitions reach a block -/// if let Some(reaching) = results.in_state(block_id) { -/// for var_id in reaching.definitions() { -/// println!("Definition {} reaches block {}", var_id, block_id); -/// } -/// } -/// ``` -pub struct ReachingDefinitions { - /// Number of variables in the function. - num_vars: usize, - /// GEN sets for each block (definitions created in the block). - gen_sets: Vec, -} - -impl ReachingDefinitions { - /// Creates a new reaching definitions analysis for the given SSA function. - #[must_use] - pub fn new(ssa: &SsaFunction) -> Self { - let num_vars = ssa.variable_count(); - let num_blocks = ssa.block_count(); - - // Compute GEN sets - let mut gen_sets = Vec::with_capacity(num_blocks); - - for block in ssa.blocks() { - let mut gen = BitSet::new(num_vars); - - // Phi nodes define variables - for phi in block.phi_nodes() { - if let Some(idx) = ssa.var_index(phi.result()) { - gen.insert(idx); - } - } - - // Instructions may define variables - for instr in block.instructions() { - if let Some(def) = instr.def() { - if let Some(idx) = ssa.var_index(def) { - gen.insert(idx); - } - } - } - - gen_sets.push(gen); - } - - Self { num_vars, gen_sets } - } - - /// Returns the number of variables being tracked. - #[must_use] - pub const fn num_variables(&self) -> usize { - self.num_vars - } -} - -impl DataFlowAnalysis for ReachingDefinitions { - type Lattice = ReachingDefsResult; - const DIRECTION: Direction = Direction::Forward; - - fn boundary(&self, ssa: &SsaFunction) -> Self::Lattice { - // At function entry, the initial definitions of arguments and locals reach - let mut defs = BitSet::new(self.num_vars); - - // Arguments and locals have initial definitions (version 0) - for (idx, var) in ssa.variables().iter().enumerate() { - if var.version() == 0 && (var.origin().is_argument() || var.origin().is_local()) { - defs.insert(idx); - } - } - - ReachingDefsResult { defs } - } - - fn initial(&self, _ssa: &SsaFunction) -> Self::Lattice { - // Initially, no definitions reach interior blocks - ReachingDefsResult { - defs: BitSet::new(self.num_vars), - } - } - - fn transfer( - &self, - block_id: usize, - _block: &SsaBlock, - input: &Self::Lattice, - _ssa: &SsaFunction, - ) -> Self::Lattice { - // OUT = GEN ∪ IN (no KILL in SSA since each variable is defined once) - let mut result = input.defs.clone(); - if let Some(gen) = self.gen_sets.get(block_id) { - result.union_with(gen); - } - ReachingDefsResult { defs: result } - } -} - -/// Result of reaching definitions analysis for a single program point. -#[derive(Debug, Clone, PartialEq)] -pub struct ReachingDefsResult { - /// Bit vector of reaching definitions (indexed by `SsaVarId`). - defs: BitSet, -} - -impl ReachingDefsResult { - /// Creates a new empty result. - #[must_use] - pub fn new(num_vars: usize) -> Self { - Self { - defs: BitSet::new(num_vars), - } - } - - /// Returns `true` if the given variable's definition reaches this point. - #[must_use] - pub fn reaches(&self, var: SsaVarId) -> bool { - let idx = var.index(); - idx < self.defs.len() && self.defs.contains(idx) - } - - /// Returns an iterator over all reaching definitions. - pub fn definitions(&self) -> impl Iterator + '_ { - self.defs.iter().map(SsaVarId::from_index) - } - - /// Returns the number of reaching definitions. - #[must_use] - pub fn count(&self) -> usize { - self.defs.count() - } - - /// Returns `true` if no definitions reach this point. - #[must_use] - pub fn is_empty(&self) -> bool { - self.defs.is_empty() - } - - /// Adds a definition to the reaching set. - pub fn add(&mut self, var: SsaVarId) { - let idx = var.index(); - if idx < self.defs.len() { - self.defs.insert(idx); - } - } - - /// Removes a definition from the reaching set. - pub fn remove(&mut self, var: SsaVarId) { - let idx = var.index(); - if idx < self.defs.len() { - self.defs.remove(idx); - } - } -} - -impl MeetSemiLattice for ReachingDefsResult { - /// Meet is union (may analysis: a definition reaches if it reaches from ANY predecessor). - fn meet(&self, other: &Self) -> Self { - let mut result = self.defs.clone(); - result.union_with(&other.defs); - Self { defs: result } - } - - fn is_bottom(&self) -> bool { - // Bottom is when all definitions reach (full set) - self.defs.count() == self.defs.len() - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_reaching_defs_result() { - let mut result = ReachingDefsResult::new(10); - assert!(result.is_empty()); - - result.add(SsaVarId::from_index(0)); - result.add(SsaVarId::from_index(5)); - - assert!(!result.is_empty()); - assert_eq!(result.count(), 2); - assert!(result.reaches(SsaVarId::from_index(0))); - assert!(result.reaches(SsaVarId::from_index(5))); - assert!(!result.reaches(SsaVarId::from_index(1))); - - result.remove(SsaVarId::from_index(0)); - assert!(!result.reaches(SsaVarId::from_index(0))); - assert_eq!(result.count(), 1); - } - - #[test] - fn test_reaching_defs_meet() { - let mut a = ReachingDefsResult::new(10); - let mut b = ReachingDefsResult::new(10); - - a.add(SsaVarId::from_index(0)); - a.add(SsaVarId::from_index(1)); - b.add(SsaVarId::from_index(1)); - b.add(SsaVarId::from_index(2)); - - let result = a.meet(&b); - assert!(result.reaches(SsaVarId::from_index(0))); - assert!(result.reaches(SsaVarId::from_index(1))); - assert!(result.reaches(SsaVarId::from_index(2))); - assert_eq!(result.count(), 3); - } - - #[test] - fn test_reaching_defs_iterator() { - let mut result = ReachingDefsResult::new(100); - result.add(SsaVarId::from_index(5)); - result.add(SsaVarId::from_index(42)); - result.add(SsaVarId::from_index(99)); - - let defs: Vec<_> = result.definitions().collect(); - assert_eq!(defs.len(), 3); - assert!(defs.contains(&SsaVarId::from_index(5))); - assert!(defs.contains(&SsaVarId::from_index(42))); - assert!(defs.contains(&SsaVarId::from_index(99))); - } -} diff --git a/dotscope/src/analysis/dataflow/sccp.rs b/dotscope/src/analysis/dataflow/sccp.rs deleted file mode 100644 index 0a31566c..00000000 --- a/dotscope/src/analysis/dataflow/sccp.rs +++ /dev/null @@ -1,767 +0,0 @@ -//! Sparse Conditional Constant Propagation (SCCP). -//! -//! SCCP is a powerful constant propagation algorithm that combines: -//! -//! 1. **Sparse analysis**: Works directly on SSA def-use chains rather than -//! iterating over all program points -//! 2. **Conditional propagation**: Uses constant branch conditions to prune -//! unreachable code paths -//! -//! # Algorithm Overview -//! -//! SCCP maintains two lattices: -//! - **Value lattice**: For each SSA variable, tracks whether it's Top (unknown), -//! Constant (known value), or Bottom (multiple values) -//! - **CFG reachability**: Tracks which CFG edges are executable -//! -//! The algorithm uses two worklists: -//! - **SSA worklist**: Variables whose values have changed -//! - **CFG worklist**: CFG edges that have become executable -//! -//! # Edge-Based Phi Evaluation -//! -//! A key insight from Wegman & Zadeck is that phi nodes should be evaluated -//! based on which **edges** are executable, not which blocks are reachable. -//! This is critical for precision: a block may be reachable via multiple edges, -//! but only some of those edges may have been discovered yet. -//! -//! For example, in a diamond CFG: -//! ```text -//! B0 -//! / \ -//! B1 B2 -//! \ / -//! B3 -//! ``` -//! If only the edge B0→B1→B3 is executable (because the branch in B0 is constant), -//! the phi in B3 should only consider the operand from B1, not B2. -//! -//! # Differences from Standard Solver -//! -//! Unlike the generic solver, SCCP doesn't use block-level transfer functions. -//! Instead, it processes individual SSA instructions and phi nodes directly, -//! which is more efficient for sparse analyses. -//! -//! # Reference -//! -//! Wegman & Zadeck, "Constant Propagation with Conditional Branches", 1991. - -use std::collections::{BTreeMap, BTreeSet, VecDeque}; - -use crate::{ - analysis::{ - dataflow::lattice::MeetSemiLattice, ssa::evaluate_const_op, ConstValue, PhiNode, SsaBlock, - SsaFunction, SsaOp, SsaVarId, - }, - metadata::typesystem::PointerSize, - utils::{ - graph::{NodeId, RootedGraph, Successors}, - BitSet, - }, -}; - -/// Sparse Conditional Constant Propagation analysis. -/// -/// This analysis computes which SSA variables have constant values, -/// taking into account that some branches may never be taken. -/// -/// # Example -/// -/// ```rust,ignore -/// use dotscope::analysis::{ConstantPropagation, ScalarValue}; -/// -/// let mut sccp = ConstantPropagation::new(PointerSize::Bit64); -/// let results = sccp.analyze(&ssa, &graph); -/// -/// // Check if a variable is constant -/// if let Some(ScalarValue::Constant(c)) = results.get_value(var_id) { -/// println!("{} = {}", var_id, c); -/// } -/// ``` -pub struct ConstantPropagation { - /// Current value for each SSA variable. - values: BTreeMap, - /// Executable CFG edges. - executable_edges: BTreeSet<(usize, usize)>, - /// Blocks that have been marked executable. - executable_blocks: BitSet, - /// SSA worklist: variables whose values have changed. - ssa_worklist: VecDeque, - /// CFG worklist: edges that have become executable. - cfg_worklist: VecDeque<(usize, usize)>, - /// Back edges: edges where the target was already executable when the edge was added. - /// These represent loop back edges and their values should be treated as unknown. - back_edges: BTreeSet<(usize, usize)>, - /// Target pointer size for native int/uint masking. - pointer_size: PointerSize, -} - -impl ConstantPropagation { - /// Creates a new constant propagation analysis. - /// - /// # Arguments - /// - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn new(ptr_size: PointerSize) -> Self { - Self { - values: BTreeMap::new(), - executable_edges: BTreeSet::new(), - executable_blocks: BitSet::new(0), - ssa_worklist: VecDeque::new(), - cfg_worklist: VecDeque::new(), - back_edges: BTreeSet::new(), - pointer_size: ptr_size, - } - } - - /// Runs the SCCP algorithm on the given SSA function. - /// - /// The CFG parameter can be any type that implements the required graph traits: - /// - `RootedGraph` for the entry point - /// - `Successors` for traversing outgoing edges - /// - /// This allows using both `ControlFlowGraph` (from CIL blocks) and `SsaCfg` - /// (from SSA function terminators). - /// - /// Returns the analysis results containing the value for each variable. - pub fn analyze(&mut self, ssa: &SsaFunction, cfg: &G) -> SccpResult - where - G: RootedGraph + Successors, - { - self.initialize(ssa, cfg); - self.propagate(ssa, cfg); - - SccpResult { - values: self.values.clone(), - executable_blocks: self.executable_blocks.clone(), - } - } - - /// Initializes the analysis state. - fn initialize(&mut self, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - self.values.clear(); - self.executable_edges.clear(); - self.executable_blocks = BitSet::new(ssa.block_count()); - self.ssa_worklist.clear(); - self.cfg_worklist.clear(); - self.back_edges.clear(); - - // Initialize variable values: - // - Argument variables (version 0, defined at entry) start as Bottom (unknown input) - // - All other variables start as Top (no information yet) - // - // This distinction is critical: arguments are external inputs that could be anything, - // while other variables are defined by instructions that SCCP will evaluate. - // Without this, branch conditions depending on arguments stay at Top forever - // (since no instruction defines them), causing the branch to never add edges. - for var in ssa.variables() { - let initial_value = if var.origin().is_argument() - && var.version() == 0 - && var.def_site().instruction.is_none() - { - // This is the initial definition of an argument - it's an unknown input - ScalarValue::Bottom - } else { - // Regular variable - will be evaluated by instructions - ScalarValue::Top - }; - self.values.insert(var.id(), initial_value); - } - - // Mark entry block as executable - let entry = cfg.entry().index(); - self.executable_blocks.insert(entry); - - // Add entry block's outgoing edges to CFG worklist - // For unconditional edges or first visit, add all successors - for succ in cfg.successors(cfg.entry()) { - self.cfg_worklist.push_back((entry, succ.index())); - } - - // Process entry block definitions immediately - if let Some(block) = ssa.block(entry) { - self.process_block_definitions(block); - } - } - - /// Main propagation loop. - fn propagate(&mut self, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - // Process until both worklists are empty - loop { - // Process CFG worklist first (to discover new blocks) - while let Some((from, to)) = self.cfg_worklist.pop_front() { - if self.executable_edges.insert((from, to)) { - // Detect back edges: if the target block was already executable - // when this edge is being added, it's a back edge (loop). - // PHI operands from back edges represent values that change - // across loop iterations and should be treated as unknown. - if self.executable_blocks.contains(to) { - self.back_edges.insert((from, to)); - } - // This edge became executable - self.process_edge(from, to, ssa, cfg); - } - } - - // Process SSA worklist - if let Some(var) = self.ssa_worklist.pop_front() { - self.process_variable_uses(var, ssa, cfg); - } else { - // Both worklists empty - break; - } - } - } - - /// Processes a newly executable CFG edge. - /// - /// When an edge `(from, to)` becomes executable: - /// 1. If this is the first edge reaching `to`, mark the block executable and - /// process all its definitions - /// 2. Re-evaluate all phi nodes in `to` since they may now have a new operand - /// from the `from` block - /// 3. If first visit, propagate outgoing edges based on the terminator - fn process_edge(&mut self, from: usize, to: usize, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - let first_visit = !self.executable_blocks.contains(to); - - if first_visit { - self.executable_blocks.insert(to); - - // Process all definitions in the block - if let Some(block) = ssa.block(to) { - self.process_block_definitions(block); - } - } - - // Re-evaluate phi nodes in the target block. - // The new edge (from, to) may contribute a new operand value. - if let Some(block) = ssa.block(to) { - for phi in block.phi_nodes() { - // Only re-evaluate if this phi has an operand from the `from` block - if phi.operand_from(from).is_some() { - let new_value = self.evaluate_phi(phi, to); - self.update_value(phi.result(), &new_value); - } - } - } - - // If first visit, propagate outgoing edges based on terminator - if first_visit { - if let Some(block) = ssa.block(to) { - self.propagate_outgoing_edges(to, block, cfg); - } - } - } - - /// Processes all definitions in a block (non-phi instructions). - /// - /// This evaluates each instruction and updates the value lattice for any - /// variables defined by the instruction. - fn process_block_definitions(&mut self, block: &SsaBlock) { - for instr in block.instructions() { - if let Some(def) = instr.def() { - let value = self.evaluate_instruction(instr.op()); - self.update_value(def, &value); - } - } - } - - /// Processes uses of a variable whose value changed. - fn process_variable_uses(&mut self, var: SsaVarId, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - // Find all uses of this variable - if let Some(ssa_var) = ssa.variable(var) { - for use_site in ssa_var.uses() { - let block_id = use_site.block; - - // Skip if block is not executable - if !self.executable_blocks.contains(block_id) { - continue; - } - - if use_site.is_phi_operand { - // Re-evaluate the phi node - if let Some(block) = ssa.block(block_id) { - if let Some(phi) = block.phi(use_site.instruction) { - let new_value = self.evaluate_phi(phi, block_id); - self.update_value(phi.result(), &new_value); - } - } - } else { - // Re-evaluate the instruction - if let Some(block) = ssa.block(block_id) { - if let Some(instr) = block.instruction(use_site.instruction) { - if let Some(def) = instr.def() { - let value = self.evaluate_instruction(instr.op()); - self.update_value(def, &value); - } - - // Check if this is a branch instruction - if instr.is_terminator() { - self.propagate_outgoing_edges(block_id, block, cfg); - } - } - } - } - } - } - } - - /// Propagates outgoing edges from a block based on terminator. - fn propagate_outgoing_edges(&mut self, block_id: usize, block: &SsaBlock, cfg: &G) - where - G: RootedGraph + Successors, - { - // Find the terminator instruction - match block.terminator_op() { - Some(SsaOp::Branch { - condition, - true_target, - false_target, - }) => { - // Conditional branch - check if condition is constant - match self.get_value(*condition) { - ScalarValue::Constant(c) => { - // Known branch direction - let target = if c.as_bool() == Some(true) { - *true_target - } else { - *false_target - }; - self.add_cfg_edge(block_id, target); - } - ScalarValue::Top => { - // Unknown - don't add edges yet - } - ScalarValue::Bottom => { - // Could go either way - add both edges - self.add_cfg_edge(block_id, *true_target); - self.add_cfg_edge(block_id, *false_target); - } - } - } - Some(SsaOp::Switch { - value, - targets, - default, - }) => { - // Switch statement - match self.get_value(*value) { - ScalarValue::Constant(c) => { - // Known switch value - use checked conversion to handle negative values - if let Some(idx) = c.as_i32().and_then(|i| usize::try_from(i).ok()) { - if let Some(target) = targets.get(idx) { - self.add_cfg_edge(block_id, *target); - } else { - self.add_cfg_edge(block_id, *default); - } - } else { - self.add_cfg_edge(block_id, *default); - } - } - ScalarValue::Top | ScalarValue::Bottom => { - // Unknown or could be anything - conservatively add all edges. - // This is critical for control flow obfuscation where the switch - // value is computed dynamically and cannot be statically determined. - for &target in targets { - self.add_cfg_edge(block_id, target); - } - self.add_cfg_edge(block_id, *default); - } - } - } - Some(SsaOp::Jump { target }) => { - // Unconditional jump - self.add_cfg_edge(block_id, *target); - } - Some(SsaOp::Return { .. } | SsaOp::Throw { .. } | SsaOp::Rethrow) => { - // No successors - } - _ => { - // Fall through or unknown terminator - add all CFG successors - let node = NodeId::new(block_id); - for succ in cfg.successors(node) { - self.add_cfg_edge(block_id, succ.index()); - } - } - } - } - - /// Adds a CFG edge to the worklist if not already executable. - fn add_cfg_edge(&mut self, from: usize, to: usize) { - if !self.executable_edges.contains(&(from, to)) { - self.cfg_worklist.push_back((from, to)); - } - } - - /// Evaluates a phi node to get its current value. - /// - /// This is the key to SCCP's precision: we only consider operands from - /// **executable edges**, not just reachable blocks. This allows us to - /// propagate constants through conditional branches more precisely. - /// - /// For example, if we have: - /// ```text - /// B0: if (true) goto B1 else goto B2 - /// B1: x = 5; goto B3 - /// B2: x = 10; goto B3 - /// B3: y = phi(x from B1, x from B2) - /// ``` - /// Even though B3 is reachable, only the edge B1→B3 is executable (because - /// the branch condition is constant true). So y = 5, not bottom. - /// - /// # Arguments - /// - /// * `phi` - The phi node to evaluate - /// * `block_id` - The block containing this phi node (needed to check edge executability) - fn evaluate_phi(&self, phi: &PhiNode, block_id: usize) -> ScalarValue { - let mut result = ScalarValue::Top; - let mut has_executable_operand = false; - - for operand in phi.operands() { - let pred = operand.predecessor(); - - // The key SCCP insight: only consider this operand if the specific - // edge (pred -> block_id) is executable, not just if pred is reachable. - if !self.executable_edges.contains(&(pred, block_id)) { - continue; - } - - has_executable_operand = true; - - // For back edges (loop edges), treat the operand value as Bottom. - // Back edge values represent loop-carried dependencies that change - // across iterations. Using the first-iteration value would incorrectly - // mark the PHI as constant when it's actually varying. - // - // Example: Fibonacci loop where b = phi(1, temp) - // - First iteration: temp = 0 + 1 = 1, so b = phi(1, 1) looks constant - // - But iteration 2: temp = 1 + 1 = 2, so b should be 2 - // Without this check, SCCP would incorrectly conclude b is always 1. - let op_value = if self.back_edges.contains(&(pred, block_id)) { - ScalarValue::Bottom - } else { - self.get_value(operand.value()) - }; - result = result.meet(&op_value); - - // Early exit if already bottom - if result.is_bottom() { - break; - } - } - - // If no operands were from executable edges, return Top (no information yet) - if !has_executable_operand { - return ScalarValue::Top; - } - - result - } - - /// Evaluates an SSA instruction to get its result value. - /// - /// This performs abstract interpretation of the instruction, computing - /// what value the result would have given the current lattice values - /// of the operands. Delegates to [`evaluate_const_op`] for arithmetic - /// dispatch, while handling lattice Top/Bottom propagation locally. - fn evaluate_instruction(&self, op: &SsaOp) -> ScalarValue { - // Copy propagates the source's lattice value directly. - if let SsaOp::Copy { src, .. } = op { - return self.get_value(*src); - } - - // Delegate arithmetic to the shared constant evaluator. - // Track whether any operand was Top (unknown) vs Bottom (varying) - // so the lattice result is correct. - let mut saw_top = false; - let ptr_size = self.pointer_size; - let result = evaluate_const_op( - op, - |var| match self.get_value(var) { - ScalarValue::Constant(c) => Some(c), - ScalarValue::Top => { - saw_top = true; - None - } - ScalarValue::Bottom => None, - }, - ptr_size, - ); - - match result { - Some(c) => ScalarValue::Constant(c), - // If the shared evaluator returned None but an operand was Top, - // the result is still unknown (Top), not varying (Bottom). - None if saw_top => ScalarValue::Top, - None => ScalarValue::Bottom, - } - } - - /// Gets the current value of a variable. - fn get_value(&self, var: SsaVarId) -> ScalarValue { - self.values.get(&var).cloned().unwrap_or_default() - } - - /// Updates a variable's value and adds it to the worklist if changed. - fn update_value(&mut self, var: SsaVarId, new_value: &ScalarValue) { - let old_value = self.values.get(&var).cloned().unwrap_or_default(); - - // Apply meet to move down the lattice (values can only decrease) - let final_value = old_value.meet(new_value); - - if final_value != old_value { - self.values.insert(var, final_value); - self.ssa_worklist.push_back(var); - } - } -} - -/// Scalar value in the SCCP lattice. -/// -/// This forms a simple three-level lattice: -/// - Top: No information (might be any value) -/// - Constant: Known compile-time constant -/// - Bottom: Not a constant (multiple possible values) -#[derive(Debug, Clone, PartialEq, Default)] -pub enum ScalarValue { - /// No information yet (top of lattice). - #[default] - Top, - /// Known constant value. - Constant(ConstValue), - /// Multiple possible values (bottom of lattice). - Bottom, -} - -impl ScalarValue { - /// Returns `true` if this is the top element. - #[must_use] - pub const fn is_top(&self) -> bool { - matches!(self, Self::Top) - } - - /// Returns `true` if this is the bottom element. - #[must_use] - pub const fn is_bottom(&self) -> bool { - matches!(self, Self::Bottom) - } - - /// Returns `true` if this is a known constant. - #[must_use] - pub const fn is_constant(&self) -> bool { - matches!(self, Self::Constant(_)) - } - - /// Returns the constant value if this is a constant. - #[must_use] - pub const fn as_constant(&self) -> Option<&ConstValue> { - match self { - Self::Constant(c) => Some(c), - _ => None, - } - } -} - -impl MeetSemiLattice for ScalarValue { - fn meet(&self, other: &Self) -> Self { - match (self, other) { - // Top meets anything yields the other - (Self::Top, x) | (x, Self::Top) => x.clone(), - - // Same constants stay constant - (Self::Constant(a), Self::Constant(b)) if a == b => Self::Constant(a.clone()), - - // Different constants or anything with bottom yields bottom - _ => Self::Bottom, - } - } - - fn is_bottom(&self) -> bool { - matches!(self, Self::Bottom) - } -} - -/// Results of SCCP analysis. -#[derive(Debug, Clone)] -pub struct SccpResult { - /// Value for each SSA variable. - values: BTreeMap, - /// Blocks determined to be executable. - executable_blocks: BitSet, -} - -impl SccpResult { - /// Creates an empty SCCP result. - /// - /// This is useful for testing or when no analysis has been performed. - #[must_use] - pub fn empty() -> Self { - Self { - values: BTreeMap::new(), - executable_blocks: BitSet::new(0), - } - } - - /// Gets the value of an SSA variable. - #[must_use] - pub fn get_value(&self, var: SsaVarId) -> Option<&ScalarValue> { - self.values.get(&var) - } - - /// Returns `true` if a variable is known to be constant. - #[must_use] - pub fn is_constant(&self, var: SsaVarId) -> bool { - self.values - .get(&var) - .is_some_and(|v| matches!(v, ScalarValue::Constant(_))) - } - - /// Returns the constant value of a variable if known. - #[must_use] - pub fn constant_value(&self, var: SsaVarId) -> Option<&ConstValue> { - self.values.get(&var).and_then(|v| match v { - ScalarValue::Constant(c) => Some(c), - _ => None, - }) - } - - /// Returns `true` if a block is executable (reachable). - #[must_use] - pub fn is_block_executable(&self, block: usize) -> bool { - self.executable_blocks.contains(block) - } - - /// Returns an iterator over all constant variables. - pub fn constants(&self) -> impl Iterator { - self.values.iter().filter_map(|(var, val)| match val { - ScalarValue::Constant(c) => Some((*var, c)), - _ => None, - }) - } - - /// Returns an iterator over all executable blocks. - pub fn executable_blocks(&self) -> impl Iterator + '_ { - self.executable_blocks.iter() - } - - /// Returns the number of variables found to be constant. - #[must_use] - pub fn constant_count(&self) -> usize { - self.values - .values() - .filter(|v| matches!(v, ScalarValue::Constant(_))) - .count() - } - - /// Returns the number of executable blocks. - #[must_use] - pub fn executable_block_count(&self) -> usize { - self.executable_blocks.count() - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - - use super::*; - - use crate::analysis::{dataflow::lattice::MeetSemiLattice, ConstValue, SsaVarId}; - - #[test] - fn test_scalar_value_meet() { - // Top meets anything yields the other - assert_eq!( - ScalarValue::Top.meet(&ScalarValue::Constant(ConstValue::I32(5))), - ScalarValue::Constant(ConstValue::I32(5)) - ); - - // Same constants stay constant - assert_eq!( - ScalarValue::Constant(ConstValue::I32(5)) - .meet(&ScalarValue::Constant(ConstValue::I32(5))), - ScalarValue::Constant(ConstValue::I32(5)) - ); - - // Different constants become bottom - assert_eq!( - ScalarValue::Constant(ConstValue::I32(5)) - .meet(&ScalarValue::Constant(ConstValue::I32(10))), - ScalarValue::Bottom - ); - - // Bottom meets anything yields bottom - assert_eq!( - ScalarValue::Bottom.meet(&ScalarValue::Constant(ConstValue::I32(5))), - ScalarValue::Bottom - ); - } - - #[test] - fn test_scalar_value_accessors() { - let top = ScalarValue::Top; - let const_val = ScalarValue::Constant(ConstValue::I32(42)); - let bottom = ScalarValue::Bottom; - - assert!(top.is_top()); - assert!(!top.is_constant()); - assert!(!top.is_bottom()); - - assert!(!const_val.is_top()); - assert!(const_val.is_constant()); - assert!(!const_val.is_bottom()); - assert_eq!(const_val.as_constant(), Some(&ConstValue::I32(42))); - - assert!(!bottom.is_top()); - assert!(!bottom.is_constant()); - assert!(bottom.is_bottom()); - } - - #[test] - fn test_sccp_result() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - - let mut values = BTreeMap::new(); - values.insert(v0, ScalarValue::Constant(ConstValue::I32(42))); - values.insert(v1, ScalarValue::Bottom); - values.insert(v2, ScalarValue::Top); - - let mut executable_blocks = BitSet::new(3); - executable_blocks.insert(0); - executable_blocks.insert(1); - - let result = SccpResult { - values, - executable_blocks, - }; - - assert!(result.is_constant(v0)); - assert!(!result.is_constant(v1)); - assert!(!result.is_constant(v2)); - - assert_eq!(result.constant_value(v0), Some(&ConstValue::I32(42))); - assert_eq!(result.constant_value(v1), None); - - assert!(result.is_block_executable(0)); - assert!(result.is_block_executable(1)); - assert!(!result.is_block_executable(2)); - - assert_eq!(result.constant_count(), 1); - assert_eq!(result.executable_block_count(), 2); - } -} diff --git a/dotscope/src/analysis/dataflow/solver.rs b/dotscope/src/analysis/dataflow/solver.rs deleted file mode 100644 index c3349c60..00000000 --- a/dotscope/src/analysis/dataflow/solver.rs +++ /dev/null @@ -1,403 +0,0 @@ -//! Worklist-based data flow solver. -//! -//! This module provides the iterative solver that computes fixpoints for -//! data flow analyses. It uses a worklist algorithm with reverse postorder -//! traversal for efficiency. -//! -//! # Algorithm -//! -//! The solver iterates until a fixpoint is reached: -//! -//! 1. Initialize all blocks with the initial value -//! 2. Set the boundary value at entry (forward) or exits (backward) -//! 3. Add all blocks to the worklist in reverse postorder -//! 4. While the worklist is non-empty: -//! a. Remove a block from the worklist -//! b. Compute the input by meeting values from predecessors/successors -//! c. Apply the transfer function to get the output -//! d. If the output changed, add affected blocks to the worklist -//! 5. Call the finalize hook for post-processing -//! -//! # Complexity -//! -//! For most analyses on reducible CFGs, the solver converges in O(n) iterations -//! where n is the number of blocks. The total work is O(n * h) where h is the -//! lattice height (number of times a value can decrease before hitting bottom). - -use std::collections::VecDeque; - -use crate::{ - analysis::{ - dataflow::{ - framework::{AnalysisResults, DataFlowAnalysis, DataFlowCfg, Direction}, - lattice::MeetSemiLattice, - }, - SsaFunction, - }, - utils::graph::NodeId, -}; - -/// Worklist-based data flow solver. -/// -/// This solver computes fixpoints for data flow analyses using an iterative -/// worklist algorithm. It supports both forward and backward analyses. -/// -/// # Usage -/// -/// ```rust,ignore -/// use dotscope::analysis::{DataFlowSolver, ReachingDefinitions}; -/// -/// let analysis = ReachingDefinitions::new(&ssa); -/// let mut solver = DataFlowSolver::new(analysis); -/// let results = solver.solve(&ssa, &graph); -/// -/// // Access results -/// let in_state = results.in_state(block_id); -/// ``` -pub struct DataFlowSolver { - /// The analysis being solved. - analysis: A, - /// Input state for each block. - in_states: Vec, - /// Output state for each block. - out_states: Vec, - /// Worklist of blocks to process. - worklist: VecDeque, - /// Whether each block is currently in the worklist (for deduplication). - in_worklist: Vec, - /// Number of iterations performed. - iterations: usize, -} - -impl DataFlowSolver { - /// Creates a new solver for the given analysis. - #[must_use] - pub fn new(analysis: A) -> Self { - Self { - analysis, - in_states: Vec::new(), - out_states: Vec::new(), - worklist: VecDeque::new(), - in_worklist: Vec::new(), - iterations: 0, - } - } - - /// Solves the data flow analysis to a fixpoint. - /// - /// Returns the analysis results containing input and output states - /// for each basic block. - pub fn solve( - mut self, - ssa: &SsaFunction, - cfg: &C, - ) -> AnalysisResults - where - A::Lattice: Clone, - { - let num_blocks = ssa.block_count(); - if num_blocks == 0 { - return AnalysisResults::new(Vec::new(), Vec::new()); - } - - // Initialize states - self.initialize(ssa, cfg); - - // Main iteration loop - self.iterate(ssa, cfg); - - // Finalize - self.analysis - .finalize(&self.in_states, &self.out_states, ssa); - - AnalysisResults::new(self.in_states, self.out_states) - } - - /// Returns the number of iterations performed. - #[must_use] - pub const fn iterations(&self) -> usize { - self.iterations - } - - /// Initializes the solver state. - fn initialize(&mut self, ssa: &SsaFunction, cfg: &C) - where - A::Lattice: Clone, - { - let num_blocks = ssa.block_count(); - let initial = self.analysis.initial(ssa); - let boundary = self.analysis.boundary(ssa); - - // Initialize all blocks with the initial value - self.in_states = vec![initial.clone(); num_blocks]; - self.out_states = vec![initial; num_blocks]; - self.in_worklist = vec![false; num_blocks]; - - // Set boundary conditions based on direction - match A::DIRECTION { - Direction::Forward => { - // Entry block gets boundary value - let entry = cfg.entry().index(); - if let Some(slot) = self.in_states.get_mut(entry) { - *slot = boundary; - } - } - Direction::Backward => { - // Exit blocks get boundary value - for exit in cfg.exits() { - let idx = exit.index(); - if let Some(slot) = self.out_states.get_mut(idx) { - *slot = boundary.clone(); - } - } - } - } - - // Add all blocks to worklist in appropriate order - let order = match A::DIRECTION { - Direction::Forward => cfg.reverse_postorder(), - Direction::Backward => cfg.postorder(), - }; - - for node in order { - let idx = node.index(); - if let Some(slot) = self.in_worklist.get_mut(idx) { - self.worklist.push_back(idx); - *slot = true; - } - } - } - - /// Main iteration loop. - fn iterate(&mut self, ssa: &SsaFunction, cfg: &C) - where - A::Lattice: Clone, - { - while let Some(block_idx) = self.worklist.pop_front() { - if let Some(slot) = self.in_worklist.get_mut(block_idx) { - *slot = false; - } - self.iterations = self.iterations.saturating_add(1); - - let changed = match A::DIRECTION { - Direction::Forward => self.process_forward(block_idx, ssa, cfg), - Direction::Backward => self.process_backward(block_idx, ssa, cfg), - }; - - if changed { - // Add affected blocks to worklist - self.add_affected_to_worklist(block_idx, cfg); - } - } - } - - /// Processes a block in forward direction. - /// - /// Returns `true` if the output state changed. - fn process_forward( - &mut self, - block_idx: usize, - ssa: &SsaFunction, - cfg: &C, - ) -> bool - where - A::Lattice: Clone, - { - // Compute input by meeting all predecessor outputs - let node = NodeId::new(block_idx); - let Some(current_in) = self.in_states.get(block_idx).cloned() else { - return false; - }; - let mut input = if cfg.predecessors(node).next().is_none() { - // Entry block or unreachable - keep current in_state - current_in.clone() - } else { - // Meet all predecessor outputs - let mut result: Option = None; - for pred in cfg.predecessors(node) { - let Some(pred_out) = self.out_states.get(pred.index()) else { - continue; - }; - result = Some(match result { - None => pred_out.clone(), - Some(acc) => acc.meet(pred_out), - }); - } - result.unwrap_or_else(|| current_in.clone()) - }; - - // Special case: entry block keeps its boundary value - if node == cfg.entry() { - input = current_in.clone(); - } - - if let Some(slot) = self.in_states.get_mut(block_idx) { - *slot = input.clone(); - } - - // Apply transfer function - let Some(block) = ssa.block(block_idx) else { - return false; - }; - let output = self.analysis.transfer(block_idx, block, &input, ssa); - - // Check if output changed - let Some(out_slot) = self.out_states.get_mut(block_idx) else { - return false; - }; - let changed = output != *out_slot; - *out_slot = output; - - changed - } - - /// Processes a block in backward direction. - /// - /// Returns `true` if the input state changed. - fn process_backward( - &mut self, - block_idx: usize, - ssa: &SsaFunction, - cfg: &C, - ) -> bool - where - A::Lattice: Clone, - { - // Compute output by meeting all successor inputs - let node = NodeId::new(block_idx); - let Some(current_out) = self.out_states.get(block_idx).cloned() else { - return false; - }; - let mut output = if cfg.successors(node).next().is_none() { - // Exit block or dead end - keep current out_state - current_out.clone() - } else { - // Meet all successor inputs - let mut result: Option = None; - for succ in cfg.successors(node) { - let Some(succ_in) = self.in_states.get(succ.index()) else { - continue; - }; - result = Some(match result { - None => succ_in.clone(), - Some(acc) => acc.meet(succ_in), - }); - } - result.unwrap_or_else(|| current_out.clone()) - }; - - // Special case: exit blocks keep their boundary value - if cfg.exits().contains(&node) { - output = current_out.clone(); - } - - if let Some(slot) = self.out_states.get_mut(block_idx) { - *slot = output.clone(); - } - - // Apply transfer function (backward: input = transfer(output)) - let Some(block) = ssa.block(block_idx) else { - return false; - }; - let input = self.analysis.transfer(block_idx, block, &output, ssa); - - // Check if input changed - let Some(in_slot) = self.in_states.get_mut(block_idx) else { - return false; - }; - let changed = input != *in_slot; - *in_slot = input; - - changed - } - - /// Adds affected blocks to the worklist after a change. - fn add_affected_to_worklist(&mut self, block_idx: usize, cfg: &C) { - let node = NodeId::new(block_idx); - - let enqueue = |idx: usize, list: &mut Vec, work: &mut VecDeque| { - if let Some(slot) = list.get_mut(idx) { - if !*slot { - work.push_back(idx); - *slot = true; - } - } - }; - - match A::DIRECTION { - Direction::Forward => { - // Forward: successors are affected - for succ in cfg.successors(node) { - enqueue(succ.index(), &mut self.in_worklist, &mut self.worklist); - } - } - Direction::Backward => { - // Backward: predecessors are affected - for pred in cfg.predecessors(node) { - enqueue(pred.index(), &mut self.in_worklist, &mut self.worklist); - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::analysis::SsaBlock; - - /// A simple constant lattice for testing. - #[derive(Debug, Clone, PartialEq)] - enum TestLattice { - Top, - Value(i32), - Bottom, - } - - impl MeetSemiLattice for TestLattice { - fn meet(&self, other: &Self) -> Self { - match (self, other) { - (Self::Top, x) | (x, Self::Top) => x.clone(), - (Self::Value(a), Self::Value(b)) if a == b => Self::Value(*a), - _ => Self::Bottom, - } - } - - fn is_bottom(&self) -> bool { - matches!(self, Self::Bottom) - } - } - - /// A trivial analysis that just propagates values unchanged. - struct TrivialAnalysis; - - impl DataFlowAnalysis for TrivialAnalysis { - type Lattice = TestLattice; - const DIRECTION: Direction = Direction::Forward; - - fn boundary(&self, _ssa: &SsaFunction) -> Self::Lattice { - TestLattice::Value(42) - } - - fn initial(&self, _ssa: &SsaFunction) -> Self::Lattice { - TestLattice::Top - } - - fn transfer( - &self, - _block_id: usize, - _block: &SsaBlock, - input: &Self::Lattice, - _ssa: &SsaFunction, - ) -> Self::Lattice { - input.clone() - } - } - - #[test] - fn test_solver_iterations() { - // This is a basic sanity test - full integration tests are elsewhere - let solver = DataFlowSolver::new(TrivialAnalysis); - assert_eq!(solver.iterations(), 0); - } -} diff --git a/dotscope/src/analysis/defuse.rs b/dotscope/src/analysis/defuse.rs deleted file mode 100644 index 1f2eb19a..00000000 --- a/dotscope/src/analysis/defuse.rs +++ /dev/null @@ -1,979 +0,0 @@ -//! Def-Use Index for efficient SSA variable lookup. -//! -//! This module provides [`DefUseIndex`], a shared index structure that enables -//! efficient queries about variable definitions and uses across an SSA function. -//! -//! # Purpose -//! -//! While each `SsaVariable` tracks its own definition site and use sites, this -//! index provides additional views: -//! -//! - **Definitions by location**: What variables are defined in block B at instruction I? -//! - **Uses by location**: What variables are used in block B at instruction I? -//! - **All definitions in block**: All variables defined in block B -//! - **Unused variables**: Variables with no uses (candidates for elimination) -//! - **Defining operations**: What operation defines variable V? (with `build_with_ops`) -//! -//! # Basic Usage -//! -//! ```rust,ignore -//! use dotscope::analysis::{DefUseIndex, SsaFunction}; -//! -//! let ssa: SsaFunction = /* ... */; -//! let index = DefUseIndex::build(&ssa); -//! -//! // Find all uses of variable v0 -//! if let Some(uses) = index.uses_of(v0) { -//! for use_site in uses { -//! println!("v0 used at block {}, instr {}", use_site.block, use_site.instruction); -//! } -//! } -//! -//! // Find variables defined at a specific instruction -//! for var_id in index.defs_at(block_idx, instr_idx) { -//! println!("Variable {} defined here", var_id); -//! } -//! -//! // Check if a variable is dead (unused) -//! if index.is_unused(var_id) { -//! println!("Variable {} can be eliminated", var_id); -//! } -//! ``` -//! -//! # Building with Operations -//! -//! For passes that need to analyze the defining operation (e.g., constant folding, -//! pattern matching), use [`DefUseIndex::build_with_ops`]: -//! -//! ```rust,ignore -//! let index = DefUseIndex::build_with_ops(&ssa); -//! -//! // Get the defining operation for a variable -//! if let Some(op) = index.def_op(var_id) { -//! match op { -//! SsaOp::Add { left, right, .. } => { /* analyze operands */ } -//! SsaOp::Const { value, .. } => { /* it's a constant */ } -//! _ => {} -//! } -//! } -//! -//! // Or get everything at once: (block, instruction, operation) -//! if let Some((block, instr, op)) = index.full_definition(var_id) { -//! println!("Defined at B{}:{} by {:?}", block, instr, op); -//! } -//! ``` - -use std::collections::{BTreeMap, BTreeSet}; - -use crate::{ - analysis::ssa::{DefSite, SsaFunction, SsaOp, SsaVarId, UseSite}, - utils::BitSet, -}; - -/// Location in the SSA function (block + instruction). -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub struct Location { - /// Block index. - pub block: usize, - /// Instruction index within the block. - pub instruction: usize, -} - -impl Location { - /// Creates a new location. - #[must_use] - pub const fn new(block: usize, instruction: usize) -> Self { - Self { block, instruction } - } -} - -/// Index for efficient def-use queries on an SSA function. -/// -/// This structure is built once from an `SsaFunction` and provides O(1) or O(k) -/// access to various def-use relationships (where k is the result size). -/// -/// # Building with Operations -/// -/// Use [`build_with_ops`](Self::build_with_ops) to also index the defining operations, -/// enabling efficient lookups via [`def_op`](Self::def_op) and -/// [`full_definition`](Self::full_definition). -/// -/// ```rust,ignore -/// let index = DefUseIndex::build_with_ops(&ssa); -/// -/// // Get block, instruction, and operation in one call -/// if let Some((block, instr, op)) = index.full_definition(var_id) { -/// println!("Defined at B{}:{} by {:?}", block, instr, op); -/// } -/// ``` -#[derive(Debug, Clone, Default)] -pub struct DefUseIndex { - /// Map from variable ID to its definition site. - definitions: BTreeMap, - - /// Map from variable ID to its use sites. - uses: BTreeMap>, - - /// Map from location to variables defined there. - /// Key: (block_idx, instr_idx), Value: variables defined at that instruction. - defs_at_location: BTreeMap>, - - /// Map from location to variables used there. - /// Key: (block_idx, instr_idx), Value: variables used at that instruction. - uses_at_location: BTreeMap>, - - /// Variables defined in each block (including phi nodes). - defs_in_block: BTreeMap>, - - /// Variables defined by phi nodes. - phi_defs: BitSet, - - /// Variables with no uses (dead variables). - unused_vars: BitSet, - - /// Total variable count. - var_count: usize, - - /// Optional: defining operations for each variable. - /// Populated when built with [`build_with_ops`](Self::build_with_ops). - def_ops: Option>, -} - -impl DefUseIndex { - /// Builds a def-use index from an SSA function. - /// - /// This is an O(n) operation where n is the total number of instructions - /// and phi nodes in the function. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to index. - /// - /// # Returns - /// - /// A new `DefUseIndex` with all relationships computed. - #[must_use] - pub fn build(ssa: &SsaFunction) -> Self { - let mut definitions = BTreeMap::new(); - let mut uses: BTreeMap> = BTreeMap::new(); - let mut defs_at_location: BTreeMap> = BTreeMap::new(); - let mut uses_at_location: BTreeMap> = BTreeMap::new(); - let mut defs_in_block: BTreeMap> = BTreeMap::new(); - let variable_count = ssa.variable_count(); - let max_var_idx = ssa - .variables() - .iter() - .map(|v| v.id().index().saturating_add(1)) - .max() - .unwrap_or(0); - let bitset_capacity = max_var_idx.max(variable_count); - let mut phi_defs = BitSet::new(bitset_capacity); - - // Collect from SsaVariables (the authoritative source) - for var in ssa.variables() { - let var_id = var.id(); - let def_site = var.def_site(); - - definitions.insert(var_id, def_site); - - // Track phi definitions - if def_site.is_phi() { - phi_defs.insert(var_id.index()); - } - - // Track definitions by location - if let Some(instr_idx) = def_site.instruction { - let loc = Location::new(def_site.block, instr_idx); - defs_at_location.entry(loc).or_default().push(var_id); - } - defs_in_block - .entry(def_site.block) - .or_default() - .push(var_id); - - // Collect uses from the variable - let var_uses: Vec = var.uses().to_vec(); - for use_site in &var_uses { - let loc = Location::new(use_site.block, use_site.instruction); - uses_at_location.entry(loc).or_default().push(var_id); - } - uses.insert(var_id, var_uses); - } - - // Identify unused variables - let mut unused_vars = BitSet::new(bitset_capacity); - for (var_id, use_sites) in &uses { - if use_sites.is_empty() { - unused_vars.insert(var_id.index()); - } - } - - Self { - definitions, - uses, - defs_at_location, - uses_at_location, - defs_in_block, - phi_defs, - unused_vars, - var_count: variable_count, - def_ops: None, - } - } - - /// Builds a def-use index with defining operations stored internally. - /// - /// This version indexes the defining operation for each variable, enabling - /// efficient lookups via [`def_op`](Self::def_op) and - /// [`full_definition`](Self::full_definition). - /// - /// Use this when passes need to analyze the defining operation alongside - /// the definition site. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to index. - /// - /// # Returns - /// - /// A `DefUseIndex` with operations indexed. - /// - /// # Example - /// - /// ```rust,ignore - /// let index = DefUseIndex::build_with_ops(&ssa); - /// - /// // Get the defining operation - /// if let Some(op) = index.def_op(var_id) { - /// match op { - /// SsaOp::Add { left, right, .. } => { /* analyze add */ } - /// _ => {} - /// } - /// } - /// - /// // Or get everything at once - /// if let Some((block, instr, op)) = index.full_definition(var_id) { - /// println!("B{}:{} {:?}", block, instr, op); - /// } - /// ``` - #[must_use] - pub fn build_with_ops(ssa: &SsaFunction) -> Self { - let mut index = Self::build(ssa); - - // Collect defining operations - let mut def_ops = BTreeMap::new(); - for (_block_idx, _instr_idx, instr) in ssa.iter_instructions() { - let op = instr.op(); - if let Some(dest) = op.dest() { - def_ops.insert(dest, op.clone()); - } - } - index.def_ops = Some(def_ops); - - index - } - - /// Builds a def-use index with operations, also returning a separate map. - /// - /// This is a compatibility method for code that needs both the index - /// and a separate operation map. Prefer [`build_with_ops`](Self::build_with_ops) - /// for new code. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to index. - /// - /// # Returns - /// - /// A tuple of (`DefUseIndex`, operation map). - #[must_use] - pub fn build_with_ops_map(ssa: &SsaFunction) -> (Self, BTreeMap) { - let index = Self::build_with_ops(ssa); - let ops = index.def_ops.clone().unwrap_or_default(); - (index, ops) - } - - /// Returns whether this index has operation information. - /// - /// Returns `true` if built with [`build_with_ops`](Self::build_with_ops). - #[must_use] - pub fn has_ops(&self) -> bool { - self.def_ops.is_some() - } - - /// Returns the defining operation for a variable. - /// - /// This method requires the index to be built with - /// [`build_with_ops`](Self::build_with_ops). - /// - /// # Arguments - /// - /// * `var` - The variable ID to look up. - /// - /// # Returns - /// - /// The defining operation, or `None` if: - /// - The variable is unknown - /// - The variable is defined by a phi node (no operation) - /// - The index was not built with operations - #[must_use] - pub fn def_op(&self, var: SsaVarId) -> Option<&SsaOp> { - self.def_ops.as_ref()?.get(&var) - } - - /// Returns full definition information: block, instruction index, and operation. - /// - /// This is a convenience method for passes that need all three pieces of - /// information together. Requires the index to be built with - /// [`build_with_ops`](Self::build_with_ops). - /// - /// # Arguments - /// - /// * `var` - The variable ID to look up. - /// - /// # Returns - /// - /// A tuple of `(block_index, instruction_index, operation)`, or `None` if: - /// - The variable is unknown - /// - The variable is defined by a phi node (no instruction index) - /// - The index was not built with operations - /// - /// # Example - /// - /// ```rust,ignore - /// let index = DefUseIndex::build_with_ops(&ssa); - /// - /// if let Some((block, instr, op)) = index.full_definition(var_id) { - /// // Check if this is an add of two constants - /// if let SsaOp::Add { left, right, .. } = op { - /// let left_const = index.def_op(*left); - /// let right_const = index.def_op(*right); - /// // ... - /// } - /// } - /// ``` - #[must_use] - pub fn full_definition(&self, var: SsaVarId) -> Option<(usize, usize, &SsaOp)> { - let site = self.def_site(var)?; - let instr = site.instruction?; // None for phi nodes - let op = self.def_op(var)?; - Some((site.block, instr, op)) - } - - /// Returns the definition site for a variable. - /// - /// # Arguments - /// - /// * `var` - The variable ID to look up. - /// - /// # Returns - /// - /// The definition site, or `None` if the variable is unknown. - #[must_use] - pub fn def_site(&self, var: SsaVarId) -> Option { - self.definitions.get(&var).copied() - } - - /// Returns all use sites for a variable. - /// - /// # Arguments - /// - /// * `var` - The variable ID to look up. - /// - /// # Returns - /// - /// A slice of use sites, or `None` if the variable is unknown. - #[must_use] - pub fn uses_of(&self, var: SsaVarId) -> Option<&[UseSite]> { - self.uses.get(&var).map(Vec::as_slice) - } - - /// Returns the number of uses for a variable. - /// - /// # Arguments - /// - /// * `var` - The variable ID to count uses for. - /// - /// # Returns - /// - /// The use count, or 0 if the variable is unknown. - #[must_use] - pub fn use_count(&self, var: SsaVarId) -> usize { - self.uses.get(&var).map_or(0, Vec::len) - } - - /// Checks if a variable has any uses. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if the variable has at least one use. - #[must_use] - pub fn has_uses(&self, var: SsaVarId) -> bool { - self.uses.get(&var).is_some_and(|u| !u.is_empty()) - } - - /// Checks if a variable is unused (dead). - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if the variable has no uses. - #[must_use] - pub fn is_unused(&self, var: SsaVarId) -> bool { - var.index() < self.unused_vars.len() && self.unused_vars.contains(var.index()) - } - - /// Checks if a variable is defined by a phi node. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if the variable is defined by a phi node. - #[must_use] - pub fn is_phi_def(&self, var: SsaVarId) -> bool { - var.index() < self.phi_defs.len() && self.phi_defs.contains(var.index()) - } - - /// Returns variables defined at a specific location. - /// - /// # Arguments - /// - /// * `block` - The block index. - /// * `instruction` - The instruction index within the block. - /// - /// # Returns - /// - /// A slice of variable IDs defined at that location. - #[must_use] - pub fn defs_at(&self, block: usize, instruction: usize) -> &[SsaVarId] { - let loc = Location::new(block, instruction); - self.defs_at_location.get(&loc).map_or(&[], Vec::as_slice) - } - - /// Returns variables used at a specific location. - /// - /// # Arguments - /// - /// * `block` - The block index. - /// * `instruction` - The instruction index within the block. - /// - /// # Returns - /// - /// A slice of variable IDs used at that location. - #[must_use] - pub fn uses_at(&self, block: usize, instruction: usize) -> &[SsaVarId] { - let loc = Location::new(block, instruction); - self.uses_at_location.get(&loc).map_or(&[], Vec::as_slice) - } - - /// Returns all variables defined in a block. - /// - /// This includes both phi node definitions and instruction definitions. - /// - /// # Arguments - /// - /// * `block` - The block index. - /// - /// # Returns - /// - /// A slice of variable IDs defined in the block. - #[must_use] - pub fn defs_in_block(&self, block: usize) -> &[SsaVarId] { - self.defs_in_block.get(&block).map_or(&[], Vec::as_slice) - } - - /// Returns all unused (dead) variables. - /// - /// These are candidates for dead code elimination. - /// - /// # Returns - /// - /// A reference to the set of unused variable IDs. - #[must_use] - pub fn unused_variables(&self) -> &BitSet { - &self.unused_vars - } - - /// Returns all phi-defined variables. - /// - /// # Returns - /// - /// A reference to the set of phi-defined variable IDs. - #[must_use] - pub fn phi_definitions(&self) -> &BitSet { - &self.phi_defs - } - - /// Returns the total number of variables indexed. - #[must_use] - pub fn variable_count(&self) -> usize { - self.var_count - } - - /// Returns the number of unused variables. - #[must_use] - pub fn unused_count(&self) -> usize { - self.unused_vars.count() - } - - /// Checks if a variable has a single use. - /// - /// Single-use variables are good candidates for inlining. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if the variable has exactly one use. - #[must_use] - pub fn is_single_use(&self, var: SsaVarId) -> bool { - self.use_count(var) == 1 - } - - /// Checks if a variable is only used in phi nodes. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if all uses are phi node operands. - #[must_use] - pub fn only_used_in_phis(&self, var: SsaVarId) -> bool { - self.uses - .get(&var) - .is_some_and(|uses| !uses.is_empty() && uses.iter().all(|u| u.is_phi_operand)) - } - - /// Returns all variables used in a block. - /// - /// This is computed by scanning all use locations in the block. - /// - /// # Arguments - /// - /// * `block` - The block index. - /// - /// # Returns - /// - /// A set of variable IDs used anywhere in the block. - #[must_use] - pub fn uses_in_block(&self, block: usize) -> BTreeSet { - let mut result = BTreeSet::new(); - let start = Location::new(block, 0); - let end = Location::new(block, usize::MAX); - for (_loc, vars) in self.uses_at_location.range(start..=end) { - result.extend(vars.iter().copied()); - } - result - } - - /// Finds the unique use site if a variable has exactly one use. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// The single use site, or `None` if the variable has zero or multiple uses. - #[must_use] - pub fn single_use_site(&self, var: SsaVarId) -> Option { - self.uses.get(&var).and_then(|uses| { - if uses.len() == 1 { - uses.first().copied() - } else { - None - } - }) - } -} - -#[cfg(test)] -mod tests { - use crate::analysis::ssa::{ - ConstValue, DefSite, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, - SsaVariable, UseSite, VariableOrigin, - }; - - use super::DefUseIndex; - - /// Helper to create test SSA and return variable IDs for assertions - fn make_test_ssa() -> (SsaFunction, SsaVarId, SsaVarId) { - // Create a simple SSA function: - // Block 0: - // v0 = const 42 - // v1 = add v0, v0 - // ret v1 - let mut ssa = SsaFunction::new(0, 0); - - // Create variables first to get their IDs - let mut v0 = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let id0 = v0.id(); - v0.add_use(UseSite::instruction(0, 1)); - v0.add_use(UseSite::instruction(0, 1)); // Used twice in add - ssa.variables_mut().push(v0); - - let mut v1 = SsaVariable::new( - SsaVarId::from_index(1), - VariableOrigin::Local(1), - 0, - DefSite::instruction(0, 1), - SsaType::Unknown, - ); - let id1 = v1.id(); - v1.add_use(UseSite::instruction(0, 2)); - ssa.variables_mut().push(v1); - - // Now create block with instructions using the auto-allocated IDs - let mut block = SsaBlock::new(0); - - // v0 = const 42 - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: id0, - value: ConstValue::I32(42), - })); - - // v1 = add v0, v0 - block.add_instruction(SsaInstruction::synthetic(SsaOp::Add { - dest: id1, - left: id0, - right: id0, - })); - - // ret v1 - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { - value: Some(id1), - })); - - ssa.add_block(block); - - (ssa, id0, id1) - } - - #[test] - fn test_build_index() { - let (ssa, _id0, _id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - assert_eq!(index.variable_count(), 2); - } - - #[test] - fn test_def_site_lookup() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - let def0 = index.def_site(id0).unwrap(); - assert_eq!(def0.block, 0); - assert_eq!(def0.instruction, Some(0)); - - let def1 = index.def_site(id1).unwrap(); - assert_eq!(def1.block, 0); - assert_eq!(def1.instruction, Some(1)); - } - - #[test] - fn test_uses_of() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - // v0 is used twice - let uses0 = index.uses_of(id0).unwrap(); - assert_eq!(uses0.len(), 2); - - // v1 is used once - let uses1 = index.uses_of(id1).unwrap(); - assert_eq!(uses1.len(), 1); - } - - #[test] - fn test_use_count() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - assert_eq!(index.use_count(id0), 2); - assert_eq!(index.use_count(id1), 1); - assert_eq!(index.use_count(SsaVarId::from_index(999999)), 0); // Unknown var - } - - #[test] - fn test_defs_at_location() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - let defs_0_0 = index.defs_at(0, 0); - assert_eq!(defs_0_0.len(), 1); - assert!(defs_0_0.contains(&id0)); - - let defs_0_1 = index.defs_at(0, 1); - assert_eq!(defs_0_1.len(), 1); - assert!(defs_0_1.contains(&id1)); - - // No defs at ret instruction - let defs_0_2 = index.defs_at(0, 2); - assert!(defs_0_2.is_empty()); - } - - #[test] - fn test_uses_at_location() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - // v0 used at instruction 1 - let uses_0_1 = index.uses_at(0, 1); - assert!(uses_0_1.contains(&id0)); - - // v1 used at instruction 2 - let uses_0_2 = index.uses_at(0, 2); - assert!(uses_0_2.contains(&id1)); - } - - #[test] - fn test_defs_in_block() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - let defs = index.defs_in_block(0); - assert_eq!(defs.len(), 2); - assert!(defs.contains(&id0)); - assert!(defs.contains(&id1)); - } - - #[test] - fn test_unused_variables() { - // Create SSA with an unused variable - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - - let dest0 = SsaVarId::from_index(0); - let dest1 = SsaVarId::from_index(1); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: dest0, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: dest1, - value: ConstValue::I32(0), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { - value: Some(dest1), - })); - - ssa.add_block(block); - - // v0: defined but never used - let v0 = SsaVariable::new( - SsaVarId::from_index(2), - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let v0_id = v0.id(); - ssa.variables_mut().push(v0); - - // v1: defined and used - let mut v1 = SsaVariable::new( - SsaVarId::from_index(3), - VariableOrigin::Local(1), - 0, - DefSite::instruction(0, 1), - SsaType::Unknown, - ); - let v1_id = v1.id(); - v1.add_use(UseSite::instruction(0, 2)); - ssa.variables_mut().push(v1); - - let index = DefUseIndex::build(&ssa); - - assert!(index.is_unused(v0_id)); - assert!(!index.is_unused(v1_id)); - assert_eq!(index.unused_count(), 1); - assert!(index.unused_variables().contains(v0_id.index())); - } - - #[test] - fn test_single_use() { - let (ssa, v0_id, v1_id) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - // v0 has 2 uses -> not single use - assert!(!index.is_single_use(v0_id)); - - // v1 has 1 use -> single use - assert!(index.is_single_use(v1_id)); - - // Get the single use site - let use_site = index.single_use_site(v1_id).unwrap(); - assert_eq!(use_site.block, 0); - assert_eq!(use_site.instruction, 2); - - // v0 doesn't have single use site - assert!(index.single_use_site(v0_id).is_none()); - } - - #[test] - fn test_phi_definitions() { - let mut ssa = SsaFunction::new(0, 0); - let block = SsaBlock::new(0); - ssa.add_block(block); - - // v0: phi definition - let v0 = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Phi, - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - let v0_id = v0.id(); - ssa.variables_mut().push(v0); - - // v1: instruction definition - let v1 = SsaVariable::new( - SsaVarId::from_index(1), - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let v1_id = v1.id(); - ssa.variables_mut().push(v1); - - let index = DefUseIndex::build(&ssa); - - assert!(index.is_phi_def(v0_id)); - assert!(!index.is_phi_def(v1_id)); - assert!(index.phi_definitions().contains(v0_id.index())); - } - - #[test] - fn test_uses_in_block() { - let (ssa, v0_id, v1_id) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - let uses = index.uses_in_block(0); - assert!(uses.contains(&v0_id)); - assert!(uses.contains(&v1_id)); - } - - #[test] - fn test_default() { - let index = DefUseIndex::default(); - assert_eq!(index.variable_count(), 0); - assert_eq!(index.unused_count(), 0); - } - - #[test] - fn test_build_without_ops() { - let (ssa, id0, _id1) = make_test_ssa(); - let index = DefUseIndex::build(&ssa); - - // Index built without ops should not have operations - assert!(!index.has_ops()); - assert!(index.def_op(id0).is_none()); - assert!(index.full_definition(id0).is_none()); - } - - #[test] - fn test_build_with_ops() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build_with_ops(&ssa); - - // Index built with ops should have operations - assert!(index.has_ops()); - - // Check v0's defining operation (const 42) - let op0 = index.def_op(id0).unwrap(); - assert!(matches!(op0, SsaOp::Const { value, .. } if value.as_i32() == Some(42))); - - // Check v1's defining operation (add) - let op1 = index.def_op(id1).unwrap(); - assert!(matches!(op1, SsaOp::Add { .. })); - } - - #[test] - fn test_full_definition() { - let (ssa, id0, id1) = make_test_ssa(); - let index = DefUseIndex::build_with_ops(&ssa); - - // v0 = const 42 at block 0, instruction 0 - let (block0, instr0, op0) = index.full_definition(id0).unwrap(); - assert_eq!(block0, 0); - assert_eq!(instr0, 0); - assert!(matches!(op0, SsaOp::Const { .. })); - - // v1 = add at block 0, instruction 1 - let (block1, instr1, op1) = index.full_definition(id1).unwrap(); - assert_eq!(block1, 0); - assert_eq!(instr1, 1); - assert!(matches!(op1, SsaOp::Add { .. })); - } - - #[test] - fn test_full_definition_phi_returns_none() { - let mut ssa = SsaFunction::new(0, 0); - let block = SsaBlock::new(0); - ssa.add_block(block); - - // v0: phi definition (no instruction index) - let v0 = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Phi, - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - let v0_id = v0.id(); - ssa.variables_mut().push(v0); - - let index = DefUseIndex::build_with_ops(&ssa); - - // Phi definitions should return None from full_definition - // (because there's no instruction index) - assert!(index.full_definition(v0_id).is_none()); - - // But def_site still works - let site = index.def_site(v0_id).unwrap(); - assert_eq!(site.block, 0); - assert!(site.instruction.is_none()); - } - - #[test] - fn test_build_with_ops_map_compatibility() { - let (ssa, id0, id1) = make_test_ssa(); - let (index, ops) = DefUseIndex::build_with_ops_map(&ssa); - - // The index should have ops internally - assert!(index.has_ops()); - - // The returned map should have the same ops - assert!(ops.contains_key(&id0)); - assert!(ops.contains_key(&id1)); - - // Both should match - let op0_from_index = index.def_op(id0).unwrap(); - let op0_from_map = ops.get(&id0).unwrap(); - assert!(matches!(op0_from_index, SsaOp::Const { .. })); - assert!(matches!(op0_from_map, SsaOp::Const { .. })); - } -} diff --git a/dotscope/src/analysis/mod.rs b/dotscope/src/analysis/mod.rs index e9ee59f4..960aadc9 100644 --- a/dotscope/src/analysis/mod.rs +++ b/dotscope/src/analysis/mod.rs @@ -69,12 +69,9 @@ //! } //! ``` -mod algebraic; mod callgraph; mod cfg; mod dataflow; -mod defuse; -mod range; mod ssa; mod taint; @@ -82,8 +79,6 @@ mod taint; mod x86; // Re-export primary public types at module level -pub use crate::utils::graph::NodeId; -pub use algebraic::{simplify_op, SimplifyResult}; pub use callgraph::{ CallGraph, CallGraphNode, CallGraphStats, CallResolver, CallSite, CallTarget, CallType, ResolverStats, @@ -93,18 +88,30 @@ pub use dataflow::{ AnalysisResults, ConstantPropagation, DataFlowAnalysis, DataFlowSolver, Direction, LiveVariables, LivenessResult, ReachingDefinitions, ScalarValue, SccpResult, }; -pub use defuse::{DefUseIndex, Location}; -pub use range::ValueRange; + +// Direct re-exports from analyssa for the formerly-shimmed analyses. +pub use analyssa::analysis::algebraic::simplify_op; +pub use analyssa::analysis::defuse::Location; + +/// CIL-defaulted alias of [`analyssa::analysis::algebraic::SimplifyResult`]. +pub type SimplifyResult = analyssa::analysis::algebraic::SimplifyResult; +/// CIL-defaulted alias of [`analyssa::analysis::defuse::DefUseIndex`]. +pub type DefUseIndex = analyssa::analysis::defuse::DefUseIndex; +/// CIL-defaulted alias of [`analyssa::analysis::range::ValueRange`]. +pub type ValueRange = analyssa::analysis::range::ValueRange; +pub use analyssa::graph::NodeId; #[cfg(feature = "z3")] pub use ssa::Z3Solver; pub use ssa::{ - resolve_corelib_valuetype, AbstractValue, BinaryOpKind, CmpKind, ConstValue, ControlFlow, - DefSite, FieldRef, MethodPurity, MethodRef, PhiAnalyzer, PhiNode, PhiOperand, ReturnInfo, - SsaBlock, SsaCfg, SsaConverter, SsaEvaluator, SsaExceptionHandler, SsaFunction, - SsaFunctionBuilder, SsaInstruction, SsaOp, SsaType, SsaVarId, SsaVariable, SymbolicEvaluator, - SymbolicExpr, SymbolicOp, TypeClass, TypeContext, TypeProvider, TypeRef, UnaryOpKind, UseSite, - ValueResolver, VariableOrigin, + resolve_corelib_valuetype, AbstractValue, BinaryOpKind, CilTarget, CmpKind, ConstEvaluator, + ConstValue, ConstValueCilExt, ControlFlow, DefSite, FieldRef, MethodPurity, MethodRef, + PhiAnalyzer, PhiNode, PhiOperand, ReturnInfo, SsaBlock, SsaCfg, SsaConverter, SsaEvaluator, + SsaExceptionHandler, SsaExceptionHandlerCilExt, SsaFunction, SsaFunctionBuilder, + SsaInstruction, SsaOp, SsaOpCilExt, SsaType, SsaVarId, SsaVariable, SymbolicEvaluator, + SymbolicExpr, SymbolicOp, Target, TypeClass, TypeContext, TypeProvider, TypeRef, UnaryOpKind, + UseSite, ValueResolver, VariableOrigin, }; +pub use ssa::{SsaFunctionCilExt, SsaFunctionSemanticsExt}; pub use taint::{ cff_taint_config, find_token_dependencies, PhiTaintMode, TaintAnalysis, TaintConfig, TokenTaintBuilder, @@ -112,6 +119,7 @@ pub use taint::{ // Re-export crate-internal types (used by other crate modules via crate::analysis::X) #[cfg(feature = "compiler")] +#[allow(unused_imports)] pub(crate) use cfg::SsaLoopAnalysis; #[cfg(feature = "x86")] @@ -124,10 +132,12 @@ pub use x86::{ #[cfg(test)] mod tests { + use analyssa::graph::NodeId; + use crate::{ analysis::{CfgEdgeKind, ControlFlowGraph, SsaConverter, SsaFunction, VariableOrigin}, assembly::{decode_blocks, InstructionAssembler}, - utils::graph::NodeId, + test::TestTypeProvider, }; /// Helper to build bytecode and decode it into a CFG. @@ -144,7 +154,7 @@ mod tests { cfg, num_args, num_locals, - &crate::test::TestTypeProvider::new(num_args, num_locals), + &TestTypeProvider::new(num_args, num_locals), ) .expect("SSA construction failed") } diff --git a/dotscope/src/analysis/range.rs b/dotscope/src/analysis/range.rs deleted file mode 100644 index faea5874..00000000 --- a/dotscope/src/analysis/range.rs +++ /dev/null @@ -1,1285 +0,0 @@ -//! Value range analysis for SSA variables. -//! -//! This module provides interval-based range analysis for tracking the possible -//! values of integer variables. It supports: -//! -//! - **Constant ranges**: `[5, 5]` — exact value known -//! - **Bounded ranges**: `[0, 255]` — value within bounds -//! - **Half-open ranges**: `[0, +∞)` — non-negative values -//! - **Union ranges**: `[0, 10] ∪ [20, 30]` — disjoint intervals -//! -//! # Lattice Structure -//! -//! The `ValueRange` forms a lattice for dataflow analysis: -//! -//! ```text -//! Top (all values) -//! | -//! [MIN, MAX] full range -//! / \ -//! [a, b] [c, d] bounded ranges -//! \ / -//! [x, x] constant (singleton) -//! | -//! Bottom (no values - unreachable) -//! ``` -//! -//! # Usage -//! -//! ```rust,no_run -//! use dotscope::analysis::ValueRange; -//! -//! // Create ranges -//! let constant = ValueRange::constant(42); -//! let non_negative = ValueRange::non_negative(); -//! let byte_range = ValueRange::bounded(0, 255); -//! -//! // Query ranges -//! assert!(constant.is_constant()); -//! assert!(non_negative.is_always_non_negative()); -//! assert!(byte_range.always_less_than(256) == Some(true)); -//! -//! // Range arithmetic -//! let sum = byte_range.add(&ValueRange::constant(1)); -//! assert_eq!(sum.min(), Some(1)); -//! assert_eq!(sum.max(), Some(256)); -//! ``` - -use std::cmp::{max, min}; -use std::fmt; - -/// A range of possible integer values for analysis. -/// -/// Represents the set of values a variable might hold at runtime. -/// Used for opaque predicate detection, bounds check elimination, -/// and general range-based optimization. -#[derive(Clone, PartialEq, Eq, Hash, Default)] -pub enum ValueRange { - /// No possible values (unreachable code). - Bottom, - - /// A single contiguous interval `[min, max]`. - Interval(IntervalRange), - - /// A union of disjoint intervals (for precision). - /// Intervals are sorted by min value and non-overlapping. - Union(Vec), - - /// All values possible (no information). - #[default] - Top, -} - -/// A single contiguous interval `[min, max]`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub struct IntervalRange { - /// Minimum value (inclusive). `None` means negative infinity. - pub min: Option, - /// Maximum value (inclusive). `None` means positive infinity. - pub max: Option, -} - -impl IntervalRange { - /// Creates a new interval range. - #[must_use] - pub const fn new(min: Option, max: Option) -> Self { - Self { min, max } - } - - /// Creates a constant (singleton) interval. - #[must_use] - pub const fn constant(value: i64) -> Self { - Self { - min: Some(value), - max: Some(value), - } - } - - /// Creates a bounded interval `[min, max]`. - #[must_use] - pub const fn bounded(min: i64, max: i64) -> Self { - Self { - min: Some(min), - max: Some(max), - } - } - - /// Creates a non-negative interval `[0, +∞)`. - #[must_use] - pub const fn non_negative() -> Self { - Self { - min: Some(0), - max: None, - } - } - - /// Creates an interval from min to infinity `[min, +∞)`. - #[must_use] - pub const fn at_least(min: i64) -> Self { - Self { - min: Some(min), - max: None, - } - } - - /// Creates an interval from negative infinity to max `(-∞, max]`. - #[must_use] - pub const fn at_most(max: i64) -> Self { - Self { - min: None, - max: Some(max), - } - } - - /// Creates a full interval `(-∞, +∞)`. - #[must_use] - pub const fn full() -> Self { - Self { - min: None, - max: None, - } - } - - /// Returns `true` if this is a constant (singleton) range. - #[must_use] - pub fn is_constant(&self) -> bool { - matches!((self.min, self.max), (Some(a), Some(b)) if a == b) - } - - /// Returns the constant value if this is a singleton. - #[must_use] - pub fn as_constant(&self) -> Option { - match (self.min, self.max) { - (Some(a), Some(b)) if a == b => Some(a), - _ => None, - } - } - - /// Returns `true` if all values in this range are non-negative. - #[must_use] - pub fn is_always_non_negative(&self) -> bool { - self.min.is_some_and(|m| m >= 0) - } - - /// Returns `true` if all values in this range are positive. - #[must_use] - pub fn is_always_positive(&self) -> bool { - self.min.is_some_and(|m| m > 0) - } - - /// Returns `true` if all values in this range are negative. - #[must_use] - pub fn is_always_negative(&self) -> bool { - self.max.is_some_and(|m| m < 0) - } - - /// Returns `true` if all values in this range are non-positive. - #[must_use] - pub fn is_always_non_positive(&self) -> bool { - self.max.is_some_and(|m| m <= 0) - } - - /// Checks if all values in this range are less than `value`. - #[must_use] - pub fn always_less_than(&self, value: i64) -> Option { - if let Some(max_val) = self.max { - if max_val < value { - return Some(true); - } - // If max >= value, then at least one value (max) is NOT < value - return Some(false); - } - if let Some(min_val) = self.min { - if min_val >= value { - return Some(false); - } - } - None - } - - /// Checks if all values in this range are greater than `value`. - #[must_use] - pub fn always_greater_than(&self, value: i64) -> Option { - if let Some(min_val) = self.min { - if min_val > value { - return Some(true); - } - // If min <= value, then at least one value (min) is NOT > value - return Some(false); - } - if let Some(max_val) = self.max { - if max_val <= value { - return Some(false); - } - } - None - } - - /// Checks if all values in this range are less than or equal to `value`. - #[must_use] - pub fn always_less_equal(&self, value: i64) -> Option { - if let Some(max_val) = self.max { - if max_val <= value { - return Some(true); - } - } - if let Some(min_val) = self.min { - if min_val > value { - return Some(false); - } - } - None - } - - /// Checks if all values in this range are greater than or equal to `value`. - #[must_use] - pub fn always_greater_equal(&self, value: i64) -> Option { - if let Some(min_val) = self.min { - if min_val >= value { - return Some(true); - } - } - if let Some(max_val) = self.max { - if max_val < value { - return Some(false); - } - } - None - } - - /// Checks if all values in this range equal `value`. - #[must_use] - pub fn always_equal_to(&self, value: i64) -> Option { - match (self.min, self.max) { - (Some(min_val), Some(max_val)) => { - if min_val == max_val && min_val == value { - Some(true) - } else if value < min_val || value > max_val { - Some(false) - } else { - None - } - } - _ => None, - } - } - - /// Checks if this range can possibly contain `value`. - #[must_use] - pub fn may_contain(&self, value: i64) -> bool { - let above_min = self.min.is_none_or(|m| value >= m); - let below_max = self.max.is_none_or(|m| value <= m); - above_min && below_max - } - - /// Returns `true` if this interval overlaps with another. - #[must_use] - pub fn overlaps(&self, other: &Self) -> bool { - // Check if intervals are disjoint - let self_below_other = match (self.max, other.min) { - (Some(self_max), Some(other_min)) => self_max < other_min, - _ => false, - }; - let other_below_self = match (other.max, self.min) { - (Some(other_max), Some(self_min)) => other_max < self_min, - _ => false, - }; - !self_below_other && !other_below_self - } - - /// Returns `true` if this interval is adjacent to another (can be merged). - #[must_use] - pub fn adjacent(&self, other: &Self) -> bool { - // Check if max + 1 == other.min or other.max + 1 == self.min - if let (Some(self_max), Some(other_min)) = (self.max, other.min) { - if self_max.checked_add(1) == Some(other_min) { - return true; - } - } - if let (Some(other_max), Some(self_min)) = (other.max, self.min) { - if other_max.checked_add(1) == Some(self_min) { - return true; - } - } - false - } - - /// Meet operation: intersection of ranges. - #[must_use] - pub fn meet(&self, other: &Self) -> Option { - let new_min = match (self.min, other.min) { - (Some(a), Some(b)) => Some(max(a, b)), - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (None, None) => None, - }; - - let new_max = match (self.max, other.max) { - (Some(a), Some(b)) => Some(min(a, b)), - (Some(a), None) => Some(a), - (None, Some(b)) => Some(b), - (None, None) => None, - }; - - // Check if result is empty - if let (Some(min_val), Some(max_val)) = (new_min, new_max) { - if min_val > max_val { - return None; // Empty intersection - } - } - - Some(Self { - min: new_min, - max: new_max, - }) - } - - /// Join operation: union/hull of ranges (may lose precision). - #[must_use] - pub fn join(&self, other: &Self) -> Self { - let new_min = match (self.min, other.min) { - (Some(a), Some(b)) => Some(min(a, b)), - _ => None, // Either unbounded -> result unbounded - }; - - let new_max = match (self.max, other.max) { - (Some(a), Some(b)) => Some(max(a, b)), - _ => None, // Either unbounded -> result unbounded - }; - - Self { - min: new_min, - max: new_max, - } - } - - /// Widen operation for loop fixpoint computation. - /// - /// If the bound is growing, extend to infinity. - #[must_use] - pub fn widen(&self, other: &Self) -> Self { - let new_min = match (self.min, other.min) { - (Some(a), Some(b)) if b < a => None, // Growing down -> -∞ - (min_val, _) => min_val, // Keep current - }; - - let new_max = match (self.max, other.max) { - (Some(a), Some(b)) if b > a => None, // Growing up -> +∞ - (max_val, _) => max_val, // Keep current - }; - - Self { - min: new_min, - max: new_max, - } - } - - /// Addition of two ranges. - #[must_use] - pub fn add(&self, other: &Self) -> Self { - let new_min = match (self.min, other.min) { - (Some(a), Some(b)) => a.checked_add(b), - _ => None, - }; - - let new_max = match (self.max, other.max) { - (Some(a), Some(b)) => a.checked_add(b), - _ => None, - }; - - Self { - min: new_min, - max: new_max, - } - } - - /// Subtraction of two ranges. - #[must_use] - pub fn sub(&self, other: &Self) -> Self { - // [a, b] - [c, d] = [a - d, b - c] - let new_min = match (self.min, other.max) { - (Some(a), Some(d)) => a.checked_sub(d), - _ => None, - }; - - let new_max = match (self.max, other.min) { - (Some(b), Some(c)) => b.checked_sub(c), - _ => None, - }; - - Self { - min: new_min, - max: new_max, - } - } - - /// Multiplication of two ranges. - #[must_use] - pub fn mul(&self, other: &Self) -> Self { - // For multiplication, we need to consider all corner combinations - // because signs matter - match (self.min, self.max, other.min, other.max) { - (Some(a), Some(b), Some(c), Some(d)) => { - // Compute all four products - let products = [ - a.checked_mul(c), - a.checked_mul(d), - b.checked_mul(c), - b.checked_mul(d), - ]; - - // If any overflowed, return unbounded - if products.iter().any(std::option::Option::is_none) { - return Self::full(); - } - - let products: Vec = products.iter().filter_map(|&p| p).collect(); - let new_min = products.iter().copied().min(); - let new_max = products.iter().copied().max(); - - Self { - min: new_min, - max: new_max, - } - } - _ => Self::full(), - } - } - - /// Bitwise AND with constant mask. - #[must_use] - pub fn and_constant(&self, mask: i64) -> Self { - if mask >= 0 { - // AND with non-negative mask always produces [0, mask] - Self::bounded(0, mask) - } else { - Self::full() - } - } - - /// Bitwise OR with constant. - #[must_use] - pub fn or_constant(&self, value: i64) -> Self { - // Hard to compute precisely, be conservative - if self.is_always_non_negative() && value >= 0 { - // Both non-negative: result >= max(self.min, value) - let new_min = max(self.min.unwrap_or(0), value); - Self { - min: Some(new_min), - max: None, - } - } else { - Self::full() - } - } -} - -impl Default for IntervalRange { - fn default() -> Self { - Self::full() - } -} - -impl fmt::Display for IntervalRange { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match (self.min, self.max) { - (Some(a), Some(b)) if a == b => write!(f, "{a}"), - (Some(a), Some(b)) => write!(f, "[{a}, {b}]"), - (Some(a), None) => write!(f, "[{a}, +∞)"), - (None, Some(b)) => write!(f, "(-∞, {b}]"), - (None, None) => write!(f, "(-∞, +∞)"), - } - } -} - -impl ValueRange { - /// Creates the bottom element (empty set, unreachable). - #[must_use] - pub const fn bottom() -> Self { - Self::Bottom - } - - /// Creates the top element (all values, no information). - #[must_use] - pub const fn top() -> Self { - Self::Top - } - - /// Creates a constant (singleton) range. - #[must_use] - pub fn constant(value: i64) -> Self { - Self::Interval(IntervalRange::constant(value)) - } - - /// Creates a bounded interval `[min, max]`. - #[must_use] - pub fn bounded(min_val: i64, max_val: i64) -> Self { - if min_val > max_val { - Self::Bottom - } else { - Self::Interval(IntervalRange::bounded(min_val, max_val)) - } - } - - /// Creates a non-negative interval `[0, +∞)`. - #[must_use] - pub fn non_negative() -> Self { - Self::Interval(IntervalRange::non_negative()) - } - - /// Creates an interval from min to infinity `[min, +∞)`. - #[must_use] - pub fn at_least(min_val: i64) -> Self { - Self::Interval(IntervalRange::at_least(min_val)) - } - - /// Creates an interval from negative infinity to max `(-∞, max]`. - #[must_use] - pub fn at_most(max_val: i64) -> Self { - Self::Interval(IntervalRange::at_most(max_val)) - } - - /// Creates a range for non-null references. - /// - /// This is semantically different from numeric ranges — it indicates - /// that a reference is known to be non-null. - #[must_use] - pub fn non_null() -> Self { - // For references, we use a special marker value - // In practice, non-null is tracked separately, but we can use Top - // to indicate "some reference value exists" - Self::Top - } - - /// Creates a union of two ranges. - /// - /// If the ranges overlap or are adjacent, they are merged. - #[must_use] - pub fn union(a: Self, b: Self) -> Self { - match (a, b) { - (Self::Bottom, other) | (other, Self::Bottom) => other, - (Self::Top, _) | (_, Self::Top) => Self::Top, - (Self::Interval(ia), Self::Interval(ib)) => { - if ia.overlaps(&ib) || ia.adjacent(&ib) { - Self::Interval(ia.join(&ib)) - } else { - // Sort intervals - let (first, second) = if ia.min <= ib.min { (ia, ib) } else { (ib, ia) }; - Self::Union(vec![first, second]) - } - } - (Self::Union(mut intervals), Self::Interval(i)) - | (Self::Interval(i), Self::Union(mut intervals)) => { - intervals.push(i); - Self::normalize_union(intervals) - } - (Self::Union(mut a_intervals), Self::Union(b_intervals)) => { - a_intervals.extend(b_intervals); - Self::normalize_union(a_intervals) - } - } - } - - /// Normalizes a union of intervals (sorts, merges overlapping/adjacent). - fn normalize_union(mut intervals: Vec) -> Self { - if intervals.is_empty() { - return Self::Bottom; - } - if intervals.len() == 1 { - return Self::Interval(intervals.remove(0)); - } - - // Sort by min value - intervals.sort_by(|a, b| match (a.min, b.min) { - (Some(a_min), Some(b_min)) => a_min.cmp(&b_min), - (None, Some(_)) => std::cmp::Ordering::Less, - (Some(_), None) => std::cmp::Ordering::Greater, - (None, None) => std::cmp::Ordering::Equal, - }); - - // Merge overlapping/adjacent intervals - let mut merged: Vec = Vec::new(); - for interval in intervals { - if let Some(last) = merged.last_mut() { - if last.overlaps(&interval) || last.adjacent(&interval) { - *last = last.join(&interval); - continue; - } - } - merged.push(interval); - } - - if merged.len() == 1 { - Self::Interval(merged.remove(0)) - } else { - Self::Union(merged) - } - } - - /// Returns `true` if this is the bottom element (empty set). - #[must_use] - pub const fn is_bottom(&self) -> bool { - matches!(self, Self::Bottom) - } - - /// Returns `true` if this is the top element (all values). - #[must_use] - pub const fn is_top(&self) -> bool { - matches!(self, Self::Top) - } - - /// Returns `true` if this is a constant (singleton) range. - #[must_use] - pub fn is_constant(&self) -> bool { - match self { - Self::Interval(i) => i.is_constant(), - _ => false, - } - } - - /// Returns the constant value if this is a singleton. - #[must_use] - pub fn as_constant(&self) -> Option { - match self { - Self::Interval(i) => i.as_constant(), - _ => None, - } - } - - /// Returns the minimum value if bounded below. - #[must_use] - pub fn min(&self) -> Option { - match self { - Self::Bottom | Self::Top => None, - Self::Interval(i) => i.min, - Self::Union(intervals) => intervals.first().and_then(|i| i.min), - } - } - - /// Returns the maximum value if bounded above. - #[must_use] - pub fn max(&self) -> Option { - match self { - Self::Bottom | Self::Top => None, - Self::Interval(i) => i.max, - Self::Union(intervals) => intervals.last().and_then(|i| i.max), - } - } - - /// Returns `true` if all values are non-negative. - #[must_use] - pub fn is_always_non_negative(&self) -> bool { - match self { - Self::Bottom => true, // Vacuously true - Self::Top => false, - Self::Interval(i) => i.is_always_non_negative(), - Self::Union(intervals) => intervals.iter().all(IntervalRange::is_always_non_negative), - } - } - - /// Returns `true` if all values are positive. - #[must_use] - pub fn is_always_positive(&self) -> bool { - match self { - Self::Bottom => true, - Self::Top => false, - Self::Interval(i) => i.is_always_positive(), - Self::Union(intervals) => intervals.iter().all(IntervalRange::is_always_positive), - } - } - - /// Checks if all values are less than `value`. - #[must_use] - pub fn always_less_than(&self, value: i64) -> Option { - match self { - Self::Bottom => Some(true), // Vacuously true - Self::Top => None, - Self::Interval(i) => i.always_less_than(value), - Self::Union(intervals) => { - // All intervals must satisfy - let results: Vec<_> = intervals - .iter() - .map(|i| i.always_less_than(value)) - .collect(); - if results.iter().all(|r| *r == Some(true)) { - Some(true) - } else if results.contains(&Some(false)) { - Some(false) - } else { - None - } - } - } - } - - /// Checks if all values are greater than `value`. - #[must_use] - pub fn always_greater_than(&self, value: i64) -> Option { - match self { - Self::Bottom => Some(true), - Self::Top => None, - Self::Interval(i) => i.always_greater_than(value), - Self::Union(intervals) => { - let results: Vec<_> = intervals - .iter() - .map(|i| i.always_greater_than(value)) - .collect(); - if results.iter().all(|r| *r == Some(true)) { - Some(true) - } else if results.contains(&Some(false)) { - Some(false) - } else { - None - } - } - } - } - - /// Checks if all values equal `value`. - #[must_use] - pub fn always_equal_to(&self, value: i64) -> Option { - match self { - Self::Bottom => Some(true), // Vacuously true - Self::Interval(i) => i.always_equal_to(value), - // Top represents all possible values, Union has multiple disjoint intervals - Self::Top | Self::Union(_) => None, - } - } - - /// Checks if value might be contained in this range. - #[must_use] - pub fn may_contain(&self, value: i64) -> bool { - match self { - Self::Bottom => false, - Self::Top => true, - Self::Interval(i) => i.may_contain(value), - Self::Union(intervals) => intervals.iter().any(|i| i.may_contain(value)), - } - } - - /// Meet operation (intersection) — greatest lower bound. - /// - /// Returns the range of values that are in both `self` and `other`. - #[must_use] - pub fn meet(&self, other: &Self) -> Self { - match (self, other) { - // Bottom absorbs - (Self::Bottom, _) | (_, Self::Bottom) => Self::Bottom, - - // Top is identity - (Self::Top, x) | (x, Self::Top) => x.clone(), - - // Interval meet - (Self::Interval(a), Self::Interval(b)) => match a.meet(b) { - Some(result) => Self::Interval(result), - None => Self::Bottom, - }, - - // Union meet: intersect each pair - (Self::Union(intervals), Self::Interval(i)) - | (Self::Interval(i), Self::Union(intervals)) => { - let results: Vec<_> = intervals.iter().filter_map(|ui| ui.meet(i)).collect(); - Self::from_intervals(results) - } - - (Self::Union(a), Self::Union(b)) => { - let mut results = Vec::new(); - for ai in a { - for bi in b { - if let Some(r) = ai.meet(bi) { - results.push(r); - } - } - } - Self::from_intervals(results) - } - } - } - - /// Join operation (union hull) — least upper bound. - /// - /// Returns a range containing all values from both `self` and `other`. - /// May lose precision (the hull may contain values not in either input). - #[must_use] - pub fn join(&self, other: &Self) -> Self { - match (self, other) { - // Bottom is identity - (Self::Bottom, x) | (x, Self::Bottom) => x.clone(), - - // Top absorbs - (Self::Top, _) | (_, Self::Top) => Self::Top, - - // Interval join: take hull - (Self::Interval(a), Self::Interval(b)) => Self::Interval(a.join(b)), - - // Union join - (Self::Union(intervals), Self::Interval(i)) - | (Self::Interval(i), Self::Union(intervals)) => { - let mut all = intervals.clone(); - all.push(*i); - Self::normalize_union(all) - } - - (Self::Union(a), Self::Union(b)) => { - let mut all = a.clone(); - all.extend(b.iter().copied()); - Self::normalize_union(all) - } - } - } - - /// Widen operation for loop fixpoint computation. - /// - /// If bounds are growing, extends to infinity to ensure termination. - #[must_use] - pub fn widen(&self, other: &Self) -> Self { - match (self, other) { - (Self::Bottom, x) | (x, Self::Bottom) => x.clone(), - (Self::Top, _) | (_, Self::Top) => Self::Top, - (Self::Interval(a), Self::Interval(b)) => Self::Interval(a.widen(b)), - // For unions, widen the hull - _ => { - let self_hull = self.hull(); - let other_hull = other.hull(); - match (self_hull, other_hull) { - (Self::Interval(a), Self::Interval(b)) => Self::Interval(a.widen(&b)), - _ => Self::Top, - } - } - } - } - - /// Returns the convex hull (single interval containing all values). - #[must_use] - pub fn hull(&self) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::Top => Self::Top, - Self::Interval(_) => self.clone(), - Self::Union(intervals) => { - if intervals.is_empty() { - Self::Bottom - } else { - let min_val = intervals.first().and_then(|i| i.min); - let max_val = intervals.last().and_then(|i| i.max); - Self::Interval(IntervalRange::new(min_val, max_val)) - } - } - } - } - - /// Addition of two ranges. - #[must_use] - pub fn add(&self, other: &Self) -> Self { - self.binary_op(other, IntervalRange::add) - } - - /// Subtraction of two ranges. - #[must_use] - pub fn sub(&self, other: &Self) -> Self { - self.binary_op(other, IntervalRange::sub) - } - - /// Multiplication of two ranges. - #[must_use] - pub fn mul(&self, other: &Self) -> Self { - self.binary_op(other, IntervalRange::mul) - } - - /// Bitwise AND with a constant mask. - #[must_use] - pub fn and_constant(&self, mask: i64) -> Self { - match self { - Self::Bottom => Self::Bottom, - Self::Top | Self::Interval(_) | Self::Union(_) => { - if mask >= 0 { - Self::bounded(0, mask) - } else { - Self::Top - } - } - } - } - - /// Helper for binary operations on ranges. - fn binary_op(&self, other: &Self, op: F) -> Self - where - F: Fn(&IntervalRange, &IntervalRange) -> IntervalRange, - { - match (self, other) { - (Self::Bottom, _) | (_, Self::Bottom) => Self::Bottom, - (Self::Top, _) | (_, Self::Top) => Self::Top, - (Self::Interval(a), Self::Interval(b)) => Self::Interval(op(a, b)), - // For unions, operate on hulls (loses precision but correct) - _ => { - let a_hull = self.hull(); - let b_hull = other.hull(); - match (a_hull, b_hull) { - (Self::Interval(a), Self::Interval(b)) => Self::Interval(op(&a, &b)), - _ => Self::Top, - } - } - } - } - - /// Creates a range from a list of intervals. - fn from_intervals(intervals: Vec) -> Self { - if intervals.is_empty() { - Self::Bottom - } else { - Self::normalize_union(intervals) - } - } -} - -impl fmt::Debug for ValueRange { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Bottom => write!(f, "⊥"), - Self::Top => write!(f, "⊤"), - Self::Interval(i) => write!(f, "{i}"), - Self::Union(intervals) => { - write!(f, "(")?; - for (i, interval) in intervals.iter().enumerate() { - if i > 0 { - write!(f, " ∪ ")?; - } - write!(f, "{interval}")?; - } - write!(f, ")") - } - } - } -} - -impl fmt::Display for ValueRange { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - fmt::Debug::fmt(self, f) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_interval_constant() { - let r = IntervalRange::constant(42); - assert!(r.is_constant()); - assert_eq!(r.as_constant(), Some(42)); - assert_eq!(r.min, Some(42)); - assert_eq!(r.max, Some(42)); - } - - #[test] - fn test_interval_bounded() { - let r = IntervalRange::bounded(0, 255); - assert!(!r.is_constant()); - assert_eq!(r.as_constant(), None); - assert!(r.is_always_non_negative()); - assert!(!r.is_always_positive()); - assert!(r.may_contain(0)); - assert!(r.may_contain(255)); - assert!(!r.may_contain(256)); - assert!(!r.may_contain(-1)); - } - - #[test] - fn test_interval_non_negative() { - let r = IntervalRange::non_negative(); - assert!(r.is_always_non_negative()); - assert!(!r.is_always_positive()); - assert_eq!(r.always_less_than(0), Some(false)); - assert_eq!(r.always_greater_equal(0), Some(true)); - } - - #[test] - fn test_interval_comparisons() { - let r = IntervalRange::bounded(5, 10); - - // always_less_than: "are ALL values in range < X?" - assert_eq!(r.always_less_than(11), Some(true)); // All of [5,10] < 11 - assert_eq!(r.always_less_than(10), Some(false)); // 10 is not < 10 - assert_eq!(r.always_less_than(5), Some(false)); // 5,6,7,8,9,10 are not all < 5 - assert_eq!(r.always_less_than(8), Some(false)); // 8,9,10 are not < 8 - - // always_greater_than: "are ALL values in range > X?" - assert_eq!(r.always_greater_than(4), Some(true)); // All of [5,10] > 4 - assert_eq!(r.always_greater_than(5), Some(false)); // 5 is not > 5 - assert_eq!(r.always_greater_than(10), Some(false)); // 5,6,7,8,9 are not > 10 - assert_eq!(r.always_greater_than(7), Some(false)); // 5,6,7 are not > 7 - - // always_equal_to - let c = IntervalRange::constant(5); - assert_eq!(c.always_equal_to(5), Some(true)); - assert_eq!(c.always_equal_to(6), Some(false)); - assert_eq!(r.always_equal_to(5), None); // Range [5,10] might equal 5 (but not always) - - // Unbounded ranges - let unbounded = IntervalRange::full(); - assert_eq!(unbounded.always_less_than(0), None); // Cannot determine - assert_eq!(unbounded.always_greater_than(0), None); - } - - #[test] - fn test_interval_meet() { - let a = IntervalRange::bounded(0, 10); - let b = IntervalRange::bounded(5, 15); - let meet = a.meet(&b).unwrap(); - assert_eq!(meet.min, Some(5)); - assert_eq!(meet.max, Some(10)); - - // Disjoint intervals - let c = IntervalRange::bounded(0, 5); - let d = IntervalRange::bounded(10, 15); - assert!(c.meet(&d).is_none()); - } - - #[test] - fn test_interval_join() { - let a = IntervalRange::bounded(0, 10); - let b = IntervalRange::bounded(5, 15); - let join = a.join(&b); - assert_eq!(join.min, Some(0)); - assert_eq!(join.max, Some(15)); - } - - #[test] - fn test_interval_widen() { - let a = IntervalRange::bounded(0, 10); - let b = IntervalRange::bounded(0, 20); // Growing up - let widen = a.widen(&b); - assert_eq!(widen.min, Some(0)); - assert_eq!(widen.max, None); // Widened to +∞ - - let c = IntervalRange::bounded(5, 10); - let d = IntervalRange::bounded(0, 10); // Growing down - let widen2 = c.widen(&d); - assert_eq!(widen2.min, None); // Widened to -∞ - assert_eq!(widen2.max, Some(10)); - } - - #[test] - fn test_interval_arithmetic() { - let a = IntervalRange::bounded(1, 5); - let b = IntervalRange::bounded(2, 3); - - // Add: [1,5] + [2,3] = [3,8] - let sum = a.add(&b); - assert_eq!(sum.min, Some(3)); - assert_eq!(sum.max, Some(8)); - - // Sub: [1,5] - [2,3] = [1-3, 5-2] = [-2, 3] - let diff = a.sub(&b); - assert_eq!(diff.min, Some(-2)); - assert_eq!(diff.max, Some(3)); - - // Mul: [1,5] * [2,3] = [2, 15] - let prod = a.mul(&b); - assert_eq!(prod.min, Some(2)); - assert_eq!(prod.max, Some(15)); - } - - #[test] - fn test_interval_mul_with_negatives() { - let a = IntervalRange::bounded(-2, 3); - let b = IntervalRange::bounded(-1, 4); - - // Products: (-2)*(-1)=2, (-2)*4=-8, 3*(-1)=-3, 3*4=12 - // Range: [-8, 12] - let prod = a.mul(&b); - assert_eq!(prod.min, Some(-8)); - assert_eq!(prod.max, Some(12)); - } - - #[test] - fn test_interval_and_constant() { - let r = IntervalRange::bounded(0, 1000); - let masked = r.and_constant(0xFF); - assert_eq!(masked.min, Some(0)); - assert_eq!(masked.max, Some(255)); - } - - #[test] - fn test_range_constant() { - let r = ValueRange::constant(42); - assert!(r.is_constant()); - assert_eq!(r.as_constant(), Some(42)); - assert!(!r.is_bottom()); - assert!(!r.is_top()); - } - - #[test] - fn test_range_bounded() { - let r = ValueRange::bounded(0, 255); - assert!(!r.is_constant()); - assert!(r.is_always_non_negative()); - assert_eq!(r.min(), Some(0)); - assert_eq!(r.max(), Some(255)); - } - - #[test] - fn test_range_invalid_bounded() { - let r = ValueRange::bounded(10, 5); // min > max - assert!(r.is_bottom()); - } - - #[test] - fn test_range_lattice_meet() { - // Top is identity - let a = ValueRange::bounded(0, 10); - assert_eq!(a.meet(&ValueRange::top()), a); - assert_eq!(ValueRange::top().meet(&a), a); - - // Bottom absorbs - assert!(a.meet(&ValueRange::bottom()).is_bottom()); - assert!(ValueRange::bottom().meet(&a).is_bottom()); - - // Interval intersection - let b = ValueRange::bounded(5, 15); - let meet = a.meet(&b); - assert_eq!(meet.min(), Some(5)); - assert_eq!(meet.max(), Some(10)); - - // Disjoint -> bottom - let c = ValueRange::bounded(20, 30); - assert!(a.meet(&c).is_bottom()); - } - - #[test] - fn test_range_lattice_join() { - // Bottom is identity - let a = ValueRange::bounded(0, 10); - assert_eq!(a.join(&ValueRange::bottom()), a); - assert_eq!(ValueRange::bottom().join(&a), a); - - // Top absorbs - assert!(a.join(&ValueRange::top()).is_top()); - assert!(ValueRange::top().join(&a).is_top()); - - // Interval hull - let b = ValueRange::bounded(5, 15); - let join = a.join(&b); - assert_eq!(join.min(), Some(0)); - assert_eq!(join.max(), Some(15)); - } - - #[test] - fn test_range_union() { - let a = ValueRange::bounded(0, 5); - let b = ValueRange::bounded(10, 15); - let union = ValueRange::union(a, b); - - // Should be a union, not merged - assert!(matches!(union, ValueRange::Union(_))); - assert!(union.may_contain(3)); - assert!(union.may_contain(12)); - assert!(!union.may_contain(7)); // Gap - - // Adjacent ranges merge - let c = ValueRange::bounded(0, 5); - let d = ValueRange::bounded(6, 10); - let merged = ValueRange::union(c, d); - assert!(matches!(merged, ValueRange::Interval(_))); - assert_eq!(merged.min(), Some(0)); - assert_eq!(merged.max(), Some(10)); - - // Overlapping ranges merge - let e = ValueRange::bounded(0, 7); - let f = ValueRange::bounded(5, 10); - let merged2 = ValueRange::union(e, f); - assert!(matches!(merged2, ValueRange::Interval(_))); - assert_eq!(merged2.min(), Some(0)); - assert_eq!(merged2.max(), Some(10)); - } - - #[test] - fn test_range_widen() { - let a = ValueRange::bounded(0, 10); - let b = ValueRange::bounded(0, 20); // Growing up - let widen = a.widen(&b); - assert_eq!(widen.min(), Some(0)); - assert_eq!(widen.max(), None); // Widened to +∞ - - let c = ValueRange::bounded(5, 10); - let d = ValueRange::bounded(0, 10); // Growing down - let widen2 = c.widen(&d); - assert_eq!(widen2.min(), None); // Widened to -∞ - assert_eq!(widen2.max(), Some(10)); - } - - #[test] - fn test_range_arithmetic() { - let a = ValueRange::bounded(1, 5); - let b = ValueRange::bounded(2, 3); - - let sum = a.add(&b); - assert_eq!(sum.min(), Some(3)); - assert_eq!(sum.max(), Some(8)); - - let diff = a.sub(&b); - assert_eq!(diff.min(), Some(-2)); - assert_eq!(diff.max(), Some(3)); - - let prod = a.mul(&b); - assert_eq!(prod.min(), Some(2)); - assert_eq!(prod.max(), Some(15)); - } - - #[test] - fn test_range_and_constant_mask() { - let r = ValueRange::bounded(0, 1000); - let masked = r.and_constant(0xFF); - assert_eq!(masked.min(), Some(0)); - assert_eq!(masked.max(), Some(255)); - - // Top with mask - let top_masked = ValueRange::top().and_constant(0x0F); - assert_eq!(top_masked.min(), Some(0)); - assert_eq!(top_masked.max(), Some(15)); - } - - #[test] - fn test_range_comparison_queries() { - let r = ValueRange::bounded(5, 10); - - // always_less_than: "are ALL values in range < X?" - assert_eq!(r.always_less_than(11), Some(true)); // All of [5,10] < 11 - assert_eq!(r.always_less_than(5), Some(false)); // 5..10 are not all < 5 - assert_eq!(r.always_less_than(8), Some(false)); // 8,9,10 are not < 8 - - // always_greater_than: "are ALL values in range > X?" - assert_eq!(r.always_greater_than(4), Some(true)); // All of [5,10] > 4 - assert_eq!(r.always_greater_than(10), Some(false)); // 5..10 are not all > 10 - assert_eq!(r.always_greater_than(7), Some(false)); // 5,6,7 are not > 7 - } - - #[test] - fn test_range_always_non_negative() { - assert!(ValueRange::non_negative().is_always_non_negative()); - assert!(ValueRange::bounded(0, 100).is_always_non_negative()); - assert!(ValueRange::at_least(5).is_always_non_negative()); - assert!(!ValueRange::bounded(-5, 100).is_always_non_negative()); - assert!(!ValueRange::top().is_always_non_negative()); - assert!(ValueRange::bottom().is_always_non_negative()); // Vacuously true - } - - #[test] - fn test_range_display() { - assert_eq!(format!("{}", ValueRange::bottom()), "⊥"); - assert_eq!(format!("{}", ValueRange::top()), "⊤"); - assert_eq!(format!("{}", ValueRange::constant(42)), "42"); - assert_eq!(format!("{}", ValueRange::bounded(0, 10)), "[0, 10]"); - assert_eq!(format!("{}", ValueRange::non_negative()), "[0, +∞)"); - assert_eq!(format!("{}", ValueRange::at_most(10)), "(-∞, 10]"); - } - - #[test] - fn test_range_union_display() { - let union = ValueRange::union(ValueRange::bounded(0, 5), ValueRange::bounded(10, 15)); - let s = format!("{union}"); - assert!(s.contains("∪")); - assert!(s.contains("[0, 5]")); - assert!(s.contains("[10, 15]")); - } -} diff --git a/dotscope/src/analysis/ssa/block.rs b/dotscope/src/analysis/ssa/block.rs deleted file mode 100644 index 411db9a8..00000000 --- a/dotscope/src/analysis/ssa/block.rs +++ /dev/null @@ -1,1073 +0,0 @@ -//! SSA basic blocks containing phi nodes and instructions. -//! -//! An SSA block is the SSA-form representation of a CFG basic block. It contains: -//! -//! - **Phi nodes**: At the block entry, merging values from predecessors -//! - **Instructions**: SSA-form instructions with explicit def/use -//! -//! # Block Structure -//! -//! ```text -//! Block B: -//! // Phi nodes (executed "simultaneously" at block entry) -//! v3 = phi(v1 from B0, v2 from B1) -//! v6 = phi(v4 from B0, v5 from B1) -//! -//! // Instructions (executed sequentially) -//! v7 = add v3, v6 -//! v8 = mul v7, v3 -//! br B2 -//! ``` -//! -//! # Semantics -//! -//! Phi nodes are evaluated at block entry before any instructions execute. -//! Conceptually, all phi nodes in a block read their operands simultaneously, -//! then all write their results simultaneously. This avoids ordering issues -//! when one phi's result is used as another phi's operand. -//! -//! # Thread Safety -//! -//! All types in this module are `Send` and `Sync`. - -use std::{ - collections::{HashMap, VecDeque}, - fmt, -}; - -use crate::{ - analysis::ssa::{PhiNode, PhiOperand, SsaInstruction, SsaOp, SsaVarId}, - utils::BitSet, -}; - -/// Result of a variable replacement operation. -/// -/// When `replace_uses` encounters an instruction whose destination equals -/// `new_var`, it skips that instruction to avoid creating self-referential -/// operations (e.g., `v5 = add(v5, v3)`). This struct reports both the -/// successful replacements and the skipped ones, allowing callers to make -/// informed decisions without post-hoc scanning. -#[derive(Debug, Clone, Copy, Default)] -pub struct ReplaceResult { - /// Number of uses successfully replaced. - pub replaced: usize, - /// Number of uses skipped due to the self-referential guard - /// (instruction's dest == new_var, replacement would create self-reference). - pub skipped: usize, -} - -impl ReplaceResult { - /// Returns true if all uses were replaced (nothing was skipped). - #[must_use] - pub const fn is_complete(&self) -> bool { - self.skipped == 0 - } -} - -impl std::ops::Add for ReplaceResult { - type Output = Self; - fn add(self, rhs: Self) -> Self { - Self { - replaced: self.replaced.saturating_add(rhs.replaced), - skipped: self.skipped.saturating_add(rhs.skipped), - } - } -} - -/// An SSA basic block with phi nodes and instructions. -/// -/// This represents a basic block in SSA form. It maintains a parallel structure -/// to the CFG blocks but with explicit variable information. -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::analysis::{SsaBlock, PhiNode, SsaInstruction, SsaVarId, VariableOrigin}; -/// -/// let mut block = SsaBlock::new(0); -/// -/// // Add a phi node -/// let v1 = SsaVarId::from_index(0); -/// let v2 = SsaVarId::from_index(1); -/// let result = SsaVarId::from_index(2); -/// let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); -/// phi.set_operand(0, v1); -/// phi.set_operand(1, v2); -/// block.add_phi(phi); -/// -/// // Add instructions -/// block.add_instruction(some_ssa_instruction); -/// ``` -#[derive(Debug, Clone)] -pub struct SsaBlock { - /// Block index (matches CFG block index). - id: usize, - - /// Phi nodes at block entry. - /// - /// These are evaluated "simultaneously" before any instructions. - phi_nodes: Vec, - - /// SSA instructions in execution order. - instructions: Vec, -} - -impl SsaBlock { - /// Creates a new empty SSA block. - /// - /// # Arguments - /// - /// * `id` - The block index (should match the corresponding CFG block) - #[must_use] - pub fn new(id: usize) -> Self { - Self { - id, - phi_nodes: Vec::new(), - instructions: Vec::new(), - } - } - - /// Creates a new SSA block with pre-allocated capacity. - /// - /// # Arguments - /// - /// * `id` - The block index - /// * `phi_capacity` - Expected number of phi nodes - /// * `instr_capacity` - Expected number of instructions - #[must_use] - pub fn with_capacity(id: usize, phi_capacity: usize, instr_capacity: usize) -> Self { - Self { - id, - phi_nodes: Vec::with_capacity(phi_capacity), - instructions: Vec::with_capacity(instr_capacity), - } - } - - /// Returns the block index. - #[must_use] - pub const fn id(&self) -> usize { - self.id - } - - /// Sets the block index. - /// - /// This is used during canonicalization when blocks are renumbered - /// after empty blocks are removed. - pub fn set_id(&mut self, id: usize) { - self.id = id; - } - - /// Returns the phi nodes in this block. - #[must_use] - pub fn phi_nodes(&self) -> &[PhiNode] { - &self.phi_nodes - } - - /// Returns a mutable reference to the phi nodes. - pub fn phi_nodes_mut(&mut self) -> &mut Vec { - &mut self.phi_nodes - } - - /// Returns the instructions in this block. - #[must_use] - pub fn instructions(&self) -> &[SsaInstruction] { - &self.instructions - } - - /// Returns a mutable reference to the instructions. - pub fn instructions_mut(&mut self) -> &mut Vec { - &mut self.instructions - } - - /// Returns the number of phi nodes. - #[must_use] - pub fn phi_count(&self) -> usize { - self.phi_nodes.len() - } - - /// Returns the number of instructions. - #[must_use] - pub fn instruction_count(&self) -> usize { - self.instructions.len() - } - - /// Returns `true` if this block has no phi nodes. - #[must_use] - pub fn has_no_phis(&self) -> bool { - self.phi_nodes.is_empty() - } - - /// Returns `true` if this block has no instructions. - #[must_use] - pub fn has_no_instructions(&self) -> bool { - self.instructions.is_empty() - } - - /// Returns `true` if this block is completely empty. - #[must_use] - pub fn is_empty(&self) -> bool { - self.phi_nodes.is_empty() && self.instructions.is_empty() - } - - /// Clears all phi nodes and instructions from this block. - /// - /// After calling this method, `is_empty()` will return `true`. - /// The block ID is preserved. - pub fn clear(&mut self) { - self.phi_nodes.clear(); - self.instructions.clear(); - } - - /// Adds a phi node to this block. - pub fn add_phi(&mut self, phi: PhiNode) { - self.phi_nodes.push(phi); - } - - /// Adds an instruction to this block. - pub fn add_instruction(&mut self, instr: SsaInstruction) { - self.instructions.push(instr); - } - - /// Gets a phi node by index. - #[must_use] - pub fn phi(&self, index: usize) -> Option<&PhiNode> { - self.phi_nodes.get(index) - } - - /// Gets a mutable phi node by index. - pub fn phi_mut(&mut self, index: usize) -> Option<&mut PhiNode> { - self.phi_nodes.get_mut(index) - } - - /// Gets an instruction by index. - #[must_use] - pub fn instruction(&self, index: usize) -> Option<&SsaInstruction> { - self.instructions.get(index) - } - - /// Gets a mutable instruction by index. - pub fn instruction_mut(&mut self, index: usize) -> Option<&mut SsaInstruction> { - self.instructions.get_mut(index) - } - - /// Gets the terminator instruction (last instruction in the block). - /// - /// In well-formed SSA, the last instruction should be a control flow - /// instruction (Jump, Branch, Switch, Return, etc.). - #[must_use] - pub fn terminator(&self) -> Option<&SsaInstruction> { - self.instructions.last() - } - - /// Gets the terminator operation if the block has a terminator instruction. - /// - /// This is a convenience method combining `terminator()` and `op()`. - #[must_use] - pub fn terminator_op(&self) -> Option<&SsaOp> { - self.instructions.last().map(SsaInstruction::op) - } - - /// Returns the successor block indices for this block. - /// - /// The successors are determined by the terminator instruction: - /// - Jump/Leave: single target - /// - Branch/BranchCmp: true and false targets - /// - Switch: all case targets plus default - /// - Return/Throw/etc: no successors - #[must_use] - pub fn successors(&self) -> Vec { - self.terminator_op() - .map_or_else(Vec::new, super::SsaOp::successors) - } - - /// Redirects control flow targets from `old_target` to `new_target`. - /// - /// This modifies the block's terminator instruction in-place, redirecting any - /// occurrences of `old_target` to `new_target`. Works with all control flow - /// instructions: `Jump`, `Leave`, `Branch`, `BranchCmp`, and `Switch`. - /// - /// # Arguments - /// - /// * `old_target` - The block index to redirect from - /// * `new_target` - The block index to redirect to - /// - /// # Returns - /// - /// `true` if any target was changed, `false` otherwise. - pub fn redirect_target(&mut self, old_target: usize, new_target: usize) -> bool { - if let Some(terminator) = self.instructions.last_mut() { - return terminator.op_mut().redirect_target(old_target, new_target); - } - false - } - - /// Sets all control flow targets to a single destination. - /// - /// This forces the block to unconditionally transfer control to `target`, - /// regardless of any branch conditions. For branches, both targets are set - /// to the same value. For other terminators (like `Return` or `Throw`), - /// the terminator is replaced with an unconditional `Jump`. - /// - /// If the block has no terminator, a `Jump` instruction is added. - /// - /// # Arguments - /// - /// * `target` - The block index to jump to - pub fn set_target(&mut self, target: usize) { - if let Some(terminator) = self.instructions.last_mut() { - match terminator.op_mut() { - SsaOp::Jump { target: t } | SsaOp::Leave { target: t } => { - *t = target; - } - SsaOp::Branch { - true_target, - false_target, - .. - } - | SsaOp::BranchCmp { - true_target, - false_target, - .. - } => { - *true_target = target; - *false_target = target; - } - SsaOp::Switch { - targets, default, .. - } => { - *default = target; - for t in targets.iter_mut() { - *t = target; - } - } - _ => { - // Other terminators (Return, Throw, etc.) - replace with Jump - *terminator = SsaInstruction::synthetic(SsaOp::Jump { target }); - } - } - } else { - // No terminator - add a Jump - self.instructions - .push(SsaInstruction::synthetic(SsaOp::Jump { target })); - } - } - - /// Replaces all uses of `old_var` with `new_var` within this block. - /// - /// This replaces uses in both instructions and phi node operands. Instructions - /// that would become self-referential (where the destination equals `new_var`) - /// are skipped to maintain SSA validity. - /// - /// # Arguments - /// - /// * `old_var` - The variable ID to find and replace - /// * `new_var` - The variable ID to replace with - /// - /// # Returns - /// - /// The number of uses that were replaced. - /// - /// # Note - /// - /// This method only replaces uses in instructions, not in PHI operands. - /// This is the safe default that avoids creating cross-origin PHI operand - /// references which can break `rebuild_ssa`. For internal operations that - /// need to also replace PHI operands (like eliminating trivial PHIs), use - /// `replace_uses_including_phis`. - pub fn replace_uses(&mut self, old_var: SsaVarId, new_var: SsaVarId) -> ReplaceResult { - let mut replaced: usize = 0; - let mut skipped: usize = 0; - - for instr in &mut self.instructions { - let op = instr.op_mut(); - // Skip if this would create a self-referential instruction - if let Some(dest) = op.dest() { - if dest == new_var { - if op.uses().contains(&old_var) { - skipped = skipped.saturating_add(1); - } - continue; - } - } - replaced = replaced.saturating_add(op.replace_uses(old_var, new_var)); - } - - ReplaceResult { replaced, skipped } - } - - /// Replaces all uses of `old_var` with `new_var`, including in PHI operands. - /// - /// Unlike [`replace_uses`](Self::replace_uses), this method also replaces uses - /// in PHI node operands. This is necessary for internal SSA operations that - /// eliminate PHI nodes and need to forward their values through other PHIs. - /// - /// # Arguments - /// - /// * `old_var` - The variable ID to find and replace. - /// * `new_var` - The variable ID to use as the replacement. - /// - /// # Returns - /// - /// The number of uses that were replaced (in both instructions and PHI operands). - /// - /// # Safety - /// - /// This method is `pub(crate)` because it can create cross-origin PHI operand - /// references if misused. The issue: `rebuild_ssa` uses a `phi_operand_origins` - /// map that can only store ONE origin per variable. If a variable becomes a PHI - /// operand for PHIs with different origins (e.g., Local(0) and Local(1)), only - /// one origin is stored, causing incorrect def site classification and broken - /// PHI placement. - /// - /// # When to Use - /// - /// Only use this method for: - /// - **Trivial PHI elimination**: When removing a PHI like `v10 = phi(v5, v5)`, - /// we need to replace uses of `v10` with `v5` everywhere, including in other - /// PHI operands. - /// - **Copy propagation within PHIs**: When a copy's destination is a PHI result - /// and we're eliminating that PHI. - /// - /// For optimization passes (copy propagation, GVN, etc.), use [`replace_uses`] - /// instead, which safely skips PHI operands. - pub(crate) fn replace_uses_including_phis( - &mut self, - old_var: SsaVarId, - new_var: SsaVarId, - ) -> ReplaceResult { - let mut result = self.replace_uses(old_var, new_var); - - // Replace in phi node operands - for phi in &mut self.phi_nodes { - for operand in phi.operands_mut() { - if operand.value() == old_var { - *operand = PhiOperand::new(new_var, operand.predecessor()); - result.replaced = result.replaced.saturating_add(1); - } - } - } - - result - } - - /// Finds a phi node that defines the given variable. - #[must_use] - pub fn find_phi_defining(&self, var: SsaVarId) -> Option<&PhiNode> { - self.phi_nodes.iter().find(|phi| phi.result() == var) - } - - /// Checks if this block is a trampoline block. - /// - /// A trampoline block is one that: - /// - Has no phi nodes (doesn't merge values from multiple predecessors) - /// - Contains only a single unconditional control transfer (`Jump` or `Leave`) - /// - /// Trampoline blocks can be bypassed by redirecting predecessors directly - /// to their targets. - /// - /// # Returns - /// - /// `Some(target)` if this block is a trampoline to `target`, `None` otherwise. - /// - /// # Example - /// - /// ```ignore - /// if let Some(target) = block.is_trampoline() { - /// // Block is a trampoline to `target` - /// } - /// ``` - #[must_use] - pub fn is_trampoline(&self) -> Option { - // Blocks with phi nodes cannot be trampolines (they merge values) - if !self.phi_nodes.is_empty() { - return None; - } - - // Must have exactly one operation - if self.instructions.len() != 1 { - return None; - } - - // That operation must be an unconditional control transfer - match self.instructions.first()?.op() { - SsaOp::Jump { target } | SsaOp::Leave { target } => Some(*target), - _ => None, - } - } - - /// Returns all variables defined in this block. - /// - /// This includes phi node results and instruction defs. - pub fn defined_variables(&self) -> impl Iterator + '_ { - let phi_defs = self.phi_nodes.iter().map(PhiNode::result); - let instr_defs = self.instructions.iter().filter_map(SsaInstruction::def); - phi_defs.chain(instr_defs) - } - - /// Returns all variables used in this block. - /// - /// This includes phi operands and instruction uses. - pub fn used_variables(&self) -> impl Iterator + '_ { - let phi_uses = self.phi_nodes.iter().flat_map(PhiNode::used_variables); - let instr_uses = self.instructions.iter().flat_map(SsaInstruction::uses); - phi_uses.chain(instr_uses) - } - - /// Sorts instructions within this block in topological order based on data dependencies. - /// - /// After sorting, if instruction A uses a value defined by instruction B (within this block), - /// then B will appear before A in the instruction list. - /// - /// # Algorithm - /// - /// Uses Kahn's algorithm for topological sorting: - /// 1. Build a dependency graph: instruction -> instructions it depends on - /// 2. Start with instructions that have no dependencies within the block - /// 3. Process in order, adding instructions whose dependencies are satisfied - /// - /// # Stability - /// - /// For instructions with no ordering constraints between them, the original - /// relative order is preserved where possible. - /// - /// # Returns - /// - /// `true` if sorting succeeded, `false` if there was a cyclic dependency - /// (which indicates invalid SSA). When a cycle is detected, the block is - /// left unchanged. - /// - /// # Example - /// - /// ```rust,ignore - /// // Before: v2 = use(v1); v1 = define(); (invalid order) - /// let sorted = block.sort_instructions_topologically(); - /// assert!(sorted); - /// // After: v1 = define(); v2 = use(v1); (valid order) - /// ``` - pub fn sort_instructions_topologically(&mut self) -> bool { - if self.instructions.len() <= 1 { - return true; - } - - // IMPORTANT: Terminators must always be at the end of the block. - // Extract terminator instructions first, sort non-terminators, then append terminators. - // This prevents the sorting algorithm from moving terminators to the middle. - let mut terminators: Vec<(usize, SsaInstruction)> = Vec::new(); - let mut non_terminator_indices: Vec = Vec::new(); - - for (idx, instr) in self.instructions.iter().enumerate() { - if instr.is_terminator() { - terminators.push((idx, instr.clone())); - } else { - non_terminator_indices.push(idx); - } - } - - // If all instructions are terminators or there's nothing to sort, we're done - if non_terminator_indices.is_empty() { - return true; - } - - // Build map of var_id -> instruction index that defines it (within this block) - // Only for non-terminator instructions - let mut def_index: HashMap = HashMap::new(); - for &idx in &non_terminator_indices { - let Some(instr) = self.instructions.get(idx) else { - continue; - }; - if let Some(dest) = instr.def() { - def_index.insert(dest, idx); - } - } - - // Also include phi node definitions as "available from the start" - // Find max variable index for BitSet sizing - let max_phi_var = self - .phi_nodes - .iter() - .map(|phi| phi.result().index()) - .max() - .map_or(0, |m| m.saturating_add(1)); - let mut phi_defs = BitSet::new(max_phi_var); - for phi in &self.phi_nodes { - phi_defs.insert(phi.result().index()); - } - - // Build dependency graph for non-terminator instructions only - // Map from original index to position in non_terminator_indices - let idx_to_pos: HashMap = non_terminator_indices - .iter() - .enumerate() - .map(|(pos, &idx)| (idx, pos)) - .collect(); - - let n = non_terminator_indices.len(); - let mut deps: Vec = (0..n).map(|_| BitSet::new(n)).collect(); - let mut rdeps: Vec = (0..n).map(|_| BitSet::new(n)).collect(); - - // Track the previous side-effecting instruction position to preserve ordering. - // Side-effecting operations (Call, CallVirt, Stfld, etc.) must execute in their - // original order to preserve program semantics (I/O ordering, memory effects). - let mut prev_side_effect_pos: Option = None; - - for (pos, &idx) in non_terminator_indices.iter().enumerate() { - let Some(instr) = self.instructions.get(idx) else { - continue; - }; - - // Add data dependencies (def-use chains) - for used in &instr.uses() { - // Skip if defined by phi (always available) - if used.index() < phi_defs.len() && phi_defs.contains(used.index()) { - continue; - } - // Skip if not defined in this block - if let Some(&dep_idx) = def_index.get(used) { - if dep_idx != idx { - if let Some(&dep_pos) = idx_to_pos.get(&dep_idx) { - // instruction at pos depends on instruction at dep_pos - if let Some(d) = deps.get_mut(pos) { - d.insert(dep_pos); - } - if let Some(r) = rdeps.get_mut(dep_pos) { - r.insert(pos); - } - } - } - } - } - - // Add ordering dependency for side-effecting operations. - // Each side-effecting instruction depends on the previous one to preserve - // the original execution order of operations like Console.WriteLine calls. - if !instr.op().is_pure() { - if let Some(prev_pos) = prev_side_effect_pos { - // This side-effecting instruction depends on the previous one - if let Some(d) = deps.get_mut(pos) { - d.insert(prev_pos); - } - if let Some(r) = rdeps.get_mut(prev_pos) { - r.insert(pos); - } - } - prev_side_effect_pos = Some(pos); - } - } - - // Kahn's algorithm: process instructions with no unsatisfied dependencies - let mut in_degree: Vec = deps.iter().map(BitSet::count).collect(); - let mut ready: VecDeque = VecDeque::new(); - - // Find instructions with no dependencies (in_degree == 0) - // Process in original order for stability - for (pos, °) in in_degree.iter().enumerate() { - if deg == 0 { - ready.push_back(pos); - } - } - - let mut sorted_positions: Vec = Vec::with_capacity(n); - while let Some(pos) = ready.pop_front() { - sorted_positions.push(pos); - - // Reduce in_degree for dependents - let Some(rd) = rdeps.get(pos) else { - continue; - }; - for dep_pos in rd.iter() { - if let Some(slot) = in_degree.get_mut(dep_pos) { - *slot = slot.saturating_sub(1); - if *slot == 0 { - ready.push_back(dep_pos); - } - } - } - } - - // Check for cycles - if sorted_positions.len() != n { - // Cycle detected - this shouldn't happen in valid SSA - // Leave the block unchanged and return false - return false; - } - - // Reorder instructions: non-terminators in sorted order, then terminators at end - let mut temp: Vec> = self.instructions.drain(..).map(Some).collect(); - - // First add non-terminator instructions in sorted order - for pos in sorted_positions { - let Some(&original_idx) = non_terminator_indices.get(pos) else { - continue; - }; - if let Some(instr) = temp.get_mut(original_idx).and_then(Option::take) { - self.instructions.push(instr); - } - } - - // Then add terminators at the end (in their original relative order) - // Sort terminators by their original index to preserve order - terminators.sort_by_key(|(idx, _)| *idx); - for (_, instr) in terminators { - self.instructions.push(instr); - } - - true - } -} - -impl fmt::Display for SsaBlock { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!(f, "B{}:", self.id)?; - - for phi in &self.phi_nodes { - writeln!(f, " {phi}")?; - } - - for instr in &self.instructions { - writeln!(f, " {instr}")?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - analysis::{ - ssa::{PhiNode, PhiOperand, SsaInstruction, SsaOp, SsaVarId, VariableOrigin}, - SsaFunctionBuilder, - }, - assembly::{FlowType, Instruction, InstructionCategory, Operand, StackBehavior}, - }; - - fn make_test_cil_instruction(mnemonic: &'static str, pops: u8, pushes: u8) -> Instruction { - Instruction { - rva: 0x1000, - offset: 0, - size: 1, - opcode: 0x00, - prefix: 0, - mnemonic, - category: InstructionCategory::Arithmetic, - flow_type: FlowType::Sequential, - operand: Operand::None, - stack_behavior: StackBehavior { - pops, - pushes, - net_effect: i8::try_from(i16::from(pushes) - i16::from(pops)).unwrap_or(0), - }, - branch_targets: vec![], - } - } - - #[test] - fn test_ssa_block_creation() { - let block = SsaBlock::new(5); - assert_eq!(block.id(), 5); - assert!(block.is_empty()); - assert!(block.has_no_phis()); - assert!(block.has_no_instructions()); - } - - #[test] - fn test_ssa_block_with_capacity() { - let block = SsaBlock::with_capacity(0, 2, 10); - assert_eq!(block.id(), 0); - assert!(block.is_empty()); - } - - #[test] - fn test_ssa_block_add_phi() { - let mut block = SsaBlock::new(0); - - let result = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 0)); - phi.add_operand(PhiOperand::new(v2, 1)); - - block.add_phi(phi); - - assert!(!block.has_no_phis()); - assert_eq!(block.phi_count(), 1); - assert!(block.phi(0).is_some()); - assert_eq!(block.phi(0).unwrap().result(), result); - } - - #[test] - fn test_ssa_block_add_instruction() { - let mut block = SsaBlock::new(0); - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let cil = make_test_cil_instruction("add", 2, 1); - let instr = SsaInstruction::new( - cil, - SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }, - ); - - block.add_instruction(instr); - - assert!(!block.has_no_instructions()); - assert_eq!(block.instruction_count(), 1); - assert!(block.instruction(0).is_some()); - } - - #[test] - fn test_ssa_block_phi_access() { - let mut block = SsaBlock::new(0); - - let r1 = SsaVarId::from_index(0); - let r2 = SsaVarId::from_index(1); - block.add_phi(PhiNode::new(r1, VariableOrigin::Local(0))); - block.add_phi(PhiNode::new(r2, VariableOrigin::Local(1))); - - assert_eq!(block.phi_count(), 2); - assert!(block.phi(0).is_some()); - assert!(block.phi(1).is_some()); - assert!(block.phi(2).is_none()); - } - - #[test] - fn test_ssa_block_instruction_access() { - let mut block = SsaBlock::new(0); - - let cil1 = make_test_cil_instruction("nop", 0, 0); - let cil2 = make_test_cil_instruction("ret", 0, 0); - - block.add_instruction(SsaInstruction::new(cil1, SsaOp::Nop)); - block.add_instruction(SsaInstruction::new(cil2, SsaOp::Return { value: None })); - - assert_eq!(block.instruction_count(), 2); - assert!(block.instruction(0).is_some()); - assert!(block.instruction(1).is_some()); - assert!(block.instruction(2).is_none()); - } - - #[test] - fn test_ssa_block_find_phi_defining() { - let mut block = SsaBlock::new(0); - - let r1 = SsaVarId::from_index(0); - let r2 = SsaVarId::from_index(1); - let other = SsaVarId::from_index(2); - block.add_phi(PhiNode::new(r1, VariableOrigin::Local(0))); - block.add_phi(PhiNode::new(r2, VariableOrigin::Local(1))); - - assert!(block.find_phi_defining(r1).is_some()); - assert!(block.find_phi_defining(r2).is_some()); - assert!(block.find_phi_defining(other).is_none()); - } - - #[test] - fn test_ssa_block_defined_variables() { - let mut block = SsaBlock::new(0); - - let phi_result = SsaVarId::from_index(0); - let v0 = SsaVarId::from_index(1); - let v1 = SsaVarId::from_index(2); - let instr_result = SsaVarId::from_index(3); - let v2 = SsaVarId::from_index(4); - - // Add phi defining phi_result - block.add_phi(PhiNode::new(phi_result, VariableOrigin::Local(0))); - - // Add instruction defining instr_result - let cil = make_test_cil_instruction("add", 2, 1); - let instr = SsaInstruction::new( - cil, - SsaOp::Add { - dest: instr_result, - left: v0, - right: v1, - }, - ); - block.add_instruction(instr); - - // Add instruction with no def - let cil2 = make_test_cil_instruction("pop", 1, 0); - block.add_instruction(SsaInstruction::new(cil2, SsaOp::Pop { value: v2 })); - - let defs: Vec<_> = block.defined_variables().collect(); - assert_eq!(defs.len(), 2); - assert!(defs.contains(&phi_result)); - assert!(defs.contains(&instr_result)); - } - - #[test] - fn test_ssa_block_used_variables() { - let mut block = SsaBlock::new(0); - - let phi_result = SsaVarId::from_index(0); - let phi_op1 = SsaVarId::from_index(1); - let phi_op2 = SsaVarId::from_index(2); - let instr_op1 = SsaVarId::from_index(3); - let instr_op2 = SsaVarId::from_index(4); - let instr_result = SsaVarId::from_index(5); - - // Add phi using phi_op1, phi_op2 - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(phi_op1, 0)); - phi.add_operand(PhiOperand::new(phi_op2, 1)); - block.add_phi(phi); - - // Add instruction using instr_op1, instr_op2 - let cil = make_test_cil_instruction("add", 2, 1); - let instr = SsaInstruction::new( - cil, - SsaOp::Add { - dest: instr_result, - left: instr_op1, - right: instr_op2, - }, - ); - block.add_instruction(instr); - - let uses: Vec<_> = block.used_variables().collect(); - assert_eq!(uses.len(), 4); - assert!(uses.contains(&phi_op1)); - assert!(uses.contains(&phi_op2)); - assert!(uses.contains(&instr_op1)); - assert!(uses.contains(&instr_op2)); - } - - #[test] - fn test_ssa_block_display_empty() { - let block = SsaBlock::new(3); - let display = format!("{block}"); - assert_eq!(display, "B3:\n"); - } - - #[test] - fn test_ssa_block_display_with_content() { - let mut block = SsaBlock::new(1); - - // Add phi - let mut phi = PhiNode::new(SsaVarId::from_index(3), VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(SsaVarId::from_index(1), 0)); - phi.add_operand(PhiOperand::new(SsaVarId::from_index(2), 2)); - block.add_phi(phi); - - // Add instruction - let cil = make_test_cil_instruction("add", 2, 1); - let instr = SsaInstruction::new( - cil, - SsaOp::Add { - dest: SsaVarId::from_index(5), - left: SsaVarId::from_index(3), - right: SsaVarId::from_index(4), - }, - ); - block.add_instruction(instr); - - let display = format!("{block}"); - assert!(display.contains("B1:")); - assert!(display.contains("v3 = phi(v1 from B0, v2 from B2)")); - assert!(display.contains("v5 = add v3, v4")); - } - - #[test] - fn test_ssa_block_mutable_access() { - let mut block = SsaBlock::new(0); - - let result = SsaVarId::from_index(0); - let operand = SsaVarId::from_index(1); - block.add_phi(PhiNode::new(result, VariableOrigin::Local(0))); - - // Modify phi through mutable access - if let Some(phi) = block.phi_mut(0) { - phi.add_operand(PhiOperand::new(operand, 1)); - } - - assert_eq!(block.phi(0).unwrap().operand_count(), 1); - } - - #[test] - fn test_is_trampoline_unconditional_jump() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - // Block with single jump is a trampoline - assert_eq!(ssa.block(0).unwrap().is_trampoline(), Some(1)); - // Block with return is not a trampoline - assert_eq!(ssa.block(1).unwrap().is_trampoline(), None); - } - - #[test] - fn test_is_trampoline_leave_instruction() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| b.leave(1)); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - // Leave is also an unconditional transfer - valid trampoline - assert_eq!(ssa.block(0).unwrap().is_trampoline(), Some(1)); - } - - #[test] - fn test_is_trampoline_blocked_by_phi_nodes() { - let mut ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - // Adding a phi node makes it not a trampoline (it merges values) - if let Some(block) = ssa.block_mut(0) { - block.add_phi(PhiNode::new( - SsaVarId::from_index(0), - VariableOrigin::Local(0), - )); - } - - assert_eq!(ssa.block(0).unwrap().is_trampoline(), None); - } - - #[test] - fn test_is_trampoline_blocked_by_multiple_instructions() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(42); // Extra instruction before jump - b.jump(1); - }); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - // Multiple instructions means not a pure trampoline - assert_eq!(ssa.block(0).unwrap().is_trampoline(), None); - } - - #[test] - fn test_is_trampoline_conditional_branch_not_trampoline() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 1); - }); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - // Conditional branch is not an unconditional transfer - assert_eq!(ssa.block(0).unwrap().is_trampoline(), None); - } -} diff --git a/dotscope/src/analysis/ssa/builder.rs b/dotscope/src/analysis/ssa/builder.rs index 61826e64..8c3bb634 100644 --- a/dotscope/src/analysis/ssa/builder.rs +++ b/dotscope/src/analysis/ssa/builder.rs @@ -33,8 +33,8 @@ use std::collections::HashMap; use crate::analysis::ssa::{ - ConstValue, DefSite, FunctionVarAllocator, MethodRef, PhiNode, PhiOperand, SsaBlock, - SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, VariableOrigin, + ConstValue, ConstValueCilExt, DefSite, FunctionVarAllocator, MethodRef, PhiNode, PhiOperand, + SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, VariableOrigin, }; /// Builder for constructing SSA functions programmatically. @@ -401,7 +401,12 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn add(&mut self, left: SsaVarId, right: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Add { dest, left, right }; + let op = SsaOp::Add { + dest, + left, + right, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -410,7 +415,12 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn sub(&mut self, left: SsaVarId, right: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Sub { dest, left, right }; + let op = SsaOp::Sub { + dest, + left, + right, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -419,7 +429,12 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn mul(&mut self, left: SsaVarId, right: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Mul { dest, left, right }; + let op = SsaOp::Mul { + dest, + left, + right, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -433,6 +448,7 @@ impl SsaBlockBuilder<'_> { left, right, unsigned: false, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -447,6 +463,7 @@ impl SsaBlockBuilder<'_> { left, right, unsigned: true, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -461,6 +478,7 @@ impl SsaBlockBuilder<'_> { left, right, unsigned: false, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -475,6 +493,7 @@ impl SsaBlockBuilder<'_> { left, right, unsigned: true, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -484,7 +503,12 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn and(&mut self, left: SsaVarId, right: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::And { dest, left, right }; + let op = SsaOp::And { + dest, + left, + right, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -493,7 +517,12 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn or(&mut self, left: SsaVarId, right: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Or { dest, left, right }; + let op = SsaOp::Or { + dest, + left, + right, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -502,7 +531,12 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn xor(&mut self, left: SsaVarId, right: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Xor { dest, left, right }; + let op = SsaOp::Xor { + dest, + left, + right, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -515,6 +549,7 @@ impl SsaBlockBuilder<'_> { dest, value, amount, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -529,6 +564,7 @@ impl SsaBlockBuilder<'_> { value, amount, unsigned: false, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -543,6 +579,7 @@ impl SsaBlockBuilder<'_> { value, amount, unsigned: true, + flags: None, }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest @@ -552,7 +589,11 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn neg(&mut self, operand: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Neg { dest, operand }; + let op = SsaOp::Neg { + dest, + operand, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } @@ -561,7 +602,11 @@ impl SsaBlockBuilder<'_> { #[must_use] pub fn not(&mut self, operand: SsaVarId) -> SsaVarId { let dest = self.builder.alloc_stack_var_typed(SsaType::I32); - let op = SsaOp::Not { dest, operand }; + let op = SsaOp::Not { + dest, + operand, + flags: None, + }; self.block.add_instruction(SsaInstruction::synthetic(op)); dest } diff --git a/dotscope/src/analysis/ssa/cfg.rs b/dotscope/src/analysis/ssa/cfg.rs deleted file mode 100644 index a28aed05..00000000 --- a/dotscope/src/analysis/ssa/cfg.rs +++ /dev/null @@ -1,466 +0,0 @@ -//! Control flow graph view of SSA functions. -//! -//! This module provides [`SsaCfg`], a lightweight CFG view that can be constructed -//! directly from an [`SsaFunction`]. Unlike the CIL-based [`ControlFlowGraph`] which -//! is built from basic blocks, this CFG is derived from SSA block terminators. -//! -//! # Purpose -//! -//! The primary use case is to enable dataflow analyses (like SCCP) that require -//! a CFG to work on SSA functions during deobfuscation passes. Since passes only -//! receive `SsaFunction` (not the original CFG), this module bridges that gap. -//! -//! # Design -//! -//! `SsaCfg` implements the standard graph traits: -//! - [`GraphBase`] - Node count and iteration -//! - [`Successors`] - Forward edge traversal (from terminators) -//! - [`Predecessors`] - Backward edge traversal (computed from successors) -//! - [`RootedGraph`] - Entry node (block 0) -//! -//! This allows it to be used with the existing dataflow analysis infrastructure, -//! particularly the SCCP algorithm in [`crate::analysis::dataflow::sccp`]. -//! -//! # Construction -//! -//! The CFG is constructed on-demand from the SSA function: -//! -//! ```rust,ignore -//! use dotscope::analysis::{SsaCfg, SsaFunction}; -//! -//! let ssa: SsaFunction = /* ... */; -//! let cfg = SsaCfg::from_ssa(&ssa); -//! -//! // Use with SCCP -//! let mut sccp = ConstantPropagation::new(PointerSize::Bit64); -//! let results = sccp.analyze(&ssa, &cfg); -//! ``` -//! -//! [`ControlFlowGraph`]: crate::analysis::ControlFlowGraph - -use crate::{ - analysis::ssa::SsaFunction, - utils::graph::{ - algorithms::{postorder, reverse_postorder}, - GraphBase, NodeId, Predecessors, RootedGraph, Successors, - }, -}; - -/// A lightweight control flow graph view of an SSA function. -/// -/// This struct provides a CFG interface over an existing [`SsaFunction`], -/// extracting control flow edges from block terminators. It's designed to -/// enable dataflow analyses that require a CFG without duplicating the -/// underlying SSA structure. -/// -/// # Performance -/// -/// The CFG computes and caches predecessor lists on construction. This is -/// an O(E) operation where E is the number of edges (typically similar to -/// the number of blocks). Once constructed, all queries are O(1) or O(k) -/// where k is the number of adjacent nodes. -/// -/// # Lifetime -/// -/// The CFG holds a reference to the SSA function it was created from. -/// The CFG must not outlive the SSA function. -#[derive(Debug)] -pub struct SsaCfg<'a> { - /// Reference to the SSA function. - ssa: &'a SsaFunction, - /// Precomputed successor lists for each block (includes exception handler edges). - successors: Vec>, - /// Precomputed predecessor lists for each block (includes exception handler edges). - predecessors: Vec>, -} - -impl<'a> SsaCfg<'a> { - /// Creates a CFG view from an SSA function. - /// - /// This extracts control flow edges by examining the terminator of each - /// SSA block. Predecessors are computed and cached for efficient backward - /// traversal. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to create a CFG view of. - /// - /// # Returns - /// - /// A new `SsaCfg` view of the given function. - /// - /// # Example - /// - /// ```rust,ignore - /// let cfg = SsaCfg::from_ssa(&ssa_function); - /// assert_eq!(cfg.node_count(), ssa_function.block_count()); - /// ``` - #[must_use] - pub fn from_ssa(ssa: &'a SsaFunction) -> Self { - let block_count = ssa.block_count(); - let mut successors = vec![Vec::new(); block_count]; - let mut predecessors = vec![Vec::new(); block_count]; - - // Build successor/predecessor lists from block terminators - for (block_idx, block_succs_list) in successors.iter_mut().enumerate() { - if let Some(block) = ssa.block(block_idx) { - let block_succs = block - .instructions() - .iter() - .rev() - .find_map(|instr| { - let op = instr.op(); - if op.is_terminator() { - Some(op.successors()) - } else { - None - } - }) - .unwrap_or_default(); - - for succ in block_succs { - if let Some(slot) = predecessors.get_mut(succ) { - block_succs_list.push(succ); - slot.push(block_idx); - } - } - } - } - - // Add synthetic edges for exception handlers. Handler blocks are only - // reachable via runtime exceptions, not explicit branches, so they - // appear disconnected in the terminator-based CFG. We add an edge from - // the try region's entry block to the handler entry block so that - // analyses (dominator computation, reachability, etc.) treat them as - // connected. - for handler in ssa.exception_handlers() { - if let (Some(try_start), Some(handler_start)) = - (handler.try_start_block, handler.handler_start_block) - { - if handler_start < block_count - && try_start < block_count - && !predecessors - .get(handler_start) - .is_some_and(|p| p.contains(&try_start)) - { - if let Some(slot) = successors.get_mut(try_start) { - slot.push(handler_start); - } - if let Some(slot) = predecessors.get_mut(handler_start) { - slot.push(try_start); - } - } - } - } - - Self { - ssa, - successors, - predecessors, - } - } - - /// Returns the underlying SSA function. - /// - /// This can be used to access block and instruction data while - /// traversing the CFG. - #[must_use] - pub const fn ssa(&self) -> &'a SsaFunction { - self.ssa - } - - /// Returns the number of blocks in the CFG. - #[must_use] - pub fn block_count(&self) -> usize { - self.ssa.block_count() - } - - /// Returns true if the CFG has no blocks. - #[must_use] - pub fn is_empty(&self) -> bool { - self.ssa.is_empty() - } - - /// Returns the successor block indices for a given block. - /// - /// Includes both terminator-derived edges and synthetic exception handler - /// edges (try entry → handler entry). - /// - /// # Arguments - /// - /// * `block_idx` - The block index to query. - /// - /// # Returns - /// - /// A slice of successor block indices. Empty if the block has no - /// successors (e.g., return, throw) or doesn't exist. - #[must_use] - pub fn block_successors(&self, block_idx: usize) -> &[usize] { - self.successors.get(block_idx).map_or(&[], Vec::as_slice) - } - - /// Returns the predecessor block indices for a given block. - /// - /// # Arguments - /// - /// * `block_idx` - The block index to query. - /// - /// # Returns - /// - /// A slice of predecessor block indices. - #[must_use] - pub fn block_predecessors(&self, block_idx: usize) -> &[usize] { - self.predecessors.get(block_idx).map_or(&[], Vec::as_slice) - } - - /// Returns the exit nodes of the CFG. - /// - /// Exit nodes are blocks with no successors (blocks that end in return, - /// throw, or other terminating instructions). - /// - /// # Returns - /// - /// A vector of exit node IDs. - #[must_use] - pub fn exits(&self) -> Vec { - let mut exits = Vec::new(); - for idx in 0..self.ssa.block_count() { - if self.block_successors(idx).is_empty() { - exits.push(NodeId::new(idx)); - } - } - exits - } - - /// Returns blocks in postorder traversal. - /// - /// Postorder is useful for backward data flow analysis. - /// - /// # Returns - /// - /// A vector of node IDs in postorder. - #[must_use] - pub fn postorder(&self) -> Vec { - postorder(self, self.entry()) - } - - /// Returns blocks in reverse postorder traversal. - /// - /// Reverse postorder is useful for forward data flow analysis. - /// - /// # Returns - /// - /// A vector of node IDs in reverse postorder. - #[must_use] - pub fn reverse_postorder(&self) -> Vec { - reverse_postorder(self, self.entry()) - } -} - -impl GraphBase for SsaCfg<'_> { - fn node_count(&self) -> usize { - self.ssa.block_count() - } - - fn node_ids(&self) -> impl Iterator { - (0..self.ssa.block_count()).map(NodeId::new) - } -} - -impl Successors for SsaCfg<'_> { - fn successors(&self, node: NodeId) -> impl Iterator { - self.block_successors(node.index()) - .iter() - .copied() - .map(NodeId::new) - } -} - -impl Predecessors for SsaCfg<'_> { - fn predecessors(&self, node: NodeId) -> impl Iterator { - self.block_predecessors(node.index()) - .iter() - .copied() - .map(NodeId::new) - } -} - -impl RootedGraph for SsaCfg<'_> { - fn entry(&self) -> NodeId { - NodeId::new(0) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - analysis::ssa::{SsaBlock, SsaInstruction, SsaOp, SsaVarId}, - utils::graph::{GraphBase, NodeId, Predecessors, RootedGraph, Successors}, - }; - - /// Creates a simple test SSA function with the given block structure. - fn create_test_ssa(terminators: Vec) -> SsaFunction { - let mut ssa = SsaFunction::new(0, 0); - - for (idx, terminator) in terminators.into_iter().enumerate() { - let mut block = SsaBlock::new(idx); - block.add_instruction(SsaInstruction::synthetic(terminator)); - ssa.add_block(block); - } - - ssa - } - - #[test] - fn test_empty_ssa() { - let ssa = SsaFunction::new(0, 0); - let cfg = SsaCfg::from_ssa(&ssa); - - assert!(cfg.is_empty()); - assert_eq!(cfg.node_count(), 0); - } - - #[test] - fn test_single_block() { - let ssa = create_test_ssa(vec![SsaOp::Return { value: None }]); - let cfg = SsaCfg::from_ssa(&ssa); - - assert_eq!(cfg.node_count(), 1); - assert_eq!(cfg.entry(), NodeId::new(0)); - assert!(cfg.block_successors(0).is_empty()); - assert!(cfg.block_predecessors(0).is_empty()); - } - - #[test] - fn test_linear_blocks() { - // B0 -> B1 -> B2 (return) - let ssa = create_test_ssa(vec![ - SsaOp::Jump { target: 1 }, - SsaOp::Jump { target: 2 }, - SsaOp::Return { value: None }, - ]); - let cfg = SsaCfg::from_ssa(&ssa); - - assert_eq!(cfg.node_count(), 3); - - // Check successors - assert_eq!(cfg.block_successors(0), vec![1]); - assert_eq!(cfg.block_successors(1), vec![2]); - assert!(cfg.block_successors(2).is_empty()); - - // Check predecessors - assert!(cfg.block_predecessors(0).is_empty()); - assert_eq!(cfg.block_predecessors(1), vec![0]); - assert_eq!(cfg.block_predecessors(2), vec![1]); - } - - #[test] - fn test_diamond_cfg() { - // B0 (branch) -> B1, B2 -> B3 (return) - let cond = SsaVarId::from_index(0); - let ssa = create_test_ssa(vec![ - SsaOp::Branch { - condition: cond, - true_target: 1, - false_target: 2, - }, - SsaOp::Jump { target: 3 }, - SsaOp::Jump { target: 3 }, - SsaOp::Return { value: None }, - ]); - let cfg = SsaCfg::from_ssa(&ssa); - - assert_eq!(cfg.node_count(), 4); - - // B0 has two successors - let b0_succs = cfg.block_successors(0); - assert_eq!(b0_succs.len(), 2); - assert!(b0_succs.contains(&1)); - assert!(b0_succs.contains(&2)); - - // B3 has two predecessors - let b3_preds = cfg.block_predecessors(3); - assert_eq!(b3_preds.len(), 2); - assert!(b3_preds.contains(&1)); - assert!(b3_preds.contains(&2)); - } - - #[test] - fn test_loop_cfg() { - // B0 -> B1 (loop) -> B1 (back edge) or B2 (exit) - let cond = SsaVarId::from_index(0); - let ssa = create_test_ssa(vec![ - SsaOp::Jump { target: 1 }, - SsaOp::Branch { - condition: cond, - true_target: 1, // back edge - false_target: 2, - }, - SsaOp::Return { value: None }, - ]); - let cfg = SsaCfg::from_ssa(&ssa); - - assert_eq!(cfg.node_count(), 3); - - // B1 has itself as a predecessor (back edge) - let b1_preds = cfg.block_predecessors(1); - assert_eq!(b1_preds.len(), 2); - assert!(b1_preds.contains(&0)); - assert!(b1_preds.contains(&1)); // self-loop - } - - #[test] - fn test_switch_cfg() { - // B0 (switch) -> B1, B2, B3 (cases), B4 (default) - let val = SsaVarId::from_index(0); - let ssa = create_test_ssa(vec![ - SsaOp::Switch { - value: val, - targets: vec![1, 2, 3], - default: 4, - }, - SsaOp::Return { value: None }, - SsaOp::Return { value: None }, - SsaOp::Return { value: None }, - SsaOp::Return { value: None }, - ]); - let cfg = SsaCfg::from_ssa(&ssa); - - assert_eq!(cfg.node_count(), 5); - - // B0 has 4 successors (3 cases + default) - let b0_succs = cfg.block_successors(0); - assert_eq!(b0_succs.len(), 4); - assert!(b0_succs.contains(&1)); - assert!(b0_succs.contains(&2)); - assert!(b0_succs.contains(&3)); - assert!(b0_succs.contains(&4)); - } - - #[test] - fn test_graph_traits() { - let ssa = create_test_ssa(vec![ - SsaOp::Jump { target: 1 }, - SsaOp::Return { value: None }, - ]); - let cfg = SsaCfg::from_ssa(&ssa); - - // Test GraphBase - assert_eq!(GraphBase::node_count(&cfg), 2); - let node_ids: Vec<_> = GraphBase::node_ids(&cfg).collect(); - assert_eq!(node_ids, vec![NodeId::new(0), NodeId::new(1)]); - - // Test Successors - let succs: Vec<_> = Successors::successors(&cfg, NodeId::new(0)).collect(); - assert_eq!(succs, vec![NodeId::new(1)]); - - // Test Predecessors - let preds: Vec<_> = Predecessors::predecessors(&cfg, NodeId::new(1)).collect(); - assert_eq!(preds, vec![NodeId::new(0)]); - - // Test RootedGraph - assert_eq!(RootedGraph::entry(&cfg), NodeId::new(0)); - } -} diff --git a/dotscope/src/analysis/ssa/constraints.rs b/dotscope/src/analysis/ssa/constraints.rs deleted file mode 100644 index b1de67f7..00000000 --- a/dotscope/src/analysis/ssa/constraints.rs +++ /dev/null @@ -1,251 +0,0 @@ -//! Constraint types for SSA path analysis. -//! -//! This module provides constraint types used during path-aware SSA evaluation. -//! These constraints represent facts learned from branch conditions while -//! traversing specific paths through the control flow graph. -//! -//! # Constraint Types -//! -//! - [`Constraint`]: A constraint on a single value (e.g., `x == 5`, `x > 10`) -//! - [`PathConstraint`]: Associates a constraint with a specific SSA variable -//! -//! # Usage with SsaEvaluator -//! -//! The [`SsaEvaluator`](super::SsaEvaluator) tracks path constraints during evaluation. -//! When taking a branch, the evaluator records constraints that must hold on that path. -//! -//! For example, after taking the true branch of `if (x == 5)`: -//! - We know `x == 5` on this path -//! - This is recorded as `PathConstraint { variable: x, constraint: Constraint::Equal(ConstValue::I32(5)) }` -//! -//! # Use Cases -//! -//! - Constraint solving with Z3 -//! - Value range analysis -//! - Dead code detection -//! - Path-sensitive constant propagation - -use crate::{ - analysis::ssa::{ConstValue, SsaVarId}, - metadata::typesystem::PointerSize, -}; - -/// A constraint on a variable's value derived from branch conditions. -/// -/// When following a specific branch path, we can derive facts about variable values. -/// For example, after taking the true branch of `if (x == 5)`, we know `x == 5`. -#[derive(Debug, Clone, PartialEq)] -pub enum Constraint { - /// Variable equals a concrete value: `x == value` - Equal(ConstValue), - /// Variable does not equal a concrete value: `x != value` - NotEqual(ConstValue), - /// Variable is greater than a value (signed): `x > value` - GreaterThan(ConstValue), - /// Variable is less than a value (signed): `x < value` - LessThan(ConstValue), - /// Variable is greater than or equal (signed): `x >= value` - GreaterOrEqual(ConstValue), - /// Variable is less than or equal (signed): `x <= value` - LessOrEqual(ConstValue), - /// Variable is greater than (unsigned): `(uint)x > value` - GreaterThanUnsigned(ConstValue), - /// Variable is less than (unsigned): `(uint)x < value` - LessThanUnsigned(ConstValue), -} - -impl Constraint { - /// Checks if a concrete value satisfies this constraint. - /// - /// Uses typed comparison methods from `ConstValue`. - #[must_use] - pub fn is_satisfied_by(&self, value: &ConstValue) -> bool { - match self { - Self::Equal(v) => value.ceq(v).is_some_and(|r| !r.is_zero()), - Self::NotEqual(v) => value.ceq(v).is_some_and(|r| r.is_zero()), - Self::GreaterThan(v) => value.cgt(v).is_some_and(|r| !r.is_zero()), - Self::LessThan(v) => value.clt(v).is_some_and(|r| !r.is_zero()), - Self::GreaterOrEqual(v) => value.clt(v).is_some_and(|r| r.is_zero()), - Self::LessOrEqual(v) => value.cgt(v).is_some_and(|r| r.is_zero()), - Self::GreaterThanUnsigned(v) => value.cgt_un(v).is_some_and(|r| !r.is_zero()), - Self::LessThanUnsigned(v) => value.clt_un(v).is_some_and(|r| !r.is_zero()), - } - } - - /// Returns the concrete value if this is an equality constraint. - #[must_use] - pub fn as_equal(&self) -> Option<&ConstValue> { - match self { - Self::Equal(v) => Some(v), - _ => None, - } - } - - /// Checks if this constraint conflicts with another (both can't be true). - /// - /// # Arguments - /// - /// * `other` - The other constraint to check against. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn conflicts_with(&self, other: &Constraint, ptr_size: PointerSize) -> bool { - match (self, other) { - // x == a conflicts with x == b (if a != b) - (Self::Equal(a), Self::Equal(b)) => a != b, - // x == a conflicts with x != a - (Self::Equal(a), Self::NotEqual(b)) | (Self::NotEqual(b), Self::Equal(a)) => a == b, - // x == a conflicts with x > b when a <= b (i.e., a is not greater than b) - (Self::Equal(a), Self::GreaterThan(b)) | (Self::GreaterThan(b), Self::Equal(a)) => { - // a <= b means a is not strictly greater than b - a.cgt(b).is_none_or(|r| r.is_zero()) - } - // x == a conflicts with x < b when a >= b (i.e., a is not less than b) - (Self::Equal(a), Self::LessThan(b)) | (Self::LessThan(b), Self::Equal(a)) => { - // a >= b means a is not strictly less than b - a.clt(b).is_none_or(|r| r.is_zero()) - } - // x > a conflicts with x < b if ranges don't overlap - (Self::GreaterThan(a), Self::LessThan(b)) - | (Self::LessThan(b), Self::GreaterThan(a)) => { - // x > a AND x < b requires b > a + 1 (there must be room for at least one integer) - // Conflicts when b <= a + 1 - let one = ConstValue::I32(1); - a.add(&one, ptr_size) - .and_then(|a_plus_1| b.cgt(&a_plus_1)) - .is_none_or(|r| r.is_zero()) - } - _ => false, - } - } -} - -/// A constraint on a path derived from branch conditions. -/// -/// Associates a [`Constraint`] with a specific SSA variable. These constraints -/// are accumulated during path evaluation and can be used for constraint -/// solving with Z3. -#[derive(Debug, Clone, PartialEq)] -pub struct PathConstraint { - /// The variable this constraint applies to. - pub variable: SsaVarId, - /// The constraint on the variable's value. - pub constraint: Constraint, -} - -impl PathConstraint { - /// Creates a new equality constraint. - #[must_use] - pub fn equal(variable: SsaVarId, value: ConstValue) -> Self { - Self { - variable, - constraint: Constraint::Equal(value), - } - } - - /// Creates a new inequality constraint. - #[must_use] - pub fn not_equal(variable: SsaVarId, value: ConstValue) -> Self { - Self { - variable, - constraint: Constraint::NotEqual(value), - } - } - - /// Creates a new less-than constraint. - #[must_use] - pub fn less_than(variable: SsaVarId, value: ConstValue) -> Self { - Self { - variable, - constraint: Constraint::LessThan(value), - } - } - - /// Creates a new greater-than constraint. - #[must_use] - pub fn greater_than(variable: SsaVarId, value: ConstValue) -> Self { - Self { - variable, - constraint: Constraint::GreaterThan(value), - } - } - - /// Checks if a concrete value satisfies this constraint. - #[must_use] - pub fn is_satisfied_by(&self, value: &ConstValue) -> bool { - self.constraint.is_satisfied_by(value) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::ssa::{ - constraints::{Constraint, PathConstraint}, - ConstValue, SsaVarId, - }, - metadata::typesystem::PointerSize, - }; - - #[test] - fn test_constraint_satisfied() { - let c = Constraint::Equal(ConstValue::I32(5)); - assert!(c.is_satisfied_by(&ConstValue::I32(5))); - assert!(!c.is_satisfied_by(&ConstValue::I32(6))); - - let c = Constraint::NotEqual(ConstValue::I32(5)); - assert!(!c.is_satisfied_by(&ConstValue::I32(5))); - assert!(c.is_satisfied_by(&ConstValue::I32(6))); - - let c = Constraint::GreaterThan(ConstValue::I32(5)); - assert!(c.is_satisfied_by(&ConstValue::I32(10))); - assert!(!c.is_satisfied_by(&ConstValue::I32(5))); - - let c = Constraint::LessThan(ConstValue::I32(10)); - assert!(c.is_satisfied_by(&ConstValue::I32(5))); - assert!(!c.is_satisfied_by(&ConstValue::I32(10))); - } - - #[test] - fn test_constraint_conflicts() { - let c1 = Constraint::Equal(ConstValue::I32(5)); - let c2 = Constraint::Equal(ConstValue::I32(10)); - assert!(c1.conflicts_with(&c2, PointerSize::Bit64)); - - let c3 = Constraint::NotEqual(ConstValue::I32(5)); - assert!(c1.conflicts_with(&c3, PointerSize::Bit64)); - - let c4 = Constraint::GreaterThan(ConstValue::I32(5)); - assert!(c1.conflicts_with(&c4, PointerSize::Bit64)); // 5 is not > 5 - - let c5 = Constraint::GreaterThan(ConstValue::I32(4)); - assert!(!c1.conflicts_with(&c5, PointerSize::Bit64)); // 5 > 4 is ok - } - - #[test] - fn test_path_constraint_satisfied() { - let var = SsaVarId::from_index(0); - - let eq = PathConstraint::equal(var, ConstValue::I32(5)); - assert!(eq.is_satisfied_by(&ConstValue::I32(5))); - assert!(!eq.is_satisfied_by(&ConstValue::I32(6))); - - let ne = PathConstraint::not_equal(var, ConstValue::I32(5)); - assert!(!ne.is_satisfied_by(&ConstValue::I32(5))); - assert!(ne.is_satisfied_by(&ConstValue::I32(6))); - } - - #[test] - fn test_path_constraint_kinds() { - let var = SsaVarId::from_index(0); - - assert!(PathConstraint::less_than(var, ConstValue::I32(10)) - .is_satisfied_by(&ConstValue::I32(5))); - assert!(!PathConstraint::less_than(var, ConstValue::I32(10)) - .is_satisfied_by(&ConstValue::I32(10))); - - assert!(PathConstraint::greater_than(var, ConstValue::I32(5)) - .is_satisfied_by(&ConstValue::I32(10))); - assert!(!PathConstraint::greater_than(var, ConstValue::I32(5)) - .is_satisfied_by(&ConstValue::I32(5))); - } -} diff --git a/dotscope/src/analysis/ssa/consts.rs b/dotscope/src/analysis/ssa/consts.rs deleted file mode 100644 index fb7fd1f3..00000000 --- a/dotscope/src/analysis/ssa/consts.rs +++ /dev/null @@ -1,520 +0,0 @@ -//! Constant evaluation for SSA operations. -//! -//! This module provides unified constant folding capabilities for SSA analysis. -//! The [`ConstEvaluator`] can be used by multiple passes (unflattening, decryption, -//! SCCP, etc.) to evaluate SSA operations to constant values. -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::analysis::{ConstEvaluator, SsaFunction}; -//! -//! let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); -//! -//! // Inject known values from external analysis -//! evaluator.set_known(state_var, ConstValue::I32(42)); -//! -//! // Evaluate a variable -//! if let Some(value) = evaluator.evaluate_var(some_var) { -//! println!("Variable evaluates to: {:?}", value); -//! } -//! -//! // Get all computed constants -//! let constants = evaluator.into_results(); -//! ``` - -use std::collections::HashMap; - -use crate::{ - analysis::ssa::{ConstValue, SsaFunction, SsaOp, SsaVarId}, - metadata::typesystem::PointerSize, - utils::BitSet, -}; - -/// Evaluates SSA operations to constant values. -/// -/// This provides a unified implementation of constant folding that can be -/// used by multiple passes (unflattening, decryption, SCCP, etc.). -/// -/// # Features -/// -/// - Caches results for efficiency -/// - Detects cycles to prevent infinite recursion -/// - Supports injecting known values from external analysis -/// - Configurable depth limit -pub struct ConstEvaluator<'a> { - /// Reference to the SSA function being analyzed. - ssa: &'a SsaFunction, - - /// Cache of evaluated constants. - /// `Some(value)` means the variable evaluates to that constant. - /// `None` means the variable was evaluated but is not constant. - cache: HashMap>, - - /// Variables currently being evaluated (for cycle detection). - visiting: BitSet, - - /// Maximum recursion depth. - max_depth: usize, - - /// Target pointer size for native int/uint masking. - pointer_size: PointerSize, -} - -impl<'a> ConstEvaluator<'a> { - /// Default maximum recursion depth. - const DEFAULT_MAX_DEPTH: usize = 20; - - /// Creates a new evaluator with default settings. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self { - Self::with_max_depth(ssa, Self::DEFAULT_MAX_DEPTH, ptr_size) - } - - /// Creates an evaluator with a custom depth limit. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `max_depth` - Maximum recursion depth for evaluation. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn with_max_depth(ssa: &'a SsaFunction, max_depth: usize, ptr_size: PointerSize) -> Self { - Self { - ssa, - cache: HashMap::new(), - visiting: BitSet::new(ssa.variable_count().max(1)), - max_depth, - pointer_size: ptr_size, - } - } - - /// Injects a known value from external analysis. - /// - /// This allows passes to provide values discovered through other means - /// (e.g., from `ctx.known_values` in decryption). Injected values take - /// precedence over computed values. - /// - /// # Arguments - /// - /// * `var` - The variable to set. - /// * `value` - The known constant value. - pub fn set_known(&mut self, var: SsaVarId, value: ConstValue) { - self.cache.insert(var, Some(value)); - } - - /// Evaluates a variable to a constant if possible. - /// - /// Results are cached, so repeated calls with the same variable are O(1). - /// - /// # Arguments - /// - /// * `var` - The SSA variable to evaluate. - /// - /// # Returns - /// - /// The constant value if the variable can be evaluated, `None` otherwise. - pub fn evaluate_var(&mut self, var: SsaVarId) -> Option { - self.evaluate_var_depth(var, 0) - } - - /// Internal evaluation with depth tracking. - fn evaluate_var_depth(&mut self, var: SsaVarId, depth: usize) -> Option { - // Check depth limit - if depth > self.max_depth { - return None; - } - - // Check cache first - if let Some(cached) = self.cache.get(&var) { - return cached.clone(); - } - - // Cycle detection - if var.index() < self.visiting.len() && self.visiting.contains(var.index()) { - return None; - } - - // Mark as visiting - if var.index() < self.visiting.len() { - self.visiting.insert(var.index()); - } - - // Get definition and evaluate - let result = self - .ssa - .get_definition(var) - .and_then(|op| self.evaluate_op_depth(op, depth)); - - // Remove from visiting set - if var.index() < self.visiting.len() { - self.visiting.remove(var.index()); - } - - // Cache the result - self.cache.insert(var, result.clone()); - - result - } - - /// Evaluates an SSA operation to a constant if possible. - /// - /// # Arguments - /// - /// * `op` - The SSA operation to evaluate. - /// - /// # Returns - /// - /// The constant value if the operation can be evaluated, `None` otherwise. - pub fn evaluate_op(&mut self, op: &SsaOp) -> Option { - self.evaluate_op_depth(op, 0) - } - - /// Internal operation evaluation with depth tracking. - fn evaluate_op_depth(&mut self, op: &SsaOp, depth: usize) -> Option { - // Check depth limit - if depth > self.max_depth { - return None; - } - - // Copy needs recursive evaluation that the shared helper cannot provide, - // because it resolves a variable rather than performing arithmetic. - if let SsaOp::Copy { src, .. } = op { - return self.evaluate_var_depth(*src, depth.saturating_add(1)); - } - - let ptr_size = self.pointer_size; - evaluate_const_op( - op, - |var| self.evaluate_var_depth(var, depth.saturating_add(1)), - ptr_size, - ) - } - - /// Returns all computed constants. - /// - /// This consumes the evaluator and returns a map of all variables - /// that were successfully evaluated to constants. - #[must_use] - pub fn into_results(self) -> HashMap { - self.cache - .into_iter() - .filter_map(|(var, opt)| opt.map(|val| (var, val))) - .collect() - } - - /// Returns a reference to the SSA function being evaluated. - #[must_use] - pub fn ssa(&self) -> &SsaFunction { - self.ssa - } - - /// Clears the evaluation cache. - /// - /// This is useful if the SSA function has been modified and - /// cached results are no longer valid. - pub fn clear_cache(&mut self) { - self.cache.clear(); - } -} - -/// Evaluates an SSA operation to a constant value using the provided operand resolver. -/// -/// This is the shared arithmetic dispatch for constant evaluation. It handles all -/// pure arithmetic, bitwise, comparison, overflow-checked, and conversion operations. -/// Callers provide a `get_const` closure that resolves an [`SsaVarId`] to its constant -/// value (if known). -/// -/// # Operations not handled -/// -/// - `Copy` — requires variable-level resolution (trace-through), not arithmetic. -/// Callers should handle `Copy` before calling this function. -/// - Calls, loads, stores, and other side-effecting operations — always returns `None`. -/// -/// # Arguments -/// -/// * `op` - The SSA operation to evaluate. -/// * `get_const` - Closure that resolves a variable to its constant value. -/// * `ptr_size` - Target pointer size for native int/uint masking. -/// -/// # Returns -/// -/// The constant result if all operands resolve and the operation succeeds, `None` otherwise. -pub fn evaluate_const_op( - op: &SsaOp, - mut get_const: impl FnMut(SsaVarId) -> Option, - ptr_size: PointerSize, -) -> Option { - match op { - SsaOp::Const { value, .. } => Some(value.clone()), - - // Binary arithmetic - SsaOp::Add { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.add(&r, ptr_size) - } - SsaOp::Sub { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.sub(&r, ptr_size) - } - SsaOp::Mul { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.mul(&r, ptr_size) - } - SsaOp::Div { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.div(&r, ptr_size) - } - SsaOp::Rem { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.rem(&r, ptr_size) - } - - // Bitwise - SsaOp::Xor { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.bitwise_xor(&r, ptr_size) - } - SsaOp::And { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.bitwise_and(&r, ptr_size) - } - SsaOp::Or { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.bitwise_or(&r, ptr_size) - } - - // Shifts - SsaOp::Shl { value, amount, .. } => { - let v = get_const(*value)?; - let a = get_const(*amount)?; - v.shl(&a, ptr_size) - } - SsaOp::Shr { - value, - amount, - unsigned, - .. - } => { - let v = get_const(*value)?; - let a = get_const(*amount)?; - v.shr(&a, *unsigned, ptr_size) - } - - // Unary - SsaOp::Neg { operand, .. } => { - let v = get_const(*operand)?; - v.negate(ptr_size) - } - SsaOp::Not { operand, .. } => { - let v = get_const(*operand)?; - v.bitwise_not(ptr_size) - } - - // Comparisons - SsaOp::Ceq { left, right, .. } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.ceq(&r) - } - SsaOp::Clt { - left, - right, - unsigned, - .. - } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - if *unsigned { - l.clt_un(&r) - } else { - l.clt(&r) - } - } - SsaOp::Cgt { - left, - right, - unsigned, - .. - } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - if *unsigned { - l.cgt_un(&r) - } else { - l.cgt(&r) - } - } - - // Overflow-checked arithmetic - SsaOp::AddOvf { - left, - right, - unsigned, - .. - } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.add_checked(&r, *unsigned, ptr_size) - } - SsaOp::SubOvf { - left, - right, - unsigned, - .. - } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.sub_checked(&r, *unsigned, ptr_size) - } - SsaOp::MulOvf { - left, - right, - unsigned, - .. - } => { - let l = get_const(*left)?; - let r = get_const(*right)?; - l.mul_checked(&r, *unsigned, ptr_size) - } - - // Type conversion - SsaOp::Conv { - operand, - target, - overflow_check, - unsigned, - .. - } => { - let v = get_const(*operand)?; - if *overflow_check { - v.convert_to_checked(target, *unsigned, ptr_size) - } else { - v.convert_to(target, *unsigned, ptr_size) - } - } - - // All other operations cannot be evaluated to constants - _ => None, - } -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::ssa::{ - ConstEvaluator, ConstValue, DefSite, SsaBlock, SsaFunction, SsaInstruction, SsaOp, - SsaType, SsaVarId, VariableOrigin, - }, - metadata::typesystem::PointerSize, - }; - - #[test] - fn test_evaluate_constant() { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - - // Create a constant: v0 = 42 - let var_id = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: var_id, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - let result = evaluator.evaluate_var(var_id); - - assert_eq!(result, Some(ConstValue::I32(42))); - } - - #[test] - fn test_evaluate_copy_chain() { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - - // v0 = 100 - let v0_id = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - // v1 = v0 (copy) - let v1_id = ssa.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::instruction(0, 1), - SsaType::Unknown, - ); - - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v0_id, - value: ConstValue::I32(100), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Copy { - dest: v1_id, - src: v0_id, - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - let result = evaluator.evaluate_var(v1_id); - - assert_eq!(result, Some(ConstValue::I32(100))); - } - - #[test] - fn test_set_known_value() { - let ssa = SsaFunction::new(0, 0); - let var_id = SsaVarId::from_index(0); - - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - evaluator.set_known(var_id, ConstValue::I32(999)); - - let result = evaluator.evaluate_var(var_id); - assert_eq!(result, Some(ConstValue::I32(999))); - } - - #[test] - fn test_into_results() { - let ssa = SsaFunction::new(0, 0); - let var1 = SsaVarId::from_index(0); - let var2 = SsaVarId::from_index(1); - - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - evaluator.set_known(var1, ConstValue::I32(1)); - evaluator.set_known(var2, ConstValue::I32(2)); - - // Force evaluation to populate cache - evaluator.evaluate_var(var1); - evaluator.evaluate_var(var2); - - let results = evaluator.into_results(); - assert_eq!(results.len(), 2); - assert_eq!(results.get(&var1), Some(&ConstValue::I32(1))); - assert_eq!(results.get(&var2), Some(&ConstValue::I32(2))); - } -} diff --git a/dotscope/src/analysis/ssa/converter.rs b/dotscope/src/analysis/ssa/converter.rs index 33858e0a..8d6b4fe7 100644 --- a/dotscope/src/analysis/ssa/converter.rs +++ b/dotscope/src/analysis/ssa/converter.rs @@ -40,7 +40,7 @@ use crate::{ analysis::{ cfg::ControlFlowGraph, ssa::{ - decompose::decompose_instruction, liveness, phis::place_pruned_phis, + decompose::decompose_instruction, liveness, place_pruned_phis, resolve_corelib_valuetype, ConstValue, DefSite, PhiNode, SimulationResult, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, StackSimulator, StackSlot, StackSlotSource, TypeProvider, UseSite, VariableOrigin, @@ -53,12 +53,12 @@ use crate::{ token::Token, typesystem::CilTypeReference, }, - utils::{ - graph::{algorithms::DominatorTree, NodeId}, - BitSet, - }, CilObject, Error, Result, }; +use analyssa::{ + graph::{algorithms::DominatorTree, NodeId}, + BitSet, +}; /// A variable definition record during SSA construction. #[derive(Debug, Clone)] @@ -344,6 +344,25 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { | SsaOp::Shr { .. } | SsaOp::Neg { .. } | SsaOp::Not { .. } + // Rotate operations + | SsaOp::Rol { .. } + | SsaOp::Ror { .. } + | SsaOp::Rcl { .. } + | SsaOp::Rcr { .. } + // Bit manipulation operations + | SsaOp::BSwap { .. } + | SsaOp::BRev { .. } + | SsaOp::BitScanForward { .. } + | SsaOp::BitScanReverse { .. } + | SsaOp::Popcount { .. } + | SsaOp::Parity { .. } + // Atomic read-modify-write operations + | SsaOp::CmpXchg { .. } + | SsaOp::AtomicRmw { .. } + // Select produces values of varying types + | SsaOp::Select { .. } + // ReadFlags reads flag bits into an integer + | SsaOp::ReadFlags { .. } // Sizeof produces int32 | SsaOp::SizeOf { .. } => SsaType::I32, @@ -493,6 +512,10 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { | SsaOp::Volatile | SsaOp::Unaligned { .. } | SsaOp::TailPrefix + | SsaOp::BranchFlags { .. } + | SsaOp::Fence { .. } + | SsaOp::InterruptReturn + | SsaOp::Unreachable | SsaOp::Readonly => SsaType::Unknown, } } @@ -2780,6 +2803,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { mod tests { use super::*; + use std::collections::BTreeSet; + use crate::{ assembly::{decode_blocks, InstructionAssembler}, test::TestTypeProvider, @@ -3402,8 +3427,7 @@ mod tests { .expect("SSA construction failed"); // For each phi node, verify all operand values reference valid variables - let var_ids: std::collections::BTreeSet<_> = - ssa.variables().iter().map(|v| v.id()).collect(); + let var_ids: BTreeSet<_> = ssa.variables().iter().map(|v| v.id()).collect(); for phi in ssa.all_phi_nodes() { for operand in phi.operands() { @@ -3903,8 +3927,8 @@ mod tests { .expect("SSA construction failed"); // Collect all variables used in the SSA - let mut all_uses = std::collections::BTreeSet::new(); - let mut all_defs = std::collections::BTreeSet::new(); + let mut all_uses = BTreeSet::new(); + let mut all_defs = BTreeSet::new(); for block in ssa.blocks() { for instr in block.instructions() { diff --git a/dotscope/src/analysis/ssa/decompose.rs b/dotscope/src/analysis/ssa/decompose.rs index 19f9ad1b..90a6fba6 100644 --- a/dotscope/src/analysis/ssa/decompose.rs +++ b/dotscope/src/analysis/ssa/decompose.rs @@ -217,16 +217,19 @@ fn decompose_standard_instruction( dest, left, right, + flags: None, }), // add 0x59 => binary_op(uses, def, |dest, left, right| SsaOp::Sub { dest, left, right, + flags: None, }), // sub 0x5A => binary_op(uses, def, |dest, left, right| SsaOp::Mul { dest, left, right, + flags: None, }), // mul 0x5B => binary_op(uses, def, |dest, left, right| SsaOp::Div { // div @@ -234,6 +237,7 @@ fn decompose_standard_instruction( left, right, unsigned: false, + flags: None, }), 0x5C => binary_op(uses, def, |dest, left, right| SsaOp::Div { // div.un @@ -241,6 +245,7 @@ fn decompose_standard_instruction( left, right, unsigned: true, + flags: None, }), 0x5D => binary_op(uses, def, |dest, left, right| SsaOp::Rem { // rem @@ -248,6 +253,7 @@ fn decompose_standard_instruction( left, right, unsigned: false, + flags: None, }), 0x5E => binary_op(uses, def, |dest, left, right| SsaOp::Rem { // rem.un @@ -255,10 +261,19 @@ fn decompose_standard_instruction( left, right, unsigned: true, + flags: None, }), - 0x65 => unary_op(uses, def, |dest, operand| SsaOp::Neg { dest, operand }), // neg - 0x66 => unary_op(uses, def, |dest, operand| SsaOp::Not { dest, operand }), // not + 0x65 => unary_op(uses, def, |dest, operand| SsaOp::Neg { + dest, + operand, + flags: None, + }), // neg + 0x66 => unary_op(uses, def, |dest, operand| SsaOp::Not { + dest, + operand, + flags: None, + }), // not // Overflow checking arithmetic 0xD6 => binary_op(uses, def, |dest, left, right| SsaOp::AddOvf { @@ -267,6 +282,7 @@ fn decompose_standard_instruction( left, right, unsigned: false, + flags: None, }), 0xD7 => binary_op(uses, def, |dest, left, right| SsaOp::AddOvf { // add.ovf.un @@ -274,6 +290,7 @@ fn decompose_standard_instruction( left, right, unsigned: true, + flags: None, }), 0xD8 => binary_op(uses, def, |dest, left, right| SsaOp::MulOvf { // mul.ovf @@ -281,6 +298,7 @@ fn decompose_standard_instruction( left, right, unsigned: false, + flags: None, }), 0xD9 => binary_op(uses, def, |dest, left, right| SsaOp::MulOvf { // mul.ovf.un @@ -288,6 +306,7 @@ fn decompose_standard_instruction( left, right, unsigned: true, + flags: None, }), 0xDA => binary_op(uses, def, |dest, left, right| SsaOp::SubOvf { // sub.ovf @@ -295,6 +314,7 @@ fn decompose_standard_instruction( left, right, unsigned: false, + flags: None, }), 0xDB => binary_op(uses, def, |dest, left, right| SsaOp::SubOvf { // sub.ovf.un @@ -302,27 +322,32 @@ fn decompose_standard_instruction( left, right, unsigned: true, + flags: None, }), 0x5F => binary_op(uses, def, |dest, left, right| SsaOp::And { dest, left, right, + flags: None, }), // and 0x60 => binary_op(uses, def, |dest, left, right| SsaOp::Or { dest, left, right, + flags: None, }), // or 0x61 => binary_op(uses, def, |dest, left, right| SsaOp::Xor { dest, left, right, + flags: None, }), // xor 0x62 => binary_op(uses, def, |dest, value, amount| SsaOp::Shl { // shl dest, value, amount, + flags: None, }), 0x63 => binary_op(uses, def, |dest, value, amount| SsaOp::Shr { // shr @@ -330,6 +355,7 @@ fn decompose_standard_instruction( value, amount, unsigned: false, + flags: None, }), 0x64 => binary_op(uses, def, |dest, value, amount| SsaOp::Shr { // shr.un @@ -337,6 +363,7 @@ fn decompose_standard_instruction( value, amount, unsigned: true, + flags: None, }), 0x67 => unary_op(uses, def, |dest, operand| SsaOp::Conv { // conv.i1 @@ -1689,7 +1716,13 @@ mod tests { let op = decompose_instruction(&instr, &uses, def, &[], None); assert!(op.is_ok()); - if let Ok(SsaOp::Add { dest, left, right }) = op { + if let Ok(SsaOp::Add { + dest, + left, + right, + flags: None, + }) = op + { assert_eq!(dest, v2); assert_eq!(left, v0); assert_eq!(right, v1); @@ -1936,7 +1969,7 @@ mod tests { let def = Some(v1); let op = decompose_instruction(&instr, &uses, def, &[], None); - assert!(matches!(op, Ok(SsaOp::Neg { .. }))); + assert!(matches!(op, Ok(SsaOp::Neg { flags: None, .. }))); } #[test] @@ -2055,7 +2088,7 @@ mod tests { let def = Some(v1); let op = decompose_instruction(&instr, &uses, def, &[], None); - assert!(matches!(op, Ok(SsaOp::Not { .. }))); + assert!(matches!(op, Ok(SsaOp::Not { flags: None, .. }))); } #[test] diff --git a/dotscope/src/analysis/ssa/evaluator.rs b/dotscope/src/analysis/ssa/evaluator.rs deleted file mode 100644 index 95b40956..00000000 --- a/dotscope/src/analysis/ssa/evaluator.rs +++ /dev/null @@ -1,3098 +0,0 @@ -//! SSA evaluator for computing values. -//! -//! This module provides an interpreter for SSA operations that can evaluate -//! arithmetic and logical operations given known input values. It supports: -//! -//! - **Concrete values**: Known integer constants (fast, direct evaluation) -//! - **Symbolic values**: Expressions depending on unknown inputs (enables Z3 solving) -//! - **Unknown values**: Values that cannot be determined statically (represented as `None`) -//! -//! # Use Cases -//! -//! - Control flow unflattening (computing state transitions) -//! - Constant propagation verification -//! - Opaque predicate detection -//! - Symbolic execution of small code fragments -//! -//! # Design -//! -//! The evaluator operates directly on SSA form without needing full CIL emulation -//! infrastructure. Values are represented as [`SymbolicExpr`], where: -//! - `SymbolicExpr::Constant(v)` represents a known concrete value -//! - Other `SymbolicExpr` variants represent symbolic expressions -//! - `None` (absence from the value map) represents unknown values -//! -//! # CIL Semantics -//! -//! All arithmetic operations use 32-bit wrapping semantics as per ECMA-335. -//! Values are stored as i64 for convenience, but operations intentionally -//! truncate to i32/u32 to match CLR behavior. -//! -//! # Path-Aware Evaluation -//! -//! The evaluator supports path-aware phi node evaluation. When traversing a specific -//! path through the CFG, use [`set_predecessor`](SsaEvaluator::set_predecessor) before -//! evaluating a block to select the correct phi operand. -//! -//! # Usage -//! -//! ```rust,ignore -//! use dotscope::analysis::{SsaEvaluator, SymbolicExpr}; -//! -//! let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); -//! -//! // Set known concrete values -//! eval.set_concrete(state_var, initial_state); -//! -//! // Or mark as symbolic -//! eval.set_symbolic(arg_var, "arg0"); -//! -//! // Evaluate a block's instructions -//! eval.evaluate_block(block_idx); -//! -//! // Get computed result -//! match eval.get(result_var) { -//! Some(expr) if expr.is_constant() => println!("Known: {}", expr.as_constant().unwrap()), -//! Some(expr) => println!("Symbolic: {}", expr), -//! None => println!("Cannot determine"), -//! } -//! -//! // Or use convenience method for concrete values -//! if let Some(next_state) = eval.get_concrete(result_var) { -//! println!("Next state: {}", next_state); -//! } -//! ``` - -use std::collections::BTreeMap; - -use crate::{ - analysis::ssa::{ - constraints::Constraint, - memory::{MemoryLocation, MemoryState}, - symbolic::{SymbolicExpr, SymbolicOp}, - CmpKind, ConstValue, SsaFunction, SsaOp, SsaType, SsaVarId, VariableOrigin, - }, - metadata::typesystem::PointerSize, -}; - -/// Result of evaluating a control flow decision. -/// -/// This represents the outcome of analyzing a terminator instruction to -/// determine the next block to execute. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum ControlFlow { - /// Continue to the specified block. - Continue(usize), - /// Terminal instruction - execution ends here (return, throw, etc.). - Terminal, - /// Cannot determine the next block - condition is unknown or symbolic. - Unknown, -} - -impl ControlFlow { - /// Returns the target block if this is a `Continue` result. - #[must_use] - pub fn target(&self) -> Option { - match self { - Self::Continue(block) => Some(*block), - _ => None, - } - } - - /// Returns `true` if this is a terminal result. - #[must_use] - pub fn is_terminal(&self) -> bool { - matches!(self, Self::Terminal) - } - - /// Returns `true` if the control flow cannot be determined. - #[must_use] - pub fn is_unknown(&self) -> bool { - matches!(self, Self::Unknown) - } -} - -/// Configuration for SSA evaluators. -/// -/// This struct controls the behavior of [`SsaEvaluator`], allowing it to be -/// configured for different use cases like general evaluation, path-aware -/// analysis, or CFF deobfuscation. -#[derive(Debug, Clone, Default)] -pub struct EvaluatorConfig { - /// Track the execution path (sequence of visited blocks). - pub track_path: bool, - /// Track memory state (field loads/stores). - pub track_memory: bool, - /// Require predecessor for phi evaluation (strict path-aware mode). - pub strict_phi: bool, -} - -impl EvaluatorConfig { - /// Creates a new default configuration. - #[must_use] - pub fn new() -> Self { - Self::default() - } - - /// Creates a configuration for path-aware analysis. - /// - /// Enables path tracking, memory tracking, and strict phi evaluation. - #[must_use] - pub fn path_aware() -> Self { - Self { - track_path: true, - track_memory: true, - strict_phi: true, - } - } - - /// Creates a configuration with memory tracking only. - #[must_use] - pub fn with_memory() -> Self { - Self { - track_path: false, - track_memory: true, - strict_phi: false, - } - } - - /// Enables path tracking. - #[must_use] - pub fn with_path_tracking(mut self) -> Self { - self.track_path = true; - self - } - - /// Enables memory state tracking. - #[must_use] - pub fn with_memory_tracking(mut self) -> Self { - self.track_memory = true; - self - } - - /// Enables strict phi evaluation (requires predecessor). - #[must_use] - pub fn with_strict_phi(mut self) -> Self { - self.strict_phi = true; - self - } -} - -/// Records the execution trace of SSA evaluation. -/// -/// This struct tracks the sequence of blocks visited during SSA evaluation, -/// along with optional state values at each step. This is essential for -/// CFF (Control Flow Flattening) deobfuscation, where we need to record -/// the dispatcher state transitions to reconstruct the original control flow. -#[derive(Debug, Clone)] -pub struct ExecutionTrace { - /// Sequence of block indices visited. - blocks: Vec, - /// Optional state values captured at each block (for state machines). - states: Vec>, - /// Whether execution completed normally (reached terminal). - completed: bool, - /// Maximum blocks to trace before stopping (prevents infinite loops). - limit: usize, -} - -impl ExecutionTrace { - /// Creates a new execution trace with the given block limit. - #[must_use] - pub fn new(limit: usize) -> Self { - Self { - blocks: Vec::new(), - states: Vec::new(), - completed: false, - limit, - } - } - - /// Returns the blocks visited during execution. - #[must_use] - pub fn blocks(&self) -> &[usize] { - &self.blocks - } - - /// Returns the state values captured during execution. - #[must_use] - pub fn states(&self) -> &[Option] { - &self.states - } - - /// Returns `true` if execution completed (reached a terminal instruction). - #[must_use] - pub fn is_complete(&self) -> bool { - self.completed - } - - /// Returns the number of blocks visited. - #[must_use] - pub fn len(&self) -> usize { - self.blocks.len() - } - - /// Returns `true` if no blocks were visited. - #[must_use] - pub fn is_empty(&self) -> bool { - self.blocks.is_empty() - } - - /// Returns the last visited block, if any. - #[must_use] - pub fn last_block(&self) -> Option { - self.blocks.last().copied() - } - - /// Returns `true` if the trace reached the block limit. - #[must_use] - pub fn hit_limit(&self) -> bool { - self.blocks.len() >= self.limit - } - - /// Records a block visit. - fn record_block(&mut self, block_idx: usize, state: Option) { - self.blocks.push(block_idx); - self.states.push(state); - } - - /// Marks execution as complete. - fn mark_complete(&mut self) { - self.completed = true; - } -} - -/// SSA evaluator with hybrid concrete/symbolic value tracking. -/// -/// This evaluator interprets SSA operations to compute values without needing -/// full CIL emulation. Values are represented as [`SymbolicExpr`]: -/// -/// - **Concrete**: `SymbolicExpr::Constant(v)` - Known integer values -/// - **Symbolic**: Other `SymbolicExpr` variants - Expressions depending on unknown inputs -/// - **Unknown**: `None` (not in the values map) - Values that cannot be determined -/// -/// # Value Representation -/// -/// Values are represented as `i64` internally to accommodate both 32-bit and 64-bit -/// integer operations. For 32-bit operations, the evaluator applies appropriate -/// wrapping/truncation semantics. -#[derive(Debug, Clone)] -pub struct SsaEvaluator<'a> { - /// Reference to the SSA function being evaluated. - ssa: &'a SsaFunction, - /// Tracked values for variables. Missing entries represent unknown values. - values: BTreeMap, - /// Current predecessor block for path-aware phi evaluation. - /// When set, phi nodes will select the operand from this predecessor. - predecessor: Option, - /// Constraints on variable values derived from branch conditions. - /// Used to detect dead code and propagate information after branches. - constraints: BTreeMap>, - /// Evaluator configuration controlling behavior. - config: EvaluatorConfig, - /// Execution path (sequence of visited blocks). Only populated if `config.track_path`. - path: Vec, - /// Memory state tracking for fields. Only used if `config.track_memory`. - memory: MemoryState, - /// Target pointer size for native int/uint masking. - pointer_size: PointerSize, - /// Current value of CIL local variables, indexed by local_index. - /// Updated whenever a variable with `Local(N)` origin receives a value, - /// and read by `LoadLocal` instructions. - local_state: BTreeMap, - /// Current value of CIL arguments, indexed by arg_index. - /// Updated whenever a variable with `Argument(N)` origin receives a value, - /// and read by `LoadArg` instructions. - arg_state: BTreeMap, -} - -impl<'a> SsaEvaluator<'a> { - /// Creates a new evaluator for the given SSA function. - /// - /// The evaluator starts with no known values. Use the `set_*` methods - /// to provide initial values for input variables. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self { - Self::with_config(ssa, EvaluatorConfig::default(), ptr_size) - } - - /// Creates an evaluator with the specified configuration. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `config` - Configuration controlling evaluator behavior. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn with_config( - ssa: &'a SsaFunction, - config: EvaluatorConfig, - ptr_size: PointerSize, - ) -> Self { - Self { - ssa, - values: BTreeMap::new(), - predecessor: None, - constraints: BTreeMap::new(), - config, - path: Vec::new(), - memory: MemoryState::new(), - pointer_size: ptr_size, - local_state: BTreeMap::new(), - arg_state: BTreeMap::new(), - } - } - - /// Creates a path-aware evaluator with memory tracking. - /// - /// This is equivalent to `PathAwareEvaluator::with_memory_tracking()` and is - /// the recommended configuration for CFF deobfuscation. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn path_aware(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self { - Self::with_config(ssa, EvaluatorConfig::path_aware(), ptr_size) - } - - /// Creates an evaluator with pre-populated concrete values. - /// - /// Useful when you already have a set of known constants from SCCP or - /// other analyses. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `values` - Pre-populated concrete values. - /// * `ptr_size` - Target pointer size for native int/uint masking. - #[must_use] - pub fn with_values( - ssa: &'a SsaFunction, - values: BTreeMap, - ptr_size: PointerSize, - ) -> Self { - let exprs = values - .into_iter() - .map(|(k, v)| (k, SymbolicExpr::constant(v))) - .collect(); - Self { - ssa, - values: exprs, - predecessor: None, - constraints: BTreeMap::new(), - config: EvaluatorConfig::default(), - path: Vec::new(), - memory: MemoryState::new(), - pointer_size: ptr_size, - local_state: BTreeMap::new(), - arg_state: BTreeMap::new(), - } - } - - /// Returns the target pointer size. - #[must_use] - pub fn pointer_size(&self) -> PointerSize { - self.pointer_size - } - - /// Returns a reference to the underlying SSA function. - #[must_use] - pub fn ssa(&self) -> &SsaFunction { - self.ssa - } - - /// Returns a reference to the evaluator configuration. - #[must_use] - pub fn config(&self) -> &EvaluatorConfig { - &self.config - } - - /// Returns the execution path if path tracking is enabled. - #[must_use] - pub fn path(&self) -> &[usize] { - &self.path - } - - /// Clears the recorded execution path. - pub fn clear_path(&mut self) { - self.path.clear(); - } - - /// Returns whether memory tracking is enabled. - #[must_use] - pub fn memory_tracking_enabled(&self) -> bool { - self.config.track_memory - } - - // Value Setting - - /// Sets a concrete (known) value for a variable. - /// - /// The caller is responsible for providing the correct `ConstValue` type - /// that matches the variable's type in the SSA function. - pub fn set_concrete(&mut self, var: SsaVarId, value: ConstValue) { - let expr = SymbolicExpr::constant(value); - self.values.insert(var, expr.clone()); - self.track_origin_state(var, &expr); - } - - /// Sets a symbolic value for a variable using a named expression. - /// - /// This is useful for marking method arguments or other external inputs - /// as symbolic with descriptive names. - pub fn set_symbolic(&mut self, var: SsaVarId, name: impl Into) { - let expr = SymbolicExpr::named(name); - self.values.insert(var, expr.clone()); - self.track_origin_state(var, &expr); - } - - /// Sets a symbolic value for a variable using an expression. - pub fn set_symbolic_expr(&mut self, var: SsaVarId, expr: SymbolicExpr) { - self.values.insert(var, expr.clone()); - self.track_origin_state(var, &expr); - } - - /// Sets a variable as unknown by removing it from the values map. - pub fn set_unknown(&mut self, var: SsaVarId) { - self.values.remove(&var); - } - - /// Sets an expression for a variable. - pub fn set(&mut self, var: SsaVarId, value: SymbolicExpr) { - self.values.insert(var, value.clone()); - self.track_origin_state(var, &value); - } - - // Value Getting - - /// Gets the expression for a variable. - /// - /// Returns `None` if the variable hasn't been assigned a value (unknown). - #[must_use] - pub fn get(&self, var: SsaVarId) -> Option<&SymbolicExpr> { - self.values.get(&var) - } - - /// Gets the typed constant value for a variable, if it's a constant. - /// - /// Returns `None` if the variable is symbolic, unknown, or not set. - /// Use [`ConstValue`] methods to extract specific types (e.g., `as_i64()`, `as_i32()`). - #[must_use] - pub fn get_concrete(&self, var: SsaVarId) -> Option<&ConstValue> { - self.values.get(&var).and_then(SymbolicExpr::as_constant) - } - - /// Gets the symbolic expression for a variable, if it's not a constant. - #[must_use] - pub fn get_symbolic(&self, var: SsaVarId) -> Option<&SymbolicExpr> { - self.values.get(&var).filter(|e| !e.is_constant()) - } - - /// Checks if a variable has a concrete (constant) value. - #[must_use] - pub fn is_concrete(&self, var: SsaVarId) -> bool { - self.values.get(&var).is_some_and(SymbolicExpr::is_constant) - } - - /// Checks if a variable has a symbolic (non-constant) value. - #[must_use] - pub fn is_symbolic(&self, var: SsaVarId) -> bool { - self.values.get(&var).is_some_and(|e| !e.is_constant()) - } - - /// Checks if a variable is unknown (not in the values map). - #[must_use] - pub fn is_unknown(&self, var: SsaVarId) -> bool { - !self.values.contains_key(&var) - } - - /// Returns all tracked values as expressions. - #[must_use] - pub fn values(&self) -> &BTreeMap { - &self.values - } - - /// Returns all concrete values as a map of i64 values. - /// - /// This is useful for compatibility with code that expects `HashMap`. - /// Values that can't be converted to i64 are skipped. - #[must_use] - pub fn concrete_values(&self) -> BTreeMap { - self.values - .iter() - .filter_map(|(k, v)| v.as_i64().map(|c| (*k, c))) - .collect() - } - - /// Returns all concrete values as typed `ConstValue`. - #[must_use] - pub fn const_values(&self) -> BTreeMap { - self.values - .iter() - .filter_map(|(k, v)| v.as_constant().map(|c| (*k, c.clone()))) - .collect() - } - - /// Clears all tracked values. - pub fn clear(&mut self) { - self.values.clear(); - self.predecessor = None; - self.constraints.clear(); - } - - // Constraint Management - - /// Adds a constraint on a variable. - /// - /// If the constraint is an equality constraint, also sets the variable's value - /// to concrete. This allows constraint propagation to directly affect evaluation. - pub fn add_constraint(&mut self, var: SsaVarId, constraint: Constraint) { - // If it's an equality constraint, we can directly set the value - if let Constraint::Equal(ref v) = constraint { - let expr = SymbolicExpr::constant(v.clone()); - self.values.insert(var, expr.clone()); - self.track_origin_state(var, &expr); - } - - self.constraints.entry(var).or_default().push(constraint); - } - - /// Gets all constraints on a variable. - #[must_use] - pub fn constraints(&self, var: SsaVarId) -> &[Constraint] { - self.constraints.get(&var).map_or(&[], |v| v.as_slice()) - } - - /// Checks if a variable has any constraints. - #[must_use] - pub fn has_constraints(&self, var: SsaVarId) -> bool { - self.constraints.get(&var).is_some_and(|v| !v.is_empty()) - } - - /// Clears constraints for a specific variable. - pub fn clear_constraints(&mut self, var: SsaVarId) { - self.constraints.remove(&var); - } - - /// Applies constraints derived from taking a specific branch. - /// - /// When we know which branch was taken, we can derive facts about the condition - /// variable. For example, if we took the true branch of `if (ceq x, 5)`, we know x == 5. - /// - /// # Arguments - /// - /// * `condition` - The variable used as the branch condition - /// * `took_true_branch` - Whether we followed the true or false branch - /// - /// # Returns - /// - /// `true` if constraints were successfully derived, `false` otherwise. - pub fn apply_branch_constraint(&mut self, condition: SsaVarId, took_true_branch: bool) -> bool { - // Find the definition of the condition variable to understand what comparison it represents - let Some(ssa_var) = self.ssa.variable(condition) else { - return false; - }; - - let def_site = ssa_var.def_site(); - let Some(block) = self.ssa.block(def_site.block) else { - return false; - }; - - let Some(instr_idx) = def_site.instruction else { - return false; - }; - - let Some(instr) = block.instruction(instr_idx) else { - return false; - }; - - self.derive_constraints_from_comparison(instr.op(), took_true_branch) - } - - /// Derives constraints from a comparison operation. - fn derive_constraints_from_comparison(&mut self, op: &SsaOp, took_true_branch: bool) -> bool { - match op { - SsaOp::Ceq { left, right, .. } => { - // ceq: true branch means left == right, false means left != right - let left_val = self.get(*left).cloned(); - let right_val = self.get(*right).cloned(); - - if took_true_branch { - // left == right - match (&left_val, &right_val) { - (Some(l), None) => { - if let Some(v) = l.as_constant() { - self.add_constraint(*right, Constraint::Equal(v.clone())); - true - } else { - false - } - } - (None, Some(r)) => { - if let Some(v) = r.as_constant() { - self.add_constraint(*left, Constraint::Equal(v.clone())); - true - } else { - false - } - } - (Some(l), Some(r)) if l.as_constant() == r.as_constant() => { - // Both concrete and equal - constraint is satisfied - true - } - _ => false, - } - } else { - // left != right - match (&left_val, &right_val) { - (Some(l), None) => { - if let Some(v) = l.as_constant() { - self.add_constraint(*right, Constraint::NotEqual(v.clone())); - true - } else { - false - } - } - (None, Some(r)) => { - if let Some(v) = r.as_constant() { - self.add_constraint(*left, Constraint::NotEqual(v.clone())); - true - } else { - false - } - } - _ => false, - } - } - } - - SsaOp::Cgt { - left, - right, - unsigned, - .. - } => { - // cgt: true branch means left > right - let right_val = self.get(*right).and_then(|e| e.as_constant().cloned()); - - if took_true_branch { - // left > right - if let Some(v) = right_val { - if *unsigned { - self.add_constraint(*left, Constraint::GreaterThanUnsigned(v)); - } else { - self.add_constraint(*left, Constraint::GreaterThan(v)); - } - return true; - } - } else { - // left <= right - if let Some(v) = right_val { - self.add_constraint(*left, Constraint::LessOrEqual(v)); - return true; - } - } - false - } - - SsaOp::Clt { - left, - right, - unsigned, - .. - } => { - // clt: true branch means left < right - let right_val = self.get(*right).and_then(|e| e.as_constant().cloned()); - - if took_true_branch { - // left < right - if let Some(v) = right_val { - if *unsigned { - self.add_constraint(*left, Constraint::LessThanUnsigned(v)); - } else { - self.add_constraint(*left, Constraint::LessThan(v)); - } - return true; - } - } else { - // left >= right - if let Some(v) = right_val { - self.add_constraint(*left, Constraint::GreaterOrEqual(v)); - return true; - } - } - false - } - - _ => false, - } - } - - /// Checks if the current constraints imply that a condition is always true or false. - /// - /// This is useful for detecting dead code after branch conditions. - /// - /// # Returns - /// - /// - `Some(true)` if the condition is always true given current constraints - /// - `Some(false)` if the condition is always false given current constraints - /// - `None` if the condition cannot be determined - #[must_use] - pub fn evaluate_condition_with_constraints(&self, condition: SsaVarId) -> Option { - if let Some(v) = self.get_concrete(condition) { - return Some(!v.is_zero()); - } - - // Check if constraints imply a value - // For now, we handle the case where we have conflicting constraints - // which would indicate dead code - let ssa_var = self.ssa.variable(condition)?; - let def_site = ssa_var.def_site(); - let block = self.ssa.block(def_site.block)?; - let instr_idx = def_site.instruction?; - let instr = block.instruction(instr_idx)?; - self.check_condition_against_constraints(instr.op()) - } - - /// Checks if a comparison's result can be determined from constraints. - fn check_condition_against_constraints(&self, op: &SsaOp) -> Option { - match op { - SsaOp::Ceq { left, right, .. } => { - // Check if we know both operands are equal or not equal - let left_constraints = self.constraints(*left); - let right_val = self.get_concrete(*right)?; - - for constraint in left_constraints { - match constraint { - Constraint::Equal(v) => { - // v == right_val means ceq is true - return Some(v.ceq(right_val).is_some_and(|r| !r.is_zero())); - } - Constraint::NotEqual(v) - // If v == right_val, then left != right_val, so ceq is false - if v.ceq(right_val).is_some_and(|r| !r.is_zero()) => - { - return Some(false); - } - Constraint::GreaterThan(v) - // left > v, so if right_val <= v, then left != right_val - if right_val.cgt(v).is_none_or(|r| r.is_zero()) => - { - return Some(false); - } - Constraint::LessThan(v) - // left < v, so if right_val >= v, then left != right_val - if right_val.clt(v).is_none_or(|r| r.is_zero()) => - { - return Some(false); - } - _ => {} - } - } - None - } - - SsaOp::Cgt { left, right, .. } => { - let left_constraints = self.constraints(*left); - let right_val = self.get_concrete(*right)?; - - for constraint in left_constraints { - match constraint { - Constraint::GreaterThan(v) - // left > v, so if v >= right_val, then left > right_val - if v.clt(right_val).is_none_or(|r| r.is_zero()) => - { - return Some(true); - } - Constraint::LessOrEqual(v) - // left <= v, so if v <= right_val, then left <= right_val - if v.cgt(right_val).is_none_or(|r| r.is_zero()) => - { - return Some(false); - } - Constraint::LessThan(v) - // left < v, so if v <= right_val, then left < right_val <= right_val - if v.cgt(right_val).is_none_or(|r| r.is_zero()) => - { - return Some(false); - } - Constraint::Equal(v) => { - // left == v, so return v > right_val - return Some(v.cgt(right_val).is_some_and(|r| !r.is_zero())); - } - _ => {} - } - } - None - } - - SsaOp::Clt { left, right, .. } => { - let left_constraints = self.constraints(*left); - let right_val = self.get_concrete(*right)?; - - for constraint in left_constraints { - match constraint { - Constraint::LessThan(v) - // left < v, so if v <= right_val, then left < right_val - if v.cgt(right_val).is_none_or(|r| r.is_zero()) => - { - return Some(true); - } - Constraint::GreaterOrEqual(v) - // left >= v, so if v >= right_val, then left >= right_val - if v.clt(right_val).is_none_or(|r| r.is_zero()) => - { - return Some(false); - } - Constraint::GreaterThan(v) - // left > v, so if v >= right_val, then left > right_val >= right_val - if v.clt(right_val).is_none_or(|r| r.is_zero()) => - { - return Some(false); - } - Constraint::Equal(v) => { - // left == v, so return v < right_val - return Some(v.clt(right_val).is_some_and(|r| !r.is_zero())); - } - _ => {} - } - } - None - } - - _ => None, - } - } - - // Path-Aware Evaluation - - /// Sets the predecessor block for path-aware phi evaluation. - /// - /// When evaluating a block, phi nodes will select the operand that - /// corresponds to this predecessor. This enables accurate evaluation - /// when following a specific path through the CFG. - /// - /// # Arguments - /// - /// * `pred` - The predecessor block index, or `None` to clear. - pub fn set_predecessor(&mut self, pred: Option) { - self.predecessor = pred; - } - - /// Gets the current predecessor for phi evaluation. - #[must_use] - pub fn predecessor(&self) -> Option { - self.predecessor - } - - // Block Evaluation - - /// Evaluates all phi nodes in a block. - /// - /// **REQUIRES** a predecessor to be set via [`set_predecessor`](Self::set_predecessor). - /// If no predecessor is set, phi results will be `None` (removed from value map). - /// - /// # Phi Node Semantics - /// - /// Phi nodes execute "simultaneously" - all source values are read BEFORE any - /// results are written. This is critical for correct swap semantics: - /// - /// ```text - /// v1 = phi(v2 from pred) - /// v2 = phi(v1 from pred) - /// ``` - /// - /// This swaps v1 and v2. If we wrote v1 before reading for v2, we'd get the - /// wrong value. The implementation uses a two-phase approach: - /// 1. Read all source values into a temporary buffer - /// 2. Write all results from the buffer - pub fn evaluate_phis(&mut self, block_idx: usize) { - let Some(block) = self.ssa.block(block_idx) else { - return; - }; - - // Phase 1: Read all phi source values BEFORE any writes - // This ensures correct "simultaneous" phi semantics (no swap problem) - let phi_results: Vec<(SsaVarId, Option)> = block - .phi_nodes() - .iter() - .map(|phi| { - let result = phi.result(); - // REQUIRE predecessor - no fallback, no merging - let value = self.predecessor.and_then(|pred| { - phi.operands() - .iter() - .find(|op| op.predecessor() == pred) - .and_then(|op| self.values.get(&op.value()).cloned()) - }); - (result, value) - }) - .collect(); - - // Phase 2: Write all results - for (result, value) in phi_results { - if let Some(v) = value { - self.values.insert(result, v.clone()); - self.track_origin_state(result, &v); - } else { - // No predecessor or no operand from predecessor = no value - self.values.remove(&result); - } - } - } - - /// Evaluates all instructions in a block, updating tracked values. - /// - /// This evaluates phi nodes first (if predecessor is set), then - /// evaluates all other instructions in order. - pub fn evaluate_block(&mut self, block_idx: usize) { - // Record path if tracking is enabled - if self.config.track_path { - self.path.push(block_idx); - } - - // First evaluate phi nodes - self.evaluate_phis(block_idx); - - // Then evaluate instructions - let Some(block) = self.ssa.block(block_idx) else { - return; - }; - - for instr in block.instructions() { - self.evaluate_op(instr.op()); - } - } - - /// Evaluates a sequence of blocks in order. - /// - /// This is useful for evaluating a path through the CFG. - /// Note: This does not set predecessors automatically. - pub fn evaluate_blocks(&mut self, block_indices: &[usize]) { - for &block_idx in block_indices { - self.evaluate_block(block_idx); - } - } - - /// Evaluates a sequence of blocks along a path. - /// - /// For each block after the first, sets the predecessor to the previous - /// block before evaluation. This enables accurate phi node evaluation. - pub fn evaluate_path(&mut self, path: &[usize]) { - for (i, &block_idx) in path.iter().enumerate() { - if i > 0 { - if let Some(&prev) = path.get(i.saturating_sub(1)) { - self.set_predecessor(Some(prev)); - } - } - self.evaluate_block(block_idx); - } - } - - // Fixed-Point Iteration for Loops - - /// Evaluates a loop until values reach a fixed point. - /// - /// This is useful for analyzing loops where variable values may change each - /// iteration until they stabilize. The method iterates up to `max_iterations` - /// times, or until all tracked values stop changing. - /// - /// # Arguments - /// - /// * `loop_blocks` - The blocks that form the loop body (in execution order) - /// * `max_iterations` - Maximum number of iterations before giving up - /// - /// # Returns - /// - /// The number of iterations performed before reaching fixed point (or max). - pub fn evaluate_loop_to_fixpoint( - &mut self, - loop_blocks: &[usize], - max_iterations: usize, - ) -> usize { - if loop_blocks.is_empty() { - return 0; - } - - for iteration in 0..max_iterations { - // Snapshot current values - let snapshot: BTreeMap = self.values.clone(); - - // Evaluate all loop blocks - for (i, &block_idx) in loop_blocks.iter().enumerate() { - if i > 0 { - if let Some(&prev) = loop_blocks.get(i.saturating_sub(1)) { - self.set_predecessor(Some(prev)); - } - } else if loop_blocks.len() > 1 { - // First block - predecessor is the last block (loop back edge) - if let Some(&last) = loop_blocks.last() { - self.set_predecessor(Some(last)); - } - } - self.evaluate_block(block_idx); - } - - // Check if values changed - if self.values_match(&snapshot) { - return iteration.saturating_add(1); - } - } - - // Didn't reach fixed point - mark variables that changed as widened - self.widen_unstable_values(loop_blocks); - max_iterations - } - - /// Checks if current values match a snapshot. - fn values_match(&self, snapshot: &BTreeMap) -> bool { - self.values == *snapshot - } - - /// Widens values that didn't stabilize in a loop to Unknown. - /// - /// This is called when fixed-point iteration doesn't converge. Variables - /// defined in loop blocks that still have different values are marked Unknown. - fn widen_unstable_values(&mut self, loop_blocks: &[usize]) { - // Find all variables defined in the loop - for &block_idx in loop_blocks { - let Some(block) = self.ssa.block(block_idx) else { - continue; - }; - - // Mark phi results as unknown (they depend on loop iteration) - for phi in block.phi_nodes() { - self.values.remove(&phi.result()); - } - - // Check instructions for variables that might not have stabilized - for instr in block.instructions() { - // If this op defines a variable, consider widening it - if let Some(dest) = instr.op().dest() { - // Keep concrete values if they're stable, widen symbolic to unknown - if let Some(expr) = self.values.get(&dest) { - if !expr.is_constant() { - // Symbolic values that didn't stabilize become unknown - self.values.remove(&dest); - } - } - } - } - } - } - - /// Evaluates a loop with a specific iteration count. - /// - /// This is useful when you know exactly how many times a loop should run - /// (e.g., from a constant loop bound). - pub fn evaluate_loop_iterations(&mut self, loop_blocks: &[usize], iterations: usize) { - for _ in 0..iterations { - for (i, &block_idx) in loop_blocks.iter().enumerate() { - if i > 0 { - if let Some(&prev) = loop_blocks.get(i.saturating_sub(1)) { - self.set_predecessor(Some(prev)); - } - } - self.evaluate_block(block_idx); - } - } - } - - /// Evaluates a single SSA operation, updating tracked values. - /// - /// Returns the computed expression for operations that produce a result, - /// or `None` for operations without results (stores, branches, etc.) or - /// when the result is unknown. - pub fn evaluate_op(&mut self, op: &SsaOp) -> Option { - match op { - SsaOp::Const { dest, value } => { - let expr = SymbolicExpr::constant(value.clone()); - self.values.insert(*dest, expr.clone()); - self.track_origin_state(*dest, &expr); - Some(expr) - } - - SsaOp::Copy { dest, src } => { - let value = self.values.get(src).cloned(); - if let Some(v) = value { - self.values.insert(*dest, v.clone()); - self.track_origin_state(*dest, &v); - Some(v) - } else { - self.values.remove(dest); - None - } - } - - SsaOp::LoadLocal { dest, local_index } => { - let value = self.local_state.get(local_index).cloned(); - if let Some(v) = value { - self.values.insert(*dest, v.clone()); - Some(v) - } else { - self.values.remove(dest); - None - } - } - - SsaOp::LoadArg { dest, arg_index } => { - let value = self.arg_state.get(arg_index).cloned(); - if let Some(v) = value { - self.values.insert(*dest, v.clone()); - Some(v) - } else { - self.values.remove(dest); - None - } - } - - SsaOp::Add { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::Add) - } - - SsaOp::Sub { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::Sub) - } - - SsaOp::Mul { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::Mul) - } - - SsaOp::Div { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::DivU - } else { - SymbolicOp::DivS - }; - self.eval_binary_op(*dest, *left, *right, op) - } - - SsaOp::Rem { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::RemU - } else { - SymbolicOp::RemS - }; - self.eval_binary_op(*dest, *left, *right, op) - } - - SsaOp::Xor { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::Xor) - } - - SsaOp::And { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::And) - } - - SsaOp::Or { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::Or) - } - - SsaOp::Shl { - dest, - value, - amount, - } => self.eval_binary_op(*dest, *value, *amount, SymbolicOp::Shl), - - SsaOp::Shr { - dest, - value, - amount, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::ShrU - } else { - SymbolicOp::ShrS - }; - self.eval_binary_op(*dest, *value, *amount, op) - } - - SsaOp::Neg { dest, operand } => self.eval_unary_op(*dest, *operand, SymbolicOp::Neg), - - SsaOp::Not { dest, operand } => self.eval_unary_op(*dest, *operand, SymbolicOp::Not), - - SsaOp::Ceq { dest, left, right } => { - self.eval_binary_op(*dest, *left, *right, SymbolicOp::Eq) - } - - SsaOp::Cgt { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::GtU - } else { - SymbolicOp::GtS - }; - self.eval_binary_op(*dest, *left, *right, op) - } - - SsaOp::Clt { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::LtU - } else { - SymbolicOp::LtS - }; - self.eval_binary_op(*dest, *left, *right, op) - } - - SsaOp::Conv { - dest, - operand, - target, - unsigned, - .. - } => { - let value = self.values.get(operand).cloned(); - if let Some(expr) = value { - if let Some(v) = expr.as_i64() { - // Apply proper truncation/extension and store with correct type - let converted = self.apply_conversion(v, target, *unsigned); - let result = SymbolicExpr::constant(converted); - self.values.insert(*dest, result.clone()); - self.track_origin_state(*dest, &result); - Some(result) - } else { - // Symbolic/Unknown pass through (conversions don't change symbolic structure) - self.values.insert(*dest, expr.clone()); - self.track_origin_state(*dest, &expr); - Some(expr) - } - } else { - self.values.remove(dest); - None - } - } - - // Operations with Option dest that produce unknown results - SsaOp::Call { dest, .. } - | SsaOp::CallVirt { dest, .. } - | SsaOp::CallIndirect { dest, .. } => { - if let Some(d) = dest { - self.values.remove(d); - } - None - } - - // Memory operations (when tracking is enabled) - SsaOp::LoadStaticField { dest, field } => { - if self.config.track_memory { - let location = MemoryLocation::StaticField(*field); - if let Some(stored_var) = self.memory.load(&location) { - // Propagate the stored value - if let Some(expr) = self.values.get(&stored_var).cloned() { - self.values.insert(*dest, expr.clone()); - return Some(expr); - } - self.values - .insert(*dest, SymbolicExpr::variable(stored_var)); - return Some(SymbolicExpr::variable(stored_var)); - } - } - self.values.remove(dest); - None - } - - SsaOp::StoreStaticField { value, field } => { - if self.config.track_memory { - let location = MemoryLocation::StaticField(*field); - // Use 0 as version for simple tracking (version not critical for evaluation) - self.memory.store(location, *value, 0); - } - None - } - - SsaOp::LoadField { - dest, - object, - field, - } => { - if self.config.track_memory { - let location = MemoryLocation::InstanceField(*object, *field); - if let Some(stored_var) = self.memory.load(&location) { - if let Some(expr) = self.values.get(&stored_var).cloned() { - self.values.insert(*dest, expr.clone()); - return Some(expr); - } - self.values - .insert(*dest, SymbolicExpr::variable(stored_var)); - return Some(SymbolicExpr::variable(stored_var)); - } - } - self.values.remove(dest); - None - } - - SsaOp::StoreField { - object, - field, - value, - } => { - if self.config.track_memory { - let location = MemoryLocation::InstanceField(*object, *field); - self.memory.store(location, *value, 0); - } - None - } - - // Address-of-local/arg: the address is taken, meaning external code - // can write to this local/arg through the pointer. Invalidate the - // tracked state so subsequent LoadLocal/LoadArg returns Unknown - // instead of the stale pre-address-taken value. - // - // This is critical for patterns like `Monitor.Enter(obj, ref bool)` - // where the CLR writes `true` to the bool via the by-reference - // parameter. Without invalidation, the evaluator would see the - // initial value (false/0) and incorrectly fold branches that check - // the lock flag. - SsaOp::LoadLocalAddr { - dest, local_index, .. - } => { - self.values.remove(dest); - self.local_state.remove(local_index); - None - } - SsaOp::LoadArgAddr { - dest, arg_index, .. - } => { - self.values.remove(dest); - self.arg_state.remove(arg_index); - None - } - - // Operations with SsaVarId dest that produce unknown results - SsaOp::NewObj { dest, .. } - | SsaOp::NewArr { dest, .. } - | SsaOp::LoadElement { dest, .. } - | SsaOp::LoadIndirect { dest, .. } - | SsaOp::Box { dest, .. } - | SsaOp::Unbox { dest, .. } - | SsaOp::UnboxAny { dest, .. } - | SsaOp::CastClass { dest, .. } - | SsaOp::IsInst { dest, .. } - | SsaOp::ArrayLength { dest, .. } - | SsaOp::LoadToken { dest, .. } - | SsaOp::SizeOf { dest, .. } - | SsaOp::Ckfinite { dest, .. } - | SsaOp::LocalAlloc { dest, .. } - | SsaOp::LoadFunctionPtr { dest, .. } - | SsaOp::LoadVirtFunctionPtr { dest, .. } - | SsaOp::LoadFieldAddr { dest, .. } - | SsaOp::LoadStaticFieldAddr { dest, .. } - | SsaOp::LoadElementAddr { dest, .. } - | SsaOp::LoadObj { dest, .. } => { - self.values.remove(dest); - None - } - - // Operations without results (stores, branches, etc.) - _ => None, - } - } - - /// Helper to evaluate a binary operation. - fn eval_binary_op( - &mut self, - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - op: SymbolicOp, - ) -> Option { - let left_expr = self.values.get(&left)?; - let right_expr = self.values.get(&right)?; - - // Build expression and simplify (handles constant folding automatically) - let result = SymbolicExpr::binary(op, left_expr.clone(), right_expr.clone()) - .simplify(self.pointer_size); - - // Mask native int/uint results to target pointer width - let result = self.mask_symbolic_native(result); - - self.values.insert(dest, result.clone()); - self.track_origin_state(dest, &result); - Some(result) - } - - /// Helper to evaluate a unary operation. - fn eval_unary_op( - &mut self, - dest: SsaVarId, - operand: SsaVarId, - op: SymbolicOp, - ) -> Option { - let operand_expr = self.values.get(&operand)?; - - // Build expression and simplify - let result = SymbolicExpr::unary(op, operand_expr.clone()).simplify(self.pointer_size); - - // Mask native int/uint results to target pointer width - let result = self.mask_symbolic_native(result); - - self.values.insert(dest, result.clone()); - self.track_origin_state(dest, &result); - Some(result) - } - - /// Updates local/arg state when a variable with Local(N) or Argument(N) - /// origin receives a value. This enables `LoadLocal`/`LoadArg` to read - /// the most recently written value for the corresponding CIL local/arg. - fn track_origin_state(&mut self, var: SsaVarId, value: &SymbolicExpr) { - if let Some(ssa_var) = self.ssa.variable(var) { - match ssa_var.origin() { - VariableOrigin::Local(idx) => { - self.local_state.insert(idx, value.clone()); - } - VariableOrigin::Argument(idx) => { - self.arg_state.insert(idx, value.clone()); - } - _ => {} - } - } - } - - /// Masks a `SymbolicExpr` constant to the target pointer width if it contains - /// a `NativeInt` or `NativeUInt` value. - fn mask_symbolic_native(&self, expr: SymbolicExpr) -> SymbolicExpr { - if let Some(cv) = expr.as_constant() { - match cv { - ConstValue::NativeInt(_) | ConstValue::NativeUInt(_) => { - SymbolicExpr::constant(cv.clone().mask_native(self.pointer_size)) - } - _ => expr, - } - } else { - expr - } - } - - /// Applies a CIL type conversion to a value, returning the properly typed ConstValue. - /// - /// This handles truncation and sign/zero extension according to ECMA-335 semantics, - /// and returns a ConstValue with the correct type variant. - #[allow( - clippy::cast_possible_truncation, - clippy::cast_sign_loss, - clippy::cast_possible_wrap - )] - fn apply_conversion(&self, value: i64, target: &SsaType, unsigned: bool) -> ConstValue { - match target { - SsaType::I8 => { - if unsigned { - ConstValue::I8((value as u8) as i8) - } else { - ConstValue::I8(value as i8) - } - } - SsaType::U8 | SsaType::Bool => ConstValue::U8(value as u8), - SsaType::I16 => { - if unsigned { - ConstValue::I16((value as u16) as i16) - } else { - ConstValue::I16(value as i16) - } - } - SsaType::U16 => ConstValue::U16(value as u16), - SsaType::I32 => { - if unsigned { - ConstValue::I32((value as u32) as i32) - } else { - ConstValue::I32(value as i32) - } - } - SsaType::U32 => ConstValue::U32(value as u32), - SsaType::NativeInt => match self.pointer_size { - PointerSize::Bit32 => { - if unsigned { - ConstValue::NativeInt(i64::from((value as u32) as i32)) - } else { - ConstValue::NativeInt(i64::from(value as i32)) - } - } - PointerSize::Bit64 => ConstValue::NativeInt(value), - }, - SsaType::NativeUInt => match self.pointer_size { - PointerSize::Bit32 => ConstValue::NativeUInt(u64::from(value as u32)), - PointerSize::Bit64 => ConstValue::NativeUInt(value as u64), - }, - SsaType::U64 => ConstValue::U64(value as u64), - // Safe: precision loss is acceptable for integer-to-float conversion - #[allow(clippy::cast_precision_loss)] - SsaType::F32 => { - let float_val = if unsigned { - (value as u64) as f32 - } else { - value as f32 - }; - ConstValue::F32(float_val) - } - // Safe: precision loss is acceptable for integer-to-float conversion - #[allow(clippy::cast_precision_loss)] - SsaType::F64 => { - let float_val = if unsigned { - (value as u64) as f64 - } else { - value as f64 - }; - ConstValue::F64(float_val) - } - // For other types, default to I64 - _ => ConstValue::I64(value), - } - } - - /// Tries to resolve a variable's value by tracing back through its definition. - /// - /// This is useful when a variable's value depends on earlier computations - /// that haven't been evaluated yet. It recursively evaluates dependencies. - /// - /// # Arguments - /// - /// * `var` - The variable to resolve - /// * `max_depth` - Maximum recursion depth to prevent infinite loops - pub fn resolve_with_trace(&mut self, var: SsaVarId, max_depth: usize) -> Option { - // Already known? - if let Some(v) = self.values.get(&var) { - return Some(v.clone()); - } - - if max_depth == 0 { - return None; - } - - // Find the definition of this variable - let ssa_var = self.ssa.variable(var)?; - let def_site = ssa_var.def_site(); - let block = self.ssa.block(def_site.block)?; - // Is it defined by a phi node? Without path context, it's unknown - let instr_idx = def_site.instruction?; - let instr = block.instruction(instr_idx)?; - let op = instr.op(); - - // Recursively resolve operands first - for operand in op.uses() { - if !self.values.contains_key(&operand) { - if let Some(resolved) = - self.resolve_with_trace(operand, max_depth.saturating_sub(1)) - { - self.values.insert(operand, resolved); - } - } - } - - // Now evaluate this operation - self.evaluate_op(op) - } - - /// Tries to evaluate a variable by tracing back through its definition. - /// - /// Alias for [`resolve_with_trace`](Self::resolve_with_trace) that returns - /// `Option` for API compatibility. - pub fn evaluate_with_trace(&mut self, var: SsaVarId, max_depth: usize) -> Option { - self.resolve_with_trace(var, max_depth) - .and_then(|e| e.as_i64()) - } - - /// Determines the next block to execute based on the terminator of the given block. - /// - /// This is the core method for control flow analysis. It evaluates the terminating - /// instruction of a block and determines which block(s) execution should continue to. - /// - /// # Returns - /// - /// - `ControlFlow::Continue(block)` - Continue to the specified block - /// - `ControlFlow::Terminal` - No successor (return, throw, etc.) - /// - `ControlFlow::Unknown` - Cannot determine (condition is unknown/symbolic) - /// - /// # Example - /// - /// ```rust,ignore - /// let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - /// eval.set_concrete(state_var, initial_state); - /// eval.evaluate_block(0); - /// - /// match eval.next_block(0) { - /// ControlFlow::Continue(next) => { /* continue to next */ } - /// ControlFlow::Terminal => { /* execution ends */ } - /// ControlFlow::Unknown => { /* cannot determine */ } - /// } - /// ``` - #[must_use] - pub fn next_block(&self, block_idx: usize) -> ControlFlow { - let Some(block) = self.ssa.block(block_idx) else { - return ControlFlow::Unknown; - }; - - // Find the terminating instruction - let terminator = block - .instructions() - .iter() - .rev() - .find(|instr| instr.op().is_terminator()); - - let Some(instr) = terminator else { - // No terminator - fall through to next block if it exists - let next_idx = block_idx.saturating_add(1); - if next_idx < self.ssa.block_count() { - return ControlFlow::Continue(next_idx); - } - return ControlFlow::Unknown; - }; - - self.evaluate_control_flow(instr.op()) - } - - /// Evaluates a control flow operation to determine the next block. - /// - /// Uses typed `ConstValue` operations for comparisons and truthiness checks. - fn evaluate_control_flow(&self, op: &SsaOp) -> ControlFlow { - match op { - // Unconditional jumps - SsaOp::Jump { target } | SsaOp::Leave { target } => ControlFlow::Continue(*target), - - // Conditional branch (bool condition) - SsaOp::Branch { - condition, - true_target, - false_target, - } => match self.get_concrete(*condition) { - Some(v) => { - // Non-zero is true in CIL - if v.is_zero() { - ControlFlow::Continue(*false_target) - } else { - ControlFlow::Continue(*true_target) - } - } - None => ControlFlow::Unknown, - }, - - // Compare and branch - SsaOp::BranchCmp { - left, - right, - cmp, - unsigned, - true_target, - false_target, - } => { - let left_val = self.get_concrete(*left); - let right_val = self.get_concrete(*right); - - match (left_val, right_val) { - (Some(l), Some(r)) => { - let result = Self::evaluate_comparison(l, r, *cmp, *unsigned); - if result { - ControlFlow::Continue(*true_target) - } else { - ControlFlow::Continue(*false_target) - } - } - _ => ControlFlow::Unknown, - } - } - - // Switch - needs a non-negative integer index - SsaOp::Switch { - value, - targets, - default, - } => match self.get_concrete(*value).and_then(ConstValue::as_u64) { - Some(v) => { - #[allow(clippy::cast_possible_truncation)] - let idx = v as usize; - if let Some(&target) = targets.get(idx) { - ControlFlow::Continue(target) - } else { - ControlFlow::Continue(*default) - } - } - None => ControlFlow::Unknown, - }, - - // Terminal instructions - SsaOp::Return { .. } - | SsaOp::Throw { .. } - | SsaOp::Rethrow - | SsaOp::EndFinally - | SsaOp::EndFilter { .. } => ControlFlow::Terminal, - - // Not a control flow operation - _ => ControlFlow::Unknown, - } - } - - /// Evaluates a comparison between two typed constant values. - /// - /// Uses the typed comparison methods on `ConstValue` which properly - /// handle signedness based on the operand types. - pub fn evaluate_comparison( - left: &ConstValue, - right: &ConstValue, - cmp: CmpKind, - unsigned: bool, - ) -> bool { - match cmp { - CmpKind::Eq => left.ceq(right).is_some_and(|v| !v.is_zero()), - CmpKind::Ne => left.ceq(right).is_some_and(|v| v.is_zero()), - CmpKind::Lt => if unsigned { - left.clt_un(right) - } else { - left.clt(right) - } - .is_some_and(|v| !v.is_zero()), - CmpKind::Le => { - // x <= y is !(x > y) - if unsigned { - left.cgt_un(right) - } else { - left.cgt(right) - } - .is_some_and(|v| v.is_zero()) - } - CmpKind::Gt => if unsigned { - left.cgt_un(right) - } else { - left.cgt(right) - } - .is_some_and(|v| !v.is_zero()), - CmpKind::Ge => { - // x >= y is !(x < y) - if unsigned { - left.clt_un(right) - } else { - left.clt(right) - } - .is_some_and(|v| v.is_zero()) - } - } - } - - /// Executes the SSA function starting from a given block and records the trace. - /// - /// This method steps through the SSA, evaluating each block and following - /// control flow decisions based on computed values. It records the sequence - /// of blocks visited and optionally captures state values at each step. - /// - /// # Arguments - /// - /// * `start_block` - The block to start execution from - /// * `state_var` - Optional variable to capture state values (for CFF analysis) - /// * `max_steps` - Maximum number of blocks to visit (prevents infinite loops) - /// - /// # Returns - /// - /// An [`ExecutionTrace`] containing the visited blocks and state values. - /// - /// # Example - /// - /// ```rust,ignore - /// let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - /// eval.set_concrete(state_var, initial_state); - /// - /// let trace = eval.execute(0, Some(state_var), 1000); - /// for (block, state) in trace.blocks().iter().zip(trace.states()) { - /// println!("Block {}: state = {:?}", block, state); - /// } - /// ``` - pub fn execute( - &mut self, - start_block: usize, - state_var: Option, - max_steps: usize, - ) -> ExecutionTrace { - let mut trace = ExecutionTrace::new(max_steps); - let mut current_block = start_block; - - loop { - // Check if we've hit the limit - if trace.hit_limit() { - break; - } - - // Record the current state before evaluation - let state = state_var.and_then(|v| self.get_concrete(v).cloned()); - trace.record_block(current_block, state); - - // Set predecessor for phi evaluation - if let Some(prev) = trace.blocks().iter().rev().nth(1) { - self.set_predecessor(Some(*prev)); - } - - // Evaluate the block - self.evaluate_block(current_block); - - // Determine next block - match self.next_block(current_block) { - ControlFlow::Continue(next) => { - current_block = next; - } - ControlFlow::Terminal => { - trace.mark_complete(); - break; - } - ControlFlow::Unknown => { - // Can't determine next block - stop execution - break; - } - } - } - - trace - } - - /// Executes starting from block 0 with default settings. - /// - /// This is a convenience method for simple cases where you want to execute - /// from the entry block without state tracking. - pub fn execute_from_entry(&mut self, max_steps: usize) -> ExecutionTrace { - self.execute(0, None, max_steps) - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::analysis::ssa::{PhiNode, PhiOperand, SsaBlock, SsaInstruction, VariableOrigin}; - use crate::analysis::{SsaFunctionBuilder, SsaType}; - - #[test] - fn test_const_evaluation() { - let (ssa, v0) = { - let mut v0_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - v0_out = b.const_i32(42); - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - assert_eq!(eval.get_concrete(v0).and_then(ConstValue::as_i32), Some(42)); - } - - #[test] - fn test_add_evaluation() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(10); - let v1 = b.const_i32(32); - v2_out = b.add(v0, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - assert_eq!(eval.get_concrete(v2).and_then(ConstValue::as_i32), Some(42)); - } - - #[test] - fn test_xor_mul_pattern() { - // Test the typical ConfuserEx-style state computation: - // next_state = (current_state * mul_const) ^ xor_const - let (ssa, current_state, next_state) = { - let mut current_state_out = SsaVarId::from_index(0); - let mut next_state_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - // Use an argument as current_state (so it's "external" input) - let current_state = f.arg(0, SsaType::I32); - current_state_out = current_state; - f.block(0, |b| { - let mul_const = b.const_i32(785121953); - let xor_const = b.const_i32(-934590555); - let mul_result = b.mul(current_state, mul_const); - next_state_out = b.xor(mul_result, xor_const); - b.ret(); - }); - }) - .unwrap(); - (ssa, current_state_out, next_state_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - // Set the input state (this would come from the dispatcher) - eval.set_concrete(current_state, ConstValue::I32(120931986)); // The XORed state value - - eval.evaluate_block(0); - - // Verify we can compute the next state - assert!(eval.get_concrete(next_state).is_some()); - } - - #[test] - fn test_rem_un_evaluation() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(120931986); // Some large positive number - let v1 = b.const_i32(13); - v2_out = b.rem_un(v0, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - // 120931986 % 13 = 6 - assert_eq!(eval.get_concrete(v2).and_then(ConstValue::as_i32), Some(6)); - } - - #[test] - fn test_wrapping_mul() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - // Test overflow wrapping - let v0 = b.const_i32(i32::MAX); - let v1 = b.const_i32(2); - v2_out = b.mul(v0, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - // Should wrap around - assert_eq!( - eval.get_concrete(v2).and_then(ConstValue::as_i32), - Some(i32::MAX.wrapping_mul(2)) - ); - } - - #[test] - fn test_comparison_ops() { - let (ssa, ceq_result, clt_result, cgt_result) = { - let mut ceq_out = SsaVarId::from_index(0); - let mut clt_out = SsaVarId::from_index(1); - let mut cgt_out = SsaVarId::from_index(2); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(5); - let v1 = b.const_i32(10); - ceq_out = b.ceq(v0, v1); - clt_out = b.clt(v0, v1); - cgt_out = b.cgt(v0, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, ceq_out, clt_out, cgt_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - assert_eq!( - eval.get_concrete(ceq_result).and_then(ConstValue::as_i32), - Some(0) - ); // 5 != 10 - assert_eq!( - eval.get_concrete(clt_result).and_then(ConstValue::as_i32), - Some(1) - ); // 5 < 10 - assert_eq!( - eval.get_concrete(cgt_result).and_then(ConstValue::as_i32), - Some(0) - ); // 5 !> 10 - } - - #[test] - fn test_set_value_manual() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - - let v0 = SsaVarId::from_index(0); - eval.set_concrete(v0, ConstValue::I32(12345)); - - assert_eq!( - eval.get_concrete(v0).and_then(ConstValue::as_i32), - Some(12345) - ); - assert_eq!( - eval.get_concrete(v0).and_then(ConstValue::as_i32), - Some(12345) - ); - } - - #[test] - fn test_unknown_operand_returns_unknown() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - // Use an argument as unknown variable (it won't have a value set) - let unknown = f.arg(0, SsaType::I32); - f.block(0, |b| { - let v1 = b.const_i32(10); - v2_out = b.add(unknown, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - // Result should be unknown (not set) - assert!(eval.is_unknown(v2)); - assert_eq!(eval.get_concrete(v2), None); - } - - #[test] - fn test_symbolic_evaluation() { - let (ssa, arg0, v2) = { - let mut arg0_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - arg0_out = f.arg(0, SsaType::I32); - f.block(0, |b| { - let v1 = b.const_i32(10); - v2_out = b.add(arg0_out, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, arg0_out, v2_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - // Mark the argument as symbolic - eval.set_symbolic(arg0, "arg0"); - eval.evaluate_block(0); - - // Result should be symbolic (arg0 + 10) - assert!(eval.is_symbolic(v2)); - let expr = eval.get(v2).unwrap(); - // Check the expression contains our named variable - assert!(format!("{}", expr).contains("arg0")); - } - - #[test] - fn test_xor_rem_pattern() { - // Test the ConfuserEx-style dispatch computation: - // switch_idx = (state ^ xor_const) % modulo - let (ssa, state_var, result_var) = { - let mut state_out = SsaVarId::from_index(0); - let mut result_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - state_out = f.arg(0, SsaType::I32); - f.block(0, |b| { - let xor_const = b.const_i32(-557527955); - let modulo = b.const_i32(13); - let xored = b.xor(state_out, xor_const); - result_out = b.rem_un(xored, modulo); - b.ret(); - }); - }) - .unwrap(); - (ssa, state_out, result_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - // Mark state as symbolic - eval.set_symbolic(state_var, "state"); - eval.evaluate_block(0); - - // Result should be symbolic - assert!(eval.is_symbolic(result_var)); - - // Now with concrete value - let mut eval2 = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval2.set_concrete(state_var, ConstValue::I32(-638481665_i32)); - eval2.evaluate_block(0); - - // Result should be concrete and correct - assert!(eval2.is_concrete(result_var)); - assert_eq!( - eval2.get_concrete(result_var).and_then(ConstValue::as_i32), - Some(6) - ); - } - - #[test] - fn test_mixed_operations() { - // Test: (arg0 * const1) ^ const2 - let (ssa, arg0, result) = { - let mut arg0_out = SsaVarId::from_index(0); - let mut result_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - arg0_out = f.arg(0, SsaType::I32); - f.block(0, |b| { - let const1 = b.const_i32(785121953); - let const2 = b.const_i32(-934590555); - let mul_result = b.mul(arg0_out, const1); - result_out = b.xor(mul_result, const2); - b.ret(); - }); - }) - .unwrap(); - (ssa, arg0_out, result_out) - }; - - // With symbolic input - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.set_symbolic(arg0, "state"); - eval.evaluate_block(0); - assert!(eval.is_symbolic(result)); - - // With concrete input - should produce concrete result - let mut eval2 = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval2.set_concrete(arg0, ConstValue::I32(120931986)); - eval2.evaluate_block(0); - assert!(eval2.is_concrete(result)); - } - - #[test] - fn test_with_values_constructor() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let mut initial = BTreeMap::new(); - initial.insert(v0, ConstValue::I64(42)); - initial.insert(v1, ConstValue::I64(100)); - - let eval = SsaEvaluator::with_values(&ssa, initial, PointerSize::Bit64); - - assert_eq!(eval.get_concrete(v0).and_then(ConstValue::as_i64), Some(42)); - assert_eq!( - eval.get_concrete(v1).and_then(ConstValue::as_i64), - Some(100) - ); - } - - #[test] - fn test_concrete_values_extraction() { - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - let arg0 = f.arg(0, SsaType::I32); - v0_out = arg0; - f.block(0, |b| { - v1_out = b.const_i32(42); - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.set_symbolic(v0, "arg0"); // symbolic - eval.evaluate_block(0); // v1 = 42 (concrete) - - let concrete = eval.concrete_values(); - assert!(!concrete.contains_key(&v0)); // symbolic not included - assert_eq!(concrete.get(&v1), Some(&42)); // concrete included - } - - #[test] - fn test_conversion_truncation() { - // Test that conversions properly truncate values - let (ssa, v1) = { - let mut v1_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(256); // Larger than byte - v1_out = b.conv_un(v0, SsaType::U8); - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - // 256 truncated to u8 should be 0 - assert_eq!(eval.get_concrete(v1).and_then(ConstValue::as_i32), Some(0)); - } - - #[test] - fn test_conversion_sign_extension() { - // Test that signed conversions sign-extend - let (ssa, v1) = { - let mut v1_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(255); // 0xFF as u8, -1 as i8 - v1_out = b.conv(v0, SsaType::I8); - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - - // 255 as i8 is -1, sign-extended to i32 - assert_eq!(eval.get_concrete(v1).and_then(ConstValue::as_i32), Some(-1)); - } - - #[test] - fn test_constraint_equal() { - // Test that Equal constraint propagates value - let (ssa, arg0) = { - let mut arg0_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - arg0_out = f.arg(0, SsaType::I32); - f.block(0, |b| { - b.ret(); - }); - }) - .unwrap(); - (ssa, arg0_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - assert!(eval.is_unknown(arg0)); - - // Add equality constraint - eval.add_constraint(arg0, Constraint::Equal(ConstValue::I32(42))); - - // Now we should know the value - assert!(eval.is_concrete(arg0)); - assert_eq!( - eval.get_concrete(arg0).and_then(ConstValue::as_i32), - Some(42) - ); - } - - #[test] - fn test_constraint_conflicts() { - // Test constraint conflict detection - let c1 = Constraint::Equal(ConstValue::I32(5)); - let c2 = Constraint::Equal(ConstValue::I32(10)); - assert!(c1.conflicts_with(&c2, PointerSize::Bit64)); - - let c3 = Constraint::NotEqual(ConstValue::I32(5)); - assert!(c1.conflicts_with(&c3, PointerSize::Bit64)); - - let c4 = Constraint::GreaterThan(ConstValue::I32(5)); - assert!(c1.conflicts_with(&c4, PointerSize::Bit64)); // 5 is not > 5 - - let c5 = Constraint::GreaterThan(ConstValue::I32(4)); - assert!(!c1.conflicts_with(&c5, PointerSize::Bit64)); // 5 > 4 is ok - } - - #[test] - fn test_evaluate_loop_to_fixpoint() { - // Test loop fixed-point iteration with a simple loop structure: - // B0: entry, jump to B1 - // B1: header with computation, branch to B2 or B3 - // B2: body, jump back to B1 (back edge) - // B3: exit, ret - - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry with initial value - f.block(0, |b| { - v0_out = b.const_i32(0); - b.jump(1); - }); - // B1: header with conditional - f.block(1, |b| { - // Increment the value (simulating an induction variable) - let one = b.const_i32(1); - v1_out = b.add(v0_out, one); - let cond = b.const_true(); - b.branch(cond, 2, 3); - }); - // B2: body, jump back - f.block(2, |b| b.jump(1)); - // B3: exit - f.block(3, |b| b.ret()); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - - // First evaluate entry block to establish initial values - eval.evaluate_block(0); - assert_eq!(eval.get_concrete(v0).and_then(ConstValue::as_i32), Some(0)); - - // Now evaluate loop blocks with fixed-point iteration - // Loop body is blocks 1 and 2 - let iterations = eval.evaluate_loop_to_fixpoint(&[1, 2], 5); - - // Should terminate (either reaching fixed point or max iterations) - assert!(iterations > 0); - assert!(iterations <= 5); - - // Value v1 should have been computed (0 + 1 = 1) - assert_eq!(eval.get_concrete(v1).and_then(ConstValue::as_i32), Some(1)); - } - - #[test] - fn test_evaluate_loop_to_fixpoint_empty() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - - // Empty loop blocks should return 0 iterations - let iterations = eval.evaluate_loop_to_fixpoint(&[], 10); - assert_eq!(iterations, 0); - } - - #[test] - fn test_phi_simple_path_aware() { - // Test basic phi evaluation with predecessor - // - // B0: v0 = 10, jump B2 - // B1: v1 = 20, jump B2 - // B2: v2 = phi(v0 from B0, v1 from B1), ret - // - // Coming from B0, v2 should be 10 - // Coming from B1, v2 should be 20 - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - - let mut ssa = SsaFunction::new(0, 0); - - // Block 0: v0 = 10, jump B2 - let mut b0 = SsaBlock::new(0); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v0, - value: ConstValue::I32(10), - })); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(b0); - - // Block 1: v1 = 20, jump B2 - let mut b1 = SsaBlock::new(1); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(20), - })); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(b1); - - // Block 2: v2 = phi(v0 from B0, v1 from B1), ret - let mut b2 = SsaBlock::new(2); - let mut phi = PhiNode::new(v2, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); - phi.add_operand(PhiOperand::new(v1, 1)); - b2.add_phi(phi); - b2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) })); - ssa.add_block(b2); - - // Evaluate B0 first to set v0 - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - assert_eq!(eval.get_concrete(v0).and_then(ConstValue::as_i32), Some(10)); - - // Now evaluate B2 coming from B0 - eval.set_predecessor(Some(0)); - eval.evaluate_block(2); - assert_eq!( - eval.get_concrete(v2).and_then(ConstValue::as_i32), - Some(10), - "v2 should be 10 when coming from B0" - ); - - // Fresh evaluator: evaluate B1 first to set v1 - let mut eval2 = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval2.evaluate_block(1); - assert_eq!( - eval2.get_concrete(v1).and_then(ConstValue::as_i32), - Some(20) - ); - - // Now evaluate B2 coming from B1 - eval2.set_predecessor(Some(1)); - eval2.evaluate_block(2); - assert_eq!( - eval2.get_concrete(v2).and_then(ConstValue::as_i32), - Some(20), - "v2 should be 20 when coming from B1" - ); - } - - #[test] - fn test_phi_swap_semantics() { - // CRITICAL TEST: Verify phi nodes execute "simultaneously" - // - // This is the "swap problem" - phi nodes must read all values BEFORE - // writing any results. - // - // B0: v1 = 10, v2 = 20, jump B1 - // B1: v1' = phi(v2 from B0), v2' = phi(v1 from B0), ret - // - // After evaluating B1 coming from B0: - // v1' should be 20 (original v2) - // v2' should be 10 (original v1) - // - // If phi nodes executed sequentially (bug), we'd get: - // v1' = 20 (correct) - // v2' = 20 (WRONG - read v1' instead of v1) - - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let v1_prime = SsaVarId::from_index(2); - let v2_prime = SsaVarId::from_index(3); - - let mut ssa = SsaFunction::new(0, 0); - - // Block 0: v1 = 10, v2 = 20, jump B1 - let mut b0 = SsaBlock::new(0); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(10), - })); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v2, - value: ConstValue::I32(20), - })); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - ssa.add_block(b0); - - // Block 1: phi nodes that swap v1 and v2 - let mut b1 = SsaBlock::new(1); - // v1' = phi(v2 from B0) - reads v2 - let mut phi1 = PhiNode::new(v1_prime, VariableOrigin::Local(0)); - phi1.add_operand(PhiOperand::new(v2, 0)); - // v2' = phi(v1 from B0) - reads v1 - let mut phi2 = PhiNode::new(v2_prime, VariableOrigin::Local(1)); - phi2.add_operand(PhiOperand::new(v1, 0)); - b1.add_phi(phi1); - b1.add_phi(phi2); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(b1); - - // Evaluate B0 to set v1=10, v2=20 - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - assert_eq!(eval.get_concrete(v1).and_then(ConstValue::as_i32), Some(10)); - assert_eq!(eval.get_concrete(v2).and_then(ConstValue::as_i32), Some(20)); - - // Now evaluate B1 coming from B0 - this should swap the values - eval.set_predecessor(Some(0)); - eval.evaluate_block(1); - - // CRITICAL ASSERTIONS: values should be swapped - assert_eq!( - eval.get_concrete(v1_prime).and_then(ConstValue::as_i32), - Some(20), - "v1' should be 20 (swapped from v2)" - ); - assert_eq!( - eval.get_concrete(v2_prime).and_then(ConstValue::as_i32), - Some(10), - "v2' should be 10 (swapped from v1)" - ); - } - - #[test] - fn test_phi_triple_rotate() { - // Test a three-way rotation: a, b, c = c, a, b - // - // B0: a = 1, b = 2, c = 3, jump B1 - // B1: a' = phi(c), b' = phi(a), c' = phi(b), ret - // - // After: a' = 3, b' = 1, c' = 2 - - let a = SsaVarId::from_index(0); - let b = SsaVarId::from_index(1); - let c = SsaVarId::from_index(2); - let a_prime = SsaVarId::from_index(3); - let b_prime = SsaVarId::from_index(4); - let c_prime = SsaVarId::from_index(5); - - let mut ssa = SsaFunction::new(0, 0); - - // Block 0: a = 1, b = 2, c = 3 - let mut blk0 = SsaBlock::new(0); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: a, - value: ConstValue::I32(1), - })); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: b, - value: ConstValue::I32(2), - })); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: c, - value: ConstValue::I32(3), - })); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - ssa.add_block(blk0); - - // Block 1: rotate a, b, c = c, a, b - let mut blk1 = SsaBlock::new(1); - let mut phi_a = PhiNode::new(a_prime, VariableOrigin::Local(0)); - phi_a.add_operand(PhiOperand::new(c, 0)); - blk1.add_phi(phi_a); - let mut phi_b = PhiNode::new(b_prime, VariableOrigin::Local(1)); - phi_b.add_operand(PhiOperand::new(a, 0)); - blk1.add_phi(phi_b); - let mut phi_c = PhiNode::new(c_prime, VariableOrigin::Local(2)); - phi_c.add_operand(PhiOperand::new(b, 0)); - blk1.add_phi(phi_c); - blk1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(blk1); - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - eval.set_predecessor(Some(0)); - eval.evaluate_block(1); - - assert_eq!( - eval.get_concrete(a_prime).and_then(ConstValue::as_i32), - Some(3), - "a' should be 3 (from c)" - ); - assert_eq!( - eval.get_concrete(b_prime).and_then(ConstValue::as_i32), - Some(1), - "b' should be 1 (from a)" - ); - assert_eq!( - eval.get_concrete(c_prime).and_then(ConstValue::as_i32), - Some(2), - "c' should be 2 (from b)" - ); - } - - #[test] - fn test_phi_self_reference_blocked() { - // Test that phi reading from itself doesn't cause issues - // (This would be a malformed SSA but we should handle it gracefully) - // - // B0: v1 = 10, jump B1 - // B1: v2 = phi(v1 from B0, v2 from B1), ret - // - // Coming from B0: v2 = 10 - // Coming from B1: v2 should remain as its current value (or unknown if not set) - - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - - let mut ssa = SsaFunction::new(0, 0); - - // Block 0: v1 = 10, jump B1 - let mut blk0 = SsaBlock::new(0); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(10), - })); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - ssa.add_block(blk0); - - // Block 1: v2 = phi(v1 from B0, v2 from B1) - let mut blk1 = SsaBlock::new(1); - let mut phi = PhiNode::new(v2, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 0)); - phi.add_operand(PhiOperand::new(v2, 1)); - blk1.add_phi(phi); - blk1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) })); - ssa.add_block(blk1); - - // Coming from B0 - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - eval.set_predecessor(Some(0)); - eval.evaluate_block(1); - assert_eq!( - eval.get_concrete(v2).and_then(ConstValue::as_i32), - Some(10), - "v2 should be 10 from B0" - ); - - // Coming from B1 (self-reference) - v2 should keep its value - eval.set_predecessor(Some(1)); - eval.evaluate_block(1); - assert_eq!( - eval.get_concrete(v2).and_then(ConstValue::as_i32), - Some(10), - "v2 should still be 10 (self-reference)" - ); - } - - #[test] - fn test_phi_merge_same_values() { - // Test phi evaluation with predecessor set (Phase 1: no phi merging without predecessor) - // - // B0: v0 = 42, jump B2 - // B1: v1 = 42, jump B2 - // B2: v2 = phi(v0 from B0, v1 from B1), ret - // - // With predecessor set to B0, should get 42 from v0 - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - - let mut ssa = SsaFunction::new(0, 0); - - let mut blk0 = SsaBlock::new(0); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v0, - value: ConstValue::I32(42), - })); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(blk0); - - let mut blk1 = SsaBlock::new(1); - blk1.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(42), - })); - blk1.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(blk1); - - let mut blk2 = SsaBlock::new(2); - let mut phi = PhiNode::new(v2, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); - phi.add_operand(PhiOperand::new(v1, 1)); - blk2.add_phi(phi); - blk2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) })); - ssa.add_block(blk2); - - // Phase 1: Require predecessor context for phi evaluation - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - eval.evaluate_block(1); - - // Without predecessor set - should NOT merge (returns None) - eval.evaluate_block(2); - assert_eq!( - eval.get_concrete(v2), - None, - "phi should NOT merge without predecessor (Phase 1 behavior)" - ); - - // With predecessor set - should get 42 from v0 - let mut eval2 = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval2.evaluate_block(0); - eval2.set_predecessor(Some(0)); - eval2.evaluate_block(2); - assert_eq!( - eval2.get_concrete(v2).and_then(ConstValue::as_i32), - Some(42), - "phi should get 42 from predecessor B0" - ); - } - - #[test] - fn test_phi_merge_different_values_unknown() { - // Test phi merge when operands have different values (no predecessor set) - // - // B0: v0 = 10, jump B2 - // B1: v1 = 20, jump B2 - // B2: v2 = phi(v0 from B0, v1 from B1), ret - // - // Without predecessor, should be unknown since 10 != 20 - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - - let mut ssa = SsaFunction::new(0, 0); - - let mut blk0 = SsaBlock::new(0); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v0, - value: ConstValue::I32(10), - })); - blk0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(blk0); - - let mut blk1 = SsaBlock::new(1); - blk1.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(20), - })); - blk1.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(blk1); - - let mut blk2 = SsaBlock::new(2); - let mut phi = PhiNode::new(v2, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); - phi.add_operand(PhiOperand::new(v1, 1)); - blk2.add_phi(phi); - blk2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) })); - ssa.add_block(blk2); - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); - eval.evaluate_block(1); - // No predecessor set - should be unknown - eval.evaluate_block(2); - assert!( - eval.is_unknown(v2), - "phi should be unknown when operands differ and no predecessor set" - ); - } - - #[test] - fn test_next_block_jump() { - // Test unconditional jump - // B0: jump B2 - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(2)); - f.block(1, |b| b.ret()); - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - let result = eval.next_block(0); - assert_eq!(result, ControlFlow::Continue(2)); - } - - #[test] - fn test_next_block_return() { - // Test return (terminal) - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - - let eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - let result = eval.next_block(0); - assert_eq!(result, ControlFlow::Terminal); - } - - #[test] - fn test_next_block_branch_known_true() { - // Test conditional branch with known true condition - // B0: cond = true, branch(cond, B1, B2) - let (ssa, cond) = { - let mut cond_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - cond_out = b.const_true(); - b.branch(cond_out, 1, 2); - }); - f.block(1, |b| b.ret()); - f.block(2, |b| b.ret()); - }) - .unwrap(); - (ssa, cond_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); // Evaluate to set cond = true - let result = eval.next_block(0); - assert_eq!(result, ControlFlow::Continue(1)); - assert!(eval.is_concrete(cond)); - } - - #[test] - fn test_next_block_branch_known_false() { - // Test conditional branch with known false condition - // B0: cond = false, branch(cond, B1, B2) - let (ssa, _cond) = { - let mut cond_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - cond_out = b.const_false(); - b.branch(cond_out, 1, 2); - }); - f.block(1, |b| b.ret()); - f.block(2, |b| b.ret()); - }) - .unwrap(); - (ssa, cond_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - eval.evaluate_block(0); // Evaluate to set cond = false - let result = eval.next_block(0); - assert_eq!(result, ControlFlow::Continue(2)); - } - - #[test] - fn test_next_block_branch_unknown() { - // Test conditional branch with unknown condition - let (ssa, cond) = { - let mut cond_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - cond_out = f.arg(0, SsaType::I32); // Argument is unknown - f.block(0, |b| { - b.branch(cond_out, 1, 2); - }); - f.block(1, |b| b.ret()); - f.block(2, |b| b.ret()); - }) - .unwrap(); - (ssa, cond_out) - }; - - let eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - // Don't evaluate - cond is unknown - assert!(eval.is_unknown(cond)); - let result = eval.next_block(0); - assert_eq!(result, ControlFlow::Unknown); - } - - #[test] - fn test_execute_simple_path() { - // Test execute with a simple linear path - // B0: v0 = 10, jump B1 - // B1: v1 = v0 + 5, jump B2 - // B2: ret - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - v0_out = b.const_i32(10); - b.jump(1); - }); - f.block(1, |b| { - let five = b.const_i32(5); - v1_out = b.add(v0_out, five); - b.jump(2); - }); - f.block(2, |b| b.ret()); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - let trace = eval.execute(0, None, 100); - - assert!(trace.is_complete()); - assert_eq!(trace.blocks(), &[0, 1, 2]); - assert_eq!(eval.get_concrete(v0).and_then(ConstValue::as_i32), Some(10)); - assert_eq!(eval.get_concrete(v1).and_then(ConstValue::as_i32), Some(15)); - } - - #[test] - fn test_execute_with_branch() { - // Test execute with a branch - // B0: state = 5, jump B1 - // B1: cmp = state == 5, branch(cmp, B2, B3) - // B2: ret (true path) - // B3: ret (false path) - let (ssa, state, cmp_result) = { - let mut state_out = SsaVarId::from_index(0); - let mut cmp_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - state_out = b.const_i32(5); - b.jump(1); - }); - f.block(1, |b| { - let five = b.const_i32(5); - cmp_out = b.ceq(state_out, five); - b.branch(cmp_out, 2, 3); - }); - f.block(2, |b| b.ret()); - f.block(3, |b| b.ret()); - }) - .unwrap(); - (ssa, state_out, cmp_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - let trace = eval.execute(0, Some(state), 100); - - assert!(trace.is_complete()); - // Should go B0 -> B1 -> B2 (true branch because state == 5) - assert_eq!(trace.blocks(), &[0, 1, 2]); - assert_eq!( - eval.get_concrete(cmp_result).and_then(ConstValue::as_i32), - Some(1) - ); // true - } - - #[test] - fn test_execute_max_steps() { - // Test that execute respects max_steps limit - // Infinite loop: B0 -> B0 -> B0 -> ... - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(0)); // Self-loop - }) - .unwrap(); - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - let trace = eval.execute(0, None, 5); - - assert!(!trace.is_complete()); // Did not complete - assert!(trace.hit_limit()); // Hit the limit - assert_eq!(trace.len(), 5); // Exactly 5 blocks visited - } - - #[test] - fn test_execute_state_tracking() { - // Test state variable tracking - captures state at start of each block - // We track a single variable that's only assigned in B0 - // B0: state = 10, jump B1 - // B1: jump B2 (state unchanged) - // B2: ret - let (ssa, state) = { - let mut state_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - state_out = b.const_i32(10); - b.jump(1); - }); - f.block(1, |b| { - b.jump(2); - }); - f.block(2, |b| b.ret()); - }) - .unwrap(); - (ssa, state_out) - }; - - let mut eval = SsaEvaluator::new(&ssa, PointerSize::Bit64); - let trace = eval.execute(0, Some(state), 100); - - assert!(trace.is_complete()); - assert_eq!(trace.blocks(), &[0, 1, 2]); - // State tracking captures state at the START of each block BEFORE evaluation - // B0: state not yet set (None) - // B1: state was set to 10 in B0 (and remains 10) - // B2: state still 10 - assert_eq!(trace.states()[0], None); // Before B0 evaluation - assert_eq!( - trace.states()[1].as_ref().and_then(ConstValue::as_i32), - Some(10) - ); // After B0 evaluation, before B1 - assert_eq!( - trace.states()[2].as_ref().and_then(ConstValue::as_i32), - Some(10) - ); // Still 10 after B1, before B2 - } - - #[test] - fn test_control_flow_result_helpers() { - // Test ControlFlow helper methods - let cont = ControlFlow::Continue(5); - assert_eq!(cont.target(), Some(5)); - assert!(!cont.is_terminal()); - assert!(!cont.is_unknown()); - - let term = ControlFlow::Terminal; - assert_eq!(term.target(), None); - assert!(term.is_terminal()); - assert!(!term.is_unknown()); - - let unknown = ControlFlow::Unknown; - assert_eq!(unknown.target(), None); - assert!(!unknown.is_terminal()); - assert!(unknown.is_unknown()); - } - - #[test] - fn test_execution_trace_helpers() { - let mut trace = ExecutionTrace::new(100); - assert!(trace.is_empty()); - assert!(!trace.is_complete()); - assert!(!trace.hit_limit()); - assert_eq!(trace.last_block(), None); - - trace.record_block(0, Some(ConstValue::I32(10))); - trace.record_block(1, Some(ConstValue::I32(20))); - trace.record_block(2, None); - - assert_eq!(trace.len(), 3); - assert!(!trace.is_empty()); - assert_eq!(trace.blocks(), &[0, 1, 2]); - assert_eq!( - trace.states(), - &[Some(ConstValue::I32(10)), Some(ConstValue::I32(20)), None] - ); - assert_eq!(trace.last_block(), Some(2)); - - trace.mark_complete(); - assert!(trace.is_complete()); - } -} diff --git a/dotscope/src/analysis/ssa/exception.rs b/dotscope/src/analysis/ssa/exception.rs index 33ce5582..422748b5 100644 --- a/dotscope/src/analysis/ssa/exception.rs +++ b/dotscope/src/analysis/ssa/exception.rs @@ -1,210 +1,68 @@ -//! Exception handler representation in SSA form. +//! Re-export shim — generic SSA exception handlers live in `analyssa::ir::exception`. //! -//! This module provides the [`SsaExceptionHandler`] type which preserves exception handler -//! information from the original method body through SSA transformations, enabling accurate -//! exception handler emission during code generation. -//! -//! # Exception Handler Preservation -//! -//! When converting CIL to SSA form, exception handlers need special treatment: -//! -//! 1. **Original offsets**: The original IL byte offsets for try/handler regions are preserved -//! 2. **Block mapping**: During SSA construction, we map these offsets to SSA block IDs -//! 3. **Offset remapping**: During code generation, we use block offsets to compute new IL offsets -//! -//! This ensures exception handlers remain valid even when SSA optimizations change instruction -//! sizes or reorder blocks. - -use crate::metadata::{ - method::{ExceptionHandler, ExceptionHandlerFlags}, - token::Token, +//! CIL-specific construction (`from_exception_handler`) and the +//! `class_token` accessor are provided here, since they reference dotscope +//! metadata types that analyssa doesn't see. + +use analyssa::ir::exception::SsaExceptionHandler as AnalyssaSsaExceptionHandler; + +use crate::{ + analysis::ssa::target::CilTarget, + metadata::{ + method::{ExceptionHandler, ExceptionHandlerFlags}, + token::Token, + }, }; -/// Exception handler information preserved in SSA form. -/// -/// This structure preserves the original exception handler data from the method body -/// so it can be accurately emitted during code generation. The offsets are preserved -/// as-is from the original method body, and are remapped during code generation -/// based on the new block layout. -/// -/// Additionally, block indices can be set to enable offset remapping when the -/// code is regenerated with different instruction sizes. -#[derive(Debug, Clone)] -pub struct SsaExceptionHandler { - /// Exception handler type flags (EXCEPTION, FILTER, FINALLY, or FAULT). - pub flags: ExceptionHandlerFlags, - - /// Original byte offset of the protected try block start. - pub try_offset: u32, - - /// Length of the protected try block in bytes. - pub try_length: u32, - - /// Original byte offset of the exception handler code start. - pub handler_offset: u32, - - /// Length of the exception handler code in bytes. - pub handler_length: u32, +/// CIL-defaulted alias of `analyssa::ir::exception::SsaExceptionHandler`. +pub type SsaExceptionHandler = AnalyssaSsaExceptionHandler; - /// For EXCEPTION handlers: the class token for the caught exception type. - /// For FILTER handlers: the offset of the filter expression. - pub class_token_or_filter: u32, - - /// Block ID where the try region starts (set during SSA construction). - pub try_start_block: Option, - - /// Block ID where the try region ends (exclusive, set during SSA construction). - pub try_end_block: Option, - - /// Block ID where the handler region starts (set during SSA construction). - pub handler_start_block: Option, - - /// Block ID where the handler region ends (exclusive, set during SSA construction). - pub handler_end_block: Option, - - /// Block ID where the filter expression starts (for FILTER handlers). - pub filter_start_block: Option, -} - -/// Finds the next surviving block index at or after `start` in the remap table. +/// Creates a new SSA exception handler from the original CIL exception handler. /// -/// Used for exclusive end-block boundaries (`try_end_block`, `handler_end_block`). -/// When an end-boundary block is removed during canonicalization, we need to find the -/// next block that survived to preserve the boundary semantics. -fn find_next_surviving(block_remap: &[Option], start: usize) -> Option { - block_remap.get(start..)?.iter().find_map(|entry| *entry) -} - -impl SsaExceptionHandler { - /// Creates a new SSA exception handler from the original exception handler. - #[must_use] - pub fn from_exception_handler(handler: &ExceptionHandler) -> Self { - // For EXCEPTION handlers, get the class token from filter_offset (which stores it) - // For FILTER handlers, this is the actual filter offset - let class_token_or_filter = if handler.flags == ExceptionHandlerFlags::EXCEPTION { - // Try to get token from handler type, otherwise use filter_offset - handler - .handler - .as_ref() - .map_or(handler.filter_offset, |t| t.token.value()) - } else { - handler.filter_offset - }; - - Self { - flags: handler.flags, - try_offset: handler.try_offset, - try_length: handler.try_length, - handler_offset: handler.handler_offset, - handler_length: handler.handler_length, - class_token_or_filter, - try_start_block: None, - try_end_block: None, - handler_start_block: None, - handler_end_block: None, - filter_start_block: None, - } +/// CIL-specific factory; callers historically used +/// `SsaExceptionHandler::from_exception_handler(...)` (an inherent method on +/// the CIL impl). After the analyssa extraction it's a free function because +/// orphan rules forbid inherent impls on foreign types. +#[must_use] +pub fn from_exception_handler(handler: &ExceptionHandler) -> SsaExceptionHandler { + let class_token_or_filter = if handler.flags == ExceptionHandlerFlags::EXCEPTION { + handler + .handler + .as_ref() + .map_or(handler.filter_offset, |t| t.token.value()) + } else { + handler.filter_offset + }; + + SsaExceptionHandler { + flags: handler.flags, + try_offset: handler.try_offset, + try_length: handler.try_length, + handler_offset: handler.handler_offset, + handler_length: handler.handler_length, + class_token_or_filter, + try_start_block: None, + try_end_block: None, + handler_start_block: None, + handler_end_block: None, + filter_start_block: None, } +} +/// CIL-specific extension methods on `SsaExceptionHandler`. +pub trait SsaExceptionHandlerCilExt { /// Returns the class token for EXCEPTION handlers. - #[must_use] - pub fn class_token(&self) -> Option { + fn class_token(&self) -> Option; +} + +impl SsaExceptionHandlerCilExt for AnalyssaSsaExceptionHandler { + fn class_token(&self) -> Option { if self.flags == ExceptionHandlerFlags::EXCEPTION { Some(Token::new(self.class_token_or_filter)) } else { None } } - - /// Returns the filter offset for FILTER handlers. - #[must_use] - pub fn filter_offset(&self) -> Option { - if self.flags == ExceptionHandlerFlags::FILTER { - Some(self.class_token_or_filter) - } else { - None - } - } - - /// Checks if block indices have been set for offset remapping. - #[must_use] - pub fn has_block_mapping(&self) -> bool { - self.try_start_block.is_some() && self.handler_start_block.is_some() - } - - /// Remaps block indices using the provided block remapping. - /// - /// This method updates all block index fields (`try_start_block`, `try_end_block`, - /// `handler_start_block`, `handler_end_block`, `filter_start_block`) to reflect - /// block renumbering that occurs during SSA canonicalization. - /// - /// # Arguments - /// - /// * `block_remap` - A slice where `block_remap[old_idx]` contains: - /// - `Some(new_idx)` if the block at `old_idx` was kept and is now at `new_idx` - /// - `None` if the block at `old_idx` was removed - /// - /// # Behavior - /// - /// For each block index field: - /// - If the field is `None`, it remains `None` - /// - If the field contains an index that maps to `Some(new_idx)`, it's updated to `new_idx` - /// - If the field contains an index that maps to `None` (block removed), the field becomes `None` - /// - If the index is out of bounds in `block_remap`, the field becomes `None` - /// - /// # Why This Is Necessary - /// - /// During SSA canonicalization, empty blocks may be removed and remaining blocks - /// are renumbered to maintain contiguous indices. Without remapping, exception - /// handler block indices would become stale, causing code generation to: - /// 1. Fail to find block offsets (falling back to original IL offsets) - /// 2. Produce incorrect exception handler regions - /// 3. Generate invalid IL that crashes at runtime - /// - /// # Example - /// - /// ```text - /// // Before canonicalization: blocks [0, 1, 2, 3, 4] - /// // Block 1 is empty and removed - /// // After canonicalization: blocks [0, 2, 3, 4] → renumbered to [0, 1, 2, 3] - /// - /// // block_remap = [Some(0), None, Some(1), Some(2), Some(3)] - /// - /// // If handler_start_block was Some(3), it becomes Some(2) - /// // If try_start_block was Some(1), it becomes None (block removed) - /// ``` - pub fn remap_block_indices(&mut self, block_remap: &[Option]) { - // Start blocks: must map exactly (protected, so should always survive) - self.try_start_block = self - .try_start_block - .and_then(|idx| block_remap.get(idx).copied().flatten()); - - // End blocks (exclusive boundaries): if removed, use next surviving block - self.try_end_block = self.try_end_block.and_then(|idx| { - block_remap - .get(idx) - .copied() - .flatten() - .or_else(|| find_next_surviving(block_remap, idx)) - }); - - self.handler_start_block = self - .handler_start_block - .and_then(|idx| block_remap.get(idx).copied().flatten()); - - // End block (exclusive boundary): if removed, use next surviving block - self.handler_end_block = self.handler_end_block.and_then(|idx| { - block_remap - .get(idx) - .copied() - .flatten() - .or_else(|| find_next_surviving(block_remap, idx)) - }); - - self.filter_start_block = self - .filter_start_block - .and_then(|idx| block_remap.get(idx).copied().flatten()); - } } #[cfg(test)] @@ -213,6 +71,10 @@ mod tests { use super::*; + // Lock T to CilTarget for tests; they construct `SsaExceptionHandler` with + // CIL flags directly. + type SsaExceptionHandler = super::SsaExceptionHandler; + #[test] fn test_remap_block_indices_basic() { let mut handler = SsaExceptionHandler { diff --git a/dotscope/src/analysis/ssa/function.rs b/dotscope/src/analysis/ssa/function.rs new file mode 100644 index 00000000..de677930 --- /dev/null +++ b/dotscope/src/analysis/ssa/function.rs @@ -0,0 +1,313 @@ +//! CIL-pinned extension traits on `SsaFunction`. +//! +//! - [`SsaFunctionCilExt`] adds CIL-typed local-variable utilities +//! (`optimize_locals`, `generate_local_signature`, `infer_local_type`) that +//! reference `SsaType` / `SignatureLocalVariable` and so cannot live inside +//! analyssa. +//! - [`SsaFunctionSemanticsExt`] delegates block- and loop-classification to +//! [`SemanticAnalyzer`](crate::analysis::cfg::SemanticAnalyzer), which itself +//! depends on dotscope's CFG-loop machinery (`LoopInfo`, dominator-based +//! loop detection). +//! +//! Both traits exist because Rust's orphan rules forbid inherent impls on the +//! foreign `analyssa::ir::function::SsaFunction` type. + +use std::collections::{BTreeMap, BTreeSet, HashMap}; + +use analyssa::{ + ir::{ + function::SsaFunction, + ops::SsaOp, + variable::{SsaVarId, VariableOrigin}, + }, + target::Target, +}; + +use crate::{ + analysis::{ + cfg::{BlockSemantics, LoopSemantics, SemanticAnalyzer}, + ssa::{target::CilTarget, types::SsaType}, + LoopInfo, + }, + metadata::signatures::{CustomModifiers, SignatureLocalVariable, SignatureLocalVariables}, + Error, Result, +}; + +/// CIL-specific extension methods on `SsaFunction`. +/// +/// `optimize_locals` / `generate_local_signature` / `infer_local_type` use +/// `SsaType` and `SignatureLocalVariable` which are CIL-side only. +pub trait SsaFunctionCilExt { + /// Optimizes local variables by removing unused ones and compacting indices. + fn optimize_locals(&mut self) -> Vec>; + + /// Generates a `SignatureLocalVariables` from the function's locals. + fn generate_local_signature( + &self, + override_count: Option, + temporary_types: Option<&BTreeMap>, + ) -> Result; + + /// Infers a CIL `SsaType` for a local index from its SSA variable definitions. + fn infer_local_type(&self, local_idx: usize) -> Option; +} + +impl SsaFunctionCilExt for SsaFunction { + fn optimize_locals(&mut self) -> Vec> { + // Collect all used local indices. + let mut used_locals: BTreeSet = BTreeSet::new(); + + // From variables + for var in self.variables() { + if let VariableOrigin::Local(idx) = var.origin() { + used_locals.insert(idx); + } + } + + // From phi nodes + for block in self.blocks() { + for phi in block.phi_nodes() { + if let VariableOrigin::Local(idx) = phi.origin() { + used_locals.insert(idx); + } + } + } + + // From LoadLocal and LoadLocalAddr instructions + for block in self.blocks() { + for instr in block.instructions() { + match instr.op() { + SsaOp::LoadLocal { local_index, .. } + | SsaOp::LoadLocalAddr { local_index, .. } => { + used_locals.insert(*local_index); + } + _ => {} + } + } + } + + // Determine the actual range of local indices (may exceed num_locals + // if SSA construction allocated extras). + let max_idx = used_locals.iter().copied().max().unwrap_or(0); + let max_known = u16::try_from(self.num_locals()) + .unwrap_or(u16::MAX) + .saturating_sub(1) + .max(max_idx); + + // Build remapping (old → new). + let mut remap: Vec> = vec![None; usize::from(max_known) + 1]; + let mut new_idx: u16 = 0; + for old_idx in 0..=max_known { + if used_locals.contains(&old_idx) { + if let Some(slot) = remap.get_mut(usize::from(old_idx)) { + *slot = Some(new_idx); + } + new_idx = new_idx.saturating_add(1); + } + } + + let new_count = new_idx; + + // Apply remap to variables. + let mut new_origins: Vec<(SsaVarId, VariableOrigin)> = Vec::new(); + for var in self.variables() { + if let VariableOrigin::Local(old) = var.origin() { + if let Some(Some(new)) = remap.get(usize::from(old)) { + new_origins.push((var.id(), VariableOrigin::Local(*new))); + } + } + } + for (id, origin) in new_origins { + if let Some(v) = self.variable_mut(id) { + v.set_origin(origin); + } + } + + // Apply remap to phi nodes + for block in self.blocks_mut() { + for phi in block.phi_nodes_mut() { + if let VariableOrigin::Local(old) = phi.origin() { + if let Some(Some(new)) = remap.get(usize::from(old)) { + phi.set_origin(VariableOrigin::Local(*new)); + } + } + } + } + + // Apply remap to LoadLocal/LoadLocalAddr instructions + for block in self.blocks_mut() { + for instr in block.instructions_mut() { + let op = instr.op_mut(); + let new_index = match op { + SsaOp::LoadLocal { local_index, .. } + | SsaOp::LoadLocalAddr { local_index, .. } => { + remap.get(usize::from(*local_index)).copied().flatten() + } + _ => None, + }; + if let Some(new) = new_index { + match op { + SsaOp::LoadLocal { local_index, .. } + | SsaOp::LoadLocalAddr { local_index, .. } => *local_index = new, + _ => {} + } + } + } + } + + let original = self.original_num_locals(); + self.set_num_locals(usize::from(new_count), original); + remap + } + + fn generate_local_signature( + &self, + override_count: Option, + temporary_types: Option<&BTreeMap>, + ) -> Result { + let empty_temps = BTreeMap::new(); + let temps = temporary_types.unwrap_or(&empty_temps); + + let local_count = override_count + .map(usize::from) + .unwrap_or_else(|| self.num_locals()); + + // Path 1: original local types preserved from source assembly + if let Some(orig) = self.original_local_types() { + let mut locals: Vec = Vec::with_capacity(local_count); + for sig in orig.iter().take(local_count) { + locals.push(SignatureLocalVariable { + modifiers: sig.modifiers.clone(), + is_byref: sig.is_byref, + is_pinned: sig.is_pinned, + base: sig.base.clone(), + }); + } + // Pad with inferred types if needed + while locals.len() < local_count { + let idx = u16::try_from(locals.len()).unwrap_or(u16::MAX); + let ty = temps + .get(&idx) + .cloned() + .or_else(|| self.infer_local_type(usize::from(idx))); + let Some(ty) = ty else { + return Err(Error::SsaError(format!( + "missing type info for local index {idx} during signature generation" + ))); + }; + locals.push(SignatureLocalVariable { + modifiers: CustomModifiers::default(), + is_byref: false, + is_pinned: false, + base: ty.to_type_signature(), + }); + } + return Ok(SignatureLocalVariables { locals }); + } + + // Path 2: temporary_types map + SSA inference + let mut local_types: Vec> = vec![None; local_count]; + for (idx, ty) in temps.iter() { + if let Some(slot) = local_types.get_mut(usize::from(*idx)) { + *slot = Some(ty.clone()); + } + } + for (idx, slot) in local_types.iter_mut().enumerate() { + if slot.is_none() { + *slot = self.infer_local_type(idx); + } + } + + let mut locals: Vec = Vec::with_capacity(local_types.len()); + for (idx, ty) in local_types.iter().enumerate() { + let Some(ty) = ty else { + return Err(Error::SsaError(format!( + "missing type info for local index {idx} during signature generation" + ))); + }; + locals.push(SignatureLocalVariable { + modifiers: CustomModifiers::default(), + is_byref: false, + is_pinned: false, + base: ty.to_type_signature(), + }); + } + + Ok(SignatureLocalVariables { locals }) + } + + fn infer_local_type(&self, local_idx: usize) -> Option { + let target_origin = VariableOrigin::Local(u16::try_from(local_idx).ok()?); + + // Search variables for the first concrete type for this local origin + for var in self.variables() { + if var.origin() == target_origin { + let ty = var.var_type(); + if !CilTarget::is_unknown(ty) { + return Some(ty.clone()); + } + } + } + None + } +} + +/// Block- and loop-semantic-analysis extension methods on `SsaFunction`. +pub trait SsaFunctionSemanticsExt { + /// Analyzes the semantic role of a specific block. + fn analyze_block_semantics(&self, block_idx: usize) -> BlockSemantics; + + /// Analyzes semantic roles of multiple blocks. + fn analyze_blocks_semantics(&self, blocks: &[usize]) -> HashMap; + + /// Analyzes the semantic structure of a structural loop. + fn analyze_loop_semantics(&self, loop_info: &LoopInfo) -> LoopSemantics; + + /// Recovers loop semantics from flattened dispatcher case blocks. + fn recover_loop_from_cases( + &self, + case_blocks: &[usize], + dispatcher_block: Option, + ) -> LoopSemantics; + + /// Creates a semantic analyzer for this function (cache-friendly for + /// multiple analyses). + fn semantic_analyzer(&self) -> SemanticAnalyzer<'_, T>; +} + +impl SsaFunctionSemanticsExt for SsaFunction { + fn analyze_block_semantics(&self, block_idx: usize) -> BlockSemantics { + let mut analyzer = SemanticAnalyzer::new(self); + analyzer.analyze_block(block_idx).clone() + } + + fn analyze_blocks_semantics(&self, blocks: &[usize]) -> HashMap { + let mut analyzer = SemanticAnalyzer::new(self); + let mut results = HashMap::new(); + for &block in blocks { + results.insert(block, analyzer.analyze_block(block).clone()); + } + results + } + + fn analyze_loop_semantics(&self, loop_info: &LoopInfo) -> LoopSemantics { + let mut analyzer = SemanticAnalyzer::new(self); + analyzer.analyze_loop(loop_info) + } + + fn recover_loop_from_cases( + &self, + case_blocks: &[usize], + dispatcher_block: Option, + ) -> LoopSemantics { + let mut analyzer = SemanticAnalyzer::new(self); + if let Some(disp) = dispatcher_block { + analyzer.mark_dispatcher(disp); + } + analyzer.recover_loop_from_cases(case_blocks) + } + + fn semantic_analyzer(&self) -> SemanticAnalyzer<'_, T> { + SemanticAnalyzer::new(self) + } +} diff --git a/dotscope/src/analysis/ssa/function/canonical.rs b/dotscope/src/analysis/ssa/function/canonical.rs deleted file mode 100644 index 0be15b08..00000000 --- a/dotscope/src/analysis/ssa/function/canonical.rs +++ /dev/null @@ -1,430 +0,0 @@ -//! Canonicalization of SSA functions for clean code generation. -//! -//! Strips Nops, removes empty blocks, compacts block indices, and -//! ensures valid terminators after deobfuscation passes. - -use std::collections::{BTreeMap, BTreeSet}; - -use crate::{ - analysis::ssa::{PhiOperand, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId}, - utils::BitSet, -}; - -/// Finds kept predecessors of a removed block during canonicalization. -/// -/// When a block is removed, we need to find the actual predecessor blocks -/// (that are being kept) which would flow into the removed block. This is -/// used to properly update PHI node predecessors. -/// -/// The function follows predecessor chains through removed blocks until it -/// finds blocks that are being kept (have entries in `block_remap`). -pub(crate) fn find_kept_predecessors( - removed_block: usize, - predecessors: &BTreeMap>, - block_remap: &[Option], - redirect_map: &BTreeMap, -) -> Vec { - let mut result = Vec::new(); - let block_count = block_remap.len(); - let mut visited = BitSet::new(block_count); - let mut queue = vec![removed_block]; - - while let Some(current) = queue.pop() { - if current >= block_count || !visited.insert(current) { - continue; - } - - if let Some(preds) = predecessors.get(¤t) { - for &pred in preds { - if let Some(Some(new_idx)) = block_remap.get(pred) { - // This predecessor is kept - add its new index - result.push(*new_idx); - } else if redirect_map.contains_key(&pred) { - // This predecessor is also removed - follow the chain - queue.push(pred); - } - } - } - } - - result -} - -impl SsaFunction { - /// Canonicalizes the SSA function for clean code generation. - /// - /// This method performs final cleanup after deobfuscation passes: - /// - /// 1. **Strip Nop instructions**: Removes all `SsaOp::Nop` instructions - /// 2. **Identify empty blocks**: Marks blocks with no instructions or phi nodes for removal - /// 3. **Build redirect map**: For removed blocks, finds their ultimate jump targets - /// 4. **Update branch targets**: Retargets jumps to skip removed empty blocks - /// 5. **Update PHI predecessors**: Fixes PHI node operands when predecessor blocks are removed - /// 6. **Compact blocks**: Removes empty blocks and renumbers remaining blocks contiguously - /// - /// This should be called after all deobfuscation passes complete, before - /// code generation. The resulting SSA is cleaner and easier to convert to IL. - pub fn canonicalize(&mut self) { - // Phase 1: Strip Nop instructions from all blocks - for block in &mut self.blocks { - block - .instructions_mut() - .retain(|instr| !matches!(instr.op(), SsaOp::Nop)); - } - - // Collect blocks that must be preserved: - // - Exception handler entry blocks - // - Leave targets (exception handler exit blocks) - let block_count = self.blocks.len(); - let mut protected_blocks = BitSet::new(block_count); - - // Protect exception handler entry blocks - for handler in &self.exception_handlers { - if let Some(try_block) = handler.try_start_block { - if try_block < block_count { - protected_blocks.insert(try_block); - } - } - if let Some(handler_block) = handler.handler_start_block { - if handler_block < block_count { - protected_blocks.insert(handler_block); - } - } - if let Some(filter_block) = handler.filter_start_block { - if filter_block < block_count { - protected_blocks.insert(filter_block); - } - } - } - - // Protect Leave targets (exception handler exit blocks) - for block in &self.blocks { - if let Some(SsaOp::Leave { target }) = block.terminator_op() { - if *target < block_count { - protected_blocks.insert(*target); - } - } - } - - // Phase 2: Identify empty blocks and build remapping. - // An empty block has no instructions AND no phi nodes. - // Exception: Keep block 0 (entry) and protected exception handler blocks even if empty. - let mut block_remap: Vec> = Vec::with_capacity(block_count); - let mut new_index = 0usize; - - for (old_index, block) in self.blocks.iter().enumerate() { - let is_empty = block.instructions().is_empty() && block.phi_nodes().is_empty(); - let is_entry = old_index == 0; - let is_protected = protected_blocks.contains(old_index); - - if is_empty && !is_entry && !is_protected { - block_remap.push(None); // This block will be removed - } else { - block_remap.push(Some(new_index)); - new_index = new_index.saturating_add(1); - } - } - - // Phase 3: Build redirect map for removed blocks. - // For each removed block, find its ultimate jump target (following jump chains). - // If we can't find a redirect for a block, we must keep it instead of removing it. - let mut redirect_map: BTreeMap = BTreeMap::new(); - - for old_index in 0..self.blocks.len() { - if matches!(block_remap.get(old_index), Some(None)) { - // This block is being removed - find where it would jump to - if let Some(target) = self.find_ultimate_target(old_index, &block_remap) { - redirect_map.insert(old_index, target); - } else { - // Can't find a redirect target - we must keep this block. - // Reassign it a new index. - if let Some(slot) = block_remap.get_mut(old_index) { - *slot = Some(new_index); - new_index = new_index.saturating_add(1); - } - } - } - } - - // Build predecessor map for PHI updates (needed for Phase 5). - // For each block, collect all blocks that have edges TO it. - let mut predecessors: BTreeMap> = BTreeMap::new(); - for (block_idx, block) in self.blocks.iter().enumerate() { - for target in block.successors() { - predecessors.entry(target).or_default().push(block_idx); - } - } - - // Phase 4: Update all branch targets in remaining blocks. - for block in &mut self.blocks { - for instr in block.instructions_mut() { - Self::remap_branch_targets(instr.op_mut(), &block_remap, &redirect_map); - } - } - - // Phase 5: Update PHI node predecessors. - // When a predecessor block is removed, we find the kept blocks that would have - // flowed into the removed block and use those as the new predecessors. - // - // Special case: Some PHI operands may reference orphaned blocks (blocks with no - // predecessors). This happens when deobfuscation passes modify the CFG without - // properly updating PHI predecessors. We try to recover these by assigning - // orphaned values to unaccounted-for predecessors. - - // Process each block's PHI nodes - for block_idx in 0..self.blocks.len() { - // Get the predecessors of THIS block (the one containing the PHI) - // These are the OLD indices of blocks that jump to this block. - let phi_block_preds: Vec = - predecessors.get(&block_idx).cloned().unwrap_or_default(); - - // Also compute the NEW indices of kept predecessors - let kept_phi_block_preds: Vec = phi_block_preds - .iter() - .filter_map(|&old_pred| block_remap.get(old_pred).and_then(|opt| *opt)) - .collect(); - - let Some(block) = self.blocks.get_mut(block_idx) else { - continue; - }; - for phi in block.phi_nodes_mut() { - // Collect changes first (to avoid borrow issues) - let mut changes: Vec<(usize, Option, Vec)> = Vec::new(); - // Track orphaned values (removed operands with no replacement) - let mut orphaned_values: Vec = Vec::new(); - - for (op_idx, operand) in phi.operands().iter().enumerate() { - let old_pred = operand.predecessor(); - let value = operand.value(); - - if redirect_map.contains_key(&old_pred) { - // This predecessor was removed. Find all kept blocks that flow into it. - let kept_preds = find_kept_predecessors( - old_pred, - &predecessors, - &block_remap, - &redirect_map, - ); - - if kept_preds.is_empty() { - // Orphaned operand - track the value for potential recovery below - orphaned_values.push(value); - } - - let replacements: Vec = kept_preds - .into_iter() - .map(|new_pred| PhiOperand::new(value, new_pred)) - .collect(); - - // None = remove this operand, replacements = add these instead - changes.push((op_idx, None, replacements)); - } else if let Some(Some(new_pred)) = block_remap.get(old_pred) { - // Predecessor was kept but renumbered - update in place - changes.push((op_idx, Some(PhiOperand::new(value, *new_pred)), Vec::new())); - } - } - - // Apply changes in reverse order (to preserve indices when removing) - for (op_idx, replacement, additions) in changes.into_iter().rev() { - if let Some(new_op) = replacement { - // Update in place - if let Some(operand) = phi.operands_mut().get_mut(op_idx) { - *operand = new_op; - } - } else { - // Remove the operand - phi.operands_mut().remove(op_idx); - // Add replacement operands - for op in additions { - phi.add_operand(op); - } - } - } - - // Post-processing: try to recover orphaned values by assigning them - // to unaccounted-for predecessors. - if !orphaned_values.is_empty() { - // Get the predecessors that are currently accounted for in the PHI - let accounted_preds: BTreeSet = - phi.operands().iter().map(PhiOperand::predecessor).collect(); - - // Find predecessors that are NOT accounted for - let unaccounted_preds: Vec = kept_phi_block_preds - .iter() - .copied() - .filter(|pred| !accounted_preds.contains(pred)) - .collect(); - - // Assign orphaned values to unaccounted predecessors - for (orphan_val, &unaccounted_pred) in - orphaned_values.iter().zip(unaccounted_preds.iter()) - { - phi.add_operand(PhiOperand::new(*orphan_val, unaccounted_pred)); - } - } - } - } - - // Phase 6: Remove empty blocks and compact block indices. - let mut kept_blocks: Vec = Vec::with_capacity(new_index); - for (old_index, block) in self.blocks.drain(..).enumerate() { - if matches!(block_remap.get(old_index), Some(Some(_))) { - kept_blocks.push(block); - } - } - - // Update block indices in kept blocks - for (new_idx, block) in kept_blocks.iter_mut().enumerate() { - block.set_id(new_idx); - } - - self.blocks = kept_blocks; - - // Phase 7: Remap exception handler block indices. - for handler in &mut self.exception_handlers { - handler.remap_block_indices(&block_remap); - } - - // Phase 8: Ensure the method has a valid terminator. - self.ensure_valid_terminator(); - } - - /// Ensures the function has a valid terminator path from the entry block. - /// - /// This handles the case where all meaningful code has been removed (e.g., after - /// neutralizing 100% protection code in a module .cctor), leaving only Jumps to - /// empty blocks. In such cases, we replace the entry block's terminator with a - /// Return instruction to produce valid IL. - fn ensure_valid_terminator(&mut self) { - // Check if the method effectively does nothing useful: - // - Only has Jump instructions (no actual code) - // - All Jump targets lead to empty blocks or more Jumps - let has_useful_code = self.blocks.iter().any(|block| { - block.instructions().iter().any(|instr| { - match instr.op() { - // Jumps and Nops don't count as useful - they're just control flow - SsaOp::Jump { .. } | SsaOp::Nop => false, - // Any other instruction (including returns, throws) is useful code - _ => true, - } - }) - }); - - // If there's no useful code, replace entry block with just a Return - if !has_useful_code { - if let Some(entry_block) = self.blocks.first_mut() { - entry_block.instructions_mut().clear(); - entry_block.phi_nodes_mut().clear(); - entry_block - .add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - } - } - } - - /// Finds the ultimate jump target for a block, following jump chains. - /// - /// Used during canonicalization to find where an empty block would - /// ultimately transfer control to. - fn find_ultimate_target( - &self, - block_idx: usize, - block_remap: &[Option], - ) -> Option { - let mut visited = BitSet::new(self.blocks.len()); - let mut current = block_idx; - - while visited.insert(current) { - let block = self.blocks.get(current)?; - - // Get the terminator's target - let terminator = block.terminator_op(); - let target = terminator.and_then(|op| match op { - SsaOp::Jump { target } => Some(*target), - // For branches, we can't simplify - the block isn't truly empty - _ => None, - }); - - // Handle the target - match target { - Some(t) if block_remap.get(t).copied().flatten().is_some() => { - // Target exists in new layout - return block_remap.get(t).copied().flatten(); - } - Some(t) => { - // Target is also being removed, follow the chain - current = t; - } - None => { - // No explicit jump target. Check if block is truly empty (no terminator). - // In CIL semantics, empty blocks fall through to the next block. - if terminator.is_none() && block.instructions().is_empty() { - // Try to fall through to the next block - let next_block = current.saturating_add(1); - if next_block < self.blocks.len() { - if let Some(Some(new_idx)) = block_remap.get(next_block) { - // Next block exists in new layout - return Some(*new_idx); - } else if block_remap.get(next_block).is_some() { - // Next block is also being removed, follow the chain - current = next_block; - continue; - } - } - } - // No simple jump target and no fall-through, can't redirect - return None; - } - } - } - - None // Cycle detected - } - - /// Remaps branch targets according to the block remapping. - fn remap_branch_targets( - op: &mut SsaOp, - block_remap: &[Option], - redirect_map: &BTreeMap, - ) { - // Helper closure to remap a single target - let remap_target = |target: &mut usize| { - // First try redirect_map (for removed blocks with known targets) - if let Some(&new_target) = redirect_map.get(target) { - *target = new_target; - return; - } - // Then try block_remap (for kept blocks) - if let Some(Some(new_target)) = block_remap.get(*target) { - *target = *new_target; - } - }; - - match op { - SsaOp::Jump { target } | SsaOp::Leave { target } => { - remap_target(target); - } - SsaOp::Branch { - true_target, - false_target, - .. - } - | SsaOp::BranchCmp { - true_target, - false_target, - .. - } => { - remap_target(true_target); - remap_target(false_target); - } - SsaOp::Switch { - targets, default, .. - } => { - for target in targets.iter_mut() { - remap_target(target); - } - remap_target(default); - } - _ => {} - } - } -} diff --git a/dotscope/src/analysis/ssa/function/duplication.rs b/dotscope/src/analysis/ssa/function/duplication.rs deleted file mode 100644 index 04fea23e..00000000 --- a/dotscope/src/analysis/ssa/function/duplication.rs +++ /dev/null @@ -1,318 +0,0 @@ -//! Block duplication and cloning for SSA functions. -//! -//! These methods handle allocating fresh variables, cloning blocks with -//! remapped variable IDs, and updating branch targets. - -use std::collections::HashMap; - -use crate::analysis::ssa::{ - PhiNode, PhiOperand, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId, -}; - -impl SsaFunction { - /// Allocates fresh SSA variable IDs for all variables defined in a block. - /// - /// This creates new unique IDs for each variable defined by: - /// - Phi nodes in the block - /// - Instructions that produce results - /// - /// The returned mapping maps old variable IDs to their fresh replacements. - /// Variables that are only used (not defined) in the block are not included - /// in the mapping - they should reference the original variables. - /// - /// # Arguments - /// - /// * `block_idx` - The index of the block to analyze - /// - /// # Returns - /// - /// A mapping from original variable IDs to fresh IDs, or `None` if the - /// block index is invalid. - #[must_use] - pub fn allocate_fresh_variables_for_block( - &mut self, - block_idx: usize, - ) -> Option> { - let block = self.block(block_idx)?; - let mut mapping = HashMap::new(); - - // Collect IDs to allocate (can't borrow self mutably while reading block) - let phi_ids: Vec = block.phi_nodes().iter().map(|phi| phi.result()).collect(); - let instr_dests: Vec = block - .instructions() - .iter() - .filter_map(|instr| instr.op().dest()) - .collect(); - - // Allocate fresh IDs for phi node results - for old_id in phi_ids { - let new_id = self.var_allocator.alloc(); - mapping.insert(old_id, new_id); - } - - // Allocate fresh IDs for instruction defs - for dest in instr_dests { - let new_id = self.var_allocator.alloc(); - mapping.insert(dest, new_id); - } - - Some(mapping) - } - - /// Clones a block with remapped variable IDs. - /// - /// This creates a deep copy of the block where all variable references - /// are transformed through the provided remapping function. Variables - /// not in the mapping are left unchanged (allowing references to - /// variables defined outside the block). - /// - /// The new block is assigned the specified ID but is NOT automatically - /// added to the function - the caller must add it explicitly. - /// - /// # Arguments - /// - /// * `block_idx` - The index of the block to clone - /// * `new_block_id` - The ID to assign to the cloned block - /// * `var_remap` - Mapping from old variable IDs to new ones - /// * `pred_remap` - Optional mapping for predecessor block indices in phi nodes - /// - /// # Returns - /// - /// A new `SsaBlock` with remapped variables, or `None` if the block doesn't exist. - #[must_use] - pub fn clone_block_with_remap( - &self, - block_idx: usize, - new_block_id: usize, - var_remap: &HashMap, - pred_remap: Option<&HashMap>, - ) -> Option { - let block = self.block(block_idx)?; - - let mut new_block = - SsaBlock::with_capacity(new_block_id, block.phi_count(), block.instruction_count()); - - // Clone phi nodes with remapped variables and predecessors - for phi in block.phi_nodes() { - let new_result = var_remap - .get(&phi.result()) - .copied() - .unwrap_or(phi.result()); - let mut new_phi = PhiNode::with_capacity(new_result, phi.origin(), phi.operand_count()); - - for operand in phi.operands() { - let new_value = var_remap - .get(&operand.value()) - .copied() - .unwrap_or(operand.value()); - let new_pred = pred_remap - .and_then(|m| m.get(&operand.predecessor()).copied()) - .unwrap_or(operand.predecessor()); - new_phi.add_operand(PhiOperand::new(new_value, new_pred)); - } - - new_block.add_phi(new_phi); - } - - // Clone instructions with remapped variables - for instr in block.instructions() { - let new_instr = Self::clone_instruction_with_remap(instr, var_remap); - new_block.add_instruction(new_instr); - } - - Some(new_block) - } - - /// Clones an instruction with remapped variable IDs. - /// - /// Creates a copy of the instruction where all variable references are - /// transformed through the provided mapping. The original CIL instruction - /// is preserved (cloned) but the SSA operation uses new variable IDs. - /// - /// # Arguments - /// - /// * `instr` - The instruction to clone - /// * `var_remap` - Mapping from old variable IDs to new ones - /// - /// # Returns - /// - /// A new `SsaInstruction` with remapped variables. - fn clone_instruction_with_remap( - instr: &SsaInstruction, - var_remap: &HashMap, - ) -> SsaInstruction { - let original = instr.original().clone(); - - // Use the remap_variables method on SsaOp - let new_op = instr - .op() - .remap_variables(|old_id| var_remap.get(&old_id).copied()); - let mut new_instr = SsaInstruction::new(original, new_op); - new_instr.set_result_type(instr.result_type().cloned()); - new_instr - } - - /// Duplicates a block, creating a complete copy with fresh variables. - /// - /// This is a high-level method that: - /// 1. Allocates fresh variable IDs for all definitions in the block - /// 2. Creates corresponding `SsaVariable` entries for each new ID - /// 3. Clones the block with the remapped variables - /// 4. Adds the new block to the function - /// - /// The new block is assigned the next available block ID. - /// - /// # Arguments - /// - /// * `block_idx` - The index of the block to duplicate - /// - /// # Returns - /// - /// A tuple of (new_block_id, variable_mapping), or `None` if the block doesn't exist. - /// The variable_mapping maps original variable IDs to their duplicated counterparts. - pub fn duplicate_block( - &mut self, - block_idx: usize, - ) -> Option<(usize, HashMap)> { - // Allocate fresh variables - let var_remap = self.allocate_fresh_variables_for_block(block_idx)?; - - // Create SsaVariable entries for each new variable via create_variable. - // We need to collect the info first to avoid borrow conflicts. - let var_info: Vec<_> = var_remap - .iter() - .filter_map(|(&old_id, _)| { - self.variable(old_id).map(|v| { - ( - old_id, - v.origin(), - v.version(), - v.def_site(), - v.var_type().clone(), - ) - }) - }) - .collect(); - // Now create new variables — they get fresh IDs from the allocator - // which match the pre-allocated IDs from allocate_fresh_variables_for_block - for (_old_id, origin, version, def_site, var_type) in var_info { - self.create_variable(origin, version, def_site, var_type); - } - - // Clone the block with new ID - let new_block_id = self.blocks.len(); - let new_block = self.clone_block_with_remap(block_idx, new_block_id, &var_remap, None)?; - self.add_block(new_block); - - Some((new_block_id, var_remap)) - } - - /// Updates branch targets in a block to point to new destinations. - /// - /// This modifies the terminator instruction of the specified block, - /// remapping any target block indices according to the provided mapping. - /// Targets not in the mapping are left unchanged. - /// - /// # Arguments - /// - /// * `block_idx` - The block whose terminator to update - /// * `target_remap` - Mapping from old target indices to new ones - /// - /// # Returns - /// - /// `true` if any targets were updated, `false` otherwise. - pub fn remap_block_targets( - &mut self, - block_idx: usize, - target_remap: &HashMap, - ) -> bool { - let Some(block) = self.block_mut(block_idx) else { - return false; - }; - let Some(last) = block.instructions_mut().last_mut() else { - return false; - }; - let new_op = match last.op() { - SsaOp::Jump { target } => { - if let Some(&new_target) = target_remap.get(target) { - SsaOp::Jump { target: new_target } - } else { - return false; - } - } - SsaOp::Branch { - condition, - true_target, - false_target, - } => { - let new_true = target_remap - .get(true_target) - .copied() - .unwrap_or(*true_target); - let new_false = target_remap - .get(false_target) - .copied() - .unwrap_or(*false_target); - if new_true == *true_target && new_false == *false_target { - return false; - } - SsaOp::Branch { - condition: *condition, - true_target: new_true, - false_target: new_false, - } - } - SsaOp::BranchCmp { - left, - right, - cmp, - unsigned, - true_target, - false_target, - } => { - let new_true = target_remap - .get(true_target) - .copied() - .unwrap_or(*true_target); - let new_false = target_remap - .get(false_target) - .copied() - .unwrap_or(*false_target); - if new_true == *true_target && new_false == *false_target { - return false; - } - SsaOp::BranchCmp { - left: *left, - right: *right, - cmp: *cmp, - unsigned: *unsigned, - true_target: new_true, - false_target: new_false, - } - } - SsaOp::Switch { - value, - targets, - default, - } => { - let new_targets: Vec = targets - .iter() - .map(|&t| target_remap.get(&t).copied().unwrap_or(t)) - .collect(); - let new_default = target_remap.get(default).copied().unwrap_or(*default); - if new_targets == *targets && new_default == *default { - return false; - } - SsaOp::Switch { - value: *value, - targets: new_targets, - default: new_default, - } - } - _ => return false, - }; - - last.set_op(new_op); - true - } -} diff --git a/dotscope/src/analysis/ssa/function/mod.rs b/dotscope/src/analysis/ssa/function/mod.rs deleted file mode 100644 index 8cbb1c39..00000000 --- a/dotscope/src/analysis/ssa/function/mod.rs +++ /dev/null @@ -1,1610 +0,0 @@ -//! SSA function representation - a complete method in SSA form. -//! -//! An `SsaFunction` is the top-level container for a method's SSA representation. -//! It holds all SSA blocks, variables, and maintains the relationship to the -//! underlying control flow graph. -//! -//! # Structure -//! -//! ```text -//! SsaFunction -//! ├── blocks: Vec // SSA blocks (1:1 with CFG blocks) -//! ├── variables: Vec // All SSA variables -//! ├── num_args: usize // Number of method arguments -//! └── num_locals: usize // Number of local variables -//! ``` -//! -//! # Construction -//! -//! An `SsaFunction` is built by the `SsaConverter` which: -//! 1. Simulates the stack to create explicit variables -//! 2. Places phi nodes at dominance frontiers -//! 3. Renames variables to achieve single-assignment form -//! -//! # Thread Safety -//! -//! `SsaFunction` is `Send` and `Sync` once constructed. - -mod canonical; -mod duplication; -mod queries; -mod rebuild; -mod repair; -mod semantics; -mod transforms; - -pub use queries::{MethodPurity, ReturnInfo}; -pub use transforms::TrivialPhiOptions; - -use std::{ - collections::{BTreeMap, BTreeSet}, - fmt, -}; - -use crate::{ - analysis::ssa::{ - exception::SsaExceptionHandler, - verifier::{SsaVerifier, VerifyLevel}, - DefSite, FunctionVarAllocator, PhiNode, PhiOperand, SsaBlock, SsaInstruction, SsaOp, - SsaType, SsaVarId, SsaVariable, VariableOrigin, - }, - metadata::signatures::SignatureLocalVariable, -}; - -/// A method in SSA (Static Single Assignment) form. -/// -/// This is the complete SSA representation of a CIL method, containing: -/// - All basic blocks with phi nodes and SSA instructions -/// - All SSA variables with their metadata -/// - Method signature information (argument/local counts) -/// - Exception handlers from the original method body -/// -/// # Examples -/// -/// ```rust,no_run -/// use dotscope::analysis::{SsaFunction, SsaBlock, SsaVarId}; -/// -/// // Create an SSA function with 2 args, 1 local, and 3 blocks -/// let mut func = SsaFunction::new(2, 1); -/// -/// // Add blocks -/// func.add_block(SsaBlock::new(0)); -/// func.add_block(SsaBlock::new(1)); -/// func.add_block(SsaBlock::new(2)); -/// -/// // Query variables -/// for var in func.variables() { -/// println!("Variable: {}", var); -/// } -/// ``` -#[derive(Debug, Clone)] -pub struct SsaFunction { - /// SSA basic blocks, indexed by block ID. - blocks: Vec, - - /// All SSA variables in this function, densely indexed by `SsaVarId`. - /// - /// Invariant: `variables[i].id().index() == i` for all valid indices. - /// This is maintained by `add_variable()` (which assigns dense IDs) and - /// `compact_variables()` (which re-establishes density after removals). - variables: Vec, - - /// Per-function allocator for dense variable IDs. - var_allocator: FunctionVarAllocator, - - /// Maps each origin to its variable IDs, ordered by version. - /// - /// This enables O(1) lookup of all versions of a given origin - /// (e.g., "all versions of Local(3)") without scanning all variables. - origin_versions: BTreeMap>, - - /// Maps each variable origin to its canonical type. - /// - /// Populated during SSA construction from method signatures and instruction - /// type inference. Used by [`create_variable_for_origin()`](Self::create_variable_for_origin) - /// to ensure new variable versions always get proper types. - origin_types: BTreeMap, - - /// Number of method arguments. - num_args: usize, - - /// Number of local variables. - num_locals: usize, - - /// Number of locals from the original method signature. - original_num_locals: usize, - - /// Variables that control input-dependent control flow. - /// Switches using these variables should not be simplified to jumps - /// even if the value appears to be constant on some paths. - preserved_dispatch_vars: BTreeSet, - - /// Original local variable types from the method signature. - /// These are preserved during SSA construction so they can be used - /// during code generation to maintain correct type information. - original_local_types: Option>, - - /// Exception handlers from the original method body. - /// These are preserved during SSA construction and remapped during - /// code generation based on the new instruction layout. - exception_handlers: Vec, - - /// Rename group for each variable, indexed by `SsaVarId::index()`. - /// - /// During SSA construction and rebuild, variables that share the same - /// "version stack" for phi placement and renaming are assigned the same - /// group ID. This separates the rename-grouping concern from - /// `VariableOrigin`, which tracks provenance only. - /// - /// Group assignment (by converter/rebuild): - /// - `Argument(i)` → group `i` - /// - `Local(i)` → group `num_args + i` - /// - Stack temp at depth D → group `num_args + num_locals + D` - /// - Orphan/pass-created → auto-incrementing from max group + 1 - rename_groups: Vec, -} - -impl SsaFunction { - /// Creates a new empty SSA function. - /// - /// # Arguments - /// - /// * `num_args` - Number of method arguments (including `this` for instance methods) - /// * `num_locals` - Number of local variables declared in the method - /// - /// # Returns - /// - /// A new empty [`SsaFunction`] with no blocks or variables. - #[must_use] - pub fn new(num_args: usize, num_locals: usize) -> Self { - Self { - blocks: Vec::new(), - variables: Vec::new(), - var_allocator: FunctionVarAllocator::new(), - origin_versions: BTreeMap::new(), - origin_types: BTreeMap::new(), - num_args, - num_locals, - original_num_locals: num_locals, - preserved_dispatch_vars: BTreeSet::new(), - original_local_types: None, - exception_handlers: Vec::new(), - rename_groups: Vec::new(), - } - } - - /// Creates a new SSA function with pre-allocated capacity. - /// - /// # Arguments - /// - /// * `num_args` - Number of method arguments - /// * `num_locals` - Number of local variables - /// * `block_capacity` - Expected number of blocks - /// * `var_capacity` - Expected number of SSA variables - /// - /// # Returns - /// - /// A new empty [`SsaFunction`] with pre-allocated storage. - #[must_use] - pub fn with_capacity( - num_args: usize, - num_locals: usize, - block_capacity: usize, - var_capacity: usize, - ) -> Self { - Self { - blocks: Vec::with_capacity(block_capacity), - variables: Vec::with_capacity(var_capacity), - var_allocator: FunctionVarAllocator::new(), - origin_versions: BTreeMap::new(), - origin_types: BTreeMap::new(), - num_args, - num_locals, - original_num_locals: num_locals, - preserved_dispatch_vars: BTreeSet::new(), - original_local_types: None, - exception_handlers: Vec::new(), - rename_groups: Vec::with_capacity(var_capacity), - } - } - - /// Returns the SSA blocks. - /// - /// # Returns - /// - /// A slice of all [`SsaBlock`]s in this function. - #[must_use] - pub fn blocks(&self) -> &[SsaBlock] { - &self.blocks - } - - /// Returns an iterator over blocks with their indices. - /// - /// This is a convenience method that pairs each block with its index, - /// avoiding the common `for block_idx in 0..ssa.block_count()` pattern. - /// - /// # Example - /// - /// ```ignore - /// for (block_idx, block) in ssa.iter_blocks() { - /// println!("Block {}: {} instructions", block_idx, block.instruction_count()); - /// } - /// ``` - pub fn iter_blocks(&self) -> impl Iterator { - self.blocks.iter().enumerate() - } - - /// Returns an iterator over all instructions with their block and instruction indices. - /// - /// This flattens the nested block/instruction structure into a single iterator, - /// which is useful for passes that need to scan all instructions. - /// - /// # Example - /// - /// ```ignore - /// for (block_idx, instr_idx, instr) in ssa.iter_instructions() { - /// { let op = instr.op(); - /// // Process instruction at (block_idx, instr_idx) - /// } - /// } - /// ``` - pub fn iter_instructions(&self) -> impl Iterator { - self.blocks - .iter() - .enumerate() - .flat_map(|(block_idx, block)| { - block - .instructions() - .iter() - .enumerate() - .map(move |(instr_idx, instr)| (block_idx, instr_idx, instr)) - }) - } - - /// Returns a mutable iterator over all instructions with their block and instruction indices. - /// - /// This is the mutable counterpart to [`iter_instructions`], allowing passes to - /// modify instructions while iterating. Note that structural changes (adding/removing - /// instructions) require collecting the modifications and applying them separately. - /// - /// # Example - /// - /// ```ignore - /// // Replace all uses of old_var with new_var - /// for (block_idx, instr_idx, instr) in ssa.iter_instructions_mut() { - /// instr.op_mut().replace_uses(old_var, new_var); - /// } - /// ``` - /// - /// # Note - /// - /// For passes that need to add or remove instructions, use [`blocks_mut`] to access - /// the blocks directly, as the iterator cannot handle structural modifications. - /// - /// [`iter_instructions`]: Self::iter_instructions - /// [`blocks_mut`]: Self::blocks_mut - pub fn iter_instructions_mut( - &mut self, - ) -> impl Iterator { - self.blocks - .iter_mut() - .enumerate() - .flat_map(|(block_idx, block)| { - block - .instructions_mut() - .iter_mut() - .enumerate() - .map(move |(instr_idx, instr)| (block_idx, instr_idx, instr)) - }) - } - - /// Returns an iterator over all phi nodes with their block and phi indices. - /// - /// This flattens the nested block/phi structure into a single iterator, - /// which is useful for passes that need to analyze all phi nodes. - /// - /// # Example - /// - /// ```ignore - /// for (block_idx, phi_idx, phi) in ssa.iter_phis() { - /// println!("Phi {} in block {} defines {}", phi_idx, block_idx, phi.result()); - /// } - /// ``` - pub fn iter_phis(&self) -> impl Iterator { - self.blocks - .iter() - .enumerate() - .flat_map(|(block_idx, block)| { - block - .phi_nodes() - .iter() - .enumerate() - .map(move |(phi_idx, phi)| (block_idx, phi_idx, phi)) - }) - } - - /// Returns a mutable reference to the blocks. - /// - /// # Returns - /// - /// A mutable reference to the vector of [`SsaBlock`]s. - pub fn blocks_mut(&mut self) -> &mut Vec { - &mut self.blocks - } - - /// Returns the SSA variables. - /// - /// # Returns - /// - /// A slice of all [`SsaVariable`]s in this function. - #[must_use] - pub fn variables(&self) -> &[SsaVariable] { - &self.variables - } - - /// Returns a mutable reference to the variables. - /// - /// # Returns - /// - /// A mutable reference to the vector of [`SsaVariable`]s. - pub fn variables_mut(&mut self) -> &mut Vec { - &mut self.variables - } - - /// Returns the number of method arguments. - /// - /// # Returns - /// - /// The count of method arguments, including `this` for instance methods. - #[must_use] - pub const fn num_args(&self) -> usize { - self.num_args - } - - /// Returns the number of local variables. - /// - /// # Returns - /// - /// The count of local variables declared in the method. - #[must_use] - pub const fn num_locals(&self) -> usize { - self.num_locals - } - - /// Returns the number of locals from the original method signature. - /// - /// With the group-based rename system, this is always equal to `num_locals` - /// since stack temporaries use `Phi` origin instead of inflated local indices. - #[must_use] - pub const fn original_num_locals(&self) -> usize { - self.original_num_locals - } - - /// Sets the total number of local variables. - pub(crate) fn set_num_locals(&mut self, num_locals: usize, original_num_locals: usize) { - self.num_locals = num_locals; - self.original_num_locals = original_num_locals; - } - - /// Returns the number of blocks. - /// - /// # Returns - /// - /// The count of basic blocks in this function. - #[must_use] - pub fn block_count(&self) -> usize { - self.blocks.len() - } - - /// Returns the number of variables. - /// - /// # Returns - /// - /// The count of SSA variables in this function. - #[must_use] - pub fn variable_count(&self) -> usize { - self.variables.len() - } - - /// Returns the minimum BitSet capacity needed to index all variable IDs - /// that appear in this function (in the variables vec, block instructions, - /// and phi nodes). - /// - /// This handles cases where variable IDs don't match their position in - /// the variables vector (e.g., in test code using `SsaVarId::from_index` - /// without registering via `create_variable`). - #[must_use] - pub fn var_id_capacity(&self) -> usize { - let from_vars = self - .variables - .iter() - .map(|v| v.id().index().saturating_add(1)) - .max() - .unwrap_or(0); - let from_blocks = self - .blocks - .iter() - .flat_map(|b| { - let phi_ids = b.phi_nodes().iter().flat_map(|p| { - std::iter::once(p.result().index()) - .chain(p.operands().iter().map(|op| op.value().index())) - }); - let instr_ids = b.instructions().iter().flat_map(|i| { - i.op() - .dest() - .into_iter() - .chain(i.op().uses()) - .map(|v| v.index()) - }); - phi_ids.chain(instr_ids) - }) - .max() - .map_or(0, |m| m.saturating_add(1)); - from_vars.max(from_blocks).max(self.variables.len()) - } - - /// Returns all variable IDs for a given origin, ordered by creation. - /// - /// This is O(1) via the version registry. For example, - /// `versions_of(VariableOrigin::Local(3))` returns all SSA versions - /// of local variable 3. - #[must_use] - pub fn versions_of(&self, origin: VariableOrigin) -> &[SsaVarId] { - self.origin_versions - .get(&origin) - .map_or(&[], |v| v.as_slice()) - } - - /// Returns the most recently created variable ID for a given origin. - #[must_use] - pub fn latest_version(&self, origin: VariableOrigin) -> Option { - self.origin_versions - .get(&origin) - .and_then(|v| v.last().copied()) - } - - /// Gets the local index for a variable ID. - /// - /// With dense IDs, this is always O(1) — the index equals `id.index()`. - /// - /// # Arguments - /// - /// * `id` - The variable ID to look up - /// - /// # Returns - /// - /// The local index (0-based), or `None` if the variable is not in this function. - #[must_use] - pub fn var_index(&self, id: SsaVarId) -> Option { - let idx = id.index(); - if idx < self.variables.len() { - Some(idx) - } else { - None - } - } - - /// Returns `true` if this function has no blocks. - /// - /// # Returns - /// - /// `true` if the function contains no blocks, `false` otherwise. - #[must_use] - pub fn is_empty(&self) -> bool { - self.blocks.is_empty() - } - - /// Gets a block by index. - /// - /// # Arguments - /// - /// * `index` - The block index to retrieve - /// - /// # Returns - /// - /// A reference to the block, or `None` if the index is out of bounds. - #[must_use] - pub fn block(&self, index: usize) -> Option<&SsaBlock> { - self.blocks.get(index) - } - - /// Gets a mutable block by index. - /// - /// # Arguments - /// - /// * `index` - The block index to retrieve - /// - /// # Returns - /// - /// A mutable reference to the block, or `None` if the index is out of bounds. - pub fn block_mut(&mut self, index: usize) -> Option<&mut SsaBlock> { - self.blocks.get_mut(index) - } - - /// Gets a variable by ID. O(1) via dense indexing. - /// - /// # Arguments - /// - /// * `id` - The variable ID to look up - /// - /// # Returns - /// - /// A reference to the variable, or `None` if the ID is out of bounds. - #[must_use] - pub fn variable(&self, id: SsaVarId) -> Option<&SsaVariable> { - self.variables.get(id.index()) - } - - /// Gets a mutable variable by ID. O(1) via dense indexing. - /// - /// # Arguments - /// - /// * `id` - The variable ID to look up - /// - /// # Returns - /// - /// A mutable reference to the variable, or `None` if the ID is out of bounds. - pub fn variable_mut(&mut self, id: SsaVarId) -> Option<&mut SsaVariable> { - self.variables.get_mut(id.index()) - } - - /// Adds a block to this function. - /// - /// # Arguments - /// - /// * `block` - The block to add - pub fn add_block(&mut self, block: SsaBlock) { - self.blocks.push(block); - } - - /// Creates a new variable with a dense ID allocated by this function. - /// - /// This is the **only** way to create variables. The ID is guaranteed to be - /// dense (equal to the variable's index in the variables Vec), enabling - /// O(1) lookup via direct indexing. - /// - /// If `var_type` is not `Unknown`, it is automatically registered in the - /// origin type registry for future lookups. - pub fn create_variable( - &mut self, - origin: VariableOrigin, - version: u32, - def_site: DefSite, - var_type: SsaType, - ) -> SsaVarId { - let id = self.var_allocator.alloc(); - let var = SsaVariable::new(id, origin, version, def_site, var_type.clone()); - debug_assert_eq!(id.index(), self.variables.len()); - self.origin_versions.entry(origin).or_default().push(id); - // Register origin type if known (first concrete type wins) - if !var_type.is_unknown() && !self.origin_types.contains_key(&origin) { - self.origin_types.insert(origin, var_type); - } - self.variables.push(var); - // Extend rename_groups to keep it in sync (default u32::MAX = no group) - if self.rename_groups.len() <= id.index() { - self.rename_groups - .resize(id.index().saturating_add(1), u32::MAX); - } - id - } - - /// Creates a new variable, inferring its type from the origin type registry. - /// - /// This is a convenience method for creating new versions of variables - /// whose origin type was previously registered. If no type is registered - /// for the origin, the variable gets `SsaType::Unknown`. - pub fn create_variable_for_origin( - &mut self, - origin: VariableOrigin, - version: u32, - def_site: DefSite, - ) -> SsaVarId { - let var_type = self.origin_type(origin); - self.create_variable(origin, version, def_site, var_type) - } - - /// Registers the canonical type for a variable origin. - /// - /// Only registers if the type is not `Unknown`. If a type is already - /// registered for this origin, it is not overwritten (first wins). - pub fn register_origin_type(&mut self, origin: VariableOrigin, var_type: SsaType) { - if !var_type.is_unknown() && !self.origin_types.contains_key(&origin) { - self.origin_types.insert(origin, var_type); - } - } - - /// Returns the registered type for a variable origin, or `SsaType::Unknown`. - #[must_use] - pub fn origin_type(&self, origin: VariableOrigin) -> SsaType { - self.origin_types - .get(&origin) - .cloned() - .unwrap_or(SsaType::Unknown) - } - - /// Returns the origin type registry. - #[must_use] - pub fn origin_types(&self) -> &BTreeMap { - &self.origin_types - } - - /// Rebuilds the origin_versions registry from the current variables list. - /// - /// Called after operations that modify the variables list (compact, reindex). - fn rebuild_origin_versions(&mut self) { - self.origin_versions.clear(); - for var in &self.variables { - self.origin_versions - .entry(var.origin()) - .or_default() - .push(var.id()); - } - } - - /// Reassigns dense variable IDs after variable removal. - /// - /// This must be called after removing variables from `self.variables` to restore - /// the dense indexing invariant (`variables[i].id().index() == i`). - /// - /// Returns a mapping from old IDs to new IDs for updating references. - fn reassign_dense_ids(&mut self) -> BTreeMap { - let mut remap = BTreeMap::new(); - let old_groups = std::mem::take(&mut self.rename_groups); - self.var_allocator = FunctionVarAllocator::starting_from(self.variables.len()); - let mut new_groups = vec![u32::MAX; self.variables.len()]; - for (index, var) in self.variables.iter_mut().enumerate() { - let old_id = var.id(); - let new_id = SsaVarId::from_index(index); - // Carry over the rename group from the old position - if let Some(&old_group) = old_groups.get(old_id.index()) { - if let Some(slot) = new_groups.get_mut(index) { - *slot = old_group; - } - } - if old_id != new_id { - remap.insert(old_id, new_id); - var.set_id(new_id); - } - } - self.rename_groups = new_groups; - remap - } - - /// Remaps all variable ID references in blocks (instructions, phi nodes, terminators) - /// using the given old-to-new ID mapping. - fn remap_var_ids_in_blocks(&mut self, remap: &BTreeMap) { - if remap.is_empty() { - return; - } - let lookup = |id: SsaVarId| -> Option { remap.get(&id).copied() }; - let resolve = |id: SsaVarId| -> SsaVarId { remap.get(&id).copied().unwrap_or(id) }; - - for block in &mut self.blocks { - // Remap phi nodes - for phi in block.phi_nodes_mut() { - let old_result = phi.result(); - phi.set_result(resolve(old_result)); - for operand in phi.operands_mut() { - let old_value = operand.value(); - *operand = PhiOperand::new(resolve(old_value), operand.predecessor()); - } - } - // Remap instructions using existing remap_variables - for instr in block.instructions_mut() { - let new_op = instr.op().remap_variables(lookup); - instr.set_op(new_op); - } - } - // Remap preserved_dispatch_vars - let remapped_dispatch: BTreeSet = self - .preserved_dispatch_vars - .iter() - .map(|id| resolve(*id)) - .collect(); - self.preserved_dispatch_vars = remapped_dispatch; - } - - /// Marks a variable as a preserved dispatch variable. - /// - /// Preserved dispatch variables control input-dependent control flow - /// (e.g., switches that depend on runtime input rather than constants). - /// Optimization passes should not simplify switches using these variables - /// even if the value appears constant on some paths. - /// - /// # Arguments - /// - /// * `var` - The variable ID to mark as preserved. - pub fn mark_preserved_dispatch_var(&mut self, var: SsaVarId) { - self.preserved_dispatch_vars.insert(var); - } - - /// Checks if a variable is a preserved dispatch variable. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if this variable controls input-dependent control flow. - #[must_use] - pub fn is_preserved_dispatch_var(&self, var: SsaVarId) -> bool { - self.preserved_dispatch_vars.contains(&var) - } - - /// Checks if any preserved dispatch variables are set. - /// - /// # Returns - /// - /// `true` if there are any preserved dispatch variables. - #[must_use] - pub fn has_preserved_dispatch_vars(&self) -> bool { - !self.preserved_dispatch_vars.is_empty() - } - - /// Sets the original local variable types from the method signature. - /// - /// These types are preserved so they can be used during code generation - /// to maintain correct type information in the output assembly. - /// - /// # Arguments - /// - /// * `types` - The original local variable types from the method signature. - pub fn set_original_local_types(&mut self, types: Vec) { - self.original_local_types = Some(types); - } - - /// Returns the original local variable types if set. - /// - /// # Returns - /// - /// The original local types, or `None` if not set. - #[must_use] - pub fn original_local_types(&self) -> Option<&[SignatureLocalVariable]> { - self.original_local_types.as_deref() - } - - /// Sets the exception handlers for this function. - /// - /// These are preserved from the original method body and will be - /// remapped during code generation based on the new instruction layout. - /// - /// # Arguments - /// - /// * `handlers` - The exception handlers from the original method body. - pub fn set_exception_handlers(&mut self, handlers: Vec) { - self.exception_handlers = handlers; - } - - /// Returns the exception handlers for this function. - /// - /// # Returns - /// - /// A slice of exception handlers, or an empty slice if none are set. - #[must_use] - pub fn exception_handlers(&self) -> &[SsaExceptionHandler] { - &self.exception_handlers - } - - /// Returns whether this function has any exception handlers. - /// - /// # Returns - /// - /// `true` if the function has at least one exception handler. - #[must_use] - pub fn has_exception_handlers(&self) -> bool { - !self.exception_handlers.is_empty() - } - - /// Returns the rename group for a variable. - /// - /// Returns `u32::MAX` if no group has been assigned (the variable was - /// created without a rename group, e.g. by a compiler pass). - #[must_use] - pub(crate) fn rename_group(&self, var_id: SsaVarId) -> u32 { - self.rename_groups - .get(var_id.index()) - .copied() - .unwrap_or(u32::MAX) - } - - /// Sets the rename group for a variable. - /// - /// Extends the `rename_groups` vector with `u32::MAX` if needed. - pub(crate) fn set_rename_group(&mut self, var_id: SsaVarId, group: u32) { - let idx = var_id.index(); - if idx >= self.rename_groups.len() { - self.rename_groups.resize(idx.saturating_add(1), u32::MAX); - } - if let Some(slot) = self.rename_groups.get_mut(idx) { - *slot = group; - } - } - - /// Rebuilds SSA form after CFG modifications (e.g., control flow unflattening). - /// - /// This method performs a complete SSA reconstruction using the standard - /// Cytron et al. algorithm. See [`rebuild::SsaRebuilder`] for the - /// individual phases. - /// - /// This is necessary because after passes like control flow unflattening, - /// the CFG structure changes significantly and PHI nodes may reference - /// variables from removed blocks or have incorrect operands. - pub fn rebuild_ssa(&mut self) -> crate::Result<()> { - if self.blocks.is_empty() { - return Ok(()); - } - rebuild::SsaRebuilder::new(self).rebuild() - } - - /// Sorts instructions in all blocks in topological order. - /// - /// This ensures that within each block, if instruction A uses a value defined - /// by instruction B, then B appears before A. - /// - /// This is called automatically by [`rebuild_ssa`](Self::rebuild_ssa) but can - /// also be called manually after passes that may have disrupted instruction order. - /// - /// # Returns - /// - /// `true` if all blocks were successfully sorted, `false` if any block has - /// cyclic dependencies (which indicates invalid SSA). - pub fn sort_all_blocks_topologically(&mut self) -> bool { - let mut all_sorted = true; - for block in &mut self.blocks { - if !block.sort_instructions_topologically() { - all_sorted = false; - } - } - all_sorted - } - - /// Validates that no meaningfully-used variable has `SsaType::Unknown`. - /// - /// This ensures that all variables whose values are actually consumed have a - /// concrete type. Variables are considered NOT meaningfully used if: - /// - They have no uses at all (dead variables, stripped by DCE) - /// - Their only uses are in `Pop` instructions (value is discarded) - /// - Their only uses are as phi operands where the phi result is also unused - /// - /// # Errors - /// - /// Returns `Err` with a description listing the first Unknown-typed - /// variable that has meaningful uses. - pub fn validate_types(&self) -> Result<(), String> { - for var in &self.variables { - if !var.var_type().is_unknown() || var.uses().is_empty() { - continue; - } - - // Check if all uses are in Pop instructions (value is discarded) - let has_meaningful_use = var.uses().iter().any(|use_site| { - if use_site.is_phi_operand { - // Phi operand — only meaningful if the phi result has a known type. - // If the phi result is also Unknown, this is just Unknown feeding - // Unknown (e.g., uninitialized locals in a loop), not a real error. - if let Some(block) = self.block(use_site.block) { - if let Some(phi) = block.phi(use_site.instruction) { - if let Some(result_var) = self.variable(phi.result()) { - return !result_var.var_type().is_unknown(); - } - } - } - return false; - } - if let Some(block) = self.block(use_site.block) { - if let Some(instr) = block.instruction(use_site.instruction) { - return !matches!(instr.op(), SsaOp::Pop { .. }); - } - } - true // Conservative: assume meaningful if we can't check - }); - - if has_meaningful_use { - // Collect details about the meaningful uses for debugging - let use_details: Vec = var - .uses() - .iter() - .map(|use_site| { - if use_site.is_phi_operand { - return format!("phi in block {}", use_site.block); - } - if let Some(block) = self.block(use_site.block) { - if let Some(instr) = block.instruction(use_site.instruction) { - return format!( - "block {} instr {}: {:?}", - use_site.block, - use_site.instruction, - instr.op() - ); - } - } - format!( - "block {} instr {}: ", - use_site.block, use_site.instruction - ) - }) - .collect(); - return Err(format!( - "Variable {} (origin={:?}) has Unknown type but is used ({} uses): [{}]", - var.id(), - var.origin(), - var.uses().len(), - use_details.join(", ") - )); - } - } - Ok(()) - } - - /// Validates that the SSA function is well-formed. - /// - /// This checks several SSA invariants: - /// - /// 1. **No cyclic dependencies within a block** - Operations must have a valid - /// topological order. If operation A uses the result of operation B, then B - /// must come before A in the instruction list. - /// - /// 2. **Single definition** - Each variable should be defined at most once - /// (the defining property of SSA form). - /// - /// 3. **Phi nodes at block start** - Phi nodes should only appear at the - /// beginning of blocks, not mixed with regular instructions. - /// - /// # Errors - /// - /// Returns `Err` with a description of the problem if any SSA invariant is violated, - /// such as cyclic dependencies, duplicate definitions, or misplaced terminators. - /// - /// # Example - /// - /// ```rust,ignore - /// let ssa = build_ssa_from_method(&method)?; - /// ssa.validate()?; // Returns error if SSA is malformed - /// - /// // After running a pass - /// some_pass.run(&mut ssa); - /// ssa.validate()?; // Check the pass didn't break SSA invariants - /// ``` - pub fn validate(&self) -> Result<(), String> { - let errors = SsaVerifier::new(self).verify(VerifyLevel::Standard); - if errors.is_empty() { - Ok(()) - } else { - Err(errors - .iter() - .map(|e| e.to_string()) - .collect::>() - .join("; ")) - } - } - - /// Checks if the SSA function is valid without returning detailed errors. - /// - /// This is a convenience method that returns `true` if [`validate`](Self::validate) - /// would return `Ok(())`. - /// - /// # Returns - /// - /// `true` if the SSA is well-formed, `false` otherwise. - #[must_use] - pub fn is_valid(&self) -> bool { - self.validate().is_ok() - } -} - -impl fmt::Display for SsaFunction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - writeln!( - f, - "SSA Function ({} args, {} locals):", - self.num_args, self.num_locals - )?; - writeln!(f, " Variables: {}", self.variables.len())?; - writeln!(f, " Blocks: {}", self.blocks.len())?; - writeln!(f)?; - - for block in &self.blocks { - write!(f, "{block}")?; - } - - Ok(()) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - analysis::{ - ssa::{ - ConstValue, DefSite, PhiNode, PhiOperand, SsaBlock, SsaInstruction, SsaOp, SsaType, - SsaVarId, UseSite, VariableOrigin, - }, - SsaFunctionBuilder, - }, - assembly::{FlowType, Instruction, InstructionCategory, Operand, StackBehavior}, - }; - - fn make_test_cil_instruction(mnemonic: &'static str) -> Instruction { - Instruction { - rva: 0x1000, - offset: 0, - size: 1, - opcode: 0x00, - prefix: 0, - mnemonic, - category: InstructionCategory::Misc, - flow_type: FlowType::Sequential, - operand: Operand::None, - stack_behavior: StackBehavior { - pops: 0, - pushes: 0, - net_effect: 0, - }, - branch_targets: vec![], - } - } - - #[test] - fn test_ssa_function_creation() { - let func = SsaFunction::new(2, 3); - assert_eq!(func.num_args(), 2); - assert_eq!(func.num_locals(), 3); - assert!(func.is_empty()); - assert_eq!(func.block_count(), 0); - assert_eq!(func.variable_count(), 0); - } - - #[test] - fn test_ssa_function_with_capacity() { - let func = SsaFunction::with_capacity(2, 1, 10, 50); - assert_eq!(func.num_args(), 2); - assert_eq!(func.num_locals(), 1); - assert!(func.is_empty()); - } - - #[test] - fn test_ssa_function_add_block() { - let mut func = SsaFunction::new(0, 0); - - func.add_block(SsaBlock::new(0)); - func.add_block(SsaBlock::new(1)); - - assert!(!func.is_empty()); - assert_eq!(func.block_count(), 2); - assert!(func.block(0).is_some()); - assert!(func.block(1).is_some()); - assert!(func.block(2).is_none()); - } - - #[test] - fn test_ssa_function_add_variable() { - let mut func = SsaFunction::new(1, 0); - - let id1 = func.create_variable( - VariableOrigin::Argument(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - let id2 = func.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - // IDs should be different - assert_ne!(id1, id2); - assert_eq!(func.variable_count(), 2); - } - - #[test] - fn test_ssa_function_variable_access() { - let mut func = SsaFunction::new(1, 0); - - let id = func.create_variable( - VariableOrigin::Argument(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - assert!(func.variable(id).is_some()); - assert_eq!( - func.variable(id).unwrap().origin(), - VariableOrigin::Argument(0) - ); - } - - #[test] - fn test_ssa_function_argument_variables() { - let mut func = SsaFunction::new(2, 1); - - // Add arg0 version 0 - func.create_variable( - VariableOrigin::Argument(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - // Add arg1 version 0 - func.create_variable( - VariableOrigin::Argument(1), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - // Add arg0 version 1 (redefinition) - func.create_variable( - VariableOrigin::Argument(0), - 1, - DefSite::instruction(1, 0), - SsaType::Unknown, - ); - - // Add local0 version 0 - func.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - let args: Vec<_> = func.argument_variables().collect(); - assert_eq!(args.len(), 2); // Only version 0 of each arg - } - - #[test] - fn test_ssa_function_local_variables() { - let mut func = SsaFunction::new(0, 2); - - func.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - func.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - // Add a non-local variable (phi origin) - should not be counted - func.create_variable(VariableOrigin::Phi, 0, DefSite::phi(0), SsaType::Unknown); - - let locals: Vec<_> = func.local_variables().collect(); - assert_eq!(locals.len(), 2); - } - - #[test] - fn test_ssa_function_variables_from_argument() { - let mut func = SsaFunction::new(2, 0); - - func.create_variable( - VariableOrigin::Argument(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - func.create_variable( - VariableOrigin::Argument(0), - 1, - DefSite::instruction(1, 0), - SsaType::Unknown, - ); - - func.create_variable( - VariableOrigin::Argument(1), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - let arg0_vars: Vec<_> = func.variables_from_argument(0).collect(); - assert_eq!(arg0_vars.len(), 2); - - let arg1_vars: Vec<_> = func.variables_from_argument(1).collect(); - assert_eq!(arg1_vars.len(), 1); - } - - #[test] - fn test_ssa_function_total_phi_count() { - let mut func = SsaFunction::new(0, 0); - - let mut block0 = SsaBlock::new(0); - block0.add_phi(PhiNode::new( - SsaVarId::from_index(0), - VariableOrigin::Local(0), - )); - block0.add_phi(PhiNode::new( - SsaVarId::from_index(1), - VariableOrigin::Local(1), - )); - func.add_block(block0); - - let mut block1 = SsaBlock::new(1); - block1.add_phi(PhiNode::new( - SsaVarId::from_index(2), - VariableOrigin::Local(0), - )); - func.add_block(block1); - - func.add_block(SsaBlock::new(2)); // No phis - - assert_eq!(func.phi_count(), 3); - } - - #[test] - fn test_ssa_function_total_instruction_count() { - let mut func = SsaFunction::new(0, 0); - - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::new( - make_test_cil_instruction("nop"), - SsaOp::Nop, - )); - block0.add_instruction(SsaInstruction::new( - make_test_cil_instruction("nop"), - SsaOp::Nop, - )); - func.add_block(block0); - - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::new( - make_test_cil_instruction("ret"), - SsaOp::Return { value: None }, - )); - func.add_block(block1); - - assert_eq!(func.instruction_count(), 3); - } - - #[test] - fn test_ssa_function_all_phi_nodes() { - let mut func = SsaFunction::new(0, 0); - - let phi_result = SsaVarId::from_index(0); - let phi_operand = SsaVarId::from_index(1); - let mut block0 = SsaBlock::new(0); - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(phi_operand, 1)); - block0.add_phi(phi); - func.add_block(block0); - - let phis: Vec<_> = func.all_phi_nodes().collect(); - assert_eq!(phis.len(), 1); - assert_eq!(phis[0].result(), phi_result); - } - - #[test] - fn test_ssa_function_dead_variables() { - let mut func = SsaFunction::new(0, 0); - - // Variable with no uses (dead) - func.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - // Variable with uses (live) - let live_id = func.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::instruction(0, 1), - SsaType::Unknown, - ); - func.variable_mut(live_id) - .unwrap() - .add_use(UseSite::instruction(0, 2)); - - let dead: Vec<_> = func.dead_variables().collect(); - assert_eq!(dead.len(), 1); - assert_eq!(func.dead_variable_count(), 1); - } - - #[test] - fn test_ssa_function_display() { - let mut func = SsaFunction::new(1, 1); - func.add_block(SsaBlock::new(0)); - - let display = format!("{func}"); - assert!(display.contains("SSA Function")); - assert!(display.contains("1 args")); - assert!(display.contains("1 locals")); - assert!(display.contains("B0:")); - } - - #[test] - fn test_compact_variables_removes_orphaned() { - let mut func = SsaFunction::new(0, 0); - - // Create a variable via create_variable (dense ID 0) - let defined_id = func.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - // Add a block with an instruction that defines the variable - let mut block = SsaBlock::new(0); - let instr = SsaInstruction::new( - make_test_cil_instruction("nop"), - SsaOp::Const { - dest: defined_id, - value: ConstValue::I32(42), - }, - ); - block.add_instruction(instr); - - // Add return - let ret = SsaInstruction::new( - make_test_cil_instruction("ret"), - SsaOp::Return { value: None }, - ); - block.add_instruction(ret); - func.add_block(block); - - // Add an orphaned variable (not defined by any instruction, version > 0 so not entry) - func.create_variable( - VariableOrigin::Local(1), - 1, - DefSite::instruction(0, 99), - SsaType::Unknown, - ); - - assert_eq!(func.variable_count(), 2); - - // Compact should remove the orphaned variable - let removed = func.compact_variables(); - assert_eq!(removed, 1); - assert_eq!(func.variable_count(), 1); - - // The remaining variable should be the defined one (may have been remapped to index 0) - assert!(func.variable(SsaVarId::from_index(0)).is_some()); - } - - #[test] - fn test_compact_variables_preserves_entry_vars() { - let mut func = SsaFunction::new(1, 1); - - // Add arg0 version 0 (entry definition - should be preserved even without instruction) - let arg_id = func.create_variable( - VariableOrigin::Argument(0), - 0, - DefSite::entry(), - SsaType::Unknown, - ); - - // Add local0 version 0 (entry definition - should be preserved) - let local_id = func.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::entry(), - SsaType::Unknown, - ); - - // Add an orphaned variable (version > 0, not defined by any instruction) - func.create_variable( - VariableOrigin::Local(2), - 1, - DefSite::instruction(0, 99), - SsaType::Unknown, - ); - - // Add an empty block - let mut block = SsaBlock::new(0); - let ret = SsaInstruction::new( - make_test_cil_instruction("ret"), - SsaOp::Return { value: None }, - ); - block.add_instruction(ret); - func.add_block(block); - - assert_eq!(func.variable_count(), 3); - - // Compact should preserve arg and local (entry definitions) but remove orphaned - let removed = func.compact_variables(); - assert_eq!(removed, 1); - assert_eq!(func.variable_count(), 2); - - // After compaction, dense IDs are reassigned: arg_id=0, local_id=1 - // arg_id was originally 0 and local_id was originally 1, so they stay the same - assert!(func.variable(arg_id).is_some()); - assert!(func.variable(local_id).is_some()); - } - - #[test] - fn test_find_constants_collects_all_const_instructions() { - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - let c1 = b.const_i32(42); - let c2 = b.const_i32(100); - let _ = b.add(c1, c2); - b.ret(); - }); - }) - .unwrap(); - - let constants = ssa.find_constants(); - assert_eq!(constants.len(), 2); - - // Verify we can look up constants by their variable IDs - let values: Vec<_> = constants.values().collect(); - assert!(values.iter().any(|v| **v == ConstValue::I32(42))); - assert!(values.iter().any(|v| **v == ConstValue::I32(100))); - } - - #[test] - fn test_find_constants_across_multiple_blocks() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(1); - b.jump(1); - }); - f.block(1, |b| { - let _ = b.const_i32(2); - let _ = b.const_i32(3); - b.ret(); - }); - }) - .unwrap(); - - let constants = ssa.find_constants(); - assert_eq!(constants.len(), 3); - } - - #[test] - fn test_find_constants_empty_when_no_constants() { - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - b.ret(); - }); - }) - .unwrap(); - - let constants = ssa.find_constants(); - assert!(constants.is_empty()); - } - - #[test] - fn test_find_trampoline_blocks_in_chain() { - let ssa = SsaFunctionBuilder::new(4, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); // trampoline -> 1 - f.block(1, |b| b.jump(2)); // trampoline -> 2 - f.block(2, |b| b.jump(3)); // trampoline -> 3 - f.block(3, |b| b.ret()); // not a trampoline - }) - .unwrap(); - - // With skip_entry = true, block 0 is excluded - let trampolines = ssa.find_trampoline_blocks(true); - assert_eq!(trampolines.len(), 2); - assert_eq!(trampolines.get(&1), Some(&2)); - assert_eq!(trampolines.get(&2), Some(&3)); - assert!(!trampolines.contains_key(&0)); - - // With skip_entry = false, block 0 is included - let trampolines = ssa.find_trampoline_blocks(false); - assert_eq!(trampolines.len(), 3); - assert_eq!(trampolines.get(&0), Some(&1)); - assert_eq!(trampolines.get(&1), Some(&2)); - assert_eq!(trampolines.get(&2), Some(&3)); - } - - #[test] - fn test_find_trampoline_blocks_mixed_control_flow() { - let ssa = SsaFunctionBuilder::new(4, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); // conditional - not a trampoline - }); - f.block(1, |b| b.jump(3)); // trampoline -> 3 - f.block(2, |b| { - let _ = b.const_i32(42); - b.jump(3); // has extra instruction - not a trampoline - }); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let trampolines = ssa.find_trampoline_blocks(false); - assert_eq!(trampolines.len(), 1); - assert_eq!(trampolines.get(&1), Some(&3)); - } - - #[test] - fn test_find_trampoline_blocks_empty_result() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(1); - b.ret(); - }); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - // No trampolines in this function - let trampolines = ssa.find_trampoline_blocks(false); - assert!(trampolines.is_empty()); - } - - #[test] - fn test_iter_instructions_mut() { - let mut ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - let c1 = b.const_i32(10); - let c2 = b.const_i32(20); - let _ = b.add(c1, c2); - b.ret(); - }); - }) - .unwrap(); - - // Count total instructions - let count = ssa.iter_instructions().count(); - assert_eq!(count, 4); // 2 consts + 1 add + 1 ret - - // Use iter_instructions_mut to count and verify positions - let mut positions: Vec<(usize, usize)> = Vec::new(); - for (block_idx, instr_idx, _instr) in ssa.iter_instructions_mut() { - positions.push((block_idx, instr_idx)); - } - - // All instructions should be in block 0 - assert_eq!(positions.len(), 4); - assert_eq!(positions[0], (0, 0)); - assert_eq!(positions[1], (0, 1)); - assert_eq!(positions[2], (0, 2)); - assert_eq!(positions[3], (0, 3)); - } - - #[test] - fn test_iter_instructions_mut_across_blocks() { - let mut ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(1); - b.jump(1); - }); - f.block(1, |b| { - let _ = b.const_i32(2); - b.ret(); - }); - }) - .unwrap(); - - let positions: Vec<(usize, usize)> = ssa - .iter_instructions_mut() - .map(|(b, i, _)| (b, i)) - .collect(); - - assert_eq!(positions.len(), 4); - // Block 0: const, jump - assert_eq!(positions[0], (0, 0)); - assert_eq!(positions[1], (0, 1)); - // Block 1: const, ret - assert_eq!(positions[2], (1, 0)); - assert_eq!(positions[3], (1, 1)); - } -} diff --git a/dotscope/src/analysis/ssa/function/queries.rs b/dotscope/src/analysis/ssa/function/queries.rs deleted file mode 100644 index fe203de4..00000000 --- a/dotscope/src/analysis/ssa/function/queries.rs +++ /dev/null @@ -1,1225 +0,0 @@ -//! Read-only query methods for SSA functions. -//! -//! These methods analyze SSA functions without modifying them, providing -//! information about variables, control flow, return behavior, and purity. - -use std::collections::BTreeMap; - -use crate::{ - analysis::ssa::{ - ConstValue, PhiNode, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId, SsaVariable, - VariableOrigin, - }, - utils::BitSet, -}; - -/// What a method returns. -#[derive(Debug, Clone, PartialEq)] -pub enum ReturnInfo { - /// Always returns this constant. - Constant(ConstValue), - - /// Returns parameter N unchanged (pass-through). - PassThrough(usize), - - /// Returns a pure computation of parameters (potentially foldable if params are known). - PureComputation, - - /// Has varying return value (depends on state, input, etc.). - Dynamic, - - /// Void method (no return value). - Void, - - /// Return behavior is unknown. - Unknown, -} - -impl ReturnInfo { - /// Checks if the return value is known at compile time. - /// - /// # Returns - /// - /// `true` if the return value is a constant or void. - #[must_use] - pub fn is_known(&self) -> bool { - matches!(self, Self::Constant(_) | Self::Void) - } - - /// Checks if the return value might be foldable with known inputs. - /// - /// # Returns - /// - /// `true` if the return could be computed at compile time given known inputs. - #[must_use] - pub fn is_potentially_foldable(&self) -> bool { - matches!( - self, - Self::Constant(_) | Self::PassThrough(_) | Self::PureComputation - ) - } -} - -/// Purity classification of a method. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum MethodPurity { - /// Method has no side effects - safe to inline, eliminate, or reorder. - Pure, - - /// Method only reads fields but doesn't modify state. - ReadOnly, - - /// Method modifies local state only (fields of `this` object). - LocalMutation, - - /// Method has global side effects (I/O, static fields, exceptions, etc.). - Impure, - - /// Purity is unknown (calls external methods, uses reflection, etc.). - Unknown, -} - -impl MethodPurity { - /// Checks if the method can be safely eliminated if its result is unused. - /// - /// # Returns - /// - /// `true` if the method has no observable side effects. - #[must_use] - pub fn can_eliminate_if_unused(&self) -> bool { - matches!(self, Self::Pure | Self::ReadOnly) - } - - /// Checks if the method can be safely inlined. - /// - /// Pure and ReadOnly methods can always be inlined. LocalMutation can - /// be inlined but requires care with the `this` reference. - /// - /// # Returns - /// - /// `true` if the method is safe to inline. - #[must_use] - pub fn can_inline(&self) -> bool { - // Pure and ReadOnly methods can always be inlined - // LocalMutation can be inlined but requires care with `this` - matches!(self, Self::Pure | Self::ReadOnly | Self::LocalMutation) - } - - /// Checks if calls to this method can be safely reordered. - /// - /// # Returns - /// - /// `true` if calls to this method can be reordered with respect to other calls. - #[must_use] - pub fn can_reorder(&self) -> bool { - matches!(self, Self::Pure) - } -} - -impl SsaFunction { - /// Returns an iterator over argument variables (version 0). - /// - /// These are the initial SSA versions of arguments at method entry. - /// Uses the version registry for O(1) lookup per argument. - /// - /// # Returns - /// - /// An iterator over argument variables with version 0. - pub fn argument_variables(&self) -> impl Iterator { - (0..self.num_args() as u16).filter_map(|idx| { - let origin = VariableOrigin::Argument(idx); - self.versions_of(origin) - .first() - .and_then(|&id| self.variable(id)) - .filter(|v| v.version() == 0) - }) - } - - /// Returns an iterator over local variables (version 0). - /// - /// These are the initial SSA versions of locals at method entry. - /// Uses the version registry for O(1) lookup per local. - /// - /// # Returns - /// - /// An iterator over local variables with version 0. - pub fn local_variables(&self) -> impl Iterator { - (0..self.num_locals() as u16).filter_map(|idx| { - let origin = VariableOrigin::Local(idx); - self.versions_of(origin) - .first() - .and_then(|&id| self.variable(id)) - .filter(|v| v.version() == 0) - }) - } - - /// Finds all variables originating from a specific argument. - /// - /// Uses the version registry for O(1) lookup. - /// - /// # Arguments - /// - /// * `arg_index` - The argument index to filter by - /// - /// # Returns - /// - /// An iterator over all SSA versions of the specified argument. - pub fn variables_from_argument(&self, arg_index: u16) -> impl Iterator { - let origin = VariableOrigin::Argument(arg_index); - self.versions_of(origin) - .iter() - .filter_map(|&id| self.variable(id)) - } - - /// Finds all variables originating from a specific local. - /// - /// Uses the version registry for O(1) lookup. - /// - /// # Arguments - /// - /// * `local_index` - The local variable index to filter by - /// - /// # Returns - /// - /// An iterator over all SSA versions of the specified local variable. - pub fn variables_from_local(&self, local_index: u16) -> impl Iterator { - let origin = VariableOrigin::Local(local_index); - self.versions_of(origin) - .iter() - .filter_map(|&id| self.variable(id)) - } - - /// Returns the total number of phi nodes across all blocks. - /// - /// # Returns - /// - /// The sum of phi node counts in all blocks. - pub fn phi_count(&self) -> usize { - self.blocks().iter().map(SsaBlock::phi_count).sum() - } - - /// Returns the total number of instructions across all blocks. - /// - /// # Returns - /// - /// The sum of instruction counts in all blocks. - pub fn instruction_count(&self) -> usize { - self.blocks().iter().map(SsaBlock::instruction_count).sum() - } - - /// Returns an iterator over all phi nodes in the function. - /// - /// # Returns - /// - /// An iterator yielding references to all [`PhiNode`]s across all blocks. - pub fn all_phi_nodes(&self) -> impl Iterator { - self.blocks().iter().flat_map(SsaBlock::phi_nodes) - } - - /// Returns an iterator over all instructions in the function. - /// - /// # Returns - /// - /// An iterator yielding references to all [`SsaInstruction`]s across all blocks. - pub fn all_instructions(&self) -> impl Iterator { - self.blocks().iter().flat_map(SsaBlock::instructions) - } - - /// Finds dead variables (variables with no uses). - /// - /// # Returns - /// - /// An iterator over variables that have no uses recorded. - pub fn dead_variables(&self) -> impl Iterator { - self.variables().iter().filter(|v| v.is_dead()) - } - - /// Counts dead variables. - /// - /// # Returns - /// - /// The number of variables with no uses. - #[must_use] - pub fn dead_variable_count(&self) -> usize { - self.variables().iter().filter(|v| v.is_dead()).count() - } - - /// Checks if a parameter at the given index is used in the function. - #[must_use] - #[allow(clippy::cast_possible_truncation)] - pub fn is_parameter_used(&self, param_index: usize) -> bool { - // Parameter indices > u16::MAX are not possible in practice - self.variables_from_argument(param_index as u16) - .any(|v| v.use_count() > 0) - } - - /// Returns the use count for a parameter. - #[must_use] - #[allow(clippy::cast_possible_truncation)] - pub fn parameter_use_count(&self, param_index: usize) -> usize { - // Parameter indices > u16::MAX are not possible in practice - self.variables_from_argument(param_index as u16) - .map(SsaVariable::use_count) - .sum() - } - - /// Checks if the function has any XOR operations. - #[must_use] - pub fn has_xor_operations(&self) -> bool { - self.all_instructions() - .any(|instr| matches!(instr.op(), SsaOp::Xor { .. })) - } - - /// Checks if the function has any array element access operations. - #[must_use] - pub fn has_array_element_access(&self) -> bool { - self.all_instructions().any(|instr| { - matches!( - instr.op(), - SsaOp::LoadElement { .. } | SsaOp::StoreElement { .. } - ) - }) - } - - /// Checks if the function has any field store operations. - #[must_use] - pub fn has_field_stores(&self) -> bool { - self.all_instructions().any(|instr| { - matches!( - instr.op(), - SsaOp::StoreField { .. } | SsaOp::StoreStaticField { .. } - ) - }) - } - - /// Checks if the function accesses any static fields. - #[must_use] - pub fn has_static_field_access(&self) -> bool { - self.all_instructions().any(|instr| { - matches!( - instr.op(), - SsaOp::LoadStaticField { .. } - | SsaOp::StoreStaticField { .. } - | SsaOp::LoadStaticFieldAddr { .. } - ) - }) - } - - /// Checks if the function has any field load operations. - #[must_use] - pub fn has_field_loads(&self) -> bool { - self.all_instructions().any(|instr| { - matches!( - instr.op(), - SsaOp::LoadField { .. } | SsaOp::LoadStaticField { .. } - ) - }) - } - - /// Returns the target count of the largest switch in the function, if any. - #[must_use] - pub fn largest_switch_target_count(&self) -> Option { - self.all_instructions() - .filter_map(|instr| { - if let SsaOp::Switch { targets, .. } = instr.op() { - Some(targets.len()) - } else { - None - } - }) - .max() - } - - /// Checks if the function returns void (no return value). - #[must_use] - pub fn is_void_return(&self) -> bool { - self.all_instructions() - .any(|instr| matches!(instr.op(), SsaOp::Return { value: None })) - } - - /// Gets the instruction operation that defines a variable. - /// - /// Searches through all blocks and instructions to find where the given - /// variable is defined (appears as a destination). - /// - /// **Note**: This only returns definitions from instructions, not phi nodes. - /// For phi node definitions, use [`find_phi_defining()`](Self::find_phi_defining). - /// - /// # Arguments - /// - /// * `var` - The SSA variable to look up. - /// - /// # Returns - /// - /// The defining `SsaOp` if found in an instruction, or `None` if the variable - /// is defined by a phi node or not found. - #[must_use] - pub fn get_definition(&self, var: SsaVarId) -> Option<&SsaOp> { - // Fast path: O(1) via the variable's DefSite - if let Some(variable) = self.variable(var) { - let def = variable.def_site(); - if let Some(instr_idx) = def.instruction { - if let Some(block) = self.block(def.block) { - if let Some(instr) = block.instructions().get(instr_idx) { - let op = instr.op(); - if op.dest() == Some(var) { - return Some(op); - } - } - } - } - } - - // Slow path: O(n) scan (DefSite may be stale after transforms or from builder) - for block in self.blocks() { - for instr in block.instructions() { - let op = instr.op(); - if op.dest() == Some(var) { - return Some(op); - } - } - } - None - } - - /// Gets the instruction that defines a variable. - /// - /// Like [`get_definition()`](Self::get_definition) but returns the full - /// `SsaInstruction` instead of just the `SsaOp`. This is needed by codegen - /// to access `instr.result_type()`. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to look up. - /// - /// # Returns - /// - /// The defining `SsaInstruction` if found, or `None` if the variable - /// is defined by a phi node or not found. - #[must_use] - pub fn get_definition_instruction(&self, var: SsaVarId) -> Option<&SsaInstruction> { - // Fast path: O(1) via the variable's DefSite - if let Some(variable) = self.variable(var) { - let def = variable.def_site(); - if let Some(instr_idx) = def.instruction { - if let Some(block) = self.block(def.block) { - if let Some(instr) = block.instructions().get(instr_idx) { - let op = instr.op(); - if op.dest() == Some(var) { - return Some(instr); - } - } - } - } - } - - // Slow path: O(n) scan (DefSite may be stale after transforms or from builder) - for block in self.blocks() { - for instr in block.instructions() { - if instr.op().dest() == Some(var) { - return Some(instr); - } - } - } - None - } - - /// Checks whether replacing `result` with `source` in all uses would - /// create a self-referential instruction (i.e., `source = f(..., source, ...)`). - /// - /// This happens when `source` is defined by an instruction that uses `result`. - /// In such cases, eliminating a trivial phi `result = phi(source, result)` by - /// replacing `result → source` would create a self-referential cycle. - /// - /// # Arguments - /// - /// * `source` - The variable that would become the replacement. - /// * `result` - The variable being replaced (e.g., a trivial phi result). - /// - /// # Returns - /// - /// `true` if the replacement would create a self-referential instruction. - #[must_use] - pub fn would_create_self_reference(&self, source: SsaVarId, result: SsaVarId) -> bool { - self.get_definition(source) - .is_some_and(|op| op.uses().contains(&result)) - } - - /// Like [`would_create_self_reference`](Self::would_create_self_reference), but only - /// considers definitions in reachable blocks. Definitions in unreachable blocks will - /// be cleared by DCE, so they don't create real self-referential cycles. - /// - /// # Arguments - /// - /// * `source` - The variable that would become the replacement. - /// * `result` - The variable being replaced. - /// * `var_def_block` - Map from variable to the block that defines it. - /// * `reachable` - Set of reachable block indices. - #[must_use] - pub fn would_create_self_reference_reachable( - &self, - source: SsaVarId, - result: SsaVarId, - var_def_block: &BTreeMap, - reachable: &BitSet, - ) -> bool { - if let Some(&def_block) = var_def_block.get(&source) { - if reachable.contains(def_block) { - return self.would_create_self_reference(source, result); - } - } - false - } - - /// Checks if a variable is defined by a constant instruction. - /// - /// This is useful for analysis passes that need to identify compile-time - /// constant values vs. runtime-computed values. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to check. - /// - /// # Returns - /// - /// `true` if the variable is defined by a `Const` instruction. - #[must_use] - pub fn is_var_constant(&self, var: SsaVarId) -> bool { - self.get_definition(var) - .is_some_and(|op| matches!(op, SsaOp::Const { .. })) - } - - /// Gets the constant value if a variable is defined by a constant instruction. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to check. - /// - /// # Returns - /// - /// The constant value if the variable is defined by a `Const` instruction, - /// `None` otherwise. - #[must_use] - pub fn get_var_constant(&self, var: SsaVarId) -> Option<&ConstValue> { - match self.get_definition(var) { - Some(SsaOp::Const { value, .. }) => Some(value), - _ => None, - } - } - - /// Returns the constant value of a variable if it was defined by a `Const` operation. - /// - /// Uses the variable's [`DefSite`] for O(1) lookup without a fallback scan. - /// Returns `None` for phi-defined variables or non-constant definitions. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to check. - /// - /// # Returns - /// - /// The constant value if the variable is defined by a `Const` instruction, - /// `None` otherwise. - #[must_use] - pub fn try_constant_value(&self, var: SsaVarId) -> Option { - let variable = self.variable(var)?; - let def_site = variable.def_site(); - - if def_site.is_phi() { - return None; - } - - let block = self.block(def_site.block)?; - let instr = block.instruction(def_site.instruction?)?; - - match instr.op() { - SsaOp::Const { value, .. } => Some(value.clone()), - _ => None, - } - } - - /// Finds the PHI node that defines a variable. - /// - /// Uses O(1) lookup via the variable's definition site when available, - /// falling back to O(n) scan across all blocks otherwise. - /// - /// # Arguments - /// - /// * `var` - The SSA variable ID to find the defining PHI for. - /// - /// # Returns - /// - /// `Some((block_idx, &PhiNode))` if the variable is defined by a PHI node, - /// `None` if the variable is not defined by a PHI or doesn't exist. - #[must_use] - pub fn find_phi_defining(&self, var: SsaVarId) -> Option<(usize, &PhiNode)> { - // Try O(1) lookup via the variable's definition site - if let Some(variable) = self.variable(var) { - let def_site = variable.def_site(); - if def_site.is_phi() { - // Variable is defined by a phi - look in that block - if let Some(block) = self.block(def_site.block) { - for phi in block.phi_nodes() { - if phi.result() == var { - return Some((def_site.block, phi)); - } - } - } - } - // Variable exists but is not defined by a phi - return None; - } - - // Fallback: O(n) scan if variable not defined by a phi in its block - for (block_idx, block) in self.iter_blocks() { - for phi in block.phi_nodes() { - if phi.result() == var { - return Some((block_idx, phi)); - } - } - } - - None - } - - /// Traces a variable backward through arithmetic operations to find a PHI source. - /// - /// This is useful for control flow unflattening where a switch variable may be - /// computed from a state PHI through operations like `(state ^ key) % N`. - /// - /// The tracing follows these operations backward: - /// - `Rem` (remainder): traces the left operand - /// - `Xor`: tries both operands (XOR is commutative) - /// - `And` (bitwise AND): traces the left operand - /// - `Shr`/`Shl` (shifts): traces the value operand - /// - `Copy`: traces the source - /// - /// # Arguments - /// - /// * `var` - The variable to trace backward from. - /// * `target_block` - Optional block where the PHI should be defined. - /// - /// # Returns - /// - /// The PHI variable that is the ultimate source, or `None` if no PHI is found. - #[must_use] - pub fn trace_to_phi(&self, var: SsaVarId, target_block: Option) -> Option { - self.trace_to_phi_impl(var, target_block, 0) - } - - /// Internal implementation with depth limit to prevent infinite recursion. - fn trace_to_phi_impl( - &self, - var: SsaVarId, - target_block: Option, - depth: usize, - ) -> Option { - // Prevent infinite recursion - const MAX_DEPTH: usize = 20; - if depth > MAX_DEPTH { - return None; - } - - // First check if this variable is directly defined by a phi node - if let Some((phi_block, phi)) = self.find_phi_defining(var) { - // If target_block specified, check if phi is in that block - if target_block.is_none_or(|target| phi_block == target) { - return Some(phi.result()); - } - // If not in target block, still return it as a valid PHI - return Some(phi.result()); - } - - // Get the definition of var - let def = self.get_definition(var)?; - - match def { - // If it's a phi node defined as instruction, use its dest - SsaOp::Phi { dest, .. } => Some(*dest), - - // Remainder (state % N) or bitwise AND (state & mask): trace left operand - SsaOp::Rem { left, .. } | SsaOp::And { left, .. } => { - self.trace_to_phi_impl(*left, target_block, depth.saturating_add(1)) - } - - // XOR operation (e.g., state ^ key): try both operands - SsaOp::Xor { left, right, .. } => { - // Try left first - if let Some(phi) = - self.trace_to_phi_impl(*left, target_block, depth.saturating_add(1)) - { - return Some(phi); - } - // Then try right (XOR is commutative) - self.trace_to_phi_impl(*right, target_block, depth.saturating_add(1)) - } - - // Arithmetic operations (ConfuserEx uses mul/add/sub for state transformation) - // e.g., new_state = (state * 529374418) ^ key - SsaOp::Mul { left, right, .. } - | SsaOp::Add { left, right, .. } - | SsaOp::Sub { left, right, .. } => { - // Try left first (usually where the state variable is) - if let Some(phi) = - self.trace_to_phi_impl(*left, target_block, depth.saturating_add(1)) - { - return Some(phi); - } - // Then try right - self.trace_to_phi_impl(*right, target_block, depth.saturating_add(1)) - } - - // Shift operations: trace the value operand - SsaOp::Shr { value, .. } | SsaOp::Shl { value, .. } => { - self.trace_to_phi_impl(*value, target_block, depth.saturating_add(1)) - } - - // Copy: trace through to source - SsaOp::Copy { src, .. } => { - self.trace_to_phi_impl(*src, target_block, depth.saturating_add(1)) - } - - // For other operations (including constants), the variable cannot be traced to a PHI - _ => None, - } - } - - /// Checks if a block has a specific successor in the control flow graph. - /// - /// This checks if control can flow from block `from_block` to block `to_block` - /// through any terminator instruction (Jump, Branch, Switch, etc.). - /// - /// # Arguments - /// - /// * `from_block` - The source block index. - /// * `to_block` - The target block index to check for. - /// - /// # Returns - /// - /// `true` if `to_block` is a successor of `from_block`. - #[must_use] - pub fn block_has_successor(&self, from_block: usize, to_block: usize) -> bool { - let Some(block) = self.block(from_block) else { - return false; - }; - let Some(op) = block.terminator_op() else { - return false; - }; - - op.successors().contains(&to_block) - } - - /// Gets all predecessor blocks that can jump to the given block. - /// - /// This scans all blocks and returns those whose terminator instruction - /// has `block_idx` as a successor. - /// - /// # Arguments - /// - /// * `block_idx` - The target block index. - /// - /// # Returns - /// - /// A vector of block indices that can transfer control to `block_idx`. - #[must_use] - pub fn block_predecessors(&self, block_idx: usize) -> Vec { - let mut preds: Vec = self - .iter_blocks() - .filter(|&(idx, _)| idx != block_idx) - .filter_map(|(idx, block)| { - block - .terminator_op() - .filter(|op| op.successors().contains(&block_idx)) - .map(|_| idx) - }) - .collect(); - - // Include synthetic exception handler edges: try_start -> handler_start. - // This matches SsaCfg::from_ssa() which also adds these edges so that - // handler blocks appear connected in the CFG. - for handler in self.exception_handlers() { - if handler.handler_start_block == Some(block_idx) { - if let Some(try_start) = handler.try_start_block { - if try_start < self.blocks.len() && !preds.contains(&try_start) { - preds.push(try_start); - } - } - } - } - - preds - } - - /// Gets all successor blocks that a given block can jump to. - /// - /// # Arguments - /// - /// * `block_idx` - The source block index. - /// - /// # Returns - /// - /// A vector of block indices that `block_idx` can transfer control to. - #[must_use] - pub fn block_successors(&self, block_idx: usize) -> Vec { - let Some(block) = self.block(block_idx) else { - return Vec::new(); - }; - let Some(op) = block.terminator_op() else { - return Vec::new(); - }; - - let mut succs = op.successors(); - - // Include synthetic exception handler edges: try_start -> handler_start. - // This matches SsaCfg::from_ssa() which also adds these edges so that - // handler blocks appear connected in the CFG. - for handler in self.exception_handlers() { - if handler.try_start_block == Some(block_idx) { - if let Some(handler_start) = handler.handler_start_block { - if handler_start < self.blocks.len() && !succs.contains(&handler_start) { - succs.push(handler_start); - } - } - } - } - - succs - } - - /// Checks if one block can reach another through the CFG. - /// - /// Uses a simple BFS to determine reachability. - /// - /// # Arguments - /// - /// * `from` - The source block index. - /// * `to` - The target block index. - /// * `successor_map` - Precomputed successor map for efficiency. - /// - /// # Returns - /// - /// `true` if there is a path from `from` to `to`, `false` otherwise. - fn block_reaches(from: usize, to: usize, successor_map: &BTreeMap>) -> bool { - if from == to { - return true; - } - - // Determine block count from successor map keys - let block_count = successor_map - .keys() - .copied() - .max() - .map_or(0, |m| m.saturating_add(1)); - let block_count = block_count - .max(from.saturating_add(1)) - .max(to.saturating_add(1)); - let mut visited = BitSet::new(block_count); - let mut worklist = vec![from]; - - while let Some(block_idx) = worklist.pop() { - if block_idx == to { - return true; - } - if block_idx >= block_count || !visited.insert(block_idx) { - continue; - } - if let Some(succs) = successor_map.get(&block_idx) { - worklist.extend(succs.iter().copied()); - } - } - - false - } - - /// Checks if a variable is a parameter variable. - /// - /// In SSA form, parameters are typically mapped to specific variable ranges - /// at the function entry. This method checks if the given variable ID - /// corresponds to a parameter. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to check. - /// - /// # Returns - /// - /// The parameter index if this is a parameter variable, `None` otherwise. - #[must_use] - pub fn is_parameter_variable(&self, var: SsaVarId) -> Option { - let variable_count = self.var_id_capacity(); - self.is_parameter_variable_impl(var, &mut BitSet::new(variable_count), variable_count) - } - - /// Internal implementation with visited set to prevent infinite recursion on cycles. - fn is_parameter_variable_impl( - &self, - var: SsaVarId, - visited: &mut BitSet, - variable_count: usize, - ) -> Option { - // Prevent infinite recursion on cycles - let idx = var.index(); - if idx >= variable_count || !visited.insert(idx) { - return None; - } - - // Parameters are typically assigned at function entry to the first N variables - // where N is the parameter count. The exact mapping depends on SSA construction. - - // Check if this variable's definition is from a parameter load - // or if it's in the initial argument range - let idx = var.index(); - if idx < self.num_args() { - return Some(idx); - } - - // Also check if defined by argument loading - for block in self.blocks() { - for instr in block.instructions() { - let op = instr.op(); - if op.dest() == Some(var) { - // Check if this is loading from an argument - if let SsaOp::Const { .. } = op { - // Not a parameter - return None; - } - // Check for patterns like copy from parameter variable - if let SsaOp::Copy { src, .. } = op { - // Recursively check if source is a parameter - return self.is_parameter_variable_impl(*src, visited, variable_count); - } - } - } - } - - None - } - - /// Counts how many times each variable is used across all blocks. - /// - /// This scans all phi node operands and instruction operands to build - /// a map of variable use counts. This is useful for optimization passes - /// that need to know whether a variable has multiple uses (e.g., for - /// deciding whether to inline an expression). - /// - /// # Returns - /// - /// A map from each used variable ID to its use count. - /// - /// # Example - /// - /// ```ignore - /// let use_counts = ssa.count_uses(); - /// if use_counts.get(&var_id).copied().unwrap_or(0) == 1 { - /// // Variable has single use - safe to inline - /// } - /// ``` - #[must_use] - pub fn count_uses(&self) -> BTreeMap { - let mut counts = BTreeMap::new(); - - for block in self.blocks() { - // Count phi node operands - for phi in block.phi_nodes() { - for operand in phi.operands() { - let entry = counts.entry(operand.value()).or_insert(0_usize); - *entry = entry.saturating_add(1); - } - } - - // Count instruction operands - for instr in block.instructions() { - for var in instr.op().uses() { - let entry = counts.entry(var).or_insert(0_usize); - *entry = entry.saturating_add(1); - } - } - } - - counts - } - - /// Finds all trampoline blocks in this SSA function. - /// - /// A trampoline block is one that has no phi nodes and contains only a single - /// unconditional control transfer (`Jump` or `Leave`). These blocks can be - /// bypassed by redirecting predecessors directly to their targets. - /// - /// # Arguments - /// - /// * `skip_entry` - If true, skips block 0 (entry block). - /// - /// # Returns - /// - /// A map from trampoline block index to its target block index. - /// - /// # Example - /// - /// ```ignore - /// let trampolines = ssa.find_trampoline_blocks(true); - /// for (trampoline_idx, target_idx) in trampolines { - /// // trampoline_idx jumps unconditionally to target_idx - /// } - /// ``` - #[must_use] - pub fn find_trampoline_blocks(&self, skip_entry: bool) -> BTreeMap { - self.iter_blocks() - .filter(|&(block_idx, _)| !skip_entry || block_idx != 0) - .filter_map(|(block_idx, block)| { - block.is_trampoline().map(|target| (block_idx, target)) - }) - .collect() - } - - /// Finds all constant definitions in this SSA function. - /// - /// Scans all blocks for `Const` instructions and returns a mapping from - /// the destination variable to its constant value. - /// - /// # Returns - /// - /// A map from variable ID to its constant value. - /// - /// # Example - /// - /// ```ignore - /// let constants = ssa.find_constants(); - /// if let Some(value) = constants.get(&var_id) { - /// // var_id is defined as a constant with this value - /// } - /// ``` - #[must_use] - pub fn find_constants(&self) -> BTreeMap { - let mut constants = BTreeMap::new(); - - for block in self.blocks() { - for instr in block.instructions() { - if let SsaOp::Const { dest, value } = instr.op() { - constants.insert(*dest, value.clone()); - } - } - } - - constants - } - - /// Finds all blocks that use a given variable. - /// - /// Scans instructions and phi nodes across all blocks to find blocks - /// that reference the specified variable. - /// - /// # Arguments - /// - /// * `var` - The variable ID to search for. - /// * `exclude_block` - Optional block to exclude from results. - /// - /// # Returns - /// - /// A vector of block indices where the variable is used. - #[must_use] - pub fn find_var_user_blocks(&self, var: SsaVarId, exclude_block: Option) -> Vec { - self.iter_blocks() - .filter(|&(block_idx, _)| exclude_block != Some(block_idx)) - .filter(|(_, block)| { - // Check instructions - block.instructions().iter().any(|instr| instr.uses().contains(&var)) - // Check phi operands - || block.phi_nodes().iter().any(|phi| { - phi.operands().iter().any(|op| op.value() == var) - }) - }) - .map(|(block_idx, _)| block_idx) - .collect() - } - - /// Analyzes what this method returns. - /// - /// Examines all return instructions in the SSA function to determine: - /// - If returns a constant - /// - If returns null - /// - If returns "this" parameter - /// - If returns a parameter directly - /// - Otherwise Unknown - #[must_use] - pub fn return_info(&self) -> ReturnInfo { - let mut return_values: Vec> = Vec::new(); - - for block in self.blocks() { - for instr in block.instructions() { - if let SsaOp::Return { value } = instr.op() { - return_values.push(*value); - } - } - } - - // If no returns found, assume void - if return_values.is_empty() { - return ReturnInfo::Void; - } - - // Check if all returns are void (None) - if return_values.iter().all(Option::is_none) { - return ReturnInfo::Void; - } - - // If there's only one return with a value, trace what it is - let non_void_returns: Vec<_> = return_values.iter().filter_map(|v| *v).collect(); - - if non_void_returns.is_empty() { - return ReturnInfo::Void; - } - - // Try to determine what all returns have in common - // Check if they all return the same constant - let mut constants_found: Vec> = Vec::new(); - - for &ret_var in &non_void_returns { - // Find the definition of this variable - let def = self.get_definition(ret_var); - - match def { - Some(SsaOp::Const { value, .. }) => { - // Const includes null values (ConstValue::Null) - constants_found.push(Some(value.clone())); - } - _ => { - constants_found.push(None); - } - } - } - - // If all returns are the same constant - if constants_found.iter().all(Option::is_some) { - if let Some(first) = constants_found.first() { - if constants_found.iter().all(|c| c == first) { - if let Some(const_val) = first { - return ReturnInfo::Constant(const_val.clone()); - } - } - } - } - - // Check if returns a specific parameter (pass-through) - for &ret_var in &non_void_returns { - if let Some(param_idx) = self.is_parameter_variable(ret_var) { - if non_void_returns.len() == 1 { - return ReturnInfo::PassThrough(param_idx); - } - } - } - - // Check if all returns come from pure computations - let all_pure = non_void_returns.iter().all(|&var| { - if let Some(def) = self.get_definition(var) { - def.is_pure() - } else { - false - } - }); - - if all_pure { - return ReturnInfo::PureComputation; - } - - // Returns depend on state or have complex control flow - ReturnInfo::Dynamic - } - - /// Analyzes method purity (side effects). - /// - /// Examines the SSA function for various side effects: - /// - Field stores (instance or static) - /// - Indirect stores (via pointers) - /// - Array element stores - /// - Calls to potentially impure methods - /// - Exception throwing - /// - /// Returns: - /// - `Pure` if the method has no observable side effects - /// - `ReadOnly` if the method only reads state, no writes - /// - `Impure` if the method has definite side effects - /// - `Unknown` if purity cannot be determined - #[must_use] - pub fn purity(&self) -> MethodPurity { - let mut has_writes = false; - let mut has_reads = false; - let mut has_unknown_calls = false; - let mut has_indirect_access = false; - let mut has_throws = false; - - for block in self.blocks() { - for instr in block.instructions() { - match instr.op() { - // Definite writes - impure - SsaOp::StoreField { .. } - | SsaOp::StoreStaticField { .. } - | SsaOp::StoreElement { .. } - | SsaOp::StoreIndirect { .. } - | SsaOp::InitObj { .. } - | SsaOp::InitBlk { .. } - | SsaOp::CopyBlk { .. } => { - has_writes = true; - } - - // Reads from external state - SsaOp::LoadField { .. } - | SsaOp::LoadStaticField { .. } - | SsaOp::LoadElement { .. } - | SsaOp::LoadIndirect { .. } - | SsaOp::LoadObj { .. } => { - has_reads = true; - } - - // Address-of operations might lead to indirect access - SsaOp::LoadFieldAddr { .. } - | SsaOp::LoadStaticFieldAddr { .. } - | SsaOp::LoadElementAddr { .. } => { - has_indirect_access = true; - } - - // Calls need deeper analysis - assume unknown - SsaOp::Call { .. } - | SsaOp::CallVirt { .. } - | SsaOp::CallIndirect { .. } - | SsaOp::NewObj { .. } => { - has_unknown_calls = true; - } - - // Throws are a form of side effect (control flow) - SsaOp::Throw { .. } | SsaOp::Rethrow => { - has_throws = true; - } - - // Everything else is either pure or doesn't affect state - _ => {} - } - } - } - - // Determine purity level based on what we found - if has_writes { - return MethodPurity::Impure; - } - - if has_unknown_calls { - // Calls to unknown methods could be impure - return MethodPurity::Unknown; - } - - if has_throws { - // Throwing exceptions is a side effect (abnormal control flow) - return MethodPurity::Impure; - } - - if has_indirect_access { - // Address-of operations could enable writes we can't track - return MethodPurity::Unknown; - } - - if has_reads { - return MethodPurity::ReadOnly; - } - - MethodPurity::Pure - } -} diff --git a/dotscope/src/analysis/ssa/function/rebuild.rs b/dotscope/src/analysis/ssa/function/rebuild.rs deleted file mode 100644 index 7153d28b..00000000 --- a/dotscope/src/analysis/ssa/function/rebuild.rs +++ /dev/null @@ -1,2118 +0,0 @@ -//! SSA rebuild: reconstructs SSA form after CFG modifications. -//! -//! After passes like control flow unflattening modify the CFG, PHI nodes may -//! reference variables from removed blocks or have incorrect operands. This -//! module provides a structured rebuilder that performs a complete SSA -//! reconstruction using the standard Cytron et al. algorithm. -//! -//! The rebuild is split into named phases, each operating on explicit -//! intermediate state stored in `SsaRebuilder`. This makes the pipeline -//! individually testable and easier to debug. - -use std::collections::{BTreeMap, BTreeSet, HashSet}; - -use crate::{ - analysis::ssa::{ - liveness, - phis::place_pruned_phis, - verifier::{SsaVerifier, VerifierError, VerifyLevel}, - DefSite, PhiOperand, SsaBlock, SsaCfg, SsaFunction, SsaInstruction, SsaOp, SsaType, - SsaVarId, TrivialPhiOptions, VariableOrigin, - }, - utils::{ - graph::{ - algorithms::{compute_dominance_frontiers, compute_dominators, DominatorTree}, - NodeId, RootedGraph, - }, - BitSet, - }, - Error, Result, -}; - -/// Immutable context for SSA variable renaming. -/// -/// Bundles precomputed data structures needed during the rename phase of SSA -/// construction/rebuild. These are all immutable references that are passed -/// unchanged through recursive calls. -struct RenameContext<'a> { - /// Maps variable IDs to their origins (Argument, Local, Phi) - var_origins: &'a BTreeMap, - /// Maps group IDs to their SSA types (for preserving type information) - group_types: &'a BTreeMap, - /// Maps group IDs to their VariableOrigin (for creating variables) - group_origins: &'a BTreeMap, - /// Maps variable IDs to their types (per-variable, for stack-derived locals - /// where different variables at the same origin can have different types) - var_types: &'a BTreeMap, - /// CFG successor map for filling PHI operands - successor_map: &'a BTreeMap>, - /// Dominator tree children for recursive traversal - dom_children: &'a BTreeMap>, - /// Maps block index to ordered list of rename groups for its phi nodes. - /// Built from `place_pruned_phis` return values so rename can associate - /// each phi with its group (needed when multiple groups share `Phi` origin). - phi_groups: &'a BTreeMap>, - /// Number of method arguments (for group ID computation) - num_args: usize, -} - -/// Structured SSA rebuilder. -/// -/// Each phase of SSA reconstruction is a named method that reads from -/// and writes to explicit fields. This replaces the former 935-line -/// monolithic `rebuild_ssa()` function. -pub(crate) struct SsaRebuilder<'a> { - ssa: &'a mut SsaFunction, - - // Phase 1 output: variable origins and types - var_origins: BTreeMap, - /// Maps group ID to SSA type (for preserving type information) - group_types: BTreeMap, - /// Maps group ID to its VariableOrigin (for creating phi nodes) - group_origins: BTreeMap, - /// Per-variable types: preserves the exact type of each variable across rebuild. - /// This is needed because stack-derived locals at the same origin can have different - /// types at different definition points. - var_types: BTreeMap, - - // Phase 2 output: CFG analysis - reachable: BitSet, - dominance_frontiers: Vec, - successor_map: BTreeMap>, - dom_children: BTreeMap>, - - // Phase 3 output: definition sites (keyed by group ID) - defs: BTreeMap>, - - // Phase 3b output: liveness (keyed by group ID) - live_in: BTreeMap, - - // Phase 4 output: per-block phi group mapping - /// Maps block index to ordered list of rename groups for its phi nodes. - phi_groups: BTreeMap>, - - /// Next auto-incrementing group ID for orphans - next_group: u32, -} - -impl<'a> SsaRebuilder<'a> { - pub fn new(ssa: &'a mut SsaFunction) -> Self { - let next_group = (ssa.num_args as u32).saturating_add(ssa.num_locals as u32); - let block_count = ssa.blocks.len(); - Self { - ssa, - var_origins: BTreeMap::new(), - group_types: BTreeMap::new(), - group_origins: BTreeMap::new(), - var_types: BTreeMap::new(), - reachable: BitSet::new(block_count), - dominance_frontiers: Vec::new(), - successor_map: BTreeMap::new(), - dom_children: BTreeMap::new(), - defs: BTreeMap::new(), - live_in: BTreeMap::new(), - phi_groups: BTreeMap::new(), - next_group, - } - } - - /// Computes the set of reachable block indices via BFS from entry + exception handler roots. - fn compute_reachable_blocks(ssa: &SsaFunction, cfg: &SsaCfg<'_>) -> BitSet { - let block_count = ssa.blocks.len(); - let mut reachable = BitSet::new(block_count); - let mut worklist = vec![0usize]; - while let Some(block_idx) = worklist.pop() { - if reachable.insert(block_idx) { - for &succ in cfg.block_successors(block_idx) { - if succ < block_count { - worklist.push(succ); - } - } - } - } - - // Include exception handler entries as roots - for handler in &ssa.exception_handlers { - for block in [handler.handler_start_block, handler.filter_start_block] - .into_iter() - .flatten() - { - if block < block_count && !reachable.contains(block) { - worklist.push(block); - while let Some(b) = worklist.pop() { - if reachable.insert(b) { - for &succ in cfg.block_successors(b) { - if succ < block_count { - worklist.push(succ); - } - } - } - } - } - } - } - - reachable - } - - /// Runs the full SSA rebuild pipeline. - pub fn rebuild(&mut self) -> Result<()> { - // Stage 1: Pre-clean - self.pre_clean_unreachable(); // Phase 1 - self.recompute_groups_from_connectivity(); // Phase 2 - - // Stage 2: Type & origin collection - self.collect_origins(); // Phase 3 - self.propagate_types(); // Phase 4 - self.propagate_instruction_types(); // Phase 5 - self.assign_orphan_origins(); // Phase 6 - - // Stage 3: CFG analysis - self.compute_cfg(); // Phase 7 - self.collect_defs(); // Phase 8 - self.collect_uses_and_liveness(); // Phase 9 - - // Stage 4: Phi placement & rename - self.clear_all_phis(); // Phase 10 - self.place_phis(); // Phase 11 - self.rename(); // Phase 12 - - // Stage 5: Cleanup & compaction - self.eliminate_trivial_phis(); // Phase 13 - self.ssa.strip_nops(); // Phase 14 - self.ssa.compact_variables(); // Phase 15 - self.remove_orphan_pops(); // Phase 16 - self.ssa.reindex_variables(); // Phase 17 - // reindex can cause stale phi operand refs to collide with new IDs - self.eliminate_trivial_phis(); // Phase 18 - self.ssa.shrink_num_locals(); // Phase 19 - - // Verification - self.verify() - } - - /// Validates the rebuilt SSA, filtering to reachable blocks only. - /// - /// Unreachable blocks (e.g., dead CFF dispatcher remnants) may contain stale - /// variable references that weren't processed by the rename phase. - fn verify(&self) -> Result<()> { - let errors = SsaVerifier::new(self.ssa).verify(VerifyLevel::Standard); - let reachable_errors: Vec<&VerifierError> = errors - .iter() - .filter(|e| { - let block = match e { - VerifierError::UndefinedUse { block, .. } - | VerifierError::MissingPhiOperand { block, .. } - | VerifierError::ExtraPhiOperand { block, .. } - | VerifierError::MissingTerminator { block } - | VerifierError::PhiInEntryBlock { block, .. } - | VerifierError::TerminatorNotLast { block, .. } - | VerifierError::IntraBlockCycle { block, .. } - | VerifierError::PlaceholderVariable { block, .. } - | VerifierError::SelfReferentialInstruction { block, .. } => Some(*block), - VerifierError::DominanceViolation { use_block, .. } => Some(*use_block), - VerifierError::DuplicateDefinition { .. } - | VerifierError::UnregisteredVariable { .. } => None, - // OrphanVariable is cosmetic — the variable exists without a - // definition but isn't harmful. Typically caused by v0 entry - // variables from stack temp groups that survive compaction. - VerifierError::OrphanVariable { .. } => return false, - }; - // Keep errors for reachable blocks (or block-independent errors) - block.is_none_or(|b| self.reachable.contains(b)) - }) - .collect(); - - if !reachable_errors.is_empty() { - let msg = reachable_errors - .iter() - .map(|e| e.to_string()) - .collect::>() - .join("; "); - return Err(Error::SsaError(format!( - "SSA rebuild validation failed ({} blocks, {} vars): {}", - self.ssa.blocks().len(), - self.ssa.variables.len(), - msg - ))); - } - - Ok(()) - } - - /// Clones side-effect-free defs from unreachable blocks into the entry - /// block so their downstream uses survive block clearing. - /// - /// Called from `pre_clean_unreachable` before the clear step. Only - /// `Const`, `LoadToken`, and `LoadStaticField` are eligible — these - /// have no operands (or only compile-time-constant operands) and no - /// side effects, so relocating them to the entry is always safe. - fn rescue_orphaned_pure_defs(ssa: &mut SsaFunction, reachable: &BitSet) { - if ssa.blocks.is_empty() { - return; - } - - // Collect every var used from reachable code — instruction operands - // and phi operand values. - let mut reachable_uses: BTreeSet = BTreeSet::new(); - for bi in reachable.iter() { - let Some(block) = ssa.block(bi) else { - continue; - }; - for phi in block.phi_nodes() { - for op in phi.operands() { - reachable_uses.insert(op.value()); - } - } - for instr in block.instructions() { - for u in instr.op().uses() { - reachable_uses.insert(u); - } - } - } - - if reachable_uses.is_empty() { - return; - } - - // Collect rescuable instructions from unreachable blocks. - let mut to_rescue: Vec = Vec::new(); - for bi in 0..ssa.block_count() { - if reachable.contains(bi) { - continue; - } - let Some(block) = ssa.block(bi) else { - continue; - }; - for instr in block.instructions() { - let Some(dest) = instr.def() else { - continue; - }; - if !reachable_uses.contains(&dest) { - continue; - } - if !matches!( - instr.op(), - SsaOp::Const { .. } | SsaOp::LoadToken { .. } | SsaOp::LoadStaticField { .. } - ) { - continue; - } - to_rescue.push(instr.clone()); - } - } - - if to_rescue.is_empty() { - return; - } - - let Some(entry) = ssa.block_mut(0) else { - return; - }; - let term_idx = entry - .instructions() - .iter() - .position(|i| i.is_terminator()) - .unwrap_or(entry.instructions().len()); - for (i, instr) in to_rescue.into_iter().enumerate() { - let pos = term_idx.saturating_add(i); - entry.instructions_mut().insert(pos, instr); - } - } - - /// Removes unreachable blocks and simplifies stale phi operands. - /// - /// Must run BEFORE `recompute_groups_from_connectivity` so that stale phi - /// operands from unreachable predecessors don't create false disconnected - /// components in the union-find. Without this, passes like anti-debug - /// removal that make blocks unreachable (but don't clean up phis in - /// successor blocks) cause phi results to be split from their operands, - /// leading to orphan groups with no definitions after phi clearing. - /// - /// Before clearing, this rescues side-effect-free defs from unreachable - /// blocks whose result variable is still referenced from reachable code. - /// CFG-modifying passes (CFF unflattening, jump threading, dead branch - /// removal, block merging) can redirect control flow past a block that - /// held a loop-invariant `Const` / `LoadToken` / `LoadStaticField` - /// hoisted there by an earlier LICM pass. Wiping the block would orphan - /// the use site and leave codegen reading an uninitialized local slot. - /// Cloning the def into the entry block keeps it dominating every - /// reachable use. - fn pre_clean_unreachable(&mut self) { - let cfg = SsaCfg::from_ssa(self.ssa); - let reachable = Self::compute_reachable_blocks(self.ssa, &cfg); - - Self::rescue_orphaned_pure_defs(self.ssa, &reachable); - - // Clear unreachable blocks - for block_idx in 0..self.ssa.blocks.len() { - if !reachable.contains(block_idx) { - if let Some(b) = self.ssa.blocks.get_mut(block_idx) { - b.instructions_mut().clear(); - b.phi_nodes_mut().clear(); - } - } - } - - // Remove phi operands from unreachable predecessors and collect - // trivial phi replacements (phi with 0 or 1 unique operand) - let mut replacements: Vec<(SsaVarId, SsaVarId)> = Vec::new(); - for block_idx in 0..self.ssa.blocks.len() { - if !reachable.contains(block_idx) { - continue; - } - let Some(block) = self.ssa.blocks.get_mut(block_idx) else { - continue; - }; - - // Remove operands from unreachable predecessors - for phi in block.phi_nodes_mut().iter_mut() { - phi.retain_operands(|pred| reachable.contains(pred)); - } - - // Inline trivial phis (0 or 1 unique operand value) - block.phi_nodes_mut().retain(|phi| { - let operands = phi.operands(); - let Some(first_op) = operands.first() else { - return false; // Remove empty phi - }; - let first = first_op.value(); - // Check if all operands resolve to the same value - if operands - .iter() - .all(|op| op.value() == first || op.value() == phi.result()) - { - replacements.push((phi.result(), first)); - return false; - } - true - }); - } - - // Apply replacements: substitute phi result uses with the single operand - if !replacements.is_empty() { - let replacement_map: BTreeMap = replacements.into_iter().collect(); - for block in &mut self.ssa.blocks { - for instr in block.instructions_mut() { - for (&old_var, &new_var) in &replacement_map { - instr.op_mut().replace_uses(old_var, new_var); - } - } - for phi in block.phi_nodes_mut() { - for (&old_var, &new_var) in &replacement_map { - for op in phi.operands_mut() { - if op.value() == old_var { - *op = PhiOperand::new(new_var, op.predecessor()); - } - } - } - } - } - } - } - - /// Splits stale rename groups that contain disconnected components. - /// - /// After CFG modifications (e.g., CFF unflattening), stale rename groups may - /// cause unrelated variables to share the same group. For example, CFF stores - /// different values (Calculator instance, format string) in the same local slot - /// across different switch cases. After unflattening removes the dispatcher phi - /// that merged them, they should be in separate groups — but the original groups - /// from SSA construction still link them. - /// - /// This method conservatively splits only groups that have multiple disconnected - /// components (based on phi/copy/load connectivity). Groups that are already - /// a single connected component are left unchanged. This avoids the regression - /// of splitting groups that are legitimately connected through the dominator tree - /// but lack explicit phi/copy edges. - fn recompute_groups_from_connectivity(&mut self) { - let num_vars = self.ssa.variables.len(); - if num_vars == 0 { - return; - } - - // Union-find structure (parent array, initially each variable is its own root) - let mut parent: Vec = (0..num_vars).collect(); - let mut rank: Vec = vec![0; num_vars]; - - let find = |parent: &mut Vec, mut x: usize| -> usize { - // Bound the iterations to the union-find size to avoid infinite - // loops if the parent array is corrupted. - for _ in 0..parent.len() { - let Some(&p) = parent.get(x) else { - return x; - }; - if p == x { - return x; - } - // Path halving: parent[x] = parent[parent[x]]. - let pp = parent.get(p).copied().unwrap_or(p); - if let Some(slot) = parent.get_mut(x) { - *slot = pp; - } - x = pp; - } - x - }; - - let union = |parent: &mut Vec, rank: &mut Vec, a: usize, b: usize| { - let ra = find(parent, a); - let rb = find(parent, b); - if ra == rb { - return; - } - let rank_ra = rank.get(ra).copied().unwrap_or(0); - let rank_rb = rank.get(rb).copied().unwrap_or(0); - if rank_ra < rank_rb { - if let Some(slot) = parent.get_mut(ra) { - *slot = rb; - } - } else if rank_ra > rank_rb { - if let Some(slot) = parent.get_mut(rb) { - *slot = ra; - } - } else { - if let Some(slot) = parent.get_mut(rb) { - *slot = ra; - } - if let Some(slot) = rank.get_mut(ra) { - *slot = slot.saturating_add(1); - } - } - }; - - // Build a mapping from SsaVarId to index in the variables array - let mut var_to_idx: BTreeMap = BTreeMap::new(); - for (idx, var) in self.ssa.variables.iter().enumerate() { - var_to_idx.insert(var.id(), idx); - } - - // Union phi operands with their phi result to maintain group connectivity. - // - // We only process phis in REACHABLE blocks — unreachable blocks were - // cleaned in pre_clean_unreachable (Phase 1), so any remaining phis - // are genuine. This is critical: block-merging's trampoline elimination - // updates phi operands to reference new predecessors, but the new - // operand variables may be in different rename groups than the phi - // result. Without unconditional union here, the group splits, causing - // phi placement to skip the entry-only group → switch phis collapse - // → CFF dispatchers are incorrectly constant-folded. - // - // The original same-group restriction was added to avoid false - // connectivity from stale phis. Phase 1's unreachable block cleanup - // eliminates stale phis, making the restriction unnecessary for - // reachable blocks. - let cfg_for_reach = SsaCfg::from_ssa(self.ssa); - let reachable_here = Self::compute_reachable_blocks(self.ssa, &cfg_for_reach); - for block in &self.ssa.blocks { - let block_idx = block.id(); - if !reachable_here.contains(block_idx) { - continue; - } - for phi in block.phi_nodes() { - let phi_result = phi.result(); - if self.ssa.rename_group(phi_result) == u32::MAX { - continue; - } - if let Some(&result_idx) = var_to_idx.get(&phi_result) { - for operand in phi.operands() { - if let Some(&operand_idx) = var_to_idx.get(&operand.value()) { - union(&mut parent, &mut rank, result_idx, operand_idx); - } - } - } - } - } - - // Union copy sources with their destinations - for block in &self.ssa.blocks { - for instr in block.instructions() { - if let SsaOp::Copy { dest, src } = instr.op() { - if let (Some(&dest_idx), Some(&src_idx)) = - (var_to_idx.get(dest), var_to_idx.get(src)) - { - union(&mut parent, &mut rank, dest_idx, src_idx); - } - } - } - } - - // Union LoadLocal/LoadArg destinations with their respective arg/local - // group representatives, so loads from the same slot stay connected. - let num_args = self.ssa.num_args; - let mut arg_local_reps: BTreeMap = BTreeMap::new(); - for (idx, var) in self.ssa.variables.iter().enumerate() { - match var.origin() { - VariableOrigin::Argument(ai) => { - let group = ai as u32; - arg_local_reps.entry(group).or_insert(idx); - } - VariableOrigin::Local(li) => { - let group = (num_args as u32).saturating_add(li as u32); - arg_local_reps.entry(group).or_insert(idx); - } - _ => {} - } - } - for block in &self.ssa.blocks { - for instr in block.instructions() { - match instr.op() { - SsaOp::LoadLocal { dest, local_index } => { - let group = (num_args as u32).saturating_add(*local_index as u32); - if let (Some(&dest_idx), Some(&rep_idx)) = - (var_to_idx.get(dest), arg_local_reps.get(&group)) - { - union(&mut parent, &mut rank, dest_idx, rep_idx); - } - } - SsaOp::LoadArg { dest, arg_index } => { - let group = *arg_index as u32; - if let (Some(&dest_idx), Some(&rep_idx)) = - (var_to_idx.get(dest), arg_local_reps.get(&group)) - { - union(&mut parent, &mut rank, dest_idx, rep_idx); - } - } - _ => {} - } - } - } - - // Collect variables by their CURRENT rename group - let mut group_members: BTreeMap> = BTreeMap::new(); - for (idx, var) in self.ssa.variables.iter().enumerate() { - let group = self.ssa.rename_group(var.id()); - if group != u32::MAX { - group_members.entry(group).or_default().push(idx); - } - } - - // For each existing group, check if it has multiple disconnected components. - // Only split groups that actually have disconnected components. - let max_existing = self.ssa.rename_groups.iter().copied().max().unwrap_or(0); - let mut next_new_group = if max_existing == u32::MAX { - (num_args as u32).saturating_add(self.ssa.num_locals as u32) - } else { - max_existing.saturating_add(1) - }; - - let mut updates: Vec<(SsaVarId, u32)> = Vec::new(); - - let real_local_limit = (num_args as u32).saturating_add(self.ssa.num_locals as u32); - - for (&original_group, members) in &group_members { - if members.len() <= 1 { - continue; // Single-variable groups can't have disconnected components - } - - // Find the distinct connected components within this group - let mut component_roots: BTreeMap> = BTreeMap::new(); - for &idx in members { - let root = find(&mut parent, idx); - component_roots.entry(root).or_default().push(idx); - } - - if component_roots.len() <= 1 { - continue; // Single component — group is fine as-is - } - - // Multiple components detected — split this group. - // Keep the original group ID for the component that contains a variable - // with Argument/Local origin (the canonical component). Assign new group - // IDs to the other components. - let mut canonical_root: Option = None; - for (&root, component_members) in &component_roots { - for &idx in component_members { - let Some(var) = self.ssa.variables.get(idx) else { - continue; - }; - match var.origin() { - VariableOrigin::Argument(_) | VariableOrigin::Local(_) => { - canonical_root = Some(root); - break; - } - _ => {} - } - } - if canonical_root.is_some() { - break; - } - } - - // If no canonical root found, pick the largest component. - // Break ties deterministically by the smallest variable index - // within each component to avoid nondeterministic grouping. - let canonical_root = match canonical_root { - Some(r) => r, - None => match component_roots - .iter() - .max_by(|(_, members_a), (_, members_b)| { - members_a.len().cmp(&members_b.len()).then_with(|| { - let min_a = members_a.iter().copied().min().unwrap_or(usize::MAX); - let min_b = members_b.iter().copied().min().unwrap_or(usize::MAX); - min_b.cmp(&min_a) - }) - }) - .map(|(root, _)| *root) - { - Some(r) => r, - // No components — group is empty, nothing to split. - None => continue, - }, - }; - - // Decide which components to keep vs. split. - // - // Two tiers: - // - // 1. When the canonical component has Argument/Local-origin - // variables, keep only components that share an (origin, type) - // pair. This handles the CFF case where different if/else - // branches assign to the same local with the same type, while - // still splitting when CFF reuses a local slot for different - // types (e.g., Calculator Object vs format String). - // - // 2. When ALL variables in the group are Phi-origin (common after - // a previous rebuild) AND the group represents a real CIL - // local/argument (group ID < num_args + num_locals), fall back - // to type-only comparison: same-type components stay together, - // different-type components are split. Stack temp groups - // (group >= num_args + num_locals) always split. - let canonical_origin_types: HashSet<(VariableOrigin, SsaType)> = component_roots - .get(&canonical_root) - .into_iter() - .flat_map(|members| members.iter()) - .filter_map(|&idx| { - let var = self.ssa.variables.get(idx)?; - match var.origin() { - VariableOrigin::Argument(_) | VariableOrigin::Local(_) => { - Some((var.origin(), var.var_type().clone())) - } - _ => None, - } - }) - .collect(); - - let has_canonical_origin = !canonical_origin_types.is_empty(); - - // Precompute the canonical component's type for type-based - // split decisions (Tier 2 and stack temp groups). - let canonical_type: Option = if !has_canonical_origin { - component_roots - .get(&canonical_root) - .into_iter() - .flat_map(|members| members.iter()) - .filter_map(|&idx| { - let t = self.ssa.variables.get(idx)?.var_type(); - if t.is_unknown() { - None - } else { - Some(t.clone()) - } - }) - .next() - } else { - None - }; - - for (&root, component_members) in &component_roots { - if root == canonical_root { - continue; // Keep original group ID - } - - let keep = if has_canonical_origin { - // Tier 1: origin + type matching - component_members.iter().any(|&idx| { - let Some(var) = self.ssa.variables.get(idx) else { - return false; - }; - match var.origin() { - VariableOrigin::Argument(_) | VariableOrigin::Local(_) => { - canonical_origin_types - .contains(&(var.origin(), var.var_type().clone())) - } - _ => false, - } - }) - } else if original_group < real_local_limit { - // Tier 2: type-only matching for real local/argument groups - // where all variables are Phi-origin (from a previous - // rebuild's rename phase). - let comp_type: Option = component_members - .iter() - .filter_map(|&idx| { - let t = self.ssa.variables.get(idx)?.var_type(); - if t.is_unknown() { - None - } else { - Some(t.clone()) - } - }) - .next(); - - match (&comp_type, &canonical_type) { - (Some(ct), Some(cano_t)) => ct == cano_t, - _ => true, // Unknown types: keep together (conservative) - } - } else { - false // Stack temp group — always split to avoid aliasing - }; - - if keep { - continue; // Same origin+type or same type — keep in canonical group - } - - let new_group = next_new_group; - next_new_group = next_new_group.saturating_add(1); - for &idx in component_members { - let Some(var) = self.ssa.variables.get(idx) else { - continue; - }; - updates.push((var.id(), new_group)); - } - } - } - - for (var_id, new_group) in updates { - self.ssa.set_rename_group(var_id, new_group); - } - - // Split groups that have multiple instruction-level definitions in the - // same block. After block merging or constant hoisting, a single block - // can end up with several Const/Copy definitions in the same rename - // group. The rename phase processes them sequentially and only the LAST - // definition is visible to successors (last-writer-wins). If the first - // definition is the semantically correct one (e.g., the CFF initial - // state), it gets shadowed by a later hoisted constant. - // - // Fix: for each block, keep the FIRST definition of each group and - // assign new groups to subsequent definitions. This makes each hoisted - // constant independent, preventing it from shadowing the original. - let max_group_so_far = self - .ssa - .rename_groups - .iter() - .copied() - .filter(|&g| g != u32::MAX) - .max() - .unwrap_or(next_new_group); - let mut next_split_group = max_group_so_far.saturating_add(1); - let mut split_updates: Vec<(SsaVarId, u32)> = Vec::new(); - - for block in &self.ssa.blocks { - // Track which groups already have a definition in this block. - // Key: group ID, Value: true if first def already seen. - let mut seen_groups: BTreeMap = BTreeMap::new(); - - for instr in block.instructions() { - let Some(dest) = instr.op().dest() else { - continue; - }; - let group = self.ssa.rename_group(dest); - if group == u32::MAX { - continue; - } - if let Some(already_seen) = seen_groups.get_mut(&group) { - if *already_seen { - // Third+ definition — also split - split_updates.push((dest, next_split_group)); - next_split_group = next_split_group.saturating_add(1); - } else { - // Second definition — split this one - split_updates.push((dest, next_split_group)); - next_split_group = next_split_group.saturating_add(1); - *already_seen = true; - } - } else { - // First definition — keep in original group - seen_groups.insert(group, false); - } - } - } - - for (var_id, new_group) in split_updates { - self.ssa.set_rename_group(var_id, new_group); - } - } - - /// Builds var_id → origin map, group → type map, and var_id → type map. - fn collect_origins(&mut self) { - self.var_origins = self - .ssa - .variables - .iter() - .map(|v| (v.id(), v.origin())) - .collect(); - - // Build group_origins from existing rename_groups. - // Prefer Local/Argument origins over Phi — after CFF reconstruction - // clears dispatcher phis, the union-find may group Local-origin and - // Phi-origin variables together. The group must retain the Local origin - // so the codegen allocates a CIL local slot instead of a stack temporary. - for var in &self.ssa.variables { - let group = self.ssa.rename_group(var.id()); - if group != u32::MAX { - self.group_origins - .entry(group) - .and_modify(|existing| { - if matches!(existing, VariableOrigin::Phi) - && matches!( - var.origin(), - VariableOrigin::Local(_) | VariableOrigin::Argument(_) - ) - { - *existing = var.origin(); - } - }) - .or_insert(var.origin()); - } - } - - // Also register arg/local groups - for i in 0..self.ssa.num_args { - let group = i as u32; - self.group_origins - .entry(group) - .or_insert(VariableOrigin::Argument(i as u16)); - } - for i in 0..self.ssa.num_locals { - let group = (self.ssa.num_args as u32).saturating_add(i as u32); - self.group_origins - .entry(group) - .or_insert(VariableOrigin::Local(i as u16)); - } - - // Track the best type for each group (prefer non-unknown types) - // and per-variable types - for var in &self.ssa.variables { - let var_type = var.var_type(); - if !var_type.is_unknown() { - let group = self.ssa.rename_group(var.id()); - if group != u32::MAX { - self.group_types.entry(group).or_insert(var_type.clone()); - } - self.var_types.insert(var.id(), var_type.clone()); - } - } - - // Update next_group to be above all existing groups - let max_existing = self.ssa.rename_groups.iter().copied().max().unwrap_or(0); - if max_existing != u32::MAX { - self.next_group = self.next_group.max(max_existing.saturating_add(1)); - } - } - - /// Propagates types from LoadLocal/LoadArg to their dest's group. - fn propagate_types(&mut self) { - for block in &self.ssa.blocks { - for instr in block.instructions() { - match instr.op() { - SsaOp::LoadLocal { dest, local_index } => { - let dest_group = self.ssa.rename_group(*dest); - if dest_group != u32::MAX && !self.group_types.contains_key(&dest_group) { - let local_group = - (self.ssa.num_args as u32).saturating_add(*local_index as u32); - if let Some(local_type) = self.group_types.get(&local_group).cloned() { - self.group_types.insert(dest_group, local_type); - } - } - } - SsaOp::LoadArg { dest, arg_index } => { - let dest_group = self.ssa.rename_group(*dest); - if dest_group != u32::MAX && !self.group_types.contains_key(&dest_group) { - let arg_group = *arg_index as u32; - if let Some(arg_type) = self.group_types.get(&arg_group).cloned() { - self.group_types.insert(dest_group, arg_type); - } - } - } - _ => {} - } - } - } - } - - /// Infers types for instructions whose destination variable has no type in `group_types`. - /// - /// Runs after `propagate_types()` and before `assign_orphan_origins()`. For each - /// instruction with a destination variable, if the variable's origin has no entry - /// in `group_types`, checks the instruction's `result_type` first (set during SSA - /// construction with full TypeContext), then falls back to `SsaOp::infer_result_type()`. - fn propagate_instruction_types(&mut self) { - for block in &self.ssa.blocks { - for instr in block.instructions() { - if let Some(dest) = instr.op().dest() { - let group = self.ssa.rename_group(dest); - if group != u32::MAX && !self.group_types.contains_key(&group) { - // Priority: instruction result_type (from converter with TypeContext) - // > op structural inference - if let Some(rt) = instr.result_type() { - if !rt.is_unknown() { - self.group_types.insert(group, rt.clone()); - continue; - } - } - if let Some(inferred) = instr.op().infer_result_type() { - self.group_types.insert(group, inferred); - } - } - } - } - } - } - - /// Assigns origins and groups to orphan variables not in self.variables. - /// - /// Orphan variables are created by passes. They need origins and groups so they - /// can be renamed. PHI origins/groups are propagated first, then remaining orphans - /// get `Phi` origin with unique group IDs. - fn assign_orphan_origins(&mut self) { - // First pass: propagate PHI origins/groups to ALL operands. - // This ensures that phi operands use the same origin as the phi during rename, - // so they end up on the same version stack and properly fill phi operands. - // Collect group propagations first to avoid borrow conflicts. - let mut group_propagations: Vec<(SsaVarId, u32)> = Vec::new(); - for block in &self.ssa.blocks { - for phi in block.phi_nodes() { - let phi_origin = self - .var_origins - .get(&phi.result()) - .copied() - .unwrap_or_else(|| phi.origin()); - - let phi_group = self.ssa.rename_group(phi.result()); - - // Assign the PHI's origin to its result if orphan - self.var_origins.entry(phi.result()).or_insert(phi_origin); - - // Assign the phi's origin and group to ORPHAN operands only. - // IMPORTANT: Do NOT overwrite existing origins for non-orphan variables. - for operand in phi.operands() { - let op_var = operand.value(); - self.var_origins.entry(op_var).or_insert(phi_origin); - // Propagate the phi's group to orphan operands - if phi_group != u32::MAX && self.ssa.rename_group(op_var) == u32::MAX { - group_propagations.push((op_var, phi_group)); - } - } - } - } - for (var_id, group) in group_propagations { - self.ssa.set_rename_group(var_id, group); - } - - // Second pass: assign Phi origin and unique group IDs to remaining orphan - // variables. Orphans must participate in rename (with version stacks) so that - // their defs and uses get proper new IDs. Each orphan variable gets its own - // group, mirroring the old behavior where orphans got unique Local(next_idx) - // origins. - // Collect orphan var IDs first to avoid borrow conflicts. - let mut orphan_vars: Vec<(SsaVarId, Option)> = Vec::new(); - for block in &self.ssa.blocks { - for instr in block.instructions() { - for use_var in instr.uses().iter().copied() { - if !self.var_origins.contains_key(&use_var) { - orphan_vars.push((use_var, None)); - } - } - if let Some(dest) = instr.def() { - if !self.var_origins.contains_key(&dest) { - // Prefer instruction result_type (from converter), fall back - // to structural inference - let inferred_type = instr - .result_type() - .filter(|rt| !rt.is_unknown()) - .cloned() - .or_else(|| instr.op().infer_result_type()); - orphan_vars.push((dest, inferred_type)); - } - } - } - } - - for (var_id, inferred_type) in orphan_vars { - self.var_origins - .entry(var_id) - .or_insert(VariableOrigin::Phi); - if self.ssa.rename_group(var_id) == u32::MAX { - let group = self.next_group; - self.next_group = self.next_group.saturating_add(1); - self.ssa.set_rename_group(var_id, group); - self.group_origins.insert(group, VariableOrigin::Phi); - if let Some(inferred) = inferred_type { - self.group_types.entry(group).or_insert(inferred); - } - } - } - - // No num_locals inflation — orphans use Phi origin - } - - /// Computes reachability, dominators, dominance frontiers, and successor/children maps. - fn compute_cfg(&mut self) { - // First pass: compute reachability from the raw SSA - { - let cfg = SsaCfg::from_ssa(self.ssa); - self.reachable = Self::compute_reachable_blocks(self.ssa, &cfg); - } - - // Clear unreachable blocks to remove phantom CFG edges. - // Without this, unreachable blocks (e.g., dead crash code after anti-debug - // removal) keep outgoing edges that pollute the predecessor graph. This - // causes incorrect dominator computation and stale variable references - // during rename. - for block_idx in 0..self.ssa.blocks.len() { - if !self.reachable.contains(block_idx) { - if let Some(b) = self.ssa.blocks.get_mut(block_idx) { - b.instructions_mut().clear(); - b.phi_nodes_mut().clear(); - } - } - } - - // Second pass: rebuild CFG from cleaned-up SSA (no phantom edges). - // The cfg borrow of self.ssa must end before merge_handler_dom_trees - // borrows &mut self, so we scope it in a block. - let (dom_tree, entry_node) = { - let cfg = SsaCfg::from_ssa(self.ssa); - let dom_tree = compute_dominators(&cfg, cfg.entry()); - self.dominance_frontiers = compute_dominance_frontiers(&cfg, &dom_tree); - - // Extract successor map (only for reachable blocks) - for i in self.reachable.iter() { - self.successor_map - .insert(i, cfg.block_successors(i).to_vec()); - } - - // Extract dominator tree children (only for reachable blocks) - let entry_node = cfg.entry(); - for i in self.reachable.iter() { - self.dom_children.insert( - i, - dom_tree - .children(NodeId::new(i)) - .iter() - .filter(|n| { - n.index() < self.ssa.blocks.len() && self.reachable.contains(n.index()) - }) - .map(|n| n.index()) - .collect(), - ); - } - - (dom_tree, entry_node) - }; - - self.merge_handler_dom_trees(&dom_tree, entry_node); - } - - /// Merges local dominator trees for exception handler roots into the main structures. - /// - /// Handler/filter entries not dominated by the entry block are reachable via - /// exception flow but not via the normal dominator tree. This computes local - /// dom trees for each handler root and merges their `dom_children` and - /// `dominance_frontiers` into the main structures so that rename covers all blocks. - fn merge_handler_dom_trees(&mut self, dom_tree: &DominatorTree, entry: NodeId) { - let cfg = SsaCfg::from_ssa(self.ssa); - // IMPORTANT: The handler BFS must NOT cross into blocks already in the - // main dom tree. If it did, the handler's local dom tree could create - // parent→child relationships that conflict with the main tree, introducing - // cycles in dom_children and causing rename_block_recursive to loop forever. - let block_count = self.ssa.blocks.len(); - let mut main_dom_blocks = BitSet::new(block_count); - for b in self.reachable.iter() { - if dom_tree.dominates(entry, NodeId::new(b)) { - main_dom_blocks.insert(b); - } - } - - let mut handler_roots: Vec = Vec::new(); - for handler in &self.ssa.exception_handlers { - for block in [handler.handler_start_block, handler.filter_start_block] - .into_iter() - .flatten() - { - if block < block_count - && self.reachable.contains(block) - && !main_dom_blocks.contains(block) - { - handler_roots.push(block); - } - } - } - - for &root in &handler_roots { - let local_dom = compute_dominators(&cfg, NodeId::new(root)); - let local_df = compute_dominance_frontiers(&cfg, &local_dom); - - // Collect handler-reachable blocks via BFS from root, stopping at - // blocks already in the main dom tree to prevent cycle creation. - let mut handler_reachable = BitSet::new(block_count); - let mut wl = vec![root]; - while let Some(b) = wl.pop() { - if handler_reachable.insert(b) { - for &succ in cfg.block_successors(b) { - if succ < block_count - && self.reachable.contains(succ) - && !main_dom_blocks.contains(succ) - { - wl.push(succ); - } - } - } - } - - // Merge dom_children (only for handler-reachable blocks) - for b in handler_reachable.iter() { - let children: Vec = local_dom - .children(NodeId::new(b)) - .iter() - .filter(|n| n.index() < block_count && handler_reachable.contains(n.index())) - .map(|n| n.index()) - .collect(); - if !children.is_empty() { - self.dom_children.entry(b).or_default().extend(children); - } - } - - // Merge dominance frontiers - for b in handler_reachable.iter() { - let Some(local_b) = local_df.get(b) else { - continue; - }; - if b >= self.dominance_frontiers.len() { - let new_len = b.checked_add(1).unwrap_or(self.dominance_frontiers.len()); - self.dominance_frontiers - .resize(new_len, BitSet::new(block_count)); - } - if let Some(slot) = self.dominance_frontiers.get_mut(b) { - slot.union_with(local_b); - } - } - } - } - - /// Collects definition sites from reachable blocks (before clearing PHIs). - fn collect_defs(&mut self) { - // Arguments are always defined at entry (block 0) — they have values from the caller. - for i in 0..self.ssa.num_args { - let group = i as u32; - self.defs.entry(group).or_default().insert(0); - } - // Only ORIGINAL .NET locals have default-initialization at entry. - // `num_locals == original_num_locals` always now (no inflation). - for i in 0..self.ssa.num_locals { - let group = (self.ssa.num_args as u32).saturating_add(i as u32); - self.defs.entry(group).or_default().insert(0); - } - - // Collect defs from instructions using group IDs. - for block in &self.ssa.blocks { - let block_idx = block.id(); - if !self.reachable.contains(block_idx) { - continue; - } - for instr in block.instructions() { - if let Some(dest) = instr.def() { - let group = self.ssa.rename_group(dest); - if group != u32::MAX { - self.defs.entry(group).or_default().insert(block_idx); - } - } - } - } - } - - /// Collects use sites and computes liveness for pruned SSA phi placement. - fn collect_uses_and_liveness(&mut self) { - let block_count = self.ssa.blocks.len(); - let variable_count = self.ssa.var_id_capacity(); - - // Pre-compute which variables are consumed by non-Nop instructions. - let mut consumed_vars = BitSet::new(variable_count); - for block in &self.ssa.blocks { - if !self.reachable.contains(block.id()) { - continue; - } - for instr in block.instructions() { - if !matches!(instr.op(), SsaOp::Nop) { - for &use_var in instr.uses().iter() { - consumed_vars.insert(use_var.index()); - } - } - } - } - - let mut use_sites: BTreeMap = BTreeMap::new(); - for block in &self.ssa.blocks { - let block_idx = block.id(); - if !self.reachable.contains(block_idx) { - continue; - } - for instr in block.instructions() { - for use_var in instr.uses().iter().copied() { - let group = self.ssa.rename_group(use_var); - if group != u32::MAX { - use_sites - .entry(group) - .or_insert_with(|| BitSet::new(block_count)) - .insert(block_idx); - } - } - // Track implicit uses from LoadLocal/LoadArg - match instr.op() { - SsaOp::LoadLocal { dest, local_index } - if consumed_vars.contains(dest.index()) => - { - let group = (self.ssa.num_args as u32).saturating_add(*local_index as u32); - use_sites - .entry(group) - .or_insert_with(|| BitSet::new(block_count)) - .insert(block_idx); - } - SsaOp::LoadArg { dest, arg_index } if consumed_vars.contains(dest.index()) => { - let group = *arg_index as u32; - use_sites - .entry(group) - .or_insert_with(|| BitSet::new(block_count)) - .insert(block_idx); - } - _ => {} - } - } - } - - // Build successors list for liveness analysis - let successors_list: Vec> = (0..block_count) - .map(|i| self.successor_map.get(&i).cloned().unwrap_or_default()) - .collect(); - - // Convert defs to BTreeMap for liveness - let defs_for_liveness: BTreeMap = self - .defs - .iter() - .map(|(group, blocks)| { - let mut bs = BitSet::new(block_count); - for &b in blocks { - bs.insert(b); - } - (*group, bs) - }) - .collect(); - - self.live_in = liveness::compute_live_in_blocks( - &defs_for_liveness, - &use_sites, - &successors_list, - block_count, - ); - } - - /// Clears all phi nodes from all blocks before fresh placement. - fn clear_all_phis(&mut self) { - for block in &mut self.ssa.blocks { - block.phi_nodes_mut().clear(); - } - } - - /// Places PHI nodes for all groups using iterated dominance frontiers. - fn place_phis(&mut self) { - // Leave target resolver for exception handler phi placement - let leave_target_fn = |block_idx: usize, blocks: &[SsaBlock]| -> Option { - blocks - .get(block_idx) - .and_then(|block| match block.terminator_op() { - Some(SsaOp::Leave { target }) => Some(*target), - _ => None, - }) - }; - - let block_count = self.ssa.blocks.len(); - - // Convert defs to BTreeMap and filter to skip single-entry-only groups - let filtered_defs: BTreeMap = self - .defs - .iter() - .filter(|(_, def_blocks)| !(def_blocks.len() == 1 && def_blocks.contains(&0))) - .map(|(k, v)| { - let mut bs = BitSet::new(block_count); - for &b in v { - bs.insert(b); - } - (*k, bs) - }) - .collect(); - - let group_origins = self.group_origins.clone(); - let num_args = self.ssa.num_args; - - let placements = place_pruned_phis( - &mut self.ssa.blocks, - &filtered_defs, - &self.live_in, - &self.dominance_frontiers, - Some(&self.reachable), - &|_| true, // Process all groups - &|group| { - group_origins - .get(&group) - .copied() - .unwrap_or(if (group as usize) < num_args { - VariableOrigin::Argument(group as u16) - } else { - VariableOrigin::Phi - }) - }, - Some(&leave_target_fn), - ); - - // Build per-block phi group mapping from placement info - for (block_idx, group) in placements { - self.phi_groups.entry(block_idx).or_default().push(group); - } - } - - /// Renames variables after PHI placement during SSA rebuild. - fn rename(&mut self) { - let ctx = RenameContext { - var_origins: &self.var_origins, - group_types: &self.group_types, - group_origins: &self.group_origins, - var_types: &self.var_types, - successor_map: &self.successor_map, - dom_children: &self.dom_children, - phi_groups: &self.phi_groups, - num_args: self.ssa.num_args, - }; - - // Version stacks: for each group, track the current reaching definition - let mut version_stacks: BTreeMap> = BTreeMap::new(); - let mut next_version: BTreeMap = BTreeMap::new(); - - // Initialize with arguments and locals version 0 from existing variables. - // Only use version-0 variables that have an entry-point def_site (no specific - // instruction). Variables with instruction-specific def_sites are actual - // definitions in specific blocks and must NOT be used as the initial reaching - // definition for the entry block, as that would create use-before-def when the - // entry block references a variable defined in a later block. - for var in &self.ssa.variables { - let group = self.ssa.rename_group(var.id()); - if group != u32::MAX { - match var.origin() { - VariableOrigin::Argument(_) | VariableOrigin::Local(_) - if var.version() == 0 && var.def_site().instruction.is_none() => - { - version_stacks.entry(group).or_default().push(var.id()); - next_version.insert(group, 1); - } - _ => {} - } - } - } - - // Ensure all groups that have definitions get a version 0 entry. - // Without version 0 entries, the rename step will leave stale - // references when the version stack is empty (which causes - // apply_rename_map to create use-before-def errors). - for &group in self.defs.keys() { - if !version_stacks.contains_key(&group) { - let origin = self - .group_origins - .get(&group) - .copied() - .unwrap_or(VariableOrigin::Phi); - let var_type = self - .group_types - .get(&group) - .cloned() - .unwrap_or(SsaType::Unknown); - let id = self - .ssa - .create_variable(origin, 0, DefSite::entry(), var_type); - self.ssa.set_rename_group(id, group); - version_stacks.entry(group).or_default().push(id); - next_version.insert(group, 1); - } - } - - let mut rename_map: BTreeMap = BTreeMap::new(); - - // Rename from entry block — the dom tree now covers handler blocks - // via local dom trees computed in compute_cfg(). - Self::rename_block_recursive( - self.ssa, - 0, - &ctx, - &mut version_stacks, - &mut next_version, - &mut rename_map, - ); - - // Rename handler roots that are not reachable from the entry's dom tree. - // With the augmented dom tree, handler body blocks are dom_children of their - // handler root, so rename_block_recursive from the root covers them. - let block_count = self.ssa.blocks.len(); - let mut dom_tree_reachable = BitSet::new(block_count); - let mut dom_stack = vec![0usize]; - while let Some(block_idx) = dom_stack.pop() { - if dom_tree_reachable.insert(block_idx) { - if let Some(children) = ctx.dom_children.get(&block_idx) { - dom_stack.extend(children.iter().copied()); - } - } - } - - for handler in self.ssa.exception_handlers.clone() { - for block in [handler.handler_start_block, handler.filter_start_block] - .into_iter() - .flatten() - { - if self.reachable.contains(block) && !dom_tree_reachable.contains(block) { - Self::rename_block_recursive( - self.ssa, - block, - &ctx, - &mut version_stacks, - &mut next_version, - &mut rename_map, - ); - // Mark this subtree as reachable so we don't re-visit - let mut sub_stack = vec![block]; - while let Some(b) = sub_stack.pop() { - if dom_tree_reachable.insert(b) { - if let Some(children) = ctx.dom_children.get(&b) { - sub_stack.extend(children.iter().copied()); - } - } - } - } - } - } - - // Apply renames to all variable uses - Self::apply_rename_map(self.ssa, &rename_map); - - // Fill missing phi operands: if rename didn't provide an operand for a - // predecessor (e.g., the predecessor was not visited or the version stack - // was empty), fill it with the version 0 entry variable for that origin. - { - // Collect all needed fixups first (to avoid borrow conflicts) - let fixups: Vec<(usize, usize, usize, SsaVarId)> = { - let cfg = SsaCfg::from_ssa(self.ssa); - let mut fixes = Vec::new(); - for block_idx in 0..self.ssa.blocks.len() { - let preds: Vec = cfg.block_predecessors(block_idx).to_vec(); - if preds.is_empty() { - continue; - } - if let Some(block) = self.ssa.block(block_idx) { - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - let mut existing = BitSet::new(block_count); - for op in phi.operands() { - existing.insert(op.predecessor()); - } - let group = self.ssa.rename_group(phi.result()); - for &pred in &preds { - if !existing.contains(pred) { - if let Some(&v0) = - version_stacks.get(&group).and_then(|stack| stack.first()) - { - fixes.push((block_idx, phi_idx, pred, v0)); - } - } - } - } - } - } - fixes - }; - - // Apply fixups - for (block_idx, phi_idx, pred, v0) in fixups { - if let Some(block) = self.ssa.block_mut(block_idx) { - if let Some(phi) = block.phi_nodes_mut().get_mut(phi_idx) { - phi.set_operand(pred, v0); - } - } - } - } - - // Final cleanup: Remove Pop instructions that use undefined variables - { - let variable_count = self.ssa.var_id_capacity(); - let mut defined_vars = BitSet::new(variable_count); - for v in &self.ssa.variables { - let idx = v.id().index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - for block in &mut self.ssa.blocks { - block.instructions_mut().retain(|instr| { - if let SsaOp::Pop { value } = instr.op() { - let idx = value.index(); - return idx < variable_count && defined_vars.contains(idx); - } - true - }); - } - } - - // Eliminate trivial PHIs from rename - self.eliminate_trivial_phis(); - } - - /// Eliminates trivial PHI nodes (all operands resolve to the same value). - /// - /// This also handles phis that become trivial when only considering operands - /// from reachable predecessors. Unreachable predecessors may provide stale - /// version-0 operands that make a phi look non-trivial when it's actually - /// trivial for the reachable control flow. Without this, DCE would have to - /// prune unreachable operands on every iteration, causing a ping-pong cycle. - fn eliminate_trivial_phis(&mut self) { - self.ssa.eliminate_trivial_phis(&TrivialPhiOptions { - reachable: Some(&self.reachable), - }); - - // Also eliminate dead phis: phis whose result is never used. - // This prevents oscillation with DCE: rebuild places phis for all - // variables with multiple defs, but some may have no consumers. - self.ssa.eliminate_dead_phis(); - } - - /// Removes Pop instructions that reference variables removed by - /// `eliminate_trivial_phis` or `compact_variables`. - fn remove_orphan_pops(&mut self) { - let variable_count = self.ssa.var_id_capacity(); - let defined_vars: BitSet = { - let mut d = BitSet::new(variable_count); - for b in self.ssa.blocks() { - for phi in b.phi_nodes() { - let idx = phi.result().index(); - if idx < variable_count { - d.insert(idx); - } - } - for instr in b.instructions() { - if let Some(dest) = instr.op().dest() { - let idx = dest.index(); - if idx < variable_count { - d.insert(idx); - } - } - } - } - d - }; - - // Collect exception/filter handler entry blocks — their Pop instructions - // consume the runtime-pushed exception object which has no SSA definition. - let block_count = self.ssa.blocks.len(); - let mut handler_entry_blocks = BitSet::new(block_count); - for h in &self.ssa.exception_handlers { - if let Some(b) = h.handler_start_block { - if b < block_count { - handler_entry_blocks.insert(b); - } - } - if let Some(b) = h.filter_start_block { - if b < block_count { - handler_entry_blocks.insert(b); - } - } - } - - for block in &mut self.ssa.blocks { - let is_handler_entry = handler_entry_blocks.contains(block.id()); - block.instructions_mut().retain(|instr| { - if let SsaOp::Pop { value } = instr.op() { - // Preserve Pops in handler entry blocks — the exception object - // is pushed by the runtime and has no SSA definition. - if is_handler_entry { - return true; - } - let idx = value.index(); - return idx < variable_count && defined_vars.contains(idx); - } - true - }); - } - } - - /// Iteratively renames variables in a block and its dominated children. - fn rename_block_recursive( - ssa: &mut SsaFunction, - entry_block_idx: usize, - ctx: &RenameContext<'_>, - version_stacks: &mut BTreeMap>, - next_version: &mut BTreeMap, - rename_map: &mut BTreeMap, - ) { - enum RenameAction { - Enter(usize), - Exit(BTreeMap), - } - - let mut work_stack = vec![RenameAction::Enter(entry_block_idx)]; - let mut visited = BitSet::new(ssa.blocks.len()); - - while let Some(action) = work_stack.pop() { - match action { - RenameAction::Exit(pushed_counts) => { - for (group, count) in pushed_counts { - if let Some(stack) = version_stacks.get_mut(&group) { - for _ in 0..count { - stack.pop(); - } - } - } - } - RenameAction::Enter(block_idx) => { - // Guard against cycles in dom_children (can occur when - // multiple exception handlers share blocks outside the - // main dominator tree). - if !visited.insert(block_idx) { - continue; - } - - let pushed_counts = Self::rename_block_process( - ssa, - block_idx, - ctx, - version_stacks, - next_version, - rename_map, - ); - - let children = ctx - .dom_children - .get(&block_idx) - .cloned() - .unwrap_or_default(); - - work_stack.push(RenameAction::Exit(pushed_counts)); - - for child in children.into_iter().rev() { - work_stack.push(RenameAction::Enter(child)); - } - } - } - } - } - - /// Processes a single block during rebuild rename. - fn rename_block_process( - ssa: &mut SsaFunction, - block_idx: usize, - ctx: &RenameContext<'_>, - version_stacks: &mut BTreeMap>, - next_version: &mut BTreeMap, - rename_map: &mut BTreeMap, - ) -> BTreeMap { - let mut pushed_counts: BTreeMap = BTreeMap::new(); - - Self::rename_phis( - ssa, - block_idx, - ctx, - version_stacks, - next_version, - rename_map, - &mut pushed_counts, - ); - Self::rename_instructions( - ssa, - block_idx, - ctx, - version_stacks, - next_version, - rename_map, - &mut pushed_counts, - ); - Self::fill_successor_phi_operands(ssa, block_idx, ctx, version_stacks); - - pushed_counts - } - - /// Renames PHI node results in a block during rebuild rename. - fn rename_phis( - ssa: &mut SsaFunction, - block_idx: usize, - ctx: &RenameContext<'_>, - version_stacks: &mut BTreeMap>, - next_version: &mut BTreeMap, - rename_map: &mut BTreeMap, - pushed_counts: &mut BTreeMap, - ) { - // Look up the group for each phi from the phi_groups mapping built during placement. - let phi_info: Vec<(u32, VariableOrigin, SsaVarId)> = { - let block_phi_groups = ctx.phi_groups.get(&block_idx); - ssa.block(block_idx) - .map(|b| { - b.phi_nodes() - .iter() - .enumerate() - .map(|(i, phi)| { - let origin = phi.origin(); - let group = block_phi_groups - .and_then(|groups| groups.get(i).copied()) - .unwrap_or_else(|| { - // Fallback: derive group from origin for Argument/Local - match origin { - VariableOrigin::Argument(idx) => idx as u32, - VariableOrigin::Local(idx) => { - (ctx.num_args as u32).saturating_add(idx as u32) - } - VariableOrigin::Phi => u32::MAX, - } - }); - (group, origin, phi.result()) - }) - .collect() - }) - .unwrap_or_default() - }; - - for (i, (group, origin, old_result)) in phi_info.iter().enumerate() { - let version = *next_version.get(group).unwrap_or(&0); - let entry = next_version.entry(*group).or_insert(0); - *entry = entry.saturating_add(1); - - let var_type = ctx - .group_types - .get(group) - .cloned() - .unwrap_or(SsaType::Unknown); - let new_var_id = - ssa.create_variable(*origin, version, DefSite::phi(block_idx), var_type); - ssa.set_rename_group(new_var_id, *group); - - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(phi) = block.phi_nodes_mut().get_mut(i) { - phi.set_result(new_var_id); - } - } - - version_stacks.entry(*group).or_default().push(new_var_id); - let pc = pushed_counts.entry(*group).or_insert(0); - *pc = pc.saturating_add(1); - - if *old_result != new_var_id { - rename_map.insert(*old_result, new_var_id); - } - } - } - - /// Renames instruction uses and definitions in a block during rebuild rename. - fn rename_instructions( - ssa: &mut SsaFunction, - block_idx: usize, - ctx: &RenameContext<'_>, - version_stacks: &mut BTreeMap>, - next_version: &mut BTreeMap, - rename_map: &mut BTreeMap, - pushed_counts: &mut BTreeMap, - ) { - // Collect instruction info including load targets for LoadArg/LoadLocal. - // A load_target of Some(group) means the instruction is a LoadArg/LoadLocal - // that reads from the given arg/local group. During rename, these are - // resolved to the current reaching definition instead of creating new versions, - // ensuring multiple loads of the same arg/local produce the same SSA variable. - type InstrRenameInfo = (usize, Vec, Option, Option); - let instr_info: Vec = ssa - .block(block_idx) - .map(|b| { - b.instructions() - .iter() - .enumerate() - .map(|(i, instr)| { - let load_target_group = match instr.op() { - SsaOp::LoadArg { arg_index, .. } => Some(*arg_index as u32), - SsaOp::LoadLocal { local_index, .. } => { - Some((ctx.num_args as u32).saturating_add(*local_index as u32)) - } - _ => None, - }; - (i, instr.uses(), instr.def(), load_target_group) - }) - .collect() - }) - .unwrap_or_default(); - - for (instr_idx, old_uses, opt_def, load_target_group) in &instr_info { - // Apply use renames directly to the instruction - let mut use_renames: Vec<(SsaVarId, SsaVarId)> = Vec::new(); - for &old_use in old_uses { - let group = ssa.rename_group(old_use); - if group != u32::MAX { - if let Some(reaching_def) = version_stacks - .get(&group) - .and_then(|stack| stack.last().copied()) - { - if reaching_def != old_use { - use_renames.push((old_use, reaching_def)); - } - } - } - } - - if !use_renames.is_empty() { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) { - let op = instr.op_mut(); - for (old_use, new_use) in &use_renames { - op.replace_uses(*old_use, *new_use); - } - } - } - } - - if let Some(old_dest) = opt_def { - // LoadArg/LoadLocal: resolve dest to the current reaching definition - // for the arg/local instead of creating a new version. This ensures - // that multiple loads of the same arg/local produce the same SSA - // variable, enabling patterns like `x - x = 0` to be recognized. - if let Some(target_group) = load_target_group { - if let Some(reaching_def) = version_stacks - .get(target_group) - .and_then(|stack| stack.last().copied()) - { - rename_map.insert(*old_dest, reaching_def); - // Also push the reaching def onto the dest's group stack - // so that within-block uses (which resolve via version_stacks, not - // rename_map) also see the correct reaching definition. - let dest_group = ssa.rename_group(*old_dest); - if dest_group != u32::MAX { - version_stacks - .entry(dest_group) - .or_default() - .push(reaching_def); - let pc = pushed_counts.entry(dest_group).or_insert(0); - *pc = pc.saturating_add(1); - } - // Convert to Nop since the value is the reaching definition - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) { - instr.set_op(SsaOp::Nop); - } - } - continue; - } - } - - let group = ssa.rename_group(*old_dest); - let origin = ctx.var_origins.get(old_dest).copied(); - if group != u32::MAX { - if let Some(origin) = origin { - let version = *next_version.get(&group).unwrap_or(&0); - let nv = next_version.entry(group).or_insert(0); - *nv = nv.saturating_add(1); - - // Use per-variable type first (preserves stack-derived local types), - // fall back to per-group type - let var_type = ctx - .var_types - .get(old_dest) - .or_else(|| ctx.group_types.get(&group)) - .cloned() - .unwrap_or(SsaType::Unknown); - let new_var_id = ssa.create_variable( - origin, - version, - DefSite::instruction(block_idx, *instr_idx), - var_type, - ); - ssa.set_rename_group(new_var_id, group); - - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) { - instr.op_mut().set_dest(new_var_id); - } - } - - version_stacks.entry(group).or_default().push(new_var_id); - let pc = pushed_counts.entry(group).or_insert(0); - *pc = pc.saturating_add(1); - - if *old_dest != new_var_id { - rename_map.insert(*old_dest, new_var_id); - } - } - } - } - } - } - - /// Fills PHI operands in successor blocks with current reaching definitions. - fn fill_successor_phi_operands( - ssa: &mut SsaFunction, - block_idx: usize, - ctx: &RenameContext<'_>, - version_stacks: &mut BTreeMap>, - ) { - let successors = ctx - .successor_map - .get(&block_idx) - .cloned() - .unwrap_or_default(); - for succ_idx in successors { - // Collect each successor phi's group from the phi_groups mapping - let phi_updates: Vec<(usize, u32)> = { - let succ_phi_groups = ctx.phi_groups.get(&succ_idx); - ssa.block(succ_idx) - .map(|b| { - b.phi_nodes() - .iter() - .enumerate() - .map(|(i, phi)| { - let group = succ_phi_groups - .and_then(|groups| groups.get(i).copied()) - .unwrap_or_else(|| { - // Fallback for phis created outside place_pruned_phis - // (e.g., by rename itself). Use the phi result's group. - let result = phi.result(); - if result.index() < ssa.rename_groups.len() { - ssa.rename_group(result) - } else { - match phi.origin() { - VariableOrigin::Argument(idx) => idx as u32, - VariableOrigin::Local(idx) => { - (ctx.num_args as u32).saturating_add(idx as u32) - } - VariableOrigin::Phi => u32::MAX, - } - } - }); - (i, group) - }) - .collect() - }) - .unwrap_or_default() - }; - - for (phi_idx, group) in phi_updates { - if group == u32::MAX { - continue; - } - // Determine the reaching definition. Using stack.last() is almost - // always correct — it's the most recent definition of the group - // that dominates this edge. But for a back-edge where the loop - // body doesn't redefine the group, stack.last() is the successor - // block's OWN PHI result (pushed when rename entered the - // successor), which would produce a self-referential PHI operand. - // - // Self-referential operands are technically valid SSA (meaning - // "no change on this edge") but they destroy per-edge attribution: - // an earlier LICM pass may have legitimately hoisted a case-block - // state-update Const to a preheader and left the PHI operand - // pointing at the hoisted variable. That hoisted variable still - // dominates this edge, so the existing operand value is the - // correct reaching def — but rename's dominator-tree walk only - // sees the header PHI on top of the stack because case blocks - // don't redefine the group locally. Overwriting with the header - // PHI result loses the hoisted value. - // - // When the most-recent reaching def would be the successor's own - // PHI result, walk down the version stack to find the next - // reaching def that ISN'T a self-reference. If none exists, skip - // the fill — preserving whatever operand is already there (e.g., - // a LICM-maintained reference to the hoisted variable). - let succ_phi_result = ssa - .block(succ_idx) - .and_then(|b| b.phi_nodes().get(phi_idx)) - .map(|phi| phi.result()); - let reaching_def = version_stacks.get(&group).and_then(|stack| { - stack - .iter() - .rev() - .find(|&&v| Some(v) != succ_phi_result) - .copied() - }); - if let Some(reaching_def) = reaching_def { - if let Some(succ_block) = ssa.block_mut(succ_idx) { - if let Some(phi) = succ_block.phi_nodes_mut().get_mut(phi_idx) { - phi.set_operand(block_idx, reaching_def); - } - } - } - } - } - } - - /// Applies the rename map to all variable uses in the function. - fn apply_rename_map(ssa: &mut SsaFunction, rename_map: &BTreeMap) { - if rename_map.is_empty() { - return; - } - - let resolve = |var: SsaVarId| -> SsaVarId { - let mut current = var; - let mut visited = BTreeSet::new(); - while let Some(&new_var) = rename_map.get(¤t) { - if !visited.insert(current) { - break; - } - current = new_var; - } - current - }; - - // Collect all phi operand updates first - let mut phi_updates: Vec<(usize, usize, usize, SsaVarId)> = Vec::new(); - for block in &ssa.blocks { - let block_idx = block.id(); - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - for op in phi.operands() { - let old_val = op.value(); - let new_val = resolve(old_val); - if new_val != old_val { - phi_updates.push((block_idx, phi_idx, op.predecessor(), new_val)); - } - } - } - } - - for (block_idx, phi_idx, pred, new_val) in phi_updates { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(phi) = block.phi_nodes_mut().get_mut(phi_idx) { - phi.set_operand(pred, new_val); - } - } - } - - // Collect all instruction use updates - let mut instr_updates: Vec<(usize, usize, SsaVarId, SsaVarId)> = Vec::new(); - for block in &ssa.blocks { - let block_idx = block.id(); - for (instr_idx, instr) in block.instructions().iter().enumerate() { - let mut seen = std::collections::BTreeSet::new(); - for &old_use in &instr.uses() { - if seen.insert(old_use) { - let new_use = resolve(old_use); - if new_use != old_use { - instr_updates.push((block_idx, instr_idx, old_use, new_use)); - } - } - } - } - } - - for (block_idx, instr_idx, old_var, new_var) in instr_updates { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - instr.op_mut().replace_uses(old_var, new_var); - } - } - } - - // NOTE: We intentionally do NOT sort instructions topologically here. - // The rename phase processes instructions in their current order and - // produces definitions before their uses within each block, so the - // result is already in valid topological order. - // - // Sorting here would reorder pure instructions (like Const) earlier - // than their original program position. When rebuild_ssa() is called - // again (e.g., after normalization passes), the rename phase would - // then process the sorted order, causing incorrect reaching-definition - // assignments for variables sharing the same origin (stack slot). - } -} diff --git a/dotscope/src/analysis/ssa/function/repair.rs b/dotscope/src/analysis/ssa/function/repair.rs deleted file mode 100644 index a0c99e63..00000000 --- a/dotscope/src/analysis/ssa/function/repair.rs +++ /dev/null @@ -1,192 +0,0 @@ -//! Lightweight SSA repair for non-CFG-modifying passes. -//! -//! After passes that only modify instructions (e.g., constant propagation, -//! copy propagation, DCE, algebraic simplification), SSA form may need minor -//! cleanup but does NOT need full reconstruction. The CFG topology, dominator -//! tree, and phi placement are all still valid — only instruction-level -//! artifacts need attention. -//! -//! `repair_ssa` is the lightweight alternative to `rebuild_ssa` for passes -//! classified as `ModificationScope::InstructionsOnly` or `UsesOnly`. It -//! performs: -//! -//! 1. **Nop stripping** — removes Nop instructions and reindexes DefSites -//! 2. **Trivial phi elimination** — phis where all operands resolve to one value -//! 3. **Dead phi elimination** — phis whose result has no consumers -//! 4. **Variable compaction** — removes orphaned variables and reindexes IDs -//! -//! What it does NOT do (saving significant overhead): -//! - Recompute dominators or dominance frontiers -//! - Recompute liveness -//! - Re-place phi nodes -//! - Full variable renaming -//! - Orphan origin assignment - -use crate::analysis::ssa::{SsaFunction, TrivialPhiOptions}; - -impl SsaFunction { - /// Lightweight SSA repair for passes that don't modify CFG structure. - /// - /// This is the fast path alternative to [`rebuild_ssa`](Self::rebuild_ssa) - /// for passes classified as `InstructionsOnly` or `UsesOnly`. It assumes - /// the CFG topology is unchanged and only cleans up instruction-level - /// artifacts. - /// - /// # What this does - /// - /// 1. Strips Nop instructions and reindexes variable DefSites - /// 2. Eliminates trivial phi nodes (all operands resolve to one value) - /// 3. Eliminates dead phi nodes (result never used) - /// 4. Compacts orphaned variables and reindexes IDs - /// - /// # When to use - /// - /// Use this instead of `rebuild_ssa` when the pass only: - /// - Replaces instruction opcodes/operands - /// - Converts instructions to Nops (for DCE) - /// - Substitutes variable uses (copy propagation, GVN) - /// - /// Do NOT use this if the pass: - /// - Adds, removes, or reorders blocks - /// - Changes branch targets (changes predecessor lists) - /// - Converts branches to jumps (changes CFG edges) - pub fn repair_ssa(&mut self) { - if self.blocks.is_empty() { - return; - } - - self.strip_nops(); - self.eliminate_trivial_phis(&TrivialPhiOptions { reachable: None }); - self.eliminate_dead_phis(); - self.compact_variables(); - self.reindex_variables(); - } -} - -#[cfg(test)] -mod tests { - use crate::analysis::{ssa::SsaOp, SsaFunctionBuilder}; - - #[test] - fn test_repair_strips_nops() { - let mut ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - let c1 = b.const_i32(42); - b.nop(); - b.nop(); - let c2 = b.const_i32(100); - let _ = b.add(c1, c2); - b.ret(); - }); - }) - .unwrap(); - - let pre_count = ssa.blocks()[0].instructions().len(); - ssa.repair_ssa(); - let post_count = ssa.blocks()[0].instructions().len(); - - // Should have removed the 2 nops - assert_eq!(pre_count - post_count, 2); - assert!(ssa.validate().is_ok()); - } - - #[test] - fn test_repair_eliminates_trivial_phis() { - let mut ssa = SsaFunctionBuilder::new(2, 1) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(42); - b.jump(1); - }); - f.block(1, |b| { - b.ret(); - }); - }) - .unwrap(); - - // repair_ssa should handle trivial phi elimination without panicking - // and preserve valid SSA. - ssa.repair_ssa(); - assert!(ssa.validate().is_ok()); - } - - #[test] - fn test_repair_strips_nops_and_compacts() { - let mut ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - let c1 = b.const_i32(42); - let c2 = b.const_i32(100); - let _ = b.add(c1, c2); - b.ret(); - }); - }) - .unwrap(); - - // Nop out one instruction — the Nop should be stripped and the - // result should still be valid SSA (the add will have a dangling - // use but compact handles that). - let pre_instr_count = ssa.blocks()[0].instructions().len(); - if let Some(block) = ssa.block_mut(0) { - if let Some(instr) = block.instructions_mut().get_mut(1) { - instr.set_op(SsaOp::Nop); - } - } - - ssa.repair_ssa(); - - // The Nop should be removed - let post_instr_count = ssa.blocks()[0].instructions().len(); - assert_eq!(pre_instr_count - post_instr_count, 1); - } - - #[test] - fn test_repair_preserves_valid_ssa() { - let mut ssa = SsaFunctionBuilder::new(3, 1) - .build_with(|f| { - f.block(0, |b| { - let c = b.const_true(); - b.branch(c, 1, 2); - }); - f.block(1, |b| { - let _ = b.const_i32(10); - b.jump(2); - }); - f.block(2, |b| { - b.ret(); - }); - }) - .unwrap(); - - // repair_ssa should maintain valid SSA form - ssa.repair_ssa(); - assert!(ssa.validate().is_ok()); - } - - #[test] - fn test_repair_is_idempotent() { - let mut ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let c1 = b.const_i32(1); - let c2 = b.const_i32(2); - let _ = b.add(c1, c2); - b.jump(1); - }); - f.block(1, |b| { - b.ret(); - }); - }) - .unwrap(); - - // Running repair twice should produce the same result - ssa.repair_ssa(); - let vars_after_first = ssa.variable_count(); - let blocks_after_first = ssa.block_count(); - - ssa.repair_ssa(); - assert_eq!(ssa.variable_count(), vars_after_first); - assert_eq!(ssa.block_count(), blocks_after_first); - } -} diff --git a/dotscope/src/analysis/ssa/function/semantics.rs b/dotscope/src/analysis/ssa/function/semantics.rs deleted file mode 100644 index 569420f0..00000000 --- a/dotscope/src/analysis/ssa/function/semantics.rs +++ /dev/null @@ -1,115 +0,0 @@ -//! Block and loop semantic analysis for SSA functions. -//! -//! These methods delegate to the `SemanticAnalyzer` to classify blocks -//! by their role: initialization, condition, body, latch, exit, etc. - -use std::collections::HashMap; - -use crate::analysis::{ - cfg::{BlockSemantics, LoopSemantics, SemanticAnalyzer}, - ssa::SsaFunction, - LoopInfo, -}; - -impl SsaFunction { - /// Analyzes the semantic role of a specific block. - /// - /// Uses the `SemanticAnalyzer` to determine what a block does: - /// initialization, condition testing, loop body work, variable updates, etc. - /// - /// # Arguments - /// - /// * `block_idx` - The block index to analyze - /// - /// # Returns - /// - /// Semantic information about the block including its role and characteristics. - #[must_use] - pub fn analyze_block_semantics(&self, block_idx: usize) -> BlockSemantics { - let mut analyzer = SemanticAnalyzer::new(self); - analyzer.analyze_block(block_idx).clone() - } - - /// Analyzes semantic roles of multiple blocks. - /// - /// # Arguments - /// - /// * `blocks` - The block indices to analyze - /// - /// # Returns - /// - /// A map of block index to semantic information. - #[must_use] - pub fn analyze_blocks_semantics(&self, blocks: &[usize]) -> HashMap { - let mut analyzer = SemanticAnalyzer::new(self); - let mut results = HashMap::new(); - - for &block in blocks { - results.insert(block, analyzer.analyze_block(block).clone()); - } - - results - } - - /// Analyzes the semantic structure of a structural loop. - /// - /// Given a `LoopInfo` from dominance-based loop detection, this method - /// classifies each block within the loop by its semantic role: - /// init, condition, body, latch, exit. - /// - /// # Arguments - /// - /// * `loop_info` - Structural loop information from `LoopForest` - /// - /// # Returns - /// - /// Semantic loop information with classified blocks and execution order. - #[must_use] - pub fn analyze_loop_semantics(&self, loop_info: &LoopInfo) -> LoopSemantics { - let mut analyzer = SemanticAnalyzer::new(self); - analyzer.analyze_loop(loop_info) - } - - /// Recovers loop semantics from flattened dispatcher case blocks. - /// - /// This is the key method for control flow unflattening. Given the target - /// blocks from a switch dispatcher, it analyzes each block's semantic role - /// to reconstruct the original loop structure. - /// - /// # Arguments - /// - /// * `case_blocks` - Block indices that are case targets of the dispatcher - /// * `dispatcher_block` - Optional index of the dispatcher block to exclude - /// - /// # Returns - /// - /// Semantic loop structure with blocks classified and ordered correctly. - #[must_use] - pub fn recover_loop_from_cases( - &self, - case_blocks: &[usize], - dispatcher_block: Option, - ) -> LoopSemantics { - let mut analyzer = SemanticAnalyzer::new(self); - - // Mark dispatcher as known if provided - if let Some(disp) = dispatcher_block { - analyzer.mark_dispatcher(disp); - } - - analyzer.recover_loop_from_cases(case_blocks) - } - - /// Creates a semantic analyzer for this function. - /// - /// Use this when you need to perform multiple semantic analyses - /// and want to benefit from caching. - /// - /// # Returns - /// - /// A new `SemanticAnalyzer` instance for this function. - #[must_use] - pub fn semantic_analyzer(&self) -> SemanticAnalyzer<'_> { - SemanticAnalyzer::new(self) - } -} diff --git a/dotscope/src/analysis/ssa/function/transforms.rs b/dotscope/src/analysis/ssa/function/transforms.rs deleted file mode 100644 index 9c382db8..00000000 --- a/dotscope/src/analysis/ssa/function/transforms.rs +++ /dev/null @@ -1,1291 +0,0 @@ -//! Mutation/transform methods for SSA functions. -//! -//! These methods modify SSA functions: replacing uses, simplifying phis, -//! folding constants, compacting variables, optimizing locals, and -//! generating local signatures. -//! -//! # `replace_uses` Architecture -//! -//! The SSA function provides two variable-replacement primitives with different -//! safety characteristics: -//! -//! - **`replace_uses(old, new)`** — replaces uses in instructions only, leaving -//! phi operands unchanged. This is the safe default for compiler passes because -//! it avoids creating cross-origin phi operand references (where a variable -//! defined at one phi origin appears as an operand of a phi with a different -//! origin). Cross-origin references can break `rebuild_ssa`'s assumption that -//! each variable flows to at most one phi origin. -//! -//! - **`replace_uses_including_phis(old, new)`** — also replaces uses in phi -//! operands. This is `pub(crate)` and intended only for SSA infrastructure -//! operations like trivial phi elimination, where the phi being eliminated -//! and its forwarding target share the same origin context. -//! -//! ## The Self-Referential Guard -//! -//! Both methods skip replacements where the instruction's destination equals -//! `new_var`, preventing the creation of self-referential instructions like -//! `v0 = add(v0, v1)`. This guard is necessary because replacing `v1 → v0` -//! in `v0 = add(v1, v2)` would make `v0` both the definition and a use, which -//! is invalid in SSA form. -//! -//! The [`ReplaceResult`] type makes this guard transparent: `replaced` counts -//! successful substitutions, `skipped` counts uses that were left unchanged. -//! Callers use `result.is_complete()` to determine whether all uses were -//! handled, eliminating the need for post-hoc scanning. -//! -//! ## High-Level Operations -//! -//! Compiler passes should prefer the high-level operations that encapsulate -//! the replace-and-check pattern: -//! -//! - **`propagate_copies(copies)`** — batch copy propagation. Takes a -//! dest→src map, replaces uses via `replace_uses` (not including phis), -//! and reports which copies were fully vs. partially propagated. -//! -//! - **`eliminate_trivial_phis(options)`** — finds and removes trivial phi -//! nodes (all non-self operands are the same value). Replaces uses via -//! `replace_uses_including_phis` and iterates to fixpoint. -//! -//! - **`prune_phi_operands(reachable)`** — removes stale phi operands from -//! unreachable predecessors after CFG changes. - -use std::collections::{BTreeMap, BTreeSet}; - -use crate::{ - analysis::ssa::{ - ConstValue, DefSite, ReplaceResult, SsaFunction, SsaOp, SsaType, SsaVarId, UseSite, - VariableOrigin, - }, - metadata::signatures::{CustomModifiers, SignatureLocalVariable, SignatureLocalVariables}, - utils::BitSet, - Error, -}; - -/// Options for trivial phi elimination. -pub struct TrivialPhiOptions<'a> { - /// If set, only consider phis in reachable blocks and use reachability-aware - /// self-referential checks. Unreachable predecessor operands are filtered out - /// as a second-pass check. All trivial phis are removed unconditionally. - /// - /// If `None`, all blocks are considered. Chain resolution is applied, and only - /// fully propagated phis (no skipped uses from the self-ref guard) are removed. - pub reachable: Option<&'a BitSet>, -} - -/// Result of batch copy propagation. -pub struct CopyPropagationResult { - /// Total number of uses replaced across all copies. - pub total_replaced: usize, - /// Set of copy destinations that were fully propagated (all uses replaced). - /// These copies can safely be Nop'd by the caller. - /// Stored as a BitSet indexed by `SsaVarId::index()`. - pub fully_propagated: BitSet, - /// Set of copy destinations that still have remaining instruction uses - /// (due to self-referential guard). These copies must be kept alive. - /// Stored as a BitSet indexed by `SsaVarId::index()`. - pub partially_propagated: BitSet, -} - -impl SsaFunction { - /// Replaces all uses of `old_var` with `new_var` throughout the function. - /// - /// This is the core operation for copy propagation - when we know that - /// `v1 = v0` (a copy), we can replace all uses of `v1` with `v0`. - /// - /// # Note - /// - /// This method only replaces uses in instructions, not in PHI operands. - /// For internal operations that need to also replace PHI operands, use - /// `replace_uses_including_phis`. - pub fn replace_uses(&mut self, old_var: SsaVarId, new_var: SsaVarId) -> ReplaceResult { - self.blocks - .iter_mut() - .map(|block| block.replace_uses(old_var, new_var)) - .fold(ReplaceResult::default(), |acc, r| ReplaceResult { - replaced: acc.replaced.saturating_add(r.replaced), - skipped: acc.skipped.saturating_add(r.skipped), - }) - } - - /// Replaces all uses of `old_var` with `new_var`, including in PHI operands. - /// - /// Unlike [`replace_uses`](Self::replace_uses), this method also replaces uses - /// in PHI node operands across all blocks. This is necessary for internal SSA - /// operations that eliminate PHI nodes and need to forward their values through - /// other PHIs. - /// - /// # Safety - /// - /// This method is `pub(crate)` because it can create cross-origin PHI operand - /// references if misused. - /// - /// # When to Use - /// - /// Only use this method for: - /// - **Trivial PHI elimination**: When removing a PHI like `v10 = phi(v5, v5)`, - /// we need to replace uses of `v10` with `v5` everywhere, including in other - /// PHI operands. - /// - **Copy propagation within PHIs**: When a copy's destination is a PHI result - /// and we're eliminating that PHI. - pub(crate) fn replace_uses_including_phis( - &mut self, - old_var: SsaVarId, - new_var: SsaVarId, - ) -> ReplaceResult { - self.blocks - .iter_mut() - .map(|block| block.replace_uses_including_phis(old_var, new_var)) - .fold(ReplaceResult::default(), |acc, r| ReplaceResult { - replaced: acc.replaced.saturating_add(r.replaced), - skipped: acc.skipped.saturating_add(r.skipped), - }) - } - - /// Replaces all uses of `old_var` with `new_var` within a specific block. - /// - /// This is a targeted version of `replace_uses` that only affects instructions - /// within the specified block (not PHI operands). - pub fn replace_uses_in_block( - &mut self, - block_idx: usize, - old_var: SsaVarId, - new_var: SsaVarId, - ) -> ReplaceResult { - self.block_mut(block_idx) - .map_or(ReplaceResult::default(), |block| { - block.replace_uses(old_var, new_var) - }) - } - - /// Propagates a batch of copy mappings (dest → src) through all instructions. - /// - /// For each mapping, replaces all uses of `dest` with `src` in instructions - /// (NOT in phi operands — this is the safe default that avoids cross-origin - /// phi references). Reports which copies were fully propagated vs. which - /// still have remaining uses due to the self-referential guard. - /// - /// # Usage - /// - /// ```rust,ignore - /// let result = ssa.propagate_copies(&resolved_copies); - /// // Nop only the fully propagated copies: - /// for dest in &result.fully_propagated { - /// ssa.nop_copy_defining(*dest); - /// } - /// ``` - pub fn propagate_copies( - &mut self, - copies: &BTreeMap, - ) -> CopyPropagationResult { - let variable_count = self.var_id_capacity(); - let mut total_replaced: usize = 0; - let mut fully_propagated = BitSet::new(variable_count); - let mut partially_propagated = BitSet::new(variable_count); - - for (dest, src) in copies { - if dest == src { - continue; - } - - let result = self.replace_uses(*dest, *src); - - if result.replaced > 0 { - if result.is_complete() { - fully_propagated.insert(dest.index()); - } else { - partially_propagated.insert(dest.index()); - } - total_replaced = total_replaced.saturating_add(result.replaced); - } - } - - CopyPropagationResult { - total_replaced, - fully_propagated, - partially_propagated, - } - } - - /// Neutralizes Copy instructions that define the given variable by - /// replacing them with Nop. - /// - /// This is used after copy propagation to eliminate dead copy instructions - /// whose destination has been fully propagated to all use sites. Without - /// this, rebuild_ssa's rename would re-create versions for the Copy's origin, - /// shadowing the source variable and undoing the propagation. - pub fn nop_copy_defining(&mut self, dest: SsaVarId) { - for block in &mut self.blocks { - for instr in block.instructions_mut() { - if let SsaOp::Copy { dest: d, .. } = instr.op() { - if *d == dest { - instr.set_op(SsaOp::Nop); - return; - } - } - } - } - } - - /// Prunes phi operands from non-existent or unreachable predecessors. - /// - /// After block removal or CFG changes, phi nodes may reference predecessors - /// that no longer exist or are unreachable. This method removes those stale - /// operands, ensuring phi nodes only reference valid predecessors with - /// defined values. - /// - /// Returns the number of operands pruned. - pub fn prune_phi_operands(&mut self, reachable: &BitSet) -> usize { - let variable_count = self.var_id_capacity(); - - // Build a set of all defined variables in reachable blocks - let mut defined_vars = BitSet::new(variable_count); - - for block_idx in reachable.iter() { - if let Some(block) = self.block(block_idx) { - for phi in block.phi_nodes() { - let idx = phi.result().index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - for instr in block.instructions() { - if let Some(def) = instr.def() { - let idx = def.index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - } - } - - // Include argument variables (implicitly defined at function entry) - for var in &self.variables { - if var.origin().is_argument() { - let idx = var.id().index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - - // Compute actual predecessors from the CFG - let block_count = self.blocks.len(); - let mut actual_predecessors: BTreeMap = BTreeMap::new(); - for block_idx in reachable.iter() { - if let Some(block) = self.block(block_idx) { - for successor in block.successors() { - actual_predecessors - .entry(successor) - .or_insert_with(|| BitSet::new(block_count)) - .insert(block_idx); - } - } - } - - let mut pruned: usize = 0; - - for block_idx in reachable.iter() { - if let Some(block) = self.block_mut(block_idx) { - let preds = actual_predecessors.get(&block_idx); - - for phi in block.phi_nodes_mut() { - let operands = phi.operands_mut(); - let original_len = operands.len(); - - if original_len == 0 { - continue; - } - - let to_keep: Vec = operands - .iter() - .map(|op| { - let pred = op.predecessor(); - let value = op.value(); - let pred_ok = - pred < block_count && preds.is_some_and(|p| p.contains(pred)); - let val_ok = value.index() < variable_count - && defined_vars.contains(value.index()); - pred_ok && val_ok - }) - .collect(); - - // Never leave a PHI completely empty - let keep_count = to_keep.iter().filter(|&&k| k).count(); - if keep_count == 0 { - continue; - } - - let mut keep_iter = to_keep.iter(); - operands.retain(|_| *keep_iter.next().unwrap_or(&true)); - - pruned = pruned.saturating_add(original_len.saturating_sub(operands.len())); - } - } - } - - pruned - } - - /// Recomputes all use information from scratch. - /// - /// This should be called after SSA transformations that may have invalidated - /// the use tracking. - pub fn recompute_uses(&mut self) { - // Step 1: Clear all existing uses - for var in &mut self.variables { - var.clear_uses(); - } - - // Step 2: Scan instructions to record uses - for (block_idx, block) in self.blocks.iter().enumerate() { - // Record uses from instructions - for (instr_idx, instr) in block.instructions().iter().enumerate() { - for use_var in instr.op().uses() { - if let Some(var) = self.var_index(use_var) { - let use_site = UseSite::instruction(block_idx, instr_idx); - if let Some(slot) = self.variables.get_mut(var) { - slot.add_use(use_site); - } - } - } - } - - // Record uses from phi nodes - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - for operand in phi.operands() { - if let Some(var) = self.var_index(operand.value()) { - let use_site = UseSite::phi_operand(block_idx, phi_idx); - if let Some(slot) = self.variables.get_mut(var) { - slot.add_use(use_site); - } - } - } - } - } - } - - /// Replaces the operation of an instruction at a specific location. - pub fn replace_instruction_op( - &mut self, - block_idx: usize, - instr_idx: usize, - new_op: SsaOp, - ) -> bool { - if let Some(block) = self.blocks.get_mut(block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - instr.set_op(new_op); - return true; - } - } - false - } - - /// Removes an instruction by replacing it with a Nop. - pub fn remove_instruction(&mut self, block_idx: usize, instr_idx: usize) -> bool { - self.replace_instruction_op(block_idx, instr_idx, SsaOp::Nop) - } - - /// Simplifies a phi node by converting it to a copy operation. - /// - /// When a phi node has all identical operands (excluding self-references), - /// it can be converted to a simple copy operation: `phi_result = source`. - pub fn simplify_phi_to_copy( - &mut self, - block_idx: usize, - phi_idx: usize, - source: SsaVarId, - ) -> bool { - let Some(block) = self.blocks.get_mut(block_idx) else { - return false; - }; - - let Some(phi) = block.phi_nodes().get(phi_idx) else { - return false; - }; - - let dest = phi.result(); - - // Remove the phi node - block.phi_nodes_mut().remove(phi_idx); - - // Replace all uses of `dest` with `source` - self.replace_uses_including_phis(dest, source); - - true - } - - /// Removes a phi node by index without any validation. - pub fn remove_phi_unchecked(&mut self, block_idx: usize, phi_idx: usize) -> bool { - if let Some(block) = self.blocks.get_mut(block_idx) { - if phi_idx < block.phi_nodes().len() { - block.phi_nodes_mut().remove(phi_idx); - return true; - } - } - false - } - - /// Eliminates trivial phi nodes where all non-self operands resolve to a - /// single value. Iterates to fixpoint (cascading simplification). - /// - /// A phi is trivial when, excluding self-references, all operands provide - /// the same value. The phi result is replaced by that value everywhere - /// (including other phi operands). - /// - /// # Modes - /// - /// When `options.reachable` is `Some`: - /// - Uses reachability-aware self-referential checks (definitions in unreachable - /// blocks don't count as creating cycles). - /// - Performs a second-pass check filtering operands from unreachable predecessors. - /// - All trivial phis are removed unconditionally (suitable for rebuild_ssa). - /// - /// When `options.reachable` is `None`: - /// - Uses basic self-referential checks. - /// - Resolves chains among trivial phis to avoid stale references. - /// - Only fully propagated phis (no skipped uses) are removed (suitable for repair_ssa). - /// - /// # Returns - /// - /// The number of phis eliminated. - pub fn eliminate_trivial_phis(&mut self, options: &TrivialPhiOptions) -> usize { - let mut total_eliminated: usize = 0; - let block_count = self.blocks.len(); - - // Precompute reachability data if in reachable mode - let reachable_preds: Option> = options.reachable.map(|reachable| { - let mut map = BTreeMap::new(); - for block in &self.blocks { - let block_idx = block.id(); - if !reachable.contains(block_idx) { - continue; - } - let mut preds = BitSet::new(block_count); - for p in self.block_predecessors(block_idx) { - if reachable.contains(p) { - preds.insert(p); - } - } - map.insert(block_idx, preds); - } - map - }); - - let var_def_block: Option> = options.reachable.map(|_| { - let mut map = BTreeMap::new(); - for block in &self.blocks { - let block_idx = block.id(); - for instr in block.instructions() { - if let Some(dest) = instr.op().dest() { - map.insert(dest, block_idx); - } - } - } - map - }); - - loop { - let mut trivial_phis: Vec<(SsaVarId, SsaVarId)> = Vec::new(); - - for block in &self.blocks { - let block_idx = block.id(); - let block_reachable_preds = - reachable_preds.as_ref().and_then(|rp| rp.get(&block_idx)); - - for phi in block.phi_nodes() { - let result = phi.result(); - - // Collect unique non-self operands - let unique_sources: BTreeSet = phi - .operands() - .iter() - .map(|op| op.value()) - .filter(|&v| v != result) - .collect(); - - if let Some(&source) = unique_sources - .iter() - .next() - .filter(|_| unique_sources.len() == 1) - { - let is_self_ref = match (&var_def_block, options.reachable) { - (Some(vdb), Some(reachable)) => self - .would_create_self_reference_reachable( - source, result, vdb, reachable, - ), - _ => self.would_create_self_reference(source, result), - }; - - if !is_self_ref { - trivial_phis.push((result, source)); - continue; - } - } else if unique_sources.is_empty() && !phi.operands().is_empty() { - // Fully self-referential phi - trivial_phis.push((result, result)); - continue; - } - - // Reachable-only second pass: filter out operands from - // unreachable predecessors and check triviality again - if unique_sources.len() > 1 { - if let Some(rpreds) = block_reachable_preds { - let unique_reachable: BTreeSet = phi - .operands() - .iter() - .filter(|op| { - let pred = op.predecessor(); - pred < block_count && rpreds.contains(pred) - }) - .map(|op| op.value()) - .filter(|&v| v != result) - .collect(); - - if let Some(&source) = unique_reachable - .iter() - .next() - .filter(|_| unique_reachable.len() == 1) - { - let is_self_ref = match (&var_def_block, options.reachable) { - (Some(vdb), Some(reachable)) => self - .would_create_self_reference_reachable( - source, result, vdb, reachable, - ), - _ => self.would_create_self_reference(source, result), - }; - if !is_self_ref { - trivial_phis.push((result, source)); - } - } else if unique_reachable.is_empty() - && phi.operands().iter().any(|op| { - let pred = op.predecessor(); - pred < block_count && rpreds.contains(pred) - }) - { - trivial_phis.push((result, result)); - } - } - } - } - } - - if trivial_phis.is_empty() { - break; - } - - let variable_count = self.var_id_capacity(); - - if options.reachable.is_none() { - // Repair mode: resolve chains among trivial phis. - let trivial_map: BTreeMap = - trivial_phis.iter().copied().collect(); - for entry in &mut trivial_phis { - if entry.0 == entry.1 { - continue; - } - let mut current = entry.1; - let mut visited = BTreeSet::new(); - while let Some(&next) = trivial_map.get(¤t) { - if next == current || !visited.insert(current) { - break; - } - current = next; - } - entry.1 = current; - } - - // Replace uses and track which phis were fully propagated. - let mut trivial_set = BitSet::new(variable_count); - for (phi_result, source) in &trivial_phis { - if *phi_result != *source { - let result = self.replace_uses_including_phis(*phi_result, *source); - if result.is_complete() { - trivial_set.insert(phi_result.index()); - } - } else { - trivial_set.insert(phi_result.index()); - } - } - if trivial_set.is_empty() { - break; - } - - total_eliminated = total_eliminated.saturating_add(trivial_set.count()); - for block in &mut self.blocks { - block.phi_nodes_mut().retain(|phi| { - let idx = phi.result().index(); - idx >= variable_count || !trivial_set.contains(idx) - }); - } - self.variables.retain(|v| { - let idx = v.id().index(); - idx >= variable_count || !trivial_set.contains(idx) - }); - } else { - // Rebuild mode: replace uses and remove unconditionally. - for (phi_result, source) in &trivial_phis { - if *phi_result != *source { - self.replace_uses_including_phis(*phi_result, *source); - } - } - - let mut trivial_set = BitSet::new(variable_count); - for (result, _) in &trivial_phis { - trivial_set.insert(result.index()); - } - total_eliminated = total_eliminated.saturating_add(trivial_set.count()); - for block in &mut self.blocks { - block.phi_nodes_mut().retain(|phi| { - let idx = phi.result().index(); - idx >= variable_count || !trivial_set.contains(idx) - }); - } - self.variables.retain(|v| { - let idx = v.id().index(); - idx >= variable_count || !trivial_set.contains(idx) - }); - } - } - - total_eliminated - } - - /// Folds a constant operation, replacing its uses with the computed value. - pub fn fold_constant(&mut self, block_idx: usize, instr_idx: usize, value: ConstValue) -> bool { - if let Some(block) = self.blocks.get_mut(block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - if let Some(dest) = instr.op().dest() { - instr.set_op(SsaOp::Const { dest, value }); - return true; - } - } - } - false - } - - /// Compacts the variable table by removing orphaned variables. - /// - /// A variable is considered orphaned if: - /// - It's not defined by any instruction in any block - /// - It's not defined by any phi node in any block - /// - /// # Returns - /// - /// The number of variables that were removed. - pub fn compact_variables(&mut self) -> usize { - let variable_count = self.var_id_capacity(); - - // Phase 1: Collect all variables that still have active definitions - let mut defined_vars = BitSet::new(variable_count); - - for block in &self.blocks { - // From instructions - for instr in block.instructions() { - let op = instr.op(); - // Skip Nop instructions - they have no definition - if matches!(op, SsaOp::Nop) { - continue; - } - if let Some(dest) = op.dest() { - let idx = dest.index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - - // From phi nodes - for phi in block.phi_nodes() { - let idx = phi.result().index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - - // Also keep version-0 entry-point variables. These have no instruction - // def but are implicitly defined at function entry: - // - Argument/Local v0: method parameters and default-initialized locals - // - Phi v0 with entry def_site: placeholder reaching defs for stack temp - // groups created during SSA rebuild - for var in &self.variables { - if var.version() == 0 && var.def_site().instruction.is_none() { - let idx = var.id().index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - - // Also keep variables that are still referenced by non-Nop instructions. - // This can happen when replace_uses skips replacements due to the - // self-referential guard (dest == new_var), leaving uses behind after - // the definition was Nop'd or eliminated. - for block in &self.blocks { - for instr in block.instructions() { - if matches!(instr.op(), SsaOp::Nop) { - continue; - } - for u in instr.uses() { - let idx = u.index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - - // Also keep variables referenced by phi operands. A phi may - // reference a variable whose defining instruction was Nop'd by - // an optimization pass — without this, compact would remove the - // variable and the phi operand would become a dangling reference. - for phi in block.phi_nodes() { - for op in phi.operands() { - let idx = op.value().index(); - if idx < variable_count { - defined_vars.insert(idx); - } - } - } - } - - // Phase 2: Remove orphaned variables - let original_count = self.variables.len(); - self.variables.retain(|v| { - let idx = v.id().index(); - idx < variable_count && defined_vars.contains(idx) - }); - // Reassign dense IDs and rebuild registries - let remap = self.reassign_dense_ids(); - self.remap_var_ids_in_blocks(&remap); - self.rebuild_origin_versions(); - original_count.saturating_sub(self.variables.len()) - } - - /// Reassigns all variable IDs to dense contiguous indices (0..N-1) and - /// remaps all references in blocks. - /// - /// **Warning**: This invalidates any externally-held `SsaVarId` references. - pub fn reindex_variables(&mut self) -> usize { - let remap = self.reassign_dense_ids(); - let remapped = remap.len(); - self.remap_var_ids_in_blocks(&remap); - self.rebuild_origin_versions(); - remapped - } - - /// Strips Nop instructions from all blocks and reindexes variable DefSites. - /// - /// This is the shared implementation used by both `repair_ssa` and - /// `rebuild_ssa`. After stripping Nops: - /// - /// 1. Non-Nop instructions that shifted get their DefSites remapped - /// 2. Variables whose defining instruction was a Nop get reset to entry DefSite - /// 3. Any remaining out-of-bounds DefSites are reset to entry DefSite - pub fn strip_nops(&mut self) { - let mut remap: BTreeMap<(usize, usize), usize> = BTreeMap::new(); - let mut nop_sites: BTreeSet<(usize, usize)> = BTreeSet::new(); - - for (block_idx, block) in self.blocks.iter_mut().enumerate() { - let instructions = block.instructions_mut(); - - if !instructions.iter().any(|i| matches!(i.op(), SsaOp::Nop)) { - continue; - } - - let mut new_idx = 0usize; - for (old_idx, instr) in instructions.iter().enumerate() { - if matches!(instr.op(), SsaOp::Nop) { - nop_sites.insert((block_idx, old_idx)); - } else { - if old_idx != new_idx { - remap.insert((block_idx, old_idx), new_idx); - } - new_idx = new_idx.saturating_add(1); - } - } - - instructions.retain(|instr| !matches!(instr.op(), SsaOp::Nop)); - } - - // Update variable DefSites to reflect new instruction positions. - // Variables whose defining instruction was a Nop get reset to entry. - if !remap.is_empty() || !nop_sites.is_empty() { - for var in &mut self.variables { - let site = var.def_site(); - if let Some(old_instr) = site.instruction { - if nop_sites.contains(&(site.block, old_instr)) { - var.set_def_site(DefSite::entry()); - } else if let Some(&new_instr) = remap.get(&(site.block, old_instr)) { - var.set_def_site(DefSite::instruction(site.block, new_instr)); - } - } - } - } - - // Validate remaining DefSites are in-bounds. Catches stale DefSites - // that existed before strip_nops was called (e.g., from passes that - // modified instructions without updating DefSites). - let block_instr_counts: Vec = - self.blocks.iter().map(|b| b.instructions().len()).collect(); - - for var in &mut self.variables { - let site = var.def_site(); - if let Some(instr_idx) = site.instruction { - let out_of_bounds = match block_instr_counts.get(site.block) { - Some(&count) => instr_idx >= count, - None => true, - }; - if out_of_bounds { - var.set_def_site(DefSite::entry()); - } - } - } - } - - /// Eliminates dead phi nodes whose result is never used. - /// - /// A phi is dead if its result variable has no consumers (no instruction - /// or other phi uses it). Handles dead phi cycles (A uses B, B uses A, - /// neither used elsewhere) via liveness propagation. - /// - /// Also bridges implicit uses from `LoadLocal`/`LoadArg` instructions - /// to the corresponding phi nodes for that local/arg origin, ensuring - /// phis that are read by index-based loads are not incorrectly eliminated. - pub fn eliminate_dead_phis(&mut self) { - let variable_count = self.var_id_capacity(); - let mut all_phi_results = BitSet::new(variable_count); - for block in &self.blocks { - for phi in block.phi_nodes() { - let idx = phi.result().index(); - if idx < variable_count { - all_phi_results.insert(idx); - } - } - } - - if all_phi_results.is_empty() { - return; - } - - // Build map from phi origin to phi result IDs for LoadLocal/LoadArg bridging. - let mut origin_to_phi_results: BTreeMap> = BTreeMap::new(); - for block in &self.blocks { - for phi in block.phi_nodes() { - origin_to_phi_results - .entry(phi.origin()) - .or_default() - .push(phi.result()); - } - } - - // Phase 1: Mark phis as live if used by any non-phi instruction - let mut live_phis = BitSet::new(variable_count); - for block in &self.blocks { - for instr in block.instructions() { - // Direct SSA uses - for u in instr.uses() { - let idx = u.index(); - if idx < variable_count && all_phi_results.contains(idx) { - live_phis.insert(idx); - } - } - - // Implicit uses via LoadLocal/LoadArg (index-based reads). - // These don't appear in uses() but create a dependency on - // the corresponding PHI node for that local/arg origin. - match instr.op() { - SsaOp::LoadLocal { local_index, .. } => { - let origin = VariableOrigin::Local(*local_index); - if let Some(phi_results) = origin_to_phi_results.get(&origin) { - for &phi_result in phi_results { - let idx = phi_result.index(); - if idx < variable_count { - live_phis.insert(idx); - } - } - } - } - SsaOp::LoadArg { arg_index, .. } => { - let origin = VariableOrigin::Argument(*arg_index); - if let Some(phi_results) = origin_to_phi_results.get(&origin) { - for &phi_result in phi_results { - let idx = phi_result.index(); - if idx < variable_count { - live_phis.insert(idx); - } - } - } - } - _ => {} - } - } - } - - // Phase 2: Propagate liveness through phi operands - let mut changed = true; - while changed { - changed = false; - for block in &self.blocks { - for phi in block.phi_nodes() { - let result_idx = phi.result().index(); - if result_idx < variable_count && live_phis.contains(result_idx) { - for op in phi.operands() { - let val_idx = op.value().index(); - if val_idx < variable_count - && all_phi_results.contains(val_idx) - && live_phis.insert(val_idx) - { - changed = true; - } - } - } - } - } - } - - // Phase 3: Remove dead phis (all_phi_results - live_phis) - let mut dead_phis = all_phi_results.clone(); - dead_phis.difference_with(&live_phis); - - if dead_phis.is_empty() { - return; - } - - for block in &mut self.blocks { - block.phi_nodes_mut().retain(|phi| { - let idx = phi.result().index(); - idx >= variable_count || !dead_phis.contains(idx) - }); - } - - self.variables.retain(|v| { - let idx = v.id().index(); - idx >= variable_count || !dead_phis.contains(idx) - }); - } - - /// Optimizes local variables by removing unused ones and compacting indices. - /// - /// This method: - /// 1. Identifies which local indices are actually used - /// 2. Creates a compact remapping (old index -> new index) - /// 3. Updates all `VariableOrigin::Local` references - /// 4. Updates all `SsaOp::LoadLocal` and `SsaOp::LoadLocalAddr` indices - /// 5. Updates `num_locals` to the new count - /// - /// # Returns - /// - /// A vector where `result[old_index]` contains `Some(new_index)` for used locals, - /// or `None` for unused locals. - pub fn optimize_locals(&mut self) -> Vec> { - // Phase 1: Collect all used local indices - let mut used_locals: BTreeSet = BTreeSet::new(); - - // From variables - for var in &self.variables { - if let VariableOrigin::Local(idx) = var.origin() { - used_locals.insert(idx); - } - } - - // From phi nodes - for block in &self.blocks { - for phi in block.phi_nodes() { - if let VariableOrigin::Local(idx) = phi.origin() { - used_locals.insert(idx); - } - } - } - - // From LoadLocal and LoadLocalAddr instructions - for block in &self.blocks { - for instr in block.instructions() { - match instr.op() { - SsaOp::LoadLocal { local_index, .. } - | SsaOp::LoadLocalAddr { local_index, .. } => { - used_locals.insert(*local_index); - } - _ => {} - } - } - } - - // Determine the actual range of local indices (may exceed num_locals - // when stack-originated locals have indices >= original num_locals) - let max_local_idx = - (used_locals.iter().copied().max().unwrap_or(0) as usize).saturating_add(1); - let remap_size = max_local_idx.max(self.num_locals); - - // If no optimization needed (all locals used or no locals), return identity mapping - if used_locals.len() == remap_size || remap_size == 0 { - #[allow(clippy::cast_possible_truncation)] - return (0..remap_size as u16).map(Some).collect(); - } - - // Phase 2: Build remapping (old index -> new index) - let mut remap: Vec> = vec![None; remap_size]; - let mut sorted_used: Vec = used_locals.into_iter().collect(); - sorted_used.sort_unstable(); - - for (new_idx, &old_idx) in sorted_used.iter().enumerate() { - #[allow(clippy::cast_possible_truncation)] - let new_idx_u16 = new_idx as u16; - if let Some(slot) = remap.get_mut(old_idx as usize) { - *slot = Some(new_idx_u16); - } - } - - let new_num_locals = sorted_used.len(); - - // Phase 3: Update all variable origins - for var in &mut self.variables { - if let VariableOrigin::Local(idx) = var.origin() { - if let Some(Some(new_idx)) = remap.get(idx as usize).copied() { - var.set_origin(VariableOrigin::Local(new_idx)); - } - } - } - - // Phase 4: Update phi nodes - for block in &mut self.blocks { - for phi in block.phi_nodes_mut() { - if let VariableOrigin::Local(idx) = phi.origin() { - if let Some(Some(new_idx)) = remap.get(idx as usize).copied() { - phi.set_origin(VariableOrigin::Local(new_idx)); - } - } - } - } - - // Phase 5: Update LoadLocal and LoadLocalAddr instructions - for block in &mut self.blocks { - for instr in block.instructions_mut() { - match instr.op_mut() { - SsaOp::LoadLocal { local_index, .. } - | SsaOp::LoadLocalAddr { local_index, .. } => { - if let Some(Some(new_idx)) = remap.get(*local_index as usize).copied() { - *local_index = new_idx; - } - } - _ => {} - } - } - } - - // Phase 6: Update num_locals - self.num_locals = new_num_locals; - - remap - } - - /// Generates a local variable signature from the SSA variable types. - /// - /// This creates a signature based on the types of locals in the SSA, combining - /// information from multiple sources in order of priority: - /// - /// 1. **Original types from CilObject** - preserved from source assembly - /// 2. **Temporary types map** - for codegen-allocated locals - /// 3. **SSA inference** - from SSA variables with `VariableOrigin::Local` - /// - /// # Errors - /// - /// Returns an error if type information is missing for any local variable. - pub fn generate_local_signature( - &self, - override_count: Option, - temporary_types: Option<&BTreeMap>, - ) -> crate::Result { - // Use empty map if none provided - let empty_temps = BTreeMap::new(); - let temp_types = temporary_types.unwrap_or(&empty_temps); - - // Use override count if provided, otherwise use the SSA's num_locals - let local_count = override_count.map_or(self.num_locals, |c| c as usize); - - // If we have original local types (from CilObject), use them as the base - if let Some(original_types) = &self.original_local_types { - let mut locals: Vec = Vec::with_capacity(local_count); - - for (idx, orig) in original_types.iter().enumerate() { - if idx >= local_count { - break; - } - locals.push(orig.clone()); - } - - // For any additional locals (temporaries allocated by codegen) - for idx in original_types.len()..local_count { - #[allow(clippy::cast_possible_truncation)] - let idx_u16 = idx as u16; - let local_type = temp_types - .get(&idx_u16) - .cloned() - .or_else(|| self.infer_local_type(idx)) - .ok_or_else(|| { - Error::CodegenFailed(format!( - "no type information for local {idx} \ - (original_num_locals={}, num_locals={})", - self.original_num_locals, self.num_locals, - )) - })?; - locals.push(SignatureLocalVariable { - modifiers: CustomModifiers::default(), - is_pinned: false, - is_byref: false, - base: local_type.to_type_signature(), - }); - } - - return Ok(SignatureLocalVariables { locals }); - } - - // No original types - fall back to inference for all locals - let mut local_types: Vec> = vec![None; local_count]; - - // First, populate with any provided temporary types (highest priority for temps) - for (idx, typ) in temp_types { - let idx = *idx as usize; - if let Some(slot) = local_types.get_mut(idx) { - *slot = Some(typ.clone()); - } - } - - // Get type from SSA variables with Local origin - for var in &self.variables { - if let VariableOrigin::Local(idx) = var.origin() { - let idx = idx as usize; - if let Some(slot) = local_types.get_mut(idx) { - if slot.is_none() { - let var_type = var.var_type(); - if !var_type.is_unknown() { - *slot = Some(var_type.clone()); - } - } - } - } - } - - // Also check phi nodes for type information - for block in &self.blocks { - for phi in block.phi_nodes() { - if let VariableOrigin::Local(idx) = phi.origin() { - let idx = idx as usize; - if let Some(slot) = local_types.get_mut(idx) { - if slot.is_none() { - if let Some(var) = self.variable(phi.result()) { - let var_type = var.var_type(); - if !var_type.is_unknown() { - *slot = Some(var_type.clone()); - } - } - } - } - } - } - } - - // Build the signature locals - let mut locals: Vec = Vec::with_capacity(local_types.len()); - for (idx, opt_type) in local_types.into_iter().enumerate() { - let base_type = opt_type.ok_or_else(|| { - Error::CodegenFailed(format!( - "no type information for local {idx} \ - (original_num_locals={}, num_locals={})", - self.original_num_locals, self.num_locals, - )) - })?; - locals.push(SignatureLocalVariable { - modifiers: CustomModifiers::default(), - is_pinned: false, - is_byref: false, - base: base_type.to_type_signature(), - }); - } - - Ok(SignatureLocalVariables { locals }) - } - - /// Infers the type for a local variable from SSA information. - /// - /// Returns `None` if no type information is available for the given local index. - fn infer_local_type(&self, local_idx: usize) -> Option { - // Try to find type from variables with this Local origin - for var in &self.variables { - if let VariableOrigin::Local(idx) = var.origin() { - if idx as usize == local_idx { - let var_type = var.var_type(); - if !var_type.is_unknown() { - return Some(var_type.clone()); - } - } - } - } - - // Try phi nodes - for block in &self.blocks { - for phi in block.phi_nodes() { - if let VariableOrigin::Local(idx) = phi.origin() { - if idx as usize == local_idx { - if let Some(var) = self.variable(phi.result()) { - let var_type = var.var_type(); - if !var_type.is_unknown() { - return Some(var_type.clone()); - } - } - } - } - } - } - - None - } - - /// Shrinks `num_locals` to the actual maximum local index in use. - /// - /// After `compact_variables()` removes unused variables, `num_locals` may - /// exceed the actual maximum local index referenced. This scans all - /// `VariableOrigin::Local(idx)` references (variables, phi nodes, and - /// `LoadLocal`/`LoadLocalAddr` instructions) to find the true maximum, then - /// sets `num_locals = max(max_used + 1, original_num_locals)`. - /// - /// The `original_num_locals` floor ensures we never drop below the method's - /// declared local count (those locals have default-initialization semantics). - pub fn shrink_num_locals(&mut self) { - let mut max_local_idx: Option = None; - - // From variables - for var in &self.variables { - if let VariableOrigin::Local(idx) = var.origin() { - max_local_idx = Some(max_local_idx.map_or(idx, |cur| cur.max(idx))); - } - } - - // From phi nodes - for block in &self.blocks { - for phi in block.phi_nodes() { - if let VariableOrigin::Local(idx) = phi.origin() { - max_local_idx = Some(max_local_idx.map_or(idx, |cur| cur.max(idx))); - } - } - } - - // From LoadLocal and LoadLocalAddr instructions - for block in &self.blocks { - for instr in block.instructions() { - match instr.op() { - SsaOp::LoadLocal { local_index, .. } - | SsaOp::LoadLocalAddr { local_index, .. } => { - max_local_idx = - Some(max_local_idx.map_or(*local_index, |cur| cur.max(*local_index))); - } - _ => {} - } - } - } - - let needed = max_local_idx.map_or(0, |idx| (idx as usize).saturating_add(1)); - self.num_locals = needed.max(self.original_num_locals); - } -} diff --git a/dotscope/src/analysis/ssa/instruction.rs b/dotscope/src/analysis/ssa/instruction.rs deleted file mode 100644 index b83f2ecc..00000000 --- a/dotscope/src/analysis/ssa/instruction.rs +++ /dev/null @@ -1,425 +0,0 @@ -//! SSA-form instructions with explicit def/use information. -//! -//! This module provides the SSA representation of CIL instructions. Unlike -//! stack-based CIL where operands are implicit on the evaluation stack, -//! SSA instructions have explicit operands (uses) and results (defs). -//! -//! # Design -//! -//! Each SSA instruction contains: -//! -//! - **Original**: The original CIL instruction (for debugging/display) -//! - **Op**: The decomposed SSA operation in `result = op(operands)` form -//! -//! The `SsaOp` is the primary representation for analysis passes, while -//! the original CIL instruction is retained for debugging and to maintain -//! the connection to source locations. -//! -//! This explicit representation enables: -//! - Direct construction of def-use chains -//! - Easy identification of dead code (def with no uses) -//! - Straightforward data flow analysis -//! - Pattern matching on decomposed operations -//! -//! # Thread Safety -//! -//! All types in this module are `Send` and `Sync`. - -use std::fmt; - -use crate::{ - analysis::ssa::{SsaOp, SsaType, SsaVarId}, - assembly::{FlowType, Instruction, InstructionCategory, Operand, StackBehavior}, -}; - -/// An instruction in SSA form with explicit operands. -/// -/// This contains both the original CIL instruction (for debugging) and a -/// decomposed SSA operation for analysis. The `SsaOp` provides a clean -/// `result = op(operands)` form suitable for optimization passes. -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::analysis::{SsaInstruction, SsaOp, SsaVarId}; -/// use dotscope::assembly::Instruction; -/// -/// // An add instruction: result = left + right -/// let left = SsaVarId::from_index(0); -/// let right = SsaVarId::from_index(1); -/// let result = SsaVarId::from_index(2); -/// let instr = SsaInstruction::new( -/// cil_instr, -/// SsaOp::Add { dest: result, left, right }, -/// ); -/// ``` -#[derive(Debug, Clone)] -pub struct SsaInstruction { - /// The original CIL instruction (retained for debugging and source mapping). - original: Instruction, - - /// The decomposed SSA operation. - /// - /// This is the authoritative representation used by analysis passes. It provides - /// a clean `result = op(operands)` form where all data dependencies are explicit. - op: SsaOp, - - /// Resolved result type from the converter's TypeContext. - /// - /// This captures the precise type information available during initial SSA - /// construction (when the full assembly metadata is available). It survives - /// through deobfuscation transforms and is used by rebuild and codegen to - /// recover types that cannot be inferred structurally from the op alone - /// (e.g., Call return types, LoadField types, LoadArg/LoadLocal types). - result_type: Option, -} - -impl SsaInstruction { - /// Creates a new SSA instruction with a decomposed operation. - /// - /// # Arguments - /// - /// * `original` - The original CIL instruction - /// * `op` - The decomposed SSA operation - #[must_use] - pub fn new(original: Instruction, op: SsaOp) -> Self { - Self { - original, - op, - result_type: None, - } - } - - /// Creates an SSA instruction with only a decomposed operation (no CIL instruction). - /// - /// This is useful for synthetic instructions like phi nodes that don't - /// correspond to any CIL instruction. - #[must_use] - pub fn synthetic(op: SsaOp) -> Self { - // Create a dummy instruction for synthetic ops - let dummy = Instruction { - rva: 0, - offset: 0, - size: 0, - opcode: 0, - prefix: 0, - mnemonic: "synthetic", - category: InstructionCategory::Misc, - flow_type: FlowType::Sequential, - operand: Operand::None, - stack_behavior: StackBehavior { - pops: 0, - pushes: 0, - net_effect: 0, - }, - branch_targets: vec![], - }; - Self { - original: dummy, - op, - result_type: None, - } - } - - /// Returns a reference to the original CIL instruction. - #[must_use] - pub const fn original(&self) -> &Instruction { - &self.original - } - - /// Returns the decomposed SSA operation. - #[must_use] - pub const fn op(&self) -> &SsaOp { - &self.op - } - - /// Returns a mutable reference to the decomposed SSA operation. - pub fn op_mut(&mut self) -> &mut SsaOp { - &mut self.op - } - - /// Sets the decomposed SSA operation. - /// - /// Clears `result_type` because the new op may have a different result type. - /// Callers that know the type should call `set_result_type()` afterwards. - pub fn set_op(&mut self, op: SsaOp) { - self.op = op; - self.result_type = None; - } - - /// Returns the resolved result type, if set during SSA construction. - #[must_use] - pub fn result_type(&self) -> Option<&SsaType> { - self.result_type.as_ref() - } - - /// Sets the resolved result type. - pub fn set_result_type(&mut self, ty: Option) { - self.result_type = ty; - } - - /// Builder pattern: sets the result type and returns self. - #[must_use] - pub fn with_result_type(mut self, ty: SsaType) -> Self { - self.result_type = Some(ty); - self - } - - /// Returns `true` if this instruction is a terminator. - /// - /// Terminators are instructions that end a basic block (jumps, branches, returns, throws). - #[must_use] - pub fn is_terminator(&self) -> bool { - self.op.is_terminator() - } - - /// Returns `true` if this instruction may throw an exception. - #[must_use] - pub fn may_throw(&self) -> bool { - self.op.may_throw() - } - - /// Returns `true` if this instruction is pure (has no side effects). - /// - /// Pure instructions can be eliminated if their result is unused. - #[must_use] - pub fn is_pure(&self) -> bool { - self.op.is_pure() - } - - /// Returns the SSA variables used (read) by this instruction. - #[must_use] - pub fn uses(&self) -> Vec { - self.op.uses() - } - - /// Returns the SSA variable defined by this instruction, if any. - #[must_use] - pub fn def(&self) -> Option { - self.op.dest() - } - - /// Returns `true` if this instruction defines a value. - #[must_use] - pub fn has_def(&self) -> bool { - self.op.dest().is_some() - } - - /// Returns `true` if this instruction has no uses. - #[must_use] - pub fn has_no_uses(&self) -> bool { - self.op.uses().is_empty() - } - - /// Returns the instruction's mnemonic. - #[must_use] - pub fn mnemonic(&self) -> &'static str { - self.original.mnemonic - } - - /// Returns the instruction's RVA. - #[must_use] - pub const fn rva(&self) -> u64 { - self.original.rva - } - - /// Returns all SSA variables referenced by this instruction. - /// - /// This includes both uses and the def (if present). - #[must_use] - pub fn all_variables(&self) -> Vec { - let mut vars = self.op.uses(); - if let Some(def) = self.op.dest() { - vars.push(def); - } - vars - } -} - -impl fmt::Display for SsaInstruction { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}", self.op) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - analysis::ssa::{value::ConstValue, SsaOp, SsaVarId}, - assembly::{FlowType, Instruction, InstructionCategory, Operand, StackBehavior}, - }; - - fn make_test_instruction(mnemonic: &'static str, pops: u8, pushes: u8) -> Instruction { - Instruction { - rva: 0x1000, - offset: 0, - size: 1, - opcode: 0x58, // add - prefix: 0, - mnemonic, - category: InstructionCategory::Arithmetic, - flow_type: FlowType::Sequential, - operand: Operand::None, - stack_behavior: StackBehavior { - pops, - pushes, - net_effect: pushes as i8 - pops as i8, - }, - branch_targets: vec![], - } - } - - #[test] - fn test_ssa_instruction_new() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let cil = make_test_instruction("add", 2, 1); - let op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let instr = SsaInstruction::new(cil, op); - - assert_eq!(instr.uses().len(), 2); - assert_eq!(instr.def(), Some(v2)); - assert!(instr.has_def()); - assert!(!instr.has_no_uses()); - } - - #[test] - fn test_ssa_instruction_uses() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let cil = make_test_instruction("add", 2, 1); - let op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let instr = SsaInstruction::new(cil, op); - - let uses = instr.uses(); - assert_eq!(uses.len(), 2); - assert!(uses.contains(&v0)); - assert!(uses.contains(&v1)); - } - - #[test] - fn test_ssa_instruction_all_variables() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let cil = make_test_instruction("add", 2, 1); - let op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let instr = SsaInstruction::new(cil, op); - - let vars = instr.all_variables(); - assert_eq!(vars.len(), 3); - assert!(vars.contains(&v0)); - assert!(vars.contains(&v1)); - assert!(vars.contains(&v2)); - } - - #[test] - fn test_ssa_instruction_all_variables_no_def() { - let v = SsaVarId::from_index(0); - let cil = make_test_instruction("pop", 1, 0); - let op = SsaOp::Pop { value: v }; - let instr = SsaInstruction::new(cil, op); - - let vars = instr.all_variables(); - assert_eq!(vars.len(), 1); - assert!(vars.contains(&v)); - } - - #[test] - fn test_ssa_instruction_display() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let cil = make_test_instruction("add", 2, 1); - let op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let instr = SsaInstruction::new(cil, op); - - assert_eq!(format!("{instr}"), "v2 = add v0, v1"); - } - - #[test] - fn test_ssa_instruction_display_no_def() { - let v = SsaVarId::from_index(5); - let cil = make_test_instruction("pop", 1, 0); - let op = SsaOp::Pop { value: v }; - let instr = SsaInstruction::new(cil, op); - - assert_eq!(format!("{instr}"), "pop v5"); - } - - #[test] - fn test_ssa_instruction_display_const() { - let v = SsaVarId::from_index(3); - let cil = make_test_instruction("ldc.i4", 0, 1); - let op = SsaOp::Const { - dest: v, - value: ConstValue::I32(42), - }; - let instr = SsaInstruction::new(cil, op); - - assert_eq!(format!("{instr}"), "v3 = 42"); - } - - #[test] - fn test_ssa_instruction_synthetic() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let instr = SsaInstruction::synthetic(op); - - assert_eq!(instr.uses().len(), 2); - assert_eq!(instr.def(), Some(v2)); - assert_eq!(instr.mnemonic(), "synthetic"); - } - - #[test] - fn test_ssa_instruction_set_op() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let v3 = SsaVarId::from_index(3); - let cil = make_test_instruction("add", 2, 1); - let op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let mut instr = SsaInstruction::new(cil, op); - - // Replace with a different operation - let new_op = SsaOp::Sub { - dest: v3, - left: v0, - right: v1, - }; - instr.set_op(new_op); - - assert_eq!(instr.def(), Some(v3)); - assert!(matches!(instr.op(), SsaOp::Sub { .. })); - } -} diff --git a/dotscope/src/analysis/ssa/liveness.rs b/dotscope/src/analysis/ssa/liveness.rs deleted file mode 100644 index a5661af2..00000000 --- a/dotscope/src/analysis/ssa/liveness.rs +++ /dev/null @@ -1,241 +0,0 @@ -//! Liveness analysis for pruned SSA phi placement. -//! -//! Computes live-in blocks for each variable group using backward dataflow. -//! A variable is live-in to block B if: -//! - B contains a use of the variable, OR -//! - B has a successor where the variable is live-in AND B does not define it -//! -//! This information is used to prune phi placement: phi nodes are only placed -//! at dominance frontier blocks where the variable is actually live, avoiding -//! dead-on-arrival phi nodes that DCE would need to clean up. -//! -//! Based on LLVM's mem2reg `ComputeLiveInBlocks()` approach. - -use std::collections::BTreeMap; - -use crate::utils::BitSet; - -/// Computes live-in blocks for each variable group. -/// -/// Given the definition sites and use sites for each group ID, computes which -/// blocks each variable is live-in to using backward dataflow. -/// -/// # Arguments -/// * `defs` - For each group ID, the set of blocks that contain a definition -/// * `uses` - For each group ID, the set of blocks that contain a use -/// * `successors` - CFG successors for each block index -/// * `block_count` - Total number of blocks in the CFG -/// -/// # Returns -/// For each group ID, the set of blocks where the variable is live-in. -pub fn compute_live_in_blocks( - defs: &BTreeMap, - uses: &BTreeMap, - successors: &[Vec], - block_count: usize, -) -> BTreeMap { - let mut live_in: BTreeMap = BTreeMap::new(); - - // Pre-compute predecessors from successors (avoids recomputation per group) - let mut predecessors: Vec> = vec![Vec::new(); block_count]; - for (block_idx, succs) in successors.iter().enumerate() { - for &succ in succs { - if let Some(preds) = predecessors.get_mut(succ) { - preds.push(block_idx); - } - } - } - - // For each group that has both defs and uses, compute liveness - for (group, def_blocks) in defs { - let use_blocks = match uses.get(group) { - Some(blocks) => blocks, - None => continue, // No uses → variable is dead everywhere → no phis needed - }; - - // Backward dataflow: start from use blocks, propagate backward - let mut live_in_set = BitSet::new(block_count); - let mut worklist: Vec = Vec::new(); - - // Seed: blocks that contain a use and don't define the variable before the use. - // For simplicity (and matching LLVM's approach), we treat a block as a use block - // if it contains any use, regardless of whether it also defines the variable. - // The key insight: if a block both defines and uses a variable, the use might - // refer to a previous definition from outside the block. - // - // Conservative approach: if a block uses the variable but also defines it, - // it's only live-in if the use comes before the def. Since we don't track - // instruction ordering here, we conservatively mark use-only blocks as live-in - // and blocks that both use and define as live-in (slightly over-approximate - // but safe — may place a few extra phis that trivial phi elimination removes). - for use_block in use_blocks.iter() { - if live_in_set.insert(use_block) { - worklist.push(use_block); - } - } - - // Propagate backward: variable is live-in to predecessor if it's live-in to - // a successor and the predecessor doesn't define it (or it's live-in to the - // predecessor due to a direct use). - while let Some(block_idx) = worklist.pop() { - let Some(preds) = predecessors.get(block_idx) else { - continue; - }; - for &pred in preds { - // If predecessor defines the variable, liveness doesn't propagate further - // (the definition satisfies the use). But the predecessor itself is NOT - // live-in for this variable (the def originates here). - if def_blocks.contains(pred) { - continue; - } - if live_in_set.insert(pred) { - worklist.push(pred); - } - } - } - - live_in.insert(*group, live_in_set); - } - - live_in -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - - use crate::utils::BitSet; - - use super::compute_live_in_blocks; - - fn bitset_from(cap: usize, indices: &[usize]) -> BitSet { - let mut bs = BitSet::new(cap); - for &i in indices { - bs.insert(i); - } - bs - } - - /// Diamond CFG: - /// 0 → 1, 2 - /// 1 → 3 - /// 2 → 3 - /// - /// def at 0, use at 3 → live through 1 and 2 - #[test] - fn test_diamond_liveness() { - let mut defs = BTreeMap::new(); - let mut uses = BTreeMap::new(); - let group: u32 = 0; - - defs.insert(group, bitset_from(4, &[0])); - uses.insert(group, bitset_from(4, &[3])); - - let successors = vec![ - vec![1, 2], // block 0 - vec![3], // block 1 - vec![3], // block 2 - vec![], // block 3 - ]; - - let live_in = compute_live_in_blocks(&defs, &uses, &successors, 4); - - let live = live_in.get(&group).unwrap(); - assert!(live.contains(3), "use block should be live-in"); - assert!(live.contains(1), "path block 1 should be live-in"); - assert!(live.contains(2), "path block 2 should be live-in"); - assert!(!live.contains(0), "def block should NOT be live-in"); - } - - /// Loop CFG: - /// 0 → 1 - /// 1 → 2, 3 - /// 2 → 1 - /// 3 (exit) - /// - /// def at 0 and 2, use at 1 → live in 1 and 2 - #[test] - fn test_loop_liveness() { - let mut defs = BTreeMap::new(); - let mut uses = BTreeMap::new(); - let group: u32 = 0; - - defs.insert(group, bitset_from(4, &[0, 2])); - uses.insert(group, bitset_from(4, &[1])); - - let successors = vec![ - vec![1], // block 0 - vec![2, 3], // block 1 (loop header) - vec![1], // block 2 (loop body) - vec![], // block 3 (exit) - ]; - - let live_in = compute_live_in_blocks(&defs, &uses, &successors, 4); - - let live = live_in.get(&group).unwrap(); - assert!(live.contains(1), "use/header block should be live-in"); - assert!(!live.contains(0), "def block 0 should NOT be live-in"); - assert!(!live.contains(2), "def block 2 should NOT be live-in"); - } - - /// No uses → no liveness entry at all - #[test] - fn test_no_uses() { - let mut defs = BTreeMap::new(); - let uses = BTreeMap::new(); - let group: u32 = 0; - - defs.insert(group, bitset_from(2, &[0])); - - let successors = vec![vec![1], vec![]]; - - let live_in = compute_live_in_blocks(&defs, &uses, &successors, 2); - - assert!( - !live_in.contains_key(&group), - "no uses means no liveness entry" - ); - } - - /// Nested if-else: - /// 0 → 1, 2 - /// 1 → 3, 4 - /// 2 → 5 - /// 3 → 5 - /// 4 → 5 - /// - /// def at 1 and 2, use at 5 - #[test] - fn test_nested_if_liveness() { - let mut defs = BTreeMap::new(); - let mut uses = BTreeMap::new(); - let group: u32 = 0; - - defs.insert(group, bitset_from(6, &[1, 2])); - uses.insert(group, bitset_from(6, &[5])); - - let successors = vec![ - vec![1, 2], // block 0 - vec![3, 4], // block 1 - vec![5], // block 2 - vec![5], // block 3 - vec![5], // block 4 - vec![], // block 5 - ]; - - let live_in = compute_live_in_blocks(&defs, &uses, &successors, 6); - - let live = live_in.get(&group).unwrap(); - assert!(live.contains(5), "use block should be live-in"); - assert!( - live.contains(3), - "block 3 should be live-in (no def, on path)" - ); - assert!( - live.contains(4), - "block 4 should be live-in (no def, on path)" - ); - assert!(!live.contains(1), "def block 1 should NOT be live-in"); - assert!(!live.contains(2), "def block 2 should NOT be live-in"); - } -} diff --git a/dotscope/src/analysis/ssa/memory.rs b/dotscope/src/analysis/ssa/memory.rs deleted file mode 100644 index 9861e7cd..00000000 --- a/dotscope/src/analysis/ssa/memory.rs +++ /dev/null @@ -1,1126 +0,0 @@ -//! Memory SSA (MSSA) for tracking versioned memory locations. -//! -//! This module extends SSA to track state stored in fields, arrays, and heap locations. -//! Memory SSA is essential for precise analysis of obfuscated code that stores state -//! in memory rather than local variables. -//! -//! # Architecture -//! -//! Memory SSA builds on top of traditional SSA by: -//! -//! 1. **Memory Locations**: Abstract representation of memory (fields, arrays, pointers) -//! 2. **Memory Versioning**: Each store creates a new version, each load reads a version -//! 3. **Memory Phi Nodes**: At control flow merges, memory versions are merged -//! -//! ```text -//! Traditional SSA: Memory SSA: -//! -//! v1 = x v1 = x -//! obj.field = v1 mem[obj.field]₁ = v1 -//! ... ... -//! v2 = obj.field v2 = mem[obj.field]₁ -//! ``` -//! -//! # Memory Location Hierarchy -//! -//! Memory locations form a hierarchy for alias analysis: -//! -//! ```text -//! Unknown (may alias anything) -//! ├── StaticField(token) - Specific static field -//! ├── InstanceField(obj, token) - Specific instance field -//! ├── ArrayElement(arr, idx) - Specific array element -//! │ ├── ArrayElement(arr, Constant(i)) - Known index -//! │ └── ArrayElement(arr, Variable(v)) - Unknown index (may alias) -//! └── Indirect(addr) - Pointer dereference -//! ``` -//! -//! # Usage -//! -//! ```rust,ignore -//! use dotscope::analysis::{MemorySsa, SsaFunction, SsaCfg}; -//! -//! let cfg = SsaCfg::from_ssa(&ssa); -//! let mem_ssa = MemorySsa::build(&ssa, &cfg); -//! -//! // Query memory version at a specific point -//! let loc = MemoryLocation::StaticField(field_token); -//! if let Some(version) = mem_ssa.version_at_block(&loc, block_idx) { -//! println!("Memory version: {}", version); -//! } -//! ``` -//! -//! # References -//! -//! - Chow et al., "Effective Representation of Aliases and Indirect Memory -//! Operations in SSA Form", CC 1996 - -use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet, VecDeque}; - -use crate::{ - analysis::ssa::{FieldRef, SsaCfg, SsaFunction, SsaOp, SsaVarId}, - utils::graph::{ - algorithms::{compute_dominance_frontiers, compute_dominators}, - GraphBase, NodeId, RootedGraph, Successors, - }, -}; - -/// Represents an abstract memory location. -/// -/// Memory locations are used to track which memory is being accessed by -/// load/store operations. The granularity varies by location type: -/// -/// - Static fields are precise (one location per field) -/// - Instance fields depend on object identity (may alias if objects may alias) -/// - Array elements depend on both array identity and index -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub enum MemoryLocation { - /// Instance field access: `object.field` - /// - /// The `SsaVarId` identifies the object, and `FieldRef` identifies the field. - /// Two instance field locations may alias if the objects may alias. - InstanceField(SsaVarId, FieldRef), - - /// Static field access: `ClassName.field` - /// - /// Static fields are uniquely identified by their token. Two static field - /// locations alias iff they have the same token. - StaticField(FieldRef), - - /// Array element access: `array[index]` - /// - /// The `SsaVarId` identifies the array, and `ArrayIndex` identifies the index. - /// Array element locations may alias based on array identity and index overlap. - ArrayElement(SsaVarId, ArrayIndex), - - /// Indirect memory access through a pointer: `*ptr` - /// - /// The `SsaVarId` is the pointer variable. Indirect accesses are the most - /// conservative - they may alias anything the pointer could point to. - Indirect(SsaVarId), - - /// Unknown/escaped memory. - /// - /// Used when we can't determine the exact location (e.g., after a call - /// that may modify memory, or for volatile accesses). - Unknown, -} - -impl MemoryLocation { - /// Returns the base object variable, if any. - /// - /// For instance fields and arrays, this is the object/array variable. - /// For static fields and unknown locations, returns `None`. - #[must_use] - pub fn base_object(&self) -> Option { - match self { - Self::InstanceField(obj, _) => Some(*obj), - Self::ArrayElement(arr, _) => Some(*arr), - Self::Indirect(ptr) => Some(*ptr), - Self::StaticField(_) | Self::Unknown => None, - } - } - - /// Returns `true` if this location may alias the other location. - /// - /// This is a conservative analysis - if we can't prove non-aliasing, - /// we assume aliasing is possible. - #[must_use] - pub fn may_alias(&self, other: &Self) -> bool { - match (self, other) { - // Unknown aliases everything; Indirect may alias any concrete location - (Self::Unknown, _) - | (_, Self::Unknown) - | ( - Self::Indirect(_), - Self::InstanceField(..) | Self::ArrayElement(..) | Self::StaticField(_), - ) - | ( - Self::InstanceField(..) | Self::ArrayElement(..) | Self::StaticField(_), - Self::Indirect(_), - ) => true, - - // Static fields alias iff same field - (Self::StaticField(f1), Self::StaticField(f2)) => f1 == f2, - - // Static fields don't alias instance fields or arrays; - // Instance fields don't alias array elements (different memory types) - (Self::StaticField(_), Self::InstanceField(..) | Self::ArrayElement(..)) - | (Self::InstanceField(..) | Self::ArrayElement(..), Self::StaticField(_)) - | (Self::InstanceField(..), Self::ArrayElement(..)) - | (Self::ArrayElement(..), Self::InstanceField(..)) => false, - - // Instance fields alias if same object AND same field - // Conservative: different objects assumed to not alias - (Self::InstanceField(obj1, f1), Self::InstanceField(obj2, f2)) => { - obj1 == obj2 && f1 == f2 - } - - // Array elements alias if same array AND indices may overlap - (Self::ArrayElement(arr1, idx1), Self::ArrayElement(arr2, idx2)) => { - arr1 == arr2 && idx1.may_overlap(idx2) - } - - // Indirect access may alias anything with same pointer - (Self::Indirect(p1), Self::Indirect(p2)) => p1 == p2, - } - } - - /// Returns `true` if this location must alias the other location. - /// - /// This is a more precise analysis - returns `true` only if we can - /// prove the locations definitely refer to the same memory. - #[must_use] - pub fn must_alias(&self, other: &Self) -> bool { - match (self, other) { - // Static fields must-alias iff same field - (Self::StaticField(f1), Self::StaticField(f2)) => f1 == f2, - - // Instance fields must-alias iff same object AND same field - (Self::InstanceField(obj1, f1), Self::InstanceField(obj2, f2)) => { - obj1 == obj2 && f1 == f2 - } - - // Array elements must-alias iff same array AND same constant index - (Self::ArrayElement(arr1, idx1), Self::ArrayElement(arr2, idx2)) => { - arr1 == arr2 && idx1.must_equal(idx2) - } - - // Indirect must-alias iff same pointer - (Self::Indirect(p1), Self::Indirect(p2)) => p1 == p2, - - // Unknown never must-aliases (not precise enough) - _ => false, - } - } -} - -/// Represents an array index for array element locations. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub enum ArrayIndex { - /// A constant index value. - Constant(i64), - /// A variable index. - Variable(SsaVarId), - /// Unknown index (could be any value). - Unknown, -} - -impl ArrayIndex { - /// Returns `true` if these indices may refer to the same element. - #[must_use] - pub fn may_overlap(&self, other: &Self) -> bool { - match (self, other) { - // Unknown overlaps everything; Variable indices may overlap (conservative) - (Self::Unknown | Self::Variable(_), _) | (_, Self::Unknown | Self::Variable(_)) => true, - // Constants overlap iff equal - (Self::Constant(i1), Self::Constant(i2)) => i1 == i2, - } - } - - /// Returns `true` if these indices must refer to the same element. - #[must_use] - pub fn must_equal(&self, other: &Self) -> bool { - match (self, other) { - (Self::Constant(i1), Self::Constant(i2)) => i1 == i2, - (Self::Variable(v1), Self::Variable(v2)) => v1 == v2, - _ => false, - } - } -} - -/// A memory operation (load or store). -#[derive(Debug, Clone)] -pub enum MemoryOp { - /// A memory load operation. - Load { - /// The memory location being loaded. - location: MemoryLocation, - /// The SSA variable receiving the loaded value. - dest: SsaVarId, - /// Block containing this operation. - block: usize, - /// Instruction index within the block. - instr: usize, - }, - /// A memory store operation. - Store { - /// The memory location being stored to. - location: MemoryLocation, - /// The SSA variable being stored. - value: SsaVarId, - /// Block containing this operation. - block: usize, - /// Instruction index within the block. - instr: usize, - }, -} - -impl MemoryOp { - /// Returns the memory location accessed by this operation. - #[must_use] - pub fn location(&self) -> &MemoryLocation { - match self { - Self::Load { location, .. } | Self::Store { location, .. } => location, - } - } - - /// Returns the block index containing this operation. - #[must_use] - pub fn block(&self) -> usize { - match self { - Self::Load { block, .. } | Self::Store { block, .. } => *block, - } - } - - /// Returns the instruction index within the block. - #[must_use] - pub fn instr(&self) -> usize { - match self { - Self::Load { instr, .. } | Self::Store { instr, .. } => *instr, - } - } - - /// Returns `true` if this is a store operation. - #[must_use] - pub fn is_store(&self) -> bool { - matches!(self, Self::Store { .. }) - } - - /// Returns `true` if this is a load operation. - #[must_use] - pub fn is_load(&self) -> bool { - matches!(self, Self::Load { .. }) - } -} - -/// A phi node for memory locations. -/// -/// Memory phi nodes are placed at control flow merge points where different -/// memory versions from different predecessors need to be merged. -#[derive(Debug, Clone)] -pub struct MemoryPhi { - /// The memory location this phi node is for. - pub location: MemoryLocation, - /// The result version number produced by this phi. - pub result_version: u32, - /// The operands from each predecessor. - pub operands: Vec, -} - -impl MemoryPhi { - /// Creates a new memory phi node. - #[must_use] - pub fn new(location: MemoryLocation, result_version: u32) -> Self { - Self { - location, - result_version, - operands: Vec::new(), - } - } - - /// Adds an operand from a predecessor block. - pub fn add_operand(&mut self, predecessor: usize, version: u32) { - self.operands.push(MemoryPhiOperand { - predecessor, - version, - }); - } - - /// Returns the operand from a specific predecessor, if present. - #[must_use] - pub fn operand_from(&self, predecessor: usize) -> Option<&MemoryPhiOperand> { - self.operands - .iter() - .find(|op| op.predecessor == predecessor) - } -} - -/// An operand of a memory phi node. -#[derive(Debug, Clone)] -pub struct MemoryPhiOperand { - /// The predecessor block this operand comes from. - pub predecessor: usize, - /// The memory version from that predecessor. - pub version: u32, -} - -/// Memory version identifier. -/// -/// Combines a memory location with a version number to uniquely identify -/// a specific "value" of that memory location. -#[derive(Debug, Clone, Hash, PartialEq, Eq)] -pub struct MemoryVersion { - /// The memory location. - pub location: MemoryLocation, - /// The version number. - pub version: u32, -} - -impl MemoryVersion { - /// Creates a new memory version. - #[must_use] - pub fn new(location: MemoryLocation, version: u32) -> Self { - Self { location, version } - } -} - -/// Definition site for a memory version. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum MemoryDefSite { - /// Defined at entry (initial version). - Entry, - /// Defined by a store instruction. - Store { - /// The block containing the store. - block: usize, - /// The instruction index within the block. - instr: usize, - }, - /// Defined by a memory phi node. - Phi { - /// The block containing the phi node. - block: usize, - }, -} - -/// Memory SSA representation. -/// -/// This structure tracks versioned memory locations throughout a function, -/// enabling precise tracking of memory state for analysis. -#[derive(Debug)] -pub struct MemorySsa { - /// Next version number for each memory location. - next_version: HashMap, - - /// Memory phi nodes at each block. - /// Key is block index, value is list of memory phi nodes. - memory_phis: BTreeMap>, - - /// Definition sites for each memory version. - definitions: HashMap, - - /// Memory version at block entry for each location. - /// Key is (location, block), value is version. - entry_versions: HashMap<(MemoryLocation, usize), u32>, - - /// Memory version at block exit for each location. - /// Key is (location, block), value is version. - exit_versions: HashMap<(MemoryLocation, usize), u32>, - - /// All identified memory operations. - operations: Vec, - - /// All unique memory locations in the function. - locations: HashSet, -} - -impl MemorySsa { - /// Creates an empty Memory SSA structure. - #[must_use] - pub fn new() -> Self { - Self { - next_version: HashMap::new(), - memory_phis: BTreeMap::new(), - definitions: HashMap::new(), - entry_versions: HashMap::new(), - exit_versions: HashMap::new(), - operations: Vec::new(), - locations: HashSet::new(), - } - } - - /// Builds Memory SSA from an SSA function. - /// - /// This performs the full Memory SSA construction: - /// 1. Identify all memory operations - /// 2. Place memory phi nodes at dominance frontiers - /// 3. Rename memory versions using dominator tree traversal - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// * `cfg` - The control flow graph of the function. - /// - /// # Returns - /// - /// A complete Memory SSA representation. - #[must_use] - pub fn build(ssa: &SsaFunction, cfg: &SsaCfg<'_>) -> Self { - let mut mem_ssa = Self::new(); - - // Phase 1: Identify all memory operations - mem_ssa.identify_memory_operations(ssa); - - // Phase 2: Place memory phi nodes - mem_ssa.place_memory_phis(cfg); - - // Phase 3: Rename memory versions - mem_ssa.rename_memory_versions(ssa, cfg); - - mem_ssa - } - - /// Returns the memory phi nodes at a block. - #[must_use] - pub fn memory_phis(&self, block: usize) -> &[MemoryPhi] { - self.memory_phis.get(&block).map_or(&[], Vec::as_slice) - } - - /// Returns all memory operations. - #[must_use] - pub fn operations(&self) -> &[MemoryOp] { - &self.operations - } - - /// Returns all unique memory locations. - #[must_use] - pub fn locations(&self) -> &HashSet { - &self.locations - } - - /// Returns the memory version at block entry for a location. - #[must_use] - pub fn version_at_entry(&self, location: &MemoryLocation, block: usize) -> Option { - self.entry_versions.get(&(location.clone(), block)).copied() - } - - /// Returns the memory version at block exit for a location. - #[must_use] - pub fn version_at_exit(&self, location: &MemoryLocation, block: usize) -> Option { - self.exit_versions.get(&(location.clone(), block)).copied() - } - - /// Returns the definition site for a memory version. - #[must_use] - pub fn definition(&self, version: &MemoryVersion) -> Option { - self.definitions.get(version).copied() - } - - /// Returns the next version number for a location (and increments it). - fn allocate_version(&mut self, location: &MemoryLocation) -> u32 { - let version = self.next_version.entry(location.clone()).or_insert(0); - let result = *version; - *version = version.saturating_add(1); - result - } - - /// Returns the current version number for a location without incrementing. - fn current_version(&self, location: &MemoryLocation) -> u32 { - self.next_version - .get(location) - .copied() - .unwrap_or(0) - .saturating_sub(1) - } - - /// Phase 1: Identify all memory operations in the SSA function. - fn identify_memory_operations(&mut self, ssa: &SsaFunction) { - for (block_idx, instr_idx, instr) in ssa.iter_instructions() { - if let Some(mem_op) = Self::classify_memory_operation(instr.op(), block_idx, instr_idx) - { - self.locations.insert(mem_op.location().clone()); - self.operations.push(mem_op); - } - } - } - - /// Classifies an SSA operation as a memory operation, if applicable. - fn classify_memory_operation(op: &SsaOp, block: usize, instr: usize) -> Option { - match op { - SsaOp::LoadField { - dest, - object, - field, - } => { - let location = MemoryLocation::InstanceField(*object, *field); - Some(MemoryOp::Load { - location, - dest: *dest, - block, - instr, - }) - } - SsaOp::StoreField { - object, - field, - value, - } => { - let location = MemoryLocation::InstanceField(*object, *field); - Some(MemoryOp::Store { - location, - value: *value, - block, - instr, - }) - } - SsaOp::LoadStaticField { dest, field } => { - let location = MemoryLocation::StaticField(*field); - Some(MemoryOp::Load { - location, - dest: *dest, - block, - instr, - }) - } - SsaOp::StoreStaticField { field, value } => { - let location = MemoryLocation::StaticField(*field); - Some(MemoryOp::Store { - location, - value: *value, - block, - instr, - }) - } - SsaOp::LoadElement { - dest, array, index, .. - } => { - let idx = Self::resolve_array_index(*index); - let location = MemoryLocation::ArrayElement(*array, idx); - Some(MemoryOp::Load { - location, - dest: *dest, - block, - instr, - }) - } - SsaOp::StoreElement { - array, - index, - value, - .. - } => { - let idx = Self::resolve_array_index(*index); - let location = MemoryLocation::ArrayElement(*array, idx); - Some(MemoryOp::Store { - location, - value: *value, - block, - instr, - }) - } - SsaOp::LoadIndirect { dest, addr, .. } => { - let location = MemoryLocation::Indirect(*addr); - Some(MemoryOp::Load { - location, - dest: *dest, - block, - instr, - }) - } - SsaOp::StoreIndirect { addr, value, .. } => { - let location = MemoryLocation::Indirect(*addr); - Some(MemoryOp::Store { - location, - value: *value, - block, - instr, - }) - } - _ => None, - } - } - - /// Resolves an array index to an `ArrayIndex` abstraction. - fn resolve_array_index(index_var: SsaVarId) -> ArrayIndex { - // For now, treat all variable indices as unknown - // Could be improved with constant propagation - ArrayIndex::Variable(index_var) - } - - /// Phase 2: Place memory phi nodes at dominance frontiers. - fn place_memory_phis(&mut self, cfg: &SsaCfg<'_>) { - let block_count = cfg.node_count(); - if block_count == 0 { - return; - } - - // Compute dominators and dominance frontiers - let dom_tree = compute_dominators(cfg, cfg.entry()); - let frontiers = compute_dominance_frontiers(cfg, &dom_tree); - - // For each memory location, find blocks that define it (stores) - let mut def_blocks: HashMap> = HashMap::new(); - for op in &self.operations { - if op.is_store() { - def_blocks - .entry(op.location().clone()) - .or_default() - .insert(op.block()); - } - } - - // Standard phi placement algorithm (iterated dominance frontier) - for (location, defs) in def_blocks { - let mut phi_blocks: BTreeSet = BTreeSet::new(); - let mut worklist: VecDeque = defs.iter().copied().collect(); - let mut processed: BTreeSet = BTreeSet::new(); - - while let Some(block) = worklist.pop_front() { - if !processed.insert(block) { - continue; - } - - let node_id = NodeId::new(block); - if node_id.index() >= frontiers.len() { - continue; - } - - let Some(frontier_set) = frontiers.get(node_id.index()) else { - continue; - }; - for frontier_block in frontier_set.iter() { - if phi_blocks.insert(frontier_block) { - // Add phi at frontier - let version = self.allocate_version(&location); - let phi = MemoryPhi::new(location.clone(), version); - self.memory_phis - .entry(frontier_block) - .or_default() - .push(phi); - self.definitions.insert( - MemoryVersion::new(location.clone(), version), - MemoryDefSite::Phi { - block: frontier_block, - }, - ); - worklist.push_back(frontier_block); - } - } - } - } - } - - /// Phase 3: Rename memory versions using dominator tree traversal. - fn rename_memory_versions(&mut self, ssa: &SsaFunction, cfg: &SsaCfg<'_>) { - let block_count = cfg.node_count(); - if block_count == 0 { - return; - } - - // Compute dominators for traversal order - let dom_tree = compute_dominators(cfg, cfg.entry()); - - // Stack of versions for each location - let mut version_stacks: HashMap> = HashMap::new(); - - // Initialize all locations with version 0 (entry version) - let locations: Vec<_> = self.locations.iter().cloned().collect(); - for location in locations { - let entry_version = self.allocate_version(&location); - version_stacks - .entry(location.clone()) - .or_default() - .push(entry_version); - self.definitions.insert( - MemoryVersion::new(location, entry_version), - MemoryDefSite::Entry, - ); - } - - // Rename in dominator tree order (preorder) - let mut visited = vec![false; block_count]; - let mut worklist = vec![cfg.entry().index()]; - - while let Some(block_idx) = worklist.pop() { - match visited.get(block_idx) { - Some(true) => continue, - None => continue, - Some(false) => {} - } - if let Some(slot) = visited.get_mut(block_idx) { - *slot = true; - } - - self.rename_block(block_idx, ssa, cfg, &mut version_stacks); - - // Add dominated blocks to worklist - for child in dom_tree.children(NodeId::new(block_idx)) { - if visited.get(child.index()).copied() == Some(false) { - worklist.push(child.index()); - } - } - } - } - - /// Renames memory versions within a single block. - fn rename_block( - &mut self, - block_idx: usize, - ssa: &SsaFunction, - cfg: &SsaCfg<'_>, - version_stacks: &mut HashMap>, - ) { - // Record entry versions - for location in self.locations.clone() { - if let Some(&version) = version_stacks.get(&location).and_then(|s| s.last()) { - self.entry_versions - .insert((location.clone(), block_idx), version); - } - } - - // Process memory phi nodes - they define new versions - if let Some(phis) = self.memory_phis.get(&block_idx).cloned() { - for phi in phis { - version_stacks - .entry(phi.location.clone()) - .or_default() - .push(phi.result_version); - } - } - - // Process instructions in the block - let Some(block) = ssa.block(block_idx) else { - return; - }; - - for (instr_idx, instr) in block.instructions().iter().enumerate() { - // Handle stores - create new version - if let Some(mem_op) = Self::classify_memory_operation(instr.op(), block_idx, instr_idx) - { - if mem_op.is_store() { - let location = mem_op.location().clone(); - let new_version = self.allocate_version(&location); - version_stacks - .entry(location.clone()) - .or_default() - .push(new_version); - self.definitions.insert( - MemoryVersion::new(location, new_version), - MemoryDefSite::Store { - block: block_idx, - instr: instr_idx, - }, - ); - } - } - } - - // Record exit versions - for location in self.locations.clone() { - if let Some(&version) = version_stacks.get(&location).and_then(|s| s.last()) { - self.exit_versions - .insert((location.clone(), block_idx), version); - } - } - - // Fill in phi operands for successors - for succ_id in cfg.successors(NodeId::new(block_idx)) { - let succ_idx = succ_id.index(); - if let Some(phis) = self.memory_phis.get_mut(&succ_idx) { - for phi in phis { - if let Some(&version) = version_stacks.get(&phi.location).and_then(|s| s.last()) - { - phi.add_operand(block_idx, version); - } - } - } - } - } - - /// Returns statistics about the Memory SSA. - #[must_use] - pub fn stats(&self) -> MemorySsaStats { - let total_phis = self.memory_phis.values().map(Vec::len).sum(); - let store_count = self.operations.iter().filter(|op| op.is_store()).count(); - let load_count = self.operations.iter().filter(|op| op.is_load()).count(); - - MemorySsaStats { - location_count: self.locations.len(), - memory_phi_count: total_phis, - store_count, - load_count, - version_count: self.definitions.len(), - } - } -} - -impl Default for MemorySsa { - fn default() -> Self { - Self::new() - } -} - -/// Statistics about Memory SSA. -#[derive(Debug, Clone, Copy)] -pub struct MemorySsaStats { - /// Number of unique memory locations tracked. - pub location_count: usize, - /// Number of memory phi nodes placed. - pub memory_phi_count: usize, - /// Number of store operations. - pub store_count: usize, - /// Number of load operations. - pub load_count: usize, - /// Total number of memory versions. - pub version_count: usize, -} - -/// Memory state tracker for path-aware evaluation. -/// -/// This tracks the memory values along a specific execution path, enabling -/// precise tracking of memory contents during symbolic or concrete evaluation. -#[derive(Debug, Clone)] -pub struct MemoryState { - /// Current memory values: location -> (version, value as SSA variable). - values: HashMap, - /// Reference to the Memory SSA for version lookups. - mem_ssa: Option>, -} - -impl MemoryState { - /// Creates a new empty memory state. - #[must_use] - pub fn new() -> Self { - Self { - values: HashMap::new(), - mem_ssa: None, - } - } - - /// Creates a memory state with a reference to Memory SSA. - #[must_use] - pub fn with_mem_ssa(mem_ssa: std::sync::Arc) -> Self { - Self { - values: HashMap::new(), - mem_ssa: Some(mem_ssa), - } - } - - /// Records a memory store. - pub fn store(&mut self, location: MemoryLocation, value: SsaVarId, version: u32) { - self.values.insert(location, (version, value)); - } - - /// Loads from a memory location. - /// - /// Returns the SSA variable holding the value, if known. - #[must_use] - pub fn load(&self, location: &MemoryLocation) -> Option { - // Direct match - if let Some((_, value)) = self.values.get(location) { - return Some(*value); - } - - // Check for aliasing locations - for (loc, (_, value)) in &self.values { - if location.must_alias(loc) { - return Some(*value); - } - } - - None - } - - /// Returns the current version for a location, if known. - #[must_use] - pub fn version(&self, location: &MemoryLocation) -> Option { - self.values.get(location).map(|(v, _)| *v) - } - - /// Checks if any stored location may alias the given location. - #[must_use] - pub fn has_may_alias(&self, location: &MemoryLocation) -> bool { - self.values.keys().any(|loc| loc.may_alias(location)) - } - - /// Clears all memory state. - pub fn clear(&mut self) { - self.values.clear(); - } - - /// Returns the number of tracked locations. - #[must_use] - pub fn len(&self) -> usize { - self.values.len() - } - - /// Returns `true` if no memory is being tracked. - #[must_use] - pub fn is_empty(&self) -> bool { - self.values.is_empty() - } -} - -impl Default for MemoryState { - fn default() -> Self { - Self::new() - } -} - -/// Alias analysis result for a pair of memory locations. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -#[allow(clippy::enum_variant_names)] -pub enum AliasResult { - /// The locations definitely do not alias. - NoAlias, - /// The locations may alias (conservative). - MayAlias, - /// The locations definitely alias (same memory). - MustAlias, -} - -/// Performs alias analysis between two memory locations. -#[must_use] -pub fn analyze_alias(loc1: &MemoryLocation, loc2: &MemoryLocation) -> AliasResult { - if loc1.must_alias(loc2) { - AliasResult::MustAlias - } else if loc1.may_alias(loc2) { - AliasResult::MayAlias - } else { - AliasResult::NoAlias - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::{ - analysis::ssa::{FieldRef, SsaVarId}, - metadata::token::Token, - }; - - #[test] - fn test_memory_location_static_field_alias() { - let field1 = FieldRef::new(Token::new(0x04000001)); - let field2 = FieldRef::new(Token::new(0x04000002)); - - let loc1 = MemoryLocation::StaticField(field1); - let loc2 = MemoryLocation::StaticField(field1); - let loc3 = MemoryLocation::StaticField(field2); - - assert!(loc1.must_alias(&loc2)); - assert!(loc1.may_alias(&loc2)); - assert!(!loc1.may_alias(&loc3)); - } - - #[test] - fn test_memory_location_instance_field_alias() { - let field = FieldRef::new(Token::new(0x04000001)); - let obj1 = SsaVarId::from_index(0); - let obj2 = SsaVarId::from_index(1); - - let loc1 = MemoryLocation::InstanceField(obj1, field); - let loc2 = MemoryLocation::InstanceField(obj1, field); - let loc3 = MemoryLocation::InstanceField(obj2, field); - - assert!(loc1.must_alias(&loc2)); - assert!(loc1.may_alias(&loc2)); - assert!(!loc1.may_alias(&loc3)); // Different objects - } - - #[test] - fn test_array_index_overlap() { - let idx1 = ArrayIndex::Constant(5); - let idx2 = ArrayIndex::Constant(5); - let idx3 = ArrayIndex::Constant(10); - let idx4 = ArrayIndex::Unknown; - - assert!(idx1.may_overlap(&idx2)); - assert!(idx1.must_equal(&idx2)); - assert!(!idx1.may_overlap(&idx3)); - assert!(idx1.may_overlap(&idx4)); // Unknown overlaps everything - } - - #[test] - fn test_memory_location_array_element_alias() { - let arr = SsaVarId::from_index(0); - let idx1 = ArrayIndex::Constant(5); - let idx2 = ArrayIndex::Constant(5); - let idx3 = ArrayIndex::Constant(10); - - let loc1 = MemoryLocation::ArrayElement(arr, idx1); - let loc2 = MemoryLocation::ArrayElement(arr, idx2); - let loc3 = MemoryLocation::ArrayElement(arr, idx3); - - assert!(loc1.must_alias(&loc2)); - assert!(!loc1.may_alias(&loc3)); - } - - #[test] - fn test_memory_location_unknown_alias() { - let field = FieldRef::new(Token::new(0x04000001)); - let loc1 = MemoryLocation::Unknown; - let loc2 = MemoryLocation::StaticField(field); - - assert!(loc1.may_alias(&loc2)); // Unknown aliases everything - assert!(!loc1.must_alias(&loc2)); // But doesn't must-alias - } - - #[test] - fn test_alias_result() { - let field = FieldRef::new(Token::new(0x04000001)); - let loc1 = MemoryLocation::StaticField(field); - let loc2 = MemoryLocation::StaticField(field); - - assert_eq!(analyze_alias(&loc1, &loc2), AliasResult::MustAlias); - - let arr1 = SsaVarId::from_index(0); - let arr2 = SsaVarId::from_index(1); - let loc3 = MemoryLocation::ArrayElement(arr1, ArrayIndex::Constant(0)); - let loc4 = MemoryLocation::ArrayElement(arr2, ArrayIndex::Constant(0)); - - assert_eq!(analyze_alias(&loc3, &loc4), AliasResult::NoAlias); - } - - #[test] - fn test_memory_state() { - let mut state = MemoryState::new(); - let field = FieldRef::new(Token::new(0x04000001)); - let loc = MemoryLocation::StaticField(field); - let value = SsaVarId::from_index(0); - - state.store(loc.clone(), value, 1); - assert_eq!(state.load(&loc), Some(value)); - assert_eq!(state.version(&loc), Some(1)); - assert_eq!(state.len(), 1); - - state.clear(); - assert!(state.is_empty()); - } - - #[test] - fn test_memory_phi() { - let field = FieldRef::new(Token::new(0x04000001)); - let loc = MemoryLocation::StaticField(field); - - let mut phi = MemoryPhi::new(loc.clone(), 2); - phi.add_operand(0, 0); - phi.add_operand(1, 1); - - assert_eq!(phi.result_version, 2); - assert_eq!(phi.operands.len(), 2); - assert_eq!(phi.operand_from(0).unwrap().version, 0); - assert_eq!(phi.operand_from(1).unwrap().version, 1); - assert!(phi.operand_from(2).is_none()); - } - - #[test] - fn test_memory_op() { - let field = FieldRef::new(Token::new(0x04000001)); - let loc = MemoryLocation::StaticField(field); - let dest = SsaVarId::from_index(0); - let value = SsaVarId::from_index(1); - - let load = MemoryOp::Load { - location: loc.clone(), - dest, - block: 0, - instr: 5, - }; - assert!(load.is_load()); - assert!(!load.is_store()); - assert_eq!(load.block(), 0); - assert_eq!(load.instr(), 5); - - let store = MemoryOp::Store { - location: loc, - value, - block: 1, - instr: 3, - }; - assert!(!store.is_load()); - assert!(store.is_store()); - } -} diff --git a/dotscope/src/analysis/ssa/mod.rs b/dotscope/src/analysis/ssa/mod.rs index 3c6dff48..c58100eb 100644 --- a/dotscope/src/analysis/ssa/mod.rs +++ b/dotscope/src/analysis/ssa/mod.rs @@ -7,17 +7,21 @@ //! //! # Architecture //! -//! The SSA module is organized into focused sub-modules: -//! -//! - [`variable`] - SSA variable representation and identifiers -//! - [`phi`] - Phi node representation for control flow merges -//! - [`instruction`] - SSA-form instructions with explicit def/use chains -//! - [`block`] - SSA basic blocks containing phi nodes and instructions -//! - [`function`] - Complete SSA representation of a method -//! - [`builder`] - SSA construction algorithm (Cytron et al.) -//! - [`types`] - SSA type system for CIL types -//! - [`value`] - Value tracking for constant propagation and CSE -//! - [`ops`] - Decomposed SSA operations +//! The SSA primitives ([`SsaVarId`], [`PhiNode`], [`SsaBlock`], +//! [`SsaFunction`], [`SsaOp`]) and analyses ([`SsaCfg`], [`PhiAnalyzer`], +//! [`SsaEvaluator`], …) live in `analyssa::ir` / `analyssa::analysis`. The +//! files in this directory are CIL-side boundary code: +//! +//! - [`builder`] - SSA construction driver (Cytron et al.) for CIL +//! - [`converter`] - CIL → SSA conversion +//! - [`decompose`] - CIL instruction decomposition into SSA ops +//! - [`stack`] - CIL stack-typing simulator +//! - [`types`] - CIL type system (`SsaType`, `TypeRef`, `MethodRef`, …) +//! - [`target`] - `CilTarget` impl of `analyssa::Target` +//! - [`exception`] - CIL exception handler bridge +//! - [`value`], [`ops`] - CIL extension impls on `analyssa::ConstValue` / `analyssa::SsaOp` +//! - [`function`] - CIL-pinned `SsaFunctionCilExt`/`Semantics` extensions +//! - [`resolver`] - CIL-side value resolver for inline values //! //! # CIL to SSA Transformation //! @@ -70,51 +74,103 @@ //! Control Dependence Graph", ACM TOPLAS 1991 //! - Cooper & Torczon, "Engineering a Compiler", Chapter 9 -mod block; +// CIL-bound boundary code stays in dotscope (CIL → SSA conversion, stack +// typing, codegen extensions on analyssa types). mod builder; -mod cfg; -mod constraints; -mod consts; mod converter; mod decompose; -mod evaluator; mod exception; mod function; -mod instruction; -mod liveness; -mod memory; mod ops; -mod patterns; -mod phi; -mod phis; mod resolver; mod stack; mod symbolic; +mod target; mod types; mod value; -mod variable; -pub(crate) mod verifier; -pub use block::{ReplaceResult, SsaBlock}; +// Generic SSA primitives + analyses live in analyssa. The thin re-export +// shims that used to mediate via `mod cfg/consts/phi/...` collapsed into +// the `pub use` block below. + pub use builder::SsaFunctionBuilder; -pub use cfg::SsaCfg; -pub use consts::{evaluate_const_op, ConstEvaluator}; pub use converter::SsaConverter; -pub use evaluator::{ControlFlow, SsaEvaluator}; -pub use exception::SsaExceptionHandler; -pub use function::{MethodPurity, ReturnInfo, SsaFunction, TrivialPhiOptions}; -pub use instruction::SsaInstruction; -pub use ops::{BinaryOpKind, CmpKind, SsaOp, UnaryOpKind}; -pub use phi::{PhiNode, PhiOperand}; -pub use phis::PhiAnalyzer; +pub use exception::{SsaExceptionHandler, SsaExceptionHandlerCilExt}; +pub use function::{SsaFunctionCilExt, SsaFunctionSemanticsExt}; +pub use ops::{BinaryOpKind, CmpKind, SsaOp, SsaOpCilExt, UnaryOpKind}; + +// `SsaFunction`/`ReturnInfo`/`MethodPurity` live in `analyssa::ir::function`. +pub use analyssa::ir::function::MethodPurity; +/// CIL-defaulted alias of [`analyssa::ir::function::SsaFunction`]. +pub type SsaFunction = analyssa::ir::function::SsaFunction; +/// CIL-defaulted alias of [`analyssa::ir::function::ReturnInfo`]. +pub type ReturnInfo = analyssa::ir::function::ReturnInfo; pub use resolver::ValueResolver; pub use stack::{SimulationResult, StackSimulator, StackSlot, StackSlotSource}; #[cfg(feature = "z3")] pub use symbolic::Z3Solver; pub use symbolic::{SymbolicEvaluator, SymbolicExpr, SymbolicOp}; +pub use target::CilTarget; pub use types::{ resolve_corelib_valuetype, FieldRef, MethodRef, SsaType, TypeClass, TypeContext, TypeProvider, TypeRef, }; -pub use value::{AbstractValue, ConstValue}; -pub use variable::{DefSite, FunctionVarAllocator, SsaVarId, SsaVariable, UseSite, VariableOrigin}; +pub use value::{AbstractValue, ConstValue, ConstValueCilExt}; + +// Direct re-exports from analyssa for the now-collapsed shim files. Each line +// here used to be a one-line module file in `dotscope/src/analysis/ssa/`. +pub use analyssa::ir::phi::{PhiNode, PhiOperand}; +pub use analyssa::ir::variable::{ + DefSite, FunctionVarAllocator, SsaVarId, UseSite, VariableOrigin, +}; +pub use analyssa::Target; + +#[allow(unused_imports)] +pub use analyssa::analysis::consts::evaluate_const_op; +pub use analyssa::analysis::evaluator::ControlFlow; +pub use analyssa::analysis::phis::{place_pruned_phis, PhiAnalyzer}; + +/// CIL-defaulted alias of [`analyssa::ir::block::SsaBlock`]. +pub type SsaBlock = analyssa::ir::block::SsaBlock; +/// CIL-defaulted alias of [`analyssa::ir::instruction::SsaInstruction`]. +pub type SsaInstruction = analyssa::ir::instruction::SsaInstruction; +/// CIL-defaulted alias of [`analyssa::ir::variable::SsaVariable`]. +pub type SsaVariable = analyssa::ir::variable::SsaVariable; +/// CIL-defaulted alias of [`analyssa::analysis::SsaCfg`]. +pub type SsaCfg<'a, T = CilTarget> = analyssa::analysis::cfg::SsaCfg<'a, T>; +/// CIL-defaulted alias of [`analyssa::analysis::consts::ConstEvaluator`]. +pub type ConstEvaluator<'a, T = CilTarget> = analyssa::analysis::consts::ConstEvaluator<'a, T>; +/// CIL-defaulted alias of [`analyssa::analysis::evaluator::SsaEvaluator`]. +pub type SsaEvaluator<'a, T = CilTarget> = analyssa::analysis::evaluator::SsaEvaluator<'a, T>; +/// CIL-defaulted alias of [`analyssa::analysis::evaluator::ExecutionTrace`]. +pub type ExecutionTrace = analyssa::analysis::evaluator::ExecutionTrace; +/// CIL-defaulted alias of [`analyssa::analysis::patterns::PatternDetector`]. +pub type PatternDetector<'a, T = CilTarget> = analyssa::analysis::patterns::PatternDetector<'a, T>; +/// CIL-defaulted alias of [`analyssa::analysis::patterns::DispatcherPattern`]. +pub type DispatcherPattern = analyssa::analysis::patterns::DispatcherPattern; +/// CIL-defaulted alias of [`analyssa::analysis::patterns::SourceBlock`]. +pub type SourceBlock = analyssa::analysis::patterns::SourceBlock; +/// CIL-defaulted alias of [`analyssa::analysis::patterns::OpaquePredicatePattern`]. +pub type OpaquePredicatePattern = + analyssa::analysis::patterns::OpaquePredicatePattern; +/// CIL-defaulted alias of [`analyssa::analysis::patterns::PredicateResolution`]. +pub type PredicateResolution = analyssa::analysis::patterns::PredicateResolution; +/// CIL-defaulted alias of [`analyssa::analysis::constraints::Constraint`]. +pub type Constraint = analyssa::analysis::constraints::Constraint; +/// CIL-defaulted alias of [`analyssa::analysis::constraints::PathConstraint`]. +pub type PathConstraint = analyssa::analysis::constraints::PathConstraint; +/// CIL-defaulted alias of [`analyssa::analysis::memory::MemoryLocation`]. +pub type MemoryLocation = analyssa::analysis::memory::MemoryLocation; +/// CIL-defaulted alias of [`analyssa::analysis::memory::MemoryOp`]. +pub type MemoryOp = analyssa::analysis::memory::MemoryOp; +/// CIL-defaulted alias of [`analyssa::analysis::memory::MemoryPhi`]. +pub type MemoryPhi = analyssa::analysis::memory::MemoryPhi; +/// CIL-defaulted alias of [`analyssa::analysis::memory::MemoryVersion`]. +pub type MemoryVersion = analyssa::analysis::memory::MemoryVersion; +/// CIL-defaulted alias of [`analyssa::analysis::verifier::SsaVerifier`]. +pub type SsaVerifier<'a, T = CilTarget> = analyssa::analysis::verifier::SsaVerifier<'a, T>; + +/// Liveness analysis lifted to analyssa. Shim kept for back-compat. +pub mod liveness { + pub use analyssa::analysis::liveness::*; +} diff --git a/dotscope/src/analysis/ssa/ops.rs b/dotscope/src/analysis/ssa/ops.rs index 4b04576f..bb65f85c 100644 --- a/dotscope/src/analysis/ssa/ops.rs +++ b/dotscope/src/analysis/ssa/ops.rs @@ -1,3037 +1,62 @@ -//! Decomposed SSA operations. +//! Re-export shim — generic SSA ops live in `analyssa::ir::ops`. //! -//! This module defines `SsaOp`, the decomposed operation representation that -//! converts complex CIL instructions into simple `result = op(operands)` form. -//! -//! # Design Goals -//! -//! - **Single assignment**: Each operation produces at most one result -//! - **Explicit operands**: All data dependencies are explicit SSA variables -//! - **Pattern matching**: Enum variants enable easy pattern matching for analysis -//! - **Type safety**: Operations are typed where possible -//! -//! # Operation Categories -//! -//! - **Constants**: Load constant values -//! - **Arithmetic**: Binary and unary math operations -//! - **Bitwise**: And, or, xor, shifts -//! - **Comparison**: Equality and relational comparisons -//! - **Conversion**: Type conversions -//! - **Control flow**: Branches, jumps, returns -//! - **Memory**: Field, array, and indirect access -//! - **Objects**: Allocation, casting, boxing -//! - **Calls**: Method invocations -//! -//! # Field Documentation -//! -//! The struct fields in this module follow a consistent naming convention: -//! - `dest`: The destination SSA variable for the operation result -//! - `left`, `right`: Binary operands (left and right hand side) -//! - `operand`: Unary operand -//! - `value`: A value being stored or used -//! - `object`: The object instance for field/method operations -//! - `array`, `index`: Array and index for element operations -//! - `addr`: Address for indirect memory operations -//! - `target`, `true_target`, `false_target`: Branch targets (block indices) -//! - `unsigned`: Whether the operation treats values as unsigned -//! - `overflow_check`: Whether the operation checks for overflow - -#![allow(missing_docs)] - -use std::fmt; - -use crate::{ - analysis::ssa::{ - types::{FieldRef, MethodRef, SigRef, SsaType, TypeRef}, - value::ConstValue, - SsaVarId, - }, - metadata::token::Token, -}; - -/// Comparison kind for `BranchCmp` operations. -/// -/// Represents the comparison operator used in combined compare-and-branch -/// operations like `blt`, `beq`, etc. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum CmpKind { - /// Equal: `left == right` - Eq, - /// Not equal: `left != right` - Ne, - /// Less than: `left < right` - Lt, - /// Less than or equal: `left <= right` - Le, - /// Greater than: `left > right` - Gt, - /// Greater than or equal: `left >= right` - Ge, -} - -impl fmt::Display for CmpKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Eq => write!(f, "=="), - Self::Ne => write!(f, "!="), - Self::Lt => write!(f, "<"), - Self::Le => write!(f, "<="), - Self::Gt => write!(f, ">"), - Self::Ge => write!(f, ">="), - } - } -} - -/// Kind of binary operation for extracted binary op info. -/// -/// This enum categorizes all binary operations in `SsaOp` for uniform -/// handling in optimization passes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum BinaryOpKind { - /// Addition: `left + right` - Add, - /// Addition with overflow check - AddOvf, - /// Subtraction: `left - right` - Sub, - /// Subtraction with overflow check - SubOvf, - /// Multiplication: `left * right` - Mul, - /// Multiplication with overflow check - MulOvf, - /// Division: `left / right` - Div, - /// Remainder: `left % right` - Rem, - /// Bitwise AND: `left & right` - And, - /// Bitwise OR: `left | right` - Or, - /// Bitwise XOR: `left ^ right` - Xor, - /// Shift left: `value << amount` - Shl, - /// Shift right: `value >> amount` - Shr, - /// Compare equal: `left == right` - Ceq, - /// Compare less than: `left < right` - Clt, - /// Compare greater than: `left > right` - Cgt, -} - -impl fmt::Display for BinaryOpKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Add => write!(f, "add"), - Self::AddOvf => write!(f, "add.ovf"), - Self::Sub => write!(f, "sub"), - Self::SubOvf => write!(f, "sub.ovf"), - Self::Mul => write!(f, "mul"), - Self::MulOvf => write!(f, "mul.ovf"), - Self::Div => write!(f, "div"), - Self::Rem => write!(f, "rem"), - Self::And => write!(f, "and"), - Self::Or => write!(f, "or"), - Self::Xor => write!(f, "xor"), - Self::Shl => write!(f, "shl"), - Self::Shr => write!(f, "shr"), - Self::Ceq => write!(f, "ceq"), - Self::Clt => write!(f, "clt"), - Self::Cgt => write!(f, "cgt"), - } - } -} - -impl BinaryOpKind { - /// Returns `true` if this operation is commutative (`a op b == b op a`). - /// - /// Commutative operations can have their operands swapped without changing - /// the result. This is useful for normalization in optimizations like GVN. - /// - /// # Commutative Operations - /// - /// - Arithmetic: `Add`, `AddOvf`, `Mul`, `MulOvf` - /// - Bitwise: `And`, `Or`, `Xor` - /// - Comparison: `Ceq` (equality is symmetric) - #[must_use] - pub const fn is_commutative(self) -> bool { - matches!( - self, - Self::Add - | Self::AddOvf - | Self::Mul - | Self::MulOvf - | Self::And - | Self::Or - | Self::Xor - | Self::Ceq - ) - } - - /// Returns `true` if this is a comparison operation. - /// - /// Comparison operations produce a boolean result (0 or 1) based on - /// comparing two operands. - #[must_use] - pub const fn is_comparison(self) -> bool { - matches!(self, Self::Ceq | Self::Clt | Self::Cgt) - } - - /// Returns the operation with swapped operand semantics, if applicable. - /// - /// For comparison operations: - /// - `Clt` (less than) becomes `Cgt` (greater than) when operands swap - /// - `Cgt` (greater than) becomes `Clt` (less than) when operands swap - /// - `Ceq` (equal) stays the same (symmetric) - /// - /// For non-comparison operations, returns `self` unchanged. - /// - /// # Example - /// - /// ```ignore - /// // a < b is equivalent to b > a - /// assert_eq!(BinaryOpKind::Clt.swapped(), BinaryOpKind::Cgt); - /// ``` - #[must_use] - pub const fn swapped(self) -> Self { - match self { - Self::Clt => Self::Cgt, - Self::Cgt => Self::Clt, - other => other, - } - } - - /// Returns `true` if signedness affects the operation's semantics. - /// - /// Operations where the `unsigned` flag changes behavior: - /// - `Div`, `Rem`: Signed vs unsigned division/remainder - /// - `Shr`: Arithmetic (signed) vs logical (unsigned) shift - /// - `Clt`, `Cgt`: Signed vs unsigned comparison - /// - /// For other operations, the unsigned flag has no effect. - #[must_use] - pub const fn is_signedness_sensitive(self) -> bool { - matches!( - self, - Self::Div | Self::Rem | Self::Shr | Self::Clt | Self::Cgt - ) - } -} - -/// Information about a binary operation extracted from an `SsaOp`. -/// -/// This provides a uniform view of binary operations for optimization passes, -/// allowing them to handle all binary ops generically without matching on -/// each variant individually. -/// -/// # Example -/// -/// ```ignore -/// if let Some(info) = op.as_binary_op() { -/// // Handle all binary ops uniformly -/// println!("{} = {} {} {}", info.dest, info.left, info.kind, info.right); -/// } -/// ``` -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct BinaryOpInfo { - /// The kind of binary operation. - pub kind: BinaryOpKind, - /// Destination variable for the result. - pub dest: SsaVarId, - /// Left operand. - pub left: SsaVarId, - /// Right operand. - pub right: SsaVarId, - /// Whether the operation treats operands as unsigned. - pub unsigned: bool, -} +//! `referenced_token` is the only CIL-specific method; it lives here as the +//! [`SsaOpCilExt`] extension trait (orphan rule prevents inherent impls on +//! foreign types). -impl BinaryOpInfo { - /// Returns a normalized version of this operation for value numbering. - /// - /// For commutative operations, this ensures operands are in a canonical - /// order (smaller variable index first). For non-commutative comparisons - /// like `Clt` and `Cgt`, swapping operands also swaps the operation kind. - /// - /// This is useful for Global Value Numbering (GVN) where `a + b` and `b + a` - /// should hash to the same value. - /// - /// # Example - /// - /// ```ignore - /// let info = BinaryOpInfo { kind: BinaryOpKind::Add, left: v5, right: v2, ... }; - /// let normalized = info.normalized(); - /// // normalized.left = v2, normalized.right = v5 (swapped for canonical order) - /// ``` - #[must_use] - pub fn normalized(self) -> Self { - // Only normalize if right operand should come first - if self.right.index() < self.left.index() { - if self.kind.is_commutative() { - // Commutative: just swap operands - Self { - left: self.right, - right: self.left, - ..self - } - } else if self.kind.is_comparison() { - // Non-commutative comparison: swap operands AND operation - Self { - kind: self.kind.swapped(), - left: self.right, - right: self.left, - ..self - } - } else { - // Non-commutative, non-comparison: don't normalize - self - } - } else { - self - } - } +use analyssa::ir::ops::SsaOp as AnalyssaSsaOp; - /// Returns a tuple suitable for use as a hash key in value numbering. - /// - /// The tuple includes all semantically relevant fields: - /// - Operation kind - /// - Unsigned flag (only if the operation is signedness-sensitive) - /// - Left and right operands - /// - /// For operations where signedness doesn't matter, the unsigned field - /// is normalized to `false` to ensure consistent hashing. - #[must_use] - pub fn value_key(self) -> (BinaryOpKind, bool, SsaVarId, SsaVarId) { - let unsigned = if self.kind.is_signedness_sensitive() { - self.unsigned - } else { - false // Normalize for consistent hashing - }; - (self.kind, unsigned, self.left, self.right) - } -} +use crate::{analysis::ssa::target::CilTarget, metadata::token::Token}; -/// Kind of unary operation for extracted unary op info. -/// -/// This enum categorizes all unary operations in `SsaOp` for uniform -/// handling in optimization passes. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum UnaryOpKind { - /// Negation: `-operand` - Neg, - /// Bitwise NOT: `~operand` - Not, - /// Check finite (throws if NaN or infinity) - Ckfinite, -} +// `BinaryOpInfo`/`UnaryOpInfo` aren't re-exported (the original dotscope +// `ops.rs` didn't surface them either; direct callers go through +// `analyssa::ir::ops` if they need them). +pub use analyssa::ir::ops::{BinaryOpKind, CmpKind, UnaryOpKind}; -impl fmt::Display for UnaryOpKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Neg => write!(f, "neg"), - Self::Not => write!(f, "not"), - Self::Ckfinite => write!(f, "ckfinite"), - } - } -} +/// CIL-defaulted alias of `analyssa::ir::ops::SsaOp`. +pub type SsaOp = AnalyssaSsaOp; -/// Information about a unary operation extracted from an `SsaOp`. +/// CIL-specific extension methods on `SsaOp`. /// -/// This provides a uniform view of unary operations for optimization passes, -/// allowing them to handle all unary ops generically without matching on -/// each variant individually. -/// -/// # Example -/// -/// ```ignore -/// if let Some(info) = op.as_unary_op() { -/// // Handle all unary ops uniformly -/// println!("{} = {} {}", info.dest, info.kind, info.operand); -/// } -/// ``` -#[derive(Debug, Clone, Copy, PartialEq)] -pub struct UnaryOpInfo { - /// The kind of unary operation. - pub kind: UnaryOpKind, - /// Destination variable for the result. - pub dest: SsaVarId, - /// The operand. - pub operand: SsaVarId, -} - -/// A decomposed SSA operation. -/// -/// Each variant represents a single operation with explicit inputs and outputs. -/// This enables clean pattern matching for optimization and analysis passes. -/// -/// # Conventions -/// -/// - For operations that produce a result, the first `SsaVarId` is the destination -/// - Operands follow in the order they appear on the CIL stack (first pushed = first operand) -/// - Optional results use `Option` (e.g., calls that may not return a value) -#[derive(Debug, Clone, PartialEq)] -pub enum SsaOp { - /// Load a constant value. - /// - /// `dest = const value` - Const { dest: SsaVarId, value: ConstValue }, - - /// Addition: `dest = left + right` - Add { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Addition with overflow check: `dest = left + right` (throws on overflow) - AddOvf { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Subtraction: `dest = left - right` - Sub { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Subtraction with overflow check: `dest = left - right` - SubOvf { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Multiplication: `dest = left * right` - Mul { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Multiplication with overflow check: `dest = left * right` - MulOvf { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Division: `dest = left / right` - Div { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Remainder: `dest = left % right` - Rem { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Negation: `dest = -operand` - Neg { dest: SsaVarId, operand: SsaVarId }, - - /// Bitwise AND: `dest = left & right` - And { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Bitwise OR: `dest = left | right` - Or { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Bitwise XOR: `dest = left ^ right` - Xor { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Bitwise NOT: `dest = ~operand` - Not { dest: SsaVarId, operand: SsaVarId }, - - /// Shift left: `dest = value << amount` - Shl { - dest: SsaVarId, - value: SsaVarId, - amount: SsaVarId, - }, - - /// Shift right: `dest = value >> amount` - Shr { - dest: SsaVarId, - value: SsaVarId, - amount: SsaVarId, - unsigned: bool, - }, - - /// Compare equal: `dest = (left == right) ? 1 : 0` - Ceq { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - }, - - /// Compare less than: `dest = (left < right) ? 1 : 0` - Clt { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Compare greater than: `dest = (left > right) ? 1 : 0` - Cgt { - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - }, - - /// Type conversion: `dest = (target_type)operand` - Conv { - dest: SsaVarId, - operand: SsaVarId, - target: SsaType, - overflow_check: bool, - unsigned: bool, - }, - - /// Unconditional jump to a block. - Jump { target: usize }, - - /// Conditional branch: if condition is true, go to true_target, else false_target. - Branch { - condition: SsaVarId, - true_target: usize, - false_target: usize, - }, - - /// Compare and branch: if (left cmp right) goto true_target else false_target. - /// - /// This represents CIL comparison branch instructions like `beq`, `blt`, `bgt`, etc. - /// These are combined compare-and-branch operations that don't produce an intermediate - /// comparison result. - BranchCmp { - left: SsaVarId, - right: SsaVarId, - cmp: CmpKind, - unsigned: bool, - true_target: usize, - false_target: usize, - }, - - /// Switch statement: jump to `targets[value]` or default if out of range. - Switch { - value: SsaVarId, - targets: Vec, - default: usize, - }, - - /// Return from method with optional value. - Return { value: Option }, - - /// Load instance field: `dest = object.field` - LoadField { - dest: SsaVarId, - object: SsaVarId, - field: FieldRef, - }, - - /// Store instance field: `object.field = value` - StoreField { - object: SsaVarId, - field: FieldRef, - value: SsaVarId, - }, - - /// Load static field: `dest = ClassName.field` - LoadStaticField { dest: SsaVarId, field: FieldRef }, - - /// Store static field: `ClassName.field = value` - StoreStaticField { field: FieldRef, value: SsaVarId }, - - /// Load field address: `dest = &object.field` - LoadFieldAddr { - dest: SsaVarId, - object: SsaVarId, - field: FieldRef, - }, - - /// Load static field address: `dest = &ClassName.field` - LoadStaticFieldAddr { dest: SsaVarId, field: FieldRef }, - - /// Load array element: `dest = array[index]` - LoadElement { - dest: SsaVarId, - array: SsaVarId, - index: SsaVarId, - elem_type: SsaType, - }, - - /// Store array element: `array[index] = value` - StoreElement { - array: SsaVarId, - index: SsaVarId, - value: SsaVarId, - elem_type: SsaType, - }, - - /// Load array element address: `dest = &array[index]` - LoadElementAddr { - dest: SsaVarId, - array: SsaVarId, - index: SsaVarId, - elem_type: TypeRef, - }, - - /// Get array length: `dest = array.Length` - ArrayLength { dest: SsaVarId, array: SsaVarId }, - - /// Load through pointer: `dest = *ptr` - LoadIndirect { - dest: SsaVarId, - addr: SsaVarId, - value_type: SsaType, - }, - - /// Store through pointer: `*ptr = value` - StoreIndirect { - addr: SsaVarId, - value: SsaVarId, - value_type: SsaType, - }, - - /// Create new object: `dest = new Type(args...)` - NewObj { - dest: SsaVarId, - ctor: MethodRef, - args: Vec, - }, - - /// Create new array: `dest = new Type[length]` - NewArr { - dest: SsaVarId, - elem_type: TypeRef, - length: SsaVarId, - }, - - /// Cast object to type (throws if invalid): `dest = (Type)obj` - CastClass { - dest: SsaVarId, - object: SsaVarId, - target_type: TypeRef, - }, - - /// Type check (returns null if invalid): `dest = obj as Type` - IsInst { - dest: SsaVarId, - object: SsaVarId, - target_type: TypeRef, - }, - - /// Box value type: `dest = (object)value` - Box { - dest: SsaVarId, - value: SsaVarId, - value_type: TypeRef, - }, - - /// Unbox to pointer: `dest = &((ValueType)obj)` - Unbox { - dest: SsaVarId, - object: SsaVarId, - value_type: TypeRef, - }, - - /// Unbox and copy: `dest = (ValueType)obj` - UnboxAny { - dest: SsaVarId, - object: SsaVarId, - value_type: TypeRef, - }, - - /// Get size of value type: `dest = sizeof(Type)` - SizeOf { dest: SsaVarId, value_type: TypeRef }, - - /// Load runtime type token: `dest = typeof(Type).TypeHandle` - LoadToken { dest: SsaVarId, token: TypeRef }, - - /// Direct method call: `dest = method(args...)` - Call { - dest: Option, - method: MethodRef, - args: Vec, - }, - - /// Virtual method call: `dest = obj.method(args...)` - CallVirt { - dest: Option, - method: MethodRef, - args: Vec, - }, - - /// Indirect call through function pointer: `dest = fptr(args...)` - CallIndirect { - dest: Option, - fptr: SsaVarId, - signature: SigRef, - args: Vec, - }, - - /// Load function pointer: `dest = &method` - LoadFunctionPtr { dest: SsaVarId, method: MethodRef }, - - /// Load virtual function pointer: `dest = &obj.method` - LoadVirtFunctionPtr { - dest: SsaVarId, - object: SsaVarId, - method: MethodRef, - }, - - /// Load argument value: `dest = argN` - LoadArg { dest: SsaVarId, arg_index: u16 }, - - /// Load local value: `dest = localN` - LoadLocal { dest: SsaVarId, local_index: u16 }, - - /// Load argument address: `dest = &argN` - LoadArgAddr { dest: SsaVarId, arg_index: u16 }, - - /// Load local address: `dest = &localN` - LoadLocalAddr { dest: SsaVarId, local_index: u16 }, - - /// Copy value (from dup): `dest = src` - Copy { dest: SsaVarId, src: SsaVarId }, - - /// Pop value from stack (value is discarded, but we track the use) - Pop { value: SsaVarId }, - - /// Throw exception: `throw obj` - Throw { exception: SsaVarId }, - - /// Rethrow current exception (in catch handler) - Rethrow, - - /// End finally block - EndFinally, - - /// End filter block with result - EndFilter { result: SsaVarId }, - - /// Leave protected region - Leave { target: usize }, - - /// Initialize block of memory to zero - InitBlk { - dest_addr: SsaVarId, - value: SsaVarId, - size: SsaVarId, - }, - - /// Copy block of memory - CopyBlk { - dest_addr: SsaVarId, - src_addr: SsaVarId, - size: SsaVarId, - }, - - /// Initialize object (for value types): `*dest = default(T)` - InitObj { - dest_addr: SsaVarId, - value_type: TypeRef, - }, - - /// Copy object (for value types): `*dest = *src` - CopyObj { - dest_addr: SsaVarId, - src_addr: SsaVarId, - value_type: TypeRef, - }, - - /// Load object (value type copy): `dest = *src` - LoadObj { - dest: SsaVarId, - src_addr: SsaVarId, - value_type: TypeRef, - }, - - /// Store object (value type copy): `*dest = value` - StoreObj { - dest_addr: SsaVarId, - value: SsaVarId, - value_type: TypeRef, - }, - - /// No operation (for nop instructions) - Nop, - - /// Breakpoint trap - Break, - - /// Check for finite floating point: throws if not finite - Ckfinite { dest: SsaVarId, operand: SsaVarId }, - - /// Localloc: allocate stack space - LocalAlloc { dest: SsaVarId, size: SsaVarId }, - - /// Constrained virtual call prefix (affects next callvirt) - Constrained { constraint_type: TypeRef }, - - /// Volatile prefix (next memory access must not be reordered/cached) - Volatile, - - /// Unaligned prefix (next memory access may be unaligned) - Unaligned { alignment: u8 }, - - /// Tail call prefix (next call is a tail call) - TailPrefix, - - /// Readonly prefix (next ldelema returns a controlled-mutability managed pointer) - Readonly, - - /// Phi node: merges values from different predecessors. - /// - /// This is placed at the beginning of blocks with multiple predecessors. - Phi { - dest: SsaVarId, - operands: Vec<(usize, SsaVarId)>, - }, -} - -impl SsaOp { - /// Returns the destination variable if this operation produces one. - #[must_use] - pub fn dest(&self) -> Option { - match self { - Self::Const { dest, .. } - | Self::Add { dest, .. } - | Self::AddOvf { dest, .. } - | Self::Sub { dest, .. } - | Self::SubOvf { dest, .. } - | Self::Mul { dest, .. } - | Self::MulOvf { dest, .. } - | Self::Div { dest, .. } - | Self::Rem { dest, .. } - | Self::Neg { dest, .. } - | Self::And { dest, .. } - | Self::Or { dest, .. } - | Self::Xor { dest, .. } - | Self::Not { dest, .. } - | Self::Shl { dest, .. } - | Self::Shr { dest, .. } - | Self::Ceq { dest, .. } - | Self::Clt { dest, .. } - | Self::Cgt { dest, .. } - | Self::Conv { dest, .. } - | Self::LoadField { dest, .. } - | Self::LoadStaticField { dest, .. } - | Self::LoadFieldAddr { dest, .. } - | Self::LoadStaticFieldAddr { dest, .. } - | Self::LoadElement { dest, .. } - | Self::LoadElementAddr { dest, .. } - | Self::ArrayLength { dest, .. } - | Self::LoadIndirect { dest, .. } - | Self::NewObj { dest, .. } - | Self::NewArr { dest, .. } - | Self::CastClass { dest, .. } - | Self::IsInst { dest, .. } - | Self::Box { dest, .. } - | Self::Unbox { dest, .. } - | Self::UnboxAny { dest, .. } - | Self::SizeOf { dest, .. } - | Self::LoadToken { dest, .. } - | Self::LoadFunctionPtr { dest, .. } - | Self::LoadVirtFunctionPtr { dest, .. } - | Self::LoadArg { dest, .. } - | Self::LoadLocal { dest, .. } - | Self::LoadArgAddr { dest, .. } - | Self::LoadLocalAddr { dest, .. } - | Self::Copy { dest, .. } - | Self::Ckfinite { dest, .. } - | Self::LocalAlloc { dest, .. } - | Self::LoadObj { dest, .. } - | Self::Phi { dest, .. } => Some(*dest), - - Self::Call { dest, .. } - | Self::CallVirt { dest, .. } - | Self::CallIndirect { dest, .. } => *dest, - - // Operations that don't produce a result - Self::StoreField { .. } - | Self::StoreStaticField { .. } - | Self::StoreElement { .. } - | Self::StoreIndirect { .. } - | Self::Jump { .. } - | Self::Branch { .. } - | Self::BranchCmp { .. } - | Self::Switch { .. } - | Self::Return { .. } - | Self::Pop { .. } - | Self::Throw { .. } - | Self::Rethrow - | Self::EndFinally - | Self::EndFilter { .. } - | Self::Leave { .. } - | Self::InitBlk { .. } - | Self::CopyBlk { .. } - | Self::InitObj { .. } - | Self::CopyObj { .. } - | Self::StoreObj { .. } - | Self::Nop - | Self::Break - | Self::Constrained { .. } - | Self::Volatile - | Self::Unaligned { .. } - | Self::TailPrefix - | Self::Readonly => None, - } - } - - /// Sets the destination variable for operations that produce a result. - /// - /// This is used during SSA renaming to update the dest after assigning - /// new SSA variable IDs. Returns `true` if the dest was updated. - /// - /// # Arguments - /// - /// * `new_dest` - The new destination variable ID - pub fn set_dest(&mut self, new_dest: SsaVarId) -> bool { - match self { - Self::Const { dest, .. } - | Self::Add { dest, .. } - | Self::AddOvf { dest, .. } - | Self::Sub { dest, .. } - | Self::SubOvf { dest, .. } - | Self::Mul { dest, .. } - | Self::MulOvf { dest, .. } - | Self::Div { dest, .. } - | Self::Rem { dest, .. } - | Self::Neg { dest, .. } - | Self::And { dest, .. } - | Self::Or { dest, .. } - | Self::Xor { dest, .. } - | Self::Not { dest, .. } - | Self::Shl { dest, .. } - | Self::Shr { dest, .. } - | Self::Ceq { dest, .. } - | Self::Clt { dest, .. } - | Self::Cgt { dest, .. } - | Self::Conv { dest, .. } - | Self::LoadField { dest, .. } - | Self::LoadStaticField { dest, .. } - | Self::LoadFieldAddr { dest, .. } - | Self::LoadStaticFieldAddr { dest, .. } - | Self::LoadElement { dest, .. } - | Self::LoadElementAddr { dest, .. } - | Self::ArrayLength { dest, .. } - | Self::LoadIndirect { dest, .. } - | Self::NewObj { dest, .. } - | Self::NewArr { dest, .. } - | Self::CastClass { dest, .. } - | Self::IsInst { dest, .. } - | Self::Box { dest, .. } - | Self::Unbox { dest, .. } - | Self::UnboxAny { dest, .. } - | Self::SizeOf { dest, .. } - | Self::LoadToken { dest, .. } - | Self::LoadFunctionPtr { dest, .. } - | Self::LoadVirtFunctionPtr { dest, .. } - | Self::LoadArg { dest, .. } - | Self::LoadLocal { dest, .. } - | Self::LoadArgAddr { dest, .. } - | Self::LoadLocalAddr { dest, .. } - | Self::Copy { dest, .. } - | Self::Ckfinite { dest, .. } - | Self::LocalAlloc { dest, .. } - | Self::LoadObj { dest, .. } - | Self::Phi { dest, .. } => { - *dest = new_dest; - true - } - - Self::Call { dest, .. } - | Self::CallVirt { dest, .. } - | Self::CallIndirect { dest, .. } => { - *dest = Some(new_dest); - true - } - - // Operations that don't produce a result - cannot set dest - Self::StoreField { .. } - | Self::StoreStaticField { .. } - | Self::StoreElement { .. } - | Self::StoreIndirect { .. } - | Self::Jump { .. } - | Self::Branch { .. } - | Self::BranchCmp { .. } - | Self::Switch { .. } - | Self::Return { .. } - | Self::Pop { .. } - | Self::Throw { .. } - | Self::Rethrow - | Self::EndFinally - | Self::EndFilter { .. } - | Self::Leave { .. } - | Self::InitBlk { .. } - | Self::CopyBlk { .. } - | Self::InitObj { .. } - | Self::CopyObj { .. } - | Self::StoreObj { .. } - | Self::Nop - | Self::Break - | Self::Constrained { .. } - | Self::Volatile - | Self::Unaligned { .. } - | Self::TailPrefix - | Self::Readonly => false, - } - } - - /// Returns all variables used by this operation. - #[must_use] - #[allow(clippy::match_same_arms)] // Kept separate for clarity by operation category - pub fn uses(&self) -> Vec { - match self { - Self::Const { .. } => vec![], - - Self::Add { left, right, .. } - | Self::AddOvf { left, right, .. } - | Self::Sub { left, right, .. } - | Self::SubOvf { left, right, .. } - | Self::Mul { left, right, .. } - | Self::MulOvf { left, right, .. } - | Self::Div { left, right, .. } - | Self::Rem { left, right, .. } - | Self::And { left, right, .. } - | Self::Or { left, right, .. } - | Self::Xor { left, right, .. } - | Self::Ceq { left, right, .. } - | Self::Clt { left, right, .. } - | Self::Cgt { left, right, .. } => vec![*left, *right], - - Self::Shl { value, amount, .. } | Self::Shr { value, amount, .. } => { - vec![*value, *amount] - } - - Self::Neg { operand, .. } - | Self::Not { operand, .. } - | Self::Conv { operand, .. } - | Self::Ckfinite { operand, .. } => vec![*operand], - - Self::Branch { condition, .. } => vec![*condition], - Self::BranchCmp { left, right, .. } => vec![*left, *right], - Self::Switch { value, .. } => vec![*value], - Self::Return { value } => value.iter().copied().collect(), - - Self::LoadField { object, .. } => vec![*object], - Self::StoreField { object, value, .. } => vec![*object, *value], - Self::LoadStaticField { .. } => vec![], - Self::StoreStaticField { value, .. } => vec![*value], - Self::LoadFieldAddr { object, .. } => vec![*object], - Self::LoadStaticFieldAddr { .. } => vec![], - - Self::LoadElement { array, index, .. } | Self::LoadElementAddr { array, index, .. } => { - vec![*array, *index] - } - Self::StoreElement { - array, - index, - value, - .. - } => vec![*array, *index, *value], - Self::ArrayLength { array, .. } => vec![*array], - - Self::LoadIndirect { addr, .. } => vec![*addr], - Self::StoreIndirect { addr, value, .. } => vec![*addr, *value], - - Self::NewObj { args, .. } => args.clone(), - Self::NewArr { length, .. } => vec![*length], - Self::CastClass { object, .. } - | Self::IsInst { object, .. } - | Self::Unbox { object, .. } - | Self::UnboxAny { object, .. } => vec![*object], - Self::Box { value, .. } => vec![*value], - Self::SizeOf { .. } | Self::LoadToken { .. } => vec![], - - Self::Call { args, .. } | Self::CallVirt { args, .. } => args.clone(), - Self::CallIndirect { fptr, args, .. } => { - let mut uses = vec![*fptr]; - uses.extend(args); - uses - } - - Self::LoadFunctionPtr { .. } => vec![], - Self::LoadVirtFunctionPtr { object, .. } => vec![*object], - - Self::LoadArg { .. } - | Self::LoadLocal { .. } - | Self::LoadArgAddr { .. } - | Self::LoadLocalAddr { .. } => vec![], - - Self::Copy { src, .. } => vec![*src], - Self::Pop { value } => vec![*value], - - Self::Throw { exception } => vec![*exception], - Self::EndFilter { result } => vec![*result], - - Self::InitBlk { - dest_addr, - value, - size, - } - | Self::CopyBlk { - dest_addr, - src_addr: value, - size, - } => vec![*dest_addr, *value, *size], - - Self::InitObj { dest_addr, .. } => vec![*dest_addr], - Self::CopyObj { - dest_addr, - src_addr, - .. - } => vec![*dest_addr, *src_addr], - Self::LoadObj { src_addr, .. } => vec![*src_addr], - Self::StoreObj { - dest_addr, value, .. - } => vec![*dest_addr, *value], - - Self::LocalAlloc { size, .. } => vec![*size], - - Self::Phi { operands, .. } => operands.iter().map(|(_, v)| *v).collect(), - - Self::Jump { .. } - | Self::Rethrow - | Self::EndFinally - | Self::Leave { .. } - | Self::Nop - | Self::Break - | Self::Constrained { .. } - | Self::Volatile - | Self::Unaligned { .. } - | Self::TailPrefix - | Self::Readonly => vec![], - } - } - - /// Returns `true` if this operation is a terminator (ends a basic block). - #[must_use] - pub const fn is_terminator(&self) -> bool { - matches!( - self, - Self::Jump { .. } - | Self::Branch { .. } - | Self::BranchCmp { .. } - | Self::Switch { .. } - | Self::Return { .. } - | Self::Throw { .. } - | Self::Rethrow - | Self::Leave { .. } - | Self::EndFinally - | Self::EndFilter { .. } - ) - } - - /// Returns `true` if this operation may throw an exception. - #[must_use] - pub const fn may_throw(&self) -> bool { - matches!( - self, - Self::Div { .. } - | Self::Rem { .. } - | Self::AddOvf { .. } - | Self::SubOvf { .. } - | Self::MulOvf { .. } - | Self::Conv { - overflow_check: true, - .. - } - | Self::LoadField { .. } - | Self::StoreField { .. } - | Self::LoadElement { .. } - | Self::StoreElement { .. } - | Self::LoadElementAddr { .. } - | Self::LoadIndirect { .. } - | Self::StoreIndirect { .. } - | Self::NewObj { .. } - | Self::NewArr { .. } - | Self::CastClass { .. } - | Self::Unbox { .. } - | Self::UnboxAny { .. } - | Self::Call { .. } - | Self::CallVirt { .. } - | Self::CallIndirect { .. } - | Self::Throw { .. } - | Self::Ckfinite { .. } - ) - } - - /// Returns `true` if this operation is pure (has no side effects). - /// - /// Pure operations can be eliminated if their result is unused. - #[must_use] - pub const fn is_pure(&self) -> bool { - matches!( - self, - Self::Const { .. } - | Self::Add { .. } - | Self::Sub { .. } - | Self::Mul { .. } - | Self::Neg { .. } - | Self::And { .. } - | Self::Or { .. } - | Self::Xor { .. } - | Self::Not { .. } - | Self::Shl { .. } - | Self::Shr { .. } - | Self::Ceq { .. } - | Self::Clt { .. } - | Self::Cgt { .. } - | Self::Conv { - overflow_check: false, - .. - } - | Self::Copy { .. } - | Self::SizeOf { .. } - | Self::LoadToken { .. } - | Self::LoadArg { .. } - | Self::LoadLocal { .. } - | Self::LoadArgAddr { .. } - | Self::LoadLocalAddr { .. } - | Self::Phi { .. } - | Self::Nop - | Self::Pop { .. } - ) - } - - /// Replaces all uses of `old_var` with `new_var` in this operation. - /// - /// This is used for copy propagation and other variable substitution transformations. - /// - /// # Arguments - /// - /// * `old_var` - The variable to replace. - /// * `new_var` - The variable to use instead. - /// - /// # Returns - /// - /// The number of replacements made. - pub fn replace_uses(&mut self, old_var: SsaVarId, new_var: SsaVarId) -> usize { - let mut count: usize = 0; - - // Helper closure to replace a variable - let mut replace = |var: &mut SsaVarId| { - if *var == old_var { - *var = new_var; - count = count.saturating_add(1); - } - }; - - match self { - // Binary arithmetic and comparison branches - Self::Add { left, right, .. } - | Self::AddOvf { left, right, .. } - | Self::Sub { left, right, .. } - | Self::SubOvf { left, right, .. } - | Self::Mul { left, right, .. } - | Self::MulOvf { left, right, .. } - | Self::Div { left, right, .. } - | Self::Rem { left, right, .. } - | Self::And { left, right, .. } - | Self::Or { left, right, .. } - | Self::Xor { left, right, .. } - | Self::Ceq { left, right, .. } - | Self::Clt { left, right, .. } - | Self::Cgt { left, right, .. } - | Self::BranchCmp { left, right, .. } => { - replace(left); - replace(right); - } - - // Unary operations and conversion - Self::Neg { operand, .. } - | Self::Not { operand, .. } - | Self::Ckfinite { operand, .. } - | Self::Conv { operand, .. } => { - replace(operand); - } - - // Shift operations - Self::Shl { value, amount, .. } | Self::Shr { value, amount, .. } => { - replace(value); - replace(amount); - } - - // Copy operation - Self::Copy { src, .. } => { - replace(src); - } - - // Control flow - Self::Branch { condition, .. } => { - replace(condition); - } - Self::Switch { value, .. } - | Self::StoreStaticField { value, .. } - | Self::Pop { value } => { - replace(value); - } - Self::Return { value: Some(v) } => { - replace(v); - } - - // Object/field operations - Self::LoadField { object, .. } - | Self::LoadFieldAddr { object, .. } - | Self::CastClass { object, .. } - | Self::IsInst { object, .. } - | Self::Box { value: object, .. } - | Self::Unbox { object, .. } - | Self::UnboxAny { object, .. } - | Self::LoadVirtFunctionPtr { object, .. } => { - replace(object); - } - Self::StoreField { object, value, .. } => { - replace(object); - replace(value); - } - - // Array operations - Self::LoadElement { array, index, .. } | Self::LoadElementAddr { array, index, .. } => { - replace(array); - replace(index); - } - Self::StoreElement { - array, - index, - value, - .. - } => { - replace(array); - replace(index); - replace(value); - } - Self::NewArr { length, .. } => { - replace(length); - } - Self::ArrayLength { array, .. } => { - replace(array); - } - - // Indirect load/store - Self::LoadIndirect { addr, .. } => { - replace(addr); - } - Self::StoreIndirect { addr, value, .. } => { - replace(addr); - replace(value); - } - - // Calls - Self::Call { args, .. } | Self::CallVirt { args, .. } | Self::NewObj { args, .. } => { - for arg in args { - replace(arg); - } - } - Self::CallIndirect { fptr, args, .. } => { - replace(fptr); - for arg in args { - replace(arg); - } - } - - // Other - Self::Throw { exception } => { - replace(exception); - } - Self::EndFilter { result } => { - replace(result); - } - Self::Phi { operands, .. } => { - for (_, operand) in operands { - replace(operand); - } - } - Self::StoreObj { - dest_addr, value, .. - } => { - replace(dest_addr); - replace(value); - } - Self::LoadObj { src_addr, .. } => { - replace(src_addr); - } - Self::LocalAlloc { size, .. } => { - replace(size); - } - Self::InitObj { dest_addr, .. } => { - replace(dest_addr); - } - Self::CopyObj { - dest_addr, - src_addr, - .. - } => { - replace(dest_addr); - replace(src_addr); - } - Self::CopyBlk { - dest_addr, - src_addr, - size, - } => { - replace(dest_addr); - replace(src_addr); - replace(size); - } - Self::InitBlk { - dest_addr, - value, - size, - } => { - replace(dest_addr); - replace(value); - replace(size); - } - - // Operations without variable uses - Self::Const { .. } - | Self::LoadStaticField { .. } - | Self::LoadStaticFieldAddr { .. } - | Self::Jump { .. } - | Self::Return { value: None } - | Self::Rethrow - | Self::EndFinally - | Self::Leave { .. } - | Self::SizeOf { .. } - | Self::LoadToken { .. } - | Self::LoadArg { .. } - | Self::LoadLocal { .. } - | Self::LoadArgAddr { .. } - | Self::LoadLocalAddr { .. } - | Self::LoadFunctionPtr { .. } - | Self::Nop - | Self::Break - | Self::Constrained { .. } - | Self::Volatile - | Self::Unaligned { .. } - | Self::TailPrefix - | Self::Readonly => {} - } - - count - } - - /// Remaps branch target block indices using the provided mapping function. - /// - /// This is used to translate RVA-based targets (from CIL instructions) to - /// sequential block indices (used by the SSA representation). - /// - /// # Arguments - /// - /// * `remap` - A function that maps old block indices to new block indices. - /// Returns `None` if the target should remain unchanged. - pub fn remap_branch_targets(&mut self, remap: F) - where - F: Fn(usize) -> Option, - { - match self { - Self::Jump { target } | Self::Leave { target } => { - if let Some(new_target) = remap(*target) { - *target = new_target; - } - } - Self::Branch { - true_target, - false_target, - .. - } - | Self::BranchCmp { - true_target, - false_target, - .. - } => { - if let Some(new_target) = remap(*true_target) { - *true_target = new_target; - } - if let Some(new_target) = remap(*false_target) { - *false_target = new_target; - } - } - Self::Switch { - targets, default, .. - } => { - for target in targets.iter_mut() { - if let Some(new_target) = remap(*target) { - *target = new_target; - } - } - if let Some(new_target) = remap(*default) { - *default = new_target; - } - } - // All other operations don't have branch targets - _ => {} - } - } - - /// Returns the successor block indices for this operation. - /// - /// For control flow operations (terminators), this returns the indices of - /// all possible successor blocks: - /// - `Jump` and `Leave`: single target block - /// - `Branch`: true and false target blocks - /// - `Switch`: all case targets plus the default target - /// - /// For non-terminator operations, returns an empty vector. - /// - /// # Returns - /// - /// A vector of successor block indices. Empty for non-branching operations. - /// - /// # Example - /// - /// ```ignore - /// let op = SsaOp::Branch { - /// condition: var, - /// true_target: 1, - /// false_target: 2, - /// }; - /// assert_eq!(op.successors(), vec![1, 2]); - /// ``` - #[must_use] - pub fn successors(&self) -> Vec { - match self { - Self::Jump { target } | Self::Leave { target } => vec![*target], - Self::Branch { - true_target, - false_target, - .. - } - | Self::BranchCmp { - true_target, - false_target, - .. - } => vec![*true_target, *false_target], - Self::Switch { - targets, default, .. - } => { - let mut succs = targets.clone(); - succs.push(*default); - succs - } - // Return, Throw, Rethrow, EndFinally, EndFilter have no successors - _ => vec![], - } - } - - /// Redirects control flow targets from `old_target` to `new_target`. - /// - /// This method modifies branch/jump targets in-place. It handles all control - /// flow operations: `Jump`, `Leave`, `Branch`, `BranchCmp`, and `Switch`. - /// - /// # Arguments - /// - /// * `old_target` - The block index to redirect from - /// * `new_target` - The block index to redirect to - /// - /// # Returns - /// - /// `true` if any target was changed, `false` otherwise. - /// - /// # Example - /// - /// ```ignore - /// // Redirect all jumps to block 2 to instead go to block 5 - /// if op.redirect_target(2, 5) { - /// println!("Target redirected"); - /// } - /// ``` - pub fn redirect_target(&mut self, old_target: usize, new_target: usize) -> bool { - if old_target == new_target { - return false; - } - - match self { - Self::Jump { target } | Self::Leave { target } if *target == old_target => { - *target = new_target; - true - } - Self::Branch { - true_target, - false_target, - .. - } - | Self::BranchCmp { - true_target, - false_target, - .. - } => { - let mut changed = false; - if *true_target == old_target { - *true_target = new_target; - changed = true; - } - if *false_target == old_target { - *false_target = new_target; - changed = true; - } - changed - } - Self::Switch { - targets, default, .. - } => { - let mut changed = false; - if *default == old_target { - *default = new_target; - changed = true; - } - for target in targets.iter_mut() { - if *target == old_target { - *target = new_target; - changed = true; - } - } - changed - } - _ => false, - } - } - +/// Import this trait to call `op.referenced_token()` as before. +pub trait SsaOpCilExt { /// Returns the metadata token referenced by this operation, if any. /// - /// This extracts the token from operations that reference metadata entities - /// such as methods, fields, or types. Used for cleanup operations to identify + /// Extracts the token from operations that reference metadata entities + /// (methods, fields, types). Used for cleanup operations to identify /// which SSA operations reference tokens that are being removed. - /// - /// # Returns - /// - /// - `Some(Token)` if the operation references a method, field, or type token - /// - `None` if the operation doesn't reference any metadata token - /// - /// # Operations That Return Tokens - /// - /// - `Call`, `CallVirt`: Method token - /// - `NewObj`: Constructor method token - /// - `LoadField`, `StoreField`, `LoadFieldAddr`: Field token - /// - `LoadStaticField`, `StoreStaticField`, `LoadStaticFieldAddr`: Field token - /// - `Box`, `Unbox`, `UnboxAny`, `InitObj`, `SizeOf`: Value type token - /// - `IsInst`, `CastClass`: Target type token - /// - `NewArr`: Element type token - /// - `LoadToken`: The loaded token - #[must_use] - pub fn referenced_token(&self) -> Option { - match self { - Self::Call { method, .. } - | Self::CallVirt { method, .. } - | Self::LoadFunctionPtr { method, .. } - | Self::LoadVirtFunctionPtr { method, .. } => Some(method.token()), - Self::NewObj { ctor, .. } => Some(ctor.token()), - Self::LoadField { field, .. } - | Self::StoreField { field, .. } - | Self::LoadFieldAddr { field, .. } - | Self::LoadStaticField { field, .. } - | Self::StoreStaticField { field, .. } - | Self::LoadStaticFieldAddr { field, .. } => Some(field.token()), - Self::Box { value_type, .. } - | Self::Unbox { value_type, .. } - | Self::UnboxAny { value_type, .. } - | Self::InitObj { value_type, .. } - | Self::SizeOf { value_type, .. } - | Self::CopyObj { value_type, .. } - | Self::LoadObj { value_type, .. } - | Self::StoreObj { value_type, .. } => Some(value_type.token()), - Self::IsInst { target_type, .. } | Self::CastClass { target_type, .. } => { - Some(target_type.token()) - } - Self::NewArr { elem_type, .. } | Self::LoadElementAddr { elem_type, .. } => { - Some(elem_type.token()) - } - Self::LoadToken { token, .. } => Some(token.token()), - Self::Constrained { constraint_type } => Some(constraint_type.token()), - _ => None, - } - } - - /// Creates a clone of this operation with all variable IDs remapped. - /// - /// This is used for block duplication where all variable references - /// (both destinations and uses) need to be updated to fresh IDs. - /// - /// # Arguments - /// - /// * `remap` - A function that maps old variable IDs to new ones. - /// If the function returns `None`, the original ID is kept. - /// - /// # Returns - /// - /// A new `SsaOp` with all variable IDs remapped. - #[must_use] - pub fn remap_variables(&self, remap: F) -> Self - where - F: Fn(SsaVarId) -> Option, - { - // Helper to remap a single variable - let r = |var: SsaVarId| remap(var).unwrap_or(var); - - match self.clone() { - Self::Const { dest, value } => Self::Const { - dest: r(dest), - value, - }, - - Self::Add { dest, left, right } => Self::Add { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::AddOvf { - dest, - left, - right, - unsigned, - } => Self::AddOvf { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - Self::Sub { dest, left, right } => Self::Sub { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::SubOvf { - dest, - left, - right, - unsigned, - } => Self::SubOvf { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - Self::Mul { dest, left, right } => Self::Mul { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::MulOvf { - dest, - left, - right, - unsigned, - } => Self::MulOvf { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - Self::Div { - dest, - left, - right, - unsigned, - } => Self::Div { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - Self::Rem { - dest, - left, - right, - unsigned, - } => Self::Rem { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - - Self::Neg { dest, operand } => Self::Neg { - dest: r(dest), - operand: r(operand), - }, - Self::And { dest, left, right } => Self::And { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::Or { dest, left, right } => Self::Or { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::Xor { dest, left, right } => Self::Xor { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::Not { dest, operand } => Self::Not { - dest: r(dest), - operand: r(operand), - }, - - Self::Shl { - dest, - value, - amount, - } => Self::Shl { - dest: r(dest), - value: r(value), - amount: r(amount), - }, - Self::Shr { - dest, - value, - amount, - unsigned, - } => Self::Shr { - dest: r(dest), - value: r(value), - amount: r(amount), - unsigned, - }, - - Self::Ceq { dest, left, right } => Self::Ceq { - dest: r(dest), - left: r(left), - right: r(right), - }, - Self::Clt { - dest, - left, - right, - unsigned, - } => Self::Clt { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - Self::Cgt { - dest, - left, - right, - unsigned, - } => Self::Cgt { - dest: r(dest), - left: r(left), - right: r(right), - unsigned, - }, - - Self::Conv { - dest, - operand, - target, - overflow_check, - unsigned, - } => Self::Conv { - dest: r(dest), - operand: r(operand), - target, - overflow_check, - unsigned, - }, - Self::Ckfinite { dest, operand } => Self::Ckfinite { - dest: r(dest), - operand: r(operand), - }, - - // Control flow - no dests, may have uses - Self::Jump { target } => Self::Jump { target }, - Self::Branch { - condition, - true_target, - false_target, - } => Self::Branch { - condition: r(condition), - true_target, - false_target, - }, - Self::BranchCmp { - left, - right, - cmp, - unsigned, - true_target, - false_target, - } => Self::BranchCmp { - left: r(left), - right: r(right), - cmp, - unsigned, - true_target, - false_target, - }, - Self::Switch { - value, - targets, - default, - } => Self::Switch { - value: r(value), - targets, - default, - }, - Self::Return { value } => Self::Return { - value: value.map(&r), - }, - Self::Leave { target } => Self::Leave { target }, - - // Field operations - Self::LoadField { - dest, - object, - field, - } => Self::LoadField { - dest: r(dest), - object: r(object), - field, - }, - Self::StoreField { - object, - field, - value, - } => Self::StoreField { - object: r(object), - field, - value: r(value), - }, - Self::LoadStaticField { dest, field } => Self::LoadStaticField { - dest: r(dest), - field, - }, - Self::StoreStaticField { field, value } => Self::StoreStaticField { - field, - value: r(value), - }, - Self::LoadFieldAddr { - dest, - object, - field, - } => Self::LoadFieldAddr { - dest: r(dest), - object: r(object), - field, - }, - Self::LoadStaticFieldAddr { dest, field } => Self::LoadStaticFieldAddr { - dest: r(dest), - field, - }, - - // Array operations - Self::LoadElement { - dest, - array, - index, - elem_type, - } => Self::LoadElement { - dest: r(dest), - array: r(array), - index: r(index), - elem_type, - }, - Self::StoreElement { - array, - index, - value, - elem_type, - } => Self::StoreElement { - array: r(array), - index: r(index), - value: r(value), - elem_type, - }, - Self::LoadElementAddr { - dest, - array, - index, - elem_type, - } => Self::LoadElementAddr { - dest: r(dest), - array: r(array), - index: r(index), - elem_type, - }, - Self::ArrayLength { dest, array } => Self::ArrayLength { - dest: r(dest), - array: r(array), - }, - - // Indirect operations - Self::LoadIndirect { - dest, - addr, - value_type, - } => Self::LoadIndirect { - dest: r(dest), - addr: r(addr), - value_type, - }, - Self::StoreIndirect { - addr, - value, - value_type, - } => Self::StoreIndirect { - addr: r(addr), - value: r(value), - value_type, - }, - - // Object operations - Self::NewObj { dest, ctor, args } => Self::NewObj { - dest: r(dest), - ctor, - args: args.into_iter().map(&r).collect(), - }, - Self::NewArr { - dest, - elem_type, - length, - } => Self::NewArr { - dest: r(dest), - elem_type, - length: r(length), - }, - Self::CastClass { - dest, - object, - target_type, - } => Self::CastClass { - dest: r(dest), - object: r(object), - target_type, - }, - Self::IsInst { - dest, - object, - target_type, - } => Self::IsInst { - dest: r(dest), - object: r(object), - target_type, - }, - Self::Box { - dest, - value, - value_type, - } => Self::Box { - dest: r(dest), - value: r(value), - value_type, - }, - Self::Unbox { - dest, - object, - value_type, - } => Self::Unbox { - dest: r(dest), - object: r(object), - value_type, - }, - Self::UnboxAny { - dest, - object, - value_type, - } => Self::UnboxAny { - dest: r(dest), - object: r(object), - value_type, - }, - Self::SizeOf { dest, value_type } => Self::SizeOf { - dest: r(dest), - value_type, - }, - Self::LoadToken { dest, token } => Self::LoadToken { - dest: r(dest), - token, - }, - - // Call operations - Self::Call { dest, method, args } => Self::Call { - dest: dest.map(&r), - method, - args: args.into_iter().map(&r).collect(), - }, - Self::CallVirt { dest, method, args } => Self::CallVirt { - dest: dest.map(&r), - method, - args: args.into_iter().map(&r).collect(), - }, - Self::CallIndirect { - dest, - fptr, - signature, - args, - } => Self::CallIndirect { - dest: dest.map(&r), - fptr: r(fptr), - signature, - args: args.into_iter().map(&r).collect(), - }, - - // Function pointer operations - Self::LoadFunctionPtr { dest, method } => Self::LoadFunctionPtr { - dest: r(dest), - method, - }, - Self::LoadVirtFunctionPtr { - dest, - object, - method, - } => Self::LoadVirtFunctionPtr { - dest: r(dest), - object: r(object), - method, - }, - - // Value and address loading - Self::LoadArg { dest, arg_index } => Self::LoadArg { - dest: r(dest), - arg_index, - }, - Self::LoadLocal { dest, local_index } => Self::LoadLocal { - dest: r(dest), - local_index, - }, - Self::LoadArgAddr { dest, arg_index } => Self::LoadArgAddr { - dest: r(dest), - arg_index, - }, - Self::LoadLocalAddr { dest, local_index } => Self::LoadLocalAddr { - dest: r(dest), - local_index, - }, - - // Misc operations - Self::Copy { dest, src } => Self::Copy { - dest: r(dest), - src: r(src), - }, - Self::Pop { value } => Self::Pop { value: r(value) }, - Self::Throw { exception } => Self::Throw { - exception: r(exception), - }, - Self::Rethrow => Self::Rethrow, - Self::EndFilter { result } => Self::EndFilter { result: r(result) }, - Self::EndFinally => Self::EndFinally, - Self::Nop => Self::Nop, - Self::Break => Self::Break, - - // Memory block operations - Self::LocalAlloc { dest, size } => Self::LocalAlloc { - dest: r(dest), - size: r(size), - }, - Self::InitObj { - dest_addr, - value_type, - } => Self::InitObj { - dest_addr: r(dest_addr), - value_type, - }, - Self::LoadObj { - dest, - src_addr, - value_type, - } => Self::LoadObj { - dest: r(dest), - src_addr: r(src_addr), - value_type, - }, - Self::StoreObj { - dest_addr, - value, - value_type, - } => Self::StoreObj { - dest_addr: r(dest_addr), - value: r(value), - value_type, - }, - Self::CopyObj { - dest_addr, - src_addr, - value_type, - } => Self::CopyObj { - dest_addr: r(dest_addr), - src_addr: r(src_addr), - value_type, - }, - Self::CopyBlk { - dest_addr, - src_addr, - size, - } => Self::CopyBlk { - dest_addr: r(dest_addr), - src_addr: r(src_addr), - size: r(size), - }, - Self::InitBlk { - dest_addr, - value, - size, - } => Self::InitBlk { - dest_addr: r(dest_addr), - value: r(value), - size: r(size), - }, - - // Phi operations - Self::Phi { dest, operands } => Self::Phi { - dest: r(dest), - operands: operands.into_iter().map(|(p, v)| (p, r(v))).collect(), - }, - - Self::Constrained { constraint_type } => Self::Constrained { constraint_type }, - Self::Volatile => Self::Volatile, - Self::Unaligned { alignment } => Self::Unaligned { alignment }, - Self::TailPrefix => Self::TailPrefix, - Self::Readonly => Self::Readonly, - } - } - - /// Extracts binary operation information if this is a binary operation. - /// - /// This method provides a uniform view of all binary operations (arithmetic, - /// bitwise, comparison, shifts) for optimization passes that need to handle - /// them generically. - /// - /// # Returns - /// - /// - `Some(BinaryOpInfo)` if this is a binary operation - /// - `None` for all other operations - /// - /// # Supported Operations - /// - /// - Arithmetic: `Add`, `AddOvf`, `Sub`, `SubOvf`, `Mul`, `MulOvf`, `Div`, `Rem` - /// - Bitwise: `And`, `Or`, `Xor` - /// - Shifts: `Shl`, `Shr` - /// - Comparisons: `Ceq`, `Clt`, `Cgt` - /// - /// # Example - /// - /// ```ignore - /// match op.as_binary_op() { - /// Some(info) if info.kind == BinaryOpKind::Add => { - /// // Handle addition - /// } - /// Some(info) => { - /// // Handle other binary ops - /// } - /// None => { - /// // Not a binary operation - /// } - /// } - /// ``` - #[must_use] - pub fn as_binary_op(&self) -> Option { - match *self { - Self::Add { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::Add, - dest, - left, - right, - unsigned: false, - }), - Self::AddOvf { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::AddOvf, - dest, - left, - right, - unsigned, - }), - Self::Sub { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::Sub, - dest, - left, - right, - unsigned: false, - }), - Self::SubOvf { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::SubOvf, - dest, - left, - right, - unsigned, - }), - Self::Mul { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::Mul, - dest, - left, - right, - unsigned: false, - }), - Self::MulOvf { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::MulOvf, - dest, - left, - right, - unsigned, - }), - Self::Div { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::Div, - dest, - left, - right, - unsigned, - }), - Self::Rem { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::Rem, - dest, - left, - right, - unsigned, - }), - Self::And { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::And, - dest, - left, - right, - unsigned: false, - }), - Self::Or { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::Or, - dest, - left, - right, - unsigned: false, - }), - Self::Xor { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::Xor, - dest, - left, - right, - unsigned: false, - }), - Self::Shl { - dest, - value, - amount, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::Shl, - dest, - left: value, - right: amount, - unsigned: false, - }), - Self::Shr { - dest, - value, - amount, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::Shr, - dest, - left: value, - right: amount, - unsigned, - }), - Self::Ceq { dest, left, right } => Some(BinaryOpInfo { - kind: BinaryOpKind::Ceq, - dest, - left, - right, - unsigned: false, - }), - Self::Clt { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::Clt, - dest, - left, - right, - unsigned, - }), - Self::Cgt { - dest, - left, - right, - unsigned, - } => Some(BinaryOpInfo { - kind: BinaryOpKind::Cgt, - dest, - left, - right, - unsigned, - }), - _ => None, - } - } - - /// Extracts unary operation information if this is a unary operation. - /// - /// This method provides a uniform view of all unary operations for - /// optimization passes that need to handle them generically. - /// - /// # Returns - /// - /// - `Some(UnaryOpInfo)` if this is a unary operation - /// - `None` for all other operations - /// - /// # Supported Operations - /// - /// - `Neg`: Negation - /// - `Not`: Bitwise NOT - /// - `Ckfinite`: Check finite - /// - /// # Note - /// - /// `Conv` is not included because it requires additional type information - /// that doesn't fit the simple unary pattern. - /// - /// # Example - /// - /// ```ignore - /// if let Some(info) = op.as_unary_op() { - /// println!("Unary {} on {}", info.kind, info.operand); - /// } - /// ``` - #[must_use] - pub fn as_unary_op(&self) -> Option { - match *self { - Self::Neg { dest, operand } => Some(UnaryOpInfo { - kind: UnaryOpKind::Neg, - dest, - operand, - }), - Self::Not { dest, operand } => Some(UnaryOpInfo { - kind: UnaryOpKind::Not, - dest, - operand, - }), - Self::Ckfinite { dest, operand } => Some(UnaryOpInfo { - kind: UnaryOpKind::Ckfinite, - dest, - operand, - }), - _ => None, - } - } - - /// Returns the stack effect (pops, pushes) for this SSA operation. - /// - /// This represents the net effect on the evaluation stack when the operation - /// is executed, assuming operands have already been loaded. The effect is: - /// - pops: number of values consumed from the stack - /// - pushes: number of values produced to the stack - /// - /// Note: This tracks the operation's own effect, not the loading of operands - /// (which is tracked separately during codegen). - #[must_use] - pub fn stack_effect(&self) -> (u32, u32) { - match self { - // Binary arithmetic, comparisons, and array access - pop 2, push 1 - Self::Add { .. } - | Self::Sub { .. } - | Self::Mul { .. } - | Self::Div { .. } - | Self::Rem { .. } - | Self::AddOvf { .. } - | Self::SubOvf { .. } - | Self::MulOvf { .. } - | Self::And { .. } - | Self::Or { .. } - | Self::Xor { .. } - | Self::Shl { .. } - | Self::Shr { .. } - | Self::Ceq { .. } - | Self::Clt { .. } - | Self::Cgt { .. } - | Self::LoadElement { .. } - | Self::LoadElementAddr { .. } => (2, 1), - - // Control flow - Self::Return { value } => { - if value.is_some() { - (1, 0) // pop return value - } else { - (0, 0) // void return - } - } - // No stack effect (0, 0) - Self::Jump { .. } - | Self::Rethrow - | Self::Leave { .. } - | Self::EndFinally - | Self::Copy { .. } - | Self::Nop - | Self::Break - | Self::Constrained { .. } - | Self::Volatile - | Self::Unaligned { .. } - | Self::TailPrefix - | Self::Readonly - | Self::Phi { .. } => (0, 0), - - // Pop 1, push 0 (1, 0) - Self::Branch { .. } - | Self::Switch { .. } - | Self::Throw { .. } - | Self::EndFilter { .. } - | Self::Pop { .. } - | Self::StoreStaticField { .. } - | Self::InitObj { .. } => (1, 0), - - // Pop 2, push 0 (2, 0) - Self::BranchCmp { .. } - | Self::StoreField { .. } - | Self::StoreIndirect { .. } - | Self::StoreObj { .. } - | Self::CopyObj { .. } => (2, 0), - - // Pop 3, push 0 (3, 0) - Self::StoreElement { .. } | Self::InitBlk { .. } | Self::CopyBlk { .. } => (3, 0), - - // Pop 0, push 1 (0, 1) - Self::LoadStaticField { .. } - | Self::LoadStaticFieldAddr { .. } - | Self::SizeOf { .. } - | Self::LoadToken { .. } - | Self::LoadArg { .. } - | Self::LoadLocal { .. } - | Self::LoadArgAddr { .. } - | Self::LoadLocalAddr { .. } - | Self::LoadFunctionPtr { .. } - | Self::Const { .. } => (0, 1), - - // Pop 1, push 1 (1, 1) - Self::Neg { .. } - | Self::Not { .. } - | Self::Conv { .. } - | Self::Ckfinite { .. } - | Self::LoadField { .. } - | Self::LoadFieldAddr { .. } - | Self::ArrayLength { .. } - | Self::NewArr { .. } - | Self::LoadIndirect { .. } - | Self::LoadObj { .. } - | Self::Box { .. } - | Self::Unbox { .. } - | Self::UnboxAny { .. } - | Self::CastClass { .. } - | Self::IsInst { .. } - | Self::LoadVirtFunctionPtr { .. } - | Self::LocalAlloc { .. } => (1, 1), - - // Call operations - stack effect depends on args and return type - Self::Call { dest, args, .. } | Self::CallVirt { dest, args, .. } => { - // args.len() will never exceed u32 for CIL methods - #[allow(clippy::cast_possible_truncation)] - let pops = args.len() as u32; - let pushes = u32::from(dest.is_some()); - (pops, pushes) - } - Self::CallIndirect { dest, args, .. } => { - // Indirect call pops args + function pointer - // args.len() will never exceed u32 for CIL methods - #[allow(clippy::cast_possible_truncation)] - let pops = (args.len() as u32).saturating_add(1); - let pushes = u32::from(dest.is_some()); - (pops, pushes) - } - Self::NewObj { args, .. } => { - // newobj pops constructor args, always pushes new instance - // args.len() will never exceed u32 for CIL methods - #[allow(clippy::cast_possible_truncation)] - let pops = args.len() as u32; - (pops, 1) - } - } - } - - /// Tries to infer the result type of this SSA operation. - /// - /// Returns `Some(type)` for operations whose result type can be determined - /// structurally (constants, conversions, comparisons, arithmetic, etc.), - /// or `None` for operations that don't produce values or need metadata. - #[must_use] - pub fn infer_result_type(&self) -> Option { - match self { - // Constants - infer type from the constant value - Self::Const { value, .. } => Some(match value { - ConstValue::I8(_) => SsaType::I8, - ConstValue::I16(_) => SsaType::I16, - ConstValue::I32(_) => SsaType::I32, - ConstValue::I64(_) => SsaType::I64, - ConstValue::U8(_) => SsaType::U8, - ConstValue::U16(_) => SsaType::U16, - ConstValue::U32(_) => SsaType::U32, - ConstValue::U64(_) => SsaType::U64, - ConstValue::NativeInt(_) => SsaType::NativeInt, - ConstValue::NativeUInt(_) => SsaType::NativeUInt, - ConstValue::F32(_) => SsaType::F32, - ConstValue::F64(_) => SsaType::F64, - ConstValue::String(_) | ConstValue::DecryptedString(_) => SsaType::String, - ConstValue::DecryptedArray { .. } => SsaType::Object, - ConstValue::Null => SsaType::Null, - ConstValue::True | ConstValue::False => SsaType::Bool, - ConstValue::Type(_) | ConstValue::MethodHandle(_) | ConstValue::FieldHandle(_) => { - SsaType::Object - } - }), - // Type conversions have explicit target type - Self::Conv { target, .. } => Some(target.clone()), - // Comparisons produce bool (represented as I32 on CIL stack, but - // semantically boolean for type inference purposes) - Self::Ceq { .. } | Self::Clt { .. } | Self::Cgt { .. } => Some(SsaType::Bool), - // Arithmetic/bitwise ops and SizeOf produce I32 per CIL spec - Self::Add { .. } - | Self::Sub { .. } - | Self::Mul { .. } - | Self::Div { .. } - | Self::Rem { .. } - | Self::And { .. } - | Self::Or { .. } - | Self::Xor { .. } - | Self::Shl { .. } - | Self::Shr { .. } - | Self::Neg { .. } - | Self::Not { .. } - | Self::AddOvf { .. } - | Self::SubOvf { .. } - | Self::MulOvf { .. } - | Self::SizeOf { .. } => Some(SsaType::I32), - // UnboxAny/LoadObj — use the embedded value_type - Self::UnboxAny { value_type, .. } | Self::LoadObj { value_type, .. } => { - Some(SsaType::ValueType(*value_type)) - } - // Context-dependent ops — return None; resolved from - // SsaInstruction.result_type() which is set during SSA construction - // with full TypeContext metadata. - Self::LoadField { .. } - | Self::LoadStaticField { .. } - | Self::Call { dest: Some(_), .. } - | Self::CallVirt { dest: Some(_), .. } - | Self::CallIndirect { dest: Some(_), .. } - | Self::LoadArg { .. } - | Self::LoadLocal { .. } => None, - // Box, NewObj, NewArr, CastClass/IsInst produce object references - Self::Box { .. } - | Self::NewObj { .. } - | Self::NewArr { .. } - | Self::CastClass { .. } - | Self::IsInst { .. } => Some(SsaType::Object), - // Array length and LocalAlloc return native int - Self::ArrayLength { .. } | Self::LocalAlloc { .. } => Some(SsaType::NativeInt), - // Ckfinite operates on F64 stack type - Self::Ckfinite { .. } => Some(SsaType::F64), - // Function pointer loads produce native int - Self::LoadFunctionPtr { .. } | Self::LoadVirtFunctionPtr { .. } => { - Some(SsaType::NativeInt) - } - // Load element — type embedded in the op - Self::LoadElement { elem_type, .. } => Some(elem_type.clone()), - // Load indirect — type embedded in the op - Self::LoadIndirect { value_type, .. } => Some(value_type.clone()), - // Load token — resolving to the correct corelib - // `valuetype [mscorlib]System.Runtime*Handle` requires metadata - // access (see `TypeProvider::runtime_handle_type`). This op - // inference is assembly-free, so return `None` and let the - // caller fall back to the variable's declared type set during - // SSA construction. - Self::LoadToken { .. } => None, - // Unbox produces ByRef to the embedded value type - Self::Unbox { value_type, .. } => { - Some(SsaType::ByRef(Box::new(SsaType::ValueType(*value_type)))) - } - // LoadElementAddr produces ByRef to the embedded element type - Self::LoadElementAddr { elem_type, .. } => { - Some(SsaType::ByRef(Box::new(SsaType::Class(*elem_type)))) - } - // Context-dependent address loads — return None; resolved from - // SsaInstruction.result_type() - Self::LoadFieldAddr { .. } - | Self::LoadStaticFieldAddr { .. } - | Self::LoadArgAddr { .. } - | Self::LoadLocalAddr { .. } => None, - // Operations that don't produce values or have complex types - _ => None, - } - } + fn referenced_token(&self) -> Option; } -impl fmt::Display for SsaOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { +impl SsaOpCilExt for AnalyssaSsaOp { + fn referenced_token(&self) -> Option { match self { - Self::Const { dest, value } => write!(f, "{dest} = {value}"), - Self::Add { dest, left, right } => write!(f, "{dest} = add {left}, {right}"), - Self::AddOvf { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = add.ovf{suffix} {left}, {right}") - } - Self::Sub { dest, left, right } => write!(f, "{dest} = sub {left}, {right}"), - Self::SubOvf { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = sub.ovf{suffix} {left}, {right}") - } - Self::Mul { dest, left, right } => write!(f, "{dest} = mul {left}, {right}"), - Self::MulOvf { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = mul.ovf{suffix} {left}, {right}") - } - Self::Div { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = div{suffix} {left}, {right}") - } - Self::Rem { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = rem{suffix} {left}, {right}") - } - Self::Neg { dest, operand } => write!(f, "{dest} = neg {operand}"), - Self::And { dest, left, right } => write!(f, "{dest} = and {left}, {right}"), - Self::Or { dest, left, right } => write!(f, "{dest} = or {left}, {right}"), - Self::Xor { dest, left, right } => write!(f, "{dest} = xor {left}, {right}"), - Self::Not { dest, operand } => write!(f, "{dest} = not {operand}"), - Self::Shl { - dest, - value, - amount, - } => write!(f, "{dest} = shl {value}, {amount}"), - Self::Shr { - dest, - value, - amount, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = shr{suffix} {value}, {amount}") - } - Self::Ceq { dest, left, right } => write!(f, "{dest} = ceq {left}, {right}"), - Self::Clt { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = clt{suffix} {left}, {right}") - } - Self::Cgt { - dest, - left, - right, - unsigned, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!(f, "{dest} = cgt{suffix} {left}, {right}") - } - Self::Conv { - dest, - operand, - target, - .. - } => write!(f, "{dest} = conv.{target} {operand}"), - Self::Jump { target } => write!(f, "jump B{target}"), - Self::Branch { - condition, - true_target, - false_target, - } => write!(f, "branch {condition}, B{true_target}, B{false_target}"), - Self::BranchCmp { - left, - right, - cmp, - unsigned, - true_target, - false_target, - } => { - let suffix = if *unsigned { ".un" } else { "" }; - write!( - f, - "branchcmp{suffix} {left} {cmp} {right}, B{true_target}, B{false_target}" - ) - } - Self::Switch { - value, - targets, - default, - } => { - write!(f, "switch {value}, [")?; - for (i, t) in targets.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "B{t}")?; - } - write!(f, "], B{default}") - } - Self::Return { value: Some(v) } => write!(f, "ret {v}"), - Self::Return { value: None } => write!(f, "ret"), - Self::LoadField { - dest, - object, - field, - } => { - write!(f, "{dest} = ldfld {field}, {object}") - } - Self::StoreField { - object, - field, - value, - } => write!(f, "stfld {field}, {object}, {value}"), - Self::LoadStaticField { dest, field } => write!(f, "{dest} = ldsfld {field}"), - Self::StoreStaticField { field, value } => write!(f, "stsfld {field}, {value}"), - Self::LoadFieldAddr { - dest, - object, - field, - } => { - write!(f, "{dest} = ldflda {field}, {object}") - } - Self::LoadStaticFieldAddr { dest, field } => write!(f, "{dest} = ldsflda {field}"), - Self::LoadElement { - dest, - array, - index, - elem_type, - } => write!(f, "{dest} = ldelem.{elem_type} {array}[{index}]"), - Self::StoreElement { - array, - index, - value, - elem_type, - } => write!(f, "stelem.{elem_type} {array}[{index}], {value}"), - Self::LoadElementAddr { - dest, array, index, .. - } => write!(f, "{dest} = ldelema {array}[{index}]"), - Self::ArrayLength { dest, array } => write!(f, "{dest} = ldlen {array}"), - Self::LoadIndirect { - dest, - addr, - value_type, - } => write!(f, "{dest} = ldind.{value_type} {addr}"), - Self::StoreIndirect { - addr, - value, - value_type, - } => write!(f, "stind.{value_type} {addr}, {value}"), - Self::NewObj { dest, ctor, args } => { - write!(f, "{dest} = newobj {ctor}(")?; - for (i, arg) in args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{arg}")?; - } - write!(f, ")") - } - Self::NewArr { - dest, - elem_type, - length, - } => write!(f, "{dest} = newarr {elem_type}[{length}]"), - Self::CastClass { - dest, - object, - target_type, - } => write!(f, "{dest} = castclass {target_type}, {object}"), - Self::IsInst { - dest, - object, - target_type, - } => write!(f, "{dest} = isinst {target_type}, {object}"), - Self::Box { - dest, - value, - value_type, - } => write!(f, "{dest} = box {value_type}, {value}"), - Self::Unbox { - dest, - object, - value_type, - } => write!(f, "{dest} = unbox {value_type}, {object}"), - Self::UnboxAny { - dest, - object, - value_type, - } => write!(f, "{dest} = unbox.any {value_type}, {object}"), - Self::SizeOf { dest, value_type } => write!(f, "{dest} = sizeof {value_type}"), - Self::LoadToken { dest, token } => write!(f, "{dest} = ldtoken {token}"), - Self::Call { dest, method, args } => { - if let Some(d) = dest { - write!(f, "{d} = ")?; - } - write!(f, "call {method}(")?; - for (i, arg) in args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{arg}")?; - } - write!(f, ")") - } - Self::CallVirt { dest, method, args } => { - if let Some(d) = dest { - write!(f, "{d} = ")?; - } - write!(f, "callvirt {method}(")?; - for (i, arg) in args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{arg}")?; - } - write!(f, ")") - } - Self::CallIndirect { - dest, fptr, args, .. - } => { - if let Some(d) = dest { - write!(f, "{d} = ")?; - } - write!(f, "calli {fptr}(")?; - for (i, arg) in args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{arg}")?; - } - write!(f, ")") - } - Self::LoadFunctionPtr { dest, method } => write!(f, "{dest} = ldftn {method}"), - Self::LoadVirtFunctionPtr { - dest, - object, - method, - } => write!(f, "{dest} = ldvirtftn {method}, {object}"), - Self::LoadArg { dest, arg_index } => write!(f, "{dest} = ldarg {arg_index}"), - Self::LoadLocal { dest, local_index } => write!(f, "{dest} = ldloc {local_index}"), - Self::LoadArgAddr { dest, arg_index } => write!(f, "{dest} = ldarga {arg_index}"), - Self::LoadLocalAddr { dest, local_index } => { - write!(f, "{dest} = ldloca {local_index}") - } - Self::Copy { dest, src } => write!(f, "{dest} = {src}"), - Self::Pop { value } => write!(f, "pop {value}"), - Self::Throw { exception } => write!(f, "throw {exception}"), - Self::Rethrow => write!(f, "rethrow"), - Self::EndFinally => write!(f, "endfinally"), - Self::EndFilter { result } => write!(f, "endfilter {result}"), - Self::Leave { target } => write!(f, "leave B{target}"), - Self::InitBlk { - dest_addr, - value, - size, - } => write!(f, "initblk {dest_addr}, {value}, {size}"), - Self::CopyBlk { - dest_addr, - src_addr, - size, - } => write!(f, "cpblk {dest_addr}, {src_addr}, {size}"), - Self::InitObj { - dest_addr, - value_type, - } => write!(f, "initobj {value_type}, {dest_addr}"), - Self::CopyObj { - dest_addr, - src_addr, - value_type, - } => write!(f, "cpobj {value_type}, {dest_addr}, {src_addr}"), - Self::LoadObj { - dest, - src_addr, - value_type, - } => write!(f, "{dest} = ldobj {value_type}, {src_addr}"), - Self::StoreObj { - dest_addr, - value, - value_type, - } => write!(f, "stobj {value_type}, {dest_addr}, {value}"), - Self::LocalAlloc { dest, size } => write!(f, "{dest} = localloc {size}"), - Self::Constrained { constraint_type } => { - write!(f, "constrained. {constraint_type}") - } - Self::Volatile => write!(f, "volatile."), - Self::Unaligned { alignment } => write!(f, "unaligned. {alignment}"), - Self::TailPrefix => write!(f, "tail."), - Self::Readonly => write!(f, "readonly."), - Self::Ckfinite { dest, operand } => write!(f, "{dest} = ckfinite {operand}"), - Self::Nop => write!(f, "nop"), - Self::Break => write!(f, "break"), - Self::Phi { dest, operands } => { - write!(f, "{dest} = phi(")?; - for (i, (block, var)) in operands.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "B{block}: {var}")?; - } - write!(f, ")") - } + AnalyssaSsaOp::Call { method, .. } + | AnalyssaSsaOp::CallVirt { method, .. } + | AnalyssaSsaOp::LoadFunctionPtr { method, .. } + | AnalyssaSsaOp::LoadVirtFunctionPtr { method, .. } => Some(method.token()), + AnalyssaSsaOp::NewObj { ctor, .. } => Some(ctor.token()), + AnalyssaSsaOp::LoadField { field, .. } + | AnalyssaSsaOp::StoreField { field, .. } + | AnalyssaSsaOp::LoadFieldAddr { field, .. } + | AnalyssaSsaOp::LoadStaticField { field, .. } + | AnalyssaSsaOp::StoreStaticField { field, .. } + | AnalyssaSsaOp::LoadStaticFieldAddr { field, .. } => Some(field.token()), + AnalyssaSsaOp::Box { value_type, .. } + | AnalyssaSsaOp::Unbox { value_type, .. } + | AnalyssaSsaOp::UnboxAny { value_type, .. } + | AnalyssaSsaOp::InitObj { value_type, .. } + | AnalyssaSsaOp::SizeOf { value_type, .. } + | AnalyssaSsaOp::CopyObj { value_type, .. } + | AnalyssaSsaOp::LoadObj { value_type, .. } + | AnalyssaSsaOp::StoreObj { value_type, .. } => Some(value_type.token()), + AnalyssaSsaOp::IsInst { target_type, .. } + | AnalyssaSsaOp::CastClass { target_type, .. } => Some(target_type.token()), + AnalyssaSsaOp::NewArr { elem_type, .. } + | AnalyssaSsaOp::LoadElementAddr { elem_type, .. } => Some(elem_type.token()), + AnalyssaSsaOp::LoadToken { token, .. } => Some(token.token()), + AnalyssaSsaOp::Constrained { constraint_type } => Some(constraint_type.token()), + _ => None, } } } @@ -3040,20 +65,31 @@ impl fmt::Display for SsaOp { mod tests { use crate::{ analysis::ssa::{ - ops::{BinaryOpKind, SsaOp, UnaryOpKind}, + ops::{BinaryOpKind, UnaryOpKind}, + target::CilTarget, types::{FieldRef, MethodRef}, - value::ConstValue, + value::ConstValue as RawConstValue, SsaVarId, }, metadata::token::Token, }; + // Lock the type parameter to CilTarget for the test module so unit-only + // variant constructions like `SsaOp::Add { ... }` infer cleanly. + type SsaOp = super::SsaOp; + type ConstValue = RawConstValue; + #[test] fn test_dest_extraction() { let dest = SsaVarId::from_index(0); let left = SsaVarId::from_index(1); let right = SsaVarId::from_index(2); - let op = SsaOp::Add { dest, left, right }; + let op = SsaOp::Add { + dest, + left, + right, + flags: None, + }; assert_eq!(op.dest(), Some(dest)); let op = SsaOp::Jump { target: 1 }; @@ -3080,7 +116,12 @@ mod tests { let dest = SsaVarId::from_index(0); let left = SsaVarId::from_index(1); let right = SsaVarId::from_index(2); - let op = SsaOp::Add { dest, left, right }; + let op = SsaOp::Add { + dest, + left, + right, + flags: None, + }; assert_eq!(op.uses(), vec![left, right]); let const_dest = SsaVarId::from_index(3); @@ -3119,7 +160,13 @@ mod tests { assert!(SsaOp::Throw { exception: exc }.is_terminator()); assert!(!SsaOp::Nop.is_terminator()); - assert!(!SsaOp::Add { dest, left, right }.is_terminator()); + assert!(!SsaOp::Add { + dest, + left, + right, + flags: None + } + .is_terminator()); } #[test] @@ -3131,7 +178,13 @@ mod tests { let object = SsaVarId::from_index(4); let value = SsaVarId::from_index(5); - assert!(SsaOp::Add { dest, left, right }.is_pure()); + assert!(SsaOp::Add { + dest, + left, + right, + flags: None + } + .is_pure()); assert!(SsaOp::Const { dest: const_dest, value: ConstValue::I32(42) @@ -3160,6 +213,7 @@ mod tests { dest: SsaVarId::from_index(2), left: SsaVarId::from_index(0), right: SsaVarId::from_index(1), + flags: None, }; assert_eq!(format!("{op}"), "v2 = add v0, v1"); @@ -3231,7 +285,12 @@ mod tests { assert!(op.successors().is_empty()); // Non-terminators have no successors - let op = SsaOp::Add { dest, left, right }; + let op = SsaOp::Add { + dest, + left, + right, + flags: None, + }; assert!(op.successors().is_empty()); let op = SsaOp::Nop; @@ -3245,7 +304,12 @@ mod tests { let right = SsaVarId::from_index(2); // Add is a binary operation - let op = SsaOp::Add { dest, left, right }; + let op = SsaOp::Add { + dest, + left, + right, + flags: None, + }; let info = op.as_binary_op().expect("Add should be binary op"); assert_eq!(info.kind, BinaryOpKind::Add); assert_eq!(info.dest, dest); @@ -3259,6 +323,7 @@ mod tests { left, right, unsigned: true, + flags: None, }; let info = op.as_binary_op().expect("Div should be binary op"); assert_eq!(info.kind, BinaryOpKind::Div); @@ -3271,6 +336,7 @@ mod tests { dest, value, amount, + flags: None, }; let info = op.as_binary_op().expect("Shl should be binary op"); assert_eq!(info.kind, BinaryOpKind::Shl); @@ -3293,7 +359,8 @@ mod tests { assert!(SsaOp::Jump { target: 1 }.as_binary_op().is_none()); assert!(SsaOp::Neg { dest, - operand: left + operand: left, + flags: None, } .as_binary_op() .is_none()); @@ -3311,14 +378,22 @@ mod tests { let operand = SsaVarId::from_index(1); // Neg is a unary operation - let op = SsaOp::Neg { dest, operand }; + let op = SsaOp::Neg { + dest, + operand, + flags: None, + }; let info = op.as_unary_op().expect("Neg should be unary op"); assert_eq!(info.kind, UnaryOpKind::Neg); assert_eq!(info.dest, dest); assert_eq!(info.operand, operand); // Not is a unary operation - let op = SsaOp::Not { dest, operand }; + let op = SsaOp::Not { + dest, + operand, + flags: None, + }; let info = op.as_unary_op().expect("Not should be unary op"); assert_eq!(info.kind, UnaryOpKind::Not); @@ -3333,7 +408,14 @@ mod tests { let left = SsaVarId::from_index(2); let right = SsaVarId::from_index(3); - assert!(SsaOp::Add { dest, left, right }.as_unary_op().is_none()); + assert!(SsaOp::Add { + dest, + left, + right, + flags: None + } + .as_unary_op() + .is_none()); assert!(SsaOp::Const { dest, diff --git a/dotscope/src/analysis/ssa/patterns.rs b/dotscope/src/analysis/ssa/patterns.rs deleted file mode 100644 index f6987868..00000000 --- a/dotscope/src/analysis/ssa/patterns.rs +++ /dev/null @@ -1,672 +0,0 @@ -//! Generic pattern detection for SSA-based analysis. -//! -//! This module provides pattern detection for common obfuscation constructs -//! without being tied to any specific obfuscator. The patterns detected include: -//! -//! - **Dispatcher patterns**: Switch-based state machines used for control flow flattening -//! - **Source blocks**: Blocks that eventually reach a dispatcher -//! - **Opaque predicates**: Conditional branches with statically determinable outcomes -//! -//! # Design Philosophy -//! -//! The pattern detection is designed to be: -//! -//! - **Generic**: Not tied to specific obfuscators (ConfuserEx, etc.) -//! - **SSA-based**: Leverages SSA form for accurate data flow analysis -//! - **Composable**: Patterns can be combined for more sophisticated analysis -//! -//! # Architecture -//! -//! Pattern detection works in phases: -//! -//! 1. **Structural analysis**: Identify potential patterns based on CFG structure -//! 2. **Data flow analysis**: Use tracked evaluation to understand value flow -//! 3. **Validation**: Verify patterns meet expected criteria -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::analysis::{PatternDetector, SsaFunction}; -//! use dotscope::metadata::typesystem::PointerSize; -//! -//! let detector = PatternDetector::new(&ssa, PointerSize::Bit32); -//! -//! // Find dispatcher patterns -//! let dispatchers = detector.find_dispatchers(); -//! -//! for dispatcher in &dispatchers { -//! println!("Dispatcher at block {}", dispatcher.block); -//! -//! // Find all source blocks -//! let sources = detector.find_sources(dispatcher); -//! for source in &sources { -//! println!(" Source block {} -> case {}", source.block, source.target_case); -//! } -//! } -//! ``` - -use std::collections::HashMap; - -use crate::{ - analysis::ssa::{ - evaluator::SsaEvaluator, symbolic::SymbolicExpr, ConstValue, SsaFunction, SsaOp, SsaVarId, - }, - metadata::typesystem::PointerSize, - utils::BitSet, -}; - -/// Detects common obfuscation patterns in SSA form. -/// -/// The detector analyzes the SSA function to identify structural patterns -/// commonly used in obfuscation, such as: -/// -/// - Control flow flattening dispatchers -/// - Opaque predicates -/// - Dead code regions -#[derive(Debug)] -pub struct PatternDetector<'a> { - ssa: &'a SsaFunction, - pointer_size: PointerSize, -} - -impl<'a> PatternDetector<'a> { - /// Creates a new pattern detector for the given SSA function. - #[must_use] - pub fn new(ssa: &'a SsaFunction, pointer_size: PointerSize) -> Self { - Self { ssa, pointer_size } - } - - /// Returns the underlying SSA function. - #[must_use] - pub fn ssa(&self) -> &SsaFunction { - self.ssa - } - - // Dispatcher Detection - - /// Finds all potential dispatcher patterns in the function. - /// - /// A dispatcher is characterized by: - /// - A switch instruction with multiple targets - /// - Targets that eventually loop back to the dispatcher - /// - A computed switch index (not just a simple variable) - /// - /// # Returns - /// - /// A vector of detected dispatcher patterns, sorted by block index. - #[must_use] - pub fn find_dispatchers(&self) -> Vec { - let mut dispatchers: Vec<_> = (0..self.ssa.block_count()) - .filter_map(|block_idx| self.analyze_potential_dispatcher(block_idx)) - .collect(); - dispatchers.sort_by_key(|d| d.block); - dispatchers - } - - /// Analyzes a block to determine if it's a dispatcher. - fn analyze_potential_dispatcher(&self, block_idx: usize) -> Option { - let block = self.ssa.block(block_idx)?; - - // Look for Switch instruction at the end - let terminator = block.terminator()?; - let (switch_var, targets, default) = match terminator.op() { - SsaOp::Switch { - value, - targets, - default, - } => (*value, targets.clone(), *default), - _ => return None, - }; - - // Must have multiple targets to be a dispatcher - if targets.len() < 2 { - return None; - } - - // Check if any targets loop back to this block - let has_loopback = targets - .iter() - .any(|&target| self.reaches_block(target, block_idx)) - || self.reaches_block(default, block_idx); - - if !has_loopback { - return None; - } - - // Try to build the dispatch expression - let dispatch_expr = self.build_dispatch_expression(block_idx, switch_var); - - // Identify state variables (inputs to the dispatch computation) - let state_vars = dispatch_expr - .as_ref() - .map(|e| e.variables().into_iter().collect()) - .unwrap_or_default(); - - Some(DispatcherPattern { - block: block_idx, - switch_var, - targets, - default, - dispatch_expr, - state_vars, - }) - } - - /// Checks if there's a path from `from_block` that reaches `target_block`. - /// - /// Uses BFS with a depth limit to avoid infinite loops. - fn reaches_block(&self, from_block: usize, target_block: usize) -> bool { - let block_count = self.ssa.block_count().max(1); - let mut visited = BitSet::new(block_count); - let mut queue = vec![from_block]; - let max_depth: u32 = 50; // Prevent infinite loops - let mut depth: u32 = 0; - - while !queue.is_empty() && depth < max_depth { - let mut next_queue = Vec::new(); - - for block_idx in queue { - if block_idx == target_block { - return true; - } - - if block_idx >= block_count || !visited.insert(block_idx) { - continue; - } - - // Get successors - if let Some(successors) = self.block_successors(block_idx) { - next_queue.extend(successors); - } - } - - queue = next_queue; - depth = depth.saturating_add(1); - } - - false - } - - /// Gets the successor blocks of a given block. - fn block_successors(&self, block_idx: usize) -> Option> { - let block = self.ssa.block(block_idx)?; - block.terminator()?; - Some(block.successors()) - } - - /// Builds a symbolic expression for how the switch index is computed. - fn build_dispatch_expression( - &self, - block_idx: usize, - switch_var: SsaVarId, - ) -> Option { - let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size); - - // Mark all phi results as symbolic (they come from different paths) - if let Some(block) = self.ssa.block(block_idx) { - for phi in block.phi_nodes() { - let name = format!("phi_{}", phi.result().index()); - eval.set_symbolic(phi.result(), name); - } - } - - // Evaluate the block - eval.evaluate_block(block_idx); - - // Get the switch variable's value as a symbolic expression - eval.get(switch_var).cloned() - } - - // Source Block Detection - - /// Finds all source blocks for a given dispatcher. - /// - /// A source block is one that: - /// - Sets a state value that determines which case is taken - /// - Eventually reaches the dispatcher (directly or through intermediates) - /// - /// # Arguments - /// - /// * `dispatcher` - The dispatcher pattern to find sources for. - /// - /// # Returns - /// - /// A vector of source blocks with their analyzed state values. - #[must_use] - pub fn find_sources(&self, dispatcher: &DispatcherPattern) -> Vec { - // Find all blocks that reach the dispatcher - let reaching_blocks = self.find_reaching_blocks(dispatcher.block); - - reaching_blocks - .iter() - .filter(|&block_idx| block_idx != dispatcher.block) - .filter_map(|block_idx| self.analyze_source_block(block_idx, dispatcher)) - .collect() - } - - /// Finds all blocks that can reach the dispatcher block. - fn find_reaching_blocks(&self, dispatcher_block: usize) -> BitSet { - let block_count = self.ssa.block_count().max(1); - let mut reaching = BitSet::new(block_count); - - // Build reverse CFG (predecessors) - let mut predecessors: HashMap> = HashMap::new(); - for block_idx in 0..self.ssa.block_count() { - if let Some(succs) = self.block_successors(block_idx) { - for succ in succs { - predecessors.entry(succ).or_default().push(block_idx); - } - } - } - - // BFS backwards from dispatcher - let mut queue = vec![dispatcher_block]; - while let Some(block_idx) = queue.pop() { - if block_idx >= block_count || !reaching.insert(block_idx) { - continue; - } - - if let Some(preds) = predecessors.get(&block_idx) { - queue.extend(preds.iter().copied()); - } - } - - reaching - } - - /// Analyzes a block to determine if it's a source for the dispatcher. - fn analyze_source_block( - &self, - block_idx: usize, - dispatcher: &DispatcherPattern, - ) -> Option { - let block = self.ssa.block(block_idx)?; - - // Check if this block has a jump or branch that leads to dispatcher - let terminator = block.terminator()?; - let (leads_to_dispatcher, is_conditional) = match terminator.op() { - SsaOp::Jump { target } => (*target == dispatcher.block, false), - SsaOp::Branch { - true_target, - false_target, - .. - } => { - let leads = *true_target == dispatcher.block || *false_target == dispatcher.block; - (leads, true) - } - _ => return None, - }; - - if !leads_to_dispatcher { - return None; - } - - // Try to determine what state value this block sets - let state_value = self.compute_state_value(block_idx, dispatcher); - - // Try to determine which case this leads to - let target_case = self.compute_target_case(state_value.as_ref(), dispatcher); - - Some(SourceBlock { - block: block_idx, - state_value, - target_case, - is_conditional, - }) - } - - /// Computes the state value set by a block. - fn compute_state_value( - &self, - block_idx: usize, - dispatcher: &DispatcherPattern, - ) -> Option { - let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size); - - // Set phi nodes in this block as symbolic - if let Some(block) = self.ssa.block(block_idx) { - for phi in block.phi_nodes() { - let name = format!("phi_{}", phi.result().index()); - eval.set_symbolic(phi.result(), name); - } - } - - // Evaluate the block - eval.evaluate_block(block_idx); - - // Get the value of the state variable - // We need to find which variable in this block contributes to the dispatcher's state - if let Some(state_var) = dispatcher.state_vars.first() { - // The state var in dispatcher comes from a phi, we need to find - // what value this block provides to that phi - if let Some(disp_block) = self.ssa.block(dispatcher.block) { - for phi in disp_block.phi_nodes() { - if phi.result() == *state_var { - // Find operand from our block - for operand in phi.operands() { - if operand.predecessor() == block_idx { - return eval.get(operand.value()).cloned(); - } - } - } - } - } - } - - None - } - - /// Computes which case a state value leads to. - fn compute_target_case( - &self, - state_value: Option<&SymbolicExpr>, - dispatcher: &DispatcherPattern, - ) -> Option { - // If we have a concrete state value and a dispatch expression, - // we can compute the target case - let concrete_state = state_value.and_then(SymbolicExpr::as_constant)?; - let dispatch_expr = dispatcher.dispatch_expr.as_ref()?; - - // Build bindings for evaluation - // First, create owned names so we can borrow them - let state_var_names: Vec = dispatcher - .state_vars - .iter() - .map(|v| format!("phi_{}", v.index())) - .collect(); - - let mut bindings: HashMap<&str, ConstValue> = HashMap::new(); - for name in &state_var_names { - bindings.insert(name.as_str(), concrete_state.clone()); - } - - // Also try with just "state" as a generic name - bindings.insert("state", concrete_state.clone()); - - // Evaluate the dispatch expression - let case_idx = dispatch_expr.evaluate_named(&bindings, self.pointer_size)?; - - // Convert to usize and check bounds - let idx = case_idx.as_i64().and_then(|v| usize::try_from(v).ok())?; - if idx < dispatcher.targets.len() { - Some(idx) - } else { - None // Out of bounds -> default case - } - } - - // Opaque Predicate Detection - - /// Finds potential opaque predicates in the function. - /// - /// An opaque predicate is a conditional branch where the condition - /// can be statically determined to always be true or always false. - #[must_use] - pub fn find_opaque_predicates(&self) -> Vec { - (0..self.ssa.block_count()) - .filter_map(|block_idx| self.analyze_opaque_predicate(block_idx)) - .collect() - } - - /// Analyzes a block to see if it contains an opaque predicate. - fn analyze_opaque_predicate(&self, block_idx: usize) -> Option { - let block = self.ssa.block(block_idx)?; - - // Look for Branch instruction - let terminator = block.terminator()?; - let (condition_var, true_target, false_target) = match terminator.op() { - SsaOp::Branch { - condition, - true_target, - false_target, - } => (*condition, *true_target, *false_target), - _ => return None, - }; - - // Evaluate the block to see if condition is determinable - let mut eval = SsaEvaluator::new(self.ssa, self.pointer_size); - - // Set phi nodes as symbolic - for phi in block.phi_nodes() { - let name = format!("phi_{}", phi.result().index()); - eval.set_symbolic(phi.result(), name); - } - - eval.evaluate_block(block_idx); - - let condition_value = eval.get(condition_var); - - let resolution = match condition_value { - Some(expr) if expr.is_constant() => { - if expr.as_constant().is_some_and(ConstValue::is_zero) { - PredicateResolution::AlwaysFalse - } else { - PredicateResolution::AlwaysTrue - } - } - Some(expr) => PredicateResolution::Symbolic(expr.clone()), - None => PredicateResolution::Unknown, - }; - - // Only report if it's always true or always false (actual opaque predicate) - if matches!( - resolution, - PredicateResolution::AlwaysTrue | PredicateResolution::AlwaysFalse - ) { - Some(OpaquePredicatePattern { - block: block_idx, - condition_var, - true_target, - false_target, - resolution, - }) - } else { - None - } - } -} - -// Pattern Types - -/// A detected dispatcher pattern (switch-based state machine). -/// -/// Dispatchers are the core of control flow flattening. They use a computed -/// switch index to dispatch to different case blocks, with state variables -/// controlling the execution flow. -#[derive(Debug, Clone)] -pub struct DispatcherPattern { - /// Block index containing the switch instruction. - pub block: usize, - - /// The SSA variable used as the switch condition. - pub switch_var: SsaVarId, - - /// Target blocks for each switch case (indexed by case value). - pub targets: Vec, - - /// Default target when case is out of range. - pub default: usize, - - /// The symbolic expression computing the switch index, if determinable. - /// This is typically something like `(state ^ const) % num_cases`. - pub dispatch_expr: Option, - - /// State variables that control the dispatch (phi node results that - /// feed into the dispatch expression). - pub state_vars: Vec, -} - -impl DispatcherPattern { - /// Returns the number of cases in this dispatcher. - #[must_use] - pub fn case_count(&self) -> usize { - self.targets.len() - } - - /// Gets the target block for a specific case index. - #[must_use] - pub fn target_for_case(&self, case_idx: usize) -> usize { - self.targets.get(case_idx).copied().unwrap_or(self.default) - } -} - -/// A source block that feeds into a dispatcher. -/// -/// Source blocks set state values that determine which case the dispatcher -/// will execute next. -#[derive(Debug, Clone)] -pub struct SourceBlock { - /// Block index. - pub block: usize, - - /// The state value this block sets. - /// `None` means the value could not be determined (unknown). - pub state_value: Option, - - /// The target case this state value leads to, if determinable. - pub target_case: Option, - - /// Whether this is a conditional source (branch vs jump). - pub is_conditional: bool, -} - -/// A detected opaque predicate. -/// -/// Opaque predicates are conditionals that always evaluate to the same result, -/// used to confuse analysis and add fake branches. -#[derive(Debug, Clone)] -pub struct OpaquePredicatePattern { - /// Block index containing the branch. - pub block: usize, - - /// The SSA variable holding the condition. - pub condition_var: SsaVarId, - - /// Target if condition is true. - pub true_target: usize, - - /// Target if condition is false. - pub false_target: usize, - - /// How the predicate resolves. - pub resolution: PredicateResolution, -} - -impl OpaquePredicatePattern { - /// Returns the target that will always be taken. - #[must_use] - pub fn actual_target(&self) -> Option { - match self.resolution { - PredicateResolution::AlwaysTrue => Some(self.true_target), - PredicateResolution::AlwaysFalse => Some(self.false_target), - _ => None, - } - } - - /// Returns the target that will never be taken. - #[must_use] - pub fn dead_target(&self) -> Option { - match self.resolution { - PredicateResolution::AlwaysTrue => Some(self.false_target), - PredicateResolution::AlwaysFalse => Some(self.true_target), - _ => None, - } - } -} - -/// How an opaque predicate resolves. -#[derive(Debug, Clone)] -pub enum PredicateResolution { - /// Always evaluates to true. - AlwaysTrue, - - /// Always evaluates to false. - AlwaysFalse, - - /// Depends on symbolic values (not truly opaque). - Symbolic(SymbolicExpr), - - /// Cannot determine. - Unknown, -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::analysis::SsaFunctionBuilder; - use crate::metadata::typesystem::PointerSize; - - #[test] - fn test_pattern_detector_creation() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - - let detector = PatternDetector::new(&ssa, PointerSize::Bit32); - assert_eq!(detector.ssa().block_count(), 1); - } - - #[test] - fn test_find_no_dispatchers_in_simple_function() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v = b.const_i32(42); - b.ret_val(v); - }); - }) - .unwrap(); - - let detector = PatternDetector::new(&ssa, PointerSize::Bit32); - let dispatchers = detector.find_dispatchers(); - assert!(dispatchers.is_empty()); - } - - #[test] - fn test_dispatcher_pattern_methods() { - // Test the helper methods on DispatcherPattern - let pattern = DispatcherPattern { - block: 0, - switch_var: SsaVarId::from_index(0), - targets: vec![1, 2, 3], - default: 4, - dispatch_expr: None, - state_vars: vec![], - }; - - assert_eq!(pattern.case_count(), 3); - assert_eq!(pattern.target_for_case(0), 1); - assert_eq!(pattern.target_for_case(1), 2); - assert_eq!(pattern.target_for_case(2), 3); - assert_eq!(pattern.target_for_case(10), 4); // Out of bounds -> default - } - - #[test] - fn test_opaque_predicate_methods() { - let cond = SsaVarId::from_index(0); - let pattern = OpaquePredicatePattern { - block: 0, - condition_var: cond, - true_target: 1, - false_target: 2, - resolution: PredicateResolution::AlwaysTrue, - }; - - assert_eq!(pattern.actual_target(), Some(1)); - assert_eq!(pattern.dead_target(), Some(2)); - - let pattern2 = OpaquePredicatePattern { - block: 0, - condition_var: cond, - true_target: 1, - false_target: 2, - resolution: PredicateResolution::AlwaysFalse, - }; - - assert_eq!(pattern2.actual_target(), Some(2)); - assert_eq!(pattern2.dead_target(), Some(1)); - } -} diff --git a/dotscope/src/analysis/ssa/phi.rs b/dotscope/src/analysis/ssa/phi.rs deleted file mode 100644 index 68277b5a..00000000 --- a/dotscope/src/analysis/ssa/phi.rs +++ /dev/null @@ -1,435 +0,0 @@ -//! Phi node representation for SSA form. -//! -//! Phi nodes are the cornerstone of SSA form - they represent the merging of -//! values at control flow join points. When multiple control flow paths converge, -//! a phi node selects which value to use based on which path was taken. -//! -//! # Semantics -//! -//! A phi node `v3 = phi(v1 from B1, v2 from B2)` means: -//! - If control came from block B1, use value v1 -//! - If control came from block B2, use value v2 -//! -//! Phi nodes are not real instructions - they are evaluated "instantaneously" -//! at the entry of a basic block, before any real instructions execute. -//! -//! # Placement -//! -//! Phi nodes are placed at dominance frontiers during SSA construction. -//! A block B needs a phi node for variable V if: -//! 1. B is in the dominance frontier of some block that defines V -//! 2. V is live at the entry of B -//! -//! # Thread Safety -//! -//! All types in this module are `Send` and `Sync`. - -use std::fmt; - -use crate::analysis::ssa::{SsaVarId, VariableOrigin}; - -/// An operand of a phi node - a value coming from a specific predecessor block. -/// -/// Each phi operand represents one possible value that could be selected, -/// associated with the predecessor block from which that value comes. -/// -/// # Examples -/// -/// ```rust,no_run -/// use dotscope::analysis::{PhiOperand, SsaVarId}; -/// -/// // Value coming from block 1 -/// let var = SsaVarId::from_index(0); -/// let operand = PhiOperand::new(var, 1); -/// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct PhiOperand { - /// The SSA variable providing the value. - value: SsaVarId, - /// The predecessor block from which this value comes. - predecessor: usize, -} - -impl PhiOperand { - /// Creates a new phi operand. - /// - /// # Arguments - /// - /// * `value` - The SSA variable providing the value - /// * `predecessor` - The block index from which this value comes - #[must_use] - pub const fn new(value: SsaVarId, predecessor: usize) -> Self { - Self { value, predecessor } - } - - /// Returns the SSA variable providing the value. - #[must_use] - pub const fn value(&self) -> SsaVarId { - self.value - } - - /// Returns the predecessor block index. - #[must_use] - pub const fn predecessor(&self) -> usize { - self.predecessor - } - - /// Updates the predecessor block index. - /// - /// Used when block merging redirects edges, requiring phi operands - /// to reference the new predecessor instead of the eliminated trampoline. - pub fn set_predecessor(&mut self, predecessor: usize) { - self.predecessor = predecessor; - } -} - -impl fmt::Display for PhiOperand { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} from B{}", self.value, self.predecessor) - } -} - -/// A phi node that merges values at a control flow join point. -/// -/// Phi nodes are placed at the beginning of basic blocks where control flow -/// from multiple predecessors converges. They select which value to use based -/// on which predecessor block was executed. -/// -/// # Invariants -/// -/// - Each phi node has exactly one operand for each predecessor of its block -/// - All operands must have the same type (enforced by SSA construction) -/// - The result variable is defined by this phi node -/// -/// # Examples -/// -/// ```rust,no_run -/// use dotscope::analysis::{PhiNode, PhiOperand, SsaVarId, VariableOrigin}; -/// -/// // Create phi: result = phi(v1 from B1, v2 from B2) -/// let v1 = SsaVarId::from_index(0); -/// let v2 = SsaVarId::from_index(1); -/// let result = SsaVarId::from_index(2); -/// let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); -/// phi.add_operand(PhiOperand::new(v1, 1)); -/// phi.add_operand(PhiOperand::new(v2, 2)); -/// ``` -#[derive(Debug, Clone)] -pub struct PhiNode { - /// The SSA variable defined by this phi node. - result: SsaVarId, - /// The original variable this phi merges (Argument or Local index). - origin: VariableOrigin, - /// Operands from each predecessor block. - operands: Vec, -} - -impl PhiNode { - /// Creates a new phi node for the given variable origin. - /// - /// The phi node is created with no operands - they must be added - /// during SSA construction as predecessor blocks are processed. - /// - /// # Arguments - /// - /// * `result` - The SSA variable that this phi node defines - /// * `origin` - The original variable (Argument or Local) this phi merges - #[must_use] - pub fn new(result: SsaVarId, origin: VariableOrigin) -> Self { - Self { - result, - origin, - operands: Vec::new(), - } - } - - /// Creates a new phi node with pre-allocated operand capacity. - /// - /// Use this when the number of predecessors is known in advance - /// to avoid reallocations. - /// - /// # Arguments - /// - /// * `result` - The SSA variable that this phi node defines - /// * `origin` - The original variable this phi merges - /// * `predecessor_count` - Expected number of predecessor blocks - #[must_use] - pub fn with_capacity( - result: SsaVarId, - origin: VariableOrigin, - predecessor_count: usize, - ) -> Self { - Self { - result, - origin, - operands: Vec::with_capacity(predecessor_count), - } - } - - /// Returns the SSA variable defined by this phi node. - #[must_use] - pub const fn result(&self) -> SsaVarId { - self.result - } - - /// Returns the original variable origin this phi merges. - #[must_use] - pub const fn origin(&self) -> VariableOrigin { - self.origin - } - - /// Sets the SSA variable defined by this phi node. - /// - /// Used during SSA construction when renaming variables. - pub fn set_result(&mut self, var: SsaVarId) { - self.result = var; - } - - /// Returns the operands of this phi node. - #[must_use] - pub fn operands(&self) -> &[PhiOperand] { - &self.operands - } - - /// Returns a mutable reference to the operands. - pub fn operands_mut(&mut self) -> &mut Vec { - &mut self.operands - } - - /// Adds an operand to this phi node. - /// - /// # Arguments - /// - /// * `operand` - The phi operand to add - pub fn add_operand(&mut self, operand: PhiOperand) { - self.operands.push(operand); - } - - /// Returns the number of operands. - #[must_use] - pub fn operand_count(&self) -> usize { - self.operands.len() - } - - /// Returns `true` if this phi node has no operands. - /// - /// A phi node with no operands is incomplete and should not appear - /// in a fully-constructed SSA form. - #[must_use] - pub fn is_empty(&self) -> bool { - self.operands.is_empty() - } - - /// Finds the operand coming from the specified predecessor block. - /// - /// # Arguments - /// - /// * `predecessor` - The block index to look up - /// - /// # Returns - /// - /// The phi operand if found, or `None` if no operand comes from that predecessor. - #[must_use] - pub fn operand_from(&self, predecessor: usize) -> Option<&PhiOperand> { - self.operands - .iter() - .find(|op| op.predecessor == predecessor) - } - - /// Returns all the SSA variables used by this phi node. - /// - /// This is useful for building def-use chains and liveness analysis. - pub fn used_variables(&self) -> impl Iterator + '_ { - self.operands.iter().map(|op| op.value) - } - - /// Retains only operands whose predecessor satisfies the predicate. - pub fn retain_operands bool>(&mut self, pred: F) { - self.operands.retain(|op| pred(op.predecessor)); - } - - /// Sets the origin of this phi node. - /// - /// This is used during local variable optimization to update indices - /// after unused locals are removed. - pub fn set_origin(&mut self, origin: VariableOrigin) { - self.origin = origin; - } - - /// Sets the operand value for a specific predecessor. - /// - /// If an operand from that predecessor already exists, it is updated. - /// Otherwise, a new operand is added. - /// - /// # Arguments - /// - /// * `predecessor` - The predecessor block index - /// * `value` - The SSA variable value from that predecessor - pub fn set_operand(&mut self, predecessor: usize, value: SsaVarId) { - if let Some(existing) = self - .operands - .iter_mut() - .find(|op| op.predecessor == predecessor) - { - existing.value = value; - } else { - self.operands.push(PhiOperand::new(value, predecessor)); - } - } -} - -impl fmt::Display for PhiNode { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{} = phi(", self.result)?; - for (i, operand) in self.operands.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{operand}")?; - } - write!(f, ")") - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_phi_operand_creation() { - let v = SsaVarId::from_index(0); - let operand = PhiOperand::new(v, 2); - assert_eq!(operand.value(), v); - assert_eq!(operand.predecessor(), 2); - } - - #[test] - fn test_phi_operand_display() { - let operand = PhiOperand::new(SsaVarId::from_index(3), 1); - assert_eq!(format!("{operand}"), "v3 from B1"); - } - - #[test] - fn test_phi_node_creation() { - let result = SsaVarId::from_index(0); - let phi = PhiNode::new(result, VariableOrigin::Local(0)); - assert_eq!(phi.result(), result); - assert_eq!(phi.origin(), VariableOrigin::Local(0)); - assert!(phi.is_empty()); - assert_eq!(phi.operand_count(), 0); - } - - #[test] - fn test_phi_node_with_capacity() { - let result = SsaVarId::from_index(0); - let phi = PhiNode::with_capacity(result, VariableOrigin::Argument(1), 3); - assert_eq!(phi.result(), result); - assert_eq!(phi.origin(), VariableOrigin::Argument(1)); - assert!(phi.is_empty()); - } - - #[test] - fn test_phi_node_add_operands() { - let result = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - - phi.add_operand(PhiOperand::new(v1, 0)); - phi.add_operand(PhiOperand::new(v2, 1)); - - assert!(!phi.is_empty()); - assert_eq!(phi.operand_count(), 2); - - let ops = phi.operands(); - assert_eq!(ops[0].value(), v1); - assert_eq!(ops[0].predecessor(), 0); - assert_eq!(ops[1].value(), v2); - assert_eq!(ops[1].predecessor(), 1); - } - - #[test] - fn test_phi_node_operand_from() { - let result = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v3 = SsaVarId::from_index(2); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 2)); - phi.add_operand(PhiOperand::new(v3, 4)); - - assert!(phi.operand_from(2).is_some()); - assert_eq!(phi.operand_from(2).unwrap().value(), v1); - - assert!(phi.operand_from(4).is_some()); - assert_eq!(phi.operand_from(4).unwrap().value(), v3); - - assert!(phi.operand_from(0).is_none()); - assert!(phi.operand_from(99).is_none()); - } - - #[test] - fn test_phi_node_set_operand_new() { - let result = SsaVarId::from_index(0); - let v10 = SsaVarId::from_index(1); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - - phi.set_operand(1, v10); - assert_eq!(phi.operand_count(), 1); - assert_eq!(phi.operand_from(1).unwrap().value(), v10); - } - - #[test] - fn test_phi_node_set_operand_update() { - let result = SsaVarId::from_index(0); - let v10 = SsaVarId::from_index(1); - let v20 = SsaVarId::from_index(2); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - - phi.set_operand(1, v10); - phi.set_operand(1, v20); // Update existing - - assert_eq!(phi.operand_count(), 1); - assert_eq!(phi.operand_from(1).unwrap().value(), v20); - } - - #[test] - fn test_phi_node_used_variables() { - let result = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let v3 = SsaVarId::from_index(3); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 0)); - phi.add_operand(PhiOperand::new(v2, 1)); - phi.add_operand(PhiOperand::new(v3, 2)); - - let used: Vec<_> = phi.used_variables().collect(); - assert_eq!(used.len(), 3); - assert!(used.contains(&v1)); - assert!(used.contains(&v2)); - assert!(used.contains(&v3)); - } - - #[test] - fn test_phi_node_display() { - let mut phi = PhiNode::new(SsaVarId::from_index(5), VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(SsaVarId::from_index(1), 0)); - phi.add_operand(PhiOperand::new(SsaVarId::from_index(2), 1)); - - let display = format!("{phi}"); - assert_eq!(display, "v5 = phi(v1 from B0, v2 from B1)"); - } - - #[test] - fn test_phi_node_display_empty() { - let phi = PhiNode::new(SsaVarId::from_index(3), VariableOrigin::Local(0)); - assert_eq!(format!("{phi}"), "v3 = phi()"); - } - - #[test] - fn test_phi_node_display_single_operand() { - let mut phi = PhiNode::new(SsaVarId::from_index(7), VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(SsaVarId::from_index(4), 2)); - assert_eq!(format!("{phi}"), "v7 = phi(v4 from B2)"); - } -} diff --git a/dotscope/src/analysis/ssa/phis.rs b/dotscope/src/analysis/ssa/phis.rs deleted file mode 100644 index acc3fcc7..00000000 --- a/dotscope/src/analysis/ssa/phis.rs +++ /dev/null @@ -1,833 +0,0 @@ -//! PHI node analysis and placement utilities. -//! -//! This module provides utilities for PHI nodes in SSA form: -//! -//! - [`PhiAnalyzer`] - Identifies patterns like trivial PHIs (single unique source), -//! uniform constants (all operands resolve to the same value), and finding PHI definitions. -//! - [`place_pruned_phis`] - Shared pruned phi placement algorithm used by both -//! `SsaConverter` (initial construction) and `SsaRebuilder` (repair). -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::analysis::{PhiAnalyzer, ConstEvaluator, SsaFunction}; -//! -//! let analyzer = PhiAnalyzer::new(&ssa); -//! -//! // Check if a PHI is trivial (has single unique non-self source) -//! if let Some(source) = analyzer.is_trivial(phi) { -//! println!("PHI can be replaced with copy from {:?}", source); -//! } -//! -//! // Check if all PHI operands resolve to the same constant -//! let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); -//! if let Some(value) = analyzer.uniform_constant(phi, &mut evaluator) { -//! println!("PHI always produces: {:?}", value); -//! } -//! ``` - -use std::collections::{BTreeMap, BTreeSet}; - -use crate::{ - analysis::ssa::{ - ConstEvaluator, ConstValue, PhiNode, PhiOperand, SsaBlock, SsaFunction, SsaOp, SsaVarId, - VariableOrigin, - }, - utils::{graph::NodeId, BitSet}, -}; - -/// Analyzes PHI nodes for various patterns. -/// -/// This struct provides methods for common PHI node analysis tasks: -/// - Detecting trivial PHIs that can be replaced with copies -/// - Finding PHIs where all operands resolve to the same constant -/// - Looking up PHI operands by predecessor block -/// - Finding the PHI node that defines a variable -pub struct PhiAnalyzer<'a> { - /// Reference to the SSA function being analyzed. - ssa: &'a SsaFunction, -} - -impl<'a> PhiAnalyzer<'a> { - /// Creates a new PHI analyzer for the given SSA function. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - #[must_use] - pub fn new(ssa: &'a SsaFunction) -> Self { - Self { ssa } - } - - /// Returns a reference to the SSA function being analyzed. - #[must_use] - pub fn ssa(&self) -> &SsaFunction { - self.ssa - } - - /// Checks if a PHI is trivial (has a single unique non-self source). - /// - /// A trivial PHI can be replaced with a simple copy operation. - /// This occurs when all non-self-referential operands point to the - /// same source variable. - /// - /// # Arguments - /// - /// * `phi` - The PHI node to analyze. - /// - /// # Returns - /// - /// `Some(source)` if the PHI has exactly one unique non-self source, - /// `None` otherwise. - /// - /// # Examples - /// - /// ```text - /// // Trivial PHI (can be replaced with: result = v1) - /// result = phi(v1, v1, result) // Returns Some(v1) - /// - /// // Non-trivial PHI (multiple different sources) - /// result = phi(v1, v2) // Returns None - /// - /// // Non-trivial PHI (only self-references, unreachable) - /// result = phi(result, result) // Returns None - /// ``` - #[must_use] - pub fn is_trivial(&self, phi: &PhiNode) -> Option { - let result = phi.result(); - - // Collect non-self-referential operands - let unique_sources: BTreeSet = phi - .operands() - .iter() - .map(PhiOperand::value) - .filter(|&v| v != result) - .collect(); - - // Trivial if exactly one unique non-self source - if unique_sources.len() == 1 { - let source = unique_sources.into_iter().next()?; - - // Check if replacing result with source would create a self-referential instruction. - if self.ssa.would_create_self_reference(source, result) { - return None; - } - - Some(source) - } else { - None - } - } - - /// Checks if a PHI is fully self-referential (all operands reference the PHI's result). - /// - /// A fully self-referential PHI indicates unreachable code or undefined behavior, - /// since there's no external value entering the PHI. Such PHIs can be safely removed. - /// - /// # Arguments - /// - /// * `phi` - The PHI node to analyze. - /// - /// # Returns - /// - /// `true` if all operands reference the PHI's own result variable, `false` otherwise. - /// - /// # Examples - /// - /// ```text - /// // Fully self-referential (returns true) - /// result = phi(result, result) - /// - /// // Not fully self-referential (returns false) - /// result = phi(v1, result) - /// result = phi(v1, v2) - /// ``` - #[must_use] - pub fn is_fully_self_referential(&self, phi: &PhiNode) -> bool { - let result = phi.result(); - !phi.operands().is_empty() && phi.operands().iter().all(|op| op.value() == result) - } - - /// Analyzes a PHI to determine its trivial status. - /// - /// This is the comprehensive analysis method that distinguishes between: - /// - Trivial PHIs with a single replacement value - /// - Fully self-referential PHIs that should be removed - /// - Non-trivial PHIs that must be kept - /// - /// # Arguments - /// - /// * `phi` - The PHI node to analyze. - /// - /// # Returns - /// - /// - `Some(Some(var))` - PHI is trivial, can be replaced with `var` - /// - `Some(None)` - PHI is fully self-referential, can be removed - /// - `None` - PHI is not trivial, must be kept - #[must_use] - pub fn analyze_trivial(&self, phi: &PhiNode) -> Option> { - // Check if trivial with a replacement value - if let Some(source) = self.is_trivial(phi) { - return Some(Some(source)); - } - - // Check if fully self-referential (can be removed) - if self.is_fully_self_referential(phi) { - return Some(None); - } - - // Not trivial - None - } - - /// Finds all trivial PHI nodes in the SSA function. - /// - /// Scans all reachable blocks for PHI nodes that are either: - /// - Trivial with a single replacement value - /// - Fully self-referential and can be removed - /// - /// # Arguments - /// - /// * `reachable` - Set of reachable block indices to scan. - /// - /// # Returns - /// - /// A vector of `(block_idx, phi_idx, replacement)` tuples where: - /// - `replacement = Some(var)` - PHI can be replaced with `var` - /// - `replacement = None` - PHI is fully self-referential and can be removed - #[must_use] - pub fn find_all_trivial( - &self, - reachable: &BTreeSet, - ) -> Vec<(usize, usize, Option)> { - let mut trivial = Vec::new(); - - for &block_idx in reachable { - if let Some(block) = self.ssa.block(block_idx) { - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - if let Some(replacement) = self.analyze_trivial(phi) { - trivial.push((block_idx, phi_idx, replacement)); - } - } - } - } - - trivial - } - - /// Collects all copy-like operations in the SSA function. - /// - /// This method identifies all operations that are effectively copies: - /// - Explicit `Copy` instructions: `dest = copy src` - /// - Trivial phi nodes: `dest = phi(src, src, ...)` where all non-self operands are identical - /// - /// This is the unified entry point for copy detection, used by copy propagation - /// and other optimizations that need to identify copy relationships. - /// - /// # Returns - /// - /// A map from each copy destination to its immediate source. - /// - /// # Example - /// - /// ```text - /// // Given: - /// v1 = copy v0 // Explicit copy - /// v2 = phi(v0, v0) // Trivial phi (all same source) - /// v3 = phi(v0, v3) // Trivial phi (self-ref excluded) - /// v4 = phi(v0, v1) // Non-trivial (different sources) - /// - /// // Returns: {v1 → v0, v2 → v0, v3 → v0} - /// ``` - #[must_use] - pub fn collect_all_copies(&self) -> BTreeMap { - let mut copies = BTreeMap::new(); - - for block in self.ssa.blocks() { - // Collect explicit copy instructions - for instr in block.instructions() { - if let SsaOp::Copy { dest, src } = instr.op() { - copies.insert(*dest, *src); - } - } - - // Collect trivial phi nodes (effectively copies) - for phi in block.phi_nodes() { - if let Some(source) = self.is_trivial(phi) { - copies.insert(phi.result(), source); - } - } - } - - copies - } - - /// Checks if all PHI operands resolve to the same constant. - /// - /// This is useful for detecting PHIs that always produce the same value, - /// which can be replaced with a constant assignment. - /// - /// # Arguments - /// - /// * `phi` - The PHI node to analyze. - /// * `evaluator` - A constant evaluator for resolving operand values. - /// - /// # Returns - /// - /// `Some(value)` if all operands evaluate to the same constant, - /// `None` if operands differ, cannot be evaluated, or PHI is empty. - /// - /// # Examples - /// - /// ```text - /// // Given: v1 = 42, v2 = 42 - /// result = phi(v1, v2) // Returns Some(42) - /// - /// // Given: v1 = 42, v2 = 99 - /// result = phi(v1, v2) // Returns None (values differ) - /// - /// // Given: v1 = 42, v2 = unknown - /// result = phi(v1, v2) // Returns None (v2 not constant) - /// ``` - pub fn uniform_constant( - &self, - phi: &PhiNode, - evaluator: &mut ConstEvaluator, - ) -> Option { - let operands = phi.operands(); - - // Empty PHI has no uniform value - if operands.is_empty() { - return None; - } - - // Get the first operand's constant value - let first_value = evaluator.evaluate_var(operands.first()?.value())?; - - // Check that all other operands have the same value - for operand in operands.iter().skip(1) { - let value = evaluator.evaluate_var(operand.value())?; - if value != first_value { - return None; - } - } - - Some(first_value) - } - - /// Finds the PHI node that defines a variable. - /// - /// This delegates to [`SsaFunction::find_phi_defining`] for the actual lookup, - /// which uses O(1) lookup via the variable's definition site when available. - /// - /// # Arguments - /// - /// * `var` - The SSA variable ID to find the defining PHI for. - /// - /// # Returns - /// - /// `Some((block_idx, &PhiNode))` if the variable is defined by a PHI node, - /// `None` if the variable is not defined by a PHI or doesn't exist. - #[must_use] - pub fn find_phi_defining(&self, var: SsaVarId) -> Option<(usize, &PhiNode)> { - self.ssa.find_phi_defining(var) - } -} - -/// Callback type for resolving Leave targets in exception handler blocks. -pub(crate) type LeaveTargetFn<'a> = dyn Fn(usize, &[SsaBlock]) -> Option + 'a; - -/// Places phi nodes at iterated dominance frontier blocks, pruned by liveness. -/// -/// This implements the standard IDF phi placement algorithm from Cytron et al., -/// with pruning based on liveness analysis (only placing phis where the variable -/// is live-in). -/// -/// Data structures are keyed by `u32` group IDs. The `group_to_origin` mapping -/// translates group IDs to `VariableOrigin` values for the created phi nodes. -/// -/// # Arguments -/// -/// * `blocks` - Mutable slice of SSA blocks to insert phi nodes into -/// * `defs` - Definition sites for each group ID -/// * `live_in` - Liveness information: for each group ID, the set of blocks where it's live-in -/// * `dominance_frontiers` - Precomputed dominance frontiers -/// * `reachable` - Set of reachable block indices (if `None`, all blocks are considered reachable) -/// * `group_filter` - Predicate to select which group IDs to process -/// * `group_to_origin` - Maps group ID to the `VariableOrigin` to use for the phi node -/// * `leave_target_fn` - Optional function to get Leave targets for exception handler blocks -/// (used to add extra phi placement points for handler exits) -/// -/// # Returns -/// A list of `(block_idx, group)` pairs for each phi node placed, in the order -/// they were added to each block. This allows callers to associate phi nodes -/// with their rename groups during the rename phase. -#[allow(clippy::too_many_arguments)] -pub(crate) fn place_pruned_phis( - blocks: &mut [SsaBlock], - defs: &BTreeMap, - live_in: &BTreeMap, - dominance_frontiers: &[BitSet], - reachable: Option<&BitSet>, - group_filter: &dyn Fn(u32) -> bool, - group_to_origin: &dyn Fn(u32) -> VariableOrigin, - leave_target_fn: Option<&LeaveTargetFn<'_>>, -) -> Vec<(usize, u32)> { - let block_count = blocks.len(); - let mut placements: Vec<(usize, u32)> = Vec::new(); - - for (&group, def_blocks) in defs { - if !group_filter(group) { - continue; - } - - // Compute iterated dominance frontier - let mut phi_blocks = BitSet::new(block_count); - let mut worklist: Vec = def_blocks.iter().collect(); - - while let Some(block_idx) = worklist.pop() { - let node_id = NodeId::new(block_idx); - if let Some(frontier) = dominance_frontiers.get(node_id.index()) { - for frontier_idx in frontier.iter() { - let is_reachable = reachable.is_none_or(|r| r.contains(frontier_idx)); - if frontier_idx < block_count && is_reachable && phi_blocks.insert(frontier_idx) - { - worklist.push(frontier_idx); - } - } - } - - // For exception handler blocks, use Leave targets as phi placement points - if let Some(leave_fn) = leave_target_fn { - if let Some(target) = leave_fn(block_idx, blocks) { - let is_reachable = reachable.is_none_or(|r| r.contains(target)); - if target < block_count && is_reachable && phi_blocks.insert(target) { - worklist.push(target); - } - } - } - } - - // Pruned SSA: only place phi if variable is live at the frontier block. - // If no liveness data for this group, place unconditionally (used by - // converter for Stack origins that don't have liveness tracking). - let group_live_in = live_in.get(&group); - - for phi_block_idx in phi_blocks.iter() { - if let Some(live_set) = group_live_in { - if !live_set.contains(phi_block_idx) { - continue; - } - } - - if let Some(block) = blocks.get_mut(phi_block_idx) { - let origin = group_to_origin(group); - // Phi result IDs are temporary placeholders; they will be replaced - // during the rename phase with properly allocated variable IDs. - let phi = PhiNode::new(SsaVarId::PLACEHOLDER, origin); - block.add_phi(phi); - placements.push((phi_block_idx, group)); - } - } - } - - placements -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeSet; - - use crate::{ - analysis::ssa::{ - ConstEvaluator, ConstValue, DefSite, PhiAnalyzer, PhiNode, PhiOperand, SsaBlock, - SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, VariableOrigin, - }, - metadata::typesystem::PointerSize, - }; - - #[test] - fn test_phi_analyzer_creation() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - // Basic sanity check - assert_eq!(analyzer.ssa().num_args(), 0); - assert_eq!(analyzer.ssa().num_locals(), 0); - } - - #[test] - fn test_is_trivial_single_source() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let source = SsaVarId::from_index(1); - - // phi(v1, v1) - trivial, single unique source - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(source, 0)); - phi.add_operand(PhiOperand::new(source, 1)); - - assert_eq!(analyzer.is_trivial(&phi), Some(source)); - } - - #[test] - fn test_is_trivial_with_self_reference() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let source = SsaVarId::from_index(1); - - // phi(v1, result, v1) - trivial, self-references are ignored - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(source, 0)); - phi.add_operand(PhiOperand::new(result, 1)); // self-reference - phi.add_operand(PhiOperand::new(source, 2)); - - assert_eq!(analyzer.is_trivial(&phi), Some(source)); - } - - #[test] - fn test_is_trivial_multiple_sources() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let source1 = SsaVarId::from_index(1); - let source2 = SsaVarId::from_index(2); - - // phi(v1, v2) - not trivial, multiple different sources - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(source1, 0)); - phi.add_operand(PhiOperand::new(source2, 1)); - - assert_eq!(analyzer.is_trivial(&phi), None); - } - - #[test] - fn test_is_trivial_only_self_references() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - - // phi(result, result) - not trivial, only self-references (unreachable) - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(result, 0)); - phi.add_operand(PhiOperand::new(result, 1)); - - assert_eq!(analyzer.is_trivial(&phi), None); - } - - #[test] - fn test_uniform_constant_same_values() { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - - // v1 = 42 - let v1_id = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - // v2 = 42 - let v2_id = ssa.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::instruction(0, 1), - SsaType::Unknown, - ); - - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1_id, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v2_id, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let analyzer = PhiAnalyzer::new(&ssa); - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - - // phi(v1, v2) where both are 42 - let phi_result = SsaVarId::from_index(0); - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1_id, 0)); - phi.add_operand(PhiOperand::new(v2_id, 1)); - - assert_eq!( - analyzer.uniform_constant(&phi, &mut evaluator), - Some(ConstValue::I32(42)) - ); - } - - #[test] - fn test_uniform_constant_different_values() { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - - // v1 = 42 - let v1_id = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - // v2 = 99 - let v2_id = ssa.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::instruction(0, 1), - SsaType::Unknown, - ); - - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1_id, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v2_id, - value: ConstValue::I32(99), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let analyzer = PhiAnalyzer::new(&ssa); - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - - // phi(v1, v2) where v1=42 and v2=99 - let phi_result = SsaVarId::from_index(0); - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1_id, 0)); - phi.add_operand(PhiOperand::new(v2_id, 1)); - - assert_eq!(analyzer.uniform_constant(&phi, &mut evaluator), None); - } - - #[test] - fn test_uniform_constant_empty_phi() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - - // Empty PHI - let phi_result = SsaVarId::from_index(0); - let phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - - assert_eq!(analyzer.uniform_constant(&phi, &mut evaluator), None); - } - - #[test] - fn test_uniform_constant_non_constant_operand() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - let mut evaluator = ConstEvaluator::new(&ssa, PointerSize::Bit64); - - // phi(v1, v2) where neither is defined (not constant) - let phi_result = SsaVarId::from_index(0); - let v1_id = SsaVarId::from_index(1); - let v2_id = SsaVarId::from_index(2); - - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1_id, 0)); - phi.add_operand(PhiOperand::new(v2_id, 1)); - - assert_eq!(analyzer.uniform_constant(&phi, &mut evaluator), None); - } - - #[test] - fn test_find_defining_phi() { - let mut ssa = SsaFunction::new(0, 0); - - // Create a variable defined by a PHI - let phi_result_id = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - // Create block with PHI node - let mut block = SsaBlock::new(0); - let mut phi = PhiNode::new(phi_result_id, VariableOrigin::Local(0)); - let operand_id = SsaVarId::from_index(0); - phi.add_operand(PhiOperand::new(operand_id, 1)); - block.add_phi(phi); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let analyzer = PhiAnalyzer::new(&ssa); - - // Should find the PHI - let result = analyzer.find_phi_defining(phi_result_id); - assert!(result.is_some()); - let (block_idx, found_phi) = result.unwrap(); - assert_eq!(block_idx, 0); - assert_eq!(found_phi.result(), phi_result_id); - } - - #[test] - fn test_find_defining_not_phi() { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - - // Create a variable defined by a regular instruction (not PHI) - let var_id = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: var_id, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let analyzer = PhiAnalyzer::new(&ssa); - - // Should not find a PHI (variable is defined by Const, not PHI) - assert!(analyzer.find_phi_defining(var_id).is_none()); - } - - #[test] - fn test_is_fully_self_referential_true() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - - // phi(result, result) - fully self-referential - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(result, 0)); - phi.add_operand(PhiOperand::new(result, 1)); - - assert!(analyzer.is_fully_self_referential(&phi)); - } - - #[test] - fn test_is_fully_self_referential_false() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let source = SsaVarId::from_index(1); - - // phi(source, result) - not fully self-referential - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(source, 0)); - phi.add_operand(PhiOperand::new(result, 1)); - - assert!(!analyzer.is_fully_self_referential(&phi)); - } - - #[test] - fn test_is_fully_self_referential_empty() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - - // Empty phi - not fully self-referential - let phi = PhiNode::new(result, VariableOrigin::Local(0)); - - assert!(!analyzer.is_fully_self_referential(&phi)); - } - - #[test] - fn test_analyze_trivial_with_replacement() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let source = SsaVarId::from_index(1); - - // phi(source, source) - trivial with replacement - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(source, 0)); - phi.add_operand(PhiOperand::new(source, 1)); - - assert_eq!(analyzer.analyze_trivial(&phi), Some(Some(source))); - } - - #[test] - fn test_analyze_trivial_self_referential_removal() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - - // phi(result, result) - fully self-referential, should be removed - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(result, 0)); - phi.add_operand(PhiOperand::new(result, 1)); - - assert_eq!(analyzer.analyze_trivial(&phi), Some(None)); - } - - #[test] - fn test_analyze_trivial_not_trivial() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let source1 = SsaVarId::from_index(1); - let source2 = SsaVarId::from_index(2); - - // phi(source1, source2) - not trivial (different sources) - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(source1, 0)); - phi.add_operand(PhiOperand::new(source2, 1)); - - assert_eq!(analyzer.analyze_trivial(&phi), None); - } - - #[test] - fn test_find_all_trivial() { - let mut ssa = SsaFunction::new(0, 0); - - // Block 0: entry with trivial phi - let mut block0 = SsaBlock::new(0); - let phi_result1 = SsaVarId::from_index(0); - let source1 = SsaVarId::from_index(1); - let mut phi1 = PhiNode::new(phi_result1, VariableOrigin::Local(0)); - phi1.add_operand(PhiOperand::new(source1, 1)); // trivial: single source - block0.add_phi(phi1); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - ssa.add_block(block0); - - // Block 1: self-referential phi - let mut block1 = SsaBlock::new(1); - let phi_result2 = SsaVarId::from_index(2); - let mut phi2 = PhiNode::new(phi_result2, VariableOrigin::Local(1)); - phi2.add_operand(PhiOperand::new(phi_result2, 0)); // self-referential - phi2.add_operand(PhiOperand::new(phi_result2, 1)); - block1.add_phi(phi2); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block1); - - let analyzer = PhiAnalyzer::new(&ssa); - let reachable: BTreeSet = [0, 1].iter().copied().collect(); - - let trivial = analyzer.find_all_trivial(&reachable); - - // Should find 2 trivial PHIs - assert_eq!(trivial.len(), 2); - - // Block 0, phi 0: trivial with replacement source1 - assert!(trivial.contains(&(0, 0, Some(source1)))); - - // Block 1, phi 0: self-referential, no replacement - assert!(trivial.contains(&(1, 0, None))); - } -} diff --git a/dotscope/src/analysis/ssa/resolver.rs b/dotscope/src/analysis/ssa/resolver.rs index 22acbc5b..bcd8d2e7 100644 --- a/dotscope/src/analysis/ssa/resolver.rs +++ b/dotscope/src/analysis/ssa/resolver.rs @@ -1,9 +1,10 @@ //! Unified constant resolution for SSA variables. //! -//! The [`ValueResolver`] composes [`ConstEvaluator`], [`PhiAnalyzer`], and optionally -//! [`SsaEvaluator`] into a single reusable entry point for demand-driven constant -//! resolution. It replaces ad-hoc tracing logic (like the former `trace_to_constant`) -//! with a three-tier fallback strategy: +//! [`ValueResolver`] composes [`ConstEvaluator`](crate::analysis::ssa::ConstEvaluator), +//! [`PhiAnalyzer`](crate::analysis::ssa::PhiAnalyzer), and optionally +//! [`SsaEvaluator`](crate::analysis::ssa::SsaEvaluator) into a single reusable entry point +//! for demand-driven constant resolution. It replaces ad-hoc tracing logic (like the +//! former `trace_to_constant`) with a three-tier fallback strategy: //! //! 1. **ConstEvaluator** — handles all instruction-defined ops (arithmetic, bitwise, //! comparisons, conversions) with caching and cycle detection. @@ -28,7 +29,10 @@ //! ``` use crate::{ - analysis::ssa::{ConstEvaluator, ConstValue, PhiAnalyzer, SsaEvaluator, SsaFunction, SsaVarId}, + analysis::ssa::{ + target::{CilTarget, Target}, + ConstEvaluator, ConstValue, PhiAnalyzer, SsaEvaluator, SsaFunction, SsaVarId, + }, metadata::typesystem::PointerSize, }; #[cfg(feature = "compiler")] @@ -37,22 +41,23 @@ use crate::{compiler::CompilerContext, metadata::token::Token}; /// Demand-driven constant resolver composing multiple analysis components. /// /// Provides a unified API for resolving SSA variables to constant values, -/// combining the strengths of [`ConstEvaluator`] (instruction folding), -/// [`PhiAnalyzer`] (uniform PHI detection), and optionally [`SsaEvaluator`] +/// combining the strengths of [`ConstEvaluator`](crate::analysis::ssa::ConstEvaluator) +/// (instruction folding), [`PhiAnalyzer`](crate::analysis::ssa::PhiAnalyzer) (uniform +/// PHI detection), and optionally [`SsaEvaluator`](crate::analysis::ssa::SsaEvaluator) /// (path-aware tracing for variables that pure constant folding can't handle). -pub struct ValueResolver<'a> { - ssa: &'a SsaFunction, - evaluator: ConstEvaluator<'a>, - phi: PhiAnalyzer<'a>, +pub struct ValueResolver<'a, T: Target = CilTarget> { + ssa: &'a SsaFunction, + evaluator: ConstEvaluator<'a, T>, + phi: PhiAnalyzer<'a, T>, resolve_phis: bool, path_aware_fallback: bool, ptr_size: PointerSize, } -impl<'a> ValueResolver<'a> { +impl<'a, T: Target> ValueResolver<'a, T> { /// Creates a new resolver with PHI resolution enabled and path-aware fallback disabled. #[must_use] - pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self { + pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self { Self { ssa, evaluator: ConstEvaluator::new(ssa, ptr_size), @@ -73,28 +78,19 @@ impl<'a> ValueResolver<'a> { self } - /// Bulk-injects all known values for a method from the [`CompilerContext`]. - /// - /// This loads values discovered by earlier passes (e.g., constant propagation) - /// into the inner [`ConstEvaluator`] so they're available during resolution. - #[cfg(feature = "compiler")] - pub fn load_known_values(&mut self, ctx: &CompilerContext, method_token: Token) { - ctx.for_each_known_value(method_token, |var, val| { - self.evaluator.set_known(var, val.clone()); - }); - } - /// Injects a single known value into the resolver. - pub fn set_known(&mut self, var: SsaVarId, value: ConstValue) { + pub fn set_known(&mut self, var: SsaVarId, value: ConstValue) { self.evaluator.set_known(var, value); } /// Resolves a variable to a constant using a three-tier fallback strategy. /// - /// 1. Try [`ConstEvaluator`] (all instruction-defined ops). - /// 2. If PHI-defined, check for uniform constant via [`PhiAnalyzer`]. - /// 3. If path-aware fallback is enabled, try [`SsaEvaluator::resolve_with_trace`]. - pub fn resolve(&mut self, var: SsaVarId) -> Option { + /// 1. Try [`ConstEvaluator`](crate::analysis::ssa::ConstEvaluator) (all instruction-defined ops). + /// 2. If PHI-defined, check for uniform constant via + /// [`PhiAnalyzer`](crate::analysis::ssa::PhiAnalyzer). + /// 3. If path-aware fallback is enabled, try + /// [`SsaEvaluator::resolve_with_trace`](crate::analysis::ssa::SsaEvaluator). + pub fn resolve(&mut self, var: SsaVarId) -> Option> { // 1. Try ConstEvaluator (handles Const, arithmetic, bitwise, etc. with caching) if let Some(val) = self.evaluator.evaluate_var(var) { return Some(val); @@ -126,7 +122,7 @@ impl<'a> ValueResolver<'a> { } /// Resolves all variables to constants. Returns `None` if any variable can't be resolved. - pub fn resolve_all(&mut self, vars: &[SsaVarId]) -> Option> { + pub fn resolve_all(&mut self, vars: &[SsaVarId]) -> Option>> { let mut result = Vec::with_capacity(vars.len()); for &var in vars { result.push(self.resolve(var)?); @@ -135,9 +131,27 @@ impl<'a> ValueResolver<'a> { } } +#[cfg(feature = "compiler")] +impl<'a> ValueResolver<'a, CilTarget> { + /// Bulk-injects all known values for a method from the [`CompilerContext`]. + /// + /// This loads values discovered by earlier passes (e.g., constant propagation) + /// into the inner [`ConstEvaluator`](crate::analysis::ssa::ConstEvaluator) so they're + /// available during resolution. + /// CIL-specific because `CompilerContext` stores `ConstValue` values. + pub fn load_known_values(&mut self, ctx: &CompilerContext, method_token: Token) { + ctx.for_each_known_value(method_token, |var, val| { + self.evaluator.set_known(var, val.clone()); + }); + } +} + #[cfg(test)] mod tests { - use super::ValueResolver; + use super::*; + + #[cfg(feature = "compiler")] + use std::sync::Arc; use crate::{ analysis::ssa::{ @@ -146,10 +160,12 @@ mod tests { }, metadata::typesystem::PointerSize, }; + #[cfg(feature = "compiler")] + use crate::{analysis::CallGraph, compiler::CompilerContext, metadata::token::Token}; #[test] fn test_resolve_const() { - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let var_id = ssa.create_variable( @@ -172,7 +188,7 @@ mod tests { #[test] fn test_resolve_arithmetic_chain() { - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); // v0 = 10 @@ -211,6 +227,7 @@ mod tests { dest: v2_id, left: v0_id, right: v1_id, + flags: None, })); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); ssa.add_block(block); @@ -221,7 +238,7 @@ mod tests { #[test] fn test_resolve_phi_uniform() { - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); // Block 0: v0 = 42, jump to block 2 let v0_id = ssa.create_variable( @@ -277,7 +294,7 @@ mod tests { #[test] fn test_resolve_phi_non_uniform() { - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); // v0 = 42 let v0_id = ssa.create_variable( @@ -333,7 +350,7 @@ mod tests { #[test] fn test_set_known() { - let ssa = SsaFunction::new(0, 0); + let ssa: SsaFunction = SsaFunction::new(0, 0); let var_id = SsaVarId::from_index(0); let mut resolver = ValueResolver::new(&ssa, PointerSize::Bit64); @@ -345,11 +362,7 @@ mod tests { #[test] #[cfg(feature = "compiler")] fn test_load_known_values() { - use std::sync::Arc; - - use crate::{analysis::CallGraph, compiler::CompilerContext, metadata::token::Token}; - - let ssa = SsaFunction::new(0, 0); + let ssa: SsaFunction = SsaFunction::new(0, 0); let ctx = CompilerContext::new(Arc::new(CallGraph::new())); let method = Token::new(0x06000001); @@ -367,7 +380,7 @@ mod tests { #[test] fn test_resolve_all_success() { - let ssa = SsaFunction::new(0, 0); + let ssa: SsaFunction = SsaFunction::new(0, 0); let var1 = SsaVarId::from_index(0); let var2 = SsaVarId::from_index(1); @@ -382,7 +395,7 @@ mod tests { #[test] fn test_resolve_all_partial_failure() { - let ssa = SsaFunction::new(0, 0); + let ssa: SsaFunction = SsaFunction::new(0, 0); let var1 = SsaVarId::from_index(0); let var2 = SsaVarId::from_index(1); @@ -396,7 +409,7 @@ mod tests { #[test] fn test_resolve_all_empty() { - let ssa = SsaFunction::new(0, 0); + let ssa: SsaFunction = SsaFunction::new(0, 0); let mut resolver = ValueResolver::new(&ssa, PointerSize::Bit64); assert_eq!(resolver.resolve_all(&[]), Some(vec![])); @@ -404,7 +417,7 @@ mod tests { #[test] fn test_resolve_unknown_var() { - let ssa = SsaFunction::new(0, 0); + let ssa: SsaFunction = SsaFunction::new(0, 0); let unknown = SsaVarId::from_index(0); let mut resolver = ValueResolver::new(&ssa, PointerSize::Bit64); @@ -413,7 +426,7 @@ mod tests { #[test] fn test_resolve_xor_both_const() { - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let v0_id = ssa.create_variable( @@ -449,6 +462,7 @@ mod tests { dest: v2_id, left: v0_id, right: v1_id, + flags: None, })); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); ssa.add_block(block); diff --git a/dotscope/src/analysis/ssa/symbolic/evaluator.rs b/dotscope/src/analysis/ssa/symbolic/evaluator.rs deleted file mode 100644 index 5fbc736e..00000000 --- a/dotscope/src/analysis/ssa/symbolic/evaluator.rs +++ /dev/null @@ -1,369 +0,0 @@ -//! Symbolic evaluator for building expression trees from SSA operations. -//! -//! This module provides [`SymbolicEvaluator`], which tracks SSA operations -//! symbolically to build [`SymbolicExpr`] trees. Unlike concrete evaluation, -//! symbolic evaluation preserves the relationship between operations, enabling -//! constraint solving with Z3. - -use std::collections::HashMap; - -use crate::{ - analysis::ssa::{ - symbolic::{expr::SymbolicExpr, ops::SymbolicOp}, - ConstValue, SsaFunction, SsaOp, SsaVarId, - }, - metadata::typesystem::PointerSize, -}; - -/// Symbolic evaluator that builds expression trees from SSA operations. -/// -/// Unlike `SsaEvaluator` which computes concrete values, `SymbolicEvaluator` -/// tracks operations symbolically, building `SymbolicExpr` trees that can -/// later be solved using Z3. -#[derive(Debug)] -pub struct SymbolicEvaluator<'a> { - ssa: &'a SsaFunction, - /// Expressions computed for each variable. - expressions: HashMap, - /// Target pointer size for native int/uint masking. - pointer_size: PointerSize, -} - -impl<'a> SymbolicEvaluator<'a> { - /// Creates a new symbolic evaluator for the given SSA function. - /// - /// The evaluator starts with no known expressions. Use [`set_symbolic`](Self::set_symbolic) - /// or [`set_constant`](Self::set_constant) to initialize variable values before - /// evaluating blocks. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `ptr_size` - Target pointer size for native int/uint masking. - /// - /// # Returns - /// - /// A new evaluator with no expressions. - #[must_use] - pub fn new(ssa: &'a SsaFunction, ptr_size: PointerSize) -> Self { - Self { - ssa, - expressions: HashMap::new(), - pointer_size: ptr_size, - } - } - - /// Sets a variable to a named symbolic value. - /// - /// Named symbolic values represent external inputs whose concrete values - /// are unknown. Use this to mark the "state" variable in control flow - /// unflattening. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to set. - /// * `name` - The symbolic name (e.g., "state"). - pub fn set_symbolic(&mut self, var: SsaVarId, name: impl Into) { - self.expressions.insert(var, SymbolicExpr::named(name)); - } - - /// Sets a variable to a typed constant value. - /// - /// Use this to provide known initial values for variables with type preservation. - /// The caller is responsible for providing the correct `ConstValue` type. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to set. - /// * `value` - The typed constant value. - pub fn set_constant(&mut self, var: SsaVarId, value: ConstValue) { - self.expressions.insert(var, SymbolicExpr::constant(value)); - } - - /// Gets the expression for a variable, if known. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to look up. - /// - /// # Returns - /// - /// The expression for the variable, or `None` if not yet evaluated. - #[must_use] - pub fn get_expression(&self, var: SsaVarId) -> Option<&SymbolicExpr> { - self.expressions.get(&var) - } - - /// Gets the expression for a variable, simplified. - /// - /// Returns a copy of the expression with constant folding and algebraic - /// simplifications applied. - /// - /// # Arguments - /// - /// * `var` - The SSA variable to look up. - /// - /// # Returns - /// - /// A simplified copy of the expression, or `None` if not yet evaluated. - #[must_use] - pub fn get_simplified(&self, var: SsaVarId) -> Option { - self.expressions - .get(&var) - .map(|e| e.simplify(self.pointer_size)) - } - - /// Returns all computed expressions. - /// - /// # Returns - /// - /// A reference to the map from variable IDs to their symbolic expressions. - #[must_use] - pub fn expressions(&self) -> &HashMap { - &self.expressions - } - - /// Evaluates all instructions in a block symbolically. - /// - /// Processes each instruction in the block, building symbolic expressions - /// for any variables they define. - /// - /// # Arguments - /// - /// * `block_idx` - The index of the block to evaluate. - pub fn evaluate_block(&mut self, block_idx: usize) { - let Some(block) = self.ssa.block(block_idx) else { - return; - }; - - for instr in block.instructions() { - self.evaluate_op(instr.op()); - } - } - - /// Evaluates a sequence of blocks in order. - /// - /// # Arguments - /// - /// * `block_indices` - The indices of blocks to evaluate, in order. - pub fn evaluate_blocks(&mut self, block_indices: &[usize]) { - for &block_idx in block_indices { - self.evaluate_block(block_idx); - } - } - - /// Evaluates a single SSA operation symbolically. - /// - /// Builds a symbolic expression for the operation's result based on its - /// operands. If operands have known expressions, those are used; otherwise, - /// the operands are treated as symbolic variables. - /// - /// # Arguments - /// - /// * `op` - The SSA operation to evaluate. - pub fn evaluate_op(&mut self, op: &SsaOp) { - match op { - SsaOp::Const { dest, value } => { - self.expressions - .insert(*dest, SymbolicExpr::constant(value.clone())); - } - - SsaOp::Copy { dest, src } => { - if let Some(expr) = self.expressions.get(src) { - self.expressions.insert(*dest, expr.clone()); - } else { - self.expressions.insert(*dest, SymbolicExpr::variable(*src)); - } - } - - SsaOp::Add { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::Add); - } - SsaOp::Sub { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::Sub); - } - SsaOp::Mul { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::Mul); - } - SsaOp::Div { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::DivU - } else { - SymbolicOp::DivS - }; - self.eval_binary(*dest, *left, *right, op); - } - SsaOp::Rem { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::RemU - } else { - SymbolicOp::RemS - }; - self.eval_binary(*dest, *left, *right, op); - } - SsaOp::Xor { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::Xor); - } - SsaOp::And { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::And); - } - SsaOp::Or { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::Or); - } - SsaOp::Shl { - dest, - value, - amount, - } => { - self.eval_binary(*dest, *value, *amount, SymbolicOp::Shl); - } - SsaOp::Shr { - dest, - value, - amount, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::ShrU - } else { - SymbolicOp::ShrS - }; - self.eval_binary(*dest, *value, *amount, op); - } - SsaOp::Neg { dest, operand } => { - self.eval_unary(*dest, *operand, SymbolicOp::Neg); - } - SsaOp::Not { dest, operand } => { - self.eval_unary(*dest, *operand, SymbolicOp::Not); - } - SsaOp::Ceq { dest, left, right } => { - self.eval_binary(*dest, *left, *right, SymbolicOp::Eq); - } - SsaOp::Cgt { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::GtU - } else { - SymbolicOp::GtS - }; - self.eval_binary(*dest, *left, *right, op); - } - SsaOp::Clt { - dest, - left, - right, - unsigned, - } => { - let op = if *unsigned { - SymbolicOp::LtU - } else { - SymbolicOp::LtS - }; - self.eval_binary(*dest, *left, *right, op); - } - SsaOp::Conv { dest, operand, .. } => { - if let Some(expr) = self.expressions.get(operand) { - self.expressions.insert(*dest, expr.clone()); - } else { - self.expressions - .insert(*dest, SymbolicExpr::variable(*operand)); - } - } - - // LoadArg/LoadLocal: propagate expression from the source argument/local variable - SsaOp::LoadArg { dest, arg_index } => { - // Find the argument variable (version 0) and propagate its expression - if let Some(arg_var) = self - .ssa - .variables_from_argument(*arg_index) - .find(|v| v.version() == 0) - { - let src = arg_var.id(); - if let Some(expr) = self.expressions.get(&src).cloned() { - self.expressions.insert(*dest, expr); - } else { - self.expressions.insert(*dest, SymbolicExpr::variable(src)); - } - } - } - SsaOp::LoadLocal { dest, local_index } => { - if let Some(local_var) = self - .ssa - .variables_from_local(*local_index) - .find(|v| v.version() == 0) - { - let src = local_var.id(); - if let Some(expr) = self.expressions.get(&src).cloned() { - self.expressions.insert(*dest, expr); - } else { - self.expressions.insert(*dest, SymbolicExpr::variable(src)); - } - } - } - - _ => {} - } - } - - /// Evaluates a binary operation and stores the result expression. - /// - /// Looks up expressions for the operands (or creates variable references - /// if unknown), combines them with the operation, simplifies, and stores. - /// - /// # Arguments - /// - /// * `dest` - The destination variable for the result. - /// * `left` - The left operand variable. - /// * `right` - The right operand variable. - /// * `op` - The binary operation to apply. - fn eval_binary(&mut self, dest: SsaVarId, left: SsaVarId, right: SsaVarId, op: SymbolicOp) { - let left_expr = self - .expressions - .get(&left) - .cloned() - .unwrap_or_else(|| SymbolicExpr::variable(left)); - let right_expr = self - .expressions - .get(&right) - .cloned() - .unwrap_or_else(|| SymbolicExpr::variable(right)); - - let result = SymbolicExpr::binary(op, left_expr, right_expr).simplify(self.pointer_size); - self.expressions.insert(dest, result); - } - - /// Evaluates a unary operation and stores the result expression. - /// - /// Looks up the expression for the operand (or creates a variable reference - /// if unknown), applies the operation, simplifies, and stores. - /// - /// # Arguments - /// - /// * `dest` - The destination variable for the result. - /// * `operand` - The operand variable. - /// * `op` - The unary operation to apply. - fn eval_unary(&mut self, dest: SsaVarId, operand: SsaVarId, op: SymbolicOp) { - let operand_expr = self - .expressions - .get(&operand) - .cloned() - .unwrap_or_else(|| SymbolicExpr::variable(operand)); - - let result = SymbolicExpr::unary(op, operand_expr).simplify(self.pointer_size); - self.expressions.insert(dest, result); - } -} diff --git a/dotscope/src/analysis/ssa/symbolic/expr.rs b/dotscope/src/analysis/ssa/symbolic/expr.rs deleted file mode 100644 index 816c7614..00000000 --- a/dotscope/src/analysis/ssa/symbolic/expr.rs +++ /dev/null @@ -1,1134 +0,0 @@ -//! Symbolic expression tree representation. -//! -//! This module defines [`SymbolicExpr`], an intermediate representation for -//! symbolic values that can contain variables, constants, and operations. -//! Expressions map directly to SSA operations and can be translated to Z3 -//! for constraint solving. - -use std::{ - collections::{HashMap, HashSet}, - fmt, -}; - -use crate::{ - analysis::ssa::{symbolic::ops::SymbolicOp, ConstValue, SsaVarId}, - metadata::typesystem::PointerSize, -}; - -/// A symbolic expression that can contain variables, constants, and operations. -/// -/// This is our intermediate representation for symbolic values. It maps directly -/// to SSA operations and can be translated to Z3 for constraint solving. -#[derive(Debug, Clone, PartialEq)] -pub enum SymbolicExpr { - /// A typed constant value preserving CIL type information. - Constant(ConstValue), - - /// A symbolic variable (identified by SSA variable ID). - Variable(SsaVarId), - - /// A named symbolic variable (for external inputs like "state"). - NamedVar(String), - - /// A unary operation. - Unary { - /// The operation to perform. - op: SymbolicOp, - /// The operand. - operand: Box, - }, - - /// A binary operation. - Binary { - /// The operation to perform. - op: SymbolicOp, - /// The left operand. - left: Box, - /// The right operand. - right: Box, - }, -} - -impl SymbolicExpr { - /// Creates a constant expression from a typed `ConstValue`. - /// - /// This is the preferred constructor as it preserves type information. - /// - /// # Arguments - /// - /// * `value` - The typed constant value. - /// - /// # Returns - /// - /// A new [`SymbolicExpr::Constant`] containing the value. - #[must_use] - pub fn constant(value: ConstValue) -> Self { - Self::Constant(value) - } - - /// Creates a constant expression from an i64 value. - /// - /// The value is stored as `ConstValue::I64`. For type-preserving operations, - /// use [`constant`](Self::constant) with an explicit `ConstValue` instead. - /// - /// # Arguments - /// - /// * `value` - The integer value. - /// - /// # Returns - /// - /// A new [`SymbolicExpr::Constant`] containing the value as I64. - #[must_use] - pub fn constant_i64(value: i64) -> Self { - Self::Constant(ConstValue::I64(value)) - } - - /// Creates a constant expression from an i32 value. - /// - /// # Arguments - /// - /// * `value` - The integer value. - /// - /// # Returns - /// - /// A new [`SymbolicExpr::Constant`] containing the value as I32. - #[must_use] - pub fn constant_i32(value: i32) -> Self { - Self::Constant(ConstValue::I32(value)) - } - - /// Creates a variable expression from an SSA variable ID. - /// - /// # Arguments - /// - /// * `var` - The SSA variable identifier. - /// - /// # Returns - /// - /// A new [`SymbolicExpr::Variable`] referencing the given variable. - #[must_use] - pub const fn variable(var: SsaVarId) -> Self { - Self::Variable(var) - } - - /// Creates a named variable expression. - /// - /// Named variables are used for external inputs like "state" that aren't - /// tied to a specific SSA variable ID. - /// - /// # Arguments - /// - /// * `name` - The variable name (e.g., "state"). - /// - /// # Returns - /// - /// A new [`SymbolicExpr::NamedVar`] with the given name. - #[must_use] - pub fn named(name: impl Into) -> Self { - Self::NamedVar(name.into()) - } - - /// Creates a unary operation expression. - /// - /// # Arguments - /// - /// * `op` - The unary operation (Neg or Not). - /// * `operand` - The operand expression. - /// - /// # Returns - /// - /// A new [`SymbolicExpr::Unary`] applying the operation to the operand. - #[must_use] - pub fn unary(op: SymbolicOp, operand: Self) -> Self { - Self::Unary { - op, - operand: Box::new(operand), - } - } - - /// Creates a binary operation expression. - /// - /// # Arguments - /// - /// * `op` - The binary operation (Add, Sub, Mul, etc.). - /// * `left` - The left operand expression. - /// * `right` - The right operand expression. - /// - /// # Returns - /// - /// A new [`SymbolicExpr::Binary`] applying the operation to the operands. - #[must_use] - pub fn binary(op: SymbolicOp, left: Self, right: Self) -> Self { - Self::Binary { - op, - left: Box::new(left), - right: Box::new(right), - } - } - - /// Checks if this expression is a constant. - /// - /// # Returns - /// - /// `true` if this is a [`SymbolicExpr::Constant`]. - #[must_use] - pub const fn is_constant(&self) -> bool { - matches!(self, Self::Constant(_)) - } - - /// Checks if this expression is a variable. - /// - /// # Returns - /// - /// `true` if this is a [`SymbolicExpr::Variable`] or [`SymbolicExpr::NamedVar`]. - #[must_use] - pub const fn is_variable(&self) -> bool { - matches!(self, Self::Variable(_) | Self::NamedVar(_)) - } - - /// Returns the typed constant value if this is a constant expression. - /// - /// # Returns - /// - /// `Some(&ConstValue)` if this is a constant, `None` otherwise. - #[must_use] - pub const fn as_constant(&self) -> Option<&ConstValue> { - match self { - Self::Constant(v) => Some(v), - _ => None, - } - } - - /// Returns the constant as i64 if this is a constant expression. - /// - /// This extracts the raw i64 value regardless of the underlying type. - /// For type-preserving operations, use [`as_constant`](Self::as_constant) instead. - /// - /// # Returns - /// - /// `Some(i64)` if this is a constant with an integer value, `None` otherwise. - #[must_use] - pub fn as_i64(&self) -> Option { - match self { - Self::Constant(v) => v.as_i64(), - _ => None, - } - } - - /// Returns the SSA variable ID if this is a variable expression. - /// - /// # Returns - /// - /// `Some(var_id)` if this is a [`SymbolicExpr::Variable`], `None` otherwise. - /// Note: Returns `None` for [`SymbolicExpr::NamedVar`]. - #[must_use] - pub const fn as_variable(&self) -> Option { - match self { - Self::Variable(v) => Some(*v), - _ => None, - } - } - - /// Collects all SSA variables referenced in this expression. - /// - /// Recursively traverses the expression tree to find all variable references. - /// - /// # Returns - /// - /// A set of all [`SsaVarId`]s referenced in this expression. - #[must_use] - pub fn variables(&self) -> HashSet { - match self { - Self::Constant(_) | Self::NamedVar(_) => HashSet::new(), - Self::Variable(v) => { - let mut vars = HashSet::new(); - vars.insert(*v); - vars - } - Self::Unary { operand, .. } => operand.variables(), - Self::Binary { left, right, .. } => { - let mut vars = left.variables(); - vars.extend(right.variables()); - vars - } - } - } - - /// Collects all named variables referenced in this expression. - /// - /// Recursively traverses the expression tree to find all named variable references. - /// - /// # Returns - /// - /// A set of all variable names referenced in this expression. - #[must_use] - pub fn named_variables(&self) -> HashSet { - let mut vars = HashSet::new(); - self.collect_named_variables(&mut vars); - vars - } - - /// Recursively collects named variables into the provided set. - /// - /// # Arguments - /// - /// * `vars` - The set to collect variable names into. - fn collect_named_variables(&self, vars: &mut HashSet) { - match self { - Self::Constant(_) | Self::Variable(_) => {} - Self::NamedVar(name) => { - vars.insert(name.clone()); - } - Self::Unary { operand, .. } => operand.collect_named_variables(vars), - Self::Binary { left, right, .. } => { - left.collect_named_variables(vars); - right.collect_named_variables(vars); - } - } - } - - /// Evaluates the expression with the given SSA variable bindings. - /// - /// Recursively evaluates the expression tree, substituting bound variables - /// with their values and computing operations. Returns the result as a - /// typed `ConstValue`. - /// - /// # Arguments - /// - /// * `bindings` - Map from SSA variable IDs to their concrete values. - /// * `ptr_size` - Target pointer size for native int/uint masking. - /// - /// # Returns - /// - /// `Some(ConstValue)` if evaluation succeeds, `None` if any variable is unbound, - /// a named variable is encountered, or division by zero occurs. - #[must_use] - pub fn evaluate( - &self, - bindings: &HashMap, - ptr_size: PointerSize, - ) -> Option { - match self { - Self::Constant(v) => Some(v.clone()), - Self::Variable(var) => bindings.get(var).cloned(), - Self::NamedVar(_) => None, - Self::Unary { op, operand } => { - let v = operand.evaluate(bindings, ptr_size)?; - evaluate_unary_typed(*op, &v, ptr_size) - } - Self::Binary { op, left, right } => { - let l = left.evaluate(bindings, ptr_size)?; - let r = right.evaluate(bindings, ptr_size)?; - evaluate_binary_typed(*op, &l, &r, ptr_size) - } - } - } - - /// Evaluates the expression with named variable bindings. - /// - /// Similar to [`evaluate`](Self::evaluate), but uses string names instead - /// of SSA variable IDs. Useful for evaluating expressions with external inputs. - /// - /// # Arguments - /// - /// * `bindings` - Map from variable names to their concrete values. - /// * `ptr_size` - Target pointer size for native int/uint masking. - /// - /// # Returns - /// - /// `Some(ConstValue)` if evaluation succeeds, `None` if any named variable is - /// unbound, an SSA variable is encountered, or division by zero occurs. - #[must_use] - pub fn evaluate_named( - &self, - bindings: &HashMap<&str, ConstValue>, - ptr_size: PointerSize, - ) -> Option { - match self { - Self::Constant(v) => Some(v.clone()), - Self::Variable(_) => None, - Self::NamedVar(name) => bindings.get(name.as_str()).cloned(), - Self::Unary { op, operand } => { - let v = operand.evaluate_named(bindings, ptr_size)?; - evaluate_unary_typed(*op, &v, ptr_size) - } - Self::Binary { op, left, right } => { - let l = left.evaluate_named(bindings, ptr_size)?; - let r = right.evaluate_named(bindings, ptr_size)?; - evaluate_binary_typed(*op, &l, &r, ptr_size) - } - } - } - - /// Substitutes an SSA variable with a replacement expression. - /// - /// Creates a new expression tree where all occurrences of `var` are - /// replaced with `replacement`. - /// - /// # Arguments - /// - /// * `var` - The SSA variable ID to replace. - /// * `replacement` - The expression to substitute in place of the variable. - /// - /// # Returns - /// - /// A new expression with the substitution applied. - #[must_use] - pub fn substitute(&self, var: SsaVarId, replacement: &Self) -> Self { - match self { - Self::Constant(v) => Self::Constant(v.clone()), - Self::Variable(v) if *v == var => replacement.clone(), - Self::Variable(v) => Self::Variable(*v), - Self::NamedVar(name) => Self::NamedVar(name.clone()), - Self::Unary { op, operand } => Self::Unary { - op: *op, - operand: Box::new(operand.substitute(var, replacement)), - }, - Self::Binary { op, left, right } => Self::Binary { - op: *op, - left: Box::new(left.substitute(var, replacement)), - right: Box::new(right.substitute(var, replacement)), - }, - } - } - - /// Substitutes a named variable with a constant value. - /// - /// Creates a new expression tree where all occurrences of the named - /// variable are replaced with the constant value, then simplifies. - /// - /// # Arguments - /// - /// * `name` - The name of the variable to replace (e.g., "state"). - /// * `value` - The constant value to substitute. - /// - /// # Returns - /// - /// A simplified expression with the substitution applied. - /// - /// # Example - /// - /// ```rust,ignore - /// let expr = SymbolicExpr::binary( - /// SymbolicOp::Xor, - /// SymbolicExpr::named("state"), - /// SymbolicExpr::constant_i64(0x12345678), - /// ); - /// let result = expr.substitute_named("state", 100); - /// assert_eq!(result.as_i64(), Some(100 ^ 0x12345678)); - /// ``` - #[must_use] - pub fn substitute_named(&self, name: &str, value: i64, ptr_size: PointerSize) -> Self { - self.substitute_named_expr(name, &Self::Constant(ConstValue::I64(value))) - .simplify(ptr_size) - } - - /// Substitutes a named variable with a replacement expression. - /// - /// Creates a new expression tree where all occurrences of the named - /// variable are replaced with the replacement expression. - /// - /// # Arguments - /// - /// * `name` - The name of the variable to replace. - /// * `replacement` - The expression to substitute. - /// - /// # Returns - /// - /// A new expression with the substitution applied. - #[must_use] - pub fn substitute_named_expr(&self, name: &str, replacement: &Self) -> Self { - match self { - Self::Constant(v) => Self::Constant(v.clone()), - Self::Variable(v) => Self::Variable(*v), - Self::NamedVar(n) if n == name => replacement.clone(), - Self::NamedVar(n) => Self::NamedVar(n.clone()), - Self::Unary { op, operand } => Self::Unary { - op: *op, - operand: Box::new(operand.substitute_named_expr(name, replacement)), - }, - Self::Binary { op, left, right } => Self::Binary { - op: *op, - left: Box::new(left.substitute_named_expr(name, replacement)), - right: Box::new(right.substitute_named_expr(name, replacement)), - }, - } - } - - /// Simplifies the expression by evaluating constant subexpressions. - /// - /// Performs constant folding and applies algebraic identities: - /// - Folds constant operations (e.g., `5 + 3` → `8`) - /// - Removes identity operations (e.g., `x + 0` → `x`, `x * 1` → `x`) - /// - Simplifies zero multiplications (e.g., `x * 0` → `0`) - /// - Self-cancellation patterns (e.g., `x ^ x = 0`, `x - x = 0`) - /// - Double operation cancellation (e.g., `--x = x`, `~~x = x`) - /// - XOR constant cancellation (e.g., `(x ^ c) ^ c = x`) - /// - /// # Returns - /// - /// A simplified expression that is semantically equivalent to this one. - #[must_use] - #[allow(clippy::match_same_arms)] // Documents distinct algebraic identities: x*0=0 vs x&0=0 - pub fn simplify(&self, ptr_size: PointerSize) -> Self { - match self { - Self::Constant(_) | Self::Variable(_) | Self::NamedVar(_) => self.clone(), - Self::Unary { op, operand } => { - let simplified = operand.simplify(ptr_size); - - // Constant folding using typed operations - if let Self::Constant(v) = &simplified { - if let Some(result) = evaluate_unary_typed(*op, v, ptr_size) { - return Self::Constant(result); - } - } - - // Double operation cancellation: --x = x, ~~x = x - if let Self::Unary { - op: inner_op, - operand: inner_operand, - } = &simplified - { - if op == inner_op { - match op { - // --x = x (double negation) - SymbolicOp::Neg => return (**inner_operand).clone(), - // ~~x = x (double NOT) - SymbolicOp::Not => return (**inner_operand).clone(), - _ => {} - } - } - } - - Self::Unary { - op: *op, - operand: Box::new(simplified), - } - } - Self::Binary { op, left, right } => { - let left_simp = left.simplify(ptr_size); - let right_simp = right.simplify(ptr_size); - - // Both constants - evaluate using typed operations - if let (Self::Constant(l), Self::Constant(r)) = (&left_simp, &right_simp) { - if let Some(result) = evaluate_binary_typed(*op, l, r, ptr_size) { - return Self::Constant(result); - } - } - - // Self-cancellation patterns (when left == right) - if left_simp == right_simp { - match op { - // x ^ x = 0 - SymbolicOp::Xor => return Self::Constant(ConstValue::I32(0)), - // x - x = 0 - SymbolicOp::Sub => return Self::Constant(ConstValue::I32(0)), - // x | x = x - SymbolicOp::Or => return left_simp, - // x & x = x - SymbolicOp::And => return left_simp, - _ => {} - } - } - - // XOR constant cancellation: (x ^ c) ^ c = x - // This is critical for deobfuscation - many obfuscators use XOR with same constant - if *op == SymbolicOp::Xor { - if let Self::Constant(c1) = &right_simp { - if let Self::Binary { - op: SymbolicOp::Xor, - left: inner_left, - right: inner_right, - } = &left_simp - { - // (x ^ c1) ^ c1 = x - if let Self::Constant(c2) = inner_right.as_ref() { - if c1 == c2 { - return (**inner_left).clone(); - } - } - // (c1 ^ x) ^ c1 = x - if let Self::Constant(c2) = inner_left.as_ref() { - if c1 == c2 { - return (**inner_right).clone(); - } - } - } - } - // Also handle c ^ (x ^ c) = x - if let Self::Constant(c1) = &left_simp { - if let Self::Binary { - op: SymbolicOp::Xor, - left: inner_left, - right: inner_right, - } = &right_simp - { - // c1 ^ (x ^ c1) = x - if let Self::Constant(c2) = inner_right.as_ref() { - if c1 == c2 { - return (**inner_left).clone(); - } - } - // c1 ^ (c1 ^ x) = x - if let Self::Constant(c2) = inner_left.as_ref() { - if c1 == c2 { - return (**inner_right).clone(); - } - } - } - } - } - - // Identity simplifications - check if constant is zero/one - if let Self::Constant(r) = &right_simp { - if r.is_zero() { - match op { - // x + 0 = x, x - 0 = x - SymbolicOp::Add | SymbolicOp::Sub => return left_simp, - // x * 0 = 0 - SymbolicOp::Mul => return Self::Constant(ConstValue::I32(0)), - // x ^ 0 = x, x | 0 = x - SymbolicOp::Xor | SymbolicOp::Or => return left_simp, - // x & 0 = 0 - SymbolicOp::And => return Self::Constant(ConstValue::I32(0)), - _ => {} - } - } else if r.is_one() { - match op { - // x * 1 = x, x / 1 = x - SymbolicOp::Mul | SymbolicOp::DivS | SymbolicOp::DivU => { - return left_simp - } - _ => {} - } - } else if r.is_all_ones() { - match op { - // x & -1 = x - SymbolicOp::And => return left_simp, - // x | -1 = -1 - SymbolicOp::Or => return Self::Constant(r.clone()), - // x ^ -1 = ~x - SymbolicOp::Xor => { - return Self::Unary { - op: SymbolicOp::Not, - operand: Box::new(left_simp), - } - } - _ => {} - } - } - } - - if let Self::Constant(l) = &left_simp { - if l.is_zero() { - match op { - // 0 + x = x - SymbolicOp::Add => return right_simp, - // 0 - x = -x - SymbolicOp::Sub => { - return Self::Unary { - op: SymbolicOp::Neg, - operand: Box::new(right_simp), - } - } - // 0 * x = 0 - SymbolicOp::Mul => return Self::Constant(ConstValue::I32(0)), - // 0 ^ x = x, 0 | x = x - SymbolicOp::Xor | SymbolicOp::Or => return right_simp, - // 0 & x = 0 - SymbolicOp::And => return Self::Constant(ConstValue::I32(0)), - _ => {} - } - } else if l.is_one() { - // 1 * x = x - if *op == SymbolicOp::Mul { - return right_simp; - } - } else if l.is_all_ones() { - match op { - // -1 & x = x - SymbolicOp::And => return right_simp, - // -1 | x = -1 - SymbolicOp::Or => return Self::Constant(l.clone()), - // -1 ^ x = ~x - SymbolicOp::Xor => { - return Self::Unary { - op: SymbolicOp::Not, - operand: Box::new(right_simp), - } - } - _ => {} - } - } - } - - Self::Binary { - op: *op, - left: Box::new(left_simp), - right: Box::new(right_simp), - } - } - } - } - - /// Returns the depth of the expression tree. - /// - /// The depth is the length of the longest path from the root to a leaf. - /// Constants and variables have depth 0. - /// - /// # Returns - /// - /// The maximum nesting depth of operations in this expression. - #[must_use] - pub fn depth(&self) -> usize { - match self { - Self::Constant(_) | Self::Variable(_) | Self::NamedVar(_) => 0, - Self::Unary { operand, .. } => 1usize.saturating_add(operand.depth()), - Self::Binary { left, right, .. } => { - 1usize.saturating_add(left.depth().max(right.depth())) - } - } - } -} - -impl fmt::Display for SymbolicExpr { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Constant(v) => write!(f, "{v}"), - Self::Variable(var) => write!(f, "v{}", var.index()), - Self::NamedVar(name) => write!(f, "{name}"), - Self::Unary { op, operand } => write!(f, "({op}{operand})"), - Self::Binary { op, left, right } => write!(f, "({left} {op} {right})"), - } - } -} - -impl From for SymbolicExpr { - fn from(value: ConstValue) -> Self { - Self::Constant(value) - } -} - -impl From for SymbolicExpr { - fn from(value: i32) -> Self { - Self::Constant(ConstValue::I32(value)) - } -} - -impl From for SymbolicExpr { - fn from(value: i64) -> Self { - Self::Constant(ConstValue::I64(value)) - } -} - -/// Evaluates a unary operation on a typed constant value. -/// -/// Uses the type-preserving operations on `ConstValue`. -/// -/// # Arguments -/// -/// * `op` - The unary operation to perform (Neg or Not). -/// * `value` - The typed operand value. -/// -/// # Returns -/// -/// The result of the operation as a `ConstValue`, or `None` if the operation fails. -pub fn evaluate_unary_typed( - op: SymbolicOp, - value: &ConstValue, - ptr_size: PointerSize, -) -> Option { - match op { - SymbolicOp::Neg => value.negate(ptr_size), - SymbolicOp::Not => value.bitwise_not(ptr_size), - _ => None, - } -} - -/// Evaluates a binary operation on typed constant values. -/// -/// Uses the type-preserving operations on `ConstValue`. -/// -/// # Arguments -/// -/// * `op` - The binary operation to perform. -/// * `left` - The typed left operand value. -/// * `right` - The typed right operand value. -/// -/// # Returns -/// -/// The result of the operation as a `ConstValue`, or `None` if the operation -/// fails (e.g., division by zero, type mismatch). -pub fn evaluate_binary_typed( - op: SymbolicOp, - left: &ConstValue, - right: &ConstValue, - ptr_size: PointerSize, -) -> Option { - match op { - SymbolicOp::Add => left.add(right, ptr_size), - SymbolicOp::Sub => left.sub(right, ptr_size), - SymbolicOp::Mul => left.mul(right, ptr_size), - // div/rem handle signedness based on ConstValue's underlying type - SymbolicOp::DivS | SymbolicOp::DivU => left.div(right, ptr_size), - SymbolicOp::RemS | SymbolicOp::RemU => left.rem(right, ptr_size), - SymbolicOp::And => left.bitwise_and(right, ptr_size), - SymbolicOp::Or => left.bitwise_or(right, ptr_size), - SymbolicOp::Xor => left.bitwise_xor(right, ptr_size), - SymbolicOp::Shl => left.shl(right, ptr_size), - SymbolicOp::ShrS => left.shr(right, false, ptr_size), - SymbolicOp::ShrU => left.shr(right, true, ptr_size), - SymbolicOp::Eq => left.ceq(right), - SymbolicOp::Ne => left.ceq(right).map(|v| { - // Negate the equality result - if v.is_zero() { - ConstValue::I32(1) - } else { - ConstValue::I32(0) - } - }), - SymbolicOp::LtS => left.clt(right), - SymbolicOp::LtU => left.clt_un(right), - SymbolicOp::GtS => left.cgt(right), - SymbolicOp::GtU => left.cgt_un(right), - SymbolicOp::LeS => { - // x <= y is !(x > y) - left.cgt(right).map(|v| { - if v.is_zero() { - ConstValue::I32(1) - } else { - ConstValue::I32(0) - } - }) - } - SymbolicOp::LeU => { - // x <=u y is !(x >u y) - left.cgt_un(right).map(|v| { - if v.is_zero() { - ConstValue::I32(1) - } else { - ConstValue::I32(0) - } - }) - } - SymbolicOp::GeS => { - // x >= y is !(x < y) - left.clt(right).map(|v| { - if v.is_zero() { - ConstValue::I32(1) - } else { - ConstValue::I32(0) - } - }) - } - SymbolicOp::GeU => { - // x >=u y is !(x None, - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use super::*; - - use crate::{ - analysis::ssa::{ConstValue, SsaVarId}, - metadata::typesystem::PointerSize, - }; - - #[test] - fn test_constant_expression() { - let expr = SymbolicExpr::constant_i32(42); - assert!(expr.is_constant()); - assert_eq!(expr.as_constant(), Some(&ConstValue::I32(42))); - assert_eq!( - expr.evaluate(&HashMap::new(), PointerSize::Bit64), - Some(ConstValue::I32(42)) - ); - } - - #[test] - fn test_variable_expression() { - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::variable(var); - assert!(expr.is_variable()); - assert_eq!(expr.as_variable(), Some(var)); - - let mut bindings = HashMap::new(); - assert_eq!(expr.evaluate(&bindings, PointerSize::Bit64), None); - - bindings.insert(var, ConstValue::I32(100)); - assert_eq!( - expr.evaluate(&bindings, PointerSize::Bit64), - Some(ConstValue::I32(100)) - ); - } - - #[test] - fn test_simplify_constant_fold() { - let expr = SymbolicExpr::binary( - SymbolicOp::Add, - SymbolicExpr::constant_i32(10), - SymbolicExpr::constant_i32(20), - ); - let simplified = expr.simplify(PointerSize::Bit64); - assert_eq!(simplified, SymbolicExpr::constant(ConstValue::I32(30))); - } - - #[test] - fn test_simplify_identity() { - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::Add, - SymbolicExpr::variable(var), - SymbolicExpr::constant_i32(0), - ); - let simplified = expr.simplify(PointerSize::Bit64); - assert_eq!(simplified, SymbolicExpr::variable(var)); - } - - #[test] - fn test_expression_display() { - let expr = SymbolicExpr::binary( - SymbolicOp::RemU, - SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::named("state"), - SymbolicExpr::constant_i32(0x1234), - ), - SymbolicExpr::constant_i32(13), - ); - - let display = format!("{}", expr); - assert!(display.contains("state")); - assert!(display.contains("^")); - assert!(display.contains("%u")); - } - - #[test] - fn test_simplify_xor_self_cancellation() { - // x ^ x = 0 - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::variable(var), - SymbolicExpr::variable(var), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_simplify_sub_self_cancellation() { - // x - x = 0 - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::Sub, - SymbolicExpr::variable(var), - SymbolicExpr::variable(var), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::constant(ConstValue::I32(0)) - ); - } - - #[test] - fn test_simplify_or_self_idempotent() { - // x | x = x - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::Or, - SymbolicExpr::variable(var), - SymbolicExpr::variable(var), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_and_self_idempotent() { - // x & x = x - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::And, - SymbolicExpr::variable(var), - SymbolicExpr::variable(var), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_double_negation() { - // --x = x - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::unary( - SymbolicOp::Neg, - SymbolicExpr::unary(SymbolicOp::Neg, SymbolicExpr::variable(var)), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_double_not() { - // ~~x = x - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::unary( - SymbolicOp::Not, - SymbolicExpr::unary(SymbolicOp::Not, SymbolicExpr::variable(var)), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_xor_constant_cancellation() { - // (x ^ c) ^ c = x - let var = SsaVarId::from_index(0); - let const_val = ConstValue::I32(0x12345678_u32 as i32); - let expr = SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::variable(var), - SymbolicExpr::constant(const_val.clone()), - ), - SymbolicExpr::constant(const_val), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_xor_constant_cancellation_reversed() { - // c ^ (x ^ c) = x - let var = SsaVarId::from_index(0); - let const_val = ConstValue::I64(0xDEADBEEF_u32 as i64); - let expr = SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::constant(const_val.clone()), - SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::variable(var), - SymbolicExpr::constant(const_val), - ), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_and_all_ones() { - // x & -1 = x - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::And, - SymbolicExpr::variable(var), - SymbolicExpr::constant_i32(-1), - ); - assert_eq!( - expr.simplify(PointerSize::Bit64), - SymbolicExpr::variable(var) - ); - } - - #[test] - fn test_simplify_or_all_ones() { - // x | -1 = -1 - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::Or, - SymbolicExpr::variable(var), - SymbolicExpr::constant_i32(-1), - ); - // Result should have all ones - let simplified = expr.simplify(PointerSize::Bit64); - if let SymbolicExpr::Constant(v) = simplified { - assert!(v.is_all_ones()); - } else { - panic!("Expected constant result"); - } - } - - #[test] - fn test_simplify_xor_all_ones() { - // x ^ -1 = ~x - let var = SsaVarId::from_index(0); - let expr = SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::variable(var), - SymbolicExpr::constant_i32(-1), - ); - let simplified = expr.simplify(PointerSize::Bit64); - // Should be ~x (NOT operation) - assert!(matches!( - simplified, - SymbolicExpr::Unary { - op: SymbolicOp::Not, - .. - } - )); - } - - #[test] - fn test_simplify_confuserex_state_pattern() { - // ConfuserEx uses: ((state * mul) ^ xor_key) % mod_val - // After XOR cancellation: (state * mul) % mod_val - // This tests that XOR with same constant cancels out - let state = SymbolicExpr::named("state"); - let mul_const = ConstValue::I32(0x1234); - let xor_key = ConstValue::I32(0xABCD_u32 as i32); - - // Build: ((state * mul) ^ xor) ^ xor - let expr = SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::binary( - SymbolicOp::Xor, - SymbolicExpr::binary( - SymbolicOp::Mul, - state.clone(), - SymbolicExpr::constant(mul_const), - ), - SymbolicExpr::constant(xor_key.clone()), - ), - SymbolicExpr::constant(xor_key), - ); - - let simplified = expr.simplify(PointerSize::Bit64); - - // Should simplify to: state * mul - assert!(matches!( - simplified, - SymbolicExpr::Binary { - op: SymbolicOp::Mul, - .. - } - )); - } -} diff --git a/dotscope/src/analysis/ssa/symbolic/mod.rs b/dotscope/src/analysis/ssa/symbolic/mod.rs index 69d50b17..cbe9d32a 100644 --- a/dotscope/src/analysis/ssa/symbolic/mod.rs +++ b/dotscope/src/analysis/ssa/symbolic/mod.rs @@ -20,10 +20,16 @@ //! //! # Module Structure //! -//! - [`ops`] - Symbolic operation types ([`SymbolicOp`]: add, xor, comparison, etc.) -//! - [`expr`] - Symbolic expression tree representation ([`SymbolicExpr`]) -//! - [`solver`] - Z3-based constraint solver ([`Z3Solver`]) -//! - [`evaluator`] - Builds expressions from SSA operations ([`SymbolicEvaluator`]) +//! The target-agnostic IR + evaluator now live in `analyssa::analysis::symbolic`: +//! +//! - [`SymbolicOp`] - Symbolic operation kinds (add, xor, comparison, etc.) — `analyssa::analysis::symbolic::ops` +//! - [`SymbolicExpr`] - Symbolic expression tree — `analyssa::analysis::symbolic::expr` +//! - [`SymbolicEvaluator`] - Builds expressions from SSA operations — `analyssa::analysis::symbolic::evaluator` +//! +//! The Z3 binding stays in dotscope under the `z3` Cargo feature (analyssa +//! intentionally avoids the Z3 dependency): +//! +//! - [`Z3Solver`] - Z3-based constraint solver (this crate, gated on `z3`) //! //! # Use Cases //! @@ -81,19 +87,13 @@ //! [`Z3Solver`] and reuse it. The solver uses 32-bit bitvectors (`BV(32)`) for //! all computations, matching CIL's `int32` semantics. -mod evaluator; -mod expr; -mod ops; +// `SymbolicEvaluator`, `SymbolicExpr`, `SymbolicOp` live in +// `analyssa::analysis::symbolic`. The Z3 binding stays here behind the `z3` +// feature flag — Z3 is a dotscope-side dependency that analyssa intentionally +// doesn't take. +pub use analyssa::analysis::symbolic::{SymbolicEvaluator, SymbolicExpr, SymbolicOp}; -// The solver module requires the z3 dependency #[cfg(feature = "z3")] mod solver; - -// Re-export public types - SymbolicExpr, SymbolicOp, SymbolicEvaluator are always available -pub use evaluator::SymbolicEvaluator; -pub use expr::SymbolicExpr; -pub use ops::SymbolicOp; - -// Z3Solver is only available with the z3 feature #[cfg(feature = "z3")] pub use solver::Z3Solver; diff --git a/dotscope/src/analysis/ssa/symbolic/ops.rs b/dotscope/src/analysis/ssa/symbolic/ops.rs deleted file mode 100644 index b81fde51..00000000 --- a/dotscope/src/analysis/ssa/symbolic/ops.rs +++ /dev/null @@ -1,159 +0,0 @@ -//! Symbolic operation types. -//! -//! This module defines [`SymbolicOp`], the set of operations supported in -//! symbolic expressions. These operations map directly to CIL arithmetic -//! and logical operations, using 32-bit semantics. -//! -//! Operations are categorized as: -//! - **Arithmetic**: Add, Sub, Mul, Div, Rem, Neg -//! - **Bitwise**: And, Or, Xor, Not, Shl, Shr -//! - **Comparison**: Eq, Ne, Lt, Gt, Le, Ge (with signed/unsigned variants) - -use std::fmt; - -/// A symbolic operation in an expression tree. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum SymbolicOp { - // Arithmetic operations - /// Addition. - Add, - /// Subtraction. - Sub, - /// Multiplication. - Mul, - /// Signed division. - DivS, - /// Unsigned division. - DivU, - /// Signed remainder (modulo). - RemS, - /// Unsigned remainder (modulo). - RemU, - /// Negation. - Neg, - - // Bitwise operations - /// Bitwise AND. - And, - /// Bitwise OR. - Or, - /// Bitwise XOR. - Xor, - /// Bitwise NOT. - Not, - /// Shift left. - Shl, - /// Arithmetic shift right (preserves sign). - ShrS, - /// Logical shift right (zero-fill). - ShrU, - - // Comparison operations (return 0 or 1) - /// Equal. - Eq, - /// Not equal. - Ne, - /// Signed less than. - LtS, - /// Unsigned less than. - LtU, - /// Signed greater than. - GtS, - /// Unsigned greater than. - GtU, - /// Signed less than or equal. - LeS, - /// Unsigned less than or equal. - LeU, - /// Signed greater than or equal. - GeS, - /// Unsigned greater than or equal. - GeU, -} - -impl SymbolicOp { - /// Checks if this operation is commutative. - /// - /// Commutative operations produce the same result regardless of operand order: - /// `a op b == b op a`. This property is useful for expression canonicalization. - /// - /// # Returns - /// - /// `true` if the operation is commutative (Add, Mul, And, Or, Xor, Eq, Ne). - #[must_use] - pub const fn is_commutative(self) -> bool { - matches!( - self, - Self::Add | Self::Mul | Self::And | Self::Or | Self::Xor | Self::Eq | Self::Ne - ) - } - - /// Checks if this operation is a comparison. - /// - /// Comparison operations return 0 or 1 based on the relationship between operands. - /// - /// # Returns - /// - /// `true` if this is a comparison operation (Eq, Ne, Lt*, Gt*, Le*, Ge*). - #[must_use] - pub const fn is_comparison(self) -> bool { - matches!( - self, - Self::Eq - | Self::Ne - | Self::LtS - | Self::LtU - | Self::GtS - | Self::GtU - | Self::LeS - | Self::LeU - | Self::GeS - | Self::GeU - ) - } - - /// Checks if this is a unary operation. - /// - /// Unary operations take a single operand, unlike binary operations which take two. - /// - /// # Returns - /// - /// `true` if this is a unary operation (Neg, Not). - #[must_use] - pub const fn is_unary(self) -> bool { - matches!(self, Self::Neg | Self::Not) - } -} - -impl fmt::Display for SymbolicOp { - #[allow(clippy::match_same_arms)] // Sub and Neg are semantically different (binary vs unary) - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Add => write!(f, "+"), - Self::Sub => write!(f, "-"), - Self::Mul => write!(f, "*"), - Self::DivS => write!(f, "/"), - Self::DivU => write!(f, "/u"), - Self::RemS => write!(f, "%"), - Self::RemU => write!(f, "%u"), - Self::Neg => write!(f, "-"), - Self::And => write!(f, "&"), - Self::Or => write!(f, "|"), - Self::Xor => write!(f, "^"), - Self::Not => write!(f, "~"), - Self::Shl => write!(f, "<<"), - Self::ShrS => write!(f, ">>"), - Self::ShrU => write!(f, ">>>"), - Self::Eq => write!(f, "=="), - Self::Ne => write!(f, "!="), - Self::LtS => write!(f, "<"), - Self::LtU => write!(f, " write!(f, ">"), - Self::GtU => write!(f, ">u"), - Self::LeS => write!(f, "<="), - Self::LeU => write!(f, "<=u"), - Self::GeS => write!(f, ">="), - Self::GeU => write!(f, ">=u"), - } - } -} diff --git a/dotscope/src/analysis/ssa/symbolic/solver.rs b/dotscope/src/analysis/ssa/symbolic/solver.rs index e07907c3..c25dd270 100644 --- a/dotscope/src/analysis/ssa/symbolic/solver.rs +++ b/dotscope/src/analysis/ssa/symbolic/solver.rs @@ -6,10 +6,8 @@ use std::collections::HashMap; -use crate::{ - analysis::ssa::symbolic::{expr::SymbolicExpr, ops::SymbolicOp}, - metadata::typesystem::PointerSize, -}; +use crate::metadata::typesystem::PointerSize; +use analyssa::analysis::symbolic::{expr::SymbolicExpr, ops::SymbolicOp}; /// Z3-based constraint solver for symbolic expressions. /// @@ -521,10 +519,8 @@ impl Z3Solver { #[cfg(test)] mod tests { - use crate::{ - analysis::ssa::symbolic::{expr::SymbolicExpr, ops::SymbolicOp, solver::Z3Solver}, - metadata::typesystem::PointerSize, - }; + use crate::{analysis::ssa::symbolic::Z3Solver, metadata::typesystem::PointerSize}; + use analyssa::analysis::symbolic::{expr::SymbolicExpr, ops::SymbolicOp}; #[test] fn test_z3_simple_solve() { diff --git a/dotscope/src/analysis/ssa/target.rs b/dotscope/src/analysis/ssa/target.rs new file mode 100644 index 00000000..60e02339 --- /dev/null +++ b/dotscope/src/analysis/ssa/target.rs @@ -0,0 +1,559 @@ +//! `CilTarget` — the .NET CIL host's concrete impl of `analyssa::Target`. +//! +//! The trait + `MockTarget` live in `analyssa::target`; this file plugs CIL +//! semantics into them and is the only place in dotscope that knows the +//! mapping from `Target` associated types to dotscope's metadata types. +//! +//! Conversion helpers (`cil_convert_const`, `cil_convert_const_checked`, +//! `cil_evaluator_apply_conversion`) live here too so the `Target` impl can +//! delegate without forming an `impl ConstValue` ↔ `impl Target for +//! CilTarget` cycle. + +use analyssa::{ir::value::ConstValue, PointerSize}; + +use crate::{ + analysis::ssa::types::{FieldRef, MethodRef, SigRef, SsaType, TypeRef}, + assembly::{FlowType, Instruction, InstructionCategory, Operand, StackBehavior}, + compiler::CilCapability, + metadata::{method::ExceptionHandlerFlags, signatures::SignatureLocalVariable}, +}; + +// Re-export so existing `crate::analysis::ssa::target::Target` import paths +// in the rest of dotscope continue to resolve. The trait itself lives in +// `analyssa::target`. +pub use analyssa::target::Target; + +/// `Target` impl for .NET CIL. +/// +/// Instances carry the pointer width chosen at construction (4 for 32-bit +/// hosts, 8 for 64-bit). The associated types alias the existing dotscope +/// metadata types so the rest of the crate continues to compile unchanged. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub struct CilTarget { + ptr_bytes: u32, +} + +impl CilTarget { + /// 64-bit CIL target (8-byte pointers). The default for x86_64 hosts. + #[must_use] + pub const fn x64() -> Self { + Self { ptr_bytes: 8 } + } + + /// 32-bit CIL target (4-byte pointers). + #[must_use] + pub const fn x86() -> Self { + Self { ptr_bytes: 4 } + } + + /// Construct a `CilTarget` with an explicit pointer width. + /// + /// `ptr_bytes` must be 4 or 8. Other values are accepted but undefined for + /// pointer-sized type inference downstream. + #[must_use] + pub const fn with_ptr_bytes(ptr_bytes: u32) -> Self { + Self { ptr_bytes } + } +} + +impl Default for CilTarget { + fn default() -> Self { + Self::x64() + } +} + +impl Target for CilTarget { + type TypeRef = TypeRef; + type MethodRef = MethodRef; + type FieldRef = FieldRef; + type SigRef = SigRef; + type ExceptionKind = ExceptionHandlerFlags; + type Type = SsaType; + type OriginalInstruction = Instruction; + type LocalSignature = SignatureLocalVariable; + type Capability = CilCapability; + + fn ptr_bytes(&self) -> u32 { + self.ptr_bytes + } + + fn synthetic_instruction() -> Self::OriginalInstruction { + Instruction { + rva: 0, + offset: 0, + size: 0, + opcode: 0, + prefix: 0, + mnemonic: "synthetic", + category: InstructionCategory::Misc, + flow_type: FlowType::Sequential, + operand: Operand::None, + stack_behavior: StackBehavior { + pops: 0, + pushes: 0, + net_effect: 0, + }, + branch_targets: vec![], + } + } + + fn unknown_type() -> Self::Type { + SsaType::Unknown + } + + fn is_integer(t: &Self::Type) -> bool { + t.is_integer() + } + + fn is_floating(t: &Self::Type) -> bool { + t.is_float() + } + + fn is_signed(t: &Self::Type) -> bool { + matches!( + t, + SsaType::I8 + | SsaType::I16 + | SsaType::I32 + | SsaType::I64 + | SsaType::NativeInt + | SsaType::F32 + | SsaType::F64 + ) + } + + fn is_pointer(t: &Self::Type) -> bool { + t.is_pointer() + } + + fn is_reference(t: &Self::Type) -> bool { + t.is_reference() + } + + fn is_unknown(t: &Self::Type) -> bool { + t.is_unknown() + } + + fn bit_width(t: &Self::Type) -> Option { + t.size_bytes().map(|b| b.saturating_mul(8)) + } + + fn instruction_mnemonic(instr: &Self::OriginalInstruction) -> &'static str { + instr.mnemonic + } + + fn instruction_rva(instr: &Self::OriginalInstruction) -> u64 { + instr.rva + } + + fn is_filter_handler(flags: &Self::ExceptionKind) -> bool { + *flags == ExceptionHandlerFlags::FILTER + } + + fn result_type_for_const(value: &ConstValue) -> Option { + Some(match value { + ConstValue::I8(_) => SsaType::I8, + ConstValue::I16(_) => SsaType::I16, + ConstValue::I32(_) => SsaType::I32, + ConstValue::I64(_) => SsaType::I64, + ConstValue::U8(_) => SsaType::U8, + ConstValue::U16(_) => SsaType::U16, + ConstValue::U32(_) => SsaType::U32, + ConstValue::U64(_) => SsaType::U64, + ConstValue::NativeInt(_) => SsaType::NativeInt, + ConstValue::NativeUInt(_) => SsaType::NativeUInt, + ConstValue::F32(_) => SsaType::F32, + ConstValue::F64(_) => SsaType::F64, + ConstValue::String(_) | ConstValue::DecryptedString(_) => SsaType::String, + ConstValue::DecryptedArray { .. } => SsaType::Object, + ConstValue::Null => SsaType::Null, + ConstValue::True | ConstValue::False => SsaType::Bool, + ConstValue::Type(_) | ConstValue::MethodHandle(_) | ConstValue::FieldHandle(_) => { + SsaType::Object + } + }) + } + + fn comparison_result_type() -> Option { + Some(SsaType::Bool) + } + + fn arithmetic_result_type() -> Option { + Some(SsaType::I32) + } + + fn native_int_result_type() -> Option { + Some(SsaType::NativeInt) + } + + fn ckfinite_result_type() -> Option { + Some(SsaType::F64) + } + + fn function_ptr_result_type() -> Option { + Some(SsaType::NativeInt) + } + + fn object_result_type() -> Option { + Some(SsaType::Object) + } + + fn value_type_from_ref(r: &Self::TypeRef) -> Option { + Some(SsaType::ValueType(*r)) + } + + fn byref_value_type_from_ref(r: &Self::TypeRef) -> Option { + Some(SsaType::ByRef(Box::new(SsaType::ValueType(*r)))) + } + + fn byref_class_type_from_ref(r: &Self::TypeRef) -> Option { + Some(SsaType::ByRef(Box::new(SsaType::Class(*r)))) + } + + fn convert_const( + value: &ConstValue, + target_type: &Self::Type, + unsigned_source: bool, + ptr_bytes: u32, + ) -> Option> { + let ptr_size = if ptr_bytes == 4 { + PointerSize::Bit32 + } else { + PointerSize::Bit64 + }; + cil_convert_const(value, target_type, unsigned_source, ptr_size) + } + + fn convert_const_checked( + value: &ConstValue, + target_type: &Self::Type, + unsigned_source: bool, + ptr_bytes: u32, + ) -> Option> { + let ptr_size = if ptr_bytes == 4 { + PointerSize::Bit32 + } else { + PointerSize::Bit64 + }; + cil_convert_const_checked(value, target_type, unsigned_source, ptr_size) + } + + fn evaluate_int_conv( + value: i64, + target: &Self::Type, + unsigned: bool, + ptr_bytes: u32, + ) -> Option> { + let ptr_size = if ptr_bytes == 4 { + PointerSize::Bit32 + } else { + PointerSize::Bit64 + }; + Some(cil_evaluator_apply_conversion( + value, target, unsigned, ptr_size, + )) + } +} + +fn cil_convert_const( + value: &ConstValue, + target: &SsaType, + unsigned_source: bool, + ptr_size: PointerSize, +) -> Option> { + let (signed_val, unsigned_val) = if unsigned_source { + let u = value.as_u64()?; + (i64::from_ne_bytes(u.to_ne_bytes()), u) + } else { + let s = value.as_i64()?; + (s, u64::from_ne_bytes(s.to_ne_bytes())) + }; + + #[allow(clippy::cast_possible_truncation)] + let converted = match target { + SsaType::I8 => ConstValue::I8(signed_val as i8), + SsaType::U8 => ConstValue::U8(unsigned_val as u8), + SsaType::I16 => ConstValue::I16(signed_val as i16), + SsaType::U16 | SsaType::Char => ConstValue::U16(unsigned_val as u16), + SsaType::I32 => ConstValue::I32(signed_val as i32), + SsaType::U32 => ConstValue::U32(unsigned_val as u32), + SsaType::I64 => ConstValue::I64(signed_val), + SsaType::U64 => ConstValue::U64(unsigned_val), + SsaType::NativeInt => ConstValue::NativeInt(signed_val), + SsaType::NativeUInt => ConstValue::NativeUInt(unsigned_val), + SsaType::F32 => + { + #[allow(clippy::cast_precision_loss)] + if unsigned_source { + ConstValue::F32(unsigned_val as f32) + } else { + ConstValue::F32(signed_val as f32) + } + } + SsaType::F64 => + { + #[allow(clippy::cast_precision_loss)] + if unsigned_source { + ConstValue::F64(unsigned_val as f64) + } else { + ConstValue::F64(signed_val as f64) + } + } + SsaType::Bool => ConstValue::from_bool(signed_val != 0), + _ => return None, + }; + Some(converted.mask_native(ptr_size)) +} + +fn cil_convert_const_checked( + value: &ConstValue, + target: &SsaType, + unsigned_source: bool, + ptr_size: PointerSize, +) -> Option> { + let (signed_val, unsigned_val) = if unsigned_source { + let u = value.as_u64()?; + (i64::from_ne_bytes(u.to_ne_bytes()), u) + } else { + let s = value.as_i64()?; + (s, u64::from_ne_bytes(s.to_ne_bytes())) + }; + + let fits = match target { + SsaType::I8 => i8::try_from(signed_val).is_ok(), + SsaType::U8 => u8::try_from(unsigned_val).is_ok() && signed_val >= 0, + SsaType::I16 => i16::try_from(signed_val).is_ok(), + SsaType::U16 => u16::try_from(unsigned_val).is_ok() && signed_val >= 0, + SsaType::I32 => i32::try_from(signed_val).is_ok(), + SsaType::U32 => u32::try_from(unsigned_val).is_ok() && signed_val >= 0, + SsaType::I64 + | SsaType::NativeInt + | SsaType::Bool + | SsaType::Char + | SsaType::F32 + | SsaType::F64 => true, + SsaType::U64 | SsaType::NativeUInt => signed_val >= 0, + _ => return None, + }; + + if !fits { + return None; + } + cil_convert_const(value, target, unsigned_source, ptr_size) +} + +/// CIL-side `evaluate_int_conv` body. Mirrors the legacy +/// `SsaEvaluator::apply_conversion` semantics (raw `as`-casts; `Bool` truncates +/// the low byte rather than booleanizing) so generifying the evaluator does +/// not change behavior. +#[allow( + clippy::cast_possible_truncation, + clippy::cast_sign_loss, + clippy::cast_possible_wrap, + clippy::cast_precision_loss +)] +fn cil_evaluator_apply_conversion( + value: i64, + target: &SsaType, + unsigned: bool, + ptr_size: PointerSize, +) -> ConstValue { + match target { + SsaType::I8 => { + if unsigned { + ConstValue::I8((value as u8) as i8) + } else { + ConstValue::I8(value as i8) + } + } + SsaType::U8 | SsaType::Bool => ConstValue::U8(value as u8), + SsaType::I16 => { + if unsigned { + ConstValue::I16((value as u16) as i16) + } else { + ConstValue::I16(value as i16) + } + } + SsaType::U16 => ConstValue::U16(value as u16), + SsaType::I32 => { + if unsigned { + ConstValue::I32((value as u32) as i32) + } else { + ConstValue::I32(value as i32) + } + } + SsaType::U32 => ConstValue::U32(value as u32), + SsaType::NativeInt => match ptr_size { + PointerSize::Bit32 => { + if unsigned { + ConstValue::NativeInt(i64::from((value as u32) as i32)) + } else { + ConstValue::NativeInt(i64::from(value as i32)) + } + } + PointerSize::Bit64 => ConstValue::NativeInt(value), + PointerSize::Bit8 | PointerSize::Bit16 | PointerSize::Bit128 => { + ConstValue::NativeInt(value) + } + }, + SsaType::NativeUInt => match ptr_size { + PointerSize::Bit32 => ConstValue::NativeUInt(u64::from(value as u32)), + PointerSize::Bit64 => ConstValue::NativeUInt(value as u64), + PointerSize::Bit8 | PointerSize::Bit16 | PointerSize::Bit128 => { + ConstValue::NativeUInt(value as u64) + } + }, + SsaType::U64 => ConstValue::U64(value as u64), + SsaType::F32 => { + let float_val = if unsigned { + (value as u64) as f32 + } else { + value as f32 + }; + ConstValue::F32(float_val) + } + SsaType::F64 => { + let float_val = if unsigned { + (value as u64) as f64 + } else { + value as f64 + }; + ConstValue::F64(float_val) + } + _ => ConstValue::I64(value), + } +} + +#[cfg(test)] +mod tests { + use super::*; + + use analyssa::{MockTarget, MockType}; + + use crate::analysis::ssa::{ + value::ConstValue, DefSite, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId, + VariableOrigin, + }; + + #[test] + fn cil_target_ptr_bytes() { + assert_eq!(CilTarget::x64().ptr_bytes(), 8); + assert_eq!(CilTarget::x86().ptr_bytes(), 4); + assert_eq!(CilTarget::with_ptr_bytes(8).ptr_bytes(), 8); + assert_eq!(CilTarget::default().ptr_bytes(), 8); + } + + #[test] + fn cil_target_type_queries() { + assert!(CilTarget::is_integer(&SsaType::I32)); + assert!(CilTarget::is_integer(&SsaType::U64)); + assert!(!CilTarget::is_integer(&SsaType::F32)); + + assert!(CilTarget::is_floating(&SsaType::F32)); + assert!(CilTarget::is_floating(&SsaType::F64)); + assert!(!CilTarget::is_floating(&SsaType::I32)); + + assert!(CilTarget::is_signed(&SsaType::I32)); + assert!(!CilTarget::is_signed(&SsaType::U32)); + + assert!(CilTarget::is_pointer(&SsaType::Pointer(Box::new( + SsaType::I32 + )))); + assert!(CilTarget::is_pointer(&SsaType::ByRef(Box::new( + SsaType::I32 + )))); + assert!(!CilTarget::is_pointer(&SsaType::I32)); + + assert!(CilTarget::is_reference(&SsaType::Object)); + assert!(CilTarget::is_reference(&SsaType::String)); + assert!(!CilTarget::is_reference(&SsaType::I32)); + + assert!(CilTarget::is_unknown(&SsaType::Unknown)); + assert!(CilTarget::is_unknown(&SsaType::Varying)); + assert!(!CilTarget::is_unknown(&SsaType::I32)); + + assert_eq!(CilTarget::bit_width(&SsaType::I32), Some(32)); + assert_eq!(CilTarget::bit_width(&SsaType::I64), Some(64)); + assert_eq!(CilTarget::bit_width(&SsaType::Bool), Some(8)); + assert_eq!(CilTarget::bit_width(&SsaType::NativeInt), None); + } + + #[test] + fn cil_target_unknown() { + assert_eq!(CilTarget::unknown_type(), SsaType::Unknown); + } + + #[test] + fn cil_target_synthetic_instruction() { + let i = CilTarget::synthetic_instruction(); + assert_eq!(i.mnemonic, "synthetic"); + assert_eq!(i.size, 0); + } + + /// End-to-end IR-core smoke test using `MockTarget`. + /// + /// Constructs an `SsaFunction` end-to-end using the generic + /// IR API. Lives in dotscope (rather than analyssa) until `SsaFunction` + /// itself moves to analyssa. Once that happens, this test should migrate + /// to a analyssa-side integration test. + #[test] + fn mock_target_builds_generic_ir() { + // 1. Empty function with the mock target. + let mut func = SsaFunction::::new(1, 1); + assert_eq!(func.num_args(), 1); + assert_eq!(func.num_locals(), 1); + assert!(func.is_empty()); + + // 2. Allocate a few variables. + let arg0 = func.create_variable( + VariableOrigin::Argument(0), + 0, + DefSite::phi(0), + MockType::I32, + ); + let const_var = func.create_variable( + VariableOrigin::Local(0), + 0, + DefSite::instruction(0, 0), + MockType::I32, + ); + let sum = func.create_variable( + VariableOrigin::Local(1), + 0, + DefSite::instruction(0, 1), + MockType::I32, + ); + assert_eq!(func.variable_count(), 3); + assert_eq!(func.variable(arg0).unwrap().var_type(), &MockType::I32); + + // 3. Build a block: const = 42, sum = arg0 + const, return sum. + let mut block: SsaBlock = SsaBlock::new(0); + block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { + dest: const_var, + value: ConstValue::I32(42), + })); + block.add_instruction(SsaInstruction::synthetic(SsaOp::Add { + dest: sum, + left: arg0, + right: const_var, + flags: None, + })); + block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { + value: Some(sum), + })); + assert_eq!(block.instruction_count(), 3); + func.add_block(block); + + // 4. Iterate over the function's instructions. + let instr_dests: Vec> = func + .iter_instructions() + .map(|(_, _, instr)| instr.def()) + .collect(); + assert_eq!(instr_dests, vec![Some(const_var), Some(sum), None]); + + // 5. Confirm the IR carries `T::Type` values opaquely. + for var in func.variables() { + assert!(MockTarget::is_integer(var.var_type())); + } + } +} diff --git a/dotscope/src/analysis/ssa/types.rs b/dotscope/src/analysis/ssa/types.rs index ede4d449..c57a2b70 100644 --- a/dotscope/src/analysis/ssa/types.rs +++ b/dotscope/src/analysis/ssa/types.rs @@ -77,6 +77,12 @@ impl MethodRef { } } +impl From for MethodRef { + fn from(token: Token) -> Self { + Self(token) + } +} + impl fmt::Display for MethodRef { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { write!(f, "MethodRef({})", self.0) @@ -152,7 +158,6 @@ impl fmt::Display for SigRef { /// ``` #[derive(Debug, Clone, PartialEq, Eq, Hash)] pub enum SsaType { - // ========== Primitives ========== /// No return value (void). Void, @@ -198,7 +203,6 @@ pub enum SsaType { /// Unicode character (System.Char). Char, - // ========== Reference Types ========== /// System.Object reference. Object, @@ -227,7 +231,6 @@ pub enum SsaType { /// Managed reference (byref) to a type. ByRef(Box), - // ========== Special Types ========== /// Typed reference (System.TypedReference). TypedReference, @@ -244,7 +247,6 @@ pub enum SsaType { /// Function pointer type. FnPtr(Box), - // ========== Analysis Types ========== /// Known null constant (more precise than Object). Null, diff --git a/dotscope/src/analysis/ssa/value.rs b/dotscope/src/analysis/ssa/value.rs index 93bc5c13..173c6a05 100644 --- a/dotscope/src/analysis/ssa/value.rs +++ b/dotscope/src/analysis/ssa/value.rs @@ -1,1702 +1,129 @@ -//! Value tracking for SSA variables. +//! Re-export shim + CIL-specific extension surface for `ConstValue`. //! -//! This module provides abstract value representation for SSA variables, -//! enabling constant propagation, value numbering, and range analysis. +//! The generic `ConstValue`, `AbstractValue`, `ComputedValue`, and +//! `ComputedOp` types live in `analyssa::ir::value`. This module: //! -//! # Lattice Structure -//! -//! The `AbstractValue` type forms a lattice for dataflow analysis: -//! -//! ```text -//! Top (no information) -//! | -//! +---------+---------+ -//! | | | -//! Const Range NonNull -//! | | | -//! +---------+---------+ -//! | -//! Bottom (conflicting info) -//! ``` -//! -//! - `Top`: No information known yet (initial state) -//! - `Constant`: Known compile-time constant -//! - `Range`: Value in a bounded range -//! - `NonNull`: Known to be non-null (for references) -//! - `Bottom`: Multiple conflicting values (cannot be constant) +//! - Re-exports them with type aliases so callers writing `ConstValue` (no +//! `T` parameter) keep getting `ConstValue` for back-compat. +//! - Adds the CIL-specific extension trait [`ConstValueCilExt`] for +//! `ssa_type` / `as_string_content`, which need `SsaType` / `CilObject` +//! (foreign types from analyssa's perspective). +//! - Adds `TryFrom<&ConstValue> for Immediate` for instruction +//! encoding (orphan-rule allowed: `Immediate` is local). -use std::fmt; +use analyssa::ir::value::ConstValue as AnalyssaConstValue; use crate::{ - analysis::ssa::{ - types::{FieldRef, MethodRef, TypeRef}, - SsaType, SsaVarId, - }, + analysis::ssa::{target::CilTarget, types::SsaType}, assembly::Immediate, - metadata::{token::Token, typesystem::PointerSize}, CilObject, Error, }; -/// Constant values that can appear in SSA form. -/// -/// These represent compile-time constants that can be tracked through -/// the SSA graph for constant propagation and folding. -#[derive(Debug, Clone, PartialEq)] -pub enum ConstValue { - /// 8-bit signed integer. - I8(i8), - - /// 16-bit signed integer. - I16(i16), - - /// 32-bit signed integer. - I32(i32), - - /// 64-bit signed integer. - I64(i64), - - /// 8-bit unsigned integer. - U8(u8), - - /// 16-bit unsigned integer. - U16(u16), - - /// 32-bit unsigned integer. - U32(u32), - - /// 64-bit unsigned integer. - U64(u64), - - /// Native integer (pointer-sized). - NativeInt(i64), +// Type aliases preserve the `T = CilTarget` default so existing callers +// writing `ConstValue` (no params) compile. `ComputedValue`/`ComputedOp` +// aren't re-exported because the original dotscope `value.rs` didn't +// re-export them either — direct callers use `analyssa::ir::value::ComputedOp`. - /// Native unsigned integer (pointer-sized). - NativeUInt(u64), +/// CIL-defaulted alias of `analyssa::ir::value::ConstValue`. +pub type ConstValue = AnalyssaConstValue; - /// 32-bit floating point. - F32(f32), - - /// 64-bit floating point. - F64(f64), - - /// String constant (index into #US heap). - String(u32), - - /// Decrypted string value (actual string content, not a heap index). - /// Used by deobfuscation passes to store strings that were decrypted at analysis time. - DecryptedString(String), - - /// Null reference. - Null, - - /// Boolean true. - True, - - /// Boolean false. - False, - - /// Runtime type handle (typeof result). - Type(TypeRef), - - /// Runtime method handle. - MethodHandle(MethodRef), - - /// Runtime field handle. - FieldHandle(FieldRef), - - /// Decrypted array data (raw bytes and element type token). - /// Used by deobfuscation passes to store arrays that were decrypted at analysis time. - /// The codegen emits `newarr` + element stores to reconstruct the array. - DecryptedArray { - /// Raw bytes of the array data in little-endian layout. - data: Vec, - /// Metadata token of the element type (TypeRef/TypeDef from the assembly). - element_type_token: Token, - /// Size of each element in bytes (1 for byte, 4 for int, etc.). - element_size: usize, - }, -} - -impl ConstValue { - /// Returns `true` if this is the null constant. - #[must_use] - pub const fn is_null(&self) -> bool { - matches!(self, Self::Null) - } - - /// Returns `true` if this is a boolean constant. - #[must_use] - pub const fn is_bool(&self) -> bool { - matches!(self, Self::True | Self::False) - } - - /// Returns `true` if this is an integer constant (signed or unsigned). - #[must_use] - pub const fn is_integer(&self) -> bool { - matches!( - self, - Self::I8(_) - | Self::I16(_) - | Self::I32(_) - | Self::I64(_) - | Self::U8(_) - | Self::U16(_) - | Self::U32(_) - | Self::U64(_) - | Self::NativeInt(_) - | Self::NativeUInt(_) - ) - } - - /// Returns `true` if this is a signed integer constant. - #[must_use] - pub const fn is_signed(&self) -> bool { - matches!( - self, - Self::I8(_) | Self::I16(_) | Self::I32(_) | Self::I64(_) | Self::NativeInt(_) - ) - } - - /// Returns `true` if this is an unsigned integer constant. - #[must_use] - pub const fn is_unsigned(&self) -> bool { - matches!( - self, - Self::U8(_) | Self::U16(_) | Self::U32(_) | Self::U64(_) | Self::NativeUInt(_) - ) - } - - /// Returns `true` if this is a floating-point constant. - #[must_use] - pub const fn is_float(&self) -> bool { - matches!(self, Self::F32(_) | Self::F64(_)) - } +/// CIL-defaulted alias of `analyssa::ir::value::AbstractValue`. +pub type AbstractValue = analyssa::ir::value::AbstractValue; +/// CIL-specific extension methods on `ConstValue`. +/// +/// These can't be inherent impls (orphan rule: `ConstValue` is in analyssa) so +/// they're a trait. Import this trait to call `value.ssa_type()` and +/// `value.as_string_content(&assembly)` as before. +pub trait ConstValueCilExt { /// Returns the SSA type corresponding to this constant value. - #[must_use] - pub const fn ssa_type(&self) -> SsaType { - match self { - Self::I8(_) => SsaType::I8, - Self::I16(_) => SsaType::I16, - Self::I32(_) => SsaType::I32, - Self::I64(_) => SsaType::I64, - Self::U8(_) => SsaType::U8, - Self::U16(_) => SsaType::U16, - Self::U32(_) => SsaType::U32, - Self::U64(_) => SsaType::U64, - Self::F32(_) => SsaType::F32, - Self::F64(_) => SsaType::F64, - Self::NativeInt(_) | Self::Type(_) | Self::MethodHandle(_) | Self::FieldHandle(_) => { - SsaType::NativeInt - } - Self::NativeUInt(_) => SsaType::NativeUInt, - Self::True | Self::False => SsaType::Bool, - Self::Null - | Self::String(_) - | Self::DecryptedString(_) - | Self::DecryptedArray { .. } => SsaType::Object, - } - } - - /// Returns `true` if this value is a string (`String` or `DecryptedString`). - #[must_use] - pub const fn is_string_like(&self) -> bool { - matches!(self, Self::String(_) | Self::DecryptedString(_)) - } - - /// Returns the constant as an i32 if applicable. - #[must_use] - pub const fn as_i32(&self) -> Option { - match self { - Self::I8(v) => Some(*v as i32), - Self::I16(v) => Some(*v as i32), - Self::I32(v) => Some(*v), - Self::U8(v) => Some(*v as i32), - Self::U16(v) => Some(*v as i32), - Self::True => Some(1), - Self::False => Some(0), - _ => None, - } - } - - /// Returns the constant as an i64 if applicable. - #[must_use] - #[allow(clippy::match_same_arms)] // NativeInt is semantically different from I64 - pub const fn as_i64(&self) -> Option { - match self { - Self::I8(v) => Some(*v as i64), - Self::I16(v) => Some(*v as i64), - Self::I32(v) => Some(*v as i64), - Self::I64(v) => Some(*v), - Self::U8(v) => Some(*v as i64), - Self::U16(v) => Some(*v as i64), - Self::U32(v) => Some(*v as i64), - Self::NativeInt(v) => Some(*v), - Self::True => Some(1), - Self::False => Some(0), - _ => None, - } - } - - /// Returns the constant as a u64 if applicable (for unsigned operations). - #[must_use] - #[allow(clippy::cast_sign_loss)] // Guarded by >= 0 checks - #[allow(clippy::match_same_arms)] // NativeUInt is semantically different from U64 - pub const fn as_u64(&self) -> Option { - match self { - Self::U8(v) => Some(*v as u64), - Self::U16(v) => Some(*v as u64), - Self::U32(v) => Some(*v as u64), - Self::U64(v) => Some(*v), - Self::NativeUInt(v) => Some(*v), - Self::I8(v) if *v >= 0 => Some(*v as u64), - Self::I16(v) if *v >= 0 => Some(*v as u64), - Self::I32(v) if *v >= 0 => Some(*v as u64), - Self::I64(v) if *v >= 0 => Some(*v as u64), - Self::True => Some(1), - Self::False => Some(0), - _ => None, - } - } - - /// Returns the constant as a u32 if applicable (for unsigned operations). - #[must_use] - #[allow(clippy::cast_sign_loss)] // Guarded by >= 0 checks - pub const fn as_u32(&self) -> Option { - match self { - Self::U8(v) => Some(*v as u32), - Self::U16(v) => Some(*v as u32), - Self::U32(v) => Some(*v), - Self::I8(v) if *v >= 0 => Some(*v as u32), - Self::I16(v) if *v >= 0 => Some(*v as u32), - Self::I32(v) if *v >= 0 => Some(*v as u32), - Self::True => Some(1), - Self::False => Some(0), - _ => None, - } - } - - /// Returns the constant as an f32 if it's stored as F32. - #[must_use] - pub const fn as_f32(&self) -> Option { - match self { - Self::F32(v) => Some(*v), - _ => None, - } - } - - /// Returns the constant as an f64 if it's stored as F64. - #[must_use] - pub const fn as_f64(&self) -> Option { - match self { - Self::F64(v) => Some(*v), - _ => None, - } - } - - /// Returns the constant as a bool if applicable. - #[must_use] - pub const fn as_bool(&self) -> Option { - match self { - Self::False - | Self::Null - | Self::I8(0) - | Self::I16(0) - | Self::I32(0) - | Self::I64(0) - | Self::U8(0) - | Self::U16(0) - | Self::U32(0) - | Self::U64(0) => Some(false), - Self::True - | Self::I8(_) - | Self::I16(_) - | Self::I32(_) - | Self::I64(_) - | Self::U8(_) - | Self::U16(_) - | Self::U32(_) - | Self::U64(_) => Some(true), - _ => None, - } - } - - /// Creates a boolean constant from a bool value. - #[must_use] - pub const fn from_bool(value: bool) -> Self { - if value { - Self::True - } else { - Self::False - } - } - - /// Returns the string content if this is a `DecryptedString`. - /// - /// For `String` variants (heap references), use [`as_string_content`](Self::as_string_content) - /// which resolves the heap index via the assembly. - #[must_use] - pub fn as_decrypted_string(&self) -> Option<&str> { - match self { - Self::DecryptedString(s) => Some(s.as_str()), - _ => None, - } - } + fn ssa_type(&self) -> SsaType; /// Returns string content, resolving `#US` heap indices via the assembly. /// - /// Returns `Some` for `DecryptedString` (directly) and `String` (via heap lookup). - /// Returns `None` for all other variants. - #[must_use] - pub fn as_string_content(&self, assembly: &CilObject) -> Option { + /// Returns `Some` for `DecryptedString` (directly) and `String` (via heap + /// lookup). Returns `None` for all other variants. + fn as_string_content(&self, assembly: &CilObject) -> Option; +} + +impl ConstValueCilExt for AnalyssaConstValue { + fn ssa_type(&self) -> SsaType { + match self { + AnalyssaConstValue::I8(_) => SsaType::I8, + AnalyssaConstValue::I16(_) => SsaType::I16, + AnalyssaConstValue::I32(_) => SsaType::I32, + AnalyssaConstValue::I64(_) => SsaType::I64, + AnalyssaConstValue::U8(_) => SsaType::U8, + AnalyssaConstValue::U16(_) => SsaType::U16, + AnalyssaConstValue::U32(_) => SsaType::U32, + AnalyssaConstValue::U64(_) => SsaType::U64, + AnalyssaConstValue::F32(_) => SsaType::F32, + AnalyssaConstValue::F64(_) => SsaType::F64, + AnalyssaConstValue::NativeInt(_) + | AnalyssaConstValue::Type(_) + | AnalyssaConstValue::MethodHandle(_) + | AnalyssaConstValue::FieldHandle(_) => SsaType::NativeInt, + AnalyssaConstValue::NativeUInt(_) => SsaType::NativeUInt, + AnalyssaConstValue::True | AnalyssaConstValue::False => SsaType::Bool, + AnalyssaConstValue::Null + | AnalyssaConstValue::String(_) + | AnalyssaConstValue::DecryptedString(_) + | AnalyssaConstValue::DecryptedArray { .. } => SsaType::Object, + } + } + + fn as_string_content(&self, assembly: &CilObject) -> Option { match self { - Self::DecryptedString(s) => Some(s.clone()), - Self::String(idx) => assembly + AnalyssaConstValue::DecryptedString(s) => Some(s.clone()), + AnalyssaConstValue::String(idx) => assembly .userstrings() .and_then(|us| us.get(*idx as usize).ok()) .map(|s| s.to_string_lossy()), _ => None, } } - - /// Returns `true` if this constant represents zero. - /// - /// This includes all numeric zero values and `False`. - /// Useful for opaque predicate detection where `x ^ x`, `x - x`, `x * 0`, etc. produce zero. - #[must_use] - pub const fn is_zero(&self) -> bool { - matches!( - self, - Self::I8(0) - | Self::I16(0) - | Self::I32(0) - | Self::I64(0) - | Self::U8(0) - | Self::U16(0) - | Self::U32(0) - | Self::U64(0) - | Self::NativeInt(0) - | Self::NativeUInt(0) - | Self::False - ) - } - - /// Returns `true` if this constant represents one. - /// - /// This includes all numeric one values and `True`. - /// Useful for identity operations and opaque predicate detection. - #[must_use] - pub const fn is_one(&self) -> bool { - matches!( - self, - Self::I8(1) - | Self::I16(1) - | Self::I32(1) - | Self::I64(1) - | Self::U8(1) - | Self::U16(1) - | Self::U32(1) - | Self::U64(1) - | Self::NativeInt(1) - | Self::NativeUInt(1) - | Self::True - ) - } - - /// Returns `true` if this constant represents negative one (-1). - /// - /// This is useful for detecting `x | -1 = -1` patterns in opaque predicates. - #[must_use] - pub const fn is_minus_one(&self) -> bool { - matches!( - self, - Self::I8(-1) | Self::I16(-1) | Self::I32(-1) | Self::I64(-1) | Self::NativeInt(-1) - ) - } - - /// Returns `true` if this constant has all bits set (e.g., -1 for signed, MAX for unsigned). - /// - /// This is useful for detecting `x & -1 = x` and `x | -1 = -1` patterns. - #[must_use] - pub const fn is_all_ones(&self) -> bool { - matches!( - self, - Self::I8(-1) - | Self::I16(-1) - | Self::I32(-1) - | Self::I64(-1) - | Self::NativeInt(-1) - | Self::U8(u8::MAX) - | Self::U16(u16::MAX) - | Self::U32(u32::MAX) - | Self::U64(u64::MAX) - | Self::NativeUInt(u64::MAX) - ) - } - - /// Returns a zero constant of the same type as this constant. - /// - /// Useful for algebraic simplifications like `x * 0 = 0` where the result - /// should preserve the type of the operands. - #[must_use] - pub const fn zero_of_same_type(&self) -> Self { - match self { - Self::I8(_) => Self::I8(0), - Self::I16(_) => Self::I16(0), - Self::I64(_) => Self::I64(0), - Self::U8(_) => Self::U8(0), - Self::U16(_) => Self::U16(0), - Self::U32(_) => Self::U32(0), - Self::U64(_) => Self::U64(0), - Self::NativeInt(_) => Self::NativeInt(0), - Self::NativeUInt(_) => Self::NativeUInt(0), - Self::F32(_) => Self::F32(0.0), - Self::F64(_) => Self::F64(0.0), - // For non-numeric types (including I32), default to i32 - _ => Self::I32(0), - } - } - - /// Attempts to negate this constant. - #[must_use] - pub fn negate(&self, ptr_size: PointerSize) -> Option { - match self { - Self::I8(v) => Some(Self::I8(v.wrapping_neg())), - Self::I16(v) => Some(Self::I16(v.wrapping_neg())), - Self::I32(v) => Some(Self::I32(v.wrapping_neg())), - Self::I64(v) => Some(Self::I64(v.wrapping_neg())), - Self::NativeInt(v) => Some(Self::NativeInt(v.wrapping_neg())), - Self::F32(v) => Some(Self::F32(-v)), - Self::F64(v) => Some(Self::F64(-v)), - // Unsigned negation wraps - Self::U8(v) => Some(Self::U8(v.wrapping_neg())), - Self::U16(v) => Some(Self::U16(v.wrapping_neg())), - Self::U32(v) => Some(Self::U32(v.wrapping_neg())), - Self::U64(v) => Some(Self::U64(v.wrapping_neg())), - Self::NativeUInt(v) => Some(Self::NativeUInt(v.wrapping_neg())), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to perform bitwise NOT on this constant. - #[must_use] - pub fn bitwise_not(&self, ptr_size: PointerSize) -> Option { - match self { - Self::I8(v) => Some(Self::I8(!v)), - Self::I16(v) => Some(Self::I16(!v)), - Self::I32(v) => Some(Self::I32(!v)), - Self::I64(v) => Some(Self::I64(!v)), - Self::U8(v) => Some(Self::U8(!v)), - Self::U16(v) => Some(Self::U16(!v)), - Self::U32(v) => Some(Self::U32(!v)), - Self::U64(v) => Some(Self::U64(!v)), - Self::NativeInt(v) => Some(Self::NativeInt(!v)), - Self::NativeUInt(v) => Some(Self::NativeUInt(!v)), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to perform bitwise AND on two constants. - #[must_use] - pub fn bitwise_and(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::I8(a & b)), - (Self::I16(a), Self::I16(b)) => Some(Self::I16(a & b)), - (Self::I32(a), Self::I32(b)) => Some(Self::I32(a & b)), - (Self::I64(a), Self::I64(b)) => Some(Self::I64(a & b)), - (Self::U8(a), Self::U8(b)) => Some(Self::U8(a & b)), - (Self::U16(a), Self::U16(b)) => Some(Self::U16(a & b)), - (Self::U32(a), Self::U32(b)) => Some(Self::U32(a & b)), - (Self::U64(a), Self::U64(b)) => Some(Self::U64(a & b)), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::NativeInt(a & b)), - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::NativeUInt(a & b)), - // Cross-type: promote to i64 for mixed signed operations - (Self::I32(a), Self::I64(b)) | (Self::I64(b), Self::I32(a)) => { - Some(Self::I64(i64::from(*a) & b)) - } - (Self::U32(a), Self::U64(b)) | (Self::U64(b), Self::U32(a)) => { - Some(Self::U64(u64::from(*a) & b)) - } - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to perform bitwise OR on two constants. - #[must_use] - pub fn bitwise_or(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::I8(a | b)), - (Self::I16(a), Self::I16(b)) => Some(Self::I16(a | b)), - (Self::I32(a), Self::I32(b)) => Some(Self::I32(a | b)), - (Self::I64(a), Self::I64(b)) => Some(Self::I64(a | b)), - (Self::U8(a), Self::U8(b)) => Some(Self::U8(a | b)), - (Self::U16(a), Self::U16(b)) => Some(Self::U16(a | b)), - (Self::U32(a), Self::U32(b)) => Some(Self::U32(a | b)), - (Self::U64(a), Self::U64(b)) => Some(Self::U64(a | b)), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::NativeInt(a | b)), - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::NativeUInt(a | b)), - (Self::I32(a), Self::I64(b)) | (Self::I64(b), Self::I32(a)) => { - Some(Self::I64(i64::from(*a) | b)) - } - (Self::U32(a), Self::U64(b)) | (Self::U64(b), Self::U32(a)) => { - Some(Self::U64(u64::from(*a) | b)) - } - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to perform bitwise XOR on two constants. - #[must_use] - pub fn bitwise_xor(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::I8(a ^ b)), - (Self::I16(a), Self::I16(b)) => Some(Self::I16(a ^ b)), - (Self::I32(a), Self::I32(b)) => Some(Self::I32(a ^ b)), - (Self::I64(a), Self::I64(b)) => Some(Self::I64(a ^ b)), - (Self::U8(a), Self::U8(b)) => Some(Self::U8(a ^ b)), - (Self::U16(a), Self::U16(b)) => Some(Self::U16(a ^ b)), - (Self::U32(a), Self::U32(b)) => Some(Self::U32(a ^ b)), - (Self::U64(a), Self::U64(b)) => Some(Self::U64(a ^ b)), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::NativeInt(a ^ b)), - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::NativeUInt(a ^ b)), - (Self::I32(a), Self::I64(b)) | (Self::I64(b), Self::I32(a)) => { - Some(Self::I64(i64::from(*a) ^ b)) - } - (Self::U32(a), Self::U64(b)) | (Self::U64(b), Self::U32(a)) => { - Some(Self::U64(u64::from(*a) ^ b)) - } - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to shift left. - #[must_use] - #[allow(clippy::cast_sign_loss)] // Shift amounts are non-negative by convention - pub fn shl(&self, amount: &Self, ptr_size: PointerSize) -> Option { - let shift = amount.as_i32()? as u32; - match self { - Self::I8(v) => Some(Self::I8(v.wrapping_shl(shift))), - Self::I16(v) => Some(Self::I16(v.wrapping_shl(shift))), - Self::I32(v) => Some(Self::I32(v.wrapping_shl(shift))), - Self::I64(v) => Some(Self::I64(v.wrapping_shl(shift))), - Self::U8(v) => Some(Self::U8(v.wrapping_shl(shift))), - Self::U16(v) => Some(Self::U16(v.wrapping_shl(shift))), - Self::U32(v) => Some(Self::U32(v.wrapping_shl(shift))), - Self::U64(v) => Some(Self::U64(v.wrapping_shl(shift))), - Self::NativeInt(v) => Some(Self::NativeInt(v.wrapping_shl(shift))), - Self::NativeUInt(v) => Some(Self::NativeUInt(v.wrapping_shl(shift))), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to shift right (arithmetic for signed, logical for unsigned). - #[must_use] - #[allow(clippy::cast_sign_loss)] // Shift amounts and unsigned shifts use intentional casts - #[allow(clippy::cast_possible_wrap)] // Wrapping is expected for logical shift operations - pub fn shr(&self, amount: &Self, unsigned: bool, ptr_size: PointerSize) -> Option { - let shift = amount.as_i32()? as u32; - match self { - Self::I8(v) => { - if unsigned { - Some(Self::I8((*v as u8).wrapping_shr(shift) as i8)) - } else { - Some(Self::I8(v.wrapping_shr(shift))) - } - } - Self::I16(v) => { - if unsigned { - Some(Self::I16((*v as u16).wrapping_shr(shift) as i16)) - } else { - Some(Self::I16(v.wrapping_shr(shift))) - } - } - Self::I32(v) => { - if unsigned { - Some(Self::I32((*v as u32).wrapping_shr(shift) as i32)) - } else { - Some(Self::I32(v.wrapping_shr(shift))) - } - } - Self::I64(v) => { - if unsigned { - Some(Self::I64((*v as u64).wrapping_shr(shift) as i64)) - } else { - Some(Self::I64(v.wrapping_shr(shift))) - } - } - Self::U8(v) => Some(Self::U8(v.wrapping_shr(shift))), - Self::U16(v) => Some(Self::U16(v.wrapping_shr(shift))), - Self::U32(v) => Some(Self::U32(v.wrapping_shr(shift))), - Self::U64(v) => Some(Self::U64(v.wrapping_shr(shift))), - Self::NativeInt(v) => { - if unsigned { - Some(Self::NativeInt((*v as u64).wrapping_shr(shift) as i64)) - } else { - Some(Self::NativeInt(v.wrapping_shr(shift))) - } - } - Self::NativeUInt(v) => Some(Self::NativeUInt(v.wrapping_shr(shift))), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to add two constants. - #[must_use] - pub fn add(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::I8(a.wrapping_add(*b))), - (Self::I16(a), Self::I16(b)) => Some(Self::I16(a.wrapping_add(*b))), - (Self::I32(a), Self::I32(b)) => Some(Self::I32(a.wrapping_add(*b))), - (Self::I64(a), Self::I64(b)) => Some(Self::I64(a.wrapping_add(*b))), - (Self::U8(a), Self::U8(b)) => Some(Self::U8(a.wrapping_add(*b))), - (Self::U16(a), Self::U16(b)) => Some(Self::U16(a.wrapping_add(*b))), - (Self::U32(a), Self::U32(b)) => Some(Self::U32(a.wrapping_add(*b))), - (Self::U64(a), Self::U64(b)) => Some(Self::U64(a.wrapping_add(*b))), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::NativeInt(a.wrapping_add(*b))), - (Self::NativeUInt(a), Self::NativeUInt(b)) => { - Some(Self::NativeUInt(a.wrapping_add(*b))) - } - (Self::F32(a), Self::F32(b)) => Some(Self::F32(a + b)), - (Self::F64(a), Self::F64(b)) => Some(Self::F64(a + b)), - // Cross-type promotions - (Self::I32(a), Self::I64(b)) | (Self::I64(b), Self::I32(a)) => { - Some(Self::I64(i64::from(*a).wrapping_add(*b))) - } - (Self::U32(a), Self::U64(b)) | (Self::U64(b), Self::U32(a)) => { - Some(Self::U64(u64::from(*a).wrapping_add(*b))) - } - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to subtract two constants. - #[must_use] - pub fn sub(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::I8(a.wrapping_sub(*b))), - (Self::I16(a), Self::I16(b)) => Some(Self::I16(a.wrapping_sub(*b))), - (Self::I32(a), Self::I32(b)) => Some(Self::I32(a.wrapping_sub(*b))), - (Self::I64(a), Self::I64(b)) => Some(Self::I64(a.wrapping_sub(*b))), - (Self::U8(a), Self::U8(b)) => Some(Self::U8(a.wrapping_sub(*b))), - (Self::U16(a), Self::U16(b)) => Some(Self::U16(a.wrapping_sub(*b))), - (Self::U32(a), Self::U32(b)) => Some(Self::U32(a.wrapping_sub(*b))), - (Self::U64(a), Self::U64(b)) => Some(Self::U64(a.wrapping_sub(*b))), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::NativeInt(a.wrapping_sub(*b))), - (Self::NativeUInt(a), Self::NativeUInt(b)) => { - Some(Self::NativeUInt(a.wrapping_sub(*b))) - } - (Self::F32(a), Self::F32(b)) => Some(Self::F32(a - b)), - (Self::F64(a), Self::F64(b)) => Some(Self::F64(a - b)), - (Self::I32(a), Self::I64(b)) => Some(Self::I64(i64::from(*a).wrapping_sub(*b))), - (Self::I64(a), Self::I32(b)) => Some(Self::I64(a.wrapping_sub(i64::from(*b)))), - (Self::U32(a), Self::U64(b)) => Some(Self::U64(u64::from(*a).wrapping_sub(*b))), - (Self::U64(a), Self::U32(b)) => Some(Self::U64(a.wrapping_sub(u64::from(*b)))), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to multiply two constants. - #[must_use] - pub fn mul(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::I8(a.wrapping_mul(*b))), - (Self::I16(a), Self::I16(b)) => Some(Self::I16(a.wrapping_mul(*b))), - (Self::I32(a), Self::I32(b)) => Some(Self::I32(a.wrapping_mul(*b))), - (Self::I64(a), Self::I64(b)) => Some(Self::I64(a.wrapping_mul(*b))), - (Self::U8(a), Self::U8(b)) => Some(Self::U8(a.wrapping_mul(*b))), - (Self::U16(a), Self::U16(b)) => Some(Self::U16(a.wrapping_mul(*b))), - (Self::U32(a), Self::U32(b)) => Some(Self::U32(a.wrapping_mul(*b))), - (Self::U64(a), Self::U64(b)) => Some(Self::U64(a.wrapping_mul(*b))), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::NativeInt(a.wrapping_mul(*b))), - (Self::NativeUInt(a), Self::NativeUInt(b)) => { - Some(Self::NativeUInt(a.wrapping_mul(*b))) - } - (Self::F32(a), Self::F32(b)) => Some(Self::F32(a * b)), - (Self::F64(a), Self::F64(b)) => Some(Self::F64(a * b)), - (Self::I32(a), Self::I64(b)) | (Self::I64(b), Self::I32(a)) => { - Some(Self::I64(i64::from(*a).wrapping_mul(*b))) - } - (Self::U32(a), Self::U64(b)) | (Self::U64(b), Self::U32(a)) => { - Some(Self::U64(u64::from(*a).wrapping_mul(*b))) - } - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to add two constants with overflow checking. - /// - /// Returns `None` if the addition would overflow. - /// When `unsigned` is true, operands are treated as unsigned for overflow detection. - #[must_use] - pub fn add_checked(&self, other: &Self, unsigned: bool, ptr_size: PointerSize) -> Option { - if unsigned { - // Unsigned overflow check - match (self, other) { - (Self::I32(a), Self::I32(b)) => (*a) - .cast_unsigned() - .checked_add((*b).cast_unsigned()) - .map(|r| Self::I32(r.cast_signed())), - (Self::I64(a), Self::I64(b)) => (*a) - .cast_unsigned() - .checked_add((*b).cast_unsigned()) - .map(|r| Self::I64(r.cast_signed())), - (Self::U8(a), Self::U8(b)) => a.checked_add(*b).map(Self::U8), - (Self::U16(a), Self::U16(b)) => a.checked_add(*b).map(Self::U16), - (Self::U32(a), Self::U32(b)) => a.checked_add(*b).map(Self::U32), - (Self::U64(a), Self::U64(b)) => a.checked_add(*b).map(Self::U64), - (Self::NativeUInt(a), Self::NativeUInt(b)) => { - a.checked_add(*b).map(Self::NativeUInt) - } - _ => None, - } - } else { - // Signed overflow check - match (self, other) { - (Self::I8(a), Self::I8(b)) => a.checked_add(*b).map(Self::I8), - (Self::I16(a), Self::I16(b)) => a.checked_add(*b).map(Self::I16), - (Self::I32(a), Self::I32(b)) => a.checked_add(*b).map(Self::I32), - (Self::I64(a), Self::I64(b)) => a.checked_add(*b).map(Self::I64), - (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_add(*b).map(Self::NativeInt), - _ => None, - } - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to subtract two constants with overflow checking. - /// - /// Returns `None` if the subtraction would overflow. - /// When `unsigned` is true, operands are treated as unsigned for overflow detection. - #[must_use] - pub fn sub_checked(&self, other: &Self, unsigned: bool, ptr_size: PointerSize) -> Option { - if unsigned { - // Unsigned overflow check - match (self, other) { - (Self::I32(a), Self::I32(b)) => (*a) - .cast_unsigned() - .checked_sub((*b).cast_unsigned()) - .map(|r| Self::I32(r.cast_signed())), - (Self::I64(a), Self::I64(b)) => (*a) - .cast_unsigned() - .checked_sub((*b).cast_unsigned()) - .map(|r| Self::I64(r.cast_signed())), - (Self::U8(a), Self::U8(b)) => a.checked_sub(*b).map(Self::U8), - (Self::U16(a), Self::U16(b)) => a.checked_sub(*b).map(Self::U16), - (Self::U32(a), Self::U32(b)) => a.checked_sub(*b).map(Self::U32), - (Self::U64(a), Self::U64(b)) => a.checked_sub(*b).map(Self::U64), - (Self::NativeUInt(a), Self::NativeUInt(b)) => { - a.checked_sub(*b).map(Self::NativeUInt) - } - _ => None, - } - } else { - // Signed overflow check - match (self, other) { - (Self::I8(a), Self::I8(b)) => a.checked_sub(*b).map(Self::I8), - (Self::I16(a), Self::I16(b)) => a.checked_sub(*b).map(Self::I16), - (Self::I32(a), Self::I32(b)) => a.checked_sub(*b).map(Self::I32), - (Self::I64(a), Self::I64(b)) => a.checked_sub(*b).map(Self::I64), - (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_sub(*b).map(Self::NativeInt), - _ => None, - } - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to multiply two constants with overflow checking. - /// - /// Returns `None` if the multiplication would overflow. - /// When `unsigned` is true, operands are treated as unsigned for overflow detection. - #[must_use] - pub fn mul_checked(&self, other: &Self, unsigned: bool, ptr_size: PointerSize) -> Option { - if unsigned { - // Unsigned overflow check - match (self, other) { - (Self::I32(a), Self::I32(b)) => (*a) - .cast_unsigned() - .checked_mul((*b).cast_unsigned()) - .map(|r| Self::I32(r.cast_signed())), - (Self::I64(a), Self::I64(b)) => (*a) - .cast_unsigned() - .checked_mul((*b).cast_unsigned()) - .map(|r| Self::I64(r.cast_signed())), - (Self::U8(a), Self::U8(b)) => a.checked_mul(*b).map(Self::U8), - (Self::U16(a), Self::U16(b)) => a.checked_mul(*b).map(Self::U16), - (Self::U32(a), Self::U32(b)) => a.checked_mul(*b).map(Self::U32), - (Self::U64(a), Self::U64(b)) => a.checked_mul(*b).map(Self::U64), - (Self::NativeUInt(a), Self::NativeUInt(b)) => { - a.checked_mul(*b).map(Self::NativeUInt) - } - _ => None, - } - } else { - // Signed overflow check - match (self, other) { - (Self::I8(a), Self::I8(b)) => a.checked_mul(*b).map(Self::I8), - (Self::I16(a), Self::I16(b)) => a.checked_mul(*b).map(Self::I16), - (Self::I32(a), Self::I32(b)) => a.checked_mul(*b).map(Self::I32), - (Self::I64(a), Self::I64(b)) => a.checked_mul(*b).map(Self::I64), - (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_mul(*b).map(Self::NativeInt), - _ => None, - } - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to divide two constants. Uses `checked_div`/`checked_rem` so - /// MIN/-1 overflows fold to `None` rather than wrapping silently. - #[must_use] - pub fn div(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => a.checked_div(*b).map(Self::I8), - (Self::I16(a), Self::I16(b)) => a.checked_div(*b).map(Self::I16), - (Self::I32(a), Self::I32(b)) => a.checked_div(*b).map(Self::I32), - (Self::I64(a), Self::I64(b)) => a.checked_div(*b).map(Self::I64), - (Self::U8(a), Self::U8(b)) => a.checked_div(*b).map(Self::U8), - (Self::U16(a), Self::U16(b)) => a.checked_div(*b).map(Self::U16), - (Self::U32(a), Self::U32(b)) => a.checked_div(*b).map(Self::U32), - (Self::U64(a), Self::U64(b)) => a.checked_div(*b).map(Self::U64), - (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_div(*b).map(Self::NativeInt), - (Self::NativeUInt(a), Self::NativeUInt(b)) => a.checked_div(*b).map(Self::NativeUInt), - // Float div by zero is inf — IEEE 754 has no panic, no overflow. - (Self::F32(a), Self::F32(b)) => Some(Self::F32(a / b)), - (Self::F64(a), Self::F64(b)) => Some(Self::F64(a / b)), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to compute remainder (modulo) of two constants. Uses - /// `checked_rem` so MIN%-1 overflows fold to `None`. - #[must_use] - pub fn rem(&self, other: &Self, ptr_size: PointerSize) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => a.checked_rem(*b).map(Self::I8), - (Self::I16(a), Self::I16(b)) => a.checked_rem(*b).map(Self::I16), - (Self::I32(a), Self::I32(b)) => a.checked_rem(*b).map(Self::I32), - (Self::I64(a), Self::I64(b)) => a.checked_rem(*b).map(Self::I64), - (Self::U8(a), Self::U8(b)) => a.checked_rem(*b).map(Self::U8), - (Self::U16(a), Self::U16(b)) => a.checked_rem(*b).map(Self::U16), - (Self::U32(a), Self::U32(b)) => a.checked_rem(*b).map(Self::U32), - (Self::U64(a), Self::U64(b)) => a.checked_rem(*b).map(Self::U64), - (Self::NativeInt(a), Self::NativeInt(b)) => a.checked_rem(*b).map(Self::NativeInt), - (Self::NativeUInt(a), Self::NativeUInt(b)) => a.checked_rem(*b).map(Self::NativeUInt), - (Self::F32(a), Self::F32(b)) => Some(Self::F32(a % b)), - (Self::F64(a), Self::F64(b)) => Some(Self::F64(a % b)), - _ => None, - } - .map(|v| v.mask_native(ptr_size)) - } - - /// Attempts to compare two constants for equality. - #[must_use] - #[allow(clippy::float_cmp)] // Exact comparison is correct for constant propagation - #[allow(clippy::match_same_arms)] // NativeInt/NativeUInt are semantically different - pub fn ceq(&self, other: &Self) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::from_bool(a == b)), - (Self::I16(a), Self::I16(b)) => Some(Self::from_bool(a == b)), - (Self::I32(a), Self::I32(b)) => Some(Self::from_bool(a == b)), - (Self::I64(a), Self::I64(b)) => Some(Self::from_bool(a == b)), - (Self::U8(a), Self::U8(b)) => Some(Self::from_bool(a == b)), - (Self::U16(a), Self::U16(b)) => Some(Self::from_bool(a == b)), - (Self::U32(a), Self::U32(b)) => Some(Self::from_bool(a == b)), - (Self::U64(a), Self::U64(b)) => Some(Self::from_bool(a == b)), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::from_bool(a == b)), - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::from_bool(a == b)), - (Self::F32(a), Self::F32(b)) => Some(Self::from_bool(a == b)), - (Self::F64(a), Self::F64(b)) => Some(Self::from_bool(a == b)), - (Self::Null, Self::Null) | (Self::True, Self::True) | (Self::False, Self::False) => { - Some(Self::True) - } - (Self::True, Self::False) | (Self::False, Self::True) => Some(Self::False), - // Cross-type comparisons with promotion - (Self::I32(a), Self::I64(b)) | (Self::I64(b), Self::I32(a)) => { - Some(Self::from_bool(i64::from(*a) == *b)) - } - (Self::U32(a), Self::U64(b)) | (Self::U64(b), Self::U32(a)) => { - Some(Self::from_bool(u64::from(*a) == *b)) - } - _ => None, - } - } - - /// Attempts to compare two constants for less-than (signed). - #[must_use] - #[allow(clippy::match_same_arms)] // NativeInt/NativeUInt are semantically different - pub fn clt(&self, other: &Self) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::from_bool(a < b)), - (Self::I16(a), Self::I16(b)) => Some(Self::from_bool(a < b)), - (Self::I32(a), Self::I32(b)) => Some(Self::from_bool(a < b)), - (Self::I64(a), Self::I64(b)) => Some(Self::from_bool(a < b)), - (Self::U8(a), Self::U8(b)) => Some(Self::from_bool(a < b)), - (Self::U16(a), Self::U16(b)) => Some(Self::from_bool(a < b)), - (Self::U32(a), Self::U32(b)) => Some(Self::from_bool(a < b)), - (Self::U64(a), Self::U64(b)) => Some(Self::from_bool(a < b)), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::from_bool(a < b)), - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::from_bool(a < b)), - (Self::F32(a), Self::F32(b)) => Some(Self::from_bool(a < b)), - (Self::F64(a), Self::F64(b)) => Some(Self::from_bool(a < b)), - (Self::I32(a), Self::I64(b)) => Some(Self::from_bool(i64::from(*a) < *b)), - (Self::I64(a), Self::I32(b)) => Some(Self::from_bool(*a < i64::from(*b))), - (Self::U32(a), Self::U64(b)) => Some(Self::from_bool(u64::from(*a) < *b)), - (Self::U64(a), Self::U32(b)) => Some(Self::from_bool(*a < u64::from(*b))), - _ => None, - } - } - - /// Attempts to compare two constants for less-than (unsigned). - #[must_use] - #[allow(clippy::cast_sign_loss)] // Unsigned comparison requires interpreting bits as unsigned - #[allow(clippy::match_same_arms)] // NativeInt/NativeUInt are semantically different - pub fn clt_un(&self, other: &Self) -> Option { - // Treat values as unsigned for comparison - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::from_bool((*a as u8) < (*b as u8))), - (Self::I16(a), Self::I16(b)) => Some(Self::from_bool((*a as u16) < (*b as u16))), - (Self::I32(a), Self::I32(b)) => Some(Self::from_bool((*a as u32) < (*b as u32))), - (Self::I64(a), Self::I64(b)) => Some(Self::from_bool((*a as u64) < (*b as u64))), - (Self::U8(a), Self::U8(b)) => Some(Self::from_bool(a < b)), - (Self::U16(a), Self::U16(b)) => Some(Self::from_bool(a < b)), - (Self::U32(a), Self::U32(b)) => Some(Self::from_bool(a < b)), - (Self::U64(a), Self::U64(b)) => Some(Self::from_bool(a < b)), - (Self::NativeInt(a), Self::NativeInt(b)) => { - Some(Self::from_bool((*a as u64) < (*b as u64))) - } - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::from_bool(a < b)), - // For floats, clt.un checks for unordered (NaN) or less than - (Self::F32(a), Self::F32(b)) => { - Some(Self::from_bool(a.is_nan() || b.is_nan() || a < b)) - } - (Self::F64(a), Self::F64(b)) => { - Some(Self::from_bool(a.is_nan() || b.is_nan() || a < b)) - } - _ => None, - } - } - - /// Attempts to compare two constants for greater-than (signed). - #[must_use] - #[allow(clippy::match_same_arms)] // NativeInt/NativeUInt are semantically different - pub fn cgt(&self, other: &Self) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::from_bool(a > b)), - (Self::I16(a), Self::I16(b)) => Some(Self::from_bool(a > b)), - (Self::I32(a), Self::I32(b)) => Some(Self::from_bool(a > b)), - (Self::I64(a), Self::I64(b)) => Some(Self::from_bool(a > b)), - (Self::U8(a), Self::U8(b)) => Some(Self::from_bool(a > b)), - (Self::U16(a), Self::U16(b)) => Some(Self::from_bool(a > b)), - (Self::U32(a), Self::U32(b)) => Some(Self::from_bool(a > b)), - (Self::U64(a), Self::U64(b)) => Some(Self::from_bool(a > b)), - (Self::NativeInt(a), Self::NativeInt(b)) => Some(Self::from_bool(a > b)), - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::from_bool(a > b)), - (Self::F32(a), Self::F32(b)) => Some(Self::from_bool(a > b)), - (Self::F64(a), Self::F64(b)) => Some(Self::from_bool(a > b)), - (Self::I32(a), Self::I64(b)) => Some(Self::from_bool(i64::from(*a) > *b)), - (Self::I64(a), Self::I32(b)) => Some(Self::from_bool(*a > i64::from(*b))), - (Self::U32(a), Self::U64(b)) => Some(Self::from_bool(u64::from(*a) > *b)), - (Self::U64(a), Self::U32(b)) => Some(Self::from_bool(*a > u64::from(*b))), - _ => None, - } - } - - /// Attempts to compare two constants for greater-than (unsigned). - #[must_use] - #[allow(clippy::cast_sign_loss)] // Unsigned comparison requires interpreting bits as unsigned - #[allow(clippy::match_same_arms)] // NativeInt/NativeUInt are semantically different - pub fn cgt_un(&self, other: &Self) -> Option { - match (self, other) { - (Self::I8(a), Self::I8(b)) => Some(Self::from_bool((*a as u8) > (*b as u8))), - (Self::I16(a), Self::I16(b)) => Some(Self::from_bool((*a as u16) > (*b as u16))), - (Self::I32(a), Self::I32(b)) => Some(Self::from_bool((*a as u32) > (*b as u32))), - (Self::I64(a), Self::I64(b)) => Some(Self::from_bool((*a as u64) > (*b as u64))), - (Self::U8(a), Self::U8(b)) => Some(Self::from_bool(a > b)), - (Self::U16(a), Self::U16(b)) => Some(Self::from_bool(a > b)), - (Self::U32(a), Self::U32(b)) => Some(Self::from_bool(a > b)), - (Self::U64(a), Self::U64(b)) => Some(Self::from_bool(a > b)), - (Self::NativeInt(a), Self::NativeInt(b)) => { - Some(Self::from_bool((*a as u64) > (*b as u64))) - } - (Self::NativeUInt(a), Self::NativeUInt(b)) => Some(Self::from_bool(a > b)), - // For floats, cgt.un checks for unordered (NaN) or greater than - (Self::F32(a), Self::F32(b)) => { - Some(Self::from_bool(a.is_nan() || b.is_nan() || a > b)) - } - (Self::F64(a), Self::F64(b)) => { - Some(Self::from_bool(a.is_nan() || b.is_nan() || a > b)) - } - _ => None, - } - } - - /// Converts this constant to a different type. - /// - /// This implements CIL type conversion semantics (conv.* instructions). - /// For overflow checking conversions, use `convert_to_checked`. - /// - /// # Arguments - /// - /// * `target` - The target SSA type to convert to. - /// * `unsigned_source` - If true, treat the source value as unsigned for conversion. - /// - /// # Returns - /// - /// The converted constant, or `None` if conversion is not possible. - /// - /// # Example - /// - /// ```rust,no_run - /// use dotscope::analysis::{ConstValue, SsaType}; - /// use dotscope::metadata::typesystem::PointerSize; - /// - /// let value = ConstValue::I32(42); - /// let converted = value.convert_to(&SsaType::I64, false, PointerSize::Bit64); - /// assert_eq!(converted, Some(ConstValue::I64(42))); - /// ``` - #[must_use] - pub fn convert_to( - &self, - target: &SsaType, - unsigned_source: bool, - ptr_size: PointerSize, - ) -> Option { - // For unsigned source interpretation, get the raw bits as u64 - // For signed source, get as i64 - let (signed_val, unsigned_val) = if unsigned_source { - let u = self.as_u64()?; - // Reinterpret as signed for operations that need it - (i64::from_ne_bytes(u.to_ne_bytes()), u) - } else { - let s = self.as_i64()?; - // Reinterpret as unsigned for operations that need it - (s, u64::from_ne_bytes(s.to_ne_bytes())) - }; - - // These casts are intentional truncations for type conversion - #[allow(clippy::cast_possible_truncation)] - Some(match target { - // Truncating conversions - use wrapping to get low bits - SsaType::I8 => Self::I8(signed_val as i8), - SsaType::U8 => Self::U8(unsigned_val as u8), - SsaType::I16 => Self::I16(signed_val as i16), - SsaType::U16 | SsaType::Char => Self::U16(unsigned_val as u16), - SsaType::I32 => Self::I32(signed_val as i32), - SsaType::U32 => Self::U32(unsigned_val as u32), - // Non-truncating conversions - SsaType::I64 => Self::I64(signed_val), - SsaType::U64 => Self::U64(unsigned_val), - SsaType::NativeInt => Self::NativeInt(signed_val), - SsaType::NativeUInt => Self::NativeUInt(unsigned_val), - // Float conversions - interpretation matters - SsaType::F32 => - { - #[allow(clippy::cast_precision_loss)] - if unsigned_source { - Self::F32(unsigned_val as f32) - } else { - Self::F32(signed_val as f32) - } - } - SsaType::F64 => - { - #[allow(clippy::cast_precision_loss)] - if unsigned_source { - Self::F64(unsigned_val as f64) - } else { - Self::F64(signed_val as f64) - } - } - SsaType::Bool => Self::from_bool(signed_val != 0), - _ => return None, - }) - .map(|v| v.mask_native(ptr_size)) - } - - /// Converts this constant to a different type with overflow checking. - /// - /// This implements CIL overflow-checked conversion semantics (conv.ovf.* instructions). - /// Returns `None` if the value would overflow the target type. - /// - /// # Arguments - /// - /// * `target` - The target SSA type to convert to. - /// * `unsigned_source` - If true, treat the source value as unsigned for conversion. - /// - /// # Returns - /// - /// The converted constant if no overflow, or `None` if conversion would overflow - /// or is not possible. - /// - /// # Example - /// - /// ```rust,no_run - /// use dotscope::analysis::{ConstValue, SsaType}; - /// use dotscope::metadata::typesystem::PointerSize; - /// - /// let value = ConstValue::I32(1000); - /// // 1000 doesn't fit in i8 (-128 to 127) - /// assert_eq!(value.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), None); - /// - /// let small = ConstValue::I32(42); - /// assert_eq!(small.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), Some(ConstValue::I8(42))); - /// ``` - #[must_use] - pub fn convert_to_checked( - &self, - target: &SsaType, - unsigned_source: bool, - ptr_size: PointerSize, - ) -> Option { - // Get both signed and unsigned interpretations - let (signed_val, unsigned_val) = if unsigned_source { - let u = self.as_u64()?; - (i64::from_ne_bytes(u.to_ne_bytes()), u) - } else { - let s = self.as_i64()?; - (s, u64::from_ne_bytes(s.to_ne_bytes())) - }; - - // Check if value fits in target type using try_into for range checking - let fits = match target { - SsaType::I8 => i8::try_from(signed_val).is_ok(), - SsaType::U8 => u8::try_from(unsigned_val).is_ok() && signed_val >= 0, - SsaType::I16 => i16::try_from(signed_val).is_ok(), - SsaType::U16 => u16::try_from(unsigned_val).is_ok() && signed_val >= 0, - SsaType::I32 => i32::try_from(signed_val).is_ok(), - SsaType::U32 => u32::try_from(unsigned_val).is_ok() && signed_val >= 0, - SsaType::I64 - | SsaType::NativeInt - | SsaType::Bool - | SsaType::Char - | SsaType::F32 - | SsaType::F64 => true, - SsaType::U64 | SsaType::NativeUInt => signed_val >= 0, - _ => return None, - }; - - if !fits { - return None; // Would overflow - } - - // Perform the conversion (same as convert_to) - self.convert_to(target, unsigned_source, ptr_size) - } - - /// Masks a `ConstValue` to the target pointer width. - /// - /// For `NativeInt`, sign-extends from 32-bit on `Bit32`. - /// For `NativeUInt`, zero-extends from 32-bit on `Bit32`. - /// All other variants are returned unchanged. - #[must_use] - pub fn mask_native(self, ptr_size: PointerSize) -> Self { - match self { - Self::NativeInt(v) => Self::NativeInt(ptr_size.mask_signed(v)), - Self::NativeUInt(v) => Self::NativeUInt(ptr_size.mask_unsigned(v)), - other => other, - } - } -} - -impl fmt::Display for ConstValue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::I8(v) => write!(f, "{v}i8"), - Self::I16(v) => write!(f, "{v}i16"), - Self::I32(v) => write!(f, "{v}"), - Self::I64(v) => write!(f, "{v}L"), - Self::U8(v) => write!(f, "{v}u8"), - Self::U16(v) => write!(f, "{v}u16"), - Self::U32(v) => write!(f, "{v}u"), - Self::U64(v) => write!(f, "{v}UL"), - Self::NativeInt(v) => write!(f, "{v}n"), - Self::NativeUInt(v) => write!(f, "{v}un"), - Self::F32(v) => write!(f, "{v}f"), - Self::F64(v) => write!(f, "{v}"), - Self::String(idx) => write!(f, "str@{idx}"), - Self::DecryptedString(s) => write!(f, "\"{}\"", s.escape_default()), - Self::DecryptedArray { - data, - element_type_token, - element_size, - } => { - write!( - f, - "array[{}x{}]<0x{:08X}>", - data.len() - .checked_div(*element_size.max(&1)) - .unwrap_or(data.len()), - element_size, - element_type_token.value() - ) - } - Self::Null => write!(f, "null"), - Self::True => write!(f, "true"), - Self::False => write!(f, "false"), - Self::Type(t) => write!(f, "typeof({t})"), - Self::MethodHandle(m) => write!(f, "methodof({m})"), - Self::FieldHandle(fl) => write!(f, "fieldof({fl})"), - } - } } -/// Attempts to convert a `ConstValue` to an `Immediate` for CIL instruction encoding. -/// -/// This conversion handles the numeric `ConstValue` variants, mapping them to -/// their corresponding `Immediate` representations. The conversion follows -/// CIL semantics where: -/// -/// - Signed integers map directly to their signed `Immediate` variants -/// - Unsigned integers use bit-preserving casts to signed types (since CIL -/// doesn't distinguish unsigned at the instruction level for most operations) -/// - Floating-point values map directly -/// - Boolean values map to `Int32` (1 for true, 0 for false) -/// - Native integers use 64-bit representations -/// -/// # Errors -/// -/// Returns [`crate::Error::SsaError`] for non-numeric `ConstValue` variants -/// (`String`, `DecryptedString`, `Null`, `Type`, `MethodHandle`, `FieldHandle`) -/// since these cannot be represented as immediate values. Handle these cases -/// with pattern matching before conversion. -/// -/// # Example -/// -/// ```rust,ignore -/// use dotscope::analysis::ConstValue; -/// use dotscope::assembly::Immediate; -/// use std::convert::TryFrom; -/// -/// let const_val = ConstValue::I32(42); -/// let immediate = Immediate::try_from(&const_val)?; -/// assert!(matches!(immediate, Immediate::Int32(42))); -/// -/// // Non-numeric values return an error -/// let null_val = ConstValue::Null; -/// assert!(Immediate::try_from(&null_val).is_err()); -/// ``` -impl TryFrom<&ConstValue> for Immediate { +/// Attempts to convert a `ConstValue` to an `Immediate` for CIL instruction +/// encoding. Mirrors the original CIL-specific TryFrom impl from when +/// `ConstValue` lived in dotscope. +impl TryFrom<&AnalyssaConstValue> for Immediate { type Error = Error; #[allow(clippy::cast_possible_wrap)] // Intentional bit-preserving casts for CIL semantics - fn try_from(value: &ConstValue) -> Result { + fn try_from(value: &AnalyssaConstValue) -> Result { match value { - // Signed integers - direct mapping - ConstValue::I8(v) => Ok(Immediate::Int8(*v)), - ConstValue::I16(v) => Ok(Immediate::Int16(*v)), - ConstValue::I32(v) => Ok(Immediate::Int32(*v)), - - // Unsigned integers - use signed Immediate variants with bit-preserving casts. - // CIL instructions don't distinguish signed/unsigned for most operations; - // the bit pattern is what matters. - ConstValue::U8(v) => Ok(Immediate::Int8(*v as i8)), - ConstValue::U16(v) => Ok(Immediate::Int16(*v as i16)), - ConstValue::U32(v) => Ok(Immediate::Int32(*v as i32)), - - // 64-bit integers and native integers use Int64 representation - // (NativeInt is semantically different but has identical representation) - ConstValue::I64(v) | ConstValue::NativeInt(v) => Ok(Immediate::Int64(*v)), - ConstValue::U64(v) | ConstValue::NativeUInt(v) => Ok(Immediate::Int64(*v as i64)), - - // Floating point - direct mapping - ConstValue::F32(v) => Ok(Immediate::Float32(*v)), - ConstValue::F64(v) => Ok(Immediate::Float64(*v)), - - // Boolean values - map to Int32 (CIL uses int32 for booleans on stack) - ConstValue::True => Ok(Immediate::Int32(1)), - ConstValue::False => Ok(Immediate::Int32(0)), - - // Non-numeric types cannot be converted to immediates - ConstValue::String(_) - | ConstValue::DecryptedString(_) - | ConstValue::DecryptedArray { .. } - | ConstValue::Null - | ConstValue::Type(_) - | ConstValue::MethodHandle(_) - | ConstValue::FieldHandle(_) => Err(Error::SsaError(format!( - "Cannot convert {value:?} to Immediate - use pattern matching to handle this case" - ))), - } - } -} - -/// Abstract value for dataflow analysis. -/// -/// This represents the abstract state of an SSA variable during analysis. -/// It forms a lattice where values can be refined as more information is gathered. -#[derive(Debug, Clone, PartialEq, Default)] -pub enum AbstractValue { - /// No information yet (top of lattice). - /// - /// This is the initial state before any analysis. - #[default] - Top, - - /// Known constant value. - Constant(ConstValue), + AnalyssaConstValue::I8(v) => Ok(Immediate::Int8(*v)), + AnalyssaConstValue::I16(v) => Ok(Immediate::Int16(*v)), + AnalyssaConstValue::I32(v) => Ok(Immediate::Int32(*v)), - /// Known to be non-null (for reference types). - NonNull, - - /// Value in a bounded range [min, max]. - Range { - /// Minimum value (inclusive). - min: i64, - /// Maximum value (inclusive). - max: i64, - }, - - /// Same value as another SSA variable. - /// - /// Used for copy propagation. - SameAs(SsaVarId), + AnalyssaConstValue::U8(v) => Ok(Immediate::Int8(*v as i8)), + AnalyssaConstValue::U16(v) => Ok(Immediate::Int16(*v as i16)), + AnalyssaConstValue::U32(v) => Ok(Immediate::Int32(*v as i32)), - /// Result of a specific computation (for CSE). - Computed(ComputedValue), - - /// Multiple possible values (bottom of lattice for constants). - /// - /// This means the value cannot be determined at compile time. - Bottom, -} - -impl AbstractValue { - /// Returns `true` if this is the top element (no information). - #[must_use] - pub const fn is_top(&self) -> bool { - matches!(self, Self::Top) - } - - /// Returns `true` if this is the bottom element (conflicting info). - #[must_use] - pub const fn is_bottom(&self) -> bool { - matches!(self, Self::Bottom) - } - - /// Returns `true` if this is a known constant. - #[must_use] - pub const fn is_constant(&self) -> bool { - matches!(self, Self::Constant(_)) - } - - /// Returns the constant value if this is a constant. - #[must_use] - pub const fn as_constant(&self) -> Option<&ConstValue> { - match self { - Self::Constant(c) => Some(c), - _ => None, - } - } - - /// Returns `true` if this value is known to be non-null. - #[must_use] - pub const fn is_non_null(&self) -> bool { - matches!(self, Self::NonNull | Self::Constant(_)) - } - - /// Meet operation for the lattice (used at control flow joins). - /// - /// Returns the greatest lower bound of `self` and `other`. - #[must_use] - #[allow(clippy::match_same_arms)] // Arms kept separate for lattice documentation clarity - pub fn meet(&self, other: &Self) -> Self { - match (self, other) { - // Top meets anything yields the other - (Self::Top, x) | (x, Self::Top) => x.clone(), - - // Bottom meets anything yields Bottom - (Self::Bottom, _) | (_, Self::Bottom) => Self::Bottom, - - // Same constants stay constant - (Self::Constant(a), Self::Constant(b)) if a == b => Self::Constant(a.clone()), - - // Different constants become Bottom - (Self::Constant(_), Self::Constant(_)) => Self::Bottom, - - // NonNull meets NonNull stays NonNull - (Self::NonNull, Self::NonNull) => Self::NonNull, - - // NonNull meets Constant stays Constant (constants are non-null if not null) - (Self::NonNull, Self::Constant(c)) | (Self::Constant(c), Self::NonNull) => { - if c.is_null() { - Self::Bottom // null is not non-null - } else { - Self::Constant(c.clone()) - } + AnalyssaConstValue::I64(v) | AnalyssaConstValue::NativeInt(v) => { + Ok(Immediate::Int64(*v)) } - - // Ranges can be merged - ( - Self::Range { - min: a_min, - max: a_max, - }, - Self::Range { - min: b_min, - max: b_max, - }, - ) => { - let new_min = (*a_min).min(*b_min); - let new_max = (*a_max).max(*b_max); - Self::Range { - min: new_min, - max: new_max, - } + AnalyssaConstValue::U64(v) | AnalyssaConstValue::NativeUInt(v) => { + Ok(Immediate::Int64(*v as i64)) } - // SameAs values must match - (Self::SameAs(a), Self::SameAs(b)) if a == b => Self::SameAs(*a), - - // Computed values must match exactly - (Self::Computed(a), Self::Computed(b)) if a == b => Self::Computed(a.clone()), - - // Otherwise, Bottom - _ => Self::Bottom, - } - } - - /// Join operation for the lattice. - /// - /// Returns the least upper bound of `self` and `other`. - #[must_use] - pub fn join(&self, other: &Self) -> Self { - match (self, other) { - // Bottom joins anything yields the other - (Self::Bottom, x) | (x, Self::Bottom) => x.clone(), - - // Top joins anything yields Top - (Self::Top, _) | (_, Self::Top) => Self::Top, + AnalyssaConstValue::F32(v) => Ok(Immediate::Float32(*v)), + AnalyssaConstValue::F64(v) => Ok(Immediate::Float64(*v)), - // Same values stay the same - (a, b) if a == b => a.clone(), + AnalyssaConstValue::True => Ok(Immediate::Int32(1)), + AnalyssaConstValue::False => Ok(Immediate::Int32(0)), - // Otherwise, Top - _ => Self::Top, - } - } -} - -impl fmt::Display for AbstractValue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Top => write!(f, "⊤"), - Self::Constant(c) => write!(f, "{c}"), - Self::NonNull => write!(f, "!null"), - Self::Range { min, max } => write!(f, "[{min}..{max}]"), - Self::SameAs(v) => write!(f, "={v}"), - Self::Computed(c) => write!(f, "{c}"), - Self::Bottom => write!(f, "⊥"), - } - } -} - -/// Computed value for common subexpression elimination (CSE). -/// -/// This represents the result of a computation, enabling recognition -/// of equivalent expressions that can be eliminated. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -pub struct ComputedValue { - /// The operation that produced this value. - pub op: ComputedOp, - /// The operands to the operation. - pub operands: Vec, -} - -impl ComputedValue { - /// Creates a new computed value. - #[must_use] - pub fn new(op: ComputedOp, operands: Vec) -> Self { - Self { op, operands } - } - - /// Creates a unary computed value. - #[must_use] - pub fn unary(op: ComputedOp, operand: SsaVarId) -> Self { - Self { - op, - operands: vec![operand], - } - } - - /// Creates a binary computed value. - #[must_use] - pub fn binary(op: ComputedOp, left: SsaVarId, right: SsaVarId) -> Self { - Self { - op, - operands: vec![left, right], - } - } - - /// Normalizes commutative operations for better CSE. - /// - /// For commutative ops like add/mul, orders operands consistently - /// so that `a + b` and `b + a` have the same computed value. - #[must_use] - pub fn normalized(self) -> Self { - if self.op.is_commutative() && self.operands.len() == 2 { - let mut ops = self.operands; - if let (Some(a), Some(b)) = (ops.first(), ops.get(1)) { - if a.index() > b.index() { - ops.swap(0, 1); - } - } - Self { - op: self.op, - operands: ops, - } - } else { - self - } - } -} - -impl fmt::Display for ComputedValue { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}(", self.op)?; - for (i, op) in self.operands.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; - } - write!(f, "{op}")?; + AnalyssaConstValue::String(_) + | AnalyssaConstValue::DecryptedString(_) + | AnalyssaConstValue::DecryptedArray { .. } + | AnalyssaConstValue::Null + | AnalyssaConstValue::Type(_) + | AnalyssaConstValue::MethodHandle(_) + | AnalyssaConstValue::FieldHandle(_) => Err(Error::SsaError(format!( + "Cannot convert {value:?} to Immediate - use pattern matching to handle this case" + ))), } - write!(f, ")") - } -} - -/// Operations that can be tracked for CSE. -/// -/// These represent the pure operations whose results can be reused -/// when the same operation is performed with the same operands. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum ComputedOp { - // Arithmetic - /// Addition - Add, - /// Subtraction - Sub, - /// Multiplication - Mul, - /// Division - Div, - /// Remainder (modulo) - Rem, - /// Negation - Neg, - - // Bitwise - /// Bitwise AND - And, - /// Bitwise OR - Or, - /// Bitwise XOR - Xor, - /// Bitwise NOT - Not, - /// Shift left - Shl, - /// Shift right - Shr, - - // Comparison - /// Compare equal - Ceq, - /// Compare not equal - Cne, - /// Compare less than - Clt, - /// Compare greater than - Cgt, - /// Compare less than or equal - Cle, - /// Compare greater than or equal - Cge, - - // Conversion - /// Convert to int8 - ConvI1, - /// Convert to int16 - ConvI2, - /// Convert to int32 - ConvI4, - /// Convert to int64 - ConvI8, - /// Convert to uint8 - ConvU1, - /// Convert to uint16 - ConvU2, - /// Convert to uint32 - ConvU4, - /// Convert to uint64 - ConvU8, - /// Convert to float32 - ConvR4, - /// Convert to float64 - ConvR8, -} - -impl ComputedOp { - /// Returns `true` if this operation is commutative. - #[must_use] - pub const fn is_commutative(&self) -> bool { - matches!( - self, - Self::Add | Self::Mul | Self::And | Self::Or | Self::Xor | Self::Ceq | Self::Cne - ) - } - - /// Returns `true` if this is a comparison operation. - #[must_use] - pub const fn is_comparison(&self) -> bool { - matches!( - self, - Self::Ceq | Self::Cne | Self::Clt | Self::Cgt | Self::Cle | Self::Cge - ) - } -} - -impl fmt::Display for ComputedOp { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - let s = match self { - Self::Add => "add", - Self::Sub => "sub", - Self::Mul => "mul", - Self::Div => "div", - Self::Rem => "rem", - Self::Neg => "neg", - Self::And => "and", - Self::Or => "or", - Self::Xor => "xor", - Self::Not => "not", - Self::Shl => "shl", - Self::Shr => "shr", - Self::Ceq => "ceq", - Self::Cne => "cne", - Self::Clt => "clt", - Self::Cgt => "cgt", - Self::Cle => "cle", - Self::Cge => "cge", - Self::ConvI1 => "conv.i1", - Self::ConvI2 => "conv.i2", - Self::ConvI4 => "conv.i4", - Self::ConvI8 => "conv.i8", - Self::ConvU1 => "conv.u1", - Self::ConvU2 => "conv.u2", - Self::ConvU4 => "conv.u4", - Self::ConvU8 => "conv.u8", - Self::ConvR4 => "conv.r4", - Self::ConvR8 => "conv.r8", - }; - write!(f, "{s}") } } @@ -1704,8 +131,18 @@ impl fmt::Display for ComputedOp { mod tests { use super::*; + use analyssa::ir::value::{ComputedOp, ComputedValue}; + use analyssa::ir::variable::SsaVarId; + use crate::metadata::typesystem::PointerSize; + // Tests construct unit/numeric variants like `ConstValue::I32(_)` that + // don't constrain the `T` type parameter. Type-parameter defaults aren't + // used to break inference ambiguity in expression position, so we shadow + // the names here to lock T to CilTarget for the entire test module. + type ConstValue = super::ConstValue; + type AbstractValue = super::AbstractValue; + #[test] fn test_const_arithmetic() { let a = ConstValue::I32(10); @@ -1856,21 +293,21 @@ mod tests { // i32 -> i64 (sign extends) let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::I64, false, PointerSize::Bit64), + v.convert_to(&SsaType::I64, false, 8), Some(ConstValue::I64(42)) ); // i32 -> i64 with negative let v = ConstValue::I32(-42); assert_eq!( - v.convert_to(&SsaType::I64, false, PointerSize::Bit64), + v.convert_to(&SsaType::I64, false, 8), Some(ConstValue::I64(-42)) ); // u32 -> u64 (zero extends) let v = ConstValue::U32(42); assert_eq!( - v.convert_to(&SsaType::U64, false, PointerSize::Bit64), + v.convert_to(&SsaType::U64, false, 8), Some(ConstValue::U64(42)) ); } @@ -1880,7 +317,7 @@ mod tests { // i32 -> i8 (truncates) let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::I8, false, PointerSize::Bit64), + v.convert_to(&SsaType::I8, false, 8), Some(ConstValue::I8(42)) ); @@ -1888,14 +325,14 @@ mod tests { let v = ConstValue::I32(1000); // 1000 = 0x3E8, truncated to i8 = 0xE8 = -24 (signed) assert_eq!( - v.convert_to(&SsaType::I8, false, PointerSize::Bit64), + v.convert_to(&SsaType::I8, false, 8), Some(ConstValue::I8(-24)) ); // i64 -> i32 (truncates) let v = ConstValue::I64(0x1_0000_0042); assert_eq!( - v.convert_to(&SsaType::I32, false, PointerSize::Bit64), + v.convert_to(&SsaType::I32, false, 8), Some(ConstValue::I32(0x42)) ); } @@ -1905,21 +342,21 @@ mod tests { // i32 -> f32 let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::F32, false, PointerSize::Bit64), + v.convert_to(&SsaType::F32, false, 8), Some(ConstValue::F32(42.0)) ); // i32 -> f64 let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::F64, false, PointerSize::Bit64), + v.convert_to(&SsaType::F64, false, 8), Some(ConstValue::F64(42.0)) ); // Unsigned source to float let v = ConstValue::U32(42); assert_eq!( - v.convert_to(&SsaType::F32, true, PointerSize::Bit64), + v.convert_to(&SsaType::F32, true, 8), Some(ConstValue::F32(42.0)) ); } @@ -1929,14 +366,14 @@ mod tests { // Non-zero -> true let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::Bool, false, PointerSize::Bit64), + v.convert_to(&SsaType::Bool, false, 8), Some(ConstValue::True) ); // Zero -> false let v = ConstValue::I32(0); assert_eq!( - v.convert_to(&SsaType::Bool, false, PointerSize::Bit64), + v.convert_to(&SsaType::Bool, false, 8), Some(ConstValue::False) ); } @@ -1946,20 +383,20 @@ mod tests { // Value fits in target let v = ConstValue::I32(100); assert_eq!( - v.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), + v.convert_to_checked(&SsaType::I8, false, 8), Some(ConstValue::I8(100)) ); // Value at boundary let v = ConstValue::I32(127); assert_eq!( - v.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), + v.convert_to_checked(&SsaType::I8, false, 8), Some(ConstValue::I8(127)) ); let v = ConstValue::I32(-128); assert_eq!( - v.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), + v.convert_to_checked(&SsaType::I8, false, 8), Some(ConstValue::I8(-128)) ); } @@ -1968,25 +405,13 @@ mod tests { fn test_convert_to_checked_overflow() { // Value overflows target let v = ConstValue::I32(1000); - assert_eq!( - v.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), - None - ); + assert_eq!(v.convert_to_checked(&SsaType::I8, false, 8), None); // Negative to unsigned let v = ConstValue::I32(-1); - assert_eq!( - v.convert_to_checked(&SsaType::U8, false, PointerSize::Bit64), - None - ); - assert_eq!( - v.convert_to_checked(&SsaType::U32, false, PointerSize::Bit64), - None - ); - assert_eq!( - v.convert_to_checked(&SsaType::U64, false, PointerSize::Bit64), - None - ); + assert_eq!(v.convert_to_checked(&SsaType::U8, false, 8), None); + assert_eq!(v.convert_to_checked(&SsaType::U32, false, 8), None); + assert_eq!(v.convert_to_checked(&SsaType::U64, false, 8), None); } #[test] @@ -1994,7 +419,7 @@ mod tests { // i32 -> char (u16) let v = ConstValue::I32(65); // 'A' assert_eq!( - v.convert_to(&SsaType::Char, false, PointerSize::Bit64), + v.convert_to(&SsaType::Char, false, 8), Some(ConstValue::U16(65)) ); } @@ -2004,14 +429,14 @@ mod tests { // i32 -> NativeInt let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::NativeInt, false, PointerSize::Bit64), + v.convert_to(&SsaType::NativeInt, false, 8), Some(ConstValue::NativeInt(42)) ); // i32 -> NativeUInt let v = ConstValue::I32(42); assert_eq!( - v.convert_to(&SsaType::NativeUInt, false, PointerSize::Bit64), + v.convert_to(&SsaType::NativeUInt, false, 8), Some(ConstValue::NativeUInt(42)) ); } diff --git a/dotscope/src/analysis/ssa/variable.rs b/dotscope/src/analysis/ssa/variable.rs deleted file mode 100644 index e80a46d1..00000000 --- a/dotscope/src/analysis/ssa/variable.rs +++ /dev/null @@ -1,713 +0,0 @@ -//! SSA variable representation and identifiers. -//! -//! This module defines the core types for representing variables in SSA form. -//! Each SSA variable has a unique identifier and is assigned exactly once, -//! enabling precise tracking of data flow through the program. -//! -//! # Design Rationale -//! -//! ## Variable Identification -//! -//! SSA variables are identified by a simple index ([`SsaVarId`]) into a variable -//! table. This provides O(1) lookup and minimal memory overhead. The ID encodes -//! no semantic information - all variable metadata is stored in [`SsaVariable`]. -//! -//! ## Variable Origins -//! -//! CIL has three primary sources of values that become SSA variables: -//! -//! 1. **Arguments** - Method parameters passed by the caller -//! 2. **Locals** - Local variables declared in the method -//! 3. **Stack temporaries** - Values pushed/popped during evaluation -//! -//! Additionally, phi nodes at control flow merge points create new variables. -//! -//! ## Address-Taken Variables -//! -//! Variables whose address is taken (`ldarga`, `ldloca`) are marked specially. -//! These variables may be modified through pointers and thus cannot participate -//! in certain SSA optimizations. We track this conservatively. -//! -//! # Thread Safety -//! -//! All types in this module are `Send` and `Sync` when their generic parameters -//! (if any) are also `Send` and `Sync`. - -use std::fmt; - -use crate::analysis::ssa::SsaType; - -/// Unique identifier for an SSA variable. -/// -/// This is a lightweight handle into the variable table, providing O(1) access -/// to variable metadata. Variable IDs are dense and sequential within each -/// [`SsaFunction`](crate::analysis::SsaFunction) (0, 1, 2, ...), enabling -/// direct indexing into the variables vector. -/// -/// # Memory Layout -/// -/// Uses `usize` internally to match native indexing, avoiding conversions -/// when accessing variable tables. -/// -/// # Construction -/// -/// Variable IDs are allocated by [`FunctionVarAllocator`] through -/// [`SsaFunction::create_variable()`](crate::analysis::SsaFunction::create_variable). -/// Use [`SsaVarId::from_index()`] only to reconstruct IDs from stored indices. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] -pub struct SsaVarId(usize); - -impl SsaVarId { - /// A sentinel value representing an uninitialized or placeholder variable ID. - /// - /// This is used during phi placement and other construction phases where a - /// real variable ID hasn't been assigned yet. Placeholder IDs must be replaced - /// with real IDs before the SSA function is finalized. - pub const PLACEHOLDER: Self = Self(usize::MAX); - - /// Returns `true` if this is the placeholder sentinel value. - #[must_use] - pub const fn is_placeholder(self) -> bool { - self.0 == usize::MAX - } - - /// Creates an `SsaVarId` from an index value. - /// - /// This is the primary way to construct variable IDs. In production code, - /// IDs are allocated by [`FunctionVarAllocator`] to ensure dense, sequential - /// numbering within each function. This method is also used to reconstruct - /// IDs from stored indices (e.g., in BitSets). - /// - /// # Arguments - /// - /// * `index` - The index value for this variable ID - #[must_use] - pub const fn from_index(index: usize) -> Self { - Self(index) - } - - /// Returns the underlying index. - /// - /// In production code, this index is dense and contiguous within a function - /// (0, 1, 2, ...), enabling O(1) lookup via `variables[id.index()]`. - #[must_use] - pub const fn index(self) -> usize { - self.0 - } -} - -impl fmt::Debug for SsaVarId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "v{}", self.0) - } -} - -impl fmt::Display for SsaVarId { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "v{}", self.0) - } -} - -/// The origin of an SSA variable - where it came from in the original CIL. -/// -/// This enum tracks the semantic source of each SSA variable, which is useful -/// for debugging, optimization decisions, and mapping back to source code. -/// -/// # CIL Variable Mapping -/// -/// | CIL Instruction | Variable Origin | -/// |-----------------|-----------------| -/// | `ldarg.N`, `ldarg.s`, `ldarg` | `Argument(N)` | -/// | `ldloc.N`, `ldloc.s`, `ldloc` | `Local(N)` | -/// | Stack operations (add, call, etc.) | `Local(num_locals + K)` | -/// | Phi node result | `Phi` | -/// -/// # Examples -/// -/// ```rust,no_run -/// use dotscope::analysis::VariableOrigin; -/// -/// let arg_origin = VariableOrigin::Argument(0); // First method argument -/// let local_origin = VariableOrigin::Local(2); // Third local variable -/// let phi_origin = VariableOrigin::Phi; // From phi node -/// ``` -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] -pub enum VariableOrigin { - /// Method argument (parameter). - /// - /// The index corresponds to the argument's position in the method signature. - /// For instance methods, argument 0 is `this`. - Argument(u16), - - /// Local variable declared in the method. - /// - /// The index corresponds to the local's position in the local variable - /// signature (accessed via `ldloc`/`stloc`). `Local(idx)` always refers - /// to a real CIL local. Stack temporaries and other synthetics use - /// `Phi` origin instead. - Local(u16), - - /// Result of a phi node at a control flow merge. - /// - /// Phi nodes are synthetic - they don't correspond to any CIL instruction - /// but rather represent the merging of values from different control flow paths. - Phi, -} - -impl VariableOrigin { - /// Returns `true` if this is an argument origin. - #[must_use] - pub const fn is_argument(&self) -> bool { - matches!(self, Self::Argument(_)) - } - - /// Returns `true` if this is a local variable origin. - #[must_use] - pub const fn is_local(&self) -> bool { - matches!(self, Self::Local(_)) - } - - /// Returns `true` if this is a phi node result. - #[must_use] - pub const fn is_phi(&self) -> bool { - matches!(self, Self::Phi) - } - - /// Returns the argument index if this is an argument origin. - #[must_use] - pub const fn argument_index(&self) -> Option { - match self { - Self::Argument(idx) => Some(*idx), - _ => None, - } - } - - /// Returns the local index if this is a local variable origin. - #[must_use] - pub const fn local_index(&self) -> Option { - match self { - Self::Local(idx) => Some(*idx), - _ => None, - } - } -} - -impl fmt::Display for VariableOrigin { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Argument(idx) => write!(f, "arg{idx}"), - Self::Local(idx) => write!(f, "loc{idx}"), - Self::Phi => write!(f, "phi"), - } - } -} - -/// Per-function allocator for dense, contiguous SSA variable IDs. -/// -/// Unlike the global `SsaVarId::from_index(0)` counter, this allocator produces IDs -/// starting from 0 that are contiguous within a single function. This enables -/// O(1) variable lookup via direct vector indexing: `variables[id.index()]`. -/// -/// # Usage -/// -/// ```rust,ignore -/// let mut alloc = FunctionVarAllocator::new(); -/// let id0 = alloc.alloc(); // SsaVarId(0) -/// let id1 = alloc.alloc(); // SsaVarId(1) -/// assert_eq!(alloc.count(), 2); -/// ``` -#[derive(Debug, Clone)] -pub struct FunctionVarAllocator { - next_id: usize, -} - -impl FunctionVarAllocator { - /// Creates a new allocator starting from ID 0. - #[must_use] - pub fn new() -> Self { - Self { next_id: 0 } - } - - /// Creates a new allocator starting from a specific ID. - /// - /// Used when resuming allocation after compaction or when - /// variables already exist with IDs 0..start_id. - #[must_use] - pub fn starting_from(start_id: usize) -> Self { - Self { next_id: start_id } - } - - /// Allocates the next dense variable ID. - pub fn alloc(&mut self) -> SsaVarId { - let id = SsaVarId::from_index(self.next_id); - self.next_id = self.next_id.saturating_add(1); - id - } - - /// Returns the number of IDs allocated so far. - #[must_use] - pub fn count(&self) -> usize { - self.next_id - } -} - -impl Default for FunctionVarAllocator { - fn default() -> Self { - Self::new() - } -} - -/// Definition site of an SSA variable. -/// -/// Records where in the program a variable is defined. For most variables, -/// this is a specific instruction within a block. For phi nodes, the definition -/// is at the block entry. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct DefSite { - /// The block where this variable is defined. - pub block: usize, - /// The instruction index within the block, or `None` for phi nodes. - /// - /// Phi nodes are considered to be defined at the "top" of the block, - /// before any real instructions execute. - pub instruction: Option, -} - -impl DefSite { - /// Creates a definition site for a regular instruction. - #[must_use] - pub const fn instruction(block: usize, instr_idx: usize) -> Self { - Self { - block, - instruction: Some(instr_idx), - } - } - - /// Creates a definition site for a phi node (at block entry). - #[must_use] - pub const fn phi(block: usize) -> Self { - Self { - block, - instruction: None, - } - } - - /// Creates a definition site for function entry (arguments and initialized locals). - /// - /// These are defined at the entry block (block 0) before any instructions. - #[must_use] - pub const fn entry() -> Self { - Self { - block: 0, - instruction: None, - } - } - - /// Returns `true` if this is a phi node definition. - #[must_use] - pub const fn is_phi(&self) -> bool { - self.instruction.is_none() - } -} - -/// Use site of an SSA variable. -/// -/// Records where in the program a variable is used (read). -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct UseSite { - /// The block where this variable is used. - pub block: usize, - /// The instruction index within the block. - /// - /// For phi node operands, this refers to the phi node's index in the - /// block's phi node list (not the instruction list). - pub instruction: usize, - /// Whether this use is in a phi node operand. - pub is_phi_operand: bool, -} - -impl UseSite { - /// Creates a use site for a regular instruction. - #[must_use] - pub const fn instruction(block: usize, instr_idx: usize) -> Self { - Self { - block, - instruction: instr_idx, - is_phi_operand: false, - } - } - - /// Creates a use site for a phi node operand. - #[must_use] - pub const fn phi_operand(block: usize, phi_idx: usize) -> Self { - Self { - block, - instruction: phi_idx, - is_phi_operand: true, - } - } -} - -/// Complete metadata for an SSA variable. -/// -/// Each SSA variable has exactly one definition point and zero or more use -/// points. This structure tracks all the metadata needed for analysis and -/// optimization. -/// -/// # Construction -/// -/// Variables are created exclusively through [`SsaFunction::create_variable()`](crate::analysis::SsaFunction::create_variable), -/// which ensures dense ID allocation and proper type assignment. -#[derive(Debug, Clone)] -pub struct SsaVariable { - /// Unique identifier for this variable. - id: SsaVarId, - - /// Where this variable originated in the CIL. - origin: VariableOrigin, - - /// SSA version number for this variable. - /// - /// For arguments and locals, multiple versions exist (one per assignment). - /// The version number distinguishes between them. Version 0 is typically - /// the initial value at method entry. - version: u32, - - /// Where this variable is defined. - def_site: DefSite, - - /// The type of this variable. - /// - /// This is inferred from the operation that defines the variable. - /// Initially `SsaType::Unknown` if type inference hasn't been performed. - var_type: SsaType, - - /// All places where this variable is used. - /// - /// This is computed during SSA construction and enables dead code - /// elimination and other use-based analyses. - uses: Vec, - - /// Whether this variable's address has been taken. - /// - /// If `true`, this variable may be modified through a pointer and - /// cannot participate in certain optimizations. Set when `ldarga` - /// or `ldloca` is encountered for the corresponding argument/local. - address_taken: bool, -} - -impl SsaVariable { - /// Creates a new SSA variable with a pre-allocated ID and type. - /// - /// This is `pub(crate)` because variables should only be created through - /// [`SsaFunction::create_variable()`](crate::analysis::SsaFunction::create_variable) - /// which ensures dense ID allocation via [`FunctionVarAllocator`]. - /// - /// # Arguments - /// - /// * `id` - The dense variable ID from [`FunctionVarAllocator`] - /// * `origin` - Where this variable came from in the CIL - /// * `version` - SSA version number - /// * `def_site` - Where this variable is defined - /// * `var_type` - The type of this variable - #[must_use] - pub(crate) fn new( - id: SsaVarId, - origin: VariableOrigin, - version: u32, - def_site: DefSite, - var_type: SsaType, - ) -> Self { - Self { - id, - origin, - version, - def_site, - var_type, - uses: Vec::new(), - address_taken: false, - } - } - - /// Returns the variable's unique identifier. - #[must_use] - pub const fn id(&self) -> SsaVarId { - self.id - } - - /// Returns where this variable originated in the CIL. - #[must_use] - pub const fn origin(&self) -> VariableOrigin { - self.origin - } - - /// Returns the SSA version number. - #[must_use] - pub const fn version(&self) -> u32 { - self.version - } - - /// Returns where this variable is defined. - #[must_use] - pub const fn def_site(&self) -> DefSite { - self.def_site - } - - /// Returns the type of this variable. - /// - /// Returns `SsaType::Unknown` if type inference hasn't been performed. - #[must_use] - pub fn var_type(&self) -> &SsaType { - &self.var_type - } - - /// Updates where this variable is defined. - pub fn set_def_site(&mut self, site: DefSite) { - self.def_site = site; - } - - /// Sets the type of this variable. - /// - /// This is typically called during type inference or when resolving - /// phi node types. - pub fn set_type(&mut self, var_type: SsaType) { - self.var_type = var_type; - } - - /// Returns `true` if the variable's type is known (not Unknown). - #[must_use] - pub fn has_known_type(&self) -> bool { - !matches!(self.var_type, SsaType::Unknown) - } - - /// Returns all use sites for this variable. - #[must_use] - pub fn uses(&self) -> &[UseSite] { - &self.uses - } - - /// Returns `true` if this variable's address has been taken. - #[must_use] - pub const fn is_address_taken(&self) -> bool { - self.address_taken - } - - /// Returns `true` if this variable has no uses (dead). - #[must_use] - pub fn is_dead(&self) -> bool { - self.uses.is_empty() - } - - /// Returns the number of uses for this variable. - #[must_use] - pub fn use_count(&self) -> usize { - self.uses.len() - } - - /// Adds a use site for this variable. - pub fn add_use(&mut self, use_site: UseSite) { - self.uses.push(use_site); - } - - /// Clears all use sites for this variable. - /// - /// This is used when recomputing use information after SSA transformations - /// that may have invalidated the use tracking. - pub fn clear_uses(&mut self) { - self.uses.clear(); - } - - /// Marks this variable as having its address taken. - pub fn set_address_taken(&mut self) { - self.address_taken = true; - } - - /// Sets the origin of this variable. - /// - /// This is used during local variable optimization to update indices - /// after unused locals are removed. - pub fn set_origin(&mut self, origin: VariableOrigin) { - self.origin = origin; - } - - /// Sets the variable's ID. - /// - /// Used during variable compaction to reassign dense IDs. - pub fn set_id(&mut self, id: SsaVarId) { - self.id = id; - } -} - -impl fmt::Display for SsaVariable { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "{}_{}", self.origin, self.version) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::analysis::ssa::SsaType; - - #[test] - fn test_ssa_var_id_creation() { - let id = SsaVarId::from_index(42); - assert_eq!(id.index(), 42); - } - - #[test] - fn test_ssa_var_id_display() { - let id = SsaVarId::from_index(0); - let expected = format!("v{}", id.index()); - assert_eq!(format!("{id}"), expected); - assert_eq!(format!("{id:?}"), expected); - } - - #[test] - fn test_ssa_var_id_equality() { - let id1 = SsaVarId::from_index(10); - let id2 = SsaVarId::from_index(10); - let id3 = SsaVarId::from_index(20); - - assert_eq!(id1, id2); - assert_ne!(id1, id3); - } - - #[test] - fn test_variable_origin_argument() { - let origin = VariableOrigin::Argument(0); - assert!(origin.is_argument()); - assert!(!origin.is_local()); - assert!(!origin.is_phi()); - assert_eq!(origin.argument_index(), Some(0)); - assert_eq!(origin.local_index(), None); - assert_eq!(format!("{origin}"), "arg0"); - } - - #[test] - fn test_variable_origin_local() { - let origin = VariableOrigin::Local(3); - assert!(!origin.is_argument()); - assert!(origin.is_local()); - assert!(!origin.is_phi()); - assert_eq!(origin.argument_index(), None); - assert_eq!(origin.local_index(), Some(3)); - assert_eq!(format!("{origin}"), "loc3"); - } - - #[test] - fn test_variable_origin_phi() { - let origin = VariableOrigin::Phi; - assert!(!origin.is_argument()); - assert!(!origin.is_local()); - assert!(origin.is_phi()); - assert_eq!(format!("{origin}"), "phi"); - } - - #[test] - fn test_def_site_instruction() { - let site = DefSite::instruction(2, 5); - assert_eq!(site.block, 2); - assert_eq!(site.instruction, Some(5)); - assert!(!site.is_phi()); - } - - #[test] - fn test_def_site_phi() { - let site = DefSite::phi(3); - assert_eq!(site.block, 3); - assert_eq!(site.instruction, None); - assert!(site.is_phi()); - } - - #[test] - fn test_use_site_instruction() { - let site = UseSite::instruction(1, 4); - assert_eq!(site.block, 1); - assert_eq!(site.instruction, 4); - assert!(!site.is_phi_operand); - } - - #[test] - fn test_use_site_phi_operand() { - let site = UseSite::phi_operand(2, 0); - assert_eq!(site.block, 2); - assert_eq!(site.instruction, 0); - assert!(site.is_phi_operand); - } - - #[test] - fn test_ssa_variable_creation() { - let var = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Argument(0), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - // ID is now auto-allocated - assert_eq!(var.origin(), VariableOrigin::Argument(0)); - assert_eq!(var.version(), 0); - assert!(var.def_site().is_phi()); - assert!(var.uses().is_empty()); - assert!(!var.is_address_taken()); - assert!(var.is_dead()); - } - - #[test] - fn test_ssa_variable_add_use() { - let mut var = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Local(0), - 1, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - - assert!(var.is_dead()); - - var.add_use(UseSite::instruction(0, 5)); - var.add_use(UseSite::instruction(1, 2)); - - assert!(!var.is_dead()); - assert_eq!(var.uses().len(), 2); - } - - #[test] - fn test_ssa_variable_address_taken() { - let mut var = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Local(1), - 0, - DefSite::phi(0), - SsaType::Unknown, - ); - - assert!(!var.is_address_taken()); - var.set_address_taken(); - assert!(var.is_address_taken()); - } - - #[test] - fn test_ssa_variable_display() { - let var = SsaVariable::new( - SsaVarId::from_index(0), - VariableOrigin::Argument(2), - 3, - DefSite::phi(0), - SsaType::Unknown, - ); - assert_eq!(format!("{var}"), "arg2_3"); - - let var2 = SsaVariable::new( - SsaVarId::from_index(1), - VariableOrigin::Local(0), - 1, - DefSite::instruction(1, 2), - SsaType::Unknown, - ); - assert_eq!(format!("{var2}"), "loc0_1"); - } -} diff --git a/dotscope/src/analysis/ssa/verifier.rs b/dotscope/src/analysis/ssa/verifier.rs deleted file mode 100644 index 3e0e634c..00000000 --- a/dotscope/src/analysis/ssa/verifier.rs +++ /dev/null @@ -1,1011 +0,0 @@ -//! SSA verifier for validating SSA invariants. -//! -//! Provides comprehensive verification of SSA form at three levels: -//! -//! - **Quick**: Single-definition property + block structure (O(n)) -//! - **Standard**: + def-use chains + phi operand coverage (O(n*m)) -//! - **Full**: + dominance checking (O(n^2) worst case) -//! -//! The verifier is the safety net for all SSA transformations. It catches -//! invariant violations early, preventing silent corruption that would -//! manifest as broken codegen or incorrect deobfuscation. - -use std::collections::HashMap; - -use crate::{ - analysis::ssa::{cfg::SsaCfg, DefSite, SsaFunction, SsaVarId}, - utils::{ - graph::{ - algorithms::{compute_dominators, DominatorTree}, - NodeId, RootedGraph, - }, - BitSet, - }, -}; - -/// Definition site for verifier error reporting. -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct VerifierDefSite { - pub block: usize, - pub kind: DefKind, -} - -/// What kind of definition produced a variable. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum DefKind { - Phi(usize), - Instruction(usize), -} - -/// Errors detected by the SSA verifier. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum VerifierError { - /// A variable is used but never defined. - UndefinedUse { - block: usize, - instr_idx: usize, - var: SsaVarId, - }, - /// A phi node is missing an operand for a CFG predecessor. - MissingPhiOperand { - block: usize, - phi_idx: usize, - missing_pred: usize, - }, - /// A phi node has an operand for a non-predecessor block. - ExtraPhiOperand { - block: usize, - phi_idx: usize, - extra_pred: usize, - }, - /// A variable is defined more than once. - DuplicateDefinition { - var: SsaVarId, - def1: VerifierDefSite, - def2: VerifierDefSite, - }, - /// A variable exists in the variables vec but has no definition in any block. - OrphanVariable { var: SsaVarId }, - /// A variable appears in an instruction but is not in the variables vec. - UnregisteredVariable { var: SsaVarId }, - /// A block has successors but no terminator instruction. - MissingTerminator { block: usize }, - /// A phi node appears in the entry block (block 0), which has no predecessors - /// in a well-formed CFG. - PhiInEntryBlock { block: usize, phi_idx: usize }, - /// A variable is used in a block not dominated by its definition block. - DominanceViolation { - var: SsaVarId, - def_block: usize, - use_block: usize, - }, - /// A terminator instruction is not the last instruction in its block. - TerminatorNotLast { - block: usize, - instr_idx: usize, - instr_count: usize, - }, - /// An instruction uses a variable defined later in the same block (cycle). - IntraBlockCycle { - block: usize, - use_instr: usize, - def_instr: usize, - var: SsaVarId, - }, - /// A placeholder variable ID (usize::MAX) remains in finalized SSA. - PlaceholderVariable { block: usize, location: String }, - /// An instruction's destination appears in its own operands (self-referential). - SelfReferentialInstruction { - block: usize, - instr_idx: usize, - var: SsaVarId, - }, -} - -impl std::fmt::Display for VerifierError { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Self::UndefinedUse { - block, - instr_idx, - var, - } => write!( - f, - "Block {block}: instruction {instr_idx} uses undefined variable {var:?}" - ), - Self::MissingPhiOperand { - block, - phi_idx, - missing_pred, - } => write!( - f, - "Block {block}: phi {phi_idx} missing operand for predecessor {missing_pred}" - ), - Self::ExtraPhiOperand { - block, - phi_idx, - extra_pred, - } => write!( - f, - "Block {block}: phi {phi_idx} has operand for non-predecessor {extra_pred}" - ), - Self::DuplicateDefinition { var, def1, def2 } => write!( - f, - "Variable {var:?} defined twice: at block {} ({:?}) and block {} ({:?})", - def1.block, def1.kind, def2.block, def2.kind - ), - Self::OrphanVariable { var } => { - write!(f, "Variable {var:?} in variables vec but not defined in any block") - } - Self::UnregisteredVariable { var } => write!( - f, - "Variable {var:?} used in instruction but not in variables vec" - ), - Self::MissingTerminator { block } => { - write!(f, "Block {block}: has successors but no terminator") - } - Self::PhiInEntryBlock { block, phi_idx } => { - write!(f, "Block {block}: phi {phi_idx} in entry block") - } - Self::DominanceViolation { - var, - def_block, - use_block, - } => write!( - f, - "Variable {var:?}: def in block {def_block} does not dominate use in block {use_block}" - ), - Self::TerminatorNotLast { - block, - instr_idx, - instr_count, - } => write!( - f, - "Block {block}: terminator at position {instr_idx}/{instr_count} is not last" - ), - Self::IntraBlockCycle { - block, - use_instr, - def_instr, - var, - } => write!( - f, - "Block {block}: instruction {use_instr} uses {var:?} defined at instruction {def_instr}" - ), - Self::PlaceholderVariable { block, location } => write!( - f, - "Block {block}: placeholder variable ID (usize::MAX) at {location}" - ), - Self::SelfReferentialInstruction { - block, - instr_idx, - var, - } => write!( - f, - "Block {block}: instruction {instr_idx} has self-referential use of {var:?}" - ), - } - } -} - -impl std::error::Error for VerifierError {} - -/// Verification depth levels. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum VerifyLevel { - /// Single-definition + block structure checks (O(n)). - Quick, - /// + def-use chains + phi operand coverage (O(n*m)). - Standard, - /// + dominance checking (O(n^2) worst case). - Full, -} - -/// SSA verifier that validates invariants at configurable depth. -pub struct SsaVerifier<'a> { - ssa: &'a SsaFunction, - errors: Vec, -} - -impl<'a> SsaVerifier<'a> { - /// Creates a new verifier for the given SSA function. - #[must_use] - pub fn new(ssa: &'a SsaFunction) -> Self { - Self { - ssa, - errors: Vec::new(), - } - } - - /// Runs verification at the specified level and returns all errors found. - pub fn verify(mut self, level: VerifyLevel) -> Vec { - self.errors.clear(); - - // Quick checks (always run) - self.check_single_definition(); - self.check_block_structure(); - self.check_no_placeholders_or_self_refs(); - - if level >= VerifyLevel::Standard { - let cfg = SsaCfg::from_ssa(self.ssa); - let definitions = self.collect_definitions(); - self.check_phi_operands(&cfg); - self.check_defined_before_use(&definitions); - self.check_registered_variables(); - - if level >= VerifyLevel::Full { - let dom_tree = compute_dominators(&cfg, cfg.entry()); - self.check_dominance(&cfg, &dom_tree, &definitions); - } - } - - self.errors - } - - /// Verifies that every variable is defined at most once (the fundamental SSA property). - fn check_single_definition(&mut self) { - let mut definitions: HashMap = HashMap::new(); - - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - let var = phi.result(); - let site = VerifierDefSite { - block: block_idx, - kind: DefKind::Phi(phi_idx), - }; - if let Some(prev) = definitions.get(&var) { - self.errors.push(VerifierError::DuplicateDefinition { - var, - def1: prev.clone(), - def2: site, - }); - } else { - definitions.insert(var, site); - } - } - - for (instr_idx, instr) in block.instructions().iter().enumerate() { - if let Some(dest) = instr.op().dest() { - let site = VerifierDefSite { - block: block_idx, - kind: DefKind::Instruction(instr_idx), - }; - if let Some(prev) = definitions.get(&dest) { - self.errors.push(VerifierError::DuplicateDefinition { - var: dest, - def1: prev.clone(), - def2: site, - }); - } else { - definitions.insert(dest, site); - } - } - } - } - } - - /// Checks block structural invariants: - /// - Every block with successors has a terminator - /// - Terminators are the last instruction - /// - No intra-block cycles (use before def) - fn check_block_structure(&mut self) { - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - let instrs = block.instructions(); - let instr_count = instrs.len(); - - // Check terminator placement - for (instr_idx, instr) in instrs.iter().enumerate() { - if instr.op().is_terminator() && instr_idx < instr_count.saturating_sub(1) { - self.errors.push(VerifierError::TerminatorNotLast { - block: block_idx, - instr_idx, - instr_count, - }); - } - } - - // Check for intra-block use-before-def cycles - let mut def_indices: HashMap = HashMap::new(); - for (instr_idx, instr) in instrs.iter().enumerate() { - if let Some(dest) = instr.op().dest() { - def_indices.insert(dest, instr_idx); - } - } - - for (instr_idx, instr) in instrs.iter().enumerate() { - for used_var in instr.op().uses() { - if let Some(&def_idx) = def_indices.get(&used_var) { - if def_idx >= instr_idx { - self.errors.push(VerifierError::IntraBlockCycle { - block: block_idx, - use_instr: instr_idx, - def_instr: def_idx, - var: used_var, - }); - } - } - } - } - } - } - - /// Collects all variable definitions into a map: var_id -> (block, def_site). - fn collect_definitions(&self) -> HashMap { - let mut defs: HashMap = HashMap::new(); - - // Variables from the variables vec (includes entry-block defs for args/locals) - for var in self.ssa.variables() { - defs.insert(var.id(), (var.def_site().block, var.def_site())); - } - - // Also collect from actual block contents (may differ after transforms) - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - for phi in block.phi_nodes() { - defs.entry(phi.result()) - .or_insert((block_idx, DefSite::phi(block_idx))); - } - for (instr_idx, instr) in block.instructions().iter().enumerate() { - if let Some(dest) = instr.op().dest() { - defs.entry(dest) - .or_insert((block_idx, DefSite::instruction(block_idx, instr_idx))); - } - } - } - - defs - } - - /// Checks that every phi node has the correct operand set: - /// - One operand per CFG predecessor - /// - No operands from non-predecessor blocks - /// - No phis in the entry block (which has no predecessors) - fn check_phi_operands(&mut self, cfg: &SsaCfg<'_>) { - let block_count = self.ssa.block_count(); - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - let pred_list = cfg.block_predecessors(block_idx); - // Capacity must cover both actual predecessors and phi operand predecessors - // (which may reference non-existent blocks in malformed SSA) - let max_phi_pred = block - .phi_nodes() - .iter() - .flat_map(|phi| phi.operands().iter().map(|op| op.predecessor())) - .max() - .unwrap_or(0); - let capacity = block_count.max(max_phi_pred.saturating_add(1)).max(1); - let mut preds = BitSet::new(capacity); - for &p in pred_list { - if p < capacity { - preds.insert(p); - } - } - - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - // Entry block should not have phis (no predecessors) - if block_idx == 0 && preds.is_empty() { - self.errors.push(VerifierError::PhiInEntryBlock { - block: block_idx, - phi_idx, - }); - continue; - } - - let mut operand_preds = BitSet::new(capacity); - for op in phi.operands() { - let pred = op.predecessor(); - operand_preds.insert(pred); - } - - // Check for missing predecessors - for pred in preds.iter() { - if !operand_preds.contains(pred) { - self.errors.push(VerifierError::MissingPhiOperand { - block: block_idx, - phi_idx, - missing_pred: pred, - }); - } - } - - // Check for extra (non-predecessor) operands - for op_pred in operand_preds.iter() { - if !preds.contains(op_pred) { - self.errors.push(VerifierError::ExtraPhiOperand { - block: block_idx, - phi_idx, - extra_pred: op_pred, - }); - } - } - } - } - } - - /// Checks that every variable used in an instruction or phi operand is defined - /// somewhere (either in the variables vec or in a block). - fn check_defined_before_use(&mut self, definitions: &HashMap) { - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - for (instr_idx, instr) in block.instructions().iter().enumerate() { - for used_var in instr.op().uses() { - if !definitions.contains_key(&used_var) { - self.errors.push(VerifierError::UndefinedUse { - block: block_idx, - instr_idx, - var: used_var, - }); - } - } - } - } - } - - /// Checks that every variable used in blocks is registered in the variables vec. - fn check_registered_variables(&mut self) { - let variable_count = self.ssa.variable_count(); - // Capacity must cover all variable IDs that appear in blocks (may exceed variable_count) - let max_block_var = self - .ssa - .blocks() - .iter() - .flat_map(|b| { - let phi_ids = b.phi_nodes().iter().map(|p| p.result().index()); - let instr_ids = b - .instructions() - .iter() - .filter_map(|i| i.op().dest().map(|d| d.index())); - phi_ids.chain(instr_ids) - }) - .max() - .unwrap_or(0); - let max_reg_var = self - .ssa - .variables() - .iter() - .map(|v| v.id().index()) - .max() - .unwrap_or(0); - let capacity = max_block_var - .saturating_add(1) - .max(max_reg_var.saturating_add(1)) - .max(variable_count) - .max(1); - let mut registered = BitSet::new(capacity); - for v in self.ssa.variables() { - registered.insert(v.id().index()); - } - - // Check variables defined in blocks but not in variables vec - for block in self.ssa.blocks() { - for phi in block.phi_nodes() { - let idx = phi.result().index(); - if idx >= capacity || !registered.contains(idx) { - self.errors - .push(VerifierError::UnregisteredVariable { var: phi.result() }); - } - } - for instr in block.instructions() { - if let Some(dest) = instr.op().dest() { - let idx = dest.index(); - if idx >= capacity || !registered.contains(idx) { - self.errors - .push(VerifierError::UnregisteredVariable { var: dest }); - } - } - } - } - - // Check for orphan variables (in variables vec but not defined in any block) - let mut block_defined = BitSet::new(capacity); - for block in self.ssa.blocks() { - for phi in block.phi_nodes() { - let idx = phi.result().index(); - if idx < capacity { - block_defined.insert(idx); - } - } - for instr in block.instructions() { - if let Some(dest) = instr.op().dest() { - let idx = dest.index(); - if idx < capacity { - block_defined.insert(idx); - } - } - } - } - - for var in self.ssa.variables() { - // Version 0 entry-point variables are defined at function entry, not - // in blocks. This includes args, locals, and Phi-origin placeholder - // variables created during SSA rebuild for stack temp groups. - if var.version() == 0 && var.def_site().instruction.is_none() { - continue; - } - if !block_defined.contains(var.id().index()) { - self.errors - .push(VerifierError::OrphanVariable { var: var.id() }); - } - } - } - - /// Checks for placeholder variable IDs (usize::MAX) that should have been - /// replaced during construction. Also checks for self-referential instructions - /// where an instruction's destination appears in its own operands. - fn check_no_placeholders_or_self_refs(&mut self) { - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - // Check phi nodes for placeholder IDs - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - if phi.result().is_placeholder() { - self.errors.push(VerifierError::PlaceholderVariable { - block: block_idx, - location: format!("phi {phi_idx} result"), - }); - } - for operand in phi.operands() { - if operand.value().is_placeholder() { - self.errors.push(VerifierError::PlaceholderVariable { - block: block_idx, - location: format!( - "phi {phi_idx} operand from B{}", - operand.predecessor() - ), - }); - } - } - } - - // Check instructions for placeholder IDs and self-referential uses - for (instr_idx, instr) in block.instructions().iter().enumerate() { - let op = instr.op(); - if let Some(dest) = op.dest() { - if dest.is_placeholder() { - self.errors.push(VerifierError::PlaceholderVariable { - block: block_idx, - location: format!("instruction {instr_idx} dest"), - }); - } - // Check for self-referential instruction (dest appears in uses) - if op.uses().contains(&dest) { - self.errors.push(VerifierError::SelfReferentialInstruction { - block: block_idx, - instr_idx, - var: dest, - }); - } - } - for used_var in op.uses() { - if used_var.is_placeholder() { - self.errors.push(VerifierError::PlaceholderVariable { - block: block_idx, - location: format!("instruction {instr_idx} operand"), - }); - } - } - } - } - } - - /// Checks dominance: every use of a variable must be dominated by its definition. - /// - /// For phi operands, the use is considered to be at the end of the predecessor - /// block (not at the phi's block), following standard SSA semantics. - fn check_dominance( - &mut self, - cfg: &SsaCfg<'_>, - dom_tree: &DominatorTree, - definitions: &HashMap, - ) { - // Compute reachable blocks - let block_count = self.ssa.block_count().max(1); - let mut reachable = BitSet::new(block_count); - let mut worklist = vec![0usize]; - while let Some(block_idx) = worklist.pop() { - if block_idx < block_count && reachable.insert(block_idx) { - for &succ in cfg.block_successors(block_idx) { - if succ < block_count { - worklist.push(succ); - } - } - } - } - - // Check instruction uses - for (block_idx, block) in self.ssa.blocks().iter().enumerate() { - if !reachable.contains(block_idx) { - continue; - } - - for instr in block.instructions() { - for used_var in instr.op().uses() { - if let Some(&(def_block, _)) = definitions.get(&used_var) { - if !reachable.contains(def_block) { - continue; - } - // Definition must dominate use block - let def_node = NodeId::new(def_block); - let use_node = NodeId::new(block_idx); - if def_node.index() < dom_tree.node_count() - && use_node.index() < dom_tree.node_count() - && !dom_tree.dominates(def_node, use_node) - { - self.errors.push(VerifierError::DominanceViolation { - var: used_var, - def_block, - use_block: block_idx, - }); - } - } - } - } - - // Check phi operand uses: the use is at the end of the predecessor - for phi in block.phi_nodes() { - for operand in phi.operands() { - let used_var = operand.value(); - let pred_block = operand.predecessor(); - if let Some(&(def_block, _)) = definitions.get(&used_var) { - if !reachable.contains(def_block) || !reachable.contains(pred_block) { - continue; - } - // Definition must dominate the predecessor block - let def_node = NodeId::new(def_block); - let pred_node = NodeId::new(pred_block); - if def_node.index() < dom_tree.node_count() - && pred_node.index() < dom_tree.node_count() - && !dom_tree.dominates(def_node, pred_node) - { - self.errors.push(VerifierError::DominanceViolation { - var: used_var, - def_block, - use_block: pred_block, - }); - } - } - } - } - } - } -} - -#[cfg(test)] -mod tests { - use super::*; - - use crate::analysis::ssa::{ - ConstValue, DefSite, PhiNode, PhiOperand, SsaBlock, SsaFunction, SsaInstruction, SsaOp, - SsaType, SsaVarId, VariableOrigin, - }; - - /// Helper: create a minimal SSA function for testing. - fn make_empty_ssa() -> SsaFunction { - SsaFunction::new(0, 0) - } - - /// Helper: create an SSA with a single block containing a return. - fn make_single_block_ssa() -> SsaFunction { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - ssa - } - - #[test] - fn test_empty_ssa_passes_all_levels() { - let ssa = make_empty_ssa(); - assert!(SsaVerifier::new(&ssa).verify(VerifyLevel::Quick).is_empty()); - assert!(SsaVerifier::new(&ssa) - .verify(VerifyLevel::Standard) - .is_empty()); - assert!(SsaVerifier::new(&ssa).verify(VerifyLevel::Full).is_empty()); - } - - #[test] - fn test_single_block_passes() { - let ssa = make_single_block_ssa(); - assert!(SsaVerifier::new(&ssa) - .verify(VerifyLevel::Standard) - .is_empty()); - } - - #[test] - fn test_duplicate_definition_detected() { - let mut ssa = SsaFunction::new(0, 0); - let var_id = SsaVarId::from_index(0); - - let mut block = SsaBlock::new(0); - // Define the same var twice - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: var_id, - value: ConstValue::I32(1), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: var_id, - value: ConstValue::I32(2), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Quick); - assert!( - errors.iter().any( - |e| matches!(e, VerifierError::DuplicateDefinition { var, .. } if *var == var_id) - ), - "should detect duplicate definition: {errors:?}" - ); - } - - #[test] - fn test_terminator_not_last_detected() { - let mut ssa = SsaFunction::new(0, 0); - let mut block = SsaBlock::new(0); - // Terminator followed by another instruction - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Nop)); - ssa.add_block(block); - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Quick); - assert!( - errors - .iter() - .any(|e| matches!(e, VerifierError::TerminatorNotLast { block: 0, .. })), - "should detect terminator not last: {errors:?}" - ); - } - - #[test] - fn test_intra_block_cycle_detected() { - let mut ssa = SsaFunction::new(0, 0); - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - - let mut block = SsaBlock::new(0); - // v0 uses v1, but v1 is defined after v0 - block.add_instruction(SsaInstruction::synthetic(SsaOp::Copy { dest: v0, src: v1 })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(42), - })); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - ssa.add_block(block); - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Quick); - assert!( - errors - .iter() - .any(|e| matches!(e, VerifierError::IntraBlockCycle { .. })), - "should detect intra-block cycle: {errors:?}" - ); - } - - #[test] - fn test_missing_phi_operand_detected() { - let mut ssa = SsaFunction::new(0, 0); - - // Register variables first (dense IDs: 0, 1, 2) - let cond = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let v1 = ssa.create_variable( - VariableOrigin::Local(0), - 1, - DefSite::instruction(1, 0), - SsaType::Unknown, - ); - let phi_result = ssa.create_variable( - VariableOrigin::Local(0), - 2, - DefSite::phi(2), - SsaType::Unknown, - ); - - // Block 0: branch to 1 or 2 - let mut b0 = SsaBlock::new(0); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: cond, - value: ConstValue::I32(1), - })); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: cond, - true_target: 1, - false_target: 2, - })); - - // Block 1: jump to 2 - let mut b1 = SsaBlock::new(1); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(10), - })); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - - // Block 2: phi with only one operand (missing pred 0) - let mut b2 = SsaBlock::new(2); - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 1)); // only from block 1, missing block 0 - b2.add_phi(phi); - b2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - ssa.add_block(b0); - ssa.add_block(b1); - ssa.add_block(b2); - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Standard); - assert!( - errors.iter().any(|e| matches!( - e, - VerifierError::MissingPhiOperand { - block: 2, - missing_pred: 0, - .. - } - )), - "should detect missing phi operand from block 0: {errors:?}" - ); - } - - #[test] - fn test_well_formed_diamond_passes_full() { - let mut ssa = SsaFunction::new(0, 1); - - // Register all variables first (dense IDs: 0, 1, 2, 3) - let cond = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let v1 = ssa.create_variable( - VariableOrigin::Local(0), - 1, - DefSite::instruction(1, 0), - SsaType::Unknown, - ); - let v2 = ssa.create_variable( - VariableOrigin::Local(0), - 2, - DefSite::instruction(2, 0), - SsaType::Unknown, - ); - let phi_result = ssa.create_variable( - VariableOrigin::Local(0), - 3, - DefSite::phi(3), - SsaType::Unknown, - ); - - // Block 0: branch - let mut b0 = SsaBlock::new(0); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: cond, - value: ConstValue::I32(1), - })); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: cond, - true_target: 1, - false_target: 2, - })); - - // Block 1: define v1 - let mut b1 = SsaBlock::new(1); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(10), - })); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 3 })); - - // Block 2: define v2 - let mut b2 = SsaBlock::new(2); - b2.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v2, - value: ConstValue::I32(20), - })); - b2.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 3 })); - - // Block 3: merge with phi - let mut b3 = SsaBlock::new(3); - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 1)); - phi.add_operand(PhiOperand::new(v2, 2)); - b3.add_phi(phi); - b3.add_instruction(SsaInstruction::synthetic(SsaOp::Return { - value: Some(phi_result), - })); - - ssa.add_block(b0); - ssa.add_block(b1); - ssa.add_block(b2); - ssa.add_block(b3); - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Full); - assert!( - errors.is_empty(), - "well-formed diamond should pass Full verification: {errors:?}" - ); - } - - #[test] - fn test_undefined_use_detected() { - let mut ssa = SsaFunction::new(0, 0); - let undefined_var = SsaVarId::from_index(0); - - let mut block = SsaBlock::new(0); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { - value: Some(undefined_var), - })); - ssa.add_block(block); - // Note: undefined_var is NOT registered as a variable - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Standard); - assert!( - errors.iter().any( - |e| matches!(e, VerifierError::UndefinedUse { var, .. } if *var == undefined_var) - ), - "should detect undefined use: {errors:?}" - ); - } - - #[test] - fn test_verify_level_ordering() { - assert!(VerifyLevel::Quick < VerifyLevel::Standard); - assert!(VerifyLevel::Standard < VerifyLevel::Full); - } - - #[test] - fn test_extra_phi_operand_detected() { - let mut ssa = SsaFunction::new(0, 0); - - // Register variables first (dense IDs: 0, 1, 2) - let v1 = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let v_extra = ssa.create_variable( - VariableOrigin::Local(0), - 1, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let phi_result = ssa.create_variable( - VariableOrigin::Local(0), - 2, - DefSite::phi(1), - SsaType::Unknown, - ); - - // Block 0: jump to block 1 - let mut b0 = SsaBlock::new(0); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v1, - value: ConstValue::I32(1), - })); - b0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - - // Block 1: phi with operand from block 0 AND from non-predecessor block 5 - let mut b1 = SsaBlock::new(1); - let mut phi = PhiNode::new(phi_result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 0)); - phi.add_operand(PhiOperand::new(v_extra, 5)); // block 5 doesn't exist - b1.add_phi(phi); - b1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - ssa.add_block(b0); - ssa.add_block(b1); - - let errors = SsaVerifier::new(&ssa).verify(VerifyLevel::Standard); - assert!( - errors.iter().any(|e| matches!( - e, - VerifierError::ExtraPhiOperand { - block: 1, - extra_pred: 5, - .. - } - )), - "should detect extra phi operand from non-predecessor: {errors:?}" - ); - } -} diff --git a/dotscope/src/analysis/taint.rs b/dotscope/src/analysis/taint.rs index 6138d07b..d3c9e068 100644 --- a/dotscope/src/analysis/taint.rs +++ b/dotscope/src/analysis/taint.rs @@ -1,683 +1,28 @@ -//! Generic taint analysis for SSA functions. -//! -//! This module provides a reusable taint analysis framework that can propagate -//! taint information through SSA variables and instructions. It supports: -//! -//! - **Forward propagation**: If an instruction uses a tainted variable, its -//! output becomes tainted (the result depends on tainted data). -//! - **Backward propagation**: If an instruction's output is tainted, its -//! inputs become tainted (they contribute to tainted data). -//! - **PHI handling**: Configurable modes for how taint flows through PHI nodes. -//! -//! # Use Cases -//! -//! - **CFF Unflattening**: Track state variables to identify dispatcher machinery -//! - **Cleanup Neutralization**: Identify instructions dependent on removed tokens -//! - **Security Analysis**: Track data flow from untrusted sources -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::analysis::{TaintAnalysis, TaintConfig, PhiTaintMode, SsaFunction}; -//! -//! let ssa: SsaFunction = /* ... */; -//! -//! let config = TaintConfig { -//! forward: true, -//! backward: true, -//! phi_mode: PhiTaintMode::TaintAllOperands, -//! max_iterations: 100, -//! }; -//! -//! let mut taint = TaintAnalysis::new(config); -//! taint.add_tainted_var(some_var_id); -//! taint.propagate(&ssa); -//! -//! // Check what's tainted -//! if taint.is_var_tainted(other_var_id) { -//! println!("Variable is tainted!"); -//! } -//! ``` +//! Re-export shim — generic taint analysis lives in +//! `analyssa::analysis::taint`. This file keeps CIL-specific glue +//! (`TokenTaintBuilder`, `find_token_dependencies`) that uses +//! `op.referenced_token()` and the `Token` opaque-id type — both of which +//! analyssa intentionally doesn't see. use std::collections::HashSet; +#[allow(unused_imports)] +pub use analyssa::analysis::taint::{ + cff_taint_config, find_blocks_jumping_to, PhiTaintMode, TaintAnalysis, TaintConfig, TaintStats, +}; + use crate::{ - analysis::ssa::{SsaFunction, SsaOp, SsaVarId, VariableOrigin}, + analysis::{SsaFunction, SsaOpCilExt}, metadata::token::Token, }; -/// How to handle PHI nodes during taint propagation. -/// -/// PHI nodes are control flow merge points where values from different -/// predecessors come together. The taint mode determines how taint -/// flows through these merge points. -#[derive(Debug, Clone, PartialEq, Eq)] -pub enum PhiTaintMode { - /// If the PHI result is tainted, all operands become tainted. - /// - /// Use this for backward analysis where you need to find all sources - /// that could contribute to a tainted value. - TaintAllOperands, - - /// If any operand is tainted, the PHI result becomes tainted. - /// - /// Use this for forward analysis where taint should flow from any - /// predecessor path. - TaintIfAnyOperand, - - /// Only taint operands from specific predecessor blocks. - /// - /// Use this for path-sensitive analysis where only certain control - /// flow paths should propagate taint. - TaintFromPredecessors(HashSet), - - /// Selective backward taint through PHI chains for CFF analysis. - /// - /// This mode is specifically designed for control flow flattening (CFF) - /// analysis where we need to trace state values back through PHI chains. - /// - /// When a PHI result is tainted: - /// - Check if the PHI's origin matches (if origin filter is Some) - /// - Only taint operands from predecessors in the set - /// - Recursively trace through intermediate PHIs with the same origin - SelectivePhi { - /// Set of predecessor blocks whose operands should be tainted. - /// For CFF, this is typically the set of blocks that jump to the dispatcher. - predecessors: HashSet, - /// Optional `VariableOrigin` to filter PHI chains. - /// Only PHIs with matching origin will be traversed. - origin_filter: Option, - }, - - /// Don't propagate taint through PHI nodes. - /// - /// Use this when PHIs represent control flow merge points that - /// should act as taint barriers. - NoPropagation, -} - -/// Configuration for taint analysis. -/// -/// Controls how taint propagates through the SSA graph. -#[derive(Debug, Clone)] -pub struct TaintConfig { - /// Propagate forward (input tainted → output tainted). - /// - /// When enabled, if an instruction uses a tainted variable, its - /// defined variable (if any) becomes tainted. - pub forward: bool, - - /// Propagate backward (output tainted → inputs tainted). - /// - /// When enabled, if an instruction's defined variable is tainted, - /// all variables it uses become tainted. - pub backward: bool, - - /// How to handle PHI nodes. - pub phi_mode: PhiTaintMode, - - /// Maximum iterations for fixpoint computation. - /// - /// Prevents infinite loops in pathological cases. - pub max_iterations: usize, -} - -impl Default for TaintConfig { - fn default() -> Self { - Self { - forward: true, - backward: false, - phi_mode: PhiTaintMode::TaintIfAnyOperand, - max_iterations: 100, - } - } -} - -impl TaintConfig { - /// Creates a config for forward-only propagation. - /// - /// Suitable for tracking what variables depend on a taint source. - #[must_use] - pub fn forward_only() -> Self { - Self { - forward: true, - backward: false, - phi_mode: PhiTaintMode::TaintIfAnyOperand, - max_iterations: 100, - } - } - - /// Creates a config for bidirectional propagation. - /// - /// Suitable for cleanup neutralization where we need to find all - /// instructions connected to removed tokens. - #[must_use] - pub fn bidirectional() -> Self { - Self { - forward: true, - backward: true, - phi_mode: PhiTaintMode::TaintAllOperands, - max_iterations: 100, - } - } -} - -/// Statistics about taint analysis execution. -#[derive(Debug, Clone, Default)] -pub struct TaintStats { - /// Number of iterations to reach fixpoint. - pub iterations: usize, - /// Number of tainted variables. - pub tainted_vars: usize, - /// Number of tainted instructions. - pub tainted_instrs: usize, - /// Number of tainted PHI nodes. - pub tainted_phis: usize, -} - -/// Generic taint analysis for SSA functions. -/// -/// This struct tracks which variables and instructions are "tainted" - meaning -/// they are connected to some set of taint sources through data flow. +/// Builder for taint analysis that finds instructions referencing specific +/// CIL metadata tokens. /// -/// The analysis runs to a fixpoint, propagating taint through the SSA graph -/// according to the configuration. -#[derive(Debug, Clone)] -pub struct TaintAnalysis { - /// Tainted SSA variables. - tainted_vars: HashSet, - - /// Tainted instructions: (block_idx, instr_idx). - tainted_instrs: HashSet<(usize, usize)>, - - /// Tainted PHI nodes: (block_idx, phi_idx). - tainted_phis: HashSet<(usize, usize)>, - - /// Configuration. - config: TaintConfig, - - /// Statistics from the last propagation. - stats: TaintStats, -} - -impl TaintAnalysis { - /// Creates a new taint analysis with the given configuration. - /// - /// # Arguments - /// - /// * `config` - Configuration controlling propagation behavior. - /// - /// # Returns - /// - /// A new `TaintAnalysis` with empty taint sets. - #[must_use] - pub fn new(config: TaintConfig) -> Self { - Self { - tainted_vars: HashSet::new(), - tainted_instrs: HashSet::new(), - tainted_phis: HashSet::new(), - config, - stats: TaintStats::default(), - } - } - - /// Creates a taint analysis with default forward-only configuration. - #[must_use] - pub fn forward_only() -> Self { - Self::new(TaintConfig::forward_only()) - } - - /// Creates a taint analysis with bidirectional configuration. - #[must_use] - pub fn bidirectional() -> Self { - Self::new(TaintConfig::bidirectional()) - } - - /// Adds a variable as a taint source. - /// - /// # Arguments - /// - /// * `var` - The variable ID to mark as tainted. - pub fn add_tainted_var(&mut self, var: SsaVarId) { - self.tainted_vars.insert(var); - } - - /// Adds multiple variables as taint sources. - /// - /// # Arguments - /// - /// * `vars` - Iterator of variable IDs to mark as tainted. - pub fn add_tainted_vars(&mut self, vars: impl IntoIterator) { - self.tainted_vars.extend(vars); - } - - /// Adds an instruction as a taint source. - /// - /// Also taints the instruction's defined variable (if any) and its uses - /// (for backward propagation from instructions without defs like stores). - /// - /// # Arguments - /// - /// * `block` - Block index containing the instruction. - /// * `instr` - Instruction index within the block. - /// * `ssa` - The SSA function for looking up the instruction's def/uses. - pub fn add_tainted_instr(&mut self, block: usize, instr: usize, ssa: &SsaFunction) { - self.tainted_instrs.insert((block, instr)); - - if let Some(block_data) = ssa.block(block) { - if let Some(instruction) = block_data.instructions().get(instr) { - // Taint the instruction's defined variable (for forward propagation) - if let Some(def) = instruction.def() { - self.tainted_vars.insert(def); - } - - // Also taint the instruction's uses (for backward propagation). - // This is critical for instructions like StoreStaticField that have - // no def - we need to taint what feeds into them. - if self.config.backward { - for use_var in instruction.uses() { - self.tainted_vars.insert(use_var); - } - } - } - } - } - - /// Adds a PHI node as a taint source. - /// - /// Also taints the PHI's result variable. - /// - /// # Arguments - /// - /// * `block` - Block index containing the PHI. - /// * `phi_idx` - PHI index within the block. - /// * `ssa` - The SSA function for looking up the PHI's result. - pub fn add_tainted_phi(&mut self, block: usize, phi_idx: usize, ssa: &SsaFunction) { - self.tainted_phis.insert((block, phi_idx)); - - // Also taint the PHI's result variable - if let Some(block_data) = ssa.block(block) { - if let Some(phi) = block_data.phi_nodes().get(phi_idx) { - self.tainted_vars.insert(phi.result()); - } - } - } - - /// Runs taint propagation to fixpoint. - /// - /// Iteratively propagates taint through the SSA graph until no more - /// changes occur or the iteration limit is reached. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - pub fn propagate(&mut self, ssa: &SsaFunction) { - let mut iterations: usize = 0; - - loop { - if iterations >= self.config.max_iterations { - break; - } - iterations = iterations.saturating_add(1); - - let mut changed = false; - - // Process PHI nodes first - changed |= self.propagate_phis(ssa); - - // Process instructions - changed |= self.propagate_instructions(ssa); - - if !changed { - break; - } - } - - // Update statistics - self.stats = TaintStats { - iterations, - tainted_vars: self.tainted_vars.len(), - tainted_instrs: self.tainted_instrs.len(), - tainted_phis: self.tainted_phis.len(), - }; - } - - /// Propagates taint through PHI nodes. - /// - /// Returns `true` if any changes were made. - fn propagate_phis(&mut self, ssa: &SsaFunction) -> bool { - let mut changed = false; - - for (block_idx, block) in ssa.blocks().iter().enumerate() { - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - let result = phi.result(); - let result_tainted = self.tainted_vars.contains(&result); - - match &self.config.phi_mode { - PhiTaintMode::TaintAllOperands => { - // If result is tainted, all operands become tainted - if result_tainted { - for operand in phi.operands() { - if self.tainted_vars.insert(operand.value()) { - changed = true; - } - } - if self.tainted_phis.insert((block_idx, phi_idx)) { - changed = true; - } - } - } - PhiTaintMode::TaintIfAnyOperand => { - // If any operand is tainted, result becomes tainted - let any_operand_tainted = phi - .operands() - .iter() - .any(|op| self.tainted_vars.contains(&op.value())); - - if any_operand_tainted { - if self.tainted_vars.insert(result) { - changed = true; - } - if self.tainted_phis.insert((block_idx, phi_idx)) { - changed = true; - } - } - } - PhiTaintMode::TaintFromPredecessors(preds) => { - // Only taint operands from specific predecessors - if result_tainted { - for operand in phi.operands() { - if preds.contains(&operand.predecessor()) - && self.tainted_vars.insert(operand.value()) - { - changed = true; - } - } - if self.tainted_phis.insert((block_idx, phi_idx)) { - changed = true; - } - } - } - PhiTaintMode::SelectivePhi { - predecessors, - origin_filter, - } => { - // Selective backward taint for CFF analysis - if result_tainted { - // Check if this PHI's origin matches the filter - let should_follow = origin_filter - .as_ref() - .is_none_or(|filter| phi.origin() == *filter); - - if should_follow { - for operand in phi.operands() { - if predecessors.contains(&operand.predecessor()) - && self.tainted_vars.insert(operand.value()) - { - changed = true; - } - } - if self.tainted_phis.insert((block_idx, phi_idx)) { - changed = true; - } - } - } - } - PhiTaintMode::NoPropagation => { - // Don't propagate through PHIs - } - } - } - } - - changed - } - - /// Propagates taint through instructions. - /// - /// Returns `true` if any changes were made. - fn propagate_instructions(&mut self, ssa: &SsaFunction) -> bool { - let mut changed = false; - - for (block_idx, instr_idx, instr) in ssa.iter_instructions() { - let def = instr.def(); - let uses = instr.uses(); - - // Forward propagation: if any USE is tainted, DEF becomes tainted - if self.config.forward { - if let Some(def_var) = def { - let uses_tainted = uses.iter().any(|u| self.tainted_vars.contains(u)); - if uses_tainted { - if self.tainted_vars.insert(def_var) { - changed = true; - } - if self.tainted_instrs.insert((block_idx, instr_idx)) { - changed = true; - } - } - } - } - - // Backward propagation: if DEF is tainted, all USEs become tainted - if self.config.backward { - let def_tainted = def.is_some_and(|d| self.tainted_vars.contains(&d)); - if def_tainted { - for use_var in &uses { - if self.tainted_vars.insert(*use_var) { - changed = true; - } - } - if self.tainted_instrs.insert((block_idx, instr_idx)) { - changed = true; - } - } - } - - // Array-aware propagation: if an array is tainted, all StoreElement - // operations to that array are also tainted (they're preparing dead data). - // This is critical for cleanup neutralization where protection code fills - // arrays that are passed to removed methods. - if self.config.backward { - if let SsaOp::StoreElement { array, .. } = instr.op() { - if self.tainted_vars.contains(array) - && self.tainted_instrs.insert((block_idx, instr_idx)) - { - changed = true; - // Also taint the value and index being stored - they feed into dead code - for use_var in &uses { - if self.tainted_vars.insert(*use_var) { - changed = true; - } - } - } - } - } - - // Mark instruction as tainted if it uses tainted vars (even without def) - let uses_tainted = uses.iter().any(|u| self.tainted_vars.contains(u)); - if uses_tainted && self.tainted_instrs.insert((block_idx, instr_idx)) { - changed = true; - } - } - - changed - } - - /// Checks if a variable is tainted. - /// - /// # Arguments - /// - /// * `var` - The variable ID to check. - /// - /// # Returns - /// - /// `true` if the variable is tainted. - #[must_use] - pub fn is_var_tainted(&self, var: SsaVarId) -> bool { - self.tainted_vars.contains(&var) - } - - /// Checks if an instruction is tainted. - /// - /// # Arguments - /// - /// * `block` - Block index. - /// * `instr` - Instruction index within the block. - /// - /// # Returns - /// - /// `true` if the instruction is tainted. - #[must_use] - pub fn is_instr_tainted(&self, block: usize, instr: usize) -> bool { - self.tainted_instrs.contains(&(block, instr)) - } - - /// Checks if a PHI node is tainted. - /// - /// # Arguments - /// - /// * `block` - Block index. - /// * `phi_idx` - PHI index within the block. - /// - /// # Returns - /// - /// `true` if the PHI is tainted. - #[must_use] - pub fn is_phi_tainted(&self, block: usize, phi_idx: usize) -> bool { - self.tainted_phis.contains(&(block, phi_idx)) - } - - /// Returns all tainted variables. - #[must_use] - pub fn tainted_variables(&self) -> &HashSet { - &self.tainted_vars - } - - /// Returns all tainted instructions. - #[must_use] - pub fn tainted_instructions(&self) -> &HashSet<(usize, usize)> { - &self.tainted_instrs - } - - /// Returns all tainted PHI nodes. - #[must_use] - pub fn tainted_phis(&self) -> &HashSet<(usize, usize)> { - &self.tainted_phis - } - - /// Returns statistics from the last propagation. - #[must_use] - pub fn stats(&self) -> &TaintStats { - &self.stats - } - - /// Returns the number of tainted variables. - #[must_use] - pub fn tainted_var_count(&self) -> usize { - self.tainted_vars.len() - } - - /// Returns the number of tainted instructions. - #[must_use] - pub fn tainted_instr_count(&self) -> usize { - self.tainted_instrs.len() - } - - /// Clears all taint information. - pub fn clear(&mut self) { - self.tainted_vars.clear(); - self.tainted_instrs.clear(); - self.tainted_phis.clear(); - self.stats = TaintStats::default(); - } -} - -/// Finds all blocks that have a direct jump/branch to the target block. -/// -/// This is useful for CFF analysis where we need to identify which blocks -/// set the state variable (those that jump back to the dispatcher). -/// -/// # Arguments -/// -/// * `ssa` - The SSA function to analyze. -/// * `target` - The target block index to find jumpers to. -/// -/// # Returns -/// -/// A set of block indices that have a control flow edge to `target`. -#[must_use] -pub fn find_blocks_jumping_to(ssa: &SsaFunction, target: usize) -> HashSet { - let mut jumpers = HashSet::new(); - - for block in ssa.blocks() { - if let Some(terminator) = block.instructions().last() { - let jumps_to_target = match terminator.op() { - SsaOp::Jump { target: t } | SsaOp::Leave { target: t } => *t == target, - SsaOp::Branch { - true_target, - false_target, - .. - } - | SsaOp::BranchCmp { - true_target, - false_target, - .. - } => *true_target == target || *false_target == target, - SsaOp::Switch { - targets, default, .. - } => *default == target || targets.contains(&target), - _ => false, - }; - - if jumps_to_target { - jumpers.insert(block.id()); - } - } - } - - jumpers -} - -/// Creates a CFF-specific taint configuration for state variable analysis. -/// -/// This configuration is designed for control flow flattening analysis where: -/// - Forward propagation is enabled (derived values from state are tainted) -/// - Backward propagation is disabled (too aggressive, taints loop counters) -/// - PHI taint uses selective mode (only from blocks jumping to dispatcher) -/// -/// # Arguments -/// -/// * `ssa` - The SSA function being analyzed. -/// * `dispatcher_block` - The block index of the CFF dispatcher. -/// * `state_origin` - Optional `VariableOrigin` to filter PHI chains. -/// -/// # Returns -/// -/// A `TaintConfig` configured for CFF state tracking. -#[must_use] -pub fn cff_taint_config( - ssa: &SsaFunction, - dispatcher_block: usize, - state_origin: Option, -) -> TaintConfig { - let predecessors = find_blocks_jumping_to(ssa, dispatcher_block); - - TaintConfig { - forward: true, - backward: false, - phi_mode: PhiTaintMode::SelectivePhi { - predecessors, - origin_filter: state_origin, - }, - max_iterations: 100, - } -} - -/// Builder for taint analysis that finds instructions referencing specific tokens. -/// -/// This is a convenience builder for the common pattern of finding all instructions -/// that reference a set of tokens (methods, types, fields) and then propagating -/// taint from those instructions. +/// Convenience wrapper around the analyssa-side [`TaintAnalysis`] for the +/// common pattern of finding all instructions that reference a set of +/// tokens (methods, types, fields) and propagating taint from those +/// instructions. #[derive(Debug)] pub struct TokenTaintBuilder { /// Tokens to find references to. @@ -688,10 +33,6 @@ pub struct TokenTaintBuilder { impl TokenTaintBuilder { /// Creates a new token taint builder. - /// - /// # Arguments - /// - /// * `tokens` - Tokens to find references to. #[must_use] pub fn new(tokens: impl IntoIterator) -> Self { Self { @@ -709,21 +50,13 @@ impl TokenTaintBuilder { /// Builds and runs the taint analysis on the given SSA function. /// - /// Finds all instructions that reference the target tokens, marks them - /// as taint sources, and propagates to fixpoint. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// - /// # Returns - /// - /// The completed taint analysis. + /// Finds all instructions that reference the target tokens (via + /// `SsaOpCilExt::referenced_token`), marks them as taint sources, and + /// propagates to fixpoint. #[must_use] pub fn analyze(self, ssa: &SsaFunction) -> TaintAnalysis { let mut taint = TaintAnalysis::new(self.config); - // Find instructions that reference target tokens for (block_idx, instr_idx, instr) in ssa.iter_instructions() { if let Some(token) = instr.op().referenced_token() { if self.target_tokens.contains(&token) { @@ -732,27 +65,16 @@ impl TokenTaintBuilder { } } - // Propagate taint taint.propagate(ssa); - taint } } -/// Convenience function to find instructions referencing removed tokens. -/// -/// This is the main entry point for cleanup neutralization. It finds all -/// instructions that reference the given tokens and propagates taint to -/// find all dependent instructions. -/// -/// # Arguments +/// Find instructions referencing removed CIL tokens. /// -/// * `ssa` - The SSA function to analyze. -/// * `removed_tokens` - Tokens that will be removed. -/// -/// # Returns -/// -/// A taint analysis with all dependent instructions marked. +/// Main entry point for cleanup neutralization. Finds all instructions that +/// reference the given tokens and propagates taint to find all dependent +/// instructions. #[must_use] pub fn find_token_dependencies( ssa: &SsaFunction, @@ -760,7 +82,6 @@ pub fn find_token_dependencies( ) -> TaintAnalysis { TokenTaintBuilder::new(removed_tokens).analyze(ssa) } - #[cfg(test)] mod tests { use std::collections::HashSet; @@ -807,6 +128,7 @@ mod tests { dest: v2, left: v0, right: v1, + flags: None, })); b0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); ssa.add_block(b0); @@ -817,6 +139,7 @@ mod tests { dest: v3, left: v2, right: v0, + flags: None, })); b1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v3) })); ssa.add_block(b1); diff --git a/dotscope/src/analysis/x86/cfg.rs b/dotscope/src/analysis/x86/cfg.rs index dc773bd1..4b92a729 100644 --- a/dotscope/src/analysis/x86/cfg.rs +++ b/dotscope/src/analysis/x86/cfg.rs @@ -4,15 +4,14 @@ //! and their successor relationships. It leverages the existing graph infrastructure //! from [`crate::utils::graph`] for efficient analysis. -use crate::{ - analysis::{ - cfg::has_back_edges, - x86::types::{X86DecodedInstruction, X86EdgeKind, X86Instruction}, - }, - utils::graph::{ - algorithms::{compute_dominators, DominatorTree}, - DirectedGraph, GraphBase, NodeId, Predecessors, RootedGraph, Successors, - }, +use analyssa::graph::{ + algorithms::{compute_dominators, DominatorTree}, + DirectedGraph, GraphBase, NodeId, Predecessors, RootedGraph, Successors, +}; + +use crate::analysis::{ + cfg::has_back_edges, + x86::types::{X86DecodedInstruction, X86EdgeKind, X86Instruction}, }; use rustc_hash::FxHashMap; use std::{collections::BTreeSet, sync::OnceLock}; @@ -482,13 +481,12 @@ fn compute_edges( #[cfg(test)] mod tests { - use crate::{ - analysis::x86::{ - cfg::X86Function, - decoder::x86_decode_all, - types::{X86Condition, X86EdgeKind}, - }, - utils::graph::NodeId, + use analyssa::graph::NodeId; + + use crate::analysis::x86::{ + cfg::X86Function, + decoder::x86_decode_all, + types::{X86Condition, X86EdgeKind}, }; #[test] diff --git a/dotscope/src/analysis/x86/decoder.rs b/dotscope/src/analysis/x86/decoder.rs index 8dbc3a97..7581edb6 100644 --- a/dotscope/src/analysis/x86/decoder.rs +++ b/dotscope/src/analysis/x86/decoder.rs @@ -87,7 +87,7 @@ pub fn x86_decode_all( /// Result of traversal-based decoding. /// -/// This struct is returned by [`x86_decode_function_traversal`] and contains +/// This struct is returned by [`x86_decode_traversal`] and contains /// all instructions discovered by following control flow from the entry point. /// /// # Completeness @@ -329,7 +329,7 @@ pub fn x86_decode_traversal( /// Determines the byte size of a native x86/x64 function body. /// -/// Uses traversal-based decoding ([`x86_decode_function_traversal`]) to follow +/// Uses traversal-based decoding ([`x86_decode_traversal`]) to follow /// control-flow edges from the entry point, making it robust against /// anti-disassembly tricks (junk bytes, embedded data, overlapping /// instructions) because only code reachable through actual control flow is diff --git a/dotscope/src/analysis/x86/mod.rs b/dotscope/src/analysis/x86/mod.rs index fffa146f..83e1f770 100644 --- a/dotscope/src/analysis/x86/mod.rs +++ b/dotscope/src/analysis/x86/mod.rs @@ -28,7 +28,7 @@ //! Decodes instructions sequentially from the start until a `RET` instruction. //! Fast and simple, but vulnerable to anti-disassembly tricks. //! -//! ## Traversal-Based Decoding ([`x86_decode_function_traversal`]) +//! ## Traversal-Based Decoding ([`x86_decode_traversal`]) //! //! Follows control flow edges from the entry point, only decoding reachable code. //! More robust against: @@ -63,10 +63,10 @@ //! # Traversal-Based Example //! //! ```rust,ignore -//! use dotscope::analysis::{x86_decode_function_traversal, X86Function}; +//! use dotscope::analysis::{x86_decode_traversal, X86Function}; //! //! // Decode using control-flow following (more robust) -//! let result = x86_decode_function_traversal(bytes, 32, 0x1000, 0)?; +//! let result = x86_decode_traversal(bytes, 32, 0x1000, 0)?; //! println!("Decoded {} instructions", result.instructions.len()); //! println!("Has indirect jumps: {}", result.has_indirect_control_flow); //! @@ -300,14 +300,6 @@ mod tests { } } - // ======================================================================== - // ConfuserEx x86 native stub tests - // - // Bytecodes extracted from test samples in tests/samples/packers/confuserex/1.6.0/. - // All stubs use DynCipher calling convention: body starts after 20-byte - // prologue. These bytes are the body only (without prologue). - // ======================================================================== - /// Helper: decode body bytes, build CFG, translate to SSA, evaluate with /// concrete input, and return the i32 result. fn eval_x86_stub(body: &[u8], input: i32) -> i32 { diff --git a/dotscope/src/analysis/x86/ssa.rs b/dotscope/src/analysis/x86/ssa.rs index 6a3a468d..8e858bea 100644 --- a/dotscope/src/analysis/x86/ssa.rs +++ b/dotscope/src/analysis/x86/ssa.rs @@ -39,6 +39,9 @@ //! let ssa_function = translator.translate()?; //! ``` +use analyssa::graph::NodeId; +use rustc_hash::{FxHashMap, FxHashSet}; + use crate::{ analysis::{ ssa::{ @@ -47,17 +50,15 @@ use crate::{ }, x86::{ cfg::X86Function, - flags::{ArithmeticKind, ConditionEval, FlagState, FlagTestSource}, + flags::{ArithmeticKind, ConditionEval, FlagProducer, FlagState, FlagTestSource}, types::{ X86Condition, X86DecodedInstruction, X86Instruction, X86Memory, X86Operand, X86Register, }, }, }, - utils::graph::NodeId, Error, Result, }; -use rustc_hash::{FxHashMap, FxHashSet}; /// Number of registers tracked (0-15 GPRs for x64, 16-21 segment registers). const MAX_REGISTERS: usize = 22; @@ -574,6 +575,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: src_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::Add); @@ -604,6 +606,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: src_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::Sub); @@ -639,6 +642,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: src_var, right: multiplier, + flags: None, })); reg_state.set(*dst, res_var); flags.set_arithmetic(res_var, src_var, multiplier, ArithmeticKind::Other); @@ -658,6 +662,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: eax, right: src_var, + flags: None, })); reg_state.set(X86Register::Eax, res_var); flags.set_arithmetic(res_var, eax, src_var, ArithmeticKind::Other); @@ -673,6 +678,7 @@ impl<'a> X86ToSsaTranslator<'a> { result.push(SsaInstruction::synthetic(SsaOp::Neg { dest: res_var, operand: dst_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic_unary(res_var); @@ -702,6 +708,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: neg_cf, left: cf, right: one, + flags: None, })); flags.set_carry(neg_cf); } @@ -718,6 +725,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: one, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; // INC sets ZF/SF/OF but does NOT modify CF @@ -736,6 +744,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: one, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; // DEC sets ZF/SF/OF but does NOT modify CF @@ -760,6 +769,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: temp_var, left: dst_var, right: src_var, + flags: None, })); // result = temp + CF let res_var = self.create_variable( @@ -771,6 +781,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: temp_var, right: cf_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::Add); @@ -806,6 +817,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: new_cf, left: cf1, right: cf2, + flags: None, })); flags.set_carry(new_cf); } @@ -828,6 +840,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: temp_var, left: dst_var, right: src_var, + flags: None, })); // result = temp - CF let res_var = self.create_variable( @@ -839,6 +852,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: temp_var, right: cf_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::Sub); @@ -874,6 +888,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: new_cf, left: cf1, right: cf2, + flags: None, })); flags.set_carry(new_cf); } @@ -893,6 +908,7 @@ impl<'a> X86ToSsaTranslator<'a> { left: eax, right: src_var, unsigned: true, + flags: None, })); reg_state.set(X86Register::Eax, quot_var); // Remainder → EDX @@ -906,6 +922,7 @@ impl<'a> X86ToSsaTranslator<'a> { left: eax, right: src_var, unsigned: true, + flags: None, })); reg_state.set(X86Register::Edx, rem_var); flags.set_arithmetic(quot_var, eax, src_var, ArithmeticKind::Other); @@ -926,6 +943,7 @@ impl<'a> X86ToSsaTranslator<'a> { left: eax, right: src_var, unsigned: false, + flags: None, })); reg_state.set(X86Register::Eax, quot_var); // Remainder → EDX @@ -939,6 +957,7 @@ impl<'a> X86ToSsaTranslator<'a> { left: eax, right: src_var, unsigned: false, + flags: None, })); reg_state.set(X86Register::Edx, rem_var); flags.set_arithmetic(quot_var, eax, src_var, ArithmeticKind::Other); @@ -969,6 +988,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: val, amount: c24, unsigned: true, + flags: None, })); let byte0 = self.create_variable( VariableOrigin::Phi, @@ -979,6 +999,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: byte0, left: shr24, right: mask_ff, + flags: None, })); // byte1 = (val >> 8) & 0xFF00 @@ -992,6 +1013,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: val, amount: c8, unsigned: true, + flags: None, })); let byte1 = self.create_variable( VariableOrigin::Phi, @@ -1002,6 +1024,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: byte1, left: shr8, right: mask_ff00, + flags: None, })); // byte2 = (val << 8) & 0xFF0000 @@ -1014,6 +1037,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: shl8, value: val, amount: c8, + flags: None, })); let byte2 = self.create_variable( VariableOrigin::Phi, @@ -1024,6 +1048,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: byte2, left: shl8, right: mask_ff0000, + flags: None, })); // byte3 = (val << 24) & 0xFF000000 @@ -1036,6 +1061,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: shl24, value: val, amount: c24, + flags: None, })); let byte3 = self.create_variable( VariableOrigin::Phi, @@ -1046,6 +1072,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: byte3, left: shl24, right: mask_ff000000, + flags: None, })); // Combine: result = byte0 | byte1 | byte2 | byte3 @@ -1058,6 +1085,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: or01, left: byte0, right: byte1, + flags: None, })); let or012 = self.create_variable( VariableOrigin::Phi, @@ -1068,6 +1096,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: or012, left: or01, right: byte2, + flags: None, })); let res_var = self.create_variable( VariableOrigin::Phi, @@ -1078,6 +1107,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: or012, right: byte3, + flags: None, })); reg_state.set(*dst, res_var); @@ -1107,6 +1137,7 @@ impl<'a> X86ToSsaTranslator<'a> { result.push(SsaInstruction::synthetic(SsaOp::Neg { dest: neg_cond, operand: cond_var, + flags: None, })); let xor_diff = self.create_variable( VariableOrigin::Phi, @@ -1117,6 +1148,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: xor_diff, left: src_var, right: dst_var, + flags: None, })); let masked = self.create_variable( VariableOrigin::Phi, @@ -1127,6 +1159,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: masked, left: xor_diff, right: neg_cond, + flags: None, })); let res_var = self.create_variable( VariableOrigin::Phi, @@ -1137,6 +1170,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: masked, + flags: None, })); reg_state.set(*dst, res_var); // Cmovcc doesn't modify flags @@ -1179,6 +1213,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: src_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; self.set_operand_value(src, dst_var, reg_state, &mut result, block_idx)?; @@ -1211,6 +1246,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: src_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::LogicalOp); @@ -1231,6 +1267,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: src_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::LogicalOp); @@ -1251,6 +1288,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, left: dst_var, right: src_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, src_var, ArithmeticKind::LogicalOp); @@ -1269,6 +1307,7 @@ impl<'a> X86ToSsaTranslator<'a> { result.push(SsaInstruction::synthetic(SsaOp::Not { dest: res_var, operand: dst_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; // NOT doesn't affect flags (except in some specific cases) @@ -1286,6 +1325,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: res_var, value: dst_var, amount: count_var, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, count_var, ArithmeticKind::Other); @@ -1304,6 +1344,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: dst_var, amount: count_var, unsigned: true, + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, count_var, ArithmeticKind::Other); @@ -1322,6 +1363,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: dst_var, amount: count_var, unsigned: false, // SAR is signed shift right + flags: None, })); self.set_operand_value(dst, res_var, reg_state, &mut result, block_idx)?; flags.set_arithmetic(res_var, dst_var, count_var, ArithmeticKind::Other); @@ -1382,10 +1424,7 @@ impl<'a> X86ToSsaTranslator<'a> { // Get comparison operands from flags if let Some((cmp, left, right, unsigned)) = flags.get_branch_operands(*condition) { // Handle TEST special case - if matches!( - flags.producer(), - Some(crate::analysis::x86::flags::FlagProducer::Test { .. }) - ) { + if matches!(flags.producer(), Some(FlagProducer::Test { .. })) { // For TEST + JE/JNE, we need to compare (left & right) with 0 let and_result = self.create_variable( VariableOrigin::Phi, @@ -1396,6 +1435,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: and_result, left, right, + flags: None, })); let zero = self.get_zero_constant(&mut result, block_idx); result.push(SsaInstruction::synthetic(SsaOp::BranchCmp { @@ -1474,6 +1514,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: eax, amount: thirty_one, unsigned: false, + flags: None, })); reg_state.set(X86Register::Edx, edx_var); } @@ -1633,6 +1674,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: and_result, left, right, + flags: None, })); let zero = self.get_zero_constant(instrs, block_idx); self.emit_comparison(cmp, and_result, zero, false, instrs, block_idx) @@ -1698,6 +1740,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left: eq_result, right: one, + flags: None, })); } CmpKind::Lt => { @@ -1734,6 +1777,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left: gt_result, right: one, + flags: None, })); } CmpKind::Ge => { @@ -1754,6 +1798,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left: lt_result, right: one, + flags: None, })); } } @@ -1780,6 +1825,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left, right, + flags: None, })); result } @@ -1793,6 +1839,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left, right, + flags: None, })); result } @@ -1824,6 +1871,7 @@ impl<'a> X86ToSsaTranslator<'a> { value, amount: shift_const, unsigned: true, + flags: None, })); // sign_bit = sign_shifted & 1 @@ -1837,6 +1885,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: sign_bit, left: sign_shifted, right: one, + flags: None, })); if negated { @@ -1851,6 +1900,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left: sign_bit, right: one2, + flags: None, })); Ok(result) } else { @@ -1897,6 +1947,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: r, left, right, + flags: None, })); r }; @@ -1911,6 +1962,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: xor_lr, left, right, + flags: None, })); let xor_ls = self.create_variable( @@ -1922,6 +1974,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: xor_ls, left, right: sub_result, + flags: None, })); let and_both = self.create_variable( @@ -1933,6 +1986,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: and_both, left: xor_lr, right: xor_ls, + flags: None, })); let shift_amount = if self.func.bitness == 64 { 63 } else { 31 }; @@ -1947,6 +2001,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: and_both, amount: shift_const, unsigned: true, + flags: None, })); let one = self.get_constant(1, instrs, block_idx); @@ -1959,6 +2014,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: of_bit, left: shifted, right: one, + flags: None, })); of_bit } @@ -1975,6 +2031,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: xor_lr, left, right: add_result, + flags: None, })); let xor_rr = self.create_variable( @@ -1986,6 +2043,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: xor_rr, left: right, right: add_result, + flags: None, })); let and_both = self.create_variable( @@ -1997,6 +2055,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: and_both, left: xor_lr, right: xor_rr, + flags: None, })); let shift_amount = if self.func.bitness == 64 { 63 } else { 31 }; @@ -2011,6 +2070,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: and_both, amount: shift_const, unsigned: true, + flags: None, })); let one = self.get_constant(1, instrs, block_idx); @@ -2023,6 +2083,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: of_bit, left: shifted, right: one, + flags: None, })); of_bit } @@ -2060,6 +2121,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result_var, left: of, right: one, + flags: None, })); Ok(result_var) } else { @@ -2091,6 +2153,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: b, left: value, right: mask_ff, + flags: None, })); // b ^= b >> 4 @@ -2105,6 +2168,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: b, amount: c4, unsigned: true, + flags: None, })); let b1 = self.create_variable( VariableOrigin::Phi, @@ -2115,6 +2179,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: b1, left: b, right: shr4, + flags: None, })); // b ^= b >> 2 @@ -2129,6 +2194,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: b1, amount: c2, unsigned: true, + flags: None, })); let b2 = self.create_variable( VariableOrigin::Phi, @@ -2139,6 +2205,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: b2, left: b1, right: shr2, + flags: None, })); // b ^= b >> 1 @@ -2153,6 +2220,7 @@ impl<'a> X86ToSsaTranslator<'a> { value: b2, amount: c1, unsigned: true, + flags: None, })); let b3 = self.create_variable( VariableOrigin::Phi, @@ -2163,6 +2231,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: b3, left: b2, right: shr1, + flags: None, })); // odd_parity = b3 & 1 @@ -2176,6 +2245,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: odd_parity, left: b3, right: one, + flags: None, })); if negated { @@ -2193,6 +2263,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: result, left: odd_parity, right: one2, + flags: None, })); Ok(result) } @@ -2231,6 +2302,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: scaled_var, left: index_val, right: scale_const, + flags: None, })); scaled_var }; @@ -2245,6 +2317,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: new_addr, left: addr, right: scaled, + flags: None, })); addr = new_addr; } @@ -2261,6 +2334,7 @@ impl<'a> X86ToSsaTranslator<'a> { dest: new_addr, left: addr, right: disp_const, + flags: None, })); addr = new_addr; } diff --git a/dotscope/src/cilassembly/changes/heap.rs b/dotscope/src/cilassembly/changes/heap.rs index 742e76ee..ff1fa3ce 100644 --- a/dotscope/src/cilassembly/changes/heap.rs +++ b/dotscope/src/cilassembly/changes/heap.rs @@ -245,10 +245,6 @@ impl HeapChanges { } } -// ============================================================================= -// String heap specialization -// ============================================================================= - impl HeapChanges { /// Creates a new string heap changes tracker. pub fn new_strings() -> Self { @@ -299,10 +295,6 @@ impl HeapChanges { } } -// ============================================================================= -// Blob heap specialization -// ============================================================================= - impl HeapChanges> { /// Creates a new blob heap changes tracker. pub fn new_blobs() -> Self { @@ -334,10 +326,6 @@ impl HeapChanges> { } } -// ============================================================================= -// GUID heap specialization -// ============================================================================= - impl HeapChanges<[u8; 16]> { /// Creates a new GUID heap changes tracker. pub fn new_guids() -> Self { diff --git a/dotscope/src/cilassembly/writer/fields.rs b/dotscope/src/cilassembly/writer/fields.rs index ae4b5535..cbab965e 100644 --- a/dotscope/src/cilassembly/writer/fields.rs +++ b/dotscope/src/cilassembly/writer/fields.rs @@ -316,7 +316,7 @@ pub fn write_field_data(ctx: &mut WriteContext) -> Result<()> { let view = ctx.assembly.view(); let file = view.file(); let changes = ctx.changes; - let ptr_size = PointerSize::from_pe(file.pe().is_64bit); + let ptr_size = PointerSize::from_is_64bit(file.pe().is_64bit); let entries = collect_field_data(view, file, changes, ptr_size)?; diff --git a/dotscope/src/cilassembly/writer/tables.rs b/dotscope/src/cilassembly/writer/tables.rs index cd70964b..8a4433e3 100644 --- a/dotscope/src/cilassembly/writer/tables.rs +++ b/dotscope/src/cilassembly/writer/tables.rs @@ -127,10 +127,6 @@ pub trait ResolvePlaceholders { fn resolve_placeholders(&mut self, changes: &AssemblyChanges); } -// ============================================================================ -// ResolvePlaceholders implementations for each table type -// ============================================================================ - impl ResolvePlaceholders for ModuleRaw { fn resolve_placeholders(&mut self, changes: &AssemblyChanges) { // String heap refs @@ -456,10 +452,6 @@ impl ResolvePlaceholders for GenericParamConstraintRaw { } } -// ============================================================================ -// Helper function to resolve placeholders by table ID -// ============================================================================ - /// Resolves placeholders in a boxed table row based on its table ID. /// /// This is a dispatch function that casts the row to the appropriate type diff --git a/dotscope/src/compiler/codegen/coalescing.rs b/dotscope/src/compiler/codegen/coalescing.rs index 4626207e..c03c2f25 100644 --- a/dotscope/src/compiler/codegen/coalescing.rs +++ b/dotscope/src/compiler/codegen/coalescing.rs @@ -33,12 +33,13 @@ use std::{ use rayon::prelude::*; +use analyssa::BitSet; + use crate::{ analysis::{ AnalysisResults, DataFlowSolver, LiveVariables, LivenessResult, SsaCfg, SsaFunction, SsaOp, SsaType, SsaVarId, VariableOrigin, }, - utils::BitSet, Error, Result, }; diff --git a/dotscope/src/compiler/codegen/mod.rs b/dotscope/src/compiler/codegen/mod.rs index 96ffa824..fe7d5d98 100644 --- a/dotscope/src/compiler/codegen/mod.rs +++ b/dotscope/src/compiler/codegen/mod.rs @@ -80,7 +80,7 @@ use crate::{ /// /// Contains the CIL bytecode, stack depth, local variable signatures, and remapped /// exception handlers. This is the bridge between code generation and -/// [`crate::cilassembly::builders::MethodBodyBuilder::from_compilation`]. +/// [`MethodBodyBuilder::from_compilation`](crate::cilassembly::MethodBodyBuilder::from_compilation). pub struct CompilationResult { /// CIL bytecode. pub bytecode: Vec, @@ -381,7 +381,7 @@ impl SsaCodeGenerator { /// /// This is the high-level entry point that wraps [`generate_with_assembly`](Self::generate_with_assembly) /// and adds local variable signature building and exception handler remapping. - /// The result contains everything needed for [`crate::cilassembly::builders::MethodBodyBuilder::from_compilation`] + /// The result contains everything needed for [`MethodBodyBuilder::from_compilation`](crate::cilassembly::MethodBodyBuilder::from_compilation) /// to assemble the final method body. /// /// # Arguments @@ -853,14 +853,14 @@ impl SsaCodeGenerator { value: ConstValue::DecryptedArray { data, - element_type_token, + element_type_ref, element_size, }, .. } if !self.interned_arrays.contains_key(data) => { if let Some(info) = self.intern_array_data( data, - *element_type_token, + element_type_ref.token(), *element_size, assembly, )? { @@ -3448,7 +3448,8 @@ impl SsaCodeGenerator { | SsaOp::Leave { .. } | SsaOp::EndFinally | SsaOp::EndFilter { .. } - | SsaOp::BranchCmp { .. } => { + | SsaOp::BranchCmp { .. } + | SsaOp::BranchFlags { .. } => { self.generate_branch_op(encoder, ssa, current_block_idx, op, next_block_idx)?; } @@ -3494,9 +3495,31 @@ impl SsaCodeGenerator { encoder.emit_instruction("readonly.", None)?; } - SsaOp::Phi { .. } => { + SsaOp::Phi { .. } + | SsaOp::Fence { .. } + | SsaOp::InterruptReturn + | SsaOp::Unreachable => { // Phi nodes are eliminated during code generation - no instruction emitted } + // Rotate and bit manipulation operations - not emitted as CIL primitives + SsaOp::Rol { .. } + | SsaOp::Ror { .. } + | SsaOp::Rcl { .. } + | SsaOp::Rcr { .. } + | SsaOp::BSwap { .. } + | SsaOp::BRev { .. } + | SsaOp::BitScanForward { .. } + | SsaOp::BitScanReverse { .. } + | SsaOp::Popcount { .. } + | SsaOp::Parity { .. } + | SsaOp::Select { .. } + | SsaOp::ReadFlags { .. } + | SsaOp::CmpXchg { .. } + | SsaOp::AtomicRmw { .. } => { + // These operations may appear in the shared SSA but are not + // directly expressible in CIL; they should have been lowered + // before code generation. + } } Ok(()) } @@ -4449,7 +4472,7 @@ impl SsaCodeGenerator { // call InitializeArray ; pop 2 (+1) — net: 1 value on stack ConstValue::DecryptedArray { data, - element_type_token, + element_type_ref, element_size, } => { let elem_size = element_size.max(&1); @@ -4457,7 +4480,8 @@ impl SsaCodeGenerator { let num_elements = data.len().checked_div(*elem_size).unwrap_or(0); emitter::emit_ldc_i4(encoder, num_elements as i32)?; - encoder.emit_instruction("newarr", Some(Operand::Token(*element_type_token)))?; + encoder + .emit_instruction("newarr", Some(Operand::Token(element_type_ref.token())))?; if let Some(info) = self.interned_arrays.get(data) { // Compact: dup + ldtoken + call InitializeArray diff --git a/dotscope/src/compiler/codegen/tests.rs b/dotscope/src/compiler/codegen/tests.rs index f6d0e6a3..b7853bf5 100644 --- a/dotscope/src/compiler/codegen/tests.rs +++ b/dotscope/src/compiler/codegen/tests.rs @@ -129,6 +129,7 @@ fn test_optimized_return() { dest: v2, left: v0, right: v1, + flags: None, })); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) })); ssa.add_block(block); @@ -380,6 +381,7 @@ fn test_add_ovf_signed() { left: a, right: b, unsigned: false, + flags: None, }); blk.ret_val(dest); }); @@ -400,6 +402,7 @@ fn test_add_ovf_unsigned() { left: a, right: b, unsigned: true, + flags: None, }); blk.ret_val(dest); }); @@ -420,6 +423,7 @@ fn test_sub_ovf_signed() { left: a, right: b, unsigned: false, + flags: None, }); blk.ret_val(dest); }); @@ -440,6 +444,7 @@ fn test_sub_ovf_unsigned() { left: a, right: b, unsigned: true, + flags: None, }); blk.ret_val(dest); }); @@ -460,6 +465,7 @@ fn test_mul_ovf_signed() { left: a, right: b, unsigned: false, + flags: None, }); blk.ret_val(dest); }); @@ -480,6 +486,7 @@ fn test_mul_ovf_unsigned() { left: a, right: b, unsigned: true, + flags: None, }); blk.ret_val(dest); }); @@ -1596,6 +1603,7 @@ fn test_is_immediately_consumed_add_left() { dest: v2, left: var, right: v1, + flags: None, }; assert!(gen.is_immediately_consumed(var, Some(&next))); } @@ -1614,6 +1622,7 @@ fn test_is_immediately_consumed_right_operand_simple_left() { dest: v2, left: other_var, right: var, + flags: None, }; // Should be consumed because left is a simple load assert!(gen.is_immediately_consumed(var, Some(&next))); @@ -1631,6 +1640,7 @@ fn test_is_immediately_consumed_right_operand_complex_left() { dest: v2, left: other_var, right: var, + flags: None, }; // Should NOT be consumed because left is not a simple load assert!(!gen.is_immediately_consumed(var, Some(&next))); @@ -1649,6 +1659,7 @@ fn test_dup_optimization_same_arg_twice() { dest: result, left: x, right: x, + flags: None, }); b.ret_val(result); }); @@ -1670,6 +1681,7 @@ fn test_dup_optimization_same_local_twice() { dest: result, left: x, right: x, + flags: None, }); b.ret_val(result); }); @@ -1692,6 +1704,7 @@ fn test_no_dup_for_different_vars() { dest: result, left: x, right: y, + flags: None, }); b.ret_val(result); }); @@ -1802,6 +1815,7 @@ fn test_loop_variable_phi_nodes() { dest: i_next_id, left: i_phi_id, right: one_id, + flags: None, })); block2.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); ssa.add_block(block2); @@ -2019,6 +2033,7 @@ fn test_fibonacci_style_loop() { dest: temp_id, left: a_phi_id, right: b_phi_id, + flags: None, })); // i_next = i + 1 block2.add_instruction(SsaInstruction::synthetic(SsaOp::Const { @@ -2029,6 +2044,7 @@ fn test_fibonacci_style_loop() { dest: i_next_id, left: i_phi_id, right: one_id, + flags: None, })); block2.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); ssa.add_block(block2); @@ -2380,11 +2396,13 @@ fn test_factorial_style_loop() { dest: result_next_id, left: result_phi_id, right: i_phi_id, + flags: None, })); block2.add_instruction(SsaInstruction::synthetic(SsaOp::Sub { dest: i_next_id, left: i_phi_id, right: one_id, + flags: None, })); block2.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); ssa.add_block(block2); @@ -2624,11 +2642,13 @@ fn test_nested_loop_pattern() { dest: sum_next_id, left: sum_inner_id, right: one_id, + flags: None, })); block4.add_instruction(SsaInstruction::synthetic(SsaOp::Add { dest: j_next_id, left: j_phi_id, right: one_id, + flags: None, })); block4.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 3 })); ssa.add_block(block4); @@ -2643,6 +2663,7 @@ fn test_nested_loop_pattern() { dest: i_next_id, left: i_phi_id, right: one_id, + flags: None, })); block5.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); ssa.add_block(block5); @@ -2801,6 +2822,7 @@ fn test_accumulator_with_early_exit() { dest: sum_next_id, left: sum_phi_id, right: i_phi_id, + flags: None, })); block2.add_instruction(SsaInstruction::synthetic(SsaOp::Const { dest: hundred_id, @@ -2829,6 +2851,7 @@ fn test_accumulator_with_early_exit() { dest: i_next_id, left: i_phi_id, right: one_id, + flags: None, })); block3.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); ssa.add_block(block3); diff --git a/dotscope/src/compiler/context.rs b/dotscope/src/compiler/context.rs index 0f347d13..7f3d38d7 100644 --- a/dotscope/src/compiler/context.rs +++ b/dotscope/src/compiler/context.rs @@ -5,7 +5,7 @@ use std::{ collections::{BTreeMap, BTreeSet}, - sync::Arc, + sync::{Arc, RwLock}, time::{Duration, Instant}, }; @@ -13,12 +13,15 @@ use dashmap::{DashMap, DashSet}; use rayon::prelude::*; use crate::{ - analysis::{CallGraph, ConstValue, SsaFunction, SsaOp, SsaVarId, ValueRange}, + analysis::{ + CallGraph, ConstValue, SsaFunction, SsaFunctionCilExt, SsaOp, SsaVarId, ValueRange, + }, compiler::{ - events::EventLog, summary::{CallSiteInfo, MethodSummary}, + EventLog, ProcessingState, }, metadata::token::Token, + CilObject, }; /// Compiler context for the SSA pipeline phase. @@ -85,6 +88,20 @@ pub struct CompilerContext { /// Local variable remappings after optimization, per method. local_remappings: DashMap>>, + /// Assembly under analysis. Set by the pipeline before running passes + /// via [`set_assembly`](Self::set_assembly). Test contexts may leave + /// this `None`. Stored as a `parking_lot::RwLock>>`-like + /// behavior via interior mutability through `arc_swap`-style atomic + /// pointer; here we use a once-set `RwLock>` for + /// simplicity since the assembly set rarely changes after init. + assembly_slot: RwLock>>, + + /// Per-method dirty-tracking state for incremental fixpoint + /// scheduling. Embedded so [`crate::compiler::CompilerContext`] can + /// implement [`analyssa::DirtySet`] without + /// holding an external reference. + pub processing_state: ProcessingState, + /// When analysis started. start_time: Instant, } @@ -109,10 +126,39 @@ impl CompilerContext { known_values: DashMap::new(), known_ranges: DashMap::new(), local_remappings: DashMap::new(), + assembly_slot: RwLock::new(None), + processing_state: ProcessingState::new(), start_time: Instant::now(), } } + /// Sets the assembly under analysis. Idempotent; the pipeline calls + /// this before scheduler dispatch so passes can read the assembly + /// via [`assembly`](Self::assembly). + pub fn set_assembly(&self, assembly: Arc) { + if let Ok(mut slot) = self.assembly_slot.write() { + *slot = Some(assembly); + } + } + + /// Returns the assembly currently under analysis, if set. Passes + /// that need access to .NET metadata reach through this. Returns + /// `None` in test contexts where no assembly was supplied. + #[must_use] + pub fn assembly(&self) -> Option> { + self.assembly_slot.read().ok().and_then(|s| s.clone()) + } + + /// Releases the assembly handle held in this context. Required + /// before unwrapping the assembly from its `Arc` (e.g. for code + /// generation), since `set_assembly` adds a reference that would + /// otherwise prevent the strong-count from reaching one. + pub fn clear_assembly(&self) { + if let Ok(mut slot) = self.assembly_slot.write() { + *slot = None; + } + } + /// Returns the elapsed time since analysis started. #[must_use] pub fn elapsed(&self) -> Duration { diff --git a/dotscope/src/compiler/events.rs b/dotscope/src/compiler/events.rs deleted file mode 100644 index 10d5eb98..00000000 --- a/dotscope/src/compiler/events.rs +++ /dev/null @@ -1,903 +0,0 @@ -//! Unified event logging for the SSA compiler pipeline. -//! -//! This module provides a flexible event logging system that captures all -//! activity during SSA compilation - from individual instruction changes to -//! engine-level decisions. Events can be inspected for debugging or safely -//! ignored when not needed. -//! -//! # Architecture -//! -//! The system is built around three main types: -//! -//! - [`Event`] - A single recorded event (change, warning, info, etc.) -//! - [`EventLog`] - Collection of events with query and summary capabilities -//! - [`EventBuilder`] - Fluent API for creating events -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::compiler::{EventLog, EventKind}; -//! -//! let mut log = EventLog::new(); -//! -//! // Record a string decryption -//! log.record(EventKind::StringDecrypted) -//! .at(method_token, 0x42) -//! .message("decrypted: \"hello world\""); -//! -//! // Get summary statistics -//! println!("{}", log.summary()); -//! ``` - -use std::{ - collections::{HashMap, HashSet}, - fmt, - time::Duration, -}; - -use crate::metadata::token::Token; - -/// Categories of events that can be logged. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum EventKind { - /// A string was decrypted and inlined. - StringDecrypted, - /// A constant value was decrypted via emulation of a decryptor method. - ConstantDecrypted, - /// A constant value was folded/propagated. - ConstantFolded, - /// A conditional branch was simplified to unconditional. - BranchSimplified, - /// An instruction was removed. - InstructionRemoved, - /// A basic block was removed. - BlockRemoved, - /// A method call was inlined. - MethodInlined, - /// A phi node was simplified. - PhiSimplified, - /// An unknown value was resolved to a constant. - ValueResolved, - /// A method was marked as dead (unreachable). - MethodMarkedDead, - /// Control flow was restructured (e.g., unflattening). - ControlFlowRestructured, - /// An opaque predicate was identified and removed. - OpaquePredicateRemoved, - /// A copy operation was propagated away. - CopyPropagated, - /// An array was decrypted. - ArrayDecrypted, - /// An expensive operation was strength-reduced. - StrengthReduced, - /// Orphaned variables were removed from the variable table. - VariablesCompacted, - /// An encrypted method body was decrypted (anti-tamper). - MethodBodyDecrypted, - /// An encrypted manifest resource was decrypted and re-injected as a real - /// `ManifestResource` row (e.g. .NET Reactor Stage 7 resource encryption). - ResourceDecrypted, - /// Anti-tamper protection was removed. - AntiTamperRemoved, - /// An obfuscation artifact was removed (method, type, metadata). - ArtifactRemoved, - - /// Code regeneration completed. - CodeRegenerated, -} - -impl EventKind { - /// Returns a human-readable description of this event kind. - #[must_use] - pub fn description(&self) -> &'static str { - match self { - // Transformations - Self::StringDecrypted => "string decrypted", - Self::ConstantDecrypted => "constant decrypted", - Self::ConstantFolded => "constant folded", - Self::BranchSimplified => "branch simplified", - Self::InstructionRemoved => "instruction removed", - Self::BlockRemoved => "block removed", - Self::MethodInlined => "method inlined", - Self::PhiSimplified => "phi simplified", - Self::ValueResolved => "value resolved", - Self::MethodMarkedDead => "method marked dead", - Self::ControlFlowRestructured => "control flow restructured", - Self::OpaquePredicateRemoved => "opaque predicate removed", - Self::CopyPropagated => "copy propagated", - Self::ArrayDecrypted => "array decrypted", - Self::StrengthReduced => "strength reduced", - Self::VariablesCompacted => "variables compacted", - Self::MethodBodyDecrypted => "method body decrypted", - Self::ResourceDecrypted => "resource decrypted", - Self::AntiTamperRemoved => "anti-tamper removed", - Self::ArtifactRemoved => "artifact removed", - Self::CodeRegenerated => "code regenerated", - } - } - - /// Returns true if this event represents a code transformation. - #[must_use] - pub fn is_transformation(&self) -> bool { - !matches!(self, Self::CodeRegenerated) - } -} - -impl fmt::Display for EventKind { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(self.description()) - } -} - -/// A single logged event. -#[derive(Debug, Clone)] -pub struct Event { - /// The type of event. - pub kind: EventKind, - /// The method where the event occurred (if applicable). - pub method: Option, - /// Location within the method (offset or block ID). - pub location: Option, - /// Human-readable description. - pub message: String, - /// Associated pass name (if from a pass). - pub pass: Option, -} - -impl Event { - /// Creates a new event with the given kind and message. - fn new(kind: EventKind, message: impl Into) -> Self { - Self { - kind, - method: None, - location: None, - message: message.into(), - pass: None, - } - } -} - -impl fmt::Display for Event { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "[{}] {}", self.kind, self.message) - } -} - -/// Builder for creating events with a fluent API. -/// -/// Created by [`EventLog::record`]. The event is automatically added -/// to the log when the builder is dropped. -/// -/// # Example -/// -/// ```rust,ignore -/// log.record(EventKind::StringDecrypted) -/// .at(method, 0x42) -/// .message("decrypted: \"hello\""); -/// ``` -pub struct EventBuilder<'a> { - log: &'a EventLog, - kind: EventKind, - method: Option, - location: Option, - message: Option, - pass: Option, -} - -impl<'a> EventBuilder<'a> { - fn new(log: &'a EventLog, kind: EventKind) -> Self { - Self { - log, - kind, - method: None, - location: None, - message: None, - pass: None, - } - } - - /// Sets the method and location where the event occurred. - pub fn at(mut self, method: Token, location: usize) -> Self { - self.method = Some(method); - self.location = Some(location); - self - } - - /// Sets only the method (for method-level events without specific location). - pub fn method(mut self, method: Token) -> Self { - self.method = Some(method); - self - } - - /// Sets the location (for when method is already set or not applicable). - pub fn location(mut self, location: usize) -> Self { - self.location = Some(location); - self - } - - /// Sets a custom message describing the event. - pub fn message(mut self, msg: impl Into) -> Self { - self.message = Some(msg.into()); - self - } - - /// Associates this event with a specific pass. - pub fn pass(mut self, pass_name: impl Into) -> Self { - self.pass = Some(pass_name.into()); - self - } -} - -impl Drop for EventBuilder<'_> { - fn drop(&mut self) { - let message = self - .message - .take() - .unwrap_or_else(|| self.kind.description().to_string()); - - let event = Event { - kind: self.kind, - method: self.method.take(), - location: self.location.take(), - message, - pass: self.pass.take(), - }; - - self.log.events.push(event); - } -} - -/// Collection of events from deobfuscation. -/// -/// Provides methods for recording events, querying them, and generating -/// summaries. Statistics are derived from the events rather than tracked -/// separately. -/// -/// This type is thread-safe: events can be appended concurrently from -/// multiple threads using shared references (`&self`). -#[derive(Debug)] -pub struct EventLog { - events: boxcar::Vec, -} - -impl Default for EventLog { - fn default() -> Self { - Self { - events: boxcar::Vec::new(), - } - } -} - -impl Clone for EventLog { - fn clone(&self) -> Self { - let new_log = Self::new(); - for (_, event) in &self.events { - new_log.events.push(event.clone()); - } - new_log - } -} - -impl EventLog { - /// Creates an empty event log. - #[must_use] - pub fn new() -> Self { - Self { - events: boxcar::Vec::new(), - } - } - - /// Returns true if no events have been logged. - #[must_use] - pub fn is_empty(&self) -> bool { - self.events.count() == 0 - } - - /// Returns the total number of events. - #[must_use] - pub fn len(&self) -> usize { - self.events.count() - } - - /// Starts building a new event of the given kind. - /// - /// The event is automatically added when the builder is dropped. - /// - /// # Example - /// - /// ```rust,ignore - /// log.record(EventKind::ConstantFolded) - /// .at(method, location) - /// .message("42 + 0 → 42"); - /// ``` - pub fn record(&self, kind: EventKind) -> EventBuilder<'_> { - EventBuilder::new(self, kind) - } - - /// Merges another event log into this one. - pub fn merge(&self, other: &EventLog) { - for (_, event) in &other.events { - self.events.push(event.clone()); - } - } - - /// Returns true if any event of the given kind exists. - #[must_use] - pub fn has(&self, kind: EventKind) -> bool { - self.events.iter().any(|(_, e)| e.kind == kind) - } - - /// Returns true if any of the given event kinds exist. - #[must_use] - pub fn has_any(&self, kinds: &[EventKind]) -> bool { - self.events.iter().any(|(_, e)| kinds.contains(&e.kind)) - } - - /// Counts events of the given kind. - #[must_use] - pub fn count_kind(&self, kind: EventKind) -> usize { - self.events.iter().filter(|(_, e)| e.kind == kind).count() - } - - /// Returns an iterator over all events. - pub fn iter(&self) -> impl Iterator { - self.events.iter().map(|(_, e)| e) - } - - /// Returns an iterator over events of a specific kind. - pub fn filter_kind(&self, kind: EventKind) -> impl Iterator + '_ { - self.events - .iter() - .filter_map(move |(_, e)| if e.kind == kind { Some(e) } else { None }) - } - - /// Takes ownership of the events by cloning into a new EventLog. - /// - /// This is useful when the context is being consumed and you need to - /// extract the events. Since `boxcar::Vec` is append-only and doesn't - /// support draining, this creates a clone. - #[must_use] - pub fn take(&self) -> EventLog { - self.clone() - } - - /// Returns an iterator over events for a specific method. - pub fn filter_method(&self, method: Token) -> impl Iterator + '_ { - self.events.iter().filter_map(move |(_, e)| { - if e.method == Some(method) { - Some(e) - } else { - None - } - }) - } - - /// Returns an iterator over transformation events only. - pub fn transformations(&self) -> impl Iterator + '_ { - self.events.iter().filter_map(|(_, e)| { - if e.kind.is_transformation() { - Some(e) - } else { - None - } - }) - } - - /// Counts events grouped by kind. - #[must_use] - pub fn count_by_kind(&self) -> HashMap { - let mut counts: HashMap = HashMap::new(); - for (_, event) in &self.events { - let entry = counts.entry(event.kind).or_insert(0); - *entry = entry.saturating_add(1); - } - counts - } - - /// Counts events grouped by kind, starting from the given offset. - /// - /// Used by the scheduler to compute per-pass event deltas without - /// iterating the entire log. - #[must_use] - pub fn count_by_kind_since(&self, offset: usize) -> HashMap { - let mut counts: HashMap = HashMap::new(); - for (idx, event) in &self.events { - if idx >= offset { - let entry = counts.entry(event.kind).or_insert(0); - *entry = entry.saturating_add(1); - } - } - counts - } - - /// Returns the number of transformation events. - #[must_use] - pub fn transformation_count(&self) -> usize { - self.events - .iter() - .filter(|(_, e)| e.kind.is_transformation()) - .count() - } - - /// Returns the number of unique methods with events. - #[must_use] - pub fn methods_affected(&self) -> usize { - self.events - .iter() - .filter_map(|(_, e)| e.method) - .collect::>() - .len() - } - - /// Generates a human-readable summary of all events. - #[must_use] - pub fn summary(&self) -> String { - if self.is_empty() { - return "no events".to_string(); - } - - let counts = self.count_by_kind(); - - // Only show transformation counts in summary - let mut parts: Vec = counts - .iter() - .filter(|(k, _)| k.is_transformation()) - .map(|(kind, count)| format!("{} {}", count, kind.description())) - .collect(); - - if parts.is_empty() { - return format!("{} events", self.len()); - } - - parts.sort(); - parts.join(", ") - } -} - -/// Iterator wrapper for EventLog that yields &Event -pub struct EventLogIter<'a> { - inner: boxcar::Iter<'a, Event>, -} - -impl<'a> Iterator for EventLogIter<'a> { - type Item = &'a Event; - - fn next(&mut self) -> Option { - self.inner.next().map(|(_, e)| e) - } -} - -impl<'a> IntoIterator for &'a EventLog { - type Item = &'a Event; - type IntoIter = EventLogIter<'a>; - - fn into_iter(self) -> Self::IntoIter { - EventLogIter { - inner: self.events.iter(), - } - } -} - -impl Extend for EventLog { - fn extend>(&mut self, iter: T) { - for event in iter { - self.events.push(event); - } - } -} - -impl FromIterator for EventLog { - fn from_iter>(iter: T) -> Self { - let log = Self::new(); - for event in iter { - log.events.push(event); - } - log - } -} - -/// Statistics derived from an EventLog. -/// -/// This replaces manual stat tracking - all numbers are computed from events. -#[derive(Debug, Clone, Default)] -pub struct DerivedStats { - /// Number of methods that had any transformations. - pub methods_transformed: usize, - /// Number of strings decrypted. - pub strings_decrypted: usize, - /// Number of arrays decrypted. - pub arrays_decrypted: usize, - /// Number of constants folded. - pub constants_folded: usize, - /// Number of constants decrypted. - pub constants_decrypted: usize, - /// Number of instructions removed. - pub instructions_removed: usize, - /// Number of blocks removed. - pub blocks_removed: usize, - /// Number of branches simplified. - pub branches_simplified: usize, - /// Number of opaque predicates removed. - pub opaque_predicates_removed: usize, - /// Number of methods inlined. - pub methods_inlined: usize, - /// Number of methods marked dead. - pub methods_marked_dead: usize, - /// Number of methods with code regenerated. - pub methods_regenerated: usize, - /// Number of artifacts removed (methods, types, metadata). - pub artifacts_removed: usize, - /// Number of pass iterations. - pub iterations: usize, - /// Processing time. - pub total_time: Duration, -} - -impl DerivedStats { - /// Computes statistics from an event log. - #[must_use] - pub fn from_log(log: &EventLog) -> Self { - let counts = log.count_by_kind(); - let get = |kind: EventKind| counts.get(&kind).copied().unwrap_or(0); - - Self { - methods_transformed: log.methods_affected(), - strings_decrypted: get(EventKind::StringDecrypted), - arrays_decrypted: get(EventKind::ArrayDecrypted), - constants_folded: get(EventKind::ConstantFolded), - constants_decrypted: get(EventKind::ConstantDecrypted), - instructions_removed: get(EventKind::InstructionRemoved), - blocks_removed: get(EventKind::BlockRemoved), - branches_simplified: get(EventKind::BranchSimplified), - opaque_predicates_removed: get(EventKind::OpaquePredicateRemoved), - methods_inlined: get(EventKind::MethodInlined), - methods_marked_dead: get(EventKind::MethodMarkedDead), - methods_regenerated: get(EventKind::CodeRegenerated), - artifacts_removed: get(EventKind::ArtifactRemoved), - iterations: 0, - total_time: Duration::ZERO, - } - } - - /// Sets the total processing time. - #[must_use] - pub fn with_time(mut self, time: Duration) -> Self { - self.total_time = time; - self - } - - /// Sets the number of iterations. - #[must_use] - pub fn with_iterations(mut self, iterations: usize) -> Self { - self.iterations = iterations; - self - } - - /// Generates a human-readable summary. - #[must_use] - pub fn summary(&self) -> String { - let mut parts = Vec::new(); - - // Methods affected - if self.methods_transformed > 0 { - parts.push(format!("{} methods", self.methods_transformed)); - } - - // Decryption stats (grouped) - if self.strings_decrypted > 0 { - parts.push(format!("{} strings decrypted", self.strings_decrypted)); - } - if self.arrays_decrypted > 0 { - parts.push(format!("{} arrays decrypted", self.arrays_decrypted)); - } - if self.constants_decrypted > 0 { - parts.push(format!("{} constants decrypted", self.constants_decrypted)); - } - - // Optimization stats - if self.constants_folded > 0 { - parts.push(format!("{} constants folded", self.constants_folded)); - } - if self.instructions_removed > 0 { - parts.push(format!( - "{} instructions removed", - self.instructions_removed - )); - } - if self.blocks_removed > 0 { - parts.push(format!("{} blocks removed", self.blocks_removed)); - } - if self.branches_simplified > 0 { - parts.push(format!("{} branches simplified", self.branches_simplified)); - } - if self.methods_inlined > 0 { - parts.push(format!("{} inlined", self.methods_inlined)); - } - if self.opaque_predicates_removed > 0 { - parts.push(format!( - "{} opaque predicates", - self.opaque_predicates_removed - )); - } - - // Cleanup stats - if self.methods_marked_dead > 0 { - parts.push(format!("{} dead methods", self.methods_marked_dead)); - } - if self.methods_regenerated > 0 { - parts.push(format!("{} regenerated", self.methods_regenerated)); - } - if self.artifacts_removed > 0 { - parts.push(format!("{} artifacts removed", self.artifacts_removed)); - } - - let stats = if parts.is_empty() { - "no transformations".to_string() - } else { - parts.join(", ") - }; - - if self.total_time.as_millis() > 0 { - format!( - "{} in {:?} ({} iterations)", - stats, self.total_time, self.iterations - ) - } else { - stats - } - } -} - -impl fmt::Display for DerivedStats { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - f.write_str(&self.summary()) - } -} - -/// Truncates a string for display, adding ellipsis if needed. -#[must_use] -pub fn truncate_string(s: &str, max_len: usize) -> String { - if s.len() <= max_len { - s.to_string() - } else { - let end = max_len.saturating_sub(3); - let split = s - .char_indices() - .map(|(i, _)| i) - .take_while(|&i| i <= end) - .last() - .unwrap_or(0); - format!("{}...", &s[..split]) - } -} - -#[cfg(test)] -mod tests { - use std::{sync::Arc, thread}; - - use crate::{ - compiler::events::{DerivedStats, EventKind, EventLog}, - metadata::token::Token, - }; - - #[test] - fn test_empty_log() { - let log = EventLog::new(); - assert!(log.is_empty()); - assert_eq!(log.len(), 0); - assert!(!log.has(EventKind::StringDecrypted)); - } - - #[test] - fn test_record_event() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted) - .at(method, 0x10) - .message("decrypted: \"hello\""); - - assert!(!log.is_empty()); - assert_eq!(log.len(), 1); - assert!(log.has(EventKind::StringDecrypted)); - - let event = log.iter().next().unwrap(); - assert_eq!(event.method, Some(method)); - assert_eq!(event.location, Some(0x10)); - assert_eq!(event.message, "decrypted: \"hello\""); - } - - #[test] - fn test_multiple_events() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted) - .at(method, 0x10) - .message("first"); - log.record(EventKind::ConstantFolded) - .at(method, 0x20) - .message("second"); - - assert_eq!(log.len(), 2); - assert!(log.has(EventKind::StringDecrypted)); - assert!(log.has(EventKind::ConstantFolded)); - assert!(!log.has(EventKind::BlockRemoved)); - } - - #[test] - fn test_has_any() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted).at(method, 0x10); - - assert!(log.has_any(&[EventKind::StringDecrypted, EventKind::ArrayDecrypted])); - assert!(!log.has_any(&[EventKind::BlockRemoved, EventKind::MethodInlined])); - } - - #[test] - fn test_merge() { - let log1 = EventLog::new(); - let log2 = EventLog::new(); - let method = Token::new(0x06000001); - - log1.record(EventKind::StringDecrypted).at(method, 0x10); - log2.record(EventKind::ConstantFolded).at(method, 0x20); - - log1.merge(&log2); - - assert_eq!(log1.len(), 2); - assert!(log1.has(EventKind::StringDecrypted)); - assert!(log1.has(EventKind::ConstantFolded)); - } - - #[test] - fn test_summary() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted).at(method, 0x10); - log.record(EventKind::StringDecrypted).at(method, 0x20); - log.record(EventKind::ConstantFolded).at(method, 0x30); - - let summary = log.summary(); - assert!(summary.contains("2 string decrypted")); - assert!(summary.contains("1 constant folded")); - } - - #[test] - fn test_count_by_kind() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted).at(method, 0x10); - log.record(EventKind::StringDecrypted).at(method, 0x20); - log.record(EventKind::ConstantFolded).at(method, 0x30); - - let counts = log.count_by_kind(); - assert_eq!(counts.get(&EventKind::StringDecrypted), Some(&2)); - assert_eq!(counts.get(&EventKind::ConstantFolded), Some(&1)); - assert_eq!(counts.get(&EventKind::BlockRemoved), None); - } - - #[test] - fn test_count_by_kind_since() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted).at(method, 0x10); - log.record(EventKind::StringDecrypted).at(method, 0x20); - - let offset = log.len(); - - log.record(EventKind::ConstantFolded).at(method, 0x30); - log.record(EventKind::ConstantFolded).at(method, 0x40); - log.record(EventKind::StringDecrypted).at(method, 0x50); - - let counts = log.count_by_kind_since(offset); - assert_eq!(counts.get(&EventKind::ConstantFolded), Some(&2)); - assert_eq!(counts.get(&EventKind::StringDecrypted), Some(&1)); - assert_eq!(counts.get(&EventKind::BlockRemoved), None); - - // Offset 0 should count everything - let all = log.count_by_kind_since(0); - assert_eq!(all.get(&EventKind::StringDecrypted), Some(&3)); - assert_eq!(all.get(&EventKind::ConstantFolded), Some(&2)); - } - - #[test] - fn test_derived_stats() { - let log = EventLog::new(); - let method1 = Token::new(0x06000001); - let method2 = Token::new(0x06000002); - - log.record(EventKind::StringDecrypted).at(method1, 0x10); - log.record(EventKind::StringDecrypted).at(method2, 0x20); - log.record(EventKind::ConstantFolded).at(method1, 0x30); - - let stats = DerivedStats::from_log(&log); - assert_eq!(stats.methods_transformed, 2); - assert_eq!(stats.strings_decrypted, 2); - assert_eq!(stats.constants_folded, 1); - } - - #[test] - fn test_filter_methods() { - let log = EventLog::new(); - let method1 = Token::new(0x06000001); - let method2 = Token::new(0x06000002); - - log.record(EventKind::StringDecrypted).at(method1, 0x10); - log.record(EventKind::ConstantFolded).at(method2, 0x20); - log.record(EventKind::BlockRemoved).at(method1, 0x30); - - let method1_events: Vec<_> = log.filter_method(method1).collect(); - assert_eq!(method1_events.len(), 2); - } - - #[test] - fn test_transformations_filter() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::StringDecrypted).at(method, 0x10); - log.record(EventKind::BlockRemoved).at(method, 0x20); - - let transformations: Vec<_> = log.transformations().collect(); - assert_eq!(transformations.len(), 2); - } - - #[test] - fn test_event_with_pass() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - log.record(EventKind::ConstantFolded) - .at(method, 0x10) - .pass("ConstantFolding") - .message("42 + 0 → 42"); - - let event = log.iter().next().unwrap(); - assert_eq!(event.pass.as_deref(), Some("ConstantFolding")); - } - - #[test] - fn test_default_message() { - let log = EventLog::new(); - let method = Token::new(0x06000001); - - // No explicit message - should use default from kind - log.record(EventKind::StringDecrypted).at(method, 0x10); - - let event = log.iter().next().unwrap(); - assert_eq!(event.message, "string decrypted"); - } - - #[test] - fn test_thread_safe_append() { - let log = Arc::new(EventLog::new()); - let mut handles = vec![]; - - // Spawn multiple threads that append to the same log - for i in 0..4 { - let log_clone = Arc::clone(&log); - handles.push(thread::spawn(move || { - for j in 0..100 { - let method = Token::new(0x06000000 + (i * 100 + j) as u32); - log_clone - .record(EventKind::StringDecrypted) - .at(method, j) - .message(format!("thread {} event {}", i, j)); - } - })); - } - - for handle in handles { - handle.join().unwrap(); - } - - // All 400 events should be present - assert_eq!(log.len(), 400); - } -} diff --git a/dotscope/src/compiler/host.rs b/dotscope/src/compiler/host.rs new file mode 100644 index 00000000..b7b9c81d --- /dev/null +++ b/dotscope/src/compiler/host.rs @@ -0,0 +1,181 @@ +//! `SsaPassHost` + CIL extension trait impls on [`CompilerContext`]. +//! +//! [`CompilerContext`] is the dotscope-side adapter that the analyssa +//! [`PassScheduler`](analyssa::scheduling::PassScheduler) drives. This +//! module wires it into the analyssa host trait family +//! ([`World`](analyssa::World), [`SsaStore`](analyssa::SsaStore), +//! [`DirtySet`](analyssa::DirtySet), +//! [`SsaPassHost`](analyssa::scheduling::SsaPassHost)) and adds the CIL +//! extension trait [`CilHost`] which surfaces the assembly handle and +//! other .NET-specific accessors that target-generic analyssa passes don't +//! need to know about. + +use std::sync::Arc; + +use analyssa::{ + events::EventLog, + host::{DirtySet, SsaStore}, + ir::function::SsaFunction, + scheduling::SsaPassHost, + world::World, + PointerSize, +}; + +use crate::{ + analysis::{CilTarget, MethodRef}, + compiler::CompilerContext, + metadata::token::Token, + CilObject, +}; + +/// CIL-side host extension trait. Used by passes that need access to +/// CIL-specific facilities (the assembly handle, member resolution, +/// etc.) that don't make sense in the target-agnostic analyssa host +/// surface. +/// +/// `CompilerContext` is the canonical impl. Passes that bound their +/// generics on `H: CilHost` get full access to the CIL host while +/// remaining storable in `Box>`. +pub trait CilHost: SsaPassHost { + /// The assembly currently under analysis. Returns `None` in test + /// contexts where [`CompilerContext::set_assembly`] was never + /// called. + fn assembly(&self) -> Option>; + + /// Convenience accessor for the underlying `CompilerContext`. Lets + /// CIL pass impls reach the rich CIL-only fields (`no_inline`, + /// `summaries`, `known_values`, etc.) that aren't surfaced through + /// the analyssa traits. + fn ctx(&self) -> &CompilerContext; +} + +impl World for CompilerContext { + fn all_methods(&self) -> Vec { + self.ssa_functions + .iter() + .map(|e| MethodRef::new(*e.key())) + .collect() + } + + fn entry_points(&self) -> Vec { + self.entry_points + .iter() + .map(|t| MethodRef::new(*t)) + .collect() + } + + fn callees(&self, method: &MethodRef) -> Vec { + // Read direct callees from the call graph. + self.call_graph + .callees(method.0) + .into_iter() + .map(MethodRef::new) + .collect() + } + + fn is_dead(&self, method: &MethodRef) -> bool { + self.dead_methods.contains(&method.0) + } + + fn mark_dead(&self, method: &MethodRef) { + self.dead_methods.insert(method.0); + } + + fn methods_reverse_topological(&self) -> Vec { + CompilerContext::methods_reverse_topological(self) + .into_iter() + .map(MethodRef::new) + .collect() + } +} + +impl SsaStore for CompilerContext { + fn contains(&self, method: &MethodRef) -> bool { + self.ssa_functions.contains_key(&method.0) + } + + fn take_ssa(&self, method: &MethodRef) -> Option> { + self.ssa_functions.remove(&method.0).map(|(_, ssa)| ssa) + } + + fn insert_ssa(&self, method: MethodRef, ssa: SsaFunction) { + self.ssa_functions.insert(method.0, ssa); + } + + fn clone_ssa(&self, method: &MethodRef) -> Option> { + self.ssa_functions.get(&method.0).map(|r| r.clone()) + } + + fn iter_methods(&self) -> Vec { + self.ssa_functions + .iter() + .map(|e| MethodRef::new(*e.key())) + .collect() + } +} + +impl DirtySet for CompilerContext { + fn mark_dirty(&self, method: &MethodRef) { + self.processing_state.mark_method_dirty(method.0); + } + + fn is_dirty(&self, method: &MethodRef) -> bool { + self.processing_state.method_dirty.contains(&method.0) + } + + fn dirty_snapshot(&self) -> Vec { + self.processing_state + .method_dirty + .iter() + .map(|t| MethodRef::new(*t)) + .collect() + } + + fn clear_dirty_for(&self, method: &MethodRef) { + self.processing_state.mark_method_stable(method.0); + } + + fn mark_processed(&self, method: &MethodRef) { + self.processed_methods.insert(method.0); + } + + fn is_processed(&self, method: &MethodRef) -> bool { + self.processed_methods.contains(&method.0) + } +} + +impl SsaPassHost for CompilerContext { + fn events(&self) -> &EventLog { + &self.events + } + + fn ptr_size(&self) -> PointerSize { + // If the assembly is set, derive from its PE header. Test + // contexts without an assembly default to 64-bit (matches the + // existing dotscope default in `analysis::ssa::CilTarget::x64()`). + match self.assembly() { + Some(asm) => PointerSize::from_is_64bit(asm.file().pe().is_64bit), + None => PointerSize::Bit64, + } + } +} + +impl CilHost for CompilerContext { + fn assembly(&self) -> Option> { + Self::assembly(self) + } + + fn ctx(&self) -> &CompilerContext { + self + } +} + +// Helper: convert between CIL `Token` and analyssa `MethodRef`. Both are +// transparent newtypes around `u32`; the `From` impl is provided by the +// existing `MethodRef::from(Token)` definition in `analysis/ssa/types.rs`. +#[allow(dead_code)] +fn _token_methodref_compat() { + let _: MethodRef = MethodRef::from(Token::new(0)); + // Reverse direction (MethodRef -> Token) is via the `.0` field. + let _: Token = MethodRef::new(Token::new(0)).0; +} diff --git a/dotscope/src/compiler/mod.rs b/dotscope/src/compiler/mod.rs index a7db2b86..06ee592a 100644 --- a/dotscope/src/compiler/mod.rs +++ b/dotscope/src/compiler/mod.rs @@ -50,17 +50,31 @@ mod codegen; mod context; -mod events; +mod host; mod pass; mod passes; mod scheduler; mod state; mod summary; +use crate::analysis::CilTarget; + +pub use analyssa::events::{DerivedStats, EventKind, EventListener, NullListener}; pub use codegen::{CompilationResult, SsaCodeGenerator}; pub use context::CompilerContext; -pub use events::{DerivedStats, Event, EventKind, EventLog}; -pub use pass::{ModificationScope, PassCapability, PassPhase, SsaPass}; + +/// CIL-defaulted alias of [`analyssa::events::Event`]. +pub type Event = analyssa::events::Event; +/// CIL-defaulted alias of [`analyssa::events::EventLog`]. +pub type EventLog = analyssa::events::EventLog; +/// CIL-defaulted alias of [`analyssa::events::EventBuilder`]. +pub type EventBuilder<'a, L = EventLog> = analyssa::events::EventBuilder<'a, CilTarget, L>; + +pub use host::CilHost; +pub use pass::{ + CilCapability, DeobfuscationCapability, ModificationScope, PassCapability, PassPhase, SsaPass, + SsaPassHost, +}; pub use passes::{ AlgebraicSimplificationPass, BlockMergingPass, ConstantPropagationPass, ControlFlowSimplificationPass, CopyPropagationPass, DeadCodeEliminationPass, diff --git a/dotscope/src/compiler/pass.rs b/dotscope/src/compiler/pass.rs index 9b8d8eb7..07f78362 100644 --- a/dotscope/src/compiler/pass.rs +++ b/dotscope/src/compiler/pass.rs @@ -1,55 +1,74 @@ -//! Pass traits and infrastructure for the SSA optimization pipeline. +//! CIL-side pass infrastructure — re-exports analyssa's target-agnostic +//! [`SsaPass`] trait and supporting types, plus the CIL-specific +//! [`CilCapability`] enum and the [`PassPhase`] convention dotscope uses +//! to organize its deobfuscation pipeline. //! -//! This module defines the `SsaPass` trait that all SSA transformation passes implement, -//! along with the [`PassCapability`] enum used for capability-based pass scheduling. -//! -//! # Capability-Based Scheduling -//! -//! Passes can declare what they [`provides`](SsaPass::provides) and -//! [`requires`](SsaPass::requires) using [`PassCapability`] values. The scheduler -//! uses these declarations to build a dependency graph and topologically sort -//! passes into execution layers. Passes that don't declare capabilities fall back -//! to their assigned phase ordering. -//! -//! # Modification Scope -//! -//! Each pass declares a [`ModificationScope`] that describes the extent of its -//! modifications to the SSA function. The scheduler uses this to select the -//! appropriate repair strategy after each pass: -//! -//! - [`ModificationScope::UsesOnly`] — No repair needed (SSA invariants preserved) -//! - [`ModificationScope::InstructionsOnly`] — Lightweight repair (recompute def-use, clean up) -//! - [`ModificationScope::CfgModifying`] — Full `rebuild_ssa()` (recompute dominators, phis, etc.) +//! The actual trait and scheduling engine live in +//! [`analyssa::scheduling`]; this module specializes them for CIL. + +pub use analyssa::scheduling::{DeobfuscationCapability, ModificationScope, SsaPass, SsaPassHost}; + +/// CIL-side capability tag. +/// +/// Today this is a flat mirror of analyssa's [`DeobfuscationCapability`] +/// vocabulary so existing CIL passes can keep declaring +/// `PassCapability::DecryptedStrings` etc. without ceremony. Future +/// CIL-only milestones (e.g. .NET-specific tags not shared with x86/MIPS +/// hosts) land here as new variants. The +/// [`From`] impl bridges analyssa-side passes that +/// declare provides/requires using the shared vocabulary. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum CilCapability { + /// Static field values have been resolved to concrete constants. + ResolvedStaticFields, + /// Encrypted strings have been decrypted. + DecryptedStrings, + /// Control flow flattening has been reversed. + RestoredControlFlow, + /// Opaque predicates have been simplified or removed. + SimplifiedPredicates, + /// Proxy or virtual calls have been devirtualized. + DevirtualizedCalls, + /// Small or pure methods have been inlined at their call sites. + InlinedMethods, +} -use crate::{ - analysis::SsaFunction, compiler::CompilerContext, metadata::token::Token, CilObject, Result, -}; +impl From for CilCapability { + fn from(cap: DeobfuscationCapability) -> Self { + match cap { + DeobfuscationCapability::ResolvedStaticFields => Self::ResolvedStaticFields, + DeobfuscationCapability::DecryptedStrings => Self::DecryptedStrings, + DeobfuscationCapability::RestoredControlFlow => Self::RestoredControlFlow, + DeobfuscationCapability::SimplifiedPredicates => Self::SimplifiedPredicates, + DeobfuscationCapability::DevirtualizedCalls => Self::DevirtualizedCalls, + DeobfuscationCapability::InlinedMethods => Self::InlinedMethods, + } + } +} -/// Execution phase for an SSA pass. +/// Execution phase for a CIL pass — fallback layer assignment when the +/// scheduler can't derive ordering from capability dependencies. /// -/// Determines when in the pipeline a pass runs. The scheduler groups passes -/// by phase and executes them in layer order: `Structure` → `Value` → -/// `Simplify` → `Inline`. `Normalize` passes run between every layer's -/// fixpoint iterations rather than as a layer themselves. +/// Convention: `Structure=0`, `Value=1`, `Simplify=2`, `Inline=3`. +/// `Normalize` passes don't participate in layered scheduling and run +/// between every layer's fixpoint iterations (registered via +/// [`analyssa::scheduling::PassScheduler::add_normalize`]). #[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] pub enum PassPhase { - /// Structural transformations (e.g., control flow unflattening). + /// Structural transformations (e.g. control-flow unflattening). Structure, - /// Value-level transformations (e.g., constant decryption, string decryption). + /// Value-level transformations (e.g. constant decryption). Value, - /// Simplification passes (e.g., proxy resolution, anti-debug neutralization). + /// Simplification passes (e.g. proxy resolution). Simplify, - /// Inlining passes (e.g., delegate inlining). + /// Inlining passes (e.g. delegate inlining). Inline, - /// Normalization passes (e.g., nop removal, dead code elimination). + /// Normalization passes (e.g. nop removal, dead-code elimination). Normalize, } impl PassPhase { /// Returns the fallback scheduler layer for this phase. - /// - /// Convention: Structure=0, Value=1, Simplify=2, Inline=3. - /// Normalize passes don't participate in layered scheduling. #[must_use] pub fn as_layer(self) -> usize { match self { @@ -62,250 +81,5 @@ impl PassPhase { } } -/// Capability that a pass can provide or require. -/// -/// The scheduler uses these to build a dependency graph: if pass A provides -/// `ResolvedStaticFields` and pass B requires it, A is scheduled before B. -/// Passes that don't declare any capabilities fall back to phase-based ordering. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum PassCapability { - /// Static field values have been resolved to concrete constants. - ResolvedStaticFields, - /// Encrypted strings have been decrypted. - DecryptedStrings, - /// Control flow flattening has been reversed. - RestoredControlFlow, - /// Opaque predicates have been simplified/removed. - SimplifiedPredicates, - /// Proxy/delegate calls have been devirtualized. - DevirtualizedCalls, - /// Small methods have been inlined at call sites. - InlinedMethods, -} - -/// Describes the extent of modifications a pass makes to the SSA function. -/// -/// The scheduler uses this to select the minimum repair necessary after a pass -/// runs, avoiding expensive full SSA reconstruction when it isn't needed. -/// -/// Passes should declare the **tightest** scope that covers all their -/// modifications. For example, a pass that only forwards uses should declare -/// `UsesOnly`, not `CfgModifying`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] -pub enum ModificationScope { - /// The pass only replaces uses of variables with other existing variables. - /// - /// Examples: GVN (forwarding redundant uses to earlier definitions). - /// - /// SSA invariants are preserved automatically — no repair needed. - /// The pass does not create new variables, does not change instruction - /// opcodes (except possibly Nop-ing dead copies as a side effect), - /// and does not modify the CFG. - UsesOnly, - - /// The pass replaces or removes instructions but does not change the CFG. - /// - /// Examples: constant propagation (replacing ops with Const), copy - /// propagation (Nop-ing propagated copies), DCE (Nop-ing dead - /// instructions), algebraic simplification, strength reduction. - /// - /// After this scope, a lightweight repair is needed to: - /// - Strip Nop instructions and reindex DefSites - /// - Recompute variable metadata from surviving instructions - /// - Eliminate trivial phis and compact variables - /// - /// No dominator/dominance frontier recomputation is needed since - /// the CFG structure is unchanged. - InstructionsOnly, - - /// The pass may add or remove blocks, change successors/predecessors, - /// or otherwise modify control flow. - /// - /// Examples: control-flow unflattening, jump threading, block merging, - /// loop canonicalization, inlining. - /// - /// After this scope, a full `rebuild_ssa()` is required to restore - /// SSA invariants (recompute dominators, place phis, rename variables). - CfgModifying, -} - -/// An SSA transformation pass that operates on SSA form. -/// -/// All passes must be thread-safe (Send + Sync) to allow parallel execution. -/// Passes receive mutable access to the SSA function and shared access to -/// the analysis context. -/// -/// # Pipeline Integration -/// -/// Passes don't declare their own priority or triggers. Instead, the scheduler -/// runs passes in a fixed pipeline order based on a canonical optimization -/// sequence: -/// -/// 1. **Normalize**: ADCE, GVN, constant folding (loop until stable) -/// 2. **Opaque predicates**: Range analysis, predicate removal -/// 3. **CFG recovery**: Structuring, loop identification -/// 4. **Unflattening**: Control-flow unflattening -/// 5. **Proxy inlining**: Delegate/proxy method inlining -/// 6. **Decryption**: String and constant decryption -/// 7. **Devirtualization**: VM handler recovery (if present) -/// 8. **Cleanup**: Final DCE, GVN, small function inlining -/// -/// # Assembly Access -/// -/// Passes that need access to the assembly (e.g., for emulation) receive it -/// as a parameter. The assembly flows linearly through the pipeline with clear -/// ownership semantics - it is NOT stored in the context. -pub trait SsaPass: Send + Sync { - /// Unique name for logging and debugging. - fn name(&self) -> &'static str; - - /// Should this pass run on a specific method? - /// - /// Called before `run_on_method`. Override to skip methods that - /// don't need this pass (e.g., already processed, too simple). - /// - /// NOTE: Dead method skipping is NOT done here. Dead method detection - /// can be inaccurate for obfuscated code (e.g., CFF hides call sites). - /// All methods with SSA are processed; dead method filtering is handled - /// during code generation. - fn should_run(&self, _method_token: Token, _ctx: &CompilerContext) -> bool { - true - } - - /// Run the pass on a single method's SSA. - /// - /// This is the main entry point for per-method passes. - /// Returns `true` if any changes were made, `false` otherwise. - /// Events should be recorded directly to `ctx.events`. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to transform. - /// * `method_token` - The metadata token of the method. - /// * `ctx` - The compiler context (thread-safe, uses shared reference). - /// * `assembly` - Shared reference to the assembly (for emulation, lookups, etc.). - /// - /// # Errors - /// - /// Returns an error if the pass fails to process the method. - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result; - - /// Run on the entire assembly (for interprocedural passes). - /// - /// Override this for passes that need to see all methods at once, - /// like dead method detection or whole-program constant propagation. - /// Returns `true` if any changes were made, `false` otherwise. - /// Events should be recorded directly to `ctx.events`. - /// - /// # Arguments - /// - /// * `ctx` - The compiler context (thread-safe, uses shared reference). - /// * `assembly` - Shared reference to the assembly. - /// - /// # Errors - /// - /// Returns an error if the pass fails to process the assembly. - fn run_global(&self, _ctx: &CompilerContext, _assembly: &CilObject) -> Result { - Ok(false) - } - - /// Does this pass operate globally (across all methods)? - /// - /// Global passes have their `run_global` called instead of - /// iterating over methods with `run_on_method`. - fn is_global(&self) -> bool { - false - } - - /// Called once before the pass runs in a phase. - /// - /// Use this to initialize pass-specific state or caches. - /// - /// # Errors - /// - /// Returns an error if initialization fails. - fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { - Ok(()) - } - - /// Called once after the pass completes in a phase. - /// - /// Use this to clean up pass-specific state. - /// - /// # Errors - /// - /// Returns an error if finalization fails. - fn finalize(&mut self, _ctx: &CompilerContext) -> Result<()> { - Ok(()) - } - - /// Declares the extent of modifications this pass makes to the SSA function. - /// - /// The scheduler uses this to select the appropriate repair strategy: - /// - /// - [`ModificationScope::UsesOnly`] — No repair needed - /// - [`ModificationScope::InstructionsOnly`] — Lightweight [`SsaFunction::repair_ssa`] - /// - [`ModificationScope::CfgModifying`] — Full [`SsaFunction::rebuild_ssa`] - /// - /// The default is `CfgModifying` (conservative). Override this to declare - /// a tighter scope for passes that don't modify the CFG. - fn modification_scope(&self) -> ModificationScope { - ModificationScope::CfgModifying - } - - /// Get a description of what this pass does. - fn description(&self) -> &'static str { - "No description available" - } - - /// Capabilities this pass provides after successful execution. - /// - /// The scheduler uses this to determine which passes can run after this one. - /// Passes that don't override this return an empty slice and are scheduled - /// based on their fallback phase. - fn provides(&self) -> &[PassCapability] { - &[] - } - - /// Capabilities this pass requires before it can run. - /// - /// The scheduler ensures all providers of required capabilities are - /// scheduled in earlier layers. If no provider is registered for a - /// required capability, the requirement is ignored (allows the pass - /// to run at its fallback layer). - fn requires(&self) -> &[PassCapability] { - &[] - } - - /// Whether this pass reads other methods' SSA during `run_on_method`. - /// - /// Passes like inlining and proxy devirtualization look up callee SSA - /// via `ctx.with_ssa()`. During parallel execution, the scheduler must - /// keep each method's SSA visible in the DashMap so other threads can - /// read it. When this returns `true`, the scheduler clones the SSA - /// before processing instead of removing it, ensuring concurrent - /// visibility. Passes that only modify their own method should return - /// `false` (the default) to avoid the clone overhead. - fn reads_peer_ssa(&self) -> bool { - false - } - - /// Whether this pass requires a full scan of all methods every iteration. - /// - /// If `true`, the scheduler calls `run_on_method` for every method with SSA, - /// regardless of dirty tracking state. If `false` (default), the scheduler - /// only processes methods in the dirty set. - /// - /// Most passes operate on individual methods independently and should use - /// the default. Only passes that read other methods' SSA or need whole-program - /// visibility should return `true`. - fn requires_full_scan(&self) -> bool { - false - } -} +/// Backwards-compatible alias of [`CilCapability`]. +pub type PassCapability = CilCapability; diff --git a/dotscope/src/compiler/passes/algebraic.rs b/dotscope/src/compiler/passes/algebraic.rs deleted file mode 100644 index 095e27a4..00000000 --- a/dotscope/src/compiler/passes/algebraic.rs +++ /dev/null @@ -1,310 +0,0 @@ -//! Algebraic simplifications pass. -//! -//! This pass transforms algebraic identities into simpler forms: -//! -//! ## Self-canceling operations -//! - `x xor x` → `0` -//! - `x sub x` → `0` -//! -//! ## Idempotent operations -//! - `x or x` → `x` -//! - `x and x` → `x` -//! -//! ## Identity operations (with constant 0) -//! - `x add 0` / `0 add x` → `x` -//! - `x sub 0` → `x` -//! - `x xor 0` / `0 xor x` → `x` -//! - `x or 0` / `0 or x` → `x` -//! -//! ## Absorbing operations (with constant 0) -//! - `x mul 0` / `0 mul x` → `0` -//! - `x and 0` / `0 and x` → `0` -//! -//! ## Identity operations (with constant 1) -//! - `x mul 1` / `1 mul x` → `x` -//! -//! ## All-bits-set identity (with constant -1) -//! - `x and -1` / `-1 and x` → `x` -//! - `x or -1` / `-1 or x` → `-1` -//! -//! These simplifications are essential for deobfuscation because obfuscators -//! often insert redundant operations like `x xor x xor y` to compute `y`. - -use std::collections::BTreeMap; - -use crate::{ - analysis::{simplify_op, ConstValue, SimplifyResult, SsaFunction, SsaOp, SsaVarId}, - compiler::{ - pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, EventLog, - }, - metadata::token::Token, - CilObject, Result, -}; - -/// Algebraic simplifications pass that transforms redundant operations. -/// -/// This pass identifies patterns like `x xor x`, `x or x`, and operations -/// with identity elements (0 for add/xor/or, 1 for mul) and simplifies them. -pub struct AlgebraicSimplificationPass; - -impl Default for AlgebraicSimplificationPass { - fn default() -> Self { - Self::new() - } -} - -/// The type of simplification applied. -#[derive(Debug, Clone)] -enum Simplification { - /// Replace with a constant value - Constant(ConstValue), - /// Replace with a copy from another variable - Copy(SsaVarId), -} - -/// Information about a simplification candidate. -#[derive(Debug)] -struct SimplificationCandidate { - /// Block index - block_idx: usize, - /// Instruction index within block - instr_idx: usize, - /// The destination variable - dest: SsaVarId, - /// The simplification to apply - simplification: Simplification, - /// Description for logging - description: String, -} - -impl AlgebraicSimplificationPass { - /// Creates a new algebraic simplification pass. - #[must_use] - pub fn new() -> Self { - Self - } - - /// Identifies simplification candidates in the SSA function. - fn find_candidates( - ssa: &SsaFunction, - constants: &BTreeMap, - ) -> Vec { - let mut candidates = Vec::new(); - - for (block_idx, instr_idx, instr) in ssa.iter_instructions() { - let op = instr.op(); - if let Some(candidate) = Self::check_simplification(op, block_idx, instr_idx, constants) - { - candidates.push(candidate); - } - } - - candidates - } - - /// Checks if an operation can be algebraically simplified. - fn check_simplification( - op: &SsaOp, - block_idx: usize, - instr_idx: usize, - constants: &BTreeMap, - ) -> Option { - let dest = op.dest()?; - match simplify_op(op, constants) { - SimplifyResult::Constant(value) => Some(SimplificationCandidate { - block_idx, - instr_idx, - dest, - simplification: Simplification::Constant(value), - description: "algebraic → const".to_string(), - }), - SimplifyResult::Copy(src) => Some(SimplificationCandidate { - block_idx, - instr_idx, - dest, - simplification: Simplification::Copy(src), - description: "algebraic → copy".to_string(), - }), - SimplifyResult::None => None, - } - } - - /// Applies the simplifications to the SSA function. - fn apply_simplifications( - ssa: &mut SsaFunction, - candidates: Vec, - method_token: Token, - changes: &mut EventLog, - ) { - for candidate in candidates { - if let Some(block) = ssa.block_mut(candidate.block_idx) { - let Some(instr) = block.instructions_mut().get_mut(candidate.instr_idx) else { - continue; - }; - let new_op = match candidate.simplification { - Simplification::Constant(value) => SsaOp::Const { - dest: candidate.dest, - value, - }, - Simplification::Copy(src) => SsaOp::Copy { - dest: candidate.dest, - src, - }, - }; - instr.set_op(new_op); - changes - .record(EventKind::ConstantFolded) - .at(method_token, candidate.instr_idx) - .message(&candidate.description); - } - } - } -} - -impl SsaPass for AlgebraicSimplificationPass { - fn name(&self) -> &'static str { - "algebraic-simplification" - } - - fn description(&self) -> &'static str { - "Simplify algebraic identities (x xor x = 0, x or x = x, etc.)" - } - - fn modification_scope(&self) -> ModificationScope { - ModificationScope::InstructionsOnly - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - - // Find all constant definitions - let constants = ssa.find_constants(); - - // Find simplification candidates - let candidates = Self::find_candidates(ssa, &constants); - - // Apply simplifications - Self::apply_simplifications(ssa, candidates, method_token, &mut changes); - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) - } -} - -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - - use crate::{ - analysis::{ConstValue, SsaOp, SsaVarId}, - compiler::passes::algebraic::{AlgebraicSimplificationPass, Simplification}, - }; - - #[test] - fn test_div_by_one() { - let left = SsaVarId::from_index(0); - let right = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let constants: BTreeMap = [(right, ConstValue::I32(1))].into(); - let op = SsaOp::Div { - dest, - left, - right, - unsigned: false, - }; - let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants); - assert!(result.is_some()); - let candidate = result.unwrap(); - assert!(matches!(candidate.simplification, Simplification::Copy(v) if v == left)); - } - - #[test] - fn test_rem_by_one() { - let left = SsaVarId::from_index(0); - let right = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - let constants: BTreeMap = [(right, ConstValue::I32(1))].into(); - let op = SsaOp::Rem { - dest, - left, - right, - unsigned: false, - }; - let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants); - assert!(result.is_some()); - let candidate = result.unwrap(); - assert!(matches!( - candidate.simplification, - Simplification::Constant(ConstValue::I32(0)) - )); - } - - #[test] - fn test_ceq_same_var() { - let x = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let constants: BTreeMap = BTreeMap::new(); - let op = SsaOp::Ceq { - dest, - left: x, - right: x, - }; - let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants); - assert!(result.is_some()); - let candidate = result.unwrap(); - assert!(matches!( - candidate.simplification, - Simplification::Constant(ConstValue::I32(1)) - )); - } - - #[test] - fn test_clt_same_var() { - let x = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let constants: BTreeMap = BTreeMap::new(); - let op = SsaOp::Clt { - dest, - left: x, - right: x, - unsigned: false, - }; - let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants); - assert!(result.is_some()); - let candidate = result.unwrap(); - assert!(matches!( - candidate.simplification, - Simplification::Constant(ConstValue::I32(0)) - )); - } - - #[test] - fn test_cgt_same_var() { - let x = SsaVarId::from_index(0); - let dest = SsaVarId::from_index(1); - let constants: BTreeMap = BTreeMap::new(); - let op = SsaOp::Cgt { - dest, - left: x, - right: x, - unsigned: false, - }; - let result = AlgebraicSimplificationPass::check_simplification(&op, 0, 0, &constants); - assert!(result.is_some()); - let candidate = result.unwrap(); - assert!(matches!( - candidate.simplification, - Simplification::Constant(ConstValue::I32(0)) - )); - } -} diff --git a/dotscope/src/compiler/passes/blockmerge.rs b/dotscope/src/compiler/passes/blockmerge.rs deleted file mode 100644 index a1c4f248..00000000 --- a/dotscope/src/compiler/passes/blockmerge.rs +++ /dev/null @@ -1,908 +0,0 @@ -//! Block merging pass for simplifying control flow. -//! -//! Two optimizations: -//! -//! 1. **Trampoline elimination** — removes blocks containing only an unconditional -//! jump by redirecting predecessors to the ultimate target. -//! -//! 2. **Block coalescing** — merges a block into its sole predecessor when the -//! predecessor's only successor is that block. This eliminates unnecessary -//! block boundaries in straight-line code (common after CFF reconstruction). -//! Phi nodes in the successor are converted to Copy instructions since they -//! have exactly one incoming edge. -//! -//! # Entry Block Handling -//! -//! The entry block (B0) is handled specially because it has no predecessors to -//! redirect. When B0 is a trampoline, the target block is inlined into B0 -//! (if safe) or the method is marked for code regeneration. This generically -//! handles anti-disassembly patterns where obfuscators inject junk bytes after -//! an unconditional branch at method start (e.g., `br.s +N` followed by garbage). -//! The SSA builder never decodes the unreachable junk, so regenerating the IL -//! from SSA produces clean output. -//! -//! # Trampoline Elimination Example -//! -//! Before: -//! ```text -//! B0: jump B1 -//! B1: jump B4 -//! B4: ... actual code ... -//! ``` -//! -//! After: -//! ```text -//! B0: jump B4 -//! B4: ... actual code ... -//! ``` -//! -//! # Block Coalescing Example -//! -//! Before (after CFF reconstruction): -//! ```text -//! B5: v1 = call Foo() -//! jump B6 -//! B6: callvirt Bar(v1) -//! jump B7 -//! ``` -//! -//! After: -//! ```text -//! B5: v1 = call Foo() -//! callvirt Bar(v1) -//! jump B7 -//! ``` - -use std::collections::BTreeMap; - -use crate::{ - analysis::{PhiOperand, SsaFunction, SsaInstruction, SsaOp}, - compiler::{pass::SsaPass, passes::utils::resolve_chain, CompilerContext, EventKind, EventLog}, - metadata::token::Token, - utils::BitSet, - CilObject, Result, -}; - -/// Block merging pass for eliminating trampoline blocks. -/// -/// A trampoline block is a block that: -/// - Has no phi nodes -/// - Contains only a single unconditional jump instruction -/// -/// This pass redirects all edges that go through trampolines directly to their -/// ultimate targets, simplifying the control flow graph. -pub struct BlockMergingPass { - /// Maximum fixpoint iterations before stopping. - max_iterations: usize, -} - -impl BlockMergingPass { - /// Creates a new block merging pass. - /// - /// # Arguments - /// - /// * `max_iterations` - Maximum fixpoint iterations for both trampoline - /// elimination and block coalescing phases. The default config value is 50. - #[must_use] - pub fn new(max_iterations: usize) -> Self { - Self { max_iterations } - } - - /// Redirects all jumps that go to trampolines to their ultimate targets. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `trampolines` - Map of trampoline blocks to their direct targets. - /// * `method_token` - Token for change tracking. - /// * `changes` - Event log for recording changes. - /// - /// # Returns - /// - /// The number of redirections performed. - fn redirect_to_ultimate_targets( - ssa: &mut SsaFunction, - trampolines: &BTreeMap, - method_token: Token, - changes: &mut EventLog, - ) -> usize { - if trampolines.is_empty() { - return 0; - } - - // Precompute ultimate targets for all trampolines using shared utility - let ultimate_targets: BTreeMap = trampolines - .keys() - .map(|&t| (t, resolve_chain(trampolines, t))) - .collect(); - - // Collect which blocks redirect through which trampolines, so we can - // update phi operands at the ultimate target afterwards. - // Maps: (trampoline, ultimate_target) → [predecessor blocks that were redirected] - let mut redirected_preds: BTreeMap<(usize, usize), Vec> = BTreeMap::new(); - - let mut redirected: usize = 0; - - // Update all branch targets in all blocks - for block_idx in 0..ssa.block_count() { - if let Some(block) = ssa.block_mut(block_idx) { - for instr in block.instructions_mut() { - let op = instr.op_mut(); - let old_targets = op.successors(); - - // Redirect each trampoline to its ultimate target - let mut changed = false; - for (&trampoline, &ultimate) in &ultimate_targets { - if op.redirect_target(trampoline, ultimate) { - redirected_preds - .entry((trampoline, ultimate)) - .or_default() - .push(block_idx); - changed = true; - } - } - - if changed { - let new_targets = op.successors(); - changes - .record(EventKind::BranchSimplified) - .at(method_token, block_idx) - .message(format!( - "redirected through trampoline: {old_targets:?} -> {new_targets:?}" - )); - redirected = redirected.saturating_add(1); - } - } - } - } - - // Update phi operands at ultimate target blocks. When B_src was redirected - // from B_trampoline → B_target, phi operands at B_target that referenced - // B_trampoline must be updated to reference B_src instead. If multiple - // blocks were redirected through the same trampoline, the single phi - // operand is duplicated for each new predecessor. - for (&(trampoline, ultimate), preds) in &redirected_preds { - if let Some(target_block) = ssa.block_mut(ultimate) { - for phi in target_block.phi_nodes_mut() { - // Find the operand that came from the trampoline - let trampoline_operand = phi - .operands() - .iter() - .find(|op| op.predecessor() == trampoline) - .map(|op| op.value()); - - if let Some(value) = trampoline_operand { - // Update the existing operand to point to the first new predecessor - if let Some(&first_pred) = preds.first() { - for operand in phi.operands_mut() { - if operand.predecessor() == trampoline { - operand.set_predecessor(first_pred); - break; - } - } - } - - // Add duplicate operands for any additional predecessors - // (same value, different predecessor) - for &pred in preds.iter().skip(1) { - phi.add_operand(PhiOperand::new(value, pred)); - } - } - } - } - } - - redirected - } - - /// Clears trampoline blocks that are no longer referenced. - /// - /// After redirecting all edges away from trampolines, they become unreachable - /// and can be cleared. This is done by the DCE pass, but we record the event. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `trampolines` - The trampoline blocks to clear. - /// * `method_token` - Token for change tracking. - /// * `changes` - Event log for recording changes. - /// - /// # Returns - /// - /// The number of blocks cleared. - fn clear_trampolines( - ssa: &mut SsaFunction, - trampolines: &BTreeMap, - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut cleared: usize = 0; - - for &block_idx in trampolines.keys() { - if let Some(block) = ssa.block_mut(block_idx) { - if !block.instructions().is_empty() { - block.instructions_mut().clear(); - changes - .record(EventKind::BlockRemoved) - .at(method_token, block_idx) - .message(format!("cleared trampoline block B{block_idx}")); - cleared = cleared.saturating_add(1); - } - } - } - - cleared - } - - /// Runs a single iteration of block merging. - /// - /// # Returns - /// - /// The number of changes made (redirections + cleared blocks). - fn run_iteration(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize { - let trampolines = ssa.find_trampoline_blocks(true); - - if trampolines.is_empty() { - return 0; - } - - let redirected = - Self::redirect_to_ultimate_targets(ssa, &trampolines, method_token, changes); - let cleared = Self::clear_trampolines(ssa, &trampolines, method_token, changes); - - redirected.saturating_add(cleared) - } - - /// Merges blocks connected by a single edge. - /// - /// When Block A's only successor is Block B (via `Jump`) and Block B's only - /// predecessor is Block A, the two blocks can be merged: Block A's terminator - /// is replaced by Block B's instructions. Any phi nodes in Block B are - /// converted to Copy instructions (they have exactly one incoming edge). - /// - /// Blocks involved in exception handler boundaries are excluded because - /// merging them would break the handler region structure. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `method_token` - Token of the method being processed (for event logging). - /// * `changes` - Event log to record merge operations. - /// * `max_iterations` - Maximum fixpoint iterations for the merge loop. - fn coalesce_blocks( - ssa: &mut SsaFunction, - method_token: Token, - changes: &mut EventLog, - max_iterations: usize, - ) -> usize { - let mut merged: usize = 0; - - // Collect exception handler boundary blocks. - // - // Region *start* blocks (try_start, handler_start, filter_start) must not - // be used as the MERGE TARGET because absorbing a predecessor outside the - // region would pull non-region code into the region. - // - // Region *end* blocks (try_end, handler_end) must not be used as the MERGE - // SOURCE because absorbing a successor outside the region would extend the - // region past its intended boundary. - // - // Merging within a region is safe: if A is a try_start and B is the next - // block inside the same try body, merging B into A keeps the try region - // starting at A. - let mut no_merge_into = BitSet::new(ssa.block_count()); - let mut no_merge_from = BitSet::new(ssa.block_count()); - for handler in ssa.exception_handlers() { - if let Some(b) = handler.try_start_block { - no_merge_into.insert(b); - } - if let Some(b) = handler.try_end_block { - no_merge_from.insert(b); - } - if let Some(b) = handler.handler_start_block { - no_merge_into.insert(b); - } - if let Some(b) = handler.handler_end_block { - no_merge_from.insert(b); - } - if let Some(b) = handler.filter_start_block { - no_merge_into.insert(b); - } - } - - // Iterate until fixed point. - for _ in 0..max_iterations { - let mut iteration_merges: usize = 0; - - // Build predecessor counts for all blocks. - let block_count = ssa.block_count(); - let mut pred_counts: Vec = vec![0; block_count]; - let mut pred_of: Vec> = vec![None; block_count]; - for idx in 0..block_count { - let successors = ssa - .block(idx) - .and_then(|b| b.terminator_op()) - .map(SsaOp::successors) - .unwrap_or_default(); - for succ in successors { - if succ < block_count { - if let Some(c) = pred_counts.get_mut(succ) { - *c = c.saturating_add(1); - } - if let Some(p) = pred_of.get_mut(succ) { - *p = Some(idx); - } - } - } - } - // Entry block has an implicit edge. - if let Some(c) = pred_counts.get_mut(0) { - *c = c.saturating_add(1); - } - - // Find mergeable pairs: A -> B where A's terminator is Jump(B), - // B has exactly 1 predecessor, and neither is a handler boundary. - let mut pairs: Vec<(usize, usize)> = Vec::new(); - let mut consumed = BitSet::new(block_count); - for a_idx in 0..block_count { - if consumed.contains(a_idx) { - continue; - } - let b_idx = match ssa.block(a_idx).and_then(|b| b.terminator_op()) { - Some(SsaOp::Jump { target }) => *target, - _ => continue, - }; - if b_idx >= block_count || b_idx == a_idx { - continue; - } - if pred_counts.get(b_idx).copied().unwrap_or(0) != 1 { - continue; - } - if no_merge_from.contains(a_idx) || no_merge_into.contains(b_idx) { - continue; - } - // B must have instructions (not already cleared). - let b_empty = ssa.block(b_idx).is_none_or(|b| b.instructions().is_empty()); - if b_empty { - continue; - } - pairs.push((a_idx, b_idx)); - consumed.insert(a_idx); - consumed.insert(b_idx); - } - - for (a_idx, b_idx) in pairs { - // Convert B's phi nodes to Copy instructions. - let phi_copies: Vec = ssa - .block(b_idx) - .map(|b| { - b.phi_nodes() - .iter() - .filter_map(|phi| { - // Single predecessor → exactly one operand. - let operand = phi.operands().first()?; - let dest = phi.result(); - let src = operand.value(); - if dest == src { - return None; // Self-copy, skip. - } - Some(SsaInstruction::synthetic(SsaOp::Copy { dest, src })) - }) - .collect() - }) - .unwrap_or_default(); - - // Take B's instructions. - let b_instrs: Vec = ssa - .block(b_idx) - .map(|b| b.instructions().to_vec()) - .unwrap_or_default(); - - // Remove A's terminator (the Jump) and append phi copies + B's instructions. - if let Some(a_block) = ssa.block_mut(a_idx) { - // Pop the Jump terminator. - let instrs = a_block.instructions_mut(); - if instrs - .last() - .is_some_and(|i| matches!(i.op(), SsaOp::Jump { .. })) - { - instrs.pop(); - } - // Append phi copies then B's instructions. - instrs.extend(phi_copies); - instrs.extend(b_instrs); - } - - // Update B's internal self-references to point to A. - if let Some(a_block) = ssa.block_mut(a_idx) { - for instr in a_block.instructions_mut() { - instr.op_mut().redirect_target(b_idx, a_idx); - } - } - - // Clear B. - if let Some(b_block) = ssa.block_mut(b_idx) { - b_block.phi_nodes_mut().clear(); - b_block.instructions_mut().clear(); - } - - // Redirect any other block that referenced B to now reference A. - // This handles the case where B had successors that now become A's - // successors — their phi operands need predecessor updates. - for phi_block_idx in 0..block_count { - if phi_block_idx == a_idx || phi_block_idx == b_idx { - continue; - } - if let Some(block) = ssa.block_mut(phi_block_idx) { - for phi in block.phi_nodes_mut() { - for operand in phi.operands_mut() { - if operand.predecessor() == b_idx { - *operand = PhiOperand::new(operand.value(), a_idx); - } - } - } - } - } - - changes - .record(EventKind::BlockRemoved) - .at(method_token, b_idx) - .message(format!("coalesced B{b_idx} into B{a_idx}")); - - iteration_merges = iteration_merges.saturating_add(1); - } - - merged = merged.saturating_add(iteration_merges); - if iteration_merges == 0 { - break; - } - } - - merged - } - - /// Simplifies an entry block that is just a trampoline (unconditional jump). - /// - /// Non-entry trampolines are handled by `run_iteration` which redirects - /// predecessors and clears the block. The entry block (B0) has no - /// predecessors, so that approach doesn't work — there's nothing to - /// redirect. - /// - /// Instead, when B0 is a trampoline to B_target: - /// - /// - If B_target has exactly one predecessor (B0) and no phi nodes, we - /// inline B_target's instructions into B0 and clear B_target. - /// - Otherwise, we just mark the method as modified so codegen regenerates - /// clean IL without the original junk bytes (e.g., anti-disassembly - /// garbage injected by obfuscators like BitMono's junk prefix). - fn simplify_entry_trampoline( - ssa: &mut SsaFunction, - method_token: Token, - changes: &mut EventLog, - ) { - // Check if B0 is a trampoline - let target = match ssa.block(0).and_then(|b| b.is_trampoline()) { - Some(t) => t, - None => return, - }; - - let preds = ssa.block_predecessors(target); - let target_has_phis = ssa.block(target).is_none_or(|b| !b.phi_nodes().is_empty()); - - if preds.len() == 1 && preds.first().copied() == Some(0) && !target_has_phis { - // Safe to inline: B_target's only external predecessor is B0 and it - // has no phis. Move B_target's instructions into B0. - let target_instrs = ssa - .block(target) - .map(|b| b.instructions().to_vec()) - .unwrap_or_default(); - - if let Some(entry) = ssa.block_mut(0) { - entry.instructions_mut().clear(); - *entry.instructions_mut() = target_instrs; - - // Redirect any self-references: if B_target had a back-edge to - // itself (e.g., a loop), those now need to point to B0 since - // B_target's content lives in B0. - for instr in entry.instructions_mut() { - instr.op_mut().redirect_target(target, 0); - } - } - - if let Some(target_block) = ssa.block_mut(target) { - target_block.instructions_mut().clear(); - } - - changes - .record(EventKind::BlockRemoved) - .at(method_token, 0) - .message(format!( - "inlined entry trampoline: B0 jump to B{target} merged into B0" - )); - } else { - // Can't inline (multiple predecessors or phis), but mark as modified - // so codegen regenerates clean IL without original junk bytes. - changes - .record(EventKind::BranchSimplified) - .at(method_token, 0) - .message(format!( - "entry block is trampoline to B{target} (regenerating clean IL)" - )); - } - } -} - -impl SsaPass for BlockMergingPass { - fn name(&self) -> &'static str { - "block-merging" - } - - fn description(&self) -> &'static str { - "Eliminates trampoline blocks and coalesces single-edge block pairs" - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - - // Phase 1: Eliminate trampoline blocks (jump-only blocks). - // Trampoline elimination updates phi operands at target blocks to - // reference the new predecessors, maintaining SSA invariants. - for _ in 0..self.max_iterations { - let iteration_changes = Self::run_iteration(ssa, method_token, &mut changes); - - if iteration_changes == 0 { - break; - } - } - - // Phase 2: Handle entry block trampoline — B0 has no predecessors so the - // redirect-and-clear approach above can't handle it. Instead, inline - // the target block when safe, or just mark for regeneration. - Self::simplify_entry_trampoline(ssa, method_token, &mut changes); - - // Phase 3: Coalesce non-trivial blocks connected by a single edge. - // After trampoline elimination and CFF reconstruction, there may be - // blocks with actual instructions connected by unconditional jumps - // where the successor has a single predecessor. Merging these produces - // larger blocks, reducing cross-block stores in the codegen. - Self::coalesce_blocks(ssa, method_token, &mut changes, self.max_iterations); - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::{ - analysis::{CallGraph, SsaFunctionBuilder, SsaOp}, - compiler::{passes::blockmerge::BlockMergingPass, CompilerContext, SsaPass}, - metadata::token::Token, - test::helpers::test_assembly_arc, - }; - - #[test] - fn test_redirect_simple() { - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry, jump to B1 (trampoline) - f.block(0, |b| b.jump(1)); - // B1: trampoline to B2 - f.block(1, |b| b.jump(2)); - // B2: actual code - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!(changed); - - // B0 should now jump directly to B2 - if let Some(block) = ssa.block(0) { - if let Some(instr) = block.instructions().first() { - if let SsaOp::Jump { target } = instr.op() { - assert_eq!(*target, 2); - } - } - } - - // B1 should be cleared (empty) - if let Some(block) = ssa.block(1) { - assert!( - block.instructions().is_empty(), - "B1 should be cleared, but has {} instructions", - block.instructions().len() - ); - } - } - - #[test] - fn test_chain_of_trampolines() { - // B0 -> B1 -> B2 -> B3 (actual code) - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.jump(2)); - f.block(2, |b| b.jump(3)); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!(changed); - - // B0 should jump directly to B3 (following the chain) - if let Some(block) = ssa.block(0) { - if let Some(instr) = block.instructions().first() { - if let SsaOp::Jump { target } = instr.op() { - assert_eq!(*target, 3, "B0 should jump to B3, not B{}", *target); - } - } - } - - // B1 and B2 should be cleared - for i in 1..=2 { - if let Some(block) = ssa.block(i) { - assert!(block.instructions().is_empty(), "B{} should be cleared", i); - } - } - } - - #[test] - fn test_entry_trampoline_inlined() { - // B0 is a trampoline to B1, B1 has only B0 as predecessor. - // B0's content should be replaced with B1's instructions. - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!(changed, "entry trampoline should trigger a change"); - - // B0 should now contain B1's ret instruction - let block0 = ssa.block(0).unwrap(); - assert_eq!(block0.instruction_count(), 1); - assert!( - matches!(block0.instructions()[0].op(), SsaOp::Return { .. }), - "B0 should contain ret after inlining, got {:?}", - block0.instructions()[0].op() - ); - - // B1 should be cleared - let block1 = ssa.block(1).unwrap(); - assert!( - block1.instructions().is_empty(), - "B1 should be cleared after inlining" - ); - } - - #[test] - fn test_entry_trampoline_with_loop() { - // B0: jump B1, B1: branch(cond, B2, B3), B2: jump B1 (loop), B3: ret. - // The non-entry pass redirects B1's branch from B2 to B1 (self-loop), - // then the entry trampoline logic inlines B1 into B0. - // The self-reference to B1 should be redirected to B0. - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| { - let cond = b.const_i32(1); - b.branch(cond, 2, 3); - }); - f.block(2, |b| b.jump(1)); // back-edge to B1 - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!(changed); - - // B0 should now contain B1's code (const + branch) with the - // self-reference redirected from B1 to B0 - let block0 = ssa.block(0).unwrap(); - assert_eq!( - block0.instruction_count(), - 2, - "B0 should have const + branch" - ); - if let SsaOp::Branch { - true_target, - false_target, - .. - } = block0.instructions()[1].op() - { - assert_eq!(*true_target, 0, "self-loop should point to B0 after inline"); - assert_eq!(*false_target, 3, "exit should still point to B3"); - } else { - panic!("expected Branch in B0"); - } - } - - #[test] - fn test_entry_trampoline_not_inlined_multi_pred() { - // B0: jump B1, B1: code, B2: jump B1 (B1 has preds B0 AND B2). - // Can't inline B1 into B0, but method should be marked as changed. - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| { - let cond = b.const_i32(1); - b.branch(cond, 2, 3); - }); - f.block(2, |b| { - // Not a trampoline — has nop + jump (2 instructions) - b.nop(); - b.jump(1); - }); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!( - changed, - "entry trampoline should mark as changed even when target can't be inlined" - ); - - // B0 should still be a jump (B1 has 2 predecessors: B0 and B2) - let block0 = ssa.block(0).unwrap(); - assert_eq!(block0.instruction_count(), 1); - assert!( - matches!(block0.instructions()[0].op(), SsaOp::Jump { .. }), - "B0 should remain a jump when target has multiple external predecessors" - ); - } - - #[test] - fn test_no_entry_trampoline() { - // B0 has actual code — not a trampoline. Should report no changes. - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!(!changed, "non-trampoline entry should report no changes"); - } - - #[test] - fn test_coalesce_single_edge_blocks() { - // B0: const + jump B1, B1: const + jump B2, B2: ret. - // B0→B1 and B1→B2 are single-edge pairs that should be coalesced. - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(42); - b.jump(1); - }); - f.block(1, |b| { - let _ = b.const_i32(99); - b.jump(2); - }); - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - assert!(changed, "block coalescing should trigger changes"); - - // B0 should now contain all instructions: two consts + ret - let block0 = ssa.block(0).unwrap(); - assert!( - block0.instruction_count() >= 3, - "B0 should have at least 3 instructions after coalescing, got {}", - block0.instruction_count() - ); - assert!( - matches!( - block0.instructions().last().map(|i| i.op()), - Some(SsaOp::Return { .. }) - ), - "B0's last instruction should be ret" - ); - - // B1 and B2 should be cleared - for i in 1..=2 { - if let Some(block) = ssa.block(i) { - assert!( - block.instructions().is_empty(), - "B{i} should be cleared after coalescing" - ); - } - } - } - - #[test] - fn test_coalesce_preserves_multi_predecessor_blocks() { - // B0: branch(cond, B1, B2), B1: jump B3, B2: jump B3, B3: ret. - // B3 has two predecessors — should NOT be coalesced. - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let c = b.const_i32(1); - b.branch(c, 1, 2); - }); - f.block(1, |b| { - let _ = b.const_i32(10); - b.jump(3); - }); - f.block(2, |b| { - let _ = b.const_i32(20); - b.jump(3); - }); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let pass = BlockMergingPass::new(50); - let ctx = CompilerContext::new(Arc::new(CallGraph::new())); - let assembly = test_assembly_arc(); - - pass.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &assembly) - .unwrap(); - - // B3 should still exist with instructions (not merged) - let block3 = ssa.block(3).unwrap(); - assert!( - !block3.instructions().is_empty(), - "B3 should NOT be coalesced (has 2 predecessors)" - ); - } -} diff --git a/dotscope/src/compiler/passes/constants/mod.rs b/dotscope/src/compiler/passes/constants/mod.rs index 77ad63ad..ace94bd9 100644 --- a/dotscope/src/compiler/passes/constants/mod.rs +++ b/dotscope/src/compiler/passes/constants/mod.rs @@ -35,18 +35,20 @@ use std::collections::BTreeMap; +use analyssa::BitSet; + use crate::{ analysis::{ - simplify_op, CmpKind, ConstValue, ConstantPropagation, MethodRef, SccpResult, - SimplifyResult, SsaCfg, SsaEvaluator, SsaFunction, SsaOp, SsaType, SsaVarId, + simplify_op, CilTarget, CmpKind, ConstValue, ConstValueCilExt, ConstantPropagation, + MethodRef, SccpResult, SimplifyResult, SsaCfg, SsaEvaluator, SsaFunction, SsaOp, SsaType, + SsaVarId, }, compiler::{ pass::{ModificationScope, SsaPass}, CompilerContext, EventKind, EventLog, }, - metadata::{token::Token, typesystem::PointerSize}, - utils::BitSet, - CilObject, Result, + metadata::{tables::TableId, token::Token, typesystem::PointerSize}, + CilObject, }; /// Checks whether `token` resolves to a method whose declaring type name contains `type_name`. @@ -397,10 +399,11 @@ impl ConstantPropagationPass { } = instr.op() { if let Some(operand_val) = constants.get(operand) { + let ptr_bytes = ptr_size.bytes() as u32; let result = if *overflow_check { - operand_val.convert_to_checked(target, *unsigned, ptr_size) + operand_val.convert_to_checked(target, *unsigned, ptr_bytes) } else { - operand_val.convert_to(target, *unsigned, ptr_size) + operand_val.convert_to(target, *unsigned, ptr_bytes) }; if let Some(result) = result { new_constants.push((*dest, result, block_idx, instr_idx)); @@ -798,6 +801,7 @@ impl ConstantPropagationPass { left, right, unsigned, + .. } => { let l = constants.get(left)?; let r = constants.get(right)?; @@ -822,6 +826,7 @@ impl ConstantPropagationPass { left, right, unsigned, + .. } => { let l = constants.get(left)?; let r = constants.get(right)?; @@ -846,6 +851,7 @@ impl ConstantPropagationPass { left, right, unsigned, + .. } => { let l = constants.get(left)?; let r = constants.get(right)?; @@ -906,7 +912,7 @@ impl ConstantPropagationPass { }; // Only handle MethodDef tokens (same-assembly methods) - if !callee_token.is_table(crate::metadata::tables::TableId::MethodDef) { + if !callee_token.is_table(TableId::MethodDef) { continue; } @@ -984,6 +990,7 @@ impl ConstantPropagationPass { left, right, unsigned, + .. } = instr.op() { let lval = constants @@ -1392,8 +1399,8 @@ impl ConstantPropagationPass { for (block_idx, instr_idx, instr) in ssa.iter_instructions() { let (dest, operand, is_neg) = match instr.op() { - SsaOp::Neg { dest, operand } => (*dest, *operand, true), - SsaOp::Not { dest, operand } => (*dest, *operand, false), + SsaOp::Neg { dest, operand, .. } => (*dest, *operand, true), + SsaOp::Not { dest, operand, .. } => (*dest, *operand, false), _ => continue, }; @@ -1439,10 +1446,12 @@ impl ConstantPropagationPass { SsaOp::Neg { dest: d, operand: inner, + .. } if is_neg => (*d, *inner), SsaOp::Not { dest: d, operand: inner, + .. } if !is_neg => (*d, *inner), _ => break, }; @@ -1498,11 +1507,13 @@ impl ConstantPropagationPass { instr.set_op(SsaOp::Neg { dest: t.outermost_dest, operand: t.innermost_operand, + flags: None, }); } else { instr.set_op(SsaOp::Not { dest: t.outermost_dest, operand: t.innermost_operand, + flags: None, }); } } @@ -1738,7 +1749,7 @@ impl ConstantPropagationPass { } } -impl SsaPass for ConstantPropagationPass { +impl SsaPass for ConstantPropagationPass { fn name(&self) -> &'static str { "constant-propagation" } @@ -1754,12 +1765,15 @@ impl SsaPass for ConstantPropagationPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly = host + .assembly() + .ok_or_else(|| analyssa::Error::new("ConstantPropagationPass requires an assembly"))?; + let method_token = method.0; let mut changes = EventLog::new(); - let ptr_size = PointerSize::from_pe(assembly.file().pe().is_64bit); + let ptr_size = PointerSize::from_is_64bit(assembly.file().pe().is_64bit); // Run constant propagation and transformation let constants = Self::run_constant_propagation( @@ -1768,17 +1782,17 @@ impl SsaPass for ConstantPropagationPass { &mut changes, ptr_size, self.max_iterations, - assembly, + &assembly, ); // Cache the constants we found for other passes for (var, value) in &constants { - ctx.add_known_value(method_token, *var, value.clone()); + host.add_known_value(method_token, *var, value.clone()); } let changed = !changes.is_empty(); if changed { - ctx.events.merge(&changes); + host.events.merge(&changes); } Ok(changed) } diff --git a/dotscope/src/compiler/passes/constants/tests.rs b/dotscope/src/compiler/passes/constants/tests.rs index 8da3f500..a3283010 100644 --- a/dotscope/src/compiler/passes/constants/tests.rs +++ b/dotscope/src/compiler/passes/constants/tests.rs @@ -1,7 +1,10 @@ use std::{collections::BTreeMap, sync::Arc}; use crate::{ - analysis::{CallGraph, ConstValue, MethodRef, SsaFunctionBuilder, SsaOp, SsaType, SsaVarId}, + analysis::{ + CallGraph, ConstValue, ConstValueCilExt, MethodRef, SsaFunctionBuilder, SsaOp, SsaType, + SsaVarId, + }, compiler::{ passes::constants::{AlgebraicResult, ConstantPropagationPass}, CompilerContext, EventLog, SsaPass, @@ -10,10 +13,15 @@ use crate::{ test::helpers::test_assembly_arc, }; -/// Creates a test compiler context. +/// Creates a test compiler context with the standard test assembly +/// pre-attached. The analyssa pass scheduler reaches the assembly through +/// [`CompilerContext::assembly`], so passes that need it will fail +/// without this setup. fn test_context() -> CompilerContext { let call_graph = Arc::new(CallGraph::new()); - CompilerContext::new(call_graph) + let ctx = CompilerContext::new(call_graph); + ctx.set_assembly(test_assembly_arc()); + ctx } #[test] @@ -31,145 +39,142 @@ fn test_pass_default() { #[test] fn test_conv_i32_to_i8() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::I8, false, PointerSize::Bit64), + operand.convert_to(&SsaType::I8, false, 8), Some(ConstValue::I8(42)) ); } #[test] fn test_conv_i32_to_i8_truncate() { - let operand = ConstValue::I32(1000); + let operand: ConstValue = ConstValue::I32(1000); // 1000 truncated to i8 is -24 (1000 & 0xFF = 232, as signed = -24) assert_eq!( - operand.convert_to(&SsaType::I8, false, PointerSize::Bit64), + operand.convert_to(&SsaType::I8, false, 8), Some(ConstValue::I8(-24)) ); } #[test] fn test_conv_i32_to_i64() { - let operand = ConstValue::I32(-42); + let operand: ConstValue = ConstValue::I32(-42); assert_eq!( - operand.convert_to(&SsaType::I64, false, PointerSize::Bit64), + operand.convert_to(&SsaType::I64, false, 8), Some(ConstValue::I64(-42)) ); } #[test] fn test_conv_to_bool_nonzero() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::Bool, false, PointerSize::Bit64), + operand.convert_to(&SsaType::Bool, false, 8), Some(ConstValue::True) ); } #[test] fn test_conv_to_bool_zero() { - let operand = ConstValue::I32(0); + let operand: ConstValue = ConstValue::I32(0); assert_eq!( - operand.convert_to(&SsaType::Bool, false, PointerSize::Bit64), + operand.convert_to(&SsaType::Bool, false, 8), Some(ConstValue::False) ); } #[test] fn test_conv_to_f32() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::F32, false, PointerSize::Bit64), + operand.convert_to(&SsaType::F32, false, 8), Some(ConstValue::F32(42.0)) ); } #[test] fn test_conv_ovf_in_range() { - let operand = ConstValue::I32(100); + let operand: ConstValue = ConstValue::I32(100); assert_eq!( - operand.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), + operand.convert_to_checked(&SsaType::I8, false, 8), Some(ConstValue::I8(100)) ); } #[test] fn test_conv_ovf_out_of_range() { - let operand = ConstValue::I32(1000); - assert_eq!( - operand.convert_to_checked(&SsaType::I8, false, PointerSize::Bit64), - None - ); // Would overflow + let operand: ConstValue = ConstValue::I32(1000); + assert_eq!(operand.convert_to_checked(&SsaType::I8, false, 8), None); // Would overflow } #[test] fn test_conv_u8() { - let operand = ConstValue::I32(200); + let operand: ConstValue = ConstValue::I32(200); assert_eq!( - operand.convert_to(&SsaType::U8, false, PointerSize::Bit64), + operand.convert_to(&SsaType::U8, false, 8), Some(ConstValue::U8(200)) ); } #[test] fn test_conv_u16() { - let operand = ConstValue::I32(50000); + let operand: ConstValue = ConstValue::I32(50000); assert_eq!( - operand.convert_to(&SsaType::U16, false, PointerSize::Bit64), + operand.convert_to(&SsaType::U16, false, 8), Some(ConstValue::U16(50000)) ); } #[test] fn test_conv_u32() { - let operand = ConstValue::I64(3_000_000_000); + let operand: ConstValue = ConstValue::I64(3_000_000_000); assert_eq!( - operand.convert_to(&SsaType::U32, false, PointerSize::Bit64), + operand.convert_to(&SsaType::U32, false, 8), Some(ConstValue::U32(3_000_000_000)) ); } #[test] fn test_conv_u64() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::U64, false, PointerSize::Bit64), + operand.convert_to(&SsaType::U64, false, 8), Some(ConstValue::U64(42)) ); } #[test] fn test_conv_f64() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::F64, false, PointerSize::Bit64), + operand.convert_to(&SsaType::F64, false, 8), Some(ConstValue::F64(42.0)) ); } #[test] fn test_conv_native_int() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::NativeInt, false, PointerSize::Bit64), + operand.convert_to(&SsaType::NativeInt, false, 8), Some(ConstValue::NativeInt(42)) ); } #[test] fn test_conv_native_uint() { - let operand = ConstValue::I32(42); + let operand: ConstValue = ConstValue::I32(42); assert_eq!( - operand.convert_to(&SsaType::NativeUInt, false, PointerSize::Bit64), + operand.convert_to(&SsaType::NativeUInt, false, 8), Some(ConstValue::NativeUInt(42)) ); } #[test] fn test_conv_char() { - let operand = ConstValue::I32(65); // 'A' + let operand: ConstValue = ConstValue::I32(65); // 'A' assert_eq!( - operand.convert_to(&SsaType::Char, false, PointerSize::Bit64), + operand.convert_to(&SsaType::Char, false, 8), Some(ConstValue::U16(65)) ); } @@ -187,6 +192,7 @@ fn test_identity_add_zero() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -213,6 +219,7 @@ fn test_identity_mul_one() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -238,6 +245,7 @@ fn test_identity_and_minus_one() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -263,6 +271,7 @@ fn test_absorbing_mul_zero() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -288,6 +297,7 @@ fn test_absorbing_and_zero() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -313,6 +323,7 @@ fn test_absorbing_or_minus_one() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -339,6 +350,7 @@ fn test_add_ovf_no_overflow() { left: v0, right: v1, unsigned: false, + flags: None, }; let result = ConstantPropagationPass::check_overflow_op(&op, &constants, PointerSize::Bit64); @@ -359,6 +371,7 @@ fn test_mul_ovf_with_zero() { left: v0, right: v1, unsigned: false, + flags: None, }; // x * 0 = 0, even with overflow check @@ -373,7 +386,7 @@ fn test_pass_empty_function() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); assert!(!result.unwrap()); } @@ -396,7 +409,7 @@ fn test_pass_simple_folding() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that v2 was folded to constant 8 @@ -431,7 +444,7 @@ fn test_pass_branch_simplification() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that branch was simplified to jump @@ -464,7 +477,7 @@ fn test_pass_switch_simplification() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that switch was simplified to jump to target 2 (index 1) @@ -496,7 +509,7 @@ fn test_constants_cached_in_context() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that constant was cached @@ -516,6 +529,7 @@ fn test_identity_shl_zero() { dest: v2, value: v0, amount: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -542,6 +556,7 @@ fn test_identity_shr_zero() { value: v0, amount: v1, unsigned: false, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -567,6 +582,7 @@ fn test_identity_xor_zero() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -593,6 +609,7 @@ fn test_identity_div_one() { left: v0, right: v1, unsigned: false, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -618,6 +635,7 @@ fn test_identity_sub_zero() { dest: v2, left: v0, right: v1, + flags: None, }; let result = ConstantPropagationPass::check_algebraic_identity(&op, &constants); @@ -658,7 +676,7 @@ fn test_chained_constant_folding() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that v4 was folded to 10 @@ -683,7 +701,7 @@ fn test_branch_false_condition() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that branch was simplified to jump to false branch (target 2) @@ -716,7 +734,7 @@ fn test_switch_out_of_range_uses_default() { let method_token = Token::new(0x0600_0001); let ctx = test_context(); - let result = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let result = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx); assert!(result.is_ok()); // Check that switch was simplified to jump to default (target 4) diff --git a/dotscope/src/compiler/passes/controlflow.rs b/dotscope/src/compiler/passes/controlflow.rs deleted file mode 100644 index 45f882da..00000000 --- a/dotscope/src/compiler/passes/controlflow.rs +++ /dev/null @@ -1,872 +0,0 @@ -//! Control flow simplification pass. -//! -//! Simplifies the control flow graph through several transformations: -//! -//! 1. **Jump threading**: Skip intermediate trampoline blocks -//! 2. **Branch canonicalization**: Simplify `branch cond, B, B` to `jump B` -//! 3. **Unreachable tail removal**: Remove code after unconditional exits -//! -//! Uses an iterative fixed-point algorithm to handle cascading simplifications. -//! -//! ## Example -//! -//! Before: -//! ```text -//! B0: jump B1 -//! B1: jump B2 -//! B2: ret -//! ``` -//! -//! After: -//! ```text -//! B0: jump B2 // Directly to B2 -//! B1: jump B2 // Will be eliminated by DCE -//! B2: ret -//! ``` -//! - -use std::collections::BTreeMap; - -use crate::{ - analysis::{SsaFunction, SsaOp}, - compiler::{ - pass::SsaPass, - passes::{deadcode::find_dead_tails, utils::resolve_chain}, - CompilerContext, EventKind, EventLog, - }, - metadata::token::Token, - CilObject, Result, -}; - -/// Control flow simplification pass. -/// -/// Performs iterative control flow simplification including: -/// - Jump threading through trampoline blocks -/// - Branch-to-same-target simplification -/// - Dead tail removal (code after terminators) -/// -/// The pass iterates until no more changes are made (fixed point). -pub struct ControlFlowSimplificationPass { - /// Maximum fixpoint iterations before stopping. - max_iterations: usize, -} - -impl ControlFlowSimplificationPass { - /// Creates a new control flow simplification pass. - /// - /// # Arguments - /// - /// * `max_iterations` - Maximum fixpoint iterations for jump threading, branch - /// simplification, and dead tail removal. The default config value is 20. - #[must_use] - pub fn new(max_iterations: usize) -> Self { - Self { max_iterations } - } - - /// Finds branches where both targets resolve to the same block. - /// - /// A branch `branch cond, B, B` can be simplified to `jump B` since - /// the condition doesn't affect the control flow. Also detects cases - /// where targets are different blocks but resolve to the same ultimate - /// destination through trampoline chains. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// * `trampolines` - Map of trampoline blocks to their targets. - /// - /// # Returns - /// - /// A vector of (block index, target block) pairs for branches to simplify. - fn find_same_target_branches( - ssa: &SsaFunction, - trampolines: &BTreeMap, - ) -> Vec<(usize, usize)> { - ssa.iter_blocks() - .filter_map(|(block_idx, block)| { - block.terminator_op().and_then(|op| match op { - SsaOp::Branch { - true_target, - false_target, - .. - } - | SsaOp::BranchCmp { - true_target, - false_target, - .. - } => { - if true_target == false_target { - return Some((block_idx, *true_target)); - } - // Resolve through trampoline chains to catch convergent targets - let true_ultimate = resolve_chain(trampolines, *true_target); - let false_ultimate = resolve_chain(trampolines, *false_target); - if true_ultimate == false_ultimate { - Some((block_idx, true_ultimate)) - } else { - None - } - } - SsaOp::Switch { - targets, default, .. - } => { - if targets.iter().all(|t| *t == *default) { - return Some((block_idx, *default)); - } - // Resolve through trampoline chains - let default_ultimate = resolve_chain(trampolines, *default); - if targets - .iter() - .all(|t| resolve_chain(trampolines, *t) == default_ultimate) - { - return Some((block_idx, default_ultimate)); - } - // Self-loop elimination: if all cases except one are - // self-loops (target == block_idx), the switch degenerates - // to a jump to the single non-self target. This handles - // residual CFF in exception handlers where the obfuscator - // creates an irresolvable cycle with endfinally as the - // only exit. - let non_self: Vec = targets - .iter() - .chain(std::iter::once(default)) - .copied() - .filter(|&t| t != block_idx) - .collect(); - if let Some(&first) = non_self.first() { - if non_self.iter().all(|t| *t == first) { - Some((block_idx, first)) - } else { - None - } - } else { - None - } - } - _ => None, - }) - }) - .collect() - } - - /// Applies jump threading to all control flow instructions. - /// - /// Updates jumps, branches, and switches to skip trampoline blocks - /// and go directly to their ultimate targets. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `trampolines` - The map of trampoline blocks. - /// * `method_token` - The method token for change tracking. - /// * `changes` - The change set to record modifications. - /// - /// # Returns - /// - /// The number of control flow instructions that were updated. - fn apply_jump_threading( - ssa: &mut SsaFunction, - trampolines: &BTreeMap, - method_token: Token, - changes: &mut EventLog, - ) -> usize { - // Precompute ultimate targets for all trampolines - let ultimate_targets: BTreeMap = trampolines - .keys() - .map(|&t| (t, resolve_chain(trampolines, t))) - .collect(); - - let mut threaded_count: usize = 0; - - for block_idx in 0..ssa.block_count() { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(last) = block.instructions_mut().last_mut() { - let op = last.op_mut(); - let old_targets = op.successors(); - - // Redirect each trampoline to its ultimate target - let mut changed = false; - for (&trampoline, &ultimate) in &ultimate_targets { - if op.redirect_target(trampoline, ultimate) { - changed = true; - } - } - - if changed { - let new_targets = op.successors(); - changes - .record(EventKind::ControlFlowRestructured) - .at(method_token, block_idx) - .message(format!("jump threaded: {old_targets:?} -> {new_targets:?}")); - threaded_count = threaded_count.saturating_add(1); - } - } - } - } - - threaded_count - } - - /// Simplifies branches where both targets are the same. - /// - /// Converts `branch cond, B, B` to `jump B`. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `same_target_branches` - The branches to simplify. - /// * `method_token` - The method token for change tracking. - /// * `changes` - The change set to record modifications. - /// - /// # Returns - /// - /// The number of branches that were simplified. - fn simplify_same_target_branches( - ssa: &mut SsaFunction, - same_target_branches: &[(usize, usize)], - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut simplified_count: usize = 0; - - for &(block_idx, target) in same_target_branches { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(last) = block.instructions_mut().last_mut() { - last.set_op(SsaOp::Jump { target }); - changes - .record(EventKind::BranchSimplified) - .at(method_token, block_idx) - .message(format!( - "branch to same target simplified: B{block_idx} branch -> jump B{target}" - )); - simplified_count = simplified_count.saturating_add(1); - } - } - } - - simplified_count - } - - /// Removes dead code tails (instructions after terminators). - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `dead_tails` - The dead tails to remove. - /// * `method_token` - The method token for change tracking. - /// * `changes` - The change set to record modifications. - /// - /// # Returns - /// - /// The number of instructions removed. - fn remove_dead_tails( - ssa: &mut SsaFunction, - dead_tails: &[(usize, usize)], - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut removed_count: usize = 0; - - for &(block_idx, start_idx) in dead_tails { - if let Some(block) = ssa.block_mut(block_idx) { - let instr_count = block.instruction_count(); - let to_remove = instr_count.saturating_sub(start_idx); - for _ in 0..to_remove { - block.instructions_mut().pop(); - removed_count = removed_count.saturating_add(1); - } - if to_remove > 0 { - changes - .record(EventKind::InstructionRemoved) - .at(method_token, block_idx) - .message(format!( - "removed {to_remove} dead instructions after terminator in B{block_idx}" - )); - } - } - } - - removed_count - } - - /// Runs a single iteration of control flow simplification. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `method_token` - The method token for change tracking. - /// * `changes` - The change set to record modifications. - /// - /// # Returns - /// - /// The total number of changes made during this iteration. - fn run_iteration(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize { - let mut total_changes: usize = 0; - - // Step 1: Find and apply jump threading (don't skip entry block) - let trampolines = ssa.find_trampoline_blocks(false); - if !trampolines.is_empty() { - total_changes = total_changes.saturating_add(Self::apply_jump_threading( - ssa, - &trampolines, - method_token, - changes, - )); - } - - // Step 2: Simplify branches to same target (also resolves through trampolines) - let same_target_branches = Self::find_same_target_branches(ssa, &trampolines); - if !same_target_branches.is_empty() { - total_changes = total_changes.saturating_add(Self::simplify_same_target_branches( - ssa, - &same_target_branches, - method_token, - changes, - )); - } - - // Step 3: Remove dead tails - let dead_tails = find_dead_tails(ssa); - if !dead_tails.is_empty() { - total_changes = total_changes.saturating_add(Self::remove_dead_tails( - ssa, - &dead_tails, - method_token, - changes, - )); - } - - total_changes - } -} - -impl SsaPass for ControlFlowSimplificationPass { - fn name(&self) -> &'static str { - "control-flow-simplification" - } - - fn description(&self) -> &'static str { - "Simplifies control flow by threading jumps and eliminating trampolines" - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - - // Iterate until fixed point - for _ in 0..self.max_iterations { - let iteration_changes = Self::run_iteration(ssa, method_token, &mut changes); - if iteration_changes == 0 { - break; - } - } - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) - } -} -#[cfg(test)] -mod tests { - use std::collections::BTreeMap; - use std::sync::Arc; - - use crate::{ - analysis::{ - CallGraph, ConstValue, SsaBlock, SsaFunction, SsaFunctionBuilder, SsaInstruction, - SsaOp, SsaVarId, - }, - compiler::{ - passes::{controlflow::ControlFlowSimplificationPass, deadcode::find_dead_tails}, - CompilerContext, SsaPass, - }, - metadata::token::Token, - test::helpers::test_assembly_arc, - }; - - /// Helper to create a minimal analysis context for testing. - fn test_context() -> CompilerContext { - let call_graph = Arc::new(CallGraph::new()); - CompilerContext::new(call_graph) - } - - #[test] - fn test_find_same_target_branches_none() { - let ssa = SsaFunctionBuilder::new(3, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); // Different targets - }); - }) - .unwrap(); - - let same_targets = - ControlFlowSimplificationPass::find_same_target_branches(&ssa, &BTreeMap::new()); - assert!(same_targets.is_empty()); - } - - #[test] - fn test_find_same_target_branches_found() { - let ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 1); // Same target! - }); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - let same_targets = - ControlFlowSimplificationPass::find_same_target_branches(&ssa, &BTreeMap::new()); - assert_eq!(same_targets.len(), 1); - assert_eq!(same_targets[0], (0, 1)); - } - - #[test] - fn test_find_same_target_branches_multiple() { - let ssa = SsaFunctionBuilder::new(4, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 2, 2); - }); - f.block(1, |b| { - let cond = b.const_true(); - b.branch(cond, 3, 3); - }); - f.block(2, |b| b.ret()); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let same_targets = - ControlFlowSimplificationPass::find_same_target_branches(&ssa, &BTreeMap::new()); - assert_eq!(same_targets.len(), 2); - } - - #[test] - fn test_find_same_target_branches_convergent_trampolines() { - // Branch(cond, 1, 2) where both 1 and 2 are trampolines to 3 - let ssa = SsaFunctionBuilder::new(4, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); // Different targets... - }); - f.block(1, |b| b.jump(3)); // ...but both trampoline to 3 - f.block(2, |b| b.jump(3)); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - // Without trampoline info, targets look different - let same_targets = - ControlFlowSimplificationPass::find_same_target_branches(&ssa, &BTreeMap::new()); - assert!(same_targets.is_empty()); - - // With trampoline info, convergence is detected - let trampolines = ssa.find_trampoline_blocks(false); - let same_targets = - ControlFlowSimplificationPass::find_same_target_branches(&ssa, &trampolines); - assert_eq!(same_targets.len(), 1); - assert_eq!(same_targets[0], (0, 3)); - } - - #[test] - fn test_find_same_target_branches_one_trampoline() { - // Branch(cond, 2, 1) where 1 is a trampoline to 2 - let ssa = SsaFunctionBuilder::new(3, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 2, 1); // true goes direct, false via trampoline - }); - f.block(1, |b| b.jump(2)); // trampoline to 2 - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let trampolines = ssa.find_trampoline_blocks(false); - let same_targets = - ControlFlowSimplificationPass::find_same_target_branches(&ssa, &trampolines); - assert_eq!(same_targets.len(), 1); - assert_eq!(same_targets[0], (0, 2)); - } - - #[test] - fn test_find_dead_tails_empty() { - let ssa = SsaFunctionBuilder::new(0, 0).build_with(|_f| {}).unwrap(); - let dead_tails = find_dead_tails(&ssa); - assert!(dead_tails.is_empty()); - } - - #[test] - fn test_find_dead_tails_with_dead_code() { - // Need to use manual construction here since builder won't allow - // instructions after a terminator - let mut ssa = SsaFunction::new(1, 0); - - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - // Dead code after return - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: SsaVarId::from_index(0), - value: ConstValue::I32(42), - })); - ssa.add_block(block0); - - let dead_tails = find_dead_tails(&ssa); - assert_eq!(dead_tails.len(), 1); - assert_eq!(dead_tails[0], (0, 1)); - } - - #[test] - fn test_find_dead_tails_no_dead_code() { - let ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(42); - b.ret(); - }); - }) - .unwrap(); - - let dead_tails = find_dead_tails(&ssa); - assert!(dead_tails.is_empty()); - } - - #[test] - fn test_find_dead_tails_multiple_dead_instructions() { - // Need to use manual construction here since builder won't allow - // instructions after a terminator - let mut ssa = SsaFunction::new(1, 0); - - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: SsaVarId::from_index(0), - value: ConstValue::I32(1), - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: SsaVarId::from_index(1), - value: ConstValue::I32(2), - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: SsaVarId::from_index(2), - value: ConstValue::I32(3), - })); - ssa.add_block(block0); - - let dead_tails = find_dead_tails(&ssa); - assert_eq!(dead_tails.len(), 1); - assert_eq!(dead_tails[0], (0, 1)); // Start at index 1 - } - - #[test] - fn test_pass_empty_function() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - let mut ssa = SsaFunctionBuilder::new(0, 0).build_with(|_f| {}).unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(!changed); - } - - #[test] - fn test_pass_no_simplification_needed() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - let mut ssa = SsaFunctionBuilder::new(1, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(42); - b.ret(); - }); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(!changed); - } - - #[test] - fn test_pass_jump_threading() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - // Block 0: jump to trampoline - // Block 1: trampoline to block 2 - // Block 2: return - let mut ssa = SsaFunctionBuilder::new(3, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.jump(2)); - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // Verify block 0 now jumps directly to block 2 - if let Some(block) = ssa.block(0) { - if let Some(SsaOp::Jump { target }) = block.terminator_op() { - assert_eq!(*target, 2); - } - } - } - - #[test] - fn test_pass_leave_threading() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - // Block 0: leave to trampoline - // Block 1: trampoline (leave) to block 2 - // Block 2: return - let mut ssa = SsaFunctionBuilder::new(3, 0) - .build_with(|f| { - f.block(0, |b| b.leave(1)); - f.block(1, |b| b.leave(2)); - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // Verify block 0 now leaves directly to block 2 - if let Some(block) = ssa.block(0) { - if let Some(SsaOp::Leave { target }) = block.terminator_op() { - assert_eq!(*target, 2); - } - } - } - - #[test] - fn test_pass_branch_threading() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - // Block 0: branch to trampolines - // Block 1: trampoline to block 3 - // Block 2: trampoline to block 4 - // Block 3, 4: return - let mut ssa = SsaFunctionBuilder::new(5, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - f.block(1, |b| b.jump(3)); - f.block(2, |b| b.jump(4)); - f.block(3, |b| b.ret()); - f.block(4, |b| b.ret()); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // Verify branch targets were threaded - if let Some(block) = ssa.block(0) { - if let Some(SsaOp::Branch { - true_target, - false_target, - .. - }) = block.terminator_op() - { - assert_eq!(*true_target, 3); - assert_eq!(*false_target, 4); - } - } - } - - #[test] - fn test_pass_switch_threading() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - // Block 0: switch with trampoline targets - // Blocks 1, 2, 3: trampolines to block 4 - // Block 4: return - let mut ssa = SsaFunctionBuilder::new(5, 0) - .build_with(|f| { - f.block(0, |b| { - let val = b.const_i32(0); - b.switch(val, vec![1, 2], 3); - }); - f.block(1, |b| b.jump(4)); - f.block(2, |b| b.jump(4)); - f.block(3, |b| b.jump(4)); - f.block(4, |b| b.ret()); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // Verify switch targets were threaded - if let Some(block) = ssa.block(0) { - if let Some(SsaOp::Switch { - targets, default, .. - }) = block.terminator_op() - { - assert!(targets.iter().all(|&t| t == 4)); - assert_eq!(*default, 4); - } - } - } - - #[test] - fn test_pass_same_target_branch_simplification() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - // Block 0: branch to same target - let mut ssa = SsaFunctionBuilder::new(2, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 1); - }); - f.block(1, |b| b.ret()); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // Verify branch was converted to jump - if let Some(block) = ssa.block(0) { - assert!(matches!( - block.terminator_op(), - Some(SsaOp::Jump { target: 1 }) - )); - } - } - - #[test] - fn test_pass_dead_tail_removal() { - // Need to use manual construction here since builder won't allow - // instructions after a terminator - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - let mut ssa = SsaFunction::new(1, 0); - - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: SsaVarId::from_index(0), - value: ConstValue::I32(42), - })); - ssa.add_block(block0); - - assert_eq!(ssa.block(0).unwrap().instruction_count(), 2); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - assert_eq!(ssa.block(0).unwrap().instruction_count(), 1); - } - - #[test] - fn test_pass_iterative_convergence() { - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - // Create a chain: 0 -> 1 -> 2 -> 3 -> 4 - let mut ssa = SsaFunctionBuilder::new(5, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.jump(2)); - f.block(2, |b| b.jump(3)); - f.block(3, |b| b.jump(4)); - f.block(4, |b| b.ret()); - }) - .unwrap(); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // All jumps should now go directly to block 4 - for i in 0..4 { - if let Some(block) = ssa.block(i) { - if let Some(SsaOp::Jump { target }) = block.terminator_op() { - assert_eq!(*target, 4); - } - } - } - } - - #[test] - fn test_pass_combined_simplifications() { - // Need to use manual construction here since builder won't allow - // instructions after a terminator (for the dead tail test case) - let pass = ControlFlowSimplificationPass::new(20); - let ctx = test_context(); - let mut ssa = SsaFunction::new(4, 0); - - // Block 0: branch to same trampoline target - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: SsaVarId::from_index(0), - true_target: 1, - false_target: 1, - })); - ssa.add_block(block0); - - // Block 1: trampoline to block 2 - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 2 })); - ssa.add_block(block1); - - // Block 2: trampoline to block 3 - let mut block2 = SsaBlock::new(2); - block2.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 3 })); - ssa.add_block(block2); - - // Block 3: return with dead tail - let mut block3 = SsaBlock::new(3); - block3.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - block3.add_instruction(SsaInstruction::synthetic(SsaOp::Nop)); - ssa.add_block(block3); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - assert!(changed); - - // Block 0 should be a jump to block 3 - if let Some(block) = ssa.block(0) { - assert!(matches!( - block.terminator_op(), - Some(SsaOp::Jump { target: 3 }) - )); - } - - // Block 3 should have no dead tail - assert_eq!(ssa.block(3).unwrap().instruction_count(), 1); - } -} diff --git a/dotscope/src/compiler/passes/copying.rs b/dotscope/src/compiler/passes/copying.rs index 15d407b7..662a313f 100644 --- a/dotscope/src/compiler/passes/copying.rs +++ b/dotscope/src/compiler/passes/copying.rs @@ -1,71 +1,33 @@ -//! Copy propagation pass. +//! Copy propagation pass — thin wrapper. //! -//! This pass eliminates redundant copy operations by replacing uses of -//! copy destinations with their sources. This simplifies the SSA graph -//! and enables further optimizations. +//! Pure-SSA transformation logic lives in [`analyssa::passes::copying`]. This +//! file contributes: //! -//! # Example -//! -//! Before: -//! ```text -//! v1 = v0 // Copy -//! v2 = add v1, 5 -//! ret v1 -//! ``` -//! -//! After (with v1 replaced by v0): -//! ```text -//! v1 = v0 // Can now be eliminated by DCE -//! v2 = add v0, 5 -//! ret v0 -//! ``` -//! -//! # Algorithm -//! -//! The pass uses an iterative fixed-point algorithm: -//! -//! 1. Build a map of all copy-like operations: -//! - Explicit `Copy` instructions -//! - Trivial phi nodes (all operands identical after excluding self-references) -//! 2. Resolve copy chains to find ultimate sources (v2 → v1 → v0 becomes v2 → v0) -//! 3. Replace all uses of copy destinations with their ultimate sources -//! 4. Repeat until no more changes (fixed point) -//! -//! Dead code elimination will then remove the now-unused copy instructions. -//! -//! # Complexity -//! -//! - Time: O(n × m) where n is the number of variables and m is the number of iterations -//! - Space: O(n) for the copy map -//! -//! In practice, the algorithm converges quickly (usually 1-3 iterations). +//! 1. The [`SsaPass`] trait impl that the dotscope scheduler consumes. +//! 2. `propagate_local_types` — the CIL-side post-step that runs once per +//! iteration and propagates `SsaType` from local-origin destinations to +//! their ultimate sources. It uses `ssa.original_local_types()` (CIL +//! signature data) and `SsaType::from_type_signature(..., assembly)`, +//! so it cannot move into analyssa. use std::collections::BTreeMap; +use analyssa::passes::copying; + use crate::{ - analysis::{PhiAnalyzer, SsaFunction, SsaOp, SsaType, SsaVarId, VariableOrigin}, + analysis::{CilTarget, MethodRef, SsaFunction, SsaType, SsaVarId, VariableOrigin}, compiler::{ pass::{ModificationScope, SsaPass}, - passes::utils::resolve_chain, - CompilerContext, EventKind, EventLog, + CompilerContext, }, - metadata::token::Token, - utils::BitSet, - CilObject, Result, + CilObject, }; /// Copy propagation pass. /// -/// Tracks copy operations and propagates the source to all uses of the copy. -/// Uses an iterative fixed-point algorithm to handle cascading copies and -/// newly exposed opportunities after each round of propagation. -/// -/// # Handled Cases -/// -/// - Direct copy instructions: `v1 = copy v0` -/// - Trivial phi nodes: `v1 = phi(v0, v0, v0)` (all operands identical) -/// - Self-referential phis: `v1 = phi(v0, v1)` → `v1 = v0` -/// - Copy chains: `v2 = v1; v1 = v0` → both map to `v0` +/// Tracks copy operations and propagates the source to all uses of the +/// copy. Uses an iterative fixed-point algorithm to handle cascading copies +/// and newly exposed opportunities after each round of propagation. pub struct CopyPropagationPass { /// Maximum fixpoint iterations before stopping. max_iterations: usize, @@ -74,278 +36,15 @@ pub struct CopyPropagationPass { impl CopyPropagationPass { /// Creates a new copy propagation pass. /// - /// # Arguments - /// - /// * `max_iterations` - Maximum fixpoint iterations before stopping. Copy chains - /// converge in ~3 iterations; the default config value is 15. + /// `max_iterations` caps the inner fixpoint loop. Copy chains converge + /// in ~3 iterations; the default config value is 15. #[must_use] pub fn new(max_iterations: usize) -> Self { Self { max_iterations } } - - /// Removes copies from the map when they are the sole instruction-based - /// definition of a local or argument group and the source is in a different - /// group. - /// - /// After CFF unflattening, `stloc.0` becomes a Copy from a stack-temp - /// variable to a Local(0) variable. If copy propagation eliminates this - /// Copy, Local(0)'s group loses its only definition. Subsequent - /// `rebuild_ssa` calls then assign the entry value (null) to all uses - /// of Local(0), corrupting the data flow. - fn protect_sole_local_defs(ssa: &SsaFunction, copies: &mut BTreeMap) { - let real_local_limit = ssa.num_args().saturating_add(ssa.num_locals()) as u32; - - // Count instruction-based definitions per local/argument group. - let mut group_def_count: BTreeMap = BTreeMap::new(); - for block in ssa.blocks() { - for instr in block.instructions() { - if let Some(dest) = instr.op().dest() { - let group = ssa.rename_group(dest); - if group < real_local_limit { - let counter = group_def_count.entry(group).or_insert(0); - *counter = counter.saturating_add(1); - } - } - } - } - - // Identify local/arg groups whose variables appear as phi operands. - // These are the groups at risk: after copy-prop eliminates the bridging - // Copy, rebuild_ssa's rename may not be able to reconstruct the correct - // reaching definition through phi chains. - let group_bound = ssa - .num_locals() - .saturating_add(ssa.num_args()) - .saturating_add(1); - let mut groups_in_phis = BitSet::new(group_bound); - for block in ssa.blocks() { - for phi in block.phi_nodes() { - for operand in phi.operands() { - let group = ssa.rename_group(operand.value()); - if group < real_local_limit { - groups_in_phis.insert(group as usize); - } - } - let result_group = ssa.rename_group(phi.result()); - if result_group < real_local_limit { - groups_in_phis.insert(result_group as usize); - } - } - } - - // Identify address-taken local groups (locals accessed via ldloca). - // These locals are read through pointers, not through SSA variables. - // If copy-prop eliminates the stloc Copy, the local is never - // initialized and reads through the pointer see the default value - // (0/null) instead of the stored constant. - let mut address_taken_groups = BitSet::new(group_bound); - for block in ssa.blocks() { - for instr in block.instructions() { - if let SsaOp::LoadLocalAddr { local_index, .. } = instr.op() { - let group = (ssa.num_args() as u32).saturating_add(*local_index as u32); - if group < real_local_limit { - address_taken_groups.insert(group as usize); - } - } - } - } - - // Collect copy dests to protect: dest is in a local/arg group with - // the source in a different group, AND either: - // (a) the group has exactly 1 instruction def and participates in - // phi nodes (cross-block flow — removing the sole def would - // leave the group undefined), OR - // (b) the group is address-taken (accessed via ldloca). ALL stores - // to address-taken locals must be preserved regardless of def - // count, because the runtime reads the actual memory location - // through the pointer. Removing any store causes the pointer - // read to see stale/default values (e.g., Monitor.Enter reads - // lockTaken=true from a previous lock instead of the fresh - // false initialization). - let mut protected = BitSet::new(ssa.var_id_capacity()); - for (&dest, &src) in copies.iter() { - let dest_group = ssa.rename_group(dest); - if dest_group >= real_local_limit { - continue; // Not a local/arg group - } - - if ssa.rename_group(src) == dest_group { - continue; // Same group — safe to propagate - } - let def_count = group_def_count.get(&dest_group).copied().unwrap_or(0); - if address_taken_groups.contains(dest_group as usize) - || (def_count <= 1 && groups_in_phis.contains(dest_group as usize)) - { - protected.insert(dest.index()); - } - } - - if !protected.is_empty() { - copies.retain(|dest, _| !protected.contains(dest.index())); - } - } - - /// Resolves all copy chains to their ultimate sources. - /// - /// Uses the shared `resolve_chain` utility to follow each copy to its - /// ultimate source, handling cycles correctly. - /// - /// # Arguments - /// - /// * `copies` - Map of direct copies (dest → immediate source). - /// - /// # Returns - /// - /// Map of each copy destination to its ultimate source. - fn resolve_chains(copies: &BTreeMap) -> BTreeMap { - copies - .iter() - .map(|(&dest, &src)| (dest, resolve_chain(copies, src))) - .collect() - } - - /// Runs a single iteration of copy propagation. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `method_token` - The method token for change tracking. - /// * `changes` - The change set to record modifications. - /// * `assembly` - The assembly context for type resolution. - /// - /// # Returns - /// - /// The number of uses that were replaced. - fn run_iteration( - ssa: &mut SsaFunction, - method_token: Token, - changes: &mut EventLog, - assembly: &CilObject, - ) -> usize { - // Step 1: Collect all copy-like operations - let mut copies = PhiAnalyzer::new(ssa).collect_all_copies(); - - if copies.is_empty() { - return 0; - } - - // Step 1b: Protect cross-group copies that are the sole definition - // of a local/argument group. Without this, propagating the copy - // replaces all uses of the local-group variable with the source - // (from a different group), and DCE removes the now-dead Copy. - // Subsequent rebuild_ssa then finds the local group has no - // instruction-based definitions, producing the entry value (null/0) - // instead of the actual stored value. - // - // This specifically prevents the JIEJIE.NET Issue 13 pattern where - // `stloc.0` creates a Copy bridging a stack-temp group to Local(0), - // and propagating it disconnects the local from its value. - Self::protect_sole_local_defs(ssa, &mut copies); - - // Step 2: Resolve chains to ultimate sources - let resolved = Self::resolve_chains(&copies); - - // Step 3: Propagate types from Local-origin destinations to their sources - Self::propagate_local_types(ssa, &resolved, assembly); - - // Step 4: Apply propagations and record events - let result = ssa.propagate_copies(&resolved); - - for dest_idx in result - .fully_propagated - .iter() - .chain(result.partially_propagated.iter()) - { - let dest = SsaVarId::from_index(dest_idx); - if let Some(src) = resolved.get(&dest) { - changes - .record(EventKind::CopyPropagated) - .method(method_token) - .message(format!("{dest} → {src}")); - } - } - - // Step 5: Neutralize dead Copy instructions whose dests were fully propagated - for dest_idx in result.fully_propagated.iter() { - ssa.nop_copy_defining(SsaVarId::from_index(dest_idx)); - } - - result.total_replaced - } - - /// Propagates types from Local-origin destinations to their ultimate sources. - /// - /// When a Local-origin variable is a copy destination (e.g., `local_0 = copy phi_result`), - /// the source variable should inherit the local's original type. This ensures that - /// after copy propagation eliminates the intermediate copy, the source retains the - /// correct type information for code generation. - /// - /// This follows the .NET JIT's approach of keeping local slot types (`lvType`) - /// separate from IR/computational types (`gtType`), ensuring original types are - /// preserved through optimization. - fn propagate_local_types( - ssa: &mut SsaFunction, - resolved: &BTreeMap, - assembly: &CilObject, - ) { - // Get the original local types from the SSA function - let original_types = match ssa.original_local_types() { - Some(types) => types.to_vec(), - None => return, - }; - - // Collect type assignments first (can't borrow mutably while iterating) - let mut type_assignments: Vec<(SsaVarId, SsaType)> = Vec::new(); - - for (dest, src) in resolved { - if dest == src { - continue; - } - - // Check if the destination is a Local-origin variable - let Some(dest_var) = ssa.variable(*dest) else { - continue; - }; - let VariableOrigin::Local(local_idx) = dest_var.origin() else { - continue; - }; - - // Get the original type for this local - let local_type = match original_types.get(local_idx as usize) { - Some(sig) => &sig.base, - None => continue, - }; - - // Convert to SsaType - let ssa_type = SsaType::from_type_signature(local_type, assembly); - - // Only propagate if the type is known (not Unknown/I32) - if ssa_type.is_unknown() || matches!(ssa_type, SsaType::I32) { - continue; - } - - // Check if the source variable currently has Unknown type - // Only propagate if we're improving the type information - let should_propagate = match ssa.variable(*src) { - Some(src_var) => src_var.var_type().is_unknown(), - None => false, - }; - - if should_propagate { - type_assignments.push((*src, ssa_type)); - } - } - - // Apply type assignments - for (var_id, ssa_type) in type_assignments { - if let Some(var) = ssa.variable_mut(var_id) { - var.set_type(ssa_type); - } - } - } } -impl SsaPass for CopyPropagationPass { +impl SsaPass for CopyPropagationPass { fn name(&self) -> &'static str { "copy-propagation" } @@ -361,677 +60,82 @@ impl SsaPass for CopyPropagationPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - - // Iterate until fixed point - for _ in 0..self.max_iterations { - let replaced = Self::run_iteration(ssa, method_token, &mut changes, assembly); - - if replaced == 0 { - break; - } - } - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly = host + .assembly() + .ok_or_else(|| analyssa::Error::new("CopyPropagationPass requires an assembly"))?; + let changed = copying::run_with_hook( + ssa, + method, + &host.events, + self.max_iterations, + |ssa, resolved| propagate_local_types(ssa, resolved, &assembly), + ); Ok(changed) } } -#[cfg(test)] -mod tests { - use std::{collections::BTreeMap, sync::Arc}; - - use crate::{ - analysis::{ - CallGraph, ConstValue, DefSite, PhiAnalyzer, PhiNode, PhiOperand, SsaBlock, - SsaFunction, SsaFunctionBuilder, SsaInstruction, SsaOp, SsaType, SsaVarId, - VariableOrigin, - }, - compiler::CompilerContext, - compiler::{CopyPropagationPass, SsaPass}, - metadata::token::Token, - test::helpers::test_assembly_arc, +/// CIL-side post-step: propagates types from Local-origin destinations to +/// their ultimate sources. +/// +/// When a Local-origin variable is a copy destination (e.g. +/// `local_0 = copy phi_result`), the source variable should inherit the +/// local's original type. This ensures that after copy propagation +/// eliminates the intermediate copy, the source retains the correct type +/// information for code generation. +/// +/// Mirrors the .NET JIT's approach of keeping local slot types (`lvType`) +/// separate from IR/computational types (`gtType`). +fn propagate_local_types( + ssa: &mut SsaFunction, + resolved: &BTreeMap, + assembly: &CilObject, +) { + let original_types = match ssa.original_local_types() { + Some(types) => types.to_vec(), + None => return, }; - /// Helper to create a minimal analysis context for testing. - fn test_context() -> CompilerContext { - let call_graph = Arc::new(CallGraph::new()); - CompilerContext::new(call_graph) - } - - #[test] - fn test_collect_empty_function() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - let copies = PhiAnalyzer::new(&ssa).collect_all_copies(); - assert!(copies.is_empty()); - } - - #[test] - fn test_collect_single_copy() { - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let v1 = b.copy(v0); - v0_out = v0; - v1_out = v1; - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let copies = PhiAnalyzer::new(&ssa).collect_all_copies(); - assert_eq!(copies.len(), 1); - assert_eq!(copies.get(&v1), Some(&v0)); - } - - #[test] - fn test_collect_multiple_copies() { - let (ssa, v0, v1, v2) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let mut v2_out = SsaVarId::from_index(2); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let v1 = b.copy(v0); - let v2 = b.copy(v1); - v0_out = v0; - v1_out = v1; - v2_out = v2; - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out, v2_out) - }; + let mut type_assignments: Vec<(SsaVarId, SsaType)> = Vec::new(); - let copies = PhiAnalyzer::new(&ssa).collect_all_copies(); - assert_eq!(copies.len(), 2); - assert_eq!(copies.get(&v1), Some(&v0)); - assert_eq!(copies.get(&v2), Some(&v1)); - } - - #[test] - fn test_collect_trivial_phi_all_same() { - let (ssa, v0, v_phi) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v_phi_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let cond = b.const_true(); - v0_out = v0; - b.branch(cond, 1, 2); - }); - f.block(1, |b| b.jump(3)); - f.block(2, |b| b.jump(3)); - f.block(3, |b| { - // phi with all same operands (v0 from both paths) - let phi_result = b.phi(&[(1, v0_out), (2, v0_out)]); - v_phi_out = phi_result; - b.ret_val(phi_result); - }); - }) - .unwrap(); - (ssa, v0_out, v_phi_out) - }; - - let copies = PhiAnalyzer::new(&ssa).collect_all_copies(); - assert_eq!(copies.len(), 1); - assert_eq!(copies.get(&v_phi), Some(&v0)); - } - - #[test] - fn test_collect_trivial_phi_with_self_reference() { - // Self-referential phi where the phi references itself (for loop back-edges) - // We need to manually construct this since the builder can't create self-references - let mut ssa = SsaFunction::new(0, 0); - - // Create variables - let v0 = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let phi_var = ssa.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::phi(1), - SsaType::Unknown, - ); - - // Block 0: entry, defines v0, jumps to block 1 - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v0, - value: ConstValue::I32(42), - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - ssa.add_block(block0); - - // Block 1: loop header with self-referential phi - // phi_var = phi(v0 from block 0, phi_var from block 1) - let mut block1 = SsaBlock::new(1); - let mut phi = PhiNode::new(phi_var, VariableOrigin::Local(1)); - phi.add_operand(PhiOperand::new(v0, 0)); // from block 0 - phi.add_operand(PhiOperand::new(phi_var, 1)); // from block 1 (self-reference) - block1.add_phi(phi); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { - value: Some(phi_var), - })); - ssa.add_block(block1); - - let copies = PhiAnalyzer::new(&ssa).collect_all_copies(); - // phi_var = phi(v0, phi_var) should be detected as trivial (phi_var → v0) - assert_eq!(copies.len(), 1); - assert_eq!(copies.get(&phi_var), Some(&v0)); - } - - #[test] - fn test_collect_non_trivial_phi() { - let ssa = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - f.block(1, |b| { - let v0 = b.const_i32(10); - v0_out = v0; - b.jump(3); - }); - f.block(2, |b| { - let v1 = b.const_i32(20); - v1_out = v1; - b.jump(3); - }); - f.block(3, |b| { - // phi with different operands - let phi_result = b.phi(&[(1, v0_out), (2, v1_out)]); - b.ret_val(phi_result); - }); - }) - .unwrap() - }; - - let copies = PhiAnalyzer::new(&ssa).collect_all_copies(); - // Non-trivial phi should not be collected - assert!(copies.is_empty()); - } - - #[test] - fn test_resolve_simple_chain() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let mut copies = BTreeMap::new(); - // v2 → v1 → v0 - copies.insert(v2, v1); - copies.insert(v1, v0); - - let resolved = CopyPropagationPass::resolve_chains(&copies); - - // Both should resolve to v0 - assert_eq!(resolved.get(&v1), Some(&v0)); - assert_eq!(resolved.get(&v2), Some(&v0)); - } - - #[test] - fn test_resolve_long_chain() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let v3 = SsaVarId::from_index(3); - let v4 = SsaVarId::from_index(4); - let mut copies = BTreeMap::new(); - // v4 → v3 → v2 → v1 → v0 - copies.insert(v4, v3); - copies.insert(v3, v2); - copies.insert(v2, v1); - copies.insert(v1, v0); - - let resolved = CopyPropagationPass::resolve_chains(&copies); - - // All should resolve to v0 - assert_eq!(resolved.get(&v1), Some(&v0)); - assert_eq!(resolved.get(&v2), Some(&v0)); - assert_eq!(resolved.get(&v3), Some(&v0)); - assert_eq!(resolved.get(&v4), Some(&v0)); - } - - #[test] - fn test_resolve_cycle() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let mut copies = BTreeMap::new(); - // v1 → v2 → v1 (cycle) - copies.insert(v1, v2); - copies.insert(v2, v1); - - let resolved = CopyPropagationPass::resolve_chains(&copies); - - // Should handle cycle gracefully (stop at some point in the cycle) - assert!(resolved.contains_key(&v1)); - assert!(resolved.contains_key(&v2)); - } - - #[test] - fn test_resolve_multiple_independent_chains() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let v3 = SsaVarId::from_index(3); - let v4 = SsaVarId::from_index(4); - let v5 = SsaVarId::from_index(5); - let mut copies = BTreeMap::new(); - // Chain 1: v2 → v1 → v0 - copies.insert(v2, v1); - copies.insert(v1, v0); - // Chain 2: v5 → v4 → v3 - copies.insert(v5, v4); - copies.insert(v4, v3); - - let resolved = CopyPropagationPass::resolve_chains(&copies); - - // Chain 1 resolves to v0 - assert_eq!(resolved.get(&v1), Some(&v0)); - assert_eq!(resolved.get(&v2), Some(&v0)); - // Chain 2 resolves to v3 - assert_eq!(resolved.get(&v4), Some(&v3)); - assert_eq!(resolved.get(&v5), Some(&v3)); - } - - #[test] - fn test_trivial_phi_single_operand() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let v0 = SsaVarId::from_index(1); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); - - let source = analyzer.is_trivial(&phi); - assert_eq!(source, Some(v0)); - } - - #[test] - fn test_trivial_phi_all_same_operands() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let result = SsaVarId::from_index(0); - let v0 = SsaVarId::from_index(1); - let mut phi = PhiNode::new(result, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); - phi.add_operand(PhiOperand::new(v0, 1)); - phi.add_operand(PhiOperand::new(v0, 2)); - - let source = analyzer.is_trivial(&phi); - assert_eq!(source, Some(v0)); - } - - #[test] - fn test_trivial_phi_with_self_references() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let mut phi = PhiNode::new(v1, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); // Non-self - phi.add_operand(PhiOperand::new(v1, 1)); // Self-reference - phi.add_operand(PhiOperand::new(v1, 2)); // Self-reference - - let source = analyzer.is_trivial(&phi); - assert_eq!(source, Some(v0)); - } - - #[test] - fn test_non_trivial_phi_different_operands() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v5 = SsaVarId::from_index(2); - let mut phi = PhiNode::new(v5, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v0, 0)); - phi.add_operand(PhiOperand::new(v1, 1)); - - let source = analyzer.is_trivial(&phi); - assert_eq!(source, None); - } - - #[test] - fn test_trivial_phi_all_self_references() { - let ssa = SsaFunction::new(0, 0); - let analyzer = PhiAnalyzer::new(&ssa); - - let v1 = SsaVarId::from_index(0); - let mut phi = PhiNode::new(v1, VariableOrigin::Local(0)); - phi.add_operand(PhiOperand::new(v1, 0)); // Self - phi.add_operand(PhiOperand::new(v1, 1)); // Self - - let source = analyzer.is_trivial(&phi); - // All self-references means no unique source - assert_eq!(source, None); - } - - #[test] - fn test_propagate_single_copy() { - let (mut ssa, v0, _v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let v1 = b.copy(v0); - let _v2 = b.add(v1, v1); - v0_out = v0; - v1_out = v1; - b.ret_val(v1); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Should have propagated v1 → v0 - assert!(changed); - - // Verify: add should now use v0 - let block = ssa.block(0).unwrap(); - let add_instr = &block.instructions()[2]; - if let SsaOp::Add { left, right, .. } = add_instr.op() { - assert_eq!(*left, v0); - assert_eq!(*right, v0); - } else { - panic!("Expected Add instruction"); + for (dest, src) in resolved { + if dest == src { + continue; } - // Verify: return should now use v0 - let ret_instr = &block.instructions()[3]; - if let SsaOp::Return { value } = ret_instr.op() { - assert_eq!(*value, Some(v0)); - } else { - panic!("Expected Return instruction"); - } - } - - #[test] - fn test_propagate_copy_chain() { - let (mut ssa, v0) = { - let mut v0_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let v1 = b.copy(v0); - let v2 = b.copy(v1); - let v3 = b.copy(v2); - v0_out = v0; - b.ret_val(v3); - }); - }) - .unwrap(); - (ssa, v0_out) + let Some(dest_var) = ssa.variable(*dest) else { + continue; }; - - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Verify: return should now use v0 (ultimate source) - let block = ssa.block(0).unwrap(); - if let Some(SsaOp::Return { value }) = block.terminator_op() { - assert_eq!(*value, Some(v0)); - } else { - panic!("Expected Return instruction"); - } - } - - #[test] - fn test_propagate_trivial_phi() { - let (mut ssa, v0) = { - let mut v0_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: entry - f.block(0, |b| { - v0_out = b.const_i32(42); - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - // Block 1 - f.block(1, |b| b.jump(3)); - // Block 2 - f.block(2, |b| b.jump(3)); - // Block 3: trivial phi (both operands are v0) - f.block(3, |b| { - let phi_result = b.phi(&[(1, v0_out), (2, v0_out)]); - // Use phi result - let _ = b.add(phi_result, phi_result); - b.ret_val(phi_result); - }); - }) - .unwrap(); - (ssa, v0_out) + let VariableOrigin::Local(local_idx) = dest_var.origin() else { + continue; }; - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - assert!(changed); - - // Verify: uses of phi should be replaced with v0 - let block3 = ssa.block(3).unwrap(); - let add_instr = &block3.instructions()[0]; - if let SsaOp::Add { left, right, .. } = add_instr.op() { - assert_eq!(*left, v0); - assert_eq!(*right, v0); - } - } - - #[test] - fn test_no_propagation_needed() { - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - b.ret_val(v0); // no copies - }); - }) - .unwrap(); - - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // No copies, no changes - assert!(!changed); - } - - #[test] - fn test_iterative_convergence() { - // Test that the pass converges even with complex copy patterns - let (mut ssa, v0) = { - let mut v0_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - v0_out = v0; - // Create a chain: v1 = v0, v2 = v1, v3 = v2 - let v1 = b.copy(v0); - let v2 = b.copy(v1); - let v3 = b.copy(v2); - // Use all copies - let v10 = b.add(v1, v2); - let v11 = b.add(v10, v3); - b.ret_val(v11); - }); - }) - .unwrap(); - (ssa, v0_out) + let local_type = match original_types.get(local_idx as usize) { + Some(sig) => &sig.base, + None => continue, }; - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let result = - pass.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()); - - // Should complete without error (convergence) - assert!(result.is_ok()); + let ssa_type = SsaType::from_type_signature(local_type, assembly); - // Verify all uses point to v0 - let block = ssa.block(0).unwrap(); - - // Check first add: should be add v0, v0 - let add1 = &block.instructions()[4]; - if let SsaOp::Add { left, right, .. } = add1.op() { - assert_eq!(*left, v0); - assert_eq!(*right, v0); + if ssa_type.is_unknown() || matches!(ssa_type, SsaType::I32) { + continue; } - // Check second add: right should be v0 - let add2 = &block.instructions()[5]; - if let SsaOp::Add { right, .. } = add2.op() { - assert_eq!(*right, v0); - } - } - - #[test] - fn test_copy_not_propagated_to_definition() { - // Ensure we don't replace the copy's own definition - let (mut ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - v0_out = v0; - let v1 = b.copy(v0); - v1_out = v1; - b.ret_val(v1); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) + let should_propagate = match ssa.variable(*src) { + Some(src_var) => src_var.var_type().is_unknown(), + None => false, }; - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // The copy instruction itself should remain unchanged (dest is still v1) - let block = ssa.block(0).unwrap(); - let copy_instr = &block.instructions()[1]; - if let SsaOp::Copy { dest, src } = copy_instr.op() { - assert_eq!(*dest, v1); - assert_eq!(*src, v0); + if should_propagate { + type_assignments.push((*src, ssa_type)); } } - #[test] - fn test_phi_operands_preserved() { - // Test that copy propagation does NOT replace PHI operands. - // This is intentional: replacing PHI operands can create cross-origin - // references that break rebuild_ssa's assumption that each variable - // flows to at most one PHI origin. - let (mut ssa, _v0, v1, v2) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let mut v2_out = SsaVarId::from_index(2); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: entry - f.block(0, |b| { - let v0 = b.const_i32(42); - v0_out = v0; - let v1 = b.copy(v0); - v1_out = v1; - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - // Block 1: defines v2 - f.block(1, |b| { - v2_out = b.const_i32(100); - b.jump(3); - }); - // Block 2: just jumps - f.block(2, |b| b.jump(3)); - // Block 3: phi using v1 (copy of v0) and v2 - f.block(3, |b| { - let phi_result = b.phi(&[(1, v2_out), (2, v1_out)]); - b.ret_val(phi_result); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out, v2_out) - }; - - // Run pass - let pass = CopyPropagationPass::new(15); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Verify: phi operands should be PRESERVED (not replaced) - // v1 should remain in the PHI, not be replaced with v0 - let block3 = ssa.block(3).unwrap(); - let phi = &block3.phi_nodes()[0]; - let operand_values: Vec<_> = phi.operands().iter().map(|op| op.value()).collect(); - - // One operand should be v2, the other should still be v1 (preserved) - assert!(operand_values.contains(&v2)); - assert!(operand_values.contains(&v1)); // v1 is preserved, not replaced with v0 + for (var_id, ssa_type) in type_assignments { + if let Some(var) = ssa.variable_mut(var_id) { + var.set_type(ssa_type); + } } } diff --git a/dotscope/src/compiler/passes/deadcode.rs b/dotscope/src/compiler/passes/deadcode.rs index f0793ed0..2ec6cc08 100644 --- a/dotscope/src/compiler/passes/deadcode.rs +++ b/dotscope/src/compiler/passes/deadcode.rs @@ -1,1073 +1,31 @@ -//! Dead code elimination pass. +//! Global dead-method elimination — CIL adapter. //! -//! This pass performs comprehensive dead code elimination including: -//! -//! 1. **Unreachable block elimination**: Remove blocks that cannot be reached -//! 2. **Dead instruction elimination**: Remove instructions whose results are unused -//! 3. **Trivial phi elimination**: Remove phi nodes with only one unique operand -//! 4. **Dead phi elimination**: Remove phi nodes whose results are never used -//! 5. **Phi operand pruning**: Remove stale operands from unreachable predecessors -//! 6. **Self-referential phi simplification**: Simplify phis like `v1 = phi(v1, v2)` to `v1 = v2` -//! -//! # Algorithm -//! -//! The pass uses an iterative worklist algorithm: -//! 1. Mark entry block and exception handlers as roots -//! 2. Compute reachable blocks via control flow traversal -//! 3. Prune phi operands from unreachable predecessors -//! 4. Compute live variables via reverse dataflow -//! 5. Remove dead definitions and trivial phis -//! 6. Repeat until no changes (fixed point) -//! -//! # Prerequisites -//! -//! This pass works best after constant propagation and branch simplification, -//! as those passes may expose more dead code. +//! Per-method DCE uses analyssa's blanket +//! [`analyssa::passes::DeadCodeEliminationPass`] directly. This file keeps +//! [`DeadMethodEliminationPass`], which uses a CIL-specific [`CtxWorld`] +//! adapter that combines the SSA-derived call graph +//! (`ctx.build_ssa_call_graph()`) with the static call graph +//! (`ctx.call_graph`) so the reachability walk sees both. The combined +//! view is richer than what `World` on `CompilerContext` +//! provides today (which only reads the static call graph), so the +//! adapter stays. + +use std::collections::{BTreeMap, BTreeSet}; -use std::collections::{BTreeMap, BTreeSet, VecDeque}; +use analyssa::{passes::deadcode, World}; use crate::{ - analysis::{ - PhiAnalyzer, PhiNode, SsaCfg, SsaFunction, SsaInstruction, SsaOp, SsaVarId, VariableOrigin, - }, - compiler::{ - pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, EventLog, - }, + analysis::{CilTarget, MethodRef, SsaFunction}, + compiler::{pass::SsaPass, CompilerContext}, metadata::token::Token, - utils::{ - graph::{algorithms, NodeId}, - BitSet, - }, - CilObject, Result, }; -/// Finds blocks that have dead code after terminator instructions. -/// -/// Identifies blocks where a return, throw, or other terminator is followed -/// by unreachable instructions that should be removed. This can happen when -/// control flow simplification leaves behind unreachable code. -/// -/// # Arguments -/// -/// * `ssa` - The SSA function to analyze. -/// -/// # Returns -/// -/// A vector of (block index, first dead instruction index) pairs. -#[must_use] -pub fn find_dead_tails(ssa: &SsaFunction) -> Vec<(usize, usize)> { - ssa.iter_blocks() - .filter_map(|(block_idx, block)| { - // Find first terminator - let last_idx = block.instruction_count().checked_sub(1)?; - for (instr_idx, instr) in block.instructions().iter().enumerate() { - if instr.op().is_terminator() && instr_idx < last_idx { - // There are instructions after the terminator - return Some((block_idx, instr_idx.saturating_add(1))); - } - } - None - }) - .collect() -} - -/// Dead code elimination pass. -/// -/// Removes unreachable blocks and unused definitions to simplify the SSA graph. -/// Uses an iterative algorithm to handle cascading dead code. -pub struct DeadCodeEliminationPass { - /// Maximum fixpoint iterations before stopping. - max_iterations: usize, -} - -impl DeadCodeEliminationPass { - /// Creates a new dead code elimination pass. - /// - /// # Arguments - /// - /// * `max_iterations` - Maximum fixpoint iterations before stopping. Typical - /// convergence is 2–3 iterations; the default config value is 20. - #[must_use] - pub fn new(max_iterations: usize) -> Self { - Self { max_iterations } - } - - /// Finds all reachable blocks starting from entry and exception handlers. - /// - /// Uses the graph infrastructure's BFS traversal to find blocks reachable from: - /// - Block 0 (the entry block) - /// - Exception handler blocks from SSA exception handler info - /// - Fallback: blocks starting with `EndFinally` or `Rethrow` instructions - /// - /// Exception handlers are treated as additional roots since they may be - /// reachable via implicit exception edges that are not explicitly represented - /// in the SSA graph. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// - /// # Returns - /// - /// A set of block indices that are reachable from the entry point or exception handlers. - fn find_reachable_blocks(ssa: &SsaFunction) -> BitSet { - if ssa.block_count() == 0 { - return BitSet::new(0); - } - - // Build CFG view for graph traversal - let cfg = SsaCfg::from_ssa(ssa); - - // Use BFS from entry block (block 0) to find reachable nodes - let mut reachable = BitSet::new(ssa.block_count()); - for n in algorithms::bfs(&cfg, NodeId::new(0)) { - let n: NodeId = n; - reachable.insert(n.index()); - } - - // Collect exception handler entry blocks from SSA exception handler info - let mut exception_roots = BitSet::new(ssa.block_count()); - for handler in ssa.exception_handlers() { - // Add handler start block as a root - if let Some(handler_block) = handler.handler_start_block { - if !reachable.contains(handler_block) { - exception_roots.insert(handler_block); - } - } - // Add filter start block for FILTER handlers - if let Some(filter_block) = handler.filter_start_block { - if !reachable.contains(filter_block) { - exception_roots.insert(filter_block); - } - } - } - - // Fallback: find exception handler blocks by instruction patterns - // (for methods where exception handler info wasn't preserved) - for (block_idx, block) in ssa.iter_blocks() { - if reachable.contains(block_idx) || exception_roots.contains(block_idx) { - continue; - } - // Check if this block starts with exception handling instructions - if let Some(first_instr) = block.instructions().first() { - // EndFinally and Rethrow indicate exception handler blocks - if matches!(first_instr.op(), SsaOp::EndFinally | SsaOp::Rethrow) { - exception_roots.insert(block_idx); - } - } - } - - // Traverse from each exception handler root - for root in exception_roots.iter() { - for node in algorithms::bfs(&cfg, NodeId::new(root)) { - let node: NodeId = node; - reachable.insert(node.index()); - } - } - - reachable - } - - /// Computes reverse post-order of reachable blocks for efficient dataflow traversal. - /// - /// Reverse post-order (RPO) is an ordering where each block appears before its - /// successors (except for back edges in loops). This ordering is optimal for - /// forward dataflow analysis as it minimizes the number of iterations needed - /// to reach a fixed point. - /// - /// Uses the graph infrastructure's `reverse_postorder` algorithm. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function containing the blocks. - /// * `reachable` - The set of reachable block indices to include in the ordering. - /// - /// # Returns - /// - /// A vector of block indices in reverse post-order. The entry block appears first, - /// and exit blocks appear last. - fn compute_reverse_postorder(ssa: &SsaFunction, reachable: &BitSet) -> Vec { - if ssa.block_count() == 0 || reachable.is_empty() { - return Vec::new(); - } - - // Build CFG view for graph traversal - let cfg = SsaCfg::from_ssa(ssa); - - // Use the graph infrastructure's reverse_postorder from entry block - let mut rpo: Vec = algorithms::reverse_postorder(&cfg, NodeId::new(0)) - .into_iter() - .map(|n: NodeId| n.index()) - .filter(|idx| reachable.contains(*idx)) - .collect(); - - // Handle any remaining reachable blocks (exception handlers) not covered - // by traversal from entry. Add them in sorted order for determinism. - let mut in_rpo = BitSet::new(ssa.block_count()); - for &idx in &rpo { - in_rpo.insert(idx); - } - let mut additional: Vec = reachable - .iter() - .filter(|idx| !in_rpo.contains(*idx)) - .collect(); - additional.sort_unstable(); - - // For exception handlers, compute RPO from each root - for &root in &additional { - let handler_rpo: Vec = algorithms::reverse_postorder(&cfg, NodeId::new(root)) - .into_iter() - .map(|n: NodeId| n.index()) - .filter(|idx| reachable.contains(*idx) && !rpo.contains(idx)) - .collect(); - rpo.extend(handler_rpo); - } - - rpo - } - - /// Computes the set of live variables using reverse dataflow analysis. - /// - /// A variable is considered live if any of the following conditions hold: - /// 1. It's used by a side-effectful instruction (calls, stores, etc.) - /// 2. It's used as a return value - /// 3. It's used as a thrown exception value - /// 4. It's transitively used by another live variable's definition - /// - /// The algorithm uses a two-phase approach: - /// 1. Mark initial live variables (roots) from side-effectful uses - /// 2. Propagate liveness backwards through the def-use chain - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// * `reachable` - The set of reachable block indices. - /// * `rpo` - The blocks in reverse post-order for efficient traversal. - /// - /// # Returns - /// - /// A set of SSA variable IDs that are live (used by observable operations). - fn compute_live_variables(ssa: &SsaFunction, reachable: &BitSet, rpo: &[usize]) -> BitSet { - let mut live = BitSet::new(ssa.var_id_capacity()); - let mut worklist = VecDeque::new(); - - // Phase 1: Mark variables used by side-effectful operations as live - for &block_idx in rpo { - if !reachable.contains(block_idx) { - continue; - } - - if let Some(block) = ssa.block(block_idx) { - for instr in block.instructions() { - let op = instr.op(); - // Instructions with side effects make their operands live - if !op.is_pure() { - for var in op.uses() { - if live.insert(var.index()) { - worklist.push_back(var); - } - } - } - - // Return values are live - if let SsaOp::Return { value: Some(v) } = op { - if live.insert(v.index()) { - worklist.push_back(*v); - } - } - - // Thrown exceptions are live - if let SsaOp::Throw { exception } = op { - if live.insert(exception.index()) { - worklist.push_back(*exception); - } - } - } - } - } - - // Phase 2: Propagate liveness backwards through definitions - // Build def-to-uses map for efficiency - let mut def_uses: BTreeMap> = BTreeMap::new(); - - // Collect all definitions per Local/Arg origin, and all LoadLocal/LoadArg - // instructions. LoadLocal/LoadArg reference locals/args by index (not SSA - // variable), creating an implicit dependency that needs bridging. - let mut origin_defs: BTreeMap> = BTreeMap::new(); - let mut load_local_info: Vec<(SsaVarId, VariableOrigin)> = Vec::new(); - - for &block_idx in rpo { - if !reachable.contains(block_idx) { - continue; - } - - if let Some(block) = ssa.block(block_idx) { - // Phi nodes: build def_uses and collect Local/Arg defs - for phi in block.phi_nodes() { - let def = phi.result(); - for operand in phi.operands() { - def_uses.entry(def).or_default().push(operand.value()); - } - let origin = phi.origin(); - if matches!( - origin, - VariableOrigin::Local(_) | VariableOrigin::Argument(_) - ) { - origin_defs.entry(origin).or_default().push(def); - } - } - - // Instructions: build def_uses, collect LoadLocal/LoadArg, track defs - for instr in block.instructions() { - let op = instr.op(); - if let Some(def) = op.dest() { - for use_var in op.uses() { - def_uses.entry(def).or_default().push(use_var); - } - - // Track Local/Arg-origin definitions (from stloc/starg Copies) - if let Some(var) = ssa.variable(def) { - let origin = var.origin(); - if matches!( - origin, - VariableOrigin::Local(_) | VariableOrigin::Argument(_) - ) { - origin_defs.entry(origin).or_default().push(def); - } - } - } - - match op { - SsaOp::LoadLocal { dest, local_index } => { - load_local_info.push((*dest, VariableOrigin::Local(*local_index))); - } - SsaOp::LoadLocalAddr { dest, local_index } => { - load_local_info.push((*dest, VariableOrigin::Local(*local_index))); - } - SsaOp::LoadArg { dest, arg_index } => { - load_local_info.push((*dest, VariableOrigin::Argument(*arg_index))); - } - _ => {} - } - } - } - } - - // Worklist algorithm: if a variable is live, its defining uses are live - while let Some(var) = worklist.pop_front() { - if let Some(uses) = def_uses.get(&var) { - for &use_var in uses { - if live.insert(use_var.index()) { - worklist.push_back(use_var); - } - } - } - } - - // Phase 3: Bridge the LoadLocal/LoadArg gap. - // For each live LoadLocal/LoadArg dest, mark ALL definitions of the - // corresponding Local(K)/Arg(K) origin as live. This is conservative - // (keeps all versions alive) but correct — the specific reaching def - // would require dominator tree traversal. - // - // This must loop because re-propagation can make new LoadLocal/LoadArg - // dests live (e.g., v6=Copy(v5) makes v5 live, and v5=LoadLocal(0) - // needs to bridge to defs of Local(0)). - loop { - let mut newly_live = false; - for (dest, origin) in &load_local_info { - if !live.contains(dest.index()) { - continue; - } - if let Some(defs) = origin_defs.get(origin) { - for &def_var in defs { - if live.insert(def_var.index()) { - worklist.push_back(def_var); - newly_live = true; - } - } - } - } - - // Re-propagate liveness from newly-live variables - while let Some(var) = worklist.pop_front() { - if let Some(uses) = def_uses.get(&var) { - for &use_var in uses { - if live.insert(use_var.index()) { - worklist.push_back(use_var); - } - } - } - } - - if !newly_live { - break; - } - } - - live - } - - /// Finds dead definitions (pure instructions whose results are never used). - /// - /// An instruction is dead if: - /// 1. It defines a variable that is not in the live set - /// 2. It has no side effects (is pure) - /// - /// Instructions with side effects (calls, stores, etc.) are never considered - /// dead, even if their result is unused, because removing them would change - /// program behavior. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// * `reachable` - The set of reachable block indices. - /// * `live` - The set of live variable IDs. - /// - /// # Returns - /// - /// A vector of `(block_idx, instruction_idx)` tuples identifying dead instructions. - fn find_dead_definitions( - ssa: &SsaFunction, - reachable: &BitSet, - live: &BitSet, - dead_phi_results: &BitSet, - ) -> Vec<(usize, usize)> { - // Track dead variables for Pop elimination - let mut dead_vars = BitSet::new(ssa.var_id_capacity()); - let mut dead = Vec::new(); - - for block_idx in reachable.iter() { - if let Some(block) = ssa.block(block_idx) { - for (instr_idx, instr) in block.instructions().iter().enumerate() { - let op = instr.op(); - // Skip instructions with side effects - if !op.is_pure() { - continue; - } - - // Skip Pop in first pass - handled below - if matches!(op, SsaOp::Pop { .. }) { - continue; - } - - match op.dest() { - None => { - // Pure instruction with no dest (like Nop) is always dead - dead.push((block_idx, instr_idx)); - } - Some(def) => { - if !live.contains(def.index()) { - dead.push((block_idx, instr_idx)); - dead_vars.insert(def.index()); - } - } - } - } - } - } - - // Second pass: find dead Pop instructions - // Pop is dead if its operand's definer is being removed in this iteration. - // Note: We intentionally don't check for definers Nop'd in previous iterations - // because that can cause stack depth mismatches with complex control flow. - // The basic check (same-iteration removal) handles the common case correctly. - for block_idx in reachable.iter() { - if let Some(block) = ssa.block(block_idx) { - for (instr_idx, instr) in block.instructions().iter().enumerate() { - if let SsaOp::Pop { value } = instr.op() { - // Check if operand's definer is being removed this iteration - let instr_definer_being_removed = dead_vars.contains(value.index()); - let phi_definer_being_removed = dead_phi_results.contains(value.index()); - - if instr_definer_being_removed || phi_definer_being_removed { - dead.push((block_idx, instr_idx)); - } - } - } - } - } - - dead - } - - /// Finds dead phi nodes (phi nodes whose results are never used). - /// - /// A phi node is dead if its result variable is not in the live set. - /// Unlike regular instructions, phi nodes never have side effects, - /// so they can always be removed if their result is unused. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// * `reachable` - The set of reachable block indices. - /// * `live` - The set of live variable IDs. - /// - /// # Returns - /// - /// A vector of `(block_idx, phi_idx)` tuples identifying dead phi nodes. - fn find_dead_phis(ssa: &SsaFunction, reachable: &BitSet, live: &BitSet) -> Vec<(usize, usize)> { - let mut dead = Vec::new(); - - for block_idx in reachable.iter() { - if let Some(block) = ssa.block(block_idx) { - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - if !live.contains(phi.result().index()) { - dead.push((block_idx, phi_idx)); - } - } - } - } - - dead - } - - /// Removes dead instructions by replacing them with `Nop` operations. - /// - /// Instructions are processed in reverse order within each block to preserve - /// indices during removal. Rather than actually removing instructions (which - /// would shift indices), this function replaces them with `Nop` operations - /// to maintain the instruction array structure. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `dead_defs` - A slice of `(block_idx, instruction_idx)` tuples identifying - /// instructions to remove. - /// * `method_token` - The metadata token of the method, used for change tracking. - /// * `changes` - The change set to record modifications in. - fn remove_instructions( - ssa: &mut SsaFunction, - dead_defs: &[(usize, usize)], - method_token: Token, - changes: &mut EventLog, - ) { - // Group by block - let mut by_block: BTreeMap> = BTreeMap::new(); - for &(block_idx, instr_idx) in dead_defs { - by_block.entry(block_idx).or_default().push(instr_idx); - } - - for (block_idx, mut indices) in by_block { - // Sort in reverse order to remove from end first - indices.sort_by(|a, b| b.cmp(a)); - - if let Some(block) = ssa.block_mut(block_idx) { - for instr_idx in indices { - if instr_idx < block.instructions().len() { - // Replace with Nop instead of removing to preserve indices - // during the same iteration - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - // Log with appropriate message based on instruction type - let message = if let Some(dest) = instr.op().dest() { - format!("dead definition {dest}") - } else { - format!("dead {}", instr.mnemonic()) - }; - instr.set_op(SsaOp::Nop); - let location = block_idx.saturating_mul(1000).saturating_add(instr_idx); - changes - .record(EventKind::InstructionRemoved) - .at(method_token, location) - .message(message); - } - } - } - } - } - } - - /// Removes dead phi nodes from their blocks. - /// - /// Phi nodes are processed in reverse order within each block to preserve - /// indices during removal. Unlike instructions, phi nodes are actually - /// removed from the block rather than replaced with a placeholder. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `dead_phis` - A slice of `(block_idx, phi_idx)` tuples identifying phi nodes - /// to remove. - /// * `method_token` - The metadata token of the method, used for change tracking. - /// * `changes` - The change set to record modifications in. - fn remove_phis( - ssa: &mut SsaFunction, - dead_phis: &[(usize, usize)], - method_token: Token, - changes: &mut EventLog, - ) { - // Group by block - let mut by_block: BTreeMap> = BTreeMap::new(); - for &(block_idx, phi_idx) in dead_phis { - by_block.entry(block_idx).or_default().push(phi_idx); - } - - for (block_idx, mut indices) in by_block { - // Sort in reverse order - indices.sort_by(|a, b| b.cmp(a)); - - if let Some(block) = ssa.block_mut(block_idx) { - for phi_idx in indices { - if phi_idx < block.phi_nodes().len() { - block.phi_nodes_mut().remove(phi_idx); - changes - .record(EventKind::PhiSimplified) - .at(method_token, block_idx) - .message("removed dead phi node"); - } - } - } - } - } - - /// Simplifies trivial phi nodes by performing copy propagation. - /// - /// For each trivial phi identified by [`PhiAnalyzer::find_all_trivial`]: - /// - If a replacement value is provided, the phi is converted to a copy and - /// all uses of the phi's result are replaced with the replacement value. - /// - If no replacement value is provided (fully self-referential phi), the - /// phi is simply removed as it represents undefined/unreachable code. - /// - /// Phi nodes are processed in reverse order within each block to preserve - /// indices during modification. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `trivial_phis` - A slice of `(block_idx, phi_idx, replacement)` tuples - /// from [`PhiAnalyzer::find_all_trivial`]. - /// * `method_token` - The metadata token of the method, used for change tracking. - /// * `changes` - The change set to record modifications in. - /// - /// # Returns - /// - /// The number of phi nodes that were simplified. - fn simplify_trivial_phis( - ssa: &mut SsaFunction, - trivial_phis: &[(usize, usize, Option)], - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut simplified: usize = 0; - - // Process in reverse order by phi_idx within each block - let mut by_block: BTreeMap)>> = BTreeMap::new(); - for &(block_idx, phi_idx, replacement) in trivial_phis { - by_block - .entry(block_idx) - .or_default() - .push((phi_idx, replacement)); - } - - for (block_idx, mut phis) in by_block { - // Sort by phi_idx in reverse order - phis.sort_by_key(|p| std::cmp::Reverse(p.0)); - - for (phi_idx, replacement) in phis { - if let Some(replacement_var) = replacement { - // Use the existing simplify_phi_to_copy which handles use replacement - if ssa.simplify_phi_to_copy(block_idx, phi_idx, replacement_var) { - changes - .record(EventKind::PhiSimplified) - .at(method_token, block_idx) - .message(format!("replaced with {replacement_var}")); - simplified = simplified.saturating_add(1); - } - } else { - // All self-references - just remove the phi - if ssa.remove_phi_unchecked(block_idx, phi_idx) { - changes - .record(EventKind::PhiSimplified) - .at(method_token, block_idx) - .message("removed self-referential phi"); - simplified = simplified.saturating_add(1); - } - } - } - } - - simplified - } - - /// Clears all instructions and phi nodes from unreachable blocks. - /// - /// Unreachable blocks are emptied rather than removed to preserve block indices - /// throughout the SSA graph. This is important because branch targets and phi - /// operand predecessors reference blocks by index. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `reachable` - The set of reachable block indices. - /// * `method_token` - The metadata token of the method, used for change tracking. - /// * `changes` - The change set to record modifications in. - /// - /// # Returns - /// - /// The number of blocks that were cleared. - fn clear_unreachable_blocks( - ssa: &mut SsaFunction, - reachable: &BitSet, - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut cleared: usize = 0; - let total_blocks = ssa.block_count(); - - for block_idx in 0..total_blocks { - if !reachable.contains(block_idx) { - if let Some(block) = ssa.block_mut(block_idx) { - if !block.is_empty() { - block.clear(); - changes - .record(EventKind::BlockRemoved) - .at(method_token, block_idx) - .message(format!("removed unreachable block {block_idx}")); - cleared = cleared.saturating_add(1); - } - } - } - } - - cleared - } - - /// Finds instructions without SSA operations (stack simulation artifacts). - /// - /// During SSA construction, some CIL instructions (like `ldloc`, `ldarg`) don't - /// create new SSA definitions - they just read existing variables. These instructions - /// remain in the instruction list with `op = None` but serve no purpose in SSA form. - /// - /// This function identifies such instructions for removal. Only non-terminator - /// instructions without an SSA operation are considered dead. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze. - /// * `reachable` - The set of reachable block indices. - /// - /// # Returns - /// - /// A vector of `(block_idx, instruction_idx)` tuples identifying op-less instructions. - fn find_opless_instructions(ssa: &SsaFunction, reachable: &BitSet) -> Vec<(usize, usize)> { - let mut opless = Vec::new(); - - for block_idx in reachable.iter() { - if let Some(block) = ssa.block(block_idx) { - let instr_count = block.instructions().len(); - for (instr_idx, instr) in block.instructions().iter().enumerate() { - // Skip the last instruction if it might be a terminator - // (terminators should always have an op, but be safe) - let is_last = instr_idx == instr_count.saturating_sub(1); - - // An instruction with Nop op is a stack simulation artifact - if matches!(instr.op(), SsaOp::Nop) { - // Don't remove the last instruction if the block would become empty - // (this preserves block structure for terminators) - if !is_last || instr_count > 1 { - opless.push((block_idx, instr_idx)); - } - } - } - } - } - - opless - } - - /// Removes op-less instructions (stack simulation artifacts). - /// - /// These instructions have no SSA operation and serve no purpose after - /// SSA construction. They are removed by actually deleting them from the - /// instruction list rather than replacing with Nop. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `opless` - A slice of `(block_idx, instruction_idx)` tuples. - /// * `method_token` - The metadata token of the method. - /// * `changes` - The change set to record modifications. - /// - /// # Returns - /// - /// The number of instructions removed. - fn remove_opless_instructions( - ssa: &mut SsaFunction, - opless: &[(usize, usize)], - method_token: Token, - changes: &mut EventLog, - ) -> usize { - if opless.is_empty() { - return 0; - } - - // Group by block - let mut by_block: BTreeMap> = BTreeMap::new(); - for &(block_idx, instr_idx) in opless { - by_block.entry(block_idx).or_default().push(instr_idx); - } - - let mut removed: usize = 0; - - for (block_idx, mut indices) in by_block { - // Sort in reverse order to remove from end first (preserves indices) - indices.sort_by(|a, b| b.cmp(a)); - - if let Some(block) = ssa.block_mut(block_idx) { - for instr_idx in indices { - if instr_idx < block.instructions().len() { - // Get mnemonic for logging before removal - let mnemonic = block - .instructions() - .get(instr_idx) - .map_or("unknown", SsaInstruction::mnemonic); - - block.instructions_mut().remove(instr_idx); - let location = block_idx.saturating_mul(1000).saturating_add(instr_idx); - changes - .record(EventKind::InstructionRemoved) - .at(method_token, location) - .message(format!("removed op-less instruction: {mnemonic}")); - removed = removed.saturating_add(1); - } - } - } - } - - removed - } - - /// Removes all Nop instructions from reachable blocks. - /// - /// Nop instructions are dead code that should be removed to simplify - /// the CFG. This is done before block merging so trampolines can be - /// properly detected. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `reachable` - Set of reachable block indices. - /// * `method_token` - The method token for change tracking. - /// * `changes` - Event log for recording changes. - /// - /// # Returns - /// - /// The number of Nop instructions removed. - fn remove_nop_instructions( - ssa: &mut SsaFunction, - reachable: &BitSet, - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut removed: usize = 0; - - for block_idx in reachable.iter() { - if let Some(block) = ssa.block_mut(block_idx) { - let original_len = block.instructions().len(); - block - .instructions_mut() - .retain(|instr| !matches!(instr.op(), SsaOp::Nop)); - let new_len = block.instructions().len(); - let nops_removed = original_len.saturating_sub(new_len); - - if nops_removed > 0 { - changes - .record(EventKind::InstructionRemoved) - .at(method_token, block_idx) - .message(format!("removed {nops_removed} Nop instructions")); - removed = removed.saturating_add(nops_removed); - } - } - } - - removed - } - - /// Runs a single iteration of the dead code elimination algorithm. - /// - /// Each iteration performs the following steps: - /// 1. Find reachable blocks from entry and exception handlers - /// 2. Clear unreachable blocks - /// 3. Remove op-less instructions (stack simulation artifacts) - /// 4. Remove Nop instructions - /// 5. Prune phi operands from unreachable predecessors - /// 6. Find and simplify trivial phi nodes - /// 7. Recompute reachability (may change after phi simplification) - /// 8. Compute liveness via reverse dataflow - /// 9. Remove dead phi nodes - /// 10. Remove dead definitions - /// - /// The algorithm is run iteratively until no more changes are made (fixed point). - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to modify. - /// * `method_token` - The metadata token of the method, used for change tracking. - /// * `changes` - The change set to record modifications in. - /// - /// # Returns - /// - /// The total number of changes made during this iteration. Zero indicates - /// the algorithm has reached a fixed point. - fn run_iteration(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize { - let mut total_changes: usize = 0; - - // Step 1: Find reachable blocks - let reachable = Self::find_reachable_blocks(ssa); - - // Step 2: Clear unreachable blocks - total_changes = total_changes.saturating_add(Self::clear_unreachable_blocks( - ssa, - &reachable, - method_token, - changes, - )); - - // Step 3: Remove op-less instructions (stack simulation artifacts like ldloc/ldarg - // that weren't decomposed to SSA operations) - let opless = Self::find_opless_instructions(ssa, &reachable); - total_changes = total_changes.saturating_add(Self::remove_opless_instructions( - ssa, - &opless, - method_token, - changes, - )); - - // Step 4: Remove Nop instructions (simplifies CFG for block merging) - total_changes = total_changes.saturating_add(Self::remove_nop_instructions( - ssa, - &reachable, - method_token, - changes, - )); - - // Step 5: Prune phi operands from unreachable predecessors - total_changes = total_changes.saturating_add(ssa.prune_phi_operands(&reachable)); - let reachable_set: BTreeSet = reachable.iter().collect(); - - // Step 6: Find and simplify trivial phis (doesn't need liveness) - // Trivial phis are identified purely by structure (all operands same or self-referential) - let trivial_phis = PhiAnalyzer::new(ssa).find_all_trivial(&reachable_set); - total_changes = total_changes.saturating_add(Self::simplify_trivial_phis( - ssa, - &trivial_phis, - method_token, - changes, - )); - - // Step 7: Recompute reachability after phi simplification - let reachable = Self::find_reachable_blocks(ssa); - - // Step 8: Compute reverse post-order and liveness for dead code analysis - let rpo = Self::compute_reverse_postorder(ssa, &reachable); - let live = Self::compute_live_variables(ssa, &reachable, &rpo); - - // Step 9: Find and remove dead phi nodes (unused results) - let dead_phis = Self::find_dead_phis(ssa, &reachable, &live); - - // Collect dead phi results for Pop elimination - let mut dead_phi_results = BitSet::new(ssa.var_id_capacity()); - for &(block_idx, phi_idx) in &dead_phis { - if let Some(result) = ssa - .block(block_idx) - .and_then(|b| b.phi_nodes().get(phi_idx)) - .map(PhiNode::result) - { - dead_phi_results.insert(result.index()); - } - } - - Self::remove_phis(ssa, &dead_phis, method_token, changes); - total_changes = total_changes.saturating_add(dead_phis.len()); - - // Step 10: Find and remove dead definitions (pure ops with unused results) - let dead_defs = Self::find_dead_definitions(ssa, &reachable, &live, &dead_phi_results); - let c10 = dead_defs.len(); - Self::remove_instructions(ssa, &dead_defs, method_token, changes); - total_changes = total_changes.saturating_add(c10); - - // Step 10b: Clean up Nops created by remove_instructions (which replaces - // dead instructions with Nop to preserve indices). Without this, the next - // iteration would find these Nops as opless instructions, creating an - // oscillation: dead_def → Nop → opless → dead_def → ... - if c10 > 0 { - for block_idx in reachable.iter() { - if let Some(block) = ssa.block_mut(block_idx) { - block - .instructions_mut() - .retain(|instr| !matches!(instr.op(), SsaOp::Nop)); - } - } - } - - total_changes - } -} - -impl SsaPass for DeadCodeEliminationPass { - fn name(&self) -> &'static str { - "dead-code-elimination" - } - - fn description(&self) -> &'static str { - "Eliminates unreachable code and unused definitions" - } - - fn modification_scope(&self) -> ModificationScope { - ModificationScope::InstructionsOnly - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - - // Iterate until fixed point - for _ in 0..self.max_iterations { - if Self::run_iteration(ssa, method_token, &mut changes) == 0 { - break; - } - } - - // After removing dead instructions and phis, compact the variable table - // to remove orphaned variable entries. - let changed = !changes.is_empty(); - if changed { - ssa.compact_variables(); - ctx.events.merge(&changes); - } - - Ok(changed) - } -} - /// Global dead method elimination pass. /// -/// This pass operates at the assembly level to identify and mark methods that -/// are never called and are not entry points. Unlike [`DeadCodeEliminationPass`] -/// which operates within a single method, this pass analyzes the call graph -/// across the entire assembly. -/// -/// # Algorithm -/// -/// The pass uses a worklist algorithm starting from entry points: -/// 1. Initialize the live set with all entry point methods -/// 2. For each live method, add all its callees to the worklist -/// 3. Continue until the worklist is empty -/// 4. Mark all methods not in the live set as dead -/// -/// # Entry Points -/// -/// Entry points typically include: -/// - The `Main` method -/// - Event handlers -/// - Methods invoked via reflection -/// - Virtual method implementations that may be called polymorphically +/// Operates at the assembly level to identify methods that are never +/// called and not entry points. Delegates the reachability walk to +/// [`analyssa::passes::deadcode::run_global`] via an internal `World`-trait +/// adapter that combines SSA-derived and static call edges. pub struct DeadMethodEliminationPass; impl Default for DeadMethodEliminationPass { @@ -1078,17 +36,13 @@ impl Default for DeadMethodEliminationPass { impl DeadMethodEliminationPass { /// Creates a new dead method elimination pass. - /// - /// # Returns - /// - /// A new instance of `DeadMethodEliminationPass`. #[must_use] pub fn new() -> Self { Self } } -impl SsaPass for DeadMethodEliminationPass { +impl SsaPass for DeadMethodEliminationPass { fn name(&self) -> &'static str { "dead-method-elimination" } @@ -1108,597 +62,77 @@ impl SsaPass for DeadMethodEliminationPass { fn run_on_method( &self, _ssa: &mut SsaFunction, - _method_token: Token, - _ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - // This is a global pass, so run_on_method is not used + _method: &MethodRef, + _host: &CompilerContext, + ) -> analyssa::Result { + // Global pass — never invoked per-method. Ok(false) } - fn run_global(&self, ctx: &CompilerContext, _assembly: &CilObject) -> Result { - let changes = EventLog::new(); - - // Build a live call graph from actual SSA calls (not the static call graph). - // This accounts for inlining and other transformations that may have removed calls. - let ssa_callees = ctx.build_ssa_call_graph(); - - // Methods that are definitely live (entry points and their transitive callees) - let mut live_methods: BTreeSet = ctx.entry_points.iter().map(|e| *e).collect(); - let mut worklist: VecDeque = live_methods.iter().copied().collect(); - - // Mark transitive callees as live using SSA-derived call information - while let Some(method) = worklist.pop_front() { - // Use SSA callees if available, otherwise fall back to call graph - let callees = if let Some(ssa_calls) = ssa_callees.get(&method) { - ssa_calls.iter().copied().collect::>() - } else { - ctx.call_graph.callees(method) - }; - - for callee in callees { - if !live_methods.contains(&callee) { - live_methods.insert(callee); - worklist.push_back(callee); - } - } - } - - // Mark all methods not in live set as dead - for method in ctx.all_methods() { - if !live_methods.contains(&method) && !ctx.is_dead(method) { - ctx.mark_dead(method); - changes - .record(EventKind::MethodMarkedDead) - .method(method) - .message("method has no live callers"); - } - } - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) + fn run_global(&self, host: &CompilerContext) -> analyssa::Result { + let world = CtxWorld::new(host); + Ok(deadcode::run_global::( + &world, + &host.events, + )) } } -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::{ - analysis::{ - CallGraph, ConstValue, DefSite, MethodRef, PhiNode, PhiOperand, SsaBlock, SsaFunction, - SsaFunctionBuilder, SsaInstruction, SsaOp, SsaType, SsaVarId, VariableOrigin, - }, - compiler::{passes::deadcode::DeadCodeEliminationPass, CompilerContext, SsaPass}, - metadata::token::Token, - test::helpers::test_assembly_arc, - }; - - // Helper to create a minimal analysis context for testing - fn test_context() -> CompilerContext { - let call_graph = Arc::new(CallGraph::new()); - CompilerContext::new(call_graph) - } - - #[test] - fn test_successor_extraction() { - // Test jump - let op = SsaOp::Jump { target: 5 }; - assert_eq!(op.successors(), vec![5]); - - // Test branch - let cond = SsaVarId::from_index(0); - let op = SsaOp::Branch { - condition: cond, - true_target: 1, - false_target: 2, - }; - assert_eq!(op.successors(), vec![1, 2]); - - // Test switch - let val = SsaVarId::from_index(1); - let op = SsaOp::Switch { - value: val, - targets: vec![1, 2, 3], - default: 4, - }; - assert_eq!(op.successors(), vec![1, 2, 3, 4]); - - // Test return (no successors) - let op = SsaOp::Return { value: None }; - assert!(op.successors().is_empty()); - - // Test leave - let op = SsaOp::Leave { target: 3 }; - assert_eq!(op.successors(), vec![3]); - } - - #[test] - fn test_empty_function() { - let ssa = SsaFunction::new(0, 0); - - let reachable = DeadCodeEliminationPass::find_reachable_blocks(&ssa); - assert!(reachable.is_empty()); - } - - #[test] - fn test_single_block_reachable() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.ret()); - }) - .unwrap(); - - let reachable = DeadCodeEliminationPass::find_reachable_blocks(&ssa); - assert_eq!(reachable.count(), 1); - assert!(reachable.contains(0)); - } - - #[test] - fn test_unreachable_block_detection() { - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: entry, jumps to block 1 - f.block(0, |b| b.jump(1)); - // Block 1: reachable from block 0 - f.block(1, |b| b.ret()); - // Block 2: unreachable - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let reachable = DeadCodeEliminationPass::find_reachable_blocks(&ssa); - assert_eq!(reachable.count(), 2); - assert!(reachable.contains(0)); - assert!(reachable.contains(1)); - assert!(!reachable.contains(2)); - } - - #[test] - fn test_cascading_dead_code() { - // Test that iterative DCE removes cascading dead definitions: - // v1 = 5 (dead after v2 removed) - // v2 = v1 + 3 (dead after v3 removed) - // v3 = v2 * 2 (dead - unused) - // return - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v1 = b.const_i32(5); - let three = b.const_i32(3); - let v2 = b.add(v1, three); - let two = b.const_i32(2); - let _ = b.mul(v2, two); - b.ret(); // return (no value - nothing is live) - }); - }) - .unwrap(); - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // All pure operations should be marked as dead - assert!(changed); - - // Only the return should remain as non-Nop - let block = ssa.block(0).unwrap(); - let non_nop_count = block - .instructions() - .iter() - .filter(|i| !matches!(i.op(), SsaOp::Nop)) - .count(); - - assert_eq!(non_nop_count, 1); // Only return - } - - #[test] - fn test_dead_phi_elimination() { - // Test that unused phi nodes are removed - let mut ssa = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: entry, branch - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - // Block 1: defines v1 - f.block(1, |b| { - v1_out = b.const_i32(10); - b.jump(3); - }); - // Block 2: defines v2 - f.block(2, |b| { - v2_out = b.const_i32(20); - b.jump(3); - }); - // Block 3: merge with phi (but result is unused!) - f.block(3, |b| { - let _ = b.phi(&[(1, v1_out), (2, v2_out)]); - b.ret(); // Phi result not used! - }); - }) - .unwrap() - }; - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // The phi should be removed since its result is never used - let block3 = ssa.block(3).unwrap(); - assert_eq!(block3.phi_nodes().len(), 0); - } - - #[test] - fn test_trivial_phi_single_operand() { - // Test that phi with single operand is simplified - let (mut ssa, v1) = { - let mut v1_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: defines v1, jumps to block 1 - f.block(0, |b| { - v1_out = b.const_i32(42); - b.jump(1); - }); - // Block 1: phi with single operand (trivial) - f.block(1, |b| { - let phi_result = b.phi(&[(0, v1_out)]); - b.ret_val(phi_result); - }); - }) - .unwrap(); - (ssa, v1_out) - }; - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Phi should be simplified - uses of phi_result should be replaced with v1 - let block1 = ssa.block(1).unwrap(); - assert_eq!(block1.phi_nodes().len(), 0); - - // Return should now use v1 - if let Some(SsaOp::Return { value }) = block1.terminator_op() { - assert_eq!(*value, Some(v1)); - } - } - - #[test] - fn test_trivial_phi_all_same() { - // Test that phi where all operands are the same is simplified - let (mut ssa, v1) = { - let mut v1_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: entry, branch - f.block(0, |b| { - let cond = b.const_true(); - v1_out = b.const_i32(42); - b.branch(cond, 1, 2); - }); - // Block 1: jumps to merge - f.block(1, |b| b.jump(3)); - // Block 2: jumps to merge - f.block(2, |b| b.jump(3)); - // Block 3: phi with all same operands (both from v1) - f.block(3, |b| { - let phi_result = b.phi(&[(1, v1_out), (2, v1_out)]); - b.ret_val(phi_result); - }); - }) - .unwrap(); - (ssa, v1_out) - }; - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Phi should be simplified - let block3 = ssa.block(3).unwrap(); - assert_eq!(block3.phi_nodes().len(), 0); - - // Return should now use v1 - if let Some(SsaOp::Return { value }) = block3.terminator_op() { - assert_eq!(*value, Some(v1)); - } - } - - #[test] - fn test_self_referential_phi() { - // Test phi like phi_var = phi(phi_var, v2) simplifies to phi_var = v2 - // We need to manually construct this since the builder can't create self-references - let mut ssa = SsaFunction::new(0, 0); - - // Create variables - let v2 = ssa.create_variable( - VariableOrigin::Local(0), - 0, - DefSite::instruction(0, 0), - SsaType::Unknown, - ); - let phi_var = ssa.create_variable( - VariableOrigin::Local(1), - 0, - DefSite::phi(1), - SsaType::Unknown, - ); - - // Block 0: entry, defines v2, jumps to block 1 - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: v2, - value: ConstValue::I32(42), - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - ssa.add_block(block0); - - // Block 1: loop header with self-referential phi - // phi_var = phi(v2 from block 0, phi_var from block 1) - let mut block1 = SsaBlock::new(1); - let mut phi = PhiNode::new(phi_var, VariableOrigin::Local(1)); - phi.add_operand(PhiOperand::new(v2, 0)); // from block 0 - phi.add_operand(PhiOperand::new(phi_var, 1)); // from block 1 (self-reference) - block1.add_phi(phi); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { - value: Some(phi_var), - })); - ssa.add_block(block1); - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Phi should be simplified - phi_var = phi(v2, phi_var) becomes phi_var = v2 - let block1 = ssa.block(1).unwrap(); - assert_eq!(block1.phi_nodes().len(), 0); +/// `World` adapter over `CompilerContext`. +/// +/// Built once per `run_global` invocation. Combines the SSA-derived call +/// graph (from `ctx.build_ssa_call_graph()`) with the static call graph +/// (`ctx.call_graph`) so [`World::callees`] returns a single unified view — +/// SSA-derived edges win when present, static edges fill in for methods we +/// haven't built SSA for. +struct CtxWorld<'a> { + ctx: &'a CompilerContext, + ssa_callees: BTreeMap>, + methods: Vec, + entries: Vec, +} - // Return should now use v2 - if let Some(SsaOp::Return { value }) = block1.terminator_op() { - assert_eq!(*value, Some(v2)); +impl<'a> CtxWorld<'a> { + fn new(ctx: &'a CompilerContext) -> Self { + let ssa_callees = ctx.build_ssa_call_graph(); + let methods: Vec = ctx.all_methods().collect(); + let entries: Vec = ctx.entry_points.iter().map(|e| *e).collect(); + Self { + ctx, + ssa_callees, + methods, + entries, } } +} - #[test] - fn test_phi_operand_pruning() { - // Test that phi operands from unreachable blocks are pruned - let mut ssa = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Block 0: entry, always jumps to block 1 (block 2 is unreachable) - f.block(0, |b| { - v1_out = b.const_i32(10); - b.jump(1); // Always goes to 1 - }); - // Block 1: reachable merge - f.block(1, |b| { - let phi_result = b.phi(&[(0, v1_out), (2, v2_out)]); // v2 from unreachable block 2 - b.ret_val(phi_result); - }); - // Block 2: unreachable - f.block(2, |b| { - v2_out = b.const_i32(20); - b.jump(1); - }); - }) - .unwrap() - }; - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Block 2 should be cleared - let block2 = ssa.block(2).unwrap(); - assert!(block2.instructions().is_empty()); - - // Phi in block 1 should be simplified (only one valid operand after pruning) - let block1 = ssa.block(1).unwrap(); - // After pruning, the phi becomes trivial and should be simplified - assert_eq!(block1.phi_nodes().len(), 0); - } - - #[test] - fn test_side_effect_preservation() { - // Test that side-effectful operations are not removed - let (mut ssa, v1) = { - let mut v1_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - // v1 = some value (will be used by call) - v1_out = b.const_i32(42); - // Call (side effect - should not be removed even if result unused) - let method = MethodRef::new(Token::new(0x06000002)); - let _ = b.call(method, &[v1_out], SsaType::I32); - // Return without using call result - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out) - }; - - // Run DCE - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let _changes = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Call should still be there (side effect) - let block = ssa.block(0).unwrap(); - let has_call = block - .instructions() - .iter() - .any(|i| matches!(i.op(), SsaOp::Call { .. })); - assert!(has_call); - - // v1 should also be preserved (used by call) - let has_const = block - .instructions() - .iter() - .any(|i| matches!(i.op(), SsaOp::Const { dest, .. } if *dest == v1)); - assert!(has_const); +impl World for CtxWorld<'_> { + fn all_methods(&self) -> Vec { + self.methods.iter().copied().map(MethodRef::from).collect() } - #[test] - fn test_reverse_postorder() { - // Create a diamond CFG: 0 -> {1, 2} -> 3 - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - f.block(1, |b| b.jump(3)); - f.block(2, |b| b.jump(3)); - f.block(3, |b| b.ret()); - }) - .unwrap(); - - let reachable = DeadCodeEliminationPass::find_reachable_blocks(&ssa); - let rpo = DeadCodeEliminationPass::compute_reverse_postorder(&ssa, &reachable); - - // RPO should have entry first, exit last - assert_eq!(rpo[0], 0); // Entry - assert_eq!(*rpo.last().unwrap(), 3); // Exit (merge point) - assert_eq!(rpo.len(), 4); + fn entry_points(&self) -> Vec { + self.entries.iter().copied().map(MethodRef::from).collect() } - #[test] - fn test_live_variable_computation() { - let (ssa, v1, v2) = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - // v1 = 10 (live - used by return) - v1_out = b.const_i32(10); - // v2 = 20 (dead - not used) - v2_out = b.const_i32(20); - // return v1 - b.ret_val(v1_out); - }); - }) - .unwrap(); - (ssa, v1_out, v2_out) - }; - - let reachable = DeadCodeEliminationPass::find_reachable_blocks(&ssa); - let rpo = DeadCodeEliminationPass::compute_reverse_postorder(&ssa, &reachable); - let live = DeadCodeEliminationPass::compute_live_variables(&ssa, &reachable, &rpo); - - assert!(live.contains(v1.index())); // v1 is live (returned) - assert!(!live.contains(v2.index())); // v2 is dead + fn callees(&self, method: &MethodRef) -> Vec { + let token = method.token(); + if let Some(ssa_calls) = self.ssa_callees.get(&token) { + return ssa_calls.iter().copied().map(MethodRef::from).collect(); + } + self.ctx + .call_graph + .callees(token) + .into_iter() + .map(MethodRef::from) + .collect() } - #[test] - fn test_transitive_liveness() { - // Test that liveness propagates transitively - // v1 = 5 - // v2 = v1 + 1 - // v3 = v2 * 2 - // return v3 - // All should be live! - - let (ssa, v1, v2, v3, c1, c2) = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let mut v3_out = SsaVarId::from_index(2); - let mut c1_out = SsaVarId::from_index(3); - let mut c2_out = SsaVarId::from_index(4); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - v1_out = b.const_i32(5); - c1_out = b.const_i32(1); - v2_out = b.add(v1_out, c1_out); - c2_out = b.const_i32(2); - v3_out = b.mul(v2_out, c2_out); - b.ret_val(v3_out); - }); - }) - .unwrap(); - (ssa, v1_out, v2_out, v3_out, c1_out, c2_out) - }; - - let reachable = DeadCodeEliminationPass::find_reachable_blocks(&ssa); - let rpo = DeadCodeEliminationPass::compute_reverse_postorder(&ssa, &reachable); - let live = DeadCodeEliminationPass::compute_live_variables(&ssa, &reachable, &rpo); - - // All should be live transitively - assert!(live.contains(v1.index())); - assert!(live.contains(v2.index())); - assert!(live.contains(v3.index())); - assert!(live.contains(c1.index())); - assert!(live.contains(c2.index())); + fn is_dead(&self, method: &MethodRef) -> bool { + self.ctx.is_dead(method.token()) } - #[test] - fn test_iterative_convergence() { - // Test that the algorithm converges (doesn't infinite loop) - let mut ssa = { - let mut v0_out = SsaVarId::from_index(0); - let mut phi_out = SsaVarId::from_index(1); - SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // Create a loop structure - f.block(0, |b| { - v0_out = b.const_i32(0); - b.jump(1); - }); - f.block(1, |b| { - // phi: from entry (v0) and from back edge (v2) - phi_out = b.phi(&[(0, v0_out), (1, phi_out)]); - // v2 = phi + 1 (unused, becomes back edge value) - let one = b.const_i32(1); - let _ = b.add(phi_out, one); - // Condition to exit loop - let cond = b.const_true(); - b.branch(cond, 2, 1); - }); - f.block(2, |b| b.ret()); - }) - .unwrap() - }; - - // Run DCE - should converge - let pass = DeadCodeEliminationPass::new(20); - let ctx = test_context(); - let result = - pass.run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()); - - assert!(result.is_ok()); + fn mark_dead(&self, method: &MethodRef) { + self.ctx.mark_dead(method.token()); } } diff --git a/dotscope/src/compiler/passes/gvn.rs b/dotscope/src/compiler/passes/gvn.rs deleted file mode 100644 index 5830c855..00000000 --- a/dotscope/src/compiler/passes/gvn.rs +++ /dev/null @@ -1,418 +0,0 @@ -//! Global Value Numbering (GVN) pass. -//! -//! This pass eliminates redundant computations by detecting when the same -//! expression is computed multiple times with the same operands, and replacing -//! later uses with the earlier result. -//! -//! # Example -//! -//! Before: -//! ```text -//! v1 = add v0, 5 -//! v2 = add v0, 5 // Redundant - same operation -//! v3 = mul v1, v2 -//! ``` -//! -//! After: -//! ```text -//! v1 = add v0, 5 -//! v2 = add v0, 5 // Now dead (DCE will remove) -//! v3 = mul v1, v1 // v2 replaced with v1 -//! ``` -//! -//! # Algorithm -//! -//! The pass uses hash-based value numbering: -//! -//! 1. For each pure operation, create a hashable key from its opcode and operands -//! 2. Normalize commutative operations (e.g., `add v1, v0` → `add v0, v1`) -//! 3. If the key was seen before, replace uses of the new result with the old one -//! 4. Otherwise, record the key with this operation's result -//! -//! # Limitations -//! -//! - Only handles pure operations (no side effects) -//! - Does not perform code motion (dominator-based GVN would be more powerful) -//! - Works within a single method - -use std::collections::HashMap; - -use crate::{ - analysis::{BinaryOpKind, SsaFunction, SsaOp, SsaVarId, UnaryOpKind}, - compiler::{ - pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, EventLog, - }, - metadata::token::Token, - CilObject, Result, -}; - -/// A hashable key representing an operation for value numbering. -/// -/// This captures the "value" of an expression - the operation type and operands, -/// but not the destination. Two operations with the same key compute the same value. -/// -/// Uses the centralized `BinaryOpKind` and `UnaryOpKind` types from the SSA module. -#[derive(Debug, Clone, PartialEq, Eq, Hash)] -enum ValueKey { - /// Binary operation: (kind, unsigned_flag, left, right) - /// - /// The unsigned flag is included for operations where signedness affects - /// semantics (Div, Rem, Shr, Clt, Cgt). For other operations, it's normalized - /// to `false` for consistent hashing. - Binary(BinaryOpKind, bool, SsaVarId, SsaVarId), - - /// Unary operation: (kind, operand) - Unary(UnaryOpKind, SsaVarId), - - /// Argument load: loading the same argument always produces the same value. - LoadArg(usize), -} - -impl ValueKey { - /// Creates a normalized value key from an SSA operation. - /// - /// Returns `None` for operations that shouldn't be value-numbered - /// (impure operations, constants, control flow, etc.). - /// - /// Uses `as_binary_op()` and `as_unary_op()` for extraction, then applies - /// normalization for commutative operations. - fn from_op(op: &SsaOp) -> Option<(Self, SsaVarId)> { - // Try binary operations first - if let Some(info) = op.as_binary_op() { - // Skip overflow-checked operations (they may throw) - if matches!( - info.kind, - BinaryOpKind::AddOvf | BinaryOpKind::SubOvf | BinaryOpKind::MulOvf - ) { - return None; - } - - // Normalize for consistent hashing - let normalized = info.normalized(); - let (kind, unsigned, left, right) = normalized.value_key(); - return Some((Self::Binary(kind, unsigned, left, right), normalized.dest)); - } - - // Try unary operations - if let Some(info) = op.as_unary_op() { - // Skip Ckfinite (it may throw) - if info.kind == UnaryOpKind::Ckfinite { - return None; - } - return Some((Self::Unary(info.kind, info.operand), info.dest)); - } - - // LoadArg: loading the same argument twice produces the same value. - // This enables downstream passes (e.g., opaque predicate removal) to - // detect self-comparisons like `ceq(arg0, arg0)`. - if let SsaOp::LoadArg { dest, arg_index } = op { - return Some((Self::LoadArg(*arg_index as usize), *dest)); - } - - // Skip everything else (constants, loads, stores, calls, control flow, etc.) - None - } -} - -/// Global Value Numbering pass. -/// -/// Eliminates redundant computations by detecting equivalent expressions -/// and replacing later occurrences with references to earlier results. -pub struct GlobalValueNumberingPass; - -impl Default for GlobalValueNumberingPass { - fn default() -> Self { - Self::new() - } -} - -impl GlobalValueNumberingPass { - /// Creates a new GVN pass. - #[must_use] - pub fn new() -> Self { - Self - } - - /// Runs GVN on a single SSA function. - /// - /// Returns the number of redundant expressions eliminated. - fn run_gvn(ssa: &mut SsaFunction, method_token: Token, changes: &mut EventLog) -> usize { - // Map from value key to the first variable that computed it - let mut value_map: HashMap = HashMap::new(); - - // Collect redundant definitions: (redundant_var, original_var, block_idx, instr_idx) - let mut redundant: Vec<(SsaVarId, SsaVarId, usize, usize)> = Vec::new(); - - // First pass: identify redundant computations - for block in ssa.blocks() { - let block_idx = block.id(); - for (instr_idx, instr) in block.instructions().iter().enumerate() { - if let Some((key, dest)) = ValueKey::from_op(instr.op()) { - if let Some(&original) = value_map.get(&key) { - // This computation is redundant - redundant.push((dest, original, block_idx, instr_idx)); - } else { - // First time seeing this value - value_map.insert(key, dest); - } - } - } - } - - // Second pass: replace uses of redundant variables with originals, - // including phi operands, then nop-out the dead instruction. - // This prevents ping-ponging with DCE: without nop-out, DCE would find - // the dead instruction as "new work" on the next normalization iteration. - let mut total_replaced: usize = 0; - for (redundant_var, original_var, block_idx, instr_idx) in &redundant { - let result = ssa.replace_uses_including_phis(*redundant_var, *original_var); - if result.replaced > 0 { - changes - .record(EventKind::ConstantFolded) // Reuse existing kind for expression elimination - .method(method_token) - .message(format!( - "GVN: {redundant_var} → {original_var} ({} uses)", - result.replaced - )); - total_replaced = total_replaced.saturating_add(result.replaced); - } - // Nop-out the redundant instruction so rebuild_ssa's strip_nops - // removes it. This avoids leaving dead instructions for DCE to find. - ssa.remove_instruction(*block_idx, *instr_idx); - } - - total_replaced - } -} - -impl SsaPass for GlobalValueNumberingPass { - fn name(&self) -> &'static str { - "global-value-numbering" - } - - fn description(&self) -> &'static str { - "Eliminates redundant computations using value numbering" - } - - fn modification_scope(&self) -> ModificationScope { - ModificationScope::UsesOnly - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - - // GVN is a single-pass algorithm - no iteration needed - // (unlike copy propagation which needs to resolve chains) - Self::run_gvn(ssa, method_token, &mut changes); - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::{BinaryOpKind, ConstValue, SsaFunctionBuilder, SsaOp, SsaVarId, UnaryOpKind}, - compiler::{ - passes::gvn::{GlobalValueNumberingPass, ValueKey}, - EventLog, - }, - metadata::token::Token, - }; - - #[test] - fn test_value_key_binary_commutative() { - // Test that commutative operations with swapped operands produce the same key - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - let v3 = SsaVarId::from_index(3); - - // Add is commutative - should normalize to same key - let add_op1 = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let add_op2 = SsaOp::Add { - dest: v3, - left: v1, - right: v0, - }; // Swapped operands - - let (key1, _) = ValueKey::from_op(&add_op1).unwrap(); - let (key2, _) = ValueKey::from_op(&add_op2).unwrap(); - assert_eq!(key1, key2, "Add should be commutative"); - - // Sub is not commutative - different keys - let sub_op1 = SsaOp::Sub { - dest: v2, - left: v0, - right: v1, - }; - let sub_op2 = SsaOp::Sub { - dest: v3, - left: v1, - right: v0, - }; - - let (key3, _) = ValueKey::from_op(&sub_op1).unwrap(); - let (key4, _) = ValueKey::from_op(&sub_op2).unwrap(); - assert_ne!(key3, key4, "Sub should NOT be commutative"); - } - - #[test] - fn test_value_key_from_op() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let v2 = SsaVarId::from_index(2); - - // Add operation - let add_op = SsaOp::Add { - dest: v2, - left: v0, - right: v1, - }; - let (key, dest) = ValueKey::from_op(&add_op).unwrap(); - assert_eq!(dest, v2); - assert!(matches!(key, ValueKey::Binary(BinaryOpKind::Add, _, _, _))); - - // Neg operation - let neg_op = SsaOp::Neg { - dest: v1, - operand: v0, - }; - let (key, dest) = ValueKey::from_op(&neg_op).unwrap(); - assert_eq!(dest, v1); - assert!(matches!(key, ValueKey::Unary(UnaryOpKind::Neg, _))); - - // Const should return None (not value-numbered) - let const_op = SsaOp::Const { - dest: v0, - value: ConstValue::I32(42), - }; - assert!(ValueKey::from_op(&const_op).is_none()); - } - - #[test] - fn test_gvn_eliminates_redundant() { - // Build SSA: - // v2 = add v0, v1 - // v3 = add v0, v1 <- redundant - // v4 = mul v2, v3 - let (mut ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(10); - let v1 = b.const_i32(20); - let v2 = b.add(v0, v1); - v2_out = v2; - let v3 = b.add(v0, v1); // Redundant - let v4 = b.mul(v2, v3); - b.ret_val(v4); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut changes = EventLog::new(); - let replaced = - GlobalValueNumberingPass::run_gvn(&mut ssa, Token::new(0x06000001), &mut changes); - - // Should have replaced uses of v3 with v2 - assert!(replaced > 0); - assert!(!changes.is_empty()); - - // The mul should now use v2 twice - let block = ssa.block(0).unwrap(); - let mul_instr = &block.instructions()[4]; // After 2 consts and 2 adds - if let SsaOp::Mul { left, right, .. } = mul_instr.op() { - assert_eq!(*left, v2); - assert_eq!(*right, v2); - } else { - panic!("Expected Mul instruction"); - } - } - - #[test] - fn test_gvn_commutative_order() { - // Build SSA: - // v2 = add v0, v1 - // v3 = add v1, v0 <- same as v2 (commutative) - let (mut ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(10); - let v1 = b.const_i32(20); - let v2 = b.add(v0, v1); - v2_out = v2; - let v3 = b.add(v1, v0); // Swapped - b.ret_val(v3); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut changes = EventLog::new(); - let replaced = - GlobalValueNumberingPass::run_gvn(&mut ssa, Token::new(0x06000001), &mut changes); - - // Should detect that add v1, v0 == add v0, v1 - assert!(replaced > 0); - - // Return should now use v2 - let block = ssa.block(0).unwrap(); - let ret_instr = &block.instructions()[4]; // After 2 consts and 2 adds - if let SsaOp::Return { - value: Some(ret_val), - } = ret_instr.op() - { - assert_eq!(*ret_val, v2); - } else { - panic!("Expected Return instruction"); - } - } - - #[test] - fn test_gvn_non_commutative_preserved() { - // Build SSA: - // v2 = sub v0, v1 - // v3 = sub v1, v0 <- NOT the same (sub is not commutative) - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(10); - let v1 = b.const_i32(20); - let _v2 = b.sub(v0, v1); - let v3 = b.sub(v1, v0); - b.ret_val(v3); - }); - }) - .unwrap(); - - let mut changes = EventLog::new(); - let replaced = - GlobalValueNumberingPass::run_gvn(&mut ssa, Token::new(0x06000001), &mut changes); - - // Should NOT replace - these are different values - assert_eq!(replaced, 0); - assert!(changes.is_empty()); - } -} diff --git a/dotscope/src/compiler/passes/inlining.rs b/dotscope/src/compiler/passes/inlining.rs index a864af61..cde080c4 100644 --- a/dotscope/src/compiler/passes/inlining.rs +++ b/dotscope/src/compiler/passes/inlining.rs @@ -34,14 +34,15 @@ use std::collections::BTreeMap; use crate::{ analysis::{ - DefSite, ReturnInfo, SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, VariableOrigin, + CilTarget, DefSite, MethodRef, ReturnInfo, SsaFunction, SsaInstruction, SsaOp, SsaType, + SsaVarId, VariableOrigin, }, compiler::{ pass::{PassCapability, SsaPass}, CompilerContext, EventKind, EventLog, }, metadata::{tables::MemberRefSignature, token::Token, typesystem::CilTypeReference}, - CilObject, Result, + CilObject, }; /// A candidate call site for inlining. @@ -705,7 +706,7 @@ impl InliningPass { } } -impl SsaPass for InliningPass { +impl SsaPass for InliningPass { fn name(&self) -> &'static str { "InliningPass" } @@ -729,12 +730,15 @@ impl SsaPass for InliningPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly = host + .assembly() + .ok_or_else(|| analyssa::Error::new("InliningPass requires an assembly"))?; + // Create method-specific inlining context - let mut inline_ctx = InliningContext::new(self, ssa, method_token, ctx, assembly); + let mut inline_ctx = InliningContext::new(self, ssa, method.0, host, &assembly); // Find all candidates let candidates = inline_ctx.find_candidates(); @@ -750,7 +754,7 @@ impl SsaPass for InliningPass { // Merge changes back to the analysis context let changed = inline_ctx.has_changes(); if changed { - ctx.events.merge(&inline_ctx.into_changes()); + host.events.merge(&inline_ctx.into_changes()); } Ok(changed) } diff --git a/dotscope/src/compiler/passes/licm.rs b/dotscope/src/compiler/passes/licm.rs deleted file mode 100644 index 2e7a7b65..00000000 --- a/dotscope/src/compiler/passes/licm.rs +++ /dev/null @@ -1,704 +0,0 @@ -//! Loop Invariant Code Motion (LICM) Pass. -//! -//! This pass moves computations that produce the same value on every iteration -//! out of loops. This is useful for: -//! -//! - Performance optimization -//! - Cleaning up loop-based obfuscation patterns -//! -//! # Algorithm -//! -//! An instruction is loop-invariant if: -//! 1. All its operands are defined outside the loop, OR -//! 2. All its operands are defined by loop-invariant instructions -//! -//! An instruction can be hoisted if: -//! 1. It is loop-invariant -//! 2. It has no side effects (pure computation) -//! 3. The loop has a preheader where we can place the hoisted code -//! -//! # Example -//! -//! ```text -//! // Before LICM -//! preheader: -//! a = 5 -//! b = 10 -//! jump header -//! -//! header: -//! i = phi(0, i') -//! x = a + b // Loop invariant! -//! use(x) -//! i' = i + 1 -//! branch (i < 10), header, exit -//! -//! // After LICM -//! preheader: -//! a = 5 -//! b = 10 -//! x = a + b // Hoisted -//! jump header -//! -//! header: -//! i = phi(0, i') -//! use(x) -//! i' = i + 1 -//! branch (i < 10), header, exit -//! ``` - -use std::collections::{HashMap, HashSet, VecDeque}; - -use crate::{ - analysis::{LoopAnalyzer, LoopInfo, SsaFunction, SsaInstruction, SsaOp, SsaVarId}, - compiler::{ - pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, - }, - metadata::token::Token, - utils::BitSet, - CilObject, Result, -}; - -/// Loop Invariant Code Motion Pass. -/// -/// Moves loop-invariant computations to the loop preheader. -pub struct LicmPass; - -impl Default for LicmPass { - fn default() -> Self { - Self::new() - } -} - -impl LicmPass { - /// Creates a new LICM pass. - #[must_use] - pub fn new() -> Self { - Self - } -} - -impl SsaPass for LicmPass { - fn name(&self) -> &'static str { - "licm" - } - - fn description(&self) -> &'static str { - "Moves loop-invariant computations to loop preheaders" - } - - fn modification_scope(&self) -> ModificationScope { - ModificationScope::InstructionsOnly - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let forest = LoopAnalyzer::new(ssa).analyze(); - - if forest.is_empty() { - return Ok(false); - } - - let mut total_hoisted: usize = 0; - - // Process loops from innermost to outermost. - // This naturally propagates hoists through nesting levels: inner hoists - // move instructions to the inner preheader (inside the outer loop body), - // and outer loop processing then hoists them further out if invariant. - // The Nop skip in find_loop_invariants prevents exponential blowup by - // ignoring the Nops left behind by inner hoists. - for loop_info in forest.by_depth_descending() { - // Skip loops without preheaders - we need somewhere to hoist to - let Some(preheader) = loop_info.preheader else { - continue; - }; - - // Validate that the preheader is an actual immediate predecessor of - // the loop header. If it isn't (e.g., for loops inside CFF dispatchers - // where the loop analyzer may pick a distant dominator as preheader), - // hoisting would move definitions to a block that isn't a CFG - // predecessor of the header, making phi operand updates invalid. - let header_idx = loop_info.header.index(); - let preheader_is_pred = ssa - .block(preheader.index()) - .map(|b| { - b.instructions() - .last() - .map(|i| i.op().successors().contains(&header_idx)) - .unwrap_or(false) - }) - .unwrap_or(false); - if !preheader_is_pred { - continue; - } - - // Skip loops whose header is a switch-based dispatcher (CFF pattern). - // Hoisting into the preheader of such loops adds variable definitions - // that change the stack layout at the preheader's exit. When a - // subsequent CfgModifying pass triggers rebuild_ssa(), the phi nodes - // at the switch header are reconstructed from scratch using stack-slot - // analysis. The extra definitions shift which variable occupies each - // stack slot, causing the switch to read the wrong phi — e.g., a - // hoisted constant instead of the CFF state variable. This misroutes - // the CFF tracer and drops entire code phases. - let header_has_switch = ssa - .block(header_idx) - .and_then(|b| b.terminator_op()) - .is_some_and(|op| matches!(op, SsaOp::Switch { .. })); - if header_has_switch { - continue; - } - - // Find invariant instructions - let invariants = find_loop_invariants(ssa, loop_info); - - if invariants.is_empty() { - continue; - } - - // Filter to hoistable instructions - let mut hoistable: Vec<_> = invariants - .into_iter() - .filter(|(block_idx, instr_idx)| can_hoist(ssa, loop_info, *block_idx, *instr_idx)) - .collect(); - - // Second filter: ensure all operands of hoistable instructions are - // either defined outside the loop or by other hoistable instructions. - // Without this, an instruction like Conv(v10) would be hoisted but its - // operand v10 (from ArrayLength, which is invariant but not hoistable - // due to side effects) would remain in the loop body, producing a - // use-before-def. - let mut outside_defs = BitSet::new(ssa.var_id_capacity()); - for v in ssa.variables() { - if !loop_info.body.contains(v.def_site().block) { - outside_defs.insert(v.id().index()); - } - } - - loop { - let mut hoistable_defs = BitSet::new(ssa.var_id_capacity()); - for (block_idx, instr_idx) in hoistable.iter() { - if let Some(def) = ssa - .block(*block_idx) - .and_then(|b| b.instruction(*instr_idx)) - .and_then(|i| i.def()) - { - hoistable_defs.insert(def.index()); - } - } - - let before = hoistable.len(); - hoistable.retain(|(block_idx, instr_idx)| { - let Some(block) = ssa.block(*block_idx) else { - return false; - }; - let Some(instr) = block.instruction(*instr_idx) else { - return false; - }; - instr.op().uses().iter().all(|operand| { - outside_defs.contains(operand.index()) - || hoistable_defs.contains(operand.index()) - }) - }); - - if hoistable.len() == before { - break; - } - } - - // Guard: if hoisting ALL non-terminator instructions from a block - // would make it a trampoline AND that block feeds phis at a successor, - // skip hoisting from that block entirely. Making such blocks trampolines - // causes block-merging to clear them, and subsequent rebuild_ssa may - // not correctly reconnect the phi with the preheader's definitions. - { - // Count hoistable instructions per block - let mut hoist_count_per_block: HashMap = HashMap::new(); - for (block_idx, _) in &hoistable { - let entry = hoist_count_per_block.entry(*block_idx).or_insert(0); - *entry = entry.saturating_add(1); - } - - // Find blocks that would become trampolines - let mut trampoline_blocks = BitSet::new(ssa.block_count()); - for (&block_idx, &hoist_count) in &hoist_count_per_block { - if let Some(block) = ssa.block(block_idx) { - let non_term = block - .instructions() - .iter() - .filter(|i| !i.is_terminator() && !matches!(i.op(), SsaOp::Nop)) - .count(); - if hoist_count >= non_term { - // This block would become a trampoline — check if it feeds phis - if let Some(term) = block.terminator_op() { - for succ in term.successors() { - if let Some(succ_block) = ssa.block(succ) { - if !succ_block.phi_nodes().is_empty() { - trampoline_blocks.insert(block_idx); - } - } - } - } - } - } - } - - if !trampoline_blocks.is_empty() { - hoistable.retain(|(block_idx, _)| !trampoline_blocks.contains(*block_idx)); - } - } - - if hoistable.is_empty() { - continue; - } - - // Collect instructions to hoist (we need to clone them before mutation) - let mut to_hoist: Vec<(usize, usize, SsaOp)> = Vec::new(); - for (block_idx, instr_idx) in &hoistable { - if let Some(block) = ssa.block(*block_idx) { - if let Some(instr) = block.instruction(*instr_idx) { - to_hoist.push((*block_idx, *instr_idx, instr.op().clone())); - } - } - } - - // Sort hoistable instructions by their dependency order. - // Instructions must be hoisted in the order they were originally defined - // to maintain correct dependencies. Sort by (block_idx, instr_idx). - to_hoist.sort_by_key(|(block_idx, instr_idx, _)| (*block_idx, *instr_idx)); - - // Find the insertion point in the preheader (before the terminator) - let insert_base = if let Some(preheader_block) = ssa.block(preheader.index()) { - let instrs = preheader_block.instructions(); - if instrs.is_empty() { - 0 - } else if instrs.last().is_some_and(SsaInstruction::is_terminator) { - instrs.len().saturating_sub(1) - } else { - instrs.len() - } - } else { - 0 - }; - - // Track which source blocks had ALL non-terminator instructions hoisted. - // These blocks become trampolines, and their successor phis need - // predecessor updates from the source block to the preheader. - let mut hoisted_from = BitSet::new(ssa.block_count()); - - // Apply hoisting - insert all at once to maintain order - for (i, (block_idx, instr_idx, op)) in to_hoist.iter().enumerate() { - hoisted_from.insert(*block_idx); - - // Add to preheader - if let Some(preheader_block) = ssa.block_mut(preheader.index()) { - let new_instr = SsaInstruction::synthetic(op.clone()); - let instrs = preheader_block.instructions_mut(); - instrs.insert(insert_base.saturating_add(i), new_instr); - } - - // Remove from original location (replace with Nop) - if let Some(block) = ssa.block_mut(*block_idx) { - if let Some(instr) = block.instructions_mut().get_mut(*instr_idx) { - instr.set_op(SsaOp::Nop); - } - } - - total_hoisted = total_hoisted.saturating_add(1); - } - - // Update phi operands at successor blocks. When all non-terminator - // instructions were hoisted from a block to the preheader, the block - // becomes a trampoline (just a Jump). Phi operands at successors that - // referenced the source block now need to reference the preheader - // (where the definitions live after hoisting). - let preheader_idx = preheader.index(); - for source_block in hoisted_from.iter() { - // Check if all non-terminator instructions were hoisted (block is now a trampoline) - let is_trampoline = ssa.block(source_block).is_some_and(|b| { - b.instructions() - .iter() - .all(|i| i.is_terminator() || matches!(i.op(), SsaOp::Nop)) - }); - - if !is_trampoline { - continue; - } - - // Get successors of the source block - let successors: Vec = ssa - .block(source_block) - .map(|b| { - b.instructions() - .last() - .map(|i| i.op().successors()) - .unwrap_or_default() - }) - .unwrap_or_default(); - - for succ in successors { - if let Some(succ_block) = ssa.block_mut(succ) { - for phi in succ_block.phi_nodes_mut() { - for operand in phi.operands_mut() { - if operand.predecessor() == source_block { - operand.set_predecessor(preheader_idx); - } - } - } - } - } - } - } - - if total_hoisted > 0 { - ctx.events - .record(EventKind::InstructionRemoved) - .at(method_token, 0) - .message(format!( - "LICM: hoisted {total_hoisted} loop-invariant instructions" - )); - } - - Ok(total_hoisted > 0) - } -} - -/// Finds all loop-invariant instructions in a loop. -/// -/// An instruction is loop-invariant if all its operands are: -/// - Defined outside the loop, OR -/// - Defined by loop-invariant instructions -/// -/// IMPORTANT: PHI nodes at the loop HEADER define induction variables that change -/// each iteration. Instructions using these values are NOT loop-invariant. -fn find_loop_invariants(ssa: &SsaFunction, loop_info: &LoopInfo) -> Vec<(usize, usize)> { - let mut invariants: HashSet<(usize, usize)> = HashSet::new(); - let mut invariant_defs = BitSet::new(ssa.var_id_capacity()); - - // Collect PHI-defined variables from the loop HEADER only. - // These are loop induction variables that change each iteration. - // PHIs at other loop body blocks are path merge points and don't affect invariance. - let mut header_phi_defs = BitSet::new(ssa.var_id_capacity()); - if let Some(header_block) = ssa.block(loop_info.header.index()) { - for phi in header_block.phi_nodes() { - header_phi_defs.insert(phi.result().index()); - } - } - - // Build map of variables defined outside the loop - let mut outside_defs = BitSet::new(ssa.var_id_capacity()); - for var in ssa.variables() { - let def_site = var.def_site(); - if !loop_info.body.contains(def_site.block) { - outside_defs.insert(var.id().index()); - } - } - - let mut changed = true; - while changed { - changed = false; - - for block_idx in loop_info.body.iter() { - if let Some(block) = ssa.block(block_idx) { - for (instr_idx, instr) in block.instructions().iter().enumerate() { - // Skip if already marked invariant - if invariants.contains(&(block_idx, instr_idx)) { - continue; - } - - // Skip terminators - if instr.is_terminator() { - continue; - } - - // Skip Nop instructions - they have no effect and hoisting them - // causes exponential blowup when processing nested loops - // (inner hoists create Nops which outer loops then re-hoist) - if matches!(instr.op(), SsaOp::Nop) { - continue; - } - - // Check if instruction is invariant - if is_instruction_invariant( - instr, - &outside_defs, - &invariant_defs, - &header_phi_defs, - ) { - invariants.insert((block_idx, instr_idx)); - if let Some(def) = instr.def() { - invariant_defs.insert(def.index()); - } - changed = true; - } - } - } - } - } - - invariants.into_iter().collect() -} - -/// Checks if an instruction is loop-invariant. -/// -/// An instruction is NOT loop-invariant if it uses any loop header PHI-defined variable, -/// since those represent induction variables that change each iteration. -fn is_instruction_invariant( - instr: &SsaInstruction, - outside_defs: &BitSet, - invariant_defs: &BitSet, - header_phi_defs: &BitSet, -) -> bool { - // Use the built-in uses() method to get all operands - for operand in instr.op().uses() { - // If the operand is defined by a PHI at the loop header, it's loop-varying - if header_phi_defs.contains(operand.index()) { - return false; - } - // Otherwise check if it's defined outside the loop or by an invariant instruction - if !outside_defs.contains(operand.index()) && !invariant_defs.contains(operand.index()) { - return false; - } - } - - true -} - -/// Checks if an instruction can be safely hoisted. -fn can_hoist(ssa: &SsaFunction, loop_info: &LoopInfo, block_idx: usize, instr_idx: usize) -> bool { - let Some(block) = ssa.block(block_idx) else { - return false; - }; - - let Some(instr) = block.instruction(instr_idx) else { - return false; - }; - - // Only hoist instructions that define a value - hoisting effectless - // instructions (like Nop) is pointless and causes exponential blowup - if instr.def().is_none() { - return false; - } - - // Only hoist pure computations (is_pure is defined on SsaOp) - if !instr.op().is_pure() { - return false; - } - - // Don't hoist if there's no preheader - if loop_info.preheader.is_none() { - return false; - } - - // CRITICAL: Don't hoist if this instruction's result feeds a PHI's back-edge operand. - // Hoisting such instructions would make the PHI's back-edge operand orphaned or - // self-referential, breaking the loop structure. - if let Some(dest) = instr.def() { - if feeds_phi_back_edge(ssa, loop_info, dest) { - return false; - } - } - - true -} - -/// Checks if a variable (directly or indirectly) feeds a PHI operand on an -/// intra-loop edge — i.e. a phi at any block in the loop body whose operand -/// comes from another loop body block. -/// -/// Hoisting a def in this category is unsafe because the def's value is -/// attributed to a specific CFG edge (pred → phi-block). Moving it to the -/// shared preheader makes it dominate every intra-loop edge; when -/// `rebuild_ssa` recomputes phi operands by reaching definitions, every -/// such edge sees the hoisted value, collapsing per-edge attribution. -/// -/// Typical break case: CFF state-machine dispatcher. Case blocks push -/// different `Const` values (state 5, 4, 0, 2, ...) into a phi at the -/// dispatcher. Without this guard LICM hoists all of them into the -/// single preheader, and SSA rebuild rewrites every phi operand to point -/// at whichever hoisted const happens to be on top of the version stack, -/// destroying the state machine. -fn feeds_phi_back_edge(ssa: &SsaFunction, loop_info: &LoopInfo, var: SsaVarId) -> bool { - let mut worklist: VecDeque = VecDeque::new(); - let mut visited = BitSet::new(ssa.var_id_capacity()); - - worklist.push_back(var); - visited.insert(var.index()); - - while let Some(current) = worklist.pop_front() { - // Check phis at any block in the loop body (including the header). - // An intra-loop edge is one where both the phi's block and the - // operand's predecessor are in the loop body. - for phi_block_idx in loop_info.body.iter() { - let Some(phi_block) = ssa.block(phi_block_idx) else { - continue; - }; - for phi in phi_block.phi_nodes() { - for operand in phi.operands() { - if operand.value() == current && loop_info.body.contains(operand.predecessor()) - { - return true; - } - } - } - } - - // Find instructions that use this variable and add their dests to the worklist - for body_block_idx in loop_info.body.iter() { - if let Some(body_block) = ssa.block(body_block_idx) { - for instr in body_block.instructions() { - if instr.op().uses().contains(¤t) { - if let Some(dest) = instr.def() { - if visited.insert(dest.index()) { - worklist.push_back(dest); - } - } - } - } - } - } - } - - false -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::{ConstValue, LoopAnalyzer, MethodRef, SsaFunctionBuilder, SsaOp, SsaVarId}, - compiler::{LicmPass, SsaPass}, - metadata::token::Token, - }; - - #[test] - fn test_pass_metadata() { - let pass = LicmPass::new(); - assert_eq!(pass.name(), "licm"); - assert!(!pass.description().is_empty()); - } - - #[test] - fn test_op_is_pure() { - let add_op = SsaOp::Add { - dest: SsaVarId::from_index(0), - left: SsaVarId::from_index(1), - right: SsaVarId::from_index(2), - }; - assert!(add_op.is_pure()); - - let const_op = SsaOp::Const { - dest: SsaVarId::from_index(3), - value: ConstValue::I32(42), - }; - assert!(const_op.is_pure()); - - let call_op = SsaOp::Call { - dest: Some(SsaVarId::from_index(4)), - method: MethodRef::new(Token::new(0x06000001)), - args: vec![], - }; - assert!(!call_op.is_pure()); - } - - #[test] - fn test_op_uses() { - let v1 = SsaVarId::from_index(0); - let v2 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - - let op = SsaOp::Add { - dest, - left: v1, - right: v2, - }; - let uses = op.uses(); - assert_eq!(uses.len(), 2); - assert!(uses.contains(&v1)); - assert!(uses.contains(&v2)); - - let const_op = SsaOp::Const { - dest, - value: ConstValue::I32(42), - }; - assert!(const_op.uses().is_empty()); - } - - #[test] - fn test_no_loops() { - // Function with no loops should return false - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let _ = b.const_i32(42); - b.ret(); - }); - }) - .unwrap(); - - let forest = LoopAnalyzer::new(&ssa).analyze(); - assert!(forest.is_empty()); - } - - #[test] - fn test_loop_without_preheader() { - // Loop without preheader (multiple entry edges) can't be optimized - // This creates a function where the loop header has multiple predecessors - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry with branch to different blocks - f.block(0, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - // B1: goes to loop header - f.block(1, |b| b.jump(3)); - // B2: also goes to loop header (no single preheader) - f.block(2, |b| b.jump(3)); - // B3: loop header - f.block(3, |b| { - let cond = b.const_true(); - b.branch(cond, 3, 4); // self-loop - }); - // B4: exit - f.block(4, |b| b.ret()); - }) - .unwrap(); - - let forest = LoopAnalyzer::new(&ssa).analyze(); - assert!(!forest.is_empty()); - - let loop_info = &forest.loops()[0]; - // This loop has multiple entry edges so no preheader - assert!(!loop_info.has_preheader()); - } - - #[test] - fn test_simple_loop_has_preheader() { - // Create a loop with a single preheader - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: preheader - f.block(0, |b| b.jump(1)); - // B1: header with self-loop - f.block(1, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 2); - }); - // B2: exit - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let forest = LoopAnalyzer::new(&ssa).analyze(); - assert_eq!(forest.len(), 1); - - let loop_info = &forest.loops()[0]; - assert!(loop_info.has_preheader()); - } -} diff --git a/dotscope/src/compiler/passes/loopcanon.rs b/dotscope/src/compiler/passes/loopcanon.rs deleted file mode 100644 index 568835b6..00000000 --- a/dotscope/src/compiler/passes/loopcanon.rs +++ /dev/null @@ -1,694 +0,0 @@ -//! Loop canonicalization pass. -//! -//! This pass transforms loops into canonical form to enable more effective -//! analysis and optimization. Canonical loops have: -//! -//! - **Single preheader**: A unique block that dominates the loop header and -//! through which all loop entries pass -//! - **Single latch**: A unique back edge source block -//! -//! # Why Canonicalization Matters -//! -//! Many loop optimizations (loop-invariant code motion, induction variable -//! analysis, loop unrolling) require loops to be in canonical form: -//! -//! ```text -//! Non-canonical: Canonical: -//! -//! A B A B -//! \ / \ / -//! \ / \ / -//! v v v v -//! [header] <--+ [preheader] -//! | | | -//! v | v -//! [body] | [header] <--+ -//! / \ | | | -//! v \ | v | -//! [exit] [latch1] [body] | -//! | | / \ | -//! | | v \ | -//! [latch2] [exit] [latch] -//! | -//! +---+ -//! ``` -//! -//! # Transformations -//! -//! ## Preheader Insertion -//! -//! When a loop header has multiple non-loop predecessors, we insert a preheader: -//! -//! 1. Create a new block with a jump to the header -//! 2. Redirect all non-loop predecessors to the preheader -//! 3. Update phi nodes in the header to receive values from the preheader -//! -//! ## Latch Unification -//! -//! When a loop has multiple back edges (latches), we unify them: -//! -//! 1. Create a new unified latch block with a jump to the header -//! 2. Redirect all original latches to the unified latch -//! 3. Insert phi nodes in the unified latch to merge values from original latches -//! 4. Update phi nodes in the header to receive from the unified latch -//! -//! # Phi Node Handling -//! -//! The pass carefully maintains SSA form by: -//! - Splitting phi operands when inserting preheaders -//! - Merging phi operands when unifying latches -//! - Creating new phi nodes where necessary to preserve value flow - -use std::collections::HashMap; - -use crate::{ - analysis::{ - DefSite, LoopInfo, PhiNode, PhiOperand, SsaBlock, SsaFunction, SsaInstruction, - SsaLoopAnalysis, SsaOp, SsaVarId, VariableOrigin, - }, - compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog}, - metadata::token::Token, - CilObject, Result, -}; - -/// Loop canonicalization pass. -/// -/// Transforms loops into canonical form with single preheaders and single latches. -/// This enables more effective loop analysis and optimization. -/// -/// # Example -/// -/// ```rust,ignore -/// use dotscope::compiler::LoopCanonicalizationPass; -/// -/// let pass = LoopCanonicalizationPass::new(); -/// let changes = pass.run_on_method(&mut ssa, method_token, &ctx)?; -/// ``` -pub struct LoopCanonicalizationPass; - -impl Default for LoopCanonicalizationPass { - fn default() -> Self { - Self::new() - } -} - -impl LoopCanonicalizationPass { - /// Creates a new loop canonicalization pass. - #[must_use] - pub fn new() -> Self { - Self - } - - /// Canonicalizes all loops in the SSA function. - /// - /// Returns the number of loops that were modified. - fn canonicalize_loops( - ssa: &mut SsaFunction, - method_token: Token, - changes: &mut EventLog, - ) -> usize { - let mut total_modified: usize = 0; - - // We need to iterate until no more changes because inserting blocks - // can affect loop structure - loop { - let forest = ssa.analyze_loops(); - - if forest.is_empty() { - break; - } - - let mut modified_this_iteration: usize = 0; - - // Process loops from innermost to outermost to avoid invalidating - // parent loop analysis when modifying inner loops - for loop_info in forest.by_depth_descending() { - // Check if this loop needs a preheader - if !loop_info.has_preheader() { - let non_loop_preds = Self::get_non_loop_predecessors(ssa, loop_info); - if non_loop_preds.len() > 1 { - Self::insert_preheader( - ssa, - loop_info, - &non_loop_preds, - method_token, - changes, - ); - modified_this_iteration = modified_this_iteration.saturating_add(1); - // After inserting a preheader, we need to re-analyze loops - break; - } - } - - // Check if this loop needs latch unification - if !loop_info.has_single_latch() && loop_info.latches.len() > 1 { - Self::unify_latches(ssa, loop_info, method_token, changes); - modified_this_iteration = modified_this_iteration.saturating_add(1); - // After unifying latches, we need to re-analyze loops - break; - } - } - - total_modified = total_modified.saturating_add(modified_this_iteration); - - if modified_this_iteration == 0 { - break; - } - } - - total_modified - } - - /// Gets the non-loop predecessor block indices for a loop header. - fn get_non_loop_predecessors(ssa: &SsaFunction, loop_info: &LoopInfo) -> Vec { - let header_idx = loop_info.header.index(); - let mut non_loop_preds = Vec::new(); - - // Find all blocks that jump to the header - for (block_idx, block) in ssa.iter_blocks() { - if let Some(op) = block.terminator_op() { - let targets = Self::get_targets(op); - if targets.contains(&header_idx) && !loop_info.body.contains(block_idx) { - non_loop_preds.push(block_idx); - } - } - } - - non_loop_preds - } - - /// Extracts all target block indices from a terminator operation. - fn get_targets(op: &SsaOp) -> Vec { - match op { - SsaOp::Jump { target } | SsaOp::Leave { target } => vec![*target], - SsaOp::Branch { - true_target, - false_target, - .. - } => vec![*true_target, *false_target], - SsaOp::Switch { - targets, default, .. - } => { - let mut all = targets.clone(); - all.push(*default); - all - } - _ => vec![], - } - } - - /// Inserts a preheader block for a loop. - /// - /// Creates a new block that becomes the single entry point into the loop, - /// redirecting all non-loop predecessors through it. - fn insert_preheader( - ssa: &mut SsaFunction, - loop_info: &LoopInfo, - non_loop_preds: &[usize], - method_token: Token, - changes: &mut EventLog, - ) { - let header_idx = loop_info.header.index(); - let preheader_idx = ssa.block_count(); - - // Step 1: Create the preheader block with a jump to the header - let mut preheader = SsaBlock::new(preheader_idx); - preheader.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { - target: header_idx, - })); - - // Step 2: If the header has phi nodes, we need to handle them carefully. - // The preheader needs to forward values from non-loop predecessors. - // We'll create phi nodes in the preheader if there are multiple non-loop preds, - // or just forward the single value if there's only one. - // - // Collect phi info first, then allocate variables (needs &mut ssa). - let phi_info: Vec<(VariableOrigin, Vec)> = ssa - .block(header_idx) - .map(|header| { - header - .phi_nodes() - .iter() - .filter_map(|phi| { - let non_loop_operands: Vec<_> = phi - .operands() - .iter() - .filter(|op| non_loop_preds.contains(&op.predecessor())) - .copied() - .collect(); - if non_loop_operands.len() > 1 { - Some((phi.origin(), non_loop_operands)) - } else { - None - } - }) - .collect() - }) - .unwrap_or_default(); - - for (origin, operands) in &phi_info { - let new_var = ssa.create_variable_for_origin(*origin, 0, DefSite::phi(preheader_idx)); - let mut preheader_phi = PhiNode::new(new_var, *origin); - for op in operands { - preheader_phi.add_operand(*op); - } - preheader.phi_nodes_mut().push(preheader_phi); - } - - // Step 3: Add the preheader to the function - ssa.add_block(preheader); - - // Step 4: Redirect all non-loop predecessors to the preheader - for &pred_idx in non_loop_preds { - Self::redirect_targets(ssa, pred_idx, header_idx, preheader_idx); - } - - // Step 5: First, collect information about preheader phis - let preheader_phi_map: HashMap = ssa - .block(preheader_idx) - .map(|b| { - b.phi_nodes() - .iter() - .map(|p| (p.origin(), p.result())) - .collect() - }) - .unwrap_or_default(); - - // Step 6: Update phi nodes in the header - if let Some(header) = ssa.block_mut(header_idx) { - for phi in header.phi_nodes_mut() { - let origin = phi.origin(); - let operands = phi.operands_mut(); - let mut loop_operands: Vec = Vec::new(); - let mut non_loop_values: Vec = Vec::new(); - - for op in operands.drain(..) { - if non_loop_preds.contains(&op.predecessor()) { - non_loop_values.push(op); - } else { - loop_operands.push(op); - } - } - - // Keep loop operands as-is - operands.extend(loop_operands); - - // For non-loop values: if there was a phi created in preheader, - // reference that phi's result; otherwise reference the single value - if !non_loop_values.is_empty() { - if let [single] = non_loop_values.as_slice() { - // Single non-loop predecessor: just update the predecessor - operands.push(PhiOperand::new(single.value(), preheader_idx)); - } else if let Some(&preheader_var) = preheader_phi_map.get(&origin) { - // Multiple non-loop predecessors: use the phi we created in preheader - operands.push(PhiOperand::new(preheader_var, preheader_idx)); - } - } - } - } - - changes - .record(EventKind::ControlFlowRestructured) - .at(method_token, preheader_idx) - .message(format!( - "Inserted preheader B{preheader_idx} for loop at B{header_idx}" - )); - } - - /// Unifies multiple latch blocks into a single latch. - /// - /// Creates a new unified latch block and redirects all original latches to it. - fn unify_latches( - ssa: &mut SsaFunction, - loop_info: &LoopInfo, - method_token: Token, - changes: &mut EventLog, - ) { - let header_idx = loop_info.header.index(); - let latches: Vec = loop_info.latches.iter().map(|n| n.index()).collect(); - let unified_latch_idx = ssa.block_count(); - - // Step 1: Create the unified latch block - let mut unified_latch = SsaBlock::new(unified_latch_idx); - unified_latch.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { - target: header_idx, - })); - - // Step 2: If the header has phi nodes with operands from multiple latches, - // we need to create phi nodes in the unified latch to merge those values. - // Collect phi info first, then allocate variables (needs &mut ssa). - let mut latch_phi_vars: HashMap = HashMap::new(); - - let phi_info: Vec<(VariableOrigin, Vec)> = ssa - .block(header_idx) - .map(|header| { - header - .phi_nodes() - .iter() - .map(|phi| { - let latch_operands: Vec<_> = phi - .operands() - .iter() - .filter(|op| latches.contains(&op.predecessor())) - .copied() - .collect(); - (phi.origin(), latch_operands) - }) - .collect() - }) - .unwrap_or_default(); - - for (origin, latch_operands) in &phi_info { - if latch_operands.len() > 1 { - // Need a phi node in the unified latch — allocate a real variable - let new_var = - ssa.create_variable_for_origin(*origin, 0, DefSite::phi(unified_latch_idx)); - let mut latch_phi = PhiNode::new(new_var, *origin); - for op in latch_operands { - latch_phi.add_operand(*op); - } - latch_phi_vars.insert(*origin, new_var); - unified_latch.phi_nodes_mut().push(latch_phi); - } else if let [single] = latch_operands.as_slice() { - // Single latch operand - just remember its value - latch_phi_vars.insert(*origin, single.value()); - } - } - - // Step 3: Add the unified latch to the function - ssa.add_block(unified_latch); - - // Step 4: Redirect all original latches to the unified latch instead of header - for &latch_idx in &latches { - Self::redirect_targets(ssa, latch_idx, header_idx, unified_latch_idx); - } - - // Step 5: Update phi nodes in the header to reference the unified latch - if let Some(header) = ssa.block_mut(header_idx) { - for phi in header.phi_nodes_mut() { - let origin = phi.origin(); - let operands = phi.operands_mut(); - - // Remove operands from original latches - operands.retain(|op| !latches.contains(&op.predecessor())); - - // Add operand from unified latch - if let Some(&var) = latch_phi_vars.get(&origin) { - operands.push(PhiOperand::new(var, unified_latch_idx)); - } - } - } - - changes - .record(EventKind::ControlFlowRestructured) - .at(method_token, unified_latch_idx) - .message(format!( - "Unified {} latches into B{} for loop at B{}", - latches.len(), - unified_latch_idx, - header_idx - )); - } - - /// Redirects branch targets in a block from old_target to new_target. - fn redirect_targets( - ssa: &mut SsaFunction, - block_idx: usize, - old_target: usize, - new_target: usize, - ) { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(last) = block.instructions_mut().last_mut() { - let new_op = match last.op() { - SsaOp::Jump { target } if *target == old_target => { - Some(SsaOp::Jump { target: new_target }) - } - SsaOp::Leave { target } if *target == old_target => { - Some(SsaOp::Leave { target: new_target }) - } - SsaOp::Branch { - condition, - true_target, - false_target, - } => { - let new_true = if *true_target == old_target { - new_target - } else { - *true_target - }; - let new_false = if *false_target == old_target { - new_target - } else { - *false_target - }; - if new_true != *true_target || new_false != *false_target { - Some(SsaOp::Branch { - condition: *condition, - true_target: new_true, - false_target: new_false, - }) - } else { - None - } - } - SsaOp::Switch { - value, - targets, - default, - } => { - let new_targets: Vec<_> = targets - .iter() - .map(|&t| if t == old_target { new_target } else { t }) - .collect(); - let new_default = if *default == old_target { - new_target - } else { - *default - }; - if new_targets != *targets || new_default != *default { - Some(SsaOp::Switch { - value: *value, - targets: new_targets, - default: new_default, - }) - } else { - None - } - } - _ => None, - }; - - if let Some(new_op) = new_op { - last.set_op(new_op); - } - } - } - } -} - -impl SsaPass for LoopCanonicalizationPass { - fn name(&self) -> &'static str { - "LoopCanonicalization" - } - - fn description(&self) -> &'static str { - "Transforms loops into canonical form with single preheaders and single latches" - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - // Skip very small functions (no loops possible) - if ssa.block_count() < 2 { - return Ok(false); - } - - let mut changes = EventLog::new(); - let modified = Self::canonicalize_loops(ssa, method_token, &mut changes); - - if modified > 0 { - // Canonicalize the function to clean up and renumber blocks - ssa.canonicalize(); - ctx.events.merge(&changes); - return Ok(true); - } - - Ok(false) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::{SsaFunctionBuilder, SsaLoopAnalysis}, - compiler::{passes::loopcanon::LoopCanonicalizationPass, EventLog}, - metadata::token::Token, - }; - - #[test] - fn test_preheader_insertion() { - // Create a loop with two entry points: - // B0 (entry) -> B1 or B2 - // B1 -> B3 (header) - // B2 -> B3 (header) - // B3 -> B4 (body) - // B4 -> B3 (back edge) or B5 (exit) - // B5: return - - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry - branch to B1 or B2 - f.block(0, |b| { - let cond0 = b.const_true(); - b.branch(cond0, 1, 2); - }); - // B1: first path to header - f.block(1, |b| b.jump(3)); - // B2: second path to header - f.block(2, |b| b.jump(3)); - // B3: header, jump to body - f.block(3, |b| b.jump(4)); - // B4: body with back edge or exit - f.block(4, |b| { - let cond1 = b.const_true(); - b.branch(cond1, 3, 5); // back edge to 3, exit to 5 - }); - // B5: exit - f.block(5, |b| b.ret()); - }) - .unwrap(); - - // Verify loop exists but doesn't have preheader - let forest = ssa.analyze_loops(); - assert_eq!(forest.len(), 1); - let loop_info = &forest.loops()[0]; - assert!(!loop_info.has_preheader()); - - // Run canonicalization - let mut changes = EventLog::new(); - let modified = - LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes); - - assert!(modified > 0); - - // Verify loop now has preheader - let forest = ssa.analyze_loops(); - assert_eq!(forest.len(), 1); - let loop_info = &forest.loops()[0]; - assert!(loop_info.has_preheader()); - } - - #[test] - fn test_latch_unification() { - // Create a loop with two back edges: - // B0 -> B1 (header) - // B1 -> B2, B3 - // B2 -> B1 (back edge 1) - // B3 -> B1 (back edge 2), B4 - // B4: exit - - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: entry - f.block(0, |b| b.jump(1)); - // B1: header - f.block(1, |b| { - let cond1 = b.const_true(); - b.branch(cond1, 2, 3); - }); - // B2: latch 1 (back to header) - f.block(2, |b| b.jump(1)); - // B3: latch 2 or exit - f.block(3, |b| { - let cond2 = b.const_true(); - b.branch(cond2, 1, 4); // back edge to 1, exit to 4 - }); - // B4: exit - f.block(4, |b| b.ret()); - }) - .unwrap(); - - // Verify loop has multiple latches - let forest = ssa.analyze_loops(); - assert_eq!(forest.len(), 1); - let loop_info = &forest.loops()[0]; - assert!(!loop_info.has_single_latch()); - assert!(loop_info.latches.len() >= 2); - - // Run canonicalization - let mut changes = EventLog::new(); - let modified = - LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes); - - assert!(modified > 0); - - // Verify loop now has single latch - let forest = ssa.analyze_loops(); - assert_eq!(forest.len(), 1); - let loop_info = &forest.loops()[0]; - assert!(loop_info.has_single_latch()); - } - - #[test] - fn test_already_canonical_loop() { - // Create a canonical loop: - // B0 -> B1 (header) - single entry - // B1 -> B2 - // B2 -> B1 (single back edge), B3 - // B3: exit - - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - // B0: preheader - f.block(0, |b| b.jump(1)); - // B1: header - f.block(1, |b| b.jump(2)); - // B2: body/latch - f.block(2, |b| { - let cond = b.const_true(); - b.branch(cond, 1, 3); // back edge to 1, exit to 3 - }); - // B3: exit - f.block(3, |b| b.ret()); - }) - .unwrap(); - - // Verify loop is already canonical - let forest = ssa.analyze_loops(); - assert_eq!(forest.len(), 1); - let loop_info = &forest.loops()[0]; - assert!(loop_info.is_canonical()); - - // Run canonicalization - should make no changes - let mut changes = EventLog::new(); - let modified = - LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes); - - assert_eq!(modified, 0); - } - - #[test] - fn test_no_loops() { - // Linear flow: B0 -> B1 -> B2 - let mut ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| b.jump(1)); - f.block(1, |b| b.jump(2)); - f.block(2, |b| b.ret()); - }) - .unwrap(); - - let forest = ssa.analyze_loops(); - assert!(forest.is_empty()); - - let mut changes = EventLog::new(); - let modified = - LoopCanonicalizationPass::canonicalize_loops(&mut ssa, Token::new(0), &mut changes); - - assert_eq!(modified, 0); - } -} diff --git a/dotscope/src/compiler/passes/mod.rs b/dotscope/src/compiler/passes/mod.rs index 02902081..ba0bb4a7 100644 --- a/dotscope/src/compiler/passes/mod.rs +++ b/dotscope/src/compiler/passes/mod.rs @@ -77,38 +77,31 @@ //! - **Loop Analysis** ([`LoopAnalyzer`](crate::analysis::LoopAnalyzer)): Used by //! loop canonicalization to identify and restructure loops -mod algebraic; -mod blockmerge; mod constants; -mod controlflow; mod copying; mod deadcode; -mod gvn; mod inlining; -mod licm; -mod loopcanon; -mod predicates; mod proxy; -mod ranges; -mod reassociate; mod strength; -mod threading; -mod utils; -// Re-export passes for public API (may not be used internally but exposed for crate users) -pub use self::algebraic::AlgebraicSimplificationPass; -pub use self::blockmerge::BlockMergingPass; +// Re-export the analyssa-supplied target-agnostic pass structs at the +// dotscope namespace so existing callers (registered via +// `compiler::PassScheduler`) keep working. +pub use analyssa::passes::{ + AlgebraicSimplificationPass, BlockMergingPass, ControlFlowSimplificationPass, + DeadCodeEliminationPass, GlobalValueNumberingPass, JumpThreadingPass, LicmPass, + LoopCanonicalizationPass, OpaquePredicatePass, PredicateResult, ReassociationPass, + ValueRangePropagationPass, +}; + +// CIL-specific pass impls remain dotscope-side. `copying` and `strength` +// keep their custom hooks (CIL local-type propagation, interprocedural +// value ranges); `deadcode` keeps its CIL-call-graph-aware +// `DeadMethodEliminationPass`; `constants`, `inlining`, `proxy` are +// CIL-specific in their entirety. pub use self::constants::ConstantPropagationPass; -pub use self::controlflow::ControlFlowSimplificationPass; pub use self::copying::CopyPropagationPass; -pub use self::deadcode::{DeadCodeEliminationPass, DeadMethodEliminationPass}; -pub use self::gvn::GlobalValueNumberingPass; +pub use self::deadcode::DeadMethodEliminationPass; pub use self::inlining::InliningPass; -pub use self::licm::LicmPass; -pub use self::loopcanon::LoopCanonicalizationPass; -pub use self::predicates::{OpaquePredicatePass, PredicateResult}; pub use self::proxy::ProxyDevirtualizationPass; -pub use self::ranges::ValueRangePropagationPass; -pub use self::reassociate::ReassociationPass; pub use self::strength::StrengthReductionPass; -pub use self::threading::JumpThreadingPass; diff --git a/dotscope/src/compiler/passes/predicates.rs b/dotscope/src/compiler/passes/predicates.rs deleted file mode 100644 index e17253e6..00000000 --- a/dotscope/src/compiler/passes/predicates.rs +++ /dev/null @@ -1,2308 +0,0 @@ -//! Opaque predicate detection and removal pass. -//! -//! Opaque predicates are conditional expressions that always evaluate to the same -//! value at runtime, but appear complex to static analysis. Obfuscators use them -//! to confuse decompilers and analysis tools. -//! -//! # Detection Strategies -//! -//! ## Basic Patterns -//! - **Self-comparison**: `x == x`, `x != x`, `x < x`, `x > x` -//! - **Identity operations**: `x ^ x == 0`, `x - x == 0` -//! - **Zero operations**: `x * 0`, `x & 0`, `x % 1` -//! -//! ## Number-Theoretic Predicates -//! - **Consecutive integers**: `(x * (x + 1)) % 2 == 0` (always true) -//! - **Square properties**: `x² >= 0` (always true for integers) -//! - **Modular arithmetic**: `(x² - x) % 2 == 0` (always true) -//! -//! ## Type-Based Predicates -//! - **Null checks**: `obj != null` after `newobj` (always true) -//! - **Array length**: `arr.Length >= 0` (always true) -//! -//! ## Range-Based Predicates -//! - **Unsigned bounds**: `unsigned_x >= 0` (always true) -//! - **Correlated conditions**: `if (x > 5) { if (x < 3) { dead } }` -//! -//! # Example -//! -//! Before: -//! ```text -//! v0 = 5 -//! v1 = ceq v0, v0 // Always true -//! branch v1, B1, B2 // Always goes to B1 -//! ``` -//! -//! After: -//! ```text -//! v0 = 5 -//! v1 = true -//! jump B1 -//! ``` - -use std::collections::BTreeMap; - -use crate::{ - analysis::{ - ConstValue, DefUseIndex, SsaEvaluator, SsaFunction, SsaInstruction, SsaOp, SsaVarId, - ValueRange, - }, - compiler::{ - pass::{PassCapability, SsaPass}, - CompilerContext, EventKind, EventLog, - }, - metadata::{token::Token, typesystem::PointerSize}, - utils::BitSet, - CilObject, Result, -}; - -/// Result of analyzing a potential opaque predicate. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum PredicateResult { - /// The predicate always evaluates to true. - AlwaysTrue, - /// The predicate always evaluates to false. - AlwaysFalse, - /// Cannot determine the predicate's value. - Unknown, -} - -impl PredicateResult { - /// Converts to an optional boolean. - /// - /// # Returns - /// - /// `Some(true)` for `AlwaysTrue`, `Some(false)` for `AlwaysFalse`, `None` for `Unknown`. - #[must_use] - pub fn as_bool(self) -> Option { - match self { - Self::AlwaysTrue => Some(true), - Self::AlwaysFalse => Some(false), - Self::Unknown => None, - } - } - - /// Negates the predicate result. - /// - /// `AlwaysTrue` becomes `AlwaysFalse` and vice versa. `Unknown` stays `Unknown`. - #[must_use] - pub fn negate(self) -> Self { - match self { - Self::AlwaysTrue => Self::AlwaysFalse, - Self::AlwaysFalse => Self::AlwaysTrue, - Self::Unknown => Self::Unknown, - } - } -} - -/// Result of analyzing a comparison for algebraic simplification. -/// -/// Unlike `PredicateResult` which determines if a comparison is always true/false, -/// this enum represents transformations that simplify comparisons while preserving -/// their runtime behavior. -#[derive(Debug, Clone)] -enum ComparisonSimplification { - /// Replace with a simpler comparison operation. - SimplerOp { new_op: SsaOp, reason: &'static str }, - /// Replace with a copy of another variable (e.g., `(cmp) == 1` → `cmp`). - Copy { - dest: SsaVarId, - src: SsaVarId, - reason: &'static str, - }, -} - -/// Cached definition information for efficient predicate analysis. -/// -/// Wraps [`DefUseIndex`] for basic definition lookups and augments it with -/// specialized tracking that `DefUseIndex` does not provide: phi-defined -/// variables, non-null provenance (from `NewObj`/`NewArr`/`Box`/`LoadToken`), -/// array-length provenance, and computed [`ValueRange`]s for constant and -/// non-negative variables. -struct DefinitionCache { - /// Index for definition lookups (block, instruction, operation). - index: DefUseIndex, - /// Variables defined by phi nodes. - phi_defs: BitSet, - /// Variables that are known to be non-null (produced by `NewObj`, `NewArr`, `Box`, or `LoadToken`). - non_null_vars: BitSet, - /// Variables that come from `ArrayLength` operations. - array_length_vars: BitSet, - /// Computed value ranges for variables (constants get exact ranges, array lengths get non-negative). - ranges: BTreeMap, -} - -impl DefinitionCache { - /// Builds the definition cache from an SSA function. - /// - /// Performs a single pass over all blocks to populate: - /// - `index`: delegated to [`DefUseIndex::build_with_ops`] for var-to-op mapping. - /// - `phi_defs`: bitset of variables defined by phi nodes (not tracked by `DefUseIndex`). - /// - `non_null_vars`: bitset of variables produced by `NewObj`, `NewArr`, `Box`, or `LoadToken`. - /// - `array_length_vars`: bitset of variables from `ArrayLength` ops. - /// - `ranges`: [`ValueRange::constant`] for `Const` ops, [`ValueRange::non_negative`] for - /// `ArrayLength` ops. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to build the cache from. - /// - /// # Returns - /// - /// A fully populated `DefinitionCache` with definition index, phi-def tracking, - /// non-null provenance, array-length provenance, and value ranges. - fn build(ssa: &SsaFunction) -> Self { - // Use DefUseIndex for basic definition tracking - let index = DefUseIndex::build_with_ops(ssa); - - let var_count = ssa.variable_count(); - let mut phi_defs = BitSet::new(var_count); - let mut non_null_vars = BitSet::new(var_count); - let mut array_length_vars = BitSet::new(var_count); - let mut ranges = BTreeMap::new(); - - for (_block_idx, block) in ssa.iter_blocks() { - // Process phi nodes (not covered by DefUseIndex) - for phi in block.phi_nodes() { - phi_defs.insert(phi.result().index()); - } - - // Process instructions for specialized tracking - for instr in block.instructions() { - let op = instr.op(); - if let Some(dest) = op.dest() { - // Track non-null producing operations and value ranges - match op { - SsaOp::NewObj { .. } - | SsaOp::NewArr { .. } - | SsaOp::Box { .. } - | SsaOp::LoadToken { .. } => { - // Non-null tracked separately (not a numeric range) - non_null_vars.insert(dest.index()); - } - SsaOp::ArrayLength { .. } => { - array_length_vars.insert(dest.index()); - ranges.insert(dest, ValueRange::non_negative()); - } - SsaOp::Const { value, .. } => { - if let Some(v) = value.as_i64() { - ranges.insert(dest, ValueRange::constant(v)); - } - } - _ => {} - } - } - } - } - - Self { - index, - phi_defs, - non_null_vars, - array_length_vars, - ranges, - } - } - - /// Gets the defining operation for a variable. - fn get_definition(&self, var: SsaVarId) -> Option<&SsaOp> { - self.index.def_op(var) - } - - /// Checks if a variable is defined by a phi node. - fn is_phi_defined(&self, var: SsaVarId) -> bool { - self.phi_defs.contains(var.index()) - } - - /// Checks if a variable is known to be non-null. - fn is_non_null(&self, var: SsaVarId) -> bool { - self.non_null_vars.contains(var.index()) - } - - /// Gets the value range for a variable. - fn get_range(&self, var: SsaVarId) -> Option<&ValueRange> { - self.ranges.get(&var) - } -} - -/// Opaque predicate detection and removal pass. -pub struct OpaquePredicatePass; - -impl Default for OpaquePredicatePass { - fn default() -> Self { - Self::new() - } -} - -impl OpaquePredicatePass { - /// Creates a new opaque predicate pass. - #[must_use] - pub fn new() -> Self { - Self - } - - /// Maximum recursion depth for nested predicate analysis. - /// - /// Each level corresponds to one SSA instruction defining a comparison result - /// that feeds into another comparison. 16 levels handles deeply nested opaque - /// predicates from advanced obfuscators (e.g., PureLogs multi-level chains) - /// while preventing stack overflow on pathological inputs. - const MAX_PREDICATE_DEPTH: usize = 16; - - /// Analyzes a predicate operation with full context, dispatching by `SsaOp` kind. - /// - /// Pattern-matching cascade: - /// 1. **Self-comparison** (`Ceq`/`Clt`/`Cgt` where `left == right`): immediate result. - /// 2. **Equality analysis** (`Ceq`): delegates to [`analyze_equality`](Self::analyze_equality) - /// for XOR==0, SUB==0, MUL*0==0, AND&0==0, number-theoretic, constant, null, and nested patterns. - /// 3. **Less-than / greater-than** (`Clt`/`Cgt`): delegates to range-based and constant analysis. - /// 4. **Zero-producing ops** (`Xor`/`Sub` with `left==right`): returns `Unknown` since the - /// result only becomes meaningful when used in a comparison (handled at that level). - /// 5. **Remainder / Multiplication / And**: delegates to specialized analyzers. - /// - /// Supports recursion up to [`MAX_PREDICATE_DEPTH`](Self::MAX_PREDICATE_DEPTH) for nested - /// predicates (e.g., `ceq(ceq(x, x), 1)`). - /// - /// # Arguments - /// - /// * `op` - The SSA operation to analyze (typically a comparison or arithmetic op). - /// * `cache` - Pre-built definition cache for efficient variable resolution. - /// * `depth` - Current recursion depth (0 at the top level). - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] or [`AlwaysFalse`](PredicateResult::AlwaysFalse) if the - /// predicate can be statically determined, [`Unknown`](PredicateResult::Unknown) otherwise. - fn analyze_predicate_with_cache( - op: &SsaOp, - cache: &DefinitionCache, - depth: usize, - ) -> PredicateResult { - if depth > Self::MAX_PREDICATE_DEPTH { - return PredicateResult::Unknown; - } - - match op { - // Self-comparison patterns - SsaOp::Ceq { left, right, .. } => { - if left == right { - return PredicateResult::AlwaysTrue; - } - Self::analyze_equality(*left, *right, cache, depth) - } - - SsaOp::Clt { - left, - right, - unsigned, - .. - } => { - if left == right { - return PredicateResult::AlwaysFalse; - } - Self::analyze_less_than(*left, *right, *unsigned, cache, depth) - } - - SsaOp::Cgt { - left, - right, - unsigned, - .. - } => { - if left == right { - return PredicateResult::AlwaysFalse; - } - Self::analyze_greater_than(*left, *right, *unsigned, cache, depth) - } - - // Operations that produce zero - SsaOp::Xor { left, right, .. } if left == right => { - // x ^ x = 0, handled when used in comparison - PredicateResult::Unknown - } - - SsaOp::Sub { left, right, .. } if left == right => { - // x - x = 0, handled when used in comparison - PredicateResult::Unknown - } - - SsaOp::Rem { left, right, .. } => Self::analyze_remainder(*left, *right, cache, depth), - - SsaOp::Mul { left, right, .. } => { - Self::analyze_multiplication(*left, *right, cache, depth) - } - - SsaOp::And { left, right, .. } => Self::analyze_and(*left, *right, cache, depth), - - _ => PredicateResult::Unknown, - } - } - - /// Analyzes an equality comparison (`Ceq`) for opaque predicate patterns. - /// - /// Checks the following patterns (each with symmetric left/right variants): - /// - `(x ^ x) == 0` -- XOR self-cancellation, always true. - /// - `(x - x) == 0` -- subtraction self-cancellation, always true. - /// - `(x * 0) == 0` or `(0 * x) == 0` -- zero-producing multiplication, always true. - /// - `(x & 0) == 0` or `(0 & x) == 0` -- zero-producing AND, always true. - /// - Number-theoretic: `(x*(x+1)) % 2 == 0` and factored forms, always true. - /// - Constant equality: both sides are constants with known values. - /// - Null checks: non-null variable (from `NewObj` etc.) compared to null, always false. - /// - **Nested analysis fallback**: if the left operand is itself a predicate (comparison), - /// recursively analyzes it. If the result is known and compared to 1, returns that result; - /// if compared to 0, returns the negation. - /// - /// # Arguments - /// - /// * `left` - Left operand of the `Ceq`. - /// * `right` - Right operand of the `Ceq`. - /// * `cache` - Definition cache for resolving variable definitions. - /// * `depth` - Current recursion depth for nested predicate analysis. - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] or [`AlwaysFalse`](PredicateResult::AlwaysFalse) if the - /// equality can be statically determined, [`Unknown`](PredicateResult::Unknown) otherwise. - fn analyze_equality( - left: SsaVarId, - right: SsaVarId, - cache: &DefinitionCache, - depth: usize, - ) -> PredicateResult { - let left_def = cache.get_definition(left); - let right_def = cache.get_definition(right); - - // Check for (x ^ x) == 0 pattern - if let Some(SsaOp::Xor { - left: xl, - right: xr, - .. - }) = left_def - { - if xl == xr { - if let Some(r) = right_def { - if Self::is_zero_constant(r) { - return PredicateResult::AlwaysTrue; - } - } - } - } - - // Symmetric check - if let Some(SsaOp::Xor { - left: xl, - right: xr, - .. - }) = right_def - { - if xl == xr { - if let Some(l) = left_def { - if Self::is_zero_constant(l) { - return PredicateResult::AlwaysTrue; - } - } - } - } - - // Check for (x - x) == 0 pattern - if let Some(SsaOp::Sub { - left: sl, - right: sr, - .. - }) = left_def - { - if sl == sr { - if let Some(r) = right_def { - if Self::is_zero_constant(r) { - return PredicateResult::AlwaysTrue; - } - } - } - } - - // Symmetric check - if let Some(SsaOp::Sub { - left: sl, - right: sr, - .. - }) = right_def - { - if sl == sr { - if let Some(l) = left_def { - if Self::is_zero_constant(l) { - return PredicateResult::AlwaysTrue; - } - } - } - } - - // Check for (x * 0) == 0 pattern - if Self::is_zero_producing_mul(left_def, cache) { - if let Some(r) = right_def { - if Self::is_zero_constant(r) { - return PredicateResult::AlwaysTrue; - } - } - } - - // Symmetric check - if Self::is_zero_producing_mul(right_def, cache) { - if let Some(l) = left_def { - if Self::is_zero_constant(l) { - return PredicateResult::AlwaysTrue; - } - } - } - - // Check for (x & 0) == 0 pattern - if Self::is_zero_producing_and(left_def, cache) { - if let Some(r) = right_def { - if Self::is_zero_constant(r) { - return PredicateResult::AlwaysTrue; - } - } - } - - // Symmetric check - if Self::is_zero_producing_and(right_def, cache) { - if let Some(l) = left_def { - if Self::is_zero_constant(l) { - return PredicateResult::AlwaysTrue; - } - } - } - - // Check for number-theoretic predicates that always evaluate to zero: - // (x * (x + 1)) % 2 == 0 — consecutive integer product is always even - // (x * x - x) % 2 == 0 — x²-x = x(x-1), consecutive product factored - if Self::is_always_even_expression(left_def, cache) { - if let Some(r) = right_def { - if Self::is_zero_constant(r) { - return PredicateResult::AlwaysTrue; - } - } - } - - // Check constant equality - if let (Some(SsaOp::Const { value: lval, .. }), Some(SsaOp::Const { value: rval, .. })) = - (left_def, right_def) - { - if let (Some(l), Some(r)) = (lval.as_i64(), rval.as_i64()) { - return if l == r { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } - - // Check non-null equality with null - if cache.is_non_null(left) { - if let Some(r) = right_def { - if Self::is_null_constant(r) { - return PredicateResult::AlwaysFalse; - } - } - } - - if cache.is_non_null(right) { - if let Some(l) = left_def { - if Self::is_null_constant(l) { - return PredicateResult::AlwaysFalse; - } - } - } - - // Nested analysis - if let Some(left_op) = left_def { - let left_result = - Self::analyze_predicate_with_cache(left_op, cache, depth.saturating_add(1)); - if left_result != PredicateResult::Unknown { - if let Some(r) = right_def { - if Self::is_one_constant(r) { - return left_result; - } - if Self::is_zero_constant(r) { - return left_result.negate(); - } - } - } - } - - PredicateResult::Unknown - } - - /// Analyzes a less-than comparison (`Clt`) for opaque predicate patterns. - /// - /// Checks in order: - /// 1. **Constant comparison**: both operands are constants, evaluate directly (signed or unsigned). - /// 2. **Range-based**: if both operands have known [`ValueRange`]s, checks whether - /// `left.max < right.min` (always true) or `left.min >= right.max` (always false). - /// 3. **Left range vs. constant right**: uses [`ValueRange::always_less_than`]. - /// 4. **Unsigned bounds**: `x <.un 0` is always false (no unsigned value is less than zero). - /// 5. **Non-negative check**: if `left` is known non-negative (e.g., `ArrayLength`), - /// then `left < 0` is always false. - /// - /// # Arguments - /// - /// * `left` - Left operand of the comparison. - /// * `right` - Right operand of the comparison. - /// * `unsigned` - Whether this is an unsigned comparison (`clt.un`). - /// * `cache` - Definition cache for resolving variable definitions and ranges. - /// * `_depth` - Unused (less-than analysis is non-recursive). - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] or [`AlwaysFalse`](PredicateResult::AlwaysFalse) if the - /// less-than comparison can be statically determined, - /// [`Unknown`](PredicateResult::Unknown) otherwise. - fn analyze_less_than( - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - cache: &DefinitionCache, - _depth: usize, - ) -> PredicateResult { - let left_def = cache.get_definition(left); - let right_def = cache.get_definition(right); - - // Constant comparison - if let (Some(SsaOp::Const { value: lval, .. }), Some(SsaOp::Const { value: rval, .. })) = - (left_def, right_def) - { - if unsigned { - if let (Some(l), Some(r)) = (lval.as_u64(), rval.as_u64()) { - return if l < r { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } else if let (Some(l), Some(r)) = (lval.as_i64(), rval.as_i64()) { - return if l < r { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } - - // Range-based analysis - if let Some(left_range) = cache.get_range(left) { - if let Some(right_range) = cache.get_range(right) { - // left.max < right.min => always true - if let (Some(l_max), Some(r_min)) = (left_range.max(), right_range.min()) { - if l_max < r_min { - return PredicateResult::AlwaysTrue; - } - } - // left.min >= right.max => always false - if let (Some(l_min), Some(r_max)) = (left_range.min(), right_range.max()) { - if l_min >= r_max { - return PredicateResult::AlwaysFalse; - } - } - } - - // Check if left < constant - if let Some(SsaOp::Const { value: rval, .. }) = right_def { - if let Some(r) = rval.as_i64() { - if let Some(result) = left_range.always_less_than(r) { - return if result { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } - } - } - - // Unsigned comparison: x < 0 is always false - if unsigned { - if let Some(SsaOp::Const { value: rval, .. }) = right_def { - if rval.as_u64() == Some(0) { - return PredicateResult::AlwaysFalse; - } - } - } - - // Non-negative < 0 is always false - if let Some(left_range) = cache.get_range(left) { - if left_range.is_always_non_negative() { - if let Some(SsaOp::Const { value: rval, .. }) = right_def { - if rval.as_i64() == Some(0) { - return PredicateResult::AlwaysFalse; - } - } - } - } - - PredicateResult::Unknown - } - - /// Analyzes a greater-than comparison (`Cgt`) for opaque predicate patterns. - /// - /// Checks in order: - /// 1. **Constant comparison**: both operands are constants, evaluate directly (signed or unsigned). - /// 2. **Range-based**: if both operands have known [`ValueRange`]s, checks whether - /// `left.min > right.max` (always true) or `left.max <= right.min` (always false). - /// 3. **Left range vs. constant right**: uses [`ValueRange::always_greater_than`]. - /// 4. **Unsigned bounds**: `0 >.un x` is always false (zero is never greater than any unsigned value). - /// 5. **Non-negative vs. negative**: if `left` is known non-negative and `right` is a - /// negative constant, returns always true. - /// - /// # Arguments - /// - /// * `left` - Left operand of the comparison. - /// * `right` - Right operand of the comparison. - /// * `unsigned` - Whether this is an unsigned comparison (`cgt.un`). - /// * `cache` - Definition cache for resolving variable definitions and ranges. - /// * `_depth` - Unused (greater-than analysis is non-recursive). - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] or [`AlwaysFalse`](PredicateResult::AlwaysFalse) if the - /// greater-than comparison can be statically determined, - /// [`Unknown`](PredicateResult::Unknown) otherwise. - fn analyze_greater_than( - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - cache: &DefinitionCache, - _depth: usize, - ) -> PredicateResult { - let left_def = cache.get_definition(left); - let right_def = cache.get_definition(right); - - // Constant comparison - if let (Some(SsaOp::Const { value: lval, .. }), Some(SsaOp::Const { value: rval, .. })) = - (left_def, right_def) - { - if unsigned { - if let (Some(l), Some(r)) = (lval.as_u64(), rval.as_u64()) { - return if l > r { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } else if let (Some(l), Some(r)) = (lval.as_i64(), rval.as_i64()) { - return if l > r { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } - - // Range-based analysis - if let Some(left_range) = cache.get_range(left) { - if let Some(right_range) = cache.get_range(right) { - // left.min > right.max => always true - if let (Some(l_min), Some(r_max)) = (left_range.min(), right_range.max()) { - if l_min > r_max { - return PredicateResult::AlwaysTrue; - } - } - // left.max <= right.min => always false - if let (Some(l_max), Some(r_min)) = (left_range.max(), right_range.min()) { - if l_max <= r_min { - return PredicateResult::AlwaysFalse; - } - } - } - - // Check if left > constant - if let Some(SsaOp::Const { value: rval, .. }) = right_def { - if let Some(r) = rval.as_i64() { - if let Some(result) = left_range.always_greater_than(r) { - return if result { - PredicateResult::AlwaysTrue - } else { - PredicateResult::AlwaysFalse - }; - } - } - } - } - - // Unsigned: 0 > x is always false - if unsigned { - if let Some(SsaOp::Const { value: lval, .. }) = left_def { - if lval.as_u64() == Some(0) { - return PredicateResult::AlwaysFalse; - } - } - } - - // Non-negative value >= 0 is always true (x > -1 equivalent) - if let Some(left_range) = cache.get_range(left) { - if left_range.is_always_non_negative() { - if let Some(SsaOp::Const { value: rval, .. }) = right_def { - if rval.as_i64().is_some_and(|r| r < 0) { - return PredicateResult::AlwaysTrue; - } - } - } - } - - PredicateResult::Unknown - } - - /// Analyzes a remainder operation (`Rem`) for `x % 1` which always produces zero. - /// - /// Returns `Unknown` rather than `AlwaysTrue`/`AlwaysFalse` because the zero result - /// is only meaningful when subsequently compared (handled by the equality analysis). - /// - /// # Arguments - /// - /// * `_left` - Left operand of the remainder (unused; any value mod 1 is zero). - /// * `right` - Right operand of the remainder (checked for constant 1). - /// * `cache` - Definition cache for resolving variable definitions. - /// * `_depth` - Unused (remainder analysis is non-recursive). - /// - /// # Returns - /// - /// Always returns [`PredicateResult::Unknown`] because the zero result is only - /// meaningful when used in a subsequent comparison. - fn analyze_remainder( - _left: SsaVarId, - right: SsaVarId, - cache: &DefinitionCache, - _depth: usize, - ) -> PredicateResult { - // x % 1 == 0 is always true - if let Some(SsaOp::Const { value: rval, .. }) = cache.get_definition(right) { - if rval.as_i64() == Some(1) { - // Result is always 0 - return PredicateResult::Unknown; // Handled when compared to 0 - } - } - PredicateResult::Unknown - } - - /// Analyzes a multiplication (`Mul`) for zero-producing patterns (`x * 0` or `0 * x`). - /// - /// Returns `Unknown` because the zero result is only meaningful when subsequently - /// compared (handled by the equality analysis at the comparison level). - /// - /// # Arguments - /// - /// * `left` - Left operand of the multiplication. - /// * `right` - Right operand of the multiplication. - /// * `cache` - Definition cache for resolving variable definitions. - /// * `_depth` - Unused (multiplication analysis is non-recursive). - /// - /// # Returns - /// - /// Always returns [`PredicateResult::Unknown`] because the zero result is only - /// meaningful when used in a subsequent comparison. - fn analyze_multiplication( - left: SsaVarId, - right: SsaVarId, - cache: &DefinitionCache, - _depth: usize, - ) -> PredicateResult { - // x * 0 = 0 - if let Some(SsaOp::Const { value: lval, .. }) = cache.get_definition(left) { - if lval.is_zero() { - return PredicateResult::Unknown; // Result is 0 - } - } - if let Some(SsaOp::Const { value: rval, .. }) = cache.get_definition(right) { - if rval.is_zero() { - return PredicateResult::Unknown; // Result is 0 - } - } - PredicateResult::Unknown - } - - /// Analyzes a bitwise AND (`And`) for zero-producing patterns (`x & 0` or `0 & x`). - /// - /// Returns `Unknown` because the zero result is only meaningful when subsequently - /// compared (handled by the equality analysis at the comparison level). - /// - /// # Arguments - /// - /// * `left` - Left operand of the AND. - /// * `right` - Right operand of the AND. - /// * `cache` - Definition cache for resolving variable definitions. - /// * `_depth` - Unused (AND analysis is non-recursive). - /// - /// # Returns - /// - /// Always returns [`PredicateResult::Unknown`] because the zero result is only - /// meaningful when used in a subsequent comparison. - fn analyze_and( - left: SsaVarId, - right: SsaVarId, - cache: &DefinitionCache, - _depth: usize, - ) -> PredicateResult { - // x & 0 = 0 - if let Some(SsaOp::Const { value: lval, .. }) = cache.get_definition(left) { - if lval.is_zero() { - return PredicateResult::Unknown; - } - } - if let Some(SsaOp::Const { value: rval, .. }) = cache.get_definition(right) { - if rval.is_zero() { - return PredicateResult::Unknown; - } - } - PredicateResult::Unknown - } - - /// Checks if an operation produces a constant zero. - fn is_zero_constant(op: &SsaOp) -> bool { - matches!(op, SsaOp::Const { value, .. } if value.is_zero()) - } - - /// Checks if an operation produces a constant one. - fn is_one_constant(op: &SsaOp) -> bool { - matches!(op, SsaOp::Const { value, .. } if value.is_one()) - } - - /// Checks if an operation produces a null constant. - fn is_null_constant(op: &SsaOp) -> bool { - matches!(op, SsaOp::Const { value, .. } if value.is_null()) - } - - /// Checks if an operation produces a constant -1. - fn is_minus_one_constant(op: &SsaOp) -> bool { - matches!(op, SsaOp::Const { value, .. } if value.is_minus_one()) - } - - /// Returns `true` if the operation is a `Mul` where either operand is a constant zero. - /// - /// # Arguments - /// - /// * `op` - The operation to check, or `None` if the variable has no definition. - /// * `cache` - Definition cache for resolving the multiplication operands. - /// - /// # Returns - /// - /// `true` if `op` is a `Mul` with at least one constant-zero operand, `false` otherwise. - fn is_zero_producing_mul(op: Option<&SsaOp>, cache: &DefinitionCache) -> bool { - if let Some(SsaOp::Mul { left, right, .. }) = op { - if let Some(l) = cache.get_definition(*left) { - if Self::is_zero_constant(l) { - return true; - } - } - if let Some(r) = cache.get_definition(*right) { - if Self::is_zero_constant(r) { - return true; - } - } - } - false - } - - /// Returns `true` if the operation is an `And` where either operand is a constant zero. - /// - /// # Arguments - /// - /// * `op` - The operation to check, or `None` if the variable has no definition. - /// * `cache` - Definition cache for resolving the AND operands. - /// - /// # Returns - /// - /// `true` if `op` is an `And` with at least one constant-zero operand, `false` otherwise. - fn is_zero_producing_and(op: Option<&SsaOp>, cache: &DefinitionCache) -> bool { - if let Some(SsaOp::And { left, right, .. }) = op { - if let Some(l) = cache.get_definition(*left) { - if Self::is_zero_constant(l) { - return true; - } - } - if let Some(r) = cache.get_definition(*right) { - if Self::is_zero_constant(r) { - return true; - } - } - } - false - } - - /// Checks if an operation is an expression modulo 2 that always evaluates to 0. - /// - /// Detects number-theoretic opaque predicates based on the mathematical - /// property that the product of two consecutive integers is always even: - /// - /// - `(x * (x + 1)) % 2` — direct consecutive product - /// - `(x * (x - 1)) % 2` — reversed consecutive product - /// - `(x * x - x) % 2` — factored form: x^2-x = x(x-1) - /// - `(x * x + x) % 2` — factored form: x^2+x = x(x+1) - /// - /// # Arguments - /// - /// * `op` - The operation to check, or `None` if the variable has no definition. - /// * `cache` - Definition cache for resolving operand definitions. - /// - /// # Returns - /// - /// `true` if the expression is a `Rem` by 2 whose dividend is always even, `false` otherwise. - fn is_always_even_expression(op: Option<&SsaOp>, cache: &DefinitionCache) -> bool { - let Some(SsaOp::Rem { - left: rem_left, - right: rem_right, - .. - }) = op - else { - return false; - }; - - // Divisor must be 2 - let is_mod2 = cache - .get_definition(*rem_right) - .is_some_and(|d| matches!(d, SsaOp::Const { value, .. } if value.as_i64() == Some(2))); - if !is_mod2 { - return false; - } - - let dividend_def = cache.get_definition(*rem_left); - - // Pattern 1: x * (x +/- 1) — consecutive product - if let Some(SsaOp::Mul { - left: mul_left, - right: mul_right, - .. - }) = dividend_def - { - if Self::is_consecutive_pair(*mul_left, *mul_right, cache) { - return true; - } - } - - // Pattern 2: (x * x) -/+ x — factored consecutive product - // x^2-x = x(x-1), x^2+x = x(x+1), both always even - if let Some(SsaOp::Sub { - left: op_left, - right: op_right, - .. - }) - | Some(SsaOp::Add { - left: op_left, - right: op_right, - .. - }) = dividend_def - { - if Self::is_self_square(*op_left, *op_right, cache) - || Self::is_self_square(*op_right, *op_left, cache) - { - return true; - } - } - - false - } - - /// Checks if `square_var` is defined as `other * other` (i.e., `other^2`). - /// - /// # Arguments - /// - /// * `square_var` - The variable suspected to be a square. - /// * `other` - The variable that should appear as both operands of the multiplication. - /// * `cache` - Definition cache for resolving the definition of `square_var`. - /// - /// # Returns - /// - /// `true` if `square_var` is defined as `Mul { left: other, right: other }`, `false` otherwise. - fn is_self_square(square_var: SsaVarId, other: SsaVarId, cache: &DefinitionCache) -> bool { - matches!( - cache.get_definition(square_var), - Some(SsaOp::Mul { left, right, .. }) if *left == other && *right == other - ) - } - - /// Checks if two variables form a consecutive integer pair (`n` and `n+1`). - /// - /// Performs three symmetric checks: - /// - `b = a + 1` (either operand order of the `Add`). - /// - `a = b + 1` (symmetric: `a` is the incremented one). - /// - `b = a - (-1)` (subtraction of -1 is equivalent to adding 1). - /// - /// # Arguments - /// - /// * `a` - First variable of the potential consecutive pair. - /// * `b` - Second variable of the potential consecutive pair. - /// * `cache` - Definition cache for resolving variable definitions. - /// - /// # Returns - /// - /// `true` if one variable is defined as the other plus one, `false` otherwise. - fn is_consecutive_pair(a: SsaVarId, b: SsaVarId, cache: &DefinitionCache) -> bool { - // Check if b = a + 1 - if let Some(SsaOp::Add { - left: add_left, - right: add_right, - .. - }) = cache.get_definition(b) - { - if *add_left == a { - if let Some(SsaOp::Const { value: rval, .. }) = cache.get_definition(*add_right) { - if rval.as_i64() == Some(1) { - return true; - } - } - } - if *add_right == a { - if let Some(SsaOp::Const { value: lval, .. }) = cache.get_definition(*add_left) { - if lval.as_i64() == Some(1) { - return true; - } - } - } - } - - // Check if a = b + 1 (symmetric) - if let Some(SsaOp::Add { - left: add_left, - right: add_right, - .. - }) = cache.get_definition(a) - { - if *add_left == b { - if let Some(SsaOp::Const { value: rval, .. }) = cache.get_definition(*add_right) { - if rval.as_i64() == Some(1) { - return true; - } - } - } - if *add_right == b { - if let Some(SsaOp::Const { value: lval, .. }) = cache.get_definition(*add_left) { - if lval.as_i64() == Some(1) { - return true; - } - } - } - } - - // Check if b = a - (-1) which is also a + 1 - if let Some(SsaOp::Sub { - left: sub_left, - right: sub_right, - .. - }) = cache.get_definition(b) - { - if *sub_left == a { - if let Some(SsaOp::Const { value: rval, .. }) = cache.get_definition(*sub_right) { - if rval.as_i64() == Some(-1) { - return true; - } - } - } - } - - false - } - - /// Analyzes a branch condition variable to determine if it is an opaque predicate. - /// - /// Follows `Copy` chains iteratively with a [`BitSet`]-based cycle detector to handle - /// SSA copies from phi nodes or obfuscated control flow. At each step: - /// 1. If the variable has no definition in `DefUseIndex` but is phi-defined, checks - /// range info for a constant-zero equivalence (all-zero phi = always false). - /// 2. Delegates to [`analyze_predicate_with_cache`](Self::analyze_predicate_with_cache) for - /// comparison operations. - /// 3. For `Copy` ops, advances to the source variable and continues the loop. - /// 4. For all other operations, falls back to [`analyze_branch_op`](Self::analyze_branch_op) - /// which checks direct truthiness of arithmetic and constant operations. - /// - /// # Arguments - /// - /// * `condition` - The SSA variable used as the branch condition. - /// * `cache` - Definition cache for resolving the condition's definition chain. - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] if the branch always takes the true path, - /// [`AlwaysFalse`](PredicateResult::AlwaysFalse) if it always takes the false path, - /// [`Unknown`](PredicateResult::Unknown) if the condition cannot be statically resolved. - fn analyze_branch(condition: SsaVarId, cache: &DefinitionCache) -> PredicateResult { - // Follow Copy chain iteratively with cycle detection to prevent infinite recursion. - // This is needed because SSA can have Copy cycles (e.g., from phi nodes or - // obfuscated control flow patterns). - let mut current = condition; - let mut visited = BitSet::new(cache.phi_defs.len()); - - loop { - // Cycle detection: if we've seen this variable before, bail out - if !visited.insert(current.index()) { - return PredicateResult::Unknown; - } - - let Some(cond_op) = cache.get_definition(current) else { - // Check if it's a phi node - analyze all operands - if cache.is_phi_defined(current) { - // For phi nodes, we'd need to check if all operands lead to the same result - // This is complex, so we return Unknown for now unless we have range info - if let Some(range) = cache.get_range(current) { - if let Some(result) = range.always_equal_to(0) { - return if result { - PredicateResult::AlwaysFalse - } else { - PredicateResult::AlwaysTrue - }; - } - } - } - return PredicateResult::Unknown; - }; - - // First, check if it's a direct comparison predicate - let predicate_result = Self::analyze_predicate_with_cache(cond_op, cache, 0); - if predicate_result != PredicateResult::Unknown { - return predicate_result; - } - - // Check if the condition is a Copy - trace through to the source iteratively - if let SsaOp::Copy { src, .. } = cond_op { - current = *src; - continue; - } - - // Not a Copy, break out and analyze the operation - return Self::analyze_branch_op(cond_op, cache); - } - } - - /// Analyzes a non-Copy, non-comparison operation for direct truthiness in a branch. - /// - /// Called after the `Copy` chain has been resolved by [`analyze_branch`](Self::analyze_branch). - /// Checks: - /// - `x ^ x` = 0 (always false in `brtrue`). - /// - `x - x` = 0 (always false). - /// - `x & 0` or `x * 0` = 0 (always false). - /// - `x | -1` = all-bits-set (always true). - /// - `Const`: zero/null/false is always false; non-zero numeric or `true` is always true. - /// - /// # Arguments - /// - /// * `cond_op` - The resolved operation producing the branch condition value. - /// * `cache` - Definition cache for resolving operands of the condition operation. - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] if the operation always produces a non-zero value, - /// [`AlwaysFalse`](PredicateResult::AlwaysFalse) if it always produces zero, - /// [`Unknown`](PredicateResult::Unknown) if the truthiness cannot be determined. - fn analyze_branch_op(cond_op: &SsaOp, cache: &DefinitionCache) -> PredicateResult { - // Check operations that produce known zero values - match cond_op { - // x ^ x = 0, so brtrue on this result never jumps - SsaOp::Xor { left, right, .. } if left == right => PredicateResult::AlwaysFalse, - - // x - x = 0, so brtrue on this result never jumps - SsaOp::Sub { left, right, .. } if left == right => PredicateResult::AlwaysFalse, - - // x & 0 = 0, x * 0 = 0 - SsaOp::And { left, right, .. } | SsaOp::Mul { left, right, .. } => { - let is_left_zero = cache - .get_definition(*left) - .is_some_and(Self::is_zero_constant); - let is_right_zero = cache - .get_definition(*right) - .is_some_and(Self::is_zero_constant); - - if is_left_zero || is_right_zero { - PredicateResult::AlwaysFalse - } else { - PredicateResult::Unknown - } - } - - // x | -1 = -1 (all bits set), so brtrue always jumps - SsaOp::Or { left, right, .. } => { - let is_left_minus_one = cache - .get_definition(*left) - .is_some_and(Self::is_minus_one_constant); - let is_right_minus_one = cache - .get_definition(*right) - .is_some_and(Self::is_minus_one_constant); - - if is_left_minus_one || is_right_minus_one { - PredicateResult::AlwaysTrue - } else { - PredicateResult::Unknown - } - } - - // Constant values: 0/null/false is always false, non-zero is always true - SsaOp::Const { value, .. } => { - if value.is_zero() || value.is_null() { - PredicateResult::AlwaysFalse - } else if value.as_i64().is_some() || value.as_bool().is_some() { - // Non-zero numeric or true boolean - PredicateResult::AlwaysTrue - } else { - PredicateResult::Unknown - } - } - - // All other operations have unknown truthiness - // Note: ArrayLength is always >= 0, but we can't prove non-empty - _ => PredicateResult::Unknown, - } - } - - /// Analyzes a comparison operation for algebraic simplification opportunities. - /// - /// This checks for patterns like: - /// - `(x - y) == 0` → `x == y` - /// - `(x - y) < 0` → `x < y` - /// - `(x - y) > 0` → `x > y` - /// - `(x ^ y) == 0` → `x == y` - /// - `(cmp) == 1` → `cmp` - /// - /// # Arguments - /// - /// * `op` - The SSA comparison operation to analyze (`Ceq`, `Clt`, or `Cgt`). - /// * `cache` - Definition cache for resolving operand definitions. - /// - /// # Returns - /// - /// `Some(ComparisonSimplification)` if the comparison can be algebraically simplified, - /// `None` if no simplification applies or the operation is not a comparison. - fn analyze_comparison_simplification( - op: &SsaOp, - cache: &DefinitionCache, - ) -> Option { - match op { - SsaOp::Ceq { dest, left, right } => { - Self::analyze_ceq_simplification(*dest, *left, *right, cache) - } - SsaOp::Clt { - dest, - left, - right, - unsigned, - } => Self::analyze_clt_simplification(*dest, *left, *right, *unsigned, cache), - SsaOp::Cgt { - dest, - left, - right, - unsigned, - } => Self::analyze_cgt_simplification(*dest, *left, *right, *unsigned, cache), - _ => None, - } - } - - /// Checks if a variable is defined as a constant zero. - fn is_zero_var(var: SsaVarId, cache: &DefinitionCache) -> bool { - cache - .get_definition(var) - .is_some_and(Self::is_zero_constant) - } - - /// Checks if a variable is defined as a constant with value 1. - fn is_one_var(var: SsaVarId, cache: &DefinitionCache) -> bool { - cache.get_definition(var).is_some_and(Self::is_one_constant) - } - - /// Analyzes a `Ceq` operation for algebraic simplification. - /// - /// Detects three patterns: - /// - `(x - y) == 0` simplifies to `x == y` (subtraction-zero, skip self-subtraction). - /// - `(x ^ y) == 0` simplifies to `x == y` (XOR-zero, skip self-XOR). - /// - `(cmp) == 1` simplifies to `Copy(cmp)` when the other operand is a `Ceq`/`Clt`/`Cgt` - /// result, since CIL comparisons already produce 0 or 1. - /// - /// # Arguments - /// - /// * `dest` - Destination variable of the `Ceq` (preserved in the simplified op). - /// * `left` - Left operand of the `Ceq`. - /// * `right` - Right operand of the `Ceq`. - /// * `cache` - Definition cache for resolving operand definitions. - /// - /// # Returns - /// - /// `Some(ComparisonSimplification)` if a simplification pattern matches, `None` otherwise. - fn analyze_ceq_simplification( - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - cache: &DefinitionCache, - ) -> Option { - // Check if comparing to zero - let (other_var, is_comparing_to_zero) = if Self::is_zero_var(right, cache) { - (left, true) - } else if Self::is_zero_var(left, cache) { - (right, true) - } else { - (left, false) - }; - - if is_comparing_to_zero { - if let Some(def_op) = cache.get_definition(other_var) { - // Pattern: (x - y) == 0 → x == y - if let SsaOp::Sub { - left: sub_left, - right: sub_right, - .. - } = def_op - { - // Skip self-subtraction - that's handled by PredicateResult (always true) - if sub_left != sub_right { - return Some(ComparisonSimplification::SimplerOp { - new_op: SsaOp::Ceq { - dest, - left: *sub_left, - right: *sub_right, - }, - reason: "(x - y) == 0 simplified to x == y", - }); - } - } - - // Pattern: (x ^ y) == 0 → x == y - if let SsaOp::Xor { - left: xor_left, - right: xor_right, - .. - } = def_op - { - // Skip self-XOR - that's handled by PredicateResult (always true) - if xor_left != xor_right { - return Some(ComparisonSimplification::SimplerOp { - new_op: SsaOp::Ceq { - dest, - left: *xor_left, - right: *xor_right, - }, - reason: "(x ^ y) == 0 simplified to x == y", - }); - } - } - } - } - - // Check if comparing to one (true in CIL) - let (other_var, is_comparing_to_one) = if Self::is_one_var(right, cache) { - (left, true) - } else if Self::is_one_var(left, cache) { - (right, true) - } else { - (left, false) - }; - - if is_comparing_to_one { - if let Some(def_op) = cache.get_definition(other_var) { - // Pattern: (cmp) == 1 → copy cmp - if matches!( - def_op, - SsaOp::Ceq { .. } | SsaOp::Clt { .. } | SsaOp::Cgt { .. } - ) { - return Some(ComparisonSimplification::Copy { - dest, - src: other_var, - reason: "(cmp) == 1 simplified to cmp", - }); - } - } - } - - None - } - - /// Analyzes a `Clt` operation for algebraic simplification. - /// - /// Detects `(x - y) < 0` and simplifies to `x < y` (signed only; unsigned subtraction - /// has different overflow semantics). Self-subtraction is skipped since it is handled - /// as a constant predicate (`AlwaysFalse`). - /// - /// # Arguments - /// - /// * `dest` - Destination variable of the `Clt` (preserved in the simplified op). - /// * `left` - Left operand of the `Clt`. - /// * `right` - Right operand of the `Clt`. - /// * `unsigned` - Whether this is an unsigned comparison; if `true`, no simplification is attempted. - /// * `cache` - Definition cache for resolving operand definitions. - /// - /// # Returns - /// - /// `Some(ComparisonSimplification::SimplerOp)` if `(x - y) < 0` is detected, `None` otherwise. - fn analyze_clt_simplification( - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - cache: &DefinitionCache, - ) -> Option { - // Only handle signed comparisons for subtraction patterns - // (unsigned subtraction has different overflow semantics) - if unsigned { - return None; - } - - // Pattern: (x - y) < 0 → x < y - if Self::is_zero_var(right, cache) { - if let Some(SsaOp::Sub { - left: sub_left, - right: sub_right, - .. - }) = cache.get_definition(left) - { - // Skip self-subtraction - that's handled by PredicateResult (always false) - if sub_left != sub_right { - return Some(ComparisonSimplification::SimplerOp { - new_op: SsaOp::Clt { - dest, - left: *sub_left, - right: *sub_right, - unsigned, - }, - reason: "(x - y) < 0 simplified to x < y", - }); - } - } - } - - None - } - - /// Analyzes a `Cgt` operation for algebraic simplification. - /// - /// Detects `(x - y) > 0` and simplifies to `x > y` (signed only; unsigned subtraction - /// has different overflow semantics). Self-subtraction is skipped since it is handled - /// as a constant predicate (`AlwaysFalse`). - /// - /// # Arguments - /// - /// * `dest` - Destination variable of the `Cgt` (preserved in the simplified op). - /// * `left` - Left operand of the `Cgt`. - /// * `right` - Right operand of the `Cgt`. - /// * `unsigned` - Whether this is an unsigned comparison; if `true`, no simplification is attempted. - /// * `cache` - Definition cache for resolving operand definitions. - /// - /// # Returns - /// - /// `Some(ComparisonSimplification::SimplerOp)` if `(x - y) > 0` is detected, `None` otherwise. - fn analyze_cgt_simplification( - dest: SsaVarId, - left: SsaVarId, - right: SsaVarId, - unsigned: bool, - cache: &DefinitionCache, - ) -> Option { - // Only handle signed comparisons for subtraction patterns - if unsigned { - return None; - } - - // Pattern: (x - y) > 0 → x > y - if Self::is_zero_var(right, cache) { - if let Some(SsaOp::Sub { - left: sub_left, - right: sub_right, - .. - }) = cache.get_definition(left) - { - // Skip self-subtraction - that's handled by PredicateResult (always false) - if sub_left != sub_right { - return Some(ComparisonSimplification::SimplerOp { - new_op: SsaOp::Cgt { - dest, - left: *sub_left, - right: *sub_right, - unsigned, - }, - reason: "(x - y) > 0 simplified to x > y", - }); - } - } - } - - None - } - - /// Fallback evaluator for branch conditions that pattern matching cannot resolve. - /// - /// Creates an [`SsaEvaluator`] and runs a forward pass over all blocks from 0 through - /// `block_idx`, accumulating concrete values via dataflow propagation. If the condition - /// variable resolves to a constant after evaluation, returns `AlwaysTrue` (non-zero) or - /// `AlwaysFalse` (zero). - /// - /// This catches predicates that require multi-step constant propagation across blocks - /// (e.g., a value computed in block 0 that flows through assignments to block 5's branch). - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to evaluate. - /// * `condition` - The branch condition variable to resolve. - /// * `block_idx` - The block containing the branch (evaluation covers blocks 0..=block_idx). - /// * `ptr_size` - Pointer size for the evaluator (affects address arithmetic). - /// - /// # Returns - /// - /// [`PredicateResult::AlwaysTrue`] or [`AlwaysFalse`](PredicateResult::AlwaysFalse) if the - /// evaluator resolves the condition to a constant. [`Unknown`](PredicateResult::Unknown) if - /// the value depends on runtime inputs, loops, or unresolvable phi operands. - fn evaluate_with_tracked( - ssa: &SsaFunction, - condition: SsaVarId, - block_idx: usize, - ptr_size: PointerSize, - ) -> PredicateResult { - let mut evaluator = SsaEvaluator::new(ssa, ptr_size); - - // Evaluate all blocks up to and including the current block. - // We use a simple forward pass - in complex cases with loops, - // this may not capture all values, but it handles linear flows. - for idx in 0..=block_idx { - // For blocks that precede our target, we can evaluate them - // to build up the value state - evaluator.evaluate_block(idx); - } - - // Check if we have a concrete value for the condition - match evaluator.get(condition) { - Some(expr) if expr.is_constant() => { - if expr.as_constant().is_some_and(ConstValue::is_zero) { - PredicateResult::AlwaysFalse - } else { - PredicateResult::AlwaysTrue - } - } - Some(_) | None => PredicateResult::Unknown, - } - } - - /// Detects phi nodes where every operand resolves to the same constant value. - /// - /// Iterates over all phi nodes in all blocks. For each phi, looks up each operand's - /// defining operation: if all are `Const` with identical values, records the mapping - /// from the phi result variable to that constant. These entries are later used to - /// replace the phi with a `Const` instruction and to resolve branch conditions that - /// depend on phi-defined variables. - /// - /// # Arguments - /// - /// * `ssa` - The SSA function whose phi nodes are analyzed. - /// - /// # Returns - /// - /// A map from phi result variable to the constant value that all operands agree on. - /// Empty if no phi nodes have all-constant, all-identical operands. - fn analyze_phi_constants(ssa: &SsaFunction) -> BTreeMap { - let mut phi_constants = BTreeMap::new(); - - for block in ssa.blocks() { - for phi in block.phi_nodes() { - let operands: Vec<_> = phi.operands().iter().collect(); - let Some(first_operand) = operands.first() else { - continue; - }; - - // Check if all operands come from the same constant - let first_val = first_operand.value(); - let mut all_same_const = true; - let mut const_value = None; - - for operand in &operands { - let var = operand.value(); - // Look up the definition - if let Some(op) = ssa.get_definition(var) { - if let SsaOp::Const { value, .. } = op { - if const_value.is_none() { - const_value = Some(value.clone()); - } else if const_value.as_ref() != Some(value) { - all_same_const = false; - break; - } - } else { - all_same_const = false; - break; - } - } else if var != first_val { - all_same_const = false; - break; - } - } - - if all_same_const { - if let Some(value) = const_value { - phi_constants.insert(phi.result(), value); - } - } - } - } - - phi_constants - } -} - -impl SsaPass for OpaquePredicatePass { - fn name(&self) -> &'static str { - "opaque-predicate-removal" - } - - fn description(&self) -> &'static str { - "Detects and removes opaque predicates (always-true/false conditions)" - } - - fn provides(&self) -> &[PassCapability] { - &[PassCapability::SimplifiedPredicates] - } - - fn requires(&self) -> &[PassCapability] { - &[PassCapability::RestoredControlFlow] - } - - /// Runs opaque predicate detection and removal on a single method. - /// - /// Operates in two phases: - /// - /// **Collection phase** (read-only scan of all blocks): - /// 1. **Branch simplifications**: for each `Branch` terminator, checks phi-constants first, - /// then pattern-matching via [`analyze_branch`](OpaquePredicatePass::analyze_branch), then - /// [`evaluate_with_tracked`](OpaquePredicatePass::evaluate_with_tracked) as a dataflow fallback. - /// 2. **Comparison replacements**: comparison instructions (`Ceq`/`Clt`/`Cgt`) that are - /// opaque predicates get replaced with `Const(true)` or `Const(false)`. - /// 3. **Comparison simplifications**: algebraic rewrites like `(x-y)==0` to `x==y`. - /// 4. **Phi replacements**: phi nodes where all operands are the same constant. - /// - /// **Apply phase** (mutates the SSA in collected order): - /// - Branch terminators become `Jump` to the always-taken target. - /// - Comparison instructions become `Const` or `Copy` operations. - /// - Phi nodes become `Const` instructions (inserted at block start, phi removed). - /// - /// # Arguments - /// - /// * `ssa` - The SSA function to analyze and mutate. - /// * `method_token` - Token of the method (for event logging). - /// * `ctx` - Compiler context for event recording. - /// * `assembly` - Assembly reference for pointer size detection (used by `SsaEvaluator`). - /// - /// # Returns - /// - /// `Ok(true)` if structural changes were made (branches simplified, comparisons resolved). - /// `Ok(false)` if only phi-only changes occurred or no changes at all. Phi-only changes - /// return `false` to avoid triggering an immediate SSA rebuild that would recreate the phi. - /// - /// # Errors - /// - /// Returns an error if SSA mutation fails (should not occur in practice). - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { - let changes = EventLog::new(); - - // Build definition cache for efficient lookup - let cache = DefinitionCache::build(ssa); - - // Analyze phi nodes for constant values - let phi_constants = Self::analyze_phi_constants(ssa); - - // Collect branches to simplify - let mut branch_simplifications: Vec<(usize, usize, bool)> = Vec::new(); - - // Collect comparison replacements (opaque predicates that become constant true/false) - let mut comparison_replacements: Vec<(usize, usize, SsaVarId, bool)> = Vec::new(); - - // Collect comparison simplifications (algebraic simplifications like (x-y)==0 → x==y) - let mut comparison_simplifications: Vec<(usize, usize, ComparisonSimplification)> = - Vec::new(); - - // Collect phi replacements - let mut phi_replacements: Vec<(usize, usize, SsaVarId, ConstValue)> = Vec::new(); - - // Analyze each block - for (block_idx, block) in ssa.iter_blocks() { - // Analyze branch terminators - if let Some(SsaOp::Branch { - condition, - true_target, - false_target, - }) = block.terminator_op() - { - // Check phi constants first - if let Some(const_val) = phi_constants.get(condition) { - let is_true = const_val.as_bool().unwrap_or(false) - || const_val.as_i64().is_some_and(|v| v != 0); - if is_true { - branch_simplifications.push((block_idx, *true_target, true)); - } else { - branch_simplifications.push((block_idx, *false_target, false)); - } - // Can't use continue with iter_blocks in a for loop, collect the data - } else { - let mut result = Self::analyze_branch(*condition, &cache); - - // If pattern matching couldn't determine the result, - // try using SsaEvaluator for dataflow-based analysis - if result == PredicateResult::Unknown { - let ptr_size = PointerSize::from_pe(assembly.file().pe().is_64bit); - result = Self::evaluate_with_tracked(ssa, *condition, block_idx, ptr_size); - } - - match result { - PredicateResult::AlwaysTrue => { - branch_simplifications.push((block_idx, *true_target, true)); - } - PredicateResult::AlwaysFalse => { - branch_simplifications.push((block_idx, *false_target, false)); - } - PredicateResult::Unknown => {} - } - } - } - - // Analyze comparison instructions - for (instr_idx, instr) in block.instructions().iter().enumerate() { - let op = instr.op(); - // First check for opaque predicates (constant true/false) - let result = Self::analyze_predicate_with_cache(op, &cache, 0); - if let Some(value) = result.as_bool() { - if let Some(dest) = op.dest() { - comparison_replacements.push((block_idx, instr_idx, dest, value)); - continue; // Don't also check for simplification - } - } - - // Then check for algebraic simplifications - if let Some(simplification) = Self::analyze_comparison_simplification(op, &cache) { - comparison_simplifications.push((block_idx, instr_idx, simplification)); - } - } - - // Check for phi nodes that can be replaced with constants - for (phi_idx, phi) in block.phi_nodes().iter().enumerate() { - if let Some(const_val) = phi_constants.get(&phi.result()) { - phi_replacements.push((block_idx, phi_idx, phi.result(), const_val.clone())); - } - } - } - - // Track structural changes (branches, comparisons) vs phi-only changes. - // Phi constant replacement doesn't warrant rebuild_ssa (see comment below). - let has_structural = !branch_simplifications.is_empty() - || !comparison_replacements.is_empty() - || !comparison_simplifications.is_empty(); - - // Apply branch simplifications - for (block_idx, target, is_true) in branch_simplifications { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(last_instr) = block.instructions_mut().last_mut() { - last_instr.set_op(SsaOp::Jump { target }); - changes - .record(EventKind::OpaquePredicateRemoved) - .at(method_token, block_idx) - .message(format!( - "removed opaque predicate (always {})", - if is_true { "true" } else { "false" } - )); - changes - .record(EventKind::BranchSimplified) - .at(method_token, block_idx) - .message(format!("simplified to unconditional branch to {target}")); - } - } - } - - // Apply comparison replacements (opaque predicates → constant true/false) - for (block_idx, instr_idx, dest, value) in comparison_replacements { - if let Some(block) = ssa.block_mut(block_idx) { - let const_value = if value { - ConstValue::True - } else { - ConstValue::False - }; - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - instr.set_op(SsaOp::Const { - dest, - value: const_value, - }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("opaque predicate → {value}")); - } - } - } - - // Apply comparison simplifications (algebraic transformations) - for (block_idx, instr_idx, simplification) in comparison_simplifications { - if let Some(block) = ssa.block_mut(block_idx) { - match simplification { - ComparisonSimplification::SimplerOp { new_op, reason } => { - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - instr.set_op(new_op); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(reason); - } - } - ComparisonSimplification::Copy { dest, src, reason } => { - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - instr.set_op(SsaOp::Copy { dest, src }); - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(reason); - } - } - } - } - } - - // Apply phi replacements: PHIs where all operands are the same constant - // We replace the PHI with a constant instruction and remove the PHI. - // Process in reverse order to handle phi_idx correctly when removing. - let mut phi_removals: Vec<(usize, usize)> = Vec::new(); - for (block_idx, phi_idx, phi_result, const_value) in phi_replacements { - // Create a constant instruction with the same destination as the PHI - let const_instr = SsaInstruction::synthetic(SsaOp::Const { - dest: phi_result, - value: const_value.clone(), - }); - - // Insert at the beginning of the block's instructions - if let Some(block) = ssa.block_mut(block_idx) { - block.instructions_mut().insert(0, const_instr); - } - - // Mark this phi for removal - phi_removals.push((block_idx, phi_idx)); - - changes - .record(EventKind::ConstantFolded) - .at(method_token, block_idx) - .message(format!("phi with constant operands → {const_value:?}")); - } - - // Remove the PHIs (in reverse order to maintain correct indices) - phi_removals.sort_by(|a, b| b.cmp(a)); // Sort descending by (block_idx, phi_idx) - for (block_idx, phi_idx) in phi_removals { - if let Some(block) = ssa.block_mut(block_idx) { - if phi_idx < block.phi_nodes().len() { - block.phi_nodes_mut().remove(phi_idx); - } - } - } - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - - // Only report structural changes for rebuild_ssa. Phi constant - // replacement is a valid transformation, but triggering rebuild - // immediately undoes it: rebuild re-creates the phi from the original - // variable definitions that still exist in other blocks. By not - // triggering rebuild for phi-only changes, DCE can first clean up - // the now-dead original definitions. When rebuild eventually runs - // (triggered by a structural change), the original defs are gone - // and the phi won't be recreated. - Ok(has_structural) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::{ConstValue, MethodRef, SsaFunctionBuilder, SsaOp, SsaVarId, ValueRange}, - compiler::passes::predicates::{DefinitionCache, OpaquePredicatePass, PredicateResult}, - metadata::token::Token, - }; - - #[test] - fn test_predicate_result() { - assert_eq!(PredicateResult::AlwaysTrue.as_bool(), Some(true)); - assert_eq!(PredicateResult::AlwaysFalse.as_bool(), Some(false)); - assert_eq!(PredicateResult::Unknown.as_bool(), None); - - assert_eq!( - PredicateResult::AlwaysTrue.negate(), - PredicateResult::AlwaysFalse - ); - assert_eq!( - PredicateResult::AlwaysFalse.negate(), - PredicateResult::AlwaysTrue - ); - assert_eq!(PredicateResult::Unknown.negate(), PredicateResult::Unknown); - } - - #[test] - fn test_self_equality() { - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v = b.const_i32(42); - v0_out = v; - v1_out = b.ceq(v, v); // v1 = ceq v0, v0 (always true) - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let cache = DefinitionCache::build(&ssa); - let op = SsaOp::Ceq { - dest: v1, - left: v0, - right: v0, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysTrue - ); - } - - #[test] - fn test_self_less_than() { - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v = b.const_i32(42); - v0_out = v; - v1_out = b.clt(v, v); - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let cache = DefinitionCache::build(&ssa); - - // x < x is always false - let op = SsaOp::Clt { - dest: v1, - left: v0, - right: v0, - unsigned: false, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysFalse - ); - } - - #[test] - fn test_xor_self_equals_zero() { - let (ssa, v1, v2, v3) = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let mut v3_out = SsaVarId::from_index(2); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let v1 = b.xor(v0, v0); // v1 = v0 ^ v0 (always 0) - v1_out = v1; - let v2 = b.const_i32(0); - v2_out = v2; - v3_out = b.ceq(v1, v2); // v3 = ceq v1, v2 (always true) - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out, v2_out, v3_out) - }; - - let cache = DefinitionCache::build(&ssa); - let op = SsaOp::Ceq { - dest: v3, - left: v1, - right: v2, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysTrue - ); - } - - #[test] - fn test_constant_comparison() { - let (ssa, v0, v1, v2) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let mut v2_out = SsaVarId::from_index(2); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(5); - v0_out = v0; - let v1 = b.const_i32(10); - v1_out = v1; - v2_out = b.clt(v0, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out, v2_out) - }; - - let cache = DefinitionCache::build(&ssa); - - // 5 < 10 is always true - let op = SsaOp::Clt { - dest: v2, - left: v0, - right: v1, - unsigned: false, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysTrue - ); - - // 5 > 10 is always false - let op = SsaOp::Cgt { - dest: v2, - left: v0, - right: v1, - unsigned: false, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysFalse - ); - } - - #[test] - fn test_unsigned_comparison() { - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_val(ConstValue::U32(5)); - v0_out = v0; - let v1 = b.const_val(ConstValue::U32(0)); - v1_out = v1; - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let cache = DefinitionCache::build(&ssa); - - // unsigned x < 0 is always false - let dest = SsaVarId::from_index(2); - let op = SsaOp::Clt { - dest, - left: v0, - right: v1, - unsigned: true, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysFalse - ); - } - - #[test] - fn test_newobj_non_null() { - let (ssa, v0, v1) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v1_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - // v0 = newobj (always non-null) - let v0 = b.newobj(MethodRef::new(Token::new(0x06000001)), &[]); - v0_out = v0; - // v1 = null - let v1 = b.const_null(); - v1_out = v1; - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v1_out) - }; - - let cache = DefinitionCache::build(&ssa); - - // newobj result == null is always false - let dest = SsaVarId::from_index(2); - let op = SsaOp::Ceq { - dest, - left: v0, - right: v1, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysFalse - ); - } - - #[test] - fn test_array_length_non_negative() { - let (ssa, v1, v2) = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - // v0 = some array (placeholder) - let v0 = b.const_null(); - // v1 = array.Length (always >= 0) - let v1 = b.array_length(v0); - v1_out = v1; - // v2 = 0 - let v2 = b.const_i32(0); - v2_out = v2; - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out, v2_out) - }; - - let cache = DefinitionCache::build(&ssa); - - // array.Length < 0 is always false - let dest = SsaVarId::from_index(2); - let op = SsaOp::Clt { - dest, - left: v1, - right: v2, - unsigned: false, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysFalse - ); - } - - #[test] - fn test_multiply_by_zero() { - let (ssa, v1, v2) = { - let mut v1_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(42); - let v1 = b.const_i32(0); - v1_out = v1; - let v2 = b.mul(v0, v1); // v2 = v0 * v1 (always 0) - v2_out = v2; - let _ = b.ceq(v2, v1); // v3 = ceq v2, v1 (always true) - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out, v2_out) - }; - - let cache = DefinitionCache::build(&ssa); - let dest = SsaVarId::from_index(2); - let op = SsaOp::Ceq { - dest, - left: v2, - right: v1, - }; - - assert_eq!( - OpaquePredicatePass::analyze_predicate_with_cache(&op, &cache, 0), - PredicateResult::AlwaysTrue - ); - } - - #[test] - fn test_value_range() { - let range = ValueRange::constant(5); - assert_eq!(range.always_less_than(10), Some(true)); - assert_eq!(range.always_less_than(5), Some(false)); - assert_eq!(range.always_less_than(3), Some(false)); - - assert_eq!(range.always_greater_than(3), Some(true)); - assert_eq!(range.always_greater_than(5), Some(false)); - assert_eq!(range.always_greater_than(10), Some(false)); - - assert_eq!(range.always_equal_to(5), Some(true)); - assert_eq!(range.always_equal_to(3), Some(false)); - - let non_neg = ValueRange::non_negative(); - assert!(non_neg.is_always_non_negative()); - assert_eq!(non_neg.always_less_than(0), Some(false)); - } - - #[test] - fn test_consecutive_pair_detection() { - let (ssa, v0, v2) = { - let mut v0_out = SsaVarId::from_index(0); - let mut v2_out = SsaVarId::from_index(1); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(5); - v0_out = v0; - let v1 = b.const_i32(1); - let v2 = b.add(v0, v1); // v2 = v0 + v1 (x + 1) - v2_out = v2; - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out, v2_out) - }; - - let cache = DefinitionCache::build(&ssa); - - // v0 and v2 should be detected as consecutive pair (x and x+1) - assert!(OpaquePredicatePass::is_consecutive_pair(v0, v2, &cache)); - } - - #[test] - fn test_phi_constant_analysis() { - let (ssa, phi_var) = { - let mut c0_out = SsaVarId::from_index(0); - let mut c1_out = SsaVarId::from_index(1); - let mut phi_out = SsaVarId::from_index(2); - - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - c0_out = b.const_i32(42); - b.jump(2); - }); - f.block(1, |b| { - c1_out = b.const_i32(42); // Same constant - b.jump(2); - }); - f.block(2, |b| { - phi_out = b.phi(&[(0, c0_out), (1, c1_out)]); - }); - }) - .unwrap(); - - (ssa, phi_out) - }; - - let phi_constants = OpaquePredicatePass::analyze_phi_constants(&ssa); - - // The phi should be recognized as constant since both operands are 42 - assert!(phi_constants.contains_key(&phi_var)); - assert_eq!(phi_constants.get(&phi_var), Some(&ConstValue::I32(42))); - } -} diff --git a/dotscope/src/compiler/passes/proxy.rs b/dotscope/src/compiler/passes/proxy.rs index 1ad53b3b..3469f0bf 100644 --- a/dotscope/src/compiler/passes/proxy.rs +++ b/dotscope/src/compiler/passes/proxy.rs @@ -44,12 +44,12 @@ use crate::{ analysis::{ - ConstValue, DefSite, MethodRef, ReturnInfo, SsaFunction, SsaInstruction, SsaOp, SsaVarId, - VariableOrigin, + CilTarget, ConstValue, ConstValueCilExt, DefSite, MethodRef, ReturnInfo, SsaFunction, + SsaInstruction, SsaOp, SsaVarId, VariableOrigin, }, compiler::{CompilerContext, EventKind, EventLog, ModificationScope, PassCapability, SsaPass}, metadata::{tables::MemberRefSignature, token::Token, typesystem::CilTypeReference}, - CilObject, Result, + CilObject, }; /// How the proxy method forwards to its target. @@ -765,7 +765,7 @@ impl ProxyDevirtualizationPass { } } -impl SsaPass for ProxyDevirtualizationPass { +impl SsaPass for ProxyDevirtualizationPass { fn name(&self) -> &'static str { "proxy-devirtualization" } @@ -789,11 +789,14 @@ impl SsaPass for ProxyDevirtualizationPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { - let candidates = Self::find_candidates(ssa, method_token, ctx, assembly); + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly = host.assembly().ok_or_else(|| { + analyssa::Error::new("ProxyDevirtualizationPass requires an assembly") + })?; + let method_token = method.0; + let candidates = Self::find_candidates(ssa, method_token, host, &assembly); if candidates.is_empty() { return Ok(false); } @@ -802,12 +805,12 @@ impl SsaPass for ProxyDevirtualizationPass { // Process candidates in reverse order to maintain valid indices for candidate in candidates.into_iter().rev() { - Self::process_candidate(ssa, &candidate, method_token, ctx, &mut changes); + Self::process_candidate(ssa, &candidate, method_token, host, &mut changes); } let changed = !changes.is_empty(); if changed { - ctx.events.merge(&changes); + host.events.merge(&changes); } Ok(changed) } @@ -825,15 +828,12 @@ mod tests { }, metadata::token::Token, test::helpers::test_assembly_arc, - CilObject, }; fn test_context() -> CompilerContext { - CompilerContext::new(Arc::new(CallGraph::new())) - } - - fn test_assembly() -> Arc { - test_assembly_arc() + let ctx = CompilerContext::new(Arc::new(CallGraph::new())); + ctx.set_assembly(test_assembly_arc()); + ctx } #[test] @@ -985,9 +985,8 @@ mod tests { ctx.set_ssa(proxy_token, proxy_ssa); let pass = ProxyDevirtualizationPass::new(); - let assembly = test_assembly(); let changed = pass - .run_on_method(&mut caller_ssa, caller_token, &ctx, &assembly) + .run_on_method(&mut caller_ssa, &MethodRef::from(caller_token), &ctx) .unwrap(); assert!(changed, "Should have made changes"); @@ -1034,9 +1033,8 @@ mod tests { ctx.set_ssa(noop_token, noop_ssa); let pass = ProxyDevirtualizationPass::new(); - let assembly = test_assembly(); let changed = pass - .run_on_method(&mut caller_ssa, caller_token, &ctx, &assembly) + .run_on_method(&mut caller_ssa, &MethodRef::from(caller_token), &ctx) .unwrap(); assert!(changed, "Should have made changes"); @@ -1080,9 +1078,8 @@ mod tests { ctx.set_ssa(const_token, const_ssa); let pass = ProxyDevirtualizationPass::new(); - let assembly = test_assembly(); let changed = pass - .run_on_method(&mut caller_ssa, caller_token, &ctx, &assembly) + .run_on_method(&mut caller_ssa, &MethodRef::from(caller_token), &ctx) .unwrap(); assert!(changed, "Should have made changes"); @@ -1119,9 +1116,8 @@ mod tests { ctx.set_ssa(self_token, self_ssa); let pass = ProxyDevirtualizationPass::new(); - let assembly = test_assembly(); let changed = pass - .run_on_method(&mut caller_ssa, self_token, &ctx, &assembly) + .run_on_method(&mut caller_ssa, &MethodRef::from(self_token), &ctx) .unwrap(); assert!(!changed, "Should not devirtualize self-recursive calls"); diff --git a/dotscope/src/compiler/passes/ranges.rs b/dotscope/src/compiler/passes/ranges.rs deleted file mode 100644 index df385cd3..00000000 --- a/dotscope/src/compiler/passes/ranges.rs +++ /dev/null @@ -1,1012 +0,0 @@ -//! Value Range Propagation Pass. -//! -//! This pass performs dataflow-based range analysis to track the possible -//! values of integer variables throughout the control flow graph. It strengthens -//! opaque predicate detection by proving comparisons based on value ranges. -//! -//! # Algorithm -//! -//! Uses a sparse worklist algorithm similar to SCCP: -//! 1. Initialize all variables to `Top` (unknown range) -//! 2. Process definitions to narrow ranges based on operations -//! 3. At conditional branches, narrow ranges for the taken path -//! 4. Use ranges to simplify always-true/false comparisons -//! -//! # Improvements Over Pattern Matching -//! -//! While the `OpaquePredicatePass` uses local pattern matching, this pass -//! propagates ranges through the CFG to catch cases like: -//! -//! ```text -//! B0: x = 5 -//! jump B1 -//! -//! B1: y = x + 10 // y ∈ [15, 15] -//! jump B2 -//! -//! B2: if (y > 100) // Always false: 15 > 100 is false -//! ... -//! ``` -//! -//! The pattern matcher in `OpaquePredicatePass` can't see through the add, -//! but range propagation tracks y = 15 through the CFG. - -use std::collections::{HashMap, HashSet, VecDeque}; - -use crate::{ - analysis::{ConstValue, PhiNode, SsaBlock, SsaCfg, SsaFunction, SsaOp, SsaVarId, ValueRange}, - compiler::{ - pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, EventLog, - }, - metadata::token::Token, - utils::{ - graph::{NodeId, RootedGraph, Successors}, - BitSet, - }, - CilObject, Result, -}; - -/// Value Range Propagation Pass. -/// -/// Performs dataflow-based range analysis to strengthen opaque predicate -/// detection and simplify comparisons that can be proven always-true or -/// always-false based on value ranges. -pub struct ValueRangePropagationPass { - /// Maximum worklist iterations for the dataflow solver. - max_iterations: usize, -} - -impl ValueRangePropagationPass { - /// Creates a new value range propagation pass. - /// - /// # Arguments - /// - /// * `max_iterations` - Maximum worklist iterations for the dataflow solver. - /// In practice analysis converges quickly; the default config value is 10,000. - #[must_use] - pub fn new(max_iterations: usize) -> Self { - Self { max_iterations } - } -} - -impl SsaPass for ValueRangePropagationPass { - fn name(&self) -> &'static str { - "value-range-propagation" - } - - fn description(&self) -> &'static str { - "Propagates value ranges through CFG to detect opaque predicates" - } - - fn modification_scope(&self) -> ModificationScope { - ModificationScope::CfgModifying - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - // Run range analysis - let mut analysis = RangeAnalysis::new(self.max_iterations); - let result = analysis.analyze(ssa); - - // Collect transformations to apply - let changes = EventLog::new(); - let mut branch_simplifications: Vec<(usize, usize, bool)> = Vec::new(); - let mut comparison_replacements: Vec<(usize, usize, SsaVarId, bool)> = Vec::new(); - - // Find branches and comparisons that can be simplified - for (block_idx, block) in ssa.iter_blocks() { - // Check branch terminator - if let Some(SsaOp::Branch { - condition, - true_target, - false_target, - }) = block.terminator_op() - { - if let Some(range) = result.get_range(*condition) { - // Check if range proves the condition - if let Some(is_true) = range.always_equal_to(0) { - // always_equal_to(0) being true means always false - // always_equal_to(0) being false means possibly non-zero - if is_true { - // Condition is always 0 (false) - branch_simplifications.push((block_idx, *false_target, false)); - } - } - - // Check if range is a known non-zero constant - if let Some(val) = range.as_constant() { - if val != 0 { - branch_simplifications.push((block_idx, *true_target, true)); - } - } - } - } - - // Check comparison instructions - for (instr_idx, instr) in block.instructions().iter().enumerate() { - if let Some((dest, value)) = Self::try_simplify_comparison(instr.op(), &result) { - comparison_replacements.push((block_idx, instr_idx, dest, value)); - } - } - } - - // Apply branch simplifications - for (block_idx, target, is_true) in branch_simplifications { - if let Some(block) = ssa.block_mut(block_idx) { - if let Some(last_instr) = block.instructions_mut().last_mut() { - last_instr.set_op(SsaOp::Jump { target }); - changes - .record(EventKind::OpaquePredicateRemoved) - .at(method_token, block_idx) - .message(format!( - "range analysis: condition always {}", - if is_true { "true" } else { "false" } - )); - changes - .record(EventKind::BranchSimplified) - .at(method_token, block_idx) - .message(format!("simplified to unconditional jump to {target}")); - } - } - } - - // Apply comparison replacements - for (block_idx, instr_idx, dest, value) in comparison_replacements { - if let Some(block) = ssa.block_mut(block_idx) { - let const_value = if value { - ConstValue::True - } else { - ConstValue::False - }; - if let Some(instr) = block.instructions_mut().get_mut(instr_idx) { - instr.set_op(SsaOp::Const { - dest, - value: const_value, - }); - } - changes - .record(EventKind::ConstantFolded) - .at(method_token, instr_idx) - .message(format!("range analysis: comparison → {value}")); - } - } - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) - } -} - -impl ValueRangePropagationPass { - /// Tries to simplify a comparison operation using range information. - /// - /// Returns `Some((dest, value))` if the comparison can be proven to always - /// have a constant result, where `value` is the boolean result. - fn try_simplify_comparison(op: &SsaOp, result: &RangeResult) -> Option<(SsaVarId, bool)> { - match op { - SsaOp::Clt { - dest, - left, - right, - unsigned: _, - } => { - let left_range = result.get_range(*left)?; - let right_range = result.get_range(*right)?; - - // Check if left.max < right.min (always true) - // or left.min >= right.max (always false) - if let (Some(l_max), Some(r_min)) = (left_range.max(), right_range.min()) { - if l_max < r_min { - return Some((*dest, true)); - } - } - if let (Some(l_min), Some(r_max)) = (left_range.min(), right_range.max()) { - if l_min >= r_max { - return Some((*dest, false)); - } - } - None - } - - SsaOp::Cgt { - dest, - left, - right, - unsigned: _, - } => { - let left_range = result.get_range(*left)?; - let right_range = result.get_range(*right)?; - - // Check if left.min > right.max (always true) - // or left.max <= right.min (always false) - if let (Some(l_min), Some(r_max)) = (left_range.min(), right_range.max()) { - if l_min > r_max { - return Some((*dest, true)); - } - } - if let (Some(l_max), Some(r_min)) = (left_range.max(), right_range.min()) { - if l_max <= r_min { - return Some((*dest, false)); - } - } - None - } - - SsaOp::Ceq { dest, left, right } => { - let left_range = result.get_range(*left)?; - let right_range = result.get_range(*right)?; - - // If both are constants and equal - if let (Some(l), Some(r)) = (left_range.as_constant(), right_range.as_constant()) { - return Some((*dest, l == r)); - } - - // If ranges don't overlap, they can never be equal - if !Self::ranges_overlap(left_range, right_range) { - return Some((*dest, false)); - } - - None - } - - _ => None, - } - } - - /// Checks if two ranges have any overlap. - fn ranges_overlap(a: &ValueRange, b: &ValueRange) -> bool { - // If either is Top, they might overlap - if a.is_top() || b.is_top() { - return true; - } - // If either is Bottom, they don't overlap (empty set) - if a.is_bottom() || b.is_bottom() { - return false; - } - - // Check if a.max >= b.min && a.min <= b.max - match (a.max(), a.min(), b.max(), b.min()) { - (Some(a_max), Some(a_min), Some(b_max), Some(b_min)) => { - a_max >= b_min && a_min <= b_max - } - // If any bound is unbounded, they might overlap - _ => true, - } - } -} - -/// Sparse range propagation analysis. -/// -/// Uses a worklist algorithm similar to SCCP but tracks value ranges -/// instead of just constants. -struct RangeAnalysis { - /// Current range for each SSA variable. - ranges: HashMap, - /// Executable CFG edges. - executable_edges: HashSet<(usize, usize)>, - /// Blocks that have been marked executable. - executable_blocks: BitSet, - /// SSA worklist: variables whose ranges have changed. - ssa_worklist: VecDeque, - /// CFG worklist: edges that have become executable. - cfg_worklist: VecDeque<(usize, usize)>, - /// Maximum worklist iterations for the dataflow solver. - max_iterations: usize, -} - -impl RangeAnalysis { - /// Creates a new range analysis. - /// - /// # Arguments - /// - /// * `max_iterations` - Maximum worklist iterations for the dataflow solver. - fn new(max_iterations: usize) -> Self { - Self { - ranges: HashMap::new(), - executable_edges: HashSet::new(), - executable_blocks: BitSet::new(0), - ssa_worklist: VecDeque::new(), - cfg_worklist: VecDeque::new(), - max_iterations, - } - } - - /// Runs the range propagation algorithm. - fn analyze(&mut self, ssa: &SsaFunction) -> RangeResult { - let cfg = SsaCfg::from_ssa(ssa); - self.initialize(ssa, &cfg); - self.propagate(ssa, &cfg); - - RangeResult { - ranges: self.ranges.clone(), - } - } - - /// Initializes the analysis state. - fn initialize(&mut self, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - self.ranges.clear(); - self.executable_edges.clear(); - self.executable_blocks = BitSet::new(ssa.block_count()); - self.ssa_worklist.clear(); - self.cfg_worklist.clear(); - - // All variables start as Top (unknown range) - for var in ssa.variables() { - self.ranges.insert(var.id(), ValueRange::top()); - } - - // Mark entry block as executable - let entry = cfg.entry().index(); - self.executable_blocks.insert(entry); - - // Add entry block's outgoing edges - for succ in cfg.successors(cfg.entry()) { - self.cfg_worklist.push_back((entry, succ.index())); - } - - // Process entry block definitions - if let Some(block) = ssa.block(entry) { - self.process_block_definitions(block); - } - } - - /// Main propagation loop. - fn propagate(&mut self, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - // Iteration limit to prevent infinite loops with widening ranges. - // In practice, analysis should converge quickly for most methods. - // If we hit this limit, we still have valid (possibly imprecise) results. - let mut iterations: usize = 0; - - loop { - iterations = iterations.saturating_add(1); - if iterations > self.max_iterations { - // Hit iteration limit - return with current results. - // This can happen with unbounded widening in loops. - break; - } - - // Process CFG worklist first - while let Some((from, to)) = self.cfg_worklist.pop_front() { - if self.executable_edges.insert((from, to)) { - self.process_edge(from, to, ssa, cfg); - } - } - - // Process SSA worklist - if let Some(var) = self.ssa_worklist.pop_front() { - self.process_variable_uses(var, ssa, cfg); - } else { - break; - } - } - } - - /// Processes a newly executable CFG edge. - fn process_edge(&mut self, from: usize, to: usize, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - let first_visit = !self.executable_blocks.contains(to); - - if first_visit { - self.executable_blocks.insert(to); - - if let Some(block) = ssa.block(to) { - self.process_block_definitions(block); - } - } - - // Re-evaluate phi nodes in target block - if let Some(block) = ssa.block(to) { - for phi in block.phi_nodes() { - if phi.operand_from(from).is_some() { - let new_range = self.evaluate_phi(phi, to); - self.update_range(phi.result(), &new_range); - } - } - } - - // If first visit, propagate outgoing edges - if first_visit { - if let Some(block) = ssa.block(to) { - self.propagate_outgoing_edges(to, block, cfg); - } - } - } - - /// Processes all definitions in a block. - fn process_block_definitions(&mut self, block: &SsaBlock) { - for instr in block.instructions() { - if let Some(def) = instr.def() { - let range = self.evaluate_instruction(instr.op()); - self.update_range(def, &range); - } - } - } - - /// Processes uses of a variable whose range changed. - fn process_variable_uses(&mut self, var: SsaVarId, ssa: &SsaFunction, cfg: &G) - where - G: RootedGraph + Successors, - { - if let Some(ssa_var) = ssa.variable(var) { - for use_site in ssa_var.uses() { - let block_id = use_site.block; - - if !self.executable_blocks.contains(block_id) { - continue; - } - - if use_site.is_phi_operand { - if let Some(block) = ssa.block(block_id) { - if let Some(phi) = block.phi(use_site.instruction) { - let new_range = self.evaluate_phi(phi, block_id); - self.update_range(phi.result(), &new_range); - } - } - } else if let Some(block) = ssa.block(block_id) { - if let Some(instr) = block.instruction(use_site.instruction) { - if let Some(def) = instr.def() { - let range = self.evaluate_instruction(instr.op()); - self.update_range(def, &range); - } - - if instr.is_terminator() { - self.propagate_outgoing_edges(block_id, block, cfg); - } - } - } - } - } - } - - /// Propagates outgoing edges based on terminator. - fn propagate_outgoing_edges(&mut self, block_id: usize, block: &SsaBlock, cfg: &G) - where - G: RootedGraph + Successors, - { - match block.terminator_op() { - Some(SsaOp::Branch { - condition, - true_target, - false_target, - }) => { - let range = self.get_range(*condition); - - // Check if we can determine the branch direction - if let Some(val) = range.as_constant() { - if val != 0 { - self.add_cfg_edge(block_id, *true_target); - } else { - self.add_cfg_edge(block_id, *false_target); - } - } else if range.always_equal_to(0) == Some(true) { - // Always zero -> always false - self.add_cfg_edge(block_id, *false_target); - } else if range.is_always_positive() { - // Always positive -> always true (non-zero) - self.add_cfg_edge(block_id, *true_target); - } else if range.is_top() { - // Unknown - don't add edges yet - } else { - // Could go either way - self.add_cfg_edge(block_id, *true_target); - self.add_cfg_edge(block_id, *false_target); - } - } - - Some(SsaOp::Switch { - value, - targets, - default, - }) => { - let range = self.get_range(*value); - - if let Some(idx) = range.as_constant().and_then(|i| usize::try_from(i).ok()) { - // Known switch value - if let Some(&target) = targets.get(idx) { - self.add_cfg_edge(block_id, target); - } else { - self.add_cfg_edge(block_id, *default); - } - } else { - // Unknown - add all edges - for &target in targets { - self.add_cfg_edge(block_id, target); - } - self.add_cfg_edge(block_id, *default); - } - } - - Some(SsaOp::Jump { target }) => { - self.add_cfg_edge(block_id, *target); - } - - Some(SsaOp::Return { .. } | SsaOp::Throw { .. } | SsaOp::Rethrow) => { - // No successors - } - - _ => { - // Fall through - add all CFG successors - let node = NodeId::new(block_id); - for succ in cfg.successors(node) { - self.add_cfg_edge(block_id, succ.index()); - } - } - } - } - - /// Adds a CFG edge to the worklist. - fn add_cfg_edge(&mut self, from: usize, to: usize) { - if !self.executable_edges.contains(&(from, to)) { - self.cfg_worklist.push_back((from, to)); - } - } - - /// Evaluates a phi node to get its current range. - fn evaluate_phi(&self, phi: &PhiNode, block_id: usize) -> ValueRange { - let mut result = ValueRange::bottom(); - let mut has_executable_operand = false; - - for operand in phi.operands() { - let pred = operand.predecessor(); - - if !self.executable_edges.contains(&(pred, block_id)) { - continue; - } - - has_executable_operand = true; - let op_range = self.get_range(operand.value()); - - // Join ranges at merge point - result = result.join(&op_range); - - // Early exit if we've lost all precision - if result.is_top() { - break; - } - } - - if !has_executable_operand { - return ValueRange::top(); - } - - result - } - - /// Evaluates an instruction to get the range of its result. - fn evaluate_instruction(&self, op: &SsaOp) -> ValueRange { - match op { - SsaOp::Const { value, .. } => { - if let Some(v) = value.as_i64() { - ValueRange::constant(v) - } else { - ValueRange::top() - } - } - - SsaOp::Copy { src, .. } => self.get_range(*src), - - SsaOp::Add { left, right, .. } => { - let l = self.get_range(*left); - let r = self.get_range(*right); - l.add(&r) - } - - SsaOp::Sub { left, right, .. } => { - let l = self.get_range(*left); - let r = self.get_range(*right); - l.sub(&r) - } - - SsaOp::Mul { left, right, .. } => { - let l = self.get_range(*left); - let r = self.get_range(*right); - l.mul(&r) - } - - SsaOp::And { left, right, .. } => { - // AND with a constant produces a bounded range - let r = self.get_range(*right); - if let Some(mask) = r.as_constant() { - ValueRange::bounded(0, mask.max(0)) - } else { - let l = self.get_range(*left); - if let Some(mask) = l.as_constant() { - ValueRange::bounded(0, mask.max(0)) - } else { - ValueRange::top() - } - } - } - - SsaOp::Shr { - value, - amount, - unsigned, - .. - } => { - let val_range = self.get_range(*value); - let amt_range = self.get_range(*amount); - - // If shifting by a known amount - if let Some(amt) = amt_range.as_constant() { - if (0..64).contains(&amt) && *unsigned && val_range.is_always_non_negative() { - // Unsigned right shift of non-negative preserves non-negative - // and reduces the range - if let (Some(min), Some(max)) = (val_range.min(), val_range.max()) { - let new_min = min >> amt; - let new_max = max >> amt; - return ValueRange::bounded(new_min, new_max); - } - } - } - ValueRange::top() - } - - SsaOp::Rem { left, right, .. } => { - // x % n produces values in [-(n-1), n-1] for signed - // or [0, n-1] for unsigned - let r = self.get_range(*right); - if let Some(n) = r.as_constant() { - if n > 0 { - // Positive divisor: result in [0, n-1] if dividend is non-negative - let l = self.get_range(*left); - if l.is_always_non_negative() { - return ValueRange::bounded(0, n.saturating_sub(1)); - } - } - } - ValueRange::top() - } - - SsaOp::ArrayLength { .. } => { - // Array length is always >= 0 - ValueRange::non_negative() - } - - SsaOp::NewArr { .. } - | SsaOp::NewObj { .. } - | SsaOp::Box { .. } - | SsaOp::LoadToken { .. } => { - // References - don't track as numeric ranges - ValueRange::top() - } - - // Comparisons produce 0 or 1 - SsaOp::Ceq { .. } | SsaOp::Clt { .. } | SsaOp::Cgt { .. } => ValueRange::bounded(0, 1), - - // All other operations - unknown range - _ => ValueRange::top(), - } - } - - /// Gets the current range of a variable. - fn get_range(&self, var: SsaVarId) -> ValueRange { - self.ranges.get(&var).cloned().unwrap_or_default() - } - - /// Updates a variable's range using meet (intersection). - fn update_range(&mut self, var: SsaVarId, new_range: &ValueRange) { - let old_range = self.ranges.get(&var).cloned().unwrap_or_default(); - - // For range analysis, we use meet (intersection) to narrow ranges - // But we need to be careful: at merge points we use join, not meet - // The evaluate functions already handle this correctly - - // Only update if the range changed - if *new_range != old_range { - self.ranges.insert(var, new_range.clone()); - self.ssa_worklist.push_back(var); - } - } -} - -/// Results of range analysis. -#[derive(Debug)] -struct RangeResult { - /// Range for each SSA variable. - ranges: HashMap, -} - -impl RangeResult { - /// Gets the range of an SSA variable. - fn get_range(&self, var: SsaVarId) -> Option<&ValueRange> { - self.ranges.get(&var) - } -} - -#[cfg(test)] -mod tests { - use std::collections::HashMap; - - use crate::{ - analysis::{SsaFunctionBuilder, SsaOp, SsaVarId, ValueRange}, - compiler::{ - passes::ranges::{RangeAnalysis, RangeResult, ValueRangePropagationPass}, - SsaPass, - }, - }; - - #[test] - fn test_pass_metadata() { - let pass = ValueRangePropagationPass::new(10_000); - assert_eq!(pass.name(), "value-range-propagation"); - assert!(!pass.description().is_empty()); - } - - #[test] - fn test_constant_range() { - let (ssa, v0) = { - let mut v0_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - v0_out = b.const_i32(42); - b.ret(); - }); - }) - .unwrap(); - (ssa, v0_out) - }; - - let mut analysis = RangeAnalysis::new(10_000); - let result = analysis.analyze(&ssa); - - let range = result.get_range(v0).unwrap(); - assert!(range.is_constant()); - assert_eq!(range.as_constant(), Some(42)); - } - - #[test] - fn test_add_range() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(5); - let v1 = b.const_i32(10); - v2_out = b.add(v0, v1); // 5 + 10 = 15 - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut analysis = RangeAnalysis::new(10_000); - let result = analysis.analyze(&ssa); - - let range = result.get_range(v2).unwrap(); - assert_eq!(range.as_constant(), Some(15)); - } - - #[test] - fn test_sub_range() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(20); - let v1 = b.const_i32(7); - v2_out = b.sub(v0, v1); // 20 - 7 = 13 - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut analysis = RangeAnalysis::new(10_000); - let result = analysis.analyze(&ssa); - - let range = result.get_range(v2).unwrap(); - assert_eq!(range.as_constant(), Some(13)); - } - - #[test] - fn test_and_range() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(1000); - let v1 = b.const_i32(0xFF); // Mask to byte range - v2_out = b.and(v0, v1); - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut analysis = RangeAnalysis::new(10_000); - let result = analysis.analyze(&ssa); - - let range = result.get_range(v2).unwrap(); - // AND with 0xFF produces range [0, 255] - assert_eq!(range.min(), Some(0)); - assert_eq!(range.max(), Some(255)); - } - - #[test] - fn test_array_length_range() { - let (ssa, v1) = { - let mut v1_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_null(); // Placeholder array - v1_out = b.array_length(v0); - b.ret(); - }); - }) - .unwrap(); - (ssa, v1_out) - }; - - let mut analysis = RangeAnalysis::new(10_000); - let result = analysis.analyze(&ssa); - - let range = result.get_range(v1).unwrap(); - assert!(range.is_always_non_negative()); - } - - #[test] - fn test_comparison_range() { - let (ssa, v2) = { - let mut v2_out = SsaVarId::from_index(0); - let ssa = SsaFunctionBuilder::new(0, 0) - .build_with(|f| { - f.block(0, |b| { - let v0 = b.const_i32(5); - let v1 = b.const_i32(10); - v2_out = b.clt(v0, v1); // 5 < 10 - b.ret(); - }); - }) - .unwrap(); - (ssa, v2_out) - }; - - let mut analysis = RangeAnalysis::new(10_000); - let result = analysis.analyze(&ssa); - - let range = result.get_range(v2).unwrap(); - // Comparison produces 0 or 1 - assert_eq!(range.min(), Some(0)); - assert_eq!(range.max(), Some(1)); - } - - #[test] - fn test_ranges_overlap() { - // Non-overlapping ranges - let a = ValueRange::bounded(0, 5); - let b = ValueRange::bounded(10, 15); - assert!(!ValueRangePropagationPass::ranges_overlap(&a, &b)); - - // Overlapping ranges - let c = ValueRange::bounded(0, 10); - let d = ValueRange::bounded(5, 15); - assert!(ValueRangePropagationPass::ranges_overlap(&c, &d)); - - // Same range - let e = ValueRange::bounded(5, 10); - assert!(ValueRangePropagationPass::ranges_overlap(&e, &e)); - - // Top overlaps with everything - let top = ValueRange::top(); - assert!(ValueRangePropagationPass::ranges_overlap(&top, &a)); - - // Bottom doesn't overlap - let bottom = ValueRange::bottom(); - assert!(!ValueRangePropagationPass::ranges_overlap(&bottom, &a)); - } - - #[test] - fn test_try_simplify_clt() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - - let mut ranges = HashMap::new(); - ranges.insert(v0, ValueRange::bounded(0, 5)); // [0, 5] - ranges.insert(v1, ValueRange::bounded(10, 20)); // [10, 20] - - let result = RangeResult { ranges }; - - // v0 < v1 should always be true (5 < 10) - let op = SsaOp::Clt { - dest, - left: v0, - right: v1, - unsigned: false, - }; - let simplified = ValueRangePropagationPass::try_simplify_comparison(&op, &result); - assert_eq!(simplified, Some((dest, true))); - } - - #[test] - fn test_try_simplify_cgt() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - - let mut ranges = HashMap::new(); - ranges.insert(v0, ValueRange::bounded(100, 200)); // [100, 200] - ranges.insert(v1, ValueRange::bounded(0, 50)); // [0, 50] - - let result = RangeResult { ranges }; - - // v0 > v1 should always be true (100 > 50) - let op = SsaOp::Cgt { - dest, - left: v0, - right: v1, - unsigned: false, - }; - let simplified = ValueRangePropagationPass::try_simplify_comparison(&op, &result); - assert_eq!(simplified, Some((dest, true))); - } - - #[test] - fn test_try_simplify_ceq_never() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - - let mut ranges = HashMap::new(); - ranges.insert(v0, ValueRange::bounded(0, 5)); // [0, 5] - ranges.insert(v1, ValueRange::bounded(10, 20)); // [10, 20] - - let result = RangeResult { ranges }; - - // v0 == v1 should always be false (ranges don't overlap) - let op = SsaOp::Ceq { - dest, - left: v0, - right: v1, - }; - let simplified = ValueRangePropagationPass::try_simplify_comparison(&op, &result); - assert_eq!(simplified, Some((dest, false))); - } - - #[test] - fn test_try_simplify_ceq_constants() { - let v0 = SsaVarId::from_index(0); - let v1 = SsaVarId::from_index(1); - let dest = SsaVarId::from_index(2); - - let mut ranges = HashMap::new(); - ranges.insert(v0, ValueRange::constant(42)); - ranges.insert(v1, ValueRange::constant(42)); - - let result = RangeResult { ranges }; - - // v0 == v1 should always be true (both are 42) - let op = SsaOp::Ceq { - dest, - left: v0, - right: v1, - }; - let simplified = ValueRangePropagationPass::try_simplify_comparison(&op, &result); - assert_eq!(simplified, Some((dest, true))); - } -} diff --git a/dotscope/src/compiler/passes/reassociate.rs b/dotscope/src/compiler/passes/reassociate.rs deleted file mode 100644 index c565d628..00000000 --- a/dotscope/src/compiler/passes/reassociate.rs +++ /dev/null @@ -1,585 +0,0 @@ -//! Reassociation pass. -//! -//! This pass reorders operations to enable better constant folding. For example: -//! -//! ```text -//! (x + 5) + 3 → x + (5 + 3) → x + 8 -//! ``` -//! -//! # Transformed Patterns -//! -//! ## Associative operations (`add`, `mul`, `and`, `or`, `xor`) -//! -//! For these operations, constants combine using the same operation: -//! - `(x + c1) + c2` → `x + (c1 + c2)` -//! - `(x * c1) * c2` → `x * (c1 * c2)` -//! - `(x ^ c1) ^ c2` → `x ^ (c1 ^ c2)` (common in obfuscation) -//! -//! ## Subtraction chains -//! -//! Subtraction is not associative, but constants combine with addition: -//! - `(x - c1) - c2` → `x - (c1 + c2)` -//! -//! ## Shift chains -//! -//! Shift amounts combine with addition: -//! - `(x << c1) << c2` → `x << (c1 + c2)` -//! - `(x >> c1) >> c2` → `x >> (c1 + c2)` (preserves signedness) -//! -//! # Implementation Strategy -//! -//! The pass works in three phases: -//! 1. Find constants and their defining instructions -//! 2. Identify operations where one operand is the result of another same-op with a constant -//! 3. Combine the constants and rewrite to `x op combined_const` -//! -//! The SCCP pass will then fold the combined constants in the next iteration. - -use std::collections::{BTreeMap, HashSet}; - -use crate::{ - analysis::{ConstValue, DefUseIndex, SsaFunction, SsaOp, SsaVarId}, - compiler::{ - pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, EventLog, - }, - metadata::{token::Token, typesystem::PointerSize}, - CilObject, Result, -}; - -/// Reassociation pass that reorders operations to enable constant folding. -pub struct ReassociationPass; - -impl Default for ReassociationPass { - fn default() -> Self { - Self::new() - } -} - -/// A candidate for reassociation. -#[derive(Debug)] -struct ReassociationCandidate { - /// Block containing the outer operation - block_idx: usize, - /// Instruction index of the outer operation - instr_idx: usize, - /// The destination variable - dest: SsaVarId, - /// The non-constant operand (x in `(x op c1) op c2`) - base_var: SsaVarId, - /// The first constant variable (c1) - const1_var: SsaVarId, - /// The second constant variable (c2) - const2_var: SsaVarId, - /// The first constant value - const1_value: ConstValue, - /// The second constant value - const2_value: ConstValue, - /// Block and instruction of the inner operation (to mark for removal) - inner_block: usize, - inner_instr: usize, - inner_dest: SsaVarId, - /// The type of operation - op_kind: OpKind, -} - -/// The kind of operation that can be reassociated. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -enum OpKind { - /// Addition: constants combine with add - Add, - /// Subtraction: constants combine with add - Sub, - /// Multiplication: constants combine with mul - Mul, - /// Bitwise AND: constants combine with and - And, - /// Bitwise OR: constants combine with or - Or, - /// Bitwise XOR: constants combine with xor - Xor, - /// Shift left: shift amounts combine with add - Shl, - /// Shift right: shift amounts combine with add (preserves signedness) - Shr { unsigned: bool }, -} - -impl OpKind { - /// Combines two constants for reassociation. - /// - /// For associative operations (Add, Mul, And, Or, Xor), the combining - /// operation is the same as the main operation. - /// - /// For non-associative operations that still benefit from reassociation: - /// - Sub: `(x - c1) - c2` → `x - (c1 + c2)` (combine with add) - /// - Shl/Shr: `(x << c1) << c2` → `x << (c1 + c2)` (combine with add) - fn combine( - self, - c1: &ConstValue, - c2: &ConstValue, - ptr_size: PointerSize, - ) -> Option { - match self { - // Associative: combine with same operation - OpKind::Add | OpKind::Sub | OpKind::Shl | OpKind::Shr { .. } => c1.add(c2, ptr_size), - OpKind::Mul => c1.mul(c2, ptr_size), - OpKind::And => c1.bitwise_and(c2, ptr_size), - OpKind::Or => c1.bitwise_or(c2, ptr_size), - OpKind::Xor => c1.bitwise_xor(c2, ptr_size), - } - } - - /// Returns a description of the operation. - fn name(self) -> &'static str { - match self { - OpKind::Add => "add", - OpKind::Sub => "sub", - OpKind::Mul => "mul", - OpKind::And => "and", - OpKind::Or => "or", - OpKind::Xor => "xor", - OpKind::Shl => "shl", - OpKind::Shr { unsigned: false } => "shr", - OpKind::Shr { unsigned: true } => "shr.un", - } - } - - /// Returns the name of the combining operation for logging. - fn combine_name(self) -> &'static str { - match self { - // Associative: combine with same operation - OpKind::Add | OpKind::Mul | OpKind::And | OpKind::Or | OpKind::Xor => self.name(), - // Non-associative: combine with addition - OpKind::Sub | OpKind::Shl | OpKind::Shr { .. } => "add", - } - } - - /// Returns true if this operation is commutative. - /// - /// For non-commutative operations, the constant must be on the right side. - const fn is_commutative(self) -> bool { - match self { - OpKind::Add | OpKind::Mul | OpKind::And | OpKind::Or | OpKind::Xor => true, - OpKind::Sub | OpKind::Shl | OpKind::Shr { .. } => false, - } - } -} - -impl ReassociationPass { - /// Creates a new reassociation pass. - #[must_use] - pub fn new() -> Self { - Self - } - - /// Gets the OpKind if the operation can be reassociated. - /// - /// Returns (kind, dest, left_operand, right_operand). - /// For shift operations, left is the value being shifted, right is the amount. - fn get_op_kind(op: &SsaOp) -> Option<(OpKind, SsaVarId, SsaVarId, SsaVarId)> { - match op { - SsaOp::Add { dest, left, right } => Some((OpKind::Add, *dest, *left, *right)), - SsaOp::Sub { dest, left, right } => Some((OpKind::Sub, *dest, *left, *right)), - SsaOp::Mul { dest, left, right } => Some((OpKind::Mul, *dest, *left, *right)), - SsaOp::And { dest, left, right } => Some((OpKind::And, *dest, *left, *right)), - SsaOp::Or { dest, left, right } => Some((OpKind::Or, *dest, *left, *right)), - SsaOp::Xor { dest, left, right } => Some((OpKind::Xor, *dest, *left, *right)), - SsaOp::Shl { - dest, - value, - amount, - } => Some((OpKind::Shl, *dest, *value, *amount)), - SsaOp::Shr { - dest, - value, - amount, - unsigned, - } => Some(( - OpKind::Shr { - unsigned: *unsigned, - }, - *dest, - *value, - *amount, - )), - _ => None, - } - } - - /// Creates a new operation with the given operands. - /// - /// For shift operations, `left` is the value being shifted and `right` is the amount. - fn make_op(kind: OpKind, dest: SsaVarId, left: SsaVarId, right: SsaVarId) -> SsaOp { - match kind { - OpKind::Add => SsaOp::Add { dest, left, right }, - OpKind::Sub => SsaOp::Sub { dest, left, right }, - OpKind::Mul => SsaOp::Mul { dest, left, right }, - OpKind::And => SsaOp::And { dest, left, right }, - OpKind::Or => SsaOp::Or { dest, left, right }, - OpKind::Xor => SsaOp::Xor { dest, left, right }, - OpKind::Shl => SsaOp::Shl { - dest, - value: left, - amount: right, - }, - OpKind::Shr { unsigned } => SsaOp::Shr { - dest, - value: left, - amount: right, - unsigned, - }, - } - } - - /// Finds reassociation candidates. - fn find_candidates( - ssa: &SsaFunction, - constants: &BTreeMap, - index: &DefUseIndex, - uses: &BTreeMap, - ) -> Vec { - let mut candidates = Vec::new(); - - for (block_idx, instr_idx, instr) in ssa.iter_instructions() { - if let Some(candidate) = - Self::check_reassociation(instr.op(), block_idx, instr_idx, constants, index, uses) - { - candidates.push(candidate); - } - } - - candidates - } - - /// Checks if an operation can be reassociated. - fn check_reassociation( - op: &SsaOp, - block_idx: usize, - instr_idx: usize, - constants: &BTreeMap, - index: &DefUseIndex, - uses: &BTreeMap, - ) -> Option { - // Get the outer operation's kind and operands - let (outer_kind, dest, outer_left, outer_right) = Self::get_op_kind(op)?; - - // Try: (inner_result op c2) where inner_result = (x op c1) - // Check right operand is a constant - let c2_value = constants.get(&outer_right)?; - - // Check left operand is the result of a same-kind operation - // Use DefUseIndex to get (block, instruction, operation) in one call - let (inner_block, inner_instr, inner_op) = index.full_definition(outer_left)?; - let (inner_kind, inner_dest, inner_left, inner_right) = Self::get_op_kind(inner_op)?; - - // Must be the same operation kind - if inner_kind != outer_kind { - return None; - } - - // The inner result should only be used once (by this outer operation) - // Otherwise we'd create extra computation - let inner_uses = uses.get(&inner_dest).copied().unwrap_or(0); - if inner_uses > 1 { - return None; - } - - // Try to find a constant in the inner operation - // Case 1: inner_right is a constant (works for all operations) - // Pattern: (x op c1) op c2 → x op (c1 combine c2) - if let Some(c1_value) = constants.get(&inner_right) { - return Some(ReassociationCandidate { - block_idx, - instr_idx, - dest, - base_var: inner_left, - const1_var: inner_right, - const2_var: outer_right, - const1_value: c1_value.clone(), - const2_value: c2_value.clone(), - inner_block, - inner_instr, - inner_dest, - op_kind: outer_kind, - }); - } - - // Case 2: inner_left is a constant (only for commutative ops) - // Pattern: (c1 op x) op c2 → x op (c1 combine c2) - // This doesn't work for sub/shl/shr: (c1 - x) - c2 ≠ x - (c1 + c2) - if outer_kind.is_commutative() { - if let Some(c1_value) = constants.get(&inner_left) { - return Some(ReassociationCandidate { - block_idx, - instr_idx, - dest, - base_var: inner_right, - const1_var: inner_left, - const2_var: outer_right, - const1_value: c1_value.clone(), - const2_value: c2_value.clone(), - inner_block, - inner_instr, - inner_dest, - op_kind: outer_kind, - }); - } - } - - None - } - - /// Applies the reassociation transformations. - /// - /// Candidates from a single pass can overlap: in a chain `(A op B) op C) op K`, - /// the pass emits a candidate for `(Op(A,B), Op(t1,C))` *and* one for - /// `(Op(t1,C), Op(t2,K))`. Applying both blindly is unsafe because the second - /// candidate's captured `inner_instr` was rewritten by the first (into a - /// `Copy`), and unconditionally overwriting that `Copy` with a new binary op - /// re-introduces a stale operand — causing the middle constant to be applied - /// twice and cancelling itself under XOR. Skip any candidate whose inner or - /// outer instruction position has already been mutated in this pass. - fn apply_reassociations( - ssa: &mut SsaFunction, - candidates: Vec, - method_token: Token, - changes: &mut EventLog, - ptr_size: PointerSize, - ) { - let mut modified: HashSet<(usize, usize)> = HashSet::new(); - - for candidate in candidates { - // Skip if either position has already been rewritten by a prior - // overlapping candidate (see doc comment above). - if modified.contains(&(candidate.inner_block, candidate.inner_instr)) - || modified.contains(&(candidate.block_idx, candidate.instr_idx)) - { - continue; - } - - // Combine the constants - let Some(combined) = candidate.op_kind.combine( - &candidate.const1_value, - &candidate.const2_value, - ptr_size, - ) else { - continue; - }; - - // Update the first constant definition to the combined value - if let Some(block) = ssa.block_mut(candidate.inner_block) { - // Find and update the const1 definition - for instr in block.instructions_mut() { - if let SsaOp::Const { dest, value: _ } = instr.op() { - if *dest == candidate.const1_var { - instr.set_op(SsaOp::Const { - dest: *dest, - value: combined.clone(), - }); - break; - } - } - } - - // Update the inner operation to just use base_var and the combined constant - if let Some(inner_instr) = block.instructions_mut().get_mut(candidate.inner_instr) { - inner_instr.set_op(Self::make_op( - candidate.op_kind, - candidate.inner_dest, - candidate.base_var, - candidate.const1_var, - )); - } - } - - // Replace the outer operation with a Copy from the inner result - if let Some(block) = ssa.block_mut(candidate.block_idx) { - if let Some(outer_instr) = block.instructions_mut().get_mut(candidate.instr_idx) { - outer_instr.set_op(SsaOp::Copy { - dest: candidate.dest, - src: candidate.inner_dest, - }); - } - } - - modified.insert((candidate.inner_block, candidate.inner_instr)); - modified.insert((candidate.block_idx, candidate.instr_idx)); - - changes - .record(EventKind::ConstantFolded) - .at(method_token, candidate.instr_idx) - .message(format!( - "reassociate: (x {} c1) {} c2 → x {} (c1 {} c2)", - candidate.op_kind.name(), - candidate.op_kind.name(), - candidate.op_kind.name(), - candidate.op_kind.combine_name() - )); - } - } -} - -impl SsaPass for ReassociationPass { - fn name(&self) -> &'static str { - "reassociation" - } - - fn description(&self) -> &'static str { - "Reorder operations to enable constant folding (add, sub, mul, and, or, xor, shl, shr)" - } - - fn modification_scope(&self) -> ModificationScope { - ModificationScope::InstructionsOnly - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { - let ptr_size = PointerSize::from_pe(assembly.file().pe().is_64bit); - let mut changes = EventLog::new(); - - // Gather information - let constants = ssa.find_constants(); - let index = DefUseIndex::build_with_ops(ssa); - let uses = ssa.count_uses(); - - // Find and apply reassociations - let candidates = Self::find_candidates(ssa, &constants, &index, &uses); - Self::apply_reassociations(ssa, candidates, method_token, &mut changes, ptr_size); - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) - } -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::ConstValue, compiler::passes::reassociate::OpKind, - metadata::typesystem::PointerSize, - }; - - #[test] - fn test_op_kind_combine_add() { - let c1 = ConstValue::I32(5); - let c2 = ConstValue::I32(3); - let result = OpKind::Add.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(8))); - } - - #[test] - fn test_op_kind_combine_xor() { - let c1 = ConstValue::I32(0xF0); - let c2 = ConstValue::I32(0x0F); - let result = OpKind::Xor.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(0xFF))); - } - - #[test] - fn test_op_kind_combine_mul() { - let c1 = ConstValue::I32(7); - let c2 = ConstValue::I32(11); - let result = OpKind::Mul.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(77))); - } - - #[test] - fn test_op_kind_combine_and() { - let c1 = ConstValue::I32(0xFF); - let c2 = ConstValue::I32(0x0F); - let result = OpKind::And.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(0x0F))); - } - - #[test] - fn test_op_kind_combine_or() { - let c1 = ConstValue::I32(0xF0); - let c2 = ConstValue::I32(0x0F); - let result = OpKind::Or.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(0xFF))); - } - - #[test] - fn test_op_kind_combine_sub() { - // (x - 5) - 3 → x - (5 + 3) → x - 8 - let c1 = ConstValue::I32(5); - let c2 = ConstValue::I32(3); - let result = OpKind::Sub.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(8))); - } - - #[test] - fn test_op_kind_combine_shl() { - // (x << 2) << 3 → x << (2 + 3) → x << 5 - let c1 = ConstValue::I32(2); - let c2 = ConstValue::I32(3); - let result = OpKind::Shl.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(5))); - } - - #[test] - fn test_op_kind_combine_shr() { - // (x >> 4) >> 2 → x >> (4 + 2) → x >> 6 - let c1 = ConstValue::I32(4); - let c2 = ConstValue::I32(2); - let result = OpKind::Shr { unsigned: false }.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(6))); - } - - #[test] - fn test_op_kind_combine_shr_unsigned() { - // (x >>> 4) >>> 2 → x >>> (4 + 2) → x >>> 6 - let c1 = ConstValue::I32(4); - let c2 = ConstValue::I32(2); - let result = OpKind::Shr { unsigned: true }.combine(&c1, &c2, PointerSize::Bit64); - assert_eq!(result, Some(ConstValue::I32(6))); - } - - #[test] - fn test_op_kind_is_commutative() { - assert!(OpKind::Add.is_commutative()); - assert!(OpKind::Mul.is_commutative()); - assert!(OpKind::And.is_commutative()); - assert!(OpKind::Or.is_commutative()); - assert!(OpKind::Xor.is_commutative()); - assert!(!OpKind::Sub.is_commutative()); - assert!(!OpKind::Shl.is_commutative()); - assert!(!OpKind::Shr { unsigned: false }.is_commutative()); - assert!(!OpKind::Shr { unsigned: true }.is_commutative()); - } - - #[test] - fn test_op_kind_combine_name() { - // Associative operations combine with themselves - assert_eq!(OpKind::Add.combine_name(), "add"); - assert_eq!(OpKind::Mul.combine_name(), "mul"); - assert_eq!(OpKind::And.combine_name(), "and"); - assert_eq!(OpKind::Or.combine_name(), "or"); - assert_eq!(OpKind::Xor.combine_name(), "xor"); - // Non-associative operations combine with add - assert_eq!(OpKind::Sub.combine_name(), "add"); - assert_eq!(OpKind::Shl.combine_name(), "add"); - assert_eq!(OpKind::Shr { unsigned: false }.combine_name(), "add"); - assert_eq!(OpKind::Shr { unsigned: true }.combine_name(), "add"); - } - - #[test] - fn test_op_kind_name() { - assert_eq!(OpKind::Add.name(), "add"); - assert_eq!(OpKind::Sub.name(), "sub"); - assert_eq!(OpKind::Mul.name(), "mul"); - assert_eq!(OpKind::And.name(), "and"); - assert_eq!(OpKind::Or.name(), "or"); - assert_eq!(OpKind::Xor.name(), "xor"); - assert_eq!(OpKind::Shl.name(), "shl"); - assert_eq!(OpKind::Shr { unsigned: false }.name(), "shr"); - assert_eq!(OpKind::Shr { unsigned: true }.name(), "shr.un"); - } -} diff --git a/dotscope/src/compiler/passes/strength.rs b/dotscope/src/compiler/passes/strength.rs index 9bf68145..93a8e051 100644 --- a/dotscope/src/compiler/passes/strength.rs +++ b/dotscope/src/compiler/passes/strength.rs @@ -1,55 +1,22 @@ -//! Strength reduction pass. +//! Strength reduction pass — thin wrapper. //! -//! This pass transforms expensive operations into cheaper equivalents: -//! -//! - **Multiplication by power of 2**: `x * 2^n` → `x << n` -//! - **Unsigned division by power of 2**: `x / 2^n` → `x >> n` (unsigned only) -//! - **Unsigned modulo by power of 2**: `x % 2^n` → `x & (2^n - 1)` (unsigned only) -//! -//! # Safety -//! -//! Signed division and modulo are NOT transformed because: -//! - Signed division rounds toward zero, shifts round toward negative infinity -//! - `-5 / 2 = -2` but `-5 >> 1 = -3` -//! - The transformation is only safe when the value is provably non-negative -//! -//! # Implementation Strategy -//! -//! The pass works by: -//! 1. Finding constant definitions and tracking their use counts -//! 2. Identifying reducible operations (mul/div/rem with power-of-2 constant) -//! 3. For single-use constants: modify the constant value in-place and transform the op -//! 4. For multi-use constants: skip (would require inserting new instructions) -//! -//! # Example -//! -//! Before: -//! ```text -//! v1 = const 8 -//! v2 = mul v0, v1 -//! ``` -//! -//! After: -//! ```text -//! v1 = const 3 -//! v2 = shl v0, v1 -//! ``` +//! Transformation logic lives in [`analyssa::passes::strength`]. The wrapper +//! supplies the host-side "is this variable provably non-negative?" +//! predicate — looked up via `CompilerContext::with_known_range` against +//! the range-analysis cache. + +use analyssa::passes::strength; use crate::{ - analysis::{ConstValue, DefUseIndex, SsaFunction, SsaOp, SsaVarId, ValueRange}, + analysis::{CilTarget, MethodRef, SsaFunction, ValueRange}, compiler::{ pass::{ModificationScope, SsaPass}, - CompilerContext, EventKind, EventLog, + CompilerContext, }, - metadata::token::Token, - utils::{is_power_of_two, BitSet}, - CilObject, Result, }; -/// Strength reduction pass that transforms expensive operations to cheaper equivalents. -/// -/// This pass identifies multiplication, division, and modulo operations where one -/// operand is a power of two, and transforms them to equivalent shift/mask operations. +/// Strength reduction pass that transforms expensive operations to cheaper +/// equivalents. pub struct StrengthReductionPass; impl Default for StrengthReductionPass { @@ -58,337 +25,15 @@ impl Default for StrengthReductionPass { } } -/// Location of an instruction in SSA form. -#[derive(Debug, Clone, Copy)] -struct InstrLocation { - /// Block index containing the instruction - block_idx: usize, - /// Instruction index within the block - instr_idx: usize, -} - -/// Helper for checking strength reduction candidates. -/// -/// Bundles the def-use index and used constants tracking to avoid -/// passing them through every reduction check function. -struct ReductionChecker<'a> { - /// Def-use index for the SSA function - index: &'a DefUseIndex, - /// Constants already used in other reductions (to avoid double-transform) - used_constants: &'a BitSet, -} - -impl<'a> ReductionChecker<'a> { - /// Creates a new reduction checker. - fn new(index: &'a DefUseIndex, used_constants: &'a BitSet) -> Self { - Self { - index, - used_constants, - } - } - - /// Tries to create a multiplication reduction candidate: `x * 2^n` → `x << n` - fn try_mul_reduction( - &self, - dest: SsaVarId, - value_var: SsaVarId, - const_var: SsaVarId, - location: InstrLocation, - ) -> Option { - // Check if const_var is a constant using DefUseIndex - let (const_block, const_instr, const_op) = self.index.full_definition(const_var)?; - let SsaOp::Const { - value: const_value, .. - } = const_op - else { - return None; - }; - - // Check if it's a power of two - let value = const_value.as_i64()?; - let exponent = is_power_of_two(value)?; - - // Check if constant is single-use (or we skip this reduction) - let uses = self.index.use_count(const_var); - if uses != 1 || self.used_constants.contains(const_var.index()) { - return None; - } - - Some(ReductionCandidate { - location, - const_var, - const_block, - const_instr, - new_const_value: ConstValue::I32(i32::from(exponent)), - new_op: SsaOp::Shl { - dest, - value: value_var, - amount: const_var, - }, - description: format!("mul x, {value} → shl x, {exponent}"), - }) - } - - /// Tries to create a division reduction candidate: `x / 2^n` → `x >> n` - fn try_div_reduction( - &self, - dest: SsaVarId, - dividend: SsaVarId, - divisor_var: SsaVarId, - unsigned: bool, - location: InstrLocation, - ) -> Option { - let (const_block, const_instr, const_op) = self.index.full_definition(divisor_var)?; - let SsaOp::Const { - value: const_value, .. - } = const_op - else { - return None; - }; - let value = const_value.as_i64()?; - let exponent = is_power_of_two(value)?; - - let uses = self.index.use_count(divisor_var); - if uses != 1 || self.used_constants.contains(divisor_var.index()) { - return None; - } - - let desc = if unsigned { - format!("div.un x, {value} → shr.un x, {exponent}") - } else { - format!("div x, {value} → shr x, {exponent} (x >= 0)") - }; - - Some(ReductionCandidate { - location, - const_var: divisor_var, - const_block, - const_instr, - new_const_value: ConstValue::I32(i32::from(exponent)), - new_op: SsaOp::Shr { - dest, - value: dividend, - amount: divisor_var, - unsigned, - }, - description: desc, - }) - } - - /// Tries to create a remainder reduction candidate: `x % 2^n` → `x & (2^n - 1)` - #[allow(clippy::cast_possible_truncation)] // mask fits in i32 for typical divisors - fn try_rem_reduction( - &self, - dest: SsaVarId, - dividend: SsaVarId, - divisor_var: SsaVarId, - unsigned: bool, - location: InstrLocation, - ) -> Option { - let (const_block, const_instr, const_op) = self.index.full_definition(divisor_var)?; - let SsaOp::Const { - value: const_value, .. - } = const_op - else { - return None; - }; - let value = const_value.as_i64()?; - let _exponent = is_power_of_two(value)?; - let mask = value.checked_sub(1)?; // 2^n - 1 - - let uses = self.index.use_count(divisor_var); - if uses != 1 || self.used_constants.contains(divisor_var.index()) { - return None; - } - - let desc = if unsigned { - format!("rem.un x, {value} → and x, {mask}") - } else { - format!("rem x, {value} → and x, {mask} (x >= 0)") - }; - - Some(ReductionCandidate { - location, - const_var: divisor_var, - const_block, - const_instr, - new_const_value: ConstValue::I32(mask as i32), - new_op: SsaOp::And { - dest, - left: dividend, - right: divisor_var, - }, - description: desc, - }) - } -} - -/// Information about a potential reduction. -#[derive(Debug)] -struct ReductionCandidate { - /// Location of the operation to reduce - location: InstrLocation, - /// The constant variable (power of 2) - const_var: SsaVarId, - /// Block where the constant is defined - const_block: usize, - /// Instruction index where the constant is defined - const_instr: usize, - /// The new constant value (shift amount or mask) - new_const_value: ConstValue, - /// The new operation to replace with - new_op: SsaOp, - /// Description for logging - description: String, -} - impl StrengthReductionPass { /// Creates a new strength reduction pass. #[must_use] pub fn new() -> Self { Self } - - /// Identifies reduction candidates in the SSA function. - fn find_candidates( - ssa: &SsaFunction, - index: &DefUseIndex, - ctx: &CompilerContext, - method_token: Token, - ) -> Vec { - let mut candidates = Vec::new(); - - // Set of constant variables that are already being transformed - // (to avoid transforming the same constant twice if used in multiple reductions) - let mut used_constants = BitSet::new(ssa.var_id_capacity()); - - for (block_idx, instr_idx, instr) in ssa.iter_instructions() { - let checker = ReductionChecker::new(index, &used_constants); - let location = InstrLocation { - block_idx, - instr_idx, - }; - if let Some(candidate) = - Self::check_reduction(instr.op(), location, &checker, ctx, method_token) - { - used_constants.insert(candidate.const_var.index()); - candidates.push(candidate); - } - } - - candidates - } - - /// Checks if an operation can be strength-reduced. - fn check_reduction( - op: &SsaOp, - location: InstrLocation, - checker: &ReductionChecker<'_>, - ctx: &CompilerContext, - method_token: Token, - ) -> Option { - match op { - // Multiplication: x * 2^n → x << n - SsaOp::Mul { dest, left, right } => { - // Try right operand first (more common: x * 8) - if let Some(candidate) = checker.try_mul_reduction(*dest, *left, *right, location) { - return Some(candidate); - } - // Try left operand (less common: 8 * x) - checker.try_mul_reduction(*dest, *right, *left, location) - } - - // Unsigned division: x / 2^n → x >> n (unsigned only) - SsaOp::Div { - dest, - left, - right, - unsigned: true, - } => checker.try_div_reduction(*dest, *left, *right, true, location), - - // Signed division: only when dividend is provably non-negative - SsaOp::Div { - dest, - left, - right, - unsigned: false, - } => { - // Check if left operand is provably non-negative - if Self::is_provably_non_negative(*left, ctx, method_token) { - checker.try_div_reduction(*dest, *left, *right, false, location) - } else { - None - } - } - - // Unsigned modulo: x % 2^n → x & (2^n - 1) - SsaOp::Rem { - dest, - left, - right, - unsigned: true, - } => checker.try_rem_reduction(*dest, *left, *right, true, location), - - // Signed modulo: only when dividend is provably non-negative - SsaOp::Rem { - dest, - left, - right, - unsigned: false, - } => { - if Self::is_provably_non_negative(*left, ctx, method_token) { - checker.try_rem_reduction(*dest, *left, *right, false, location) - } else { - None - } - } - - _ => None, - } - } - - /// Checks if a variable is provably non-negative via range analysis. - fn is_provably_non_negative(var: SsaVarId, ctx: &CompilerContext, method_token: Token) -> bool { - ctx.with_known_range(method_token, var, ValueRange::is_always_non_negative) - .unwrap_or(false) - } - - /// Applies the reduction candidates to the SSA function. - fn apply_reductions( - ssa: &mut SsaFunction, - candidates: Vec, - method_token: Token, - changes: &mut EventLog, - ) { - for candidate in candidates { - // First, update the constant definition - if let Some(block) = ssa.block_mut(candidate.const_block) { - if let Some(const_instr) = block.instructions_mut().get_mut(candidate.const_instr) { - const_instr.set_op(SsaOp::Const { - dest: candidate.const_var, - value: candidate.new_const_value, - }); - } - } - - // Then, update the operation - if let Some(block) = ssa.block_mut(candidate.location.block_idx) { - if let Some(instr) = block - .instructions_mut() - .get_mut(candidate.location.instr_idx) - { - instr.set_op(candidate.new_op); - changes - .record(EventKind::StrengthReduced) - .at(method_token, candidate.location.instr_idx) - .message(&candidate.description); - } - } - } - } } -impl SsaPass for StrengthReductionPass { +impl SsaPass for StrengthReductionPass { fn name(&self) -> &'static str { "strength-reduction" } @@ -404,23 +49,14 @@ impl SsaPass for StrengthReductionPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - let mut changes = EventLog::new(); - // Build DefUseIndex for definition and use tracking - let index = DefUseIndex::build_with_ops(ssa); - - let candidates = Self::find_candidates(ssa, &index, ctx, method_token); - - // Apply reductions - Self::apply_reductions(ssa, candidates, method_token, &mut changes); - - let changed = !changes.is_empty(); - if changed { - ctx.events.merge(&changes); - } - Ok(changed) + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let token = method.0; + let is_non_negative = |var| { + host.with_known_range(token, var, ValueRange::is_always_non_negative) + .unwrap_or(false) + }; + Ok(strength::run(ssa, method, &host.events, &is_non_negative)) } } diff --git a/dotscope/src/compiler/passes/threading.rs b/dotscope/src/compiler/passes/threading.rs deleted file mode 100644 index ef85e410..00000000 --- a/dotscope/src/compiler/passes/threading.rs +++ /dev/null @@ -1,545 +0,0 @@ -//! Jump threading pass for semantic control flow simplification. -//! -//! This pass threads branches when the condition value can be determined from -//! a specific predecessor path. Unlike the basic trampoline threading in -//! [`ControlFlowSimplificationPass`], this pass evaluates branch conditions -//! based on known values. -//! -//! # Motivation -//! -//! After control-flow unflattening, we often have patterns like: -//! -//! ```text -//! B0: state = 5 -//! jump B1 -//! -//! B1: if (state > 0) goto B2 else goto B3 -//! ``` -//! -//! Jump threading recognizes that coming from B0, the condition `state > 0` -//! is always true (since state=5), and threads B0 directly to B2: -//! -//! ```text -//! B0: state = 5 -//! jump B2 // Threaded! -//! -//! B1: if (state > 0) goto B2 else goto B3 -//! ``` -//! -//! DCE will later clean up B1 if it becomes unreachable. -//! -//! # Algorithm -//! -//! For each block with a branch terminator: -//! 1. For each predecessor of the block -//! 2. Use path-aware evaluation to determine the condition value from that predecessor -//! 3. If the condition evaluates to a known constant, thread the predecessor -//! directly to the taken branch target -//! -//! [`ControlFlowSimplificationPass`]: super::ControlFlowSimplificationPass - -use crate::{ - analysis::{ConstValue, SsaCfg, SsaEvaluator, SsaFunction, SsaOp, SsaVarId}, - compiler::{pass::SsaPass, CompilerContext, EventKind, EventLog}, - metadata::{token::Token, typesystem::PointerSize}, - CilObject, Result, -}; - -/// Jump threading pass for semantic branch elimination. -/// -/// This pass evaluates branch conditions based on values known from specific -/// incoming paths, and threads predecessors directly to the taken target when -/// the branch outcome can be determined. -pub struct JumpThreadingPass; - -impl Default for JumpThreadingPass { - fn default() -> Self { - Self::new() - } -} - -impl JumpThreadingPass { - /// Creates a new jump threading pass. - #[must_use] - pub fn new() -> Self { - Self - } - - /// Evaluates a branch condition from a specific predecessor using path-aware evaluation. - /// - /// This uses `SsaEvaluator` to: - /// 1. Evaluate the predecessor block to establish known values - /// 2. Set the predecessor for phi node resolution - /// 3. Evaluate the branch block's phi nodes - /// 4. Resolve the condition value - /// - /// Returns the target block if the condition can be determined, None otherwise. - fn try_thread( - ssa: &SsaFunction, - pred_block: usize, - branch_block: usize, - condition: SsaVarId, - true_target: usize, - false_target: usize, - ptr_size: PointerSize, - ) -> Option { - let mut eval = SsaEvaluator::new(ssa, ptr_size); - - // Evaluate the predecessor block to establish any constant values - eval.evaluate_block(pred_block); - - // Set predecessor for phi resolution in the branch block - eval.set_predecessor(Some(pred_block)); - - // Evaluate phi nodes in the branch block (this resolves phis using the predecessor) - eval.evaluate_phis(branch_block); - - // Try to resolve the condition value with tracing - let cond_value = eval - .get_concrete(condition) - .and_then(ConstValue::as_i64) - .or_else(|| { - eval.resolve_with_trace(condition, 10) - .and_then(|e| e.as_i64()) - })?; - - if cond_value != 0 { - Some(true_target) - } else { - Some(false_target) - } - } - - /// Applies threading by updating the predecessor's terminator. - fn apply_threading( - ssa: &mut SsaFunction, - pred_block: usize, - _branch_block: usize, - new_target: usize, - method_token: Token, - changes: &mut EventLog, - ) -> bool { - let Some(block) = ssa.block_mut(pred_block) else { - return false; - }; - - let Some(last) = block.instructions_mut().last_mut() else { - return false; - }; - - match last.op().clone() { - SsaOp::Jump { target } if target != new_target => { - last.set_op(SsaOp::Jump { target: new_target }); - changes - .record(EventKind::ControlFlowRestructured) - .at(method_token, pred_block) - .message(format!( - "jump threaded: B{pred_block} now jumps to B{new_target} (was B{target})" - )); - true - } - SsaOp::Branch { - condition, - true_target, - false_target, - } => { - // For branches, we thread to the known target - // We convert the branch to a jump since we know which way it goes - let old_target = if new_target == true_target { - false_target - } else { - true_target - }; - last.set_op(SsaOp::Jump { target: new_target }); - changes - .record(EventKind::BranchSimplified) - .at(method_token, pred_block) - .message(format!( - "branch threaded: B{pred_block} condition on {condition:?} resolved to B{new_target} (eliminated B{old_target})" - )); - true - } - SsaOp::Leave { target } if target != new_target => { - last.set_op(SsaOp::Leave { target: new_target }); - changes - .record(EventKind::ControlFlowRestructured) - .at(method_token, pred_block) - .message(format!( - "leave threaded: B{pred_block} now leaves to B{new_target} (was B{target})" - )); - true - } - _ => false, - } - } - - /// Runs jump threading on the SSA function. - fn run_threading( - ssa: &mut SsaFunction, - method_token: Token, - changes: &mut EventLog, - ptr_size: PointerSize, - ) -> bool { - if ssa.is_empty() { - return false; - } - - let cfg = SsaCfg::from_ssa(ssa); - - // Collect threading opportunities first (to avoid borrow issues) - let mut threadings: Vec<(usize, usize, usize)> = Vec::new(); - - for (block_idx, block) in ssa.iter_blocks() { - // Look for branch terminators - let Some(SsaOp::Branch { - condition, - true_target, - false_target, - }) = block.terminator_op() - else { - continue; - }; - - // For each predecessor, check if we can thread - for pred_idx in cfg.block_predecessors(block_idx) { - if let Some(target) = Self::try_thread( - ssa, - *pred_idx, - block_idx, - *condition, - *true_target, - *false_target, - ptr_size, - ) { - // Only thread if we're actually changing the control flow - // (i.e., the predecessor doesn't already go directly to target) - let pred_target = ssa.block(*pred_idx).and_then(|b| { - b.terminator_op().and_then(|op| match op { - SsaOp::Jump { target } | SsaOp::Leave { target } => Some(*target), - _ => None, - }) - }); - - if pred_target != Some(target) { - threadings.push((*pred_idx, block_idx, target)); - } - } - } - } - - // Apply all threadings - let mut changed = false; - for (pred_block, branch_block, new_target) in threadings { - if Self::apply_threading( - ssa, - pred_block, - branch_block, - new_target, - method_token, - changes, - ) { - changed = true; - } - } - - changed - } -} - -impl SsaPass for JumpThreadingPass { - fn name(&self) -> &'static str { - "jump-threading" - } - - fn description(&self) -> &'static str { - "Threads branches when condition is known from predecessor path" - } - - fn run_on_method( - &self, - ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { - let ptr_size = PointerSize::from_pe(assembly.file().pe().is_64bit); - let mut changes = EventLog::new(); - let changed = Self::run_threading(ssa, method_token, &mut changes, ptr_size); - - if changed { - ctx.events.merge(&changes); - } - - Ok(changed) - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use crate::{ - analysis::{CallGraph, ConstValue, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId}, - compiler::{passes::threading::JumpThreadingPass, CompilerContext, SsaPass}, - metadata::token::Token, - test::helpers::test_assembly_arc, - }; - - fn test_context() -> CompilerContext { - let call_graph = Arc::new(CallGraph::new()); - CompilerContext::new(call_graph) - } - - #[test] - fn test_empty_function() { - let pass = JumpThreadingPass::new(); - let ctx = test_context(); - let mut ssa = SsaFunction::new(0, 0); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - assert!(!changed); - } - - #[test] - fn test_no_branches() { - let pass = JumpThreadingPass::new(); - let ctx = test_context(); - - // B0: jump to B1 - // B1: return - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut ssa = SsaFunction::new(0, 0); - ssa.add_block(block0); - ssa.add_block(block1); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - assert!(!changed); - } - - #[test] - fn test_thread_with_constant_true() { - let pass = JumpThreadingPass::new(); - let ctx = test_context(); - - // B0: cond = true; jump B1 - // B1: if cond goto B2 else B3 - // Should thread B0 directly to B2 - let cond_var = SsaVarId::from_index(0); - - // Block 0: const true, jump to 1 - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: cond_var, - value: ConstValue::True, - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - - // Block 1: branch on cond_var - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: cond_var, - true_target: 2, - false_target: 3, - })); - - // Block 2 and 3: return - let mut block2 = SsaBlock::new(2); - block2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut block3 = SsaBlock::new(3); - block3.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut ssa = SsaFunction::new(0, 0); - ssa.add_block(block0); - ssa.add_block(block1); - ssa.add_block(block2); - ssa.add_block(block3); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - assert!(changed); - - // Verify B0 now jumps directly to B2 - if let Some(block) = ssa.block(0) { - assert!(matches!( - block.terminator_op(), - Some(SsaOp::Jump { target: 2 }) - )); - } - } - - #[test] - fn test_thread_with_constant_false() { - let pass = JumpThreadingPass::new(); - let ctx = test_context(); - - let cond_var = SsaVarId::from_index(0); - - // Block 0: const false, jump to 1 - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: cond_var, - value: ConstValue::False, - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - - // Block 1: branch on cond_var - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: cond_var, - true_target: 2, - false_target: 3, - })); - - // Block 2 and 3: return - let mut block2 = SsaBlock::new(2); - block2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut block3 = SsaBlock::new(3); - block3.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut ssa = SsaFunction::new(0, 0); - ssa.add_block(block0); - ssa.add_block(block1); - ssa.add_block(block2); - ssa.add_block(block3); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - assert!(changed); - - // Verify B0 now jumps directly to B3 (false branch) - if let Some(block) = ssa.block(0) { - assert!(matches!( - block.terminator_op(), - Some(SsaOp::Jump { target: 3 }) - )); - } - } - - #[test] - fn test_thread_comparison_greater() { - let pass = JumpThreadingPass::new(); - let ctx = test_context(); - - let x_var = SsaVarId::from_index(0); - let zero_var = SsaVarId::from_index(1); - let cmp_var = SsaVarId::from_index(2); - - // Block 0: x = 5; zero = 0; cmp = (x > zero); jump to 1 - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: x_var, - value: ConstValue::I32(5), - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Const { - dest: zero_var, - value: ConstValue::I32(0), - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Cgt { - dest: cmp_var, - left: x_var, - right: zero_var, - unsigned: false, - })); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - - // Block 1: branch on cmp_var - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: cmp_var, - true_target: 2, - false_target: 3, - })); - - // Block 2 and 3: return - let mut block2 = SsaBlock::new(2); - block2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut block3 = SsaBlock::new(3); - block3.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut ssa = SsaFunction::new(0, 0); - ssa.add_block(block0); - ssa.add_block(block1); - ssa.add_block(block2); - ssa.add_block(block3); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - assert!(changed); - - // Verify B0 now jumps directly to B2 (true branch, since 5 > 0) - if let Some(block) = ssa.block(0) { - assert!(matches!( - block.terminator_op(), - Some(SsaOp::Jump { target: 2 }) - )); - } - } - - #[test] - fn test_no_thread_unknown_condition() { - let pass = JumpThreadingPass::new(); - let ctx = test_context(); - - // Block 0: jump to 1 - // Block 1: branch on x (which has no known definition) - // Should NOT thread since x is unknown - let x_var = SsaVarId::from_index(0); - - // x has no definition - simulating an argument or external value - let mut block0 = SsaBlock::new(0); - block0.add_instruction(SsaInstruction::synthetic(SsaOp::Jump { target: 1 })); - - let mut block1 = SsaBlock::new(1); - block1.add_instruction(SsaInstruction::synthetic(SsaOp::Branch { - condition: x_var, - true_target: 2, - false_target: 3, - })); - - let mut block2 = SsaBlock::new(2); - block2.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut block3 = SsaBlock::new(3); - block3.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); - - let mut ssa = SsaFunction::new(1, 0); - ssa.add_block(block0); - ssa.add_block(block1); - ssa.add_block(block2); - ssa.add_block(block3); - - let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000001), &ctx, &test_assembly_arc()) - .unwrap(); - - // Should NOT change since condition is unknown - assert!(!changed); - } - - #[test] - fn test_pass_name_and_description() { - let pass = JumpThreadingPass::new(); - assert_eq!(pass.name(), "jump-threading"); - assert!(!pass.description().is_empty()); - } -} diff --git a/dotscope/src/compiler/passes/utils.rs b/dotscope/src/compiler/passes/utils.rs deleted file mode 100644 index c807909d..00000000 --- a/dotscope/src/compiler/passes/utils.rs +++ /dev/null @@ -1,109 +0,0 @@ -//! Shared utilities for SSA transformation passes. -//! -//! This module contains common functionality used by multiple passes to avoid -//! code duplication and ensure consistent behavior. - -use std::collections::{BTreeMap, BTreeSet}; - -/// Follows a chain of mappings to find the ultimate target. -/// -/// Given a map where keys point to values, and values may also be keys, -/// this function follows the chain until reaching a value that is not a key. -/// Handles cycles by stopping when a previously visited key is encountered. -/// -/// This is useful for: -/// - Following trampoline chains (block → block → block) -/// - Resolving copy chains (var → var → var) -/// - Any other transitive closure computation -/// -/// # Arguments -/// -/// * `map` - The mapping to follow. -/// * `start` - The starting key. -/// -/// # Returns -/// -/// The ultimate target after following the chain. -/// -/// # Example -/// -/// ```ignore -/// let mut trampolines = BTreeMap::new(); -/// trampolines.insert(1, 2); -/// trampolines.insert(2, 3); -/// // 1 -> 2 -> 3 -/// assert_eq!(resolve_chain(&trampolines, 1), 3); -/// ``` -#[must_use] -pub fn resolve_chain(map: &BTreeMap, start: K) -> K -where - K: Copy + Ord, -{ - let mut current = start; - let mut visited = BTreeSet::new(); - - while let Some(&next) = map.get(¤t) { - if !visited.insert(current) { - // Cycle detected - return current position - break; - } - current = next; - } - - current -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_resolve_chain_follows_mappings() { - let mut map = BTreeMap::new(); - map.insert(1, 2); - map.insert(2, 3); - map.insert(3, 4); - - // Should follow the full chain: 1 -> 2 -> 3 -> 4 - assert_eq!(resolve_chain(&map, 1), 4); - assert_eq!(resolve_chain(&map, 2), 4); - assert_eq!(resolve_chain(&map, 3), 4); - // Value not in map returns itself - assert_eq!(resolve_chain(&map, 4), 4); - } - - #[test] - fn test_resolve_chain_handles_cycles() { - let mut map = BTreeMap::new(); - map.insert(1, 2); - map.insert(2, 1); // Cycle: 1 -> 2 -> 1 - - // Should terminate without infinite loop - let result = resolve_chain(&map, 1); - assert!(result == 1 || result == 2); - } - - #[test] - fn test_resolve_chain_single_step() { - let mut map = BTreeMap::new(); - map.insert(5, 10); - - assert_eq!(resolve_chain(&map, 5), 10); - } - - #[test] - fn test_resolve_chain_empty_map() { - let map: BTreeMap = BTreeMap::new(); - // Value not in map returns itself - assert_eq!(resolve_chain(&map, 42), 42); - } - - #[test] - fn test_resolve_chain_self_loop() { - let mut map = BTreeMap::new(); - map.insert(1, 1); // Self-loop - - // Should terminate - assert_eq!(resolve_chain(&map, 1), 1); - } -} diff --git a/dotscope/src/compiler/scheduler.rs b/dotscope/src/compiler/scheduler.rs index 2fd3ef97..228e3672 100644 --- a/dotscope/src/compiler/scheduler.rs +++ b/dotscope/src/compiler/scheduler.rs @@ -1,62 +1,39 @@ -//! Pass scheduler for orchestrating SSA pass execution. +//! CIL-side pass scheduler — thin wrapper around +//! [`analyssa::scheduling::PassScheduler`]. //! -//! The [`PassScheduler`] manages the execution of SSA optimization passes using -//! capability-based dependency scheduling. Passes declare what they provide and -//! require via [`PassCapability`](super::PassCapability), and the scheduler -//! topologically sorts them into execution layers. Each layer runs to fixpoint -//! with normalization between iterations. +//! All scheduling logic (capability layer assignment, fixpoint +//! iteration, parallel per-method dispatch, modification-scope-driven +//! repair) lives in analyssa. This module provides a CIL-flavored facade +//! that: //! -//! # Layer Computation +//! - Maps [`PassPhase`] enum values to fallback layer numbers. +//! - Sets the assembly handle on the context before invoking the +//! underlying scheduler. +//! - Bridges error types between analyssa and dotscope. //! -//! Passes that don't declare capabilities fall back to a numeric layer derived -//! from their original phase assignment (Structure=0, Value=1, Simplify=2, -//! Inline=3). Passes that declare capabilities may be moved to a later layer -//! to satisfy their dependencies. -//! -//! # Normalization -//! -//! Normalize passes (DCE, constant propagation, GVN, etc.) are separate from -//! the layered passes. They run between every layer's fixpoint iterations, -//! cleaning up after each round of structural changes to expose new -//! optimization opportunities. +//! Hosts targeting other ISAs use analyssa's scheduler directly. -use std::{ - collections::HashMap, - sync::{ - atomic::{AtomicBool, AtomicUsize, Ordering}, - Arc, - }, -}; +use std::sync::Arc; -use dashmap::DashSet; -use log::debug; -use rayon::prelude::*; +use analyssa::scheduling::PassScheduler as AnalyssaPassScheduler; use crate::{ + analysis::CilTarget, compiler::{ context::CompilerContext, - events::EventKind, - pass::{ModificationScope, PassCapability, PassPhase, SsaPass}, + pass::{PassPhase, SsaPass}, state::ProcessingState, }, - metadata::token::Token, - utils::graph::IndexedGraph, CilObject, Error, Result, }; -/// Orchestrates SSA pass execution using capability-based scheduling. +/// Orchestrates CIL SSA pass execution. /// -/// Passes are organized into execution layers computed from their declared -/// capabilities. Each layer runs all its passes to fixpoint with normalization -/// between iterations. The entire pipeline then repeats until global fixpoint -/// or max iterations. -/// -/// # Layer Computation -/// -/// 1. Each pass starts at a fallback layer based on its phase assignment. -/// 2. If pass A provides capability X and pass B requires X, B is pushed -/// to a layer strictly after A. -/// 3. Cycles in the dependency graph are detected and reported as errors. +/// Wraps [`analyssa::scheduling::PassScheduler`] +/// with CIL-specific phase mapping. Passes are added with a +/// [`PassPhase`]; layered passes go through capability-based scheduling +/// in analyssa, normalize passes interleave between layer fixpoint +/// iterations. /// /// # Example /// @@ -68,17 +45,7 @@ use crate::{ /// scheduler.run_pipeline(&ctx, &assembly, None)?; /// ``` pub struct PassScheduler { - /// Maximum iterations for the entire pipeline. - max_iterations: usize, - /// Number of stable iterations before stopping. - stable_iterations: usize, - /// Maximum iterations for a single layer before moving on. - max_phase_iterations: usize, - /// All non-normalize passes with their fallback layer number. - passes: Vec<(Box, usize)>, - /// Normalization passes (DCE, GVN, const/copy propagation). - /// Run after each layer to clean up before the next. - normalize: Vec>, + inner: AnalyssaPassScheduler, } impl Default for PassScheduler { @@ -96,10 +63,6 @@ impl PassScheduler { /// * `stable_iterations` - Stop early if no changes for this many consecutive iterations. /// * `max_phase_iterations` - Maximum fixpoint iterations for a single layer before /// moving to the next. - /// - /// # Returns - /// - /// A new `PassScheduler` with no passes registered. #[must_use] pub fn new( max_iterations: usize, @@ -107,977 +70,65 @@ impl PassScheduler { max_phase_iterations: usize, ) -> Self { Self { - max_iterations, - stable_iterations, - max_phase_iterations, - passes: Vec::new(), - normalize: Vec::new(), + inner: AnalyssaPassScheduler::new( + max_iterations, + stable_iterations, + max_phase_iterations, + ), } } /// Returns the number of non-normalize passes registered. #[must_use] pub fn pass_count(&self) -> usize { - self.passes.len() + self.inner.pass_count() } /// Returns the number of normalization passes registered. #[must_use] pub fn normalize_count(&self) -> usize { - self.normalize.len() + self.inner.normalize_count() } /// Adds a pass to the scheduler with its execution phase. /// - /// Layered passes (`Structure`, `Value`, `Simplify`, `Inline`) are placed - /// into execution layers based on their phase and capability dependencies. - /// `Normalize` passes run between every layer's fixpoint iterations and are - /// excluded from the capability dependency graph. - /// - /// If a layered pass declares capabilities (via [`SsaPass::provides`] / - /// [`SsaPass::requires`]), the scheduler may place it in a later layer - /// to satisfy dependency constraints. - /// - /// # Arguments - /// - /// * `pass` - The SSA pass to register. - /// * `phase` - The execution phase determining when this pass runs. - pub fn add(&mut self, pass: Box, phase: PassPhase) { + /// `Normalize` passes go to the analyssa normalize-pass list (run + /// between every layer's fixpoint iterations). All other phases map + /// to their numeric fallback layer via [`PassPhase::as_layer`]. + pub fn add(&mut self, pass: Box>, phase: PassPhase) { match phase { - PassPhase::Normalize => self.normalize.push(pass), - _ => self.passes.push((pass, phase.as_layer())), - } - } - - /// Computes execution layer assignments from capability dependencies. - /// - /// The algorithm has three phases: - /// - /// 1. **Graph construction**: Builds a directed graph using [`IndexedGraph`] - /// where an edge from pass A to pass B means "A must run before B" - /// (A provides a capability that B requires). - /// - /// 2. **Cycle validation**: Runs topological sort on the graph. If it fails, - /// the graph contains a cycle and the passes cannot be scheduled. - /// - /// 3. **Layer assignment**: Each pass starts at its fallback layer, then - /// Bellman-Ford relaxation pushes passes forward until every dependency - /// constraint `layer[dependent] > layer[provider]` is satisfied. - /// - /// Unsatisfied requirements (no provider registered for a required capability) - /// are silently ignored — the pass stays at its fallback layer. This allows - /// e.g. CFF to run without `Int32ValueContainer` when JIEJIE.NET is not detected. - /// - /// # Returns - /// - /// A `Vec` where element `i` is the layer number for `self.passes[i]`. - /// - /// # Errors - /// - /// Returns [`crate::Error::SsaError`] if a cycle is detected in the capability - /// dependencies, including the names of the passes involved in the cycle. - fn compute_layer_assignment(&self) -> Result> { - let n = self.passes.len(); - if n == 0 { - return Ok(vec![]); - } - - // Build capability -> provider indices map - let mut providers: HashMap> = HashMap::new(); - for (i, (pass, _)) in self.passes.iter().enumerate() { - for &cap in pass.provides() { - providers.entry(cap).or_default().push(i); - } - } - - // Build dependency graph: edge from provider → dependent - let mut graph: IndexedGraph = IndexedGraph::with_capacity(n, n); - for i in 0..n { - graph.add_node(i); - } - - // deps[i] = indices of passes that must run before pass i - let mut deps: Vec> = vec![vec![]; n]; - for (i, (pass, _)) in self.passes.iter().enumerate() { - for &cap in pass.requires() { - if let Some(provider_indices) = providers.get(&cap) { - for &j in provider_indices { - if j != i { - if let Some(slot) = deps.get_mut(i) { - slot.push(j); - } - let _ = graph.add_edge(j, i, ()); - } - } - } - } - } - - // Validate the DAG is acyclic via topological sort - if graph.topological_sort().is_none() { - if let Some(cycle) = graph.find_any_cycle() { - let names: Vec<&str> = cycle - .iter() - .filter_map(|&i| self.passes.get(i).map(|p| p.0.name())) - .collect(); - return Err(Error::SsaError(format!( - "Cycle detected in pass capability dependencies: {}", - names.join(" → ") - ))); - } - return Err(Error::SsaError( - "Cycle detected in pass capability dependencies".to_string(), - )); - } - - // Bellman-Ford relaxation: push layers forward to satisfy dependencies. - // Invariant: after convergence, layer[i] > layer[dep] for all deps of i. - let mut layer: Vec = self.passes.iter().map(|(_, fallback)| *fallback).collect(); - let mut changed = true; - while changed { - changed = false; - for i in 0..n { - let dep_list = match deps.get(i) { - Some(d) => d.clone(), - None => continue, - }; - for dep in dep_list { - let layer_i = layer.get(i).copied().unwrap_or(0); - let layer_dep = layer.get(dep).copied().unwrap_or(0); - if layer_i <= layer_dep { - if let Some(slot) = layer.get_mut(i) { - *slot = layer_dep.saturating_add(1); - } - changed = true; - } - } - } - } - - // Log any passes that were moved from their fallback layer - if !deps.iter().all(Vec::is_empty) { - let max_layer = layer.iter().copied().max().unwrap_or(0); - debug!( - "Capability scheduling: {} passes across {} layers", - n, - max_layer.saturating_add(1) - ); - for (i, (pass, fallback)) in self.passes.iter().enumerate() { - let layer_i = layer.get(i).copied().unwrap_or(*fallback); - if layer_i != *fallback { - debug!( - " pass '{}': layer {} (moved from fallback {})", - pass.name(), - layer_i, - fallback - ); - } - } - } - - Ok(layer) - } - - /// Runs normalization passes repeatedly until no pass reports changes. - /// - /// Each iteration runs all normalize passes once. If any pass makes changes, - /// another iteration begins. Stops when a full iteration produces no changes - /// or `max_phase_iterations` is reached. - /// - /// # Arguments - /// - /// * `ctx` - The compiler context (shared state, SSA functions, events). - /// * `passes` - The normalization passes to run. - /// * `max_phase_iterations` - Maximum fixpoint iterations before giving up. - /// * `assembly` - Shared reference to the assembly for pass lookups. - /// - /// # Returns - /// - /// `true` if any pass made changes across all iterations, `false` otherwise. - /// - /// # Errors - /// - /// Returns an error if any pass fails during execution. - fn normalize_to_fixpoint( - ctx: &CompilerContext, - passes: &mut [Box], - max_phase_iterations: usize, - assembly: &Arc, - state: Option<&ProcessingState>, - iteration_modified: Option<&DashSet>, - ) -> Result { - let mut any_changed = false; - - for _ in 0..max_phase_iterations { - let changed = Self::run_passes_once(ctx, passes, assembly, state, iteration_modified)?; - - if !changed { - break; - } - - any_changed = true; - } - - Ok(any_changed) - } - - /// Runs a single execution layer to fixpoint with normalization. - /// - /// Each fixpoint iteration: - /// 1. Runs all passes in the layer once across all methods. - /// 2. If any pass made changes, runs normalization to fixpoint. - /// 3. Repeats until no layer pass makes changes or `max_phase_iterations` - /// is reached. - /// - /// # Arguments - /// - /// * `ctx` - The compiler context. - /// * `all_passes` - The full pass list (layer passes are selected by index). - /// * `layer_indices` - Indices into `all_passes` for this layer's passes. - /// * `normalize_passes` - Normalization passes to run between iterations. - /// * `max_phase_iterations` - Maximum fixpoint iterations for this layer. - /// * `assembly` - Shared reference to the assembly. - /// - /// # Returns - /// - /// `true` if any pass made changes during this layer's execution. - /// - /// # Errors - /// - /// Returns an error if any pass fails during execution. - #[allow(clippy::too_many_arguments)] - fn layer_to_fixpoint( - ctx: &CompilerContext, - all_passes: &mut [(Box, usize)], - layer_indices: &[usize], - normalize_passes: &mut [Box], - max_phase_iterations: usize, - assembly: &Arc, - state: Option<&ProcessingState>, - iteration_modified: Option<&DashSet>, - ) -> Result { - if layer_indices.is_empty() { - return Ok(false); - } - - let mut phase_changed = false; - - for _ in 0..max_phase_iterations { - let pass_changed = Self::run_layer_passes_once( - ctx, - all_passes, - layer_indices, - assembly, - state, - iteration_modified, - )?; - - if !pass_changed { - // Layer converged. Run normalize one final time to clean up - // any modifications the last layer iteration made to SSA - // (e.g., CFF unflattening rebuilds SSA, proxy devirt needs - // to see the final state). - if phase_changed && !normalize_passes.is_empty() { - Self::normalize_to_fixpoint( - ctx, - normalize_passes, - max_phase_iterations, - assembly, - state, - iteration_modified, - )?; - } - break; - } - - phase_changed = true; - - if !normalize_passes.is_empty() { - Self::normalize_to_fixpoint( - ctx, - normalize_passes, - max_phase_iterations, - assembly, - state, - iteration_modified, - )?; - } - } - - Ok(phase_changed) - } - - /// Runs a contiguous slice of passes once over all methods. - /// - /// Used for normalization passes, which are stored as a contiguous - /// `Vec>`. For layer passes (which are a subset of - /// the full pass list), use [`run_layer_passes_once`](Self::run_layer_passes_once). - /// - /// The execution order is: - /// 1. Initialize all passes ([`SsaPass::initialize`]). - /// 2. Run global passes sequentially ([`SsaPass::run_global`]). - /// 3. For each non-global pass, run it across all methods in parallel - /// via [`run_single_pass`](Self::run_single_pass). - /// 4. Finalize all passes ([`SsaPass::finalize`]). - /// - /// # Arguments - /// - /// * `ctx` - The compiler context. - /// * `passes` - The passes to execute (typically normalization passes). - /// * `assembly` - Shared reference to the assembly. - /// - /// # Returns - /// - /// `true` if any pass made changes, `false` otherwise. - /// - /// # Errors - /// - /// Returns an error if any pass fails during initialization, execution, - /// or finalization. - fn run_passes_once( - ctx: &CompilerContext, - passes: &mut [Box], - assembly: &Arc, - state: Option<&ProcessingState>, - iteration_modified: Option<&DashSet>, - ) -> Result { - for pass in passes.iter_mut() { - pass.initialize(ctx)?; - } - - // Dirty filtering: non-full-scan passes see only dirty methods - let dirty_set = state.map(|s| &s.method_dirty); - let all_methods = Self::method_order(ctx, None); - let dirty_methods = Self::method_order(ctx, dirty_set); - let any_changed = AtomicBool::new(false); - - for pass in passes.iter() { - if pass.is_global() && pass.run_global(ctx, assembly)? { - any_changed.store(true, Ordering::Relaxed); - } - } - - for pass in passes.iter() { - if pass.is_global() { - continue; - } - let methods = if pass.requires_full_scan() { - &all_methods - } else { - &dirty_methods - }; - Self::run_single_pass( - pass.as_ref(), - ctx, - methods, - assembly, - &any_changed, - iteration_modified, - ); - } - - for pass in passes.iter_mut() { - pass.finalize(ctx)?; - } - - Ok(any_changed.load(Ordering::Relaxed)) - } - - /// Runs a subset of passes (identified by indices) once over all methods. - /// - /// Used for layer execution, where the passes to run are a non-contiguous - /// subset of `all_passes` identified by `indices`. The execution follows - /// the same init → global → per-method → finalize order as - /// [`run_passes_once`](Self::run_passes_once). - /// - /// # Arguments - /// - /// * `ctx` - The compiler context. - /// * `all_passes` - The full pass list (with fallback layer metadata). - /// * `indices` - Indices into `all_passes` selecting this layer's passes. - /// * `assembly` - Shared reference to the assembly. - /// - /// # Returns - /// - /// `true` if any pass made changes, `false` otherwise. - /// - /// # Errors - /// - /// Returns an error if any pass fails during initialization, execution, - /// or finalization. - fn run_layer_passes_once( - ctx: &CompilerContext, - all_passes: &mut [(Box, usize)], - indices: &[usize], - assembly: &Arc, - state: Option<&ProcessingState>, - iteration_modified: Option<&DashSet>, - ) -> Result { - for &idx in indices { - let pass_entry = all_passes.get_mut(idx).ok_or_else(|| { - Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) - })?; - pass_entry.0.initialize(ctx)?; - } - - let dirty_set = state.map(|s| &s.method_dirty); - let all_methods = Self::method_order(ctx, None); - let dirty_methods = Self::method_order(ctx, dirty_set); - let any_changed = AtomicBool::new(false); - - for &idx in indices { - let pass_entry = all_passes.get(idx).ok_or_else(|| { - Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) - })?; - let pass = &pass_entry.0; - if pass.is_global() && pass.run_global(ctx, assembly)? { - any_changed.store(true, Ordering::Relaxed); - } - } - - for &idx in indices { - let pass_entry = all_passes.get(idx).ok_or_else(|| { - Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) - })?; - let pass = &pass_entry.0; - if pass.is_global() { - continue; - } - let methods = if pass.requires_full_scan() { - &all_methods - } else { - &dirty_methods - }; - Self::run_single_pass( - pass.as_ref(), - ctx, - methods, - assembly, - &any_changed, - iteration_modified, - ); - } - - for &idx in indices { - let pass_entry = all_passes.get_mut(idx).ok_or_else(|| { - Error::SsaError(format!("scheduler: pass index {idx} out of bounds")) - })?; - pass_entry.0.finalize(ctx)?; - } - - Ok(any_changed.load(Ordering::Relaxed)) - } - - /// Computes the method processing order for parallel pass execution. - /// - /// Returns methods sorted in reverse topological order of the call graph - /// (callees before callers), filtered to only methods that have SSA - /// representations. When `dirty_only` is provided, further filters to - /// only methods in the dirty set. - /// - /// Falls back to arbitrary iteration order if the call graph has no - /// topological ordering (e.g., due to recursion). - fn method_order(ctx: &CompilerContext, dirty_only: Option<&DashSet>) -> Vec { - let topo = ctx.methods_reverse_topological(); - let order: Vec<_> = if topo.is_empty() { - ctx.all_methods().collect() - } else { - topo - }; - order - .into_iter() - .filter(|token| ctx.ssa_functions.contains_key(token)) - .filter(|token| dirty_only.is_none_or(|dirty| dirty.contains(token))) - .collect() - } - - /// Runs a single pass across all methods in parallel, tracking changes. - /// - /// Methods are processed in parallel using rayon. For each method: - /// 1. Checks [`SsaPass::should_run`] to skip inapplicable methods. - /// 2. Removes the SSA from the concurrent map (brief lock). - /// 3. Calls [`SsaPass::run_on_method`] with no locks held. - /// 4. If changes were made, repairs or rebuilds SSA based on the pass's - /// [`ModificationScope`]: - /// - [`UsesOnly`](ModificationScope::UsesOnly) / - /// [`InstructionsOnly`](ModificationScope::InstructionsOnly): lightweight - /// [`repair_ssa`](crate::analysis::SsaFunction::repair_ssa) - /// - [`CfgModifying`](ModificationScope::CfgModifying): full - /// [`rebuild_ssa`](crate::analysis::SsaFunction::rebuild_ssa) - /// 5. Reinserts the SSA and marks the method as processed. - /// - /// # Arguments - /// - /// * `pass` - The pass to execute (shared reference, must be `Send + Sync`). - /// * `ctx` - The compiler context containing SSA functions and events. - /// * `methods` - Method tokens to process, in the order from [`method_order`](Self::method_order). - /// * `assembly` - Shared reference to the assembly for pass lookups. - /// * `any_changed` - Atomic flag set to `true` if any method was modified. - fn run_single_pass( - pass: &dyn SsaPass, - ctx: &CompilerContext, - methods: &[Token], - assembly: &Arc, - any_changed: &AtomicBool, - iteration_modified: Option<&DashSet>, - ) { - let event_snapshot = ctx.events.len(); - let pass_change_count = AtomicUsize::new(0); - - // Passes that read other methods' SSA (e.g., inlining, proxy devirt) - // need peer SSAs to remain visible in the DashMap during parallel - // execution. For these passes, we clone the SSA before processing so - // the original stays readable by other threads. Passes that only - // modify their own method use the faster remove/insert path. - let clone_for_visibility = pass.reads_peer_ssa(); - - methods.par_iter().for_each(|&method_token| { - if !pass.should_run(method_token, ctx) { - return; - } - - let mut ssa = if clone_for_visibility { - let Some(ssa_ref) = ctx.ssa_functions.get(&method_token) else { - return; - }; - ssa_ref.clone() - } else { - let Some((_, ssa)) = ctx.ssa_functions.remove(&method_token) else { - return; - }; - ssa - }; - - let result = pass.run_on_method(&mut ssa, method_token, ctx, assembly); - - if let Ok(true) = result { - match pass.modification_scope() { - ModificationScope::UsesOnly | ModificationScope::InstructionsOnly => { - ssa.repair_ssa(); - } - ModificationScope::CfgModifying => { - if let Err(e) = ssa.rebuild_ssa() { - log::warn!("SSA rebuild failed for {}: {}", method_token, e); - } - } - } - } - - ctx.ssa_functions.insert(method_token, ssa); - - if let Ok(true) = result { - any_changed.store(true, Ordering::Relaxed); - pass_change_count.fetch_add(1, Ordering::Relaxed); - ctx.processed_methods.insert(method_token); - if let Some(modified) = iteration_modified { - modified.insert(method_token); - } - } - }); - - let count = pass_change_count.load(Ordering::Relaxed); - if count > 0 { - let event_delta = ctx.events.count_by_kind_since(event_snapshot); - if event_delta.is_empty() { - debug!(" pass '{}' changed {} methods", pass.name(), count); - } else { - let summary = format_event_delta(&event_delta); - if summary.is_empty() { - debug!(" pass '{}' changed {} methods", pass.name(), count); - } else { - debug!( - " pass '{}' changed {} methods ({})", - pass.name(), - count, - summary - ); - } - } + PassPhase::Normalize => self.inner.add_normalize(pass), + other => self.inner.add_at_layer(pass, other.as_layer()), } } /// Runs the complete deobfuscation pipeline. /// - /// Execution proceeds as follows: - /// - /// 1. **Layer computation**: Calls [`compute_layer_assignment`](Self::compute_layer_assignment) - /// to build the capability DAG and assign each pass to an execution layer. - /// - /// 2. **Outer loop** (up to `max_iterations`): For each iteration: - /// a. Run each layer to fixpoint via [`layer_to_fixpoint`](Self::layer_to_fixpoint). - /// b. On the first iteration only, if no layer made changes, run normalization - /// to ensure cleanup passes execute at least once. - /// c. Track stability: stop early if no changes for `stable_iterations` - /// consecutive iterations. - /// - /// Layer assignments are recomputed at the start of each call to `run_pipeline`, - /// so passes added between calls (e.g., by the detection re-scan loop) are - /// incorporated automatically. - /// - /// # Arguments - /// - /// * `ctx` - The compiler context (thread-safe, shared across all passes). - /// * `assembly` - Shared reference to the assembly being processed. - /// - /// # Returns - /// - /// The number of outer iterations completed. Pass-level events are - /// accumulated in `ctx.events`. + /// `assembly` is stored on `ctx` via [`CompilerContext::set_assembly`] + /// so passes that need it (inlining, proxy devirt, constant folding) + /// can reach it through the host. `state` is currently a no-op + /// parameter retained for source-compatibility; dirty tracking + /// flows through `ctx.processing_state`. /// /// # Errors /// - /// Returns an error if: - /// - A cycle is detected in the capability dependency graph. - /// - Any pass fails during execution. + /// Returns an error if a cycle is detected in the capability + /// dependency graph or any pass fails. pub fn run_pipeline( &mut self, ctx: &CompilerContext, assembly: &Arc, - state: Option<&ProcessingState>, + _state: Option<&ProcessingState>, ) -> Result { - let layer_assignment = self.compute_layer_assignment()?; - - // Group pass indices by layer, then discard empty layers - let num_layers = layer_assignment - .iter() - .copied() - .max() - .map_or(0, |m| m.saturating_add(1)); - let mut layer_indices: Vec> = vec![vec![]; num_layers]; - for (i, &layer) in layer_assignment.iter().enumerate() { - if let Some(slot) = layer_indices.get_mut(layer) { - slot.push(i); - } - } - layer_indices.retain(|layer| !layer.is_empty()); - - let mut stable_count: usize = 0; - let mut iterations: usize = 0; - let max_phase = self.max_phase_iterations; - let max_iterations = self.max_iterations; - let stable_iterations = self.stable_iterations; - - for iteration in 0..max_iterations { - iterations = iteration.saturating_add(1); - debug!("Pipeline iteration {}/{}", iterations, max_iterations); - - // Track which methods are modified in this iteration so we can - // transition unmodified methods from dirty → stable at the end. - let iteration_modified = DashSet::new(); - let modified_ref = state.map(|_| &iteration_modified); - let mut iteration_changed = false; - - for layer in &layer_indices { - if Self::layer_to_fixpoint( - ctx, - &mut self.passes, - layer, - &mut self.normalize, - max_phase, - assembly, - state, - modified_ref, - )? { - iteration_changed = true; - } - } - - // Ensure normalize runs at least once even if no layer pass makes changes - if iteration == 0 && !iteration_changed && !self.normalize.is_empty() { - iteration_changed = Self::normalize_to_fixpoint( - ctx, - &mut self.normalize, - max_phase, - assembly, - state, - modified_ref, - )?; - } - - // Update dirty/stable tracking at iteration boundary - if let Some(state) = state { - if iteration_changed { - // Move unmodified dirty methods to stable - let dirty: Vec = state.method_dirty.iter().map(|t| *t).collect(); - for token in dirty { - if !iteration_modified.contains(&token) { - state.mark_method_stable(token); - } - } - // Methods modified during this iteration stay dirty for - // subsequent passes to see them (already in method_dirty - // or re-marked dirty by mark_method_dirty in the pass). - for token in iteration_modified.iter() { - state.mark_method_dirty(*token); - } - } else { - // No changes at all — all dirty methods are now stable - let dirty: Vec = state.method_dirty.iter().map(|t| *t).collect(); - for token in dirty { - state.mark_method_stable(token); - } - } - } - - if iteration_changed { - stable_count = 0; - } else { - stable_count = stable_count.saturating_add(1); - if stable_count >= stable_iterations { - debug!("Pipeline stable after {} iterations", iterations); - break; - } - } - } - - Ok(iterations) - } -} - -/// Formats an event-kind delta map into a compact summary string. -/// -/// Example: "93 strings decrypted, 115 constants folded" -fn format_event_delta(delta: &HashMap) -> String { - let mut parts: Vec = delta - .iter() - .filter(|(kind, _)| kind.is_transformation()) - .map(|(kind, count)| format!("{} {}", count, kind.description())) - .collect(); - parts.sort(); - parts.join(", ") -} - -#[cfg(test)] -mod tests { - use crate::{ - analysis::SsaFunction, - compiler::{ - context::CompilerContext, - pass::{PassCapability, PassPhase, SsaPass}, - EventKind, PassScheduler, - }, - metadata::token::Token, - CilObject, Result, - }; - - /// A minimal [`SsaPass`] implementation for testing. - /// - /// Reports changes for `changes_to_make` iterations, then stops. - struct TestPass { - name: &'static str, - changes_to_make: usize, - } - - impl TestPass { - fn new(name: &'static str, changes: usize) -> Self { - Self { - name, - changes_to_make: changes, - } - } - } - - impl SsaPass for TestPass { - fn name(&self) -> &'static str { - self.name - } - - fn run_on_method( - &self, - _ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - for i in 0..self.changes_to_make { - ctx.events - .record(EventKind::ConstantFolded) - .at(method_token, i) - .message("test"); - } - Ok(self.changes_to_make > 0) - } - } - - /// A test pass that declares [`PassCapability`] provides/requires. - struct CapabilityPass { - name: &'static str, - provides: Vec, - requires: Vec, - } - - impl SsaPass for CapabilityPass { - fn name(&self) -> &'static str { - self.name - } - - fn run_on_method( - &self, - _ssa: &mut SsaFunction, - _method_token: Token, - _ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { - Ok(false) - } - - fn provides(&self) -> &[PassCapability] { - &self.provides - } - - fn requires(&self) -> &[PassCapability] { - &self.requires - } - } - - #[test] - fn test_scheduler_iteration_limits() { - let scheduler = PassScheduler::new(10, 3, 5); - assert_eq!(scheduler.max_iterations, 10); - assert_eq!(scheduler.stable_iterations, 3); - assert_eq!(scheduler.max_phase_iterations, 5); - } - - #[test] - fn test_default_scheduler() { - let scheduler = PassScheduler::default(); - assert_eq!(scheduler.max_iterations, 5); - assert_eq!(scheduler.stable_iterations, 2); - assert_eq!(scheduler.max_phase_iterations, 15); - } - - #[test] - fn test_pass_names() { - let passes: Vec> = vec![ - Box::new(TestPass::new("pass1", 0)), - Box::new(TestPass::new("pass2", 0)), - ]; - - assert_eq!(passes.len(), 2); - assert_eq!(passes[0].name(), "pass1"); - assert_eq!(passes[1].name(), "pass2"); - } - - #[test] - fn test_add_pass() { - let mut scheduler = PassScheduler::new(5, 2, 15); - scheduler.add( - Box::new(TestPass::new("structure_pass", 0)), - PassPhase::Structure, - ); - scheduler.add(Box::new(TestPass::new("value_pass", 0)), PassPhase::Value); - scheduler.add( - Box::new(TestPass::new("simplify_pass", 0)), - PassPhase::Simplify, - ); - assert_eq!(scheduler.pass_count(), 3); - } - - /// Verifies that capability dependencies push passes to later layers. - /// - /// Setup: - /// - value-resolver (Value=1) provides `ResolvedStaticFields` - /// - cff-reconstruction (Structure=0) requires `ResolvedStaticFields` → pushed to layer 2 - /// - opaque-predicates (Simplify=2) requires `RestoredControlFlow` → pushed to layer 3 - #[test] - fn test_capability_layer_computation() { - let mut scheduler = PassScheduler::new(5, 2, 15); - - scheduler.add( - Box::new(CapabilityPass { - name: "value-resolver", - provides: vec![PassCapability::ResolvedStaticFields], - requires: vec![], - }), - PassPhase::Value, - ); - - scheduler.add( - Box::new(CapabilityPass { - name: "cff-reconstruction", - provides: vec![PassCapability::RestoredControlFlow], - requires: vec![PassCapability::ResolvedStaticFields], - }), - PassPhase::Structure, - ); - - scheduler.add( - Box::new(CapabilityPass { - name: "opaque-predicates", - provides: vec![PassCapability::SimplifiedPredicates], - requires: vec![PassCapability::RestoredControlFlow], - }), - PassPhase::Simplify, - ); - - let layers = scheduler.compute_layer_assignment().unwrap(); - assert_eq!(layers[0], 1); // value-resolver stays at 1 - assert_eq!(layers[1], 2); // cff-reconstruction pushed from 0 to 2 - assert_eq!(layers[2], 3); // opaque-predicates pushed from 2 to 3 - } - - /// Verifies that passes without capabilities stay at their fallback layers. - #[test] - fn test_no_capabilities_uses_fallback() { - let mut scheduler = PassScheduler::new(5, 2, 15); - - scheduler.add( - Box::new(TestPass::new("structure", 0)), - PassPhase::Structure, - ); - scheduler.add(Box::new(TestPass::new("value", 0)), PassPhase::Value); - scheduler.add(Box::new(TestPass::new("simplify", 0)), PassPhase::Simplify); - - let layers = scheduler.compute_layer_assignment().unwrap(); - assert_eq!(layers[0], 0); - assert_eq!(layers[1], 1); - assert_eq!(layers[2], 2); - } - - /// Verifies that a pass requiring a capability with no provider stays at fallback. - /// - /// This is the ConfuserEx scenario: CFF requires `ResolvedStaticFields` but - /// no `StaticFieldResolutionPass` is registered (no JIEJIE.NET detected). - #[test] - fn test_missing_provider_uses_fallback() { - let mut scheduler = PassScheduler::new(5, 2, 15); - - scheduler.add( - Box::new(CapabilityPass { - name: "cff", - provides: vec![PassCapability::RestoredControlFlow], - requires: vec![PassCapability::ResolvedStaticFields], - }), - PassPhase::Structure, - ); - - let layers = scheduler.compute_layer_assignment().unwrap(); - assert_eq!(layers[0], 0); - } - - /// Verifies that mutual capability dependencies are detected as a cycle. - #[test] - fn test_cycle_detection() { - let mut scheduler = PassScheduler::new(5, 2, 15); - - scheduler.add( - Box::new(CapabilityPass { - name: "pass-a", - provides: vec![PassCapability::ResolvedStaticFields], - requires: vec![PassCapability::RestoredControlFlow], - }), - PassPhase::Structure, - ); - scheduler.add( - Box::new(CapabilityPass { - name: "pass-b", - provides: vec![PassCapability::RestoredControlFlow], - requires: vec![PassCapability::ResolvedStaticFields], - }), - PassPhase::Structure, - ); - - let result = scheduler.compute_layer_assignment(); - assert!(result.is_err()); + ctx.set_assembly(assembly.clone()); + let result = self + .inner + .run_pipeline(ctx) + .map_err(|e| Error::SsaError(e.0)); + // Release the in-context assembly handle so callers can unwrap + // the `Arc` for code generation. Without this, the + // strong-count never drops to one and `Arc::try_unwrap` fails. + ctx.clear_assembly(); + result } } diff --git a/dotscope/src/compiler/summary.rs b/dotscope/src/compiler/summary.rs index 14e960d3..35bb86ef 100644 --- a/dotscope/src/compiler/summary.rs +++ b/dotscope/src/compiler/summary.rs @@ -420,6 +420,7 @@ impl CallSiteInfo { #[cfg(test)] mod tests { use super::*; + use crate::analysis::CilTarget; #[test] fn test_method_summary_default() { @@ -433,13 +434,13 @@ mod tests { #[test] fn test_return_info() { - assert!(ReturnInfo::Void.is_known()); - assert!(ReturnInfo::Constant(ConstValue::I32(42)).is_known()); - assert!(!ReturnInfo::Dynamic.is_known()); + assert!(ReturnInfo::::Void.is_known()); + assert!(ReturnInfo::::Constant(ConstValue::I32(42)).is_known()); + assert!(!ReturnInfo::::Dynamic.is_known()); - assert!(ReturnInfo::PassThrough(0).is_potentially_foldable()); - assert!(ReturnInfo::PureComputation.is_potentially_foldable()); - assert!(!ReturnInfo::Dynamic.is_potentially_foldable()); + assert!(ReturnInfo::::PassThrough(0).is_potentially_foldable()); + assert!(ReturnInfo::::PureComputation.is_potentially_foldable()); + assert!(!ReturnInfo::::Dynamic.is_potentially_foldable()); } #[test] diff --git a/dotscope/src/deobfuscation/context.rs b/dotscope/src/deobfuscation/context.rs index 4c11a3f1..de9b9e7d 100644 --- a/dotscope/src/deobfuscation/context.rs +++ b/dotscope/src/deobfuscation/context.rs @@ -13,6 +13,7 @@ use std::{ use dashmap::DashSet; use crate::{ + analysis::CallGraph, compiler::{CompilerContext, ProcessingState}, deobfuscation::{ config::EngineConfig, @@ -138,12 +139,12 @@ impl Deref for AnalysisContext { impl AnalysisContext { /// Creates a new analysis context with default configuration. - pub fn new(call_graph: Arc) -> Self { + pub fn new(call_graph: Arc) -> Self { Self::with_config(call_graph, EngineConfig::default()) } /// Creates a new analysis context with custom configuration. - pub fn with_config(call_graph: Arc, config: EngineConfig) -> Self { + pub fn with_config(call_graph: Arc, config: EngineConfig) -> Self { Self { compiler: CompilerContext::new(call_graph), decryptors: Arc::new(DecryptorContext::new()), @@ -305,15 +306,16 @@ impl AnalysisContext { #[cfg(test)] mod tests { + use std::{sync::Arc, thread}; + + use super::*; + use crate::{ analysis::{CallGraph, ConstValue, SsaVarId}, compiler::CallSiteInfo, - deobfuscation::context::AnalysisContext, metadata::token::Token, }; - use std::sync::Arc; - #[test] fn test_call_site_info() { let info = CallSiteInfo { @@ -392,8 +394,6 @@ mod tests { #[test] fn test_thread_safe_access() { - use std::thread; - let call_graph = Arc::new(CallGraph::new()); let ctx = Arc::new(AnalysisContext::new(call_graph)); diff --git a/dotscope/src/deobfuscation/engine/api.rs b/dotscope/src/deobfuscation/engine/api.rs index b2caaba6..4fffdbfe 100644 --- a/dotscope/src/deobfuscation/engine/api.rs +++ b/dotscope/src/deobfuscation/engine/api.rs @@ -54,7 +54,8 @@ impl DeobfuscationEngine { /// Runs technique detection on an assembly (IL-level and SSA-level). /// /// This is a detection-only API — no transforms are applied. It runs both - /// IL-level [`Technique::detect`] and SSA-level [`Technique::detect_ssa`] + /// IL-level [`Technique::detect`](crate::deobfuscation::techniques::Technique::detect) + /// and SSA-level [`Technique::detect_ssa`](crate::deobfuscation::techniques::Technique::detect_ssa) /// on each technique, catching patterns that require cross-block def-use /// chain analysis (e.g., BitMono string encryption, opaque field predicates, /// delegate proxies). diff --git a/dotscope/src/deobfuscation/engine/pipeline.rs b/dotscope/src/deobfuscation/engine/pipeline.rs index 1d554736..583c6559 100644 --- a/dotscope/src/deobfuscation/engine/pipeline.rs +++ b/dotscope/src/deobfuscation/engine/pipeline.rs @@ -11,9 +11,12 @@ use std::{ use log::{debug, info, warn}; +use analyssa::scheduling::SsaPass as AnalyssaSsaPass; + use crate::{ + analysis::{CilTarget, MethodRef}, cilassembly::{expand_type_tokens, CleanupRequest}, - compiler::{DeadMethodEliminationPass, EventLog, PassScheduler, SsaPass}, + compiler::{CompilerContext, DeadMethodEliminationPass, EventLog, PassScheduler}, deobfuscation::{ cleanup::{build_cleanup_request, execute_cleanup}, context::AnalysisContext, @@ -559,7 +562,14 @@ impl<'a> PipelineRun<'a> { ) -> Result<(CilObject, DeobfuscationResult)> { if ctx.config.cleanup.remove_unused_methods { let dead_method_pass = DeadMethodEliminationPass::new(); - let _ = dead_method_pass.run_global(&ctx, &assembly_arc)?; + // Ensure the assembly handle is set on the context so the + // global pass can reach it via the analyssa host trait. + ctx.compiler.set_assembly(assembly_arc.clone()); + let _ = AnalyssaSsaPass::::run_global( + &dead_method_pass, + &ctx.compiler, + ) + .map_err(|e| Error::SsaError(e.0))?; } let ssa_call_graph = ctx.build_ssa_call_graph(); @@ -638,9 +648,20 @@ impl<'a> PipelineRun<'a> { let pass = NeutralizationPass::new(&all_tokens); let mut neutralized = false; let method_tokens: Vec = ctx.ssa_functions.iter().map(|e| *e.key()).collect(); + // Ensure the assembly handle is set on the context so passes can + // reach it through the analyssa host trait. + ctx.set_assembly(assembly_arc.clone()); for method_token in &method_tokens { if let Some(mut ssa) = ctx.ssa_functions.get_mut(method_token) { - if pass.run_on_method(&mut ssa, *method_token, ctx, assembly_arc)? { + let method_ref = MethodRef::from(*method_token); + if AnalyssaSsaPass::::run_on_method( + &pass, + &mut ssa, + &method_ref, + ctx, + ) + .map_err(|e| Error::SsaError(e.0))? + { neutralized = true; // Ensure neutralized methods get code-generated. Without this, // methods modified only by neutralization keep their original IL @@ -670,6 +691,13 @@ impl<'a> PipelineRun<'a> { if let Some(pool) = ctx.template_pool.get() { pool.release(); } + // Drop the assembly reference held by `CompilerContext` so the + // strong-count can drop to one. The analyssa scheduler's + // `run_pipeline` already clears it on its own exit, but nested + // global-pass invocations (e.g. dead-method elimination outside + // the scheduler) and re-entrant pipeline iterations can leave a + // stray reference here. + ctx.compiler.clear_assembly(); Arc::try_unwrap(assembly_arc).map_err(|_| { Error::Deobfuscation("Cannot unwrap assembly - still has other references".into()) }) diff --git a/dotscope/src/deobfuscation/engine/tests.rs b/dotscope/src/deobfuscation/engine/tests.rs index b895c22a..d2a6aea0 100644 --- a/dotscope/src/deobfuscation/engine/tests.rs +++ b/dotscope/src/deobfuscation/engine/tests.rs @@ -79,7 +79,7 @@ fn test_pipeline_passes_selective() { #[test] fn test_analyze_return_void() { // Create SSA with void return - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); ssa.add_block(block); @@ -91,7 +91,7 @@ fn test_analyze_return_void() { #[test] fn test_analyze_return_constant() { // Create SSA that returns a constant - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); // Define a constant @@ -114,7 +114,7 @@ fn test_analyze_return_constant() { #[test] fn test_analyze_return_no_returns_is_void() { // Create SSA with no return statements (unusual but possible) - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let block = SsaBlock::new(0); ssa.add_block(block); @@ -125,7 +125,7 @@ fn test_analyze_return_no_returns_is_void() { #[test] fn test_analyze_purity_pure() { // Create SSA with only pure operations - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); // Pure arithmetic operation @@ -136,6 +136,7 @@ fn test_analyze_purity_pure() { dest, left: src1, right: src2, + flags: None, })); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(dest), @@ -149,7 +150,7 @@ fn test_analyze_purity_pure() { #[test] fn test_analyze_purity_impure_store_field() { // Create SSA with a field store - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let obj = SsaVarId::from_index(0); @@ -169,7 +170,7 @@ fn test_analyze_purity_impure_store_field() { #[test] fn test_analyze_purity_impure_throw() { // Create SSA with a throw - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let exc = SsaVarId::from_index(0); @@ -183,7 +184,7 @@ fn test_analyze_purity_impure_throw() { #[test] fn test_analyze_purity_readonly() { // Create SSA with only field reads - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let dest = SsaVarId::from_index(0); @@ -208,7 +209,7 @@ fn test_analyze_purity_readonly() { #[test] fn test_analyze_purity_unknown_calls() { // Create SSA with a call - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let dest = SsaVarId::from_index(0); @@ -232,13 +233,18 @@ fn test_analyze_purity_unknown_calls() { #[test] fn test_detect_string_decryptor_xor() { // Create small SSA with XOR operations (typical of string decryption) - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let dest = SsaVarId::from_index(0); let left = SsaVarId::from_index(1); let right = SsaVarId::from_index(2); - block.add_instruction(SsaInstruction::synthetic(SsaOp::Xor { dest, left, right })); + block.add_instruction(SsaInstruction::synthetic(SsaOp::Xor { + dest, + left, + right, + flags: None, + })); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(dest), })); @@ -251,7 +257,7 @@ fn test_detect_string_decryptor_xor() { #[test] fn test_detect_string_decryptor_large_method() { // Create large SSA (over 200 instructions) - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); // Add 250 instructions @@ -273,7 +279,7 @@ fn test_detect_string_decryptor_large_method() { #[test] fn test_detect_dispatcher_with_switch() { // Create SSA with a switch having 5+ targets - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let value = SsaVarId::from_index(0); @@ -291,7 +297,7 @@ fn test_detect_dispatcher_with_switch() { #[test] fn test_detect_dispatcher_small_switch() { // Create SSA with a small switch (< 5 targets) - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let value = SsaVarId::from_index(0); @@ -309,7 +315,7 @@ fn test_detect_dispatcher_small_switch() { #[test] fn test_detect_dispatcher_no_switch() { // Create SSA without switch - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); block.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: None })); ssa.add_block(block); @@ -323,7 +329,7 @@ fn test_compute_method_summary() { let engine = DeobfuscationEngine::default(); // Create a simple pure method with constant return - let mut ssa = SsaFunction::new(0, 0); + let mut ssa: SsaFunction = SsaFunction::new(0, 0); let mut block = SsaBlock::new(0); let var = SsaVarId::from_index(0); diff --git a/dotscope/src/deobfuscation/passes/antidebug.rs b/dotscope/src/deobfuscation/passes/antidebug.rs index 857dbc55..a6c88db9 100644 --- a/dotscope/src/deobfuscation/passes/antidebug.rs +++ b/dotscope/src/deobfuscation/passes/antidebug.rs @@ -53,11 +53,13 @@ use std::collections::HashSet; use crate::{ - analysis::{PhiTaintMode, SsaFunction, SsaOp, TaintConfig, TokenTaintBuilder}, + analysis::{ + CilTarget, MethodRef, PhiTaintMode, SsaFunction, SsaOp, TaintConfig, TokenTaintBuilder, + }, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, deobfuscation::utils::resolve_qualified_method_name, metadata::token::Token, - CilObject, Result, + CilObject, }; /// Controls how many sentinel patterns must co-occur in a method for the pass @@ -153,7 +155,7 @@ impl SentinelTaintRemovalPass { } } -impl SsaPass for SentinelTaintRemovalPass { +impl SsaPass for SentinelTaintRemovalPass { fn name(&self) -> &'static str { self.pass_name } @@ -166,20 +168,24 @@ impl SsaPass for SentinelTaintRemovalPass { ModificationScope::CfgModifying } - fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool { - self.target_methods.is_empty() || self.target_methods.contains(&method_token) + fn should_run(&self, method: &MethodRef, _host: &CompilerContext) -> bool { + self.target_methods.is_empty() || self.target_methods.contains(&method.0) } fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly = host + .assembly() + .ok_or_else(|| analyssa::Error::new("SentinelTaintRemovalPass requires an assembly"))?; + let method_token = method.0; + let ctx = host; // Step 1: Find sentinel tokens and check co-occurrence condition. let (sentinel_tokens, distinct_count) = - find_sentinel_tokens(ssa, assembly, &self.sentinel_patterns); + find_sentinel_tokens(ssa, &assembly, &self.sentinel_patterns); if !self .condition .is_satisfied(distinct_count, self.sentinel_patterns.len()) diff --git a/dotscope/src/deobfuscation/passes/bitmono/strings.rs b/dotscope/src/deobfuscation/passes/bitmono/strings.rs index 65850736..208e3294 100644 --- a/dotscope/src/deobfuscation/passes/bitmono/strings.rs +++ b/dotscope/src/deobfuscation/passes/bitmono/strings.rs @@ -47,7 +47,7 @@ use std::{ }; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, ConstValue, MethodRef, SsaFunction, SsaOp, SsaVarId}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, deobfuscation::{ techniques::BitMonoStringFindings, @@ -185,7 +185,7 @@ impl StringDecryptionPass { } } -impl SsaPass for StringDecryptionPass { +impl SsaPass for StringDecryptionPass { fn name(&self) -> &'static str { "BitMonoStringDecryption" } @@ -201,10 +201,15 @@ impl SsaPass for StringDecryptionPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host + .assembly() + .ok_or_else(|| analyssa::Error::new("StringDecryptionPass requires an assembly"))?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; + let method_token = method.0; let mut changed = false; // Build LoadStaticField index once for the entire method diff --git a/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs b/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs index 1903ccd0..ab2c3f97 100644 --- a/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs +++ b/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs @@ -31,10 +31,9 @@ use std::collections::HashMap; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, ConstValue, MethodRef, SsaFunction, SsaOp, SsaVarId}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Result, }; /// SSA pass that replaces UnmanagedString call+newobj patterns with string constants. @@ -57,7 +56,7 @@ pub struct UnmanagedStringReversalPass { pub(crate) native_string_map: HashMap, } -impl SsaPass for UnmanagedStringReversalPass { +impl SsaPass for UnmanagedStringReversalPass { fn name(&self) -> &'static str { "BitMonoUnmanagedString" } @@ -73,10 +72,9 @@ impl SsaPass for UnmanagedStringReversalPass { fn run_on_method( &self, ssa: &mut SsaFunction, - _method_token: Token, + _method: &MethodRef, ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + ) -> analyssa::Result { let mut changed = false; for block_idx in 0..ssa.blocks().len() { diff --git a/dotscope/src/deobfuscation/passes/decryption.rs b/dotscope/src/deobfuscation/passes/decryption.rs index ae41f6bb..9812de0f 100644 --- a/dotscope/src/deobfuscation/passes/decryption.rs +++ b/dotscope/src/deobfuscation/passes/decryption.rs @@ -67,7 +67,10 @@ use std::{ use rayon::prelude::*; use crate::{ - analysis::{ConstValue, SsaCfg, SsaFunction, SsaOp, SsaVarId, ValueResolver}, + analysis::{ + CilTarget, ConstValue, MethodRef, SsaCfg, SsaFunction, SsaOp, SsaVarId, TypeRef, + ValueResolver, + }, compiler::{CompilerContext, EventKind, EventLog, ModificationScope, PassCapability, SsaPass}, deobfuscation::{ context::AnalysisContext, @@ -82,12 +85,12 @@ use crate::{ token::Token, typesystem::{CilFlavor, CilPrimitive, CilPrimitiveKind, PointerSize}, }, - utils::graph::{ - algorithms::{compute_dominators, DominatorTree}, - GraphBase, NodeId, RootedGraph, - }, CilObject, Error, Result, }; +use analyssa::graph::{ + algorithms::{compute_dominators, DominatorTree}, + GraphBase, NodeId, RootedGraph, +}; /// Decryption pass for obfuscated constants and strings. /// @@ -101,9 +104,10 @@ use crate::{ /// /// # Emulation Template /// -/// The pass uses the shared [`EmulationTemplatePool`] for O(1) CoW forks. -/// The pool is warmed up once during engine initialization; this pass simply -/// calls [`EmulationTemplatePool::fork()`] for each decryption attempt. +/// The pass uses the shared emulation-template pool (an internal CoW snapshot +/// store, see `crate::deobfuscation::context`) for O(1) forks. The pool is +/// warmed up once during engine initialization; this pass forks a fresh +/// emulator state from it for each decryption attempt. /// /// If the pool is unavailable (no techniques registered emulation hooks), /// decryption is silently skipped. @@ -149,7 +153,7 @@ impl DecryptionPass { /// Creates a new constant decryption pass from an analysis context. /// /// Captures shared references to the context's decryptors and state machine - /// providers. Emulation is backed by the shared [`EmulationTemplatePool`] + /// providers. Emulation is backed by the shared emulation-template pool /// stored in the context (if available). /// /// # Arguments @@ -442,7 +446,7 @@ impl DecryptionPass { if let Some(token) = Self::resolve_flavor_to_typeref(&elem_flavor, thread) { return Some(ConstValue::DecryptedArray { data: bytes, - element_type_token: token, + element_type_ref: TypeRef::new(token), element_size: elem_size, }); } @@ -911,7 +915,7 @@ impl DecryptionPass { } } -impl SsaPass for DecryptionPass { +impl SsaPass for DecryptionPass { fn name(&self) -> &'static str { "decryption" } @@ -928,13 +932,13 @@ impl SsaPass for DecryptionPass { &[PassCapability::DecryptedStrings] } - fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn initialize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { // This pass relies on DecryptorContext being populated by obfuscator // detection before it runs. No additional setup needed. Ok(()) } - fn finalize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn finalize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { // Pool lifecycle is managed by the engine — nothing to release here. Ok(()) } @@ -942,16 +946,21 @@ impl SsaPass for DecryptionPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host + .assembly() + .ok_or_else(|| analyssa::Error::new("DecryptionPass requires an assembly"))?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; + let method_token = method.0; // Early exit if no decryptors are registered if !self.decryptors.has_decryptors() { return Ok(false); } - let ptr_size = PointerSize::from_pe(assembly.file().pe().is_64bit); + let ptr_size = PointerSize::from_is_64bit(assembly.file().pe().is_64bit); // Track whether any mode made changes let mut any_changed = false; @@ -1134,7 +1143,9 @@ mod tests { /// Helper to create a minimal analysis context for testing. fn create_test_context() -> AnalysisContext { let call_graph = Arc::new(CallGraph::new()); - AnalysisContext::new(call_graph) + let ctx = AnalysisContext::new(call_graph); + ctx.compiler.set_assembly(test_assembly_arc()); + ctx } #[test] @@ -1162,7 +1173,7 @@ mod tests { // No decryptors registered, should return false (no changes) let changed = pass - .run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()) + .run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx.compiler) .unwrap(); assert!(!changed); } @@ -1189,7 +1200,7 @@ mod tests { // Has decryptors but no calls to them let changed = pass - .run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()) + .run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx.compiler) .unwrap(); assert!(!changed); } @@ -1217,7 +1228,7 @@ mod tests { // Call without destination, can't replace let changed = pass - .run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()) + .run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx.compiler) .unwrap(); assert!(!changed); } @@ -1246,7 +1257,7 @@ mod tests { // Argument is not in known_values, should fail let changed = pass - .run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()) + .run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx.compiler) .unwrap(); assert!(!changed); @@ -1277,7 +1288,7 @@ mod tests { .unwrap(); let changed = pass - .run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()) + .run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx.compiler) .unwrap(); assert!(!changed); } @@ -1313,7 +1324,7 @@ mod tests { // This will try to decrypt but fail (no assembly for emulation) // The point is that it recognized the MethodSpec as a decryptor call - let _ = pass.run_on_method(&mut ssa, method_token, &ctx, &test_assembly_arc()); + let _ = pass.run_on_method(&mut ssa, &MethodRef::from(method_token), &ctx.compiler); // Verify the MethodSpec was resolved to the decryptor and a failure was recorded // (failure because there's no assembly to emulate against) diff --git a/dotscope/src/deobfuscation/passes/delegates.rs b/dotscope/src/deobfuscation/passes/delegates.rs index bfaf36f5..9f6893bc 100644 --- a/dotscope/src/deobfuscation/passes/delegates.rs +++ b/dotscope/src/deobfuscation/passes/delegates.rs @@ -29,14 +29,14 @@ use std::{ collections::{HashMap, HashSet}, - sync::Arc, + sync::{Arc, RwLockReadGuard}, }; use dashmap::{DashMap, DashSet}; use log::debug; use crate::{ - analysis::{MethodRef as SsaMethodRef, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, MethodRef, MethodRef as SsaMethodRef, SsaFunction, SsaOp, SsaVarId}, assembly::{FlowType, Instruction, Operand}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, deobfuscation::{utils::build_def_map, EmulationTemplatePool, ProcessCell}, @@ -173,9 +173,7 @@ impl DelegateProxyResolutionPass { /// Delegates to [`ProcessCell::ensure_initialized`] with a fork + /// targeted warmup closure and a post-init callback that extracts delegate /// targets from the emulated state. - fn ensure_initialized( - &self, - ) -> Result>> { + fn ensure_initialized(&self) -> Result>> { self.lazy_process.ensure_initialized( || self.create_process_from_pool(), |process| self.extract_targets(process), @@ -411,7 +409,7 @@ fn resolve_synthetic_target(instructions: &[Instruction]) -> Option<(Token, bool call_token.map(|token| (token, is_callvirt)) } -impl SsaPass for DelegateProxyResolutionPass { +impl SsaPass for DelegateProxyResolutionPass { fn name(&self) -> &'static str { "delegate-proxy-resolution" } @@ -424,12 +422,11 @@ impl SsaPass for DelegateProxyResolutionPass { ModificationScope::InstructionsOnly } - fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool { - self.affected_methods.contains(&method_token) - && !self.processed_methods.contains(&method_token) + fn should_run(&self, method: &MethodRef, _host: &CompilerContext) -> bool { + self.affected_methods.contains(&method.0) && !self.processed_methods.contains(&method.0) } - fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn initialize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { let remaining = self .affected_methods .len() @@ -448,10 +445,11 @@ impl SsaPass for DelegateProxyResolutionPass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let ctx = host; + let method_token = method.0; let guard = self.ensure_initialized()?; let Some(_process) = guard.as_ref() else { return Ok(false); @@ -605,8 +603,8 @@ impl SsaPass for DelegateProxyResolutionPass { Ok(changed) } - fn finalize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn finalize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { // Clear the emulation process to release its Arc reference. - self.lazy_process.clear() + self.lazy_process.clear().map_err(Into::into) } } diff --git a/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs b/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs index 16b69767..d0fb0b94 100644 --- a/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs +++ b/dotscope/src/deobfuscation/passes/jiejienet/arrays.rs @@ -39,10 +39,9 @@ //! ``` use crate::{ - analysis::{ConstValue, FieldRef, MethodRef, SsaFunction, SsaOp}, + analysis::{CilTarget, ConstValue, FieldRef, MethodRef, SsaFunction, SsaOp}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Result, }; /// SSA pass that replaces JIEJIE.NET field handle container calls with direct field handle constants @@ -87,7 +86,7 @@ impl ArrayInitRestorationPass { } } -impl SsaPass for ArrayInitRestorationPass { +impl SsaPass for ArrayInitRestorationPass { fn name(&self) -> &'static str { "jiejie-array-init-restore" } @@ -103,10 +102,9 @@ impl SsaPass for ArrayInitRestorationPass { fn run_on_method( &self, ssa: &mut SsaFunction, - _method_token: Token, + _method: &MethodRef, ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + ) -> analyssa::Result { if self.field_tokens.is_empty() { return Ok(false); } diff --git a/dotscope/src/deobfuscation/passes/jiejienet/resources.rs b/dotscope/src/deobfuscation/passes/jiejienet/resources.rs index a5bd90fd..d03c3358 100644 --- a/dotscope/src/deobfuscation/passes/jiejienet/resources.rs +++ b/dotscope/src/deobfuscation/passes/jiejienet/resources.rs @@ -34,10 +34,9 @@ use std::collections::HashMap; use log::info; use crate::{ - analysis::{MethodRef, SsaFunction, SsaOp}, + analysis::{CilTarget, MethodRef, SsaFunction, SsaOp}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Result, }; /// Target information for a resource interception method. @@ -77,7 +76,7 @@ impl ResourceRestorationPass { } } -impl SsaPass for ResourceRestorationPass { +impl SsaPass for ResourceRestorationPass { fn name(&self) -> &'static str { "jiejie-resource-restoration" } @@ -93,10 +92,9 @@ impl SsaPass for ResourceRestorationPass { fn run_on_method( &self, ssa: &mut SsaFunction, - _method_token: Token, + _method: &MethodRef, ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + ) -> analyssa::Result { if self.redirects.is_empty() { return Ok(false); } diff --git a/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs b/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs index 593c7faa..d6c634cb 100644 --- a/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs +++ b/dotscope/src/deobfuscation/passes/jiejienet/typeofs.rs @@ -36,12 +36,12 @@ use log::debug; use crate::{ analysis::{ - resolve_corelib_valuetype, ConstValue, DefSite, MethodRef, SsaFunction, SsaInstruction, - SsaOp, SsaVarId, TypeRef, VariableOrigin, + resolve_corelib_valuetype, CilTarget, ConstValue, DefSite, MethodRef, SsaFunction, + SsaInstruction, SsaOp, SsaVarId, TypeRef, VariableOrigin, }, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Error, Result, + CilObject, Error, }; /// SSA pass that replaces JIEJIE.NET typeof container calls with @@ -81,7 +81,7 @@ impl TypeOfRestorationPass { } } -impl SsaPass for TypeOfRestorationPass { +impl SsaPass for TypeOfRestorationPass { fn name(&self) -> &'static str { "jiejie-typeof-restore" } @@ -97,10 +97,14 @@ impl SsaPass for TypeOfRestorationPass { fn run_on_method( &self, ssa: &mut SsaFunction, - _method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + _method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host + .assembly() + .ok_or_else(|| analyssa::Error::new("TypeOfRestorationPass requires an assembly"))?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; if self.type_tokens.is_empty() { return Ok(false); } diff --git a/dotscope/src/deobfuscation/passes/netreactor/resolver.rs b/dotscope/src/deobfuscation/passes/netreactor/resolver.rs index 6218e608..bbf55605 100644 --- a/dotscope/src/deobfuscation/passes/netreactor/resolver.rs +++ b/dotscope/src/deobfuscation/passes/netreactor/resolver.rs @@ -19,10 +19,9 @@ use std::collections::HashSet; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId, TypeRef}, + analysis::{CilTarget, ConstValue, MethodRef, SsaFunction, SsaOp, SsaVarId, TypeRef}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Result, }; /// Folds `accessor()` calls into `ldtoken X` for the NR @@ -45,7 +44,7 @@ impl TokenResolverPass { } } -impl SsaPass for TokenResolverPass { +impl SsaPass for TokenResolverPass { fn name(&self) -> &'static str { "netreactor-token-resolver" } @@ -61,10 +60,9 @@ impl SsaPass for TokenResolverPass { fn run_on_method( &self, ssa: &mut SsaFunction, - _method_token: Token, + _method: &MethodRef, ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + ) -> analyssa::Result { if self.accessor_tokens.is_empty() { return Ok(false); } diff --git a/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs b/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs index 96608edb..cd4a1a41 100644 --- a/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs +++ b/dotscope/src/deobfuscation/passes/netreactor/rewrite.rs @@ -35,10 +35,9 @@ use std::collections::HashSet; use crate::{ - analysis::{MethodRef, SsaFunction, SsaOp}, + analysis::{CilTarget, MethodRef, SsaFunction, SsaOp}, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, metadata::token::Token, - CilObject, Result, }; /// Rewrites NR resource-resolver shim calls in user code. @@ -78,7 +77,7 @@ impl ResourceShimRewritePass { } } -impl SsaPass for ResourceShimRewritePass { +impl SsaPass for ResourceShimRewritePass { fn name(&self) -> &'static str { "netreactor-resource-shim-rewrite" } @@ -95,10 +94,10 @@ impl SsaPass for ResourceShimRewritePass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, + method: &MethodRef, ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + ) -> analyssa::Result { + let method_token = method.0; // Skip work for the resolver's own methods — the cleanup pipeline // deletes them wholesale, so rewriting here would be wasted. if self.shim_method_tokens.contains(&method_token) || self.lazy_init_token == method_token { @@ -186,7 +185,9 @@ mod tests { }; fn make_ctx() -> AnalysisContext { - AnalysisContext::new(Arc::new(CallGraph::new())) + let ctx = AnalysisContext::new(Arc::new(CallGraph::new())); + ctx.compiler.set_assembly(test_assembly_arc()); + ctx } #[test] @@ -208,7 +209,11 @@ mod tests { let pass = ResourceShimRewritePass::new(vec![shim], lazy_init, bcl); let ctx = make_ctx(); let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000003), &ctx, &test_assembly_arc()) + .run_on_method( + &mut ssa, + &MethodRef::from(Token::new(0x06000003)), + &ctx.compiler, + ) .unwrap(); assert!(changed); let block = ssa.block(0).unwrap(); @@ -239,7 +244,11 @@ mod tests { let pass = ResourceShimRewritePass::new(vec![shim], lazy_init, bcl); let ctx = make_ctx(); let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000003), &ctx, &test_assembly_arc()) + .run_on_method( + &mut ssa, + &MethodRef::from(Token::new(0x06000003)), + &ctx.compiler, + ) .unwrap(); assert!(changed); let block = ssa.block(0).unwrap(); @@ -265,7 +274,7 @@ mod tests { let ctx = make_ctx(); // Running on the lazy_init method itself should NOT touch its body. let changed = pass - .run_on_method(&mut ssa, lazy_init, &ctx, &test_assembly_arc()) + .run_on_method(&mut ssa, &MethodRef::from(lazy_init), &ctx.compiler) .unwrap(); assert!(!changed); } @@ -289,7 +298,11 @@ mod tests { let pass = ResourceShimRewritePass::new(vec![shim], lazy_init, bcl_unset); let ctx = make_ctx(); let changed = pass - .run_on_method(&mut ssa, Token::new(0x06000003), &ctx, &test_assembly_arc()) + .run_on_method( + &mut ssa, + &MethodRef::from(Token::new(0x06000003)), + &ctx.compiler, + ) .unwrap(); // Shim left intact — fallback is to leave the call alone. assert!(!changed); diff --git a/dotscope/src/deobfuscation/passes/neutralize.rs b/dotscope/src/deobfuscation/passes/neutralize.rs index 05ecb887..efd21bef 100644 --- a/dotscope/src/deobfuscation/passes/neutralize.rs +++ b/dotscope/src/deobfuscation/passes/neutralize.rs @@ -39,10 +39,9 @@ use std::collections::{HashMap, HashSet}; use crate::{ - analysis::{find_token_dependencies, SsaFunction, SsaOp}, + analysis::{find_token_dependencies, CilTarget, ConstValue, MethodRef, SsaFunction, SsaOp}, compiler::{CompilerContext, EventKind, SsaPass}, metadata::token::Token, - CilObject, Result, }; /// Action to perform on a tainted instruction during neutralization. @@ -291,7 +290,7 @@ impl<'a> NeutralizationPass<'a> { if matches!( instr.op(), SsaOp::Const { - value: crate::analysis::ConstValue::DecryptedString(_), + value: ConstValue::DecryptedString(_), .. } ) { @@ -359,7 +358,7 @@ impl<'a> NeutralizationPass<'a> { } } -impl SsaPass for NeutralizationPass<'_> { +impl SsaPass for NeutralizationPass<'_> { fn name(&self) -> &'static str { "neutralization" } @@ -371,14 +370,14 @@ impl SsaPass for NeutralizationPass<'_> { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - _assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let method_token = method.0; let neutralized = self.neutralize_method(ssa); if neutralized > 0 { - ctx.events + host.events .record(EventKind::InstructionRemoved) .method(method_token) .message(format!( @@ -441,6 +440,7 @@ mod tests { dest: v2, left: v0, right: v1, + flags: None, })); b0.add_instruction(SsaInstruction::synthetic(SsaOp::Return { value: Some(v2) })); diff --git a/dotscope/src/deobfuscation/passes/opaquefields.rs b/dotscope/src/deobfuscation/passes/opaquefields.rs index 09b10e97..53801b77 100644 --- a/dotscope/src/deobfuscation/passes/opaquefields.rs +++ b/dotscope/src/deobfuscation/passes/opaquefields.rs @@ -41,14 +41,14 @@ use std::{ collections::{HashMap, HashSet}, - sync::Arc, + sync::{Arc, RwLockReadGuard}, }; use dashmap::DashSet; use log::debug; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, ConstValue, MethodRef, SsaFunction, SsaOp, SsaVarId}, compiler::{CompilerContext, EventKind, ModificationScope, PassCapability, SsaPass}, deobfuscation::{utils::build_def_map, EmulationTemplatePool, ProcessCell}, emulation::{EmValue, EmulationProcess, HeapRef}, @@ -190,9 +190,7 @@ impl OpaqueFieldPredicatePass { /// /// Delegates to [`ProcessCell::ensure_initialized`] with a fork + /// targeted warmup closure. - fn ensure_initialized( - &self, - ) -> Result>> { + fn ensure_initialized(&self) -> Result>> { self.lazy_process .ensure_initialized(|| self.create_process_from_pool(), |_| {}) } @@ -396,7 +394,7 @@ fn trace_field_chain( None // Exceeded max depth } -impl SsaPass for OpaqueFieldPredicatePass { +impl SsaPass for OpaqueFieldPredicatePass { fn name(&self) -> &'static str { "opaque-field-predicate-removal" } @@ -413,12 +411,11 @@ impl SsaPass for OpaqueFieldPredicatePass { &[PassCapability::ResolvedStaticFields] } - fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool { - self.affected_methods.contains(&method_token) - && !self.processed_methods.contains(&method_token) + fn should_run(&self, method: &MethodRef, _host: &CompilerContext) -> bool { + self.affected_methods.contains(&method.0) && !self.processed_methods.contains(&method.0) } - fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn initialize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { let remaining = self .affected_methods .len() @@ -438,10 +435,15 @@ impl SsaPass for OpaqueFieldPredicatePass { fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host + .assembly() + .ok_or_else(|| analyssa::Error::new("OpaqueFieldPredicatePass requires an assembly"))?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; + let method_token = method.0; let guard = self.ensure_initialized()?; let Some(process) = guard.as_ref() else { return Ok(false); @@ -679,9 +681,9 @@ impl SsaPass for OpaqueFieldPredicatePass { Ok(changed) } - fn finalize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn finalize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { // Clear the emulation process to release its Arc reference. // This is needed so the assembly can be unwrapped for code generation. - self.lazy_process.clear() + self.lazy_process.clear().map_err(Into::into) } } diff --git a/dotscope/src/deobfuscation/passes/reflection.rs b/dotscope/src/deobfuscation/passes/reflection.rs index ffaf40aa..ddd7fb7e 100644 --- a/dotscope/src/deobfuscation/passes/reflection.rs +++ b/dotscope/src/deobfuscation/passes/reflection.rs @@ -82,11 +82,13 @@ use std::collections::HashSet; use dashmap::DashSet; use crate::{ - analysis::{ConstValue, FieldRef, MethodRef, SsaFunction, SsaOp, SsaVarId, VariableOrigin}, + analysis::{ + CilTarget, ConstValue, FieldRef, MethodRef, SsaFunction, SsaOp, SsaVarId, VariableOrigin, + }, compiler::{CompilerContext, EventKind, ModificationScope, SsaPass}, deobfuscation::utils::is_method_named, metadata::token::Token, - CilObject, Result, + CilObject, }; /// A detected reflection call site with its resolved target and location. @@ -201,7 +203,7 @@ impl ReflectionDevirtualizationPass { } } -impl SsaPass for ReflectionDevirtualizationPass { +impl SsaPass for ReflectionDevirtualizationPass { fn name(&self) -> &'static str { "reflection-devirtualization" } @@ -214,17 +216,22 @@ impl SsaPass for ReflectionDevirtualizationPass { ModificationScope::InstructionsOnly } - fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool { - self.target_methods.contains(&method_token) && !self.processed.contains(&method_token) + fn should_run(&self, method: &MethodRef, _host: &CompilerContext) -> bool { + self.target_methods.contains(&method.0) && !self.processed.contains(&method.0) } fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host.assembly().ok_or_else(|| { + analyssa::Error::new("ReflectionDevirtualizationPass requires an assembly") + })?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; + let method_token = method.0; let sites = find_reflection_sites(ssa, assembly); if sites.is_empty() { self.processed.insert(method_token); diff --git a/dotscope/src/deobfuscation/passes/staticfields.rs b/dotscope/src/deobfuscation/passes/staticfields.rs index cd775576..fde150e1 100644 --- a/dotscope/src/deobfuscation/passes/staticfields.rs +++ b/dotscope/src/deobfuscation/passes/staticfields.rs @@ -47,12 +47,12 @@ use std::{ use log::{debug, info, warn}; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, ConstValue, MethodRef, SsaFunction, SsaOp, SsaVarId}, compiler::{CompilerContext, EventKind, ModificationScope, PassCapability, SsaPass}, deobfuscation::{EmulationTemplatePool, ProcessCell}, emulation::{EmValue, EmulationProcess}, metadata::token::Token, - CilObject, Error, Result, + CilObject, Error, }; /// Extracts an SSA [`ConstValue`] from a raw emulator value. @@ -411,7 +411,7 @@ impl StaticFieldResolutionPass { } } -impl SsaPass for StaticFieldResolutionPass { +impl SsaPass for StaticFieldResolutionPass { fn name(&self) -> &'static str { self.pass_name } @@ -428,7 +428,7 @@ impl SsaPass for StaticFieldResolutionPass { &self.capabilities } - fn initialize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn initialize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { // Eagerly initialize field values before parallel method processing. // This avoids a race condition where parallel run_on_method calls // compete on lazy initialization — losing threads would see empty @@ -440,10 +440,14 @@ impl SsaPass for StaticFieldResolutionPass { fn run_on_method( &self, ssa: &mut SsaFunction, - _method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + _method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host.assembly().ok_or_else(|| { + analyssa::Error::new("StaticFieldResolutionPass requires an assembly") + })?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; if !self.ensure_initialized() { return Ok(false); } @@ -499,8 +503,8 @@ impl SsaPass for StaticFieldResolutionPass { Ok(true) } - fn finalize(&mut self, _ctx: &CompilerContext) -> Result<()> { + fn finalize(&mut self, _host: &CompilerContext) -> analyssa::Result<()> { // Release the emulation process to free the Arc reference - self.lazy_process.clear() + self.lazy_process.clear().map_err(Into::into) } } diff --git a/dotscope/src/deobfuscation/passes/unflattening/detection.rs b/dotscope/src/deobfuscation/passes/unflattening/detection.rs index 78470328..21b5f754 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/detection.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/detection.rs @@ -17,7 +17,10 @@ //! The confidence score combines these signals to distinguish CFF from normal //! loops or state machines. -use std::collections::{HashSet, VecDeque}; +use std::{ + cmp::Ordering, + collections::{HashSet, VecDeque}, +}; use rayon::prelude::*; @@ -28,13 +31,13 @@ use crate::{ statevar::{identify_state_variable, StateVariable}, UnflattenConfig, }, - utils::{ - graph::{ - algorithms::{compute_dominators, DominatorTree}, - GraphBase, NodeId, Successors, - }, - BitSet, +}; +use analyssa::{ + graph::{ + algorithms::{compute_dominators, DominatorTree}, + GraphBase, NodeId, Successors, }, + BitSet, }; /// Entry point into a CFF region. @@ -351,7 +354,7 @@ impl<'a> CffDetector<'a> { patterns.sort_by(|a, b| { b.confidence .partial_cmp(&a.confidence) - .unwrap_or(std::cmp::Ordering::Equal) + .unwrap_or(Ordering::Equal) }); patterns @@ -1151,10 +1154,11 @@ fn can_reach_dispatcher( #[cfg(test)] mod tests { + use analyssa::BitSet; + use crate::{ analysis::SsaVarId, deobfuscation::passes::unflattening::dispatcher::{DispatcherInfo, StateTransform}, - utils::BitSet, }; use super::{CffPattern, EntryCondition, EntryPoint}; diff --git a/dotscope/src/deobfuscation/passes/unflattening/mod.rs b/dotscope/src/deobfuscation/passes/unflattening/mod.rs index 1fe31719..8f4031a5 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/mod.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/mod.rs @@ -53,7 +53,7 @@ use rayon::prelude::*; use std::collections::HashMap; use crate::{ - analysis::SsaFunction, + analysis::{CilTarget, MethodRef, SsaFunction}, compiler::{CompilerContext, PassCapability, SsaPass}, deobfuscation::{ config::DetectionWeights, @@ -61,7 +61,7 @@ use crate::{ passes::unflattening::tracer::{trace_for_dispatcher, TracedDispatcher}, }, metadata::{token::Token, typesystem::PointerSize}, - CilObject, Result, + CilObject, }; /// High-level API: Unflatten a method using tree-based tracing and patching. @@ -319,12 +319,10 @@ impl UnflattenConfig { /// obfuscation technique that converts structured control flow into /// a state machine. /// -/// When [`FlatteningFindings`] are available from the detection phase, -/// the pass uses pre-computed dispatchers directly instead of re-running -/// detection. This avoids duplicate structural analysis (dominance, SCCs, -/// confidence scoring) for methods already analyzed during `detect_ssa()`. -/// -/// [`FlatteningFindings`]: crate::deobfuscation::techniques::generic::flattening::FlatteningFindings +/// When detection-phase findings are available, the pass uses pre-computed +/// dispatchers directly instead of re-running detection. This avoids duplicate +/// structural analysis (dominance, SCCs, confidence scoring) for methods +/// already analyzed during the detection phase's SSA pass. pub struct CffReconstructionPass { config: UnflattenConfig, /// Successfully unflattened dispatcher methods (shared with deob engine). @@ -375,7 +373,7 @@ impl CffReconstructionPass { /// Sets the pre-detected dispatchers from the detection phase. /// /// When set, `run_on_method` uses these dispatchers directly instead of - /// re-running [`CffDetector`] analysis, avoiding duplicate work. + /// re-running the internal CFF detector, avoiding duplicate work. /// /// # Arguments /// @@ -401,7 +399,7 @@ impl CffReconstructionPass { } } -impl SsaPass for CffReconstructionPass { +impl SsaPass for CffReconstructionPass { fn name(&self) -> &'static str { "cff-reconstruction" } @@ -418,25 +416,30 @@ impl SsaPass for CffReconstructionPass { &[PassCapability::ResolvedStaticFields] } - fn should_run(&self, method_token: Token, _ctx: &CompilerContext) -> bool { + fn should_run(&self, method: &MethodRef, _host: &CompilerContext) -> bool { // Skip methods already unflattened and also skip methods that were never // detected as having dispatchers. - !self.unflattened_dispatchers.contains(&method_token) + !self.unflattened_dispatchers.contains(&method.0) && self .pre_detected - .get(&method_token) + .get(&method.0) .is_some_and(|d| !d.is_empty()) } fn run_on_method( &self, ssa: &mut SsaFunction, - method_token: Token, - ctx: &CompilerContext, - assembly: &CilObject, - ) -> Result { + method: &MethodRef, + host: &CompilerContext, + ) -> analyssa::Result { + let assembly_arc = host + .assembly() + .ok_or_else(|| analyssa::Error::new("CffReconstructionPass requires an assembly"))?; + let assembly: &CilObject = &assembly_arc; + let ctx = host; + let method_token = method.0; let mut config = self.config.clone(); - config.pointer_size = PointerSize::from_pe(assembly.file().pe().is_64bit); + config.pointer_size = PointerSize::from_is_64bit(assembly.file().pe().is_64bit); // Use pre-detected dispatcher block indices from detect_ssa phase, but // refresh variable IDs from the current SSA. Earlier passes (opaque field diff --git a/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs b/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs index 2a70a47e..a0a73e03 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/reconstruction.rs @@ -48,10 +48,11 @@ use std::collections::{BTreeMap, BTreeSet}; +use analyssa::BitSet; + use crate::{ analysis::{SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaVarId}, deobfuscation::passes::unflattening::tracer::{TraceNode, TraceTerminator, TraceTree}, - utils::BitSet, }; type PhiOperands = Vec<(usize, SsaVarId)>; @@ -553,7 +554,7 @@ fn extract_redirects_from_node( .unwrap_or(node.blocks_visited.len()); let interior_start = start_idx.saturating_add(1); if end_idx > interior_start { - let intermediate_blocks: std::collections::BTreeSet = node + let intermediate_blocks: BTreeSet = node .blocks_visited .get(interior_start..end_idx) .map(|s| s.iter().copied().collect()) diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs index 089f72c2..e1f7b8ce 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/context.rs @@ -14,13 +14,17 @@ use std::{collections::BTreeSet, mem}; +use analyssa::BitSet; + use crate::{ analysis::{ cff_taint_config, ConstValue, SsaEvaluator, SsaFunction, SsaOp, SsaVarId, SsaVariable, TaintAnalysis, }, - deobfuscation::passes::unflattening::{tracer::types::TracedDispatcher, UnflattenConfig}, - utils::BitSet, + deobfuscation::passes::unflattening::{ + tracer::{helpers, types::TracedDispatcher}, + UnflattenConfig, + }, CilObject, }; @@ -274,7 +278,7 @@ impl<'a> TreeTraceContext<'a> { /// Encapsulates the borrow of both `ssa` and `state_tainted` within one method /// to avoid split-borrow issues at call sites. pub fn propagate_taint_forward(&mut self) { - super::helpers::propagate_taint_forward(self.ssa, &mut self.state_tainted); + helpers::propagate_taint_forward(self.ssa, &mut self.state_tainted); } /// Gets the current CFF state value (if we can determine it). diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs index 1ccb4040..0368fb46 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs @@ -30,9 +30,9 @@ use crate::{ types::{HandlerTrace, TraceNode, TraceStats, TraceTerminator}, }, metadata::{token::Token, typesystem::PointerSize}, - utils::BitSet, CilObject, }; +use analyssa::BitSet; /// Traces exception handler entry blocks that were not visited by the main trace. /// diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/types.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/types.rs index f1ae8089..639b9f3b 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/types.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/types.rs @@ -12,8 +12,9 @@ use std::collections::BTreeMap; +use analyssa::BitSet; + use crate::analysis::{SsaInstruction, SsaVarId}; -use crate::utils::BitSet; /// Information about the dispatcher found during tracing. #[derive(Debug, Clone)] diff --git a/dotscope/src/deobfuscation/renamer/cascade.rs b/dotscope/src/deobfuscation/renamer/cascade.rs index 352836a9..84c7be1a 100644 --- a/dotscope/src/deobfuscation/renamer/cascade.rs +++ b/dotscope/src/deobfuscation/renamer/cascade.rs @@ -25,7 +25,7 @@ use crate::{ utils::{is_obfuscated_name, is_special_name}, }, metadata::{ - tables::{FieldRaw, MethodDefRaw, ParamRaw, TableId, TypeDefRaw}, + tables::{FieldRaw, MetadataTable, MethodDefRaw, ParamRaw, TableId, TypeDefRaw}, token::Token, }, CilObject, Result, @@ -1010,7 +1010,7 @@ impl<'a> CascadeRenamer<'a> { /// belonging to that method. The range extends to the next method's `param_list` /// or the end of the Param table. fn build_param_owner_map( - methoddef_table: &crate::metadata::tables::MetadataTable<'_, MethodDefRaw>, + methoddef_table: &MetadataTable<'_, MethodDefRaw>, param_row_count: u32, ) -> HashMap { let mut map = HashMap::new(); @@ -1049,7 +1049,7 @@ fn build_param_owner_map( /// Works for both MethodDef and Field tables by using a closure to extract /// the list-start column (`method_list` or `field_list`) from each TypeDef row. fn build_member_owner_map( - typedef_table: &crate::metadata::tables::MetadataTable<'_, TypeDefRaw>, + typedef_table: &MetadataTable<'_, TypeDefRaw>, member_row_count: u32, get_list_start: fn(&TypeDefRaw) -> u32, ) -> HashMap { @@ -1128,6 +1128,8 @@ fn generate_phase_label_from_context( #[cfg(test)] mod tests { + use std::path::Path; + use crate::{ deobfuscation::{ renamer::{ @@ -1238,7 +1240,7 @@ mod tests { #[test] fn test_cascade_on_bitmono_sample() { let path = "tests/samples/packers/bitmono/0.39.0/bitmono_renamer.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: sample not found"); return; } @@ -1266,7 +1268,7 @@ mod tests { #[test] fn test_cascade_rename_patterns() { let path = "tests/samples/packers/bitmono/0.39.0/bitmono_renamer.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: sample not found"); return; } @@ -1337,7 +1339,7 @@ mod tests { #[test] fn test_cascade_preserves_known_names() { let path = "tests/samples/packers/bitmono/0.39.0/original.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: original sample not found"); return; } @@ -1364,7 +1366,7 @@ mod tests { #[test] fn test_cascade_entry_counts() { let path = "tests/samples/packers/bitmono/0.39.0/bitmono_renamer.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: sample not found"); return; } @@ -1416,7 +1418,7 @@ mod tests { #[test] fn test_cascade_respects_config() { let path = "tests/samples/packers/bitmono/0.39.0/bitmono_renamer.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: sample not found"); return; } @@ -1458,7 +1460,7 @@ mod tests { #[test] fn test_cascade_ssa_context_populated() { let path = "tests/samples/packers/confuserex/1.6.0/original.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: sample not found"); return; } @@ -1508,7 +1510,7 @@ mod tests { #[test] fn test_cascade_context_quality_on_obfuscated() { let path = "tests/samples/packers/confuserex/1.6.0/mkaring_maximum.exe"; - if !std::path::Path::new(path).exists() { + if !Path::new(path).exists() { eprintln!("Skipping: sample not found"); return; } @@ -2014,9 +2016,7 @@ mod tests { eprintln!(" DIAGNOSTIC CONTEXT DUMP — original.exe (clean, unobfuscated)"); eprintln!("========================================================================"); - // --------------------------------------------------------------- // 1. Enumerate ALL types, methods, fields, params in metadata - // --------------------------------------------------------------- let tables = assembly.tables().expect("assembly should have tables"); let strings = assembly.strings().expect("assembly should have strings"); @@ -2086,9 +2086,7 @@ mod tests { } } - // --------------------------------------------------------------- // 2. Build infrastructure (SSA + call graph) via CascadeRenamer - // --------------------------------------------------------------- let provider = SimpleProvider::new(); let fallback = SimpleProvider::new(); let config = SmartRenameConfig::default(); @@ -2137,9 +2135,7 @@ mod tests { } } - // --------------------------------------------------------------- // 3. For EACH method with SSA: dump all extracted features - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" PER-METHOD SSA FEATURE EXTRACTION"); eprintln!("========================================================================"); @@ -2270,9 +2266,7 @@ mod tests { } } - // --------------------------------------------------------------- // 4. build_method_context() for each method - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" FULL RENAME CONTEXT (build_method_context)"); eprintln!("========================================================================"); @@ -2329,9 +2323,7 @@ mod tests { } } - // --------------------------------------------------------------- // 5. Build param contexts - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" PARAMETER CONTEXTS"); eprintln!("========================================================================"); @@ -2367,9 +2359,7 @@ mod tests { } } - // --------------------------------------------------------------- // 6. Build field contexts - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" FIELD CONTEXTS"); eprintln!("========================================================================"); @@ -2395,9 +2385,7 @@ mod tests { } } - // --------------------------------------------------------------- // 7. Build type contexts - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" TYPE CONTEXTS"); eprintln!("========================================================================"); @@ -2424,9 +2412,7 @@ mod tests { } } - // --------------------------------------------------------------- // 8. Run full cascade and dump entries - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" FULL CASCADE EXECUTION"); eprintln!("========================================================================"); @@ -2449,9 +2435,7 @@ mod tests { eprintln!(" (No entries -- clean assembly has no obfuscated names, as expected)"); } - // --------------------------------------------------------------- // 9. Call graph analysis - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" CALL GRAPH ANALYSIS"); eprintln!("========================================================================"); @@ -2471,9 +2455,7 @@ mod tests { eprintln!(" (No call graph available)"); } - // --------------------------------------------------------------- // 10. Summary statistics - // --------------------------------------------------------------- eprintln!("\n========================================================================"); eprintln!(" SUMMARY"); eprintln!("========================================================================"); diff --git a/dotscope/src/deobfuscation/renamer/config.rs b/dotscope/src/deobfuscation/renamer/config.rs index d9c368c8..695efe88 100644 --- a/dotscope/src/deobfuscation/renamer/config.rs +++ b/dotscope/src/deobfuscation/renamer/config.rs @@ -8,8 +8,9 @@ use std::path::PathBuf; /// Configuration for the smart renaming pipeline. /// /// Controls the local inference backend (mistral.rs) and the cascade -/// renaming pipeline parameters. When `None` in [`CleanupConfig`], the -/// simple sequential renamer is used instead. +/// renaming pipeline parameters. When `None` in +/// [`CleanupConfig`](crate::deobfuscation::CleanupConfig), the simple +/// sequential renamer is used instead. /// /// # Feature Gate /// diff --git a/dotscope/src/deobfuscation/renamer/mod.rs b/dotscope/src/deobfuscation/renamer/mod.rs index a886f788..adc0fb6c 100644 --- a/dotscope/src/deobfuscation/renamer/mod.rs +++ b/dotscope/src/deobfuscation/renamer/mod.rs @@ -396,6 +396,8 @@ fn update_row_name_field( #[cfg(test)] mod tests { use std::collections::HashSet; + #[cfg(feature = "smart-rename")] + use std::path::PathBuf; use crate::{ cilassembly::{CilAssembly, GeneratorConfig}, @@ -747,8 +749,6 @@ mod tests { #[ignore] #[cfg(feature = "smart-rename")] fn test_smart_rename_llm() { - use std::path::PathBuf; - let model_path = PathBuf::from(env!("CARGO_MANIFEST_DIR")) .join("../qwen2.5-coder-3b-instruct-q4_k_m.gguf"); assert!( diff --git a/dotscope/src/deobfuscation/renamer/phases.rs b/dotscope/src/deobfuscation/renamer/phases.rs index 2d79e518..4f29763a 100644 --- a/dotscope/src/deobfuscation/renamer/phases.rs +++ b/dotscope/src/deobfuscation/renamer/phases.rs @@ -154,7 +154,9 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti } // Arithmetic - SsaOp::Add { dest, left, right } + SsaOp::Add { + dest, left, right, .. + } | SsaOp::AddOvf { dest, left, right, .. } => { @@ -165,7 +167,9 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti right.index() )); } - SsaOp::Sub { dest, left, right } + SsaOp::Sub { + dest, left, right, .. + } | SsaOp::SubOvf { dest, left, right, .. } => { @@ -176,7 +180,9 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti right.index() )); } - SsaOp::Mul { dest, left, right } + SsaOp::Mul { + dest, left, right, .. + } | SsaOp::MulOvf { dest, left, right, .. } => { @@ -207,7 +213,7 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti right.index() )); } - SsaOp::Neg { dest, operand } => { + SsaOp::Neg { dest, operand, .. } => { lines.push(format!( " var_{} = -var_{};", dest.index(), @@ -216,7 +222,9 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti } // Bitwise - SsaOp::And { dest, left, right } => { + SsaOp::And { + dest, left, right, .. + } => { lines.push(format!( " var_{} = var_{} & var_{};", dest.index(), @@ -224,7 +232,9 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti right.index() )); } - SsaOp::Or { dest, left, right } => { + SsaOp::Or { + dest, left, right, .. + } => { lines.push(format!( " var_{} = var_{} | var_{};", dest.index(), @@ -232,7 +242,9 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti right.index() )); } - SsaOp::Xor { dest, left, right } => { + SsaOp::Xor { + dest, left, right, .. + } => { lines.push(format!( " var_{} = var_{} ^ var_{};", dest.index(), @@ -240,7 +252,7 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti right.index() )); } - SsaOp::Not { dest, operand } => { + SsaOp::Not { dest, operand, .. } => { lines.push(format!( " var_{} = ~var_{};", dest.index(), @@ -251,6 +263,7 @@ pub fn build_call_site_skeleton(ssa: &SsaFunction, assembly: &CilObject) -> Opti dest, value, amount, + .. } => { lines.push(format!( " var_{} = var_{} << var_{};", diff --git a/dotscope/src/deobfuscation/renamer/prompt.rs b/dotscope/src/deobfuscation/renamer/prompt.rs index 32238ca0..3c1833b5 100644 --- a/dotscope/src/deobfuscation/renamer/prompt.rs +++ b/dotscope/src/deobfuscation/renamer/prompt.rs @@ -7,7 +7,7 @@ //! Each [`IdentifierKind`] has a distinct template optimized for the //! information most useful for that kind of rename. -use crate::deobfuscation::renamer::context::{IdentifierKind, PhaseInfo, RenameContext}; +use crate::deobfuscation::renamer::context::{IdentifierKind, ParamInfo, PhaseInfo, RenameContext}; /// Builds a FIM prompt from a rename context. /// @@ -416,7 +416,7 @@ fn build_parameter_prompt(context: &RenameContext) -> (String, String) { /// # Returns /// /// A comma-separated parameter string (e.g., `"string path, int param_1"`). -fn format_params(params: &[crate::deobfuscation::renamer::context::ParamInfo]) -> String { +fn format_params(params: &[ParamInfo]) -> String { if params.is_empty() { return String::new(); } @@ -478,9 +478,10 @@ fn truncate_phases(phases: &[PhaseInfo], max_phases: usize) -> Vec<&PhaseInfo> { #[cfg(test)] mod tests { - use crate::deobfuscation::renamer::{ - context::{IdentifierKind, ParamInfo, PhaseInfo, RenameContext}, - prompt::{build_fim_prompt, build_phase_label_prompt}, + use super::*; + + use crate::deobfuscation::renamer::context::{ + ApiCallInfo, IdentifierKind, ParamInfo, PhaseInfo, RenameContext, }; /// Default max phases used in tests. @@ -713,8 +714,6 @@ mod tests { /// Parameter prompt should include call targets from the owning method. #[test] fn test_prompt_parameter_with_method_context() { - use crate::deobfuscation::renamer::context::ApiCallInfo; - let ctx = RenameContext { kind: Some(IdentifierKind::Parameter), dotnet_type: Some("byte[]".to_string()), diff --git a/dotscope/src/deobfuscation/statemachine.rs b/dotscope/src/deobfuscation/statemachine.rs index cf4b6c60..fc45665c 100644 --- a/dotscope/src/deobfuscation/statemachine.rs +++ b/dotscope/src/deobfuscation/statemachine.rs @@ -42,13 +42,18 @@ //! - `init_ops = [Mul, Mul, Mul, Mul]` //! - `slot_ops = [Xor, Add, Xor, Sub]` (incremental mode operations) -use std::{collections::HashSet, sync::Arc}; +use std::{ + collections::HashSet, + fmt::{self, Debug, Display, Formatter}, + sync::Arc, +}; + +use analyssa::graph::{algorithms::DominatorTree, NodeId}; use crate::{ analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, compiler::CompilerContext, metadata::token::Token, - utils::graph::{algorithms::DominatorTree, NodeId}, CilObject, }; @@ -160,7 +165,7 @@ pub struct CfgInfo<'a> { /// } /// } /// ``` -pub trait StateMachineProvider: Send + Sync + std::fmt::Debug { +pub trait StateMachineProvider: Send + Sync + Debug { /// Returns the name of this state machine provider (for diagnostics). fn name(&self) -> &'static str; @@ -394,8 +399,8 @@ pub enum SsaOpKind { Neg, } -impl std::fmt::Display for SsaOpKind { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Display for SsaOpKind { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { Self::Xor => write!(f, "xor"), Self::Add => write!(f, "add"), diff --git a/dotscope/src/deobfuscation/techniques/bitmono/debug.rs b/dotscope/src/deobfuscation/techniques/bitmono/debug.rs index 83c6b29c..37b3937d 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/debug.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/debug.rs @@ -51,11 +51,13 @@ use std::{collections::HashSet, sync::Arc}; use crate::{ - compiler::{PassPhase, SsaPass}, + analysis::CilTarget, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::{SentinelCondition, SentinelTaintRemovalPass}, techniques::{Detection, Evidence, Technique, TechniqueCategory}, + utils::find_methods_calling_apis, }, metadata::token::Token, CilObject, @@ -95,7 +97,7 @@ impl Technique for BitMonoAntiDebug { fn detect(&self, assembly: &CilObject) -> Detection { let patterns = &["get_UtcNow", "op_Subtraction", "get_TotalMilliseconds"]; - let matches = crate::deobfuscation::utils::find_methods_calling_apis(assembly, patterns); + let matches = find_methods_calling_apis(assembly, patterns); // Require all three sentinel APIs in the same method let method_tokens: HashSet = matches @@ -131,7 +133,7 @@ impl Technique for BitMonoAntiDebug { _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/bitmono/strings.rs b/dotscope/src/deobfuscation/techniques/bitmono/strings.rs index 770edc99..82f58539 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/strings.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/strings.rs @@ -54,12 +54,14 @@ use std::{ #[cfg(feature = "legacy-crypto")] use crate::deobfuscation::passes::bitmono::StringDecryptionPass; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, ConstValue, SsaFunction, SsaOp, SsaVarId}, cilassembly::CleanupRequest, - compiler::{PassPhase, SsaPass}, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ + config::EngineConfig, context::AnalysisContext, techniques::{Detection, Evidence, Technique, TechniqueCategory}, + utils::build_init_array_map, }, metadata::{ tables::{MemberRefRaw, TableId, TypeDefRaw, TypeRefRaw}, @@ -117,12 +119,12 @@ impl Technique for BitMonoStrings { } #[cfg(feature = "legacy-crypto")] - fn enabled(&self, _config: &crate::deobfuscation::config::EngineConfig) -> bool { + fn enabled(&self, _config: &EngineConfig) -> bool { true } #[cfg(not(feature = "legacy-crypto"))] - fn enabled(&self, _config: &crate::deobfuscation::config::EngineConfig) -> bool { + fn enabled(&self, _config: &EngineConfig) -> bool { false } @@ -244,7 +246,7 @@ impl Technique for BitMonoStrings { _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; @@ -260,7 +262,7 @@ impl Technique for BitMonoStrings { _ctx: &AnalysisContext, _detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { Vec::new() } @@ -318,7 +320,7 @@ fn collect_constant_data_tokens( } // Build mapping: byte_array_field_token → backing_field_token - let init_map = crate::deobfuscation::utils::build_init_array_map(assembly); + let init_map = build_init_array_map(assembly); if init_map.is_empty() { return (constant_data_fields, constant_data_types); } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs b/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs index 0cca0eb6..07c1c641 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs @@ -39,9 +39,9 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use crate::{ - analysis::x86_native_body_size, + analysis::{x86_native_body_size, CilTarget}, cilassembly::CleanupRequest, - compiler::{PassPhase, SsaPass}, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::bitmono::UnmanagedStringReversalPass, @@ -161,7 +161,7 @@ impl Technique for BitMonoUnmanaged { _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/confuserex/debug.rs b/dotscope/src/deobfuscation/techniques/confuserex/debug.rs index dcac22eb..9e2ae3d0 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/debug.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/debug.rs @@ -74,8 +74,9 @@ use std::{collections::HashSet, sync::Arc}; use crate::{ + analysis::CilTarget, cilassembly::CleanupRequest, - compiler::{PassPhase, SsaPass}, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::{SentinelCondition, SentinelTaintRemovalPass}, @@ -246,7 +247,7 @@ impl Technique for ConfuserExAntiDebug { _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs b/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs index 171841f2..6f46bff0 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs @@ -50,9 +50,9 @@ use crate::{ typesystem::CilType, }, prelude::FlowType, - utils::graph::NodeId, CilObject, }; +use analyssa::graph::NodeId; /// Information about a call site to a decryptor method. pub struct DetectedCallSite { diff --git a/dotscope/src/deobfuscation/techniques/detection.rs b/dotscope/src/deobfuscation/techniques/detection.rs index 9da3fabd..c78bc6b4 100644 --- a/dotscope/src/deobfuscation/techniques/detection.rs +++ b/dotscope/src/deobfuscation/techniques/detection.rs @@ -165,8 +165,9 @@ impl Detections { /// Returns the current generation counter. /// /// Incremented on any mutation that could affect technique sorting - /// (insert, merge, merge_all). Used by [`TechniqueRegistry`] to - /// invalidate its sorted cache. + /// (insert, merge, merge_all). Used by + /// [`TechniqueRegistry`](crate::deobfuscation::techniques::TechniqueRegistry) + /// to invalidate its sorted cache. #[must_use] pub fn generation(&self) -> u64 { self.generation diff --git a/dotscope/src/deobfuscation/techniques/generic/delegates.rs b/dotscope/src/deobfuscation/techniques/generic/delegates.rs index 92b549d7..675b23be 100644 --- a/dotscope/src/deobfuscation/techniques/generic/delegates.rs +++ b/dotscope/src/deobfuscation/techniques/generic/delegates.rs @@ -40,9 +40,9 @@ use std::{ use log::debug; use crate::{ - analysis::{SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, SsaFunction, SsaOp, SsaVarId}, cilassembly::CleanupRequest, - compiler::{PassPhase, SsaPass}, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::{ @@ -374,12 +374,12 @@ impl Technique for GenericDelegateProxy { ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(combined) = detection.findings::() else { return Vec::new(); }; - let mut passes: Vec> = Vec::new(); + let mut passes: Vec>> = Vec::new(); // Delegate proxy resolution pass (emulation-based) let delegate = &combined.delegate; diff --git a/dotscope/src/deobfuscation/techniques/generic/flattening.rs b/dotscope/src/deobfuscation/techniques/generic/flattening.rs index 03bfd0ff..4baf140d 100644 --- a/dotscope/src/deobfuscation/techniques/generic/flattening.rs +++ b/dotscope/src/deobfuscation/techniques/generic/flattening.rs @@ -26,7 +26,8 @@ use std::{collections::HashMap, sync::Arc}; use crate::{ - compiler::{PassPhase, SsaPass}, + analysis::CilTarget, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::{CffDetector, CffReconstructionPass, Dispatcher, UnflattenConfig}, @@ -141,7 +142,7 @@ impl Technique for GenericFlattening { ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let cff_config = UnflattenConfig { max_states: ctx.config.unflattening.max_states_per_case, max_tree_depth: ctx.config.unflattening.max_trace_iterations, diff --git a/dotscope/src/deobfuscation/techniques/generic/opaquefields.rs b/dotscope/src/deobfuscation/techniques/generic/opaquefields.rs index ee318726..36c6a5ce 100644 --- a/dotscope/src/deobfuscation/techniques/generic/opaquefields.rs +++ b/dotscope/src/deobfuscation/techniques/generic/opaquefields.rs @@ -52,9 +52,9 @@ use std::{ }; use crate::{ - analysis::{SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, SsaFunction, SsaOp, SsaVarId}, cilassembly::CleanupRequest, - compiler::{PassPhase, SsaPass}, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::OpaqueFieldPredicatePass, @@ -436,7 +436,7 @@ impl Technique for GenericOpaquePredicates { ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(pool) = ctx.template_pool.get().cloned() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs b/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs index 6c4c8dbd..20dfd8aa 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs @@ -40,8 +40,9 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use crate::{ - assembly::{Immediate, Operand}, - compiler::{EventLog, PassPhase, SsaPass}, + analysis::CilTarget, + assembly::{Immediate, Instruction, Operand}, + compiler::{CompilerContext, EventLog, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::jiejienet::ArrayInitRestorationPass, @@ -53,7 +54,7 @@ use crate::{ signatures::TypeSignature, tables::{FieldRvaRaw, TableId}, token::Token, - typesystem::wellknown, + typesystem::{wellknown, PointerSize}, }, CilObject, Error, Result, }; @@ -300,7 +301,7 @@ impl Technique for JiejieNetArrays { _ctx: &AnalysisContext, detection: &Detection, assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; @@ -724,7 +725,7 @@ fn emulate_delta_chain_cctor(assembly: &CilObject, cctor_token: Token) -> HashMa /// Finds the i32 value from the nearest `ldc.i4*` or `ldsfld` (Int32ValueContainer) /// instruction preceding `pos`. fn find_preceding_i32_value( - instructions: &[&crate::assembly::Instruction], + instructions: &[&Instruction], pos: usize, container_values: &HashMap, ) -> Option { @@ -756,7 +757,7 @@ fn find_preceding_i32_value( /// Finds the i32 index constant preceding a `call GetHandle` instruction /// that occurs before `pos`. fn find_preceding_get_handle_index( - instructions: &[&crate::assembly::Instruction], + instructions: &[&Instruction], pos: usize, accessor_token: Token, container_values: &HashMap, @@ -863,7 +864,7 @@ fn calculate_field_data_size(assembly: &CilObject, field_rid: u32) -> Result, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs b/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs index 55f93003..e959ed8f 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs @@ -47,10 +47,10 @@ use std::{any::Any, collections::HashMap, sync::Arc}; use log::debug; use crate::{ - analysis::{ConstValue, SsaFunction, SsaOp, SsaVarId}, + analysis::{CilTarget, ConstValue, SsaFunction, SsaOp, SsaVarId}, assembly::{Immediate, Operand}, cilassembly::GeneratorConfig, - compiler::{EventKind, EventLog, PassPhase, SsaPass}, + compiler::{CompilerContext, EventKind, EventLog, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::jiejienet::{ResourceRestorationPass, ResourceTarget}, @@ -324,7 +324,7 @@ impl Technique for JiejieNetResources { _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs b/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs index 3944e515..56e174de 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/strings.rs @@ -31,7 +31,8 @@ use std::{any::Any, sync::Arc}; use crate::{ - compiler::{PassPhase, SsaPass}, + analysis::CilTarget, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::{StaticFieldResolutionPass, StringExtractor}, @@ -39,6 +40,7 @@ use crate::{ }, metadata::{ signatures::TypeSignature, + tables::TypeAttributes, token::Token, typesystem::{wellknown, CilType}, }, @@ -300,7 +302,7 @@ impl Technique for JiejieNetStrings { ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; @@ -368,7 +370,7 @@ fn is_byte_array_data_container(cil_type: &CilType) -> bool { let mut explicit_layout_count: usize = 0; for (_, nested_ref) in cil_type.nested_types.iter() { if let Some(nested) = nested_ref.upgrade() { - if nested.flags.layout() == crate::metadata::tables::TypeAttributes::EXPLICIT_LAYOUT { + if nested.flags.layout() == TypeAttributes::EXPLICIT_LAYOUT { explicit_layout_count = explicit_layout_count.saturating_add(1); } } @@ -379,15 +381,15 @@ fn is_byte_array_data_container(cil_type: &CilType) -> bool { #[cfg(test)] mod tests { + use std::sync::Arc; + + use super::*; + use crate::{ - deobfuscation::techniques::{ - jiejienet::strings::{JiejieNetStrings, StringFindings}, - Technique, - }, - emulation::{EmValue, EmulationOutcome}, + deobfuscation::techniques::Technique, + emulation::{EmValue, EmulationOutcome, ProcessBuilder}, test::helpers::load_sample, }; - use std::sync::Arc; #[test] fn test_detect_positive_strings_only() { @@ -462,8 +464,6 @@ mod tests { /// first (dependency), then the string class .cctor, and checks the results. #[test] fn test_emulate_string_cctor() { - use crate::emulation::ProcessBuilder; - let asm = load_sample("tests/samples/packers/jiejie/source/jiejie_strings_only.exe"); let technique = JiejieNetStrings; let detection = technique.detect(&asm); diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs b/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs index fd45ade4..73ba5945 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs @@ -22,8 +22,9 @@ use std::{any::Any, sync::Arc}; use crate::{ + analysis::CilTarget, assembly::Operand, - compiler::{PassPhase, SsaPass}, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::jiejienet::TypeOfRestorationPass, @@ -197,7 +198,7 @@ impl Technique for JiejieNetTypeOf { _ctx: &AnalysisContext, detection: &Detection, assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/mod.rs b/dotscope/src/deobfuscation/techniques/mod.rs index 75d257f4..26c49764 100644 --- a/dotscope/src/deobfuscation/techniques/mod.rs +++ b/dotscope/src/deobfuscation/techniques/mod.rs @@ -73,8 +73,9 @@ pub(crate) use bitmono::StringFindings as BitMonoStringFindings; use std::sync::Arc; use crate::{ + analysis::CilTarget, cilassembly::CleanupRequest, - compiler::{EventLog, PassPhase, SsaPass}, + compiler::{CompilerContext, EventLog, PassPhase, SsaPass}, deobfuscation::{config::EngineConfig, context::AnalysisContext}, CilObject, Result, }; @@ -265,7 +266,7 @@ pub trait Technique: Send + Sync { _ctx: &AnalysisContext, _detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { Vec::new() } diff --git a/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs b/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs index 323a6e94..25983b89 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs @@ -55,7 +55,8 @@ use std::{any::Any, sync::Arc}; use crate::{ - compiler::{PassPhase, SsaPass}, + analysis::CilTarget, + compiler::{CompilerContext, PassPhase, SsaPass}, deobfuscation::{ context::AnalysisContext, passes::netreactor::TokenResolverPass, @@ -233,7 +234,7 @@ impl Technique for NetReactorAntiTamp { _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; @@ -280,11 +281,14 @@ fn collect_runtime_method_tokens(assembly: &CilObject, runtime_type: Option Option { let path = format!("tests/samples/packers/netreactor/7.5.0/{name}"); - if !std::path::Path::new(&path).exists() { + if !Path::new(&path).exists() { eprintln!("Skipping test: sample not found at {path}"); return None; } diff --git a/dotscope/src/deobfuscation/techniques/netreactor/resources.rs b/dotscope/src/deobfuscation/techniques/netreactor/resources.rs index c6a2a339..5af304ec 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/resources.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/resources.rs @@ -88,16 +88,22 @@ use std::{any::Any, collections::HashSet, sync::Arc}; use crate::{ + analysis::CilTarget, assembly::Operand, cilassembly::GeneratorConfig, - compiler::{EventKind, EventLog}, - deobfuscation::techniques::{ - netreactor::{helpers::find_resources_referenced_by_methods, hooks}, - Detection, Detections, Evidence, Technique, TechniqueCategory, WorkingAssembly, + compiler::{CompilerContext, EventKind, EventLog, PassPhase, SsaPass}, + deobfuscation::{ + context::AnalysisContext, + passes::netreactor::ResourceShimRewritePass, + techniques::{ + netreactor::{helpers::find_resources_referenced_by_methods, hooks}, + Detection, Detections, Evidence, Technique, TechniqueCategory, WorkingAssembly, + }, }, - emulation::{EmulationOutcome, ProcessBuilder}, + emulation::{CapturedAssembly, EmulationOutcome, ProcessBuilder}, error::Error, metadata::{ + method::MethodRc, signatures::TypeSignature, tables::{ManifestResourceBuilder, TableId, TypeRefRaw}, token::Token, @@ -269,10 +275,10 @@ impl Technique for NetReactorResources { fn create_pass( &self, - _ctx: &crate::deobfuscation::context::AnalysisContext, + _ctx: &AnalysisContext, detection: &Detection, _assembly: &Arc, - ) -> Vec> { + ) -> Vec>> { let Some(findings) = detection.findings::() else { return Vec::new(); }; @@ -289,22 +295,20 @@ impl Technique for NetReactorResources { // NOT the reflective `Assembly.Load` wrappers (those are // emulator-only and are intercepted by the runtime hook in // `byte_transform`). - vec![Box::new( - crate::deobfuscation::passes::netreactor::ResourceShimRewritePass::new( - findings - .get_manifest_resource_names_shim_tokens - .iter() - .copied(), - findings.lazy_init_token, - findings.bcl_get_manifest_resource_names, - ), - )] + vec![Box::new(ResourceShimRewritePass::new( + findings + .get_manifest_resource_names_shim_tokens + .iter() + .copied(), + findings.lazy_init_token, + findings.bcl_get_manifest_resource_names, + ))] } - fn ssa_phase(&self) -> Option { + fn ssa_phase(&self) -> Option { // Same phase as the other NR shim folders — runs alongside the // value-folding stage so the rewrites land before final cleanup. - Some(crate::compiler::PassPhase::Value) + Some(PassPhase::Value) } fn byte_transform( @@ -530,7 +534,7 @@ impl Technique for NetReactorResources { /// (offset out of range, external implementation) and empty bodies are /// skipped — they would not round-trip back into the deobfuscated /// output. -fn harvest_resources(captured: &[crate::emulation::CapturedAssembly]) -> Vec<(String, Vec)> { +fn harvest_resources(captured: &[CapturedAssembly]) -> Vec<(String, Vec)> { let mut out = Vec::new(); for cap in captured { let parsed = match CilObject::from_mem_with_validation( @@ -706,11 +710,7 @@ fn find_lazy_init(assembly: &CilObject, cil_type: &CilTypeRc) -> Option { } /// Returns true if `method`'s IL matches the lazy-init shape exactly. -fn is_lazy_init_body( - assembly: &CilObject, - method: &crate::metadata::method::MethodRc, - cil_type: &CilTypeRc, -) -> bool { +fn is_lazy_init_body(assembly: &CilObject, method: &MethodRc, cil_type: &CilTypeRc) -> bool { // Strip nop/br.s noise to match the canonical shape regardless of whether // the protector inserted padding. let instrs: Vec<_> = method @@ -1084,10 +1084,7 @@ fn find_assembly_load_shim_methods(assembly: &CilObject, cil_type: &CilTypeRc) - /// method that already passed the `(uint8[]) -> object` signature gate is /// strong enough — no legitimate static `(byte[]) -> object` method on a /// resolver-shape type does both. -fn is_assembly_load_reflection_shim( - method: &crate::metadata::method::MethodRc, - assembly: &CilObject, -) -> bool { +fn is_assembly_load_reflection_shim(method: &MethodRc, assembly: &CilObject) -> bool { let mut saw_assembly_token = false; let mut saw_invoke_call = false; for instr in method.instructions() { @@ -1132,9 +1129,14 @@ fn encrypted_resource_names(assembly: &CilObject, findings: &ResourceFindings) - #[cfg(test)] mod tests { - use std::sync::Arc; - use super::*; + + use std::{ + env, + path::{Path, PathBuf}, + sync::Arc, + }; + use crate::{ emulation::{EmulationOutcome, ProcessBuilder, TracingConfig}, metadata::validation::ValidationConfig, @@ -1142,7 +1144,7 @@ mod tests { fn try_load_sample(name: &str) -> Option { let path = format!("tests/samples/packers/netreactor/7.5.0/{name}"); - if !std::path::Path::new(&path).exists() { + if !Path::new(&path).exists() { eprintln!("Skipping test: sample not found at {path}"); return None; } @@ -1274,7 +1276,7 @@ mod tests { ); let cilobject = Arc::new(assembly); - let trace_path = std::env::var("NR_TRACE").ok().map(std::path::PathBuf::from); + let trace_path = env::var("NR_TRACE").ok().map(PathBuf::from); let mut builder = ProcessBuilder::new() .assembly_arc(Arc::clone(&cilobject)) diff --git a/dotscope/src/emulation/filesystem.rs b/dotscope/src/emulation/filesystem.rs index 499320e2..cfc98b89 100644 --- a/dotscope/src/emulation/filesystem.rs +++ b/dotscope/src/emulation/filesystem.rs @@ -194,9 +194,11 @@ impl Default for VirtualFs { #[cfg(test)] mod tests { + use std::io::Write; + use cowfile::CowFile; - use crate::emulation::filesystem::VirtualFs; + use super::*; #[test] fn test_normalize_path() { @@ -237,7 +239,6 @@ mod tests { #[test] fn test_fork_from_disk() { - use std::io::Write; let mut tmpfile = tempfile::NamedTempFile::new().unwrap(); tmpfile.write_all(&[0xDE, 0xAD]).unwrap(); tmpfile.flush().unwrap(); diff --git a/dotscope/src/emulation/memory/heap/mod.rs b/dotscope/src/emulation/memory/heap/mod.rs index ceb9d3d9..dfd62468 100644 --- a/dotscope/src/emulation/memory/heap/mod.rs +++ b/dotscope/src/emulation/memory/heap/mod.rs @@ -1051,7 +1051,7 @@ impl ManagedHeap { /// /// This enables virtual dispatch and type checks on BCL wrapper objects /// that were allocated via factory methods. The token is stored in - /// `original_types` and takes priority in [`get_type_token`]. + /// `original_types` and takes priority in [`Self::get_type_token`]. /// /// # Errors /// diff --git a/dotscope/src/emulation/process/builder.rs b/dotscope/src/emulation/process/builder.rs index f59f00b8..7eaeada2 100644 --- a/dotscope/src/emulation/process/builder.rs +++ b/dotscope/src/emulation/process/builder.rs @@ -136,7 +136,7 @@ fn populate_fieldrva_statics(assembly: &CilObject, address_space: &AddressSpace) let types = assembly.types(); let file = assembly.file(); let pe_data = file.data(); - let ptr_size = PointerSize::from_pe(file.pe().is_64bit); + let ptr_size = PointerSize::from_is_64bit(file.pe().is_64bit); for row in fieldrva_table { if row.rva == 0 { diff --git a/dotscope/src/emulation/runtime/appdomain.rs b/dotscope/src/emulation/runtime/appdomain.rs index 756f59ef..c80f7b18 100644 --- a/dotscope/src/emulation/runtime/appdomain.rs +++ b/dotscope/src/emulation/runtime/appdomain.rs @@ -440,7 +440,8 @@ impl AppDomainState { /// Registers a parsed assembly loaded at runtime (e.g., via `Assembly.Load(byte[])`). /// /// Returns the index of the newly registered assembly, which can be stored - /// in [`ThreadCallFrame::assembly_index`] to associate frames with their + /// in [`ThreadCallFrame::assembly_index`](crate::emulation::ThreadCallFrame::assembly_index) + /// to associate frames with their /// originating assembly. /// /// # Arguments diff --git a/dotscope/src/emulation/runtime/bcl/runtime.rs b/dotscope/src/emulation/runtime/bcl/runtime.rs index a332a23b..e67a2085 100644 --- a/dotscope/src/emulation/runtime/bcl/runtime.rs +++ b/dotscope/src/emulation/runtime/bcl/runtime.rs @@ -918,6 +918,7 @@ fn unsafe_sizeof_pre(ctx: &HookContext<'_>, _thread: &mut EmulationThread) -> Pr let size = match ctx.pointer_size { PointerSize::Bit64 => 8, PointerSize::Bit32 => 4, + PointerSize::Bit8 | PointerSize::Bit16 | PointerSize::Bit128 => 0, }; PreHookResult::Bypass(Some(EmValue::I32(size))) } diff --git a/dotscope/src/emulation/value/emvalue.rs b/dotscope/src/emulation/value/emvalue.rs index e8032453..e7eda6ff 100644 --- a/dotscope/src/emulation/value/emvalue.rs +++ b/dotscope/src/emulation/value/emvalue.rs @@ -1533,10 +1533,16 @@ impl EmValue { (EmValue::NativeInt(v), _) => match ptr_size { PointerSize::Bit32 => LeBytes::from_4((*v as i32).to_le_bytes()), PointerSize::Bit64 => LeBytes::from_8(v.to_le_bytes()), + PointerSize::Bit8 => LeBytes::from_byte((*v as i8).to_le_bytes()[0]), + PointerSize::Bit16 => LeBytes::from_2((*v as i16).to_le_bytes()), + PointerSize::Bit128 => LeBytes::zeroed(8), }, (EmValue::NativeUInt(v), _) => match ptr_size { PointerSize::Bit32 => LeBytes::from_4((*v as u32).to_le_bytes()), PointerSize::Bit64 => LeBytes::from_8(v.to_le_bytes()), + PointerSize::Bit8 => LeBytes::from_byte((*v as u8).to_le_bytes()[0]), + PointerSize::Bit16 => LeBytes::from_2((*v as u16).to_le_bytes()), + PointerSize::Bit128 => LeBytes::zeroed(8), }, (EmValue::Bool(v), _) => LeBytes::from_byte(u8::from(*v)), (EmValue::Char(v), _) => LeBytes::from_2((*v as u16).to_le_bytes()), @@ -1626,6 +1632,36 @@ impl EmValue { EmValue::NativeUInt(u64::from_le_bytes(arr)) } } + PointerSize::Bit8 => { + let byte = bytes.first().copied().unwrap_or(0); + if matches!(flavor, CilFlavor::I) { + EmValue::NativeInt(i64::from(byte as i8)) + } else { + EmValue::NativeUInt(u64::from(byte)) + } + } + PointerSize::Bit16 => { + let arr: [u8; 2] = bytes + .get(..2.min(bytes.len())) + .and_then(|s| s.try_into().ok()) + .unwrap_or([0, 0]); + if matches!(flavor, CilFlavor::I) { + EmValue::NativeInt(i64::from(i16::from_le_bytes(arr))) + } else { + EmValue::NativeUInt(u64::from(u16::from_le_bytes(arr))) + } + } + PointerSize::Bit128 => { + let arr: [u8; 16] = bytes + .get(..16.min(bytes.len())) + .and_then(|s| s.try_into().ok()) + .unwrap_or([0; 16]); + if matches!(flavor, CilFlavor::I) { + EmValue::NativeInt(i128::from_le_bytes(arr) as i64) + } else { + EmValue::NativeUInt(u128::from_le_bytes(arr) as u64) + } + } }, _ => EmValue::I32(0), } diff --git a/dotscope/src/emulation/value/ops/binary.rs b/dotscope/src/emulation/value/ops/binary.rs index e8ba3800..2fddb58c 100644 --- a/dotscope/src/emulation/value/ops/binary.rs +++ b/dotscope/src/emulation/value/ops/binary.rs @@ -920,7 +920,7 @@ mod tests { use crate::{ emulation::{ engine::EmulationError, - value::{BinaryOp, EmValue}, + value::{BinaryOp, EmValue, HeapRef}, }, metadata::typesystem::PointerSize, Error, @@ -1211,7 +1211,6 @@ mod tests { #[test] fn test_bitwise_and_objectref_i32() { - use crate::emulation::value::HeapRef; let obj = EmValue::ObjectRef(HeapRef::new(0xDEAD_BEEF)); let mask = EmValue::I32(0x0000_FFFF_u32 as i32); let result = obj @@ -1230,7 +1229,6 @@ mod tests { #[test] fn test_bitwise_xor_objectref_i32() { - use crate::emulation::value::HeapRef; let obj = EmValue::ObjectRef(HeapRef::new(42)); let key = EmValue::I32(99); let result = obj diff --git a/dotscope/src/error.rs b/dotscope/src/error.rs index 31bba1a0..87642776 100644 --- a/dotscope/src/error.rs +++ b/dotscope/src/error.rs @@ -93,6 +93,8 @@ //! propagation in concurrent parsing and analysis operations. //! +use std::io; + use thiserror::Error; #[cfg(feature = "emulation")] @@ -283,7 +285,7 @@ pub enum Error { /// Wraps standard I/O errors that can occur during file operations /// such as reading from disk, permission issues, or filesystem errors. #[error("{0}")] - Io(#[from] std::io::Error), + Io(#[from] io::Error), /// Other errors that don't fit specific categories. /// @@ -695,3 +697,15 @@ impl From for Error { Error::Emulation(Box::new(err)) } } + +impl From for Error { + fn from(err: analyssa::GraphError) -> Self { + Error::GraphError(err.0) + } +} + +impl From for analyssa::Error { + fn from(err: Error) -> Self { + analyssa::Error::new(err.to_string()) + } +} diff --git a/dotscope/src/formatting/mod.rs b/dotscope/src/formatting/mod.rs index 34df224a..f3dd5bed 100644 --- a/dotscope/src/formatting/mod.rs +++ b/dotscope/src/formatting/mod.rs @@ -1,8 +1,9 @@ //! ILDasm-compatible CIL disassembly formatter. //! //! This module provides a comprehensive formatter that produces ILDasm/ILAsm-compatible -//! text output from .NET assembly metadata. The [`IlFormatter`] is the main entry point, -//! with [`FormatterOptions`] controlling the level of detail in the output. +//! text output from .NET assembly metadata. [`crate::formatting::IlFormatter`] is the main +//! entry point, with [`crate::formatting::FormatterOptions`] controlling the level of detail +//! in the output. //! //! # Usage //! diff --git a/dotscope/src/lib.rs b/dotscope/src/lib.rs index 932c9ef6..4595c9c8 100644 --- a/dotscope/src/lib.rs +++ b/dotscope/src/lib.rs @@ -15,16 +15,12 @@ // SPDX-License-Identifier: Apache-2.0 #![doc(html_no_source)] -// This crate is used for malware analysis: every input byte is -// adversarial and must not be allowed to panic the parser. -#![deny( - missing_docs, - clippy::unwrap_used, - clippy::expect_used, - clippy::panic, - clippy::arithmetic_side_effects, - clippy::indexing_slicing -)] +// The `missing_docs`, `clippy::unwrap_used`, `clippy::expect_used`, +// `clippy::panic`, `clippy::arithmetic_side_effects`, and +// `clippy::indexing_slicing` lints are declared in `Cargo.toml` under +// `[lints]` so they enforce on every build regardless of the consuming +// workspace. dotscope is used in malware-analysis pipelines where every +// input byte is adversarial and the parser must not panic. #![cfg_attr( test, allow( diff --git a/dotscope/src/metadata/cilobject.rs b/dotscope/src/metadata/cilobject.rs index a5cc1fbc..f70c7546 100644 --- a/dotscope/src/metadata/cilobject.rs +++ b/dotscope/src/metadata/cilobject.rs @@ -1489,7 +1489,8 @@ impl CilObject { /// Returns the method definition for the given token, if it exists. /// /// This is a convenience accessor that looks up a method by its metadata token - /// and returns a cloned reference-counted pointer to the [`Method`] object. It + /// and returns a cloned reference-counted pointer to the + /// [`Method`](crate::metadata::method::Method) object. It /// eliminates the need to call [`methods()`](Self::methods), unwrap the `Entry` /// guard, and clone the value manually. /// @@ -1499,7 +1500,8 @@ impl CilObject { /// /// # Returns /// - /// A reference-counted [`Method`] if a method with the given token exists, `None` otherwise. + /// A reference-counted [`Method`](crate::metadata::method::Method) if a method with the + /// given token exists, `None` otherwise. /// /// # Examples /// diff --git a/dotscope/src/metadata/dependencies/graph.rs b/dotscope/src/metadata/dependencies/graph.rs index cf1fd2b9..2e5518b8 100644 --- a/dotscope/src/metadata/dependencies/graph.rs +++ b/dotscope/src/metadata/dependencies/graph.rs @@ -11,9 +11,10 @@ use std::sync::{ use dashmap::DashMap; +use analyssa::graph::{algorithms, DirectedGraph, IndexedGraph, NodeId}; + use crate::{ metadata::{dependencies::AssemblyDependency, identity::AssemblyIdentity}, - utils::graph::{algorithms, DirectedGraph, IndexedGraph, NodeId}, Error, Result, }; @@ -751,11 +752,13 @@ impl Default for AssemblyDependencyGraph { #[cfg(test)] mod tests { use super::*; + + use std::{collections::HashMap, thread}; + use crate::{ metadata::dependencies::DependencyType, test::helpers::dependencies::{create_test_dependency, create_test_identity}, }; - use std::thread; #[test] fn test_dependency_graph_creation() { @@ -980,7 +983,7 @@ mod tests { assert_eq!(order.len(), 6); // Get positions - let positions: std::collections::HashMap = order + let positions: HashMap = order .iter() .enumerate() .map(|(i, id)| (id.name.clone(), i)) diff --git a/dotscope/src/metadata/loader/graph.rs b/dotscope/src/metadata/loader/graph.rs index 0502882a..ecf7025f 100644 --- a/dotscope/src/metadata/loader/graph.rs +++ b/dotscope/src/metadata/loader/graph.rs @@ -21,12 +21,15 @@ //! - Construction: Single-threaded only //! - Generated plans: Thread-safe for parallel execution //! -use std::collections::{HashMap, HashSet}; -use std::fmt::Write; +use std::{ + collections::{HashMap, HashSet}, + fmt::Write, +}; + +use analyssa::graph::IndexedGraph; use crate::{ metadata::{loader::MetadataLoader, tables::TableId}, - utils::graph::IndexedGraph, Error::GraphError, Result, }; @@ -134,12 +137,8 @@ impl<'a> LoaderGraph<'a> { /// Returns `GraphError` if a loader depends on a table without a registered loader, /// or if circular dependencies are detected (debug builds). pub fn build_relationships(&mut self) -> Result<()> { - self.dependencies - .values_mut() - .for_each(std::collections::HashSet::clear); - self.dependents - .values_mut() - .for_each(std::collections::HashSet::clear); + self.dependencies.values_mut().for_each(HashSet::clear); + self.dependents.values_mut().for_each(HashSet::clear); for (loader_key, loader) in &self.loaders { for dep_table_id in loader.dependencies() { diff --git a/dotscope/src/metadata/resolver.rs b/dotscope/src/metadata/resolver.rs index 9758e0a5..65afa55c 100644 --- a/dotscope/src/metadata/resolver.rs +++ b/dotscope/src/metadata/resolver.rs @@ -1,7 +1,8 @@ //! Cross-table token resolver for .NET metadata. //! -//! Provides [`TokenResolver`], a lightweight borrowing wrapper over [`CilObject`] that -//! normalizes token references across metadata tables. This includes: +//! Provides [`TokenResolver`](crate::metadata::resolver::TokenResolver), a lightweight +//! borrowing wrapper over [`CilObject`](crate::CilObject) that normalizes token references +//! across metadata tables. This includes: //! //! - **TypeRef → TypeDef**: Resolves external type references to their local definitions //! - **MemberRef → MethodDef**: Resolves member references to locally defined methods @@ -17,7 +18,7 @@ //! //! # Usage //! -//! Obtained via [`CilObject::resolver()`]: +//! Obtained via [`CilObject::resolver()`](crate::CilObject::resolver): //! //! ```rust,no_run //! use dotscope::CilObject; diff --git a/dotscope/src/metadata/typesystem/base.rs b/dotscope/src/metadata/typesystem/base.rs index 08b48cc9..1ec1779b 100644 --- a/dotscope/src/metadata/typesystem/base.rs +++ b/dotscope/src/metadata/typesystem/base.rs @@ -83,10 +83,9 @@ //! - [`crate::metadata::typesystem`] - Higher-level type system operations use std::{ - fmt, + fmt::{self, Debug, Formatter}, hash::{Hash, Hasher}, - sync::Arc, - sync::Weak, + sync::{Arc, Weak}, }; use crate::{ @@ -684,8 +683,8 @@ pub enum CilTypeReference { None, } -impl std::fmt::Debug for CilTypeReference { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { +impl Debug for CilTypeReference { + fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result { match self { CilTypeReference::TypeRef(_) => write!(f, "CilTypeReference::TypeRef(...)"), CilTypeReference::TypeDef(_) => write!(f, "CilTypeReference::TypeDef(...)"), @@ -1765,76 +1764,11 @@ impl From<&TypeSignature> for CilFlavor { /// Target pointer width for native int/uint types. /// -/// Per ECMA-335, `native int` and `native uint` (`System.IntPtr` / `System.UIntPtr`) -/// are pointer-sized: 4 bytes on PE32 (32-bit) targets, 8 bytes on PE32+ (64-bit) targets. +/// Re-exported from `analyssa::PointerSize`. The implementation moved to analyssa +/// so the IR core's generic arithmetic methods (which need pointer-width +/// masking) don't have to depend on `dotscope::metadata`. /// -/// Derived from the PE header: PE32 → `Bit32`, PE32+ → `Bit64`. -#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] -pub enum PointerSize { - /// 32-bit target (4-byte pointers) - Bit32, - /// 64-bit target (8-byte pointers) - Bit64, -} - -impl PointerSize { - /// Creates a `PointerSize` from the PE header's bitness flag. - /// - /// PE32 binaries are 32-bit (`Bit32`), PE32+ binaries are 64-bit (`Bit64`). - /// - /// # Arguments - /// - /// * `is_64bit` - `true` for PE32+ (64-bit), `false` for PE32 (32-bit) - #[must_use] - pub fn from_pe(is_64bit: bool) -> Self { - if is_64bit { - Self::Bit64 - } else { - Self::Bit32 - } - } - - /// Returns the pointer size in bytes. - #[must_use] - pub fn bytes(self) -> usize { - match self { - Self::Bit32 => 4, - Self::Bit64 => 8, - } - } - - /// Returns the pointer size in bits. - #[must_use] - pub fn bits(self) -> u32 { - match self { - Self::Bit32 => 32, - Self::Bit64 => 64, - } - } - - /// Masks and sign-extends a signed value to the target pointer width. - #[must_use] - pub fn mask_signed(self, value: i64) -> i64 { - match self { - Self::Bit32 => { - #[allow(clippy::cast_possible_truncation)] - let truncated = value as i32; - i64::from(truncated) - } - Self::Bit64 => value, - } - } - - /// Masks and zero-extends an unsigned value to the target pointer width. - #[must_use] - pub fn mask_unsigned(self, value: u64) -> u64 { - match self { - Self::Bit32 => { - #[allow(clippy::cast_possible_truncation)] - let truncated = value as u32; - u64::from(truncated) - } - Self::Bit64 => value, - } - } -} +/// Per ECMA-335, `native int` and `native uint` (`System.IntPtr` / +/// `System.UIntPtr`) are pointer-sized: 4 bytes on PE32, 8 bytes on PE32+. +/// Use [`PointerSize::from_pe`] to derive from the PE header. +pub use analyssa::PointerSize; diff --git a/dotscope/src/metadata/vtfixup.rs b/dotscope/src/metadata/vtfixup.rs index d125eccd..7505d057 100644 --- a/dotscope/src/metadata/vtfixup.rs +++ b/dotscope/src/metadata/vtfixup.rs @@ -15,9 +15,9 @@ //! //! # Key Components //! -//! - [`VtFixupEntry`] - A single parsed VTableFixup directory entry with RVA, flags, and tokens -//! - [`VtFixupContext`] - Pre-computed context containing all entries plus method-to-slot and export maps -//! - [`parse`] - Entry point that reads and correlates the VTableFixup directory +//! - [`VtFixupEntry`](crate::metadata::vtfixup::VtFixupEntry) - A single parsed VTableFixup directory entry with RVA, flags, and tokens +//! - [`VtFixupContext`](crate::metadata::vtfixup::VtFixupContext) - Pre-computed context containing all entries plus method-to-slot and export maps +//! - [`parse`](crate::metadata::vtfixup::parse) - Entry point that reads and correlates the VTableFixup directory //! //! # Usage Examples //! @@ -39,15 +39,17 @@ //! //! # Thread Safety //! -//! All types in this module are [`Send`] and [`Sync`]. The parsed [`VtFixupContext`] -//! is immutable after construction and safe to share across threads. +//! All types in this module are [`Send`] and [`Sync`]. The parsed +//! [`VtFixupContext`](crate::metadata::vtfixup::VtFixupContext) is immutable after +//! construction and safe to share across threads. //! //! # Integration //! //! This module integrates with: //! - [`crate::metadata::cor20header`] - Source of the VTableFixup directory RVA and size //! - [`crate::metadata::exports`] - Native PE export table for correlating exports with vtable slots -//! - [`crate::formatting::vtfixup`] - Rendering of parsed data as ILAsm directives +//! - [`crate::formatting::IlFormatter`] - Rendering of parsed data as ILAsm directives +//! (the actual `formatting::vtfixup` helper module is private; reach it via `IlFormatter`). use std::collections::HashMap; diff --git a/dotscope/src/test/analysis/templates.rs b/dotscope/src/test/analysis/templates.rs index 50169cbf..b2bfe44c 100644 --- a/dotscope/src/test/analysis/templates.rs +++ b/dotscope/src/test/analysis/templates.rs @@ -563,7 +563,7 @@ public class Program /// All defined test cases with their expected properties. pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ - // ========== CFG TESTS: SEQUENTIAL ========== + // CFG tests: sequential. AnalysisTestCase { name: "cfg_sequential_add", method_name: "Add", @@ -630,7 +630,7 @@ pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ callgraph: None, dataflow: None, }, - // ========== CFG TESTS: CONDITIONAL ========== + // CFG tests: conditional. AnalysisTestCase { name: "cfg_if_then", method_name: "IfThen", @@ -741,7 +741,7 @@ pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ callgraph: None, dataflow: None, }, - // ========== CFG TESTS: LOOPS ========== + // CFG tests: loops. AnalysisTestCase { name: "cfg_while_loop", method_name: "WhileLoop", @@ -874,7 +874,7 @@ pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ callgraph: None, dataflow: None, }, - // ========== CFG TESTS: SWITCH ========== + // CFG tests: switch. AnalysisTestCase { name: "cfg_simple_switch", method_name: "SimpleSwitch", @@ -919,7 +919,7 @@ pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ callgraph: None, dataflow: None, }, - // ========== SSA TESTS ========== + // SSA tests. AnalysisTestCase { name: "ssa_phi_required", method_name: "PhiRequired", @@ -1052,7 +1052,7 @@ pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ callgraph: None, dataflow: None, }, - // ========== CALL GRAPH TESTS ========== + // Call graph tests. AnalysisTestCase { name: "callgraph_leaf", method_name: "LeafMethod", @@ -1215,7 +1215,7 @@ pub static ANALYSIS_TEST_CASES: &[AnalysisTestCase] = &[ }), dataflow: None, }, - // ========== DATA FLOW TESTS ========== + // Data flow tests. AnalysisTestCase { name: "dataflow_constant_prop", method_name: "ConstantProp", diff --git a/dotscope/src/utils/bitset.rs b/dotscope/src/utils/bitset.rs deleted file mode 100644 index acc11218..00000000 --- a/dotscope/src/utils/bitset.rs +++ /dev/null @@ -1,384 +0,0 @@ -//! A bit vector for efficient set operations. -//! -//! This module provides a compact bit set implementation optimized for -//! set operations commonly used in data flow analysis and other algorithms -//! that track sets of entities identified by small integers. -//! -//! # Features -//! -//! - Efficient storage: 64 elements per word -//! - Set operations: union, intersection, difference -//! - Iteration over set elements -//! - Clone-on-write friendly design -//! -//! # Example -//! -//! ```rust,ignore -//! use dotscope::utils::BitSet; -//! -//! let mut set = BitSet::new(100); -//! set.insert(0); -//! set.insert(50); -//! set.insert(99); -//! -//! assert!(set.contains(50)); -//! assert_eq!(set.count(), 3); -//! -//! for idx in set.iter() { -//! println!("Set contains: {}", idx); -//! } -//! ``` - -/// A bit vector for efficient set operations. -/// -/// This is commonly used for analyses that track sets of definitions, -/// variables, or other entities identified by small integers. -#[derive(Clone, Default, PartialEq, Eq, Hash)] -pub struct BitSet { - /// The bits, stored as a vector of words. - words: Vec, - /// The number of bits in the set. - len: usize, -} - -impl BitSet { - /// Creates a new empty bit set with the given capacity. - #[must_use] - pub fn new(capacity: usize) -> Self { - let num_words = capacity.div_ceil(64); - Self { - words: vec![0; num_words], - len: capacity, - } - } - - /// Creates a new bit set with all bits set. - #[must_use] - pub fn full(capacity: usize) -> Self { - let num_words = capacity.div_ceil(64); - let mut words = vec![u64::MAX; num_words]; - - // Clear the excess bits in the last word - if !capacity.is_multiple_of(64) { - if let Some(last) = words.last_mut() { - *last = (1u64 << (capacity % 64)).saturating_sub(1); - } - } - - Self { - words, - len: capacity, - } - } - - /// Returns the capacity of this bit set. - #[must_use] - pub const fn len(&self) -> usize { - self.len - } - - /// Returns `true` if the bit set has no bits set. - #[must_use] - pub fn is_empty(&self) -> bool { - self.words.iter().all(|&w| w == 0) - } - - /// Sets the bit at the given index. - /// - /// Returns `true` if the bit was newly set (was previously unset). - /// - /// # Panics - /// - /// Panics if `index >= self.len()`. - pub fn insert(&mut self, index: usize) -> bool { - assert!(index < self.len, "index out of bounds"); - let word = index / 64; - let bit = index % 64; - let mask = 1u64 << bit; - let Some(slot) = self.words.get_mut(word) else { - return false; - }; - let was_set = *slot & mask != 0; - *slot |= mask; - !was_set - } - - /// Clears the bit at the given index. - /// - /// # Panics - /// - /// Panics if `index >= self.len()`. - pub fn remove(&mut self, index: usize) { - assert!(index < self.len, "index out of bounds"); - let word = index / 64; - let bit = index % 64; - if let Some(slot) = self.words.get_mut(word) { - *slot &= !(1u64 << bit); - } - } - - /// Returns `true` if the bit at the given index is set. - /// - /// # Panics - /// - /// Panics if `index >= self.len()`. - #[must_use] - pub fn contains(&self, index: usize) -> bool { - assert!(index < self.len, "index out of bounds"); - let word = index / 64; - let bit = index % 64; - self.words - .get(word) - .is_some_and(|w| (w & (1u64 << bit)) != 0) - } - - /// Returns the number of bits set. - #[must_use] - pub fn count(&self) -> usize { - self.words.iter().map(|w| w.count_ones() as usize).sum() - } - - /// Clears all bits. - pub fn clear(&mut self) { - for word in &mut self.words { - *word = 0; - } - } - - /// Sets all bits. - pub fn fill(&mut self) { - for word in &mut self.words { - *word = u64::MAX; - } - // Clear excess bits in last word - if !self.len.is_multiple_of(64) { - if let Some(last) = self.words.last_mut() { - *last = (1u64 << (self.len % 64)).saturating_sub(1); - } - } - } - - /// Computes the union with another bit set (in place). - /// - /// Returns `true` if `self` changed. - pub fn union_with(&mut self, other: &Self) -> bool { - assert_eq!(self.len, other.len, "bit sets must have same length"); - let mut changed = false; - for (a, b) in self.words.iter_mut().zip(other.words.iter()) { - let old = *a; - *a |= *b; - changed |= old != *a; - } - changed - } - - /// Computes the intersection with another bit set (in place). - /// - /// Returns `true` if `self` changed. - pub fn intersect_with(&mut self, other: &Self) -> bool { - assert_eq!(self.len, other.len, "bit sets must have same length"); - let mut changed = false; - for (a, b) in self.words.iter_mut().zip(other.words.iter()) { - let old = *a; - *a &= *b; - changed |= old != *a; - } - changed - } - - /// Computes the difference with another bit set (in place). - /// - /// Removes all bits that are set in `other` from `self`. - /// Returns `true` if `self` changed. - pub fn difference_with(&mut self, other: &Self) -> bool { - assert_eq!(self.len, other.len, "bit sets must have same length"); - let mut changed = false; - for (a, b) in self.words.iter_mut().zip(other.words.iter()) { - let old = *a; - *a &= !*b; - changed |= old != *a; - } - changed - } - - /// Returns an iterator over the indices of set bits. - pub fn iter(&self) -> BitSetIter<'_> { - BitSetIter { - set: self, - word_idx: 0, - bit_idx: 0, - } - } -} - -impl std::fmt::Debug for BitSet { - fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{{")?; - let mut first = true; - for i in self.iter() { - if !first { - write!(f, ", ")?; - } - write!(f, "{i}")?; - first = false; - } - write!(f, "}}") - } -} - -/// Iterator over the set bits in a `BitSet`. -pub struct BitSetIter<'a> { - set: &'a BitSet, - word_idx: usize, - bit_idx: usize, -} - -impl Iterator for BitSetIter<'_> { - type Item = usize; - - fn next(&mut self) -> Option { - while self.word_idx < self.set.words.len() { - let word = *self.set.words.get(self.word_idx)?; - while self.bit_idx < 64 { - let idx = self - .word_idx - .checked_mul(64) - .and_then(|v| v.checked_add(self.bit_idx))?; - if idx >= self.set.len { - return None; - } - let bit = self.bit_idx; - self.bit_idx = self.bit_idx.saturating_add(1); - if (word & (1u64 << bit)) != 0 { - return Some(idx); - } - } - self.word_idx = self.word_idx.saturating_add(1); - self.bit_idx = 0; - } - None - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_bitset_basic() { - let mut bs = BitSet::new(100); - assert!(bs.is_empty()); - assert_eq!(bs.count(), 0); - - bs.insert(0); - bs.insert(50); - bs.insert(99); - - assert!(!bs.is_empty()); - assert_eq!(bs.count(), 3); - assert!(bs.contains(0)); - assert!(bs.contains(50)); - assert!(bs.contains(99)); - assert!(!bs.contains(1)); - } - - #[test] - fn test_bitset_remove() { - let mut bs = BitSet::new(100); - bs.insert(42); - assert!(bs.contains(42)); - - bs.remove(42); - assert!(!bs.contains(42)); - } - - #[test] - fn test_bitset_full() { - let bs = BitSet::full(100); - assert_eq!(bs.count(), 100); - for i in 0..100 { - assert!(bs.contains(i), "bit {i} should be set"); - } - } - - #[test] - fn test_bitset_union() { - let mut a = BitSet::new(100); - let mut b = BitSet::new(100); - - a.insert(0); - a.insert(1); - b.insert(1); - b.insert(2); - - let changed = a.union_with(&b); - assert!(changed); - assert!(a.contains(0)); - assert!(a.contains(1)); - assert!(a.contains(2)); - assert_eq!(a.count(), 3); - } - - #[test] - fn test_bitset_intersect() { - let mut a = BitSet::new(100); - let mut b = BitSet::new(100); - - a.insert(0); - a.insert(1); - a.insert(2); - b.insert(1); - b.insert(2); - b.insert(3); - - let changed = a.intersect_with(&b); - assert!(changed); - assert!(!a.contains(0)); - assert!(a.contains(1)); - assert!(a.contains(2)); - assert!(!a.contains(3)); - assert_eq!(a.count(), 2); - } - - #[test] - fn test_bitset_difference() { - let mut a = BitSet::new(100); - let mut b = BitSet::new(100); - - a.insert(0); - a.insert(1); - a.insert(2); - b.insert(1); - - let changed = a.difference_with(&b); - assert!(changed); - assert!(a.contains(0)); - assert!(!a.contains(1)); - assert!(a.contains(2)); - assert_eq!(a.count(), 2); - } - - #[test] - fn test_bitset_iter() { - let mut bs = BitSet::new(100); - bs.insert(5); - bs.insert(42); - bs.insert(99); - - let bits: Vec<_> = bs.iter().collect(); - assert_eq!(bits, vec![5, 42, 99]); - } - - #[test] - fn test_bitset_clear_fill() { - let mut bs = BitSet::new(100); - bs.insert(50); - assert_eq!(bs.count(), 1); - - bs.clear(); - assert!(bs.is_empty()); - - bs.fill(); - assert_eq!(bs.count(), 100); - } -} diff --git a/dotscope/src/utils/graph/algorithms/cycles.rs b/dotscope/src/utils/graph/algorithms/cycles.rs deleted file mode 100644 index d64d9ad4..00000000 --- a/dotscope/src/utils/graph/algorithms/cycles.rs +++ /dev/null @@ -1,418 +0,0 @@ -//! Cycle detection algorithms for directed graphs. -//! -//! This module provides algorithms to detect and find cycles in directed graphs. -//! Cycle detection is essential for: -//! -//! - Validating that dependency graphs are acyclic (DAGs) -//! - Detecting recursive call patterns in call graphs -//! - Identifying loops in control flow graphs - -use crate::utils::graph::{NodeId, Successors}; - -/// Checks if a directed graph contains any cycles reachable from the start node. -/// -/// This function uses depth-first search with a recursion stack to detect -/// back edges, which indicate cycles. It only considers nodes reachable -/// from the start node. -/// -/// # Arguments -/// -/// * `graph` - The graph to check for cycles -/// * `start` - The starting node for the search -/// -/// # Returns -/// -/// `true` if a cycle is found, `false` otherwise. -/// -/// # Complexity -/// -/// - Time: O(V + E) where V is the number of vertices and E is the number of edges -/// - Space: O(V) for the visited and recursion stack sets -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::has_cycle}; -/// -/// // Acyclic graph: A -> B -> C -/// let mut dag: DirectedGraph<(), ()> = DirectedGraph::new(); -/// let a = dag.add_node(()); -/// let b = dag.add_node(()); -/// let c = dag.add_node(()); -/// dag.add_edge(a, b, ()); -/// dag.add_edge(b, c, ()); -/// -/// assert!(!has_cycle(&dag, a)); -/// -/// // Cyclic graph: A -> B -> C -> A -/// let mut cyclic: DirectedGraph<(), ()> = DirectedGraph::new(); -/// let x = cyclic.add_node(()); -/// let y = cyclic.add_node(()); -/// let z = cyclic.add_node(()); -/// cyclic.add_edge(x, y, ()); -/// cyclic.add_edge(y, z, ()); -/// cyclic.add_edge(z, x, ()); -/// -/// assert!(has_cycle(&cyclic, x)); -/// ``` -pub fn has_cycle(graph: &G, start: NodeId) -> bool { - let node_count = graph.node_count(); - if start.index() >= node_count { - return false; - } - - let mut visited = vec![false; node_count]; - let mut in_stack = vec![false; node_count]; - - has_cycle_dfs(graph, start, &mut visited, &mut in_stack) -} - -/// Recursive helper for cycle detection. -fn has_cycle_dfs( - graph: &G, - node: NodeId, - visited: &mut [bool], - in_stack: &mut [bool], -) -> bool { - let idx = node.index(); - - if in_stack.get(idx).copied().unwrap_or(false) { - // Found a back edge - cycle detected - return true; - } - - if visited.get(idx).copied().unwrap_or(false) { - // Already processed this node in a different path, no cycle here - return false; - } - - if let Some(slot) = visited.get_mut(idx) { - *slot = true; - } - if let Some(slot) = in_stack.get_mut(idx) { - *slot = true; - } - - for successor in graph.successors(node) { - if has_cycle_dfs(graph, successor, visited, in_stack) { - return true; - } - } - - if let Some(slot) = in_stack.get_mut(idx) { - *slot = false; - } - false -} - -/// Finds a cycle in a directed graph if one exists, starting from the given node. -/// -/// If a cycle is found, returns a vector of nodes forming the cycle (starting -/// and ending with the same node). If no cycle is found, returns `None`. -/// -/// # Arguments -/// -/// * `graph` - The graph to search for cycles -/// * `start` - The starting node for the search -/// -/// # Returns -/// -/// `Some(Vec)` containing the cycle path if found, `None` otherwise. -/// The cycle path starts and ends with the same node. -/// -/// # Complexity -/// -/// - Time: O(V + E) -/// - Space: O(V) -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::find_cycle}; -/// -/// // Cyclic graph: A -> B -> C -> A -/// let mut graph: DirectedGraph = DirectedGraph::new(); -/// let a = graph.add_node('A'); -/// let b = graph.add_node('B'); -/// let c = graph.add_node('C'); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// graph.add_edge(c, a, ()); -/// -/// let cycle = find_cycle(&graph, a); -/// assert!(cycle.is_some()); -/// -/// let cycle_nodes = cycle.unwrap(); -/// assert!(cycle_nodes.len() >= 3); // At least 3 nodes in the cycle -/// assert_eq!(cycle_nodes.first(), cycle_nodes.last()); // Forms a cycle -/// ``` -pub fn find_cycle(graph: &G, start: NodeId) -> Option> { - let node_count = graph.node_count(); - if start.index() >= node_count { - return None; - } - - let mut visited = vec![false; node_count]; - let mut in_stack = vec![false; node_count]; - let mut path = Vec::new(); - - find_cycle_dfs(graph, start, &mut visited, &mut in_stack, &mut path) -} - -/// Recursive helper for finding a cycle. -fn find_cycle_dfs( - graph: &G, - node: NodeId, - visited: &mut [bool], - in_stack: &mut [bool], - path: &mut Vec, -) -> Option> { - let idx = node.index(); - - if in_stack.get(idx).copied().unwrap_or(false) { - // Found a back edge - extract the cycle - let cycle_start_pos = path.iter().position(|&n| n == node)?; - let mut cycle: Vec = path.get(cycle_start_pos..)?.to_vec(); - cycle.push(node); // Close the cycle - return Some(cycle); - } - - if visited.get(idx).copied().unwrap_or(false) { - return None; - } - - if let Some(slot) = visited.get_mut(idx) { - *slot = true; - } - if let Some(slot) = in_stack.get_mut(idx) { - *slot = true; - } - path.push(node); - - for successor in graph.successors(node) { - if let Some(cycle) = find_cycle_dfs(graph, successor, visited, in_stack, path) { - return Some(cycle); - } - } - - path.pop(); - if let Some(slot) = in_stack.get_mut(idx) { - *slot = false; - } - None -} - -#[cfg(test)] -mod tests { - use crate::utils::graph::{ - algorithms::cycles::{find_cycle, has_cycle}, - DirectedGraph, NodeId, - }; - - fn create_linear_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph - } - - fn create_diamond_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph - } - - fn create_simple_cycle() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, a, ()).unwrap(); - graph - } - - fn create_self_loop() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - graph.add_edge(a, a, ()).unwrap(); - graph - } - - fn create_complex_with_cycle() -> DirectedGraph<'static, &'static str, ()> { - // A -> B -> C -> D - // ^ | - // +-------+ - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, b, ()).unwrap(); - graph - } - - #[test] - fn test_has_cycle_linear() { - let graph = create_linear_graph(); - assert!(!has_cycle(&graph, NodeId::new(0))); - } - - #[test] - fn test_has_cycle_diamond() { - let graph = create_diamond_graph(); - assert!(!has_cycle(&graph, NodeId::new(0))); - } - - #[test] - fn test_has_cycle_simple_cycle() { - let graph = create_simple_cycle(); - assert!(has_cycle(&graph, NodeId::new(0))); - } - - #[test] - fn test_has_cycle_self_loop() { - let graph = create_self_loop(); - assert!(has_cycle(&graph, NodeId::new(0))); - } - - #[test] - fn test_has_cycle_complex() { - let graph = create_complex_with_cycle(); - assert!(has_cycle(&graph, NodeId::new(0))); - } - - #[test] - fn test_has_cycle_single_node_no_loop() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - assert!(!has_cycle(&graph, a)); - } - - #[test] - fn test_has_cycle_two_separate_cycles() { - // Two separate cycles not connected - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - - // Cycle 1: A <-> B - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - - // Cycle 2: C <-> D (disconnected from A, B) - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - - // Starting from A should find cycle in first component - assert!(has_cycle(&graph, a)); - - // Starting from C should find cycle in second component - assert!(has_cycle(&graph, c)); - } - - #[test] - fn test_find_cycle_linear() { - let graph = create_linear_graph(); - assert!(find_cycle(&graph, NodeId::new(0)).is_none()); - } - - #[test] - fn test_find_cycle_diamond() { - let graph = create_diamond_graph(); - assert!(find_cycle(&graph, NodeId::new(0)).is_none()); - } - - #[test] - fn test_find_cycle_simple_cycle() { - let graph = create_simple_cycle(); - let cycle = find_cycle(&graph, NodeId::new(0)); - - assert!(cycle.is_some()); - let cycle = cycle.unwrap(); - - // Cycle should form a loop (first == last) - assert_eq!(cycle.first(), cycle.last()); - - // Should have at least 3 nodes in a triangle cycle plus the closing node - assert!(cycle.len() >= 3); - } - - #[test] - fn test_find_cycle_self_loop() { - let graph = create_self_loop(); - let cycle = find_cycle(&graph, NodeId::new(0)); - - assert!(cycle.is_some()); - let cycle = cycle.unwrap(); - - // Self loop: [A, A] - assert_eq!(cycle.len(), 2); - assert_eq!(cycle[0], cycle[1]); - } - - #[test] - fn test_find_cycle_complex() { - let graph = create_complex_with_cycle(); - let cycle = find_cycle(&graph, NodeId::new(0)); - - assert!(cycle.is_some()); - let cycle = cycle.unwrap(); - - // Cycle B -> C -> D -> B - assert_eq!(cycle.first(), cycle.last()); - } - - #[test] - fn test_find_cycle_returns_valid_path() { - let graph = create_simple_cycle(); - let cycle = find_cycle(&graph, NodeId::new(0)).unwrap(); - - // Verify the path is valid: each node connects to the next - for i in 0..cycle.len() - 1 { - let current = cycle[i]; - let next = cycle[i + 1]; - let successors: Vec = graph.successors(current).collect(); - assert!( - successors.contains(&next), - "Invalid cycle path: no edge from {:?} to {:?}", - current, - next - ); - } - } - - #[test] - fn test_find_cycle_disconnected_cycle() { - // Entry point not in the cycle - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("Entry"); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - - graph.add_edge(entry, a, ()).unwrap(); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, a, ()).unwrap(); // Cycle: A -> B -> C -> A - - let cycle = find_cycle(&graph, entry); - assert!(cycle.is_some()); - } -} diff --git a/dotscope/src/utils/graph/algorithms/dominators.rs b/dotscope/src/utils/graph/algorithms/dominators.rs deleted file mode 100644 index 619f0bf7..00000000 --- a/dotscope/src/utils/graph/algorithms/dominators.rs +++ /dev/null @@ -1,1095 +0,0 @@ -//! Dominator tree computation using the Lengauer-Tarjan algorithm. -//! -//! This module provides efficient dominator tree computation for rooted directed -//! graphs. The dominator tree is a fundamental data structure for: -//! -//! - SSA (Static Single Assignment) construction -//! - Loop detection and analysis -//! - Compiler optimizations -//! - Control flow analysis -//! -//! # Theory -//! -//! A node `d` **dominates** a node `n` if every path from the entry node to `n` -//! must pass through `d`. The **immediate dominator** of `n` (idom(n)) is the -//! unique node that strictly dominates `n` but does not strictly dominate any -//! other dominator of `n`. -//! -//! The dominator tree is formed by making each node's immediate dominator its -//! parent. The entry node is the root (it has no dominator). -//! -//! # Algorithm -//! -//! This implementation uses the Lengauer-Tarjan algorithm with path compression, -//! achieving O(V α(V)) time complexity where α is the inverse Ackermann function -//! (effectively constant for all practical inputs). - -use crate::utils::{ - graph::{NodeId, RootedGraph, Successors}, - BitSet, -}; - -/// Result of dominator tree computation. -/// -/// The dominator tree represents the dominance relationships in a control flow -/// graph. Each node (except the entry) has exactly one immediate dominator. -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::compute_dominators}; -/// -/// // Simple CFG: entry -> a -> b -> exit -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let entry = graph.add_node("entry"); -/// let a = graph.add_node("a"); -/// let b = graph.add_node("b"); -/// let exit = graph.add_node("exit"); -/// -/// graph.add_edge(entry, a, ()); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, exit, ()); -/// -/// let dom_tree = compute_dominators(&graph, entry); -/// -/// // entry dominates everything -/// assert!(dom_tree.dominates(entry, exit)); -/// // a is the immediate dominator of b -/// assert_eq!(dom_tree.immediate_dominator(b), Some(a)); -/// ``` -#[derive(Debug, Clone)] -pub struct DominatorTree { - /// The entry (root) node of the dominator tree - entry: NodeId, - /// Immediate dominator for each node (indexed by node ID) - /// Entry node maps to itself (or we could use Option, but this simplifies queries) - idom: Vec, - /// Pre-computed children for each node (indexed by node ID) - /// children[i] = list of nodes whose immediate dominator is node i - children: Vec>, - /// Number of nodes in the graph - node_count: usize, -} - -impl DominatorTree { - /// Returns the entry (root) node of the dominator tree. - #[inline] - pub fn entry(&self) -> NodeId { - self.entry - } - - /// Returns the immediate dominator of a node, or `None` for the entry node - /// or for nodes whose index is out of bounds. - /// - /// The immediate dominator is the closest strict dominator of the node. - #[inline] - pub fn immediate_dominator(&self, node: NodeId) -> Option { - if node == self.entry { - None - } else { - self.idom.get(node.index()).copied() - } - } - - /// Checks if node `a` dominates node `b`. - /// - /// A node dominates itself. The entry node dominates all reachable nodes. - /// Returns `false` if `b` is unreachable from the entry node. - /// - /// # Complexity - /// - /// O(depth) where depth is the depth of `b` in the dominator tree. - pub fn dominates(&self, a: NodeId, b: NodeId) -> bool { - if a == b { - return true; - } - - // Handle out-of-bounds node indices - if b.index() >= self.node_count { - return false; - } - - let mut current = b; - while current != self.entry { - // Check for unreachable nodes (sentinel value) or out-of-bounds - let Some(&idom) = self.idom.get(current.index()) else { - return false; - }; - if idom == a { - return true; - } - // Detect infinite loop (unreachable node pointing to sentinel) - if idom == current { - return false; - } - current = idom; - } - - // Only the entry can dominate the entry - a == self.entry - } - - /// Checks if node `a` strictly dominates node `b`. - /// - /// Strict dominance excludes self-dominance: a strictly dominates b iff - /// a dominates b and a ≠ b. - #[inline] - pub fn strictly_dominates(&self, a: NodeId, b: NodeId) -> bool { - a != b && self.dominates(a, b) - } - - /// Returns an iterator over all dominators of a node, from the node itself - /// up to (and including) the entry node. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId, algorithms::compute_dominators}; - /// - /// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - /// let entry = graph.add_node(()); - /// let a = graph.add_node(()); - /// let b = graph.add_node(()); - /// graph.add_edge(entry, a, ()); - /// graph.add_edge(a, b, ()); - /// - /// let dom_tree = compute_dominators(&graph, entry); - /// let dominators: Vec = dom_tree.dominators(b).collect(); - /// // b is dominated by b, a, and entry - /// assert_eq!(dominators, vec![b, a, entry]); - /// ``` - pub fn dominators(&self, node: NodeId) -> DominatorIterator<'_> { - DominatorIterator { - tree: self, - current: Some(node), - } - } - - /// Returns the depth of a node in the dominator tree. - /// - /// The entry node has depth 0. Returns 0 for nodes whose index is out of - /// bounds or that are unreachable from the entry. - pub fn depth(&self, node: NodeId) -> usize { - let mut depth: usize = 0; - let mut current = node; - while current != self.entry { - let Some(&idom) = self.idom.get(current.index()) else { - return depth; - }; - // Sentinel idom (== current) means unreachable: stop walking. - if idom == current { - return depth; - } - current = idom; - depth = depth.saturating_add(1); - } - depth - } - - /// Returns all children of a node in the dominator tree. - /// - /// Children are nodes whose immediate dominator is the given node. - /// - /// # Complexity - /// - /// O(1) — children are pre-computed during dominator tree construction. - pub fn children(&self, node: NodeId) -> &[NodeId] { - self.children.get(node.index()).map_or(&[], Vec::as_slice) - } - - /// Returns the number of nodes in the dominator tree. - #[inline] - pub fn node_count(&self) -> usize { - self.node_count - } -} - -/// Iterator over dominators of a node, from the node up to the entry. -pub struct DominatorIterator<'a> { - tree: &'a DominatorTree, - current: Option, -} - -impl Iterator for DominatorIterator<'_> { - type Item = NodeId; - - fn next(&mut self) -> Option { - let current = self.current?; - - if current == self.tree.entry { - self.current = None; - Some(current) - } else { - self.current = self.tree.idom.get(current.index()).copied(); - Some(current) - } - } -} - -/// Computes the dominator tree for a rooted graph using the Lengauer-Tarjan algorithm. -/// -/// This algorithm efficiently computes the immediate dominator for every node -/// reachable from the entry node. -/// -/// # Arguments -/// -/// * `graph` - The graph to analyze (must implement `RootedGraph`) -/// -/// # Returns -/// -/// A `DominatorTree` containing the dominator relationships. -/// -/// # Complexity -/// -/// - Time: O(V α(V)) where α is the inverse Ackermann function -/// - Space: O(V) -/// -/// # Algorithm Overview -/// -/// The Lengauer-Tarjan algorithm works in several phases: -/// -/// 1. **DFS numbering**: Assign DFS numbers to nodes and compute the DFS tree -/// 2. **Semidominators**: Compute semidominators using the Semidominator Theorem -/// 3. **Implicit idom**: Compute implicit immediate dominators -/// 4. **Explicit idom**: Convert implicit to explicit immediate dominators -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::compute_dominators}; -/// -/// // Diamond CFG: -/// // entry -/// // / \ -/// // a b -/// // \ / -/// // exit -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let entry = graph.add_node("entry"); -/// let a = graph.add_node("a"); -/// let b = graph.add_node("b"); -/// let exit = graph.add_node("exit"); -/// -/// graph.add_edge(entry, a, ()); -/// graph.add_edge(entry, b, ()); -/// graph.add_edge(a, exit, ()); -/// graph.add_edge(b, exit, ()); -/// -/// let dom_tree = compute_dominators(&graph, entry); -/// -/// // Entry dominates all nodes -/// assert!(dom_tree.dominates(entry, a)); -/// assert!(dom_tree.dominates(entry, b)); -/// assert!(dom_tree.dominates(entry, exit)); -/// -/// // a and b don't dominate exit (there are alternative paths) -/// assert!(!dom_tree.strictly_dominates(a, exit)); -/// assert!(!dom_tree.strictly_dominates(b, exit)); -/// -/// // exit's immediate dominator is entry (not a or b) -/// assert_eq!(dom_tree.immediate_dominator(exit), Some(entry)); -/// ``` -pub fn compute_dominators(graph: &G, entry: NodeId) -> DominatorTree -where - G: Successors, -{ - let node_count = graph.node_count(); - - if node_count == 0 { - return DominatorTree { - entry, - idom: Vec::new(), - children: Vec::new(), - node_count: 0, - }; - } - - // Pre-compute predecessors in a single O(V+E) pass for the Lengauer-Tarjan algorithm - let predecessors = precompute_predecessors(graph); - - // Lengauer-Tarjan algorithm implementation - let mut lt = LengauerTarjan::new(node_count, entry); - lt.compute(graph, &predecessors); - - // Build children list in a single O(V) pass from the idom array - let mut children: Vec> = vec![Vec::new(); node_count]; - for i in 0..node_count { - let node = NodeId::new(i); - if node == entry { - continue; - } - let Some(parent) = lt.idom.get(i).copied() else { - continue; - }; - if let Some(slot) = children.get_mut(parent.index()) { - slot.push(node); - } - } - - DominatorTree { - entry, - idom: lt.idom, - children, - node_count, - } -} - -/// Convenience function to compute dominators for a [`RootedGraph`]. -/// -/// This is equivalent to calling `compute_dominators(graph, graph.entry())`. -pub fn compute_dominators_rooted(graph: &G) -> DominatorTree -where - G: RootedGraph, -{ - compute_dominators(graph, graph.entry()) -} - -/// Pre-computes the predecessor list for all nodes in a single O(V+E) pass. -/// -/// Returns a vector where `result[i]` contains all predecessors of node `i`. -fn precompute_predecessors(graph: &G) -> Vec> { - let n = graph.node_count(); - let mut preds: Vec> = vec![Vec::new(); n]; - for i in 0..n { - let v = NodeId::new(i); - for succ in graph.successors(v) { - if let Some(slot) = preds.get_mut(succ.index()) { - slot.push(v); - } - } - } - preds -} - -/// Internal state for the Lengauer-Tarjan algorithm. -struct LengauerTarjan { - /// Number of nodes - n: usize, - /// Entry node - entry: NodeId, - /// DFS number for each node (0 = not visited) - dfnum: Vec, - /// Node with each DFS number (inverse of dfnum) - vertex: Vec, - /// Parent in DFS tree - parent: Vec, - /// Semidominator (by DFS number, stored as node ID) - semi: Vec, - /// Immediate dominator (final result) - idom: Vec, - /// Ancestor in the forest for link-eval - ancestor: Vec, - /// Best node on path to ancestor (for path compression) - best: Vec, - /// Bucket for each node (nodes whose semidominator is this node) - bucket: Vec>, - /// Current DFS counter - dfs_counter: usize, -} - -impl LengauerTarjan { - /// Sentinel NodeId representing "uninitialized" or "out-of-graph" values. - #[inline] - fn sentinel() -> NodeId { - NodeId::new(usize::MAX) - } - - #[inline] - fn get_node(slice: &[NodeId], i: usize) -> NodeId { - slice.get(i).copied().unwrap_or(Self::sentinel()) - } - - #[inline] - fn get_dfnum(&self, n: NodeId) -> usize { - self.dfnum.get(n.index()).copied().unwrap_or(0) - } - - #[inline] - fn set_node(slice: &mut [NodeId], i: usize, value: NodeId) { - if let Some(slot) = slice.get_mut(i) { - *slot = value; - } - } - - fn new(n: usize, entry: NodeId) -> Self { - let sentinel = NodeId::new(usize::MAX); - Self { - n, - entry, - dfnum: vec![0; n], - vertex: vec![sentinel; n], - parent: vec![sentinel; n], - semi: (0..n).map(NodeId::new).collect(), - idom: vec![sentinel; n], - ancestor: vec![sentinel; n], - best: (0..n).map(NodeId::new).collect(), - bucket: vec![Vec::new(); n], - dfs_counter: 0, - } - } - - fn compute(&mut self, graph: &G, predecessors: &[Vec]) { - // Phase 1: DFS numbering - self.dfs(graph, self.entry); - - // Process nodes in reverse DFS order (excluding entry) - for i in (1..self.dfs_counter).rev() { - let w = Self::get_node(&self.vertex, i); - let parent_w = Self::get_node(&self.parent, w.index()); - - // Phase 2: Compute semidominators - // semi(w) = min { v : v -> w is a CFG edge and dfnum(v) < dfnum(w) } ∪ - // { semi(u) : u -> w via tree edges where dfnum(u) > dfnum(w) } - let preds_w: &[NodeId] = predecessors.get(w.index()).map_or(&[], Vec::as_slice); - for v in preds_w { - let v = *v; - if self.get_dfnum(v) == 0 { - // v is unreachable from entry, skip - continue; - } - let u = self.eval(v); - let semi_u = Self::get_node(&self.semi, u.index()); - let semi_w = Self::get_node(&self.semi, w.index()); - if self.get_dfnum(semi_u) < self.get_dfnum(semi_w) { - Self::set_node(&mut self.semi, w.index(), semi_u); - } - } - - // Add w to bucket of its semidominator - let semi_w = Self::get_node(&self.semi, w.index()); - if let Some(bucket) = self.bucket.get_mut(semi_w.index()) { - bucket.push(w); - } - - // Link w into the forest - self.link(parent_w, w); - - // Phase 3: Implicitly compute immediate dominators - // Process bucket of parent(w) - let bucket = self - .bucket - .get_mut(parent_w.index()) - .map_or_else(Vec::new, std::mem::take); - for v in bucket { - let u = self.eval(v); - let semi_u = Self::get_node(&self.semi, u.index()); - let semi_v = Self::get_node(&self.semi, v.index()); - if semi_u == semi_v { - // idom(v) = semi(v) = parent(w) - Self::set_node(&mut self.idom, v.index(), parent_w); - } else { - // idom(v) = idom(u) (will be computed later) - Self::set_node(&mut self.idom, v.index(), u); - } - } - } - - // Phase 4: Explicitly compute immediate dominators - for i in 1..self.dfs_counter { - let w = Self::get_node(&self.vertex, i); - let idom_w = Self::get_node(&self.idom, w.index()); - let semi_w = Self::get_node(&self.semi, w.index()); - if idom_w != semi_w { - let idom_idom = Self::get_node(&self.idom, idom_w.index()); - Self::set_node(&mut self.idom, w.index(), idom_idom); - } - } - - // Entry node dominates itself - Self::set_node(&mut self.idom, self.entry.index(), self.entry); - } - - /// DFS traversal to assign DFS numbers and build DFS tree. - fn dfs(&mut self, graph: &G, start: NodeId) { - let mut stack = vec![(start, false)]; - - while let Some((node, processed)) = stack.pop() { - let idx = node.index(); - - if processed { - continue; - } - - if self.dfnum.get(idx).copied().unwrap_or(0) != 0 { - continue; - } - - self.dfs_counter = self.dfs_counter.saturating_add(1); - if let Some(slot) = self.dfnum.get_mut(idx) { - *slot = self.dfs_counter; - } - // dfs_counter is at least 1 here, so subtracting 1 is safe. - let vertex_idx = self.dfs_counter.saturating_sub(1); - Self::set_node(&mut self.vertex, vertex_idx, node); - - for succ in graph.successors(node) { - if self.get_dfnum(succ) == 0 { - Self::set_node(&mut self.parent, succ.index(), node); - stack.push((succ, false)); - } - } - } - } - - /// Link v as a child of w in the spanning forest. - fn link(&mut self, w: NodeId, v: NodeId) { - Self::set_node(&mut self.ancestor, v.index(), w); - } - - /// Evaluate: find the node with minimum semidominator on the path to the root. - fn eval(&mut self, v: NodeId) -> NodeId { - let sentinel = Self::sentinel(); - if Self::get_node(&self.ancestor, v.index()) == sentinel { - return v; - } - - self.compress(v); - Self::get_node(&self.best, v.index()) - } - - /// Path compression for the forest (iterative). - /// - /// Collects the ancestor chain from `v` up to the forest root, then walks - /// back down to update `best` and `ancestor` for every node on the path. - /// This avoids O(V) recursion depth that can overflow the stack on large - /// CFF-obfuscated CFGs (500+ blocks). - fn compress(&mut self, v: NodeId) { - let sentinel = Self::sentinel(); - - // Phase 1: collect the path from v upward until we reach a node - // whose ancestor is the forest root (ancestor == sentinel). - let mut path = Vec::new(); - let mut u = v; - loop { - let anc_u = Self::get_node(&self.ancestor, u.index()); - let anc_anc_u = Self::get_node(&self.ancestor, anc_u.index()); - if anc_anc_u == sentinel { - break; - } - path.push(u); - u = anc_u; - } - - // Phase 2: walk the path in reverse (top-down) to propagate best - // values and flatten ancestor pointers — same semantics as the - // recursive version's post-order updates. - for &node in path.iter().rev() { - let ancestor_node = Self::get_node(&self.ancestor, node.index()); - let best_ancestor = Self::get_node(&self.best, ancestor_node.index()); - let best_node = Self::get_node(&self.best, node.index()); - - let semi_ba = Self::get_node(&self.semi, best_ancestor.index()); - let semi_bn = Self::get_node(&self.semi, best_node.index()); - if self.get_dfnum(semi_ba) < self.get_dfnum(semi_bn) { - Self::set_node(&mut self.best, node.index(), best_ancestor); - } - - let new_anc = Self::get_node(&self.ancestor, ancestor_node.index()); - Self::set_node(&mut self.ancestor, node.index(), new_anc); - } - } -} - -/// Computes dominance frontiers for all nodes. -/// -/// The dominance frontier of a node `n` is the set of all nodes `m` such that: -/// - `n` dominates a predecessor of `m`, but -/// - `n` does not strictly dominate `m` -/// -/// Dominance frontiers are essential for placing φ-functions in SSA construction. -/// -/// # Arguments -/// -/// * `graph` - The control flow graph -/// * `dom_tree` - The precomputed dominator tree -/// -/// # Returns -/// -/// A vector where `result[i]` contains the dominance frontier of node `i`. -/// -/// # Complexity -/// -/// - Time: O(V + E) -/// - Space: O(V²) worst case for the frontiers -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::{compute_dominators, compute_dominance_frontiers}}; -/// -/// // Diamond CFG with join point -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let entry = graph.add_node("entry"); -/// let left = graph.add_node("left"); -/// let right = graph.add_node("right"); -/// let join = graph.add_node("join"); -/// -/// graph.add_edge(entry, left, ()); -/// graph.add_edge(entry, right, ()); -/// graph.add_edge(left, join, ()); -/// graph.add_edge(right, join, ()); -/// -/// let dom_tree = compute_dominators(&graph, entry); -/// let frontiers = compute_dominance_frontiers(&graph, &dom_tree); -/// -/// // The dominance frontier of 'left' includes 'join' (where paths merge) -/// assert!(frontiers[left.index()].contains(join.index())); -/// // The dominance frontier of 'right' includes 'join' -/// assert!(frontiers[right.index()].contains(join.index())); -/// ``` -pub fn compute_dominance_frontiers(graph: &G, dom_tree: &DominatorTree) -> Vec -where - G: Successors, -{ - let n = graph.node_count(); - let mut frontiers: Vec = vec![BitSet::new(n); n]; - - // Pre-compute predecessors in a single O(V+E) pass - let all_preds = precompute_predecessors(graph); - - // For each node, check if it's a join point (has multiple predecessors) - // For each join point, walk up the dominator tree from each predecessor - for (node_idx, preds) in all_preds.iter().enumerate() { - let node = NodeId::new(node_idx); - - if preds.len() < 2 { - continue; // Not a join point - } - - // For each predecessor, walk up its dominators until we reach idom(node) - let idom_node = dom_tree.immediate_dominator(node); - - for &pred in preds { - let mut runner = pred; - // Guard against unreachable nodes (their index may be invalid/sentinel) - while Some(runner) != idom_node && runner != dom_tree.entry() && runner.index() < n { - if let Some(slot) = frontiers.get_mut(runner.index()) { - slot.insert(node.index()); - } - if let Some(idom) = dom_tree.immediate_dominator(runner) { - // Check for sentinel value (unreachable node) - if idom.index() >= n { - break; - } - runner = idom; - } else { - break; - } - } - // Also check entry if needed (guard against invalid index) - if Some(runner) != idom_node && runner == dom_tree.entry() && runner.index() < n { - if let Some(slot) = frontiers.get_mut(runner.index()) { - slot.insert(node.index()); - } - } - } - } - - frontiers -} - -#[cfg(test)] -mod tests { - use crate::utils::graph::{ - algorithms::dominators::{compute_dominance_frontiers, compute_dominators}, - DirectedGraph, NodeId, - }; - - #[test] - fn test_dominator_empty_graph() { - let graph: DirectedGraph<(), ()> = DirectedGraph::new(); - // With empty graph, we need a valid entry - this is a degenerate case - let entry = NodeId::new(0); - let dom_tree = compute_dominators(&graph, entry); - assert_eq!(dom_tree.node_count(), 0); - } - - #[test] - fn test_dominator_single_node() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let entry = graph.add_node(()); - - let dom_tree = compute_dominators(&graph, entry); - - assert_eq!(dom_tree.entry(), entry); - assert_eq!(dom_tree.immediate_dominator(entry), None); - assert!(dom_tree.dominates(entry, entry)); - assert_eq!(dom_tree.depth(entry), 0); - } - - #[test] - fn test_dominator_linear_chain() { - // entry -> a -> b -> c - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let a = graph.add_node("a"); - let b = graph.add_node("b"); - let c = graph.add_node("c"); - - graph.add_edge(entry, a, ()).unwrap(); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // Check immediate dominators - assert_eq!(dom_tree.immediate_dominator(entry), None); - assert_eq!(dom_tree.immediate_dominator(a), Some(entry)); - assert_eq!(dom_tree.immediate_dominator(b), Some(a)); - assert_eq!(dom_tree.immediate_dominator(c), Some(b)); - - // Check dominance relationships - assert!(dom_tree.dominates(entry, a)); - assert!(dom_tree.dominates(entry, b)); - assert!(dom_tree.dominates(entry, c)); - assert!(dom_tree.dominates(a, b)); - assert!(dom_tree.dominates(a, c)); - assert!(dom_tree.dominates(b, c)); - - // Check non-dominance - assert!(!dom_tree.dominates(c, b)); - assert!(!dom_tree.dominates(b, a)); - - // Check depths - assert_eq!(dom_tree.depth(entry), 0); - assert_eq!(dom_tree.depth(a), 1); - assert_eq!(dom_tree.depth(b), 2); - assert_eq!(dom_tree.depth(c), 3); - } - - #[test] - fn test_dominator_diamond() { - // Diamond CFG: - // entry - // / \ - // a b - // \ / - // exit - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let a = graph.add_node("a"); - let b = graph.add_node("b"); - let exit = graph.add_node("exit"); - - graph.add_edge(entry, a, ()).unwrap(); - graph.add_edge(entry, b, ()).unwrap(); - graph.add_edge(a, exit, ()).unwrap(); - graph.add_edge(b, exit, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // entry is immediate dominator of a, b, and exit - assert_eq!(dom_tree.immediate_dominator(a), Some(entry)); - assert_eq!(dom_tree.immediate_dominator(b), Some(entry)); - assert_eq!(dom_tree.immediate_dominator(exit), Some(entry)); - - // a and b don't dominate exit (alternative paths exist) - assert!(!dom_tree.strictly_dominates(a, exit)); - assert!(!dom_tree.strictly_dominates(b, exit)); - - // entry dominates all - assert!(dom_tree.dominates(entry, a)); - assert!(dom_tree.dominates(entry, b)); - assert!(dom_tree.dominates(entry, exit)); - } - - #[test] - fn test_dominator_if_then_else() { - // if-then-else: - // entry - // | - // cond - // / \ - // then else - // \ / - // merge - // | - // exit - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let cond = graph.add_node("cond"); - let then_b = graph.add_node("then"); - let else_b = graph.add_node("else"); - let merge = graph.add_node("merge"); - let exit = graph.add_node("exit"); - - graph.add_edge(entry, cond, ()).unwrap(); - graph.add_edge(cond, then_b, ()).unwrap(); - graph.add_edge(cond, else_b, ()).unwrap(); - graph.add_edge(then_b, merge, ()).unwrap(); - graph.add_edge(else_b, merge, ()).unwrap(); - graph.add_edge(merge, exit, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // Check dominator chain - assert_eq!(dom_tree.immediate_dominator(cond), Some(entry)); - assert_eq!(dom_tree.immediate_dominator(then_b), Some(cond)); - assert_eq!(dom_tree.immediate_dominator(else_b), Some(cond)); - assert_eq!(dom_tree.immediate_dominator(merge), Some(cond)); - assert_eq!(dom_tree.immediate_dominator(exit), Some(merge)); - - // cond dominates merge and exit - assert!(dom_tree.dominates(cond, merge)); - assert!(dom_tree.dominates(cond, exit)); - - // then/else don't dominate merge - assert!(!dom_tree.strictly_dominates(then_b, merge)); - assert!(!dom_tree.strictly_dominates(else_b, merge)); - } - - #[test] - fn test_dominator_loop() { - // Simple loop: - // entry - // | - // v - // +-> header - // | | - // | v - // +-- body - // | - // v - // exit - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let header = graph.add_node("header"); - let body = graph.add_node("body"); - let exit = graph.add_node("exit"); - - graph.add_edge(entry, header, ()).unwrap(); - graph.add_edge(header, body, ()).unwrap(); - graph.add_edge(body, header, ()).unwrap(); // back edge - graph.add_edge(body, exit, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // header dominates body and exit - assert!(dom_tree.dominates(header, body)); - // body does not dominate header (despite the back edge) - assert!(!dom_tree.strictly_dominates(body, header)); - } - - #[test] - fn test_dominator_iterator() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let a = graph.add_node("a"); - let b = graph.add_node("b"); - let c = graph.add_node("c"); - - graph.add_edge(entry, a, ()).unwrap(); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // Iterate dominators of c - let dominators: Vec = dom_tree.dominators(c).collect(); - assert_eq!(dominators, vec![c, b, a, entry]); - - // Iterate dominators of entry - let dominators: Vec = dom_tree.dominators(entry).collect(); - assert_eq!(dominators, vec![entry]); - } - - #[test] - fn test_dominator_children() { - // Diamond CFG - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let a = graph.add_node("a"); - let b = graph.add_node("b"); - let exit = graph.add_node("exit"); - - graph.add_edge(entry, a, ()).unwrap(); - graph.add_edge(entry, b, ()).unwrap(); - graph.add_edge(a, exit, ()).unwrap(); - graph.add_edge(b, exit, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // entry has children: a, b, exit - let mut children = dom_tree.children(entry).to_vec(); - children.sort_by_key(|n| n.index()); - assert_eq!(children, vec![a, b, exit]); - - // a, b, exit have no children - assert!(dom_tree.children(a).is_empty()); - assert!(dom_tree.children(b).is_empty()); - assert!(dom_tree.children(exit).is_empty()); - } - - #[test] - fn test_dominance_frontier_diamond() { - // Diamond CFG - classic case for dominance frontiers - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let left = graph.add_node("left"); - let right = graph.add_node("right"); - let join = graph.add_node("join"); - - graph.add_edge(entry, left, ()).unwrap(); - graph.add_edge(entry, right, ()).unwrap(); - graph.add_edge(left, join, ()).unwrap(); - graph.add_edge(right, join, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - let frontiers = compute_dominance_frontiers(&graph, &dom_tree); - - // entry has no dominance frontier - assert!(frontiers[entry.index()].is_empty()); - - // left's dominance frontier is {join} - assert!(frontiers[left.index()].contains(join.index())); - assert_eq!(frontiers[left.index()].count(), 1); - - // right's dominance frontier is {join} - assert!(frontiers[right.index()].contains(join.index())); - assert_eq!(frontiers[right.index()].count(), 1); - - // join has no dominance frontier (no successors) - assert!(frontiers[join.index()].is_empty()); - } - - #[test] - fn test_dominance_frontier_loop() { - // Loop with header - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let header = graph.add_node("header"); - let body = graph.add_node("body"); - let exit = graph.add_node("exit"); - - graph.add_edge(entry, header, ()).unwrap(); - graph.add_edge(header, body, ()).unwrap(); - graph.add_edge(body, header, ()).unwrap(); // back edge - graph.add_edge(header, exit, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - let frontiers = compute_dominance_frontiers(&graph, &dom_tree); - - // body's dominance frontier includes header (the loop header) - assert!(frontiers[body.index()].contains(header.index())); - } - - #[test] - fn test_dominance_frontier_nested_if() { - // Nested if structure: - // entry - // | - // if1 - // / \ - // a b - // / \ \ - // c d e - // \ / / - // join1 / - // \ / - // join2 - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let if1 = graph.add_node("if1"); - let a = graph.add_node("a"); - let b = graph.add_node("b"); - let c = graph.add_node("c"); - let d = graph.add_node("d"); - let e = graph.add_node("e"); - let join1 = graph.add_node("join1"); - let join2 = graph.add_node("join2"); - - graph.add_edge(entry, if1, ()).unwrap(); - graph.add_edge(if1, a, ()).unwrap(); - graph.add_edge(if1, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(a, d, ()).unwrap(); - graph.add_edge(b, e, ()).unwrap(); - graph.add_edge(c, join1, ()).unwrap(); - graph.add_edge(d, join1, ()).unwrap(); - graph.add_edge(e, join2, ()).unwrap(); - graph.add_edge(join1, join2, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - let frontiers = compute_dominance_frontiers(&graph, &dom_tree); - - // c and d have join1 in their dominance frontier - assert!(frontiers[c.index()].contains(join1.index())); - assert!(frontiers[d.index()].contains(join1.index())); - - // join1 and e have join2 in their dominance frontier - assert!(frontiers[join1.index()].contains(join2.index())); - assert!(frontiers[e.index()].contains(join2.index())); - } - - #[test] - fn test_strictly_dominates() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let a = graph.add_node("a"); - - graph.add_edge(entry, a, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // entry dominates itself but doesn't strictly dominate itself - assert!(dom_tree.dominates(entry, entry)); - assert!(!dom_tree.strictly_dominates(entry, entry)); - - // entry strictly dominates a - assert!(dom_tree.strictly_dominates(entry, a)); - } - - #[test] - fn test_dominator_complex_cfg() { - // More complex CFG with multiple paths and joins - // - // entry - // | - // a - // / \ - // b c - // | | - // d e - // \ / \ - // f g - // | - // h - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let entry = graph.add_node("entry"); - let a = graph.add_node("a"); - let b = graph.add_node("b"); - let c = graph.add_node("c"); - let d = graph.add_node("d"); - let e = graph.add_node("e"); - let f = graph.add_node("f"); - let g = graph.add_node("g"); - let h = graph.add_node("h"); - - graph.add_edge(entry, a, ()).unwrap(); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(c, e, ()).unwrap(); - graph.add_edge(d, f, ()).unwrap(); - graph.add_edge(e, f, ()).unwrap(); - graph.add_edge(e, g, ()).unwrap(); - graph.add_edge(f, h, ()).unwrap(); - - let dom_tree = compute_dominators(&graph, entry); - - // a dominates everything below it - assert!(dom_tree.dominates(a, b)); - assert!(dom_tree.dominates(a, c)); - assert!(dom_tree.dominates(a, d)); - assert!(dom_tree.dominates(a, e)); - assert!(dom_tree.dominates(a, f)); - assert!(dom_tree.dominates(a, g)); - assert!(dom_tree.dominates(a, h)); - - // f's immediate dominator is a (not d or e, since there are multiple paths) - assert_eq!(dom_tree.immediate_dominator(f), Some(a)); - - // g's immediate dominator is e (only one path to g) - assert_eq!(dom_tree.immediate_dominator(g), Some(e)); - } -} diff --git a/dotscope/src/utils/graph/algorithms/mod.rs b/dotscope/src/utils/graph/algorithms/mod.rs deleted file mode 100644 index fca0314f..00000000 --- a/dotscope/src/utils/graph/algorithms/mod.rs +++ /dev/null @@ -1,99 +0,0 @@ -//! Graph algorithms for program analysis. -//! -//! This module provides standard graph algorithms optimized for program analysis -//! tasks such as control flow analysis, dominator computation, and dependency -//! resolution. -//! -//! # Available Algorithms -//! -//! ## Traversal -//! -//! - [`dfs`] - Depth-first search traversal -//! - [`bfs`] - Breadth-first search traversal -//! - [`reverse_postorder`] - Reverse postorder traversal (useful for data flow) -//! - [`postorder`] - Postorder traversal -//! -//! ## Cycle Detection -//! -//! - [`has_cycle`] - Check if a graph contains any cycles -//! - [`find_cycle`] - Find a cycle if one exists -//! -//! ## Topological Ordering -//! -//! - [`topological_sort`] - Compute a topological ordering of nodes -//! -//! ## Dominator Analysis -//! -//! - [`compute_dominators`] - Compute the dominator tree using Lengauer-Tarjan -//! - [`compute_dominance_frontiers`] - Compute dominance frontiers for SSA -//! - [`DominatorTree`] - Result of dominator computation -//! -//! ## Strongly Connected Components -//! -//! - [`strongly_connected_components`] - Tarjan's SCC algorithm -//! -//! # Algorithm Selection -//! -//! | Algorithm | Time Complexity | Use Case | -//! |-----------|-----------------|----------| -//! | DFS/BFS | O(V + E) | General traversal | -//! | Topological Sort | O(V + E) | Dependency ordering | -//! | Dominators | O(V α(V)) | SSA construction, loop analysis | -//! | SCC | O(V + E) | Recursion detection, call graph analysis | -//! -//! # Examples -//! -//! ## Traversal -//! -//! ```rust,ignore -//! use dotscope::graph::{DirectedGraph, NodeId, algorithms}; -//! -//! let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -//! let a = graph.add_node("A"); -//! let b = graph.add_node("B"); -//! let c = graph.add_node("C"); -//! graph.add_edge(a, b, ()); -//! graph.add_edge(b, c, ()); -//! -//! // DFS traversal -//! let order: Vec = algorithms::dfs(&graph, a).collect(); -//! assert_eq!(order, vec![a, b, c]); -//! ``` -//! -//! ## Cycle Detection -//! -//! ```rust,ignore -//! use dotscope::graph::{DirectedGraph, NodeId, algorithms}; -//! -//! // Acyclic graph -//! let mut dag: DirectedGraph<(), ()> = DirectedGraph::new(); -//! let a = dag.add_node(()); -//! let b = dag.add_node(()); -//! dag.add_edge(a, b, ()); -//! -//! assert!(!algorithms::has_cycle(&dag, a)); -//! -//! // Cyclic graph -//! let mut cyclic: DirectedGraph<(), ()> = DirectedGraph::new(); -//! let x = cyclic.add_node(()); -//! let y = cyclic.add_node(()); -//! cyclic.add_edge(x, y, ()); -//! cyclic.add_edge(y, x, ()); -//! -//! assert!(algorithms::has_cycle(&cyclic, x)); -//! ``` - -mod cycles; -mod dominators; -mod scc; -mod topological; -mod traversal; - -// Re-export all public items -pub use cycles::{find_cycle, has_cycle}; -#[allow(unused_imports)] -pub use dominators::{compute_dominance_frontiers, compute_dominators, DominatorTree}; -pub use scc::{condensation, strongly_connected_components}; -pub use topological::topological_sort; -#[allow(unused_imports)] -pub use traversal::{bfs, dfs, postorder, reverse_postorder}; diff --git a/dotscope/src/utils/graph/algorithms/scc.rs b/dotscope/src/utils/graph/algorithms/scc.rs deleted file mode 100644 index ce3965fc..00000000 --- a/dotscope/src/utils/graph/algorithms/scc.rs +++ /dev/null @@ -1,664 +0,0 @@ -//! Strongly Connected Components (SCC) using Tarjan's algorithm. -//! -//! This module provides Tarjan's algorithm for finding strongly connected -//! components in a directed graph. A strongly connected component is a maximal -//! set of vertices such that there is a path from every vertex to every other -//! vertex in the set. -//! -//! # Use Cases -//! -//! - **Recursion detection**: Methods that can call each other form an SCC -//! - **Call graph analysis**: Finding mutually recursive function groups -//! - **Dependency analysis**: Detecting circular dependencies -//! - **Dead code elimination**: Unreachable code forms trivial SCCs - -use crate::utils::graph::{NodeId, Successors}; - -/// Computes the strongly connected components of a directed graph. -/// -/// Uses Tarjan's algorithm with a single DFS pass. The algorithm maintains -/// a stack of vertices and assigns each vertex an index and "lowlink" value. -/// When a vertex's lowlink equals its index, it's the root of an SCC. -/// -/// # Arguments -/// -/// * `graph` - The directed graph to analyze -/// -/// # Returns -/// -/// A vector of SCCs, where each SCC is a vector of `NodeId`s. The SCCs are -/// returned in **reverse topological order** (i.e., if there's an edge from -/// SCC A to SCC B, then A appears after B in the result). -/// -/// # Complexity -/// -/// - Time: O(V + E) -/// - Space: O(V) -/// -/// # Algorithm -/// -/// 1. Perform DFS, assigning each node an index in discovery order -/// 2. Compute lowlink values (minimum index reachable via DFS subtree + back edges) -/// 3. When lowlink[v] == index[v], v is root of an SCC; pop stack until v -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::strongly_connected_components}; -/// -/// // Simple cycle: A -> B -> C -> A -/// let mut graph: DirectedGraph = DirectedGraph::new(); -/// let a = graph.add_node('A'); -/// let b = graph.add_node('B'); -/// let c = graph.add_node('C'); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// graph.add_edge(c, a, ()); -/// -/// let sccs = strongly_connected_components(&graph); -/// // All three nodes form a single SCC -/// assert_eq!(sccs.len(), 1); -/// assert_eq!(sccs[0].len(), 3); -/// ``` -/// -/// # Acyclic Graph Example -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::strongly_connected_components}; -/// -/// // DAG: A -> B -> C -/// let mut graph: DirectedGraph = DirectedGraph::new(); -/// let a = graph.add_node('A'); -/// let b = graph.add_node('B'); -/// let c = graph.add_node('C'); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// -/// let sccs = strongly_connected_components(&graph); -/// // Each node is its own SCC (no cycles) -/// assert_eq!(sccs.len(), 3); -/// for scc in &sccs { -/// assert_eq!(scc.len(), 1); -/// } -/// ``` -pub fn strongly_connected_components(graph: &G) -> Vec> -where - G: Successors, -{ - let node_count = graph.node_count(); - if node_count == 0 { - return Vec::new(); - } - - let mut state = TarjanState::new(node_count); - - // Run Tarjan's algorithm from each unvisited node - for i in 0..node_count { - let node = NodeId::new(i); - if state.index.get(i).copied().flatten().is_none() { - state.strongconnect(graph, node); - } - } - - state.sccs -} - -/// Internal state for Tarjan's algorithm. -struct TarjanState { - /// Discovery index for each node (None if not yet visited) - index: Vec>, - /// Lowlink value for each node - lowlink: Vec, - /// Whether a node is currently on the stack - on_stack: Vec, - /// The DFS stack - stack: Vec, - /// Current index counter - current_index: usize, - /// Collected SCCs - sccs: Vec>, -} - -impl TarjanState { - fn new(n: usize) -> Self { - Self { - index: vec![None; n], - lowlink: vec![0; n], - on_stack: vec![false; n], - stack: Vec::new(), - current_index: 0, - sccs: Vec::new(), - } - } - - fn strongconnect(&mut self, graph: &G, v: NodeId) { - let v_idx = v.index(); - - // Set the depth index for v - if let Some(slot) = self.index.get_mut(v_idx) { - *slot = Some(self.current_index); - } - if let Some(slot) = self.lowlink.get_mut(v_idx) { - *slot = self.current_index; - } - self.current_index = self.current_index.saturating_add(1); - self.stack.push(v); - if let Some(slot) = self.on_stack.get_mut(v_idx) { - *slot = true; - } - - // Consider successors of v - for w in graph.successors(v) { - let w_idx = w.index(); - - let w_index_visited = self.index.get(w_idx).copied().flatten(); - if w_index_visited.is_none() { - // Successor w has not yet been visited; recurse - self.strongconnect(graph, w); - let lw = self.lowlink.get(w_idx).copied().unwrap_or(usize::MAX); - if let Some(slot) = self.lowlink.get_mut(v_idx) { - *slot = (*slot).min(lw); - } - } else if self.on_stack.get(w_idx).copied().unwrap_or(false) { - // Successor w is on stack and hence in the current SCC - if let Some(idx) = w_index_visited { - if let Some(slot) = self.lowlink.get_mut(v_idx) { - *slot = (*slot).min(idx); - } - } - } - } - - // If v is a root node, pop the stack and generate an SCC - let v_index = self.index.get(v_idx).copied().flatten(); - if let Some(idx) = v_index { - if self.lowlink.get(v_idx).copied().unwrap_or(usize::MAX) == idx { - let mut scc = Vec::new(); - while let Some(w) = self.stack.pop() { - if let Some(slot) = self.on_stack.get_mut(w.index()) { - *slot = false; - } - scc.push(w); - if w == v { - break; - } - } - self.sccs.push(scc); - } - } - } -} - -/// Returns the condensation graph: a DAG where each SCC is collapsed to a single node. -/// -/// The condensation graph has one node per SCC, with edges representing -/// connections between different SCCs. This is always a DAG (directed acyclic -/// graph) since edges within SCCs are collapsed. -/// -/// # Arguments -/// -/// * `graph` - The original graph -/// * `sccs` - The SCCs as returned by `strongly_connected_components` -/// -/// # Returns -/// -/// A tuple containing: -/// - A vector mapping each original node to its SCC index -/// - A vector of edges `(from_scc, to_scc)` in the condensation graph -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::{strongly_connected_components, condensation}}; -/// -/// let mut graph: DirectedGraph = DirectedGraph::new(); -/// let a = graph.add_node('A'); -/// let b = graph.add_node('B'); -/// let c = graph.add_node('C'); -/// let d = graph.add_node('D'); -/// -/// // Cycle A -> B -> A, plus C -> D (separate) -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, a, ()); -/// graph.add_edge(a, c, ()); -/// graph.add_edge(c, d, ()); -/// -/// let sccs = strongly_connected_components(&graph); -/// let (node_to_scc, edges) = condensation(&graph, &sccs); -/// -/// // A and B are in the same SCC -/// assert_eq!(node_to_scc[a.index()], node_to_scc[b.index()]); -/// ``` -pub fn condensation(graph: &G, sccs: &[Vec]) -> (Vec, Vec<(usize, usize)>) -where - G: Successors, -{ - let node_count = graph.node_count(); - - // Build mapping from node to SCC index - let mut node_to_scc = vec![0usize; node_count]; - for (scc_idx, scc) in sccs.iter().enumerate() { - for &node in scc { - if let Some(slot) = node_to_scc.get_mut(node.index()) { - *slot = scc_idx; - } - } - } - - // Find edges between different SCCs - let mut edges = Vec::new(); - let mut seen_edges = std::collections::HashSet::new(); - - for i in 0..node_count { - let from_node = NodeId::new(i); - let from_scc = node_to_scc.get(i).copied().unwrap_or(0); - - for to_node in graph.successors(from_node) { - let to_scc = node_to_scc.get(to_node.index()).copied().unwrap_or(0); - - if from_scc != to_scc && seen_edges.insert((from_scc, to_scc)) { - edges.push((from_scc, to_scc)); - } - } - } - - (node_to_scc, edges) -} - -#[cfg(test)] -mod tests { - use std::collections::HashSet; - - use crate::utils::graph::{ - algorithms::scc::{condensation, strongly_connected_components}, - DirectedGraph, NodeId, - }; - - #[test] - fn test_scc_empty_graph() { - let graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let sccs = strongly_connected_components(&graph); - assert!(sccs.is_empty()); - } - - #[test] - fn test_scc_single_node() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - - let sccs = strongly_connected_components(&graph); - assert_eq!(sccs.len(), 1); - assert_eq!(sccs[0], vec![a]); - } - - #[test] - fn test_scc_single_node_self_loop() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - graph.add_edge(a, a, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - assert_eq!(sccs.len(), 1); - assert_eq!(sccs[0], vec![a]); - } - - #[test] - fn test_scc_linear_chain() { - // A -> B -> C (no cycles) - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - // Each node is its own SCC - assert_eq!(sccs.len(), 3); - for scc in &sccs { - assert_eq!(scc.len(), 1); - } - - // SCCs are in reverse topological order: C, B, A - let scc_nodes: Vec = sccs.iter().map(|scc| scc[0]).collect(); - assert_eq!(scc_nodes, vec![c, b, a]); - } - - #[test] - fn test_scc_simple_cycle() { - // A -> B -> C -> A - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, a, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - // All three nodes form one SCC - assert_eq!(sccs.len(), 1); - assert_eq!(sccs[0].len(), 3); - - let scc_set: HashSet = sccs[0].iter().copied().collect(); - assert!(scc_set.contains(&a)); - assert!(scc_set.contains(&b)); - assert!(scc_set.contains(&c)); - } - - #[test] - fn test_scc_two_nodes_cycle() { - // A <-> B - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - assert_eq!(sccs.len(), 1); - assert_eq!(sccs[0].len(), 2); - } - - #[test] - fn test_scc_multiple_components() { - // Two separate cycles: A <-> B and C <-> D - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - assert_eq!(sccs.len(), 2); - - // Each SCC has 2 nodes - for scc in &sccs { - assert_eq!(scc.len(), 2); - } - } - - #[test] - fn test_scc_connected_cycles() { - // Two cycles connected by an edge: - // A <-> B -> C <-> D - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - assert_eq!(sccs.len(), 2); - - // One SCC contains {A, B}, another contains {C, D} - let mut found_ab = false; - let mut found_cd = false; - - for scc in &sccs { - let scc_set: HashSet = scc.iter().copied().collect(); - if scc_set.contains(&a) && scc_set.contains(&b) && scc.len() == 2 { - found_ab = true; - } - if scc_set.contains(&c) && scc_set.contains(&d) && scc.len() == 2 { - found_cd = true; - } - } - - assert!(found_ab); - assert!(found_cd); - } - - #[test] - fn test_scc_diamond_no_cycle() { - // Diamond: A -> B -> D, A -> C -> D (no cycles) - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - // Each node is its own SCC - assert_eq!(sccs.len(), 4); - for scc in &sccs { - assert_eq!(scc.len(), 1); - } - } - - #[test] - fn test_scc_figure_eight() { - // Figure-8 pattern: A <-> B, B -> C, C <-> D - // This creates two SCCs connected through B and C - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - // Two SCCs: {A, B} and {C, D} - assert_eq!(sccs.len(), 2); - } - - #[test] - fn test_scc_reverse_topological_order() { - // Chain with cycles: (A <-> B) -> (C <-> D) -> E - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - let e = graph.add_node('E'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - graph.add_edge(d, e, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - assert_eq!(sccs.len(), 3); - - // Find which SCC each node belongs to - let find_scc = - |node: NodeId| -> usize { sccs.iter().position(|scc| scc.contains(&node)).unwrap() }; - - let scc_ab = find_scc(a); - let scc_cd = find_scc(c); - let scc_e = find_scc(e); - - // In reverse topological order: E comes first, then CD, then AB - assert!(scc_e < scc_cd); - assert!(scc_cd < scc_ab); - } - - #[test] - fn test_scc_large_cycle() { - // Large cycle: 0 -> 1 -> 2 -> ... -> 99 -> 0 - let mut graph: DirectedGraph = DirectedGraph::new(); - let nodes: Vec = (0..100).map(|i| graph.add_node(i)).collect(); - - for i in 0..100 { - graph.add_edge(nodes[i], nodes[(i + 1) % 100], ()).unwrap(); - } - - let sccs = strongly_connected_components(&graph); - - // All 100 nodes form one SCC - assert_eq!(sccs.len(), 1); - assert_eq!(sccs[0].len(), 100); - } - - #[test] - fn test_condensation_basic() { - // A <-> B -> C (single edge to C) - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - let (node_to_scc, edges) = condensation(&graph, &sccs); - - // A and B are in the same SCC - assert_eq!(node_to_scc[a.index()], node_to_scc[b.index()]); - - // C is in a different SCC - assert_ne!(node_to_scc[a.index()], node_to_scc[c.index()]); - - // There's one edge in the condensation graph (from {A,B} SCC to {C} SCC) - assert_eq!(edges.len(), 1); - } - - #[test] - fn test_condensation_no_edges() { - // Two disconnected cycles - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - let (_, edges) = condensation(&graph, &sccs); - - // No edges between SCCs - assert!(edges.is_empty()); - } - - #[test] - fn test_condensation_chain() { - // Chain of SCCs: (A<->B) -> (C<->D) -> E - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - let e = graph.add_node('E'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, a, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, c, ()).unwrap(); - graph.add_edge(d, e, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - let (node_to_scc, edges) = condensation(&graph, &sccs); - - // Verify SCC assignments - assert_eq!(node_to_scc[a.index()], node_to_scc[b.index()]); - assert_eq!(node_to_scc[c.index()], node_to_scc[d.index()]); - assert_ne!(node_to_scc[a.index()], node_to_scc[c.index()]); - assert_ne!(node_to_scc[c.index()], node_to_scc[e.index()]); - - // Two edges in condensation: AB->CD, CD->E - assert_eq!(edges.len(), 2); - } - - #[test] - fn test_scc_disconnected_graph() { - // Completely disconnected nodes - let mut graph: DirectedGraph = DirectedGraph::new(); - let _a = graph.add_node('A'); - let _b = graph.add_node('B'); - let _c = graph.add_node('C'); - - let sccs = strongly_connected_components(&graph); - - // Each node is its own SCC - assert_eq!(sccs.len(), 3); - for scc in &sccs { - assert_eq!(scc.len(), 1); - } - } - - #[test] - fn test_scc_complex_structure() { - // Complex graph with multiple SCCs - // - // +---+ - // v | - // A-->B-->C - // | | - // v v - // D<->E-->F - // | - // v - // G - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node('A'); - let b = graph.add_node('B'); - let c = graph.add_node('C'); - let d = graph.add_node('D'); - let e = graph.add_node('E'); - let f = graph.add_node('F'); - let g = graph.add_node('G'); - - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, b, ()).unwrap(); // B <-> C cycle - graph.add_edge(a, d, ()).unwrap(); - graph.add_edge(b, e, ()).unwrap(); - graph.add_edge(d, e, ()).unwrap(); - graph.add_edge(e, d, ()).unwrap(); // D <-> E cycle - graph.add_edge(e, f, ()).unwrap(); - graph.add_edge(f, g, ()).unwrap(); - - let sccs = strongly_connected_components(&graph); - - // SCCs: {B, C}, {D, E}, {A}, {F}, {G} - assert_eq!(sccs.len(), 5); - - // Count SCC sizes - let mut size_counts = std::collections::HashMap::new(); - for scc in &sccs { - *size_counts.entry(scc.len()).or_insert(0) += 1; - } - - // Two SCCs of size 2, three of size 1 - assert_eq!(size_counts.get(&2), Some(&2)); - assert_eq!(size_counts.get(&1), Some(&3)); - } -} diff --git a/dotscope/src/utils/graph/algorithms/topological.rs b/dotscope/src/utils/graph/algorithms/topological.rs deleted file mode 100644 index 1e6cab3d..00000000 --- a/dotscope/src/utils/graph/algorithms/topological.rs +++ /dev/null @@ -1,336 +0,0 @@ -//! Topological sorting for directed acyclic graphs (DAGs). -//! -//! This module provides Kahn's algorithm for computing a topological ordering -//! of nodes in a directed acyclic graph. A topological ordering is a linear -//! ordering of vertices such that for every directed edge (u, v), vertex u -//! comes before v in the ordering. -//! -//! # Use Cases -//! -//! - Dependency resolution (build systems, package managers) -//! - Task scheduling with precedence constraints -//! - Ordering metadata loader execution -//! - Data flow analysis iteration ordering - -use std::collections::VecDeque; - -use crate::utils::graph::{GraphBase, NodeId, Predecessors, Successors}; - -/// Computes a topological ordering of nodes reachable from any entry node. -/// -/// Uses Kahn's algorithm which processes nodes with no incoming edges first, -/// then removes those nodes and repeats. This produces a valid topological -/// ordering if and only if the graph is acyclic. -/// -/// # Arguments -/// -/// * `graph` - The graph to sort topologically -/// -/// # Returns -/// -/// `Some(Vec)` containing nodes in topological order if the graph is -/// acyclic (a DAG), `None` if the graph contains a cycle. -/// -/// # Complexity -/// -/// - Time: O(V + E) where V is the number of vertices and E is the number of edges -/// - Space: O(V) for the in-degree counts and queue -/// -/// # Algorithm -/// -/// 1. Compute in-degree for all nodes -/// 2. Initialize queue with all nodes having in-degree 0 -/// 3. While queue is not empty: -/// - Remove a node from the queue and add to result -/// - For each successor, decrement its in-degree -/// - If successor's in-degree becomes 0, add to queue -/// 4. If result contains all nodes, return it; otherwise graph has a cycle -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::topological_sort}; -/// -/// // A simple DAG: A -> B -> D, A -> C -> D -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// let d = graph.add_node("D"); -/// -/// graph.add_edge(a, b, ()); -/// graph.add_edge(a, c, ()); -/// graph.add_edge(b, d, ()); -/// graph.add_edge(c, d, ()); -/// -/// let order = topological_sort(&graph); -/// assert!(order.is_some()); -/// -/// let order = order.unwrap(); -/// // A must come before B, C; B and C must come before D -/// let a_pos = order.iter().position(|&n| n == a).unwrap(); -/// let d_pos = order.iter().position(|&n| n == d).unwrap(); -/// assert!(a_pos < d_pos); -/// ``` -/// -/// # Cyclic Graph Example -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::topological_sort}; -/// -/// // A graph with a cycle: A -> B -> C -> A -/// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); -/// let a = graph.add_node(()); -/// let b = graph.add_node(()); -/// let c = graph.add_node(()); -/// -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// graph.add_edge(c, a, ()); -/// -/// // Cannot topologically sort a graph with cycles -/// assert!(topological_sort(&graph).is_none()); -/// ``` -pub fn topological_sort(graph: &G) -> Option> -where - G: GraphBase + Successors + Predecessors, -{ - let node_count = graph.node_count(); - if node_count == 0 { - return Some(Vec::new()); - } - - // Compute in-degrees - let mut in_degree: Vec = vec![0; node_count]; - for node in graph.node_ids() { - for _ in graph.predecessors(node) { - let slot = in_degree.get_mut(node.index())?; - *slot = slot.saturating_add(1); - } - } - - // Initialize queue with nodes having in-degree 0 - let mut queue: VecDeque = VecDeque::new(); - for node in graph.node_ids() { - if *in_degree.get(node.index())? == 0 { - queue.push_back(node); - } - } - - let mut result = Vec::with_capacity(node_count); - - while let Some(node) = queue.pop_front() { - result.push(node); - - for successor in graph.successors(node) { - let slot = in_degree.get_mut(successor.index())?; - *slot = slot.saturating_sub(1); - if *slot == 0 { - queue.push_back(successor); - } - } - } - - // If we didn't process all nodes, there must be a cycle - if result.len() == node_count { - Some(result) - } else { - None - } -} - -#[cfg(test)] -mod tests { - use crate::utils::graph::{algorithms::topological::topological_sort, DirectedGraph, NodeId}; - - #[test] - fn test_topological_sort_empty_graph() { - let graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let result = topological_sort(&graph); - assert_eq!(result, Some(Vec::new())); - } - - #[test] - fn test_topological_sort_single_node() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - let result = topological_sort(&graph); - assert_eq!(result, Some(vec![a])); - } - - #[test] - fn test_topological_sort_linear() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - - let result = topological_sort(&graph); - assert!(result.is_some()); - assert_eq!(result.unwrap(), vec![a, b, c]); - } - - #[test] - fn test_topological_sort_diamond() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - - let result = topological_sort(&graph); - assert!(result.is_some()); - - let order = result.unwrap(); - assert_eq!(order.len(), 4); - - // Verify ordering constraints - let pos = |n: NodeId| order.iter().position(|&x| x == n).unwrap(); - assert!(pos(a) < pos(b)); - assert!(pos(a) < pos(c)); - assert!(pos(b) < pos(d)); - assert!(pos(c) < pos(d)); - } - - #[test] - fn test_topological_sort_simple_cycle() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - let b = graph.add_node(()); - let c = graph.add_node(()); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, a, ()).unwrap(); - - assert!(topological_sort(&graph).is_none()); - } - - #[test] - fn test_topological_sort_self_loop() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - graph.add_edge(a, a, ()).unwrap(); - - assert!(topological_sort(&graph).is_none()); - } - - #[test] - fn test_topological_sort_disconnected_components() { - // Two separate chains: A -> B and C -> D - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - - let result = topological_sort(&graph); - assert!(result.is_some()); - - let order = result.unwrap(); - assert_eq!(order.len(), 4); - - // Verify ordering within each chain - let pos = |n: NodeId| order.iter().position(|&x| x == n).unwrap(); - assert!(pos(a) < pos(b)); - assert!(pos(c) < pos(d)); - } - - #[test] - fn test_topological_sort_partial_cycle() { - // A -> B -> C -> D - // ^ | - // +-------+ (cycle B-C-D-B, but A is before the cycle) - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph.add_edge(d, b, ()).unwrap(); - - // Has a cycle, so should fail - assert!(topological_sort(&graph).is_none()); - } - - #[test] - fn test_topological_sort_multiple_valid_orderings() { - // A -> C, B -> C (A and B have no ordering constraint) - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - - let result = topological_sort(&graph); - assert!(result.is_some()); - - let order = result.unwrap(); - assert_eq!(order.len(), 3); - - // C must be last, but A and B can be in either order - let pos = |n: NodeId| order.iter().position(|&x| x == n).unwrap(); - assert!(pos(a) < pos(c)); - assert!(pos(b) < pos(c)); - } - - #[test] - fn test_topological_sort_wide_dag() { - // Root -> [A, B, C, D, E] (many independent children) - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let root = graph.add_node("Root"); - let children: Vec = (0..5) - .map(|_| { - let child = graph.add_node("Child"); - graph.add_edge(root, child, ()).unwrap(); - child - }) - .collect(); - - let result = topological_sort(&graph); - assert!(result.is_some()); - - let order = result.unwrap(); - assert_eq!(order.len(), 6); - - // Root must come first - assert_eq!(order[0], root); - - // All children must come after root - for child in children { - assert!(order.contains(&child)); - } - } - - #[test] - fn test_topological_sort_deep_dag() { - // A chain: 0 -> 1 -> 2 -> ... -> 99 - let mut graph: DirectedGraph = DirectedGraph::new(); - let nodes: Vec = (0..100).map(|i| graph.add_node(i)).collect(); - - for i in 0..99 { - graph.add_edge(nodes[i], nodes[i + 1], ()).unwrap(); - } - - let result = topological_sort(&graph); - assert!(result.is_some()); - - let order = result.unwrap(); - assert_eq!(order.len(), 100); - - // Must be in exact order - for i in 0..100 { - assert_eq!(order[i], nodes[i]); - } - } -} diff --git a/dotscope/src/utils/graph/algorithms/traversal.rs b/dotscope/src/utils/graph/algorithms/traversal.rs deleted file mode 100644 index d74915de..00000000 --- a/dotscope/src/utils/graph/algorithms/traversal.rs +++ /dev/null @@ -1,748 +0,0 @@ -//! Graph traversal algorithms. -//! -//! This module provides depth-first and breadth-first traversal algorithms -//! for directed graphs. These are fundamental building blocks for more complex -//! graph algorithms and program analysis. -//! -//! # Algorithms -//! -//! - [`dfs`] - Iterative depth-first search (pre-order) -//! - [`bfs`] - Breadth-first search -//! - [`postorder`] - Depth-first search with post-order visitation -//! - [`reverse_postorder`] - Reverse post-order (useful for forward data flow) -//! -//! # Iteration vs Collection -//! -//! The [`dfs`] and [`bfs`] functions return iterators for lazy evaluation, -//! avoiding unnecessary allocations when only partial traversal is needed. -//! The [`postorder`] and [`reverse_postorder`] functions return collected -//! vectors since the order requires full traversal anyway. - -use std::collections::VecDeque; - -use crate::utils::graph::{NodeId, Successors}; - -/// Depth-first search iterator over graph nodes. -/// -/// This iterator performs an iterative (non-recursive) depth-first traversal -/// starting from a given node. It visits each reachable node exactly once -/// in pre-order (visiting a node before its descendants). -/// -/// # Type Parameters -/// -/// * `'g` - Lifetime of the graph reference -/// * `G` - Graph type implementing [`Successors`] -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::dfs}; -/// -/// let mut graph: DirectedGraph = DirectedGraph::new(); -/// let a = graph.add_node('A'); -/// let b = graph.add_node('B'); -/// let c = graph.add_node('C'); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(a, c, ()); -/// -/// let visited: Vec = dfs(&graph, a).collect(); -/// assert_eq!(visited.len(), 3); -/// assert_eq!(visited[0], a); // A is visited first -/// ``` -pub struct DfsIterator<'g, G: Successors> { - graph: &'g G, - stack: Vec, - visited: Vec, -} - -impl<'g, G: Successors> DfsIterator<'g, G> { - fn new(graph: &'g G, start: NodeId) -> Self { - let node_count = graph.node_count(); - if start.index() >= node_count { - return DfsIterator { - graph, - stack: Vec::new(), - visited: Vec::new(), - }; - } - - let mut visited = vec![false; node_count]; - if let Some(slot) = visited.get_mut(start.index()) { - *slot = true; - } - - DfsIterator { - graph, - stack: vec![start], - visited, - } - } -} - -impl Iterator for DfsIterator<'_, G> { - type Item = NodeId; - - fn next(&mut self) -> Option { - let node = self.stack.pop()?; - if self.visited.is_empty() { - return None; - } - - // Push unvisited successors onto the stack in reverse order - // so that they are visited in the original order - let successors: Vec = self.graph.successors(node).collect(); - for &succ in successors.iter().rev() { - if let Some(slot) = self.visited.get_mut(succ.index()) { - if !*slot { - *slot = true; - self.stack.push(succ); - } - } - } - - Some(node) - } -} - -/// Returns a depth-first search iterator starting from the given node. -/// -/// The iterator visits each reachable node exactly once in pre-order -/// (visiting a node before its descendants). Nodes not reachable from -/// the start node are not visited. -/// -/// # Arguments -/// -/// * `graph` - The graph to traverse -/// * `start` - The starting node for traversal -/// -/// # Returns -/// -/// An iterator yielding `NodeId` in DFS pre-order. -/// -/// # Complexity -/// -/// - Time: O(V + E) where V is the number of vertices and E is the number of edges -/// - Space: O(V) for the visited set and stack -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::dfs}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// -/// // Visit nodes in DFS order -/// for node in dfs(&graph, a) { -/// println!("Visiting {:?}", node); -/// } -/// -/// // Collect all reachable nodes -/// let reachable: Vec = dfs(&graph, a).collect(); -/// assert_eq!(reachable.len(), 3); -/// ``` -pub fn dfs(graph: &G, start: NodeId) -> DfsIterator<'_, G> { - DfsIterator::new(graph, start) -} - -/// Breadth-first search iterator over graph nodes. -/// -/// This iterator performs a breadth-first traversal starting from a given node. -/// It visits each reachable node exactly once, exploring all nodes at distance d -/// before visiting any node at distance d+1. -/// -/// # Type Parameters -/// -/// * `'g` - Lifetime of the graph reference -/// * `G` - Graph type implementing [`Successors`] -pub struct BfsIterator<'g, G: Successors> { - graph: &'g G, - queue: VecDeque, - visited: Vec, -} - -impl<'g, G: Successors> BfsIterator<'g, G> { - fn new(graph: &'g G, start: NodeId) -> Self { - let node_count = graph.node_count(); - if start.index() >= node_count { - return BfsIterator { - graph, - queue: VecDeque::new(), - visited: Vec::new(), - }; - } - - let mut visited = vec![false; node_count]; - if let Some(slot) = visited.get_mut(start.index()) { - *slot = true; - } - - let mut queue = VecDeque::new(); - queue.push_back(start); - - BfsIterator { - graph, - queue, - visited, - } - } -} - -impl Iterator for BfsIterator<'_, G> { - type Item = NodeId; - - fn next(&mut self) -> Option { - let node = self.queue.pop_front()?; - if self.visited.is_empty() { - return None; - } - - // Enqueue unvisited successors - for succ in self.graph.successors(node) { - if let Some(slot) = self.visited.get_mut(succ.index()) { - if !*slot { - *slot = true; - self.queue.push_back(succ); - } - } - } - - Some(node) - } -} - -/// Returns a breadth-first search iterator starting from the given node. -/// -/// The iterator visits each reachable node exactly once, exploring nodes -/// in order of increasing distance from the start. This is useful for -/// finding shortest paths in unweighted graphs. -/// -/// # Arguments -/// -/// * `graph` - The graph to traverse -/// * `start` - The starting node for traversal -/// -/// # Returns -/// -/// An iterator yielding `NodeId` in BFS order. -/// -/// # Complexity -/// -/// - Time: O(V + E) where V is the number of vertices and E is the number of edges -/// - Space: O(V) for the visited set and queue -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::bfs}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// let d = graph.add_node("D"); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(a, c, ()); -/// graph.add_edge(b, d, ()); -/// graph.add_edge(c, d, ()); -/// -/// // BFS visits by distance from start -/// let order: Vec = bfs(&graph, a).collect(); -/// assert_eq!(order[0], a); // Distance 0 -/// // B and C are at distance 1 (order may vary) -/// // D is at distance 2 -/// assert_eq!(order[3], d); -/// ``` -pub fn bfs(graph: &G, start: NodeId) -> BfsIterator<'_, G> { - BfsIterator::new(graph, start) -} - -/// Computes the postorder traversal of nodes reachable from the start. -/// -/// In postorder, a node is visited after all its descendants have been visited. -/// This is useful for algorithms that need to process children before parents. -/// -/// # Arguments -/// -/// * `graph` - The graph to traverse -/// * `start` - The starting node for traversal -/// -/// # Returns -/// -/// A vector of `NodeId` in postorder. -/// -/// # Complexity -/// -/// - Time: O(V + E) -/// - Space: O(V) -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::postorder}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// -/// let order = postorder(&graph, a); -/// // C comes before B, B comes before A -/// assert_eq!(order, vec![c, b, a]); -/// ``` -#[allow(clippy::items_after_statements)] -pub fn postorder(graph: &G, start: NodeId) -> Vec { - let node_count = graph.node_count(); - - // Validate start node - return empty vec if invalid - if start.index() >= node_count { - return Vec::new(); - } - - let mut visited = vec![false; node_count]; - let mut result = Vec::with_capacity(node_count); - - // Iterative postorder using explicit stack with state - #[derive(Clone, Copy)] - enum State { - Enter, - Exit, - } - - let mut stack = vec![(start, State::Enter)]; - - while let Some((node, state)) = stack.pop() { - match state { - State::Enter => { - let Some(slot) = visited.get_mut(node.index()) else { - continue; - }; - if *slot { - continue; - } - *slot = true; - - // Push exit state for this node (will be processed after children) - stack.push((node, State::Exit)); - - // Push children in reverse order so they're processed in order - let successors: Vec = graph.successors(node).collect(); - for &succ in successors.iter().rev() { - if let Some(false) = visited.get(succ.index()).copied() { - stack.push((succ, State::Enter)); - } - } - } - State::Exit => { - result.push(node); - } - } - } - - result -} - -/// Computes the reverse postorder traversal of nodes reachable from the start. -/// -/// Reverse postorder (RPO) is the reverse of postorder: nodes are visited -/// such that a node comes before any of its successors (in a DAG). This is -/// the preferred iteration order for forward data flow analysis. -/// -/// # Arguments -/// -/// * `graph` - The graph to traverse -/// * `start` - The starting node for traversal -/// -/// # Returns -/// -/// A vector of `NodeId` in reverse postorder. -/// -/// # Complexity -/// -/// - Time: O(V + E) -/// - Space: O(V) -/// -/// # Use Cases -/// -/// - **Forward data flow analysis**: Iterating in RPO ensures that when analyzing -/// a node, all its predecessors (in a DAG) have already been analyzed -/// - **Dominance frontier computation**: RPO ensures correct order for iterative algorithms -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, algorithms::reverse_postorder}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// graph.add_edge(a, b, ()); -/// graph.add_edge(b, c, ()); -/// -/// let order = reverse_postorder(&graph, a); -/// // A comes before B, B comes before C -/// assert_eq!(order, vec![a, b, c]); -/// ``` -pub fn reverse_postorder(graph: &G, start: NodeId) -> Vec { - let mut result = postorder(graph, start); - result.reverse(); - result -} - -#[cfg(test)] -mod tests { - use crate::utils::graph::{ - algorithms::traversal::{bfs, dfs, postorder, reverse_postorder}, - DirectedGraph, NodeId, - }; - - fn create_linear_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph - } - - fn create_diamond_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph - } - - fn create_cycle_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, a, ()).unwrap(); - graph - } - - fn create_tree_graph() -> DirectedGraph<'static, &'static str, ()> { - // A - // / \ - // B C - // / \ \ - // D E F - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - let e = graph.add_node("E"); - let f = graph.add_node("F"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(b, e, ()).unwrap(); - graph.add_edge(c, f, ()).unwrap(); - graph - } - - #[test] - fn test_dfs_linear() { - let graph = create_linear_graph(); - let order: Vec = dfs(&graph, NodeId::new(0)).collect(); - assert_eq!(order, vec![NodeId::new(0), NodeId::new(1), NodeId::new(2)]); - } - - #[test] - fn test_dfs_diamond() { - let graph = create_diamond_graph(); - let order: Vec = dfs(&graph, NodeId::new(0)).collect(); - - // Should visit all 4 nodes - assert_eq!(order.len(), 4); - - // A should be first - assert_eq!(order[0], NodeId::new(0)); - - // D should be visited after both B and C are on the path - // The exact order depends on implementation, but D should be reachable - assert!(order.contains(&NodeId::new(3))); - } - - #[test] - fn test_dfs_cycle() { - let graph = create_cycle_graph(); - let order: Vec = dfs(&graph, NodeId::new(0)).collect(); - - // Should visit each node exactly once despite the cycle - assert_eq!(order.len(), 3); - assert_eq!(order[0], NodeId::new(0)); - } - - #[test] - fn test_dfs_tree() { - let graph = create_tree_graph(); - let order: Vec = dfs(&graph, NodeId::new(0)).collect(); - - // Should visit all 6 nodes - assert_eq!(order.len(), 6); - - // A should be first - assert_eq!(order[0], NodeId::new(0)); - } - - #[test] - fn test_dfs_single_node() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - - let order: Vec = dfs(&graph, a).collect(); - assert_eq!(order, vec![a]); - } - - #[test] - fn test_dfs_disconnected() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let _c = graph.add_node("C"); // Disconnected - - graph.add_edge(a, b, ()).unwrap(); - - let order: Vec = dfs(&graph, a).collect(); - - // Should only visit A and B, not C - assert_eq!(order.len(), 2); - assert!(!order.contains(&NodeId::new(2))); - } - - #[test] - fn test_bfs_linear() { - let graph = create_linear_graph(); - let order: Vec = bfs(&graph, NodeId::new(0)).collect(); - assert_eq!(order, vec![NodeId::new(0), NodeId::new(1), NodeId::new(2)]); - } - - #[test] - fn test_bfs_diamond() { - let graph = create_diamond_graph(); - let order: Vec = bfs(&graph, NodeId::new(0)).collect(); - - // Should visit all 4 nodes - assert_eq!(order.len(), 4); - - // A should be first (distance 0) - assert_eq!(order[0], NodeId::new(0)); - - // B and C should be next (distance 1) - assert!(order[1] == NodeId::new(1) || order[1] == NodeId::new(2)); - assert!(order[2] == NodeId::new(1) || order[2] == NodeId::new(2)); - - // D should be last (distance 2) - assert_eq!(order[3], NodeId::new(3)); - } - - #[test] - fn test_bfs_tree() { - let graph = create_tree_graph(); - let order: Vec = bfs(&graph, NodeId::new(0)).collect(); - - // Should visit all 6 nodes - assert_eq!(order.len(), 6); - - // A should be first (level 0) - assert_eq!(order[0], NodeId::new(0)); - - // B and C should be next (level 1) - let level_1: Vec = order[1..3].to_vec(); - assert!(level_1.contains(&NodeId::new(1))); - assert!(level_1.contains(&NodeId::new(2))); - - // D, E, F should be last (level 2) - let level_2: Vec = order[3..6].to_vec(); - assert!(level_2.contains(&NodeId::new(3))); - assert!(level_2.contains(&NodeId::new(4))); - assert!(level_2.contains(&NodeId::new(5))); - } - - #[test] - fn test_bfs_single_node() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - - let order: Vec = bfs(&graph, a).collect(); - assert_eq!(order, vec![a]); - } - - #[test] - fn test_bfs_disconnected() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let _c = graph.add_node("C"); // Disconnected - - graph.add_edge(a, b, ()).unwrap(); - - let order: Vec = bfs(&graph, a).collect(); - - // Should only visit A and B, not C - assert_eq!(order.len(), 2); - } - - #[test] - fn test_postorder_linear() { - let graph = create_linear_graph(); - let order = postorder(&graph, NodeId::new(0)); - - // Postorder: children before parents - // C, B, A - assert_eq!(order, vec![NodeId::new(2), NodeId::new(1), NodeId::new(0)]); - } - - #[test] - fn test_postorder_diamond() { - let graph = create_diamond_graph(); - let order = postorder(&graph, NodeId::new(0)); - - // All 4 nodes should be visited - assert_eq!(order.len(), 4); - - // A should be last (root) - assert_eq!(*order.last().unwrap(), NodeId::new(0)); - - // D should appear before both B and C (since it's their child) - let d_pos = order.iter().position(|&n| n == NodeId::new(3)).unwrap(); - let b_pos = order.iter().position(|&n| n == NodeId::new(1)).unwrap(); - let c_pos = order.iter().position(|&n| n == NodeId::new(2)).unwrap(); - - assert!(d_pos < b_pos || d_pos < c_pos); - } - - #[test] - fn test_postorder_tree() { - let graph = create_tree_graph(); - let order = postorder(&graph, NodeId::new(0)); - - // All 6 nodes - assert_eq!(order.len(), 6); - - // A should be last - assert_eq!(*order.last().unwrap(), NodeId::new(0)); - - // Leaves should come before their parents - // D and E should come before B - let d_pos = order.iter().position(|&n| n == NodeId::new(3)).unwrap(); - let e_pos = order.iter().position(|&n| n == NodeId::new(4)).unwrap(); - let b_pos = order.iter().position(|&n| n == NodeId::new(1)).unwrap(); - - assert!(d_pos < b_pos); - assert!(e_pos < b_pos); - } - - #[test] - fn test_reverse_postorder_linear() { - let graph = create_linear_graph(); - let order = reverse_postorder(&graph, NodeId::new(0)); - - // Reverse postorder: parents before children - // A, B, C - assert_eq!(order, vec![NodeId::new(0), NodeId::new(1), NodeId::new(2)]); - } - - #[test] - fn test_reverse_postorder_diamond() { - let graph = create_diamond_graph(); - let order = reverse_postorder(&graph, NodeId::new(0)); - - // All 4 nodes - assert_eq!(order.len(), 4); - - // A should be first - assert_eq!(order[0], NodeId::new(0)); - - // D should be last - assert_eq!(*order.last().unwrap(), NodeId::new(3)); - } - - #[test] - fn test_reverse_postorder_tree() { - let graph = create_tree_graph(); - let order = reverse_postorder(&graph, NodeId::new(0)); - - // All 6 nodes - assert_eq!(order.len(), 6); - - // A should be first - assert_eq!(order[0], NodeId::new(0)); - - // Parents should come before children - let a_pos = order.iter().position(|&n| n == NodeId::new(0)).unwrap(); - let b_pos = order.iter().position(|&n| n == NodeId::new(1)).unwrap(); - let d_pos = order.iter().position(|&n| n == NodeId::new(3)).unwrap(); - - assert!(a_pos < b_pos); - assert!(b_pos < d_pos); - } - - #[test] - fn test_reverse_postorder_with_cycle() { - let graph = create_cycle_graph(); - let order = reverse_postorder(&graph, NodeId::new(0)); - - // Should still visit all nodes exactly once - assert_eq!(order.len(), 3); - } - - #[test] - fn test_self_loop() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - graph.add_edge(a, a, ()).unwrap(); - - // DFS should visit the node exactly once - let dfs_order: Vec = dfs(&graph, a).collect(); - assert_eq!(dfs_order, vec![a]); - - // BFS should visit the node exactly once - let bfs_order: Vec = bfs(&graph, a).collect(); - assert_eq!(bfs_order, vec![a]); - - // Postorder should have the node once - let post_order = postorder(&graph, a); - assert_eq!(post_order, vec![a]); - } - - #[test] - fn test_multiple_edges() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - let b = graph.add_node(()); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, b, ()).unwrap(); // Duplicate edge - - // Should still visit B only once - let order: Vec = dfs(&graph, a).collect(); - assert_eq!(order, vec![a, b]); - } - - #[test] - fn test_iterator_early_termination() { - let graph = create_tree_graph(); - - // Take only first 3 nodes - let partial: Vec = dfs(&graph, NodeId::new(0)).take(3).collect(); - assert_eq!(partial.len(), 3); - } -} diff --git a/dotscope/src/utils/graph/directed.rs b/dotscope/src/utils/graph/directed.rs deleted file mode 100644 index 60b93175..00000000 --- a/dotscope/src/utils/graph/directed.rs +++ /dev/null @@ -1,1381 +0,0 @@ -//! Core directed graph implementation. -//! -//! This module provides [`DirectedGraph`], the primary graph data structure used -//! throughout the analysis infrastructure. The implementation uses adjacency lists -//! for efficient traversal while maintaining full edge data access. -//! -//! The graph supports both owned and borrowed node data through [`Cow`], enabling -//! zero-copy graph construction when nodes are borrowed from external storage. - -use std::borrow::Cow; - -use crate::{ - utils::graph::{ - edge::EdgeId, - node::NodeId, - traits::{GraphBase, Predecessors, Successors}, - }, - Error, Result, -}; - -/// Internal storage for edge data and endpoints. -#[derive(Debug, Clone)] -struct EdgeData { - /// Source node of the edge - source: NodeId, - /// Target node of the edge - target: NodeId, - /// User-provided edge data - data: E, -} - -/// A directed graph with typed node and edge data. -/// -/// `DirectedGraph` provides a flexible, efficient graph implementation suitable for -/// program analysis tasks. It supports: -/// -/// - Generic node data (`N`) - Store any data associated with each node -/// - Generic edge data (`E`) - Store any data associated with each edge -/// - Efficient adjacency queries via adjacency lists -/// - Both forward (successors) and backward (predecessors) traversal -/// - Borrowed or owned node storage via [`Cow`] -/// -/// # Memory Layout -/// -/// The graph uses separate storage for nodes and edges: -/// -/// - Nodes are stored in a [`Cow`] slice, allowing borrowed or owned data -/// - Edges are stored in a contiguous vector indexed by `EdgeId` -/// - Adjacency lists (outgoing/incoming) store `EdgeId` references -/// -/// This design provides O(1) node/edge access and efficient iteration. -/// -/// # Lifetime Parameter -/// -/// The `'a` lifetime parameter represents the lifetime of borrowed node data: -/// - Use `DirectedGraph<'static, N, E>` for owned graphs (nodes are `Cow::Owned`) -/// - Use `DirectedGraph<'a, N, E>` when borrowing nodes from external storage -/// -/// # Thread Safety -/// -/// `DirectedGraph` is [`Send`] and [`Sync`] when both `N` and `E` are, -/// enabling safe concurrent read access after construction. The graph does not -/// support concurrent modification; build the graph single-threaded, then use -/// it immutably from multiple threads. -/// -/// # Examples -/// -/// ## Creating a Simple Graph -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, EdgeId}; -/// -/// let mut graph: DirectedGraph<&str, i32> = DirectedGraph::new(); -/// -/// // Add nodes -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// -/// // Add edges with weights -/// graph.add_edge(a, b, 10); -/// graph.add_edge(b, c, 20); -/// graph.add_edge(a, c, 30); -/// -/// assert_eq!(graph.node_count(), 3); -/// assert_eq!(graph.edge_count(), 3); -/// ``` -/// -/// ## Traversing the Graph -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, Successors, Predecessors}; -/// -/// let mut graph: DirectedGraph = DirectedGraph::new(); -/// let a = graph.add_node('A'); -/// let b = graph.add_node('B'); -/// let c = graph.add_node('C'); -/// -/// graph.add_edge(a, b, ()); -/// graph.add_edge(a, c, ()); -/// -/// // Forward traversal: get successors of A -/// let successors: Vec<_> = graph.successors(a).collect(); -/// assert_eq!(successors.len(), 2); -/// -/// // Backward traversal: get predecessors of B -/// let predecessors: Vec<_> = graph.predecessors(b).collect(); -/// assert_eq!(predecessors, vec![a]); -/// ``` -#[derive(Debug, Clone)] -pub struct DirectedGraph<'a, N: Clone, E> { - /// Node data storage (borrowed or owned) - nodes: Cow<'a, [N]>, - /// Edge data storage - edges: Vec>, - /// Outgoing edges per node (adjacency list for successors) - outgoing: Vec>, - /// Incoming edges per node (adjacency list for predecessors) - incoming: Vec>, -} - -impl Default for DirectedGraph<'static, N, E> { - fn default() -> Self { - Self::new() - } -} - -impl DirectedGraph<'static, N, E> { - /// Creates a new empty directed graph with owned storage. - /// - /// The graph starts with no nodes or edges. Use [`add_node`](Self::add_node) - /// and [`add_edge`](Self::add_edge) to build up the graph structure. - /// - /// # Returns - /// - /// A new empty `DirectedGraph`. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let graph: DirectedGraph = DirectedGraph::new(); - /// assert!(graph.is_empty()); - /// ``` - #[must_use] - pub fn new() -> Self { - DirectedGraph { - nodes: Cow::Owned(Vec::new()), - edges: Vec::new(), - outgoing: Vec::new(), - incoming: Vec::new(), - } - } - - /// Creates a new directed graph with pre-allocated capacity. - /// - /// Pre-allocating capacity can improve performance when the approximate - /// size of the graph is known in advance, by avoiding reallocations - /// during construction. - /// - /// # Arguments - /// - /// * `node_capacity` - Expected number of nodes - /// * `edge_capacity` - Expected number of edges - /// - /// # Returns - /// - /// A new empty `DirectedGraph` with pre-allocated storage. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// // Pre-allocate for a graph with ~100 nodes and ~300 edges - /// let graph: DirectedGraph = DirectedGraph::with_capacity(100, 300); - /// assert!(graph.is_empty()); - /// ``` - #[must_use] - pub fn with_capacity(node_capacity: usize, edge_capacity: usize) -> Self { - DirectedGraph { - nodes: Cow::Owned(Vec::with_capacity(node_capacity)), - edges: Vec::with_capacity(edge_capacity), - outgoing: Vec::with_capacity(node_capacity), - incoming: Vec::with_capacity(node_capacity), - } - } - - /// Adds a new node with the given data to the graph. - /// - /// The node is assigned the next sequential `NodeId`, starting from 0. - /// The returned `NodeId` can be used to reference this node when adding - /// edges or querying the graph. - /// - /// # Arguments - /// - /// * `data` - The data to associate with this node - /// - /// # Returns - /// - /// The `NodeId` assigned to the new node. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// - /// let first = graph.add_node("first"); - /// let second = graph.add_node("second"); - /// - /// assert_eq!(first, NodeId::new(0)); - /// assert_eq!(second, NodeId::new(1)); - /// assert_eq!(graph.node_count(), 2); - /// ``` - pub fn add_node(&mut self, data: N) -> NodeId { - let id = NodeId::new(self.nodes.len()); - self.nodes.to_mut().push(data); - self.outgoing.push(Vec::new()); - self.incoming.push(Vec::new()); - id - } - - /// Returns a mutable reference to the data associated with the given node. - /// - /// # Arguments - /// - /// * `node` - The node to look up - /// - /// # Returns - /// - /// `Some(&mut N)` if the node exists, `None` otherwise. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph = DirectedGraph::new(); - /// let node = graph.add_node(String::from("hello")); - /// - /// if let Some(data) = graph.node_mut(node) { - /// data.push_str(" world"); - /// } - /// - /// assert_eq!(graph.node(node), Some(&String::from("hello world"))); - /// ``` - pub fn node_mut(&mut self, node: NodeId) -> Option<&mut N> { - self.nodes.to_mut().get_mut(node.index()) - } -} - -/// Methods for creating graphs with borrowed node storage. -impl<'a, N: Clone, E> DirectedGraph<'a, N, E> { - /// Creates a new directed graph borrowing nodes from an external slice. - /// - /// This enables zero-copy graph construction when nodes already exist - /// in external storage (e.g., basic blocks from a method). - /// - /// The returned graph has borrowed node storage. Edges can still be added - /// normally as they are always owned. To get an owned graph, use - /// [`into_owned`](Self::into_owned). - /// - /// # Arguments - /// - /// * `nodes` - A slice of nodes to borrow - /// - /// # Returns - /// - /// A new `DirectedGraph` with borrowed nodes and empty adjacency lists. - /// The caller must add edges separately. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let nodes = vec!["A", "B", "C"]; - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::from_nodes_borrowed(&nodes); - /// - /// // Add edges - /// graph.add_edge(NodeId::new(0), NodeId::new(1), ())?; - /// ``` - #[must_use] - pub fn from_nodes_borrowed(nodes: &'a [N]) -> Self { - let node_count = nodes.len(); - DirectedGraph { - nodes: Cow::Borrowed(nodes), - edges: Vec::new(), - outgoing: vec![Vec::new(); node_count], - incoming: vec![Vec::new(); node_count], - } - } - - /// Converts this graph into an owned graph with `'static` lifetime. - /// - /// If the nodes are already owned, this is efficient. If borrowed, - /// this clones the node data. - /// - /// # Returns - /// - /// An owned `DirectedGraph<'static, N, E>`. - #[must_use] - pub fn into_owned(self) -> DirectedGraph<'static, N, E> { - DirectedGraph { - nodes: Cow::Owned(self.nodes.into_owned()), - edges: self.edges, - outgoing: self.outgoing, - incoming: self.incoming, - } - } - - /// Returns `true` if the graph owns its node data. - /// - /// Returns `false` if nodes are borrowed from external storage. - #[must_use] - pub fn is_owned(&self) -> bool { - matches!(self.nodes, Cow::Owned(_)) - } - - /// Returns a reference to the data associated with the given node. - /// - /// # Arguments - /// - /// * `node` - The node to look up - /// - /// # Returns - /// - /// `Some(&N)` if the node exists, `None` otherwise. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// let node = graph.add_node("hello"); - /// - /// assert_eq!(graph.node(node), Some(&"hello")); - /// ``` - #[must_use] - pub fn node(&self, node: NodeId) -> Option<&N> { - self.nodes.get(node.index()) - } - - /// Returns the number of nodes in the graph. - /// - /// # Returns - /// - /// The total node count. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph = DirectedGraph::new(); - /// assert_eq!(graph.node_count(), 0); - /// - /// graph.add_node(1); - /// graph.add_node(2); - /// assert_eq!(graph.node_count(), 2); - /// ``` - #[must_use] - pub fn node_count(&self) -> usize { - self.nodes.len() - } - - /// Returns an iterator over all node identifiers in the graph. - /// - /// Nodes are yielded in the order they were added (ascending `NodeId`). - /// - /// # Returns - /// - /// An iterator yielding each `NodeId` in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph = DirectedGraph::new(); - /// graph.add_node('A'); - /// graph.add_node('B'); - /// graph.add_node('C'); - /// - /// let ids: Vec = graph.node_ids().collect(); - /// assert_eq!(ids, vec![NodeId::new(0), NodeId::new(1), NodeId::new(2)]); - /// ``` - pub fn node_ids(&self) -> impl Iterator + '_ { - (0..self.nodes.len()).map(NodeId::new) - } - - /// Returns an iterator over all nodes with their identifiers. - /// - /// This is useful when you need both the node data and its identifier. - /// - /// # Returns - /// - /// An iterator yielding `(NodeId, &N)` tuples. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// graph.add_node("first"); - /// graph.add_node("second"); - /// - /// for (id, data) in graph.nodes() { - /// println!("{}: {}", id, data); - /// } - /// ``` - pub fn nodes(&self) -> impl Iterator + '_ { - self.nodes - .iter() - .enumerate() - .map(|(i, data)| (NodeId::new(i), data)) - } - - /// Adds a directed edge from `source` to `target` with the given data. - /// - /// The edge is assigned the next sequential `EdgeId`, starting from 0. - /// Multiple edges between the same pair of nodes are allowed (multigraph). - /// - /// # Arguments - /// - /// * `source` - The source node of the edge - /// * `target` - The target node of the edge - /// * `data` - The data to associate with this edge - /// - /// # Returns - /// - /// The `EdgeId` assigned to the new edge. - /// - /// # Panics - /// - /// Panics if either `source` or `target` is not a valid node in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, EdgeId}; - /// - /// let mut graph: DirectedGraph<&str, &str> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// - /// let edge = graph.add_edge(a, b, "A->B")?; - /// assert_eq!(edge, EdgeId::new(0)); - /// assert_eq!(graph.edge_count(), 1); - /// # Ok::<(), dotscope::Error>(()) - /// ``` - /// - /// # Errors - /// - /// Returns [`crate::Error::GraphError`] if either `source` or `target` node does not exist - /// in the graph. - pub fn add_edge(&mut self, source: NodeId, target: NodeId, data: E) -> Result { - if source.index() >= self.nodes.len() { - return Err(Error::GraphError(format!( - "source node {} does not exist in graph with {} nodes", - source, - self.nodes.len() - ))); - } - if target.index() >= self.nodes.len() { - return Err(Error::GraphError(format!( - "target node {} does not exist in graph with {} nodes", - target, - self.nodes.len() - ))); - } - - let id = EdgeId::new(self.edges.len()); - self.edges.push(EdgeData { - source, - target, - data, - }); - - self.outgoing - .get_mut(source.index()) - .ok_or_else(|| { - Error::GraphError(format!( - "outgoing adjacency missing for source node {source}" - )) - })? - .push(id); - self.incoming - .get_mut(target.index()) - .ok_or_else(|| { - Error::GraphError(format!( - "incoming adjacency missing for target node {target}" - )) - })? - .push(id); - - Ok(id) - } - - /// Returns a reference to the data associated with the given edge. - /// - /// # Arguments - /// - /// * `edge` - The edge to look up - /// - /// # Returns - /// - /// `Some(&E)` if the edge exists, `None` otherwise. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<(), &str> = DirectedGraph::new(); - /// let a = graph.add_node(()); - /// let b = graph.add_node(()); - /// let edge = graph.add_edge(a, b, "label"); - /// - /// assert_eq!(graph.edge(edge), Some(&"label")); - /// ``` - #[must_use] - pub fn edge(&self, edge: EdgeId) -> Option<&E> { - self.edges.get(edge.index()).map(|e| &e.data) - } - - /// Returns a mutable reference to the data associated with the given edge. - /// - /// # Arguments - /// - /// * `edge` - The edge to look up - /// - /// # Returns - /// - /// `Some(&mut E)` if the edge exists, `None` otherwise. - pub fn edge_mut(&mut self, edge: EdgeId) -> Option<&mut E> { - self.edges.get_mut(edge.index()).map(|e| &mut e.data) - } - - /// Returns the source and target nodes of the given edge. - /// - /// # Arguments - /// - /// * `edge` - The edge to look up - /// - /// # Returns - /// - /// `Some((source, target))` if the edge exists, `None` otherwise. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// let edge = graph.add_edge(a, b, ()); - /// - /// assert_eq!(graph.edge_endpoints(edge), Some((a, b))); - /// ``` - #[must_use] - pub fn edge_endpoints(&self, edge: EdgeId) -> Option<(NodeId, NodeId)> { - self.edges.get(edge.index()).map(|e| (e.source, e.target)) - } - - /// Returns the number of edges in the graph. - /// - /// # Returns - /// - /// The total edge count. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - /// let a = graph.add_node(()); - /// let b = graph.add_node(()); - /// - /// assert_eq!(graph.edge_count(), 0); - /// - /// graph.add_edge(a, b, ()); - /// assert_eq!(graph.edge_count(), 1); - /// ``` - #[must_use] - pub fn edge_count(&self) -> usize { - self.edges.len() - } - - /// Returns an iterator over all edge identifiers in the graph. - /// - /// Edges are yielded in the order they were added (ascending `EdgeId`). - /// - /// # Returns - /// - /// An iterator yielding each `EdgeId` in the graph. - pub fn edge_ids(&self) -> impl Iterator + '_ { - (0..self.edges.len()).map(EdgeId::new) - } - - /// Returns an iterator over all edges with their identifiers. - /// - /// # Returns - /// - /// An iterator yielding `(EdgeId, &E)` tuples. - pub fn edges(&self) -> impl Iterator + '_ { - self.edges - .iter() - .enumerate() - .map(|(i, e)| (EdgeId::new(i), &e.data)) - } - - /// Returns an iterator over the successors of the given node. - /// - /// Successors are nodes that are targets of edges originating from this node. - /// - /// # Arguments - /// - /// * `node` - The node whose successors to iterate - /// - /// # Returns - /// - /// An iterator yielding the `NodeId` of each successor. - /// - /// # Panics - /// - /// Panics if `node` is not a valid node in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// let c = graph.add_node("C"); - /// - /// graph.add_edge(a, b, ()); - /// graph.add_edge(a, c, ()); - /// - /// let successors: Vec = graph.successors(a).collect(); - /// assert_eq!(successors.len(), 2); - /// ``` - pub fn successors(&self, node: NodeId) -> impl Iterator + '_ { - self.outgoing - .get(node.index()) - .into_iter() - .flatten() - .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.target)) - } - - /// Returns an iterator over the predecessors of the given node. - /// - /// Predecessors are nodes that are sources of edges targeting this node. - /// - /// # Arguments - /// - /// * `node` - The node whose predecessors to iterate - /// - /// # Returns - /// - /// An iterator yielding the `NodeId` of each predecessor. - /// - /// # Panics - /// - /// Panics if `node` is not a valid node in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// let c = graph.add_node("C"); - /// - /// graph.add_edge(a, c, ()); - /// graph.add_edge(b, c, ()); - /// - /// let predecessors: Vec = graph.predecessors(c).collect(); - /// assert_eq!(predecessors.len(), 2); - /// ``` - pub fn predecessors(&self, node: NodeId) -> impl Iterator + '_ { - self.incoming - .get(node.index()) - .into_iter() - .flatten() - .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.source)) - } - - /// Returns an iterator over outgoing edges from the given node. - /// - /// This provides access to both the edge ID and edge data for more detailed - /// edge inspection than [`successors`](Self::successors). - /// - /// # Arguments - /// - /// * `node` - The node whose outgoing edges to iterate - /// - /// # Returns - /// - /// An iterator yielding `(EdgeId, &E)` tuples for each outgoing edge. - /// - /// # Panics - /// - /// Panics if `node` is not a valid node in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<&str, i32> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// - /// graph.add_edge(a, b, 42); - /// - /// for (edge_id, weight) in graph.outgoing_edges(a) { - /// println!("Edge {} has weight {}", edge_id, weight); - /// } - /// ``` - pub fn outgoing_edges(&self, node: NodeId) -> impl Iterator + '_ { - self.outgoing - .get(node.index()) - .into_iter() - .flatten() - .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| (edge_id, &e.data))) - } - - /// Returns an iterator over incoming edges to the given node. - /// - /// This provides access to both the edge ID and edge data for more detailed - /// edge inspection than [`predecessors`](Self::predecessors). - /// - /// # Arguments - /// - /// * `node` - The node whose incoming edges to iterate - /// - /// # Returns - /// - /// An iterator yielding `(EdgeId, &E)` tuples for each incoming edge. - /// - /// # Panics - /// - /// Panics if `node` is not a valid node in the graph. - pub fn incoming_edges(&self, node: NodeId) -> impl Iterator + '_ { - self.incoming - .get(node.index()) - .into_iter() - .flatten() - .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| (edge_id, &e.data))) - } - - /// Returns the out-degree (number of outgoing edges) of a node. - /// - /// # Arguments - /// - /// * `node` - The node to query - /// - /// # Returns - /// - /// The number of outgoing edges from this node. - /// - /// # Panics - /// - /// Panics if `node` is not a valid node in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - /// let a = graph.add_node(()); - /// let b = graph.add_node(()); - /// let c = graph.add_node(()); - /// - /// graph.add_edge(a, b, ()); - /// graph.add_edge(a, c, ()); - /// - /// assert_eq!(graph.out_degree(a), 2); - /// assert_eq!(graph.out_degree(b), 0); - /// ``` - #[must_use] - pub fn out_degree(&self, node: NodeId) -> usize { - self.outgoing.get(node.index()).map_or(0, Vec::len) - } - - /// Returns the in-degree (number of incoming edges) of a node. - /// - /// # Arguments - /// - /// * `node` - The node to query - /// - /// # Returns - /// - /// The number of incoming edges to this node. - /// - /// # Panics - /// - /// Panics if `node` is not a valid node in the graph. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - /// let a = graph.add_node(()); - /// let b = graph.add_node(()); - /// let c = graph.add_node(()); - /// - /// graph.add_edge(a, c, ()); - /// graph.add_edge(b, c, ()); - /// - /// assert_eq!(graph.in_degree(c), 2); - /// assert_eq!(graph.in_degree(a), 0); - /// ``` - #[must_use] - pub fn in_degree(&self, node: NodeId) -> usize { - self.incoming.get(node.index()).map_or(0, Vec::len) - } - - /// Returns `true` if the graph contains no nodes. - /// - /// # Returns - /// - /// `true` if the graph has zero nodes, `false` otherwise. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::DirectedGraph; - /// - /// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - /// assert!(graph.is_empty()); - /// - /// graph.add_node(()); - /// assert!(!graph.is_empty()); - /// ``` - #[must_use] - pub fn is_empty(&self) -> bool { - self.nodes.is_empty() - } - - /// Returns an iterator over entry nodes (nodes with no incoming edges). - /// - /// Entry nodes have in-degree of zero and are potential starting points - /// for graph traversal. - /// - /// # Returns - /// - /// An iterator yielding the `NodeId` of each entry node. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// let c = graph.add_node("C"); - /// - /// graph.add_edge(a, b, ()); - /// graph.add_edge(a, c, ()); - /// - /// let entries: Vec = graph.entry_nodes().collect(); - /// assert_eq!(entries, vec![a]); - /// ``` - pub fn entry_nodes(&self) -> impl Iterator + '_ { - self.node_ids().filter(|&node| self.in_degree(node) == 0) - } - - /// Returns an iterator over exit nodes (nodes with no outgoing edges). - /// - /// Exit nodes have out-degree of zero and represent terminal points - /// in the graph. - /// - /// # Returns - /// - /// An iterator yielding the `NodeId` of each exit node. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - /// let a = graph.add_node("A"); - /// let b = graph.add_node("B"); - /// let c = graph.add_node("C"); - /// - /// graph.add_edge(a, b, ()); - /// graph.add_edge(a, c, ()); - /// - /// let exits: Vec = graph.exit_nodes().collect(); - /// assert_eq!(exits.len(), 2); - /// assert!(exits.contains(&b)); - /// assert!(exits.contains(&c)); - /// ``` - pub fn exit_nodes(&self) -> impl Iterator + '_ { - self.node_ids().filter(|&node| self.out_degree(node) == 0) - } - - /// Checks if the given node ID is valid for this graph. - /// - /// # Arguments - /// - /// * `node` - The node ID to check - /// - /// # Returns - /// - /// `true` if the node exists in the graph, `false` otherwise. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::{DirectedGraph, NodeId}; - /// - /// let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - /// let a = graph.add_node(()); - /// - /// assert!(graph.contains_node(a)); - /// assert!(!graph.contains_node(NodeId::new(999))); - /// ``` - #[must_use] - pub fn contains_node(&self, node: NodeId) -> bool { - node.index() < self.nodes.len() - } - - /// Checks if the given edge ID is valid for this graph. - /// - /// # Arguments - /// - /// * `edge` - The edge ID to check - /// - /// # Returns - /// - /// `true` if the edge exists in the graph, `false` otherwise. - #[must_use] - pub fn contains_edge(&self, edge: EdgeId) -> bool { - edge.index() < self.edges.len() - } -} - -// Implement the GraphBase trait -impl GraphBase for DirectedGraph<'_, N, E> { - fn node_count(&self) -> usize { - self.nodes.len() - } - - fn node_ids(&self) -> impl Iterator { - (0..self.nodes.len()).map(NodeId::new) - } -} - -// Implement the Successors trait -impl Successors for DirectedGraph<'_, N, E> { - fn successors(&self, node: NodeId) -> impl Iterator { - self.outgoing - .get(node.index()) - .into_iter() - .flatten() - .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.target)) - } -} - -// Implement the Predecessors trait -impl Predecessors for DirectedGraph<'_, N, E> { - fn predecessors(&self, node: NodeId) -> impl Iterator { - self.incoming - .get(node.index()) - .into_iter() - .flatten() - .filter_map(|&edge_id| self.edges.get(edge_id.index()).map(|e| e.source)) - } -} - -#[cfg(test)] -mod tests { - use crate::utils::graph::{ - directed::DirectedGraph, - edge::EdgeId, - node::NodeId, - traits::{GraphBase, Predecessors, Successors}, - }; - - /// Creates a simple linear graph: A -> B -> C - fn create_linear_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph - } - - /// Creates a diamond graph: A -> B, A -> C, B -> D, C -> D - fn create_diamond_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - let d = graph.add_node("D"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(a, c, ()).unwrap(); - graph.add_edge(b, d, ()).unwrap(); - graph.add_edge(c, d, ()).unwrap(); - graph - } - - /// Creates a graph with a cycle: A -> B -> C -> A - fn create_cycle_graph() -> DirectedGraph<'static, &'static str, ()> { - let mut graph = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - graph.add_edge(a, b, ()).unwrap(); - graph.add_edge(b, c, ()).unwrap(); - graph.add_edge(c, a, ()).unwrap(); - graph - } - - #[test] - fn test_new_graph_is_empty() { - let graph: DirectedGraph<(), ()> = DirectedGraph::new(); - assert!(graph.is_empty()); - assert_eq!(graph.node_count(), 0); - assert_eq!(graph.edge_count(), 0); - } - - #[test] - fn test_with_capacity() { - let graph: DirectedGraph = DirectedGraph::with_capacity(100, 200); - assert!(graph.is_empty()); - // Capacity is internal; just verify it works - } - - #[test] - fn test_default() { - let graph: DirectedGraph<(), ()> = DirectedGraph::default(); - assert!(graph.is_empty()); - } - - #[test] - fn test_add_node() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - - let a = graph.add_node("A"); - assert_eq!(a, NodeId::new(0)); - assert_eq!(graph.node_count(), 1); - - let b = graph.add_node("B"); - assert_eq!(b, NodeId::new(1)); - assert_eq!(graph.node_count(), 2); - } - - #[test] - fn test_node_access() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("hello"); - - assert_eq!(graph.node(a), Some(&"hello")); - assert_eq!(graph.node(NodeId::new(999)), None); - } - - #[test] - fn test_node_mut() { - let mut graph: DirectedGraph = DirectedGraph::new(); - let a = graph.add_node(String::from("hello")); - - if let Some(data) = graph.node_mut(a) { - data.push_str(" world"); - } - - assert_eq!(graph.node(a), Some(&String::from("hello world"))); - } - - #[test] - fn test_node_ids_iterator() { - let mut graph: DirectedGraph = DirectedGraph::new(); - graph.add_node('A'); - graph.add_node('B'); - graph.add_node('C'); - - let ids: Vec = graph.node_ids().collect(); - assert_eq!(ids, vec![NodeId::new(0), NodeId::new(1), NodeId::new(2)]); - } - - #[test] - fn test_nodes_iterator() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - graph.add_node("A"); - graph.add_node("B"); - - let nodes: Vec<(NodeId, &&str)> = graph.nodes().collect(); - assert_eq!(nodes.len(), 2); - assert_eq!(nodes[0], (NodeId::new(0), &"A")); - assert_eq!(nodes[1], (NodeId::new(1), &"B")); - } - - #[test] - fn test_add_edge() { - let mut graph: DirectedGraph<&str, &str> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - - let edge = graph.add_edge(a, b, "A->B").unwrap(); - assert_eq!(edge, EdgeId::new(0)); - assert_eq!(graph.edge_count(), 1); - } - - #[test] - fn test_edge_access() { - let mut graph: DirectedGraph<(), &str> = DirectedGraph::new(); - let a = graph.add_node(()); - let b = graph.add_node(()); - let edge = graph.add_edge(a, b, "label").unwrap(); - - assert_eq!(graph.edge(edge), Some(&"label")); - assert_eq!(graph.edge(EdgeId::new(999)), None); - } - - #[test] - fn test_edge_endpoints() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let edge = graph.add_edge(a, b, ()).unwrap(); - - assert_eq!(graph.edge_endpoints(edge), Some((a, b))); - assert_eq!(graph.edge_endpoints(EdgeId::new(999)), None); - } - - #[test] - fn test_multiple_edges() { - let mut graph: DirectedGraph<&str, i32> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - - // Allow multiple edges between same nodes (multigraph) - let e1 = graph.add_edge(a, b, 1).unwrap(); - let e2 = graph.add_edge(a, b, 2).unwrap(); - - assert_eq!(graph.edge_count(), 2); - assert_eq!(graph.edge(e1), Some(&1)); - assert_eq!(graph.edge(e2), Some(&2)); - } - - #[test] - fn test_self_loop() { - let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); - let a = graph.add_node("A"); - - let edge = graph.add_edge(a, a, ()).unwrap(); - assert_eq!(graph.edge_endpoints(edge), Some((a, a))); - assert_eq!(graph.out_degree(a), 1); - assert_eq!(graph.in_degree(a), 1); - } - - #[test] - fn test_add_edge_invalid_source() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - let result = graph.add_edge(NodeId::new(999), a, ()); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("source node")); - } - - #[test] - fn test_add_edge_invalid_target() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - let result = graph.add_edge(a, NodeId::new(999), ()); - assert!(result.is_err()); - assert!(result.unwrap_err().to_string().contains("target node")); - } - - #[test] - fn test_successors() { - let graph = create_diamond_graph(); - let a = NodeId::new(0); - - let successors: Vec = graph.successors(a).collect(); - assert_eq!(successors.len(), 2); - assert!(successors.contains(&NodeId::new(1))); // B - assert!(successors.contains(&NodeId::new(2))); // C - } - - #[test] - fn test_predecessors() { - let graph = create_diamond_graph(); - let d = NodeId::new(3); - - let predecessors: Vec = graph.predecessors(d).collect(); - assert_eq!(predecessors.len(), 2); - assert!(predecessors.contains(&NodeId::new(1))); // B - assert!(predecessors.contains(&NodeId::new(2))); // C - } - - #[test] - fn test_outgoing_edges() { - let mut graph: DirectedGraph<&str, i32> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - - graph.add_edge(a, b, 10).unwrap(); - graph.add_edge(a, c, 20).unwrap(); - - let outgoing: Vec<(EdgeId, &i32)> = graph.outgoing_edges(a).collect(); - assert_eq!(outgoing.len(), 2); - - let weights: Vec = outgoing.iter().map(|(_, &w)| w).collect(); - assert!(weights.contains(&10)); - assert!(weights.contains(&20)); - } - - #[test] - fn test_incoming_edges() { - let mut graph: DirectedGraph<&str, i32> = DirectedGraph::new(); - let a = graph.add_node("A"); - let b = graph.add_node("B"); - let c = graph.add_node("C"); - - graph.add_edge(a, c, 10).unwrap(); - graph.add_edge(b, c, 20).unwrap(); - - let incoming: Vec<(EdgeId, &i32)> = graph.incoming_edges(c).collect(); - assert_eq!(incoming.len(), 2); - } - - #[test] - fn test_out_degree() { - let graph = create_diamond_graph(); - - assert_eq!(graph.out_degree(NodeId::new(0)), 2); // A has 2 outgoing - assert_eq!(graph.out_degree(NodeId::new(1)), 1); // B has 1 outgoing - assert_eq!(graph.out_degree(NodeId::new(3)), 0); // D has 0 outgoing - } - - #[test] - fn test_in_degree() { - let graph = create_diamond_graph(); - - assert_eq!(graph.in_degree(NodeId::new(0)), 0); // A has 0 incoming - assert_eq!(graph.in_degree(NodeId::new(1)), 1); // B has 1 incoming - assert_eq!(graph.in_degree(NodeId::new(3)), 2); // D has 2 incoming - } - - #[test] - fn test_entry_nodes() { - let graph = create_diamond_graph(); - let entries: Vec = graph.entry_nodes().collect(); - - assert_eq!(entries.len(), 1); - assert_eq!(entries[0], NodeId::new(0)); // Only A is entry - } - - #[test] - fn test_exit_nodes() { - let graph = create_diamond_graph(); - let exits: Vec = graph.exit_nodes().collect(); - - assert_eq!(exits.len(), 1); - assert_eq!(exits[0], NodeId::new(3)); // Only D is exit - } - - #[test] - fn test_entry_nodes_with_cycle() { - let graph = create_cycle_graph(); - let entries: Vec = graph.entry_nodes().collect(); - - // No entry nodes in a pure cycle - assert!(entries.is_empty()); - } - - #[test] - fn test_contains_node() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - - assert!(graph.contains_node(a)); - assert!(!graph.contains_node(NodeId::new(999))); - } - - #[test] - fn test_contains_edge() { - let mut graph: DirectedGraph<(), ()> = DirectedGraph::new(); - let a = graph.add_node(()); - let b = graph.add_node(()); - let edge = graph.add_edge(a, b, ()).unwrap(); - - assert!(graph.contains_edge(edge)); - assert!(!graph.contains_edge(EdgeId::new(999))); - } - - #[test] - fn test_graph_clone() { - let original = create_diamond_graph(); - let cloned = original.clone(); - - assert_eq!(original.node_count(), cloned.node_count()); - assert_eq!(original.edge_count(), cloned.edge_count()); - - // Verify data is independent - for node_id in original.node_ids() { - assert_eq!(original.node(node_id), cloned.node(node_id)); - } - } - - #[test] - fn test_graph_base_trait() { - fn use_graph_base(g: &G) -> usize { - g.node_count() - } - - let graph = create_linear_graph(); - assert_eq!(use_graph_base(&graph), 3); - } - - #[test] - fn test_successors_trait() { - fn use_successors(g: &G, node: NodeId) -> Vec { - g.successors(node).collect() - } - - let graph = create_linear_graph(); - let successors = use_successors(&graph, NodeId::new(0)); - assert_eq!(successors, vec![NodeId::new(1)]); - } - - #[test] - fn test_predecessors_trait() { - fn use_predecessors(g: &G, node: NodeId) -> Vec { - g.predecessors(node).collect() - } - - let graph = create_linear_graph(); - let predecessors = use_predecessors(&graph, NodeId::new(2)); - assert_eq!(predecessors, vec![NodeId::new(1)]); - } - - #[test] - fn test_large_graph() { - let mut graph: DirectedGraph = DirectedGraph::with_capacity(1000, 2000); - - // Create 1000 nodes - for i in 0..1000 { - graph.add_node(i); - } - - // Create edges: each node points to next - for i in 0..999 { - graph - .add_edge(NodeId::new(i), NodeId::new(i + 1), ()) - .unwrap(); - } - - assert_eq!(graph.node_count(), 1000); - assert_eq!(graph.edge_count(), 999); - - // Check first and last - assert_eq!(graph.out_degree(NodeId::new(0)), 1); - assert_eq!(graph.out_degree(NodeId::new(999)), 0); - assert_eq!(graph.in_degree(NodeId::new(0)), 0); - assert_eq!(graph.in_degree(NodeId::new(999)), 1); - } -} diff --git a/dotscope/src/utils/graph/edge.rs b/dotscope/src/utils/graph/edge.rs deleted file mode 100644 index eeab5653..00000000 --- a/dotscope/src/utils/graph/edge.rs +++ /dev/null @@ -1,280 +0,0 @@ -//! Edge identifier implementation for directed graphs. -//! -//! This module provides the [`EdgeId`] type, a strongly-typed identifier for edges -//! within a directed graph. The newtype wrapper provides type safety and prevents -//! accidental confusion between edge indices and other integer values. - -use std::fmt; - -/// A strongly-typed identifier for edges within a directed graph. -/// -/// `EdgeId` wraps a `usize` index, providing type safety to prevent -/// accidental mixing of edge indices with other integer values or node indices. -/// Edge IDs are assigned sequentially starting from 0 when edges are added to a graph. -/// -/// # Usage -/// -/// Edge IDs are created by [`DirectedGraph::add_edge`](crate::utils::graph::DirectedGraph::add_edge) -/// and should not typically be constructed manually. They are used to: -/// -/// - Reference edges when querying edge data -/// - Look up edge endpoints (source and target nodes) -/// - Store analysis results indexed by edge -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, EdgeId}; -/// -/// let mut graph: DirectedGraph<&str, &str> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let edge: EdgeId = graph.add_edge(a, b, "A->B"); -/// -/// // EdgeIds can be used to query edge information -/// assert_eq!(graph.edge(edge), Some(&"A->B")); -/// assert_eq!(graph.edge_endpoints(edge), Some((a, b))); -/// ``` -/// -/// # Thread Safety -/// -/// `EdgeId` is [`Copy`], [`Send`], and [`Sync`], enabling efficient passing between -/// threads and use in concurrent data structures. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct EdgeId(pub(crate) usize); - -impl EdgeId { - /// Creates a new `EdgeId` from a raw index value. - /// - /// This constructor is primarily intended for internal use and testing. - /// Normal usage should obtain `EdgeId` values from [`DirectedGraph::add_edge`](crate::utils::graph::DirectedGraph::add_edge). - /// - /// # Arguments - /// - /// * `index` - The raw edge index (0-based) - /// - /// # Returns - /// - /// A new `EdgeId` wrapping the provided index. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::EdgeId; - /// - /// let edge = EdgeId::new(0); - /// assert_eq!(edge.index(), 0); - /// ``` - #[must_use] - #[inline] - pub const fn new(index: usize) -> Self { - EdgeId(index) - } - - /// Returns the raw index value of this edge identifier. - /// - /// The index is a 0-based position that can be used to index into vectors - /// or arrays that store per-edge data. - /// - /// # Returns - /// - /// The underlying index value. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::EdgeId; - /// - /// let edge = EdgeId::new(5); - /// assert_eq!(edge.index(), 5); - /// ``` - #[must_use] - #[inline] - pub const fn index(self) -> usize { - self.0 - } -} - -impl fmt::Debug for EdgeId { - /// Formats the edge ID for debugging output. - /// - /// The format shows the type name and index value for clear identification - /// in debug output and logging. - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "EdgeId({})", self.0) - } -} - -impl fmt::Display for EdgeId { - /// Formats the edge ID for user display. - /// - /// The display format shows just the prefix and index for compact output. - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "e{}", self.0) - } -} - -impl From for EdgeId { - /// Converts a raw `usize` index into an `EdgeId`. - /// - /// This conversion is provided for convenience but should be used carefully - /// to avoid creating invalid edge IDs that don't correspond to actual edges - /// in a graph. - #[inline] - fn from(index: usize) -> Self { - EdgeId(index) - } -} - -impl From for usize { - /// Extracts the raw index from an `EdgeId`. - #[inline] - fn from(edge: EdgeId) -> Self { - edge.0 - } -} - -#[cfg(test)] -mod tests { - use super::*; - use crate::utils::graph::NodeId; - use std::collections::{HashMap, HashSet}; - - #[test] - fn test_edge_id_new() { - let edge = EdgeId::new(42); - assert_eq!(edge.index(), 42); - } - - #[test] - fn test_edge_id_index() { - let edge = EdgeId::new(100); - assert_eq!(edge.index(), 100); - } - - #[test] - fn test_edge_id_equality() { - let edge1 = EdgeId::new(5); - let edge2 = EdgeId::new(5); - let edge3 = EdgeId::new(10); - - assert_eq!(edge1, edge2); - assert_ne!(edge1, edge3); - } - - #[test] - fn test_edge_id_ordering() { - let edge1 = EdgeId::new(1); - let edge2 = EdgeId::new(2); - let edge3 = EdgeId::new(3); - - assert!(edge1 < edge2); - assert!(edge2 < edge3); - assert!(edge1 < edge3); - - let mut edges = vec![edge3, edge1, edge2]; - edges.sort(); - assert_eq!(edges, vec![edge1, edge2, edge3]); - } - - #[test] - fn test_edge_id_hash() { - let mut set: HashSet = HashSet::new(); - let edge1 = EdgeId::new(1); - let edge2 = EdgeId::new(2); - let edge1_dup = EdgeId::new(1); - - set.insert(edge1); - set.insert(edge2); - set.insert(edge1_dup); // Should not add duplicate - - assert_eq!(set.len(), 2); - assert!(set.contains(&edge1)); - assert!(set.contains(&edge2)); - } - - #[test] - fn test_edge_id_as_map_key() { - let mut map: HashMap = HashMap::new(); - let edge1 = EdgeId::new(1); - let edge2 = EdgeId::new(2); - - map.insert(edge1, "first"); - map.insert(edge2, "second"); - - assert_eq!(map.get(&edge1), Some(&"first")); - assert_eq!(map.get(&edge2), Some(&"second")); - assert_eq!(map.get(&EdgeId::new(3)), None); - } - - #[test] - fn test_edge_id_copy_semantics() { - let edge1 = EdgeId::new(42); - let edge2 = edge1; // Copy - - assert_eq!(edge1, edge2); - assert_eq!(edge1.index(), 42); - assert_eq!(edge2.index(), 42); - } - - #[test] - fn test_edge_id_from_usize() { - let edge: EdgeId = 123usize.into(); - assert_eq!(edge.index(), 123); - } - - #[test] - fn test_edge_id_into_usize() { - let edge = EdgeId::new(789); - let value: usize = edge.into(); - assert_eq!(value, 789); - } - - #[test] - fn test_edge_id_debug_format() { - let edge = EdgeId::new(42); - let debug_str = format!("{edge:?}"); - assert_eq!(debug_str, "EdgeId(42)"); - } - - #[test] - fn test_edge_id_display_format() { - let edge = EdgeId::new(42); - let display_str = format!("{edge}"); - assert_eq!(display_str, "e42"); - } - - #[test] - fn test_edge_id_boundary_values() { - // Test zero - let zero = EdgeId::new(0); - assert_eq!(zero.index(), 0); - - // Test large value - let large = EdgeId::new(1_000_000); - assert_eq!(large.index(), 1_000_000); - } - - #[test] - fn test_edge_id_array_indexing() { - let weights = [1.5, 2.5, 3.5, 4.5]; - let edge = EdgeId::new(2); - - assert_eq!(weights[edge.index()], 3.5); - } - - #[test] - fn test_edge_id_distinct_from_node_id() { - // This test demonstrates that EdgeId and NodeId are distinct types - // and cannot be accidentally mixed (verified at compile time) - let node = NodeId::new(5); - let edge = EdgeId::new(5); - - // Both have the same underlying value but are different types - assert_eq!(node.index(), edge.index()); - - // The following would not compile, demonstrating type safety: - // let _: NodeId = edge; // Error: expected NodeId, found EdgeId - // let _: EdgeId = node; // Error: expected EdgeId, found NodeId - } -} diff --git a/dotscope/src/utils/graph/indexed.rs b/dotscope/src/utils/graph/indexed.rs deleted file mode 100644 index 61d1860c..00000000 --- a/dotscope/src/utils/graph/indexed.rs +++ /dev/null @@ -1,420 +0,0 @@ -//! Indexed graph wrapper for domain-typed nodes. -//! -//! This module provides [`IndexedGraph`], a convenience wrapper around [`DirectedGraph`] -//! that automatically handles the mapping between domain types (like `AssemblyIdentity` -//! or `TableId`) and internal `NodeId` indices. -//! -//! # Motivation -//! -//! When working with graph algorithms, domain code often needs to: -//! 1. Build a graph from domain-specific types -//! 2. Run algorithms that work with `NodeId` -//! 3. Map results back to domain types -//! -//! `IndexedGraph` encapsulates this pattern, providing a cleaner API. -//! -//! # Examples -//! -//! ```rust,ignore -//! use dotscope::utils::graph::{IndexedGraph, algorithms}; -//! -//! // Create a graph with string keys -//! let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); -//! -//! // Add nodes using domain types directly -//! graph.add_node("A"); -//! graph.add_node("B"); -//! graph.add_node("C"); -//! -//! // Add edges using domain types -//! graph.add_edge("A", "B", ()); -//! graph.add_edge("B", "C", ()); -//! graph.add_edge("C", "A", ()); // Creates a cycle -//! -//! // Run algorithms - results are automatically mapped back -//! if let Some(cycle) = graph.find_cycle_from("A") { -//! println!("Found cycle: {:?}", cycle); // ["A", "B", "C", "A"] -//! } -//! ``` - -use std::collections::HashMap; -use std::hash::Hash; - -use crate::{ - utils::graph::{algorithms, DirectedGraph, NodeId}, - Result, -}; - -/// A graph wrapper that provides automatic mapping between domain types and `NodeId`. -/// -/// `IndexedGraph` stores nodes indexed by keys of type `K` (which must be -/// `Hash + Eq + Clone`) and edges with data of type `E`. It maintains bidirectional -/// mappings for efficient lookups in both directions. -/// -/// # Type Parameters -/// -/// * `K` - The domain key type for nodes (e.g., `AssemblyIdentity`, `TableId`) -/// * `E` - The edge data type -/// -/// # Thread Safety -/// -/// `IndexedGraph` is `Send` and `Sync` when both `K` and `E` are. -#[derive(Debug, Clone)] -pub struct IndexedGraph -where - K: Hash + Eq + Clone, -{ - /// The underlying directed graph (nodes store unit type, keys are separate) - graph: DirectedGraph<'static, (), E>, - /// Map from domain key to `NodeId` - key_to_node: HashMap, - /// Map from `NodeId` to domain key - node_to_key: HashMap, -} - -impl Default for IndexedGraph -where - K: Hash + Eq + Clone, -{ - fn default() -> Self { - Self::new() - } -} - -impl IndexedGraph -where - K: Hash + Eq + Clone, -{ - /// Creates a new empty indexed graph. - #[must_use] - pub fn new() -> Self { - Self { - graph: DirectedGraph::new(), - key_to_node: HashMap::new(), - node_to_key: HashMap::new(), - } - } - - /// Creates a new indexed graph with pre-allocated capacity. - #[must_use] - pub fn with_capacity(node_capacity: usize, edge_capacity: usize) -> Self { - Self { - graph: DirectedGraph::with_capacity(node_capacity, edge_capacity), - key_to_node: HashMap::with_capacity(node_capacity), - node_to_key: HashMap::with_capacity(node_capacity), - } - } - - /// Adds a node with the given key, or returns the existing `NodeId` if already present. - /// - /// This method is idempotent - calling it multiple times with the same key - /// will always return the same `NodeId`. - /// - /// # Arguments - /// - /// * `key` - The domain key for this node - /// - /// # Returns - /// - /// The `NodeId` associated with this key. - pub fn add_node(&mut self, key: K) -> NodeId { - if let Some(&node_id) = self.key_to_node.get(&key) { - return node_id; - } - - let node_id = self.graph.add_node(()); - self.key_to_node.insert(key.clone(), node_id); - self.node_to_key.insert(node_id, key); - node_id - } - - /// Adds a directed edge between two nodes identified by their keys. - /// - /// If either node doesn't exist, it will be created automatically. - /// - /// # Arguments - /// - /// * `from` - The source node key - /// * `to` - The target node key - /// * `data` - The edge data - /// - /// # Returns - /// - /// * `Ok(true)` if a new edge was added - /// * `Ok(false)` if the edge already existed - /// * `Err(_)` if the edge could not be added - /// - /// # Errors - /// - /// Returns an error if the underlying graph operation fails. - pub fn add_edge(&mut self, from: K, to: K, data: E) -> Result - where - E: Clone, - { - let from_node = self.add_node(from); - let to_node = self.add_node(to); - - // Check if edge already exists - if self.graph.successors(from_node).any(|s| s == to_node) { - return Ok(false); - } - - self.graph.add_edge(from_node, to_node, data)?; - Ok(true) - } - - /// Returns the `NodeId` for a given key, if it exists. - #[must_use] - pub fn get_node_id(&self, key: &K) -> Option { - self.key_to_node.get(key).copied() - } - - /// Returns the key for a given `NodeId`, if it exists. - #[must_use] - pub fn get_key(&self, node_id: NodeId) -> Option<&K> { - self.node_to_key.get(&node_id) - } - - /// Returns the number of nodes in the graph. - #[must_use] - pub fn node_count(&self) -> usize { - self.graph.node_count() - } - - /// Returns the number of edges in the graph. - #[must_use] - pub fn edge_count(&self) -> usize { - self.graph.edge_count() - } - - /// Returns `true` if the graph contains no nodes. - #[must_use] - pub fn is_empty(&self) -> bool { - self.graph.is_empty() - } - - /// Returns a reference to the underlying `DirectedGraph`. - /// - /// This is useful when you need to pass the graph to algorithms that - /// work with `DirectedGraph` directly. - #[must_use] - pub fn inner(&self) -> &DirectedGraph<'static, (), E> { - &self.graph - } - - /// Returns an iterator over all keys in the graph. - pub fn keys(&self) -> impl Iterator { - self.key_to_node.keys() - } - - /// Maps a vector of `NodeId`s back to domain keys. - /// - /// Nodes that don't have a corresponding key are skipped. - #[must_use] - pub fn map_nodes_to_keys(&self, nodes: &[NodeId]) -> Vec { - nodes - .iter() - .filter_map(|node_id| self.node_to_key.get(node_id).cloned()) - .collect() - } - - /// Maps a vector of SCCs (each being a `Vec`) back to domain keys. - #[must_use] - pub fn map_sccs_to_keys(&self, sccs: &[Vec]) -> Vec> { - sccs.iter().map(|scc| self.map_nodes_to_keys(scc)).collect() - } -} - -// Algorithm convenience methods -impl IndexedGraph -where - K: Hash + Eq + Clone, -{ - /// Finds a cycle in the graph starting from the given key. - /// - /// Returns the cycle as a vector of domain keys if found, `None` otherwise. - #[must_use] - pub fn find_cycle_from(&self, start: &K) -> Option> { - let start_node = self.key_to_node.get(start)?; - let cycle_nodes = algorithms::find_cycle(&self.graph, *start_node)?; - Some(self.map_nodes_to_keys(&cycle_nodes)) - } - - /// Checks if the graph contains any cycle reachable from the given key. - #[must_use] - pub fn has_cycle_from(&self, start: &K) -> bool { - self.key_to_node - .get(start) - .is_some_and(|&start_node| algorithms::has_cycle(&self.graph, start_node)) - } - - /// Finds any cycle in the graph. - /// - /// Checks all nodes and returns the first cycle found. - #[must_use] - pub fn find_any_cycle(&self) -> Option> { - for &start_node in self.key_to_node.values() { - if let Some(cycle_nodes) = algorithms::find_cycle(&self.graph, start_node) { - return Some(self.map_nodes_to_keys(&cycle_nodes)); - } - } - None - } - - /// Computes strongly connected components. - /// - /// Returns SCCs as vectors of domain keys, in reverse topological order. - #[must_use] - pub fn strongly_connected_components(&self) -> Vec> { - let sccs = algorithms::strongly_connected_components(&self.graph); - self.map_sccs_to_keys(&sccs) - } - - /// Computes a topological ordering of the graph. - /// - /// Returns `Some(order)` if the graph is acyclic, `None` if it contains cycles. - #[must_use] - pub fn topological_sort(&self) -> Option> { - let order = algorithms::topological_sort(&self.graph)?; - Some(self.map_nodes_to_keys(&order)) - } -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn test_indexed_graph_basic() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - let a = graph.add_node("A"); - let b = graph.add_node("B"); - - assert_eq!(graph.node_count(), 2); - assert_eq!(graph.get_node_id(&"A"), Some(a)); - assert_eq!(graph.get_node_id(&"B"), Some(b)); - assert_eq!(graph.get_key(a), Some(&"A")); - assert_eq!(graph.get_key(b), Some(&"B")); - } - - #[test] - fn test_indexed_graph_idempotent_add() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - let a1 = graph.add_node("A"); - let a2 = graph.add_node("A"); // Same key - - assert_eq!(a1, a2); - assert_eq!(graph.node_count(), 1); - } - - #[test] - fn test_indexed_graph_add_edge() { - let mut graph: IndexedGraph<&str, i32> = IndexedGraph::new(); - - // Nodes created automatically - assert!(graph.add_edge("A", "B", 10).unwrap()); - assert!(graph.add_edge("B", "C", 20).unwrap()); - - assert_eq!(graph.node_count(), 3); - assert_eq!(graph.edge_count(), 2); - - // Duplicate edge not added - assert!(!graph.add_edge("A", "B", 10).unwrap()); - assert_eq!(graph.edge_count(), 2); - } - - #[test] - fn test_indexed_graph_find_cycle() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - graph.add_edge("A", "B", ()).unwrap(); - graph.add_edge("B", "C", ()).unwrap(); - graph.add_edge("C", "A", ()).unwrap(); // Creates cycle - - let cycle = graph.find_cycle_from(&"A"); - assert!(cycle.is_some()); - - let cycle = cycle.unwrap(); - assert!(cycle.contains(&"A")); - assert!(cycle.contains(&"B")); - assert!(cycle.contains(&"C")); - } - - #[test] - fn test_indexed_graph_no_cycle() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - graph.add_edge("A", "B", ()).unwrap(); - graph.add_edge("B", "C", ()).unwrap(); - // No back edge - - assert!(graph.find_cycle_from(&"A").is_none()); - assert!(!graph.has_cycle_from(&"A")); - } - - #[test] - fn test_indexed_graph_topological_sort() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - // A -> B -> D - // A -> C -> D - graph.add_edge("A", "B", ()).unwrap(); - graph.add_edge("A", "C", ()).unwrap(); - graph.add_edge("B", "D", ()).unwrap(); - graph.add_edge("C", "D", ()).unwrap(); - - let order = graph.topological_sort(); - assert!(order.is_some()); - - let order = order.unwrap(); - assert_eq!(order.len(), 4); - - // A must come before B, C; B and C must come before D - let pos = |k: &str| order.iter().position(|&x| x == k).unwrap(); - assert!(pos("A") < pos("B")); - assert!(pos("A") < pos("C")); - assert!(pos("B") < pos("D")); - assert!(pos("C") < pos("D")); - } - - #[test] - fn test_indexed_graph_topological_sort_with_cycle() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - graph.add_edge("A", "B", ()).unwrap(); - graph.add_edge("B", "A", ()).unwrap(); // Cycle - - assert!(graph.topological_sort().is_none()); - } - - #[test] - fn test_indexed_graph_scc() { - let mut graph: IndexedGraph<&str, ()> = IndexedGraph::new(); - - // Two SCCs: {A, B} and {C} - graph.add_edge("A", "B", ()).unwrap(); - graph.add_edge("B", "A", ()).unwrap(); // A <-> B cycle - graph.add_edge("B", "C", ()).unwrap(); - - let sccs = graph.strongly_connected_components(); - assert_eq!(sccs.len(), 2); - - // One SCC has 2 elements, one has 1 - let mut sizes: Vec = sccs.iter().map(|scc| scc.len()).collect(); - sizes.sort(); - assert_eq!(sizes, vec![1, 2]); - } - - #[test] - fn test_indexed_graph_with_integers() { - let mut graph: IndexedGraph = IndexedGraph::new(); - - graph.add_edge(1, 2, "one-two").unwrap(); - graph.add_edge(2, 3, "two-three").unwrap(); - - assert_eq!(graph.node_count(), 3); - assert!(graph.topological_sort().is_some()); - } -} diff --git a/dotscope/src/utils/graph/mod.rs b/dotscope/src/utils/graph/mod.rs deleted file mode 100644 index 35c2c2b9..00000000 --- a/dotscope/src/utils/graph/mod.rs +++ /dev/null @@ -1,122 +0,0 @@ -//! Generic directed graph infrastructure for program analysis. -//! -//! This module provides a reusable directed graph implementation designed for -//! program analysis tasks such as control flow graphs, call graphs, and dependency -//! analysis. The implementation prioritizes correctness, clear semantics, and -//! efficient algorithms over raw performance. -//! -//! # Architecture -//! -//! The graph module is organized into several components: -//! -//! - **Core Types**: [`NodeId`], [`EdgeId`], and [`DirectedGraph`] provide the fundamental -//! building blocks for graph representation -//! - **Algorithms**: Standard graph algorithms for traversal, dominator computation, -//! topological sorting, and cycle detection -//! - **Traits**: Abstraction traits enabling algorithms to work with different graph types -//! -//! # Design Principles -//! -//! ## Strongly-Typed Identifiers -//! -//! Node and edge identifiers use newtype wrappers to prevent accidental mixing of -//! indices and provide type safety at compile time. -//! -//! ## Immutable After Construction -//! -//! Graphs are designed to be built incrementally during construction, then treated -//! as immutable for analysis. This enables safe concurrent access without locks. -//! -//! ## Thread Safety -//! -//! All graph types are [`Send`] and [`Sync`] when their node and edge data types are, -//! enabling safe concurrent analysis across multiple threads. -//! -//! # Key Components -//! -//! - [`NodeId`] - Strongly-typed node identifier -//! - [`EdgeId`] - Strongly-typed edge identifier -//! - [`DirectedGraph`] - Core directed graph implementation with adjacency lists -//! - [`algorithms`] - Graph algorithms (traversal, dominators, SCC, etc.) -//! -//! # Usage Examples -//! -//! ## Creating a Simple Graph -//! -//! ```rust,ignore -//! use dotscope::graph::{DirectedGraph, NodeId}; -//! -//! // Create a diamond-shaped graph: A -> B, A -> C, B -> D, C -> D -//! let mut graph: DirectedGraph<&str, &str> = DirectedGraph::new(); -//! -//! let a = graph.add_node("A"); -//! let b = graph.add_node("B"); -//! let c = graph.add_node("C"); -//! let d = graph.add_node("D"); -//! -//! graph.add_edge(a, b, "A->B"); -//! graph.add_edge(a, c, "A->C"); -//! graph.add_edge(b, d, "B->D"); -//! graph.add_edge(c, d, "C->D"); -//! -//! assert_eq!(graph.node_count(), 4); -//! assert_eq!(graph.edge_count(), 4); -//! ``` -//! -//! ## Traversing a Graph -//! -//! ```rust,ignore -//! use dotscope::graph::{DirectedGraph, NodeId, algorithms}; -//! -//! let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -//! let a = graph.add_node("A"); -//! let b = graph.add_node("B"); -//! let c = graph.add_node("C"); -//! graph.add_edge(a, b, ()); -//! graph.add_edge(b, c, ()); -//! -//! // Depth-first traversal -//! let dfs_order: Vec = algorithms::dfs(&graph, a).collect(); -//! assert_eq!(dfs_order.len(), 3); -//! ``` -//! -//! ## Computing Dominators -//! -//! ```rust,ignore -//! use dotscope::graph::{DirectedGraph, NodeId, algorithms}; -//! -//! let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -//! let entry = graph.add_node("entry"); -//! let a = graph.add_node("A"); -//! let b = graph.add_node("B"); -//! let exit = graph.add_node("exit"); -//! -//! graph.add_edge(entry, a, ()); -//! graph.add_edge(entry, b, ()); -//! graph.add_edge(a, exit, ()); -//! graph.add_edge(b, exit, ()); -//! -//! let dominators = algorithms::compute_dominators(&graph, entry); -//! assert!(dominators.dominates(entry, exit)); // entry dominates exit -//! ``` -//! -//! # Thread Safety -//! -//! All types in this module implement [`Send`] and [`Sync`] when their generic -//! parameters do, enabling safe concurrent access for analysis operations. - -mod directed; -mod edge; -mod indexed; -mod node; -mod traits; - -pub mod algorithms; - -// Re-export core types at module level -pub use directed::DirectedGraph; -#[allow(unused_imports)] -pub use edge::EdgeId; -pub use indexed::IndexedGraph; -pub use node::NodeId; -pub use traits::{GraphBase, Predecessors, RootedGraph, Successors}; diff --git a/dotscope/src/utils/graph/node.rs b/dotscope/src/utils/graph/node.rs deleted file mode 100644 index e20c0f23..00000000 --- a/dotscope/src/utils/graph/node.rs +++ /dev/null @@ -1,273 +0,0 @@ -//! Node identifier implementation for directed graphs. -//! -//! This module provides the [`NodeId`] type, a strongly-typed identifier for nodes -//! within a directed graph. The newtype wrapper provides type safety and prevents -//! accidental confusion between node indices and other integer values. - -use std::fmt; - -/// A strongly-typed identifier for nodes within a directed graph. -/// -/// `NodeId` wraps a `usize` index, providing type safety to prevent -/// accidental mixing of node indices with other integer values. Node IDs are assigned -/// sequentially starting from 0 when nodes are added to a graph. -/// -/// # Usage -/// -/// Node IDs are created when adding nodes to a graph and should not typically -/// be constructed manually. They are used to: -/// -/// - Reference nodes when adding edges -/// - Look up node data -/// - Query adjacency relationships -/// - Store analysis results indexed by node -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let node_a: NodeId = graph.add_node("A"); -/// let node_b: NodeId = graph.add_node("B"); -/// -/// // NodeIds can be compared -/// assert_ne!(node_a, node_b); -/// -/// // NodeIds can be used as keys in collections -/// use std::collections::HashMap; -/// let mut data: HashMap = HashMap::new(); -/// data.insert(node_a, 42); -/// ``` -/// -/// # Thread Safety -/// -/// `NodeId` is [`Copy`], [`Send`], and [`Sync`], enabling efficient passing between -/// threads and use in concurrent data structures. -#[derive(Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)] -pub struct NodeId(pub(crate) usize); - -impl NodeId { - /// Creates a new `NodeId` from a raw index value. - /// - /// This constructor is primarily intended for internal use and testing. - /// Normal usage should obtain `NodeId` values from graph construction methods. - /// - /// # Arguments - /// - /// * `index` - The raw node index (0-based) - /// - /// # Returns - /// - /// A new `NodeId` wrapping the provided index. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::NodeId; - /// - /// let node = NodeId::new(0); - /// assert_eq!(node.index(), 0); - /// ``` - #[must_use] - #[inline] - pub const fn new(index: usize) -> Self { - NodeId(index) - } - - /// Returns the raw index value of this node identifier. - /// - /// The index is a 0-based position that can be used to index into vectors - /// or arrays that store per-node data. - /// - /// # Returns - /// - /// The underlying index value. - /// - /// # Examples - /// - /// ```rust,ignore - /// use dotscope::graph::NodeId; - /// - /// let node = NodeId::new(5); - /// assert_eq!(node.index(), 5); - /// - /// // Can be used to index into arrays - /// let data = vec![10, 20, 30, 40, 50, 60]; - /// let value = data[node.index()]; - /// assert_eq!(value, 60); - /// ``` - #[must_use] - #[inline] - pub const fn index(self) -> usize { - self.0 - } -} - -impl fmt::Debug for NodeId { - /// Formats the node ID for debugging output. - /// - /// The format shows the type name and index value for clear identification - /// in debug output and logging. - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "NodeId({})", self.0) - } -} - -impl fmt::Display for NodeId { - /// Formats the node ID for user display. - /// - /// The display format shows just the prefix and index for compact output. - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - write!(f, "n{}", self.0) - } -} - -impl From for NodeId { - /// Converts a raw `usize` index into a `NodeId`. - /// - /// This conversion is provided for convenience but should be used carefully - /// to avoid creating invalid node IDs that don't correspond to actual nodes - /// in a graph. - #[inline] - fn from(index: usize) -> Self { - NodeId(index) - } -} - -impl From for usize { - /// Extracts the raw index from a `NodeId`. - #[inline] - fn from(node: NodeId) -> Self { - node.0 - } -} - -#[cfg(test)] -mod tests { - use super::*; - use std::collections::{HashMap, HashSet}; - - #[test] - fn test_node_id_new() { - let node = NodeId::new(42); - assert_eq!(node.index(), 42); - } - - #[test] - fn test_node_id_index() { - let node = NodeId::new(100); - assert_eq!(node.index(), 100); - } - - #[test] - fn test_node_id_equality() { - let node1 = NodeId::new(5); - let node2 = NodeId::new(5); - let node3 = NodeId::new(10); - - assert_eq!(node1, node2); - assert_ne!(node1, node3); - } - - #[test] - fn test_node_id_ordering() { - let node1 = NodeId::new(1); - let node2 = NodeId::new(2); - let node3 = NodeId::new(3); - - assert!(node1 < node2); - assert!(node2 < node3); - assert!(node1 < node3); - - let mut nodes = vec![node3, node1, node2]; - nodes.sort(); - assert_eq!(nodes, vec![node1, node2, node3]); - } - - #[test] - fn test_node_id_hash() { - let mut set: HashSet = HashSet::new(); - let node1 = NodeId::new(1); - let node2 = NodeId::new(2); - let node1_dup = NodeId::new(1); - - set.insert(node1); - set.insert(node2); - set.insert(node1_dup); // Should not add duplicate - - assert_eq!(set.len(), 2); - assert!(set.contains(&node1)); - assert!(set.contains(&node2)); - } - - #[test] - fn test_node_id_as_map_key() { - let mut map: HashMap = HashMap::new(); - let node1 = NodeId::new(1); - let node2 = NodeId::new(2); - - map.insert(node1, "first"); - map.insert(node2, "second"); - - assert_eq!(map.get(&node1), Some(&"first")); - assert_eq!(map.get(&node2), Some(&"second")); - assert_eq!(map.get(&NodeId::new(3)), None); - } - - #[test] - fn test_node_id_copy_semantics() { - let node1 = NodeId::new(42); - let node2 = node1; // Copy - - assert_eq!(node1, node2); - assert_eq!(node1.index(), 42); - assert_eq!(node2.index(), 42); - } - - #[test] - fn test_node_id_from_usize() { - let node: NodeId = 123usize.into(); - assert_eq!(node.index(), 123); - } - - #[test] - fn test_node_id_into_usize() { - let node = NodeId::new(789); - let value: usize = node.into(); - assert_eq!(value, 789); - } - - #[test] - fn test_node_id_debug_format() { - let node = NodeId::new(42); - let debug_str = format!("{node:?}"); - assert_eq!(debug_str, "NodeId(42)"); - } - - #[test] - fn test_node_id_display_format() { - let node = NodeId::new(42); - let display_str = format!("{node}"); - assert_eq!(display_str, "n42"); - } - - #[test] - fn test_node_id_boundary_values() { - // Test zero - let zero = NodeId::new(0); - assert_eq!(zero.index(), 0); - - // Test large value - let large = NodeId::new(1_000_000); - assert_eq!(large.index(), 1_000_000); - } - - #[test] - fn test_node_id_array_indexing() { - let data = ["zero", "one", "two", "three"]; - let node = NodeId::new(2); - - assert_eq!(data[node.index()], "two"); - } -} diff --git a/dotscope/src/utils/graph/traits.rs b/dotscope/src/utils/graph/traits.rs deleted file mode 100644 index b027ef99..00000000 --- a/dotscope/src/utils/graph/traits.rs +++ /dev/null @@ -1,335 +0,0 @@ -//! Trait definitions for graph abstractions. -//! -//! This module defines the core traits that enable graph algorithms to work with -//! different graph implementations. By programming against these traits, algorithms -//! can be reused across various graph types without modification. -//! -//! # Architecture -//! -//! The trait hierarchy is designed to be minimal and composable: -//! -//! - [`GraphBase`] - Core properties: node count and node iteration -//! - [`Successors`] - Forward edge traversal (outgoing edges) -//! - [`Predecessors`] - Backward edge traversal (incoming edges) -//! - [`RootedGraph`] - Graphs with a designated entry node (for dominator computation) -//! -//! # Design Principles -//! -//! ## Iterator-Based Traversal -//! -//! All adjacency queries return iterators rather than collections, enabling lazy -//! evaluation and avoiding unnecessary allocations for simple traversals. -//! -//! ## Minimal Requirements -//! -//! Each trait requires only what is necessary for its purpose, allowing different -//! graph implementations to provide only the capabilities they support. - -use crate::utils::graph::NodeId; - -/// Base trait providing core graph properties. -/// -/// This trait defines the fundamental properties that all graphs must have: -/// the number of nodes and the ability to iterate over all node identifiers. -/// -/// # Required Methods -/// -/// - [`node_count`](GraphBase::node_count) - Returns the total number of nodes -/// - [`node_ids`](GraphBase::node_ids) - Returns an iterator over all node IDs -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, GraphBase}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// graph.add_node("A"); -/// graph.add_node("B"); -/// -/// assert_eq!(graph.node_count(), 2); -/// -/// let ids: Vec<_> = graph.node_ids().collect(); -/// assert_eq!(ids.len(), 2); -/// ``` -pub trait GraphBase { - /// Returns the number of nodes in the graph. - /// - /// This count includes all nodes that have been added to the graph, - /// regardless of their connectivity. - fn node_count(&self) -> usize; - - /// Returns an iterator over all node identifiers in the graph. - /// - /// The iteration order is typically the order in which nodes were added - /// to the graph (i.e., by ascending `NodeId` index). - fn node_ids(&self) -> impl Iterator; -} - -/// Trait for graphs that support forward edge traversal. -/// -/// This trait provides access to the successor nodes of any given node, -/// enabling forward graph traversal and algorithms that follow edges in -/// their natural direction. -/// -/// # Required Methods -/// -/// - [`successors`](Successors::successors) - Returns an iterator over successor nodes -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, Successors}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// -/// graph.add_edge(a, b, ()); -/// graph.add_edge(a, c, ()); -/// -/// let successors: Vec = graph.successors(a).collect(); -/// assert_eq!(successors.len(), 2); -/// assert!(successors.contains(&b)); -/// assert!(successors.contains(&c)); -/// ``` -pub trait Successors: GraphBase { - /// Returns an iterator over the successor nodes of the given node. - /// - /// Successors are nodes that are targets of edges originating from the - /// specified node. For a directed edge `(u, v)`, node `v` is a successor of `u`. - /// - /// # Arguments - /// - /// * `node` - The node whose successors to iterate - /// - /// # Returns - /// - /// An iterator yielding the `NodeId` of each successor node. - /// - /// # Panics - /// - /// May panic if `node` is not a valid node in the graph. - fn successors(&self, node: NodeId) -> impl Iterator; -} - -/// Trait for graphs that support backward edge traversal. -/// -/// This trait provides access to the predecessor nodes of any given node, -/// enabling backward graph traversal and algorithms that need to follow edges -/// in reverse. -/// -/// # Required Methods -/// -/// - [`predecessors`](Predecessors::predecessors) - Returns an iterator over predecessor nodes -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, Predecessors}; -/// -/// let mut graph: DirectedGraph<&str, ()> = DirectedGraph::new(); -/// let a = graph.add_node("A"); -/// let b = graph.add_node("B"); -/// let c = graph.add_node("C"); -/// -/// graph.add_edge(a, c, ()); -/// graph.add_edge(b, c, ()); -/// -/// let predecessors: Vec = graph.predecessors(c).collect(); -/// assert_eq!(predecessors.len(), 2); -/// assert!(predecessors.contains(&a)); -/// assert!(predecessors.contains(&b)); -/// ``` -pub trait Predecessors: GraphBase { - /// Returns an iterator over the predecessor nodes of the given node. - /// - /// Predecessors are nodes that are sources of edges targeting the - /// specified node. For a directed edge `(u, v)`, node `u` is a predecessor of `v`. - /// - /// # Arguments - /// - /// * `node` - The node whose predecessors to iterate - /// - /// # Returns - /// - /// An iterator yielding the `NodeId` of each predecessor node. - /// - /// # Panics - /// - /// May panic if `node` is not a valid node in the graph. - fn predecessors(&self, node: NodeId) -> impl Iterator; -} - -/// Trait for graphs with a designated entry (root) node. -/// -/// This trait extends [`Successors`] and [`Predecessors`] to indicate that the -/// graph has a single distinguished entry point. This is essential for algorithms -/// like dominator computation that require a well-defined starting point. -/// -/// # Required Methods -/// -/// - [`entry`](RootedGraph::entry) - Returns the entry node of the graph -/// -/// # Use Cases -/// -/// - **Control Flow Graphs**: The entry node is the first basic block -/// - **Call Graphs**: The entry could be the main/entry point method -/// - **Dependency Graphs**: The entry represents the root dependency -/// -/// # Examples -/// -/// ```rust,ignore -/// use dotscope::graph::{DirectedGraph, NodeId, RootedGraph, Successors, Predecessors}; -/// -/// // Create a control flow graph with explicit entry -/// struct ControlFlowGraph { -/// graph: DirectedGraph<&'static str, ()>, -/// entry: NodeId, -/// } -/// -/// impl dotscope::graph::GraphBase for ControlFlowGraph { -/// fn node_count(&self) -> usize { self.graph.node_count() } -/// fn node_ids(&self) -> impl Iterator { self.graph.node_ids() } -/// } -/// -/// impl Successors for ControlFlowGraph { -/// fn successors(&self, node: NodeId) -> impl Iterator { -/// self.graph.successors(node) -/// } -/// } -/// -/// impl Predecessors for ControlFlowGraph { -/// fn predecessors(&self, node: NodeId) -> impl Iterator { -/// self.graph.predecessors(node) -/// } -/// } -/// -/// impl RootedGraph for ControlFlowGraph { -/// fn entry(&self) -> NodeId { self.entry } -/// } -/// ``` -pub trait RootedGraph: Successors + Predecessors { - /// Returns the entry (root) node of the graph. - /// - /// The entry node is the designated starting point for forward traversals - /// and the root for dominator tree computation. In a control flow graph, - /// this is typically the first basic block of a function. - /// - /// # Returns - /// - /// The `NodeId` of the entry node. - fn entry(&self) -> NodeId; -} - -#[cfg(test)] -mod tests { - use super::*; - - // A minimal test graph implementation for trait testing - struct TestGraph { - node_count: usize, - edges: Vec<(NodeId, NodeId)>, - entry: NodeId, - } - - impl TestGraph { - fn new(node_count: usize, edges: Vec<(NodeId, NodeId)>, entry: NodeId) -> Self { - TestGraph { - node_count, - edges, - entry, - } - } - } - - impl GraphBase for TestGraph { - fn node_count(&self) -> usize { - self.node_count - } - - fn node_ids(&self) -> impl Iterator { - (0..self.node_count).map(NodeId::new) - } - } - - impl Successors for TestGraph { - fn successors(&self, node: NodeId) -> impl Iterator { - self.edges - .iter() - .filter(move |(src, _)| *src == node) - .map(|(_, dst)| *dst) - } - } - - impl Predecessors for TestGraph { - fn predecessors(&self, node: NodeId) -> impl Iterator { - self.edges - .iter() - .filter(move |(_, dst)| *dst == node) - .map(|(src, _)| *src) - } - } - - impl RootedGraph for TestGraph { - fn entry(&self) -> NodeId { - self.entry - } - } - - #[test] - fn test_graph_base() { - let graph = TestGraph::new(5, vec![], NodeId::new(0)); - assert_eq!(graph.node_count(), 5); - - let ids: Vec = graph.node_ids().collect(); - assert_eq!(ids.len(), 5); - assert_eq!(ids[0], NodeId::new(0)); - assert_eq!(ids[4], NodeId::new(4)); - } - - #[test] - fn test_successors() { - let edges = vec![ - (NodeId::new(0), NodeId::new(1)), - (NodeId::new(0), NodeId::new(2)), - (NodeId::new(1), NodeId::new(3)), - ]; - let graph = TestGraph::new(4, edges, NodeId::new(0)); - - let succ: Vec = graph.successors(NodeId::new(0)).collect(); - assert_eq!(succ.len(), 2); - assert!(succ.contains(&NodeId::new(1))); - assert!(succ.contains(&NodeId::new(2))); - - let succ: Vec = graph.successors(NodeId::new(1)).collect(); - assert_eq!(succ.len(), 1); - assert!(succ.contains(&NodeId::new(3))); - - let succ: Vec = graph.successors(NodeId::new(3)).collect(); - assert!(succ.is_empty()); - } - - #[test] - fn test_predecessors() { - let edges = vec![ - (NodeId::new(0), NodeId::new(2)), - (NodeId::new(1), NodeId::new(2)), - ]; - let graph = TestGraph::new(3, edges, NodeId::new(0)); - - let pred: Vec = graph.predecessors(NodeId::new(2)).collect(); - assert_eq!(pred.len(), 2); - assert!(pred.contains(&NodeId::new(0))); - assert!(pred.contains(&NodeId::new(1))); - - let pred: Vec = graph.predecessors(NodeId::new(0)).collect(); - assert!(pred.is_empty()); - } - - #[test] - fn test_rooted_graph() { - let graph = TestGraph::new(3, vec![], NodeId::new(1)); - assert_eq!(graph.entry(), NodeId::new(1)); - } -} diff --git a/dotscope/src/utils/mod.rs b/dotscope/src/utils/mod.rs index f10a677d..6e34ea76 100644 --- a/dotscope/src/utils/mod.rs +++ b/dotscope/src/utils/mod.rs @@ -43,13 +43,11 @@ mod alignment; mod base64; -mod bitset; mod compression; mod crypto; mod decompress; mod dot; mod enums; -pub(crate) mod graph; mod hash; mod heap_calc; mod io; @@ -61,7 +59,6 @@ mod visitedmap; pub use alignment::align_to; #[cfg(feature = "emulation")] pub use base64::{base64_decode, base64_encode}; -pub use bitset::BitSet; pub use compression::compressed_uint_size; pub use dot::escape_dot; pub use enums::EnumUtils; @@ -78,8 +75,6 @@ pub use io::{ }; #[cfg(feature = "emulation")] pub use lebytes::LeBytes; -#[cfg(feature = "compiler")] -pub use math::is_power_of_two; #[cfg(feature = "emulation")] pub use math::to_i32_saturating; pub use math::to_u32; diff --git a/dotscope/tests/bitmono.rs b/dotscope/tests/bitmono.rs index b0901c47..1ad5b496 100644 --- a/dotscope/tests/bitmono.rs +++ b/dotscope/tests/bitmono.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! BitMono deobfuscation integration tests. //! //! This test suite verifies deobfuscation of BitMono-protected assemblies, diff --git a/dotscope/tests/common/compatibility.rs b/dotscope/tests/common/compatibility.rs index 59ac0665..34d478dd 100644 --- a/dotscope/tests/common/compatibility.rs +++ b/dotscope/tests/common/compatibility.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Shared assembly compatibility testing infrastructure. //! //! Provides common loading, comparison, and analysis logic for testing diff --git a/dotscope/tests/common/framework.rs b/dotscope/tests/common/framework.rs index fef1fdd3..ca719588 100644 --- a/dotscope/tests/common/framework.rs +++ b/dotscope/tests/common/framework.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Shared integration test framework for deobfuscation tests. //! //! Provides common types, harness functions, and assertion helpers used across diff --git a/dotscope/tests/common/verification.rs b/dotscope/tests/common/verification.rs index 3696393a..ed4d280e 100644 --- a/dotscope/tests/common/verification.rs +++ b/dotscope/tests/common/verification.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Shared integration test verification framework. //! //! Provides SSA-based semantic verification for comparing original vs deobfuscated diff --git a/dotscope/tests/confuserex.rs b/dotscope/tests/confuserex.rs index 8f808f72..4fdd968c 100644 --- a/dotscope/tests/confuserex.rs +++ b/dotscope/tests/confuserex.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! ConfuserEx deobfuscation integration tests. //! //! This test suite verifies deobfuscation of ConfuserEx-protected assemblies, diff --git a/dotscope/tests/crafted_1.rs b/dotscope/tests/crafted_1.rs index e40ea946..bfad634a 100644 --- a/dotscope/tests/crafted_1.rs +++ b/dotscope/tests/crafted_1.rs @@ -1,3 +1,6 @@ +//! This is the crafted.exe source-code +//! + /* Requires .NET Core 3.0 diff --git a/dotscope/tests/deobfuscation.rs b/dotscope/tests/deobfuscation.rs index 34dc46ff..2f042500 100644 --- a/dotscope/tests/deobfuscation.rs +++ b/dotscope/tests/deobfuscation.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Deobfuscation pipeline tests with exact IL verification. //! //! All tests follow the same rigorous pattern: diff --git a/dotscope/tests/fuzzer.rs b/dotscope/tests/fuzzer.rs index c6f3ec2c..7616f497 100644 --- a/dotscope/tests/fuzzer.rs +++ b/dotscope/tests/fuzzer.rs @@ -1,3 +1,7 @@ +//! Fuzzer corpus regression tests — load every file under `fuzz/corpus/` and +//! `fuzz/artifacts/` through `CilObject::from_path` and assert we don't panic. +//! Files are intentionally malformed; errors are expected, crashes are not. + use std::{fs, path::PathBuf}; use dotscope::metadata::cilobject::CilObject; diff --git a/dotscope/tests/jiejie.rs b/dotscope/tests/jiejie.rs index 0d717492..e293f233 100644 --- a/dotscope/tests/jiejie.rs +++ b/dotscope/tests/jiejie.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! JIEJIE.NET deobfuscation integration tests. //! //! This test suite verifies deobfuscation of JIEJIE.NET-protected assemblies, diff --git a/dotscope/tests/modify_add.rs b/dotscope/tests/modify_add.rs index c914891d..0dbdf60a 100644 --- a/dotscope/tests/modify_add.rs +++ b/dotscope/tests/modify_add.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Integration tests for the write module. //! //! These tests verify the complete end-to-end functionality of writing diff --git a/dotscope/tests/modify_basic.rs b/dotscope/tests/modify_basic.rs index bd1ce249..8479c698 100644 --- a/dotscope/tests/modify_basic.rs +++ b/dotscope/tests/modify_basic.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Basic write pipeline integration tests. //! //! Tests for basic assembly writing functionality, including unmodified assemblies diff --git a/dotscope/tests/modify_heaps.rs b/dotscope/tests/modify_heaps.rs index f28e7d8d..58043891 100644 --- a/dotscope/tests/modify_heaps.rs +++ b/dotscope/tests/modify_heaps.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Heap modification integration tests. //! //! Tests for modifying metadata heaps (strings, blobs, GUIDs, userstrings) and verifying diff --git a/dotscope/tests/modify_impexp.rs b/dotscope/tests/modify_impexp.rs index b9ea1fd7..a80b5474 100644 --- a/dotscope/tests/modify_impexp.rs +++ b/dotscope/tests/modify_impexp.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Integration tests for native import/export functionality. //! //! These tests verify the complete end-to-end functionality of adding diff --git a/dotscope/tests/modify_roundtrips_crafted2.rs b/dotscope/tests/modify_roundtrips_crafted2.rs index 796afafa..b8c2d5a8 100644 --- a/dotscope/tests/modify_roundtrips_crafted2.rs +++ b/dotscope/tests/modify_roundtrips_crafted2.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Consolidated integration tests for dotscope assembly modification round-trip operations. //! //! These tests validate the complete public API by simulating real user implementations. diff --git a/dotscope/tests/modify_roundtrips_method.rs b/dotscope/tests/modify_roundtrips_method.rs index 40eda1ff..3fae4a6e 100644 --- a/dotscope/tests/modify_roundtrips_method.rs +++ b/dotscope/tests/modify_roundtrips_method.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Integration test for method injection roundtrip //! //! This test verifies that: diff --git a/dotscope/tests/modify_roundtrips_wbdll.rs b/dotscope/tests/modify_roundtrips_wbdll.rs index 00dbf76e..85d764f5 100644 --- a/dotscope/tests/modify_roundtrips_wbdll.rs +++ b/dotscope/tests/modify_roundtrips_wbdll.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! True round-trip integration tests for assembly modification operations. //! //! These tests validate the complete write pipeline by: diff --git a/dotscope/tests/netreactor.rs b/dotscope/tests/netreactor.rs index fc9a191d..55c0dbe3 100644 --- a/dotscope/tests/netreactor.rs +++ b/dotscope/tests/netreactor.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! .NET Reactor deobfuscation integration tests. //! //! This test suite verifies deobfuscation of .NET Reactor-protected assemblies. diff --git a/dotscope/tests/obfuscar.rs b/dotscope/tests/obfuscar.rs index 23c39ed5..24fb3ba6 100644 --- a/dotscope/tests/obfuscar.rs +++ b/dotscope/tests/obfuscar.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Obfuscar deobfuscation integration tests. //! //! This test suite verifies deobfuscation of Obfuscar-protected assemblies, diff --git a/dotscope/tests/roundtrip_asm.rs b/dotscope/tests/roundtrip_asm.rs index 69917d09..ffda5c44 100644 --- a/dotscope/tests/roundtrip_asm.rs +++ b/dotscope/tests/roundtrip_asm.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! Roundtrip tests for CIL instruction assembly and disassembly. //! //! These tests verify that our encoder and disassembler work perfectly together by: diff --git a/dotscope/tests/ssa.rs b/dotscope/tests/ssa.rs index c53174d1..fb4092e0 100644 --- a/dotscope/tests/ssa.rs +++ b/dotscope/tests/ssa.rs @@ -1,3 +1,12 @@ +#![allow( + clippy::unwrap_used, + clippy::expect_used, + clippy::panic, + clippy::arithmetic_side_effects, + clippy::indexing_slicing, + missing_docs +)] + //! SSA (Static Single Assignment) integration tests. //! //! These tests verify the complete SSA pipeline using the public API: @@ -15,8 +24,8 @@ use std::collections::HashMap; use common::{build_cfg, build_ssa, TestTypeProvider}; use dotscope::{ analysis::{ - ConstValue, ControlFlowGraph, SsaConverter, SsaExceptionHandler, SsaFunction, SsaOp, - SsaVarId, SymbolicEvaluator, SymbolicExpr, + ConstValue, ControlFlowGraph, SsaConverter, SsaExceptionHandler, SsaExceptionHandlerCilExt, + SsaFunction, SsaOp, SsaVarId, SymbolicEvaluator, SymbolicExpr, }, assembly::{decode_blocks, InstructionAssembler}, metadata::{ From 2d83481ca5b787a35387d511c15c610110d434ef Mon Sep 17 00:00:00 2001 From: BinFlip Date: Sat, 9 May 2026 07:15:26 -0700 Subject: [PATCH 3/6] fix: CI/CD issues from ssa migration --- dotscope/src/analysis/ssa/symbolic/solver.rs | 22 ++++++++++++++++---- dotscope/src/analysis/ssa/target.rs | 6 +++++- 2 files changed, 23 insertions(+), 5 deletions(-) diff --git a/dotscope/src/analysis/ssa/symbolic/solver.rs b/dotscope/src/analysis/ssa/symbolic/solver.rs index c25dd270..f0e83b84 100644 --- a/dotscope/src/analysis/ssa/symbolic/solver.rs +++ b/dotscope/src/analysis/ssa/symbolic/solver.rs @@ -6,8 +6,10 @@ use std::collections::HashMap; -use crate::metadata::typesystem::PointerSize; -use analyssa::analysis::symbolic::{expr::SymbolicExpr, ops::SymbolicOp}; +use crate::{analysis::ssa::CilTarget, metadata::typesystem::PointerSize}; +use analyssa::analysis::symbolic::{expr::SymbolicExpr as GenericSymbolicExpr, ops::SymbolicOp}; + +type SymbolicExpr = GenericSymbolicExpr; /// Z3-based constraint solver for symbolic expressions. /// @@ -509,8 +511,20 @@ impl Z3Solver { SymbolicOp::GeU => left_z3 .bvuge(&right_z3) .ite(&z3::ast::BV::from_i64(1, 32), &z3::ast::BV::from_i64(0, 32)), - // Unary ops shouldn't appear in binary context - SymbolicOp::Neg | SymbolicOp::Not => left_z3, + // Rotate operations. + SymbolicOp::Rol => left_z3.bvrotl(&right_z3), + SymbolicOp::Ror => left_z3.bvrotr(&right_z3), + // Rotate-through-carry have no direct bitvector mapping. + SymbolicOp::Rcl | SymbolicOp::Rcr => left_z3, + // Unary ops shouldn't appear in binary context. + SymbolicOp::Neg + | SymbolicOp::Not + | SymbolicOp::BSwap + | SymbolicOp::BRev + | SymbolicOp::BitScanForward + | SymbolicOp::BitScanReverse + | SymbolicOp::Popcount + | SymbolicOp::Parity => left_z3, } } } diff --git a/dotscope/src/analysis/ssa/target.rs b/dotscope/src/analysis/ssa/target.rs index 60e02339..7b99931e 100644 --- a/dotscope/src/analysis/ssa/target.rs +++ b/dotscope/src/analysis/ssa/target.rs @@ -11,10 +11,11 @@ use analyssa::{ir::value::ConstValue, PointerSize}; +#[cfg(feature = "compiler")] +use crate::compiler::CilCapability; use crate::{ analysis::ssa::types::{FieldRef, MethodRef, SigRef, SsaType, TypeRef}, assembly::{FlowType, Instruction, InstructionCategory, Operand, StackBehavior}, - compiler::CilCapability, metadata::{method::ExceptionHandlerFlags, signatures::SignatureLocalVariable}, }; @@ -71,7 +72,10 @@ impl Target for CilTarget { type Type = SsaType; type OriginalInstruction = Instruction; type LocalSignature = SignatureLocalVariable; + #[cfg(feature = "compiler")] type Capability = CilCapability; + #[cfg(not(feature = "compiler"))] + type Capability = (); fn ptr_bytes(&self) -> u32 { self.ptr_bytes From 75c92018b03896dedbdfbfa8b541345cd55bbac0 Mon Sep 17 00:00:00 2001 From: BinFlip Date: Wed, 3 Jun 2026 18:37:28 -0700 Subject: [PATCH 4/6] feat: updated analyssa to 0.2.0 and fixed several discovered bugs as well as improvements to error handling --- CHANGELOG.md | 15 + Cargo.lock | 316 ++++++------- dotscope-cli/Cargo.toml | 8 +- dotscope-cli/src/commands/attrs.rs | 2 +- dotscope-cli/src/commands/resolution.rs | 2 +- dotscope/Cargo.toml | 18 +- dotscope/fuzz/Cargo.lock | 344 +++++++------- dotscope/src/analysis/ssa/converter.rs | 162 ++++++- dotscope/src/analysis/ssa/mod.rs | 2 +- dotscope/src/analysis/ssa/target.rs | 2 + dotscope/src/analysis/ssa/value.rs | 4 +- dotscope/src/assembly/instruction.rs | 208 ++++++++ .../src/cilassembly/writer/heaps/streaming.rs | 99 ++-- dotscope/src/compiler/codegen/mod.rs | 107 ++++- dotscope/src/compiler/passes/constants/mod.rs | 25 +- .../src/compiler/passes/constants/tests.rs | 14 +- dotscope/src/compiler/scheduler.rs | 20 +- dotscope/src/deobfuscation/engine/analysis.rs | 2 +- .../deobfuscation/passes/bitmono/strings.rs | 2 +- .../deobfuscation/passes/bitmono/unmanaged.rs | 2 +- .../src/deobfuscation/passes/decryption.rs | 9 +- .../src/deobfuscation/passes/delegates.rs | 2 +- .../src/deobfuscation/passes/reflection.rs | 6 +- .../src/deobfuscation/passes/staticfields.rs | 4 +- .../passes/unflattening/detection.rs | 33 +- .../passes/unflattening/dispatcher.rs | 35 +- .../passes/unflattening/tracer/helpers.rs | 2 +- dotscope/src/deobfuscation/renamer/cascade.rs | 11 +- .../src/deobfuscation/renamer/features.rs | 13 +- .../deobfuscation/renamer/providers/local.rs | 19 +- .../deobfuscation/techniques/bitmono/hooks.rs | 4 +- .../techniques/bitmono/strings.rs | 2 +- .../techniques/bitmono/unmanaged.rs | 2 +- .../techniques/confuserex/constants.rs | 4 +- .../techniques/confuserex/resources.rs | 2 +- .../techniques/confuserex/statemachine.rs | 10 +- .../techniques/generic/strings.rs | 6 +- .../techniques/jiejienet/arrays.rs | 10 +- .../techniques/jiejienet/resources.rs | 14 +- .../techniques/jiejienet/typeofs.rs | 6 +- .../techniques/netreactor/antitamp.rs | 1 + .../techniques/netreactor/helpers.rs | 13 +- .../techniques/netreactor/necrobit.rs | 3 +- .../techniques/netreactor/resources.rs | 4 +- .../techniques/obfuscar/strings.rs | 2 +- dotscope/src/deobfuscation/utils.rs | 3 +- dotscope/src/emulation/engine/generics.rs | 60 +-- .../runtime/bcl/reflection/members.rs | 1 + .../runtime/bcl/reflection/methods.rs | 4 +- dotscope/src/emulation/value/emvalue.rs | 1 + dotscope/src/error.rs | 447 +++++++++++++++--- dotscope/src/file/mod.rs | 96 ++-- dotscope/src/file/parser.rs | 200 ++++---- dotscope/src/file/pe.rs | 143 +++--- dotscope/src/formatting/tokens.rs | 2 +- dotscope/src/lib.rs | 6 +- dotscope/src/metadata/cilobject.rs | 97 ++-- .../src/metadata/customattributes/parser.rs | 6 +- .../src/metadata/identity/cryptographic.rs | 8 +- dotscope/src/metadata/marshalling/parser.rs | 12 +- dotscope/src/metadata/method/exceptions.rs | 116 +++++ dotscope/src/metadata/method/mod.rs | 191 ++++++++ dotscope/src/metadata/method/types.rs | 184 +++++++ dotscope/src/metadata/resolver.rs | 4 +- dotscope/src/metadata/resources/parser.rs | 9 +- dotscope/src/metadata/root.rs | 288 ++++++----- .../src/metadata/security/permissionset.rs | 9 +- dotscope/src/metadata/signatures/parser.rs | 6 +- dotscope/src/metadata/streams/blob.rs | 22 +- dotscope/src/metadata/streams/guid.rs | 24 +- dotscope/src/metadata/streams/streamheader.rs | 80 ++-- dotscope/src/metadata/streams/strings.rs | 24 +- dotscope/src/metadata/streams/tablesheader.rs | 53 ++- dotscope/src/metadata/streams/userstrings.rs | 39 +- .../src/metadata/tables/genericparam/mod.rs | 141 +++++- .../src/metadata/typesystem/primitives.rs | 143 +++++- dotscope/src/metadata/validation/config.rs | 132 +++++- dotscope/src/metadata/validation/scanner.rs | 24 +- .../src/metadata/validation/shared/mod.rs | 12 + .../src/metadata/validation/shared/schema.rs | 8 +- .../validators/raw/constraints/generic.rs | 23 +- .../validators/raw/constraints/layout.rs | 27 +- .../validators/raw/modification/integrity.rs | 2 +- .../validators/raw/modification/operation.rs | 2 +- .../validators/raw/structure/heap.rs | 6 +- .../validators/raw/structure/table.rs | 19 +- dotscope/src/prelude.rs | 7 + dotscope/src/test/analysis/runner.rs | 4 +- dotscope/src/test/builders/methods.rs | 5 + .../validation/raw_constraints_generic.rs | 6 +- .../validation/raw_constraints_layout.rs | 12 +- .../validation/raw_structure_table.rs | 8 +- dotscope/src/utils/enums.rs | 12 +- dotscope/src/utils/io.rs | 92 ++-- dotscope/src/utils/math.rs | 12 +- dotscope/tests/bitmono.rs | 8 +- dotscope/tests/common/verification.rs | 4 +- dotscope/tests/modify_roundtrips_method.rs | 14 +- 98 files changed, 3162 insertions(+), 1267 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index de35eeaa..f525849d 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -5,6 +5,21 @@ All notable changes to this project will be documented in this file. The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). +## [0.8.0] - 2026-06-03 + +### Changed + +- **SSA functionality extracted into the standalone [`analyssa`](https://crates.io/crates/analyssa) crate**: the target-agnostic SSA IR, analyses, and optimization/deobfuscation pass framework now live in `analyssa`, with dotscope providing the CIL-specific host (lifting, type system, codegen). dotscope builds on `analyssa` 0.2.0, which adds the native SSA substrate (SIMD/vector ops, native atomics, wide arithmetic, boolean ops), a fluent SSA builder and verifier-checked editor, and ~3–4× smaller core IR types. The CIL pass scheduler is built on `analyssa`'s `PassScheduler::empty` so the deobfuscation pipeline keeps full control over which passes run +- **Structured parse errors** (**breaking**): parse failures are now reported through `Error::Parse(ParseFailure)` with a `ParseStage`, replacing the stringly-typed `Error::Malformed` / `Error::OutOfBounds` / `Error::HeapBoundsError` variants at parse sites. The `malformed_error!` / out-of-bounds helper macros and `#[non_exhaustive] ParseFailure` give callers categorizable, source-located errors. Code that matched the removed variants must migrate to `Error::Parse(..)` +- **Fallible metadata lookups** (**breaking**): metadata lookup APIs such as `CilObject::method()` now return `Result<_, Error>` instead of `Option<_>`, so a missing or unresolvable token reports a typed error rather than a bare `None`. Callers using `if let Some(..)` / `?`-on-`Option` must switch to the `Result` forms +- **Hardening against malformed/adversarial input**: enabled strict crate lints (`unwrap_used`, `expect_used`, `panic`, `arithmetic_side_effects`, `indexing_slicing` set to `deny`) and reworked the metadata parsers (custom attributes, marshalling, resources, signatures), the validation layer (scanner, schema, raw/owned constraints), and the deobfuscation passes to use fallible, bounds-checked, overflow-safe access. The parser no longer panics on crafted inputs + +### Fixed + +- **SSA construction — operand corruption**: the placeholder→final variable rename in `SsaConverter` applied cascading by-value replacements, which under `analyssa` 0.2.0's variable-id encoding could collapse a binary operation's operands (e.g. `a - b` becoming `b - b`). The rename is now applied atomically/position-wise, so it is correct regardless of id numbering (handles operand aliasing, swaps, repeated operands, and the reserved placeholder id). This only affected method bodies regenerated from SSA (deobfuscation output) +- **CFF unflattening — expression-obfuscated dispatchers**: state-variable backward tracing (`trace_to_phi`) now follows `neg`/`not`/`conv` wrappers, so ConfuserEx "expression" control-flow flattening (`-(!!state)`-style transforms) no longer hides the dispatcher state phi. `Dispatcher::refresh` reuses the same tracer instead of a shallow, fixed-depth walk +- **CFF unflattening — nested/exception-handler dispatchers**: when a dispatcher's state-setup block is also one of its own switch case targets (common for handler-region CFF), the initial state was not recovered, so the tracer could not seed the state machine and explored every path until hitting its limits — leaving a residual switch and dropping code. The initial state is now recovered from the state phi's constant operand; ConfuserEx control-flow + expression samples fully unflatten again + ## [0.7.0] - 2026-05-03 ### Added diff --git a/Cargo.lock b/Cargo.lock index 82876d96..7bb863a3 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -20,9 +20,9 @@ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aes" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66bd29a732b644c0431c6140f370d097879203d79b80c94a6747ba0872adaef8" +checksum = "f1fc76eaeac4c9164506c466d4ffdd8ec9d0c5bf57ee97177c4d8eceb3a0e138" dependencies = [ "cipher", "cpubits", @@ -99,9 +99,7 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "analyssa" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dd57908e45c301501605d7a6cd7b6449f7d0020a95c57ec2c29180be8689c78d" +version = "0.2.0" dependencies = [ "boxcar", "dashmap", @@ -127,9 +125,9 @@ checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "annotate-snippets" -version = "0.12.15" +version = "0.12.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92570a3f9c98e7e84df84b71d0965ac99b1871fcd75a3773a3bd1bad13f64cf7" +checksum = "f211a51805bc641f3ad5b7664c77d2547af685cc33b4cd8d31964027a46f13f1" dependencies = [ "anstyle", "memchr", @@ -276,9 +274,9 @@ checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" [[package]] name = "autocfg" -version = "1.5.0" +version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8" +checksum = "f2032f911046de80f0a198e0901378627c33f59ea0ac00e363d481118bd70a53" [[package]] name = "av-scenechange" @@ -316,18 +314,18 @@ dependencies = [ [[package]] name = "avif-serialize" -version = "0.8.8" +version = "0.8.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "375082f007bd67184fb9c0374614b29f9aaa604ec301635f72338bb65386a53d" +checksum = "e7178fe5f7d460b13895ebb9dcb28a3a6216d2df2574a0806cb51b555d297f38" dependencies = [ "arrayvec", ] [[package]] name = "aws-lc-rs" -version = "1.16.3" +version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ec6fb3fe69024a75fa7e1bfb48aa6cf59706a101658ea01bfd33b2b248a038f" +checksum = "5ec2f1fc3ec205783a5da9a7e6c1509cc69dedf09a1949e412c1e18469326d00" dependencies = [ "aws-lc-sys", "zeroize", @@ -335,9 +333,9 @@ dependencies = [ [[package]] name = "aws-lc-sys" -version = "0.40.0" +version = "0.41.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f50037ee5e1e41e7b8f9d161680a725bd1626cb6f8c7e901f91f942850852fe7" +checksum = "1a2f9779ce85b93ab6170dd940ad0169b5766ff848247aff13bb788b832fe3f4" dependencies = [ "cc", "cmake", @@ -401,9 +399,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.11.1" +version = "2.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4512299f36f043ab09a583e57bceb5a5aab7a73db1805848e8fef3c9e8c78b3" +checksum = "84d7ced0ae9557296835c32bf1b1e02b44c746701f898460fb000d7eaa84f00a" [[package]] name = "bitmaps" @@ -482,6 +480,15 @@ version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36f64beae40a84da1b4b26ff2761a5b895c12adc41dc25aaee1c4f2bbfe97a6e" +[[package]] +name = "bs58" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf88ba1141d185c399bee5288d850d63b8369520c1eafc32a0430b5b6c287bf4" +dependencies = [ + "tinyvec", +] + [[package]] name = "bstr" version = "1.12.1" @@ -495,15 +502,15 @@ dependencies = [ [[package]] name = "built" -version = "0.8.0" +version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f4ad8f11f288f48ca24471bbd51ac257aaeaaa07adae295591266b792902ae64" +checksum = "5c0e531d93d39c34eef561e929e8a7f86d77a5af08aac4f6d6e39976c51858e9" [[package]] name = "bumpalo" -version = "3.20.2" +version = "3.20.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d20789868f4b01b2f2caec9f5c4e0213b41e3e5702a50157d699ae31ced2fcb" +checksum = "72f5acc6cb2ba439de613abc23857ec3d78374d8ed5ac84e9d11336e87da8649" [[package]] name = "bytemuck" @@ -664,18 +671,18 @@ dependencies = [ [[package]] name = "cbc" -version = "0.2.0" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98db6aeaef0eeef2c1e3ce9a27b739218825dae116076352ac3777076aa22225" +checksum = "ce2dc9ee5f88d11e0beb842c88b33c8a5cf0d1329c4b19494af42b07dbfe8896" dependencies = [ "cipher", ] [[package]] name = "cc" -version = "1.2.61" +version = "1.2.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d16d90359e986641506914ba71350897565610e87ce0ad9e6f28569db3dd5c6d" +checksum = "556e016178bb5662a08681bbe0f00f8e17631781a4dfc8c45e466e4b185ec27f" dependencies = [ "find-msvc-tools", "jobserver", @@ -752,11 +759,11 @@ dependencies = [ [[package]] name = "cipher" -version = "0.5.1" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e34d8227fe1ba289043aeb13792056ff80fd6de1a9f49137a5f499de8e8c78ea" +checksum = "e8cf2a2c93cd704877c0858356ed03480ff301ee950b43f1cbe4573b088bfa6c" dependencies = [ - "crypto-common 0.2.1", + "crypto-common 0.2.2", "inout", ] @@ -812,9 +819,9 @@ dependencies = [ [[package]] name = "cmov" -version = "0.5.3" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f88a43d011fc4a6876cb7344703e297c71dda42494fee094d5f7c76bf13f746" +checksum = "0c9ea0ac24bc397ab3c98583a3c9ba74fa56b09a4449bbe172b9b1ddb016027a" [[package]] name = "color_quant" @@ -851,9 +858,9 @@ dependencies = [ [[package]] name = "compact_str" -version = "0.9.0" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fdb1325a1cece981e8a296ab8f0f9b63ae357bd0784a9faaf548cc7b480707a" +checksum = "9dfdd1c2274d9aa354115b09dc9a901d6c5576818cdf70d14cae2bdb47df00ab" dependencies = [ "castaway", "cfg-if", @@ -1082,7 +1089,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8b9f2e4c67f833b660cdb0a3523065869fb35570177239812ed4c905aeff87b" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "crossterm_winapi", "document-features", "parking_lot", @@ -1117,9 +1124,9 @@ dependencies = [ [[package]] name = "crypto-common" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77727bb15fa921304124b128af125e7e3b968275d1b108b379190264f4423710" +checksum = "ce6e4c961d6cd6c9a86db418387425e8bdeaf05b3c8bc1411e6dca4c252f1453" dependencies = [ "hybrid-array", ] @@ -1303,9 +1310,9 @@ dependencies = [ [[package]] name = "dashmap" -version = "6.1.0" +version = "6.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -1447,7 +1454,7 @@ checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ "block-buffer 0.12.0", "const-oid", - "crypto-common 0.2.1", + "crypto-common 0.2.2", "ctutils", ] @@ -1478,7 +1485,7 @@ version = "0.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e0e367e4e7da84520dedcac1901e4da967309406d1e51017ae1abfb97adbd38" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "block2", "libc", "objc2", @@ -1486,9 +1493,9 @@ dependencies = [ [[package]] name = "displaydoc" -version = "0.2.5" +version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" +checksum = "1ac70aa55017e108007fbaf5aa0f54b021c98f92ff8af59d42eda9da96e3dd4f" dependencies = [ "proc-macro2", "quote", @@ -1512,7 +1519,7 @@ dependencies = [ [[package]] name = "dotscope" -version = "0.7.0" +version = "0.8.0" dependencies = [ "aes", "analyssa", @@ -1557,7 +1564,7 @@ dependencies = [ [[package]] name = "dotscope-cli" -version = "0.7.0" +version = "0.8.0" dependencies = [ "anyhow", "clap", @@ -1631,9 +1638,9 @@ checksum = "b2972feb8dffe7bc8c5463b1dacda1b0dfbed3710e50f977d965429692d74cd8" [[package]] name = "either" -version = "1.15.0" +version = "1.16.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "48c757948c5ede0e46177b7add2e67155f70e33c07fea8284df6576da70b3719" +checksum = "91622ff5e7162018101f2fea40d6ebf4a78bbe5a49736a2020649edf9693679e" dependencies = [ "serde", ] @@ -2323,9 +2330,9 @@ dependencies = [ [[package]] name = "goblin" -version = "0.10.5" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "983a6aafb3b12d4c41ea78d39e189af4298ce747353945ff5105b54a056e5cd9" +checksum = "17582616a7718cca54cec18e534a76c7c4aec11a8b9a85695712f262fd15a4c8" dependencies = [ "log", "plain", @@ -2404,9 +2411,9 @@ dependencies = [ [[package]] name = "hashbrown" -version = "0.17.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f467dd6dccf739c208452f8014c75c18bb8301b050ad1cfb27153803edb0f51" +checksum = "ed5909b6e89a2db4456e54cd5f673791d7eca6732202bbf2a9cc504fe2f9b84a" [[package]] name = "heck" @@ -2504,9 +2511,9 @@ dependencies = [ [[package]] name = "http" -version = "1.4.0" +version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3ba2a386d7f85a81f119ad7498ebe444d2e22c2af0b86b069416ace48b3311a" +checksum = "8be7462df143984c4598a256ef469b251d7d7f9e271135073e78fc535414f3d0" dependencies = [ "bytes", "itoa", @@ -2543,18 +2550,18 @@ checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" [[package]] name = "hybrid-array" -version = "0.4.11" +version = "0.4.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "08d46837a0ed51fe95bd3b05de33cd64a1ee88fc797477ca48446872504507c5" +checksum = "9155a582abd142abc056962c29e3ce5ff2ad5469f4246b537ed42c5deba857da" dependencies = [ "typenum", ] [[package]] name = "hyper" -version = "1.9.0" +version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6299f016b246a94207e63da54dbe807655bf9e00044f73ded42c3ac5305fbcca" +checksum = "55281c53a1894c864990125767da440a4e630446785086f52523b20033b74498" dependencies = [ "atomic-waker", "bytes", @@ -2842,7 +2849,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d466e9454f08e4a911e14806c24e16fba1b4c121d1ea474396f396069cf949d9" dependencies = [ "equivalent", - "hashbrown 0.17.0", + "hashbrown 0.17.1", "serde", "serde_core", ] @@ -2914,16 +2921,6 @@ version = "2.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d98f6fed1fde3f8c21bc40a1abb88dd75e67924f9cffc3ef95607bad8017f8e2" -[[package]] -name = "iri-string" -version = "0.7.12" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "25e659a4bb38e810ebc252e53b5814ff908a8c58c2a9ce2fae1bbec24cbf4e20" -dependencies = [ - "memchr", - "serde", -] - [[package]] name = "is_terminal_polyfill" version = "1.70.2" @@ -2956,9 +2953,9 @@ checksum = "8f42a60cbdf9a97f5d2305f08a87dc4e09308d1276d28c869c684d7777685682" [[package]] name = "jiff" -version = "0.2.24" +version = "0.2.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f00b5dbd620d61dfdcb6007c9c1f6054ebd75319f163d886a9055cec1155073d" +checksum = "4603d3033e49e2b0e31229fcab20a5d40089c607d975cd9c80551dc69eed9102" dependencies = [ "jiff-static", "log", @@ -2969,9 +2966,9 @@ dependencies = [ [[package]] name = "jiff-static" -version = "0.2.24" +version = "0.2.28" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e000de030ff8022ea1da3f466fbb0f3a809f5e51ed31f6dd931c35181ad8e6d7" +checksum = "782d32378dddf207193ac91cefb848ad41abb58195c95168e1291227a0832b47" dependencies = [ "proc-macro2", "quote", @@ -3039,9 +3036,9 @@ dependencies = [ [[package]] name = "js-sys" -version = "0.3.97" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1840c94c045fbcf8ba2812c95db44499f7c64910a912551aaaa541decebcacf" +checksum = "142bc4740e452c1e57ade0cbc129f139c9093e354346f0872ef985f4f5cf5f11" dependencies = [ "cfg-if", "futures-util", @@ -3101,9 +3098,9 @@ checksum = "b6d2cec3eae94f9f509c767b45932f1ada8350c4bdb85af2fcab4a3c14807981" [[package]] name = "libredox" -version = "0.1.16" +version = "0.1.17" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e02f3bb43d335493c96bf3fd3a321600bf6bd07ed34bc64118e9293bdffea46c" +checksum = "f02ab6bace2054fb888a3c16f990117b579d14a3088e472d63c6011fa185c9d3" dependencies = [ "libc", ] @@ -3152,9 +3149,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.29" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "113b30b4cd05f7c06868fdb2854f66a7b9fece9a48425351cd532e810d74024f" [[package]] name = "loop9" @@ -3288,9 +3285,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.8.0" +version = "2.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ca58f447f06ed17d5fc4043ce1b10dd205e060fb3ce5b979b8ed8e59ff3f79" +checksum = "6b947ae49db0d222b1dbc6b113ce7248a3fc3a6ca21b696717bfc000ba4484d8" [[package]] name = "memmap2" @@ -3314,7 +3311,7 @@ version = "0.29.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7ecfd3296f8c56b7c1f6fbac3c71cefa9d78ce009850c45000015f206dc7fa21" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "block", "core-graphics-types", "foreign-types", @@ -3341,9 +3338,9 @@ dependencies = [ [[package]] name = "minijinja" -version = "2.19.0" +version = "2.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "805bfd7352166bae857ee569628b52bcd85a1cecf7810861ebceb1686b72b75d" +checksum = "2929e494b2280e1e18959bb2e121da03347ae896896fdfaceaab43c88a02803f" dependencies = [ "memo-map", "serde", @@ -3352,9 +3349,9 @@ dependencies = [ [[package]] name = "minijinja-contrib" -version = "2.19.0" +version = "2.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "45092d80391870622fcf3bd82f5d2af18f99533ea60debb4bc9db0c76f0e809a" +checksum = "99df5123c54391e2a228014c1dbbd85a3dab08a25e776c810526f2f47542b3de" dependencies = [ "minijinja", "serde", @@ -3390,9 +3387,9 @@ dependencies = [ [[package]] name = "mio" -version = "1.2.0" +version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "50b7e5b27aa02a74bac8c3f23f448f8d87ff11f92d3aac1a6ed369ee08cc56c1" +checksum = "02bd0af71c67b473010cbbc60715ee815645a4dc942899111f494b4b737d6fda" dependencies = [ "libc", "wasi", @@ -3416,7 +3413,7 @@ dependencies = [ "mistralrs-core", "mistralrs-macros", "rand 0.9.4", - "reqwest 0.13.3", + "reqwest 0.13.4", "schemars 1.2.1", "serde", "serde_json", @@ -3503,7 +3500,7 @@ dependencies = [ "rayon", "regex", "regex-automata", - "reqwest 0.13.3", + "reqwest 0.13.4", "rubato", "rust-mcp-schema", "rustc-hash 2.1.2", @@ -3532,7 +3529,7 @@ dependencies = [ "tracing", "tracing-subscriber", "urlencoding", - "uuid 1.23.1", + "uuid 1.23.2", "variantly", "vob", ] @@ -3559,7 +3556,7 @@ dependencies = [ "async-trait", "futures-util", "http", - "reqwest 0.13.3", + "reqwest 0.13.4", "rust-mcp-schema", "serde", "serde_json", @@ -3567,7 +3564,7 @@ dependencies = [ "tokio-tungstenite", "tracing", "utoipa", - "uuid 1.23.1", + "uuid 1.23.2", ] [[package]] @@ -3694,11 +3691,11 @@ dependencies = [ [[package]] name = "nix" -version = "0.31.2" +version = "0.31.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d0705320c1e6ba1d912b5e37cf18071b6c2e9b7fa8215a1e8a7651966f5d3" +checksum = "cf20d2fde8ff38632c426f1165ed7436270b44f199fc55284c38276f9db47c3d" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "cfg-if", "cfg_aliases", "libc", @@ -3798,9 +3795,9 @@ dependencies = [ [[package]] name = "num-conv" -version = "0.2.1" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c6673768db2d862beb9b39a78fdcb1a69439615d5794a1be50caa9bc92c81967" +checksum = "521739c6d2bac4aa25192232afe6841231376b2b26d4d9fae5ecf8ca5772e441" [[package]] name = "num-derive" @@ -3894,7 +3891,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2a180dd8642fa45cdb7dd721cd4c11b1cadd4929ce112ebd8b9f5803cc79d536" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "dispatch2", "objc2", ] @@ -3911,7 +3908,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3e0adef53c21f888deb4fa59fc59f7eb17404926ee8a6f59f5df0fd7f9f3272" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "block2", "libc", "objc2", @@ -3934,7 +3931,7 @@ version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a0125f776a10d00af4152d74616409f0d4a2053a6f57fa5b7d6aa2854ac04794" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "block2", "dispatch2", "objc2", @@ -3960,7 +3957,7 @@ version = "6.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0cc3cbf698f9438986c11a880c90a6d04b9de27575afd28bbf45b154b6c709e2" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "libc", "once_cell", "onig_sys", @@ -4226,7 +4223,7 @@ version = "0.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "60769b8b31b2a9f263dae2776c37b1b28ae246943cf719eb6946a1db05128a61" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "crc32fast", "fdeflate", "flate2", @@ -4398,9 +4395,9 @@ checksum = "a993555f31e5a609f617c12db6250dedcac1b0a85076912c436e6fc9b2c8e6a3" [[package]] name = "quick-xml" -version = "0.39.3" +version = "0.40.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "721da970c312655cde9b4ffe0547f20a8494866a4af5ff51f18b7c633d0c870b" +checksum = "2474bd2e5029e7ccb6abb2ba48cf2383a333851dedf495901544281590c7da7f" dependencies = [ "memchr", ] @@ -4645,7 +4642,7 @@ version = "11.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "498cd0dc59d73224351ee52a95fee0f1a617a2eae0e7d9d720cc622c73a54186" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", ] [[package]] @@ -4712,7 +4709,7 @@ version = "0.5.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", ] [[package]] @@ -4823,9 +4820,9 @@ dependencies = [ [[package]] name = "reqwest" -version = "0.13.3" +version = "0.13.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "62e0021ea2c22aed41653bc7e1419abb2c97e038ff2c33d0e1309e49a97deec0" +checksum = "219c5811de6525e5416c7d5d53bb656d3afdbc6c5af816e0802bcfa42dbdc1c3" dependencies = [ "base64 0.22.1", "bytes", @@ -4958,7 +4955,7 @@ version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "errno", "libc", "linux-raw-sys", @@ -4983,9 +4980,9 @@ dependencies = [ [[package]] name = "rustls-native-certs" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "612460d5f7bea540c490b2b6395d8e34a953e52b491accd6c86c8164c5932a63" +checksum = "dab5152771c58876a2146916e53e35057e1a4dfa2b9df0f0305b07f611fdea4d" dependencies = [ "openssl-probe", "rustls-pki-types", @@ -5197,7 +5194,7 @@ version = "3.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b7f4bc775c73d9a02cde8bf7b2ec4c9d12743edf609006c7facc23998404cd1d" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "core-foundation 0.10.1", "core-foundation-sys", "libc", @@ -5220,7 +5217,7 @@ version = "0.33.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "feef350c36147532e1b79ea5c1f3791373e61cbd9a6a2615413b3807bb164fb7" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "cssparser", "derive_more", "log", @@ -5317,9 +5314,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.149" +version = "1.0.150" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83fc039473c5595ace860d8c4fafa220ff474b3fc6bfdb4293327f1a37e94d86" +checksum = "e8014e44b4736ed0538adeecded0fce2a272f22dc9578a7eb6b2d9993c74cfb9" dependencies = [ "indexmap 2.14.0", "itoa", @@ -5361,11 +5358,12 @@ dependencies = [ [[package]] name = "serde_with" -version = "3.19.0" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f05839ce67618e14a09b286535c0d9c94e85ef25469b0e13cb4f844e5593eb19" +checksum = "e72c1c2cb7b223fafb600a619537a871c2818583d619401b785e7c0b746ccde2" dependencies = [ "base64 0.22.1", + "bs58", "chrono", "hex", "indexmap 1.9.3", @@ -5380,9 +5378,9 @@ dependencies = [ [[package]] name = "serde_with_macros" -version = "3.19.0" +version = "3.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf2ebbe86054f9b45bc3881e865683ccfaccce97b9b4cb53f3039d67f355a334" +checksum = "b90c488738ecb4fb0262f41f43bc40efc5868d9fb744319ddf5f5317f417bfac" dependencies = [ "darling 0.23.0", "proc-macro2", @@ -5454,9 +5452,9 @@ dependencies = [ [[package]] name = "shlex" -version = "1.3.0" +version = "2.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +checksum = "f8fadd59c855ef2080decdef8ff161eb6661b86933c9d82e5ba29dc602a55aba" [[package]] name = "signal-hook" @@ -5559,9 +5557,9 @@ checksum = "ef784004ca8777809dcdad6ac37629f0a97caee4c685fcea805278d81dd8b857" [[package]] name = "socket2" -version = "0.6.3" +version = "0.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3a766e1110788c36f4fa1c2b71b387a7815aa65f88ce0229841826633d93723e" +checksum = "52d1cfed4120b4d927bf7c0f86d2087a4a7d6027c906d9f9d525a80573b9be51" dependencies = [ "libc", "windows-sys 0.61.2", @@ -5907,7 +5905,7 @@ version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01198a2debb237c62b6826ec7081082d951f46dbb64b0e8c7649a452230d1dfc" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "byteorder", "enum-as-inner", "libc", @@ -5935,7 +5933,7 @@ version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a13f3d0daba03132c0aa9767f98351b3488edc2c100cda2d2ec2b04f3d8d3c8b" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "core-foundation 0.9.4", "system-configuration-sys", ] @@ -6191,13 +6189,13 @@ dependencies = [ [[package]] name = "tokio" -version = "1.52.2" +version = "1.52.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "110a78583f19d5cdb2c5ccf321d1290344e71313c6c37d43520d386027d18386" +checksum = "8fc7f01b389ac15039e4dc9531aa973a135d7a4135281b12d7c1bc79fd57fffe" dependencies = [ "bytes", "libc", - "mio 1.2.0", + "mio 1.2.1", "parking_lot", "pin-project-lite", "signal-hook-registry", @@ -6319,7 +6317,7 @@ version = "1.1.2+spec-1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2abe9b86193656635d2411dc43050282ca48aa31c2451210f4202550afb7526" dependencies = [ - "winnow 1.0.2", + "winnow 1.0.3", ] [[package]] @@ -6345,20 +6343,20 @@ dependencies = [ [[package]] name = "tower-http" -version = "0.6.8" +version = "0.6.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d4e6559d53cc268e5031cd8429d05415bc4cb4aefc4aa5d6cc35fbf5b924a1f8" +checksum = "4cfcf7e2740e6fc6d4d688b4ef00650406bb94adf4731e43c096c3a19fe40840" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "bytes", "futures-util", "http", "http-body", - "iri-string", "pin-project-lite", "tower", "tower-layer", "tower-service", + "url", ] [[package]] @@ -6486,9 +6484,9 @@ checksum = "8e28f89b80c87b8fb0cf04ab448d5dd0dd0ade2f8891bae878de66a75a28600e" [[package]] name = "typenum" -version = "1.20.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "40ce102ab67701b8526c123c1bab5cbe42d7040ccfd0f64af1a385808d2f43de" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "ug" @@ -6560,9 +6558,9 @@ dependencies = [ [[package]] name = "unicode-segmentation" -version = "1.13.2" +version = "1.13.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9629274872b2bfaf8d66f5f15725007f635594914870f65218920345aa11aa8c" +checksum = "c6f5d3c3b1bf09027a88a6bc961fc00497d651009560b5463668dc81b0fa87a8" [[package]] name = "unicode-width" @@ -6683,9 +6681,9 @@ dependencies = [ [[package]] name = "uuid" -version = "1.23.1" +version = "1.23.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddd74a9687298c6858e9b88ec8935ec45d22e8fd5e6394fa1bd4e99a87789c76" +checksum = "d258b83ceec21034727ecee8c382cfa6c3e133699b0742c64571814fb420c9f7" dependencies = [ "getrandom 0.4.2", "js-sys", @@ -6784,9 +6782,9 @@ dependencies = [ [[package]] name = "wasm-bindgen" -version = "0.2.120" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df52b6d9b87e0c74c9edfa1eb2d9bf85e5d63515474513aa50fa181b3c4f5db1" +checksum = "3ed04576f974d2b2fba0f38c51dbc5518011e38c36bf1143164be765528fd409" dependencies = [ "cfg-if", "once_cell", @@ -6797,9 +6795,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-futures" -version = "0.4.70" +version = "0.4.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af934872acec734c2d80e6617bbb5ff4f12b052dd8e6332b0817bce889516084" +checksum = "9473dbd2991ae90b6291c3c32c30c6187ac49aa32f9905d1cce280ec1e110b0f" dependencies = [ "js-sys", "wasm-bindgen", @@ -6807,9 +6805,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro" -version = "0.2.120" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b1041f495fb322e64aca85f5756b2172e35cd459376e67f2a6c9dffcedb103" +checksum = "916151b09da36bd82f6615cbf3a419e2f0ba23a03c6160e8e92eb6bd4aa1dec6" dependencies = [ "quote", "wasm-bindgen-macro-support", @@ -6817,9 +6815,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-macro-support" -version = "0.2.120" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9dcd0ff20416988a18ac686d4d4d0f6aae9ebf08a389ff5d29012b05af2a1b41" +checksum = "299047362ccbfce148b67ab7e73349f77748e00c8296f9542adfad2ad82c5c5e" dependencies = [ "bumpalo", "proc-macro2", @@ -6830,9 +6828,9 @@ dependencies = [ [[package]] name = "wasm-bindgen-shared" -version = "0.2.120" +version = "0.2.122" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "49757b3c82ebf16c57d69365a142940b384176c24df52a087fb748e2085359ea" +checksum = "9a929b2c61f11ba3e9bc35b50c1f25cb38e0e892c0c231ae2b8cf78d5dad4437" dependencies = [ "unicode-ident", ] @@ -6891,7 +6889,7 @@ version = "0.244.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "47b807c72e1bac69382b3a6fb3dbe8ea4c0ed87ff5629b8685ae6b9a611028fe" dependencies = [ - "bitflags 2.11.1", + "bitflags 2.12.1", "hashbrown 0.15.5", "indexmap 2.14.0", "semver", @@ -6899,9 +6897,9 @@ dependencies = [ [[package]] name = "web-sys" -version = "0.3.97" +version = "0.3.99" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2eadbac71025cd7b0834f20d1fe8472e8495821b4e9801eb0a60bd1f19827602" +checksum = "6d621441cfc37b84979402712047321980c178f299193a3589d05b99e8763436" dependencies = [ "js-sys", "wasm-bindgen", @@ -7407,9 +7405,9 @@ checksum = "df79d97927682d2fd8adb29682d1140b343be4ac0f08fd68b7765d9c059d3945" [[package]] name = "winnow" -version = "1.0.2" +version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2ee1708bef14716a11bae175f579062d4554d95be2c6829f518df847b7b3fdd0" +checksum = "0592e1c9d151f854e6fd382574c3a0855250e1d9b2f99d9281c6e6391af352f1" [[package]] name = "wit-bindgen" @@ -7475,7 +7473,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9d66ea20e9553b30172b5e831994e35fbde2d165325bec84fc43dbf6f4eb9cb2" dependencies = [ "anyhow", - "bitflags 2.11.1", + "bitflags 2.12.1", "indexmap 2.14.0", "log", "serde", @@ -7591,18 +7589,18 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.8.48" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eed437bf9d6692032087e337407a86f04cd8d6a16a37199ed57949d415bd68e9" +checksum = "3b065d4f0e55f82fae73202e189638116a87c55ab6b8e6c2721e13dd9d854ad1" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.8.48" +version = "0.8.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "70e3cd084b1788766f53af483dd21f93881ff30d7320490ec3ef7526d203bad4" +checksum = "0b631b19d36a892ab55420c92dbc83ccd79274f25be714855d3074aa71cab639" dependencies = [ "proc-macro2", "quote", @@ -7611,9 +7609,9 @@ dependencies = [ [[package]] name = "zerofrom" -version = "0.1.7" +version = "0.1.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "69faa1f2a1ea75661980b013019ed6687ed0e83d069bc1114e2cc74c6c04c4df" +checksum = "0ec05a11813ea801ff6d75110ad09cd0824ddba17dfe17128ea0d5f68e6c5272" dependencies = [ "zerofrom-derive", ] diff --git a/dotscope-cli/Cargo.toml b/dotscope-cli/Cargo.toml index 07cc1e45..7fd92481 100644 --- a/dotscope-cli/Cargo.toml +++ b/dotscope-cli/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dotscope-cli" -version = "0.7.0" +version = "0.8.0" authors = ["Johann Kempter "] edition.workspace = true license.workspace = true @@ -18,12 +18,12 @@ smart-rename = ["dotscope/smart-rename"] [dependencies] dotscope = { path = "../dotscope", features = ["deobfuscation"] } -clap = { version = "4.6.0", features = ["derive", "env", "wrap_help"] } +clap = { version = "4.6.1", features = ["derive", "env", "wrap_help"] } serde = { version = "1.0.228", features = ["derive"] } -serde_json = "1.0.149" +serde_json = "1.0.150" anyhow = "1.0.102" comfy-table = "7.2.2" widestring = "1.2.1" ctrlc = "3.5.2" env_logger = "0.11.10" -log = "0.4.29" +log = "0.4.31" diff --git a/dotscope-cli/src/commands/attrs.rs b/dotscope-cli/src/commands/attrs.rs index ab60888f..9ecfa63c 100644 --- a/dotscope-cli/src/commands/attrs.rs +++ b/dotscope-cli/src/commands/attrs.rs @@ -273,7 +273,7 @@ fn resolve_constructor_type( } TableId::MethodDef => { // Look up the method and get its declaring type - if let Some(method) = assembly.method(&constructor.token) { + if let Ok(method) = assembly.method(&constructor.token) { if let Some(name) = method.declaring_type_fullname() { return name; } diff --git a/dotscope-cli/src/commands/resolution.rs b/dotscope-cli/src/commands/resolution.rs index ae703cf6..4481e419 100644 --- a/dotscope-cli/src/commands/resolution.rs +++ b/dotscope-cli/src/commands/resolution.rs @@ -20,7 +20,7 @@ pub fn parse_token_filter(filter: &str) -> Option { pub fn resolve_methods(assembly: &CilObject, filter: &str) -> anyhow::Result>> { // Try token first if let Some(token) = parse_token_filter(filter) { - if let Some(method) = assembly.method(&token) { + if let Ok(method) = assembly.method(&token) { return Ok(vec![method]); } bail!("no method with token {filter} found"); diff --git a/dotscope/Cargo.toml b/dotscope/Cargo.toml index fac1f8bc..d5c1fb1f 100644 --- a/dotscope/Cargo.toml +++ b/dotscope/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "dotscope" -version = "0.7.0" +version = "0.8.0" authors = ["Johann Kempter "] edition.workspace = true description = "A high-performance, cross-platform framework for analyzing and reverse engineering .NET PE executables" @@ -47,7 +47,7 @@ strum = { version = "0.28.0", features = ["derive"]} cowfile = "0.2.2" memmap2 = "0.9.10" tempfile = "3.27.0" -goblin = { version = "0.10.5", default-features = false, features = ["pe32", "pe64", "std"] } +goblin = { version = "0.10.7", default-features = false, features = ["pe32", "pe64", "std"] } ouroboros = "0.18.5" sha1 = { version = "0.11.0", optional = true, features = ["oid"] } sha2 = { version = "0.11.0", features = ["oid"] } @@ -55,25 +55,25 @@ md-5 = { version = "0.11.0", optional = true } hmac = "0.13.0" imbl = { version = "7.0.0", optional = true } pbkdf2 = "0.13.0" -aes = "0.9.0" +aes = "0.9.1" des = { version = "0.9.0", optional = true } -cbc = "0.2.0" +cbc = "0.2.1" ecb = "0.2.0" -dashmap = "6.1.0" +dashmap = "6.2.1" crossbeam-skiplist = "0.1.3" rayon = "1.12.0" rustc-hash = "2.1.2" boxcar = "0.2.14" -quick-xml = "0.39.2" +quick-xml = "0.40.1" hex = "0.4.3" num-bigint = { version = "0.4.6", optional = true } -log = "0.4.29" +log = "0.4.31" flate2 = "1.1.9" -analyssa = "0.1.0" +analyssa = { path = "../../analyssa" } lzma-rs = "0.3.0" z3 = { version = "0.20.0", optional = true } iced-x86 = { version = "1.21.0", default-features = false, features = ["std", "decoder", "instr_info"], optional = true } -tokio = { version = "1.52.1", optional = true, features = ["rt-multi-thread"] } +tokio = { version = "1.52.3", optional = true, features = ["rt-multi-thread"] } # Metal GPU acceleration for LLM inference on macOS (Apple Silicon / AMD GPU). [target.'cfg(target_os = "macos")'.dependencies] diff --git a/dotscope/fuzz/Cargo.lock b/dotscope/fuzz/Cargo.lock index 9d6c2458..c334cd02 100644 --- a/dotscope/fuzz/Cargo.lock +++ b/dotscope/fuzz/Cargo.lock @@ -10,12 +10,12 @@ checksum = "320119579fcad9c21884f5c4861d16174d0e06250625266f50fe6898340abefa" [[package]] name = "aes" -version = "0.8.4" +version = "0.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +checksum = "f1fc76eaeac4c9164506c466d4ffdd8ec9d0c5bf57ee97177c4d8eceb3a0e138" dependencies = [ - "cfg-if", "cipher", + "cpubits", "cpufeatures", ] @@ -25,6 +25,17 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "250f629c0161ad8107cf89319e990051fae62832fd343083bea452d93e2205fd" +[[package]] +name = "analyssa" +version = "0.2.0" +dependencies = [ + "boxcar", + "dashmap", + "log", + "rayon", + "thiserror", +] + [[package]] name = "anyhow" version = "1.0.101" @@ -63,20 +74,20 @@ checksum = "a1d084b0137aaa901caf9f1e8b21daa6aa24d41cd806e111335541eff9683bd6" [[package]] name = "block-buffer" -version = "0.10.4" +version = "0.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" +checksum = "cdd35008169921d80bc60d3d0ab416eecb028c4cd653352907921d95084790be" dependencies = [ - "generic-array", + "hybrid-array", ] [[package]] name = "block-padding" -version = "0.3.3" +version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a8894febbff9f758034a5b8e12d87918f56dfc64a8e1fe757d65e29041538d93" +checksum = "710f1dd022ef4e93f8a438b4ba958de7f64308434fa6a87104481645cc30068b" dependencies = [ - "generic-array", + "hybrid-array", ] [[package]] @@ -99,9 +110,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "cbc" -version = "0.1.2" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26b52a9543ae338f279b96b0b9fed9c8093744685043739079ce85cd58f289a6" +checksum = "ce2dc9ee5f88d11e0beb842c88b33c8a5cf0d1329c4b19494af42b07dbfe8896" dependencies = [ "cipher", ] @@ -119,25 +130,53 @@ dependencies = [ [[package]] name = "cfg-if" -version = "1.0.0" +version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" [[package]] name = "cipher" -version = "0.4.4" +version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +checksum = "e8cf2a2c93cd704877c0858356ed03480ff301ee950b43f1cbe4573b088bfa6c" dependencies = [ "crypto-common", "inout", ] +[[package]] +name = "cmov" +version = "0.5.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0c9ea0ac24bc397ab3c98583a3c9ba74fa56b09a4449bbe172b9b1ddb016027a" + +[[package]] +name = "const-oid" +version = "0.10.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a6ef517f0926dd24a1582492c791b6a4818a4d94e789a334894aa15b0d12f55c" + +[[package]] +name = "cowfile" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6b0c9a99dda5d60063c8daece5d8c7dc1a7b1cd8a3695fb0f4be1df2193ed138" +dependencies = [ + "memmap2", + "thiserror", +] + +[[package]] +name = "cpubits" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "15b85f9c39137c3a891689859392b1bd49812121d0d61c9caf00d46ed5ce06ae" + [[package]] name = "cpufeatures" -version = "0.2.17" +version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "59ed5838eebb26a2bb2e58f6d5b5316989ae9d08bab10e0e6d103e656d1b0280" +checksum = "8b2a41393f66f16b0823bb79094d54ac5fbd34ab292ddafb9a0456ac9f87d201" dependencies = [ "libc", ] @@ -203,19 +242,27 @@ checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" [[package]] name = "crypto-common" -version = "0.1.6" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" +checksum = "ce6e4c961d6cd6c9a86db418387425e8bdeaf05b3c8bc1411e6dca4c252f1453" dependencies = [ - "generic-array", - "typenum", + "hybrid-array", +] + +[[package]] +name = "ctutils" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5515a3834141de9eafb9717ad39eea8247b5674e6066c404e8c4b365d2a29e" +dependencies = [ + "cmov", ] [[package]] name = "dashmap" -version = "6.1.0" +version = "6.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5041cc499144891f3790297212f32a74fb938e5136a14943f338ef9e0ae276cf" +checksum = "e6361d5c062261c78a176addb82d4c821ae42bed6089de0e12603cd25de2059c" dependencies = [ "cfg-if", "crossbeam-utils", @@ -227,35 +274,38 @@ dependencies = [ [[package]] name = "des" -version = "0.8.1" +version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ffdd80ce8ce993de27e9f063a444a4d53ce8e8db4c1f00cc03af5ad5a9867a1e" +checksum = "916a94e407b54f9034d71dd748234cd1e516ced6284009906ae246f177eafe5a" dependencies = [ "cipher", ] [[package]] name = "digest" -version = "0.10.7" +version = "0.11.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" +checksum = "f1dd6dbb5841937940781866fa1281a1ff7bd3bf827091440879f9994983d5c2" dependencies = [ "block-buffer", + "const-oid", "crypto-common", - "subtle", + "ctutils", ] [[package]] name = "dotscope" -version = "0.6.0" +version = "0.8.0" dependencies = [ "aes", - "bitflags", + "analyssa", "boxcar", "cbc", + "cowfile", "crossbeam-skiplist", "dashmap", "des", + "ecb", "flate2", "goblin", "hex", @@ -266,7 +316,7 @@ dependencies = [ "lzma-rs", "md-5", "memmap2", - "num_cpus", + "num-bigint", "ouroboros", "pbkdf2", "quick-xml", @@ -289,6 +339,15 @@ dependencies = [ "libfuzzer-sys", ] +[[package]] +name = "ecb" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbfbf3db731928d6912bc3beda911d55b834cee3df6131ba79b337a1298a3fa9" +dependencies = [ + "cipher", +] + [[package]] name = "either" version = "1.15.0" @@ -333,16 +392,6 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9c4f5dac5e15c24eb999c26181a6ca40b39fe946cbe4c263c7209467bc83af2" -[[package]] -name = "generic-array" -version = "0.14.7" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" -dependencies = [ - "typenum", - "version_check", -] - [[package]] name = "getrandom" version = "0.4.1" @@ -358,9 +407,9 @@ dependencies = [ [[package]] name = "goblin" -version = "0.10.5" +version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "983a6aafb3b12d4c41ea78d39e189af4298ce747353945ff5105b54a056e5cd9" +checksum = "17582616a7718cca54cec18e534a76c7c4aec11a8b9a85695712f262fd15a4c8" dependencies = [ "log", "plain", @@ -400,12 +449,6 @@ version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" -[[package]] -name = "hermit-abi" -version = "0.5.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fc0fef456e4baa96da950455cd02c081ca953b141298e41db3fc7e36b1da849c" - [[package]] name = "hex" version = "0.4.3" @@ -414,13 +457,22 @@ checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70" [[package]] name = "hmac" -version = "0.12.1" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" +checksum = "6303bc9732ae41b04cb554b844a762b4115a61bfaa81e3e83050991eeb56863f" dependencies = [ "digest", ] +[[package]] +name = "hybrid-array" +version = "0.4.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9155a582abd142abc056962c29e3ce5ff2ad5469f4246b537ed42c5deba857da" +dependencies = [ + "typenum", +] + [[package]] name = "iced-x86" version = "1.21.0" @@ -474,12 +526,12 @@ dependencies = [ [[package]] name = "inout" -version = "0.1.4" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "879f10e63c20629ecabbb64a8010319738c66a5cd0c29b02d63d272b03751d01" +checksum = "4250ce6452e92010fdf7268ccc5d14faa80bb12fc741938534c58f16804e03c7" dependencies = [ "block-padding", - "generic-array", + "hybrid-array", ] [[package]] @@ -511,9 +563,9 @@ checksum = "09edd9e8b54e49e587e4f6295a7d29c3ea94d469cb40ab8ca70b288248a81db2" [[package]] name = "libc" -version = "0.2.181" +version = "0.2.186" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "459427e2af2b9c839b132acb702a1c654d95e10f8c326bfc2ad11310e458b1c5" +checksum = "68ab91017fe16c622486840e4c83c9a37afeff978bd239b5293d61ece587de66" [[package]] name = "libfuzzer-sys" @@ -527,25 +579,24 @@ dependencies = [ [[package]] name = "linux-raw-sys" -version = "0.11.0" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df1d3c3b53da64cf5760482273a98e575c651a67eec7f77df96b5b642de8f039" +checksum = "32a66949e030da00e8c7d4434b251670a91556f4144941d37452769c25d58a53" [[package]] name = "lock_api" -version = "0.4.12" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07af8b9cdd281b7915f413fa73f29ebd5d55d0d3f0155584dade1ff18cea1b17" +checksum = "224399e74b87b5f3557511d98dff8b14089b3dadafcab6bb93eab67d3aace965" dependencies = [ - "autocfg", "scopeguard", ] [[package]] name = "log" -version = "0.4.29" +version = "0.4.31" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e5032e24019045c762d3c0f28f5b6b8bbf38563a65908389bf7978758920897" +checksum = "113b30b4cd05f7c06868fdb2854f66a7b9fece9a48425351cd532e810d74024f" [[package]] name = "lzma-rs" @@ -559,9 +610,9 @@ dependencies = [ [[package]] name = "md-5" -version = "0.10.6" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf" +checksum = "69b6441f590336821bb897fb28fc622898ccceb1d6cea3fde5ea86b090c4de98" dependencies = [ "cfg-if", "digest", @@ -575,9 +626,9 @@ checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memmap2" -version = "0.9.9" +version = "0.9.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "744133e4a0e0a658e1374cf3bf8e415c4052a15a111acd372764c55b4177d490" +checksum = "714098028fe011992e1c3962653c96b2d578c4b4bce9036e15ff220319b1e0e3" dependencies = [ "libc", ] @@ -593,20 +644,38 @@ dependencies = [ ] [[package]] -name = "num_cpus" -version = "1.17.0" +name = "num-bigint" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "91df4bbde75afed763b708b7eee1e8e7651e02d97f6d5dd763e89367e957b23b" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ - "hermit-abi", - "libc", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-integer" +version = "0.1.46" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" +dependencies = [ + "num-traits", +] + +[[package]] +name = "num-traits" +version = "0.2.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" +dependencies = [ + "autocfg", ] [[package]] name = "once_cell" -version = "1.21.3" +version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" [[package]] name = "ouroboros" @@ -634,22 +703,22 @@ dependencies = [ [[package]] name = "parking_lot_core" -version = "0.9.10" +version = "0.9.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" +checksum = "2621685985a2ebf1c516881c026032ac7deafcda1a2c9b7850dc81e3dfcb64c1" dependencies = [ "cfg-if", "libc", "redox_syscall", "smallvec", - "windows-targets", + "windows-link", ] [[package]] name = "pbkdf2" -version = "0.12.2" +version = "0.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f8ed6a7761f76e3b9f92dfb0a60a6a6477c61024b775147ff0973a02653abaf2" +checksum = "112d82ceb8c5bf524d9af484d4e4970c9fd5a0cc15ba14ad93dccd28873b0629" dependencies = [ "digest", "hmac", @@ -695,9 +764,9 @@ dependencies = [ [[package]] name = "quick-xml" -version = "0.39.0" +version = "0.40.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f2e3bf4aa9d243beeb01a7b3bc30b77cfe2c44e24ec02d751a7104a53c2c49a1" +checksum = "2474bd2e5029e7ccb6abb2ba48cf2383a333851dedf495901544281590c7da7f" dependencies = [ "memchr", ] @@ -734,9 +803,9 @@ dependencies = [ [[package]] name = "rayon" -version = "1.11.0" +version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "368f01d005bf8fd9b1206fb6fa653e6c4a81ceb1466406b81792d87c5677a58f" +checksum = "fb39b166781f92d482534ef4b4b1b2568f42613b53e5b6c160e24cfbfa30926d" dependencies = [ "either", "rayon-core", @@ -763,15 +832,15 @@ dependencies = [ [[package]] name = "rustc-hash" -version = "2.1.1" +version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "357703d41365b4b27c590e3ed91eabb1b663f07c4c084095e60cbed4362dff0d" +checksum = "94300abf3f1ae2e2b8ffb7b58043de3d399c73fa6f4b73826402a5c457614dbe" [[package]] name = "rustix" -version = "1.1.3" +version = "1.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "146c9e247ccc180c1f61615433868c99f3de3ae256a30a43b49f67c2d9171f34" +checksum = "b6fe4565b9518b83ef4f91bb47ce29620ca828bd32cb7e408f0062e9930ba190" dependencies = [ "bitflags", "errno", @@ -780,12 +849,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "rustversion" -version = "1.0.20" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eded382c5f5f786b989652c49544c4877d9f015cc22e145a5ea8ea66c2921cd2" - [[package]] name = "safe_arch" version = "0.7.4" @@ -871,9 +934,9 @@ dependencies = [ [[package]] name = "sha1" -version = "0.10.6" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +checksum = "aacc4cc499359472b4abe1bf11d0b12e688af9a805fa5e3016f9a386dc2d0214" dependencies = [ "cfg-if", "cpufeatures", @@ -882,9 +945,9 @@ dependencies = [ [[package]] name = "sha2" -version = "0.10.9" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7507d819769d01a365ab707794a4084392c824f54a7a6a7862f8c3d0892b283" +checksum = "446ba717509524cb3f22f17ecc096f10f4822d76ab5c0b9822c5f9c284e825f4" dependencies = [ "cfg-if", "cpufeatures", @@ -917,32 +980,25 @@ checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" [[package]] name = "strum" -version = "0.27.2" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "af23d6f6c1a224baef9d3f61e287d2761385a5b88fdab4eb4c6f11aeb54c4bcf" +checksum = "9628de9b8791db39ceda2b119bbe13134770b56c138ec1d3af810d045c04f9bd" dependencies = [ "strum_macros", ] [[package]] name = "strum_macros" -version = "0.27.1" +version = "0.28.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c77a8c5abcaf0f9ce05d62342b7d298c346515365c36b673df4ebe3ced01fde8" +checksum = "ab85eea0270ee17587ed4156089e10b9e6880ee688791d45a905f5b1ca36f664" dependencies = [ "heck 0.5.0", "proc-macro2", "quote", - "rustversion", "syn", ] -[[package]] -name = "subtle" -version = "2.6.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" - [[package]] name = "syn" version = "2.0.115" @@ -956,9 +1012,9 @@ dependencies = [ [[package]] name = "tempfile" -version = "3.25.0" +version = "3.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0136791f7c95b1f6dd99f9cc786b91bb81c3800b639b3478e561ddb7be95e5f1" +checksum = "32497e9a4c7b38532efcdebeef879707aa9f794296a4f0244f6f69e9bc8574bd" dependencies = [ "fastrand", "getrandom", @@ -989,9 +1045,9 @@ dependencies = [ [[package]] name = "typenum" -version = "1.18.0" +version = "1.20.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1dccffe3ce07af9386bfd29e80c0ab1a8205a2fc34e4bcd40364df902cfa8f3f" +checksum = "b6f5e870be6c3b371b77fe0ee0bafb859fa4964b4404c27de1d380043c4dda20" [[package]] name = "uguid" @@ -1100,70 +1156,6 @@ dependencies = [ "windows-link", ] -[[package]] -name = "windows-targets" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" -dependencies = [ - "windows_aarch64_gnullvm", - "windows_aarch64_msvc", - "windows_i686_gnu", - "windows_i686_gnullvm", - "windows_i686_msvc", - "windows_x86_64_gnu", - "windows_x86_64_gnullvm", - "windows_x86_64_msvc", -] - -[[package]] -name = "windows_aarch64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" - -[[package]] -name = "windows_aarch64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" - -[[package]] -name = "windows_i686_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" - -[[package]] -name = "windows_i686_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" - -[[package]] -name = "windows_i686_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" - -[[package]] -name = "windows_x86_64_gnu" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" - -[[package]] -name = "windows_x86_64_gnullvm" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" - -[[package]] -name = "windows_x86_64_msvc" -version = "0.52.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" - [[package]] name = "wit-bindgen" version = "0.51.0" diff --git a/dotscope/src/analysis/ssa/converter.rs b/dotscope/src/analysis/ssa/converter.rs index 8d6b4fe7..f2eb5093 100644 --- a/dotscope/src/analysis/ssa/converter.rs +++ b/dotscope/src/analysis/ssa/converter.rs @@ -41,9 +41,9 @@ use crate::{ cfg::ControlFlowGraph, ssa::{ decompose::decompose_instruction, liveness, place_pruned_phis, - resolve_corelib_valuetype, ConstValue, DefSite, PhiNode, SimulationResult, SsaBlock, - SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, StackSimulator, StackSlot, - StackSlotSource, TypeProvider, UseSite, VariableOrigin, + resolve_corelib_valuetype, ConstValue, DefSite, PhiNode, PhiPlacementConfig, + SimulationResult, SsaBlock, SsaFunction, SsaInstruction, SsaOp, SsaType, SsaVarId, + StackSimulator, StackSlot, StackSlotSource, TypeProvider, UseSite, VariableOrigin, }, }, assembly::{opcodes, Immediate, Instruction, Operand}, @@ -293,6 +293,8 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { ConstValue::F64(_) => SsaType::F64, ConstValue::String(_) | ConstValue::DecryptedString(_) => SsaType::String, ConstValue::DecryptedArray { .. } => SsaType::Object, + // SIMD vector constant from the native substrate — no CIL type. + ConstValue::Vector(_) => SsaType::Unknown, ConstValue::Null => SsaType::Null, ConstValue::True | ConstValue::False => SsaType::Bool, // Runtime handle types (`ldtoken` results). Each handle @@ -322,8 +324,15 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { }), }, - // Comparison results are always bool (represented as I32 on stack) - SsaOp::Ceq { .. } | SsaOp::Clt { .. } | SsaOp::Cgt { .. } => SsaType::Bool, + // Comparison results are always bool (represented as I32 on stack). + // Native boolean ops from the analyssa substrate likewise yield bool. + SsaOp::Ceq { .. } + | SsaOp::Clt { .. } + | SsaOp::Cgt { .. } + | SsaOp::BoolAnd { .. } + | SsaOp::BoolOr { .. } + | SsaOp::BoolXor { .. } + | SsaOp::BoolNot { .. } => SsaType::Bool, // Conversion - use the target type SsaOp::Conv { target, .. } => target.clone(), @@ -516,7 +525,59 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { | SsaOp::Fence { .. } | SsaOp::InterruptReturn | SsaOp::Unreachable - | SsaOp::Readonly => SsaType::Unknown, + | SsaOp::Readonly + // Native SSA substrate operations (wide arithmetic, SIMD/vector, + // native atomics, bitcast, opaque, indirect branch). These never + // appear in CIL-lifted SSA — CIL type inference does not model + // them — so they resolve to Unknown. Enumerated explicitly (no + // wildcard) so future substrate additions still trip this + // exhaustiveness check. + | SsaOp::WideMul { .. } + | SsaOp::WideDiv { .. } + | SsaOp::FloatCompareFlags { .. } + | SsaOp::Bitcast { .. } + | SsaOp::IndirectBranch { .. } + | SsaOp::NativeOpaque(_) + | SsaOp::VectorUnary { .. } + | SsaOp::VectorBinary { .. } + | SsaOp::VectorTernary { .. } + | SsaOp::VectorPredicatedUnary { .. } + | SsaOp::VectorPredicatedBinary { .. } + | SsaOp::VectorPredicatedTernary { .. } + | SsaOp::VectorCompare { .. } + | SsaOp::VectorLoad { .. } + | SsaOp::VectorStore { .. } + | SsaOp::VectorMaskedLoad { .. } + | SsaOp::VectorMaskedStore { .. } + | SsaOp::VectorBroadcastLoad { .. } + | SsaOp::VectorGather { .. } + | SsaOp::VectorFaultingLoad { .. } + | SsaOp::VectorSegmentLoad { .. } + | SsaOp::VectorScatter { .. } + | SsaOp::VectorSegmentStore { .. } + | SsaOp::VectorExtract { .. } + | SsaOp::VectorInsert { .. } + | SsaOp::VectorSplat { .. } + | SsaOp::VectorShuffle { .. } + | SsaOp::VectorCast { .. } + | SsaOp::VectorReinterpret { .. } + | SsaOp::VectorPack { .. } + | SsaOp::VectorPackLoad { .. } + | SsaOp::VectorPackStore { .. } + | SsaOp::VectorZeroUpper { .. } + | SsaOp::VectorMaskUnary { .. } + | SsaOp::VectorMaskBinary { .. } + | SsaOp::VectorReduce { .. } + | SsaOp::VectorBitmask { .. } + | SsaOp::AtomicLoad { .. } + | SsaOp::AtomicStore { .. } + | SsaOp::AtomicStoreConditional { .. } + | SsaOp::AtomicPairLoad { .. } + | SsaOp::AtomicPairStoreConditional { .. } + | SsaOp::AtomicExchange { .. } + | SsaOp::AtomicLockRmw { .. } + | SsaOp::AtomicCmpXchg { .. } + | SsaOp::AtomicPairCmpXchg { .. } => SsaType::Unknown, } } @@ -1357,7 +1418,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { match token.table() { // MethodDef (0x06) - method defined in this assembly 0x06 => { - let method = assembly.method(&token)?; + let method = assembly.method(&token).ok()?; let param_count = method.signature.params.len(); let has_this = !method.is_static(); let has_return = !matches!(method.signature.return_type.base, TypeSignature::Void); @@ -1377,7 +1438,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { // MethodSpec (0x2B) - generic method instantiation 0x2B => { - let method_spec = assembly.method_spec(&token)?; + let method_spec = assembly.method_spec(&token).ok()?; // Get the underlying method token from the CilTypeReference let underlying_token = match &method_spec.method { CilTypeReference::MethodDef(method_ref) => { @@ -1531,23 +1592,25 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { let num_locals = self.num_locals; let _ = place_pruned_phis( self.function.blocks_mut(), - &self.defs, - &live_in, - dominance_frontiers, - None, // All blocks reachable during initial construction - &|_| true, // Process all groups - &|group| { - group_origins.get(&group).copied().unwrap_or_else(|| { - if (group as usize) < num_args { - VariableOrigin::Argument(group as u16) - } else if (group as usize) < num_args.saturating_add(num_locals) { - VariableOrigin::Local((group as usize).saturating_sub(num_args) as u16) - } else { - VariableOrigin::Phi - } - }) + PhiPlacementConfig { + defs: &self.defs, + live_in: &live_in, + dominance_frontiers, + reachable: None, // All blocks reachable during initial construction + group_filter: &|_| true, // Process all groups + group_to_origin: &|group| { + group_origins.get(&group).copied().unwrap_or_else(|| { + if (group as usize) < num_args { + VariableOrigin::Argument(group as u16) + } else if (group as usize) < num_args.saturating_add(num_locals) { + VariableOrigin::Local((group as usize).saturating_sub(num_args) as u16) + } else { + VariableOrigin::Phi + } + }) + }, + leave_target_fn: None, // No Leave target handling during initial construction }, - None, // No Leave target handling during initial construction ); // Second, place PHI nodes for stack positions at merge points @@ -2620,9 +2683,56 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { if let Some(block) = self.function.block_mut(block_idx) { if let Some(instr) = block.instruction_mut(instr_idx) { let op = instr.op_mut(); - for (old_use, &new_use) in uses.iter().zip(renamed_uses.iter()) { + // Apply the placeholder→final renaming atomically. + // + // `replace_uses(old, new)` rewrites EVERY operand + // equal to `old`, so a naive sequential loop + // miscompiles whenever a freshly-written value + // aliases a not-yet-processed operand's old value + // (e.g. uses=[a,b] renamed=[b,c] collapses both + // operands to `c`). SSA var-id numbering can make a + // final id equal a sibling operand's placeholder id, + // so this is not hypothetical. + // + // Move each changing operand to a unique temporary id + // first, then from the temporary to its final value. + // This is position-wise exact and handles aliasing, + // cycles (swaps), and repeated operands. + // + // Temps only have to be distinct from this op's own + // operands (and each other) — `replace_uses` touches + // no other instruction — so the lowest unused ids + // suffice. They must NOT be derived from `max id + 1`: + // operand ids include the reserved placeholder + // (`u32::MAX - 1`), which would overflow the index + // space and reintroduce collisions. + let mut temps: Vec = Vec::with_capacity(uses.len()); + let mut candidate = 0usize; + for _ in 0..uses.len() { + loop { + let id = SsaVarId::from_index(candidate); + candidate = candidate.saturating_add(1); + if !uses.contains(&id) + && !renamed_uses.contains(&id) + && !temps.contains(&id) + { + temps.push(id); + break; + } + } + } + for (&temp, (old_use, &new_use)) in + temps.iter().zip(uses.iter().zip(renamed_uses.iter())) + { + if *old_use != new_use { + op.replace_uses(*old_use, temp); + } + } + for (&temp, (old_use, &new_use)) in + temps.iter().zip(uses.iter().zip(renamed_uses.iter())) + { if *old_use != new_use { - op.replace_uses(*old_use, new_use); + op.replace_uses(temp, new_use); } } } diff --git a/dotscope/src/analysis/ssa/mod.rs b/dotscope/src/analysis/ssa/mod.rs index c58100eb..aa74faca 100644 --- a/dotscope/src/analysis/ssa/mod.rs +++ b/dotscope/src/analysis/ssa/mod.rs @@ -128,7 +128,7 @@ pub use analyssa::Target; #[allow(unused_imports)] pub use analyssa::analysis::consts::evaluate_const_op; pub use analyssa::analysis::evaluator::ControlFlow; -pub use analyssa::analysis::phis::{place_pruned_phis, PhiAnalyzer}; +pub use analyssa::analysis::phis::{place_pruned_phis, PhiAnalyzer, PhiPlacementConfig}; /// CIL-defaulted alias of [`analyssa::ir::block::SsaBlock`]. pub type SsaBlock = analyssa::ir::block::SsaBlock; diff --git a/dotscope/src/analysis/ssa/target.rs b/dotscope/src/analysis/ssa/target.rs index 7b99931e..60bf7cfc 100644 --- a/dotscope/src/analysis/ssa/target.rs +++ b/dotscope/src/analysis/ssa/target.rs @@ -170,6 +170,8 @@ impl Target for CilTarget { ConstValue::F64(_) => SsaType::F64, ConstValue::String(_) | ConstValue::DecryptedString(_) => SsaType::String, ConstValue::DecryptedArray { .. } => SsaType::Object, + // SIMD vector constant from the native substrate — no CIL type. + ConstValue::Vector(_) => return None, ConstValue::Null => SsaType::Null, ConstValue::True | ConstValue::False => SsaType::Bool, ConstValue::Type(_) | ConstValue::MethodHandle(_) | ConstValue::FieldHandle(_) => { diff --git a/dotscope/src/analysis/ssa/value.rs b/dotscope/src/analysis/ssa/value.rs index 173c6a05..6513f33f 100644 --- a/dotscope/src/analysis/ssa/value.rs +++ b/dotscope/src/analysis/ssa/value.rs @@ -68,13 +68,14 @@ impl ConstValueCilExt for AnalyssaConstValue { AnalyssaConstValue::Null | AnalyssaConstValue::String(_) | AnalyssaConstValue::DecryptedString(_) + | AnalyssaConstValue::Vector(_) | AnalyssaConstValue::DecryptedArray { .. } => SsaType::Object, } } fn as_string_content(&self, assembly: &CilObject) -> Option { match self { - AnalyssaConstValue::DecryptedString(s) => Some(s.clone()), + AnalyssaConstValue::DecryptedString(s) => Some(s.to_string()), AnalyssaConstValue::String(idx) => assembly .userstrings() .and_then(|us| us.get(*idx as usize).ok()) @@ -116,6 +117,7 @@ impl TryFrom<&AnalyssaConstValue> for Immediate { AnalyssaConstValue::String(_) | AnalyssaConstValue::DecryptedString(_) + | AnalyssaConstValue::Vector(_) | AnalyssaConstValue::DecryptedArray { .. } | AnalyssaConstValue::Null | AnalyssaConstValue::Type(_) diff --git a/dotscope/src/assembly/instruction.rs b/dotscope/src/assembly/instruction.rs index 901c6e75..58303bb3 100644 --- a/dotscope/src/assembly/instruction.rs +++ b/dotscope/src/assembly/instruction.rs @@ -134,6 +134,38 @@ impl OperandType { OperandType::Switch => None, // Variable size: 4 + (count * 4) } } + + /// Returns a stable `&'static str` identifier for this operand type. + /// + /// The strings are part of the stable public API and safe to persist + /// (file, database, log line). They use lowercase ECMA-335-style names + /// (`"none"`, `"int8"`, `"uint8"`, `"int16"`, `"uint16"`, `"int32"`, + /// `"uint32"`, `"int64"`, `"uint64"`, `"float32"`, `"float64"`, + /// `"token"`, `"switch"`). + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + OperandType::None => "none", + OperandType::Int8 => "int8", + OperandType::UInt8 => "uint8", + OperandType::Int16 => "int16", + OperandType::UInt16 => "uint16", + OperandType::Int32 => "int32", + OperandType::UInt32 => "uint32", + OperandType::Int64 => "int64", + OperandType::UInt64 => "uint64", + OperandType::Float32 => "float32", + OperandType::Float64 => "float64", + OperandType::Token => "token", + OperandType::Switch => "switch", + } + } +} + +impl Display for OperandType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } } /// Represents an immediate value type embedded in CIL instructions. @@ -322,6 +354,36 @@ impl Operand { } } +impl Display for Operand { + /// Stable, human-readable rendering of the operand. + /// + /// Mirrors [`Operand::as_string`], but returns the string `"none"` for + /// [`Operand::None`] instead of nothing so the [`Display`] impl is total. + /// The format of each variant is part of the stable public API and safe + /// to persist or parse: + /// + /// - `Operand::None` → `"none"` + /// - `Operand::Immediate(v)` → `Debug` of the immediate (`"Int32(42)"`, `"UInt8(7)"`, …) + /// - `Operand::Target(addr)` → `"0x{addr:08X}"` + /// - `Operand::Token(tok)` → `"0x{token:08X}"` + /// - `Operand::Local(idx)` → `"V_{idx}"` + /// - `Operand::Argument(idx)` → `"A_{idx}"` + /// - `Operand::Switch(targs)` → `"switch({len})"` (target count, not the + /// targets themselves; render those explicitly via the `Switch` variant + /// payload if needed). + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Operand::None => f.write_str("none"), + Operand::Immediate(imm) => write!(f, "{imm:?}"), + Operand::Target(t) => write!(f, "0x{t:08X}"), + Operand::Token(t) => write!(f, "0x{:08X}", t.value()), + Operand::Local(l) => write!(f, "V_{l}"), + Operand::Argument(a) => write!(f, "A_{a}"), + Operand::Switch(targets) => write!(f, "switch({})", targets.len()), + } + } +} + /// How an instruction affects control flow. /// /// This enum categorizes instructions based on their control flow behavior, @@ -366,6 +428,41 @@ pub enum FlowType { Leave, } +impl FlowType { + /// Returns a stable `&'static str` identifier for this flow type. + /// + /// The strings are part of the stable public API and safe to persist + /// (file, database, log line). Variants use `snake_case` so the value + /// can be parsed without quoting: + /// + /// `"sequential"`, `"conditional_branch"`, `"unconditional_branch"`, + /// `"call"`, `"return"`, `"switch"`, `"throw"`, `"end_finally"`, + /// `"leave"`. + /// + /// Prefer this accessor over `format!("{:?}", flow_type)` — the `Debug` + /// representation is **not** part of the stable API and can change. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + FlowType::Sequential => "sequential", + FlowType::ConditionalBranch => "conditional_branch", + FlowType::UnconditionalBranch => "unconditional_branch", + FlowType::Call => "call", + FlowType::Return => "return", + FlowType::Switch => "switch", + FlowType::Throw => "throw", + FlowType::EndFinally => "end_finally", + FlowType::Leave => "leave", + } + } +} + +impl Display for FlowType { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// Stack effect of an instruction. /// /// Describes how an instruction modifies the evaluation stack. This information @@ -449,6 +546,37 @@ pub enum InstructionCategory { Misc, } +impl InstructionCategory { + /// Returns a stable `&'static str` identifier for this category. + /// + /// The strings are part of the stable public API and safe to persist. + /// Variants use `snake_case` so the value can be parsed without quoting: + /// + /// `"arithmetic"`, `"bitwise_logical"`, `"comparison"`, `"control_flow"`, + /// `"conversion"`, `"load_store"`, `"object_model"`, `"prefix"`, + /// `"misc"`. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + InstructionCategory::Arithmetic => "arithmetic", + InstructionCategory::BitwiseLogical => "bitwise_logical", + InstructionCategory::Comparison => "comparison", + InstructionCategory::ControlFlow => "control_flow", + InstructionCategory::Conversion => "conversion", + InstructionCategory::LoadStore => "load_store", + InstructionCategory::ObjectModel => "object_model", + InstructionCategory::Prefix => "prefix", + InstructionCategory::Misc => "misc", + } + } +} + +impl Display for InstructionCategory { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } +} + /// A decoded CIL instruction with all metadata needed for analysis and emulation. /// /// This struct represents a fully decoded .NET CIL instruction, including its location, @@ -1833,4 +1961,84 @@ mod tests { assert_eq!(stack_behavior.net_effect, (*pushes as i8) - (*pops as i8)); } } + + #[test] + fn test_flow_type_stable_strings() { + // Strings are part of the stable public API. Changing them is a + // breaking change for downstream consumers that persist these values + // (e.g. visus stores the result in a database). + let cases = [ + (FlowType::Sequential, "sequential"), + (FlowType::ConditionalBranch, "conditional_branch"), + (FlowType::UnconditionalBranch, "unconditional_branch"), + (FlowType::Call, "call"), + (FlowType::Return, "return"), + (FlowType::Switch, "switch"), + (FlowType::Throw, "throw"), + (FlowType::EndFinally, "end_finally"), + (FlowType::Leave, "leave"), + ]; + for (variant, expected) in cases { + assert_eq!(variant.as_str(), expected); + assert_eq!(format!("{variant}"), expected); + } + } + + #[test] + fn test_operand_type_stable_strings() { + let cases = [ + (OperandType::None, "none"), + (OperandType::Int8, "int8"), + (OperandType::UInt8, "uint8"), + (OperandType::Int16, "int16"), + (OperandType::UInt16, "uint16"), + (OperandType::Int32, "int32"), + (OperandType::UInt32, "uint32"), + (OperandType::Int64, "int64"), + (OperandType::UInt64, "uint64"), + (OperandType::Float32, "float32"), + (OperandType::Float64, "float64"), + (OperandType::Token, "token"), + (OperandType::Switch, "switch"), + ]; + for (variant, expected) in cases { + assert_eq!(variant.as_str(), expected); + assert_eq!(format!("{variant}"), expected); + } + } + + #[test] + fn test_instruction_category_stable_strings() { + let cases = [ + (InstructionCategory::Arithmetic, "arithmetic"), + (InstructionCategory::BitwiseLogical, "bitwise_logical"), + (InstructionCategory::Comparison, "comparison"), + (InstructionCategory::ControlFlow, "control_flow"), + (InstructionCategory::Conversion, "conversion"), + (InstructionCategory::LoadStore, "load_store"), + (InstructionCategory::ObjectModel, "object_model"), + (InstructionCategory::Prefix, "prefix"), + (InstructionCategory::Misc, "misc"), + ]; + for (variant, expected) in cases { + assert_eq!(variant.as_str(), expected); + assert_eq!(format!("{variant}"), expected); + } + } + + #[test] + fn test_operand_display_stable_format() { + // None now renders as "none" (Display is total, unlike as_string). + assert_eq!(format!("{}", Operand::None), "none"); + + // Numeric / target / token / variable indices match the as_string contract. + assert_eq!(format!("{}", Operand::Target(0x1000)), "0x00001000"); + assert_eq!(format!("{}", Operand::Local(5)), "V_5"); + assert_eq!(format!("{}", Operand::Argument(3)), "A_3"); + assert_eq!(format!("{}", Operand::Switch(vec![1, 2, 3])), "switch(3)"); + + // Token formatting matches as_string. + let tok = Token::new(0x06000001); + assert_eq!(format!("{}", Operand::Token(tok)), "0x06000001"); + } } diff --git a/dotscope/src/cilassembly/writer/heaps/streaming.rs b/dotscope/src/cilassembly/writer/heaps/streaming.rs index e8687d99..c44fba98 100644 --- a/dotscope/src/cilassembly/writer/heaps/streaming.rs +++ b/dotscope/src/cilassembly/writer/heaps/streaming.rs @@ -28,9 +28,18 @@ use crate::{ }, metadata::streams::{Blob, Guid, Strings, UserStrings}, utils::{compressed_uint_size, hash_blob, hash_string, to_u32, write_compressed_uint}, - Error, Result, + Error, ParseFailure, ParseStage, Result, }; +#[inline] +fn writer_overflow(field: &'static str) -> Error { + Error::Parse(ParseFailure::InvalidField { + stage: ParseStage::AssemblyWriter, + field, + reason: "arithmetic overflow".into(), + }) +} + /// Result of streaming a heap to output. #[derive(Debug)] pub struct StreamResult { @@ -155,7 +164,7 @@ fn emit_orphaned_substrings( let delta = ref_offset .checked_sub(old_offset_u32) - .ok_or_else(|| malformed_error!("Substring delta underflow"))? + .ok_or_else(|| writer_overflow("substring_delta"))? as usize; if delta >= original_bytes.len() { continue; @@ -181,10 +190,10 @@ fn emit_orphaned_substrings( let sub_pos = start_offset .checked_add(*pos) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; let sub_end = sub_pos .checked_add(sub_bytes.len() as u64) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; if let Some(out) = output.as_mut() { out.write_at(sub_pos, sub_bytes)?; out.write_at(sub_end, &[0u8])?; @@ -193,7 +202,7 @@ fn emit_orphaned_substrings( *pos = pos .checked_add(sub_bytes.len() as u64) .and_then(|p| p.checked_add(1)) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; dedup_map.insert(sub_hash, new_sub_offset); result.remapping.insert(ref_offset, new_sub_offset); } @@ -256,7 +265,7 @@ fn process_strings_heap( let original_end = old_offset_u32 .checked_add(to_u32(original_bytes.len())?) .and_then(|v| v.checked_add(1)) - .ok_or_else(|| malformed_error!("String range exceeds u32"))?; // +1 for null + .ok_or_else(|| writer_overflow("string_range"))?; // +1 for null if changes.is_removed(old_offset_u32) { // Even though this entry is removed, emit any referenced substrings @@ -306,15 +315,15 @@ fn process_strings_heap( let str_bytes = final_str.as_bytes(); let entry_size = (str_bytes.len() as u64) .checked_add(1) - .ok_or_else(|| malformed_error!("Heap entry size overflow"))?; // +1 for null terminator + .ok_or_else(|| writer_overflow("heap_entry_size"))?; // +1 for null terminator // Write if in write mode let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; let null_pos = write_pos .checked_add(str_bytes.len() as u64) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; if let Some(out) = output.as_mut() { out.write_at(write_pos, str_bytes)?; out.write_at(null_pos, &[0u8])?; @@ -322,7 +331,7 @@ fn process_strings_heap( pos = pos .checked_add(entry_size) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; dedup_map.insert(content_hash, new_offset); // Add primary offset remapping if changed @@ -351,10 +360,10 @@ fn process_strings_heap( if ref_offset > old_offset_u32 && ref_offset < original_end { let substring_delta = ref_offset .checked_sub(old_offset_u32) - .ok_or_else(|| malformed_error!("Substring delta underflow"))?; + .ok_or_else(|| writer_overflow("substring_delta"))?; let new_substring_offset = new_offset .checked_add(substring_delta) - .ok_or_else(|| malformed_error!("Substring offset overflow"))?; + .ok_or_else(|| writer_overflow("substring_offset"))?; result.remapping.insert(ref_offset, new_substring_offset); } } @@ -389,7 +398,7 @@ fn process_strings_heap( while result.remapping.contains_key(&(pos as u32)) { pos = pos .checked_add(1) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; } let new_offset = u32::try_from(pos) @@ -397,14 +406,14 @@ fn process_strings_heap( let str_bytes = final_str.as_bytes(); let entry_size = (str_bytes.len() as u64) .checked_add(1) - .ok_or_else(|| malformed_error!("Heap entry size overflow"))?; + .ok_or_else(|| writer_overflow("heap_entry_size"))?; let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; let null_pos = write_pos .checked_add(str_bytes.len() as u64) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; if let Some(out) = output.as_mut() { out.write_at(write_pos, str_bytes)?; out.write_at(null_pos, &[0u8])?; @@ -412,7 +421,7 @@ fn process_strings_heap( pos = pos .checked_add(entry_size) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; dedup_map.insert(content_hash, new_offset); change_ref.resolve_to_offset(new_offset); } @@ -558,14 +567,14 @@ fn process_blob_heap( let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; if let Some(out) = output.as_mut() { // Write compressed length of 0 out.write_at(write_pos, &[0u8])?; } pos = pos .checked_add(1) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; // Empty blob is just 1 byte (length 0) + .ok_or_else(|| writer_overflow("heap_position"))?; // Empty blob is just 1 byte (length 0) // Only add to remapping if the offset actually changed if old_offset_u32 != new_offset { @@ -587,11 +596,11 @@ fn process_blob_heap( let len_size = compressed_uint_size(final_blob.len()); let entry_size = len_size .checked_add(final_blob.len() as u64) - .ok_or_else(|| malformed_error!("Blob entry size overflow"))?; + .ok_or_else(|| writer_overflow("blob_entry_size"))?; let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; if let Some(out) = output.as_mut() { let blob_len_u32 = u32::try_from(final_blob.len()).map_err(|_| { Error::LayoutFailed(format!( @@ -603,14 +612,14 @@ fn process_blob_heap( write_compressed_uint(blob_len_u32, &mut len_bytes); let data_pos = write_pos .checked_add(len_bytes.len() as u64) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; out.write_at(write_pos, &len_bytes)?; out.write_at(data_pos, final_blob)?; } pos = pos .checked_add(entry_size) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; dedup_map.insert(content_hash, new_offset); // Only add to remapping if the offset actually changed if old_offset_u32 != new_offset { @@ -663,7 +672,7 @@ fn process_blob_heap( while result.remapping.contains_key(&(pos as u32)) { pos = pos .checked_add(1) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; } let new_offset = u32::try_from(pos) @@ -671,11 +680,11 @@ fn process_blob_heap( let len_size = compressed_uint_size(final_blob.len()); let entry_size = len_size .checked_add(final_blob.len() as u64) - .ok_or_else(|| malformed_error!("Blob entry size overflow"))?; + .ok_or_else(|| writer_overflow("blob_entry_size"))?; let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; if let Some(out) = output.as_mut() { let blob_len_u32 = u32::try_from(final_blob.len()).map_err(|_| { Error::LayoutFailed(format!( @@ -687,14 +696,14 @@ fn process_blob_heap( write_compressed_uint(blob_len_u32, &mut len_bytes); let data_pos = write_pos .checked_add(len_bytes.len() as u64) - .ok_or_else(|| malformed_error!("Heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("heap_write_offset"))?; out.write_at(write_pos, &len_bytes)?; out.write_at(data_pos, final_blob)?; } pos = pos .checked_add(entry_size) - .ok_or_else(|| malformed_error!("Heap position overflow"))?; + .ok_or_else(|| writer_overflow("heap_position"))?; dedup_map.insert(content_hash, new_offset); change_ref.resolve_to_offset(new_offset); } @@ -796,7 +805,7 @@ fn process_guid_heap( let byte_offset = old_index_u32 .saturating_sub(1) .checked_mul(16) - .ok_or_else(|| malformed_error!("GUID byte offset overflow"))?; + .ok_or_else(|| writer_overflow("guid_byte_offset"))?; // Check if deleted if changes.is_removed(byte_offset) { @@ -821,13 +830,13 @@ fn process_guid_heap( // Write if in write mode let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("GUID heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("guid_heap_write_offset"))?; if let Some(out) = output.as_mut() { out.write_at(write_pos, &final_guid)?; } pos = pos .checked_add(16) - .ok_or_else(|| malformed_error!("GUID heap position overflow"))?; + .ok_or_else(|| writer_overflow("guid_heap_position"))?; dedup_map.insert(final_guid, current_index); // Only add to remapping if the index actually changed @@ -836,7 +845,7 @@ fn process_guid_heap( } current_index = current_index .checked_add(1) - .ok_or_else(|| malformed_error!("GUID index overflow"))?; + .ok_or_else(|| writer_overflow("guid_index"))?; } } @@ -857,19 +866,19 @@ fn process_guid_heap( let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("GUID heap write offset overflow"))?; + .ok_or_else(|| writer_overflow("guid_heap_write_offset"))?; if let Some(out) = output.as_mut() { out.write_at(write_pos, final_guid)?; } pos = pos .checked_add(16) - .ok_or_else(|| malformed_error!("GUID heap position overflow"))?; + .ok_or_else(|| writer_overflow("guid_heap_position"))?; dedup_map.insert(*final_guid, current_index); change_ref.resolve_to_offset(current_index); current_index = current_index .checked_add(1) - .ok_or_else(|| malformed_error!("GUID index overflow"))?; + .ok_or_else(|| writer_overflow("guid_index"))?; } result.bytes_written = pos; @@ -999,14 +1008,14 @@ fn process_userstring_heap( // Write if in write mode let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; + .ok_or_else(|| writer_overflow("userstring_write_offset"))?; if let Some(out) = output.as_mut() { write_userstring_entry(out, write_pos, final_str)?; } pos = pos .checked_add(entry_size) - .ok_or_else(|| malformed_error!("UserString heap position overflow"))?; + .ok_or_else(|| writer_overflow("userstring_heap_position"))?; dedup_map.insert(content_hash, new_offset); // Only add to remapping if the offset actually changed if old_offset_u32 != new_offset { @@ -1038,14 +1047,14 @@ fn process_userstring_heap( let write_pos = start_offset .checked_add(pos) - .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; + .ok_or_else(|| writer_overflow("userstring_write_offset"))?; if let Some(out) = output.as_mut() { write_userstring_entry(out, write_pos, final_str)?; } pos = pos .checked_add(entry_size) - .ok_or_else(|| malformed_error!("UserString heap position overflow"))?; + .ok_or_else(|| writer_overflow("userstring_heap_position"))?; dedup_map.insert(content_hash, new_offset); change_ref.resolve_to_offset(new_offset); } @@ -1060,13 +1069,13 @@ fn userstring_entry_size(s: &str) -> Result { .encode_utf16() .count() .checked_mul(2) - .ok_or_else(|| malformed_error!("UserString UTF-16 length overflow"))?; + .ok_or_else(|| writer_overflow("userstring_utf16_length"))?; let total_len = utf16_len .checked_add(1) - .ok_or_else(|| malformed_error!("UserString total length overflow"))?; // +1 for terminal byte + .ok_or_else(|| writer_overflow("userstring_total_length"))?; // +1 for terminal byte compressed_uint_size(total_len) .checked_add(total_len as u64) - .ok_or_else(|| malformed_error!("UserString entry size overflow")) + .ok_or_else(|| writer_overflow("userstring_entry_size")) } /// Writes a single user string entry to output. @@ -1077,7 +1086,7 @@ fn write_userstring_entry(output: &mut Output, pos: u64, s: &str) -> Result<()> let total_len = utf16_bytes .len() .checked_add(1) - .ok_or_else(|| malformed_error!("UserString total length overflow"))?; + .ok_or_else(|| writer_overflow("userstring_total_length"))?; // Write compressed length let total_len_u32 = u32::try_from(total_len).map_err(|_| { @@ -1090,14 +1099,14 @@ fn write_userstring_entry(output: &mut Output, pos: u64, s: &str) -> Result<()> // Write UTF-16LE bytes let utf16_pos = pos .checked_add(len_bytes.len() as u64) - .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; + .ok_or_else(|| writer_overflow("userstring_write_offset"))?; output.write_at(utf16_pos, &utf16_bytes)?; // Write terminal byte (0x01 if any byte has high bit set, 0x00 otherwise) let terminal = u8::from(utf16_bytes.iter().any(|&b| b & 0x80 != 0)); let terminal_pos = utf16_pos .checked_add(utf16_bytes.len() as u64) - .ok_or_else(|| malformed_error!("UserString write offset overflow"))?; + .ok_or_else(|| writer_overflow("userstring_write_offset"))?; output.write_at(terminal_pos, &[terminal])?; Ok(()) diff --git a/dotscope/src/compiler/codegen/mod.rs b/dotscope/src/compiler/codegen/mod.rs index fe7d5d98..c35bf40f 100644 --- a/dotscope/src/compiler/codegen/mod.rs +++ b/dotscope/src/compiler/codegen/mod.rs @@ -844,27 +844,22 @@ impl SsaCodeGenerator { SsaOp::Const { value: ConstValue::DecryptedString(s), .. - } if !self.interned_strings.contains_key(s) => { + } if !self.interned_strings.contains_key(s.as_ref()) => { let change_ref = assembly.userstring_add(s)?; self.interned_strings - .insert(s.clone(), change_ref.placeholder()); + .insert(s.to_string(), change_ref.placeholder()); } SsaOp::Const { - value: - ConstValue::DecryptedArray { - data, - element_type_ref, - element_size, - }, + value: ConstValue::DecryptedArray(arr), .. - } if !self.interned_arrays.contains_key(data) => { + } if !self.interned_arrays.contains_key(&arr.data) => { if let Some(info) = self.intern_array_data( - data, - element_type_ref.token(), - *element_size, + &arr.data, + arr.element_type_ref.token(), + arr.element_size, assembly, )? { - self.interned_arrays.insert(data.clone(), info); + self.interned_arrays.insert(arr.data.clone(), info); } } _ => {} @@ -3515,7 +3510,63 @@ impl SsaCodeGenerator { | SsaOp::Select { .. } | SsaOp::ReadFlags { .. } | SsaOp::CmpXchg { .. } - | SsaOp::AtomicRmw { .. } => { + | SsaOp::AtomicRmw { .. } + // Native SSA substrate operations (wide arithmetic, native + // boolean/float-flag ops, SIMD/vector, native atomics, bitcast, + // opaque, indirect branch). These never appear in CIL-lifted SSA + // and have no direct CIL encoding; enumerated explicitly (no + // wildcard) so future substrate additions still trip this + // exhaustiveness check. + | SsaOp::WideMul { .. } + | SsaOp::WideDiv { .. } + | SsaOp::FloatCompareFlags { .. } + | SsaOp::BoolAnd { .. } + | SsaOp::BoolOr { .. } + | SsaOp::BoolXor { .. } + | SsaOp::BoolNot { .. } + | SsaOp::Bitcast { .. } + | SsaOp::IndirectBranch { .. } + | SsaOp::NativeOpaque(_) + | SsaOp::VectorUnary { .. } + | SsaOp::VectorBinary { .. } + | SsaOp::VectorTernary { .. } + | SsaOp::VectorPredicatedUnary { .. } + | SsaOp::VectorPredicatedBinary { .. } + | SsaOp::VectorPredicatedTernary { .. } + | SsaOp::VectorCompare { .. } + | SsaOp::VectorLoad { .. } + | SsaOp::VectorStore { .. } + | SsaOp::VectorMaskedLoad { .. } + | SsaOp::VectorMaskedStore { .. } + | SsaOp::VectorBroadcastLoad { .. } + | SsaOp::VectorGather { .. } + | SsaOp::VectorFaultingLoad { .. } + | SsaOp::VectorSegmentLoad { .. } + | SsaOp::VectorScatter { .. } + | SsaOp::VectorSegmentStore { .. } + | SsaOp::VectorExtract { .. } + | SsaOp::VectorInsert { .. } + | SsaOp::VectorSplat { .. } + | SsaOp::VectorShuffle { .. } + | SsaOp::VectorCast { .. } + | SsaOp::VectorReinterpret { .. } + | SsaOp::VectorPack { .. } + | SsaOp::VectorPackLoad { .. } + | SsaOp::VectorPackStore { .. } + | SsaOp::VectorZeroUpper { .. } + | SsaOp::VectorMaskUnary { .. } + | SsaOp::VectorMaskBinary { .. } + | SsaOp::VectorReduce { .. } + | SsaOp::VectorBitmask { .. } + | SsaOp::AtomicLoad { .. } + | SsaOp::AtomicStore { .. } + | SsaOp::AtomicStoreConditional { .. } + | SsaOp::AtomicPairLoad { .. } + | SsaOp::AtomicPairStoreConditional { .. } + | SsaOp::AtomicExchange { .. } + | SsaOp::AtomicLockRmw { .. } + | SsaOp::AtomicCmpXchg { .. } + | SsaOp::AtomicPairCmpXchg { .. } => { // These operations may appear in the shared SSA but are not // directly expressible in CIL; they should have been lowered // before code generation. @@ -4426,7 +4477,7 @@ impl SsaCodeGenerator { // Decrypted strings look up pre-interned index ConstValue::DecryptedString(s) => { - if let Some(&idx) = self.interned_strings.get(s) { + if let Some(&idx) = self.interned_strings.get(s.as_ref()) { let token = Token::new(0x7000_0000 | idx); encoder.emit_instruction("ldstr", Some(Operand::Token(token)))?; } else { @@ -4470,20 +4521,18 @@ impl SsaCodeGenerator { // dup ; push copy (+2) // ldtoken ; push handle (+3) // call InitializeArray ; pop 2 (+1) — net: 1 value on stack - ConstValue::DecryptedArray { - data, - element_type_ref, - element_size, - } => { - let elem_size = element_size.max(&1); + ConstValue::DecryptedArray(arr) => { + let elem_size = arr.element_size.max(1); #[allow(clippy::cast_possible_truncation)] - let num_elements = data.len().checked_div(*elem_size).unwrap_or(0); + let num_elements = arr.data.len().checked_div(elem_size).unwrap_or(0); emitter::emit_ldc_i4(encoder, num_elements as i32)?; - encoder - .emit_instruction("newarr", Some(Operand::Token(element_type_ref.token())))?; + encoder.emit_instruction( + "newarr", + Some(Operand::Token(arr.element_type_ref.token())), + )?; - if let Some(info) = self.interned_arrays.get(data) { + if let Some(info) = self.interned_arrays.get(&arr.data) { // Compact: dup + ldtoken + call InitializeArray // Stack: [array] → [array, array] → [array, array, handle] → [array] encoder.emit_instruction("dup", None)?; @@ -4500,6 +4549,14 @@ impl SsaCodeGenerator { // If interning failed, array is left uninitialized (zeroed). // This is still valid IL — the values will just be default(T). } + + // SIMD vector constants come from the native SSA substrate and have + // no direct CIL `ldc`-style encoding. + ConstValue::Vector(_) => { + return Err(Error::CodegenFailed( + "Vector constants are not supported in CIL code generation".to_string(), + )); + } } Ok(()) } diff --git a/dotscope/src/compiler/passes/constants/mod.rs b/dotscope/src/compiler/passes/constants/mod.rs index ace94bd9..acd2861a 100644 --- a/dotscope/src/compiler/passes/constants/mod.rs +++ b/dotscope/src/compiler/passes/constants/mod.rs @@ -60,6 +60,7 @@ fn is_method_on_type(assembly: &CilObject, token: Token, type_name: &str) -> boo match token.table() { 0x06 => assembly .method(&token) + .ok() .and_then(|m| m.declaring_type_rc()) .is_some_and(|ty| ty.name.contains(type_name)), 0x0A => assembly @@ -1052,7 +1053,7 @@ impl ConstantPropagationPass { args: &[ConstValue], ptr_size: PointerSize, ) -> Option { - let method = assembly.method(&callee_token)?; + let method = assembly.method(&callee_token).ok()?; let callee_ssa = method.ssa(assembly).ok()?; let mut eval = SsaEvaluator::new(&callee_ssa, ptr_size); @@ -1223,7 +1224,7 @@ impl ConstantPropagationPass { }) .collect(); let result = strings?.concat(); - Some((dest, ConstValue::DecryptedString(result))) + Some((dest, ConstValue::DecryptedString(result.into()))) } StringFoldOp::SubstringFrom => { let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; @@ -1233,7 +1234,7 @@ impl ConstantPropagationPass { } let start = constants.get(args.get(1)?)?.as_i32()? as usize; let tail = this_str.get(start..)?; - Some((dest, ConstValue::DecryptedString(tail.to_string()))) + Some((dest, ConstValue::DecryptedString(tail.into()))) } StringFoldOp::SubstringRange => { let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; @@ -1244,7 +1245,7 @@ impl ConstantPropagationPass { let len = constants.get(args.get(2)?)?.as_i32()? as usize; let end = start.checked_add(len)?; let slice = this_str.get(start..end)?; - Some((dest, ConstValue::DecryptedString(slice.to_string()))) + Some((dest, ConstValue::DecryptedString(slice.into()))) } StringFoldOp::Replace => { let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; @@ -1252,16 +1253,22 @@ impl ConstantPropagationPass { let new = constants.get(args.get(2)?)?.as_string_content(assembly)?; Some(( dest, - ConstValue::DecryptedString(this_str.replace(&old, &new)), + ConstValue::DecryptedString(this_str.replace(&old, &new).into()), )) } StringFoldOp::ToLower => { let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; - Some((dest, ConstValue::DecryptedString(this_str.to_lowercase()))) + Some(( + dest, + ConstValue::DecryptedString(this_str.to_lowercase().into()), + )) } StringFoldOp::ToUpper => { let this_str = constants.get(args.first()?)?.as_string_content(assembly)?; - Some((dest, ConstValue::DecryptedString(this_str.to_uppercase()))) + Some(( + dest, + ConstValue::DecryptedString(this_str.to_uppercase().into()), + )) } } } @@ -1489,7 +1496,7 @@ impl ConstantPropagationPass { // Even chain: all cancel out, result = innermost_operand ssa.replace_uses_including_phis(t.outermost_dest, t.innermost_operand); for &(b, i) in &t.instructions_to_nop { - ssa.remove_instruction(b, i); + ssa.replace_instruction_op(b, i, SsaOp::Nop); } changes .record(EventKind::ConstantFolded) @@ -1521,7 +1528,7 @@ impl ConstantPropagationPass { // Nop all except the outermost if let Some(rest) = t.instructions_to_nop.get(1..) { for &(b, i) in rest { - ssa.remove_instruction(b, i); + ssa.replace_instruction_op(b, i, SsaOp::Nop); } } changes diff --git a/dotscope/src/compiler/passes/constants/tests.rs b/dotscope/src/compiler/passes/constants/tests.rs index a3283010..ca224478 100644 --- a/dotscope/src/compiler/passes/constants/tests.rs +++ b/dotscope/src/compiler/passes/constants/tests.rs @@ -1169,8 +1169,8 @@ fn test_try_fold_concat2_decrypted_strings() { let v1 = SsaVarId::from_index(1); let mut constants = BTreeMap::new(); - constants.insert(v0, ConstValue::DecryptedString("Hello".to_string())); - constants.insert(v1, ConstValue::DecryptedString(", World".to_string())); + constants.insert(v0, ConstValue::DecryptedString("Hello".into())); + constants.insert(v1, ConstValue::DecryptedString(", World".into())); // Use a fake token — identify_string_op will fail to resolve it through the // assembly, so we test the folding logic indirectly through fold_string_operations. @@ -1194,7 +1194,7 @@ fn test_try_fold_concat_with_non_constant_arg() { let v1 = SsaVarId::from_index(1); let mut constants = BTreeMap::new(); - constants.insert(v0, ConstValue::DecryptedString("Hello".to_string())); + constants.insert(v0, ConstValue::DecryptedString("Hello".into())); // v1 is NOT in constants — simulates a non-constant argument let strings: Option> = [v0, v1] @@ -1225,11 +1225,11 @@ fn test_fold_string_operations_with_decrypted_concat() { f.block(0, |b| { b.op(SsaOp::Const { dest: v0, - value: ConstValue::DecryptedString("Sys".to_string()), + value: ConstValue::DecryptedString("Sys".into()), }); b.op(SsaOp::Const { dest: v1, - value: ConstValue::DecryptedString("tem".to_string()), + value: ConstValue::DecryptedString("tem".into()), }); b.op(SsaOp::Call { dest: Some(v2), @@ -1242,8 +1242,8 @@ fn test_fold_string_operations_with_decrypted_concat() { .unwrap(); let mut constants = BTreeMap::new(); - constants.insert(v0, ConstValue::DecryptedString("Sys".to_string())); - constants.insert(v1, ConstValue::DecryptedString("tem".to_string())); + constants.insert(v0, ConstValue::DecryptedString("Sys".into())); + constants.insert(v1, ConstValue::DecryptedString("tem".into())); let mut changes = EventLog::new(); diff --git a/dotscope/src/compiler/scheduler.rs b/dotscope/src/compiler/scheduler.rs index 228e3672..780b2dae 100644 --- a/dotscope/src/compiler/scheduler.rs +++ b/dotscope/src/compiler/scheduler.rs @@ -15,7 +15,7 @@ use std::sync::Arc; -use analyssa::scheduling::PassScheduler as AnalyssaPassScheduler; +use analyssa::scheduling::{PassScheduler as AnalyssaPassScheduler, PipelineConfig}; use crate::{ analysis::CilTarget, @@ -69,12 +69,20 @@ impl PassScheduler { stable_iterations: usize, max_phase_iterations: usize, ) -> Self { + // dotscope populates the scheduler with its own CIL passes (gated by + // engine config), so start from an empty analyssa scheduler rather + // than the default built-in pipeline. Only the iteration limits and + // `verify_hard` flag are consumed by `empty`; the remaining + // `PipelineConfig` fields tune analyssa's built-in passes, which we + // do not register here. + let config = PipelineConfig { + max_iterations, + stable_iterations, + max_phase_iterations, + ..PipelineConfig::default() + }; Self { - inner: AnalyssaPassScheduler::new( - max_iterations, - stable_iterations, - max_phase_iterations, - ), + inner: AnalyssaPassScheduler::empty(config), } } diff --git a/dotscope/src/deobfuscation/engine/analysis.rs b/dotscope/src/deobfuscation/engine/analysis.rs index d567d9d5..30f918bc 100644 --- a/dotscope/src/deobfuscation/engine/analysis.rs +++ b/dotscope/src/deobfuscation/engine/analysis.rs @@ -147,7 +147,7 @@ impl DeobfuscationEngine { let errors: Vec<(Token, Error)> = method_tokens .par_iter() .filter_map(|&method_token| { - let method = assembly.method(&method_token)?; + let method = assembly.method(&method_token).ok()?; match method.ssa(assembly) { Ok(ssa) => { ctx.set_ssa(method_token, ssa); diff --git a/dotscope/src/deobfuscation/passes/bitmono/strings.rs b/dotscope/src/deobfuscation/passes/bitmono/strings.rs index 208e3294..9454270c 100644 --- a/dotscope/src/deobfuscation/passes/bitmono/strings.rs +++ b/dotscope/src/deobfuscation/passes/bitmono/strings.rs @@ -255,7 +255,7 @@ impl SsaPass for StringDecryptionPass { if let Some(instr) = block.instruction_mut(*call_idx) { instr.set_op(SsaOp::Const { dest: *call_dest, - value: ConstValue::DecryptedString(decrypted.clone()), + value: ConstValue::DecryptedString(decrypted.clone().into()), }); } } diff --git a/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs b/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs index ab2c3f97..77146cdf 100644 --- a/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs +++ b/dotscope/src/deobfuscation/passes/bitmono/unmanaged.rs @@ -94,7 +94,7 @@ impl SsaPass for UnmanagedStringReversalPass { if let Some(instr) = block.instruction_mut(site.newobj_idx) { instr.set_op(SsaOp::Const { dest: site.newobj_dest, - value: ConstValue::DecryptedString(site.decrypted.clone()), + value: ConstValue::DecryptedString(site.decrypted.clone().into()), }); } diff --git a/dotscope/src/deobfuscation/passes/decryption.rs b/dotscope/src/deobfuscation/passes/decryption.rs index 9812de0f..0288a6fe 100644 --- a/dotscope/src/deobfuscation/passes/decryption.rs +++ b/dotscope/src/deobfuscation/passes/decryption.rs @@ -91,6 +91,7 @@ use analyssa::graph::{ algorithms::{compute_dominators, DominatorTree}, GraphBase, NodeId, RootedGraph, }; +use analyssa::ir::value::DecryptedArrayData; /// Decryption pass for obfuscated constants and strings. /// @@ -430,7 +431,7 @@ impl DecryptionPass { // Handle ObjectRef specially - try string, unbox, or array from heap if let EmValue::ObjectRef(href) = em_value { if let Ok(s) = thread.heap().get_string(*href) { - return Some(ConstValue::DecryptedString(s.to_string())); + return Some(ConstValue::DecryptedString(s.to_string().into_boxed_str())); } // Try to unbox a boxed primitive value from the heap if let Ok(unboxed) = thread.heap().unbox(*href) { @@ -444,11 +445,11 @@ impl DecryptionPass { let elem_size = elem_flavor.byte_size(PointerSize::Bit32).unwrap_or(1); // Resolve the element CilFlavor to a real TypeRef token from the assembly if let Some(token) = Self::resolve_flavor_to_typeref(&elem_flavor, thread) { - return Some(ConstValue::DecryptedArray { + return Some(ConstValue::DecryptedArray(Box::new(DecryptedArrayData { data: bytes, element_type_ref: TypeRef::new(token), element_size: elem_size, - }); + }))); } } } @@ -461,7 +462,7 @@ impl DecryptionPass { if let Some(first_field) = fields.first().filter(|_| fields.len() == 1) { if let EmValue::ObjectRef(href) = first_field { if let Ok(s) = thread.heap().get_string(*href) { - return Some(ConstValue::DecryptedString(s.to_string())); + return Some(ConstValue::DecryptedString(s.to_string().into_boxed_str())); } } // Try primitive conversion on the single field (but not if it's Null) diff --git a/dotscope/src/deobfuscation/passes/delegates.rs b/dotscope/src/deobfuscation/passes/delegates.rs index 9f6893bc..b3f360fe 100644 --- a/dotscope/src/deobfuscation/passes/delegates.rs +++ b/dotscope/src/deobfuscation/passes/delegates.rs @@ -265,7 +265,7 @@ impl DelegateProxyResolutionPass { let is_virtual = synthetic_is_virtual.unwrap_or_else(|| { assembly .as_ref() - .and_then(|asm| asm.method(&resolved_token)) + .and_then(|asm| asm.method(&resolved_token).ok()) .map(|m| !m.is_static()) .unwrap_or(false) }); diff --git a/dotscope/src/deobfuscation/passes/reflection.rs b/dotscope/src/deobfuscation/passes/reflection.rs index ddd7fb7e..75bd3d6e 100644 --- a/dotscope/src/deobfuscation/passes/reflection.rs +++ b/dotscope/src/deobfuscation/passes/reflection.rs @@ -333,7 +333,7 @@ impl<'a> ChainTracer<'a> { fn is_method_from_type(&self, token: Token, type_name: &str) -> bool { let table = token.table(); if table == 0x06 { - if let Some(method) = self.assembly.method(&token) { + if let Ok(method) = self.assembly.method(&token) { if let Some(ty) = method.declaring_type_rc() { return ty.name.contains(type_name); } @@ -412,7 +412,7 @@ impl<'a> ChainTracer<'a> { return None; }; match value { - ConstValue::DecryptedString(s) => Some(s.clone()), + ConstValue::DecryptedString(s) => Some(s.to_string()), _ => None, } } @@ -1481,7 +1481,7 @@ fn resolve_method_by_name(assembly: &CilObject, type_token: Token, name: &str) - fn find_parameterless_ctor(assembly: &CilObject, type_token: Token) -> Option { let ty = assembly.types().get(&type_token)?; let ctor_token = ty.ctor()?; - let method = assembly.method(&ctor_token)?; + let method = assembly.method(&ctor_token).ok()?; if method.params.is_empty() { Some(ctor_token) } else { diff --git a/dotscope/src/deobfuscation/passes/staticfields.rs b/dotscope/src/deobfuscation/passes/staticfields.rs index fde150e1..aeee857e 100644 --- a/dotscope/src/deobfuscation/passes/staticfields.rs +++ b/dotscope/src/deobfuscation/passes/staticfields.rs @@ -213,7 +213,7 @@ impl FieldValueExtractor for StringExtractor { ) -> Option { match value { EmValue::ObjectRef(heap_ref) => match process.address_space().get_string(*heap_ref) { - Ok(s) => Some(ConstValue::DecryptedString(s.to_string())), + Ok(s) => Some(ConstValue::DecryptedString(s.to_string().into_boxed_str())), Err(e) => { debug!( "StringField: field 0x{:08X} has ObjectRef but get_string failed: {}", @@ -225,7 +225,7 @@ impl FieldValueExtractor for StringExtractor { }, EmValue::Null => { // Null string — field initialized to null (rare but valid) - Some(ConstValue::DecryptedString(String::new())) + Some(ConstValue::DecryptedString(Box::from(""))) } other => { debug!( diff --git a/dotscope/src/deobfuscation/passes/unflattening/detection.rs b/dotscope/src/deobfuscation/passes/unflattening/detection.rs index 21b5f754..a23c09ab 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/detection.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/detection.rs @@ -693,8 +693,12 @@ impl<'a> CffDetector<'a> { .collect(); if entry_blocks.is_empty() { - // No separate entry blocks - dispatcher is the entry - let entry = EntryPoint::new(dispatcher_block); + let mut entry = EntryPoint::new(dispatcher_block); + if let Some(dispatcher_var) = state_var.and_then(|sv| sv.dispatcher_var) { + if let Some(initial) = self.initial_state_from_state_phi(dispatcher_var) { + entry.initial_state = Some(initial); + } + } entries.push(entry); return entries; } @@ -799,6 +803,31 @@ impl<'a> CffDetector<'a> { None } + /// Recovers the initial state from the dispatcher's state phi when the + /// region has no distinct entry block. + /// + /// In nested CFF (e.g. ConfuserEx handler dispatchers), the setup block + /// that assigns the initial state is itself one of the dispatcher's switch + /// case targets, so it is filtered out as a "case block" and there is no + /// separate entry to read the initial value from. The state phi at the + /// dispatcher still distinguishes the two roles: back-edge operands trace + /// to computed state updates (`mul`/`xor`/…), while the setup operand + /// traces to a constant. Return that constant. + fn initial_state_from_state_phi(&self, dispatcher_var: SsaVarId) -> Option { + for block in self.ssa.blocks() { + for phi in block.phi_nodes() { + if phi.result() == dispatcher_var { + for op in phi.operands() { + if let Some(val) = self.trace_to_constant(op.value(), 20) { + return Some(val); + } + } + } + } + } + None + } + /// Traces a variable backward through SSA definitions to find a constant value. /// /// Handles Const, Copy, and PHI definitions. For PHI nodes, tries each diff --git a/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs b/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs index 27a877dd..85d6e292 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/dispatcher.rs @@ -153,35 +153,12 @@ impl Dispatcher { })?; // Find the state phi: the phi whose result feeds into the switch value - // (directly or through a short chain of copies/arithmetic in the block) - let state_phi = block - .phi_nodes() - .iter() - .find(|phi| { - if phi.result() == switch_var { - return true; - } - // Trace switch_var backwards through block-local definitions - let mut current = switch_var; - for _ in 0..5 { - if let Some(def_instr) = block - .instructions() - .iter() - .find(|i| i.op().dest().is_some_and(|d| d == current)) - { - if let Some(&src) = def_instr.op().uses().first() { - if src == phi.result() { - return true; - } - current = src; - continue; - } - } - break; - } - false - }) - .map(|phi| phi.result()); + // through the dispatcher's state-transform chain. Reuse the IR's + // backward tracer (which follows arithmetic/bitwise/unary state + // transforms) rather than an ad-hoc, fixed-depth, first-operand walk — + // the latter capped at 5 steps and ignored `neg`/`not`, missing the + // state phi behind ConfuserEx "expression" wrappers like `-(!!state)`. + let state_phi = ssa.trace_to_phi(switch_var, Some(self.block)); let mut refreshed = Self::new( self.block, diff --git a/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs b/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs index 0368fb46..9c5ecb8d 100644 --- a/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs +++ b/dotscope/src/deobfuscation/passes/unflattening/tracer/helpers.rs @@ -161,7 +161,7 @@ pub fn resolve_call_result( pointer_size: PointerSize, ) -> Option { // Look up the method - let method = assembly.method(&method_token)?; + let method = assembly.method(&method_token).ok()?; // Build SSA for the callee let callee_ssa = method.ssa(assembly).ok()?; diff --git a/dotscope/src/deobfuscation/renamer/cascade.rs b/dotscope/src/deobfuscation/renamer/cascade.rs index 84c7be1a..977cc1d0 100644 --- a/dotscope/src/deobfuscation/renamer/cascade.rs +++ b/dotscope/src/deobfuscation/renamer/cascade.rs @@ -228,6 +228,7 @@ impl<'a> CascadeRenamer<'a> { .or_else(|| { self.assembly .method(caller_token) + .ok() .filter(|m| !is_obfuscated_name(&m.name)) .map(|m| m.name.clone()) }) @@ -619,7 +620,7 @@ impl<'a> CascadeRenamer<'a> { }; // Get method metadata - if let Some(method) = self.assembly.method(&method_token) { + if let Ok(method) = self.assembly.method(&method_token) { // Return type context.dotnet_type = Some(method.signature.return_type.to_string()); @@ -823,14 +824,14 @@ impl<'a> CascadeRenamer<'a> { // Parent method name (committed or original) if let Some(name) = self.committed.get(&method_token) { context.parent_type = Some(name.clone()); - } else if let Some(method) = self.assembly.method(&method_token) { + } else if let Ok(method) = self.assembly.method(&method_token) { if !is_obfuscated_name(&method.name) { context.parent_type = Some(method.name.clone()); } } // Parameter type from method signature - if let Some(method) = self.assembly.method(&method_token) { + if let Ok(method) = self.assembly.method(&method_token) { // param.sequence is 1-based (0 = return type), so index = sequence - 1 let sig_index = (param_sequence as usize).saturating_sub(1); if let Some(param) = method.signature.params.get(sig_index) { @@ -2048,6 +2049,7 @@ mod tests { let method_token = Token::new(0x0600_0000 | rid); let has_cfg = assembly .method(&method_token) + .ok() .map(|m| m.cfg().is_some()) .unwrap_or(false); eprintln!( @@ -2115,7 +2117,7 @@ mod tests { .get(rid) .and_then(|md| strings.get(md.name as usize).ok()) .unwrap_or("?"); - let method = assembly.method(&method_token); + let method = assembly.method(&method_token).ok(); let has_cfg = method.as_ref().map(|m| m.cfg().is_some()).unwrap_or(false); let has_body = method .as_ref() @@ -2446,6 +2448,7 @@ mod tests { for (i, token) in topo.iter().enumerate() { let name = assembly .method(token) + .ok() .map(|m| m.name.clone()) .or_else(|| assembly.resolve_method_name(*token)) .unwrap_or_else(|| format!("0x{:08X}", token.value())); diff --git a/dotscope/src/deobfuscation/renamer/features.rs b/dotscope/src/deobfuscation/renamer/features.rs index 571266ee..a912e97f 100644 --- a/dotscope/src/deobfuscation/renamer/features.rs +++ b/dotscope/src/deobfuscation/renamer/features.rs @@ -78,7 +78,7 @@ pub fn collect_string_literals(ssa: &SsaFunction, assembly: &CilObject) -> Vec { - strings.push(s.clone()); + strings.push(s.to_string()); } ConstValue::String(idx) => { // Resolve from UserStrings heap @@ -319,10 +319,11 @@ pub fn collect_call_site_context( for (_, _, nearby_instr) in window { if let SsaOp::Const { value, .. } = nearby_instr.op() { match value { - ConstValue::DecryptedString(s) - if !s.is_empty() && !nearby_strings.contains(s) => - { - nearby_strings.push(s.clone()); + ConstValue::DecryptedString(s) if !s.is_empty() => { + let s = s.to_string(); + if !nearby_strings.contains(&s) { + nearby_strings.push(s); + } } ConstValue::String(us_idx) => { if let Some(us) = assembly.userstrings() { @@ -385,7 +386,7 @@ fn resolve_qualified_method_name(assembly: &CilObject, token: Token) -> Option { - let method = assembly.method(&token)?; + let method = assembly.method(&token).ok()?; if let Some(type_name) = method.declaring_type_fullname() { Some(format!("{type_name}::{}", method.name)) } else { diff --git a/dotscope/src/deobfuscation/renamer/providers/local.rs b/dotscope/src/deobfuscation/renamer/providers/local.rs index e6393b29..c4c15dfe 100644 --- a/dotscope/src/deobfuscation/renamer/providers/local.rs +++ b/dotscope/src/deobfuscation/renamer/providers/local.rs @@ -246,7 +246,11 @@ impl RenameProvider for LocalProvider { self.config.model_path.display() ); - *self.state.lock().unwrap() = Some(InferenceState { model, runtime }); + let mut guard = self + .state + .lock() + .map_err(|e| Error::Deobfuscation(format!("Smart rename state mutex poisoned: {e}")))?; + *guard = Some(InferenceState { model, runtime }); Ok(()) } @@ -275,7 +279,10 @@ impl RenameProvider for LocalProvider { None => return Ok(None), }; - let guard = self.state.lock().unwrap(); + let guard = self + .state + .lock() + .map_err(|e| Error::Deobfuscation(format!("Smart rename state mutex poisoned: {e}")))?; let Some(ref state) = *guard else { return Ok(None); }; @@ -300,9 +307,13 @@ impl RenameProvider for LocalProvider { /// /// # Errors /// - /// This method currently does not fail. + /// Returns an error if the internal state mutex has been poisoned. fn shutdown(&mut self) -> Result<()> { - *self.state.lock().unwrap() = None; + let mut guard = self + .state + .lock() + .map_err(|e| Error::Deobfuscation(format!("Smart rename state mutex poisoned: {e}")))?; + *guard = None; log::info!("Smart rename model unloaded"); Ok(()) } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs b/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs index 0ccbf42a..cfd1f93d 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/hooks.rs @@ -793,7 +793,7 @@ fn compute_target_offset( for (stale_dummy, final_dummy_val) in sorted_entries { let final_dummy_token = Token::new(*final_dummy_val); - let Some(dummy_method) = assembly.method(&final_dummy_token) else { + let Ok(dummy_method) = assembly.method(&final_dummy_token) else { continue; }; @@ -899,7 +899,7 @@ fn is_redirect_stub_memberref( }; if memberref.class.tag == TableId::TypeDef { - if let Some(stub_method) = assembly.method(&redirect_stub_token) { + if let Ok(stub_method) = assembly.method(&redirect_stub_token) { if let Some(stub_type) = stub_method.declaring_type_rc() { return memberref.class.row == stub_type.token.row(); } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/strings.rs b/dotscope/src/deobfuscation/techniques/bitmono/strings.rs index 82f58539..43fd095c 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/strings.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/strings.rs @@ -148,7 +148,7 @@ impl Technique for BitMonoStrings { if has_crypto_ops(ssa, assembly) { decryptor_tokens.push(method_token); if decryptor_type.is_none() { - if let Some(method) = assembly.method(&method_token) { + if let Ok(method) = assembly.method(&method_token) { if let Some(decl_type) = method.declaring_type_rc() { decryptor_type = Some(decl_type.token); } diff --git a/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs b/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs index 07c1c641..2bf35fe2 100644 --- a/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs +++ b/dotscope/src/deobfuscation/techniques/bitmono/unmanaged.rs @@ -214,7 +214,7 @@ fn extract_native_string( native_token: Token, is_64bit: bool, ) -> Option { - let method = assembly.method(&native_token)?; + let method = assembly.method(&native_token).ok()?; let rva = method.rva.filter(|&r| r > 0)?; let file = assembly.file(); diff --git a/dotscope/src/deobfuscation/techniques/confuserex/constants.rs b/dotscope/src/deobfuscation/techniques/confuserex/constants.rs index 6f3eb71e..c150fcad 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/constants.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/constants.rs @@ -703,7 +703,7 @@ fn collect_module_state_fields( .collect(); for method_token in &method_tokens { - let Some(method) = assembly.method(method_token) else { + let Ok(method) = assembly.method(method_token) else { continue; }; for instr in method.instructions() { @@ -746,7 +746,7 @@ fn resolve_memberref_to_decryptor( let memberref = assembly.member_ref(&memberref_token)?; for decryptor_token in decryptor_set { - if let Some(method) = assembly.method(decryptor_token) { + if let Ok(method) = assembly.method(decryptor_token) { if method.name == memberref.name { return Some(*decryptor_token); } diff --git a/dotscope/src/deobfuscation/techniques/confuserex/resources.rs b/dotscope/src/deobfuscation/techniques/confuserex/resources.rs index b2779972..6e448ab7 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/resources.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/resources.rs @@ -406,7 +406,7 @@ fn resolve_call_name(assembly: &CilObject, token: Token) -> Option { match table_id { // MethodDef 0x06 => { - let method = assembly.method(&token)?; + let method = assembly.method(&token).ok()?; Some(method.name.clone()) } // MemberRef diff --git a/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs b/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs index 6f46bff0..275b6c71 100644 --- a/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs +++ b/dotscope/src/deobfuscation/techniques/confuserex/statemachine.rs @@ -662,7 +662,7 @@ fn extract_slot_operations( next_method: Token, field_tokens: &[Token], ) -> Option> { - let method = assembly.method(&next_method)?; + let method = assembly.method(&next_method).ok()?; // Build SSA for the method - this gives us proper data flow analysis let ssa = method.ssa(assembly).ok()?; @@ -773,7 +773,7 @@ pub fn find_constants_initializer(assembly: &CilObject) -> Option { // Get .cctor to find what methods it calls first let cctor_token = assembly.types().module_cctor()?; - let cctor = assembly.method(&cctor_token)?; + let cctor = assembly.method(&cctor_token).ok()?; // Look at ALL call instructions in .cctor // ConfuserEx injects multiple calls: anti-tamper, constants, anti-debug, etc. @@ -792,7 +792,7 @@ pub fn find_constants_initializer(assembly: &CilObject) -> Option { // Now check each candidate to see if it matches Initialize() pattern for candidate in init_candidates { - let Some(method) = assembly.method(&candidate) else { + let Ok(method) = assembly.method(&candidate) else { continue; }; @@ -858,7 +858,7 @@ pub fn find_constants_initializer(assembly: &CilObject) -> Option { if instr.flow_type == FlowType::Call { if let Operand::Token(call_target) = &instr.operand { // Check if calling a method with "Decompress" in name - if let Some(callee) = assembly.method(call_target) { + if let Ok(callee) = assembly.method(call_target) { if callee.name.contains("Decompress") || callee.name.contains("LZMA") { return Some(method.token); } @@ -967,7 +967,7 @@ fn resolve_method_spec_to_def(assembly: &CilObject, token: Token) -> Option) -> Vec { let mut declaring_types: HashSet = HashSet::new(); for &decryptor in decryptors { - if let Some(method) = assembly.method(&decryptor) { + if let Ok(method) = assembly.method(&decryptor) { if let Some(parent) = method.declaring_type_rc() { declaring_types.insert(parent.token); } diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs b/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs index 20dfd8aa..36135900 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/arrays.rs @@ -154,7 +154,7 @@ impl Technique for JiejieNetArrays { // Count ldtoken instructions in .cctor let handle_count = cctor_token - .and_then(|t| assembly.method(&t)) + .and_then(|t| assembly.method(&t).ok()) .map(|m| m.instructions().filter(|i| i.mnemonic == "ldtoken").count()) .unwrap_or(0); @@ -343,7 +343,7 @@ impl Technique for JiejieNetArrays { /// `ldtoken ` instructions. The order of `ldtoken` instructions corresponds /// to array indices 0, 1, 2, ... fn extract_cctor_field_tokens(assembly: &CilObject, cctor_token: Token) -> Vec { - let Some(method) = assembly.method(&cctor_token) else { + let Ok(method) = assembly.method(&cctor_token) else { return Vec::new(); }; @@ -413,7 +413,7 @@ fn find_my_initialize_array(assembly: &CilObject) -> (Option, Option HashMap { fn emulate_delta_chain_cctor(assembly: &CilObject, cctor_token: Token) -> HashMap { let mut values = HashMap::new(); - let Some(method) = assembly.method(&cctor_token) else { + let Ok(method) = assembly.method(&cctor_token) else { return values; }; diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs b/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs index e959ed8f..d895b7b6 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/resources.rs @@ -153,6 +153,7 @@ impl Technique for JiejieNetResources { .filter_map(|entry| { assembly .method(&entry.data_method_token) + .ok() .and_then(|method| { method .declaring_type @@ -376,7 +377,7 @@ fn find_original_bcl_call( assembly: &CilObject, interception_token: Token, ) -> Option { - let method = assembly.method(&interception_token)?; + let method = assembly.method(&interception_token).ok()?; let instructions: Vec<_> = method.instructions().collect(); // Look for callvirt instructions that call Assembly methods @@ -667,7 +668,7 @@ fn extract_resource_entries_ssa(ssa: &SsaFunction, assembly: &CilObject) -> Vec< _ => continue, }; - if let Some(called_method) = assembly.method(&method_token) { + if let Ok(called_method) = assembly.method(&method_token) { if matches!( called_method.signature.return_type.base, TypeSignature::SzArray(_) @@ -725,7 +726,7 @@ fn trace_to_string_const(ssa: &SsaFunction, var: SsaVarId, assembly: &CilObject) SsaOp::Const { value: ConstValue::DecryptedString(s), .. - } => Some(s.clone()), + } => Some(s.to_string()), SsaOp::Copy { src, .. } => trace_impl(ssa, *src, assembly, depth.saturating_add(1)), _ => None, } @@ -740,12 +741,7 @@ fn extract_and_decrypt_resource( entry: &ResourceEntry, xor_key: u8, ) -> Result> { - let method = assembly.method(&entry.data_method_token).ok_or_else(|| { - Error::Deobfuscation(format!( - "Data method 0x{:08X} not found", - entry.data_method_token.value() - )) - })?; + let method = assembly.method(&entry.data_method_token)?; let mut field_token: Option = None; let mut array_size: Option = None; diff --git a/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs b/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs index 73ba5945..b3e5346e 100644 --- a/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs +++ b/dotscope/src/deobfuscation/techniques/jiejienet/typeofs.rs @@ -127,7 +127,7 @@ impl Technique for JiejieNetTypeOf { // Extract the GetTypeFromHandle MemberRef token from the accessor body. // The accessor's IL: ldsfld handles → ldarg.0 → ldelem → call GetTypeFromHandle → ret - let get_type_from_handle_token = assembly.method(&accessor).and_then(|method| { + let get_type_from_handle_token = assembly.method(&accessor).ok().and_then(|method| { method.instructions().find_map(|instr| { if instr.mnemonic == "call" { if let Operand::Token(token) = &instr.operand { @@ -144,7 +144,7 @@ impl Technique for JiejieNetTypeOf { // Count ldtoken instructions in .cctor to determine handle count let handle_count = cctor_token - .and_then(|t| assembly.method(&t)) + .and_then(|t| assembly.method(&t).ok()) .map(|m| m.instructions().filter(|i| i.mnemonic == "ldtoken").count()) .unwrap_or(0); @@ -238,7 +238,7 @@ impl Technique for JiejieNetTypeOf { /// `ldtoken ` instructions. The order of `ldtoken` instructions corresponds /// to array indices 0, 1, 2, ... fn extract_cctor_type_tokens(assembly: &CilObject, cctor_token: Token) -> Vec { - let Some(method) = assembly.method(&cctor_token) else { + let Ok(method) = assembly.method(&cctor_token) else { return Vec::new(); }; diff --git a/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs b/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs index 25983b89..98892c55 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/antitamp.rs @@ -133,6 +133,7 @@ impl Technique for NetReactorAntiTamp { let init_method_token = fan_in.target_token; let runtime_type_token = assembly .method(&init_method_token) + .ok() .and_then(|m| m.declaring_type_rc()) .map(|t| t.token); diff --git a/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs b/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs index bd52b1e9..6bda4fd6 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/helpers.rs @@ -136,7 +136,7 @@ pub fn find_cctor_fan_in_target(assembly: &CilObject) -> Option Option Vec { continue; } - let Some(method) = assembly.method(method_token) else { + let Ok(method) = assembly.method(method_token) else { continue; }; @@ -468,7 +469,7 @@ pub fn find_nr_private_impl_containers(assembly: &CilObject) -> Vec`-trial NR-context gate it is highly specific to NR's /// license-verification stub. pub fn has_single_shot_bool_guard(assembly: &CilObject, method_token: Token) -> bool { - let Some(method) = assembly.method(&method_token) else { + let Ok(method) = assembly.method(&method_token) else { return false; }; @@ -614,7 +615,7 @@ enum AccessorKind { /// Body shape (4 instructions): /// `ldsflda ` → `ldarg.0` → `call instance ModuleHandle::Get*FromMetadataToken` → `ret`. fn classify_token_accessor_body(assembly: &CilObject, method_token: Token) -> Option { - let method = assembly.method(&method_token)?; + let method = assembly.method(&method_token).ok()?; let instrs: Vec<_> = method.instructions().collect(); if instrs.len() != 4 { return None; @@ -675,7 +676,7 @@ pub fn find_resources_referenced_by_methods( let mut results = Vec::new(); for &method_token in method_tokens { - let Some(method) = assembly.method(&method_token) else { + let Ok(method) = assembly.method(&method_token) else { continue; }; for instr in method.instructions() { @@ -727,7 +728,7 @@ pub fn classify_injected_cctors( let mut modified = Vec::new(); for &cctor_token in calling_cctors { - let Some(method) = assembly.method(&cctor_token) else { + let Ok(method) = assembly.method(&cctor_token) else { continue; }; diff --git a/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs b/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs index ecacce54..a3d60d8f 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/necrobit.rs @@ -164,6 +164,7 @@ impl Technique for NetReactorNecroBit { let runtime_type_token = init_method_token.and_then(|init_token| { assembly .method(&init_token) + .ok() .and_then(|m| m.declaring_type_rc()) .map(|t| t.token) }); @@ -923,7 +924,7 @@ fn extract_bodies_from_image( let mut still_stubs = 0usize; for &token in stub_tokens { - let Some(method) = assembly.method(&token) else { + let Ok(method) = assembly.method(&token) else { continue; }; let Some(rva) = method.rva.filter(|&r| r > 0) else { diff --git a/dotscope/src/deobfuscation/techniques/netreactor/resources.rs b/dotscope/src/deobfuscation/techniques/netreactor/resources.rs index 5af304ec..a4cd0753 100644 --- a/dotscope/src/deobfuscation/techniques/netreactor/resources.rs +++ b/dotscope/src/deobfuscation/techniques/netreactor/resources.rs @@ -675,7 +675,7 @@ fn find_resolve_handler_registration( /// handler always lives on the resolver type itself; legitimate handlers /// almost never do. fn handler_lives_on_type(assembly: &CilObject, handler_token: Token, cil_type: &CilTypeRc) -> bool { - let Some(handler) = assembly.method(&handler_token) else { + let Ok(handler) = assembly.method(&handler_token) else { return false; }; let Some(declaring) = handler.declaring_type_rc() else { @@ -758,7 +758,7 @@ fn is_lazy_init_body(assembly: &CilObject, method: &MethodRc, cil_type: &CilType return false; }; // The newobj must target a .ctor on the resolver type. - let Some(ctor) = assembly.method(ctor_token) else { + let Ok(ctor) = assembly.method(ctor_token) else { return false; }; let Some(declaring) = ctor.declaring_type_rc() else { diff --git a/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs b/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs index 4dd6ee1b..2c70ba36 100644 --- a/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs +++ b/dotscope/src/deobfuscation/techniques/obfuscar/strings.rs @@ -283,7 +283,7 @@ fn is_obfuscar_helper_type(namespace: &str) -> bool { /// Scans for the characteristic double-XOR pattern: `xor`, `ldc.i4 `, `xor`. /// The constant between the two `xor` instructions is the encryption key byte. fn extract_xor_key_from_cctor(assembly: &CilObject, cctor_token: Token) -> Option { - let method = assembly.method(&cctor_token)?; + let method = assembly.method(&cctor_token).ok()?; let instructions: Vec<_> = method.instructions().collect(); for window in instructions.windows(3) { diff --git a/dotscope/src/deobfuscation/utils.rs b/dotscope/src/deobfuscation/utils.rs index 1da9336b..e076eed9 100644 --- a/dotscope/src/deobfuscation/utils.rs +++ b/dotscope/src/deobfuscation/utils.rs @@ -496,7 +496,7 @@ pub(crate) fn exclude_cross_calling_candidates( candidates .iter() .filter(|token| { - let Some(method) = assembly.method(token) else { + let Ok(method) = assembly.method(token) else { return true; }; let calls_other = method.instructions().any(|instr| { @@ -694,6 +694,7 @@ pub(crate) fn is_method_on_type(assembly: &CilObject, token: Token, type_name: & match token.table() { 0x06 => assembly .method(&token) + .ok() .and_then(|m| m.declaring_type_rc()) .is_some_and(|ty| ty.name.contains(type_name)), 0x0A => assembly diff --git a/dotscope/src/emulation/engine/generics.rs b/dotscope/src/emulation/engine/generics.rs index 64c0cb50..afcfe9bb 100644 --- a/dotscope/src/emulation/engine/generics.rs +++ b/dotscope/src/emulation/engine/generics.rs @@ -47,7 +47,10 @@ use log::warn; use crate::{ emulation::tokens, - metadata::{tables::GenericParamAttributes, token::Token}, + metadata::{ + tables::{GenericParamAttributes, GenericParamVariance}, + token::Token, + }, }; /// Tracks generic type and method instantiations during emulation. @@ -269,39 +272,40 @@ where continue; }; - let variance = flags.bits() & GenericParamAttributes::VARIANCE_MASK.bits(); - - if variance == GenericParamAttributes::COVARIANT.bits() { - // Covariant (out): source must be assignable to target (derived → base) - if !is_assignable(*src, *tgt) { + match flags.variance() { + GenericParamVariance::Covariant => { + // Covariant (out): source must be assignable to target (derived → base) + if !is_assignable(*src, *tgt) { + warn!( + "Generic variance mismatch at position {i}: covariant parameter \ + requires 0x{:08X} assignable to 0x{:08X}", + src.value(), + tgt.value() + ); + return false; + } + } + GenericParamVariance::Contravariant => { + // Contravariant (in): target must be assignable to source (base → derived) + if !is_assignable(*tgt, *src) { + warn!( + "Generic variance mismatch at position {i}: contravariant parameter \ + requires 0x{:08X} assignable to 0x{:08X}", + tgt.value(), + src.value() + ); + return false; + } + } + GenericParamVariance::Invariant => { warn!( - "Generic variance mismatch at position {i}: covariant parameter \ - requires 0x{:08X} assignable to 0x{:08X}", + "Generic variance mismatch at position {i}: invariant parameter \ + requires exact match but got 0x{:08X} vs 0x{:08X}", src.value(), tgt.value() ); return false; } - } else if variance == GenericParamAttributes::CONTRAVARIANT.bits() { - // Contravariant (in): target must be assignable to source (base → derived) - if !is_assignable(*tgt, *src) { - warn!( - "Generic variance mismatch at position {i}: contravariant parameter \ - requires 0x{:08X} assignable to 0x{:08X}", - tgt.value(), - src.value() - ); - return false; - } - } else { - // Invariant: must match exactly - warn!( - "Generic variance mismatch at position {i}: invariant parameter \ - requires exact match but got 0x{:08X} vs 0x{:08X}", - src.value(), - tgt.value() - ); - return false; } } diff --git a/dotscope/src/emulation/runtime/bcl/reflection/members.rs b/dotscope/src/emulation/runtime/bcl/reflection/members.rs index 7a4d87dd..695d15f7 100644 --- a/dotscope/src/emulation/runtime/bcl/reflection/members.rs +++ b/dotscope/src/emulation/runtime/bcl/reflection/members.rs @@ -804,6 +804,7 @@ fn extract_member_custom_attrs( match thread.heap().get(href) { Ok(HeapObject::ReflectionMethod { method_token, .. }) => asm .method(&method_token) + .ok() .map(|m| m.custom_attributes.clone()), Ok(HeapObject::ReflectionField { field_token, diff --git a/dotscope/src/emulation/runtime/bcl/reflection/methods.rs b/dotscope/src/emulation/runtime/bcl/reflection/methods.rs index 1fe18866..82063cbf 100644 --- a/dotscope/src/emulation/runtime/bcl/reflection/methods.rs +++ b/dotscope/src/emulation/runtime/bcl/reflection/methods.rs @@ -256,7 +256,7 @@ fn method_invoke_pre(ctx: &HookContext<'_>, thread: &mut EmulationThread) -> Pre // Look up method signature to detect ByRef parameters let sig_params: Option> = thread .assembly() - .and_then(|asm| asm.method(&method_token)) + .and_then(|asm| asm.method(&method_token).ok()) .map(|m| m.signature.params.iter().map(|p| p.by_ref).collect()); // Extract the method arguments from the object[] array (second argument). @@ -689,7 +689,7 @@ fn method_get_method_body_pre( thread.heap().get(*method_ref) { if let Some(asm) = thread.assembly().cloned() { - if let Some(method) = asm.method(&method_token) { + if let Ok(method) = asm.method(&method_token) { if method.body.get().is_some() { match thread .heap_mut() diff --git a/dotscope/src/emulation/value/emvalue.rs b/dotscope/src/emulation/value/emvalue.rs index e7eda6ff..402ad251 100644 --- a/dotscope/src/emulation/value/emvalue.rs +++ b/dotscope/src/emulation/value/emvalue.rs @@ -774,6 +774,7 @@ impl From<&ConstValue> for EmValue { ConstValue::Null | ConstValue::DecryptedString(_) | ConstValue::DecryptedArray { .. } + | ConstValue::Vector(_) | ConstValue::Type(_) | ConstValue::MethodHandle(_) | ConstValue::FieldHandle(_) => EmValue::Null, diff --git a/dotscope/src/error.rs b/dotscope/src/error.rs index 87642776..0bf92017 100644 --- a/dotscope/src/error.rs +++ b/dotscope/src/error.rs @@ -60,8 +60,8 @@ //! Err(Error::NotSupported) => { //! eprintln!("File format is not supported"); //! } -//! Err(Error::Malformed { message, file, line }) => { -//! eprintln!("Malformed file: {} ({}:{})", message, file, line); +//! Err(Error::Parse(parse_err)) => { +//! eprintln!("Malformed file: {}", parse_err); //! } //! Err(Error::Io(io_err)) => { //! eprintln!("I/O error: {}", io_err); @@ -147,39 +147,39 @@ impl std::fmt::Display for EmulationError { macro_rules! malformed_error { // Single string version ($msg:expr) => { - $crate::Error::Malformed { - message: $msg.to_string(), - file: file!(), - line: line!(), - } + $crate::Error::Parse($crate::ParseFailure::Other { + stage: $crate::ParseStage::Generic, + message: format!("{} ({}:{})", $msg, file!(), line!()), + }) }; // Format string with arguments version ($fmt:expr, $($arg:tt)*) => { - $crate::Error::Malformed { - message: format!($fmt, $($arg)*), - file: file!(), - line: line!(), - } + $crate::Error::Parse($crate::ParseFailure::Other { + stage: $crate::ParseStage::Generic, + message: format!("{} ({}:{})", format!($fmt, $($arg)*), file!(), line!()), + }) }; } -/// Helper macro for creating out-of-bounds errors with source location information. +/// Helper macro for creating [`crate::ParseFailure::OutOfBounds`] errors. /// -/// This macro simplifies the creation of [`crate::Error::OutOfBounds`] errors by automatically -/// capturing the current file and line number where the out-of-bounds access was detected. +/// Convenience constructor for the structured out-of-bounds parse failure. +/// The expanded form is `Error::Parse(ParseFailure::OutOfBounds { stage: +/// ParseStage::Generic })` — call sites that can supply a more specific stage +/// should construct the variant directly instead of using this macro. /// -/// # Returns -/// -/// Returns a [`crate::Error::OutOfBounds`] variant with automatically captured source -/// location information for debugging purposes. +/// Source-location capture (`file!()`/`line!()`) was previously embedded in +/// the rendered message; it's no longer included because the structured +/// variant is intended to be matched on, not stringified for debugging. If +/// you need source-location info, use a structured `Error::Parse(...)` +/// expression and panic/`tracing::error!` with `Location::caller()` at the +/// call site. /// /// # Examples /// /// ```rust,ignore /// # use dotscope::out_of_bounds_error; -/// // Replace: Err(Error::OutOfBounds) -/// // With: Err(out_of_bounds_error!()) /// if index >= data.len() { /// return Err(out_of_bounds_error!()); /// } @@ -187,10 +187,9 @@ macro_rules! malformed_error { #[macro_export] macro_rules! out_of_bounds_error { () => { - $crate::Error::OutOfBounds { - file: file!(), - line: line!(), - } + $crate::Error::Parse($crate::ParseFailure::OutOfBounds { + stage: $crate::ParseStage::Generic, + }) }; } @@ -233,46 +232,6 @@ macro_rules! out_of_bounds_error { #[derive(Error, Debug)] pub enum Error { // File parsing Errors - /// The file is damaged and could not be parsed. - /// - /// This error indicates that the file structure is corrupted or doesn't - /// conform to the expected .NET PE format. The error includes the source - /// location where the malformation was detected for debugging purposes. - /// - /// # Fields - /// - /// * `message` - Detailed description of what was malformed - /// * `file` - Source file where the error was detected - /// * `line` - Source line where the error was detected - #[error("Malformed - {file}:{line}: {message}")] - Malformed { - /// The message to be printed for the Malformed error - message: String, - /// The source file in which this error occured - file: &'static str, - /// The source line in which this error occured - line: u32, - }, - - /// An out of bound access was attempted while parsing the file. - /// - /// This error occurs when trying to read data beyond the end of the file - /// or stream. It's a safety check to prevent buffer overruns during parsing. - /// The error includes the source location where the out-of-bounds access - /// was detected for debugging purposes. - /// - /// # Fields - /// - /// * `file` - Source file where the error was detected - /// * `line` - Source line where the error was detected - #[error("Out of Bounds - {file}:{line}")] - OutOfBounds { - /// The source file in which this error occurred - file: &'static str, - /// The source line in which this error occurred - line: u32, - }, - /// This file type is not supported. /// /// Indicates that the input file is not a supported .NET PE executable, @@ -312,6 +271,29 @@ pub enum Error { #[error("Failed to find type in TypeSystem - {0}")] TypeNotFound(Token), + /// Method or method-spec lookup failure. + /// + /// Returned by [`crate::CilObject::method`] and + /// [`crate::CilObject::method_spec`] when the supplied token does not + /// resolve to a row in the corresponding metadata table. See + /// [`MethodLookupError`] for the variant set. + #[error(transparent)] + LookupMethod(#[from] MethodLookupError), + + /// Structured parse-pipeline failure. + /// + /// Wraps a [`ParseFailure`] so consumers can categorize parse failures + /// (truncated headers, bad magic, unsupported schemas, heap corruption, + /// invalid fields) without parsing string messages. Returned by every + /// parse-pipeline error site in [`crate::file`], [`crate::metadata::root`], + /// and [`crate::metadata::streams`]. + /// + /// Match on `Error::Parse(_)` to recover the structured failure, or on a + /// specific variant of [`ParseFailure`] to react to a particular failure + /// class (e.g. `ParseFailure::Truncated { .. }`). + #[error(transparent)] + Parse(#[from] ParseFailure), + /// General error during `TypeSystem` usage. /// /// Covers various type system operations that can fail, such as @@ -430,18 +412,6 @@ pub enum Error { #[error("Cross-reference error: {0}")] CrossReferenceError(String), - /// Heap bounds validation failed. - /// - /// This error occurs when metadata heap indices are out of bounds - /// for the target heap. - #[error("Heap bounds error: {heap} index {index}")] - HeapBoundsError { - /// The type of heap (strings, blobs, etc.) - heap: String, - /// The out-of-bounds index - index: u32, - }, - /// Conflict resolution failed. /// /// This error occurs when the conflict resolution system cannot @@ -648,6 +618,321 @@ pub enum Error { TracingError(String), } +/// Failure modes for method-by-token lookups. +/// +/// Returned by [`crate::CilObject::method`] and +/// [`crate::CilObject::method_spec`]. Propagates into [`Error`] via the +/// [`Error::LookupMethod`] variant — call sites that already use +/// `Result<_, Error>` can propagate with `?` without manual conversion. +/// +/// # Stability +/// +/// `#[non_exhaustive]` so additional failure modes (e.g. partial resolution, +/// stale weak references) can be added without a breaking change. Consumers +/// must include a wildcard arm when matching. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[non_exhaustive] +pub enum MethodLookupError { + /// The supplied token does not appear in the [`MethodDef`] table. + /// + /// This is distinct from "method exists but has no body" — for an abstract + /// or P/Invoke method the call returns `Ok(method)` and the caller can + /// inspect [`crate::metadata::method::Method::rva_kind`] to see why no IL + /// is present. + /// + /// [`MethodDef`]: crate::metadata::tables::MethodDef + #[error("MethodDef token {0} not found")] + NotFound(Token), + + /// The supplied token does not appear in the [`MethodSpec`] table. + /// + /// [`MethodSpec`]: crate::metadata::tables::MethodSpec + #[error("MethodSpec token {0} not found")] + SpecNotFound(Token), +} + +/// Pipeline stage at which a parse error originated. +/// +/// Returned as part of [`ParseFailure`]. Lets consumers act on the error +/// category (e.g. retry an upgrade only for `Cor20Header`/`MetadataRoot` +/// failures) without parsing string messages. +/// +/// `#[non_exhaustive]` — additional stages may be added as the parser +/// surface grows. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum ParseStage { + /// DOS / MZ header at the very start of the PE file. + DosHeader, + /// PE signature ("PE\0\0") immediately following the DOS stub. + PeSignature, + /// COFF file header following the PE signature. + CoffHeader, + /// Optional header (PE32 / PE32+). + OptionalHeader, + /// Section table immediately after the optional header. + SectionTable, + /// CLI / COR20 runtime header located via the data-directory entry. + Cor20Header, + /// Metadata root header (signature, versions, stream count). + MetadataRoot, + /// Per-stream header inside the metadata root. + StreamHeader, + /// Tilde (`#~`) stream containing metadata table rows. + TildeStream, + /// Per-table row layout/decoding within the tilde stream. + TableRow, + /// One of the heap streams (`#Strings`, `#US`, `#Blob`, `#GUID`). + Heap, + /// Generic data-directory entry traversal. + DataDirectory, + /// Resource directory (.NET embedded resources). + Resources, + /// VTableFixup directory. + VTableFixup, + /// Strong-name signature blob. + StrongName, + /// Signature parsing within blobs (method/field/etc.). + Signature, + /// Method body header / code / EH-clause parsing. + MethodBody, + /// Custom-attribute decoding from the blob heap. + CustomAttribute, + /// DeclSecurity permission-set decoding. + PermissionSet, + /// CIL instruction decoding (disassembly). + InstructionDecoder, + /// CIL instruction encoding (assembler/serializer write paths). + InstructionEncoder, + /// Assembly / metadata writer paths (`crate::cilassembly::writer`). + AssemblyWriter, + /// Imports/exports table decoding. + ImportsExports, + /// Type system construction post-parse (`crate::metadata::tables`, + /// owned-type construction). + TypeSystem, + /// Validation passes (`crate::metadata::validation`). + Validation, + /// Emulation/runtime loader (`crate::emulation::loader`). + EmulationLoader, + /// Generic byte-parser primitives — used by helpers that have no + /// inherent stage (e.g. `crate::file::parser::Parser`). + Generic, +} + +impl std::fmt::Display for ParseStage { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + let s = match self { + ParseStage::DosHeader => "DOS header", + ParseStage::PeSignature => "PE signature", + ParseStage::CoffHeader => "COFF header", + ParseStage::OptionalHeader => "PE optional header", + ParseStage::SectionTable => "PE section table", + ParseStage::Cor20Header => "COR20 header", + ParseStage::MetadataRoot => "metadata root", + ParseStage::StreamHeader => "stream header", + ParseStage::TildeStream => "tilde (#~) stream", + ParseStage::TableRow => "metadata table row", + ParseStage::Heap => "heap", + ParseStage::DataDirectory => "data directory", + ParseStage::Resources => "resources", + ParseStage::VTableFixup => "VTable fixup", + ParseStage::StrongName => "strong-name signature", + ParseStage::Signature => "signature blob", + ParseStage::MethodBody => "method body", + ParseStage::CustomAttribute => "custom attribute", + ParseStage::PermissionSet => "permission set", + ParseStage::InstructionDecoder => "CIL decoder", + ParseStage::InstructionEncoder => "CIL encoder", + ParseStage::AssemblyWriter => "assembly writer", + ParseStage::ImportsExports => "imports/exports", + ParseStage::TypeSystem => "type system", + ParseStage::Validation => "validation", + ParseStage::EmulationLoader => "emulation loader", + ParseStage::Generic => "generic byte parser", + }; + f.write_str(s) + } +} + +/// Heap stream kind. +/// +/// Carried by [`ParseFailure::HeapOutOfBounds`] / +/// [`ParseFailure::HeapCorrupt`] so consumers can pinpoint which heap +/// failed without parsing a message string. `#[non_exhaustive]`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum HeapKind { + /// `#Strings` heap — null-terminated UTF-8 identifier names. + Strings, + /// `#US` heap — length-prefixed UTF-16 user strings. + UserStrings, + /// `#Blob` heap — length-prefixed binary blobs. + Blob, + /// `#GUID` heap — packed 16-byte GUIDs. + Guid, +} + +impl std::fmt::Display for HeapKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + HeapKind::Strings => "#Strings", + HeapKind::UserStrings => "#US", + HeapKind::Blob => "#Blob", + HeapKind::Guid => "#GUID", + }) + } +} + +/// Stream kind inside the metadata root. +/// +/// Carried by [`ParseFailure::TruncatedStream`] and similar variants. +/// `#[non_exhaustive]`. +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +#[non_exhaustive] +pub enum StreamKind { + /// Compressed metadata-table stream (`#~`). + Tilde, + /// Uncompressed metadata-table stream (`#-`). + TildeUncompressed, + /// `#Strings` heap. + Strings, + /// `#US` heap. + UserStrings, + /// `#Blob` heap. + Blob, + /// `#GUID` heap. + Guid, + /// Portable PDB (`#Pdb`) stream. + Pdb, +} + +impl std::fmt::Display for StreamKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(match self { + StreamKind::Tilde => "#~", + StreamKind::TildeUncompressed => "#-", + StreamKind::Strings => "#Strings", + StreamKind::UserStrings => "#US", + StreamKind::Blob => "#Blob", + StreamKind::Guid => "#GUID", + StreamKind::Pdb => "#Pdb", + }) + } +} + +/// Structured parse-pipeline failure. +/// +/// Reported through [`Error::Parse`] for every error site in +/// [`crate::file`], [`crate::metadata::root`], and +/// [`crate::metadata::streams`], plus per-table parse paths that read raw +/// PE/metadata bytes. Replaces the stringly-typed [`Error::Malformed`] / +/// [`Error::OutOfBounds`] / [`Error::HeapBoundsError`] variants for parse +/// sites — those remain valid for non-parse code (validation, lookups, +/// emulation), but new parse code must use [`ParseFailure`]. +/// +/// # Stability +/// +/// `#[non_exhaustive]` — additional well-known failure classes can be added +/// without a breaking change. Consumers must include a wildcard arm when +/// matching. +#[derive(Debug, Clone, PartialEq, Eq, Error)] +#[non_exhaustive] +pub enum ParseFailure { + /// Header / structure bytes were not present at the expected offset. + /// + /// `expected` and `found` are byte counts when the gap is over a known + /// header. For partial reads where exact byte counts are unknown, both + /// fields may be 0. + #[error("truncated {stage} (expected {expected} bytes, found {found})")] + Truncated { + /// Pipeline stage in which truncation was detected. + stage: ParseStage, + /// Bytes the parser needed. + expected: usize, + /// Bytes that were actually available. + found: usize, + }, + + /// A magic number or signature did not match the expected value. + #[error("bad {stage} magic: expected 0x{expected:08X}, found 0x{found:08X}")] + BadMagic { + /// Pipeline stage that performed the magic check. + stage: ParseStage, + /// Expected magic value. + expected: u32, + /// Magic value actually read. + found: u32, + }, + + /// A header field was outside the acceptable range or otherwise invalid. + #[error("invalid {stage} field `{field}`: {reason}")] + InvalidField { + /// Pipeline stage in which the field appears. + stage: ParseStage, + /// Field name (static — derived from struct definition). + field: &'static str, + /// Human-readable description of why the value was rejected. + reason: String, + }, + + /// A schema version is recognized but not supported by this parser. + #[error("unsupported {stage} schema version `{version}`")] + UnsupportedSchema { + /// Pipeline stage that surfaced the schema mismatch. + stage: ParseStage, + /// Schema version string (e.g. ECMA-335 §II.24.2.1 form). + version: String, + }, + + /// A read crossed the end of the parse buffer. + #[error("read past end of buffer in {stage}")] + OutOfBounds { + /// Pipeline stage in which the OOB read was attempted. + stage: ParseStage, + }, + + /// A stream's declared offset/size exceeds the surrounding buffer. + #[error("truncated stream `{stream}` at offset {offset}")] + TruncatedStream { + /// Stream that could not be fully read. + stream: StreamKind, + /// File offset at which the read ran out of bytes. + offset: u32, + }, + + /// A heap reference was outside the bounds of its stream. + #[error("heap `{heap}` index {index} out of bounds")] + HeapOutOfBounds { + /// Heap whose bounds were violated. + heap: HeapKind, + /// Index that exceeded the heap. + index: u32, + }, + + /// A heap's byte content did not satisfy its format invariant. + #[error("heap `{heap}` corrupt: {reason}")] + HeapCorrupt { + /// Heap whose content was rejected. + heap: HeapKind, + /// Human-readable description of the corruption. + reason: String, + }, + + /// Free-form parse failure used as a migration fallback for sites whose + /// failure shape does not yet fit a structured variant. + /// + /// New code should prefer one of the structured variants; this one is + /// retained so partial migrations do not block on unusual call sites. + #[error("{stage}: {message}")] + Other { + /// Pipeline stage in which the error occurred. + stage: ParseStage, + /// Free-form description of the failure. + message: String, + }, +} + impl Clone for Error { fn clone(&self) -> Self { match self { @@ -675,6 +960,10 @@ impl Clone for Error { Error::X86Error(s) => Error::X86Error(s.clone()), // Tracing errors are cloneable Error::TracingError(s) => Error::TracingError(s.clone()), + // Method-lookup errors are pure data and Clone-derived. + Error::LookupMethod(e) => Error::LookupMethod(e.clone()), + // Parse failures are pure data and Clone-derived. + Error::Parse(e) => Error::Parse(e.clone()), // For all other variants, convert to their string representation and use Other other => Error::Other(other.to_string()), } @@ -685,7 +974,9 @@ impl From for Error { fn from(err: cowfile::Error) -> Self { match err { cowfile::Error::Io(io_err) => Error::Io(io_err), - cowfile::Error::OutOfBounds { .. } => Error::Other(err.to_string()), + cowfile::Error::OutOfBounds { .. } => Error::Parse(ParseFailure::OutOfBounds { + stage: ParseStage::Generic, + }), cowfile::Error::LockPoisoned(msg) => Error::LockError(msg), } } diff --git a/dotscope/src/file/mod.rs b/dotscope/src/file/mod.rs index e98df9f3..5ac476a3 100644 --- a/dotscope/src/file/mod.rs +++ b/dotscope/src/file/mod.rs @@ -133,7 +133,7 @@ use crate::{ }, utils::align_to, Error::{self, Goblin, LayoutFailed, Other}, - Result, + ParseFailure, ParseStage, Result, }; use goblin::pe::PE; use pe::{DataDirectory, DataDirectoryType, Pe}; @@ -878,8 +878,11 @@ impl File { /// ``` pub fn data_slice(&self, offset: usize, len: usize) -> Result<&[u8]> { let base = self.data.data(); - let end = offset.checked_add(len).ok_or(out_of_bounds_error!())?; - base.get(offset..end).ok_or(out_of_bounds_error!()) + let oob = || ParseFailure::OutOfBounds { + stage: ParseStage::DataDirectory, + }; + let end = offset.checked_add(len).ok_or_else(|| Error::from(oob()))?; + base.get(offset..end).ok_or_else(|| Error::from(oob())) } /// Converts a virtual address (VA) to a file offset. @@ -915,9 +918,16 @@ impl File { /// ``` pub fn va_to_offset(&self, va: usize) -> Result { let ib = self.imagebase(); - let rva_u64 = (va as u64).checked_sub(ib).ok_or(out_of_bounds_error!())?; - let rva = usize::try_from(rva_u64) - .map_err(|_| malformed_error!("RVA too large to fit in usize: {}", rva_u64))?; + let rva_u64 = (va as u64).checked_sub(ib).ok_or_else(|| { + Error::from(ParseFailure::OutOfBounds { + stage: ParseStage::DataDirectory, + }) + })?; + let rva = usize::try_from(rva_u64).map_err(|_| ParseFailure::InvalidField { + stage: ParseStage::DataDirectory, + field: "rva", + reason: format!("RVA too large to fit in usize: {rva_u64}"), + })?; self.rva_to_offset(rva) } @@ -954,32 +964,41 @@ impl File { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn rva_to_offset(&self, rva: usize) -> Result { + let invalid = |field: &'static str, reason: String| ParseFailure::InvalidField { + stage: ParseStage::SectionTable, + field, + reason, + }; for section in &self.pe.sections { let Some(section_max) = section.virtual_address.checked_add(section.virtual_size) else { - return Err(malformed_error!( - "Section malformed, causing integer overflow - {} + {}", - section.virtual_address, - section.virtual_size - )); + return Err(invalid( + "section_extent", + format!( + "section malformed, causing integer overflow - {} + {}", + section.virtual_address, section.virtual_size + ), + ) + .into()); }; let rva_u32 = u32::try_from(rva) - .map_err(|_| malformed_error!("RVA too large to fit in u32: {}", rva))?; + .map_err(|_| invalid("rva", format!("RVA too large to fit in u32: {rva}")))?; if section.virtual_address <= rva_u32 && section_max > rva_u32 { let delta = rva .checked_sub(section.virtual_address as usize) - .ok_or_else(|| malformed_error!("RVA underflow vs section base"))?; + .ok_or_else(|| invalid("rva", "RVA underflow vs section base".into()))?; return delta .checked_add(section.pointer_to_raw_data as usize) - .ok_or_else(|| malformed_error!("RVA-to-offset overflow")); + .ok_or_else(|| Error::from(invalid("rva", "RVA-to-offset overflow".into()))); } } - Err(malformed_error!( - "RVA could not be converted to offset - {}", - rva - )) + Err(invalid( + "rva", + format!("RVA could not be converted to offset - {rva}"), + ) + .into()) } /// Converts a file offset to a relative virtual address (RVA). @@ -1014,34 +1033,49 @@ impl File { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn offset_to_rva(&self, offset: usize) -> Result { + let invalid = |field: &'static str, reason: String| ParseFailure::InvalidField { + stage: ParseStage::SectionTable, + field, + reason, + }; for section in &self.pe.sections { let Some(section_max) = section .pointer_to_raw_data .checked_add(section.size_of_raw_data) else { - return Err(malformed_error!( - "Section malformed, causing integer overflow - {} + {}", - section.pointer_to_raw_data, - section.size_of_raw_data - )); + return Err(invalid( + "section_extent", + format!( + "section malformed, causing integer overflow - {} + {}", + section.pointer_to_raw_data, section.size_of_raw_data + ), + ) + .into()); }; - let offset_u32 = u32::try_from(offset) - .map_err(|_| malformed_error!("Offset too large to fit in u32: {}", offset))?; + let offset_u32 = u32::try_from(offset).map_err(|_| { + invalid( + "offset", + format!("offset too large to fit in u32: {offset}"), + ) + })?; if section.pointer_to_raw_data <= offset_u32 && section_max > offset_u32 { let delta = offset .checked_sub(section.pointer_to_raw_data as usize) - .ok_or_else(|| malformed_error!("Offset underflow vs section base"))?; + .ok_or_else(|| invalid("offset", "offset underflow vs section base".into()))?; return delta .checked_add(section.virtual_address as usize) - .ok_or_else(|| malformed_error!("Offset-to-RVA overflow")); + .ok_or_else(|| { + Error::from(invalid("offset", "offset-to-RVA overflow".into())) + }); } } - Err(malformed_error!( - "Offset could not be converted to RVA - {}", - offset - )) + Err(invalid( + "offset", + format!("offset could not be converted to RVA - {offset}"), + ) + .into()) } /// Determines if a section contains .NET metadata by checking the actual metadata RVA. diff --git a/dotscope/src/file/parser.rs b/dotscope/src/file/parser.rs index e31949af..04fbf643 100644 --- a/dotscope/src/file/parser.rs +++ b/dotscope/src/file/parser.rs @@ -97,9 +97,30 @@ use crate::{ metadata::token::Token, utils::{read_be_at, read_le_at, CilIO}, - Result, + Error, ParseFailure, ParseStage, Result, }; +/// File-local helpers for converting parser failures into the crate's +/// [`Error`] without per-site verbosity. The parser is a generic +/// building block, so all errors carry [`ParseStage::Generic`]. +#[inline] +fn oob_err() -> Error { + ParseFailure::OutOfBounds { + stage: ParseStage::Generic, + } + .into() +} + +#[inline] +fn invalid_err(field: &'static str, reason: String) -> Error { + ParseFailure::InvalidField { + stage: ParseStage::Generic, + field, + reason, + } + .into() +} + /// A generic binary data parser for reading .NET metadata structures. /// /// `Parser` provides a cursor-based interface for reading binary data in both @@ -249,7 +270,7 @@ impl<'a> Parser<'a> { /// ``` pub fn seek(&mut self, pos: usize) -> Result<()> { if pos >= self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob_err()); } self.position = pos; @@ -298,12 +319,9 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn advance_by(&mut self, step: usize) -> Result<()> { - let new_pos = self - .position - .checked_add(step) - .ok_or(out_of_bounds_error!())?; + let new_pos = self.position.checked_add(step).ok_or(oob_err())?; if new_pos > self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob_err()); } self.position = new_pos; Ok(()) @@ -363,10 +381,7 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn peek_byte(&self) -> Result { - self.data - .get(self.position) - .copied() - .ok_or(out_of_bounds_error!()) + self.data.get(self.position).copied().ok_or(oob_err()) } /// Peek at a value of type `T` in little-endian format without advancing the position. @@ -475,20 +490,14 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn align(&mut self, alignment: usize) -> Result<()> { - let rem = self - .position - .checked_rem(alignment) - .ok_or(out_of_bounds_error!())?; + let rem = self.position.checked_rem(alignment).ok_or(oob_err())?; let padding = alignment .wrapping_sub(rem) .checked_rem(alignment) - .ok_or(out_of_bounds_error!())?; - let new_pos = self - .position - .checked_add(padding) - .ok_or(out_of_bounds_error!())?; + .ok_or(oob_err())?; + let new_pos = self.position.checked_add(padding).ok_or(oob_err())?; if new_pos > self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob_err()); } self.position = new_pos; Ok(()) @@ -587,7 +596,10 @@ impl<'a> Parser<'a> { return Ok(value); } - Err(malformed_error!("Invalid compressed uint - {}", first_byte)) + Err(invalid_err( + "compressed_uint", + format!("invalid compressed uint - {first_byte}"), + )) } /// Read a compressed signed integer as defined in ECMA-335 II.23.2. @@ -674,19 +686,18 @@ impl<'a> Parser<'a> { 0x1 => 0x0100_0000, // TypeRef 0x2 => 0x1B00_0000, // TypeSpec _ => { - return Err(malformed_error!( - "Invalid compressed token - {}", - compressed_token + return Err(invalid_err( + "compressed_token", + format!("invalid compressed token - {compressed_token}"), )) } }; let table_index = compressed_token >> 2; let token = table.checked_add(table_index).ok_or_else(|| { - malformed_error!( - "Compressed token index overflows table base: {} + {}", - table, - table_index + invalid_err( + "compressed_token", + format!("token index overflows table base: {table} + {table_index}"), ) })?; @@ -724,13 +735,13 @@ impl<'a> Parser<'a> { let mut shift: u32 = 0; loop { - let byte = *self.data.get(self.position).ok_or(out_of_bounds_error!())?; - self.position = self.position.checked_add(1).ok_or(out_of_bounds_error!())?; + let byte = *self.data.get(self.position).ok_or(oob_err())?; + self.position = self.position.checked_add(1).ok_or(oob_err())?; value |= u32::from(byte & 0x7F) << shift; shift = shift .checked_add(7) - .ok_or_else(|| malformed_error!("7-bit encoded integer overflow"))?; + .ok_or_else(|| invalid_err("varint", "7-bit encoded integer overflow".into()))?; if (byte & 0x80) == 0 { break; @@ -739,9 +750,11 @@ impl<'a> Parser<'a> { // A u32 can hold at most 32 bits; after 4 bytes we've read 28 bits. // A 5th continuation byte would push past 32 bits, causing overflow. if shift >= 32 { - return Err(malformed_error!( - "7-bit encoded integer overflow: value exceeds u32 capacity after {} bits", - shift + return Err(invalid_err( + "varint", + format!( + "7-bit encoded integer overflow: value exceeds u32 capacity after {shift} bits", + ), )); } } @@ -781,26 +794,27 @@ impl<'a> Parser<'a> { if b == 0 { break; } - end = end.checked_add(1).ok_or(out_of_bounds_error!())?; + end = end.checked_add(1).ok_or(oob_err())?; } // Handle two cases: // 1. Found null terminator (end < data.len()): normal null-terminated string // 2. Reached end of data (end == data.len()): string without null terminator (valid case) - let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + let string_data = self.data.get(start..end).ok_or(oob_err())?; if end < self.data.len() { - self.position = end.checked_add(1).ok_or(out_of_bounds_error!())?; + self.position = end.checked_add(1).ok_or(oob_err())?; } else { self.position = end; } String::from_utf8(string_data.to_vec()).map_err(|e| { - malformed_error!( - "Invalid UTF-8 string at offset {}-{}: {}", - start, - end, - e.utf8_error() + invalid_err( + "utf8_string", + format!( + "invalid UTF-8 string at offset {start}-{end}: {}", + e.utf8_error() + ), ) }) } @@ -829,21 +843,19 @@ impl<'a> Parser<'a> { /// ``` pub fn read_prefixed_string_utf8(&mut self) -> Result { let length = self.read_7bit_encoded_int()? as usize; - let end = self - .position - .checked_add(length) - .ok_or(out_of_bounds_error!())?; + let end = self.position.checked_add(length).ok_or(oob_err())?; let start = self.position; - let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + let string_data = self.data.get(start..end).ok_or(oob_err())?; self.position = end; String::from_utf8(string_data.to_vec()).map_err(|e| { - malformed_error!( - "Invalid UTF-8 string at offset {}-{}: {}", - start, - end, - e.utf8_error() + invalid_err( + "utf8_string", + format!( + "invalid UTF-8 string at offset {start}-{end}: {}", + e.utf8_error() + ), ) }) } @@ -885,17 +897,15 @@ impl<'a> Parser<'a> { pub fn read_prefixed_string_utf8_ref(&mut self) -> Result<&'a str> { let length = self.read_7bit_encoded_int()? as usize; let start = self.position; - let end = start.checked_add(length).ok_or(out_of_bounds_error!())?; + let end = start.checked_add(length).ok_or(oob_err())?; - let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + let string_data = self.data.get(start..end).ok_or(oob_err())?; self.position = end; std::str::from_utf8(string_data).map_err(|_| { - malformed_error!( - "Invalid UTF-8 string at position {} - {} - {:?}", - start, - end, - string_data + invalid_err( + "utf8_string", + format!("invalid UTF-8 string at position {start} - {end} - {string_data:?}"), ) }) } @@ -926,17 +936,18 @@ impl<'a> Parser<'a> { pub fn read_compressed_string_utf8(&mut self) -> Result { let length = self.read_compressed_uint()? as usize; let start = self.position; - let end = start.checked_add(length).ok_or(out_of_bounds_error!())?; + let end = start.checked_add(length).ok_or(oob_err())?; - let string_data = self.data.get(start..end).ok_or(out_of_bounds_error!())?; + let string_data = self.data.get(start..end).ok_or(oob_err())?; self.position = end; String::from_utf8(string_data.to_vec()).map_err(|e| { - malformed_error!( - "Invalid UTF-8 compressed string at offset {}-{}: {}", - start, - end, - e.utf8_error() + invalid_err( + "utf8_compressed_string", + format!( + "invalid UTF-8 compressed string at offset {start}-{end}: {}", + e.utf8_error() + ), ) }) } @@ -990,7 +1001,7 @@ impl<'a> Parser<'a> { /// ``` pub fn ensure_remaining(&self, needed: usize) -> Result<()> { if self.remaining() < needed { - return Err(out_of_bounds_error!()); + return Err(oob_err()); } Ok(()) } @@ -1023,13 +1034,10 @@ impl<'a> Parser<'a> { /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn calc_end_position(&self, length: usize) -> Result { - let end = self - .position - .checked_add(length) - .ok_or(out_of_bounds_error!())?; + let end = self.position.checked_add(length).ok_or(oob_err())?; if end > self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob_err()); } Ok(end) @@ -1060,10 +1068,7 @@ impl<'a> Parser<'a> { /// ``` pub fn read_bytes(&mut self, length: usize) -> Result<&'a [u8]> { let end = self.calc_end_position(length)?; - let bytes = self - .data - .get(self.position..end) - .ok_or(out_of_bounds_error!())?; + let bytes = self.data.get(self.position..end).ok_or(oob_err())?; self.position = end; Ok(bytes) } @@ -1093,16 +1098,16 @@ impl<'a> Parser<'a> { /// ``` pub fn read_prefixed_string_utf16(&mut self) -> Result { let length = self.read_7bit_encoded_int()? as usize; - let end = self - .position - .checked_add(length) - .ok_or(out_of_bounds_error!())?; + let end = self.position.checked_add(length).ok_or(oob_err())?; if end > self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob_err()); } if !length.is_multiple_of(2) || length < 2 { - return Err(malformed_error!("Invalid UTF-16 length - {}", length)); + return Err(invalid_err( + "utf16_length", + format!("invalid UTF-16 length - {length}"), + )); } let char_count = length / 2; @@ -1114,11 +1119,12 @@ impl<'a> Parser<'a> { match String::from_utf16(&utf16_chars) { Ok(s) => Ok(s), - Err(_) => Err(malformed_error!( - "Invalid UTF-16 str - {} - {} - {:?}", - self.position, - length, - utf16_chars + Err(_) => Err(invalid_err( + "utf16_string", + format!( + "invalid UTF-16 string at position {} (length {}): {utf16_chars:?}", + self.position, length + ), )), } } @@ -1127,7 +1133,7 @@ impl<'a> Parser<'a> { #[cfg(test)] mod tests { use super::*; - use crate::Error; + use crate::{Error, ParseFailure}; #[test] fn test_read_compressed_uint() { @@ -1150,7 +1156,7 @@ mod tests { let mut parser = Parser::new(&[]); assert!(matches!( parser.read_compressed_uint(), - Err(Error::OutOfBounds { .. }) + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) )); } @@ -1191,7 +1197,7 @@ mod tests { assert!(matches!(parser.read_compressed_uint(), Ok(8))); assert!(matches!( parser.read_compressed_uint(), - Err(Error::OutOfBounds { .. }) + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) )); } @@ -1334,7 +1340,7 @@ mod tests { let mut parser = Parser::new(&data); assert!(matches!( parser.read_prefixed_string_utf8_ref(), - Err(Error::OutOfBounds { .. }) + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) )); } @@ -1368,13 +1374,19 @@ mod tests { let data = [0x01]; let parser = Parser::new(&data); let result: Result = parser.peek_le(); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); // Test peek_le at end of data let mut parser = Parser::new(&data); parser.advance().unwrap(); let result: Result = parser.peek_le(); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); } #[test] diff --git a/dotscope/src/file/pe.rs b/dotscope/src/file/pe.rs index dc485ff2..2ee870ba 100644 --- a/dotscope/src/file/pe.rs +++ b/dotscope/src/file/pe.rs @@ -40,8 +40,18 @@ use crate::{ utils::{read_le_at, write_le_at}, - Error, Result, + Error, ParseFailure, ParseStage, Result, }; + +#[inline] +fn pe_invalid(field: &'static str, reason: String) -> Error { + ParseFailure::InvalidField { + stage: ParseStage::OptionalHeader, + field, + reason, + } + .into() +} use std::collections::HashMap; use std::fmt; use std::io::Write; @@ -550,11 +560,17 @@ impl Pe { .transpose()?; if optional_header.is_none() { - return Err(malformed_error!("File does not have an OptionalHeader")); + return Err(pe_invalid( + "optional_header", + "PE file does not have an OptionalHeader".into(), + )); } if goblin_pe.image_base == 0 { - return Err(malformed_error!("PE has invalid zero image base")); + return Err(pe_invalid( + "image_base", + "PE has invalid zero image base".into(), + )); } let sections = goblin_pe @@ -812,8 +828,9 @@ impl Pe { ); Ok(()) } else { - Err(malformed_error!( - "Cannot update CLR data directory: PE has no optional header" + Err(pe_invalid( + "optional_header", + "cannot update CLR data directory: PE has no optional header".into(), )) } } @@ -850,8 +867,9 @@ impl Pe { ); Ok(()) } else { - Err(malformed_error!( - "Cannot update data directory: PE has no optional header" + Err(pe_invalid( + "optional_header", + "cannot update data directory: PE has no optional header".into(), )) } } @@ -873,8 +891,9 @@ impl Pe { optional_header.windows_fields.size_of_image = new_size; Ok(()) } else { - Err(malformed_error!( - "Cannot update SizeOfImage: PE has no optional header" + Err(pe_invalid( + "optional_header", + "cannot update SizeOfImage: PE has no optional header".into(), )) } } @@ -896,8 +915,9 @@ impl Pe { optional_header.windows_fields.size_of_headers = new_size; Ok(()) } else { - Err(malformed_error!( - "Cannot update SizeOfHeaders: PE has no optional header" + Err(pe_invalid( + "optional_header", + "cannot update SizeOfHeaders: PE has no optional header".into(), )) } } @@ -1198,10 +1218,12 @@ impl OptionalHeader { 0x10b => false, // PE32 0x20b => true, // PE32+ magic => { - return Err(malformed_error!( - "Invalid PE optional header magic: 0x{:x} (expected 0x10b or 0x20b)", - magic - )) + return Err(ParseFailure::BadMagic { + stage: ParseStage::OptionalHeader, + expected: 0x10b, + found: u32::from(magic), + } + .into()) } }; @@ -1259,14 +1281,14 @@ impl StandardFields { major_linker_version: goblin_sf.major_linker_version, minor_linker_version: goblin_sf.minor_linker_version, size_of_code: u32::try_from(goblin_sf.size_of_code) - .map_err(|_| malformed_error!("PE size_of_code value too large"))?, + .map_err(|_| pe_invalid("size_of_code", "value too large".into()))?, size_of_initialized_data: u32::try_from(goblin_sf.size_of_initialized_data) - .map_err(|_| malformed_error!("PE size_of_initialized_data value too large"))?, + .map_err(|_| pe_invalid("size_of_initialized_data", "value too large".into()))?, size_of_uninitialized_data: u32::try_from(goblin_sf.size_of_uninitialized_data) - .map_err(|_| malformed_error!("PE size_of_uninitialized_data value too large"))?, + .map_err(|_| pe_invalid("size_of_uninitialized_data", "value too large".into()))?, address_of_entry_point: goblin_sf.address_of_entry_point, base_of_code: u32::try_from(goblin_sf.base_of_code) - .map_err(|_| malformed_error!("PE base_of_code value too large"))?, + .map_err(|_| pe_invalid("base_of_code", "value too large".into()))?, base_of_data: if goblin_sf.magic == 0x10b { Some(goblin_sf.base_of_data) } else { @@ -1299,11 +1321,10 @@ impl StandardFields { if let Some(base_of_data) = self.base_of_data { writer.write_all(&base_of_data.to_le_bytes())?; } else { - return Err(Error::Malformed { - message: "PE32 file missing base_of_data field".to_string(), - file: file!(), - line: line!(), - }); + return Err(pe_invalid( + "base_of_data", + "PE32 file missing base_of_data field".into(), + )); } } @@ -1362,7 +1383,7 @@ impl WindowsFields { } else { writer.write_all( &u32::try_from(self.image_base) - .map_err(|_| malformed_error!("Image base exceeds u32 range"))? + .map_err(|_| pe_invalid("image_base", "exceeds u32 range".into()))? .to_le_bytes(), )?; } @@ -1393,22 +1414,22 @@ impl WindowsFields { // PE32: 4-byte fields writer.write_all( &u32::try_from(self.size_of_stack_reserve) - .map_err(|_| malformed_error!("Stack reserve size exceeds u32 range"))? + .map_err(|_| pe_invalid("stack_reserve_size", "exceeds u32 range".into()))? .to_le_bytes(), )?; writer.write_all( &u32::try_from(self.size_of_stack_commit) - .map_err(|_| malformed_error!("Stack commit size exceeds u32 range"))? + .map_err(|_| pe_invalid("stack_commit_size", "exceeds u32 range".into()))? .to_le_bytes(), )?; writer.write_all( &u32::try_from(self.size_of_heap_reserve) - .map_err(|_| malformed_error!("Heap reserve size exceeds u32 range"))? + .map_err(|_| pe_invalid("heap_reserve_size", "exceeds u32 range".into()))? .to_le_bytes(), )?; writer.write_all( &u32::try_from(self.size_of_heap_commit) - .map_err(|_| malformed_error!("Heap commit size exceeds u32 range"))? + .map_err(|_| pe_invalid("heap_commit_size", "exceeds u32 range".into()))? .to_le_bytes(), )?; } @@ -1592,10 +1613,10 @@ impl SectionTable { fn from_goblin(goblin_section: &goblin::pe::section_table::SectionTable) -> Result { let name = std::str::from_utf8(&goblin_section.name) - .map_err(|_| Error::Malformed { - message: "Invalid section name".to_string(), - file: file!(), - line: line!(), + .map_err(|_| ParseFailure::InvalidField { + stage: ParseStage::SectionTable, + field: "section_name", + reason: "invalid UTF-8 in section name".into(), })? .trim_end_matches('\0') .to_string(); @@ -1653,9 +1674,9 @@ impl SectionTable { characteristics: u32, ) -> Result { let size_of_raw_data = u32::try_from(file_size) - .map_err(|_| malformed_error!("File size exceeds u32 range: {}", file_size))?; + .map_err(|_| pe_invalid("file_size", format!("exceeds u32 range: {file_size}")))?; let pointer_to_raw_data = u32::try_from(file_offset) - .map_err(|_| malformed_error!("File offset exceeds u32 range: {}", file_offset))?; + .map_err(|_| pe_invalid("file_offset", format!("exceeds u32 range: {file_offset}")))?; Ok(Self { name, @@ -1691,9 +1712,9 @@ impl SectionTable { /// Returns an error if the file offset or size exceed u32 range pub fn update_file_location(&mut self, file_offset: u64, file_size: u64) -> Result<()> { self.pointer_to_raw_data = u32::try_from(file_offset) - .map_err(|_| malformed_error!("File offset exceeds u32 range: {}", file_offset))?; + .map_err(|_| pe_invalid("file_offset", format!("exceeds u32 range: {file_offset}")))?; self.size_of_raw_data = u32::try_from(file_size) - .map_err(|_| malformed_error!("File size exceeds u32 range: {}", file_size))?; + .map_err(|_| pe_invalid("file_size", format!("exceeds u32 range: {file_size}")))?; Ok(()) } @@ -1717,11 +1738,15 @@ impl SectionTable { /// Returns an error if the name exceeds 8 bytes. pub fn set_name(&mut self, name: String) -> Result<()> { if name.len() > 8 { - return Err(malformed_error!( - "Section name '{}' exceeds 8-byte PE limit ({} bytes)", - name, - name.len() - )); + return Err(ParseFailure::InvalidField { + stage: ParseStage::SectionTable, + field: "section_name", + reason: format!( + "section name '{name}' exceeds 8-byte PE limit ({} bytes)", + name.len() + ), + } + .into()); } self.name = name; Ok(()) @@ -1779,10 +1804,10 @@ impl Import { None }, rva: u32::try_from(goblin_import.rva) - .map_err(|_| malformed_error!("PE import RVA value too large"))?, + .map_err(|_| pe_invalid("import_rva", "value too large".into()))?, hint: 0, // Not available from goblin ilt_value: u64::try_from(goblin_import.offset) - .map_err(|_| malformed_error!("PE import offset value too large"))?, + .map_err(|_| pe_invalid("import_offset", "value too large".into()))?, }) } @@ -1804,12 +1829,12 @@ impl Export { Ok(Self { name: goblin_export.name.map(ToString::to_string), rva: u32::try_from(goblin_export.rva) - .map_err(|_| malformed_error!("PE export RVA value too large"))?, + .map_err(|_| pe_invalid("export_rva", "value too large".into()))?, offset: goblin_export .offset .map(|o| { u32::try_from(o) - .map_err(|_| malformed_error!("PE export offset value too large")) + .map_err(|_| pe_invalid("export_offset", "value too large".into())) }) .transpose()?, }) @@ -2018,7 +2043,11 @@ pub fn relocate_resource_section(data: &mut [u8], old_rva: u32, new_rva: u32) -> // satisfies clippy's arithmetic_side_effects lint. let delta = i64::from(new_rva) .checked_sub(i64::from(old_rva)) - .ok_or_else(|| malformed_error!("Resource RVA delta overflow"))?; + .ok_or_else(|| ParseFailure::InvalidField { + stage: ParseStage::Resources, + field: "rva_delta", + reason: "resource RVA delta overflow".into(), + })?; // Process the root directory at offset 0 relocate_resource_directory(data, 0, delta) @@ -2028,18 +2057,23 @@ pub fn relocate_resource_section(data: &mut [u8], old_rva: u32, new_rva: u32) -> fn relocate_resource_directory(data: &mut [u8], offset: usize, delta: i64) -> Result<()> { // Read the directory header let dir = ImageResourceDirectory::read_from(data, offset)?; + let res_invalid = |field: &'static str, reason: String| ParseFailure::InvalidField { + stage: ParseStage::Resources, + field, + reason, + }; let entries_offset = offset .checked_add(IMAGE_RESOURCE_DIRECTORY_SIZE) - .ok_or_else(|| malformed_error!("Resource directory entries offset overflow"))?; + .ok_or_else(|| res_invalid("entries_offset", "directory entries offset overflow".into()))?; // Process each entry for i in 0..dir.entry_count() { let scaled = i .checked_mul(RESOURCE_ENTRY_SIZE) - .ok_or_else(|| malformed_error!("Resource entry index overflow"))?; + .ok_or_else(|| res_invalid("entry_index", "entry index overflow".into()))?; let entry_offset = entries_offset .checked_add(scaled) - .ok_or_else(|| malformed_error!("Resource entry offset overflow"))?; + .ok_or_else(|| res_invalid("entry_offset", "entry offset overflow".into()))?; let entry = ResourceEntry::read_from(data, entry_offset)?; if entry.is_directory() { @@ -2053,10 +2087,11 @@ fn relocate_resource_directory(data: &mut [u8], offset: usize, delta: i64) -> Re let old_data_rva: u32 = read_le_at(data, &mut pos)?; let new_data_rva = u32::try_from(i64::from(old_data_rva).saturating_add(delta)) .map_err(|_| { - malformed_error!( - "Resource RVA relocation overflow: old_rva={:#x}, delta={}", - old_data_rva, - delta + res_invalid( + "rva", + format!( + "RVA relocation overflow: old_rva={old_data_rva:#x}, delta={delta}", + ), ) })?; let mut pos = data_entry_offset; diff --git a/dotscope/src/formatting/tokens.rs b/dotscope/src/formatting/tokens.rs index 5e1fe8f3..fdda821c 100644 --- a/dotscope/src/formatting/tokens.rs +++ b/dotscope/src/formatting/tokens.rs @@ -168,7 +168,7 @@ fn format_memberref(mref: &MemberRef, asm: &CilObject) -> String { /// Resolves the underlying method (MethodDef or MemberRef) and appends /// the generic type arguments: `instance void Type::Method(params)` fn format_methodspec(assembly: &CilObject, token: &Token) -> Option { - let spec = assembly.method_spec(token)?; + let spec = assembly.method_spec(token).ok()?; // Format the underlying method reference let base_ref = match &spec.method { diff --git a/dotscope/src/lib.rs b/dotscope/src/lib.rs index 4595c9c8..a765d5c3 100644 --- a/dotscope/src/lib.rs +++ b/dotscope/src/lib.rs @@ -247,7 +247,7 @@ //! match CilObject::from_path(std::path::Path::new("tests/samples/crafted_2.exe")) { //! Ok(assembly) => println!("Successfully loaded assembly"), //! Err(Error::NotSupported) => println!("File format not supported"), -//! Err(Error::Malformed { message, .. }) => println!("Malformed file: {}", message), +//! Err(Error::Parse(parse_err)) => println!("Malformed file: {}", parse_err), //! Err(e) => println!("Other error: {}", e), //! } //! ``` @@ -907,11 +907,11 @@ pub type Result = std::result::Result; /// match CilObject::from_path(std::path::Path::new("tests/samples/crafted_2.exe")) { /// Ok(assembly) => println!("Loaded successfully"), /// Err(Error::NotSupported) => println!("File format not supported"), -/// Err(Error::Malformed { message, .. }) => println!("Malformed: {}", message), +/// Err(Error::Parse(parse_err)) => println!("Malformed: {}", parse_err), /// Err(e) => println!("Error: {}", e), /// } /// ``` -pub use error::Error; +pub use error::{Error, HeapKind, MethodLookupError, ParseFailure, ParseStage, StreamKind}; /// Raw assembly view for editing and modification operations. /// diff --git a/dotscope/src/metadata/cilobject.rs b/dotscope/src/metadata/cilobject.rs index f70c7546..26d7066f 100644 --- a/dotscope/src/metadata/cilobject.rs +++ b/dotscope/src/metadata/cilobject.rs @@ -196,7 +196,7 @@ use crate::{ validation::{ValidationConfig, ValidationEngine}, }, project::ProjectContext, - Error, Result, + Error, MethodLookupError, Result, }; /// A fully parsed and loaded .NET assembly representation. @@ -1486,13 +1486,12 @@ impl CilObject { &self.data.method_specs } - /// Returns the method definition for the given token, if it exists. + /// Returns the method definition for the given token. /// - /// This is a convenience accessor that looks up a method by its metadata token - /// and returns a cloned reference-counted pointer to the - /// [`Method`](crate::metadata::method::Method) object. It - /// eliminates the need to call [`methods()`](Self::methods), unwrap the `Entry` - /// guard, and clone the value manually. + /// Convenience accessor over [`methods()`](Self::methods) — looks up a + /// method by its metadata token and returns a cloned reference-counted + /// pointer to the [`Method`](crate::metadata::method::Method) object, + /// eliminating the boilerplate of unwrapping the `Entry` guard manually. /// /// # Arguments /// @@ -1500,8 +1499,19 @@ impl CilObject { /// /// # Returns /// - /// A reference-counted [`Method`](crate::metadata::method::Method) if a method with the - /// given token exists, `None` otherwise. + /// `Ok(`[`Method`](crate::metadata::method::Method)`)` reference-counted + /// if the token resolves; `Err(`[`MethodLookupError::NotFound`]`)` if the + /// token is not present in the `MethodDef` table. + /// + /// # Errors + /// + /// Returns [`Error::LookupMethod`]`(`[`MethodLookupError::NotFound`]`)` + /// when `token` does not match any row in the `MethodDef` table. **This + /// is distinct from RVA = 0.** A method that exists but has no IL body + /// (abstract, P/Invoke, runtime-managed) still resolves to `Ok(method)`; + /// inspect [`Method::rva_kind`](crate::metadata::method::Method::rva_kind) + /// to see why the body is absent. Match on `Error::LookupMethod(_)` to + /// recover the structured [`MethodLookupError`] variant. /// /// # Examples /// @@ -1512,13 +1522,20 @@ impl CilObject { /// let assembly = CilObject::from_path("tests/samples/WindowsBase.dll")?; /// let token = Token::new(0x06000001); /// - /// if let Some(method) = assembly.method(&token) { - /// println!("Method: {} (static: {})", method.name, method.is_static()); + /// match assembly.method(&token) { + /// Ok(method) => { + /// println!("Method: {} (static: {})", method.name, method.is_static()); + /// } + /// Err(e) => eprintln!("lookup failed: {e}"), /// } /// # Ok::<(), dotscope::Error>(()) /// ``` - pub fn method(&self, token: &Token) -> Option { - self.data.methods.get(token).map(|e| e.value().clone()) + pub fn method(&self, token: &Token) -> Result { + self.data + .methods + .get(token) + .map(|e| e.value().clone()) + .ok_or_else(|| MethodLookupError::NotFound(*token).into()) } /// Returns the member reference for the given token, if it exists. @@ -1533,8 +1550,11 @@ impl CilObject { /// /// # Returns /// - /// A reference-counted [`MemberRef`](crate::metadata::tables::MemberRef) if a member - /// reference with the given token exists, `None` otherwise. + /// - `Some(`[`MemberRef`](crate::metadata::tables::MemberRef)`)` reference-counted + /// if `token` resolves to a row in the `MemberRef` table. + /// - `None` only when `token` is not present in the `MemberRef` table. + /// This is a lookup miss, not a malformed-input signal — the assembly + /// may simply not reference the requested member. /// /// # Examples /// @@ -1554,11 +1574,12 @@ impl CilObject { self.data.refs_member.get(token).map(|e| e.value().clone()) } - /// Returns the method specification for the given token, if it exists. + /// Returns the method specification for the given token. /// - /// A method specification represents a generic method instantiation with concrete - /// type arguments. This is a convenience accessor that looks up a `MethodSpec` by - /// its metadata token and returns a cloned reference-counted pointer. + /// A method specification represents a generic method instantiation with + /// concrete type arguments. Convenience accessor that looks up a + /// `MethodSpec` by its metadata token and returns a cloned + /// reference-counted pointer. /// /// # Arguments /// @@ -1566,8 +1587,16 @@ impl CilObject { /// /// # Returns /// - /// A reference-counted [`MethodSpec`](crate::metadata::tables::MethodSpec) if a method - /// specification with the given token exists, `None` otherwise. + /// `Ok(`[`MethodSpec`](crate::metadata::tables::MethodSpec)`)` reference-counted + /// if the token resolves; `Err(`[`MethodLookupError::SpecNotFound`]`)` if the + /// token is not present in the `MethodSpec` table. + /// + /// # Errors + /// + /// Returns [`Error::LookupMethod`]`(`[`MethodLookupError::SpecNotFound`]`)` + /// when `token` does not match any row in the `MethodSpec` table. Match + /// on `Error::LookupMethod(_)` to recover the structured + /// [`MethodLookupError`] variant. /// /// # Examples /// @@ -1578,14 +1607,18 @@ impl CilObject { /// let assembly = CilObject::from_path("tests/samples/WindowsBase.dll")?; /// let token = Token::new(0x2B000001); /// - /// if let Some(spec) = assembly.method_spec(&token) { + /// if let Ok(spec) = assembly.method_spec(&token) { /// println!("MethodSpec: {:?} with {} generic args", /// spec.method.token(), spec.generic_args.count()); /// } /// # Ok::<(), dotscope::Error>(()) /// ``` - pub fn method_spec(&self, token: &Token) -> Option { - self.data.method_specs.get(token).map(|e| e.value().clone()) + pub fn method_spec(&self, token: &Token) -> Result { + self.data + .method_specs + .get(token) + .map(|e| e.value().clone()) + .ok_or_else(|| MethodLookupError::SpecNotFound(*token).into()) } /// Resolves a method-like token to its simple method name. @@ -1608,9 +1641,19 @@ impl CilObject { /// /// # Returns /// - /// The method name as a `String` if the token refers to a known method, `None` - /// if the token table is not one of the three supported types or if the entry - /// does not exist. + /// - `Some(name)` if `token` resolves to a row in the `MethodDef`, + /// `MemberRef`, or `MethodSpec` table. + /// - `None` in two distinct cases that this accessor does **not** + /// distinguish: + /// - `token` belongs to one of the three supported tables but the row + /// is missing (lookup miss). + /// - `token`'s table is not one of the three supported types (caller + /// bug — the helper is intended for call-site operands only). + /// + /// If you need to distinguish those cases, dispatch on + /// [`Token::table`](crate::metadata::token::Token::table) yourself and + /// call [`method`](Self::method) / [`member_ref`](Self::member_ref) / + /// [`method_spec`](Self::method_spec) directly. /// /// # Examples /// diff --git a/dotscope/src/metadata/customattributes/parser.rs b/dotscope/src/metadata/customattributes/parser.rs index 83962918..729dfd38 100644 --- a/dotscope/src/metadata/customattributes/parser.rs +++ b/dotscope/src/metadata/customattributes/parser.rs @@ -1956,13 +1956,15 @@ mod tests { let result = parse_custom_attribute_data(blob_data, &method.params); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); - // Be more flexible with error message matching - accept "Out of Bound" messages too + // Be more flexible with error message matching - accept "Out of Bound" + // and "read past end of buffer" (ParseFailure::OutOfBounds Display) too. assert!( error_msg.contains("data") || error_msg.contains("I4") || error_msg.contains("enough") || error_msg.contains("Out of Bound") - || error_msg.contains("bound"), + || error_msg.contains("bound") + || error_msg.contains("read past end"), "Error should mention data, I4, or bound issue: {error_msg}" ); diff --git a/dotscope/src/metadata/identity/cryptographic.rs b/dotscope/src/metadata/identity/cryptographic.rs index 6e0e2434..99730b5e 100644 --- a/dotscope/src/metadata/identity/cryptographic.rs +++ b/dotscope/src/metadata/identity/cryptographic.rs @@ -75,7 +75,7 @@ use crate::utils::{compute_md5, compute_sha1}; use crate::{ metadata::tables::AssemblyHashAlgorithm, utils::{compute_sha256, compute_sha384, compute_sha512, read_le}, - Result, + ParseFailure, ParseStage, Result, }; /// Assembly identity representation for .NET CIL assemblies. @@ -348,8 +348,10 @@ impl Identity { }; // Token is the last 8 bytes of the hash as little-endian u64 let start = hash.len().saturating_sub(8); - read_le::(hash.get(start..).ok_or_else(|| { - malformed_error!("Hash output is too short to extract public key token") + read_le::(hash.get(start..).ok_or(ParseFailure::Truncated { + stage: ParseStage::Generic, + expected: 8, + found: hash.len(), })?) } } diff --git a/dotscope/src/metadata/marshalling/parser.rs b/dotscope/src/metadata/marshalling/parser.rs index 33fbac79..4846f18e 100644 --- a/dotscope/src/metadata/marshalling/parser.rs +++ b/dotscope/src/metadata/marshalling/parser.rs @@ -404,7 +404,7 @@ impl<'a> MarshallingParser<'a> { #[cfg(test)] mod tests { use super::*; - use crate::Error; + use crate::{Error, ParseFailure}; #[test] fn test_parse_simple_types() { @@ -592,7 +592,10 @@ mod tests { let input: Vec = vec![]; let result = parse_marshalling_descriptor(&input); assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), Error::OutOfBounds { .. })); + assert!(matches!( + result.unwrap_err(), + Error::Parse(ParseFailure::OutOfBounds { .. }) + )); // Test unknown native type let input = vec![0xFF]; @@ -603,7 +606,10 @@ mod tests { let input = vec![NATIVE_TYPE::LPSTR, 0xC0]; // 4-byte format but only one byte available let result = parse_marshalling_descriptor(&input); assert!(result.is_err()); - assert!(matches!(result.unwrap_err(), Error::OutOfBounds { .. })); + assert!(matches!( + result.unwrap_err(), + Error::Parse(ParseFailure::OutOfBounds { .. }) + )); } #[test] diff --git a/dotscope/src/metadata/method/exceptions.rs b/dotscope/src/metadata/method/exceptions.rs index 8363f26d..c618c0af 100644 --- a/dotscope/src/metadata/method/exceptions.rs +++ b/dotscope/src/metadata/method/exceptions.rs @@ -300,6 +300,87 @@ impl ExceptionHandlerFlags { /// - Less common than finally handlers in typical .NET code /// - The `class_token`/`filter_offset` field is unused for fault handlers pub const FAULT: Self = Self(0x0004); + + /// Returns the discrete handler kind for this flag value. + /// + /// Performs the bitwise classification described by ECMA-335 §II.25.4.6 and + /// returns a typed enum, removing the need for callers to match against the + /// individual `EXCEPTION`/`FILTER`/`FINALLY`/`FAULT` constants. Unrecognized + /// flag values fall back to [`ExceptionHandlerKind::Catch`] because + /// `EXCEPTION` is the zero pattern. + /// + /// # Examples + /// + /// ```rust + /// use dotscope::metadata::method::{ExceptionHandlerFlags, ExceptionHandlerKind}; + /// + /// assert_eq!(ExceptionHandlerFlags::EXCEPTION.kind(), ExceptionHandlerKind::Catch); + /// assert_eq!(ExceptionHandlerFlags::FILTER.kind(), ExceptionHandlerKind::Filter); + /// assert_eq!(ExceptionHandlerFlags::FINALLY.kind(), ExceptionHandlerKind::Finally); + /// assert_eq!(ExceptionHandlerFlags::FAULT.kind(), ExceptionHandlerKind::Fault); + /// ``` + #[must_use] + pub const fn kind(&self) -> ExceptionHandlerKind { + // ECMA-335 §II.25.4.6: handler-kind bits are mutually exclusive. + // FILTER/FINALLY/FAULT are 0x0001/0x0002/0x0004; EXCEPTION (catch) is 0x0000. + if self.0 & ExceptionHandlerFlags::FILTER.0 != 0 { + ExceptionHandlerKind::Filter + } else if self.0 & ExceptionHandlerFlags::FINALLY.0 != 0 { + ExceptionHandlerKind::Finally + } else if self.0 & ExceptionHandlerFlags::FAULT.0 != 0 { + ExceptionHandlerKind::Fault + } else { + ExceptionHandlerKind::Catch + } + } +} + +/// Discrete classification of an exception handler clause. +/// +/// Returned by [`ExceptionHandlerFlags::kind`]. Each variant corresponds to one +/// of the four handler kinds defined by ECMA-335 §II.25.4.6. +/// +/// # Stability +/// +/// The string returned by [`ExceptionHandlerKind::as_str`] and by the [`Display`] +/// impl is part of the stable public API. It is safe to persist (file, database, +/// log line) and to parse. +/// +/// [`Display`]: std::fmt::Display +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum ExceptionHandlerKind { + /// Typed exception clause — the protected region's exceptions are matched + /// against the type referenced by the handler's class token. + Catch, + /// Filtered exception clause — IL at `filter_offset` decides whether the + /// handler runs. + Filter, + /// Finally clause — runs unconditionally on protected-region exit. + Finally, + /// Fault clause — runs only when the protected region exits via exception. + Fault, +} + +impl ExceptionHandlerKind { + /// Returns a stable `&'static str` identifier for this kind. + /// + /// Identifiers are lowercase (`"catch"`, `"filter"`, `"finally"`, `"fault"`). + /// They are part of the stable public API and safe to persist. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + ExceptionHandlerKind::Catch => "catch", + ExceptionHandlerKind::Filter => "filter", + ExceptionHandlerKind::Finally => "finally", + ExceptionHandlerKind::Fault => "fault", + } + } +} + +impl std::fmt::Display for ExceptionHandlerKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } } /// Represents a single exception handler within a .NET method body. @@ -1058,4 +1139,39 @@ mod tests { assert!(!filter_handler.is_fault()); assert!(filter_handler.is_filter()); } + + #[test] + fn test_exception_handler_flags_kind() { + assert_eq!( + ExceptionHandlerFlags::EXCEPTION.kind(), + ExceptionHandlerKind::Catch + ); + assert_eq!( + ExceptionHandlerFlags::FILTER.kind(), + ExceptionHandlerKind::Filter + ); + assert_eq!( + ExceptionHandlerFlags::FINALLY.kind(), + ExceptionHandlerKind::Finally + ); + assert_eq!( + ExceptionHandlerFlags::FAULT.kind(), + ExceptionHandlerKind::Fault + ); + } + + #[test] + fn test_exception_handler_kind_stable_strings() { + // The strings here are part of the stable public API. Changing them is + // a breaking change for downstream consumers that persist these values. + assert_eq!(ExceptionHandlerKind::Catch.as_str(), "catch"); + assert_eq!(ExceptionHandlerKind::Filter.as_str(), "filter"); + assert_eq!(ExceptionHandlerKind::Finally.as_str(), "finally"); + assert_eq!(ExceptionHandlerKind::Fault.as_str(), "fault"); + + assert_eq!(format!("{}", ExceptionHandlerKind::Catch), "catch"); + assert_eq!(format!("{}", ExceptionHandlerKind::Filter), "filter"); + assert_eq!(format!("{}", ExceptionHandlerKind::Finally), "finally"); + assert_eq!(format!("{}", ExceptionHandlerKind::Fault), "fault"); + } } diff --git a/dotscope/src/metadata/method/mod.rs b/dotscope/src/metadata/method/mod.rs index 0aab95ea..af747300 100644 --- a/dotscope/src/metadata/method/mod.rs +++ b/dotscope/src/metadata/method/mod.rs @@ -444,6 +444,69 @@ impl From for MethodRef { } } +/// Disambiguated reason for a [`Method`]'s RVA value. +/// +/// Returned by [`Method::rva_kind`]. Lets callers distinguish between a method +/// with a real IL body (`Resolved(addr)`), one of the legitimate "no IL" +/// cases (`Abstract`, `PInvoke`, `Runtime`), and a malformed-input case +/// (`UnresolvedZero` — RVA is `None` with no flag explaining why). +/// +/// `Method.rva` carried `Option` does not surface this distinction; +/// `rva_kind()` is the recommended accessor for downstream code that needs +/// to log a meaningful reason or take a code path based on the RVA's +/// provenance. +/// +/// # Stability +/// +/// `#[non_exhaustive]` so additional well-known classifications (e.g. an +/// unmanaged-export variant) can be added later without a breaking change. +/// Consumers must include a wildcard arm when matching. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[non_exhaustive] +pub enum MethodRvaKind { + /// IL code lives at the given RVA. Equivalent to `Method.rva == Some(addr)`. + Resolved(u32), + /// Method is declared abstract — body is supplied by an overriding type + /// in this or another assembly. + Abstract, + /// Method is implemented via Platform Invoke — body lives in an external + /// native module referenced by `MethodModifiers::PINVOKE_IMPL`. + PInvoke, + /// Method is runtime-managed — implemented by the CLR rather than by IL + /// in this assembly. Identified by `MethodImplCodeType::RUNTIME`. + Runtime, + /// `Method.rva` is `None` but no flag explains why. Indicates malformed + /// metadata or an unsupported method-implementation form. + UnresolvedZero, +} + +impl MethodRvaKind { + /// Returns a stable `&'static str` identifier for this RVA kind. + /// + /// Variants render as `"resolved"`, `"abstract"`, `"pinvoke"`, + /// `"runtime"`, `"unresolved_zero"`. The strings are part of the stable + /// public API and safe to persist. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + MethodRvaKind::Resolved(_) => "resolved", + MethodRvaKind::Abstract => "abstract", + MethodRvaKind::PInvoke => "pinvoke", + MethodRvaKind::Runtime => "runtime", + MethodRvaKind::UnresolvedZero => "unresolved_zero", + } + } +} + +impl std::fmt::Display for MethodRvaKind { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + MethodRvaKind::Resolved(addr) => write!(f, "resolved(0x{addr:08X})"), + other => f.write_str(other.as_str()), + } + } +} + /// Represents all the information about a CIL method. /// /// The `Method` struct contains all metadata, code, and analysis results for a single .NET method. @@ -1015,6 +1078,72 @@ impl Method { self.flags_modifiers.contains(MethodModifiers::ABSTRACT) } + /// Returns the disambiguated reason for this method's RVA value. + /// + /// `Method.rva` is `Some(addr)` for methods with IL bodies and `None` + /// for methods that legitimately have no IL (abstract, P/Invoke, + /// runtime-managed) — but the bare `Option` cannot tell consumers *why* + /// a method has no body. This accessor classifies by examining the + /// method's `MethodModifiers` and `MethodImplCodeType` flags so callers + /// can log a meaningful reason instead of treating "RVA = 0" as + /// generic missing-data. + /// + /// The classification follows ECMA-335 §II.23.1.10: `ABSTRACT` and + /// `PINVOKE_IMPL` are method-attribute bits (`MethodModifiers`), and the + /// `RUNTIME` impl-code-type bits live in `MethodImplCodeType`. Variants + /// are tested in priority order — `Abstract` first, then `PInvoke`, then + /// `Runtime` — so a method tagged with multiple flags reports the most + /// informative reason. + /// + /// # Examples + /// + /// ```rust,no_run + /// use dotscope::CilObject; + /// use dotscope::metadata::method::MethodRvaKind; + /// use dotscope::metadata::token::Token; + /// + /// let assembly = CilObject::from_path("tests/samples/WindowsBase.dll")?; + /// let token = Token::new(0x06000001); + /// + /// if let Ok(method) = assembly.method(&token) { + /// match method.rva_kind() { + /// MethodRvaKind::Resolved(addr) => { + /// println!("IL at 0x{addr:08X}"); + /// } + /// MethodRvaKind::Abstract => { + /// println!("abstract — no body in this assembly"); + /// } + /// MethodRvaKind::PInvoke => { + /// println!("P/Invoke — body lives in a native module"); + /// } + /// MethodRvaKind::Runtime => { + /// println!("runtime-managed — implemented by the CLR"); + /// } + /// MethodRvaKind::UnresolvedZero => { + /// println!("RVA = 0 with no flag explanation; suspicious metadata"); + /// } + /// _ => {} + /// } + /// } + /// # Ok::<(), dotscope::Error>(()) + /// ``` + #[must_use] + pub fn rva_kind(&self) -> MethodRvaKind { + if let Some(addr) = self.rva { + return MethodRvaKind::Resolved(addr); + } + if self.flags_modifiers.contains(MethodModifiers::ABSTRACT) { + return MethodRvaKind::Abstract; + } + if self.flags_modifiers.contains(MethodModifiers::PINVOKE_IMPL) { + return MethodRvaKind::PInvoke; + } + if self.impl_code_type.contains(MethodImplCodeType::RUNTIME) { + return MethodRvaKind::Runtime; + } + MethodRvaKind::UnresolvedZero + } + /// Returns true if the method has public access. #[must_use] pub fn is_public(&self) -> bool { @@ -2144,4 +2273,66 @@ mod tests { // Default method should not have forwarded pinvoke flag assert!(!method.is_forwarded_pinvoke()); } + + #[test] + fn test_method_rva_kind() { + // Default: rva = Some(0x1000) → Resolved. + let resolved = MethodBuilder::new().with_rva(0x1234).build(); + assert_eq!(resolved.rva_kind(), MethodRvaKind::Resolved(0x1234)); + + // Abstract: no rva, ABSTRACT modifier set. + let abstract_m = MethodBuilder::new() + .without_rva() + .with_modifiers(MethodModifiers::ABSTRACT) + .build(); + assert_eq!(abstract_m.rva_kind(), MethodRvaKind::Abstract); + + // P/Invoke: no rva, PINVOKE_IMPL modifier set. + let pinvoke = MethodBuilder::new() + .without_rva() + .with_modifiers(MethodModifiers::PINVOKE_IMPL) + .build(); + assert_eq!(pinvoke.rva_kind(), MethodRvaKind::PInvoke); + + // Runtime: no rva, MethodImplCodeType::RUNTIME. + let runtime = MethodBuilder::new() + .without_rva() + .with_impl_code_type(MethodImplCodeType::RUNTIME) + .build(); + assert_eq!(runtime.rva_kind(), MethodRvaKind::Runtime); + + // UnresolvedZero: no rva, no flag explanation — malformed metadata. + let unresolved = MethodBuilder::new().without_rva().build(); + assert_eq!(unresolved.rva_kind(), MethodRvaKind::UnresolvedZero); + + // Priority: when both ABSTRACT and a runtime-managed flag are set, + // Abstract takes precedence (it's the most informative reason). + let abstract_runtime = MethodBuilder::new() + .without_rva() + .with_modifiers(MethodModifiers::ABSTRACT) + .with_impl_code_type(MethodImplCodeType::RUNTIME) + .build(); + assert_eq!(abstract_runtime.rva_kind(), MethodRvaKind::Abstract); + } + + #[test] + fn test_method_rva_kind_stable_strings() { + // Strings are part of the stable public API. + assert_eq!(MethodRvaKind::Resolved(0).as_str(), "resolved"); + assert_eq!(MethodRvaKind::Abstract.as_str(), "abstract"); + assert_eq!(MethodRvaKind::PInvoke.as_str(), "pinvoke"); + assert_eq!(MethodRvaKind::Runtime.as_str(), "runtime"); + assert_eq!(MethodRvaKind::UnresolvedZero.as_str(), "unresolved_zero"); + + // Display renders Resolved with the address; other variants delegate to as_str. + assert_eq!( + format!("{}", MethodRvaKind::Resolved(0x1234)), + "resolved(0x00001234)" + ); + assert_eq!(format!("{}", MethodRvaKind::Abstract), "abstract"); + assert_eq!( + format!("{}", MethodRvaKind::UnresolvedZero), + "unresolved_zero" + ); + } } diff --git a/dotscope/src/metadata/method/types.rs b/dotscope/src/metadata/method/types.rs index 771e9ed1..8724bd2a 100644 --- a/dotscope/src/metadata/method/types.rs +++ b/dotscope/src/metadata/method/types.rs @@ -622,6 +622,115 @@ impl MethodAccessFlags { } } } + + /// Returns the discrete access level as a typed enum. + /// + /// This is the ergonomic counterpart to matching against the + /// `COMPILER_CONTROLLED`/`PRIVATE`/…/`PUBLIC` constants. Downstream code that + /// only needs to distinguish access levels (without preserving the raw bit + /// pattern) should prefer this accessor — pattern matching on the returned + /// [`MethodAccessLevel`] is exhaustive and cannot silently miss a variant. + /// + /// Unrecognized bit patterns (which should not appear in valid metadata) + /// fall back to [`MethodAccessLevel::CompilerControlled`]. + /// + /// # Examples + /// + /// ```rust + /// use dotscope::metadata::method::{MethodAccessFlags, MethodAccessLevel}; + /// + /// assert_eq!( + /// MethodAccessFlags::PUBLIC.access_level(), + /// MethodAccessLevel::Public + /// ); + /// assert_eq!( + /// MethodAccessFlags::from_method_flags(0x0091).access_level(), + /// MethodAccessLevel::Private + /// ); + /// ``` + #[must_use] + pub const fn access_level(self) -> MethodAccessLevel { + match self.0 { + 0x0001 => MethodAccessLevel::Private, + 0x0002 => MethodAccessLevel::FamilyAndAssembly, + 0x0003 => MethodAccessLevel::Assembly, + 0x0004 => MethodAccessLevel::Family, + 0x0005 => MethodAccessLevel::FamilyOrAssembly, + 0x0006 => MethodAccessLevel::Public, + // 0x0000 and any unrecognized value — keep the most restrictive interpretation. + _ => MethodAccessLevel::CompilerControlled, + } + } +} + +/// Discrete method accessibility level per ECMA-335 §II.23.1.10. +/// +/// Returned by [`MethodAccessFlags::access_level`]. The seven variants are +/// fixed by the standard and are mutually exclusive — exactly one applies to +/// any given method. +/// +/// The hierarchy from most restrictive to least restrictive is: +/// [`CompilerControlled`](Self::CompilerControlled) → [`Private`](Self::Private) +/// → [`FamilyAndAssembly`](Self::FamilyAndAssembly) → [`Assembly`](Self::Assembly) +/// → [`Family`](Self::Family) → [`FamilyOrAssembly`](Self::FamilyOrAssembly) +/// → [`Public`](Self::Public). +/// +/// # Stability +/// +/// The string returned by [`MethodAccessLevel::as_str`] and by the [`Display`] +/// impl is part of the stable public API. The strings match the existing +/// [`Display`] impl on [`MethodAccessFlags`] (C# keyword form: `"private"`, +/// `"public"`, `"internal"`, `"protected"`, `"protected internal"`, +/// `"private protected"`, `"compilercontrolled"`). +/// +/// [`Display`]: std::fmt::Display +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum MethodAccessLevel { + /// Member is not referenceable by external code (ECMA-335 value `0x0000`). + CompilerControlled, + /// Accessible only by the parent type (ECMA-335 value `0x0001`). + Private, + /// Accessible by sub-types only within this assembly — C# `private protected` + /// (ECMA-335 value `0x0002`). + FamilyAndAssembly, + /// Accessible by anyone in the assembly — C# `internal` (ECMA-335 value `0x0003`). + Assembly, + /// Accessible by type and sub-types — C# `protected` (ECMA-335 value `0x0004`). + Family, + /// Accessible by sub-types anywhere, plus anyone in the assembly — C# `protected internal` + /// (ECMA-335 value `0x0005`). + FamilyOrAssembly, + /// Accessible by anyone with visibility to the declaring scope — C# `public` + /// (ECMA-335 value `0x0006`). + Public, +} + +impl MethodAccessLevel { + /// Returns a stable `&'static str` identifier for this access level. + /// + /// Strings match the existing [`Display`] impl on [`MethodAccessFlags`] + /// (C# keyword form). They are part of the stable public API and safe + /// to persist. + /// + /// [`Display`]: std::fmt::Display + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + MethodAccessLevel::CompilerControlled => "compilercontrolled", + MethodAccessLevel::Private => "private", + MethodAccessLevel::FamilyAndAssembly => "private protected", + MethodAccessLevel::Assembly => "internal", + MethodAccessLevel::Family => "protected", + MethodAccessLevel::FamilyOrAssembly => "protected internal", + MethodAccessLevel::Public => "public", + } + } +} + +impl std::fmt::Display for MethodAccessLevel { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } } impl PartialOrd for MethodAccessFlags { @@ -695,6 +804,81 @@ mod tests { assert_eq!(MethodAccessFlags::PUBLIC, MethodAccessFlags::PUBLIC); assert!(MethodAccessFlags::PUBLIC >= MethodAccessFlags::PUBLIC); } + + #[test] + fn test_method_access_flags_access_level() { + assert_eq!( + MethodAccessFlags::COMPILER_CONTROLLED.access_level(), + MethodAccessLevel::CompilerControlled + ); + assert_eq!( + MethodAccessFlags::PRIVATE.access_level(), + MethodAccessLevel::Private + ); + assert_eq!( + MethodAccessFlags::FAMILY_AND_ASSEMBLY.access_level(), + MethodAccessLevel::FamilyAndAssembly + ); + assert_eq!( + MethodAccessFlags::ASSEMBLY.access_level(), + MethodAccessLevel::Assembly + ); + assert_eq!( + MethodAccessFlags::FAMILY.access_level(), + MethodAccessLevel::Family + ); + assert_eq!( + MethodAccessFlags::FAMILY_OR_ASSEMBLY.access_level(), + MethodAccessLevel::FamilyOrAssembly + ); + assert_eq!( + MethodAccessFlags::PUBLIC.access_level(), + MethodAccessLevel::Public + ); + + // Accessor must agree with from_method_flags's mask handling. + assert_eq!( + MethodAccessFlags::from_method_flags(0x0091).access_level(), + MethodAccessLevel::Private + ); + } + + #[test] + fn test_method_access_level_stable_strings() { + // The strings here are part of the stable public API. They must match + // the existing Display impl on MethodAccessFlags so consumers can swap + // accessors without string drift. + assert_eq!( + MethodAccessLevel::CompilerControlled.as_str(), + "compilercontrolled" + ); + assert_eq!(MethodAccessLevel::Private.as_str(), "private"); + assert_eq!( + MethodAccessLevel::FamilyAndAssembly.as_str(), + "private protected" + ); + assert_eq!(MethodAccessLevel::Assembly.as_str(), "internal"); + assert_eq!(MethodAccessLevel::Family.as_str(), "protected"); + assert_eq!( + MethodAccessLevel::FamilyOrAssembly.as_str(), + "protected internal" + ); + assert_eq!(MethodAccessLevel::Public.as_str(), "public"); + + // Display delegates to as_str — verify by formatting one variant. + assert_eq!(format!("{}", MethodAccessLevel::Public), "public"); + + // Strings must match MethodAccessFlags's Display so visus and other + // consumers see no string drift between the two. + assert_eq!( + format!("{}", MethodAccessLevel::Public), + format!("{}", MethodAccessFlags::PUBLIC) + ); + assert_eq!( + format!("{}", MethodAccessLevel::FamilyAndAssembly), + format!("{}", MethodAccessFlags::FAMILY_AND_ASSEMBLY) + ); + } } metadata_flags! { diff --git a/dotscope/src/metadata/resolver.rs b/dotscope/src/metadata/resolver.rs index 65afa55c..f40cf5a6 100644 --- a/dotscope/src/metadata/resolver.rs +++ b/dotscope/src/metadata/resolver.rs @@ -211,7 +211,7 @@ impl<'a> TokenResolver<'a> { }) } 0x2B => { - let method_spec = self.assembly.method_spec(&token)?; + let method_spec = self.assembly.method_spec(&token).ok()?; let underlying = Self::extract_methodspec_token(&method_spec.method)?; self.resolve_method(underlying) } @@ -474,7 +474,7 @@ impl<'a> TokenResolver<'a> { None } 0x2B => { - let method_spec = self.assembly.method_spec(&method_token)?; + let method_spec = self.assembly.method_spec(&method_token).ok()?; let underlying = Self::extract_methodspec_token(&method_spec.method)?; self.declaring_type(underlying) } diff --git a/dotscope/src/metadata/resources/parser.rs b/dotscope/src/metadata/resources/parser.rs index ae828a05..45e8ed1a 100644 --- a/dotscope/src/metadata/resources/parser.rs +++ b/dotscope/src/metadata/resources/parser.rs @@ -82,7 +82,7 @@ use crate::{ metadata::resources::{ ResourceEntry, ResourceEntryRef, ResourceType, ResourceTypeRef, RESOURCE_MAGIC, }, - Result, + ParseFailure, ParseStage, Result, }; /// Maximum number of resource types allowed in a resource file. @@ -454,7 +454,12 @@ impl Resource { /// - **Array Bounds**: Ensures hash and position arrays match resource count pub fn parse(data: &[u8]) -> Result { if data.len() < 12 { - return Err(malformed_error!("Resource data too small")); + return Err(ParseFailure::Truncated { + stage: ParseStage::Resources, + expected: 12, + found: data.len(), + } + .into()); } let mut parser = Parser::new(data); diff --git a/dotscope/src/metadata/root.rs b/dotscope/src/metadata/root.rs index f902c6ed..b2f99e8a 100644 --- a/dotscope/src/metadata/root.rs +++ b/dotscope/src/metadata/root.rs @@ -41,7 +41,7 @@ use crate::{ metadata::streams::StreamHeader, utils::{read_le, read_le_at}, - Result, + Error, ParseFailure, ParseStage, Result, }; use std::io::Write; @@ -351,34 +351,52 @@ impl Root { /// This method is thread-safe and can be called concurrently from multiple threads /// as it performs no mutations and uses only stack-allocated temporary variables. pub fn read(data: &[u8]) -> Result { + let invalid = |field: &'static str, reason: String| ParseFailure::InvalidField { + stage: ParseStage::MetadataRoot, + field, + reason, + }; + let oob = || ParseFailure::OutOfBounds { + stage: ParseStage::MetadataRoot, + }; + if data.len() < MIN_ROOT_HEADER_SIZE { - return Err(out_of_bounds_error!()); + return Err(ParseFailure::Truncated { + stage: ParseStage::MetadataRoot, + expected: MIN_ROOT_HEADER_SIZE, + found: data.len(), + } + .into()); } let signature = read_le::(data)?; if signature != CIL_HEADER_MAGIC { - return Err(malformed_error!( - "Root: invalid signature 0x{:08X}, expected 0x{:08X} [ECMA-335 §II.24.2.1]", - signature, - CIL_HEADER_MAGIC - )); + return Err(ParseFailure::BadMagic { + stage: ParseStage::MetadataRoot, + expected: CIL_HEADER_MAGIC, + found: signature, + } + .into()); } let version_string_length = read_le_at::(data, &mut { VERSION_LENGTH_OFFSET })?; match version_string_length.checked_add(u32::from(VERSION_STRING_OFFSET)) { Some(str_end) => { let data_len = u32::try_from(data.len()).map_err(|_| { - malformed_error!("Root: data length too large [ECMA-335 §II.24.2.1]") + invalid("data_length", "metadata root data length too large".into()) })?; if str_end > data_len { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } } None => { - return Err(malformed_error!( - "Root: version string length {} causes integer overflow [ECMA-335 §II.24.2.1]", - version_string_length - )) + return Err(invalid( + "version_string_length", + format!( + "{version_string_length} causes integer overflow [ECMA-335 §II.24.2.1]", + ), + ) + .into()) } } @@ -386,32 +404,39 @@ impl Root { for counter in 0..version_string_length { let mut pos = usize::from(VERSION_STRING_OFFSET) .checked_add(counter as usize) - .ok_or_else(|| malformed_error!("version string offset overflow"))?; + .ok_or_else(|| invalid("version_string_offset", "offset overflow".into()))?; version_string.push(char::from(read_le_at::(data, &mut pos)?)); } // Validate version string format and content if version_string.is_empty() { - return Err(malformed_error!( - "Root: version string cannot be empty [ECMA-335 §II.24.2.1]" - )); + return Err(invalid( + "version_string", + "version string cannot be empty [ECMA-335 §II.24.2.1]".into(), + ) + .into()); } // Check for common malformed version strings if !version_string.starts_with('v') { - return Err(malformed_error!( - "Root: version string '{}' must start with 'v' [ECMA-335 §II.24.2.1]", - version_string - )); + return Err(invalid( + "version_string", + format!("'{version_string}' must start with 'v' [ECMA-335 §II.24.2.1]",), + ) + .into()); } // Validate version string contains reasonable content if version_string.len() > MAX_VERSION_STRING_LENGTH { - return Err(malformed_error!( - "Root: version string length {} exceeds reasonable limit ({}) [ECMA-335 §II.24.2.1]", - version_string.len(), - MAX_VERSION_STRING_LENGTH - )); + return Err(invalid( + "version_string", + format!( + "length {} exceeds reasonable limit ({}) [ECMA-335 §II.24.2.1]", + version_string.len(), + MAX_VERSION_STRING_LENGTH + ), + ) + .into()); } // Stream count is located after: version_string + FLAGS_FIELD_SIZE @@ -419,19 +444,19 @@ impl Root { .len() .checked_add(usize::from(VERSION_STRING_OFFSET)) .and_then(|v| v.checked_add(FLAGS_FIELD_SIZE)) - .ok_or_else(|| malformed_error!("stream count offset overflow"))?; + .ok_or_else(|| invalid("stream_count_offset", "offset overflow".into()))?; let stream_count = read_le_at::(data, &mut stream_count_offset)?; // Validate stream count: must have at least one stream, no more than MAX_STREAM_COUNT let stream_count_size = (stream_count as usize) .checked_mul(MIN_STREAM_HEADER_SIZE) - .ok_or_else(|| malformed_error!("stream count size overflow"))?; + .ok_or_else(|| invalid("stream_count", "size overflow".into()))?; if stream_count == 0 || stream_count > MAX_STREAM_COUNT || stream_count_size > data.len() { - return Err(malformed_error!( - "Root: invalid stream count {} (must be 1-{}) [ECMA-335 §II.24.2.1]", - stream_count, - MAX_STREAM_COUNT - )); + return Err(invalid( + "stream_count", + format!("{stream_count} (must be 1-{MAX_STREAM_COUNT}) [ECMA-335 §II.24.2.1]",), + ) + .into()); } let mut streams = Vec::with_capacity(stream_count as usize); @@ -441,36 +466,40 @@ impl Root { .checked_add(usize::from(VERSION_STRING_OFFSET)) .and_then(|v| v.checked_add(FLAGS_FIELD_SIZE)) .and_then(|v| v.checked_add(STREAM_COUNT_FIELD_SIZE)) - .ok_or_else(|| malformed_error!("stream directory offset overflow"))?; + .ok_or_else(|| invalid("stream_directory_offset", "offset overflow".into()))?; let mut streams_seen = [false; MAX_STREAM_COUNT as usize]; for _i in 0..stream_count { if stream_offset > data.len() { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } - let stream_data = data.get(stream_offset..).ok_or(out_of_bounds_error!())?; + let stream_data = data + .get(stream_offset..) + .ok_or_else(|| Error::from(oob()))?; let new_stream = StreamHeader::from(stream_data)?; if new_stream.offset as usize > data.len() || new_stream.size as usize > data.len() || new_stream.name.len() > MAX_STREAM_NAME_LENGTH { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } match u32::checked_add(new_stream.offset, new_stream.size) { Some(range) => { if range as usize > data.len() { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } } None => { - return Err(malformed_error!( - "Root: stream '{}' offset {} + size {} causes integer overflow [ECMA-335 §II.24.2.2]", - new_stream.name, - new_stream.offset, - new_stream.size - )) + return Err(invalid( + "stream_extent", + format!( + "stream '{}' offset {} + size {} causes integer overflow [ECMA-335 §II.24.2.2]", + new_stream.name, new_stream.offset, new_stream.size + ), + ) + .into()) } } @@ -482,71 +511,88 @@ impl Root { "#~" => 4, "#-" => 5, _ => { - return Err(malformed_error!( - "Root: unrecognized stream name '{}' [ECMA-335 §II.24.2.2]", - new_stream.name - )) + return Err(invalid( + "stream_name", + format!( + "unrecognized stream '{}' [ECMA-335 §II.24.2.2]", + new_stream.name + ), + ) + .into()) } }; if *streams_seen .get(stream_index) - .ok_or(out_of_bounds_error!())? + .ok_or_else(|| Error::from(oob()))? { - return Err(malformed_error!( - "Root: duplicate stream '{}' found [ECMA-335 §II.24.2.2]", - new_stream.name - )); + return Err(invalid( + "stream_name", + format!( + "duplicate stream '{}' [ECMA-335 §II.24.2.2]", + new_stream.name + ), + ) + .into()); } *streams_seen .get_mut(stream_index) - .ok_or(out_of_bounds_error!())? = true; + .ok_or_else(|| Error::from(oob()))? = true; let name_aligned = new_stream .name .len() .checked_add(1) .and_then(|v| v.checked_add(3)) - .ok_or_else(|| malformed_error!("stream name alignment overflow"))? + .ok_or_else(|| invalid("stream_name", "alignment overflow".into()))? & !3usize; stream_offset = stream_offset .checked_add(STREAM_HEADER_FIXED_SIZE) .and_then(|v| v.checked_add(name_aligned)) - .ok_or_else(|| malformed_error!("stream offset overflow"))?; + .ok_or_else(|| invalid("stream_offset", "offset overflow".into()))?; streams.push(new_stream); } if streams.is_empty() { - return Err(malformed_error!( - "Root: no valid streams found [ECMA-335 §II.24.2.1]" - )); + return Err(invalid( + "streams", + "no valid streams found [ECMA-335 §II.24.2.1]".into(), + ) + .into()); } let flags_offset = usize::from(VERSION_STRING_OFFSET) .checked_add(version_string.len()) - .ok_or_else(|| malformed_error!("flags offset overflow"))?; + .ok_or_else(|| invalid("flags_offset", "offset overflow".into()))?; Ok(Root { signature, major_version: read_le::( data.get(FIELD_OFFSET_MAJOR_VERSION..) - .ok_or(out_of_bounds_error!())?, + .ok_or_else(|| Error::from(oob()))?, )?, minor_version: read_le::( data.get(FIELD_OFFSET_MINOR_VERSION..) - .ok_or(out_of_bounds_error!())?, + .ok_or_else(|| Error::from(oob()))?, )?, reserved: read_le::( data.get(FIELD_OFFSET_RESERVED..) - .ok_or(out_of_bounds_error!())?, + .ok_or_else(|| Error::from(oob()))?, )?, length: u32::try_from(version_string.len()).map_err(|_| { - malformed_error!("Root: version string length too large [ECMA-335 §II.24.2.1]") + invalid( + "version_string_length", + "string length too large [ECMA-335 §II.24.2.1]".into(), + ) + })?, + flags: read_le::(data.get(flags_offset..).ok_or_else(|| Error::from(oob()))?)?, + stream_number: u16::try_from(streams.len()).map_err(|_| { + invalid( + "stream_count", + "too many streams [ECMA-335 §II.24.2.1]".into(), + ) })?, - flags: read_le::(data.get(flags_offset..).ok_or(out_of_bounds_error!())?)?, - stream_number: u16::try_from(streams.len()) - .map_err(|_| malformed_error!("Root: too many streams [ECMA-335 §II.24.2.1]"))?, stream_headers: streams, version: version_string, }) @@ -569,16 +615,20 @@ impl Root { meta_root_offset: usize, total_metadata_size: u32, ) -> Result<()> { + let invalid = |field: &'static str, reason: String| ParseFailure::InvalidField { + stage: ParseStage::MetadataRoot, + field, + reason, + }; let mut stream_ranges: Vec<(u32, u32, &str)> = Vec::new(); // Validate stream doesn't exceed metadata bounds let metadata_end = meta_root_offset .checked_add(total_metadata_size as usize) .ok_or_else(|| { - malformed_error!( - "Metadata size causes overflow: {} + {}", - meta_root_offset, - total_metadata_size + invalid( + "metadata_size", + format!("size causes overflow: {meta_root_offset} + {total_metadata_size}",), ) })?; @@ -587,47 +637,55 @@ impl Root { let absolute_start = meta_root_offset .checked_add(stream.offset as usize) .ok_or_else(|| { - malformed_error!( - "Stream '{}' offset causes overflow: {} + {}", - stream.name, - meta_root_offset, - stream.offset + invalid( + "stream_offset", + format!( + "stream '{}' offset causes overflow: {} + {}", + stream.name, meta_root_offset, stream.offset + ), ) })?; let absolute_end = absolute_start .checked_add(stream.size as usize) .ok_or_else(|| { - malformed_error!( - "Stream '{}' size causes overflow: {} + {}", - stream.name, - absolute_start, - stream.size + invalid( + "stream_size", + format!( + "stream '{}' size causes overflow: {} + {}", + stream.name, absolute_start, stream.size + ), ) })?; if absolute_end > metadata_end { - return Err(malformed_error!( - "Stream '{}' extends beyond metadata bounds (end {} > metadata end {})", - stream.name, - absolute_end, - metadata_end - )); + return Err(invalid( + "stream_extent", + format!( + "stream '{}' extends beyond metadata bounds (end {} > metadata end {})", + stream.name, absolute_end, metadata_end + ), + ) + .into()); } stream_ranges.push(( u32::try_from(absolute_start).map_err(|_| { - malformed_error!( - "Stream '{}' start position {} exceeds u32 range", - stream.name, - absolute_start + invalid( + "stream_start", + format!( + "stream '{}' start position {} exceeds u32 range", + stream.name, absolute_start + ), ) })?, u32::try_from(absolute_end).map_err(|_| { - malformed_error!( - "Stream '{}' end position {} exceeds u32 range", - stream.name, - absolute_end + invalid( + "stream_end", + format!( + "stream '{}' end position {} exceeds u32 range", + stream.name, absolute_end + ), ) })?, &stream.name, @@ -638,15 +696,13 @@ impl Root { let skip = i.saturating_add(1); for &(start2, end2, name2) in stream_ranges.iter().skip(skip) { if start1 < end2 && start2 < end1 { - return Err(malformed_error!( - "Stream '{}' ({}..{}) overlaps with stream '{}' ({}..{})", - name1, - start1, - end1, - name2, - start2, - end2 - )); + return Err(invalid( + "stream_overlap", + format!( + "stream '{name1}' ({start1}..{end1}) overlaps with stream '{name2}' ({start2}..{end2})", + ), + ) + .into()); } } } @@ -726,16 +782,20 @@ impl Root { // Version string length (padded to 4-byte boundary) let version_bytes = self.version.as_bytes(); - let padded_len = version_bytes - .len() - .checked_add(3) - .ok_or_else(|| malformed_error!("Version string padded length overflow"))? - & !3usize; - let padded_len_u32 = u32::try_from(padded_len).map_err(|_| { - malformed_error!( - "Version string padded length {} exceeds u32 range", - padded_len - ) + let padded_len = + version_bytes + .len() + .checked_add(3) + .ok_or_else(|| ParseFailure::InvalidField { + stage: ParseStage::MetadataRoot, + field: "version_string_length", + reason: "padded length overflow".into(), + })? + & !3usize; + let padded_len_u32 = u32::try_from(padded_len).map_err(|_| ParseFailure::InvalidField { + stage: ParseStage::MetadataRoot, + field: "version_string_length", + reason: format!("padded length {padded_len} exceeds u32 range"), })?; writer.write_all(&padded_len_u32.to_le_bytes())?; diff --git a/dotscope/src/metadata/security/permissionset.rs b/dotscope/src/metadata/security/permissionset.rs index e4b8a5e4..b751dc20 100644 --- a/dotscope/src/metadata/security/permissionset.rs +++ b/dotscope/src/metadata/security/permissionset.rs @@ -277,7 +277,7 @@ use crate::{ PermissionSetFormat, SecurityPermissionFlags, }, utils::EnumUtils, - Result, + ParseFailure, ParseStage, Result, }; use quick_xml::{ events::{attributes::Attributes, Event}, @@ -670,7 +670,12 @@ impl PermissionSet { /// analysis scenarios, attribute-based parsing is sufficient. fn parse_xml_format(data: &[u8]) -> Result<(PermissionSetFormat, Vec)> { if data.len() < 5 { - return Err(malformed_error!("XML data too short")); + return Err(ParseFailure::Truncated { + stage: ParseStage::PermissionSet, + expected: 5, + found: data.len(), + } + .into()); } let xml_start = b" SignatureParser<'a> { #[cfg(test)] mod tests { - use crate::prelude::Token; - use crate::Error; - use super::*; + use crate::{prelude::Token, Error, ParseFailure}; #[test] fn test_parse_primitive_types() { @@ -2379,7 +2377,7 @@ mod tests { let mut parser = SignatureParser::new(&[0xFF, 0x01]); assert!(matches!( parser.parse_method_signature(), - Err(Error::OutOfBounds { .. }) + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) )); // Test invalid field signature format diff --git a/dotscope/src/metadata/streams/blob.rs b/dotscope/src/metadata/streams/blob.rs index e2f387aa..d3051771 100644 --- a/dotscope/src/metadata/streams/blob.rs +++ b/dotscope/src/metadata/streams/blob.rs @@ -146,7 +146,7 @@ //! - **ECMA-335 II.24.2.4**: `#Blob` heap specification //! - **ECMA-335 II.23.2**: Signature encoding formats stored in blobs -use crate::{file::parser::Parser, Result}; +use crate::{file::parser::Parser, Error, HeapKind, ParseFailure, Result}; /// ECMA-335 binary blob heap providing indexed access to variable-length data. /// @@ -354,7 +354,11 @@ impl<'a> Blob<'a> { pub fn from(data: &'a [u8]) -> Result> { match data.first() { Some(0) => Ok(Blob { data }), - _ => Err(malformed_error!("Invalid memory for #Blob heap")), + _ => Err(ParseFailure::HeapCorrupt { + heap: HeapKind::Blob, + reason: "first byte must be 0 (empty blob sentinel)".into(), + } + .into()), } } @@ -436,25 +440,29 @@ impl<'a> Blob<'a> { /// - [`crate::file::parser::Parser`]: For compressed integer parsing /// - [ECMA-335 II.23.2](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Compressed integer format pub fn get(&self, index: usize) -> Result<&'a [u8]> { + let oob = || ParseFailure::HeapOutOfBounds { + heap: HeapKind::Blob, + index: u32::try_from(index).unwrap_or(u32::MAX), + }; if index >= self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } - let mut parser = Parser::new(self.data.get(index..).ok_or(out_of_bounds_error!())?); + let mut parser = Parser::new(self.data.get(index..).ok_or_else(oob)?); let len = parser.read_compressed_uint()? as usize; let skip = parser.pos(); let Some(data_start) = index.checked_add(skip) else { - return Err(out_of_bounds_error!()); + return Err(oob().into()); }; let Some(data_end) = data_start.checked_add(len) else { - return Err(out_of_bounds_error!()); + return Err(oob().into()); }; self.data .get(data_start..data_end) - .ok_or(out_of_bounds_error!()) + .ok_or_else(|| Error::from(oob())) } /// Returns an iterator over all blobs in the heap. diff --git a/dotscope/src/metadata/streams/guid.rs b/dotscope/src/metadata/streams/guid.rs index c625e21e..f93079c5 100644 --- a/dotscope/src/metadata/streams/guid.rs +++ b/dotscope/src/metadata/streams/guid.rs @@ -144,7 +144,7 @@ //! - **ECMA-335 II.24.2.5**: `#GUID` heap specification //! - **RFC 4122**: UUID/GUID format and generation standards -use crate::Result; +use crate::{HeapKind, ParseFailure, ParseStage, Result}; /// ECMA-335 GUID heap providing indexed access to 128-bit globally unique identifiers. /// @@ -379,7 +379,12 @@ impl<'a> Guid<'a> { /// - [ECMA-335 II.24.2.5](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): GUID heap specification pub fn from(data: &'a [u8]) -> Result> { if data.len() < 16 { - return Err(malformed_error!("Data for #Guid heap is too small")); + return Err(ParseFailure::Truncated { + stage: ParseStage::Heap, + expected: 16, + found: data.len(), + } + .into()); } Ok(Guid { data }) @@ -501,19 +506,20 @@ impl<'a> Guid<'a> { /// - [`uguid::Guid`]: The returned GUID type with formatting and comparison methods /// - [ECMA-335 II.24.2.5](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): GUID heap specification pub fn get(&self, index: usize) -> Result { + let oob = || ParseFailure::HeapOutOfBounds { + heap: HeapKind::Guid, + index: u32::try_from(index).unwrap_or(u32::MAX), + }; if index < 1 { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } let offset_start = index .checked_sub(1) .and_then(|i| i.checked_mul(16)) - .ok_or(out_of_bounds_error!())?; - let offset_end = offset_start.checked_add(16).ok_or(out_of_bounds_error!())?; + .ok_or_else(oob)?; + let offset_end = offset_start.checked_add(16).ok_or_else(oob)?; - let bytes = self - .data - .get(offset_start..offset_end) - .ok_or(out_of_bounds_error!())?; + let bytes = self.data.get(offset_start..offset_end).ok_or_else(oob)?; let mut buffer = [0u8; 16]; buffer.copy_from_slice(bytes); diff --git a/dotscope/src/metadata/streams/streamheader.rs b/dotscope/src/metadata/streams/streamheader.rs index cd3ad89e..bc791ad7 100644 --- a/dotscope/src/metadata/streams/streamheader.rs +++ b/dotscope/src/metadata/streams/streamheader.rs @@ -159,7 +159,7 @@ //! - **ECMA-335 II.24.2.2**: Stream header format and directory structure //! - **ECMA-335 II.24.2**: Complete metadata stream architecture overview -use crate::{utils::read_le_at, Result}; +use crate::{utils::read_le_at, ParseFailure, ParseStage, Result}; use std::io::Write; /// ECMA-335 compliant stream header providing metadata stream location and identification. @@ -506,7 +506,12 @@ impl StreamHeader { /// - [ECMA-335 II.24.2.2](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Official stream header specification pub fn from(data: &[u8]) -> Result { if data.len() < 9 { - return Err(out_of_bounds_error!()); + return Err(ParseFailure::Truncated { + stage: ParseStage::StreamHeader, + expected: 9, + found: data.len(), + } + .into()); } let mut cursor = 0_usize; @@ -520,18 +525,22 @@ impl StreamHeader { // Validate offset bounds - offset must be reasonable if offset > 0x7FFF_FFFF { - return Err(malformed_error!( - "Stream offset {} exceeds maximum allowed value (0x7FFFFFFF)", - offset - )); + return Err(ParseFailure::InvalidField { + stage: ParseStage::StreamHeader, + field: "offset", + reason: format!("{offset} exceeds maximum 0x7FFFFFFF"), + } + .into()); } // Validate size bounds - prevent integer overflow and unreasonable sizes if size > 0x7FFF_FFFF { - return Err(malformed_error!( - "Stream size {} exceeds maximum allowed value (0x7FFFFFFF)", - size - )); + return Err(ParseFailure::InvalidField { + stage: ParseStage::StreamHeader, + field: "size", + reason: format!("{size} exceeds maximum 0x7FFFFFFF"), + } + .into()); } // After the 8-byte header, parse the name (max 32 chars or until end of data). @@ -552,7 +561,12 @@ impl StreamHeader { .iter() .any(|valid_name| name == *valid_name) { - return Err(malformed_error!("Invalid stream header name - {}", name)); + return Err(ParseFailure::InvalidField { + stage: ParseStage::StreamHeader, + field: "name", + reason: format!("unknown stream name `{name}`"), + } + .into()); } Ok(StreamHeader { offset, size, name }) @@ -616,22 +630,34 @@ impl StreamHeader { // Pad to 4-byte boundary // Name length + 1 (null terminator), padded to multiple of 4 - let name_with_null = self.name.len().checked_add(1).ok_or_else(|| { - malformed_error!("StreamHeader name length overflow: {}", self.name.len()) - })?; - let padded_len = name_with_null.checked_add(3).ok_or_else(|| { - malformed_error!( - "StreamHeader padded length overflow: {} + 3", - name_with_null - ) - })? & !3; - let padding = padded_len.checked_sub(name_with_null).ok_or_else(|| { - malformed_error!( - "StreamHeader padding underflow: padded={} name_with_null={}", - padded_len, - name_with_null - ) - })?; + let name_with_null = + self.name + .len() + .checked_add(1) + .ok_or_else(|| ParseFailure::InvalidField { + stage: ParseStage::StreamHeader, + field: "name", + reason: format!("name length overflow: {}", self.name.len()), + })?; + let padded_len = + name_with_null + .checked_add(3) + .ok_or_else(|| ParseFailure::InvalidField { + stage: ParseStage::StreamHeader, + field: "name", + reason: format!("padded length overflow: {name_with_null} + 3"), + })? + & !3; + let padding = + padded_len + .checked_sub(name_with_null) + .ok_or_else(|| ParseFailure::InvalidField { + stage: ParseStage::StreamHeader, + field: "name", + reason: format!( + "padding underflow: padded={padded_len} name_with_null={name_with_null}", + ), + })?; if padding > 0 { writer.write_all(&vec![0u8; padding])?; } diff --git a/dotscope/src/metadata/streams/strings.rs b/dotscope/src/metadata/streams/strings.rs index a9546f59..fa51b668 100644 --- a/dotscope/src/metadata/streams/strings.rs +++ b/dotscope/src/metadata/streams/strings.rs @@ -202,7 +202,7 @@ use std::{ffi::CStr, str}; -use crate::Result; +use crate::{Error, HeapKind, ParseFailure, Result}; /// ECMA-335 compliant `#Strings` heap providing UTF-8 identifier string access. /// @@ -533,7 +533,11 @@ impl<'a> Strings<'a> { /// - [ECMA-335 II.24.2.3](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Strings heap format specification pub fn from(data: &[u8]) -> Result> { if data.first() != Some(&0) { - return Err(malformed_error!("Provided #String heap is empty")); + return Err(ParseFailure::HeapCorrupt { + heap: HeapKind::Strings, + reason: "first byte must be 0 (empty string sentinel)".into(), + } + .into()); } Ok(Strings { data }) @@ -756,8 +760,12 @@ impl<'a> Strings<'a> { /// - [`crate::metadata::tables`]: Metadata tables containing string references /// - [ECMA-335 II.24.2.3](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Strings heap specification pub fn get(&self, index: usize) -> Result<&'a str> { + let oob = || ParseFailure::HeapOutOfBounds { + heap: HeapKind::Strings, + index: u32::try_from(index).unwrap_or(u32::MAX), + }; if index >= self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } // Performance note: Each call performs O(n) operations: @@ -773,13 +781,17 @@ impl<'a> Strings<'a> { // If profiling shows this is a bottleneck, callers should: // 1. Cache results at the call site for repeated access // 2. Use the iterator for bulk processing (avoids repeated bounds checks) - let slice = self.data.get(index..).ok_or(out_of_bounds_error!())?; + let slice = self.data.get(index..).ok_or_else(|| Error::from(oob()))?; + let corrupt = || ParseFailure::HeapCorrupt { + heap: HeapKind::Strings, + reason: format!("invalid UTF-8 string at index {index}"), + }; match CStr::from_bytes_until_nul(slice) { Ok(result) => match result.to_str() { Ok(result) => Ok(result), - Err(_) => Err(malformed_error!("Invalid string at index - {}", index)), + Err(_) => Err(corrupt().into()), }, - Err(_) => Err(malformed_error!("Invalid string at index - {}", index)), + Err(_) => Err(corrupt().into()), } } diff --git a/dotscope/src/metadata/streams/tablesheader.rs b/dotscope/src/metadata/streams/tablesheader.rs index 2839e5fa..114df7b8 100644 --- a/dotscope/src/metadata/streams/tablesheader.rs +++ b/dotscope/src/metadata/streams/tablesheader.rs @@ -312,7 +312,7 @@ use crate::{ TableInfo, TableInfoRef, TypeDefRaw, TypeRefRaw, TypeSpecRaw, }, utils::read_le, - Result, + Error, ParseFailure, ParseStage, Result, }; /// ECMA-335 compliant metadata tables header providing efficient access to .NET assembly metadata. @@ -1008,25 +1008,51 @@ impl<'a> TablesHeader<'a> { /// - [`crate::metadata::tables::TableInfo`]: Table metadata and row count information /// - [ECMA-335 II.24.2.6](https://ecma-international.org/wp-content/uploads/ECMA-335_6th_edition_june_2012.pdf): Tables header specification pub fn from(data: &'a [u8]) -> Result> { + let truncated = |expected: usize, found: usize| ParseFailure::Truncated { + stage: ParseStage::TildeStream, + expected, + found, + }; if data.len() < 24 { - return Err(out_of_bounds_error!()); + return Err(truncated(24, data.len()).into()); } - let valid_bitvec = read_le::(data.get(8..).ok_or(out_of_bounds_error!())?)?; + let valid_bitvec = read_le::( + data.get(8..) + .ok_or_else(|| Error::from(truncated(16, data.len())))?, + )?; if valid_bitvec == 0 { - return Err(malformed_error!("No valid rows in any of the tables")); + return Err(ParseFailure::InvalidField { + stage: ParseStage::TildeStream, + field: "valid", + reason: "no valid rows in any of the tables".into(), + } + .into()); } let tables_offset = 24usize .checked_add((valid_bitvec.count_ones() as usize).saturating_mul(4)) - .ok_or_else(|| malformed_error!("Tables header offset overflow"))?; + .ok_or_else(|| ParseFailure::InvalidField { + stage: ParseStage::TildeStream, + field: "valid", + reason: "tables-header offset overflow".into(), + })?; let table_capacity = (TableId::CustomDebugInformation as usize).saturating_add(1); let mut tables_header = TablesHeader { - major_version: read_le::(data.get(4..).ok_or(out_of_bounds_error!())?)?, - minor_version: read_le::(data.get(5..).ok_or(out_of_bounds_error!())?)?, + major_version: read_le::( + data.get(4..) + .ok_or_else(|| Error::from(truncated(5, data.len())))?, + )?, + minor_version: read_le::( + data.get(5..) + .ok_or_else(|| Error::from(truncated(6, data.len())))?, + )?, valid: valid_bitvec, - sorted: read_le::(data.get(16..).ok_or(out_of_bounds_error!())?)?, + sorted: read_le::( + data.get(16..) + .ok_or_else(|| Error::from(truncated(24, data.len())))?, + )?, info: Arc::new(TableInfo::new(data, valid_bitvec)?), tables_offset, tables: Vec::with_capacity(table_capacity), @@ -1039,11 +1065,18 @@ impl<'a> TablesHeader<'a> { let mut current_offset = tables_header.tables_offset; for table_id in TableId::iter() { if current_offset > data.len() { - return Err(out_of_bounds_error!()); + return Err(ParseFailure::OutOfBounds { + stage: ParseStage::TildeStream, + } + .into()); } tables_header.add_table( - data.get(current_offset..).ok_or(out_of_bounds_error!())?, + data.get(current_offset..).ok_or_else(|| { + Error::from(ParseFailure::OutOfBounds { + stage: ParseStage::TildeStream, + }) + })?, table_id, &mut current_offset, )?; diff --git a/dotscope/src/metadata/streams/userstrings.rs b/dotscope/src/metadata/streams/userstrings.rs index 30363240..a03acb15 100644 --- a/dotscope/src/metadata/streams/userstrings.rs +++ b/dotscope/src/metadata/streams/userstrings.rs @@ -42,7 +42,7 @@ use crate::{ utils::{read_compressed_int, read_compressed_int_at}, - Result, + Error, HeapKind, ParseFailure, Result, }; use widestring::U16Str; @@ -125,7 +125,11 @@ impl<'a> UserStrings<'a> { /// ``` pub fn from(data: &'a [u8]) -> Result> { if data.first().copied() != Some(0) { - return Err(out_of_bounds_error!()); + return Err(ParseFailure::HeapCorrupt { + heap: HeapKind::UserStrings, + reason: "first byte must be 0 (empty string sentinel)".into(), + } + .into()); } Ok(UserStrings { data }) @@ -169,20 +173,25 @@ impl<'a> UserStrings<'a> { /// used on a raw pointer conversion that is guaranteed to succeed when the input slice /// is valid. pub fn get(&self, index: usize) -> Result<&'a U16Str> { + let oob = || ParseFailure::HeapOutOfBounds { + heap: HeapKind::UserStrings, + index: u32::try_from(index).unwrap_or(u32::MAX), + }; + let corrupt = |reason: String| ParseFailure::HeapCorrupt { + heap: HeapKind::UserStrings, + reason, + }; if index >= self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } let (total_bytes, compressed_length_size) = read_compressed_int_at(self.data, index)?; let data_start = index .checked_add(compressed_length_size) - .ok_or_else(|| malformed_error!("User string offset overflow at index {}", index))?; + .ok_or_else(|| corrupt(format!("offset overflow at index {index}")))?; if total_bytes == 0 { - return Err(malformed_error!( - "Invalid zero-length string at index {}", - index - )); + return Err(corrupt(format!("zero-length entry at index {index}")).into()); } if total_bytes == 1 { @@ -194,26 +203,26 @@ impl<'a> UserStrings<'a> { // So actual UTF-16 data is total_bytes - 1 let utf16_length = total_bytes .checked_sub(1) - .ok_or_else(|| malformed_error!("User string length underflow at index {}", index))?; + .ok_or_else(|| corrupt(format!("length underflow at index {index}")))?; let total_data_end = data_start .checked_add(total_bytes) - .ok_or_else(|| malformed_error!("User string end overflow at index {}", index))?; + .ok_or_else(|| corrupt(format!("end overflow at index {index}")))?; if total_data_end > self.data.len() { - return Err(out_of_bounds_error!()); + return Err(oob().into()); } if utf16_length % 2 != 0 { - return Err(malformed_error!("Invalid UTF-16 length at index {}", index)); + return Err(corrupt(format!("odd UTF-16 byte count at index {index}")).into()); } let utf16_data_end = data_start .checked_add(utf16_length) - .ok_or_else(|| malformed_error!("User string data end overflow at index {}", index))?; + .ok_or_else(|| corrupt(format!("data end overflow at index {index}")))?; let utf16_data = self .data .get(data_start..utf16_data_end) - .ok_or(out_of_bounds_error!())?; + .ok_or_else(|| Error::from(oob()))?; // Convert byte slice to u16 slice for UTF-16 string construction. // @@ -240,7 +249,7 @@ impl<'a> UserStrings<'a> { #[allow(clippy::cast_ptr_alignment)] core::ptr::slice_from_raw_parts(ptr.cast::(), utf16_data.len() / 2) .as_ref() - .ok_or_else(|| malformed_error!("null pointer in user string slice conversion"))? + .ok_or_else(|| corrupt("null pointer in user string slice conversion".into()))? }; Ok(U16Str::from_slice(str_slice)) diff --git a/dotscope/src/metadata/tables/genericparam/mod.rs b/dotscope/src/metadata/tables/genericparam/mod.rs index 985be414..96dd4517 100644 --- a/dotscope/src/metadata/tables/genericparam/mod.rs +++ b/dotscope/src/metadata/tables/genericparam/mod.rs @@ -179,15 +179,26 @@ impl GenericParamAttributes { /// specification. These bits should be zero in valid metadata. pub const RESERVED_MASK: Self = Self(0xFFC0); - /// Extract the variance bits from the flags. + /// Returns the variance classification of this generic parameter. /// - /// Returns the variance portion of the flags by masking with [`VARIANCE_MASK`](Self::VARIANCE_MASK). - /// The result can be compared with [`COVARIANT`](Self::COVARIANT) or - /// [`CONTRAVARIANT`](Self::CONTRAVARIANT). + /// Decodes the variance bits per ECMA-335 §II.22.20 and returns a typed + /// [`GenericParamVariance`] enum so callers can pattern-match exhaustively + /// instead of comparing against the [`COVARIANT`](Self::COVARIANT) / + /// [`CONTRAVARIANT`](Self::CONTRAVARIANT) constants by hand. + /// + /// The mapping is: `0x00` → [`Invariant`](GenericParamVariance::Invariant), + /// `0x01` → [`Covariant`](GenericParamVariance::Covariant), + /// `0x02` → [`Contravariant`](GenericParamVariance::Contravariant). Bit + /// pattern `0x03` is reserved/undefined; it is treated as `Invariant`. #[inline] #[must_use] - pub const fn variance(self) -> Self { - Self(self.0 & Self::VARIANCE_MASK.0) + pub const fn variance(self) -> GenericParamVariance { + match self.0 & Self::VARIANCE_MASK.0 { + 0x0001 => GenericParamVariance::Covariant, + 0x0002 => GenericParamVariance::Contravariant, + // 0x0000 (invariant) and 0x0003 (reserved/undefined) — both default to Invariant. + _ => GenericParamVariance::Invariant, + } } /// Extract the special constraint bits from the flags. @@ -208,9 +219,9 @@ impl GenericParamAttributes { #[must_use] pub fn variance_keyword(self) -> &'static str { match self.variance() { - Self::COVARIANT => "+", - Self::CONTRAVARIANT => "-", - _ => "", + GenericParamVariance::Covariant => "+", + GenericParamVariance::Contravariant => "-", + GenericParamVariance::Invariant => "", } } @@ -238,3 +249,115 @@ impl GenericParamAttributes { parts.join(" ") } } + +/// Variance classification of a generic parameter (ECMA-335 §II.22.20). +/// +/// Returned by [`GenericParamAttributes::variance`]. The three variants are +/// fixed by the standard and mutually exclusive — exactly one applies to any +/// given generic parameter. +/// +/// # Stability +/// +/// The string returned by [`GenericParamVariance::as_str`] and by the +/// [`Display`] impl is part of the stable public API. It is safe to persist +/// (file, database, log line) and to parse. +/// +/// [`Display`]: std::fmt::Display +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)] +pub enum GenericParamVariance { + /// No variance — exact type match required. + /// + /// Corresponds to ECMA-335 variance bits `0x00` (and the reserved/undefined + /// pattern `0x03`). + Invariant, + /// Covariant — type argument may be a more derived type. + /// + /// Corresponds to ECMA-335 variance bit `0x01` and the C# `out` keyword + /// (e.g. `IEnumerable`). + Covariant, + /// Contravariant — type argument may be a less derived type. + /// + /// Corresponds to ECMA-335 variance bit `0x02` and the C# `in` keyword + /// (e.g. `Action`). + Contravariant, +} + +impl GenericParamVariance { + /// Returns a stable `&'static str` identifier for this variance. + /// + /// Identifiers are lowercase (`"invariant"`, `"covariant"`, `"contravariant"`). + /// They are part of the stable public API and safe to persist. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + GenericParamVariance::Invariant => "invariant", + GenericParamVariance::Covariant => "covariant", + GenericParamVariance::Contravariant => "contravariant", + } + } +} + +impl std::fmt::Display for GenericParamVariance { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str(self.as_str()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_generic_param_attributes_variance() { + assert_eq!( + GenericParamAttributes::new(0).variance(), + GenericParamVariance::Invariant + ); + assert_eq!( + GenericParamAttributes::COVARIANT.variance(), + GenericParamVariance::Covariant + ); + assert_eq!( + GenericParamAttributes::CONTRAVARIANT.variance(), + GenericParamVariance::Contravariant + ); + // Reserved/undefined bit pattern 0x03 — defensively maps to Invariant. + assert_eq!( + GenericParamAttributes::new(0x0003).variance(), + GenericParamVariance::Invariant + ); + // Constraint bits in higher positions must not affect variance decoding. + let with_constraints = GenericParamAttributes::COVARIANT + | GenericParamAttributes::REFERENCE_TYPE_CONSTRAINT + | GenericParamAttributes::DEFAULT_CONSTRUCTOR_CONSTRAINT; + assert_eq!(with_constraints.variance(), GenericParamVariance::Covariant); + } + + #[test] + fn test_generic_param_variance_stable_strings() { + assert_eq!(GenericParamVariance::Invariant.as_str(), "invariant"); + assert_eq!(GenericParamVariance::Covariant.as_str(), "covariant"); + assert_eq!( + GenericParamVariance::Contravariant.as_str(), + "contravariant" + ); + + assert_eq!(format!("{}", GenericParamVariance::Invariant), "invariant"); + assert_eq!(format!("{}", GenericParamVariance::Covariant), "covariant"); + assert_eq!( + format!("{}", GenericParamVariance::Contravariant), + "contravariant" + ); + } + + #[test] + fn test_variance_keyword_after_enum_migration() { + // Pre-existing keyword behavior must survive the variance() return-type change. + assert_eq!(GenericParamAttributes::new(0).variance_keyword(), ""); + assert_eq!(GenericParamAttributes::COVARIANT.variance_keyword(), "+"); + assert_eq!( + GenericParamAttributes::CONTRAVARIANT.variance_keyword(), + "-" + ); + } +} diff --git a/dotscope/src/metadata/typesystem/primitives.rs b/dotscope/src/metadata/typesystem/primitives.rs index 5aac1417..e7bbdd07 100644 --- a/dotscope/src/metadata/typesystem/primitives.rs +++ b/dotscope/src/metadata/typesystem/primitives.rs @@ -99,9 +99,17 @@ use crate::{ }, utils::read_le, Error::{self, TypeConversionInvalid, TypeNotPrimitive}, - Result, + ParseFailure, ParseStage, Result, }; +#[inline] +fn primitives_oob() -> Error { + ParseFailure::OutOfBounds { + stage: ParseStage::Generic, + } + .into() +} + /// Type-safe storage for constant primitive values. /// /// `CilPrimitiveData` provides a unified storage mechanism for all .NET primitive constant @@ -402,14 +410,14 @@ impl CilPrimitiveData { pub fn from_bytes(type_byte: u8, data: &[u8]) -> Result { match type_byte { ELEMENT_TYPE::BOOLEAN => { - let b = data.first().ok_or(out_of_bounds_error!())?; + let b = data.first().ok_or(primitives_oob())?; Ok(CilPrimitiveData::Boolean(*b != 0)) } ELEMENT_TYPE::CHAR => { - let bytes = data.get(0..2).ok_or(out_of_bounds_error!())?; + let bytes = data.get(0..2).ok_or(primitives_oob())?; let code = u16::from_le_bytes([ - *bytes.first().ok_or(out_of_bounds_error!())?, - *bytes.get(1).ok_or(out_of_bounds_error!())?, + *bytes.first().ok_or(primitives_oob())?, + *bytes.get(1).ok_or(primitives_oob())?, ]); // .NET System.Char is a UTF-16 code unit, so any u16 value is valid Ok(CilPrimitiveData::Char(code)) @@ -432,24 +440,32 @@ impl CilPrimitiveData { } if !data.len().is_multiple_of(2) { - return Err(malformed_error!( - "Invalid UTF-16 string length: {} (must be even)", - data.len() - )); + return Err(ParseFailure::InvalidField { + stage: ParseStage::Generic, + field: "utf16_length", + reason: format!( + "invalid UTF-16 string length: {} (must be even)", + data.len() + ), + } + .into()); } let mut utf16_chars: Vec = Vec::with_capacity(data.len() / 2); for chunk in data.chunks_exact(2) { - let b0 = *chunk.first().ok_or(out_of_bounds_error!())?; - let b1 = *chunk.get(1).ok_or(out_of_bounds_error!())?; + let b0 = *chunk.first().ok_or(primitives_oob())?; + let b1 = *chunk.get(1).ok_or(primitives_oob())?; utf16_chars.push(u16::from_le_bytes([b0, b1])); } match String::from_utf16(&utf16_chars) { Ok(utf_string) => Ok(CilPrimitiveData::String(utf_string)), - Err(_) => Err(malformed_error!( - "Invalid UTF-16 sequence in primitive string" - )), + Err(_) => Err(ParseFailure::InvalidField { + stage: ParseStage::Generic, + field: "utf16_string", + reason: "invalid UTF-16 sequence in primitive string".into(), + } + .into()), } } ELEMENT_TYPE::CLASS => { @@ -756,6 +772,53 @@ impl CilPrimitiveKind { _ => Err(TypeNotPrimitive), } } + + /// Returns a stable `&'static str` identifier for this primitive kind. + /// + /// The strings follow ILAsm / ECMA-335 §I.8.2.2 conventions and are part + /// of the stable public API — safe to persist (file, database, log line) + /// and to parse. Pointer-sized integers and the structural kinds use the + /// short ILAsm names rather than the C# / `System.*` aliases: + /// + /// `"void"`, `"bool"`, `"char"`, `"int8"`, `"uint8"`, `"int16"`, + /// `"uint16"`, `"int32"`, `"uint32"`, `"int64"`, `"uint64"`, + /// `"float32"`, `"float64"`, `"native int"`, `"native uint"`, + /// `"object"`, `"string"`, `"null"`, `"typedref"`, `"valuetype"`, + /// `"var"`, `"mvar"`, `"class"`. + #[must_use] + pub const fn as_str(&self) -> &'static str { + match self { + CilPrimitiveKind::Void => "void", + CilPrimitiveKind::Boolean => "bool", + CilPrimitiveKind::Char => "char", + CilPrimitiveKind::I1 => "int8", + CilPrimitiveKind::U1 => "uint8", + CilPrimitiveKind::I2 => "int16", + CilPrimitiveKind::U2 => "uint16", + CilPrimitiveKind::I4 => "int32", + CilPrimitiveKind::U4 => "uint32", + CilPrimitiveKind::I8 => "int64", + CilPrimitiveKind::U8 => "uint64", + CilPrimitiveKind::R4 => "float32", + CilPrimitiveKind::R8 => "float64", + CilPrimitiveKind::I => "native int", + CilPrimitiveKind::U => "native uint", + CilPrimitiveKind::Object => "object", + CilPrimitiveKind::String => "string", + CilPrimitiveKind::Null => "null", + CilPrimitiveKind::TypedReference => "typedref", + CilPrimitiveKind::ValueType => "valuetype", + CilPrimitiveKind::Var => "var", + CilPrimitiveKind::MVar => "mvar", + CilPrimitiveKind::Class => "class", + } + } +} + +impl fmt::Display for CilPrimitiveKind { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + f.write_str(self.as_str()) + } } impl CilPrimitive { @@ -1527,6 +1590,7 @@ impl TryFrom for CilPrimitive { #[cfg(test)] mod tests { use super::*; + use crate::ParseFailure; use std::convert::TryFrom; #[test] @@ -2360,15 +2424,24 @@ mod tests { fn test_from_blob_error_cases() { let result = CilPrimitiveData::from_bytes(ELEMENT_TYPE::BOOLEAN, &[]); assert!(result.is_err()); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); let result = CilPrimitiveData::from_bytes(ELEMENT_TYPE::CHAR, &[]); assert!(result.is_err()); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); let result = CilPrimitiveData::from_bytes(ELEMENT_TYPE::I4, &[1, 2]); assert!(result.is_err()); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); let result = CilPrimitiveData::from_bytes(ELEMENT_TYPE::STRING, &[]); assert!(result.is_ok()); @@ -2677,4 +2750,40 @@ mod tests { panic!("Expected R4 data"); } } + + #[test] + fn test_cil_primitive_kind_stable_strings() { + // The strings here are part of the stable public API. They follow + // ILAsm / ECMA-335 §I.8.2.2 conventions. Changing them is a breaking + // change for downstream consumers that persist these values. + let cases = [ + (CilPrimitiveKind::Void, "void"), + (CilPrimitiveKind::Boolean, "bool"), + (CilPrimitiveKind::Char, "char"), + (CilPrimitiveKind::I1, "int8"), + (CilPrimitiveKind::U1, "uint8"), + (CilPrimitiveKind::I2, "int16"), + (CilPrimitiveKind::U2, "uint16"), + (CilPrimitiveKind::I4, "int32"), + (CilPrimitiveKind::U4, "uint32"), + (CilPrimitiveKind::I8, "int64"), + (CilPrimitiveKind::U8, "uint64"), + (CilPrimitiveKind::R4, "float32"), + (CilPrimitiveKind::R8, "float64"), + (CilPrimitiveKind::I, "native int"), + (CilPrimitiveKind::U, "native uint"), + (CilPrimitiveKind::Object, "object"), + (CilPrimitiveKind::String, "string"), + (CilPrimitiveKind::Null, "null"), + (CilPrimitiveKind::TypedReference, "typedref"), + (CilPrimitiveKind::ValueType, "valuetype"), + (CilPrimitiveKind::Var, "var"), + (CilPrimitiveKind::MVar, "mvar"), + (CilPrimitiveKind::Class, "class"), + ]; + for (variant, expected) in cases { + assert_eq!(variant.as_str(), expected); + assert_eq!(format!("{variant}"), expected); + } + } } diff --git a/dotscope/src/metadata/validation/config.rs b/dotscope/src/metadata/validation/config.rs index 7143367b..4b2531e2 100644 --- a/dotscope/src/metadata/validation/config.rs +++ b/dotscope/src/metadata/validation/config.rs @@ -21,6 +21,34 @@ //! 1. **Raw Validation**: Validates raw assembly data during [`crate::metadata::cilassemblyview::CilAssemblyView`] loading //! 2. **Owned Validation**: Validates resolved data structures during [`crate::metadata::cilobject::CilObject`] creation //! +//! # Field-to-Validator Map +//! +//! Each `enable_*` field on [`ValidationConfig`] gates a specific group of +//! validators. The table below summarizes what each field controls, what kind +//! of malformed input it catches, and roughly how expensive it is. Use this as +//! ground truth when picking a preset or building a custom config. +//! +//! | Field | Stage | Gates | Catches | Cost class | +//! |------------------------------------|-----------------|----------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------|-----------------------------------------| +//! | `enable_raw_validation` | meta (stage 1) | enables/disables the entire raw pipeline | n/a — gates the structural/token/constraint validators below | n/a | +//! | `enable_owned_validation` | meta (stage 2) | enables/disables the entire owned pipeline | n/a — gates the cross-table/semantic/method validators below | n/a | +//! | `enable_structural_validation` | raw (1) | `RawTokenValidator`, `RawTableValidator`, `RawHeapValidator` | malformed token format, RID overflow, table-structure breakage, heap offsets out of bounds | cheap (linear scan) | +//! | `enable_token_validation` | raw (1) | `RawSignatureValidator` | malformed signature blobs, invalid coded-index tag values | cheap (linear scan) | +//! | `enable_constraint_validation` | raw (1) + owned (2) | `RawGenericConstraintValidator`, `RawLayoutConstraintValidator` | invalid constraint targets, circular constraints, layout overlap violations on `[StructLayout(LayoutKind.Explicit)]` | cheap–moderate (table walk) | +//! | `enable_cross_table_validation` | owned (2) | `OwnedCircularityValidator`, `OwnedDependencyValidator`, `OwnedOwnershipValidator` | broken cross-references, circular type hierarchies, orphaned metadata | moderate (table walk + graph analysis) | +//! | `enable_semantic_validation` | owned (2) | `OwnedFieldValidator`, `OwnedAccessibilityValidator`, `OwnedTypeDefinitionValidator`, `OwnedTypeCircularityValidator`, `OwnedTypeDependencyValidator`, `OwnedTypeOwnershipValidator`, `OwnedAttributeValidator`, `OwnedSecurityValidator`, `OwnedAssemblyValidator` | ECMA-335 semantic rules: access-modifier breaches, SpecialName violations, abstract/sealed conflicts, naming convention breaches, duplicate fields | moderate–expensive (multi-pass type-system walk) | +//! | `enable_method_validation` | owned (2) | `OwnedMethodValidator`, `OwnedSignatureValidator` | concrete types declaring abstract methods, final-override attempts, signature incompatibilities, invalid constructors | moderate (signature resolution + inheritance walk) | +//! | `enable_type_system_validation` | owned (2) | reserved — currently subsumed by `enable_semantic_validation` | n/a until wired | n/a | +//! | `enable_field_layout_validation` | owned (2) | reserved — currently subsumed by `enable_owned_validation` | n/a until wired | n/a | +//! | `max_nesting_depth` | owned (2) | nested-type depth ceiling | over-deep nested-class chains (default `64`; set `0` to disable the check) | cheap (counter) | +//! | `lenient` | both stages | error handling mode | n/a — when `true`, errors become diagnostics instead of aborting load | n/a | +//! +//! Two of the fields above (`enable_type_system_validation`, +//! `enable_field_layout_validation`) are accepted for forward compatibility but +//! not yet observed by any validator's `should_run()` check. Setting them is +//! safe; the validators they would gate are currently controlled by +//! `enable_semantic_validation` / `enable_owned_validation`. +//! //! # Key Components //! //! - [`crate::metadata::validation::config::ValidationConfig`] - Main configuration struct with predefined presets @@ -124,12 +152,24 @@ pub struct ValidationConfig { /// Validates semantic consistency across metadata tables pub enable_cross_table_validation: bool, - /// Enable field layout validation (overlap detection, offset validation) - /// Only useful for types with explicit layout; detects problematic overlaps + /// Enable field layout validation (overlap detection, offset validation). + /// + /// Only useful for types with explicit layout; detects problematic overlaps. + /// + /// **Reserved for forward compatibility.** No validator currently checks this + /// flag in `should_run()`. Field-layout validation is presently controlled by + /// `enable_owned_validation`. Setting this flag is safe but has no effect + /// until the dedicated layout validators are wired through it. pub enable_field_layout_validation: bool, - /// Enable type system validation (inheritance chains, generic constraints) - /// Validates logical consistency of type hierarchies and generic constraints + /// Enable type system validation (inheritance chains, generic constraints). + /// + /// Validates logical consistency of type hierarchies and generic constraints. + /// + /// **Reserved for forward compatibility.** No validator currently checks this + /// flag in `should_run()`. Type-system validation is presently controlled by + /// `enable_semantic_validation`. Setting this flag is safe but has no effect + /// until the dedicated type-system validators are wired through it. pub enable_type_system_validation: bool, /// Enable semantic validation (method consistency, access modifiers, abstract/concrete rules) @@ -208,6 +248,12 @@ impl ValidationConfig { /// - Potential for crashes on invalid data /// - Silent acceptance of ECMA-335 violations /// + /// # Field values + /// + /// All validation fields are `false`, including the stage gates. Every + /// `enable_*` flag is off, `lenient = false`, `max_nesting_depth = 0`. + /// The two pipelines never run. + /// /// # Examples /// /// ```rust,no_run @@ -259,6 +305,16 @@ impl ValidationConfig { /// - Semantic rule enforcement /// - Method signature validation /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = true`, `enable_owned_validation = false` + /// - **Raw validators on:** `enable_structural_validation` + /// - **Raw validators off:** `enable_token_validation`, `enable_constraint_validation` + /// - **Owned validators off:** `enable_cross_table_validation`, + /// `enable_field_layout_validation`, `enable_type_system_validation`, + /// `enable_semantic_validation`, `enable_method_validation` + /// - **Other:** `lenient = false`, `max_nesting_depth = 64` + /// /// # Examples /// /// ```rust,no_run @@ -298,6 +354,17 @@ impl ValidationConfig { /// /// Returns a [`crate::metadata::validation::config::ValidationConfig`] with all validation enabled. /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = true`, `enable_owned_validation = true` + /// - **All `enable_*` validators on** + /// - **Other:** `lenient = false`, `max_nesting_depth = 64` + /// + /// Identical in field values to [`production`](Self::production) and + /// [`strict`](Self::strict); the three exist as named entry points so + /// downstream code can document intent. Pick whichever name best matches + /// the calling context. + /// /// # Examples /// /// ```rust,no_run @@ -336,6 +403,16 @@ impl ValidationConfig { /// - Token: Runtime validates token references for security /// - Constraint: Runtime validates generic and layout constraints /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = true`, `enable_owned_validation = true` + /// - **All `enable_*` validators on** + /// - **Other:** `lenient = false`, `max_nesting_depth = 64` + /// + /// Identical in field values to [`comprehensive`](Self::comprehensive) and + /// [`strict`](Self::strict). Use `production` when you specifically want + /// to communicate that you are matching .NET runtime validation behavior. + /// /// # Examples /// /// ```rust,no_run @@ -375,6 +452,17 @@ impl ValidationConfig { /// /// Returns a [`crate::metadata::validation::config::ValidationConfig`] with strict validation enabled. /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = true`, `enable_owned_validation = true` + /// - **All `enable_*` validators on** + /// - **Other:** `lenient = false`, `max_nesting_depth = 64` + /// + /// Identical in field values to [`comprehensive`](Self::comprehensive) and + /// [`production`](Self::production). The strict alias exists to call out + /// that field-layout validation can flag legitimate overlapping fields + /// (see note below) — pick this name when that risk is acceptable. + /// /// # Examples /// /// ```rust,no_run @@ -436,6 +524,14 @@ impl ValidationConfig { /// /// Returns a [`crate::metadata::validation::config::ValidationConfig`] configured for raw validation only. /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = true`, `enable_owned_validation = false` + /// - **Raw validators on:** `enable_structural_validation` + /// - **Raw validators off:** `enable_token_validation`, `enable_constraint_validation` + /// - **Owned validators all off** + /// - **Other:** `lenient = false`, `max_nesting_depth = 64` + /// /// # Examples /// /// ```rust,no_run @@ -473,6 +569,22 @@ impl ValidationConfig { /// /// Returns a [`crate::metadata::validation::config::ValidationConfig`] configured for owned validation only. /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = false`, `enable_owned_validation = true` + /// - **Raw validators all off** (`enable_structural_validation`, + /// `enable_token_validation` are skipped along with the stage) + /// - **Owned validators on:** `enable_cross_table_validation`, + /// `enable_field_layout_validation`, `enable_type_system_validation`, + /// `enable_semantic_validation`, `enable_method_validation`, + /// `enable_constraint_validation` + /// - **Other:** `lenient = false`, `max_nesting_depth = 64` + /// + /// Caller is responsible for guaranteeing the raw stage already passed + /// (or is being deliberately skipped). Loading malformed metadata with + /// the raw stage off may surface garbage as validation errors at + /// stage 2 instead of clean parse failures. + /// /// # Examples /// /// ```rust,no_run @@ -520,6 +632,18 @@ impl ValidationConfig { /// - **Comprehensive validation**: ALL validation checks enabled to collect maximum diagnostic info /// - **Complete error collection**: Continues through all checks to build full diagnostic report /// + /// # Field values + /// + /// - **Stage gates:** `enable_raw_validation = true`, `enable_owned_validation = true` + /// - **All `enable_*` validators on** (same as [`comprehensive`](Self::comprehensive)) + /// - **Other:** `lenient = true` (the load-bearing differentiator), + /// `max_nesting_depth = 64` + /// + /// The single distinguishing field versus `comprehensive`/`production`/`strict` + /// is `lenient = true`: errors from the loader and the validation engine + /// flow into [`crate::metadata::diagnostics::Diagnostics`] instead of + /// short-circuiting the load. + /// /// # Use Cases /// /// - Analyzing obfuscated assemblies (ConfuserEx, etc.) diff --git a/dotscope/src/metadata/validation/scanner.rs b/dotscope/src/metadata/validation/scanner.rs index 5e6fe058..a9b27312 100644 --- a/dotscope/src/metadata/validation/scanner.rs +++ b/dotscope/src/metadata/validation/scanner.rs @@ -67,7 +67,7 @@ use crate::{ }, token::Token, }, - Blob, Error, Guid, Result, Strings, UserStrings, + Blob, Error, Guid, HeapKind, ParseFailure, Result, Strings, UserStrings, }; use rustc_hash::{FxHashMap, FxHashSet}; @@ -816,19 +816,23 @@ impl ReferenceScanner { "blobs" => self.heap_sizes.blobs, "guids" => self.heap_sizes.guids, "userstrings" => self.heap_sizes.userstrings, - _ => { - return Err(Error::HeapBoundsError { - heap: heap_type.to_string(), - index, - }) + other => { + return Err(Error::Parse(ParseFailure::Other { + stage: crate::ParseStage::Validation, + message: format!("unknown heap '{other}' for index {index}"), + })) } }; if index >= max_size { - return Err(Error::HeapBoundsError { - heap: heap_type.to_string(), - index, - }); + let heap = match heap_type { + "strings" => HeapKind::Strings, + "blobs" => HeapKind::Blob, + "guids" => HeapKind::Guid, + "userstrings" => HeapKind::UserStrings, + _ => unreachable!("heap_type already validated above"), + }; + return Err(Error::Parse(ParseFailure::HeapOutOfBounds { heap, index })); } Ok(()) diff --git a/dotscope/src/metadata/validation/shared/mod.rs b/dotscope/src/metadata/validation/shared/mod.rs index a1064732..8fa45bf9 100644 --- a/dotscope/src/metadata/validation/shared/mod.rs +++ b/dotscope/src/metadata/validation/shared/mod.rs @@ -66,3 +66,15 @@ mod tokens; pub use references::ReferenceValidator; pub use schema::SchemaValidator; pub use tokens::TokenValidator; + +/// Returns the canonical [`crate::Error::Parse`] error used by raw validators +/// when the assembly view does not contain the expected `#~` metadata-tables +/// stream. Centralized so all 14+ raw-validator call sites surface the same +/// structured failure. +pub(crate) fn err_no_metadata_tables() -> crate::Error { + crate::Error::Parse(crate::ParseFailure::InvalidField { + stage: crate::ParseStage::Validation, + field: "metadata_tables", + reason: "assembly view does not contain metadata tables".into(), + }) +} diff --git a/dotscope/src/metadata/validation/shared/schema.rs b/dotscope/src/metadata/validation/shared/schema.rs index 9d6f7690..8665adb9 100644 --- a/dotscope/src/metadata/validation/shared/schema.rs +++ b/dotscope/src/metadata/validation/shared/schema.rs @@ -72,7 +72,7 @@ use crate::{ ScannerStatistics, }, }, - Error, Result, + Error, HeapKind, ParseFailure, Result, }; /// Shared schema validation utilities. @@ -425,10 +425,10 @@ impl<'a> SchemaValidator<'a> { let max_index = guid_heap_size / 16; // Each GUID is 16 bytes if index > max_index { - return Err(Error::HeapBoundsError { - heap: "guids".to_string(), + return Err(Error::Parse(ParseFailure::HeapOutOfBounds { + heap: HeapKind::Guid, index, - }); + })); } Ok(()) diff --git a/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs b/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs index e54a11cc..fbb7dbd6 100644 --- a/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs +++ b/dotscope/src/metadata/validation/validators/raw/constraints/generic.rs @@ -81,6 +81,7 @@ use crate::{ }, validation::{ context::{RawValidationContext, ValidationContext}, + shared::err_no_metadata_tables, traits::RawValidator, }, }, @@ -145,9 +146,7 @@ impl RawGenericConstraintValidator { /// - Owner coded index references are null (row = 0) /// - Name references are null (name = 0) fn validate_generic_parameters(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(generic_param_table) = tables.table::() { for generic_param in generic_param_table { @@ -200,9 +199,7 @@ impl RawGenericConstraintValidator { /// - Constraint coded index references are null (constraint.row = 0) /// - Owner references exceed GenericParam table row count fn validate_parameter_constraints(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(constraint_table) = tables.table::() { let generic_param_table = tables.table::(); @@ -259,9 +256,7 @@ impl RawGenericConstraintValidator { /// - Constraint owners reference non-existent GenericParam RIDs /// - Cross-table references are inconsistent between GenericParamConstraint and GenericParam tables fn validate_constraint_inheritance(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let (Some(generic_param_table), Some(constraint_table)) = ( tables.table::(), @@ -308,9 +303,7 @@ impl RawGenericConstraintValidator { /// - TypeRef constraint references exceed TypeRef table bounds /// - TypeSpec constraint references exceed TypeSpec table bounds fn validate_constraint_types(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(constraint_table) = tables.table::() { for constraint in constraint_table { @@ -408,9 +401,7 @@ impl RawGenericConstraintValidator { /// - Reserved flag bits are set /// - Variance flags used inappropriately (method vs type parameters) fn validate_parameter_flags(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(generic_param_table) = tables.table::() { for generic_param in generic_param_table { @@ -535,7 +526,7 @@ mod tests { validator_test( raw_generic_constraint_validator_file_factory, "RawGenericConstraintValidator", - "Malformed", + "Parse", config, |context| validator.validate_raw(context), ) diff --git a/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs b/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs index f18a1d7c..0e061ddb 100644 --- a/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs +++ b/dotscope/src/metadata/validation/validators/raw/constraints/layout.rs @@ -77,6 +77,7 @@ use crate::{ tables::{ClassLayoutRaw, FieldLayoutRaw, FieldRaw, TypeDefRaw}, validation::{ context::{RawValidationContext, ValidationContext}, + shared::err_no_metadata_tables, traits::RawValidator, }, }, @@ -144,9 +145,7 @@ impl RawLayoutConstraintValidator { /// - Field references are invalid or null (zero field reference) /// - Field references exceed Field table row count fn validate_field_layouts(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(field_layout_table) = tables.table::() { let mut field_offsets: FxHashMap> = FxHashMap::default(); @@ -225,9 +224,7 @@ impl RawLayoutConstraintValidator { /// - Parent type references are invalid (null or exceed TypeDef table row count) /// - Layout constraints are malformed fn validate_class_layouts(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(class_layout_table) = tables.table::() { let typedef_table = tables.table::(); @@ -304,9 +301,7 @@ impl RawLayoutConstraintValidator { /// - Field layouts exceed reasonable offset bounds (>1MB suggesting corruption) /// - ClassLayout parent references point to non-existent TypeDef entries fn validate_layout_consistency(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let (Some(class_layout_table), Some(field_layout_table), Some(typedef_table)) = ( tables.table::(), @@ -419,9 +414,7 @@ impl RawLayoutConstraintValidator { /// - Field layouts violate natural alignment requirements /// - Explicit layout fields have unreasonable spacing fn validate_field_alignment(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let (Some(field_layout_table), Some(_field_table)) = (tables.table::(), tables.table::()) @@ -480,9 +473,7 @@ impl RawLayoutConstraintValidator { /// - Value type packing constraints are inappropriate /// - Value type field layouts create alignment issues fn validate_value_type_layouts(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let (Some(class_layout_table), Some(typedef_table)) = ( tables.table::(), @@ -540,9 +531,7 @@ impl RawLayoutConstraintValidator { /// * `Ok(())` - All sequential layouts are valid /// * `Err(`[`crate::Error::ValidationRawFailed`]`)` - Sequential layout violations found fn validate_sequential_layout(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let Some(field_layout_table) = tables.table::() { let field_layouts: Vec<_> = field_layout_table.iter().collect(); @@ -679,7 +668,7 @@ mod tests { validator_test( raw_layout_constraint_validator_file_factory, "RawLayoutConstraintValidator", - "Malformed", + "Parse", config, |context| validator.validate_raw(context), ) diff --git a/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs b/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs index f939a604..580ca752 100644 --- a/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs +++ b/dotscope/src/metadata/validation/validators/raw/modification/integrity.rs @@ -1069,7 +1069,7 @@ mod tests { validator_test( raw_change_integrity_validator_file_factory, "RawChangeIntegrityValidator", - "Malformed", + "Parse", config, |context| validator.validate_raw(context), ) diff --git a/dotscope/src/metadata/validation/validators/raw/modification/operation.rs b/dotscope/src/metadata/validation/validators/raw/modification/operation.rs index 92c9df58..f376ae81 100644 --- a/dotscope/src/metadata/validation/validators/raw/modification/operation.rs +++ b/dotscope/src/metadata/validation/validators/raw/modification/operation.rs @@ -632,7 +632,7 @@ mod tests { validator_test( raw_operation_validator_file_factory, "RawOperationValidator", - "Malformed", + "Parse", config, |context| validator.validate_raw(context), ) diff --git a/dotscope/src/metadata/validation/validators/raw/structure/heap.rs b/dotscope/src/metadata/validation/validators/raw/structure/heap.rs index 0e9ed73d..0be72934 100644 --- a/dotscope/src/metadata/validation/validators/raw/structure/heap.rs +++ b/dotscope/src/metadata/validation/validators/raw/structure/heap.rs @@ -647,7 +647,7 @@ mod tests { validator_test( raw_heap_validator_file_factory, "RawHeapValidator", - "Malformed", + "Parse", config, |context| validator.validate_raw(context), ) @@ -668,7 +668,7 @@ mod tests { let result_disabled = validator_test( clean_only_factory, "RawHeapValidator", - "Malformed", + "Parse", ValidationConfig { enable_structural_validation: false, ..Default::default() @@ -691,7 +691,7 @@ mod tests { let result_enabled = validator_test( clean_only_factory, "RawHeapValidator", - "Malformed", + "Parse", ValidationConfig { enable_structural_validation: true, ..Default::default() diff --git a/dotscope/src/metadata/validation/validators/raw/structure/table.rs b/dotscope/src/metadata/validation/validators/raw/structure/table.rs index d291acfe..0c3cef41 100644 --- a/dotscope/src/metadata/validation/validators/raw/structure/table.rs +++ b/dotscope/src/metadata/validation/validators/raw/structure/table.rs @@ -79,6 +79,7 @@ use crate::{ tables::{AssemblyRaw, FieldRaw, MethodDefRaw, ModuleRaw, TableId, TypeDefRaw}, validation::{ context::{RawValidationContext, ValidationContext}, + shared::err_no_metadata_tables, traits::RawValidator, }, }, @@ -144,9 +145,7 @@ impl RawTableValidator { /// - Module table is present but contains zero rows (at least one required) /// - Assembly table contains more than one row (ECMA-335 limit violation) fn validate_required_tables(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; let module_table = tables.table::().ok_or_else(|| { malformed_error!("Module table is required but not present in assembly") @@ -192,9 +191,7 @@ impl RawTableValidator { /// - RID values within table rows are inconsistent with expected sequential numbering /// - Internal table structure inconsistencies are detected during iteration fn validate_table_structures(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; for table_id in TableId::iter() { dispatch_table_type!(table_id, |RawType| { @@ -238,9 +235,7 @@ impl RawTableValidator { /// - TypeDef method list references exceed MethodDef table row count /// - List-based cross-table references are out of bounds fn validate_table_dependencies(assembly_view: &CilAssemblyView) -> Result<()> { - let tables = assembly_view - .tables() - .ok_or_else(|| malformed_error!("Assembly view does not contain metadata tables"))?; + let tables = assembly_view.tables().ok_or_else(err_no_metadata_tables)?; if let (Some(typedef_table), Some(field_table)) = (tables.table::(), tables.table::()) @@ -397,7 +392,7 @@ mod tests { validator_test( raw_table_validator_file_factory, "RawTableValidator", - "Malformed", + "Parse", config, |context| validator.validate_raw(context), ) @@ -421,7 +416,7 @@ mod tests { let result_disabled = validator_test( clean_only_factory, "RawTableValidator", - "Malformed", + "Parse", ValidationConfig { enable_structural_validation: false, ..Default::default() @@ -444,7 +439,7 @@ mod tests { let result_enabled = validator_test( clean_only_factory, "RawTableValidator", - "Malformed", + "Parse", ValidationConfig { enable_structural_validation: true, ..Default::default() diff --git a/dotscope/src/prelude.rs b/dotscope/src/prelude.rs index 9c01f51f..3437ccd4 100644 --- a/dotscope/src/prelude.rs +++ b/dotscope/src/prelude.rs @@ -186,6 +186,13 @@ /// Provides detailed error context for debugging and user-friendly error messages. pub use crate::Error; +/// Structured parse-pipeline failure carried inside [`Error::Parse`]. +/// +/// Lets consumers categorize parse failures (truncated headers, bad magic, +/// unsupported schemas, heap corruption, invalid fields) without parsing +/// string messages. +pub use crate::{HeapKind, ParseFailure, ParseStage, StreamKind}; + /// The result type used throughout dotscope APIs. /// /// Standard `Result` type alias for consistent error handling across the library. diff --git a/dotscope/src/test/analysis/runner.rs b/dotscope/src/test/analysis/runner.rs index f1c7ea6d..bcd2bb4c 100644 --- a/dotscope/src/test/analysis/runner.rs +++ b/dotscope/src/test/analysis/runner.rs @@ -305,8 +305,8 @@ impl AnalysisTestRunner { // Get the method let method = match assembly.method(&token) { - Some(method) => method, - None => { + Ok(method) => method, + Err(_) => { return AnalysisTestResult::run_failed( test_case.name, format!("Method with token {:?} not in method table", token), diff --git a/dotscope/src/test/builders/methods.rs b/dotscope/src/test/builders/methods.rs index 819f67e4..ac82d773 100644 --- a/dotscope/src/test/builders/methods.rs +++ b/dotscope/src/test/builders/methods.rs @@ -160,6 +160,11 @@ impl MethodBuilder { self } + pub fn with_impl_code_type(mut self, impl_code_type: MethodImplCodeType) -> Self { + self.impl_code_type = impl_code_type; + self + } + /// Add a typed input parameter by name and type flavor pub fn add_typed_param(mut self, name: &str, type_flavor: CilFlavor) -> Self { let sequence = self.param_builders.len() as u32 + 1; diff --git a/dotscope/src/test/factories/validation/raw_constraints_generic.rs b/dotscope/src/test/factories/validation/raw_constraints_generic.rs index 34e32d13..0bd42f31 100644 --- a/dotscope/src/test/factories/validation/raw_constraints_generic.rs +++ b/dotscope/src/test/factories/validation/raw_constraints_generic.rs @@ -56,7 +56,7 @@ pub fn raw_generic_constraint_validator_file_factory() -> Result Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { // Create a valid generic parameter first let typedef_token = TypeDefBuilder::new() .name("GenericType") @@ -104,7 +104,7 @@ pub fn create_assembly_with_null_constraint_owner() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/constraints/generic.rs` pub fn create_assembly_with_constraint_owner_exceeding_bounds() -> Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { // Create a valid generic parameter first let typedef_token = TypeDefBuilder::new() .name("GenericType") @@ -152,7 +152,7 @@ pub fn create_assembly_with_constraint_owner_exceeding_bounds() -> Result Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { let typedef_builder = TypeDefBuilder::new() .name("GenericType") .namespace("Test") diff --git a/dotscope/src/test/factories/validation/raw_constraints_layout.rs b/dotscope/src/test/factories/validation/raw_constraints_layout.rs index 5a6fb09d..8cb94692 100644 --- a/dotscope/src/test/factories/validation/raw_constraints_layout.rs +++ b/dotscope/src/test/factories/validation/raw_constraints_layout.rs @@ -59,7 +59,7 @@ pub fn raw_layout_constraint_validator_file_factory() -> Result Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { // Create a basic type first let _typedef_token = TypeDefBuilder::new() .name("OverlappingFieldsType") @@ -103,7 +103,7 @@ pub fn create_assembly_with_overlapping_fields() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/constraints/layout.rs` pub fn create_assembly_with_invalid_packing_size() -> Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { let typedef_token = TypeDefBuilder::new() .name("InvalidPackingType") .namespace("Test") @@ -134,7 +134,7 @@ pub fn create_assembly_with_invalid_packing_size() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/constraints/layout.rs` pub fn create_assembly_with_excessive_class_size() -> Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { let typedef_token = TypeDefBuilder::new() .name("ExcessiveSizeType") .namespace("Test") @@ -165,7 +165,7 @@ pub fn create_assembly_with_excessive_class_size() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/constraints/layout.rs` pub fn create_assembly_with_invalid_field_offset() -> Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { let _typedef_token = TypeDefBuilder::new() .name("InvalidOffsetType") .namespace("Test") @@ -205,7 +205,7 @@ pub fn create_assembly_with_invalid_field_offset() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/constraints/layout.rs` pub fn create_assembly_with_null_field_reference() -> Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { let _typedef_token = TypeDefBuilder::new() .name("NullFieldRefType") .namespace("Test") @@ -239,7 +239,7 @@ pub fn create_assembly_with_null_field_reference() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/constraints/layout.rs` pub fn create_assembly_with_boundary_field_offset() -> Result { - create_test_assembly_with_error(get_testfile_wb, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_wb, "Parse", |assembly| { let _typedef_token = TypeDefBuilder::new() .name("BoundaryOffsetType") .namespace("Test") diff --git a/dotscope/src/test/factories/validation/raw_structure_table.rs b/dotscope/src/test/factories/validation/raw_structure_table.rs index f25c07c8..0e3eb505 100644 --- a/dotscope/src/test/factories/validation/raw_structure_table.rs +++ b/dotscope/src/test/factories/validation/raw_structure_table.rs @@ -52,7 +52,7 @@ pub fn raw_table_validator_file_factory() -> Result> { /// /// Originally from: `src/metadata/validation/validators/raw/structure/table.rs` pub fn create_assembly_with_empty_module_table() -> Result { - create_test_assembly_with_error(get_testfile_mscorlib, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_mscorlib, "Parse", |assembly| { // Delete the Module table row entirely - this will reduce row_count to 0 match assembly.table_row_remove(TableId::Module, 1) { Ok(()) => { @@ -77,7 +77,7 @@ pub fn create_assembly_with_empty_module_table() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/structure/table.rs` pub fn create_assembly_with_multiple_assembly_rows() -> Result { - create_test_assembly_with_error(get_testfile_mscorlib, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_mscorlib, "Parse", |assembly| { // Create a second Assembly row which violates ECMA-335 "at most 1 row" constraint // Use add_table_row to actually add a second row (increasing row_count to 2) let duplicate_assembly = AssemblyRaw { @@ -112,7 +112,7 @@ pub fn create_assembly_with_multiple_assembly_rows() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/structure/table.rs` pub fn create_assembly_with_field_list_violation() -> Result { - create_test_assembly_with_error(get_testfile_mscorlib, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_mscorlib, "Parse", |assembly| { // Create a TypeDef with field_list pointing beyond Field table bounds let invalid_typedef = TypeDefRaw { rid: 1, @@ -143,7 +143,7 @@ pub fn create_assembly_with_field_list_violation() -> Result { /// /// Originally from: `src/metadata/validation/validators/raw/structure/table.rs` pub fn create_assembly_with_method_list_violation() -> Result { - create_test_assembly_with_error(get_testfile_mscorlib, "Malformed", |assembly| { + create_test_assembly_with_error(get_testfile_mscorlib, "Parse", |assembly| { // Create a TypeDef with method_list pointing beyond MethodDef table bounds let invalid_typedef = TypeDefRaw { rid: 1, diff --git a/dotscope/src/utils/enums.rs b/dotscope/src/utils/enums.rs index 20d2002a..0ccf86b3 100644 --- a/dotscope/src/utils/enums.rs +++ b/dotscope/src/utils/enums.rs @@ -10,7 +10,7 @@ use crate::{ signatures::TypeSignature, typesystem::{CilTypeRc, TypeRegistry}, }, - Result, + ParseFailure, ParseStage, Result, }; use std::sync::Arc; @@ -141,10 +141,12 @@ impl EnumUtils { 2 => Ok(i64::from(parser.read_le::()?)), 4 => Ok(i64::from(parser.read_le::()?)), 8 => parser.read_le::(), - _ => Err(malformed_error!( - "Invalid enum underlying type size: {} bytes", - size_bytes - )), + _ => Err(ParseFailure::InvalidField { + stage: ParseStage::Generic, + field: "enum_underlying_type_size", + reason: format!("invalid size: {size_bytes} bytes"), + } + .into()), } } diff --git a/dotscope/src/utils/io.rs b/dotscope/src/utils/io.rs index 081f4a1f..af5d7399 100644 --- a/dotscope/src/utils/io.rs +++ b/dotscope/src/utils/io.rs @@ -156,7 +156,15 @@ //! making them safe to call concurrently from multiple threads. //! -use crate::Result; +use crate::{Error, ParseFailure, ParseStage, Result}; + +#[inline] +fn io_oob() -> Error { + ParseFailure::OutOfBounds { + stage: ParseStage::Generic, + } + .into() +} /// Trait for implementing type-specific safe binary data reading operations. /// @@ -536,10 +544,10 @@ pub fn read_le(data: &[u8]) -> Result { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn read_le_at(data: &[u8], offset: &mut usize) -> Result { let type_len = std::mem::size_of::(); - let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; - let slice = data.get(*offset..end).ok_or(out_of_bounds_error!())?; + let end = offset.checked_add(type_len).ok_or(io_oob())?; + let slice = data.get(*offset..end).ok_or(io_oob())?; let Ok(read) = slice.try_into() else { - return Err(out_of_bounds_error!()); + return Err(io_oob()); }; *offset = end; @@ -669,10 +677,10 @@ pub fn read_be(data: &[u8]) -> Result { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn read_be_at(data: &[u8], offset: &mut usize) -> Result { let type_len = std::mem::size_of::(); - let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; - let slice = data.get(*offset..end).ok_or(out_of_bounds_error!())?; + let end = offset.checked_add(type_len).ok_or(io_oob())?; + let slice = data.get(*offset..end).ok_or(io_oob())?; let Ok(read) = slice.try_into() else { - return Err(out_of_bounds_error!()); + return Err(io_oob()); }; *offset = end; @@ -805,9 +813,9 @@ pub fn write_le(data: &mut [u8], value: T) -> Result<()> { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn write_le_at(data: &mut [u8], offset: &mut usize, value: T) -> Result<()> { let type_len = std::mem::size_of::(); - let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; + let end = offset.checked_add(type_len).ok_or(io_oob())?; let bytes = value.to_le_bytes(); - let dst = data.get_mut(*offset..end).ok_or(out_of_bounds_error!())?; + let dst = data.get_mut(*offset..end).ok_or(io_oob())?; dst.copy_from_slice(bytes.as_ref()); *offset = end; @@ -946,9 +954,9 @@ pub fn write_be(data: &mut [u8], value: T) -> Result<()> { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn write_be_at(data: &mut [u8], offset: &mut usize, value: T) -> Result<()> { let type_len = std::mem::size_of::(); - let end = offset.checked_add(type_len).ok_or(out_of_bounds_error!())?; + let end = offset.checked_add(type_len).ok_or(io_oob())?; let bytes = value.to_be_bytes(); - let dst = data.get_mut(*offset..end).ok_or(out_of_bounds_error!())?; + let dst = data.get_mut(*offset..end).ok_or(io_oob())?; dst.copy_from_slice(bytes.as_ref()); *offset = end; @@ -1271,17 +1279,13 @@ pub fn decode_utf16le(bytes: &[u8]) -> Option { /// Note that the offset parameter is modified, so each thread should use its own offset variable. pub fn write_string_at(data: &mut [u8], offset: &mut usize, value: &str) -> Result<()> { let string_bytes = value.as_bytes(); - let after_str = offset - .checked_add(string_bytes.len()) - .ok_or(out_of_bounds_error!())?; - let after_null = after_str.checked_add(1).ok_or(out_of_bounds_error!())?; - - let dst = data - .get_mut(*offset..after_str) - .ok_or(out_of_bounds_error!())?; + let after_str = offset.checked_add(string_bytes.len()).ok_or(io_oob())?; + let after_null = after_str.checked_add(1).ok_or(io_oob())?; + + let dst = data.get_mut(*offset..after_str).ok_or(io_oob())?; dst.copy_from_slice(string_bytes); - *data.get_mut(after_str).ok_or(out_of_bounds_error!())? = 0; + *data.get_mut(after_str).ok_or(io_oob())? = 0; *offset = after_null; @@ -1328,32 +1332,32 @@ pub fn write_string_at(data: &mut [u8], offset: &mut usize, value: &str) -> Resu /// # Ok::<(), dotscope::Error>(()) /// ``` pub fn read_compressed_int(data: &[u8], offset: &mut usize) -> Result<(usize, usize)> { - let first_byte = *data.get(*offset).ok_or(out_of_bounds_error!())?; + let first_byte = *data.get(*offset).ok_or(io_oob())?; if first_byte & 0x80 == 0 { // Single byte: 0xxxxxxx - *offset = offset.checked_add(1).ok_or(out_of_bounds_error!())?; + *offset = offset.checked_add(1).ok_or(io_oob())?; Ok((first_byte as usize, 1)) } else if first_byte & 0xC0 == 0x80 { // Two bytes: 10xxxxxx xxxxxxxx - let next = offset.checked_add(1).ok_or(out_of_bounds_error!())?; - let second_byte = *data.get(next).ok_or(out_of_bounds_error!())?; + let next = offset.checked_add(1).ok_or(io_oob())?; + let second_byte = *data.get(next).ok_or(io_oob())?; let value = (((first_byte & 0x3F) as usize) << 8) | (second_byte as usize); - *offset = offset.checked_add(2).ok_or(out_of_bounds_error!())?; + *offset = offset.checked_add(2).ok_or(io_oob())?; Ok((value, 2)) } else { // Four bytes: 110xxxxx xxxxxxxx xxxxxxxx xxxxxxxx - let o1 = offset.checked_add(1).ok_or(out_of_bounds_error!())?; - let o2 = offset.checked_add(2).ok_or(out_of_bounds_error!())?; - let o3 = offset.checked_add(3).ok_or(out_of_bounds_error!())?; - let b1 = *data.get(o1).ok_or(out_of_bounds_error!())?; - let b2 = *data.get(o2).ok_or(out_of_bounds_error!())?; - let b3 = *data.get(o3).ok_or(out_of_bounds_error!())?; + let o1 = offset.checked_add(1).ok_or(io_oob())?; + let o2 = offset.checked_add(2).ok_or(io_oob())?; + let o3 = offset.checked_add(3).ok_or(io_oob())?; + let b1 = *data.get(o1).ok_or(io_oob())?; + let b2 = *data.get(o2).ok_or(io_oob())?; + let b3 = *data.get(o3).ok_or(io_oob())?; let mut value = ((first_byte & 0x1F) as usize) << 24; value |= (b1 as usize) << 16; value |= (b2 as usize) << 8; value |= b3 as usize; - *offset = offset.checked_add(4).ok_or(out_of_bounds_error!())?; + *offset = offset.checked_add(4).ok_or(io_oob())?; Ok((value, 4)) } } @@ -1439,7 +1443,7 @@ pub fn read_packed_len(data: &[u8]) -> Option<(usize, usize)> { /// ``` pub fn read_compressed_uint(data: &[u8], offset: &mut usize) -> Result { let (value, _consumed) = read_compressed_int(data, offset)?; - u32::try_from(value).map_err(|_| out_of_bounds_error!()) + u32::try_from(value).map_err(|_| io_oob()) } /// Reads a compressed unsigned integer from a specific offset without advancing a mutable offset. @@ -1473,7 +1477,7 @@ pub fn read_compressed_uint_at(data: &[u8], offset: usize) -> Result { #[cfg(test)] mod tests { use super::*; - use crate::Error; + use crate::{Error, ParseFailure}; const TEST_BUFFER: [u8; 8] = [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]; @@ -1640,10 +1644,16 @@ mod tests { let buffer = [0xFF, 0xFF, 0xFF, 0xFF]; let result = read_le::(&buffer); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); let result = read_le::(&buffer); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); } #[test] @@ -1933,10 +1943,16 @@ mod tests { // Try to write u32 (4 bytes) into 2-byte buffer let result = write_le(&mut buffer, 0x12345678u32); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); let result = write_be(&mut buffer, 0x12345678u32); - assert!(matches!(result, Err(Error::OutOfBounds { .. }))); + assert!(matches!( + result, + Err(Error::Parse(ParseFailure::OutOfBounds { .. })) + )); } #[test] diff --git a/dotscope/src/utils/math.rs b/dotscope/src/utils/math.rs index 53cc771b..e621a84c 100644 --- a/dotscope/src/utils/math.rs +++ b/dotscope/src/utils/math.rs @@ -1,6 +1,6 @@ //! Mathematical utility functions. -use crate::Result; +use crate::{ParseFailure, ParseStage, Result}; /// Converts a `usize` to `u32` for PE serialization, returning an error if the value /// exceeds `u32::MAX`. All .NET metadata structures are bounded well below this limit. @@ -9,8 +9,14 @@ use crate::Result; /// /// Returns an error if `value` exceeds `u32::MAX`. pub fn to_u32(value: usize) -> Result { - u32::try_from(value) - .map_err(|_| malformed_error!("PE serialization value {value} exceeds u32::MAX")) + u32::try_from(value).map_err(|_| { + ParseFailure::InvalidField { + stage: ParseStage::Generic, + field: "u32_value", + reason: format!("PE serialization value {value} exceeds u32::MAX"), + } + .into() + }) } /// Converts a `usize` to `i32` with saturation at `i32::MAX`. diff --git a/dotscope/tests/bitmono.rs b/dotscope/tests/bitmono.rs index 1ad5b496..c17f6aaa 100644 --- a/dotscope/tests/bitmono.rs +++ b/dotscope/tests/bitmono.rs @@ -1029,7 +1029,7 @@ fn test_dotnethook_call_targets_for_sample(sample: &str) { .unwrap_or_default(); format!("{}.{}", tdt, tm.name) }) - .unwrap_or_else(|| format!("ext:0x{:08X}", t.value())); + .unwrap_or_else(|_| format!("ext:0x{:08X}", t.value())); format!("{} 0x{:08X}({})", i.mnemonic, t.value(), target_name) }) }) @@ -1150,7 +1150,7 @@ fn dump_call_targets(assembly: &CilObject, method_name: &str, label: &str) { let target_name = assembly .method(&target) .map(|m| m.name.clone()) - .unwrap_or_else(|| format!("(unresolved 0x{:08X})", target.value())); + .unwrap_or_else(|_| format!("(unresolved 0x{:08X})", target.value())); eprintln!( " {} 0x{:08X} -> {}", instr.mnemonic, @@ -1176,7 +1176,7 @@ fn get_call_target_names(assembly: &CilObject, method_name: &str) -> Vec let name = assembly .method(&target) .map(|m| m.name.clone()) - .unwrap_or_else(|| format!("(external 0x{:08X})", target.value())); + .unwrap_or_else(|_| format!("(external 0x{:08X})", target.value())); names.push(name); } } @@ -1239,7 +1239,7 @@ fn test_dotnethook_offset_diagnostic() { .unwrap_or_default(); format!("{}.{}", tdt, tm.name) }) - .unwrap_or_else(|| format!("ext:0x{:08X}", t.value())); + .unwrap_or_else(|_| format!("ext:0x{:08X}", t.value())); format!("{} 0x{:08X}({})", i.mnemonic, t.value(), target_name) }) }) diff --git a/dotscope/tests/common/verification.rs b/dotscope/tests/common/verification.rs index ed4d280e..48958e15 100644 --- a/dotscope/tests/common/verification.rs +++ b/dotscope/tests/common/verification.rs @@ -295,7 +295,7 @@ impl MethodSemantics { } } ConstValue::DecryptedString(content) => { - semantics.strings.insert(content.clone()); + semantics.strings.insert(content.to_string()); } ConstValue::I8(v) => { semantics.integer_constants.insert(i64::from(*v)); @@ -989,7 +989,7 @@ pub fn resolve_method_name(assembly: &CilObject, token: Token) -> Option match table_id { 0x06 => { - let method = assembly.method(&token)?; + let method = assembly.method(&token).ok()?; for entry in assembly.types().iter() { let cil_type = entry.value(); diff --git a/dotscope/tests/modify_roundtrips_method.rs b/dotscope/tests/modify_roundtrips_method.rs index 3fae4a6e..8e12d570 100644 --- a/dotscope/tests/modify_roundtrips_method.rs +++ b/dotscope/tests/modify_roundtrips_method.rs @@ -184,14 +184,12 @@ fn verify_injected_method(assembly: &CilObject) -> Result<()> { found_injected_method = true; // Verify it has a method body - let body = method - .body - .get() - .ok_or_else(|| dotscope::Error::Malformed { - message: "Injected method should have a body".to_string(), - file: file!(), - line: line!(), - })?; + let body = method.body.get().ok_or_else(|| { + Error::Parse(ParseFailure::Other { + stage: ParseStage::MethodBody, + message: "injected method should have a body".into(), + }) + })?; // Verify the body has reasonable size (our method should be small) assert!( From 356d3fecda27be185c955fadffb745dd61ddd7d9 Mon Sep 17 00:00:00 2001 From: BinFlip Date: Wed, 3 Jun 2026 19:04:10 -0700 Subject: [PATCH 5/6] fix: analyssa crate path --- Cargo.lock | 2 ++ dotscope/Cargo.toml | 2 +- 2 files changed, 3 insertions(+), 1 deletion(-) diff --git a/Cargo.lock b/Cargo.lock index 7bb863a3..42b45820 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -100,6 +100,8 @@ checksum = "683d7910e743518b0e34f1186f92494becacb047c7b6bf616c96772180fef923" [[package]] name = "analyssa" version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "345118481d5b39980833d7dcf9ef72f21d0251b090b5233489e36487d3af2bc9" dependencies = [ "boxcar", "dashmap", diff --git a/dotscope/Cargo.toml b/dotscope/Cargo.toml index d5c1fb1f..97b99aff 100644 --- a/dotscope/Cargo.toml +++ b/dotscope/Cargo.toml @@ -69,7 +69,7 @@ hex = "0.4.3" num-bigint = { version = "0.4.6", optional = true } log = "0.4.31" flate2 = "1.1.9" -analyssa = { path = "../../analyssa" } +analyssa = "0.2.0" lzma-rs = "0.3.0" z3 = { version = "0.20.0", optional = true } iced-x86 = { version = "1.21.0", default-features = false, features = ["std", "decoder", "instr_info"], optional = true } From 5b030782bf99c59b0f386781379800b06fa86b7b Mon Sep 17 00:00:00 2001 From: BinFlip Date: Wed, 3 Jun 2026 19:46:47 -0700 Subject: [PATCH 6/6] fix: clippy warnings for rust 1.96 --- dotscope/src/analysis/ssa/converter.rs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/dotscope/src/analysis/ssa/converter.rs b/dotscope/src/analysis/ssa/converter.rs index f2eb5093..df84e11f 100644 --- a/dotscope/src/analysis/ssa/converter.rs +++ b/dotscope/src/analysis/ssa/converter.rs @@ -905,7 +905,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { // (CALLI has VarPop/VarPush so net_effect=0 without resolution) let net_effect = match instr.opcode { opcodes::CALL | opcodes::CALLVIRT | opcodes::NEWOBJ => assembly - .and_then(|asm| Self::extract_token(&instr.operand).map(|t| (asm, t))) + .zip(Self::extract_token(&instr.operand)) .and_then(|(asm, token)| { Self::resolve_call_info(asm, token).map( |(param_count, has_this, has_return)| { @@ -933,7 +933,7 @@ impl<'a, 'cfg> SsaConverter<'a, 'cfg> { }) .unwrap_or(instr.stack_behavior.net_effect), opcodes::CALLI => assembly - .and_then(|asm| Self::extract_token(&instr.operand).map(|t| (asm, t))) + .zip(Self::extract_token(&instr.operand)) .and_then(|(asm, token)| { Self::resolve_calli_info(asm, token).map( |(param_count, has_this, has_return)| {