using System; using System.Collections.Generic; using System.IO; using System.Linq; using System.Text; using BayesSharp.Combiners; using BayesSharp.Tokenizers; using Newtonsoft.Json; namespace BayesSharp { public class BayesClassifier where TTagType : IComparable { private TagDictionary _tags = new TagDictionary(); private TagDictionary _cache; private readonly ITokenizer _tokenizer; private readonly ICombiner _combiner; private bool _mustRecache; private const double Tolerance = 0.0001; private const double Threshold = 0.1; public BayesClassifier(ITokenizer tokenizer) : this(tokenizer, new RobinsonCombiner()) { } public BayesClassifier(ITokenizer tokenizer, ICombiner combiner) { if (tokenizer == null) throw new ArgumentNullException("tokenizer"); if (combiner == null) throw new ArgumentNullException("combiner"); _tokenizer = tokenizer; _combiner = combiner; _tags.SystemTag = new TagData(); _mustRecache = true; } /// /// Create a new tag, without actually doing any training. /// /// Tag Id public void AddTag(TTagType tagId) { GetAndAddIfNotFound(_tags.Items, tagId); _mustRecache = true; } /// /// Remove a tag /// /// Tag Id public void RemoveTag(TTagType tagId) { _tags.Items.Remove(tagId); _mustRecache = true; } /// /// Change the Id of a tag /// /// Old Tag Id /// New Tag Id public void ChangeTagId(TTagType oldTagId, TTagType newTagId) { _tags.Items[newTagId] = _tags.Items[oldTagId]; RemoveTag(oldTagId); _mustRecache = true; } /// /// Merge an existing tag into another /// /// Tag to merged to destTagId and removed /// Destination tag Id public void MergeTags(TTagType sourceTagId, TTagType destTagId) { var sourceTag = _tags.Items[sourceTagId]; var destTag = _tags.Items[destTagId]; var count = 0; foreach (var tagItem in sourceTag.Items) { count++; var tok = tagItem; if (destTag.Items.ContainsKey(tok.Key)) { destTag.Items[tok.Key] += count; } else { destTag.Items[tok.Key] = count; destTag.TokenCount += 1; } } RemoveTag(sourceTagId); _mustRecache = true; } /// /// Return a TagData object of a Tag Id informed /// /// Tag Id public TagData GetTagById(TTagType tagId) { return _tags.Items.ContainsKey(tagId) ? _tags.Items[tagId] : null; } /// /// Save Bayes Text Classifier into a file /// /// The file to write to public void Save(string path) { using (var streamWriter = new StreamWriter(path, false, Encoding.UTF8)) { JsonSerializer.Create().Serialize(streamWriter, _tags); } } /// /// Load Bayes Text Classifier from a file /// /// The file to open for reading public void Load(string path) { using (var streamReader = new StreamReader(path, Encoding.UTF8)) { using (var jsonTextReader = new JsonTextReader(streamReader)) { _tags = JsonSerializer.Create().Deserialize>(jsonTextReader); } } _mustRecache = true; } /// /// Import Bayes Text Classifier from a json string /// /// The json content to be loaded public void ImportJsonData(string json) { var result = JsonConvert.DeserializeObject>(json); switch (result != null) { case true: _tags = result; _mustRecache = true; break; default: _tags = new TagDictionary(); break; } } /// /// Export Bayes Text Classifier to a json string /// public string ExportJsonData() { return _tags?.Items != null && _tags.Items.Any() ? JsonConvert.SerializeObject(_tags) : string.Empty; } /// /// Return a sorted list of Tag Ids /// public IEnumerable TagIds() { return _tags.Items.Keys.OrderBy(p => p); } /// /// Train Bayes by telling him that input belongs in tag. /// /// Tag Id /// Input to be trained public void Train(TTagType tagId, string input) { var tokens = _tokenizer.Tokenize(input); var tag = GetAndAddIfNotFound(_tags.Items, tagId); _train(tag, tokens); _tags.SystemTag.TrainCount += 1; tag.TrainCount += 1; _mustRecache = true; } /// /// Untrain Bayes by telling him that input no more belongs in tag. /// /// Tag Id /// Input to be untrained public void Untrain(TTagType tagId, string input) { var tokens = _tokenizer.Tokenize(input); var tag = _tags.Items[tagId]; if (tag == null) { return; } _untrain(tag, tokens); _tags.SystemTag.TrainCount += 1; tag.TrainCount += 1; _mustRecache = true; } /// /// Returns the scores in each tag the provided input /// /// Input to be classified public Dictionary Classify(string input) { var tokens = _tokenizer.Tokenize(input).ToList(); var tags = CreateCacheAnsGetTags(); var stats = new Dictionary(); foreach (var tag in tags.Items) { var probs = GetProbabilities(tag.Value, tokens).ToList(); if (probs.Count() != 0) { stats[tag.Key] = _combiner.Combine(probs); } } return stats.OrderByDescending(s => s.Value).ToDictionary(s => s.Key, pair => pair.Value); } #region Private Methods private void _train(TagData tag, IEnumerable tokens) { var tokenCount = 0; foreach (var token in tokens) { var count = tag.Get(token, 0); tag.Items[token] = count + 1; count = _tags.SystemTag.Get(token, 0); _tags.SystemTag.Items[token] = count + 1; tokenCount += 1; } tag.TokenCount += tokenCount; _tags.SystemTag.TokenCount += tokenCount; } private void _untrain(TagData tag, IEnumerable tokens) { foreach (var token in tokens) { var count = tag.Get(token, 0); if (count > 0) { if (Math.Abs(count - 1) < Tolerance) { tag.Items.Remove(token); } else { tag.Items[token] = count - 1; } tag.TokenCount -= 1; } count = _tags.SystemTag.Get(token, 0); if (count > 0) { if (Math.Abs(count - 1) < Tolerance) { _tags.SystemTag.Items.Remove(token); } else { _tags.SystemTag.Items[token] = count - 1; } _tags.SystemTag.TokenCount -= 1; } } } private static TagData GetAndAddIfNotFound(IDictionary> dic, TTagType key) { if (dic.ContainsKey(key)) { return dic[key]; } dic[key] = new TagData(); return dic[key]; } private TagDictionary CreateCacheAnsGetTags() { if (!_mustRecache) return _cache; _cache = new TagDictionary { SystemTag = _tags.SystemTag }; foreach (var tag in _tags.Items) { var thisTagTokenCount = tag.Value.TokenCount; var otherTagsTokenCount = Math.Max(_tags.SystemTag.TokenCount - thisTagTokenCount, 1); var cachedTag = GetAndAddIfNotFound(_cache.Items, tag.Key); foreach (var systemTagItem in _tags.SystemTag.Items) { var thisTagTokenFreq = tag.Value.Get(systemTagItem.Key, 0.0); if (Math.Abs(thisTagTokenFreq) < Tolerance) { continue; } var otherTagsTokenFreq = systemTagItem.Value - thisTagTokenFreq; var goodMetric = thisTagTokenCount == 0 ? 1.0 : Math.Min(1.0, otherTagsTokenFreq / thisTagTokenCount); var badMetric = Math.Min(1.0, thisTagTokenFreq / otherTagsTokenCount); var f = badMetric / (goodMetric + badMetric); if (Math.Abs(f - 0.5) >= Threshold) { cachedTag.Items[systemTagItem.Key] = Math.Max(Tolerance, Math.Min(1 - Tolerance, f)); } } } _mustRecache = false; return _cache; } private static IEnumerable GetProbabilities(TagData tag, IEnumerable tokens) { var probs = tokens.Where(tag.Items.ContainsKey).Select(t => tag.Items[t]); return probs.OrderByDescending(p => p).Take(2048); } #endregion Private Methods } }