/*
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.lucene.analysis.hunspell;

import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_APPEND;
import static org.apache.lucene.analysis.hunspell.Dictionary.AFFIX_FLAG;
import static org.apache.lucene.analysis.hunspell.Dictionary.FLAG_UNSET;
import static org.apache.lucene.analysis.hunspell.Dictionary.toSortedCharArray;

import java.io.IOException;
import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.PriorityQueue;
import java.util.Set;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.lucene.analysis.hunspell.AffixedWord.Affix;
import org.apache.lucene.internal.hppc.CharHashSet;
import org.apache.lucene.internal.hppc.CharObjectHashMap;
import org.apache.lucene.util.IntsRef;
import org.apache.lucene.util.fst.FST;
import org.apache.lucene.util.fst.IntsRefFSTEnum;

/**
 * A utility class used for generating possible word forms by adding affixes to stems ({@link
 * #getAllWordForms(String, String, Runnable)}), and suggesting stems and flags to generate the
 * given set of words ({@link #compress(List, Set, Runnable)}).
 */
public class WordFormGenerator {
  private final Dictionary dictionary;
  private final CharObjectHashMap<List<AffixEntry>> affixes = new CharObjectHashMap<>();
  private final Stemmer stemmer;

  public WordFormGenerator(Dictionary dictionary) {
    this.dictionary = dictionary;
    fillAffixMap(dictionary.prefixes, AffixKind.PREFIX);
    fillAffixMap(dictionary.suffixes, AffixKind.SUFFIX);
    stemmer = new Stemmer(dictionary);
  }

  private void fillAffixMap(FST<IntsRef> fst, AffixKind kind) {
    if (fst == null) return;

    IntsRefFSTEnum<IntsRef> fstEnum = new IntsRefFSTEnum<>(fst);
    try {
      while (true) {
        IntsRefFSTEnum.InputOutput<IntsRef> io = fstEnum.next();
        if (io == null) break;

        IntsRef affixIds = io.output;
        for (int j = 0; j < affixIds.length; j++) {
          int id = affixIds.ints[affixIds.offset + j];
          char flag = dictionary.affixData(id, AFFIX_FLAG);
          var entry =
              new AffixEntry(id, flag, kind, toString(kind, io.input), strip(id), condition(id));
          List<AffixEntry> entries;
          int index = affixes.indexOf(flag);
          if (index < 0) {
            entries = new ArrayList<>();
            affixes.indexInsert(index, flag, entries);
          } else {
            entries = affixes.indexGet(index);
          }
          entries.add(entry);
        }
      }
    } catch (IOException e) {
      throw new UncheckedIOException(e);
    }
  }

  private String toString(AffixKind kind, IntsRef input) {
    char[] affixChars = new char[input.length];
    for (int i = 0; i < affixChars.length; i++) {
      affixChars[kind == AffixKind.PREFIX ? i : affixChars.length - i - 1] =
          (char) input.ints[input.offset + i];
    }
    return new String(affixChars);
  }

  private AffixCondition condition(int affixId) {
    int condition = dictionary.getAffixCondition(affixId);
    return condition == 0 ? AffixCondition.ALWAYS_TRUE : dictionary.patterns.get(condition);
  }

  private String strip(int affixId) {
    int stripOrd = dictionary.affixData(affixId, Dictionary.AFFIX_STRIP_ORD);
    int stripStart = dictionary.stripOffsets[stripOrd];
    int stripEnd = dictionary.stripOffsets[stripOrd + 1];
    return new String(dictionary.stripData, stripStart, stripEnd - stripStart);
  }

  /**
   * Generate all word forms for all dictionary entries with the given root word. The result order
   * is stable but not specified. This is equivalent to "unmunch" from the "hunspell-tools" package.
   *
   * @param checkCanceled an object that's periodically called, allowing to interrupt the generation
   *     by throwing an exception
   */
  public List<AffixedWord> getAllWordForms(String root, Runnable checkCanceled) {
    List<AffixedWord> result = new ArrayList<>();
    DictEntries entries = dictionary.lookupEntries(root);
    if (entries != null) {
      for (DictEntry entry : entries) {
        result.addAll(getAllWordForms(root, entry.getFlags(), checkCanceled));
      }
    }
    return result;
  }

  /**
   * Generate all word forms for the given root pretending it has the given flags (in the same
   * format as the dictionary uses). The result order is stable but not specified. This is
   * equivalent to "unmunch" from the "hunspell-tools" package.
   *
   * @param checkCanceled an object that's periodically called, allowing to interrupt the generation
   *     by throwing an exception
   */
  public List<AffixedWord> getAllWordForms(String stem, String flags, Runnable checkCanceled) {
    var encodedFlags = dictionary.flagParsingStrategy.parseUtfFlags(flags);
    if (!shouldConsiderAtAll(encodedFlags)) return List.of();

    return getAllWordForms(DictEntry.create(stem, flags), encodedFlags, checkCanceled);
  }

  private List<AffixedWord> getAllWordForms(
      DictEntry entry, char[] encodedFlags, Runnable checkCanceled) {
    encodedFlags = sortAndDeduplicate(encodedFlags);
    List<AffixedWord> result = new ArrayList<>();
    AffixedWord bare = new AffixedWord(entry.getStem(), entry, List.of(), List.of());
    checkCanceled.run();
    if (!FlagEnumerator.hasFlagInSortedArray(
        dictionary.needaffix, encodedFlags, 0, encodedFlags.length)) {
      result.add(bare);
    }
    result.addAll(expand(bare, encodedFlags, checkCanceled));
    return result;
  }

  private static char[] sortAndDeduplicate(char[] flags) {
    Arrays.sort(flags);
    for (int i = 1; i < flags.length; i++) {
      if (flags[i] == flags[i - 1]) {
        return deduplicate(flags);
      }
    }
    return flags;
  }

  private static char[] deduplicate(char[] flags) {
    return toSortedCharArray(CharHashSet.from(flags));
  }

  /**
   * A sanity-check that the word form generated by affixation in {@link #getAllWordForms(String,
   * String, Runnable)} is indeed accepted by the spell-checker and analyzed to be the form of the
   * original dictionary entry. This can be overridden for cases where such check is unnecessary or
   * can be done more efficiently.
   */
  protected boolean canStemToOriginal(AffixedWord derived) {
    String word = derived.getWord();
    char[] chars = word.toCharArray();
    if (isForbiddenWord(chars, 0, chars.length)) {
      return false;
    }

    String stem = derived.getDictEntry().getStem();
    var processor =
        new Stemmer.StemCandidateProcessor(WordContext.SIMPLE_WORD) {
          boolean foundStem = false;
          boolean foundForbidden = false;

          @Override
          boolean processStemCandidate(
              char[] chars,
              int offset,
              int length,
              int lastAffix,
              int outerPrefix,
              int innerPrefix,
              int outerSuffix,
              int innerSuffix) {
            if (isForbiddenWord(chars, offset, length)) {
              foundForbidden = true;
              return false;
            }
            foundStem |= length == stem.length() && stem.equals(new String(chars, offset, length));
            return !foundStem;
          }
        };
    stemmer.removeAffixes(chars, 0, chars.length, true, -1, -1, -1, processor);
    return processor.foundStem && !processor.foundForbidden;
  }

  private boolean isForbiddenWord(char[] chars, int offset, int length) {
    if (dictionary.forbiddenword != FLAG_UNSET) {
      IntsRef forms = dictionary.lookupWord(chars, offset, length);
      if (forms != null) {
        for (int i = 0; i < forms.length; i += dictionary.formStep()) {
          if (dictionary.hasFlag(forms.ints[forms.offset + i], dictionary.forbiddenword)) {
            return true;
          }
        }
      }
    }
    return false;
  }

  private List<AffixedWord> expand(AffixedWord stem, char[] flags, Runnable checkCanceled) {
    List<AffixedWord> result = new ArrayList<>();
    for (char flag : flags) {
      List<AffixEntry> entries = affixes.get(flag);
      if (entries == null) continue;

      AffixKind kind = entries.get(0).kind;
      if (!isCompatibleWithPreviousAffixes(stem, kind, flag)) continue;

      for (AffixEntry affix : entries) {
        checkCanceled.run();
        AffixedWord derived = affix.apply(stem, dictionary);
        if (derived != null) {
          char[] append = appendFlags(affix);
          if (shouldConsiderAtAll(append)) {
            if (canStemToOriginal(derived)) {
              result.add(derived);
            }
            if (dictionary.isCrossProduct(affix.id)) {
              result.addAll(expand(derived, updateFlags(flags, flag, append), checkCanceled));
            }
          }
        }
      }
    }
    return result;
  }

  private boolean shouldConsiderAtAll(char[] flags) {
    for (char flag : flags) {
      if (flag == dictionary.compoundBegin
          || flag == dictionary.compoundMiddle
          || flag == dictionary.compoundEnd
          || flag == dictionary.forbiddenword
          || flag == dictionary.onlyincompound) {
        return false;
      }
    }

    return true;
  }

  private char[] updateFlags(char[] flags, char toRemove, char[] toAppend) {
    char[] result = new char[flags.length + toAppend.length - 1];
    int index = 0;
    for (char flag : flags) {
      if (flag != toRemove && flag != dictionary.needaffix) {
        result[index++] = flag;
      }
    }
    for (char flag : toAppend) {
      result[index++] = flag;
    }
    return sortAndDeduplicate(result);
  }

  private char[] appendFlags(AffixEntry affix) {
    char appendId = dictionary.affixData(affix.id, AFFIX_APPEND);
    return appendId == 0 ? new char[0] : dictionary.flagLookup.getFlags(appendId);
  }

  /**
   * Traverse the whole dictionary and derive all word forms via affixation (as in {@link
   * #getAllWordForms(String, String, Runnable)}) for each of the entries. The iteration order is
   * undefined. Only "simple" words are returned, no compounding flags are processed. Upper- and
   * title-case variations are not returned, even if the spellchecker accepts them.
   *
   * @param consumer the object that receives each derived word form
   * @param checkCanceled an object that's periodically called, allowing to interrupt the traversal
   *     and generation by throwing an exception
   */
  public void generateAllSimpleWords(Consumer<AffixedWord> consumer, Runnable checkCanceled) {
    dictionary.words.processAllWords(
        1,
        Integer.MAX_VALUE,
        false,
        e -> {
          String rootStr = e.root().toString();
          IntsRef forms = e.forms();
          for (int i = 0; i < forms.length; i += dictionary.formStep()) {
            char[] encodedFlags = dictionary.flagLookup.getFlags(forms.ints[forms.offset + i]);
            if (shouldConsiderAtAll(encodedFlags)) {
              String presentableFlags = dictionary.flagParsingStrategy.printFlags(encodedFlags);
              DictEntry entry = DictEntry.create(rootStr, presentableFlags);
              for (AffixedWord aw : getAllWordForms(entry, encodedFlags, checkCanceled)) {
                consumer.accept(aw);
              }
            }
          }
        });
  }

  /**
   * Given a list of words, try to produce a smaller set of dictionary entries (with some flags)
   * that would generate these words. This is equivalent to "munch" from the "hunspell-tools"
   * package. The algorithm tries to minimize the number of the dictionary entries to add or change,
   * the number of flags involved, and the number of non-requested additionally generated words. All
   * the mentioned words are in the dictionary format and case: no ICONV/OCONV/IGNORE conversions
   * are applied.
   *
   * @param words the list of words to generate
   * @param forbidden the set of words to avoid generating
   * @param checkCanceled an object that's periodically called, allowing to interrupt the generation
   *     by throwing an exception
   * @return the information about suggested dictionary entries and overgenerated words, or {@code
   *     null} if the algorithm couldn't generate anything
   */
  public EntrySuggestion compress(
      List<String> words, Set<String> forbidden, Runnable checkCanceled) {
    if (words.isEmpty()) return null;
    if (words.stream().anyMatch(forbidden::contains)) {
      throw new IllegalArgumentException("'words' and 'forbidden' shouldn't intersect");
    }

    return new WordCompressor(words, forbidden, checkCanceled).compress();
  }

  private record AffixEntry(
      int id, char flag, AffixKind kind, String affix, String strip, AffixCondition condition) {

    AffixedWord apply(AffixedWord stem, Dictionary dictionary) {
      String word = stem.getWord();
      boolean isPrefix = kind == AffixKind.PREFIX;
      if (!(isPrefix ? word.startsWith(strip) : word.endsWith(strip))) return null;

      String stripped =
          isPrefix
              ? word.substring(strip.length())
              : word.substring(0, word.length() - strip.length());
      if (!condition.acceptsStem(stripped)) return null;

      String applied = isPrefix ? affix + stripped : stripped + affix;
      List<Affix> prefixes = isPrefix ? new ArrayList<>(stem.getPrefixes()) : stem.getPrefixes();
      List<Affix> suffixes = isPrefix ? stem.getSuffixes() : new ArrayList<>(stem.getSuffixes());
      (isPrefix ? prefixes : suffixes).add(0, new Affix(dictionary, id));
      return new AffixedWord(applied, stem.getDictEntry(), prefixes, suffixes);
    }
  }

  private boolean isCompatibleWithPreviousAffixes(AffixedWord stem, AffixKind kind, char flag) {
    boolean isPrefix = kind == AffixKind.PREFIX;
    List<Affix> sameAffixes = isPrefix ? stem.getPrefixes() : stem.getSuffixes();
    int size = sameAffixes.size();
    if (size == 2) return false;
    if (isPrefix && size == 1 && !dictionary.complexPrefixes) return false;
    if (!isPrefix && !stem.getPrefixes().isEmpty()) return false;
    if (size == 1 && !dictionary.isFlagAppendedByAffix(sameAffixes.get(0).affixId, flag)) {
      return false;
    }
    return true;
  }

  private class WordCompressor {
    private final Comparator<State> solutionFitness =
        Comparator.comparingInt((State s) -> -s.potentialCoverage)
            .thenComparingInt(s -> s.stemToFlags.size())
            .thenComparingInt(s -> s.underGenerated)
            .thenComparingInt(s -> s.overGenerated);
    private final Set<String> forbidden;
    private final Runnable checkCanceled;
    private final Set<String> wordSet;
    private final Set<String> existingStems;
    private final Map<String, Set<FlagSet>> stemToPossibleFlags = new HashMap<>();
    private final Map<String, Set<String>> stemsToForms = new LinkedHashMap<>();

    WordCompressor(List<String> words, Set<String> forbidden, Runnable checkCanceled) {
      this.forbidden = forbidden;
      this.checkCanceled = checkCanceled;
      wordSet = new HashSet<>(words);

      for (String word : words) {
        checkCanceled.run();
        stemToPossibleFlags.computeIfAbsent(word, __ -> new LinkedHashSet<>());
        var processor =
            new Stemmer.StemCandidateProcessor(WordContext.SIMPLE_WORD) {
              @Override
              boolean processStemCandidate(
                  char[] chars,
                  int offset,
                  int length,
                  int lastAffix,
                  int outerPrefix,
                  int innerPrefix,
                  int outerSuffix,
                  int innerSuffix) {
                String candidate = new String(chars, offset, length);
                CharHashSet flags = new CharHashSet();
                if (outerPrefix >= 0) flags.add(dictionary.affixData(outerPrefix, AFFIX_FLAG));
                if (innerPrefix >= 0) flags.add(dictionary.affixData(innerPrefix, AFFIX_FLAG));
                if (outerSuffix >= 0) flags.add(dictionary.affixData(outerSuffix, AFFIX_FLAG));
                if (innerSuffix >= 0) flags.add(dictionary.affixData(innerSuffix, AFFIX_FLAG));
                FlagSet flagSet = new FlagSet(flags, dictionary);
                StemWithFlags swf = new StemWithFlags(candidate, Set.of(flagSet));
                if (forbidden.isEmpty()
                    || allGenerated(swf).stream().noneMatch(forbidden::contains)) {
                  registerStem(candidate);
                  stemToPossibleFlags
                      .computeIfAbsent(candidate, __ -> new LinkedHashSet<>())
                      .add(flagSet);
                }
                return true;
              }

              void registerStem(String stem) {
                stemsToForms.computeIfAbsent(stem, __ -> new LinkedHashSet<>()).add(word);
              }
            };
        processor.registerStem(word);
        stemmer.removeAffixes(word.toCharArray(), 0, word.length(), true, -1, -1, -1, processor);
      }

      existingStems =
          stemsToForms.keySet().stream()
              .filter(stem -> dictionary.lookupEntries(stem) != null)
              .collect(Collectors.toSet());
    }

    EntrySuggestion compress() {
      Comparator<String> stemSorter =
          Comparator.comparing((String s) -> existingStems.contains(s))
              .thenComparing(s -> stemsToForms.get(s).size())
              .reversed();
      List<String> sortedStems = stemsToForms.keySet().stream().sorted(stemSorter).toList();
      PriorityQueue<State> queue = new PriorityQueue<>(solutionFitness);
      Set<Map<String, Set<FlagSet>>> visited = new HashSet<>();
      queue.offer(new State(Map.of(), wordSet.size(), 0, 0));
      State result = null;
      while (!queue.isEmpty()) {
        State state = queue.poll();
        if (state.underGenerated == 0) {
          result = state;
          break;
        }

        for (String stem : sortedStems) {
          if (!state.stemToFlags.containsKey(stem)) {
            var withStem = addStem(state, stem);
            if (visited.add(withStem)) {
              var next = newState(withStem);
              if (next != null
                  && (state.underGenerated > next.underGenerated
                      || next.potentialCoverage > state.potentialCoverage)) {
                queue.offer(next);
              }
            }
          }
        }

        if (state.potentialCoverage < wordSet.size()) {
          // don't add flags until the suggested entries can potentially cover all requested forms
          continue;
        }

        for (Map.Entry<String, Set<FlagSet>> entry : state.stemToFlags.entrySet()) {
          for (FlagSet flags : stemToPossibleFlags.get(entry.getKey())) {
            if (!entry.getValue().contains(flags)) {
              var withFlags = addFlags(state, entry.getKey(), flags);
              if (visited.add(withFlags)) {
                var next = newState(withFlags);
                if (next != null && state.underGenerated > next.underGenerated) {
                  queue.offer(next);
                }
              }
            }
          }
        }
      }
      return result == null ? null : toSuggestion(result);
    }

    EntrySuggestion toSuggestion(State state) {
      List<DictEntry> toEdit = new ArrayList<>();
      List<DictEntry> toAdd = new ArrayList<>();
      for (Map.Entry<String, Set<FlagSet>> entry : state.stemToFlags.entrySet()) {
        addEntry(toEdit, toAdd, entry.getKey(), FlagSet.flatten(entry.getValue()));
      }

      List<String> extraGenerated = new ArrayList<>();
      for (String extra : allGenerated(state.stemToFlags).distinct().sorted().toList()) {
        if (wordSet.contains(extra) || existingStems.contains(extra)) continue;

        if (forbidden.contains(extra) && dictionary.forbiddenword != FLAG_UNSET) {
          addEntry(toEdit, toAdd, extra, CharHashSet.from(dictionary.forbiddenword));
        } else {
          extraGenerated.add(extra);
        }
      }

      return new EntrySuggestion(toEdit, toAdd, extraGenerated);
    }

    private void addEntry(
        List<DictEntry> toEdit, List<DictEntry> toAdd, String stem, CharHashSet flags) {
      String flagString = toFlagString(flags);
      (existingStems.contains(stem) ? toEdit : toAdd).add(DictEntry.create(stem, flagString));
    }

    private Map<String, Set<FlagSet>> addStem(State state, String stem) {
      Map<String, Set<FlagSet>> stemToFlags = new LinkedHashMap<>(state.stemToFlags);
      stemToFlags.put(stem, Set.of());
      return stemToFlags;
    }

    private Map<String, Set<FlagSet>> addFlags(State state, String stem, FlagSet flags) {
      Map<String, Set<FlagSet>> stemToFlags = new LinkedHashMap<>(state.stemToFlags);
      Set<FlagSet> flagSets = new LinkedHashSet<>(stemToFlags.get(stem));
      flagSets.add(flags);
      stemToFlags.put(stem, flagSets);
      return stemToFlags;
    }

    private State newState(Map<String, Set<FlagSet>> stemToFlags) {
      Set<String> allGenerated = allGenerated(stemToFlags).collect(Collectors.toSet());
      int overGenerated = 0;
      for (String s : allGenerated) {
        if (forbidden.contains(s)) return null;
        if (!wordSet.contains(s)) overGenerated++;
      }

      int potentialCoverage =
          (int)
              stemToFlags.keySet().stream()
                  .flatMap(s -> stemsToForms.get(s).stream())
                  .distinct()
                  .count();
      return new State(
          stemToFlags,
          (int) wordSet.stream().filter(s -> !allGenerated.contains(s)).count(),
          overGenerated,
          potentialCoverage);
    }

    private final Map<StemWithFlags, List<String>> expansionCache = new HashMap<>();

    private record StemWithFlags(String stem, Set<FlagSet> flags) {}

    private List<String> allGenerated(StemWithFlags swc) {
      Function<StemWithFlags, List<String>> expandToWords =
          e -> expand(e.stem, FlagSet.flatten(e.flags)).stream().map(w -> w.getWord()).toList();
      return expansionCache.computeIfAbsent(swc, expandToWords);
    }

    private Stream<String> allGenerated(Map<String, Set<FlagSet>> stemToFlags) {
      return stemToFlags.entrySet().stream()
          .flatMap(
              entry -> allGenerated(new StemWithFlags(entry.getKey(), entry.getValue())).stream());
    }

    private List<AffixedWord> expand(String stem, CharHashSet flagSet) {
      return getAllWordForms(stem, toFlagString(flagSet), checkCanceled);
    }

    private String toFlagString(CharHashSet flagSet) {
      return dictionary.flagParsingStrategy.printFlags(Dictionary.toSortedCharArray(flagSet));
    }
  }

  private record FlagSet(CharHashSet flags, Dictionary dictionary) {
    static CharHashSet flatten(Set<FlagSet> flagSets) {
      CharHashSet set = new CharHashSet(flagSets.size() << 1);
      flagSets.forEach(flagSet -> set.addAll(flagSet.flags));
      return set;
    }

    @Override
    public String toString() {
      return dictionary.flagParsingStrategy.printFlags(Dictionary.toSortedCharArray(flags));
    }
  }

  private record State(
      Map<String, Set<FlagSet>> stemToFlags,
      int underGenerated,
      int overGenerated,

      // The maximum number of requested forms possibly generated by adding only flags to this state
      int potentialCoverage) {}
}
