/*****************************************************************
 * C/C++ Headers
 *****************************************************************/

#include <stdio.h>
#include <ctype.h>
#include <stdlib.h>
#include <time.h>

#include <iostream>
#include <sstream>
#include <vector>
#include <string>
#include <map>

using std::vector;
using std::string;
using std::cout;
using std::endl;
using std::map;
using std::ostream;
using std::ostringstream;

/*****************************************************************
 * Local Types
 *****************************************************************/

// -------------------------
// WordPair
// -------------------------

struct WordPair
{
   WordPair(const string &first, const string &second);
   WordPair();

   static string fmt(const string &s);
   string toString() const;
   void   dump()     const;

   string mFirst;
   string mSecond;
};

WordPair::WordPair(const string &first, const string &second) :
   mFirst(first), mSecond(second)
{
}

WordPair::WordPair()
{
}

string WordPair::fmt(const string &s)
{
   static bool first = true;
   static map<string, string> theMap;

   if(first)
   {
      string entries[][2] =
      {
         { "\r", "\\r" },
         { "\n", "\\n" },
         { "\t", "\\t" }
      };

      int n = sizeof(entries) / sizeof(entries[0]);

      for(int i = 0; i < n; i ++)
      {
         string from = entries[i][0];
         string to   = entries[i][1];

         theMap[from] = to;
      }

      first = false;
   }

   if(theMap.find(s) == theMap.end())
      return(s);

   return(theMap[s]);
}

string WordPair::toString() const
{
   ostringstream out;
   out << "[" << fmt(mFirst) << "|" << fmt(mSecond) << "]";
   return(out.str());
}

void WordPair::dump() const
{
   cout << toString() << endl;
}

// -------------------------
// WordPairLessThan
// -------------------------

struct WordPairLessThan
{
   bool operator()(const WordPair &one, const WordPair &two) const
   {
      if(one.mFirst < two.mFirst)
         return(true);

      if(one.mFirst > two.mFirst)
         return(false);

      return(one.mSecond < two.mSecond);
   }
};

/*****************************************************************
 * Local Typedefs
 *****************************************************************/

typedef vector<string>   Words;
typedef map<string, int> WordToFrequency;
typedef map<WordPair, WordToFrequency, WordPairLessThan> WordPairMap;

/*****************************************************************
 * Local Functions
 *****************************************************************/

static void dumpWordToFrequency(
   const WordToFrequency &theMap, string indent = ""
)
{
   WordToFrequency::const_iterator i = theMap.begin();

   while(i != theMap.end())
   {
      cout << indent << WordPair::fmt(i->first) << ": " << i->second << endl;
      i++;
   }
}

static void dumpWordPairMap(const WordPairMap &theMap)
{
   WordPairMap::const_iterator i = theMap.begin();

   while(i != theMap.end())
   {
      i->first.dump();
      dumpWordToFrequency(i->second, "   ");
      i++;
   }
}

static void makeMap(WordPairMap &theMap, Words &words)
{
   int n = words.size();

   for(int i = 0; i < n - 2; i ++)
   {
      WordPair thePair(words[i], words[i + 1]);

      // If we've not seen this pair before, associate
      // it with an empty WordToFrequency object

      if(theMap.find(thePair) == theMap.end())
         theMap[thePair] = WordToFrequency();

      string next = words[i + 2];

      // If this pair hasn't been followed by 'next' before,
      // set up an initial entry for 'next'

      if(theMap[thePair].find(next) == theMap[thePair].end())
         theMap[thePair][next] = 0;

      // Record that 'next' followed this pair

      theMap[thePair][next]++;
   }
}

static bool wordsFromFile(
   const string &fileName, bool doCharacters, Words &words, string &errMsg
)
{
   FILE *f = 0;

   if((f = fopen(fileName.c_str(), "r")) == NULL)
   {
      ostringstream out;

      out << "Could not open: [" << fileName << "]";
      errMsg = out.str();

      return(false);
   }

   int c = 0;

   if(doCharacters)
   {
      while((c = fgetc(f)) != EOF)
      {
         string s;
         s += (char)c;
         words.push_back(s);
      }
   }
   else
   {
      string current = "";
      bool inWord    = false;

      while((c = fgetc(f)) != EOF)
      {
         if(inWord)
         {
            if(isspace(c))
            {
               inWord = false;
               words.push_back(current);
               current = "";
            }
            else
            {
               current += (char)c;
            }
         }
         else
         {
            if(!isspace(c))
            {
               inWord = true;
               current += (char)c;
            }
         }
      }

      if(current != "")
         words.push_back(current);
   }

   fclose(f);

   return(true);
}

static int getRandom(int tooLarge)
{
   static bool initialized = false;

   if(!initialized)
   {
      srand(time(NULL));
      initialized = true;
   }

   int value = rand();

   if(value < 0)
      value = -value;

   return(value % tooLarge);
}

static string dash()
{
   static bool first = true;
   static string s   = "";

   if(first)
   {
      for(int i = 0; i < 70; i ++)
         s += "-";

      first = false;
   }

   return(s);
}

static void usage(const string &program)
{
   cout <<
      "Usage: " << program << " fileName { arguments }"            << endl <<
      "   nOutputWords: integer > 0, amount of output to generate" << endl <<
      "   -c: operate on characters instead of words"              << endl <<
      "   -x: exit on dead end instead of choosing a random word"  << endl <<
      "   -r: random starting pair instead of first pair"          << endl <<
      "   -d: debug output"                                        << endl;

   exit(1);
}

int main(int argc, char **argv)
{
   string program = argv[0];

   if(argc < 2)
      usage(program);

   string fileName        = argv[1];
   int    nOutputWords    = 100;
   bool   doCharacters    = false;
   bool   exitOnDeadEnd   = false;
   bool   randomStartPair = false;
   bool   debug           = false;
   int    i               = 0;

   for(i = 2; i < argc; i ++)
   {
      if(argv[i][0] == '-')
      {
         if(strlen(argv[i]) < 2)
            usage(program);

         switch(argv[i][1])
         {
            case 'c': case 'C': doCharacters    = true; break;
            case 'x': case 'X': exitOnDeadEnd   = true; break;
            case 'r': case 'R': randomStartPair = true; break;
            case 'd': case 'D': debug           = true; break;
            default : usage(program);
         }
      }
      else
      {
         char *ptr           = NULL;
         unsigned long value = strtoul(argv[i], &ptr, 10);

         if(*ptr || !value)
            usage(program);

         nOutputWords = value;
      }
   }

   Words  theWords;
   string errMsg;

   if(!wordsFromFile(fileName, doCharacters, theWords, errMsg))
   {
      cout << errMsg << endl;
      return(1);
   }

   int nWords = theWords.size();

   if(nWords < 3)
   {
      cout << "Too few words in: [" << fileName << "]" << endl;
      return(1);
   }

   if(debug)
   {
      cout << dash() << endl;
      cout << "debug: Words - From: [" << fileName << "]" << endl;
      cout << dash() << endl;

      for(int j = 0; j < nWords; j ++)
      {
         cout <<
            "word[" << j << "] : " <<
            "{" << WordPair::fmt(theWords[j]) << "}" << endl;
      }
   }

   WordPairMap theMap;

   makeMap(theMap, theWords);

   if(debug)
   {
      cout << dash() << endl;
      cout << "debug: WordPairMap" << endl;
      cout << dash() << endl;

      dumpWordPairMap(theMap);
   }

   int startIndex = 0;

   if(randomStartPair)
      startIndex = getRandom(nWords - 1);

   WordPair thePair(theWords[startIndex], theWords[startIndex + 1]);

   string separator = (doCharacters ? "" : " ");

   cout << thePair.mFirst << separator << thePair.mSecond;

   for(i = 0; i < nOutputWords; i ++)
   {
      // Each pair is mapped to a 'WordToFrequency' object
      // that records the words that followed this pair and
      // how often they followed, something like ...
      //
      //    w["the"]  => 2
      //    w["and"]  => 1
      //    w["some"] => 6
      //
      // If only one word ever followed this pair, we'll
      // just use that word but if there were multiple
      // words, we select one at random.

      WordToFrequency w = theMap[thePair];
      string nextWord   = "";

      if(w.size() == 0)
      {
         if(exitOnDeadEnd)
            break;

         // No words found following this pair,
         // use a random word

         nextWord = theWords[getRandom(nWords)];
      }
      else
      {
         if(w.size() < 2)
         {
            nextWord  = w.begin()->first;
         }
         else
         {
            WordToFrequency::const_iterator it = w.begin();

            for(int total = 0; it != w.end(); it ++)
            {
               // If this is the first pass through, use
               // this word for 'nextWord'

               if(total == 0)
               {
                  nextWord = it->first;
               }
               else
               {
                  // We'll either leave 'nextWord' alone or update it to
                  // it->second, the odds that we'll update it will be:
                  //
                  //    it->second / (total + it->second)
                  //
                  // So, we get a random value between 0 and
                  // total + it->second - 1 and if that value is
                  // less than it->second we do the update

                  int theValue = getRandom(total + it->second);

                  if(theValue < it->second)
                     nextWord = it->first;
               }

               total += it->second;
            }
         }
      }

      cout << separator << nextWord;

      thePair = WordPair(thePair.mSecond, nextWord);
   }

   return(0);
}