initial commit with sqlite support! hooray
This commit is contained in:
parent
9b2758b7fd
commit
1eb2ef5ba6
|
@ -0,0 +1,3 @@
|
|||
[submodule "tgbarebot"]
|
||||
path = tgbarebot
|
||||
url = git@git.dekedin.me:raphy/tgbarebot.git
|
|
@ -0,0 +1,118 @@
|
|||
import logging
|
||||
import configparser
|
||||
import sys
|
||||
|
||||
# from tgbarebot.bot import Bot
|
||||
from tgbarebot.bot import Bot
|
||||
from src.markov import MarkovLite, MarkovPost, Markov
|
||||
|
||||
|
||||
# logging.basicConfig(filename='debug.log', level=logging.DEBUG)
|
||||
logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
|
||||
|
||||
|
||||
def hi(bot, text, chat):
|
||||
return "hi."
|
||||
|
||||
|
||||
def debug(bot, text, chat):
|
||||
return str(chat)
|
||||
|
||||
|
||||
def makestore(config, markov: Markov):
|
||||
try:
|
||||
groups = config["AccessControl"]["groups"].split(",")
|
||||
|
||||
def groupcheckstore(bot, text, chat):
|
||||
|
||||
grp = chat["chat"]["id"]
|
||||
|
||||
if str(grp) in groups:
|
||||
markov.store(text, grp)
|
||||
|
||||
return ""
|
||||
|
||||
return groupcheckstore
|
||||
|
||||
except KeyError:
|
||||
|
||||
def groupnocheck(bot, text, chat):
|
||||
|
||||
markov.store(text, 0) # set group as 0 as a global value
|
||||
|
||||
return ""
|
||||
|
||||
return groupnocheck
|
||||
|
||||
|
||||
def makegen(config, markov: Markov):
|
||||
|
||||
try:
|
||||
groups = config["AccessControl"]["groups"].split(",")
|
||||
|
||||
def groupcheckgen(bot, text, chat):
|
||||
|
||||
grp = chat["chat"]["id"]
|
||||
|
||||
if str(grp) in groups:
|
||||
|
||||
try: # FUCK IT, UNDOCUMENTED FEATURE
|
||||
num = int(text.split(" ")[1])
|
||||
|
||||
res = ""
|
||||
for _ in range(num):
|
||||
res += " " + markov.generate(grp)
|
||||
return str(res)
|
||||
|
||||
except:
|
||||
return markov.generate(grp)
|
||||
else:
|
||||
return ""
|
||||
|
||||
return groupcheckgen
|
||||
|
||||
except KeyError:
|
||||
|
||||
def groupnocheck(bot, text, chat):
|
||||
|
||||
return markov.generate(0) # set group as 0 as a global value
|
||||
|
||||
return groupnocheck
|
||||
|
||||
|
||||
def main():
|
||||
print("starting bot lmao")
|
||||
|
||||
# read config
|
||||
config = configparser.ConfigParser()
|
||||
try:
|
||||
config.read("conf.ini")
|
||||
except:
|
||||
print("Couldn't read config file. Exiting")
|
||||
exit()
|
||||
|
||||
# Get token
|
||||
try:
|
||||
token = config["Telegram"]["token"]
|
||||
except KeyError:
|
||||
print("No TOKEN in conf. Exiting.")
|
||||
exit()
|
||||
|
||||
# Get working mode. Defaults to SQLITE3
|
||||
try:
|
||||
if config["Database"]["backend"] == "PostgreSQL":
|
||||
pass # TODO
|
||||
markov = MarkovLite()
|
||||
except KeyError:
|
||||
markov = MarkovLite(inmemory=False) # starting temporarily in memory
|
||||
|
||||
store = makestore(config, markov)
|
||||
gen = makegen(config, markov)
|
||||
|
||||
commands = {"/start": hi, "/debug": debug, "/gen": gen}
|
||||
|
||||
b = Bot(token, commands, fallback=store)
|
||||
b.poll()
|
||||
|
||||
|
||||
main()
|
|
@ -0,0 +1,205 @@
|
|||
import sqlite3
|
||||
|
||||
from sqlite3 import Error, Connection
|
||||
from sqlite3.dbapi2 import Cursor
|
||||
|
||||
import logging as log
|
||||
|
||||
# import re
|
||||
import random
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Tuple
|
||||
|
||||
|
||||
def processwords(text: str) -> List[Tuple[str, str]]:
|
||||
|
||||
# words = re.findall(r"\w+|[^\w\s]", text, re.UNICODE)
|
||||
|
||||
words: List[Any] = text.split(" ")
|
||||
# Add a NONE at the end of the list, to mark the end of a message AND the beginning of one.
|
||||
words.append(None)
|
||||
words.insert(0, None)
|
||||
|
||||
return list(zip(words, words[1:]))
|
||||
|
||||
|
||||
def makemarkov(worddict: Dict[str | None, List[Tuple[str, int]]]) -> str:
|
||||
|
||||
# Initial value is None.
|
||||
|
||||
# curword = random.choice(list(worddict.keys()))
|
||||
|
||||
curword = None
|
||||
|
||||
res = ""
|
||||
|
||||
# implement basic memoization for frequency dict
|
||||
|
||||
mem = {}
|
||||
|
||||
while True: # We just break when curword is NONE again.
|
||||
# append first, then pick next word
|
||||
|
||||
|
||||
# build frequency dict if not memoize
|
||||
if curword not in mem:
|
||||
wordlist = worddict[curword]
|
||||
l = len(wordlist)
|
||||
candidates = []
|
||||
weights_list = []
|
||||
sum = 0
|
||||
|
||||
for word, weight in wordlist:
|
||||
candidates.append(word)
|
||||
weights_list.append(weight)
|
||||
sum += weight
|
||||
|
||||
weights = list(map(lambda x: x / l, weights_list))
|
||||
|
||||
# memoize candidates and weights
|
||||
|
||||
mem[curword] = (candidates, weights)
|
||||
|
||||
# if they're memoize, pick random
|
||||
candidates, weights = mem[curword]
|
||||
|
||||
pick = random.choices(candidates, weights)
|
||||
|
||||
curword = pick[0]
|
||||
|
||||
if curword is None:
|
||||
break
|
||||
else:
|
||||
res += " " + curword
|
||||
|
||||
return res
|
||||
|
||||
|
||||
class Markov(ABC):
|
||||
def __init__(self):
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def store(self, text: str, grp: int):
|
||||
"""Store a message in a database. The telegram grpup id is used to identify the group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def generate(self, grp: int) -> str:
|
||||
"""Generate a random message from said group."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def status(self) -> str:
|
||||
"""Print overall status of the database."""
|
||||
pass
|
||||
|
||||
|
||||
class MarkovLite(Markov):
|
||||
def __init__(self, db_file="markov.db", inmemory=False):
|
||||
|
||||
log.info(sqlite3.version)
|
||||
|
||||
conn = self.create_connection(db_file, inmemory)
|
||||
|
||||
if conn == None:
|
||||
log.error("Could not connect to database.")
|
||||
raise ValueError
|
||||
|
||||
self.conn: Connection = conn
|
||||
|
||||
self.create_tables()
|
||||
|
||||
def store(self, text: str, grp: int):
|
||||
log.debug(f"storing {text} from group {grp}")
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
for word_pair in processwords(text):
|
||||
self.insert_words(cursor, word_pair, grp)
|
||||
|
||||
log.debug("committing changes...")
|
||||
|
||||
self.conn.commit()
|
||||
|
||||
def generate(self, grp: int) -> str:
|
||||
log.debug(f"generated a string from group : {grp}")
|
||||
|
||||
worddict = self.get_worddict(grp)
|
||||
|
||||
return makemarkov(worddict)
|
||||
|
||||
def status(self):
|
||||
return "All ok! im good man"
|
||||
|
||||
def create_connection(self, db_file: str, inmemory: bool) -> Connection | None:
|
||||
"""create a database connection to a SQLite database"""
|
||||
conn = None
|
||||
try:
|
||||
if inmemory:
|
||||
conn = sqlite3.connect(":memory:")
|
||||
else:
|
||||
conn = sqlite3.connect(db_file)
|
||||
except Error as e:
|
||||
log.error(e)
|
||||
finally:
|
||||
return conn
|
||||
|
||||
def create_tables(self):
|
||||
"""create a words table"""
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
words_table = f"""CREATE TABLE IF NOT EXISTS words (
|
||||
group_id integer NOT NULL,
|
||||
word text,
|
||||
next_word text,
|
||||
count integer NOT NULL,
|
||||
WITHOUT ROWID,
|
||||
PRIMARY KEY (group_id, word, next_word)
|
||||
);
|
||||
"""
|
||||
|
||||
try:
|
||||
|
||||
log.info("Creating words tables into databse")
|
||||
|
||||
cursor.execute(words_table)
|
||||
|
||||
log.debug("committing changes...")
|
||||
self.conn.commit()
|
||||
except Error as e:
|
||||
log.error("Could not create table:")
|
||||
log.error(e)
|
||||
|
||||
def insert_words(self, cursor: Cursor, word_pair: Tuple[str, str], grp: int):
|
||||
|
||||
log.debug(f"Inserting words {word_pair} for group with id {grp}")
|
||||
|
||||
sql = """INSERT INTO words(group_id, word, next_word, count)
|
||||
VALUES(?,?,?,?) ON CONFLICT(group_id, word, next_word) DO UPDATE SET count = count + 1
|
||||
"""
|
||||
|
||||
parameters = (grp, word_pair[0], word_pair[1], 1)
|
||||
|
||||
cursor.execute(sql, parameters)
|
||||
|
||||
def get_worddict(self, grp: int):
|
||||
|
||||
cursor = self.conn.cursor()
|
||||
|
||||
cursor.execute(f"SELECT * FROM words WHERE group_id={grp}")
|
||||
|
||||
worddict = {}
|
||||
for _, word, nextword, count, _ in cursor.fetchall():
|
||||
if not word in worddict:
|
||||
worddict[word] = [(nextword, count)]
|
||||
else:
|
||||
worddict[word] += [(nextword, count)]
|
||||
|
||||
return worddict
|
||||
|
||||
|
||||
class MarkovPost(Markov):
|
||||
def __init__(self):
|
||||
pass
|
|
@ -0,0 +1 @@
|
|||
Subproject commit 977491a3f60abf489dcab67d4aaa77e9d246683e
|
Loading…
Reference in New Issue