From ef71809c9d1265f5e6bc259eb21aad9530b95768 Mon Sep 17 00:00:00 2001 From: mrAdiletDev Date: Sat, 19 Aug 2023 23:29:53 +0600 Subject: [PATCH] =?UTF-8?q?add=20views=20merging=20=F0=9F=94=A5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- apps/products/views.py | 55 +++++++++---------- hogwarts/magic_urls/gen_urls.py | 2 +- hogwarts/magic_views/gen_imports.py | 24 +++++++- hogwarts/magic_views/gen_views.py | 47 ++++++++++++++-- hogwarts/management/commands/genviews.py | 14 ++++- hogwarts/tests/url_tests/genurls_tests.py | 3 +- hogwarts/tests/view_tests/gen_import_tests.py | 38 +++++++++++++ hogwarts/utils.py | 31 +++++++++++ 8 files changed, 174 insertions(+), 40 deletions(-) create mode 100644 hogwarts/tests/view_tests/gen_import_tests.py diff --git a/apps/products/views.py b/apps/products/views.py index 9aca725..66600f2 100644 --- a/apps/products/views.py +++ b/apps/products/views.py @@ -1,8 +1,33 @@ -from .models import Category from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin from django.shortcuts import reverse from django.views.generic import DetailView, ListView, CreateView, UpdateView -from .models import Product +from .models import Category, Product + + +class CategoryDetailView(DetailView): + model = Category + context_object_name = "category" + template_name = "categories/category_detail.html" + + +class CategoryListView(ListView): + model = Category + context_object_name = "categories" + template_name = "categories/category_list.html" + + +class CategoryCreateView(LoginRequiredMixin, CreateView): + model = Category + fields = ["id", "name"] + template_name = "categories/category_create.html" + success_url = "/" + + +class CategoryUpdateView(LoginRequiredMixin, UpdateView): + model = Category + fields = ["id", "name"] + template_name = "categories/category_update.html" + success_url = "/" class ProductDetailView(DetailView): @@ -40,29 +65,3 @@ def test_func(self): def get_success_url(self): return reverse("products:detail", args=[self.get_object().id]) - - -class CategoryDetailView(DetailView): - model = Category - context_object_name = "category" - template_name = "categories/category_detail.html" - - -class CategoryListView(ListView): - model = Category - context_object_name = "categories" - template_name = "categories/category_list.html" - - -class CategoryCreateView(LoginRequiredMixin, CreateView): - model = Category - fields = ["id", "name"] - template_name = "categories/category_create.html" - success_url = "/" - - -class CategoryUpdateView(UserPassesTestMixin, UpdateView): - model = Category - fields = ["id", "name"] - template_name = "categories/category_update.html" - success_url = "/" diff --git a/hogwarts/magic_urls/gen_urls.py b/hogwarts/magic_urls/gen_urls.py index fdfb153..6384750 100644 --- a/hogwarts/magic_urls/gen_urls.py +++ b/hogwarts/magic_urls/gen_urls.py @@ -95,7 +95,7 @@ def merge_urlpatterns(self, urlpatterns, views): return urlpatterns -def gen_path(view, app_name, from_view_file: bool) -> str: +def gen_path(view, app_name, from_view_file=False) -> str: decorator = PathDecorator(view) if decorator.exists(): path_name = decorator.get_path_name() diff --git a/hogwarts/magic_views/gen_imports.py b/hogwarts/magic_views/gen_imports.py index 8dbe030..c83c311 100644 --- a/hogwarts/magic_views/gen_imports.py +++ b/hogwarts/magic_views/gen_imports.py @@ -1,3 +1,4 @@ +import ast from typing import Tuple Imports = list[Tuple[str, str]] @@ -21,8 +22,11 @@ def add_bulk(self, module, objs: list[str]): def gen(self): merged_imports = self.get_merge_imports() result = "" - for module, obj in merged_imports.items(): - result += f"from {module} import {', '.join(obj)}\n" + for module, objs in merged_imports.items(): + if module is None: + result += f"import {', '.join(objs)}\n" + else: + result += f"from {module} import {', '.join(objs)}\n" return result @@ -35,6 +39,22 @@ def get_merge_imports(self): merged_imports[module].append(obj) return merged_imports + def parse_imports(self, code): + imports = [] + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.Import): + for alias in node.names: + self.add(None, alias.name) + elif isinstance(node, ast.ImportFrom): + module = node.module + for alias in node.names: + if module: + self.add(f"{'.' * node.level}{module}", alias.name) + + return imports + @property def imported_classes(self): for _import in self.imports: diff --git a/hogwarts/magic_views/gen_views.py b/hogwarts/magic_views/gen_views.py index 9758ffa..2826ff3 100644 --- a/hogwarts/magic_views/gen_views.py +++ b/hogwarts/magic_views/gen_views.py @@ -1,9 +1,17 @@ -from ..utils import to_plural, code_strip, remove_empty_lines +from typing import Optional + +from ..utils import to_plural, code_strip, remove_empty_lines, parse_class_names, remove_imports from .gen_imports import ViewImportsGenerator class ViewGenerator: - def __init__(self, model, smart_mode=False, model_is_namespace=False): + def __init__( + self, + model, + smart_mode=False, + model_is_namespace=False, + code: Optional[str] = None + ): self.smart_mode = smart_mode self.model_is_namespace = model_is_namespace @@ -11,6 +19,12 @@ def __init__(self, model, smart_mode=False, model_is_namespace=False): self.name = model.__name__.lower() self.fields = [field.name for field in model._meta.fields if field.editable] self.creator_field = None + self.code = code + self.merge = code is not None + self.existing_class_names = [] + + if self.code: + self.existing_class_names = parse_class_names(self.code) for field in self.fields: if field in ["user", "author", "owner", "creator"]: @@ -28,18 +42,28 @@ def gen(self): update = self.update() result = code_strip(self.gen_imports()) + if self.merge: + result += "\n" + result += remove_imports(self.code) for view in [detail, _list, create, update]: - result += f"\n\n{code_strip(view)}" + if view is not None: + result += f"\n\n{code_strip(view)}" return result def gen_imports(self): + if self.merge: + self.imports_generator.parse_imports(self.code) + self.imports_generator.add_bulk("django.views.generic", self.generic_views) self.imports_generator.add(".models", self.model_name) return self.imports_generator.gen() def create(self): + if self.merge and f"{self.model_name}CreateView" in self.existing_class_names: + return None + self.generic_views.append("CreateView") builder = self.get_builder("create") @@ -71,21 +95,26 @@ def get_success_url(self): return builder.gen() def update(self): + if self.merge and f"{self.model_name}UpdateView" in self.existing_class_names: + return None + self.generic_views.append("UpdateView") builder = self.get_builder("update") builder.set_fields(self.fields) self.set_template(builder, "update") if self.smart_mode: - self.imports_generator.add_user_test() - builder.set_class("update", ["UserPassesTestMixin"]) - if self.creator_field: + self.imports_generator.add_user_test() + builder.set_class("update", ["UserPassesTestMixin"]) function = """ def test_func(self): return self.get_object() == self.request.user """ builder.set_extra_code(code_strip(function)) + else: + self.imports_generator.add_login_required() + builder.set_class("update", ["LoginRequiredMixin"]) if self.model_is_namespace: self.imports_generator.add_reverse() @@ -100,6 +129,9 @@ def get_success_url(self): return builder.gen() def detail(self): + if self.merge and f"{self.model_name}DetailView" in self.existing_class_names: + return None + self.generic_views.append("DetailView") builder = self.get_builder("detail") builder.set_context_object_name(self.name) @@ -108,6 +140,9 @@ def detail(self): return builder.gen() def list(self): + if self.merge and f"{self.model_name}ListView" in self.existing_class_names: + return None + self.generic_views.append("ListView") builder = self.get_builder("list") builder.set_context_object_name(to_plural(self.name)) diff --git a/hogwarts/management/commands/genviews.py b/hogwarts/management/commands/genviews.py index 86f13b8..6b2ff41 100644 --- a/hogwarts/management/commands/genviews.py +++ b/hogwarts/management/commands/genviews.py @@ -4,6 +4,7 @@ from .base import get_app_config from hogwarts.magic_views import ViewGenerator +from ...utils import parse_class_names class Command(BaseCommand): @@ -42,9 +43,18 @@ def handle(self, *args, **options): if model_is_namespace or model_name.lower() in app_name: namespace_model = True - code = ViewGenerator(model, smart_mode, namespace_model).gen() - path = os.path.join(app_config.path, "views.py") + + generator = ViewGenerator(model, smart_mode, namespace_model) + + with open(path, "r") as file: + existing_code = file.read() + is_empty = len(parse_class_names(existing_code)) == 0 + if not is_empty: + generator = ViewGenerator(model, smart_mode, namespace_model, code=existing_code) + + code = generator.gen() + with open(path, 'w') as file: file.write(code) diff --git a/hogwarts/tests/url_tests/genurls_tests.py b/hogwarts/tests/url_tests/genurls_tests.py index f904559..2912d1c 100644 --- a/hogwarts/tests/url_tests/genurls_tests.py +++ b/hogwarts/tests/url_tests/genurls_tests.py @@ -70,6 +70,7 @@ def test_it_generates_imports(generator): expected = """ from django.urls import path - from .views import MyListView, MyFormView, get_view""" + from .views import MyListView, MyFormView, get_view + """ assert code_strip(result) == code_strip(expected) diff --git a/hogwarts/tests/view_tests/gen_import_tests.py b/hogwarts/tests/view_tests/gen_import_tests.py new file mode 100644 index 0000000..8a15877 --- /dev/null +++ b/hogwarts/tests/view_tests/gen_import_tests.py @@ -0,0 +1,38 @@ +from hogwarts.magic_views.gen_imports import ImportsGenerator +from hogwarts.utils import code_strip + +code = """ +import ast +import datetime + +from django.views.generic import DetailView, ListView +from .forms import MyForm +""" + + +def test_it_generates_imports_from_code(): + gen = ImportsGenerator() + expected = { + None: ["ast", "datetime"], + "django.views.generic": ["DetailView", "ListView"], + ".forms": ["MyForm"], + } + gen.parse_imports(code_strip(code)) + + assert gen.get_merge_imports() == expected + + +def test_it_merges_import_with_code(): + gen = ImportsGenerator() + + expected_code = """ + import ast, datetime + from django.views.generic import DetailView, ListView, UpdateView + from .forms import MyForm, YouForm + """ + + gen.parse_imports(code_strip(code)) + gen.add("django.views.generic", "UpdateView") + gen.add(".forms", "YouForm") + + assert gen.gen() == code_strip(expected_code) diff --git a/hogwarts/utils.py b/hogwarts/utils.py index 9aa94ce..8d5ba9b 100644 --- a/hogwarts/utils.py +++ b/hogwarts/utils.py @@ -63,3 +63,34 @@ def remove_empty_lines(text: str): lines = text.splitlines() non_empty_lines = [line for line in lines if line.strip()] # DON'T TOUCH THIS LINE return '\n'.join(non_empty_lines) + + +def parse_class_names(code): + class_names = [] + + tree = ast.parse(code) + + for node in ast.walk(tree): + if isinstance(node, ast.ClassDef): + class_names.append(node.name) + + return class_names + + +def remove_imports(code): + lines = code.split('\n') + new_lines = [] + inside_import_block = False + + for line in lines: + if line.strip().startswith("import ") or line.strip().startswith("from "): + inside_import_block = True + elif inside_import_block and not line.strip(): + inside_import_block = False + continue + + if not inside_import_block: + new_lines.append(line) + + new_code = '\n'.join(new_lines) + return new_code