aho_corasick/
dfa.rs

1use std::mem::size_of;
2
3use ahocorasick::MatchKind;
4use automaton::Automaton;
5use classes::ByteClasses;
6use error::Result;
7use nfa::{PatternID, PatternLength, NFA};
8use prefilter::{Prefilter, PrefilterObj, PrefilterState};
9use state_id::{dead_id, fail_id, premultiply_overflow_error, StateID};
10use Match;
11
12#[derive(Clone, Debug)]
13pub enum DFA<S> {
14    Standard(Standard<S>),
15    ByteClass(ByteClass<S>),
16    Premultiplied(Premultiplied<S>),
17    PremultipliedByteClass(PremultipliedByteClass<S>),
18}
19
20impl<S: StateID> DFA<S> {
21    fn repr(&self) -> &Repr<S> {
22        match *self {
23            DFA::Standard(ref dfa) => dfa.repr(),
24            DFA::ByteClass(ref dfa) => dfa.repr(),
25            DFA::Premultiplied(ref dfa) => dfa.repr(),
26            DFA::PremultipliedByteClass(ref dfa) => dfa.repr(),
27        }
28    }
29
30    pub fn match_kind(&self) -> &MatchKind {
31        &self.repr().match_kind
32    }
33
34    pub fn heap_bytes(&self) -> usize {
35        self.repr().heap_bytes
36    }
37
38    pub fn max_pattern_len(&self) -> usize {
39        self.repr().max_pattern_len
40    }
41
42    pub fn pattern_count(&self) -> usize {
43        self.repr().pattern_count
44    }
45
46    pub fn start_state(&self) -> S {
47        self.repr().start_id
48    }
49
50    #[inline(always)]
51    pub fn overlapping_find_at(
52        &self,
53        prestate: &mut PrefilterState,
54        haystack: &[u8],
55        at: usize,
56        state_id: &mut S,
57        match_index: &mut usize,
58    ) -> Option<Match> {
59        match *self {
60            DFA::Standard(ref dfa) => dfa.overlapping_find_at(
61                prestate,
62                haystack,
63                at,
64                state_id,
65                match_index,
66            ),
67            DFA::ByteClass(ref dfa) => dfa.overlapping_find_at(
68                prestate,
69                haystack,
70                at,
71                state_id,
72                match_index,
73            ),
74            DFA::Premultiplied(ref dfa) => dfa.overlapping_find_at(
75                prestate,
76                haystack,
77                at,
78                state_id,
79                match_index,
80            ),
81            DFA::PremultipliedByteClass(ref dfa) => dfa.overlapping_find_at(
82                prestate,
83                haystack,
84                at,
85                state_id,
86                match_index,
87            ),
88        }
89    }
90
91    #[inline(always)]
92    pub fn earliest_find_at(
93        &self,
94        prestate: &mut PrefilterState,
95        haystack: &[u8],
96        at: usize,
97        state_id: &mut S,
98    ) -> Option<Match> {
99        match *self {
100            DFA::Standard(ref dfa) => {
101                dfa.earliest_find_at(prestate, haystack, at, state_id)
102            }
103            DFA::ByteClass(ref dfa) => {
104                dfa.earliest_find_at(prestate, haystack, at, state_id)
105            }
106            DFA::Premultiplied(ref dfa) => {
107                dfa.earliest_find_at(prestate, haystack, at, state_id)
108            }
109            DFA::PremultipliedByteClass(ref dfa) => {
110                dfa.earliest_find_at(prestate, haystack, at, state_id)
111            }
112        }
113    }
114
115    #[inline(always)]
116    pub fn find_at_no_state(
117        &self,
118        prestate: &mut PrefilterState,
119        haystack: &[u8],
120        at: usize,
121    ) -> Option<Match> {
122        match *self {
123            DFA::Standard(ref dfa) => {
124                dfa.find_at_no_state(prestate, haystack, at)
125            }
126            DFA::ByteClass(ref dfa) => {
127                dfa.find_at_no_state(prestate, haystack, at)
128            }
129            DFA::Premultiplied(ref dfa) => {
130                dfa.find_at_no_state(prestate, haystack, at)
131            }
132            DFA::PremultipliedByteClass(ref dfa) => {
133                dfa.find_at_no_state(prestate, haystack, at)
134            }
135        }
136    }
137}
138
139#[derive(Clone, Debug)]
140pub struct Standard<S>(Repr<S>);
141
142impl<S: StateID> Standard<S> {
143    fn repr(&self) -> &Repr<S> {
144        &self.0
145    }
146}
147
148impl<S: StateID> Automaton for Standard<S> {
149    type ID = S;
150
151    fn match_kind(&self) -> &MatchKind {
152        &self.repr().match_kind
153    }
154
155    fn anchored(&self) -> bool {
156        self.repr().anchored
157    }
158
159    fn prefilter(&self) -> Option<&dyn Prefilter> {
160        self.repr().prefilter.as_ref().map(|p| p.as_ref())
161    }
162
163    fn start_state(&self) -> S {
164        self.repr().start_id
165    }
166
167    fn is_valid(&self, id: S) -> bool {
168        id.to_usize() < self.repr().state_count
169    }
170
171    fn is_match_state(&self, id: S) -> bool {
172        self.repr().is_match_state(id)
173    }
174
175    fn is_match_or_dead_state(&self, id: S) -> bool {
176        self.repr().is_match_or_dead_state(id)
177    }
178
179    fn get_match(
180        &self,
181        id: S,
182        match_index: usize,
183        end: usize,
184    ) -> Option<Match> {
185        self.repr().get_match(id, match_index, end)
186    }
187
188    fn match_count(&self, id: S) -> usize {
189        self.repr().match_count(id)
190    }
191
192    unsafe fn next_state_unchecked(&self, current: S, input: u8) -> S {
193        let o = current.to_usize() * 256 + input as usize;
194        *self.repr().trans.get_unchecked(o)
195    }
196}
197
198#[derive(Clone, Debug)]
199pub struct ByteClass<S>(Repr<S>);
200
201impl<S: StateID> ByteClass<S> {
202    fn repr(&self) -> &Repr<S> {
203        &self.0
204    }
205}
206
207impl<S: StateID> Automaton for ByteClass<S> {
208    type ID = S;
209
210    fn match_kind(&self) -> &MatchKind {
211        &self.repr().match_kind
212    }
213
214    fn anchored(&self) -> bool {
215        self.repr().anchored
216    }
217
218    fn prefilter(&self) -> Option<&dyn Prefilter> {
219        self.repr().prefilter.as_ref().map(|p| p.as_ref())
220    }
221
222    fn start_state(&self) -> S {
223        self.repr().start_id
224    }
225
226    fn is_valid(&self, id: S) -> bool {
227        id.to_usize() < self.repr().state_count
228    }
229
230    fn is_match_state(&self, id: S) -> bool {
231        self.repr().is_match_state(id)
232    }
233
234    fn is_match_or_dead_state(&self, id: S) -> bool {
235        self.repr().is_match_or_dead_state(id)
236    }
237
238    fn get_match(
239        &self,
240        id: S,
241        match_index: usize,
242        end: usize,
243    ) -> Option<Match> {
244        self.repr().get_match(id, match_index, end)
245    }
246
247    fn match_count(&self, id: S) -> usize {
248        self.repr().match_count(id)
249    }
250
251    unsafe fn next_state_unchecked(&self, current: S, input: u8) -> S {
252        let alphabet_len = self.repr().byte_classes.alphabet_len();
253        let input = self.repr().byte_classes.get(input);
254        let o = current.to_usize() * alphabet_len + input as usize;
255        *self.repr().trans.get_unchecked(o)
256    }
257}
258
259#[derive(Clone, Debug)]
260pub struct Premultiplied<S>(Repr<S>);
261
262impl<S: StateID> Premultiplied<S> {
263    fn repr(&self) -> &Repr<S> {
264        &self.0
265    }
266}
267
268impl<S: StateID> Automaton for Premultiplied<S> {
269    type ID = S;
270
271    fn match_kind(&self) -> &MatchKind {
272        &self.repr().match_kind
273    }
274
275    fn anchored(&self) -> bool {
276        self.repr().anchored
277    }
278
279    fn prefilter(&self) -> Option<&dyn Prefilter> {
280        self.repr().prefilter.as_ref().map(|p| p.as_ref())
281    }
282
283    fn start_state(&self) -> S {
284        self.repr().start_id
285    }
286
287    fn is_valid(&self, id: S) -> bool {
288        (id.to_usize() / 256) < self.repr().state_count
289    }
290
291    fn is_match_state(&self, id: S) -> bool {
292        self.repr().is_match_state(id)
293    }
294
295    fn is_match_or_dead_state(&self, id: S) -> bool {
296        self.repr().is_match_or_dead_state(id)
297    }
298
299    fn get_match(
300        &self,
301        id: S,
302        match_index: usize,
303        end: usize,
304    ) -> Option<Match> {
305        if id > self.repr().max_match {
306            return None;
307        }
308        self.repr()
309            .matches
310            .get(id.to_usize() / 256)
311            .and_then(|m| m.get(match_index))
312            .map(|&(id, len)| Match { pattern: id, len, end })
313    }
314
315    fn match_count(&self, id: S) -> usize {
316        let o = id.to_usize() / 256;
317        self.repr().matches[o].len()
318    }
319
320    unsafe fn next_state_unchecked(&self, current: S, input: u8) -> S {
321        let o = current.to_usize() + input as usize;
322        *self.repr().trans.get_unchecked(o)
323    }
324}
325
326#[derive(Clone, Debug)]
327pub struct PremultipliedByteClass<S>(Repr<S>);
328
329impl<S: StateID> PremultipliedByteClass<S> {
330    fn repr(&self) -> &Repr<S> {
331        &self.0
332    }
333}
334
335impl<S: StateID> Automaton for PremultipliedByteClass<S> {
336    type ID = S;
337
338    fn match_kind(&self) -> &MatchKind {
339        &self.repr().match_kind
340    }
341
342    fn anchored(&self) -> bool {
343        self.repr().anchored
344    }
345
346    fn prefilter(&self) -> Option<&dyn Prefilter> {
347        self.repr().prefilter.as_ref().map(|p| p.as_ref())
348    }
349
350    fn start_state(&self) -> S {
351        self.repr().start_id
352    }
353
354    fn is_valid(&self, id: S) -> bool {
355        (id.to_usize() / self.repr().alphabet_len()) < self.repr().state_count
356    }
357
358    fn is_match_state(&self, id: S) -> bool {
359        self.repr().is_match_state(id)
360    }
361
362    fn is_match_or_dead_state(&self, id: S) -> bool {
363        self.repr().is_match_or_dead_state(id)
364    }
365
366    fn get_match(
367        &self,
368        id: S,
369        match_index: usize,
370        end: usize,
371    ) -> Option<Match> {
372        if id > self.repr().max_match {
373            return None;
374        }
375        self.repr()
376            .matches
377            .get(id.to_usize() / self.repr().alphabet_len())
378            .and_then(|m| m.get(match_index))
379            .map(|&(id, len)| Match { pattern: id, len, end })
380    }
381
382    fn match_count(&self, id: S) -> usize {
383        let o = id.to_usize() / self.repr().alphabet_len();
384        self.repr().matches[o].len()
385    }
386
387    unsafe fn next_state_unchecked(&self, current: S, input: u8) -> S {
388        let input = self.repr().byte_classes.get(input);
389        let o = current.to_usize() + input as usize;
390        *self.repr().trans.get_unchecked(o)
391    }
392}
393
394#[derive(Clone, Debug)]
395pub struct Repr<S> {
396    match_kind: MatchKind,
397    anchored: bool,
398    premultiplied: bool,
399    start_id: S,
400    /// The length, in bytes, of the longest pattern in this automaton. This
401    /// information is useful for keeping correct buffer sizes when searching
402    /// on streams.
403    max_pattern_len: usize,
404    /// The total number of patterns added to this automaton. This includes
405    /// patterns that may never match.
406    pattern_count: usize,
407    state_count: usize,
408    max_match: S,
409    /// The number of bytes of heap used by this NFA's transition table.
410    heap_bytes: usize,
411    /// A prefilter for quickly detecting candidate matchs, if pertinent.
412    prefilter: Option<PrefilterObj>,
413    byte_classes: ByteClasses,
414    trans: Vec<S>,
415    matches: Vec<Vec<(PatternID, PatternLength)>>,
416}
417
418impl<S: StateID> Repr<S> {
419    /// Returns the total alphabet size for this DFA.
420    ///
421    /// If byte classes are enabled, then this corresponds to the number of
422    /// equivalence classes. If they are disabled, then this is always 256.
423    fn alphabet_len(&self) -> usize {
424        self.byte_classes.alphabet_len()
425    }
426
427    /// Returns true only if the given state is a match state.
428    fn is_match_state(&self, id: S) -> bool {
429        id <= self.max_match && id > dead_id()
430    }
431
432    /// Returns true only if the given state is either a dead state or a match
433    /// state.
434    fn is_match_or_dead_state(&self, id: S) -> bool {
435        id <= self.max_match
436    }
437
438    /// Get the ith match for the given state, where the end position of a
439    /// match was found at `end`.
440    ///
441    /// # Panics
442    ///
443    /// The caller must ensure that the given state identifier is valid,
444    /// otherwise this may panic. The `match_index` need not be valid. That is,
445    /// if the given state has no matches then this returns `None`.
446    fn get_match(
447        &self,
448        id: S,
449        match_index: usize,
450        end: usize,
451    ) -> Option<Match> {
452        if id > self.max_match {
453            return None;
454        }
455        self.matches
456            .get(id.to_usize())
457            .and_then(|m| m.get(match_index))
458            .map(|&(id, len)| Match { pattern: id, len, end })
459    }
460
461    /// Return the total number of matches for the given state.
462    ///
463    /// # Panics
464    ///
465    /// The caller must ensure that the given identifier is valid, or else
466    /// this panics.
467    fn match_count(&self, id: S) -> usize {
468        self.matches[id.to_usize()].len()
469    }
470
471    /// Get the next state given `from` as the current state and `byte` as the
472    /// current input byte.
473    fn next_state(&self, from: S, byte: u8) -> S {
474        let alphabet_len = self.alphabet_len();
475        let byte = self.byte_classes.get(byte);
476        self.trans[from.to_usize() * alphabet_len + byte as usize]
477    }
478
479    /// Set the `byte` transition for the `from` state to point to `to`.
480    fn set_next_state(&mut self, from: S, byte: u8, to: S) {
481        let alphabet_len = self.alphabet_len();
482        let byte = self.byte_classes.get(byte);
483        self.trans[from.to_usize() * alphabet_len + byte as usize] = to;
484    }
485
486    /// Swap the given states in place.
487    fn swap_states(&mut self, id1: S, id2: S) {
488        assert!(!self.premultiplied, "can't swap states in premultiplied DFA");
489
490        let o1 = id1.to_usize() * self.alphabet_len();
491        let o2 = id2.to_usize() * self.alphabet_len();
492        for b in 0..self.alphabet_len() {
493            self.trans.swap(o1 + b, o2 + b);
494        }
495        self.matches.swap(id1.to_usize(), id2.to_usize());
496    }
497
498    /// This routine shuffles all match states in this DFA to the beginning
499    /// of the DFA such that every non-match state appears after every match
500    /// state. (With one exception: the special fail and dead states remain as
501    /// the first two states.)
502    ///
503    /// The purpose of doing this shuffling is to avoid an extra conditional
504    /// in the search loop, and in particular, detecting whether a state is a
505    /// match or not does not need to access any memory.
506    ///
507    /// This updates `self.max_match` to point to the last matching state as
508    /// well as `self.start` if the starting state was moved.
509    fn shuffle_match_states(&mut self) {
510        assert!(
511            !self.premultiplied,
512            "cannot shuffle match states of premultiplied DFA"
513        );
514
515        if self.state_count <= 1 {
516            return;
517        }
518
519        let mut first_non_match = self.start_id.to_usize();
520        while first_non_match < self.state_count
521            && self.matches[first_non_match].len() > 0
522        {
523            first_non_match += 1;
524        }
525
526        let mut swaps: Vec<S> = vec![fail_id(); self.state_count];
527        let mut cur = self.state_count - 1;
528        while cur > first_non_match {
529            if self.matches[cur].len() > 0 {
530                self.swap_states(
531                    S::from_usize(cur),
532                    S::from_usize(first_non_match),
533                );
534                swaps[cur] = S::from_usize(first_non_match);
535                swaps[first_non_match] = S::from_usize(cur);
536
537                first_non_match += 1;
538                while first_non_match < cur
539                    && self.matches[first_non_match].len() > 0
540                {
541                    first_non_match += 1;
542                }
543            }
544            cur -= 1;
545        }
546        for id in (0..self.state_count).map(S::from_usize) {
547            let alphabet_len = self.alphabet_len();
548            let offset = id.to_usize() * alphabet_len;
549            for next in &mut self.trans[offset..offset + alphabet_len] {
550                if swaps[next.to_usize()] != fail_id() {
551                    *next = swaps[next.to_usize()];
552                }
553            }
554        }
555        if swaps[self.start_id.to_usize()] != fail_id() {
556            self.start_id = swaps[self.start_id.to_usize()];
557        }
558        self.max_match = S::from_usize(first_non_match - 1);
559    }
560
561    fn premultiply(&mut self) -> Result<()> {
562        if self.premultiplied || self.state_count <= 1 {
563            return Ok(());
564        }
565
566        let alpha_len = self.alphabet_len();
567        premultiply_overflow_error(
568            S::from_usize(self.state_count - 1),
569            alpha_len,
570        )?;
571
572        for id in (2..self.state_count).map(S::from_usize) {
573            let offset = id.to_usize() * alpha_len;
574            for next in &mut self.trans[offset..offset + alpha_len] {
575                if *next == dead_id() {
576                    continue;
577                }
578                *next = S::from_usize(next.to_usize() * alpha_len);
579            }
580        }
581        self.premultiplied = true;
582        self.start_id = S::from_usize(self.start_id.to_usize() * alpha_len);
583        self.max_match = S::from_usize(self.max_match.to_usize() * alpha_len);
584        Ok(())
585    }
586
587    /// Computes the total amount of heap used by this NFA in bytes.
588    fn calculate_size(&mut self) {
589        let mut size = (self.trans.len() * size_of::<S>())
590            + (self.matches.len()
591                * size_of::<Vec<(PatternID, PatternLength)>>());
592        for state_matches in &self.matches {
593            size +=
594                state_matches.len() * size_of::<(PatternID, PatternLength)>();
595        }
596        size += self.prefilter.as_ref().map_or(0, |p| p.as_ref().heap_bytes());
597        self.heap_bytes = size;
598    }
599}
600
601/// A builder for configuring the determinization of an NFA into a DFA.
602#[derive(Clone, Debug)]
603pub struct Builder {
604    premultiply: bool,
605    byte_classes: bool,
606}
607
608impl Builder {
609    /// Create a new builder for a DFA.
610    pub fn new() -> Builder {
611        Builder { premultiply: true, byte_classes: true }
612    }
613
614    /// Build a DFA from the given NFA.
615    ///
616    /// This returns an error if the state identifiers exceed their
617    /// representation size. This can only happen when state ids are
618    /// premultiplied (which is enabled by default).
619    pub fn build<S: StateID>(&self, nfa: &NFA<S>) -> Result<DFA<S>> {
620        let byte_classes = if self.byte_classes {
621            nfa.byte_classes().clone()
622        } else {
623            ByteClasses::singletons()
624        };
625        let alphabet_len = byte_classes.alphabet_len();
626        let trans = vec![fail_id(); alphabet_len * nfa.state_len()];
627        let matches = vec![vec![]; nfa.state_len()];
628        let mut repr = Repr {
629            match_kind: nfa.match_kind().clone(),
630            anchored: nfa.anchored(),
631            premultiplied: false,
632            start_id: nfa.start_state(),
633            max_pattern_len: nfa.max_pattern_len(),
634            pattern_count: nfa.pattern_count(),
635            state_count: nfa.state_len(),
636            max_match: fail_id(),
637            heap_bytes: 0,
638            prefilter: nfa.prefilter_obj().map(|p| p.clone()),
639            byte_classes: byte_classes.clone(),
640            trans: trans,
641            matches: matches,
642        };
643        for id in (0..nfa.state_len()).map(S::from_usize) {
644            repr.matches[id.to_usize()].extend_from_slice(nfa.matches(id));
645
646            let fail = nfa.failure_transition(id);
647            nfa.iter_all_transitions(&byte_classes, id, |b, mut next| {
648                if next == fail_id() {
649                    next = nfa_next_state_memoized(nfa, &repr, id, fail, b);
650                }
651                repr.set_next_state(id, b, next);
652            });
653        }
654        repr.shuffle_match_states();
655        repr.calculate_size();
656        if self.premultiply {
657            repr.premultiply()?;
658            if byte_classes.is_singleton() {
659                Ok(DFA::Premultiplied(Premultiplied(repr)))
660            } else {
661                Ok(DFA::PremultipliedByteClass(PremultipliedByteClass(repr)))
662            }
663        } else {
664            if byte_classes.is_singleton() {
665                Ok(DFA::Standard(Standard(repr)))
666            } else {
667                Ok(DFA::ByteClass(ByteClass(repr)))
668            }
669        }
670    }
671
672    /// Whether to use byte classes or in the DFA.
673    pub fn byte_classes(&mut self, yes: bool) -> &mut Builder {
674        self.byte_classes = yes;
675        self
676    }
677
678    /// Whether to premultiply state identifier in the DFA.
679    pub fn premultiply(&mut self, yes: bool) -> &mut Builder {
680        self.premultiply = yes;
681        self
682    }
683}
684
685/// This returns the next NFA transition (including resolving failure
686/// transitions), except once it sees a state id less than the id of the DFA
687/// state that is currently being populated, then we no longer need to follow
688/// failure transitions and can instead query the pre-computed state id from
689/// the DFA itself.
690///
691/// In general, this should only be called when a failure transition is seen.
692fn nfa_next_state_memoized<S: StateID>(
693    nfa: &NFA<S>,
694    dfa: &Repr<S>,
695    populating: S,
696    mut current: S,
697    input: u8,
698) -> S {
699    loop {
700        if current < populating {
701            return dfa.next_state(current, input);
702        }
703        let next = nfa.next_state(current, input);
704        if next != fail_id() {
705            return next;
706        }
707        current = nfa.failure_transition(current);
708    }
709}