Skip to content

Commit

Permalink
Update per PR comments #1228
Browse files Browse the repository at this point in the history
Reference: #1228

Signed-off-by: John M. Horan <johnmhoran@gmail.com>
  • Loading branch information
johnmhoran committed Nov 22, 2023
1 parent f1a0530 commit 9ec2a6a
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 21 deletions.
26 changes: 12 additions & 14 deletions vulnerabilities/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,17 +53,21 @@ def get_affected_vulnerabilities(self, package):
affected_vulnerabilities = []

for vuln in parent_affected_vulnerabilities:
self.get_vulnerability(vuln, affected_vulnerabilities)
# self.get_vulnerability(vuln, affected_vulnerabilities)
# affected_vulnerabilities.append(self.get_vulnerability(vuln, affected_vulnerabilities))
affected_vulnerabilities.append(self.get_vulnerability(vuln))

return affected_vulnerabilities

def get_vulnerability(self, vuln, affected_vulnerabilities):
# def get_vulnerability(self, vuln, affected_vulnerabilities):
def get_vulnerability(self, vuln):
affected_vulnerability = {}

vulnerability = vuln.get("vulnerability")
if vulnerability:
affected_vulnerability["vulnerability"] = vulnerability.vulnerability_id
affected_vulnerabilities.append(affected_vulnerability)
# affected_vulnerabilities.append(affected_vulnerability)
return affected_vulnerability

affected_by_vulnerabilities = serializers.SerializerMethodField("get_affected_vulnerabilities")

Expand Down Expand Up @@ -150,17 +154,13 @@ def get_next_non_vulnerable(self, package):
next_non_vulnerable = package.fixed_package_details.get("next_non_vulnerable", None)
if next_non_vulnerable:
return next_non_vulnerable.version
else:
return None

latest_non_vulnerable_version = serializers.SerializerMethodField("get_latest_non_vulnerable")

def get_latest_non_vulnerable(self, package):
latest_non_vulnerable = package.fixed_package_details.get("latest_non_vulnerable", None)
if latest_non_vulnerable:
return latest_non_vulnerable.version
else:
return None

purl = serializers.CharField(source="package_url")

Expand Down Expand Up @@ -215,21 +215,19 @@ def get_affected_vulnerabilities(self, package) -> dict:
fix each vulnerability and whose version is greater than the `package` version).
"""
excluded_purls = []
filtered_vuln_serializer = self.get_vulnerabilities_for_a_package(
package=package, fix=False
)
package_vulnerabilities = self.get_vulnerabilities_for_a_package(package=package, fix=False)

for vuln in filtered_vuln_serializer:
for vuln in package_vulnerabilities:
for pkg in vuln["fixed_packages"]:
real_PURL = PackageURL.from_string(pkg["purl"])
if package.version_class(real_PURL.version) <= package.current_version:
real_purl = PackageURL.from_string(pkg["purl"])
if package.version_class(real_purl.version) <= package.current_version:
excluded_purls.append(pkg)

vuln["fixed_packages"] = [
pkg for pkg in vuln["fixed_packages"] if pkg not in excluded_purls
]

return filtered_vuln_serializer
return package_vulnerabilities

class Meta:
model = Package
Expand Down
71 changes: 65 additions & 6 deletions vulnerabilities/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ def paginated(self, per_page=5000):


class VulnerabilityQuerySet(BaseQuerySet):
def affected_packages(self):
"""
Return only packages affected by a vulnerability.
"""
return self.filter(packagerelatedvulnerability__fix=False)

def with_cpes(self):
"""
Return a queryset of Vulnerability that have one or more NVD CPE references.
Expand Down Expand Up @@ -644,7 +650,6 @@ def get_absolute_url(self):
"""
return reverse("package_details", args=[self.purl])

# TODO: There are other methods, variables etc. in models.py that need similar renaming.
def get_fixed_by_package_versions(self, fix=True):
"""
Return a queryset of all the package versions of this `package` that fix any vulnerability.
Expand All @@ -670,10 +675,14 @@ def sort_by_version(self, packages):
if not packages:
return []

version_class = RANGE_CLASS_BY_SCHEMES[packages[0].type].version_class
# version_class = RANGE_CLASS_BY_SCHEMES[packages[0].type].version_class
# return sorted(
# packages,
# key=lambda x: version_class(x.version),
# )
return sorted(
packages,
key=lambda x: version_class(x.version),
key=lambda x: self.version_class(x.version),
)

@property
Expand Down Expand Up @@ -739,9 +748,59 @@ def get_affecting_vulnerabilities(self):

fixed_by_packages = self.get_fixed_by_package_versions(fix=True)

package_vulnerabilities = self.vulnerabilities.filter(
packagerelatedvulnerability__fix=False
).prefetch_related(
# This is my original code and it works:

# package_vulnerabilities = self.vulnerabilities.filter(
# packagerelatedvulnerability__fix=False
# ).prefetch_related(
# Prefetch(
# "packages",
# queryset=fixed_by_packages,
# to_attr="fixed_packages",
# )
# )

# package_vulnerabilities = Package.objects.affected()
# package_vulnerabilities.prefetch_related(
# Prefetch(
# "packages",
# queryset=fixed_by_packages,
# to_attr="fixed_packages",
# )
# )

# package_vulnerabilities = Package.objects.affected().prefetch_related(
# Prefetch(
# "package",
# queryset=fixed_by_packages,
# to_attr="fixed_packages",
# )
# )

# package_vulnerabilities = self.vulnerabilities.affected().prefetch_related(
# # package_vulnerabilities = self.vulnerabilities.filter(
# # packagerelatedvulnerability__fix=False
# # ).prefetch_related(
# Prefetch(
# "packages",
# queryset=fixed_by_packages,
# to_attr="fixed_packages",
# )
# )

# package_vulnerabilities = self.vulnerabilities.affected().prefetch_related(
# Prefetch(
# "packages",
# queryset=fixed_by_packages,
# to_attr="fixed_packages",
# )
# )

# package_vulnerabilities = self.vulnerabilities.affected()

# 2023-11-22 Wednesday 08:54:04. Try my new vuln queryset manager method. This works, no failed tests!!!

package_vulnerabilities = self.vulnerabilities.affected_packages().prefetch_related(
Prefetch(
"packages",
queryset=fixed_by_packages,
Expand Down
3 changes: 2 additions & 1 deletion vulnerabilities/tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -399,7 +399,8 @@ def test_api_with_package_with_no_vulnerabilities(self):
}

package_with_no_vulnerabilities = MinimalPackageSerializer.get_vulnerability(
self, vuln, affected_vulnerabilities
self,
vuln,
)

assert package_with_no_vulnerabilities is None
Expand Down

0 comments on commit 9ec2a6a

Please sign in to comment.