Skip to content

Commit

Permalink
add views merging 🔥
Browse files Browse the repository at this point in the history
  • Loading branch information
adiletto64 committed Aug 19, 2023
1 parent 91cb291 commit ef71809
Show file tree
Hide file tree
Showing 8 changed files with 174 additions and 40 deletions.
55 changes: 27 additions & 28 deletions apps/products/views.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,33 @@
from .models import Category
from django.contrib.auth.mixins import LoginRequiredMixin, UserPassesTestMixin
from django.shortcuts import reverse
from django.views.generic import DetailView, ListView, CreateView, UpdateView
from .models import Product
from .models import Category, Product


class CategoryDetailView(DetailView):
model = Category
context_object_name = "category"
template_name = "categories/category_detail.html"


class CategoryListView(ListView):
model = Category
context_object_name = "categories"
template_name = "categories/category_list.html"


class CategoryCreateView(LoginRequiredMixin, CreateView):
model = Category
fields = ["id", "name"]
template_name = "categories/category_create.html"
success_url = "/"


class CategoryUpdateView(LoginRequiredMixin, UpdateView):
model = Category
fields = ["id", "name"]
template_name = "categories/category_update.html"
success_url = "/"


class ProductDetailView(DetailView):
Expand Down Expand Up @@ -40,29 +65,3 @@ def test_func(self):

def get_success_url(self):
return reverse("products:detail", args=[self.get_object().id])


class CategoryDetailView(DetailView):
model = Category
context_object_name = "category"
template_name = "categories/category_detail.html"


class CategoryListView(ListView):
model = Category
context_object_name = "categories"
template_name = "categories/category_list.html"


class CategoryCreateView(LoginRequiredMixin, CreateView):
model = Category
fields = ["id", "name"]
template_name = "categories/category_create.html"
success_url = "/"


class CategoryUpdateView(UserPassesTestMixin, UpdateView):
model = Category
fields = ["id", "name"]
template_name = "categories/category_update.html"
success_url = "/"
2 changes: 1 addition & 1 deletion hogwarts/magic_urls/gen_urls.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def merge_urlpatterns(self, urlpatterns, views):
return urlpatterns


def gen_path(view, app_name, from_view_file: bool) -> str:
def gen_path(view, app_name, from_view_file=False) -> str:
decorator = PathDecorator(view)
if decorator.exists():
path_name = decorator.get_path_name()
Expand Down
24 changes: 22 additions & 2 deletions hogwarts/magic_views/gen_imports.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import ast
from typing import Tuple

Imports = list[Tuple[str, str]]
Expand All @@ -21,8 +22,11 @@ def add_bulk(self, module, objs: list[str]):
def gen(self):
merged_imports = self.get_merge_imports()
result = ""
for module, obj in merged_imports.items():
result += f"from {module} import {', '.join(obj)}\n"
for module, objs in merged_imports.items():
if module is None:
result += f"import {', '.join(objs)}\n"
else:
result += f"from {module} import {', '.join(objs)}\n"

return result

Expand All @@ -35,6 +39,22 @@ def get_merge_imports(self):
merged_imports[module].append(obj)
return merged_imports

def parse_imports(self, code):
imports = []
tree = ast.parse(code)

for node in ast.walk(tree):
if isinstance(node, ast.Import):
for alias in node.names:
self.add(None, alias.name)
elif isinstance(node, ast.ImportFrom):
module = node.module
for alias in node.names:
if module:
self.add(f"{'.' * node.level}{module}", alias.name)

return imports

@property
def imported_classes(self):
for _import in self.imports:
Expand Down
47 changes: 41 additions & 6 deletions hogwarts/magic_views/gen_views.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
from ..utils import to_plural, code_strip, remove_empty_lines
from typing import Optional

from ..utils import to_plural, code_strip, remove_empty_lines, parse_class_names, remove_imports
from .gen_imports import ViewImportsGenerator


class ViewGenerator:
def __init__(self, model, smart_mode=False, model_is_namespace=False):
def __init__(
self,
model,
smart_mode=False,
model_is_namespace=False,
code: Optional[str] = None
):
self.smart_mode = smart_mode
self.model_is_namespace = model_is_namespace

self.model_name = model.__name__
self.name = model.__name__.lower()
self.fields = [field.name for field in model._meta.fields if field.editable]
self.creator_field = None
self.code = code
self.merge = code is not None
self.existing_class_names = []

if self.code:
self.existing_class_names = parse_class_names(self.code)

for field in self.fields:
if field in ["user", "author", "owner", "creator"]:
Expand All @@ -28,18 +42,28 @@ def gen(self):
update = self.update()

result = code_strip(self.gen_imports())
if self.merge:
result += "\n"
result += remove_imports(self.code)

for view in [detail, _list, create, update]:
result += f"\n\n{code_strip(view)}"
if view is not None:
result += f"\n\n{code_strip(view)}"
return result

def gen_imports(self):
if self.merge:
self.imports_generator.parse_imports(self.code)

self.imports_generator.add_bulk("django.views.generic", self.generic_views)
self.imports_generator.add(".models", self.model_name)

return self.imports_generator.gen()

def create(self):
if self.merge and f"{self.model_name}CreateView" in self.existing_class_names:
return None

self.generic_views.append("CreateView")
builder = self.get_builder("create")

Expand Down Expand Up @@ -71,21 +95,26 @@ def get_success_url(self):
return builder.gen()

def update(self):
if self.merge and f"{self.model_name}UpdateView" in self.existing_class_names:
return None

self.generic_views.append("UpdateView")
builder = self.get_builder("update")
builder.set_fields(self.fields)
self.set_template(builder, "update")

if self.smart_mode:
self.imports_generator.add_user_test()
builder.set_class("update", ["UserPassesTestMixin"])

if self.creator_field:
self.imports_generator.add_user_test()
builder.set_class("update", ["UserPassesTestMixin"])
function = """
def test_func(self):
return self.get_object() == self.request.user
"""
builder.set_extra_code(code_strip(function))
else:
self.imports_generator.add_login_required()
builder.set_class("update", ["LoginRequiredMixin"])

if self.model_is_namespace:
self.imports_generator.add_reverse()
Expand All @@ -100,6 +129,9 @@ def get_success_url(self):
return builder.gen()

def detail(self):
if self.merge and f"{self.model_name}DetailView" in self.existing_class_names:
return None

self.generic_views.append("DetailView")
builder = self.get_builder("detail")
builder.set_context_object_name(self.name)
Expand All @@ -108,6 +140,9 @@ def detail(self):
return builder.gen()

def list(self):
if self.merge and f"{self.model_name}ListView" in self.existing_class_names:
return None

self.generic_views.append("ListView")
builder = self.get_builder("list")
builder.set_context_object_name(to_plural(self.name))
Expand Down
14 changes: 12 additions & 2 deletions hogwarts/management/commands/genviews.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

from .base import get_app_config
from hogwarts.magic_views import ViewGenerator
from ...utils import parse_class_names


class Command(BaseCommand):
Expand Down Expand Up @@ -42,9 +43,18 @@ def handle(self, *args, **options):
if model_is_namespace or model_name.lower() in app_name:
namespace_model = True

code = ViewGenerator(model, smart_mode, namespace_model).gen()

path = os.path.join(app_config.path, "views.py")

generator = ViewGenerator(model, smart_mode, namespace_model)

with open(path, "r") as file:
existing_code = file.read()
is_empty = len(parse_class_names(existing_code)) == 0
if not is_empty:
generator = ViewGenerator(model, smart_mode, namespace_model, code=existing_code)

code = generator.gen()

with open(path, 'w') as file:
file.write(code)

Expand Down
3 changes: 2 additions & 1 deletion hogwarts/tests/url_tests/genurls_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ def test_it_generates_imports(generator):
expected = """
from django.urls import path
from .views import MyListView, MyFormView, get_view"""
from .views import MyListView, MyFormView, get_view
"""

assert code_strip(result) == code_strip(expected)
38 changes: 38 additions & 0 deletions hogwarts/tests/view_tests/gen_import_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from hogwarts.magic_views.gen_imports import ImportsGenerator
from hogwarts.utils import code_strip

code = """
import ast
import datetime
from django.views.generic import DetailView, ListView
from .forms import MyForm
"""


def test_it_generates_imports_from_code():
gen = ImportsGenerator()
expected = {
None: ["ast", "datetime"],
"django.views.generic": ["DetailView", "ListView"],
".forms": ["MyForm"],
}
gen.parse_imports(code_strip(code))

assert gen.get_merge_imports() == expected


def test_it_merges_import_with_code():
gen = ImportsGenerator()

expected_code = """
import ast, datetime
from django.views.generic import DetailView, ListView, UpdateView
from .forms import MyForm, YouForm
"""

gen.parse_imports(code_strip(code))
gen.add("django.views.generic", "UpdateView")
gen.add(".forms", "YouForm")

assert gen.gen() == code_strip(expected_code)
31 changes: 31 additions & 0 deletions hogwarts/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,3 +63,34 @@ def remove_empty_lines(text: str):
lines = text.splitlines()
non_empty_lines = [line for line in lines if line.strip()] # DON'T TOUCH THIS LINE
return '\n'.join(non_empty_lines)


def parse_class_names(code):
class_names = []

tree = ast.parse(code)

for node in ast.walk(tree):
if isinstance(node, ast.ClassDef):
class_names.append(node.name)

return class_names


def remove_imports(code):
lines = code.split('\n')
new_lines = []
inside_import_block = False

for line in lines:
if line.strip().startswith("import ") or line.strip().startswith("from "):
inside_import_block = True
elif inside_import_block and not line.strip():
inside_import_block = False
continue

if not inside_import_block:
new_lines.append(line)

new_code = '\n'.join(new_lines)
return new_code

0 comments on commit ef71809

Please sign in to comment.