Source code for carate.models.base_model
import torch
from abc import ABC, abstractmethod
[docs]
class Model(torch.nn.Module):
@abstractmethod
def __init__(self, dim: int, num_classes: int, num_features: int) -> None:
super(Model, self).__init__()
self.num_classes = num_classes
self.num_features = num_features
self.dim = dim
[docs]
@abstractmethod
def forward(
self, x: int, edge_index: int, batch: int, edge_weight=None
) -> torch.Tensor:
pass # pragma: no cover