Module sktmls.models.contrib.sample_pytorch_model

Classes

class SamplePyTorchModel (model, model_name: str, model_version: str, features: List[str])

MLS 모델 레지스트리에 등록되는 PyTorch 기반 샘플 클래스입니다.

PyTorch로 학습한 모델을 torch.jit.ScriptMoudle형태로 변환 후 MLS 모델로 저장합니다.

Args

  • model: PyTorch로 학습 후 torch.jit.ScriptMoudle형태로 변환한 객체 (torch.jit.ScriptMoudle로 변환 전 기존 torch 모델 device를 cpu로 변환해야 합니다.)
  • model_name: (str) 모델 이름
  • model_version: (str) 모델 버전
  • features: (list(str)) 학습에 사용된 피쳐 리스트

Example

model           # 학습이 완료된 PyTorch 모델 (상위 클래스 : `torch.nn.Module`)
tensor_sample   # `model`의 input shape에 맞는 sample용 `torch.Tensor`

script_model = torch.jit.trace(model.cpu(), tensor_sample)

my_mls_torch_model = PytorchSampleModel(
    model=script_model,
    model_name="my_model",
    model_version="v1",
    features=["feature1", "feature2", "feature3", "feature4"],
)

result = my_mls_torch_model.predict(["value_1", "value_2", "value_3", "value_4"])

Ancestors

Methods

def predict(self, x: List[Any], **kwargs) ‑> Dict[str, Any]

Inherited members