diff --git a/aitrain.py b/aitrain.py index c4c0ca7..6867bda 100644 --- a/aitrain.py +++ b/aitrain.py @@ -2,17 +2,23 @@ from textgenrnn import textgenrnn import discord from discord.ext import commands import regex as re +import functools textgen = textgenrnn(name="insert3") client = commands.Bot(command_prefix='.') -boton = 0 +boton = False + +def add_to_train(clean_content): + print(clean_content) + with open("train.txt", "a") as train: + train.write(f"{clean_content} \n") @client.event async def on_message(message): await client.process_commands(message) urls = re.findall('http[s]?://(?:[a-zA-Z]|[0-9]|[$-_@.&+]|[!*(),]|(?:%[0-9a-fA-F][0-9a-fA-F]))+',message.content.lower()) if not message.attachments: - if boton == 1: + if boton: if message.content.lower().startswith("$"): pass elif urls: @@ -24,23 +30,25 @@ async def on_message(message): elif message.content.lower().startswith("—"): pass else: - print(message.clean_content) - train = open("train.txt", "a") - train.write(f"{message.clean_content} \n") - train.close() + async with message.channel.typing: + writing_function = functools.partial(add_to_train, message.clean_content) + await bot.loop.run_in_executor(None, writing_function) +@commands.is_owner() @client.command() async def train(ctx): - if ctx.author.id == 666378959184855042: - global boton - boton = 1 - await ctx.send("I am now training what to say based on your messages") + global boton + boton = True + await ctx.send("I am now training what to say based on your messages") +@commands.is_owner() @client.command() async def stopbot(ctx): - if ctx.author.id == 666378959184855042: - await ctx.send("data collection done, I will now log of discord and build an a.i") - textgen.train_from_file('train.txt', num_epochs=11) + async with ctx.typing: + training_function = functools.partial(textgen.train_from_file, 'train.txt', num_epochs=11) + await bot.loop.run_in_executor(None, training_function) + await ctx.send("data collection done, I will now log of discord and build an a.i") + await bot.logout() client.run('BOTTOKEN')