Add files via upload

This commit is contained in:
NodeMixaholic 2023-01-14 17:56:09 -08:00 committed by GitHub
parent 2cae5e6c12
commit 3586a17c2c
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23

View file

@ -1,16 +1,81 @@
import tensorflow as tf
import tensorflow_hub as hub
import numpy as np
import pyautogui import pyautogui
import win32api, win32con, win32gui import win32api, win32con, win32gui
import cv2 import cv2
import math import math
import time import time
import argparse
import os
import json
import numpy as np
from threading import Lock
# printing only warnings and error messages
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "1"
try:
import tensorflow as tf
from PIL import Image
except ImportError:
raise ImportError("ERROR: Failed to import libraries. Please refer to READEME.md file\n")
EXPORT_MODEL_VERSION = 1
class TFModel:
def __init__(self, dir_path) -> None:
# Assume model is in the parent directory for this file
self.model_dir = os.path.dirname(dir_path)
# make sure our exported SavedModel folder exists
with open(os.path.join(self.model_dir, "signature.json"), "r") as f:
self.signature = json.load(f)
self.model_file = os.path.join(self.model_dir, self.signature.get("filename"))
if not os.path.isfile(self.model_file):
raise FileNotFoundError(f"Model file does not exist")
self.inputs = self.signature.get("inputs")
self.outputs = self.signature.get("outputs")
self.lock = Lock()
# loading the saved model
self.model = tf.saved_model.load(tags=self.signature.get("tags"), export_dir=self.model_dir)
self.predict_fn = self.model.signatures["serving_default"]
# Look for the version in signature file.
# If it's not found or the doesn't match expected, print a message
version = self.signature.get("export_model_version")
if version is None or version != EXPORT_MODEL_VERSION:
print(
f"There has been a change to the model format. Please use a model with a signature 'export_model_version' that matches {EXPORT_MODEL_VERSION}."
)
def predict(self, image: Image.Image) -> dict:
with self.lock:
# create the feed dictionary that is the input to the model
feed_dict = {}
# first, add our image to the dictionary (comes from our signature.json file)
feed_dict[list(self.inputs.keys())[0]] = tf.convert_to_tensor(image)
# run the model!
outputs = self.predict_fn(**feed_dict)
# return the processed output
return self.process_output(outputs)
def process_output(self, outputs) -> dict:
# do a bit of postprocessing
out_keys = ["label", "confidence"]
results = {}
# since we actually ran on a batch of size 1, index out the items from the returned numpy arrays
for key, tf_val in outputs.items():
val = tf_val.numpy().tolist()[0]
if isinstance(val, bytes):
val = val.decode()
results[key] = val
confs = results["Confidences"]
labels = self.signature.get("classes").get("Label")
output = [dict(zip(out_keys, group)) for group in zip(labels, confs)]
sorted_output = {"predictions": sorted(output, key=lambda k: k["confidence"], reverse=True)}
return sorted_output
dirname = os.path.dirname(__file__)
detector = tf.saved_model.load(dirname)
size_scale = 3 size_scale = 3
model = TFModel(dir_path=os.path.dirname(__file__))
while True: while True:
# Get rect of Window # Get rect of Window
hwnd = win32gui.FindWindow(None, 'Roblox') hwnd = win32gui.FindWindow(None, 'Roblox')
@ -25,7 +90,7 @@ while True:
img_w, img_h = image.shape[2], image.shape[1] img_w, img_h = image.shape[2], image.shape[1]
# Detection # Detection
result = detector(image) outputs = model.predict(image)
result = {key:value.numpy() for key,value in result.items()} result = {key:value.numpy() for key,value in result.items()}
boxes = result['detection_boxes'][0] boxes = result['detection_boxes'][0]
scores = result['detection_scores'][0] scores = result['detection_scores'][0]
@ -75,8 +140,4 @@ while True:
time.sleep(0.1) time.sleep(0.1)
win32api.mouse_event(win32con.MOUSEEVENTF_LEFTUP, x, y, 0, 0) win32api.mouse_event(win32con.MOUSEEVENTF_LEFTUP, x, y, 0, 0)
#ori_img = cv2.cvtColor(ori_img, cv2.COLOR_BGR2RGB)
#cv2.imshow("ori_img", ori_img)
#cv2.waitKey(1)
time.sleep(0.1) time.sleep(0.1)